diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 0e286bd8b2..b2c4422103 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -3,7 +3,7 @@ package types import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethutil" - "github.com/ethereum/go-ethereum/ptrie" + "github.com/ethereum/go-ethereum/trie" ) type DerivableList interface { @@ -13,7 +13,7 @@ type DerivableList interface { func DeriveSha(list DerivableList) []byte { db, _ := ethdb.NewMemDatabase() - trie := ptrie.New(nil, db) + trie := trie.New(nil, db) for i := 0; i < list.Len(); i++ { trie.Update(ethutil.Encode(i), list.GetRlp(i)) } diff --git a/ptrie/iterator.go b/ptrie/iterator.go deleted file mode 100644 index 787ba09c02..0000000000 --- a/ptrie/iterator.go +++ /dev/null @@ -1,115 +0,0 @@ -package ptrie - -import ( - "bytes" - - "github.com/ethereum/go-ethereum/trie" -) - -type Iterator struct { - trie *Trie - - Key []byte - Value []byte -} - -func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie, Key: make([]byte, 32)} -} - -func (self *Iterator) Next() bool { - self.trie.mu.Lock() - defer self.trie.mu.Unlock() - - key := trie.RemTerm(trie.CompactHexDecode(string(self.Key))) - k := self.next(self.trie.root, key) - - self.Key = []byte(trie.DecodeCompact(k)) - - return len(k) > 0 - -} - -func (self *Iterator) next(node Node, key []byte) []byte { - if node == nil { - return nil - } - - switch node := node.(type) { - case *FullNode: - if len(key) > 0 { - k := self.next(node.branch(key[0]), key[1:]) - if k != nil { - return append([]byte{key[0]}, k...) - } - } - - var r byte - if len(key) > 0 { - r = key[0] + 1 - } - - for i := r; i < 16; i++ { - k := self.key(node.branch(byte(i))) - if k != nil { - return append([]byte{i}, k...) - } - } - - case *ShortNode: - k := trie.RemTerm(node.Key()) - if vnode, ok := node.Value().(*ValueNode); ok { - if bytes.Compare([]byte(k), key) > 0 { - self.Value = vnode.Val() - return k - } - } else { - cnode := node.Value() - - var ret []byte - skey := key[len(k):] - if trie.BeginsWith(key, k) { - ret = self.next(cnode, skey) - } else if bytes.Compare(k, key[:len(k)]) > 0 { - ret = self.key(node) - } - - if ret != nil { - return append(k, ret...) - } - } - } - - return nil -} - -func (self *Iterator) key(node Node) []byte { - switch node := node.(type) { - case *ShortNode: - // Leaf node - if vnode, ok := node.Value().(*ValueNode); ok { - k := trie.RemTerm(node.Key()) - self.Value = vnode.Val() - - return k - } else { - k := trie.RemTerm(node.Key()) - return append(k, self.key(node.Value())...) - } - case *FullNode: - if node.Value() != nil { - self.Value = node.Value().(*ValueNode).Val() - - return []byte{16} - } - - for i := 0; i < 16; i++ { - k := self.key(node.branch(byte(i))) - if k != nil { - return append([]byte{byte(i)}, k...) - } - } - } - - return nil -} diff --git a/ptrie/trie.go b/ptrie/trie.go deleted file mode 100644 index 5c83b57d05..0000000000 --- a/ptrie/trie.go +++ /dev/null @@ -1,335 +0,0 @@ -package ptrie - -import ( - "bytes" - "container/list" - "fmt" - "sync" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethutil" - "github.com/ethereum/go-ethereum/trie" -) - -func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { - t2 := New(nil, backend) - - it := t1.Iterator() - for it.Next() { - t2.Update(it.Key, it.Value) - } - - return bytes.Equal(t2.Hash(), t1.Hash()), t2 -} - -type Trie struct { - mu sync.Mutex - root Node - roothash []byte - cache *Cache - - revisions *list.List -} - -func New(root []byte, backend Backend) *Trie { - trie := &Trie{} - trie.revisions = list.New() - trie.roothash = root - trie.cache = NewCache(backend) - - if root != nil { - value := ethutil.NewValueFromBytes(trie.cache.Get(root)) - trie.root = trie.mknode(value) - } - - return trie -} - -func (self *Trie) Iterator() *Iterator { - return NewIterator(self) -} - -func (self *Trie) Copy() *Trie { - return New(self.roothash, self.cache.backend) -} - -// Legacy support -func (self *Trie) Root() []byte { return self.Hash() } -func (self *Trie) Hash() []byte { - var hash []byte - if self.root != nil { - t := self.root.Hash() - if byts, ok := t.([]byte); ok && len(byts) > 0 { - hash = byts - } else { - hash = crypto.Sha3(ethutil.Encode(self.root.RlpData())) - } - } else { - hash = crypto.Sha3(ethutil.Encode("")) - } - - if !bytes.Equal(hash, self.roothash) { - self.revisions.PushBack(self.roothash) - self.roothash = hash - } - - return hash -} -func (self *Trie) Commit() { - self.mu.Lock() - defer self.mu.Unlock() - - // Hash first - self.Hash() - - self.cache.Flush() -} - -// Reset should only be called if the trie has been hashed -func (self *Trie) Reset() { - self.mu.Lock() - defer self.mu.Unlock() - - self.cache.Reset() - - if self.revisions.Len() > 0 { - revision := self.revisions.Remove(self.revisions.Back()).([]byte) - self.roothash = revision - } - value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash)) - self.root = self.mknode(value) -} - -func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } -func (self *Trie) Update(key, value []byte) Node { - self.mu.Lock() - defer self.mu.Unlock() - - k := trie.CompactHexDecode(string(key)) - - if len(value) != 0 { - self.root = self.insert(self.root, k, &ValueNode{self, value}) - } else { - self.root = self.delete(self.root, k) - } - - return self.root -} - -func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } -func (self *Trie) Get(key []byte) []byte { - self.mu.Lock() - defer self.mu.Unlock() - - k := trie.CompactHexDecode(string(key)) - - n := self.get(self.root, k) - if n != nil { - return n.(*ValueNode).Val() - } - - return nil -} - -func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } -func (self *Trie) Delete(key []byte) Node { - self.mu.Lock() - defer self.mu.Unlock() - - k := trie.CompactHexDecode(string(key)) - self.root = self.delete(self.root, k) - - return self.root -} - -func (self *Trie) insert(node Node, key []byte, value Node) Node { - if len(key) == 0 { - return value - } - - if node == nil { - return NewShortNode(self, key, value) - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - if bytes.Equal(k, key) { - return NewShortNode(self, key, value) - } - - var n Node - matchlength := trie.MatchingNibbleLength(key, k) - if matchlength == len(k) { - n = self.insert(cnode, key[matchlength:], value) - } else { - pnode := self.insert(nil, k[matchlength+1:], cnode) - nnode := self.insert(nil, key[matchlength+1:], value) - fulln := NewFullNode(self) - fulln.set(k[matchlength], pnode) - fulln.set(key[matchlength], nnode) - n = fulln - } - if matchlength == 0 { - return n - } - - return NewShortNode(self, key[:matchlength], n) - - case *FullNode: - cpy := node.Copy().(*FullNode) - cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) - - return cpy - - default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) - } -} - -func (self *Trie) get(node Node, key []byte) Node { - if len(key) == 0 { - return node - } - - if node == nil { - return nil - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - - if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { - return self.get(cnode, key[len(k):]) - } - - return nil - case *FullNode: - return self.get(node.branch(key[0]), key[1:]) - default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) - } -} - -func (self *Trie) delete(node Node, key []byte) Node { - if len(key) == 0 && node == nil { - return nil - } - - switch node := node.(type) { - case *ShortNode: - k := node.Key() - cnode := node.Value() - if bytes.Equal(key, k) { - return nil - } else if bytes.Equal(key[:len(k)], k) { - child := self.delete(cnode, key[len(k):]) - - var n Node - switch child := child.(type) { - case *ShortNode: - nkey := append(k, child.Key()...) - n = NewShortNode(self, nkey, child.Value()) - case *FullNode: - sn := NewShortNode(self, node.Key(), child) - sn.key = node.key - n = sn - } - - return n - } else { - return node - } - - case *FullNode: - n := node.Copy().(*FullNode) - n.set(key[0], self.delete(n.branch(key[0]), key[1:])) - - pos := -1 - for i := 0; i < 17; i++ { - if n.branch(byte(i)) != nil { - if pos == -1 { - pos = i - } else { - pos = -2 - } - } - } - - var nnode Node - if pos == 16 { - nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) - } else if pos >= 0 { - cnode := n.branch(byte(pos)) - switch cnode := cnode.(type) { - case *ShortNode: - // Stitch keys - k := append([]byte{byte(pos)}, cnode.Key()...) - nnode = NewShortNode(self, k, cnode.Value()) - case *FullNode: - nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) - } - } else { - nnode = n - } - - return nnode - case nil: - return nil - default: - panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key)) - } -} - -// casting functions and cache storing -func (self *Trie) mknode(value *ethutil.Value) Node { - l := value.Len() - switch l { - case 0: - return nil - case 2: - // A value node may consists of 2 bytes. - if value.Get(0).Len() != 0 { - return NewShortNode(self, trie.CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1))) - } - case 17: - fnode := NewFullNode(self) - for i := 0; i < l; i++ { - fnode.set(byte(i), self.mknode(value.Get(i))) - } - return fnode - case 32: - return &HashNode{value.Bytes()} - } - - return &ValueNode{self, value.Bytes()} -} - -func (self *Trie) trans(node Node) Node { - switch node := node.(type) { - case *HashNode: - value := ethutil.NewValueFromBytes(self.cache.Get(node.key)) - return self.mknode(value) - default: - return node - } -} - -func (self *Trie) store(node Node) interface{} { - data := ethutil.Encode(node) - if len(data) >= 32 { - key := crypto.Sha3(data) - self.cache.Put(key, data) - - return key - } - - return node.RlpData() -} - -func (self *Trie) PrintRoot() { - fmt.Println(self.root) -} diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go deleted file mode 100644 index 63a8ed36e6..0000000000 --- a/ptrie/trie_test.go +++ /dev/null @@ -1,259 +0,0 @@ -package ptrie - -import ( - "bytes" - "fmt" - "testing" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethutil" -) - -type Db map[string][]byte - -func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil } -func (self Db) Put(k, v []byte) { self[string(k)] = v } - -// Used for testing -func NewEmpty() *Trie { - return New(nil, make(Db)) -} - -func TestEmptyTrie(t *testing.T) { - trie := NewEmpty() - res := trie.Hash() - exp := crypto.Sha3(ethutil.Encode("")) - if !bytes.Equal(res, exp) { - t.Errorf("expected %x got %x", exp, res) - } -} - -func TestInsert(t *testing.T) { - trie := NewEmpty() - - trie.UpdateString("doe", "reindeer") - trie.UpdateString("dog", "puppy") - trie.UpdateString("dogglesworth", "cat") - - exp := ethutil.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") - root := trie.Hash() - if !bytes.Equal(root, exp) { - t.Errorf("exp %x got %x", exp, root) - } - - trie = NewEmpty() - trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") - - exp = ethutil.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") - root = trie.Hash() - if !bytes.Equal(root, exp) { - t.Errorf("exp %x got %x", exp, root) - } -} - -func TestGet(t *testing.T) { - trie := NewEmpty() - - trie.UpdateString("doe", "reindeer") - trie.UpdateString("dog", "puppy") - trie.UpdateString("dogglesworth", "cat") - - res := trie.GetString("dog") - if !bytes.Equal(res, []byte("puppy")) { - t.Errorf("expected puppy got %x", res) - } - - unknown := trie.GetString("unknown") - if unknown != nil { - t.Errorf("expected nil got %x", unknown) - } -} - -func TestDelete(t *testing.T) { - trie := NewEmpty() - - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - } - for _, val := range vals { - if val.v != "" { - trie.UpdateString(val.k, val.v) - } else { - trie.DeleteString(val.k) - } - } - - hash := trie.Hash() - exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash, exp) { - t.Errorf("expected %x got %x", exp, hash) - } -} - -func TestEmptyValues(t *testing.T) { - trie := NewEmpty() - - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - } - for _, val := range vals { - trie.UpdateString(val.k, val.v) - } - - hash := trie.Hash() - exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - if !bytes.Equal(hash, exp) { - t.Errorf("expected %x got %x", exp, hash) - } -} - -func TestReplication(t *testing.T) { - trie := NewEmpty() - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - {"somethingveryoddindeedthis is", "myothernodedata"}, - } - for _, val := range vals { - trie.UpdateString(val.k, val.v) - } - trie.Commit() - - trie2 := New(trie.roothash, trie.cache.backend) - if string(trie2.GetString("horse")) != "stallion" { - t.Error("expected to have horse => stallion") - } - - hash := trie2.Hash() - exp := trie.Hash() - if !bytes.Equal(hash, exp) { - t.Errorf("root failure. expected %x got %x", exp, hash) - } - -} - -func TestReset(t *testing.T) { - trie := NewEmpty() - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - } - for _, val := range vals { - trie.UpdateString(val.k, val.v) - } - trie.Commit() - - before := ethutil.CopyBytes(trie.roothash) - trie.UpdateString("should", "revert") - trie.Hash() - // Should have no effect - trie.Hash() - trie.Hash() - // ### - - trie.Reset() - after := ethutil.CopyBytes(trie.roothash) - - if !bytes.Equal(before, after) { - t.Errorf("expected roots to be equal. %x - %x", before, after) - } -} - -func TestParanoia(t *testing.T) { - t.Skip() - trie := NewEmpty() - - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - {"somethingveryoddindeedthis is", "myothernodedata"}, - } - for _, val := range vals { - trie.UpdateString(val.k, val.v) - } - trie.Commit() - - ok, t2 := ParanoiaCheck(trie, trie.cache.backend) - if !ok { - t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash) - } -} - -// Not an actual test -func TestOutput(t *testing.T) { - t.Skip() - - base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - trie := NewEmpty() - for i := 0; i < 50; i++ { - trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") - } - fmt.Println("############################## FULL ################################") - fmt.Println(trie.root) - - trie.Commit() - fmt.Println("############################## SMALL ################################") - trie2 := New(trie.roothash, trie.cache.backend) - trie2.GetString(base + "20") - fmt.Println(trie2.root) -} - -func BenchmarkGets(b *testing.B) { - trie := NewEmpty() - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - {"somethingveryoddindeedthis is", "myothernodedata"}, - } - for _, val := range vals { - trie.UpdateString(val.k, val.v) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - trie.Get([]byte("horse")) - } -} - -func BenchmarkUpdate(b *testing.B) { - trie := NewEmpty() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - trie.UpdateString(fmt.Sprintf("aaaaaaaaa%d", i), "value") - } - trie.Hash() -} diff --git a/state/state_object.go b/state/state_object.go index c1c78bee02..913c57a316 100644 --- a/state/state_object.go +++ b/state/state_object.go @@ -6,7 +6,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethutil" - "github.com/ethereum/go-ethereum/ptrie" + "github.com/ethereum/go-ethereum/trie" ) type Code []byte @@ -152,7 +152,7 @@ func (self *StateObject) Sync() { } /* - valid, t2 := ptrie.ParanoiaCheck(self.State.trie, ethutil.Config.Db) + valid, t2 := trie.ParanoiaCheck(self.State.trie, ethutil.Config.Db) if !valid { statelogger.Infof("Warn: PARANOIA: Different state storage root during copy %x vs %x\n", self.State.Root(), t2.Root()) @@ -273,7 +273,7 @@ func (c *StateObject) Init() Code { return c.InitCode } -func (self *StateObject) Trie() *ptrie.Trie { +func (self *StateObject) Trie() *trie.Trie { return self.State.trie } diff --git a/state/statedb.go b/state/statedb.go index de73147905..3176ab7555 100644 --- a/state/statedb.go +++ b/state/statedb.go @@ -6,7 +6,7 @@ import ( "github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/ptrie" + "github.com/ethereum/go-ethereum/trie" ) var statelogger = logger.NewLogger("STATE") @@ -18,7 +18,7 @@ var statelogger = logger.NewLogger("STATE") // * Accounts type StateDB struct { db ethutil.Database - trie *ptrie.Trie + trie *trie.Trie stateObjects map[string]*StateObject @@ -30,9 +30,8 @@ type StateDB struct { } // Create a new state from a given trie -//func New(trie *ptrie.Trie) *StateDB { func New(root []byte, db ethutil.Database) *StateDB { - trie := ptrie.New(root, db) + trie := trie.New(root, db) return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), manifest: NewManifest(), refund: make(map[string]*big.Int)} } @@ -308,7 +307,7 @@ func (self *StateDB) Update(gasUsed *big.Int) { // FIXME trie delete is broken if deleted { - valid, t2 := ptrie.ParanoiaCheck(self.trie, self.db) + valid, t2 := trie.ParanoiaCheck(self.trie, self.db) if !valid { statelogger.Infof("Warn: PARANOIA: Different state root during copy %x vs %x\n", self.trie.Root(), t2.Root()) diff --git a/tests/helper/trie.go b/tests/helper/trie.go index 3cfb0bbe5e..9e666d333a 100644 --- a/tests/helper/trie.go +++ b/tests/helper/trie.go @@ -1,6 +1,6 @@ package helper -import "github.com/ethereum/go-ethereum/ptrie" +import "github.com/ethereum/go-ethereum/trie" type MemDatabase struct { db map[string][]byte @@ -24,8 +24,8 @@ func (db *MemDatabase) Print() {} func (db *MemDatabase) Close() {} func (db *MemDatabase) LastKnownTD() []byte { return nil } -func NewTrie() *ptrie.Trie { +func NewTrie() *trie.Trie { db, _ := NewMemDatabase() - return ptrie.New(nil, db) + return trie.New(nil, db) } diff --git a/ptrie/cache.go b/trie/cache.go similarity index 98% rename from ptrie/cache.go rename to trie/cache.go index 721dc4cf64..e03702b255 100644 --- a/ptrie/cache.go +++ b/trie/cache.go @@ -1,4 +1,4 @@ -package ptrie +package trie type Backend interface { Get([]byte) ([]byte, error) diff --git a/ptrie/fullnode.go b/trie/fullnode.go similarity index 98% rename from ptrie/fullnode.go rename to trie/fullnode.go index 4dd98049d5..ebbe7f3844 100644 --- a/ptrie/fullnode.go +++ b/trie/fullnode.go @@ -1,4 +1,4 @@ -package ptrie +package trie import "fmt" diff --git a/ptrie/hashnode.go b/trie/hashnode.go similarity index 97% rename from ptrie/hashnode.go rename to trie/hashnode.go index 4c17569d78..40ccd54c31 100644 --- a/ptrie/hashnode.go +++ b/trie/hashnode.go @@ -1,4 +1,4 @@ -package ptrie +package trie type HashNode struct { key []byte diff --git a/trie/iterator.go b/trie/iterator.go index 1114715a66..f0dae28bb0 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -1,124 +1,73 @@ package trie -/* -import ( - "bytes" - - "github.com/ethereum/go-ethereum/ethutil" -) - -type NodeType byte - -const ( - EmptyNode NodeType = iota - BranchNode - LeafNode - ExtNode -) - -func getType(node *ethutil.Value) NodeType { - if node.Len() == 0 { - return EmptyNode - } - - if node.Len() == 2 { - k := CompactDecode(node.Get(0).Str()) - if HasTerm(k) { - return LeafNode - } - - return ExtNode - } - - return BranchNode -} +import "bytes" type Iterator struct { - Path [][]byte trie *Trie Key []byte - Value *ethutil.Value + Value []byte } func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie} + return &Iterator{trie: trie, Key: make([]byte, 32)} } -func (self *Iterator) key(node *ethutil.Value, path [][]byte) []byte { - switch getType(node) { - case LeafNode: - k := RemTerm(CompactDecode(node.Get(0).Str())) +func (self *Iterator) Next() bool { + self.trie.mu.Lock() + defer self.trie.mu.Unlock() - self.Path = append(path, k) - self.Value = node.Get(1) + key := RemTerm(CompactHexDecode(string(self.Key))) + k := self.next(self.trie.root, key) - return k - case BranchNode: - if node.Get(16).Len() > 0 { - return []byte{16} - } + self.Key = []byte(DecodeCompact(k)) - for i := byte(0); i < 16; i++ { - o := self.key(self.trie.getNode(node.Get(int(i)).Raw()), append(path, []byte{i})) - if o != nil { - return append([]byte{i}, o...) - } - } - case ExtNode: - currKey := node.Get(0).Bytes() + return len(k) > 0 - return self.key(self.trie.getNode(node.Get(1).Raw()), append(path, currKey)) +} + +func (self *Iterator) next(node Node, key []byte) []byte { + if node == nil { + return nil } - return nil -} - -func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byte { - switch typ := getType(node); typ { - case EmptyNode: - return nil - case BranchNode: + switch node := node.(type) { + case *FullNode: if len(key) > 0 { - subNode := self.trie.getNode(node.Get(int(key[0])).Raw()) - - o := self.next(subNode, key[1:], append(path, key[:1])) - if o != nil { - return append([]byte{key[0]}, o...) + k := self.next(node.branch(key[0]), key[1:]) + if k != nil { + return append([]byte{key[0]}, k...) } } - var r byte = 0 + var r byte if len(key) > 0 { r = key[0] + 1 } for i := r; i < 16; i++ { - subNode := self.trie.getNode(node.Get(int(i)).Raw()) - o := self.key(subNode, append(path, []byte{i})) - if o != nil { - return append([]byte{i}, o...) + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{i}, k...) } } - case LeafNode, ExtNode: - k := RemTerm(CompactDecode(node.Get(0).Str())) - if typ == LeafNode { - if bytes.Compare([]byte(k), []byte(key)) > 0 { - self.Value = node.Get(1) - self.Path = append(path, k) + case *ShortNode: + k := RemTerm(node.Key()) + if vnode, ok := node.Value().(*ValueNode); ok { + if bytes.Compare([]byte(k), key) > 0 { + self.Value = vnode.Val() return k } } else { - subNode := self.trie.getNode(node.Get(1).Raw()) - subKey := key[len(k):] + cnode := node.Value() + var ret []byte + skey := key[len(k):] if BeginsWith(key, k) { - ret = self.next(subNode, subKey, append(path, k)) + ret = self.next(cnode, skey) } else if bytes.Compare(k, key[:len(k)]) > 0 { - ret = self.key(node, append(path, k)) - } else { - ret = nil + ret = self.key(node) } if ret != nil { @@ -130,16 +79,33 @@ func (self *Iterator) next(node *ethutil.Value, key []byte, path [][]byte) []byt return nil } -// Get the next in keys -func (self *Iterator) Next(key string) []byte { - self.trie.mut.Lock() - defer self.trie.mut.Unlock() +func (self *Iterator) key(node Node) []byte { + switch node := node.(type) { + case *ShortNode: + // Leaf node + if vnode, ok := node.Value().(*ValueNode); ok { + k := RemTerm(node.Key()) + self.Value = vnode.Val() - k := RemTerm(CompactHexDecode(key)) - n := self.next(self.trie.getNode(self.trie.Root), k, nil) + return k + } else { + k := RemTerm(node.Key()) + return append(k, self.key(node.Value())...) + } + case *FullNode: + if node.Value() != nil { + self.Value = node.Value().(*ValueNode).Val() - self.Key = []byte(DecodeCompact(n)) + return []byte{16} + } - return self.Key + for i := 0; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{byte(i)}, k...) + } + } + } + + return nil } -*/ diff --git a/ptrie/iterator_test.go b/trie/iterator_test.go similarity index 97% rename from ptrie/iterator_test.go rename to trie/iterator_test.go index acfc03d633..74d9e903cd 100644 --- a/ptrie/iterator_test.go +++ b/trie/iterator_test.go @@ -1,4 +1,4 @@ -package ptrie +package trie import "testing" diff --git a/trie/main_test.go b/trie/main_test.go deleted file mode 100644 index f6f64c06f7..0000000000 --- a/trie/main_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package trie - -import ( - "testing" - - checker "gopkg.in/check.v1" -) - -func Test(t *testing.T) { checker.TestingT(t) } diff --git a/ptrie/node.go b/trie/node.go similarity index 98% rename from ptrie/node.go rename to trie/node.go index ab90a1a021..a1f68480f2 100644 --- a/ptrie/node.go +++ b/trie/node.go @@ -1,4 +1,4 @@ -package ptrie +package trie import "fmt" diff --git a/ptrie/shortnode.go b/trie/shortnode.go similarity index 78% rename from ptrie/shortnode.go rename to trie/shortnode.go index 73ff2914bf..f132b56d96 100644 --- a/ptrie/shortnode.go +++ b/trie/shortnode.go @@ -1,6 +1,4 @@ -package ptrie - -import "github.com/ethereum/go-ethereum/trie" +package trie type ShortNode struct { trie *Trie @@ -9,7 +7,7 @@ type ShortNode struct { } func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { - return &ShortNode{t, []byte(trie.CompactEncode(key)), value} + return &ShortNode{t, []byte(CompactEncode(key)), value} } func (self *ShortNode) Value() Node { self.value = self.trie.trans(self.value) @@ -27,5 +25,5 @@ func (self *ShortNode) Hash() interface{} { } func (self *ShortNode) Key() []byte { - return trie.CompactDecode(string(self.key)) + return CompactDecode(string(self.key)) } diff --git a/trie/trie.go b/trie/trie.go index c9fd18e009..36f2af5d22 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -1,8 +1,8 @@ package trie -/* import ( "bytes" + "container/list" "fmt" "sync" @@ -10,618 +10,325 @@ import ( "github.com/ethereum/go-ethereum/ethutil" ) -func ParanoiaCheck(t1 *Trie) (bool, *Trie) { - t2 := New(ethutil.Config.Db, "") +func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { + t2 := New(nil, backend) - t1.NewIterator().Each(func(key string, v *ethutil.Value) { - t2.Update(key, v.Str()) - }) - - return bytes.Compare(t2.GetRoot(), t1.GetRoot()) == 0, t2 -} - -func (s *Cache) Len() int { - return len(s.nodes) -} - -// TODO -// A StateObject is an object that has a state root -// This is goig to be the object for the second level caching (the caching of object which have a state such as contracts) -type StateObject interface { - State() *Trie - Sync() - Undo() -} - -type Node struct { - Key []byte - Value *ethutil.Value - Dirty bool -} - -func NewNode(key []byte, val *ethutil.Value, dirty bool) *Node { - return &Node{Key: key, Value: val, Dirty: dirty} -} - -func (n *Node) Copy() *Node { - return NewNode(n.Key, n.Value, n.Dirty) -} - -type Cache struct { - nodes map[string]*Node - db ethutil.Database - IsDirty bool -} - -func NewCache(db ethutil.Database) *Cache { - return &Cache{db: db, nodes: make(map[string]*Node)} -} - -func (cache *Cache) PutValue(v interface{}, force bool) interface{} { - value := ethutil.NewValue(v) - - enc := value.Encode() - if len(enc) >= 32 || force { - sha := crypto.Sha3(enc) - - cache.nodes[string(sha)] = NewNode(sha, value, true) - cache.IsDirty = true - - return sha + it := t1.Iterator() + for it.Next() { + t2.Update(it.Key, it.Value) } - return v + return bytes.Equal(t2.Hash(), t1.Hash()), t2 } -func (cache *Cache) Put(v interface{}) interface{} { - return cache.PutValue(v, false) -} - -func (cache *Cache) Get(key []byte) *ethutil.Value { - // First check if the key is the cache - if cache.nodes[string(key)] != nil { - return cache.nodes[string(key)].Value - } - - // Get the key of the database instead and cache it - data, _ := cache.db.Get(key) - // Create the cached value - value := ethutil.NewValueFromBytes(data) - - defer func() { - if r := recover(); r != nil { - fmt.Println("RECOVER GET", cache, cache.nodes) - panic("bye") - } - }() - // Create caching node - cache.nodes[string(key)] = NewNode(key, value, true) - - return value -} - -func (cache *Cache) Delete(key []byte) { - delete(cache.nodes, string(key)) - - cache.db.Delete(key) -} - -func (cache *Cache) Commit() { - // Don't try to commit if it isn't dirty - if !cache.IsDirty { - return - } - - for key, node := range cache.nodes { - if node.Dirty { - cache.db.Put([]byte(key), node.Value.Encode()) - node.Dirty = false - } - } - cache.IsDirty = false - - // If the nodes grows beyond the 200 entries we simple empty it - // FIXME come up with something better - if len(cache.nodes) > 200 { - cache.nodes = make(map[string]*Node) - } -} - -func (cache *Cache) Undo() { - for key, node := range cache.nodes { - if node.Dirty { - delete(cache.nodes, key) - } - } - cache.IsDirty = false -} - -// A (modified) Radix Trie implementation. The Trie implements -// a caching mechanism and will used cached values if they are -// present. If a node is not present in the cache it will try to -// fetch it from the database and store the cached value. -// Please note that the data isn't persisted unless `Sync` is -// explicitly called. type Trie struct { - mut sync.RWMutex - prevRoot interface{} - Root interface{} - //db Database - cache *Cache + mu sync.Mutex + root Node + roothash []byte + cache *Cache + + revisions *list.List } -func copyRoot(root interface{}) interface{} { - var prevRootCopy interface{} - if b, ok := root.([]byte); ok { - prevRootCopy = ethutil.CopyBytes(b) - } else { - prevRootCopy = root - } +func New(root []byte, backend Backend) *Trie { + trie := &Trie{} + trie.revisions = list.New() + trie.roothash = root + trie.cache = NewCache(backend) - return prevRootCopy -} - -func New(db ethutil.Database, Root interface{}) *Trie { - // Make absolute sure the root is copied - r := copyRoot(Root) - p := copyRoot(Root) - - trie := &Trie{cache: NewCache(db), Root: r, prevRoot: p} - trie.setRoot(Root) - - return trie -} - -func (self *Trie) setRoot(root interface{}) { - switch t := root.(type) { - case string: - //if t == "" { - // root = crypto.Sha3(ethutil.Encode("")) - //} - self.Root = []byte(t) - case []byte: - self.Root = root - default: - self.Root = self.cache.PutValue(root, true) - } -} - -func (t *Trie) Update(key, value string) { - t.mut.Lock() - defer t.mut.Unlock() - - k := CompactHexDecode(key) - - var root interface{} - if value != "" { - root = t.UpdateState(t.Root, k, value) - } else { - root = t.deleteState(t.Root, k) - } - t.setRoot(root) -} - -func (t *Trie) Get(key string) string { - t.mut.Lock() - defer t.mut.Unlock() - - k := CompactHexDecode(key) - c := ethutil.NewValue(t.getState(t.Root, k)) - - return c.Str() -} - -func (t *Trie) Delete(key string) { - t.mut.Lock() - defer t.mut.Unlock() - - k := CompactHexDecode(key) - - root := t.deleteState(t.Root, k) - t.setRoot(root) -} - -func (self *Trie) GetRoot() []byte { - switch t := self.Root.(type) { - case string: - if t == "" { - return crypto.Sha3(ethutil.Encode("")) - } - return []byte(t) - case []byte: - if len(t) == 0 { - return crypto.Sha3(ethutil.Encode("")) - } - - return t - default: - panic(fmt.Sprintf("invalid root type %T (%v)", self.Root, self.Root)) - } -} - -// Simple compare function which creates a rlp value out of the evaluated objects -func (t *Trie) Cmp(trie *Trie) bool { - return ethutil.NewValue(t.Root).Cmp(ethutil.NewValue(trie.Root)) -} - -// Returns a copy of this trie -func (t *Trie) Copy() *Trie { - trie := New(t.cache.db, t.Root) - for key, node := range t.cache.nodes { - trie.cache.nodes[key] = node.Copy() + if root != nil { + value := ethutil.NewValueFromBytes(trie.cache.Get(root)) + trie.root = trie.mknode(value) } return trie } -// Save the cached value to the database. -func (t *Trie) Sync() { - t.cache.Commit() - t.prevRoot = copyRoot(t.Root) -} - -func (t *Trie) Undo() { - t.cache.Undo() - t.Root = t.prevRoot -} - -func (t *Trie) Cache() *Cache { - return t.cache -} - -func (t *Trie) getState(node interface{}, key []byte) interface{} { - n := ethutil.NewValue(node) - // Return the node if key is empty (= found) - if len(key) == 0 || n.IsNil() || n.Len() == 0 { - return node - } - - currentNode := t.getNode(node) - length := currentNode.Len() - - if length == 0 { - return "" - } else if length == 2 { - // Decode the key - k := CompactDecode(currentNode.Get(0).Str()) - v := currentNode.Get(1).Raw() - - if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { //CompareIntSlice(k, key[:len(k)]) { - return t.getState(v, key[len(k):]) - } else { - return "" - } - } else if length == 17 { - return t.getState(currentNode.Get(int(key[0])).Raw(), key[1:]) - } - - // It shouldn't come this far - panic("unexpected return") -} - -func (t *Trie) getNode(node interface{}) *ethutil.Value { - n := ethutil.NewValue(node) - - if !n.Get(0).IsNil() { - return n - } - - str := n.Str() - if len(str) == 0 { - return n - } else if len(str) < 32 { - return ethutil.NewValueFromBytes([]byte(str)) - } - - data := t.cache.Get(n.Bytes()) - - return data -} - -func (t *Trie) UpdateState(node interface{}, key []byte, value string) interface{} { - return t.InsertState(node, key, value) -} - -func (t *Trie) Put(node interface{}) interface{} { - return t.cache.Put(node) - -} - -func EmptyStringSlice(l int) []interface{} { - slice := make([]interface{}, l) - for i := 0; i < l; i++ { - slice[i] = "" - } - return slice -} - -func (t *Trie) InsertState(node interface{}, key []byte, value interface{}) interface{} { - if len(key) == 0 { - return value - } - - // New node - n := ethutil.NewValue(node) - if node == nil || n.Len() == 0 { - newNode := []interface{}{CompactEncode(key), value} - - return t.Put(newNode) - } - - currentNode := t.getNode(node) - // Check for "special" 2 slice type node - if currentNode.Len() == 2 { - // Decode the key - - k := CompactDecode(currentNode.Get(0).Str()) - v := currentNode.Get(1).Raw() - - // Matching key pair (ie. there's already an object with this key) - if bytes.Equal(k, key) { //CompareIntSlice(k, key) { - newNode := []interface{}{CompactEncode(key), value} - return t.Put(newNode) - } - - var newHash interface{} - matchingLength := MatchingNibbleLength(key, k) - if matchingLength == len(k) { - // Insert the hash, creating a new node - newHash = t.InsertState(v, key[matchingLength:], value) - } else { - // Expand the 2 length slice to a 17 length slice - oldNode := t.InsertState("", k[matchingLength+1:], v) - newNode := t.InsertState("", key[matchingLength+1:], value) - // Create an expanded slice - scaledSlice := EmptyStringSlice(17) - // Set the copied and new node - scaledSlice[k[matchingLength]] = oldNode - scaledSlice[key[matchingLength]] = newNode - - newHash = t.Put(scaledSlice) - } - - if matchingLength == 0 { - // End of the chain, return - return newHash - } else { - newNode := []interface{}{CompactEncode(key[:matchingLength]), newHash} - return t.Put(newNode) - } - } else { - - // Copy the current node over to the new node and replace the first nibble in the key - newNode := EmptyStringSlice(17) - - for i := 0; i < 17; i++ { - cpy := currentNode.Get(i).Raw() - if cpy != nil { - newNode[i] = cpy - } - } - - newNode[key[0]] = t.InsertState(currentNode.Get(int(key[0])).Raw(), key[1:], value) - - return t.Put(newNode) - } - - panic("unexpected end") -} - -func (t *Trie) deleteState(node interface{}, key []byte) interface{} { - if len(key) == 0 { - return "" - } - - // New node - n := ethutil.NewValue(node) - //if node == nil || (n.Type() == reflect.String && (n.Str() == "" || n.Get(0).IsNil())) || n.Len() == 0 { - if node == nil || n.Len() == 0 { - //return nil - //fmt.Printf(" %x %d\n", n, len(n.Bytes())) - - return "" - } - - currentNode := t.getNode(node) - // Check for "special" 2 slice type node - if currentNode.Len() == 2 { - // Decode the key - k := CompactDecode(currentNode.Get(0).Str()) - v := currentNode.Get(1).Raw() - - // Matching key pair (ie. there's already an object with this key) - if bytes.Equal(k, key) { //CompareIntSlice(k, key) { - //fmt.Printf(" %x\n", v) - - return "" - } else if bytes.Equal(key[:len(k)], k) { //CompareIntSlice(key[:len(k)], k) { - hash := t.deleteState(v, key[len(k):]) - child := t.getNode(hash) - - var newNode []interface{} - if child.Len() == 2 { - newKey := append(k, CompactDecode(child.Get(0).Str())...) - newNode = []interface{}{CompactEncode(newKey), child.Get(1).Raw()} - } else { - newNode = []interface{}{currentNode.Get(0).Str(), hash} - } - - //fmt.Printf("%x\n", newNode) - - return t.Put(newNode) - } else { - return node - } - } else { - // Copy the current node over to the new node and replace the first nibble in the key - n := EmptyStringSlice(17) - var newNode []interface{} - - for i := 0; i < 17; i++ { - cpy := currentNode.Get(i).Raw() - if cpy != nil { - n[i] = cpy - } - } - - n[key[0]] = t.deleteState(n[key[0]], key[1:]) - amount := -1 - for i := 0; i < 17; i++ { - if n[i] != "" { - if amount == -1 { - amount = i - } else { - amount = -2 - } - } - } - if amount == 16 { - newNode = []interface{}{CompactEncode([]byte{16}), n[amount]} - } else if amount >= 0 { - child := t.getNode(n[amount]) - if child.Len() == 17 { - newNode = []interface{}{CompactEncode([]byte{byte(amount)}), n[amount]} - } else if child.Len() == 2 { - key := append([]byte{byte(amount)}, CompactDecode(child.Get(0).Str())...) - newNode = []interface{}{CompactEncode(key), child.Get(1).Str()} - } - - } else { - newNode = n - } - - //fmt.Printf("%x\n", newNode) - return t.Put(newNode) - } - - panic("unexpected return") -} - -type TrieIterator struct { - trie *Trie - key string - value string - - shas [][]byte - values []string - - lastNode []byte -} - -func (t *Trie) NewIterator() *TrieIterator { - return &TrieIterator{trie: t} -} - func (self *Trie) Iterator() *Iterator { return NewIterator(self) } -// Some time in the near future this will need refactoring :-) -// XXX Note to self, IsSlice == inline node. Str == sha3 to node -func (it *TrieIterator) workNode(currentNode *ethutil.Value) { - if currentNode.Len() == 2 { - k := CompactDecode(currentNode.Get(0).Str()) +func (self *Trie) Copy() *Trie { + return New(self.roothash, self.cache.backend) +} - if currentNode.Get(1).Str() == "" { - it.workNode(currentNode.Get(1)) +// Legacy support +func (self *Trie) Root() []byte { return self.Hash() } +func (self *Trie) Hash() []byte { + var hash []byte + if self.root != nil { + t := self.root.Hash() + if byts, ok := t.([]byte); ok && len(byts) > 0 { + hash = byts } else { - if k[len(k)-1] == 16 { - it.values = append(it.values, currentNode.Get(1).Str()) - } else { - it.shas = append(it.shas, currentNode.Get(1).Bytes()) - it.getNode(currentNode.Get(1).Bytes()) - } + hash = crypto.Sha3(ethutil.Encode(self.root.RlpData())) } } else { - for i := 0; i < currentNode.Len(); i++ { - if i == 16 && currentNode.Get(i).Len() != 0 { - it.values = append(it.values, currentNode.Get(i).Str()) - } else { - if currentNode.Get(i).Str() == "" { - it.workNode(currentNode.Get(i)) - } else { - val := currentNode.Get(i).Str() - if val != "" { - it.shas = append(it.shas, currentNode.Get(1).Bytes()) - it.getNode([]byte(val)) - } - } - } + hash = crypto.Sha3(ethutil.Encode("")) + } + + if !bytes.Equal(hash, self.roothash) { + self.revisions.PushBack(self.roothash) + self.roothash = hash + } + + return hash +} +func (self *Trie) Commit() { + self.mu.Lock() + defer self.mu.Unlock() + + // Hash first + self.Hash() + + self.cache.Flush() +} + +// Reset should only be called if the trie has been hashed +func (self *Trie) Reset() { + self.mu.Lock() + defer self.mu.Unlock() + + self.cache.Reset() + + if self.revisions.Len() > 0 { + revision := self.revisions.Remove(self.revisions.Back()).([]byte) + self.roothash = revision + } + value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash)) + self.root = self.mknode(value) +} + +func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } +func (self *Trie) Update(key, value []byte) Node { + self.mu.Lock() + defer self.mu.Unlock() + + k := CompactHexDecode(string(key)) + + if len(value) != 0 { + self.root = self.insert(self.root, k, &ValueNode{self, value}) + } else { + self.root = self.delete(self.root, k) + } + + return self.root +} + +func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } +func (self *Trie) Get(key []byte) []byte { + self.mu.Lock() + defer self.mu.Unlock() + + k := CompactHexDecode(string(key)) + + n := self.get(self.root, k) + if n != nil { + return n.(*ValueNode).Val() + } + + return nil +} + +func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } +func (self *Trie) Delete(key []byte) Node { + self.mu.Lock() + defer self.mu.Unlock() + + k := CompactHexDecode(string(key)) + self.root = self.delete(self.root, k) + + return self.root +} + +func (self *Trie) insert(node Node, key []byte, value Node) Node { + if len(key) == 0 { + return value + } + + if node == nil { + return NewShortNode(self, key, value) + } + + switch node := node.(type) { + case *ShortNode: + k := node.Key() + cnode := node.Value() + if bytes.Equal(k, key) { + return NewShortNode(self, key, value) } + + var n Node + matchlength := MatchingNibbleLength(key, k) + if matchlength == len(k) { + n = self.insert(cnode, key[matchlength:], value) + } else { + pnode := self.insert(nil, k[matchlength+1:], cnode) + nnode := self.insert(nil, key[matchlength+1:], value) + fulln := NewFullNode(self) + fulln.set(k[matchlength], pnode) + fulln.set(key[matchlength], nnode) + n = fulln + } + if matchlength == 0 { + return n + } + + return NewShortNode(self, key[:matchlength], n) + + case *FullNode: + cpy := node.Copy().(*FullNode) + cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) + + return cpy + + default: + panic(fmt.Sprintf("%T: invalid node: %v", node, node)) } } -func (it *TrieIterator) getNode(node []byte) { - currentNode := it.trie.cache.Get(node) - it.workNode(currentNode) -} +func (self *Trie) get(node Node, key []byte) Node { + if len(key) == 0 { + return node + } -func (it *TrieIterator) Collect() [][]byte { - if it.trie.Root == "" { + if node == nil { return nil } - it.getNode(ethutil.NewValue(it.trie.Root).Bytes()) + switch node := node.(type) { + case *ShortNode: + k := node.Key() + cnode := node.Value() - return it.shas -} - -func (it *TrieIterator) Purge() int { - shas := it.Collect() - for _, sha := range shas { - it.trie.cache.Delete(sha) - } - return len(it.values) -} - -func (it *TrieIterator) Key() string { - return "" -} - -func (it *TrieIterator) Value() string { - return "" -} - -type EachCallback func(key string, node *ethutil.Value) - -func (it *TrieIterator) Each(cb EachCallback) { - it.fetchNode(nil, ethutil.NewValue(it.trie.Root).Bytes(), cb) -} - -func (it *TrieIterator) fetchNode(key []byte, node []byte, cb EachCallback) { - it.iterateNode(key, it.trie.cache.Get(node), cb) -} - -func (it *TrieIterator) iterateNode(key []byte, currentNode *ethutil.Value, cb EachCallback) { - if currentNode.Len() == 2 { - k := CompactDecode(currentNode.Get(0).Str()) - - pk := append(key, k...) - if currentNode.Get(1).Len() != 0 && currentNode.Get(1).Str() == "" { - it.iterateNode(pk, currentNode.Get(1), cb) - } else { - if k[len(k)-1] == 16 { - cb(DecodeCompact(pk), currentNode.Get(1)) - } else { - it.fetchNode(pk, currentNode.Get(1).Bytes(), cb) - } + if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { + return self.get(cnode, key[len(k):]) } - } else { - for i := 0; i < currentNode.Len(); i++ { - pk := append(key, byte(i)) - if i == 16 && currentNode.Get(i).Len() != 0 { - cb(DecodeCompact(pk), currentNode.Get(i)) - } else { - if currentNode.Get(i).Len() != 0 && currentNode.Get(i).Str() == "" { - it.iterateNode(pk, currentNode.Get(i), cb) + + return nil + case *FullNode: + return self.get(node.branch(key[0]), key[1:]) + default: + panic(fmt.Sprintf("%T: invalid node: %v", node, node)) + } +} + +func (self *Trie) delete(node Node, key []byte) Node { + if len(key) == 0 && node == nil { + return nil + } + + switch node := node.(type) { + case *ShortNode: + k := node.Key() + cnode := node.Value() + if bytes.Equal(key, k) { + return nil + } else if bytes.Equal(key[:len(k)], k) { + child := self.delete(cnode, key[len(k):]) + + var n Node + switch child := child.(type) { + case *ShortNode: + nkey := append(k, child.Key()...) + n = NewShortNode(self, nkey, child.Value()) + case *FullNode: + sn := NewShortNode(self, node.Key(), child) + sn.key = node.key + n = sn + } + + return n + } else { + return node + } + + case *FullNode: + n := node.Copy().(*FullNode) + n.set(key[0], self.delete(n.branch(key[0]), key[1:])) + + pos := -1 + for i := 0; i < 17; i++ { + if n.branch(byte(i)) != nil { + if pos == -1 { + pos = i } else { - val := currentNode.Get(i).Str() - if val != "" { - it.fetchNode(pk, []byte(val), cb) - } + pos = -2 } } } + + var nnode Node + if pos == 16 { + nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) + } else if pos >= 0 { + cnode := n.branch(byte(pos)) + switch cnode := cnode.(type) { + case *ShortNode: + // Stitch keys + k := append([]byte{byte(pos)}, cnode.Key()...) + nnode = NewShortNode(self, k, cnode.Value()) + case *FullNode: + nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) + } + } else { + nnode = n + } + + return nnode + case nil: + return nil + default: + panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key)) } } -*/ + +// casting functions and cache storing +func (self *Trie) mknode(value *ethutil.Value) Node { + l := value.Len() + switch l { + case 0: + return nil + case 2: + // A value node may consists of 2 bytes. + if value.Get(0).Len() != 0 { + return NewShortNode(self, CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1))) + } + case 17: + fnode := NewFullNode(self) + for i := 0; i < l; i++ { + fnode.set(byte(i), self.mknode(value.Get(i))) + } + return fnode + case 32: + return &HashNode{value.Bytes()} + } + + return &ValueNode{self, value.Bytes()} +} + +func (self *Trie) trans(node Node) Node { + switch node := node.(type) { + case *HashNode: + value := ethutil.NewValueFromBytes(self.cache.Get(node.key)) + return self.mknode(value) + default: + return node + } +} + +func (self *Trie) store(node Node) interface{} { + data := ethutil.Encode(node) + if len(data) >= 32 { + key := crypto.Sha3(data) + self.cache.Put(key, data) + + return key + } + + return node.RlpData() +} + +func (self *Trie) PrintRoot() { + fmt.Println(self.root) +} diff --git a/trie/trie_test.go b/trie/trie_test.go index 3abe56040f..ffb78d4f2b 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1,345 +1,76 @@ package trie -/* import ( "bytes" - "encoding/hex" - "encoding/json" "fmt" - "io/ioutil" - "math/rand" - "net/http" "testing" - "time" - - checker "gopkg.in/check.v1" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethutil" ) -const LONG_WORD = "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ" +type Db map[string][]byte -type TrieSuite struct { - db *MemDatabase - trie *Trie +func (self Db) Get(k []byte) ([]byte, error) { return self[string(k)], nil } +func (self Db) Put(k, v []byte) { self[string(k)] = v } + +// Used for testing +func NewEmpty() *Trie { + return New(nil, make(Db)) } -type MemDatabase struct { - db map[string][]byte -} - -func NewMemDatabase() (*MemDatabase, error) { - db := &MemDatabase{db: make(map[string][]byte)} - return db, nil -} -func (db *MemDatabase) Put(key []byte, value []byte) { - db.db[string(key)] = value -} -func (db *MemDatabase) Get(key []byte) ([]byte, error) { - return db.db[string(key)], nil -} -func (db *MemDatabase) Delete(key []byte) error { - delete(db.db, string(key)) - return nil -} -func (db *MemDatabase) Print() {} -func (db *MemDatabase) Close() {} -func (db *MemDatabase) LastKnownTD() []byte { return nil } - -func NewTrie() (*MemDatabase, *Trie) { - db, _ := NewMemDatabase() - return db, New(db, "") -} - -func (s *TrieSuite) SetUpTest(c *checker.C) { - s.db, s.trie = NewTrie() -} - -func (s *TrieSuite) TestTrieSync(c *checker.C) { - s.trie.Update("dog", LONG_WORD) - c.Assert(s.db.db, checker.HasLen, 0, checker.Commentf("Expected no data in database")) - s.trie.Sync() - c.Assert(s.db.db, checker.HasLen, 3) -} - -func (s *TrieSuite) TestTrieDirtyTracking(c *checker.C) { - s.trie.Update("dog", LONG_WORD) - c.Assert(s.trie.cache.IsDirty, checker.Equals, true, checker.Commentf("Expected no data in database")) - - s.trie.Sync() - c.Assert(s.trie.cache.IsDirty, checker.Equals, false, checker.Commentf("Expected trie to be dirty")) - - s.trie.Update("test", LONG_WORD) - s.trie.cache.Undo() - c.Assert(s.trie.cache.IsDirty, checker.Equals, false) -} - -func (s *TrieSuite) TestTrieReset(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - c.Assert(s.trie.cache.nodes, checker.HasLen, 1, checker.Commentf("Expected cached nodes")) - - s.trie.cache.Undo() - c.Assert(s.trie.cache.nodes, checker.HasLen, 0, checker.Commentf("Expected no nodes after undo")) -} - -func (s *TrieSuite) TestTrieGet(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - x := s.trie.Get("cat") - c.Assert(x, checker.DeepEquals, LONG_WORD) -} - -func (s *TrieSuite) TestTrieUpdating(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - s.trie.Update("cat", LONG_WORD+"1") - x := s.trie.Get("cat") - c.Assert(x, checker.DeepEquals, LONG_WORD+"1") -} - -func (s *TrieSuite) TestTrieCmp(c *checker.C) { - _, trie1 := NewTrie() - _, trie2 := NewTrie() - - trie1.Update("doge", LONG_WORD) - trie2.Update("doge", LONG_WORD) - c.Assert(trie1, checker.DeepEquals, trie2) - - trie1.Update("dog", LONG_WORD) - trie2.Update("cat", LONG_WORD) - c.Assert(trie1, checker.Not(checker.DeepEquals), trie2) -} - -func (s *TrieSuite) TestTrieDelete(c *checker.C) { - s.trie.Update("cat", LONG_WORD) - exp := s.trie.Root - s.trie.Update("dog", LONG_WORD) - s.trie.Delete("dog") - c.Assert(s.trie.Root, checker.DeepEquals, exp) - - s.trie.Update("dog", LONG_WORD) - exp = s.trie.Root - s.trie.Update("dude", LONG_WORD) - s.trie.Delete("dude") - c.Assert(s.trie.Root, checker.DeepEquals, exp) -} - -func (s *TrieSuite) TestTrieDeleteWithValue(c *checker.C) { - s.trie.Update("c", LONG_WORD) - exp := s.trie.Root - s.trie.Update("ca", LONG_WORD) - s.trie.Update("cat", LONG_WORD) - s.trie.Delete("ca") - s.trie.Delete("cat") - c.Assert(s.trie.Root, checker.DeepEquals, exp) -} - -func (s *TrieSuite) TestTriePurge(c *checker.C) { - s.trie.Update("c", LONG_WORD) - s.trie.Update("ca", LONG_WORD) - s.trie.Update("cat", LONG_WORD) - - lenBefore := len(s.trie.cache.nodes) - it := s.trie.NewIterator() - num := it.Purge() - c.Assert(num, checker.Equals, 3) - c.Assert(len(s.trie.cache.nodes), checker.Equals, lenBefore) -} - -func h(str string) string { - d, err := hex.DecodeString(str) - if err != nil { - panic(err) - } - - return string(d) -} - -func get(in string) (out string) { - if len(in) > 2 && in[:2] == "0x" { - out = h(in[2:]) - } else { - out = in - } - - return -} - -type TrieTest struct { - Name string - In map[string]string - Root string -} - -func CreateTest(name string, data []byte) (TrieTest, error) { - t := TrieTest{Name: name} - err := json.Unmarshal(data, &t) - if err != nil { - return TrieTest{}, fmt.Errorf("%v", err) - } - - return t, nil -} - -func CreateTests(uri string, cb func(TrieTest)) map[string]TrieTest { - resp, err := http.Get(uri) - if err != nil { - panic(err) - } - defer resp.Body.Close() - - data, err := ioutil.ReadAll(resp.Body) - - var objmap map[string]*json.RawMessage - err = json.Unmarshal(data, &objmap) - if err != nil { - panic(err) - } - - tests := make(map[string]TrieTest) - for name, testData := range objmap { - test, err := CreateTest(name, *testData) - if err != nil { - panic(err) - } - - if cb != nil { - cb(test) - } - tests[name] = test - } - - return tests -} - -func RandomData() [][]string { - data := [][]string{ - {"0x000000000000000000000000ec4f34c97e43fbb2816cfd95e388353c7181dab1", "0x4e616d6552656700000000000000000000000000000000000000000000000000"}, - {"0x0000000000000000000000000000000000000000000000000000000000000045", "0x22b224a1420a802ab51d326e29fa98e34c4f24ea"}, - {"0x0000000000000000000000000000000000000000000000000000000000000046", "0x67706c2076330000000000000000000000000000000000000000000000000000"}, - {"0x000000000000000000000000697c7b8c961b56f675d570498424ac8de1a918f6", "0x6f6f6f6820736f2067726561742c207265616c6c6c793f000000000000000000"}, - {"0x0000000000000000000000007ef9e639e2733cb34e4dfc576d4b23f72db776b2", "0x4655474156000000000000000000000000000000000000000000000000000000"}, - {"0x6f6f6f6820736f2067726561742c207265616c6c6c793f000000000000000000", "0x697c7b8c961b56f675d570498424ac8de1a918f6"}, - {"0x4655474156000000000000000000000000000000000000000000000000000000", "0x7ef9e639e2733cb34e4dfc576d4b23f72db776b2"}, - {"0x4e616d6552656700000000000000000000000000000000000000000000000000", "0xec4f34c97e43fbb2816cfd95e388353c7181dab1"}, - } - - var c [][]string - for len(data) != 0 { - e := rand.Intn(len(data)) - c = append(c, data[e]) - - copy(data[e:], data[e+1:]) - data[len(data)-1] = nil - data = data[:len(data)-1] - } - - return c -} - -const MaxTest = 1000 - -// This test insert data in random order and seeks to find indifferences between the different tries -func (s *TrieSuite) TestRegression(c *checker.C) { - rand.Seed(time.Now().Unix()) - - roots := make(map[string]int) - for i := 0; i < MaxTest; i++ { - _, trie := NewTrie() - data := RandomData() - - for _, test := range data { - trie.Update(test[0], test[1]) - } - trie.Delete("0x4e616d6552656700000000000000000000000000000000000000000000000000") - - roots[string(trie.Root.([]byte))] += 1 - } - - c.Assert(len(roots) <= 1, checker.Equals, true) - // if len(roots) > 1 { - // for root, num := range roots { - // t.Errorf("%x => %d\n", root, num) - // } - // } -} - -func (s *TrieSuite) TestDelete(c *checker.C) { - s.trie.Update("a", "jeffreytestlongstring") - s.trie.Update("aa", "otherstring") - s.trie.Update("aaa", "othermorestring") - s.trie.Update("aabbbbccc", "hithere") - s.trie.Update("abbcccdd", "hstanoehutnaheoustnh") - s.trie.Update("rnthaoeuabbcccdd", "hstanoehutnaheoustnh") - s.trie.Update("rneuabbcccdd", "hstanoehutnaheoustnh") - s.trie.Update("rneuabboeusntahoeucccdd", "hstanoehutnaheoustnh") - s.trie.Update("rnxabboeusntahoeucccdd", "hstanoehutnaheoustnh") - s.trie.Delete("aaboaestnuhbccc") - s.trie.Delete("a") - s.trie.Update("a", "nthaonethaosentuh") - s.trie.Update("c", "shtaosntehua") - s.trie.Delete("a") - s.trie.Update("aaaa", "testmegood") - - _, t2 := NewTrie() - s.trie.NewIterator().Each(func(key string, v *ethutil.Value) { - if key == "aaaa" { - t2.Update(key, v.Str()) - } else { - t2.Update(key, v.Str()) - } - }) - - a := ethutil.NewValue(s.trie.Root).Bytes() - b := ethutil.NewValue(t2.Root).Bytes() - - c.Assert(a, checker.DeepEquals, b) -} - -func (s *TrieSuite) TestTerminator(c *checker.C) { - key := CompactDecode("hello") - c.Assert(HasTerm(key), checker.Equals, true, checker.Commentf("Expected %v to have a terminator", key)) -} - -func (s *TrieSuite) TestIt(c *checker.C) { - s.trie.Update("cat", "cat") - s.trie.Update("doge", "doge") - s.trie.Update("wallace", "wallace") - it := s.trie.Iterator() - - inputs := []struct { - In, Out string - }{ - {"", "cat"}, - {"bobo", "cat"}, - {"c", "cat"}, - {"car", "cat"}, - {"catering", "doge"}, - {"w", "wallace"}, - {"wallace123", ""}, - } - - for _, test := range inputs { - res := string(it.Next(test.In)) - c.Assert(res, checker.Equals, test.Out) +func TestEmptyTrie(t *testing.T) { + trie := NewEmpty() + res := trie.Hash() + exp := crypto.Sha3(ethutil.Encode("")) + if !bytes.Equal(res, exp) { + t.Errorf("expected %x got %x", exp, res) } } -func (s *TrieSuite) TestBeginsWith(c *checker.C) { - a := CompactDecode("hello") - b := CompactDecode("hel") +func TestInsert(t *testing.T) { + trie := NewEmpty() - c.Assert(BeginsWith(a, b), checker.Equals, false) - c.Assert(BeginsWith(b, a), checker.Equals, true) + trie.UpdateString("doe", "reindeer") + trie.UpdateString("dog", "puppy") + trie.UpdateString("dogglesworth", "cat") + + exp := ethutil.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") + root := trie.Hash() + if !bytes.Equal(root, exp) { + t.Errorf("exp %x got %x", exp, root) + } + + trie = NewEmpty() + trie.UpdateString("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + + exp = ethutil.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") + root = trie.Hash() + if !bytes.Equal(root, exp) { + t.Errorf("exp %x got %x", exp, root) + } } -func (s *TrieSuite) TestItems(c *checker.C) { - s.trie.Update("A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") - exp := "d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab" +func TestGet(t *testing.T) { + trie := NewEmpty() - c.Assert(s.trie.GetRoot(), checker.DeepEquals, ethutil.Hex2Bytes(exp)) + trie.UpdateString("doe", "reindeer") + trie.UpdateString("dog", "puppy") + trie.UpdateString("dogglesworth", "cat") + + res := trie.GetString("dog") + if !bytes.Equal(res, []byte("puppy")) { + t.Errorf("expected puppy got %x", res) + } + + unknown := trie.GetString("unknown") + if unknown != nil { + t.Errorf("expected nil got %x", unknown) + } } -func TestOtherSomething(t *testing.T) { - _, trie := NewTrie() +func TestDelete(t *testing.T) { + trie := NewEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, @@ -352,18 +83,46 @@ func TestOtherSomething(t *testing.T) { {"shaman", ""}, } for _, val := range vals { - trie.Update(val.k, val.v) + if val.v != "" { + trie.UpdateString(val.k, val.v) + } else { + trie.DeleteString(val.k) + } } + hash := trie.Hash() exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") - hash := trie.Root.([]byte) if !bytes.Equal(hash, exp) { t.Errorf("expected %x got %x", exp, hash) } } -func BenchmarkGets(b *testing.B) { - _, trie := NewTrie() +func TestEmptyValues(t *testing.T) { + trie := NewEmpty() + + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + trie.UpdateString(val.k, val.v) + } + + hash := trie.Hash() + exp := ethutil.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if !bytes.Equal(hash, exp) { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestReplication(t *testing.T) { + trie := NewEmpty() vals := []struct{ k, v string }{ {"do", "verb"}, {"ether", "wookiedoo"}, @@ -376,21 +135,125 @@ func BenchmarkGets(b *testing.B) { {"somethingveryoddindeedthis is", "myothernodedata"}, } for _, val := range vals { - trie.Update(val.k, val.v) + trie.UpdateString(val.k, val.v) + } + trie.Commit() + + trie2 := New(trie.roothash, trie.cache.backend) + if string(trie2.GetString("horse")) != "stallion" { + t.Error("expected to have horse => stallion") + } + + hash := trie2.Hash() + exp := trie.Hash() + if !bytes.Equal(hash, exp) { + t.Errorf("root failure. expected %x got %x", exp, hash) + } + +} + +func TestReset(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + } + for _, val := range vals { + trie.UpdateString(val.k, val.v) + } + trie.Commit() + + before := ethutil.CopyBytes(trie.roothash) + trie.UpdateString("should", "revert") + trie.Hash() + // Should have no effect + trie.Hash() + trie.Hash() + // ### + + trie.Reset() + after := ethutil.CopyBytes(trie.roothash) + + if !bytes.Equal(before, after) { + t.Errorf("expected roots to be equal. %x - %x", before, after) + } +} + +func TestParanoia(t *testing.T) { + t.Skip() + trie := NewEmpty() + + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + for _, val := range vals { + trie.UpdateString(val.k, val.v) + } + trie.Commit() + + ok, t2 := ParanoiaCheck(trie, trie.cache.backend) + if !ok { + t.Errorf("trie paranoia check failed %x %x", trie.roothash, t2.roothash) + } +} + +// Not an actual test +func TestOutput(t *testing.T) { + t.Skip() + + base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + trie := NewEmpty() + for i := 0; i < 50; i++ { + trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") + } + fmt.Println("############################## FULL ################################") + fmt.Println(trie.root) + + trie.Commit() + fmt.Println("############################## SMALL ################################") + trie2 := New(trie.roothash, trie.cache.backend) + trie2.GetString(base + "20") + fmt.Println(trie2.root) +} + +func BenchmarkGets(b *testing.B) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + for _, val := range vals { + trie.UpdateString(val.k, val.v) } b.ResetTimer() for i := 0; i < b.N; i++ { - trie.Get("horse") + trie.Get([]byte("horse")) } } func BenchmarkUpdate(b *testing.B) { - _, trie := NewTrie() + trie := NewEmpty() b.ResetTimer() for i := 0; i < b.N; i++ { - trie.Update(fmt.Sprintf("aaaaaaaaaaaaaaa%d", i), "value") + trie.UpdateString(fmt.Sprintf("aaaaaaaaa%d", i), "value") } + trie.Hash() } -*/ diff --git a/ptrie/valuenode.go b/trie/valuenode.go similarity index 97% rename from ptrie/valuenode.go rename to trie/valuenode.go index c593eb6c60..689befb2ad 100644 --- a/ptrie/valuenode.go +++ b/trie/valuenode.go @@ -1,4 +1,4 @@ -package ptrie +package trie type ValueNode struct { trie *Trie diff --git a/types/ethereum.go b/types/ethereum.go deleted file mode 100644 index ab1254f4c2..0000000000 --- a/types/ethereum.go +++ /dev/null @@ -1 +0,0 @@ -package types