diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go index 459373eeae..48aa830e77 100644 --- a/ethdb/memory_database.go +++ b/ethdb/memory_database.go @@ -23,6 +23,10 @@ func (db *MemDatabase) Put(key []byte, value []byte) { db.db[string(key)] = value } +func (db *MemDatabase) Set(key []byte, value []byte) { + db.Put(key, value) +} + func (db *MemDatabase) Get(key []byte) ([]byte, error) { return db.db[string(key)], nil } diff --git a/ptrie/fullnode.go b/ptrie/fullnode.go index eaa4611b66..7a7f7d22d4 100644 --- a/ptrie/fullnode.go +++ b/ptrie/fullnode.go @@ -18,7 +18,14 @@ func (self *FullNode) Branches() []Node { return self.nodes[:16] } -func (self *FullNode) Copy() Node { return self } +func (self *FullNode) Copy() Node { + nnode := NewFullNode(self.trie) + for i, node := range self.nodes { + nnode.nodes[i] = node + } + + return nnode +} // Returns the length of non-nil nodes func (self *FullNode) Len() (amount int) { diff --git a/ptrie/shortnode.go b/ptrie/shortnode.go index 49319c555d..73ff2914bf 100644 --- a/ptrie/shortnode.go +++ b/ptrie/shortnode.go @@ -17,7 +17,7 @@ func (self *ShortNode) Value() Node { return self.value } func (self *ShortNode) Dirty() bool { return true } -func (self *ShortNode) Copy() Node { return self } +func (self *ShortNode) Copy() Node { return NewShortNode(self.trie, self.key, self.value) } func (self *ShortNode) RlpData() interface{} { return []interface{}{self.key, self.value.Hash()} diff --git a/ptrie/trie.go b/ptrie/trie.go index bb2b3845ad..687126aef8 100644 --- a/ptrie/trie.go +++ b/ptrie/trie.go @@ -2,6 +2,7 @@ package ptrie import ( "bytes" + "container/list" "sync" "github.com/ethereum/go-ethereum/crypto" @@ -14,33 +15,61 @@ type Backend interface { Set([]byte, []byte) } -type Cache map[string][]byte - -func (self Cache) Get(key []byte) []byte { - return self[string(key)] +type Cache struct { + store map[string][]byte + backend Backend } -func (self Cache) Set(key []byte, data []byte) { - self[string(key)] = data + +func NewCache(backend Backend) *Cache { + return &Cache{make(map[string][]byte), backend} +} + +func (self *Cache) Get(key []byte) []byte { + data := self.store[string(key)] + if data == nil { + data = self.backend.Get(key) + } + + return data +} + +func (self *Cache) Set(key []byte, data []byte) { + self.store[string(key)] = data +} + +func (self *Cache) Flush() { + for k, v := range self.store { + self.backend.Set([]byte(k), v) + } + + // This will eventually grow too large. We'd could + // do a make limit on storage and push out not-so-popular nodes. + //self.Reset() +} + +func (self *Cache) Reset() { + self.store = make(map[string][]byte) } type Trie struct { mu sync.Mutex root Node roothash []byte - backend Backend -} + cache *Cache -func NewEmpty() *Trie { - return &Trie{sync.Mutex{}, nil, nil, make(Cache)} + revisions *list.List } func New(root []byte, backend Backend) *Trie { trie := &Trie{} + trie.revisions = list.New() trie.roothash = root - trie.backend = backend + trie.cache = NewCache(backend) - value := ethutil.NewValueFromBytes(trie.backend.Get(root)) - trie.root = trie.mknode(value) + if root != nil { + value := ethutil.NewValueFromBytes(trie.cache.Get(root)) + trie.root = trie.mknode(value) + } return trie } @@ -64,10 +93,28 @@ func (self *Trie) Hash() []byte { hash = crypto.Sha3(ethutil.Encode(self.root)) } - self.roothash = hash + if !bytes.Equal(hash, self.roothash) { + self.revisions.PushBack(self.roothash) + self.roothash = hash + } return hash } +func (self *Trie) Commit() { + // Hash first + self.Hash() + + self.cache.Flush() +} + +func (self *Trie) Reset() { + self.cache.Reset() + + revision := self.revisions.Remove(self.revisions.Back()).([]byte) + self.roothash = revision + value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash)) + self.root = self.mknode(value) +} func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } func (self *Trie) Update(key, value []byte) Node { @@ -272,7 +319,7 @@ func (self *Trie) mknode(value *ethutil.Value) Node { func (self *Trie) trans(node Node) Node { switch node := node.(type) { case *HashNode: - value := ethutil.NewValueFromBytes(self.backend.Get(node.key)) + value := ethutil.NewValueFromBytes(self.cache.Get(node.key)) return self.mknode(value) default: return node @@ -283,7 +330,7 @@ func (self *Trie) store(node Node) interface{} { data := ethutil.Encode(node) if len(data) >= 32 { key := crypto.Sha3(data) - self.backend.Set(key, data) + self.cache.Set(key, data) return key } diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go index 6af6e1b406..478f59c602 100644 --- a/ptrie/trie_test.go +++ b/ptrie/trie_test.go @@ -8,6 +8,16 @@ import ( "github.com/ethereum/go-ethereum/ethutil" ) +type Db map[string][]byte + +func (self Db) Get(k []byte) []byte { return self[string(k)] } +func (self Db) Set(k, v []byte) { self[string(k)] = v } + +// Used for testing +func NewEmpty() *Trie { + return New(nil, make(Db)) +} + func TestInsert(t *testing.T) { trie := NewEmpty() @@ -91,7 +101,7 @@ func TestReplication(t *testing.T) { } trie.Hash() - trie2 := New(trie.roothash, trie.backend) + trie2 := New(trie.roothash, trie.cache) if string(trie2.GetString("horse")) != "stallion" { t.Error("expected to have harse => stallion") } @@ -104,6 +114,53 @@ func TestReplication(t *testing.T) { } +func TestReset(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + } + for _, val := range vals { + trie.UpdateString(val.k, val.v) + } + trie.Commit() + + before := ethutil.CopyBytes(trie.roothash) + trie.UpdateString("should", "revert") + trie.Hash() + // Should have no effect + trie.Hash() + trie.Hash() + // ### + + trie.Reset() + after := ethutil.CopyBytes(trie.roothash) + + if !bytes.Equal(before, after) { + t.Errorf("expected roots to be equal. %x - %x", before, after) + } +} + +// Not actual test +func TestOutput(t *testing.T) { + t.Skip() + + base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + trie := NewEmpty() + for i := 0; i < 50; i++ { + trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") + } + trie.Hash() + fmt.Println("############################## FULL ################################") + fmt.Println(trie.root) + + trie2 := New(trie.roothash, trie.cache) + trie2.GetString(base + "20") + fmt.Println("############################## SMALL ################################") + fmt.Println(trie2.root) +} + func BenchmarkGets(b *testing.B) { trie := NewEmpty() vals := []struct{ k, v string }{ @@ -136,22 +193,3 @@ func BenchmarkUpdate(b *testing.B) { } trie.Hash() } - -// Not actual test -func TestOutput(t *testing.T) { - t.Skip() - - base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - trie := NewEmpty() - for i := 0; i < 50; i++ { - trie.UpdateString(fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") - } - trie.Hash() - fmt.Println("############################## FULL ################################") - fmt.Println(trie.root) - - trie2 := New(trie.roothash, trie.backend) - trie2.GetString(base + "20") - fmt.Println("############################## SMALL ################################") - fmt.Println(trie2.root) -} diff --git a/ptrie/valuenode.go b/ptrie/valuenode.go index c226621a70..c593eb6c60 100644 --- a/ptrie/valuenode.go +++ b/ptrie/valuenode.go @@ -8,6 +8,6 @@ type ValueNode struct { func (self *ValueNode) Value() Node { return self } // Best not to call :-) func (self *ValueNode) Val() []byte { return self.data } func (self *ValueNode) Dirty() bool { return true } -func (self *ValueNode) Copy() Node { return self } +func (self *ValueNode) Copy() Node { return &ValueNode{self.trie, self.data} } func (self *ValueNode) RlpData() interface{} { return self.data } func (self *ValueNode) Hash() interface{} { return self.data }