diff --git a/ethdb/memorydb/memorydb.go b/ethdb/memorydb/memorydb.go index f9f74322b5..2a939f9a18 100644 --- a/ethdb/memorydb/memorydb.go +++ b/ethdb/memorydb/memorydb.go @@ -207,7 +207,7 @@ func (db *Database) Len() int { // keyvalue is a key-value tuple tagged with a deletion field to allow creating // memory-database write batches. type keyvalue struct { - key []byte + key string value []byte delete bool } @@ -222,14 +222,14 @@ type batch struct { // Put inserts the given value into the batch for later committing. func (b *batch) Put(key, value []byte) error { - b.writes = append(b.writes, keyvalue{common.CopyBytes(key), common.CopyBytes(value), false}) + b.writes = append(b.writes, keyvalue{string(key), common.CopyBytes(value), false}) b.size += len(key) + len(value) return nil } // Delete inserts the a key removal into the batch for later committing. func (b *batch) Delete(key []byte) error { - b.writes = append(b.writes, keyvalue{common.CopyBytes(key), nil, true}) + b.writes = append(b.writes, keyvalue{string(key), nil, true}) b.size += len(key) return nil } @@ -249,10 +249,10 @@ func (b *batch) Write() error { } for _, keyvalue := range b.writes { if keyvalue.delete { - delete(b.db.db, string(keyvalue.key)) + delete(b.db.db, keyvalue.key) continue } - b.db.db[string(keyvalue.key)] = keyvalue.value + b.db.db[keyvalue.key] = keyvalue.value } return nil } @@ -267,12 +267,12 @@ func (b *batch) Reset() { func (b *batch) Replay(w ethdb.KeyValueWriter) error { for _, keyvalue := range b.writes { if keyvalue.delete { - if err := w.Delete(keyvalue.key); err != nil { + if err := w.Delete([]byte(keyvalue.key)); err != nil { return err } continue } - if err := w.Put(keyvalue.key, keyvalue.value); err != nil { + if err := w.Put([]byte(keyvalue.key), keyvalue.value); err != nil { return err } } diff --git a/ethdb/memorydb/memorydb_test.go b/ethdb/memorydb/memorydb_test.go index dba18ad306..51499c3b1f 100644 --- a/ethdb/memorydb/memorydb_test.go +++ b/ethdb/memorydb/memorydb_test.go @@ -17,6 +17,7 @@ package memorydb import ( + "encoding/binary" "testing" "github.com/ethereum/go-ethereum/ethdb" @@ -30,3 +31,20 @@ func TestMemoryDB(t *testing.T) { }) }) } + +// BenchmarkBatchAllocs measures the time/allocs for storing 120 kB of data +func BenchmarkBatchAllocs(b *testing.B) { + b.ReportAllocs() + var key = make([]byte, 20) + var val = make([]byte, 100) + // 120 * 1_000 -> 120_000 == 120kB + for i := 0; i < b.N; i++ { + batch := New().NewBatch() + for j := uint64(0); j < 1000; j++ { + binary.BigEndian.PutUint64(key, j) + binary.BigEndian.PutUint64(val, j) + batch.Put(key, val) + } + batch.Write() + } +} diff --git a/trie/iterator.go b/trie/iterator.go index 6f054a7245..83ccc0740f 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -144,7 +144,8 @@ type nodeIterator struct { path []byte // Path to the current node err error // Failure set in case of an internal error in the iterator - resolver NodeResolver // optional node resolver for avoiding disk hits + resolver NodeResolver // optional node resolver for avoiding disk hits + pool []*nodeIteratorState // local pool for iteratorstates } // errIteratorEnd is stored in nodeIterator.err when iteration is done. @@ -172,6 +173,24 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator { return it } +func (it *nodeIterator) putInPool(item *nodeIteratorState) { + if len(it.pool) < 40 { + item.node = nil + it.pool = append(it.pool, item) + } +} + +func (it *nodeIterator) getFromPool() *nodeIteratorState { + idx := len(it.pool) - 1 + if idx < 0 { + return new(nodeIteratorState) + } + el := it.pool[idx] + it.pool[idx] = nil + it.pool = it.pool[:idx] + return el +} + func (it *nodeIterator) AddResolver(resolver NodeResolver) { it.resolver = resolver } @@ -423,8 +442,9 @@ func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { return nil } -func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) { +func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) { var ( + path = it.path child node state *nodeIteratorState childPath []byte @@ -433,13 +453,12 @@ func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node, if n.Children[index] != nil { child = n.Children[index] hash, _ := child.cache() - state = &nodeIteratorState{ - hash: common.BytesToHash(hash), - node: child, - parent: ancestor, - index: -1, - pathlen: len(path), - } + state = it.getFromPool() + state.hash = common.BytesToHash(hash) + state.node = child + state.parent = ancestor + state.index = -1 + state.pathlen = len(path) childPath = append(childPath, path...) childPath = append(childPath, byte(index)) return child, state, childPath, index @@ -452,7 +471,7 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has switch node := parent.node.(type) { case *fullNode: // Full node, move to the first non-nil child. - if child, state, path, index := findChild(node, parent.index+1, it.path, ancestor); child != nil { + if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil { parent.index = index - 1 return state, path, true } @@ -460,13 +479,12 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has // Short node, return the pointer singleton child if parent.index < 0 { hash, _ := node.Val.cache() - state := &nodeIteratorState{ - hash: common.BytesToHash(hash), - node: node.Val, - parent: ancestor, - index: -1, - pathlen: len(it.path), - } + state := it.getFromPool() + state.hash = common.BytesToHash(hash) + state.node = node.Val + state.parent = ancestor + state.index = -1 + state.pathlen = len(it.path) path := append(it.path, node.Key...) return state, path, true } @@ -480,7 +498,7 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H switch n := parent.node.(type) { case *fullNode: // Full node, move to the first non-nil child before the desired key position - child, state, path, index := findChild(n, parent.index+1, it.path, ancestor) + child, state, path, index := it.findChild(n, parent.index+1, ancestor) if child == nil { // No more children in this fullnode return parent, it.path, false @@ -492,7 +510,7 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H } // The child is before the seek position. Try advancing for { - nextChild, nextState, nextPath, nextIndex := findChild(n, index+1, it.path, ancestor) + nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor) // If we run out of children, or skipped past the target, return the // previous one if nextChild == nil || bytes.Compare(nextPath, key) >= 0 { @@ -506,13 +524,12 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H // Short node, return the pointer singleton child if parent.index < 0 { hash, _ := n.Val.cache() - state := &nodeIteratorState{ - hash: common.BytesToHash(hash), - node: n.Val, - parent: ancestor, - index: -1, - pathlen: len(it.path), - } + state := it.getFromPool() + state.hash = common.BytesToHash(hash) + state.node = n.Val + state.parent = ancestor + state.index = -1 + state.pathlen = len(it.path) path := append(it.path, n.Key...) return state, path, true } @@ -533,6 +550,8 @@ func (it *nodeIterator) pop() { it.path = it.path[:last.pathlen] it.stack[len(it.stack)-1] = nil it.stack = it.stack[:len(it.stack)-1] + // last is now unused + it.putInPool(last) } func compareNodes(a, b NodeIterator) int { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 57d1f06a16..9679b49ca7 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -616,3 +616,15 @@ func isTrieNode(scheme string, key, val []byte) (bool, []byte, common.Hash) { } return true, path, hash } + +func BenchmarkIterator(b *testing.B) { + diskDb, srcDb, tr, _ := makeTestTrie(rawdb.HashScheme) + root := tr.Hash() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := checkTrieConsistency(diskDb, srcDb.Scheme(), root, false); err != nil { + b.Fatal(err) + } + } +} diff --git a/trie/sync_test.go b/trie/sync_test.go index 3b7986ef67..7032c6d2f7 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -571,7 +571,7 @@ func testIncompleteSync(t *testing.T, scheme string) { hash := crypto.Keccak256Hash(result.Data) if hash != root { addedKeys = append(addedKeys, result.Path) - addedHashes = append(addedHashes, crypto.Keccak256Hash(result.Data)) + addedHashes = append(addedHashes, hash) } } // Fetch the next batch to retrieve