trie: ensure resolved nodes stay loaded
Commit 40cdcf1183
broke the optimisation which kept nodes resolved
during Get in the trie. The decoder assigned cache generation 0
unconditionally, causing resolved nodes to get flushed on Commit.
This commit fixes it and adds two tests.
This commit is contained in:
parent
187d6a66a5
commit
177cab5fe7
|
@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
|
||||||
return hash, n, nil
|
return hash, n, nil
|
||||||
}
|
}
|
||||||
if n.canUnload(h.cachegen, h.cachelimit) {
|
if n.canUnload(h.cachegen, h.cachelimit) {
|
||||||
// Evict the node from cache. All of its subnodes will have a lower or equal
|
// Unload the node from cache. All of its subnodes will have a lower or equal
|
||||||
// cache generation number.
|
// cache generation number.
|
||||||
return hash, hash, nil
|
return hash, hash, nil
|
||||||
}
|
}
|
||||||
|
|
26
trie/node.go
26
trie/node.go
|
@ -104,8 +104,8 @@ func (n valueNode) fstring(ind string) string {
|
||||||
return fmt.Sprintf("%x ", []byte(n))
|
return fmt.Sprintf("%x ", []byte(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustDecodeNode(hash, buf []byte) node {
|
func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
|
||||||
n, err := decodeNode(hash, buf)
|
n, err := decodeNode(hash, buf, cachegen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("node %x: %v", hash, err))
|
panic(fmt.Sprintf("node %x: %v", hash, err))
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func mustDecodeNode(hash, buf []byte) node {
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeNode parses the RLP encoding of a trie node.
|
// decodeNode parses the RLP encoding of a trie node.
|
||||||
func decodeNode(hash, buf []byte) (node, error) {
|
func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) {
|
||||||
}
|
}
|
||||||
switch c, _ := rlp.CountValues(elems); c {
|
switch c, _ := rlp.CountValues(elems); c {
|
||||||
case 2:
|
case 2:
|
||||||
n, err := decodeShort(hash, buf, elems)
|
n, err := decodeShort(hash, buf, elems, cachegen)
|
||||||
return n, wrapError(err, "short")
|
return n, wrapError(err, "short")
|
||||||
case 17:
|
case 17:
|
||||||
n, err := decodeFull(hash, buf, elems)
|
n, err := decodeFull(hash, buf, elems, cachegen)
|
||||||
return n, wrapError(err, "full")
|
return n, wrapError(err, "full")
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid number of list elements: %v", c)
|
return nil, fmt.Errorf("invalid number of list elements: %v", c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeShort(hash, buf, elems []byte) (node, error) {
|
func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
|
||||||
kbuf, rest, err := rlp.SplitString(elems)
|
kbuf, rest, err := rlp.SplitString(elems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
flag := nodeFlag{hash: hash}
|
flag := nodeFlag{hash: hash, gen: cachegen}
|
||||||
key := compactDecode(kbuf)
|
key := compactDecode(kbuf)
|
||||||
if key[len(key)-1] == 16 {
|
if key[len(key)-1] == 16 {
|
||||||
// value node
|
// value node
|
||||||
|
@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
|
||||||
}
|
}
|
||||||
return &shortNode{key, append(valueNode{}, val...), flag}, nil
|
return &shortNode{key, append(valueNode{}, val...), flag}, nil
|
||||||
}
|
}
|
||||||
r, _, err := decodeRef(rest)
|
r, _, err := decodeRef(rest, cachegen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err, "val")
|
return nil, wrapError(err, "val")
|
||||||
}
|
}
|
||||||
return &shortNode{key, r, flag}, nil
|
return &shortNode{key, r, flag}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
|
func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
|
||||||
n := &fullNode{flags: nodeFlag{hash: hash}}
|
n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
|
||||||
for i := 0; i < 16; i++ {
|
for i := 0; i < 16; i++ {
|
||||||
cld, rest, err := decodeRef(elems)
|
cld, rest, err := decodeRef(elems, cachegen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, wrapError(err, fmt.Sprintf("[%d]", i))
|
return n, wrapError(err, fmt.Sprintf("[%d]", i))
|
||||||
}
|
}
|
||||||
|
@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
|
||||||
|
|
||||||
const hashLen = len(common.Hash{})
|
const hashLen = len(common.Hash{})
|
||||||
|
|
||||||
func decodeRef(buf []byte) (node, []byte, error) {
|
func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
|
||||||
kind, val, rest, err := rlp.Split(buf)
|
kind, val, rest, err := rlp.Split(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, buf, err
|
return nil, buf, err
|
||||||
|
@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
|
||||||
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
|
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
|
||||||
return nil, buf, err
|
return nil, buf, err
|
||||||
}
|
}
|
||||||
n, err := decodeNode(nil, buf)
|
n, err := decodeNode(nil, buf, cachegen)
|
||||||
return n, rest, err
|
return n, rest, err
|
||||||
case kind == rlp.String && len(val) == 0:
|
case kind == rlp.String && len(val) == 0:
|
||||||
// empty node
|
// empty node
|
||||||
|
|
|
@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
|
||||||
if !bytes.Equal(sha.Sum(nil), wantHash) {
|
if !bytes.Equal(sha.Sum(nil), wantHash) {
|
||||||
return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
|
return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
|
||||||
}
|
}
|
||||||
n, err := decodeNode(wantHash, buf)
|
n, err := decodeNode(wantHash, buf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("bad proof node %d: %v", i, err)
|
return nil, fmt.Errorf("bad proof node %d: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c
|
||||||
}
|
}
|
||||||
key := root.Bytes()
|
key := root.Bytes()
|
||||||
blob, _ := s.database.Get(key)
|
blob, _ := s.database.Get(key)
|
||||||
if local, err := decodeNode(key, blob); local != nil && err == nil {
|
if local, err := decodeNode(key, blob, 0); local != nil && err == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Assemble the new sub-trie sync request
|
// Assemble the new sub-trie sync request
|
||||||
|
@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Decode the node data content and update the request
|
// Decode the node data content and update the request
|
||||||
node, err := decodeNode(item.Hash[:], item.Data)
|
node, err := decodeNode(item.Hash[:], item.Data, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
|
||||||
if node, ok := (*child.node).(hashNode); ok {
|
if node, ok := (*child.node).(hashNode); ok {
|
||||||
// Try to resolve the node from the local database
|
// Try to resolve the node from the local database
|
||||||
blob, _ := s.database.Get(node)
|
blob, _ := s.database.Get(node)
|
||||||
if local, err := decodeNode(node[:], blob); local != nil && err == nil {
|
if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil {
|
||||||
*child.node = local
|
*child.node = local
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
11
trie/trie.go
11
trie/trie.go
|
@ -144,14 +144,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
|
||||||
if err == nil && didResolve {
|
if err == nil && didResolve {
|
||||||
n = n.copy()
|
n = n.copy()
|
||||||
n.Val = newnode
|
n.Val = newnode
|
||||||
|
n.flags.gen = t.cachegen
|
||||||
}
|
}
|
||||||
return value, n, didResolve, err
|
return value, n, didResolve, err
|
||||||
case *fullNode:
|
case *fullNode:
|
||||||
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
|
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
|
||||||
if err == nil && didResolve {
|
if err == nil && didResolve {
|
||||||
n = n.copy()
|
n = n.copy()
|
||||||
|
n.flags.gen = t.cachegen
|
||||||
n.Children[key[pos]] = newnode
|
n.Children[key[pos]] = newnode
|
||||||
|
|
||||||
}
|
}
|
||||||
return value, n, didResolve, err
|
return value, n, didResolve, err
|
||||||
case hashNode:
|
case hashNode:
|
||||||
|
@ -247,7 +248,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
|
||||||
return false, n, err
|
return false, n, err
|
||||||
}
|
}
|
||||||
n = n.copy()
|
n = n.copy()
|
||||||
n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
|
n.flags = t.newFlag()
|
||||||
|
n.Children[key[0]] = nn
|
||||||
return true, n, nil
|
return true, n, nil
|
||||||
|
|
||||||
case nil:
|
case nil:
|
||||||
|
@ -331,7 +333,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
|
||||||
return false, n, err
|
return false, n, err
|
||||||
}
|
}
|
||||||
n = n.copy()
|
n = n.copy()
|
||||||
n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
|
n.flags = t.newFlag()
|
||||||
|
n.Children[key[0]] = nn
|
||||||
|
|
||||||
// Check how many non-nil entries are left after deleting and
|
// Check how many non-nil entries are left after deleting and
|
||||||
// reduce the full node to a short node if only one entry is
|
// reduce the full node to a short node if only one entry is
|
||||||
|
@ -427,7 +430,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
|
||||||
SuffixLen: len(suffix),
|
SuffixLen: len(suffix),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dec := mustDecodeNode(n, enc)
|
dec := mustDecodeNode(n, enc, t.cachegen)
|
||||||
return dec, nil
|
return dec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -300,25 +300,6 @@ func TestReplication(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not an actual test
|
|
||||||
func TestOutput(t *testing.T) {
|
|
||||||
t.Skip()
|
|
||||||
|
|
||||||
base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
|
||||||
trie := newEmpty()
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
|
|
||||||
}
|
|
||||||
fmt.Println("############################## FULL ################################")
|
|
||||||
fmt.Println(trie.root)
|
|
||||||
|
|
||||||
trie.Commit()
|
|
||||||
fmt.Println("############################## SMALL ################################")
|
|
||||||
trie2, _ := New(trie.Hash(), trie.db)
|
|
||||||
getString(trie2, base+"20")
|
|
||||||
fmt.Println(trie2.root)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLargeValue(t *testing.T) {
|
func TestLargeValue(t *testing.T) {
|
||||||
trie := newEmpty()
|
trie := newEmpty()
|
||||||
trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
|
trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
|
||||||
|
@ -326,14 +307,56 @@ func TestLargeValue(t *testing.T) {
|
||||||
trie.Hash()
|
trie.Hash()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type countingDB struct {
|
||||||
|
Database
|
||||||
|
gets map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *countingDB) Get(key []byte) ([]byte, error) {
|
||||||
|
db.gets[string(key)]++
|
||||||
|
return db.Database.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheUnload checks that decoded nodes are unloaded after a
|
||||||
|
// certain number of commit operations.
|
||||||
|
func TestCacheUnload(t *testing.T) {
|
||||||
|
// Create test trie with two branches.
|
||||||
|
trie := newEmpty()
|
||||||
|
key1 := "---------------------------------"
|
||||||
|
key2 := "---some other branch"
|
||||||
|
updateString(trie, key1, "this is the branch of key1.")
|
||||||
|
updateString(trie, key2, "this is the branch of key2.")
|
||||||
|
root, _ := trie.Commit()
|
||||||
|
|
||||||
|
// Commit the trie repeatedly and access key1.
|
||||||
|
// The branch containing it is loaded from DB exactly two times:
|
||||||
|
// in the 0th and 6th iteration.
|
||||||
|
db := &countingDB{Database: trie.db, gets: make(map[string]int)}
|
||||||
|
trie, _ = New(root, db)
|
||||||
|
trie.SetCacheLimit(5)
|
||||||
|
for i := 0; i < 12; i++ {
|
||||||
|
getString(trie, key1)
|
||||||
|
trie.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it got loaded two times.
|
||||||
|
for dbkey, count := range db.gets {
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// randTest performs random trie operations.
|
||||||
|
// Instances of this test are created by Generate.
|
||||||
|
type randTest []randTestStep
|
||||||
|
|
||||||
type randTestStep struct {
|
type randTestStep struct {
|
||||||
op int
|
op int
|
||||||
key []byte // for opUpdate, opDelete, opGet
|
key []byte // for opUpdate, opDelete, opGet
|
||||||
value []byte // for opUpdate
|
value []byte // for opUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
type randTest []randTestStep
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
opUpdate = iota
|
opUpdate = iota
|
||||||
opDelete
|
opDelete
|
||||||
|
@ -342,6 +365,7 @@ const (
|
||||||
opHash
|
opHash
|
||||||
opReset
|
opReset
|
||||||
opItercheckhash
|
opItercheckhash
|
||||||
|
opCheckCacheInvariant
|
||||||
opMax // boundary value, not an actual op
|
opMax // boundary value, not an actual op
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -437,11 +461,36 @@ func runRandTest(rt randTest) bool {
|
||||||
fmt.Println("hashes not equal")
|
fmt.Println("hashes not equal")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
case opCheckCacheInvariant:
|
||||||
|
return checkCacheInvariant(tr.root, tr.cachegen, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool {
|
||||||
|
switch n := n.(type) {
|
||||||
|
case *shortNode:
|
||||||
|
if n.flags.gen > parentCachegen {
|
||||||
|
fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return checkCacheInvariant(n.Val, n.flags.gen, depth+1)
|
||||||
|
case *fullNode:
|
||||||
|
if n.flags.gen > parentCachegen {
|
||||||
|
fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, child := range n.Children {
|
||||||
|
if !checkCacheInvariant(child, n.flags.gen, depth+1) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func TestRandom(t *testing.T) {
|
func TestRandom(t *testing.T) {
|
||||||
if err := quick.Check(runRandTest, nil); err != nil {
|
if err := quick.Check(runRandTest, nil); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
Loading…
Reference in New Issue