diff --git a/trie/stacktrie.go b/trie/stacktrie.go index c34ac6c78c..f9ff10b62d 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -17,8 +17,12 @@ package trie import ( + "bufio" + "bytes" + "encoding/gob" "errors" "fmt" + "io" "sync" "github.com/ethereum/go-ethereum/common" @@ -66,6 +70,96 @@ func NewStackTrie(db ethdb.KeyValueWriter) *StackTrie { } } +// NewFromBinary initialises a serialized stacktrie with the given db. +func NewFromBinary(data []byte, db ethdb.KeyValueWriter) (*StackTrie, error) { + var st StackTrie + if err := st.UnmarshalBinary(data); err != nil { + return nil, err + } + // If a database is used, we need to recursively add it to every child + if db != nil { + st.setDb(db) + } + return &st, nil +} + +// MarshalBinary implements encoding.BinaryMarshaler +func (st *StackTrie) MarshalBinary() (data []byte, err error) { + var ( + b bytes.Buffer + w = bufio.NewWriter(&b) + ) + if err := gob.NewEncoder(w).Encode(struct { + Nodetype uint8 + Val []byte + Key []byte + KeyOffset uint8 + }{ + st.nodeType, + st.val, + st.key, + uint8(st.keyOffset), + }); err != nil { + return nil, err + } + for _, child := range st.children { + if child == nil { + w.WriteByte(0) + continue + } + w.WriteByte(1) + if childData, err := child.MarshalBinary(); err != nil { + return nil, err + } else { + w.Write(childData) + } + } + w.Flush() + return b.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (st *StackTrie) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + return st.unmarshalBinary(r) +} + +func (st *StackTrie) unmarshalBinary(r io.Reader) error { + var dec struct { + Nodetype uint8 + Val []byte + Key []byte + KeyOffset uint8 + } + gob.NewDecoder(r).Decode(&dec) + st.nodeType = dec.Nodetype + st.val = dec.Val + st.key = dec.Key + st.keyOffset = int(dec.KeyOffset) + + var hasChild = make([]byte, 1) + for i := range st.children { + if _, err := r.Read(hasChild); err != nil { + return err + } else if hasChild[0] == 0 { + continue + } + var child StackTrie + child.unmarshalBinary(r) + st.children[i] = &child + } + return nil +} + +func (st *StackTrie) setDb(db ethdb.KeyValueWriter) { + st.db = db + for _, child := range st.children { + if child != nil { + child.setDb(db) + } + } +} + func newLeaf(ko int, key, val []byte, db ethdb.KeyValueWriter) *StackTrie { st := stackTrieFromPool(db) st.nodeType = leafNode diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 50a4eda323..ccdf389d52 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -151,7 +151,6 @@ func TestStacktrieNotModifyValues(t *testing.T) { return big.NewInt(int64(i)).Bytes() } } - for i := 0; i < 1000; i++ { key := common.BigToHash(keyB) value := getValue(i) @@ -168,5 +167,51 @@ func TestStacktrieNotModifyValues(t *testing.T) { if !bytes.Equal(have, want) { t.Fatalf("item %d, have %#x want %#x", i, have, want) } + + } +} + +// TestStacktrieSerialization tests that the stacktrie works well if we +// serialize/unserialize it a lot +func TestStacktrieSerialization(t *testing.T) { + var ( + st = NewStackTrie(nil) + nt, _ = New(common.Hash{}, NewDatabase(memorydb.New())) + keyB = big.NewInt(1) + keyDelta = big.NewInt(1) + vals [][]byte + keys [][]byte + ) + getValue := func(i int) []byte { + if i%2 == 0 { // large + return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) + } else { //small + return big.NewInt(int64(i)).Bytes() + } + } + for i := 0; i < 10; i++ { + vals = append(vals, getValue(i)) + keys = append(keys, common.BigToHash(keyB).Bytes()) + keyB = keyB.Add(keyB, keyDelta) + keyDelta.Add(keyDelta, common.Big1) + } + for i, k := range keys { + nt.TryUpdate(k, common.CopyBytes(vals[i])) + } + + for i, k := range keys { + blob, err := st.MarshalBinary() + if err != nil { + t.Fatal(err) + } + newSt, err := NewFromBinary(blob, nil) + if err != nil { + t.Fatal(err) + } + st = newSt + st.TryUpdate(k, common.CopyBytes(vals[i])) + } + if have, want := st.Hash(), nt.Hash(); have != want { + t.Fatalf("have %#x want %#x", have, want) } }