From 1f1ea18b5414bea22332bb4fce53cc95b5c6a07d Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 4 Oct 2016 12:36:02 +0200 Subject: [PATCH 1/3] core/state: implement reverts by journaling all changes This commit replaces the deep-copy based state revert mechanism with a linear complexity journal. This commit also hides several internal StateDB methods to limit the number of ways in which calling code can use the journal incorrectly. As usual consultation and bug fixes to the initial implementation were provided by @karalabe, @obscuren and @Arachnid. Thank you! --- accounts/abi/bind/backends/simulated.go | 6 +- cmd/evm/main.go | 32 +-- core/chain_makers.go | 2 +- core/execution.go | 8 +- core/state/dump.go | 2 +- core/state/journal.go | 117 +++++++++ core/state/managed_state_test.go | 7 +- core/state/state_object.go | 54 +++- core/state/state_test.go | 22 +- core/state/statedb.go | 212 ++++++++++------ core/state/statedb_test.go | 317 +++++++++++++++++++++--- core/state/sync_test.go | 2 +- core/tx_pool.go | 2 +- core/vm/environment.go | 4 +- core/vm/jit_test.go | 4 +- core/vm/runtime/env.go | 8 +- core/vm_env.go | 8 +- eth/api_backend.go | 6 +- internal/ethapi/tracer_test.go | 16 +- light/state_test.go | 14 +- miner/worker.go | 6 +- tests/state_test_util.go | 22 +- tests/util.go | 34 +-- tests/vm_test_util.go | 18 +- 24 files changed, 670 insertions(+), 253 deletions(-) create mode 100644 core/state/journal.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7e09abb119..74203a468d 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { b.mu.Lock() defer b.mu.Unlock() + defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) + rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) return rval, err } @@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) { b.mu.Lock() defer b.mu.Unlock() + defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) + _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) return gas, err } diff --git a/cmd/evm/main.go b/cmd/evm/main.go index 09ade1577d..22707c1cc3 100644 --- a/cmd/evm/main.go +++ b/cmd/evm/main.go @@ -227,22 +227,22 @@ type ruleSet struct{} func (ruleSet) IsHomestead(*big.Int) bool { return true } -func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } -func (self *VMEnv) Vm() vm.Vm { return self.evm } -func (self *VMEnv) Db() vm.Database { return self.state } -func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } -func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } -func (self *VMEnv) Origin() common.Address { return *self.transactor } -func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } -func (self *VMEnv) Coinbase() common.Address { return *self.transactor } -func (self *VMEnv) Time() *big.Int { return self.time } -func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } -func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } -func (self *VMEnv) Value() *big.Int { return self.value } -func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } -func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } -func (self *VMEnv) Depth() int { return 0 } -func (self *VMEnv) SetDepth(i int) { self.depth = i } +func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } +func (self *VMEnv) Vm() vm.Vm { return self.evm } +func (self *VMEnv) Db() vm.Database { return self.state } +func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() } +func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) } +func (self *VMEnv) Origin() common.Address { return *self.transactor } +func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } +func (self *VMEnv) Coinbase() common.Address { return *self.transactor } +func (self *VMEnv) Time() *big.Int { return self.time } +func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } +func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } +func (self *VMEnv) Value() *big.Int { return self.value } +func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } +func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } +func (self *VMEnv) Depth() int { return 0 } +func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) GetHash(n uint64) common.Hash { if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 { return self.block.Hash() diff --git a/core/chain_makers.go b/core/chain_makers.go index 0b9a5f75d3..e3ad9cda08 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) { // TxNonce returns the next valid transaction nonce for the // account at addr. It panics if the account does not exist. func (b *BlockGen) TxNonce(addr common.Address) uint64 { - if !b.statedb.HasAccount(addr) { + if !b.statedb.Exist(addr) { panic("account does not exist") } return b.statedb.GetNonce(addr) diff --git a/core/execution.go b/core/execution.go index 1bc02f7fb2..1cb507ee78 100644 --- a/core/execution.go +++ b/core/execution.go @@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A createAccount = true } - snapshotPreTransfer := env.MakeSnapshot() + snapshotPreTransfer := env.SnapshotDatabase() var ( from = env.Db().GetAccount(caller.Address()) to vm.Account @@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { contract.UseGas(contract.Gas) - env.SetSnapshot(snapshotPreTransfer) + env.RevertToSnapshot(snapshotPreTransfer) } return ret, addr, err @@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA return nil, common.Address{}, vm.DepthError } - snapshot := env.MakeSnapshot() + snapshot := env.SnapshotDatabase() var to vm.Account if !env.Db().Exist(*toAddr) { @@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA if err != nil { contract.UseGas(contract.Gas) - env.SetSnapshot(snapshot) + env.RevertToSnapshot(snapshot) } return ret, addr, err diff --git a/core/state/dump.go b/core/state/dump.go index 58ecd852b7..8294d61b9f 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump { panic(err) } - obj := NewObject(common.BytesToAddress(addr), data, nil) + obj := newObject(nil, common.BytesToAddress(addr), data, nil) account := DumpAccount{ Balance: data.Balance.String(), Nonce: data.Nonce, diff --git a/core/state/journal.go b/core/state/journal.go new file mode 100644 index 0000000000..540ade6fb7 --- /dev/null +++ b/core/state/journal.go @@ -0,0 +1,117 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +type journalEntry interface { + undo(*StateDB) +} + +type journal []journalEntry + +type ( + // Changes to the account trie. + createObjectChange struct { + account *common.Address + } + resetObjectChange struct { + prev *StateObject + } + deleteAccountChange struct { + account *common.Address + prev bool // whether account had already suicided + prevbalance *big.Int + } + + // Changes to individual accounts. + balanceChange struct { + account *common.Address + prev *big.Int + } + nonceChange struct { + account *common.Address + prev uint64 + } + storageChange struct { + account *common.Address + key, prevalue common.Hash + } + codeChange struct { + account *common.Address + prevcode, prevhash []byte + } + + // Changes to other state values. + refundChange struct { + prev *big.Int + } + addLogChange struct { + txhash common.Hash + } +) + +func (ch createObjectChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).deleted = true + delete(s.stateObjects, *ch.account) + delete(s.stateObjectsDirty, *ch.account) +} + +func (ch resetObjectChange) undo(s *StateDB) { + s.setStateObject(ch.prev) +} + +func (ch deleteAccountChange) undo(s *StateDB) { + obj := s.GetStateObject(*ch.account) + if obj != nil { + obj.remove = ch.prev + obj.setBalance(ch.prevbalance) + } +} + +func (ch balanceChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setBalance(ch.prev) +} + +func (ch nonceChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setNonce(ch.prev) +} + +func (ch codeChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) +} + +func (ch storageChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue) +} + +func (ch refundChange) undo(s *StateDB) { + s.refund = ch.prev +} + +func (ch addLogChange) undo(s *StateDB) { + logs := s.logs[ch.txhash] + if len(logs) == 1 { + delete(s.logs, ch.txhash) + } else { + s.logs[ch.txhash] = logs[:len(logs)-1] + } +} diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index baa53428f9..3f7bc2aa85 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -29,11 +29,8 @@ func create() (*ManagedState, *account) { db, _ := ethdb.NewMemDatabase() statedb, _ := New(common.Hash{}, db) ms := ManageState(statedb) - so := &StateObject{address: addr} - so.SetNonce(100) - ms.StateDB.stateObjects[addr] = so - ms.accounts[addr] = newAccount(so) - + ms.StateDB.SetNonce(addr, 100) + ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr)) return ms, ms.accounts[addr] } diff --git a/core/state/state_object.go b/core/state/state_object.go index cbd50e2a36..31ff9bcd81 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -66,6 +66,7 @@ func (self Storage) Copy() Storage { type StateObject struct { address common.Address // Ethereum address of this account data Account + db *StateDB // DB error. // State objects are used by the consensus core and VM which are @@ -99,15 +100,15 @@ type Account struct { CodeHash []byte } -// NewObject creates a state object. -func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { +// newObject creates a state object. +func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { if data.Balance == nil { data.Balance = new(big.Int) } if data.CodeHash == nil { data.CodeHash = emptyCodeHash } - return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} + return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} } // EncodeRLP implements rlp.Encoder. @@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) { } } -func (self *StateObject) MarkForDeletion() { +func (self *StateObject) markForDeletion() { self.remove = true if self.onDirty != nil { self.onDirty(self.Address()) @@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash } // SetState updates a value in account storage. -func (self *StateObject) SetState(key, value common.Hash) { +func (self *StateObject) SetState(db trie.Database, key, value common.Hash) { + self.db.journal = append(self.db.journal, storageChange{ + account: &self.address, + key: key, + prevalue: self.GetState(db, key), + }) + self.setState(key, value) +} + +func (self *StateObject) setState(key, value common.Hash) { self.cachedStorage[key] = value self.dirtyStorage[key] = value @@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) { } // UpdateRoot sets the trie root to the current root hash of -func (self *StateObject) UpdateRoot(db trie.Database) { +func (self *StateObject) updateRoot(db trie.Database) { self.updateTrie(db) self.data.Root = self.trie.Hash() } @@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) { } func (self *StateObject) SetBalance(amount *big.Int) { + self.db.journal = append(self.db.journal, balanceChange{ + account: &self.address, + prev: new(big.Int).Set(self.data.Balance), + }) + self.setBalance(amount) +} + +func (self *StateObject) setBalance(amount *big.Int) { self.data.Balance = amount if self.onDirty != nil { self.onDirty(self.Address()) @@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) { // Return the gas back to the origin. Used by the Virtual machine or Closures func (c *StateObject) ReturnGas(gas, price *big.Int) {} -func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { - stateObject := NewObject(self.address, self.data, onDirty) +func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject { + stateObject := newObject(db, self.address, self.data, onDirty) stateObject.trie = self.trie stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() @@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte { } func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := self.Code(self.db.db) + self.db.journal = append(self.db.journal, codeChange{ + account: &self.address, + prevhash: self.CodeHash(), + prevcode: prevcode, + }) + self.setCode(codeHash, code) +} + +func (self *StateObject) setCode(codeHash common.Hash, code []byte) { self.code = code self.data.CodeHash = codeHash[:] self.dirtyCode = true @@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { } func (self *StateObject) SetNonce(nonce uint64) { + self.db.journal = append(self.db.journal, nonceChange{ + account: &self.address, + prev: self.data.Nonce, + }) + self.setNonce(nonce) +} + +func (self *StateObject) setNonce(nonce uint64) { self.data.Nonce = nonce if self.onDirty != nil { self.onDirty(self.Address()) @@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { cb(h, value) } - it := self.trie.Iterator() + it := self.getTrie(self.db.db).Iterator() for it.Next() { // ignore cached values key := common.BytesToHash(self.trie.GetKey(it.Key)) diff --git a/core/state/state_test.go b/core/state/state_test.go index 7b9b39e065..b86d8b1401 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) { obj3.SetBalance(big.NewInt(44)) // write some of them to the trie - s.state.UpdateStateObject(obj1) - s.state.UpdateStateObject(obj2) + s.state.updateStateObject(obj1) + s.state.updateStateObject(obj2) s.state.Commit() // check that dump contains the state objects that are in trie @@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { // set initial state object value s.state.SetState(stateobjaddr, storageaddr, data1) // get snapshot of current state - snapshot := s.state.Copy() + snapshot := s.state.Snapshot() // set new state object value s.state.SetState(stateobjaddr, storageaddr, data2) // restore snapshot - s.state.Set(snapshot) + s.state.RevertToSnapshot(snapshot) // get state storage value res := s.state.GetState(stateobjaddr, storageaddr) @@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { c.Assert(data1, checker.DeepEquals, res) } +func TestSnapshotEmpty(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + state, _ := New(common.Hash{}, db) + state.RevertToSnapshot(state.Snapshot()) +} + // use testing instead of checker because checker does not support // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { @@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) { so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.remove = false so0.deleted = false - state.SetStateObject(so0) + state.setStateObject(so0) root, _ := state.Commit() state.Reset(root) @@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) { so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.remove = true so1.deleted = true - state.SetStateObject(so1) + state.setStateObject(so1) so1 = state.GetStateObject(stateobjaddr1) if so1 != nil { t.Fatalf("deleted object not nil when getting") } - snapshot := state.Copy() - state.Set(snapshot) + snapshot := state.Snapshot() + state.RevertToSnapshot(snapshot) so0Restored := state.GetStateObject(stateobjaddr0) // Update lazily-loaded values before comparing. diff --git a/core/state/statedb.go b/core/state/statedb.go index 4204c456e0..4f74302c3a 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -20,6 +20,7 @@ package state import ( "fmt" "math/big" + "sort" "sync" "github.com/ethereum/go-ethereum/common" @@ -40,12 +41,17 @@ var StartingNonce uint64 const ( // Number of past tries to keep. The arbitrarily chosen value here // is max uncle depth + 1. - maxJournalLength = 8 + maxTrieCacheLength = 8 // Number of codehash->size associations to keep. codeSizeCacheSize = 100000 ) +type revision struct { + id int + journalIndex int +} + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -69,6 +75,12 @@ type StateDB struct { logs map[common.Hash]vm.Logs logSize uint + // Journal of state modifications. This is the backbone of + // Snapshot and RevertToSnapshot. + journal journal + validRevisions []revision + nextRevisionId int + lock sync.Mutex } @@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error { self.trie = tr self.stateObjects = make(map[common.Address]*StateObject) self.stateObjectsDirty = make(map[common.Address]struct{}) - self.refund = new(big.Int) self.thash = common.Hash{} self.bhash = common.Hash{} self.txIndex = 0 self.logs = make(map[common.Hash]vm.Logs) self.logSize = 0 + self.clearJournalAndRefund() return nil } @@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) { self.lock.Lock() defer self.lock.Unlock() - if len(self.pastTries) >= maxJournalLength { + if len(self.pastTries) >= maxTrieCacheLength { copy(self.pastTries, self.pastTries[1:]) self.pastTries[len(self.pastTries)-1] = t } else { @@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { } func (self *StateDB) AddLog(log *vm.Log) { + self.journal = append(self.journal, addLogChange{txhash: self.thash}) + log.TxHash = self.thash log.BlockHash = self.bhash log.TxIndex = uint(self.txIndex) @@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs { } func (self *StateDB) AddRefund(gas *big.Int) { + self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)}) self.refund.Add(self.refund, gas) } -func (self *StateDB) HasAccount(addr common.Address) bool { - return self.GetStateObject(addr) != nil -} - +// Exist reports whether the given account address exists in the state. +// Notably this also returns true for suicided accounts. func (self *StateDB) Exist(addr common.Address) bool { return self.GetStateObject(addr) != nil } @@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int { if stateObject != nil { return stateObject.Balance() } - return common.Big0 } @@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { } } +func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { + stateObject := self.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetBalance(amount) + } +} + func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { @@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) { func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(key, value) + stateObject.SetState(self.db, key, value) } } +// Delete marks the given account as suicided. +// This clears the account balance. +// +// The account's state object is still available until the state is committed, +// GetStateObject will return a non-nil account after Delete. func (self *StateDB) Delete(addr common.Address) bool { stateObject := self.GetStateObject(addr) - if stateObject != nil { - stateObject.MarkForDeletion() - stateObject.data.Balance = new(big.Int) - return true + if stateObject == nil { + return false } - - return false + self.journal = append(self.journal, deleteAccountChange{ + account: &addr, + prev: stateObject.remove, + prevbalance: new(big.Int).Set(stateObject.Balance()), + }) + stateObject.markForDeletion() + stateObject.data.Balance = new(big.Int) + return true } // // Setting, updating & deleting state object methods // -// Update the given state object and apply it to state trie -func (self *StateDB) UpdateStateObject(stateObject *StateObject) { +// updateStateObject writes the given object to the trie. +func (self *StateDB) updateStateObject(stateObject *StateObject) { addr := stateObject.Address() data, err := rlp.EncodeToBytes(stateObject) if err != nil { @@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) { self.trie.Update(addr[:], data) } -// Delete the given state object and delete it from the state trie -func (self *StateDB) DeleteStateObject(stateObject *StateObject) { +// deleteStateObject removes the given object from the state trie. +func (self *StateDB) deleteStateObject(stateObject *StateObject) { stateObject.deleted = true - addr := stateObject.Address() self.trie.Delete(addr[:]) } @@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje return nil } // Insert into the live set. - obj := NewObject(addr, data, self.MarkStateObjectDirty) - self.SetStateObject(obj) + obj := newObject(self, addr, data, self.MarkStateObjectDirty) + self.setStateObject(obj) return obj } -func (self *StateDB) SetStateObject(object *StateObject) { +func (self *StateDB) setStateObject(object *StateObject) { self.stateObjects[object.Address()] = object } @@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { stateObject := self.GetStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject = self.CreateStateObject(addr) + stateObject, _ = self.createObject(addr) } - return stateObject } -// NewStateObject create a state object whether it exist in the trie or not -func (self *StateDB) newStateObject(addr common.Address) *StateObject { - if glog.V(logger.Core) { - glog.Infof("(+) %x\n", addr) - } - obj := NewObject(addr, Account{}, self.MarkStateObjectDirty) - obj.SetNonce(StartingNonce) // sets the object to dirty - self.stateObjects[addr] = obj - return obj -} - // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly // state object cache iteration to find a handful of modified ones. func (self *StateDB) MarkStateObjectDirty(addr common.Address) { self.stateObjectsDirty[addr] = struct{}{} } -// Creates creates a new state object and takes ownership. -func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { - // Get previous (if any) - so := self.GetStateObject(addr) - // Create a new one - newSo := self.newStateObject(addr) - - // If it existed set the balance to the new account - if so != nil { - newSo.data.Balance = so.data.Balance +// createObject creates a new state object. If there is an existing account with +// the given address, it is overwritten and returned as the second return value. +func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) { + prev = self.GetStateObject(addr) + newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) + newobj.setNonce(StartingNonce) // sets the object to dirty + if prev == nil { + if glog.V(logger.Core) { + glog.Infof("(+) %x\n", addr) + } + self.journal = append(self.journal, createObjectChange{account: &addr}) + } else { + self.journal = append(self.journal, resetObjectChange{prev: prev}) } - - return newSo + self.setStateObject(newobj) + return newobj, prev } +// CreateAccount explicitly creates a state object. If a state object with the address +// already exists the balance is carried over to the new account. +// +// CreateAccount is called during the EVM CREATE operation. The situation might arise that +// a contract does the following: +// +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// +// Carrying over the balance ensures that Ether doesn't disappear. func (self *StateDB) CreateAccount(addr common.Address) vm.Account { - return self.CreateStateObject(addr) + new, prev := self.createObject(addr) + if prev != nil { + new.setBalance(prev.data.Balance) + } + return new } -// -// Setting, copying of the state methods -// - +// Copy creates a deep, independent copy of the state. +// Snapshots of the copied state cannot be applied to the copy. func (self *StateDB) Copy() *StateDB { self.lock.Lock() defer self.lock.Unlock() @@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB { } // Copy the dirty states and logs for addr, _ := range self.stateObjectsDirty { - state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) + state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range self.logs { @@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB { return state } -func (self *StateDB) Set(state *StateDB) { - self.lock.Lock() - defer self.lock.Unlock() - - self.db = state.db - self.trie = state.trie - self.pastTries = state.pastTries - self.stateObjects = state.stateObjects - self.stateObjectsDirty = state.stateObjectsDirty - self.codeSizeCache = state.codeSizeCache - self.refund = state.refund - self.logs = state.logs - self.logSize = state.logSize +// Snapshot returns an identifier for the current revision of the state. +func (self *StateDB) Snapshot() int { + id := self.nextRevisionId + self.nextRevisionId++ + self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) + return id } +// RevertToSnapshot reverts all state changes made since the given revision. +func (self *StateDB) RevertToSnapshot(revid int) { + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(self.validRevisions), func(i int) bool { + return self.validRevisions[i].id >= revid + }) + if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := self.validRevisions[idx].journalIndex + + // Replay the journal to undo changes. + for i := len(self.journal) - 1; i >= snapshot; i-- { + self.journal[i].undo(self) + } + self.journal = self.journal[:snapshot] + + // Remove invalidated snapshots from the stack. + self.validRevisions = self.validRevisions[:idx] +} + +// GetRefund returns the current value of the refund counter. +// The return value must not be modified by the caller and will become +// invalid at the next call to AddRefund. func (self *StateDB) GetRefund() *big.Int { return self.refund } @@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot() common.Hash { - s.refund = new(big.Int) for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] if stateObject.remove { - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else { - stateObject.UpdateRoot(s.db) - s.UpdateStateObject(stateObject) + stateObject.updateRoot(s.db) + s.updateStateObject(stateObject) } } + // Invalidate journal because reverting across transactions is not allowed. + s.clearJournalAndRefund() return s.trie.Hash() } @@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash { // DeleteSuicides should not be used for consensus related updates // under any circumstances. func (s *StateDB) DeleteSuicides() { - // Reset refund so that any used-gas calculations can use - // this method. - s.refund = new(big.Int) + // Reset refund so that any used-gas calculations can use this method. + s.clearJournalAndRefund() + for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] @@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { return root, batch } -func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { +func (s *StateDB) clearJournalAndRefund() { + s.journal = nil + s.validRevisions = s.validRevisions[:0] s.refund = new(big.Int) +} + +func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { + defer s.clearJournalAndRefund() // Commit objects to the trie. for addr, stateObject := range s.stateObjects { if stateObject.remove { // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else if _, ok := s.stateObjectsDirty[addr]; ok { // Write any contract code associated with the state object if stateObject.code != nil && stateObject.dirtyCode { @@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) return common.Hash{}, err } // Update the object in the main account trie. - s.UpdateStateObject(stateObject) + s.updateStateObject(stateObject) } delete(s.stateObjectsDirty, addr) } @@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) } return root, err } - -func (self *StateDB) Refunds() *big.Int { - return self.refund -} diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 7930b620d4..e236cb8f33 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -17,11 +17,19 @@ package state import ( + "bytes" + "encoding/binary" + "fmt" + "math" "math/big" + "math/rand" + "reflect" + "strings" "testing" + "testing/quick" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" ) @@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) { // Update it with some accounts for i := byte(0); i < 255; i++ { - obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.AddBalance(big.NewInt(int64(11 * i))) - obj.SetNonce(uint64(42 * i)) + addr := common.BytesToAddress([]byte{i}) + state.AddBalance(addr, big.NewInt(int64(11*i))) + state.SetNonce(addr, uint64(42*i)) if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) + state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) } if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) + state.SetCode(addr, []byte{i, i, i, i, i}) } - state.UpdateStateObject(obj) + state.IntermediateRoot() } // Ensure that no data was leaked into the database for _, key := range db.Keys() { @@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) { transState, _ := New(common.Hash{}, transDb) finalState, _ := New(common.Hash{}, finalDb) - // Update the states with some objects - for i := byte(0); i < 255; i++ { - // Create a new state object with some data into the transition database - obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.SetBalance(big.NewInt(int64(11 * i))) - obj.SetNonce(uint64(42 * i)) + modify := func(state *StateDB, addr common.Address, i, tweak byte) { + state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) + state.SetNonce(addr, uint64(42*i+tweak)) if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0})) + state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{}) + state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak}) } if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0}) + state.SetCode(addr, []byte{i, i, i, i, i, tweak}) } - transState.UpdateStateObject(obj) - - // Overwrite all the data with new values in the transition database - obj.SetBalance(big.NewInt(int64(11*i + 1))) - obj.SetNonce(uint64(42*i + 1)) - if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{}) - obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) - } - if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) - } - transState.UpdateStateObject(obj) - - // Create the final state object directly in the final database - obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.SetBalance(big.NewInt(int64(11*i + 1))) - obj.SetNonce(uint64(42*i + 1)) - if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) - } - if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) - } - finalState.UpdateStateObject(obj) } + + // Modify the transient state. + for i := byte(0); i < 255; i++ { + modify(transState, common.Address{byte(i)}, i, 0) + } + // Write modifications to trie. + transState.IntermediateRoot() + + // Overwrite all the data with new values in the transient database. + for i := byte(0); i < 255; i++ { + modify(transState, common.Address{byte(i)}, i, 99) + modify(finalState, common.Address{byte(i)}, i, 99) + } + + // Commit and cross check the databases. if _, err := transState.Commit(); err != nil { t.Fatalf("failed to commit transition state: %v", err) } if _, err := finalState.Commit(); err != nil { t.Fatalf("failed to commit final state: %v", err) } - // Cross check the databases to ensure they are the same for _, key := range finalDb.Keys() { if _, err := transDb.Get(key); err != nil { val, _ := finalDb.Get(key) @@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) { } } } + +func TestSnapshotRandom(t *testing.T) { + config := &quick.Config{MaxCount: 1000} + err := quick.Check((*snapshotTest).run, config) + if cerr, ok := err.(*quick.CheckError); ok { + test := cerr.In[0].(*snapshotTest) + t.Errorf("%v:\n%s", test.err, test) + } else if err != nil { + t.Error(err) + } +} + +// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes +// captured by the snapshot. Instances of this test with pseudorandom content are created +// by Generate. +// +// The test works as follows: +// +// A new state is created and all actions are applied to it. Several snapshots are taken +// in between actions. The test then reverts each snapshot. For each snapshot the actions +// leading up to it are replayed on a fresh, empty state. The behaviour of all public +// accessor methods on the reverted state must match the return value of the equivalent +// methods on the replayed state. +type snapshotTest struct { + addrs []common.Address // all account addresses + actions []testAction // modifications to the state + snapshots []int // actions indexes at which snapshot is taken + err error // failure details are reported through this field +} + +type testAction struct { + name string + fn func(testAction, *StateDB) + args []int64 + noAddr bool +} + +// newTestAction creates a random action that changes state. +func newTestAction(addr common.Address, r *rand.Rand) testAction { + actions := []testAction{ + { + name: "SetBalance", + fn: func(a testAction, s *StateDB) { + s.SetBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "AddBalance", + fn: func(a testAction, s *StateDB) { + s.AddBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetNonce", + fn: func(a testAction, s *StateDB) { + s.SetNonce(addr, uint64(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetState(addr, key, val) + }, + args: make([]int64, 2), + }, + { + name: "SetCode", + fn: func(a testAction, s *StateDB) { + code := make([]byte, 16) + binary.BigEndian.PutUint64(code, uint64(a.args[0])) + binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) + s.SetCode(addr, code) + }, + args: make([]int64, 2), + }, + { + name: "CreateAccount", + fn: func(a testAction, s *StateDB) { + s.CreateAccount(addr) + }, + }, + { + name: "Delete", + fn: func(a testAction, s *StateDB) { + s.Delete(addr) + }, + }, + { + name: "AddRefund", + fn: func(a testAction, s *StateDB) { + s.AddRefund(big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + noAddr: true, + }, + { + name: "AddLog", + fn: func(a testAction, s *StateDB) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(a.args[0])) + s.AddLog(&vm.Log{Address: addr, Data: data}) + }, + args: make([]int64, 1), + }, + } + action := actions[r.Intn(len(actions))] + var nameargs []string + if !action.noAddr { + nameargs = append(nameargs, addr.Hex()) + } + for _, i := range action.args { + action.args[i] = rand.Int63n(100) + nameargs = append(nameargs, fmt.Sprint(action.args[i])) + } + action.name += strings.Join(nameargs, ", ") + return action +} + +// Generate returns a new snapshot test of the given size. All randomness is +// derived from r. +func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { + // Generate random actions. + addrs := make([]common.Address, 50) + for i := range addrs { + addrs[i][0] = byte(i) + } + actions := make([]testAction, size) + for i := range actions { + addr := addrs[r.Intn(len(addrs))] + actions[i] = newTestAction(addr, r) + } + // Generate snapshot indexes. + nsnapshots := int(math.Sqrt(float64(size))) + if size > 0 && nsnapshots == 0 { + nsnapshots = 1 + } + snapshots := make([]int, nsnapshots) + snaplen := len(actions) / nsnapshots + for i := range snapshots { + // Try to place the snapshots some number of actions apart from each other. + snapshots[i] = (i * snaplen) + r.Intn(snaplen) + } + return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) +} + +func (test *snapshotTest) String() string { + out := new(bytes.Buffer) + sindex := 0 + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) + sindex++ + } + fmt.Fprintf(out, "%4d: %s\n", i, action.name) + } + return out.String() +} + +func (test *snapshotTest) run() bool { + // Run all actions and create snapshots. + var ( + db, _ = ethdb.NewMemDatabase() + state, _ = New(common.Hash{}, db) + snapshotRevs = make([]int, len(test.snapshots)) + sindex = 0 + ) + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + snapshotRevs[sindex] = state.Snapshot() + sindex++ + } + action.fn(action, state) + } + + // Revert all snapshots in reverse order. Each revert must yield a state + // that is equivalent to fresh state with all actions up the snapshot applied. + for sindex--; sindex >= 0; sindex-- { + checkstate, _ := New(common.Hash{}, db) + for _, action := range test.actions[:test.snapshots[sindex]] { + action.fn(action, checkstate) + } + state.RevertToSnapshot(snapshotRevs[sindex]) + if err := test.checkEqual(state, checkstate); err != nil { + test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) + return false + } + } + return true +} + +// checkEqual checks that methods of state and checkstate return the same values. +func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { + for _, addr := range test.addrs { + var err error + checkeq := func(op string, a, b interface{}) bool { + if err == nil && !reflect.DeepEqual(a, b) { + err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) + return false + } + return true + } + // Check basic accessor methods. + checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) + checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr)) + checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) + checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) + checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) + checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) + checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check storage. + if obj := state.GetStateObject(addr); obj != nil { + obj.ForEachStorage(func(key, val common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key)) + }) + checkobj := checkstate.GetStateObject(addr) + checkobj.ForEachStorage(func(key, checkval common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) + }) + } + if err != nil { + return err + } + } + + if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 { + return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", + state.GetRefund(), checkstate.GetRefund()) + } + if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) { + return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", + state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) + } + return nil +} diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 670e1fb1bf..949df7301f 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) acc.code = []byte{i, i, i, i, i} } - state.UpdateStateObject(obj) + state.updateStateObject(obj) accounts = append(accounts, acc) } root, _ := state.Commit() diff --git a/core/tx_pool.go b/core/tx_pool.go index f8b11a7ce7..10a110e0bd 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error { // Make sure the account exist. Non existent accounts // haven't got funds and well therefor never pass. - if !currentState.HasAccount(from) { + if !currentState.Exist(from) { return ErrNonExistentAccount } diff --git a/core/vm/environment.go b/core/vm/environment.go index daf6fb90d5..1038e69d57 100644 --- a/core/vm/environment.go +++ b/core/vm/environment.go @@ -36,9 +36,9 @@ type Environment interface { // The state database Db() Database // Creates a restorable snapshot - MakeSnapshot() Database + SnapshotDatabase() int // Set database to previous snapshot - SetSnapshot(Database) + RevertToSnapshot(int) // Address of the original invoker (first occurrence of the VM invoker) Origin() common.Address // The block number this VM is invoked on diff --git a/core/vm/jit_test.go b/core/vm/jit_test.go index e6922aeb74..a6de710e13 100644 --- a/core/vm/jit_test.go +++ b/core/vm/jit_test.go @@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } //func (self *Env) PrevHash() []byte { return self.parent } func (self *Env) Coinbase() common.Address { return common.Address{} } -func (self *Env) MakeSnapshot() Database { return nil } -func (self *Env) SetSnapshot(Database) {} +func (self *Env) SnapshotDatabase() int { return 0 } +func (self *Env) RevertToSnapshot(int) {} func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Db() Database { return nil } diff --git a/core/vm/runtime/env.go b/core/vm/runtime/env.go index a4793c98fc..59fbaa792a 100644 --- a/core/vm/runtime/env.go +++ b/core/vm/runtime/env.go @@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i } func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *Env) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *Env) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *Env) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *Env) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/core/vm_env.go b/core/vm_env.go index e541eaef47..d62eebbd9e 100644 --- a/core/vm_env.go +++ b/core/vm_env.go @@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *VMEnv) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *VMEnv) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *VMEnv) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *VMEnv) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/eth/api_backend.go b/eth/api_backend.go index 4adeb0aa08..42b84bf9ba 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { } func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { - stateDb := state.(EthApiState).state.Copy() + statedb := state.(EthApiState).state addr, _ := msg.From() - from := stateDb.GetOrNewStateObject(addr) + from := statedb.GetOrNewStateObject(addr) from.SetBalance(common.MaxBig) vmError := func() error { return nil } - return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil + return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil } func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { diff --git a/internal/ethapi/tracer_test.go b/internal/ethapi/tracer_test.go index 7c831d2990..127af32a84 100644 --- a/internal/ethapi/tracer_test.go +++ b/internal/ethapi/tracer_test.go @@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} } func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } //func (self *Env) PrevHash() []byte { return self.parent } -func (self *Env) Coinbase() common.Address { return common.Address{} } -func (self *Env) MakeSnapshot() vm.Database { return nil } -func (self *Env) SetSnapshot(vm.Database) {} -func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } -func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } -func (self *Env) Db() vm.Database { return nil } -func (self *Env) GasLimit() *big.Int { return self.gasLimit } -func (self *Env) VmType() vm.Type { return vm.StdVmTy } +func (self *Env) Coinbase() common.Address { return common.Address{} } +func (self *Env) SnapshotDatabase() int { return 0 } +func (self *Env) RevertToSnapshot(int) {} +func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } +func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } +func (self *Env) Db() vm.Database { return nil } +func (self *Env) GasLimit() *big.Int { return self.gasLimit } +func (self *Env) VmType() vm.Type { return vm.StdVmTy } func (self *Env) GetHash(n uint64) common.Hash { return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) } diff --git a/light/state_test.go b/light/state_test.go index d4fe950229..a6b1157868 100644 --- a/light/state_test.go +++ b/light/state_test.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" "golang.org/x/net/context" @@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) { sdb, _ := ethdb.NewMemDatabase() st, _ := state.New(common.Hash{}, sdb) for i := byte(0); i < 100; i++ { - so := st.GetOrNewStateObject(common.Address{i}) + addr := common.Address{i} for j := byte(0); j < 100; j++ { - val := common.Hash{i, j} - so.SetState(common.Hash{j}, val) - so.SetNonce(100) + st.SetState(addr, common.Hash{j}, common.Hash{i, j}) } - so.AddBalance(big.NewInt(int64(i))) - so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) - so.UpdateRoot(sdb) - st.UpdateStateObject(so) + st.SetNonce(addr, 100) + st.AddBalance(addr, big.NewInt(int64(i))) + st.SetCode(addr, []byte{i, i, i}) } root, _ := st.Commit() return root, sdb diff --git a/miner/worker.go b/miner/worker.go index ac1ef5ba3d..e5348cef42 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { self.current.receipts, ), self.current.state } - return self.current.Block, self.current.state + return self.current.Block, self.current.state.Copy() } func (self *worker) start() { @@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB } func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { - snap := env.state.Copy() + snap := env.state.Snapshot() // this is a bit of a hack to force jit for the miners config := env.config.VmConfig @@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) if err != nil { - env.state.Set(snap) + env.state.RevertToSnapshot(snap) return err, nil } env.txs = append(env.txs, tx) diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 67e4bf832e..3c4b42a184 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error { func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { b.StopTimer() db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) b.StartTimer() RunState(ruleSet, statedb, env, test.Exec) @@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string) func runStateTest(ruleSet RuleSet, test VmTest) error { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) // XXX Yeah, yeah... env := make(map[string]string) @@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string } // Set pre compiled contracts vm.Precompiled = vm.PrecompiledContracts() - snapshot := statedb.Copy() + snapshot := statedb.Snapshot() gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) key, _ := hex.DecodeString(tx["secretKey"]) @@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string vmenv.origin = addr ret, _, err := core.ApplyMessage(vmenv, message, gaspool) if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { - statedb.Set(snapshot) + statedb.RevertToSnapshot(snapshot) } statedb.Commit() diff --git a/tests/util.go b/tests/util.go index ffbcb9d56d..8a9d09213e 100644 --- a/tests/util.go +++ b/tests/util.go @@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte { return t } -func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { +func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { + statedb, _ := state.New(common.Hash{}, db) + for addr, account := range accounts { + insertAccount(statedb, addr, account) + } + return statedb +} + +func insertAccount(state *state.StateDB, saddr string, account Account) { if common.IsHex(account.Code) { account.Code = account.Code[2:] } - code := common.Hex2Bytes(account.Code) - codeHash := crypto.Keccak256Hash(code) - obj := state.NewObject(common.HexToAddress(addr), state.Account{ - Balance: common.Big(account.Balance), - CodeHash: codeHash[:], - Nonce: common.Big(account.Nonce).Uint64(), - }, onDirty) - obj.SetCode(codeHash, code) - return obj + addr := common.HexToAddress(saddr) + state.SetCode(addr, common.Hex2Bytes(account.Code)) + state.SetNonce(addr, common.Big(account.Nonce).Uint64()) + state.SetBalance(addr, common.Big(account.Balance)) + for a, v := range account.Storage { + state.SetState(addr, common.HexToHash(a), common.HexToHash(v)) + } } type VmEnv struct { @@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *Env) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *Env) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *Env) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *Env) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 4ad72d91c1..c269f21e09 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error { func benchVmTest(test VmTest, env map[string]string, b *testing.B) { b.StopTimer() db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) b.StartTimer() RunVm(statedb, env, test.Exec) @@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error { func runVmTest(test VmTest) error { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) // XXX Yeah, yeah... env := make(map[string]string) From 90fce8bfa621f8c3be6663d62740783949111ff1 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 5 Oct 2016 22:22:31 +0200 Subject: [PATCH 2/3] core/state: rename Delete/IsDeleted to Suicide/HasSuicided The delete/remove naming has caused endless confusion in the past. --- core/state/journal.go | 6 +++--- core/state/state_object.go | 12 ++++++------ core/state/state_test.go | 8 ++++---- core/state/statedb.go | 22 +++++++++++----------- core/state/statedb_test.go | 6 +++--- core/vm/environment.go | 7 +++++-- core/vm/instructions.go | 2 +- core/vm/jit.go | 2 +- core/vm/vm.go | 2 +- 9 files changed, 35 insertions(+), 32 deletions(-) diff --git a/core/state/journal.go b/core/state/journal.go index 540ade6fb7..720c821b93 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -36,7 +36,7 @@ type ( resetObjectChange struct { prev *StateObject } - deleteAccountChange struct { + suicideChange struct { account *common.Address prev bool // whether account had already suicided prevbalance *big.Int @@ -79,10 +79,10 @@ func (ch resetObjectChange) undo(s *StateDB) { s.setStateObject(ch.prev) } -func (ch deleteAccountChange) undo(s *StateDB) { +func (ch suicideChange) undo(s *StateDB) { obj := s.GetStateObject(*ch.account) if obj != nil { - obj.remove = ch.prev + obj.suicided = ch.prev obj.setBalance(ch.prevbalance) } } diff --git a/core/state/state_object.go b/core/state/state_object.go index 31ff9bcd81..a6b6028bc5 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -83,10 +83,10 @@ type StateObject struct { dirtyStorage Storage // Storage entries that need to be flushed to disk // Cache flags. - // When an object is marked for deletion it will be delete from the trie - // during the "update" phase of the state transition + // When an object is marked suicided it will be delete from the trie + // during the "update" phase of the state transition. dirtyCode bool // true if the code was updated - remove bool + suicided bool deleted bool onDirty func(addr common.Address) // Callback method to mark a state object newly dirty } @@ -123,8 +123,8 @@ func (self *StateObject) setError(err error) { } } -func (self *StateObject) markForDeletion() { - self.remove = true +func (self *StateObject) markSuicided() { + self.suicided = true if self.onDirty != nil { self.onDirty(self.Address()) self.onDirty = nil @@ -266,7 +266,7 @@ func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address) stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.cachedStorage = self.dirtyStorage.Copy() - stateObject.remove = self.remove + stateObject.suicided = self.suicided stateObject.dirtyCode = self.dirtyCode stateObject.deleted = self.deleted return stateObject diff --git a/core/state/state_test.go b/core/state/state_test.go index b86d8b1401..f188bc2715 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -156,7 +156,7 @@ func TestSnapshot2(t *testing.T) { so0.SetBalance(big.NewInt(42)) so0.SetNonce(43) so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) - so0.remove = false + so0.suicided = false so0.deleted = false state.setStateObject(so0) @@ -168,7 +168,7 @@ func TestSnapshot2(t *testing.T) { so1.SetBalance(big.NewInt(52)) so1.SetNonce(53) so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) - so1.remove = true + so1.suicided = true so1.deleted = true state.setStateObject(so1) @@ -228,8 +228,8 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) { } } - if so0.remove != so1.remove { - t.Fatalf("Remove mismatch: have %v, want %v", so0.remove, so1.remove) + if so0.suicided != so1.suicided { + t.Fatalf("suicided mismatch: have %v, want %v", so0.suicided, so1.suicided) } if so0.deleted != so1.deleted { t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) diff --git a/core/state/statedb.go b/core/state/statedb.go index 4f74302c3a..ec9e9392f0 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -275,10 +275,10 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { return common.Hash{} } -func (self *StateDB) IsDeleted(addr common.Address) bool { +func (self *StateDB) HasSuicided(addr common.Address) bool { stateObject := self.GetStateObject(addr) if stateObject != nil { - return stateObject.remove + return stateObject.suicided } return false } @@ -322,22 +322,22 @@ func (self *StateDB) SetState(addr common.Address, key common.Hash, value common } } -// Delete marks the given account as suicided. +// Suicide marks the given account as suicided. // This clears the account balance. // // The account's state object is still available until the state is committed, -// GetStateObject will return a non-nil account after Delete. -func (self *StateDB) Delete(addr common.Address) bool { +// GetStateObject will return a non-nil account after Suicide. +func (self *StateDB) Suicide(addr common.Address) bool { stateObject := self.GetStateObject(addr) if stateObject == nil { return false } - self.journal = append(self.journal, deleteAccountChange{ + self.journal = append(self.journal, suicideChange{ account: &addr, - prev: stateObject.remove, + prev: stateObject.suicided, prevbalance: new(big.Int).Set(stateObject.Balance()), }) - stateObject.markForDeletion() + stateObject.markSuicided() stateObject.data.Balance = new(big.Int) return true } @@ -516,7 +516,7 @@ func (self *StateDB) GetRefund() *big.Int { func (s *StateDB) IntermediateRoot() common.Hash { for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] - if stateObject.remove { + if stateObject.suicided { s.deleteStateObject(stateObject) } else { stateObject.updateRoot(s.db) @@ -542,7 +542,7 @@ func (s *StateDB) DeleteSuicides() { // If the object has been removed by a suicide // flag the object as deleted. - if stateObject.remove { + if stateObject.suicided { stateObject.deleted = true } delete(s.stateObjectsDirty, addr) @@ -575,7 +575,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) // Commit objects to the trie. for addr, stateObject := range s.stateObjects { - if stateObject.remove { + if stateObject.suicided { // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. s.deleteStateObject(stateObject) diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index e236cb8f33..5d041c7400 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -202,9 +202,9 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { }, }, { - name: "Delete", + name: "Suicide", fn: func(a testAction, s *StateDB) { - s.Delete(addr) + s.Suicide(addr) }, }, { @@ -323,7 +323,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { } // Check basic accessor methods. checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) - checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr)) + checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr)) checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) diff --git a/core/vm/environment.go b/core/vm/environment.go index 1038e69d57..a4b2ac1966 100644 --- a/core/vm/environment.go +++ b/core/vm/environment.go @@ -105,9 +105,12 @@ type Database interface { GetState(common.Address, common.Hash) common.Hash SetState(common.Address, common.Hash, common.Hash) - Delete(common.Address) bool + Suicide(common.Address) bool + HasSuicided(common.Address) bool + + // Exist reports whether the given account exists in state. + // Notably this should also return true for suicided accounts. Exist(common.Address) bool - IsDeleted(common.Address) bool } // Account represents a contract or basic ethereum account. diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 849a8463cc..79aee60d28 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -614,7 +614,7 @@ func opSuicide(instr instruction, pc *uint64, env Environment, contract *Contrac balance := env.Db().GetBalance(contract.Address()) env.Db().AddBalance(common.BigToAddress(stack.pop()), balance) - env.Db().Delete(contract.Address()) + env.Db().Suicide(contract.Address()) } // following functions are used by the instruction jump table diff --git a/core/vm/jit.go b/core/vm/jit.go index 460a68dddd..55d2e04775 100644 --- a/core/vm/jit.go +++ b/core/vm/jit.go @@ -425,7 +425,7 @@ func jitCalculateGasAndSize(env Environment, contract *Contract, instr instructi } gas.Set(g) case SUICIDE: - if !statedb.IsDeleted(contract.Address()) { + if !statedb.HasSuicided(contract.Address()) { statedb.AddRefund(params.SuicideRefundGas) } case MLOAD: diff --git a/core/vm/vm.go b/core/vm/vm.go index 5d78b4a2ad..033ada21c5 100644 --- a/core/vm/vm.go +++ b/core/vm/vm.go @@ -303,7 +303,7 @@ func calculateGasAndSize(env Environment, contract *Contract, caller ContractRef } gas.Set(g) case SUICIDE: - if !statedb.IsDeleted(contract.Address()) { + if !statedb.HasSuicided(contract.Address()) { statedb.AddRefund(params.SuicideRefundGas) } case MLOAD: From 3c836dd71b192de24774b1848173a4eb0ca9a63b Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Wed, 5 Oct 2016 22:56:07 +0200 Subject: [PATCH 3/3] core/state: optimize GetState There is no need to use the reflection-based decoder to decode []byte. --- core/state/state_object.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/state/state_object.go b/core/state/state_object.go index a6b6028bc5..6eab27d9e7 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -153,10 +153,13 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash return value } // Load from DB in case it is missing. - tr := self.getTrie(db) - var ret []byte - rlp.DecodeBytes(tr.Get(key[:]), &ret) - value = common.BytesToHash(ret) + if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + self.setError(err) + } + value.SetBytes(content) + } if (value != common.Hash{}) { self.cachedStorage[key] = value } @@ -209,7 +212,6 @@ func (self *StateObject) updateRoot(db trie.Database) { func (self *StateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error { self.updateTrie(db) if self.dbErr != nil { - fmt.Println("dbErr:", self.dbErr) return self.dbErr } root, err := self.trie.CommitTo(dbw)