diff --git a/core/state/statedb.go b/core/state/statedb.go index 002fa62496..de9fb367d9 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -453,7 +453,7 @@ func (self *StateDB) Copy() *StateDB { // Copy all the basic fields, initialize the memory ones state := &StateDB{ db: self.db, - trie: self.trie, + trie: self.db.CopyTrie(self.trie), stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)), stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), refund: new(big.Int).Set(self.refund), diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index b2bd18e65e..e9944cd745 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -117,6 +117,57 @@ func TestIntermediateLeaks(t *testing.T) { } } +// TestCopy tests that copying a statedb object indeed makes the original and +// the copy independent of each other. This test is a regression test against +// https://github.com/ethereum/go-ethereum/pull/15549. +func TestCopy(t *testing.T) { + // Create a random state test to copy and modify "independently" + mem, _ := ethdb.NewMemDatabase() + orig, _ := New(common.Hash{}, NewDatabase(mem)) + + for i := byte(0); i < 255; i++ { + obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + obj.AddBalance(big.NewInt(int64(i))) + orig.updateStateObject(obj) + } + orig.Finalise(false) + + // Copy the state, modify both in-memory + copy := orig.Copy() + + for i := byte(0); i < 255; i++ { + origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + + origObj.AddBalance(big.NewInt(2 * int64(i))) + copyObj.AddBalance(big.NewInt(3 * int64(i))) + + orig.updateStateObject(origObj) + copy.updateStateObject(copyObj) + } + // Finalise the changes on both concurrently + done := make(chan struct{}) + go func() { + orig.Finalise(true) + close(done) + }() + copy.Finalise(true) + <-done + + // Verify that the two states have been updated independently + for i := byte(0); i < 255; i++ { + origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + + if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { + t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) + } + if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) + } + } +} + func TestSnapshotRandom(t *testing.T) { config := &quick.Config{MaxCount: 1000} err := quick.Check((*snapshotTest).run, config)