diff --git a/trie/committer.go b/trie/committer.go index 6c4374ccfd..0939a07abb 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -57,32 +57,26 @@ func (c *committer) commit(path []byte, n node, parallel bool) node { // Commit children, then parent, and remove the dirty flag. switch cn := n.(type) { case *shortNode: - // Commit child - collapsed := cn.copy() - // If the child is fullNode, recursively commit, // otherwise it can only be hashNode or valueNode. if _, ok := cn.Val.(*fullNode); ok { - collapsed.Val = c.commit(append(path, cn.Key...), cn.Val, false) + cn.Val = c.commit(append(path, cn.Key...), cn.Val, false) } // The key needs to be copied, since we're adding it to the // modified nodeset. - collapsed.Key = hexToCompact(cn.Key) - hashedNode := c.store(path, collapsed) + cn.Key = hexToCompact(cn.Key) + hashedNode := c.store(path, cn) if hn, ok := hashedNode.(hashNode); ok { return hn } - return collapsed + return cn case *fullNode: - hashedKids := c.commitChildren(path, cn, parallel) - collapsed := cn.copy() - collapsed.Children = hashedKids - - hashedNode := c.store(path, collapsed) + c.commitChildren(path, cn, parallel) + hashedNode := c.store(path, cn) if hn, ok := hashedNode.(hashNode); ok { return hn } - return collapsed + return cn case hashNode: return cn default: @@ -92,11 +86,10 @@ func (c *committer) commit(path []byte, n node, parallel bool) node { } // commitChildren commits the children of the given fullnode -func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17]node { +func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) { var ( - wg sync.WaitGroup - nodesMu sync.Mutex - children [17]node + wg sync.WaitGroup + nodesMu sync.Mutex ) for i := 0; i < 16; i++ { child := n.Children[i] @@ -106,22 +99,21 @@ func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17] // If it's the hashed child, save the hash value directly. // Note: it's impossible that the child in range [0, 15] // is a valueNode. - if hn, ok := child.(hashNode); ok { - children[i] = hn + if _, ok := child.(hashNode); ok { continue } // Commit the child recursively and store the "hashed" value. // Note the returned node can be some embedded nodes, so it's // possible the type is not hashNode. if !parallel { - children[i] = c.commit(append(path, byte(i)), child, false) + n.Children[i] = c.commit(append(path, byte(i)), child, false) } else { wg.Add(1) go func(index int) { p := append(path, byte(index)) childSet := trienode.NewNodeSet(c.nodes.Owner) childCommitter := newCommitter(childSet, c.tracer, c.collectLeaf) - children[index] = childCommitter.commit(p, child, false) + n.Children[index] = childCommitter.commit(p, child, false) nodesMu.Lock() c.nodes.MergeSet(childSet) nodesMu.Unlock() @@ -132,11 +124,6 @@ func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17] if parallel { wg.Wait() } - // For the 17th child, it's possible the type is valuenode. - if n.Children[16] != nil { - children[16] = n.Children[16] - } - return children } // store hashes the node n and adds it to the modified nodeset. If leaf collection diff --git a/trie/hasher.go b/trie/hasher.go index 28f7f3d0c3..614640ae3a 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -53,62 +53,56 @@ func returnHasherToPool(h *hasher) { hasherPool.Put(h) } -// hash collapses a node down into a hash node, also returning a copy of the -// original node initialized with the computed hash to replace the original one. -func (h *hasher) hash(n node, force bool) (hashed node, cached node) { +// hash collapses a node down into a hash node. +func (h *hasher) hash(n node, force bool) node { // Return the cached hash if it's available if hash, _ := n.cache(); hash != nil { - return hash, n + return hash } // Trie not processed yet, walk the children switch n := n.(type) { case *shortNode: - collapsed, cached := h.hashShortNodeChildren(n) + collapsed := h.hashShortNodeChildren(n) hashed := h.shortnodeToHash(collapsed, force) - // We need to retain the possibly _not_ hashed node, in case it was too - // small to be hashed if hn, ok := hashed.(hashNode); ok { - cached.flags.hash = hn + n.flags.hash = hn } else { - cached.flags.hash = nil + n.flags.hash = nil } - return hashed, cached + return hashed case *fullNode: - collapsed, cached := h.hashFullNodeChildren(n) - hashed = h.fullnodeToHash(collapsed, force) + collapsed := h.hashFullNodeChildren(n) + hashed := h.fullnodeToHash(collapsed, force) if hn, ok := hashed.(hashNode); ok { - cached.flags.hash = hn + n.flags.hash = hn } else { - cached.flags.hash = nil + n.flags.hash = nil } - return hashed, cached + return hashed default: // Value and hash nodes don't have children, so they're left as were - return n, n + return n } } -// hashShortNodeChildren collapses the short node. The returned collapsed node -// holds a live reference to the Key, and must not be modified. -func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) { - // Hash the short node's child, caching the newly hashed subtree - collapsed, cached = n.copy(), n.copy() - // Previously, we did copy this one. We don't seem to need to actually - // do that, since we don't overwrite/reuse keys - // cached.Key = common.CopyBytes(n.Key) +// hashShortNodeChildren returns a copy of the supplied shortNode, with its child +// being replaced by either the hash or an embedded node if the child is small. +func (h *hasher) hashShortNodeChildren(n *shortNode) *shortNode { + var collapsed shortNode collapsed.Key = hexToCompact(n.Key) - // Unless the child is a valuenode or hashnode, hash it switch n.Val.(type) { case *fullNode, *shortNode: - collapsed.Val, cached.Val = h.hash(n.Val, false) + collapsed.Val = h.hash(n.Val, false) + default: + collapsed.Val = n.Val } - return collapsed, cached + return &collapsed } -func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached *fullNode) { - // Hash the full node's children, caching the newly hashed subtrees - cached = n.copy() - collapsed = n.copy() +// hashFullNodeChildren returns a copy of the supplied fullNode, with its child +// being replaced by either the hash or an embedded node if the child is small. +func (h *hasher) hashFullNodeChildren(n *fullNode) *fullNode { + var children [17]node if h.parallel { var wg sync.WaitGroup wg.Add(16) @@ -116,9 +110,9 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached go func(i int) { hasher := newHasher(false) if child := n.Children[i]; child != nil { - collapsed.Children[i], cached.Children[i] = hasher.hash(child, false) + children[i] = hasher.hash(child, false) } else { - collapsed.Children[i] = nilValueNode + children[i] = nilValueNode } returnHasherToPool(hasher) wg.Done() @@ -128,19 +122,21 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached } else { for i := 0; i < 16; i++ { if child := n.Children[i]; child != nil { - collapsed.Children[i], cached.Children[i] = h.hash(child, false) + children[i] = h.hash(child, false) } else { - collapsed.Children[i] = nilValueNode + children[i] = nilValueNode } } } - return collapsed, cached + if n.Children[16] != nil { + children[16] = n.Children[16] + } + return &fullNode{flags: nodeFlag{}, Children: children} } -// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode -// should have hex-type Key, which will be converted (without modification) -// into compact form for RLP encoding. -// If the rlp data is smaller than 32 bytes, `nil` is returned. +// shortNodeToHash computes the hash of the given shortNode. The shortNode must +// first be collapsed, with its key converted to compact form. If the RLP-encoded +// node data is smaller than 32 bytes, the node itself is returned. func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { n.encode(h.encbuf) enc := h.encodedBytes() @@ -151,8 +147,8 @@ func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { return h.hashData(enc) } -// fullnodeToHash is used to create a hashNode from a fullNode, (which -// may contain nil values) +// fullnodeToHash computes the hash of the given fullNode. If the RLP-encoded +// node data is smaller than 32 bytes, the node itself is returned. func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { n.encode(h.encbuf) enc := h.encodedBytes() @@ -203,10 +199,10 @@ func (h *hasher) hashDataTo(dst, data []byte) { func (h *hasher) proofHash(original node) (collapsed, hashed node) { switch n := original.(type) { case *shortNode: - sn, _ := h.hashShortNodeChildren(n) + sn := h.hashShortNodeChildren(n) return sn, h.shortnodeToHash(sn, false) case *fullNode: - fn, _ := h.hashFullNodeChildren(n) + fn := h.hashFullNodeChildren(n) return fn, h.fullnodeToHash(fn, false) default: // Value and hash nodes don't have children, so they're left as were diff --git a/trie/node.go b/trie/node.go index ecc2de192d..96f077ebbb 100644 --- a/trie/node.go +++ b/trie/node.go @@ -79,15 +79,19 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { return eb.Flush() } -func (n *fullNode) copy() *fullNode { copy := *n; return © } -func (n *shortNode) copy() *shortNode { copy := *n; return © } - // nodeFlag contains caching-related metadata about a node. type nodeFlag struct { hash hashNode // cached hash of the node (may be nil) dirty bool // whether the node has changes that must be written to the database } +func (n nodeFlag) copy() nodeFlag { + return nodeFlag{ + hash: common.CopyBytes(n.hash), + dirty: n.dirty, + } +} + func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty } func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty } func (n hashNode) cache() (hashNode, bool) { return nil, true } @@ -228,7 +232,9 @@ func decodeRef(buf []byte) (node, []byte, error) { err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) return nil, buf, err } - n, err := decodeNode(nil, buf) + // The buffer content has already been copied or is safe to use; + // no additional copy is required. + n, err := decodeNodeUnsafe(nil, buf) return n, rest, err case kind == rlp.String && len(val) == 0: // empty node diff --git a/trie/trie.go b/trie/trie.go index ae2a7b21a2..fdb4da9be4 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -29,11 +29,11 @@ import ( "github.com/ethereum/go-ethereum/triedb/database" ) -// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on -// top of a database. Whenever trie performs a commit operation, the generated -// nodes will be gathered and returned in a set. Once the trie is committed, -// it's not usable anymore. Callers have to re-create the trie with new root -// based on the updated trie database. +// Trie represents a Merkle Patricia Trie. Use New to create a trie that operates +// on top of a node database. During a commit operation, the trie collects all +// modified nodes into a set for return. After committing, the trie becomes +// unusable, and callers must recreate it with the new root based on the updated +// trie database. // // Trie is not safe for concurrent use. type Trie struct { @@ -67,13 +67,13 @@ func (t *Trie) newFlag() nodeFlag { // Copy returns a copy of Trie. func (t *Trie) Copy() *Trie { return &Trie{ - root: t.root, + root: copyNode(t.root), owner: t.owner, committed: t.committed, + unhashed: t.unhashed, + uncommitted: t.uncommitted, reader: t.reader, tracer: t.tracer.copy(), - uncommitted: t.uncommitted, - unhashed: t.unhashed, } } @@ -169,14 +169,12 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode no } value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { - n = n.copy() n.Val = newnode } return value, n, didResolve, err case *fullNode: value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { - n = n.copy() n.Children[key[pos]] = newnode } return value, n, didResolve, err @@ -257,7 +255,6 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod } item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key)) if err == nil && resolved > 0 { - n = n.copy() n.Val = newnode } return item, n, resolved, err @@ -265,7 +262,6 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod case *fullNode: item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1) if err == nil && resolved > 0 { - n = n.copy() n.Children[path[pos]] = newnode } return item, n, resolved, err @@ -375,7 +371,6 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if !dirty || err != nil { return false, n, err } - n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn return true, n, nil @@ -483,7 +478,6 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { if !dirty || err != nil { return false, n, err } - n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn @@ -576,6 +570,36 @@ func concat(s1 []byte, s2 ...byte) []byte { return r } +// copyNode deep-copies the supplied node along with its children recursively. +func copyNode(n node) node { + switch n := (n).(type) { + case nil: + return nil + case valueNode: + return valueNode(common.CopyBytes(n)) + + case *shortNode: + return &shortNode{ + flags: n.flags.copy(), + Key: common.CopyBytes(n.Key), + Val: copyNode(n.Val), + } + case *fullNode: + var children [17]node + for i, cn := range n.Children { + children[i] = copyNode(cn) + } + return &fullNode{ + flags: n.flags.copy(), + Children: children, + } + case hashNode: + return n + default: + panic(fmt.Sprintf("%T: unknown node type", n)) + } +} + func (t *Trie) resolve(n node, prefix []byte) (node, error) { if n, ok := n.(hashNode); ok { return t.resolveAndTrack(n, prefix) @@ -593,15 +617,16 @@ func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { return nil, err } t.tracer.onRead(prefix, blob) - return mustDecodeNode(n, blob), nil + + // The returned node blob won't be changed afterward. No need to + // deep-copy the slice. + return decodeNodeUnsafe(n, blob) } // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - hash, cached := t.hashRoot() - t.root = cached - return common.BytesToHash(hash.(hashNode)) + return common.BytesToHash(t.hashRoot().(hashNode)) } // Commit collects all dirty nodes in the trie and replaces them with the @@ -652,9 +677,9 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } // hashRoot calculates the root hash of the given trie -func (t *Trie) hashRoot() (node, node) { +func (t *Trie) hashRoot() node { if t.root == nil { - return hashNode(types.EmptyRootHash.Bytes()), nil + return hashNode(types.EmptyRootHash.Bytes()) } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) @@ -662,8 +687,7 @@ func (t *Trie) hashRoot() (node, node) { returnHasherToPool(h) t.unhashed = 0 }() - hashed, cached := h.hash(t.root, true) - return hashed, cached + return h.hash(t.root, true) } // Witness returns a set containing all trie nodes that have been accessed. diff --git a/trie/trie_test.go b/trie/trie_test.go index 77234d9d9b..54d1b083d8 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1330,3 +1330,171 @@ func printSet(set *trienode.NodeSet) string { } return out.String() } + +func TestTrieCopy(t *testing.T) { + testTrieCopy(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testTrieCopy(t, entries) +} + +func testTrieCopy(t *testing.T, entries []kv) { + tr := NewEmpty(nil) + for _, entry := range entries { + tr.Update(entry.k, entry.v) + } + trCpy := tr.Copy() + + if tr.Hash() != trCpy.Hash() { + t.Errorf("Hash mismatch: old %v, copy %v", tr.Hash(), trCpy.Hash()) + } + + // Check iterator + it, _ := tr.NodeIterator(nil) + itCpy, _ := trCpy.NodeIterator(nil) + + for it.Next(false) { + hasNext := itCpy.Next(false) + if !hasNext { + t.Fatal("Iterator is not matched") + } + if !bytes.Equal(it.Path(), itCpy.Path()) { + t.Fatal("Iterator is not matched") + } + if it.Leaf() != itCpy.Leaf() { + t.Fatal("Iterator is not matched") + } + if it.Leaf() && !bytes.Equal(it.LeafBlob(), itCpy.LeafBlob()) { + t.Fatal("Iterator is not matched") + } + } + + // Check commit + root, nodes := tr.Commit(false) + rootCpy, nodesCpy := trCpy.Commit(false) + if root != rootCpy { + t.Fatal("root mismatch") + } + if len(nodes.Nodes) != len(nodesCpy.Nodes) { + t.Fatal("commit node mismatch") + } + for p, n := range nodes.Nodes { + nn, exists := nodesCpy.Nodes[p] + if !exists { + t.Fatalf("node not exists: %v", p) + } + if !reflect.DeepEqual(n, nn) { + t.Fatalf("node mismatch: %v", p) + } + } +} + +func TestTrieCopyOldTrie(t *testing.T) { + testTrieCopyOldTrie(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testTrieCopyOldTrie(t, entries) +} + +func testTrieCopyOldTrie(t *testing.T, entries []kv) { + tr := NewEmpty(nil) + for _, entry := range entries { + tr.Update(entry.k, entry.v) + } + hash := tr.Hash() + + trCpy := tr.Copy() + for _, val := range entries { + if rand.Intn(2) == 0 { + trCpy.Delete(val.k) + } else { + trCpy.Update(val.k, testrand.Bytes(32)) + } + } + for i := 0; i < 10; i++ { + trCpy.Update(testrand.Bytes(32), testrand.Bytes(32)) + } + trCpy.Hash() + trCpy.Commit(false) + + // Traverse the original tree, the changes made on the copy one shouldn't + // affect the old one + for _, entry := range entries { + d, _ := tr.Get(entry.k) + if !bytes.Equal(d, entry.v) { + t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d) + } + } + if tr.Hash() != hash { + t.Errorf("Hash mismatch: old %v, new %v", hash, tr.Hash()) + } +} + +func TestTrieCopyNewTrie(t *testing.T) { + testTrieCopyNewTrie(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testTrieCopyNewTrie(t, entries) +} + +func testTrieCopyNewTrie(t *testing.T, entries []kv) { + tr := NewEmpty(nil) + for _, entry := range entries { + tr.Update(entry.k, entry.v) + } + trCpy := tr.Copy() + hash := trCpy.Hash() + + for _, val := range entries { + if rand.Intn(2) == 0 { + tr.Delete(val.k) + } else { + tr.Update(val.k, testrand.Bytes(32)) + } + } + for i := 0; i < 10; i++ { + tr.Update(testrand.Bytes(32), testrand.Bytes(32)) + } + + // Traverse the original tree, the changes made on the copy one shouldn't + // affect the old one + for _, entry := range entries { + d, _ := trCpy.Get(entry.k) + if !bytes.Equal(d, entry.v) { + t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d) + } + } + if trCpy.Hash() != hash { + t.Errorf("Hash mismatch: old %v, new %v", hash, tr.Hash()) + } +} diff --git a/triedb/database/database.go b/triedb/database/database.go index cd7ec1d931..8c61ea0293 100644 --- a/triedb/database/database.go +++ b/triedb/database/database.go @@ -27,6 +27,8 @@ type NodeReader interface { // node path and the corresponding node hash. No error will be returned // if the node is not found. // + // The returned node content won't be changed after the call. + // // Don't modify the returned byte slice since it's not deep-copied and // still be referenced by database. Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)