From 81fa3e6b76600a1482d7c8a50bb2bc4b7b5b5294 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Wed, 6 Nov 2024 13:48:36 +0100 Subject: [PATCH] core/state: simplify journal api, remove externally visible snapshot ids cmd/evm: post-rebase fixup --- cmd/evm/internal/t8ntool/execution.go | 8 +- cmd/evm/runner.go | 2 - core/state/journal.go | 12 +-- core/state/journal_linear.go | 62 +++++------ core/state/journal_set.go | 49 +++------ core/state/journal_test.go | 44 ++++---- core/state/state_test.go | 56 +++++----- core/state/statedb.go | 24 ++--- core/state/statedb_hooked.go | 12 +-- core/state/statedb_test.go | 101 +++++++++--------- core/vm/evm.go | 34 +++--- core/vm/interface.go | 17 ++- .../internal/tracetest/calltrace_test.go | 4 +- eth/tracers/tracers_test.go | 4 +- miner/worker.go | 8 +- tests/state_test.go | 4 +- tests/state_test_util.go | 6 +- 17 files changed, 206 insertions(+), 241 deletions(-) diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 0f0fe79841..c6bb25e0a2 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -253,16 +253,16 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, statedb.SetTxContext(tx.Hash(), txIndex) var ( - snapshot = statedb.Snapshot() - prevGas = gaspool.Gas() + prevGas = gaspool.Gas() ) + statedb.Snapshot() if tracer != nil && tracer.OnTxStart != nil { tracer.OnTxStart(evm.GetVMContext(), tx, msg.From) } // (ret []byte, usedGas uint64, failed bool, err error) msgResult, err := core.ApplyMessage(evm, msg, gaspool) if err != nil { - statedb.RevertToSnapshot(snapshot) + statedb.RevertSnapshot() log.Info("rejected tx", "index", i, "hash", tx.Hash(), "from", msg.From, "error", err) rejectedTxs = append(rejectedTxs, &rejectedTx{i, err.Error()}) gaspool.SetGas(prevGas) @@ -276,7 +276,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, } continue } - statedb.DiscardSnapshot(snapshot) + statedb.DiscardSnapshot() includedTxs = append(includedTxs, tx) if hashError != nil { return nil, nil, nil, NewError(ErrorMissingBlockhash, hashError) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 80fab19f1d..3515743c01 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -230,8 +230,6 @@ func runCmd(ctx *cli.Context) error { sdb := state.NewDatabase(triedb, nil) prestate, _ = state.New(genesis.Root(), sdb) chainConfig = genesisConfig.Config - id := statedb.Snapshot() - defer statedb.DiscardSnapshot(id) if ctx.String(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.String(SenderFlag.Name)) } diff --git a/core/state/journal.go b/core/state/journal.go index c1425c3b39..1f5f056aec 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -22,7 +22,7 @@ import ( ) type journal interface { - // snapshot returns an identifier for the current revision of the state. + // snapshot starts a new journal scope which can be reverted or discarded. // The lifeycle of journalling is as follows: // - snapshot() starts a 'scope'. // - The method snapshot() may be called any number of times. @@ -30,15 +30,15 @@ type journal interface { // the scope via either of: // - revertToSnapshot, which undoes the changes in the scope, or // - discardSnapshot, which discards the ability to revert the changes in the scope. - snapshot() int + snapshot() - // revertToSnapshot reverts all state changes made since the given revision. - revertToSnapshot(revid int, s *StateDB) + // revertSnapshot reverts all state changes made since the last call to snapshot(). + revertSnapshot(s *StateDB) - // discardSnapshot removes the snapshot with the given id; after calling this + // discardSnapshot removes the latest snapshot; after calling this // method, it is no longer possible to revert to that particular snapshot, the // changes are considered part of the parent scope. - discardSnapshot(revid int) + discardSnapshot() // reset clears the journal so it can be reused. reset() diff --git a/core/state/journal_linear.go b/core/state/journal_linear.go index 61f73ad8aa..b259d9e841 100644 --- a/core/state/journal_linear.go +++ b/core/state/journal_linear.go @@ -17,10 +17,8 @@ package state import ( - "fmt" "maps" "slices" - "sort" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -28,11 +26,6 @@ import ( "github.com/holiman/uint256" ) -type revision struct { - id int - journalIndex int -} - // journalEntry is a modification entry in the state change linear journal that can be // reverted on demand. type journalEntry interface { @@ -53,8 +46,7 @@ type linearJournal struct { entries []journalEntry // Current changes tracked by the linearJournal dirties map[common.Address]int // Dirty accounts and the number of changes - validRevisions []revision - nextRevisionId int + revisions []int // sequence of indexes to points in time designating snapshots } // compile-time interface check @@ -72,9 +64,8 @@ func newLinearJournal() *linearJournal { // can be reused. func (j *linearJournal) reset() { j.entries = j.entries[:0] - j.validRevisions = j.validRevisions[:0] + j.revisions = j.revisions[:0] clear(j.dirties) - j.nextRevisionId = 0 } func (j linearJournal) dirtyAccounts() []common.Address { @@ -86,33 +77,33 @@ func (j linearJournal) dirtyAccounts() []common.Address { return dirty } -// snapshot returns an identifier for the current revision of the state. -func (j *linearJournal) snapshot() int { - id := j.nextRevisionId - j.nextRevisionId++ - j.validRevisions = append(j.validRevisions, revision{id, j.length()}) - return id +// snapshot starts a new journal scope which can be reverted or discarded. +func (j *linearJournal) snapshot() { + j.revisions = append(j.revisions, len(j.entries)) } -func (j *linearJournal) revertToSnapshot(revid int, s *StateDB) { - // Find the snapshot in the stack of valid snapshots. - idx := sort.Search(len(j.validRevisions), func(i int) bool { - return j.validRevisions[i].id >= revid - }) - if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid { - panic(fmt.Errorf("revision id %v cannot be reverted (valid revisions: %d)", revid, len(j.validRevisions))) - } - snapshot := j.validRevisions[idx].journalIndex - +// revertSnapshot reverts all state changes made since the last call to snapshot(). +func (j *linearJournal) revertSnapshot(s *StateDB) { + id := len(j.revisions) - 1 + revision := j.revisions[id] // Replay the linearJournal to undo changes and remove invalidated snapshots - j.revert(s, snapshot) - j.validRevisions = j.validRevisions[:idx] + j.revertTo(s, revision) + j.revisions = j.revisions[:id] } -// discardSnapshot removes the snapshot with the given id; after calling this +// discardSnapshot removes the latest snapshot; after calling this // method, it is no longer possible to revert to that particular snapshot, the // changes are considered part of the parent scope. -func (j *linearJournal) discardSnapshot(id int) { +func (j *linearJournal) discardSnapshot() { + id := len(j.revisions) - 1 + if id == 0 { + // If a transaction is applied successfully, the statedb.Finalize will + // end by clearing and resetting the journal. Invoking a discardSnapshot + // afterwards will land here: calling discard on an empty journal. + // This is fine + return + } + j.revisions = j.revisions[:id] } // append inserts a new modification entry to the end of the change linearJournal. @@ -125,7 +116,7 @@ func (j *linearJournal) append(entry journalEntry) { // revert undoes a batch of journalled modifications along with any reverted // dirty handling too. -func (j *linearJournal) revert(statedb *StateDB, snapshot int) { +func (j *linearJournal) revertTo(statedb *StateDB, snapshot int) { for i := len(j.entries) - 1; i >= snapshot; i-- { // Undo the changes made by the operation j.entries[i].revert(statedb) @@ -159,10 +150,9 @@ func (j *linearJournal) copy() journal { entries = append(entries, j.entries[i].copy()) } return &linearJournal{ - entries: entries, - dirties: maps.Clone(j.dirties), - validRevisions: slices.Clone(j.validRevisions), - nextRevisionId: j.nextRevisionId, + entries: entries, + dirties: maps.Clone(j.dirties), + revisions: slices.Clone(j.revisions), } } diff --git a/core/state/journal_set.go b/core/state/journal_set.go index b09896092c..5b0cf6ad85 100644 --- a/core/state/journal_set.go +++ b/core/state/journal_set.go @@ -18,13 +18,11 @@ package state import ( "bytes" - "fmt" "maps" "slices" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/log" "github.com/holiman/uint256" ) @@ -349,50 +347,33 @@ func (j *sparseJournal) copy() journal { return cp } -// snapshot returns an identifier for the current revision of the state. +// snapshot starts a new journal scope which can be reverted or discarded. // OBS: A call to Snapshot is _required_ in order to initialize the journalling, // invoking the journal-methods without having invoked Snapshot will lead to // panic. -func (j *sparseJournal) snapshot() int { - id := len(j.entries) +func (j *sparseJournal) snapshot() { j.entries = append(j.entries, newScopedJournal()) - return id } -// revertToSnapshot reverts all state changes made since the given revision. -func (j *sparseJournal) revertToSnapshot(id int, s *StateDB) { - if id >= len(j.entries) { - panic(fmt.Errorf("revision id %v cannot be reverted", id)) - } - // Revert the entries sequentially - for i := len(j.entries) - 1; i >= id; i-- { - entry := j.entries[i] - entry.revert(s) - } +// revertSnapshot reverts all state changes made since the last call to snapshot(). +func (j *sparseJournal) revertSnapshot(s *StateDB) { + id := len(j.entries) - 1 + j.entries[id].revert(s) j.entries = j.entries[:id] } -// discardSnapshot removes the snapshot with the given id; after calling this +// discardSnapshot removes the latest snapshot; after calling this // method, it is no longer possible to revert to that particular snapshot, the // changes are considered part of the parent scope. -func (j *sparseJournal) discardSnapshot(id int) { - if id == 0 { - return - } +func (j *sparseJournal) discardSnapshot() { + id := len(j.entries) - 1 // here we must merge the 'id' with it's parent. - want := len(j.entries) - 1 - have := id - if want != have { - if want == 0 && id == 1 { - // If a transcation is applied successfully, the statedb.Finalize will - // end by clearing and resetting the journal. Invoking a discardSnapshot - // afterwards will lead us here. - // Let's not panic, but it's ok to complain a bit - log.Error("Extraneous invocation to discard snapshot") - return - } else { - panic(fmt.Sprintf("journalling error, want discard(%d), have discard(%d)", want, have)) - } + if id == 0 { + // If a transaction is applied successfully, the statedb.Finalize will + // end by clearing and resetting the journal. Invoking a discardSnapshot + // afterwards will land here: calling discard on an empty journal. + // This is fine + return } entry := j.entries[id] parent := j.entries[id-1] diff --git a/core/state/journal_test.go b/core/state/journal_test.go index 803f3d8b0a..5f272c9bf4 100644 --- a/core/state/journal_test.go +++ b/core/state/journal_test.go @@ -75,23 +75,23 @@ func testJournalAccessList(t *testing.T, j journal) { statedb.accessList = newAccessList() statedb.journal = j + j.snapshot() { // If the journal performs the rollback in the wrong order, this // will cause a panic. - id := j.snapshot() statedb.AddSlotToAccessList(common.Address{0x1}, common.Hash{0x4}) statedb.AddSlotToAccessList(common.Address{0x3}, common.Hash{0x4}) - statedb.RevertToSnapshot(id) } + statedb.RevertSnapshot() + j.snapshot() { - id := j.snapshot() statedb.AddAddressToAccessList(common.Address{0x2}) statedb.AddAddressToAccessList(common.Address{0x3}) statedb.AddAddressToAccessList(common.Address{0x4}) - statedb.RevertToSnapshot(id) - if statedb.accessList.ContainsAddress(common.Address{0x2}) { - t.Fatal("should be missing") - } + } + statedb.RevertSnapshot() + if statedb.accessList.ContainsAddress(common.Address{0x2}) { + t.Fatal("should be missing") } } @@ -107,25 +107,27 @@ func testJournalRefunds(t *testing.T, j journal) { var statedb = &StateDB{} statedb.accessList = newAccessList() statedb.journal = j - zero := j.snapshot() - j.refundChange(0) - j.refundChange(1) + j.snapshot() { - id := j.snapshot() - j.refundChange(2) - j.refundChange(3) - j.revertToSnapshot(id, statedb) + j.refundChange(0) + j.refundChange(1) + j.snapshot() + { + j.refundChange(2) + j.refundChange(3) + } + j.revertSnapshot(statedb) if have, want := statedb.refund, uint64(2); have != want { t.Fatalf("have %d want %d", have, want) } + j.snapshot() + { + j.refundChange(2) + j.refundChange(3) + } + j.discardSnapshot() } - { - id := j.snapshot() - j.refundChange(2) - j.refundChange(3) - j.discardSnapshot(id) - } - j.revertToSnapshot(zero, statedb) + j.revertSnapshot(statedb) if have, want := statedb.refund, uint64(0); have != want { t.Fatalf("have %d want %d", have, want) } diff --git a/core/state/state_test.go b/core/state/state_test.go index 6f54300c37..510ebbd628 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -160,25 +160,26 @@ func TestSnapshot(t *testing.T) { s := newStateEnv() // snapshot the genesis state - genesis := s.state.Snapshot() + s.state.Snapshot() + { + // set initial state object value + s.state.SetState(stateobjaddr, storageaddr, data1) + s.state.Snapshot() + { + // set a new state object value, revert it and ensure correct content + s.state.SetState(stateobjaddr, storageaddr, data2) + } + s.state.RevertSnapshot() - // set initial state object value - s.state.SetState(stateobjaddr, storageaddr, data1) - snapshot := s.state.Snapshot() - - // set a new state object value, revert it and ensure correct content - s.state.SetState(stateobjaddr, storageaddr, data2) - s.state.RevertToSnapshot(snapshot) - - if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 { - t.Errorf("wrong storage value %v, want %v", v, data1) + if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 { + t.Errorf("wrong storage value %v, want %v", v, data1) + } + if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) + } } - if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { - t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) - } - // revert up to the genesis state and ensure correct content - s.state.RevertToSnapshot(genesis) + s.state.RevertSnapshot() if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong storage value %v, want %v", v, common.Hash{}) } @@ -189,22 +190,23 @@ func TestSnapshot(t *testing.T) { func TestSnapshotEmpty(t *testing.T) { s := newStateEnv() - s.state.RevertToSnapshot(s.state.Snapshot()) + s.state.Snapshot() + s.state.RevertSnapshot() } func TestCreateObjectRevert(t *testing.T) { state, _ := New(types.EmptyRootHash, NewDatabaseForTesting()) addr := common.BytesToAddress([]byte("so0")) - snap := state.Snapshot() - - state.CreateAccount(addr) - so0 := state.getStateObject(addr) - so0.SetBalance(uint256.NewInt(42)) - so0.SetNonce(43) - so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) - state.setStateObject(so0) - - state.RevertToSnapshot(snap) + state.Snapshot() + { + state.CreateAccount(addr) + so0 := state.getStateObject(addr) + so0.SetBalance(uint256.NewInt(42)) + so0.SetNonce(43) + so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) + state.setStateObject(so0) + } + state.RevertSnapshot() if state.Exist(addr) { t.Error("Unexpected account after revert") } diff --git a/core/state/statedb.go b/core/state/statedb.go index 9f73135721..b1cf6beca8 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -132,7 +132,7 @@ type StateDB struct { transientStorage transientStorage // Journal of state modifications. This is the backbone of - // Snapshot and RevertToSnapshot. + // Snapshot and RevertSnapshot. journal journal // State witness if cross validation is needed @@ -708,21 +708,21 @@ func (s *StateDB) Copy() *StateDB { return state } -// Snapshot returns an identifier for the current revision of the state. -func (s *StateDB) Snapshot() int { - return s.journal.snapshot() +// Snapshot starts a new journalled scope. +func (s *StateDB) Snapshot() { + s.journal.snapshot() } -// DiscardSnapshot removes the snapshot with the given id; after calling this -// method, it is no longer possible to revert to that particular snapshot, the -// changes are considered part of the parent scope. -func (s *StateDB) DiscardSnapshot(id int) { - s.journal.discardSnapshot(id) +// DiscardSnapshot removes the ability to roll back the changes in the most +// recent journalled scope. After calling this method, the changes are considered +// part of the parent scope. +func (s *StateDB) DiscardSnapshot() { + s.journal.discardSnapshot() } -// RevertToSnapshot reverts all state changes made since the given revision. -func (s *StateDB) RevertToSnapshot(revid int) { - s.journal.revertToSnapshot(revid, s) +// RevertSnapshot reverts all state changes made in the most recent journalled scope. +func (s *StateDB) RevertSnapshot() { + s.journal.revertSnapshot(s) } // GetRefund returns the current value of the refund counter. diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index 8b84f96014..0db120c4b6 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -141,16 +141,16 @@ func (s *hookedStateDB) Prepare(rules params.Rules, sender, coinbase common.Addr s.inner.Prepare(rules, sender, coinbase, dest, precompiles, txAccesses) } -func (s *hookedStateDB) DiscardSnapshot(id int) { - s.inner.DiscardSnapshot(id) +func (s *hookedStateDB) DiscardSnapshot() { + s.inner.DiscardSnapshot() } -func (s *hookedStateDB) RevertToSnapshot(i int) { - s.inner.RevertToSnapshot(i) +func (s *hookedStateDB) RevertSnapshot() { + s.inner.RevertSnapshot() } -func (s *hookedStateDB) Snapshot() int { - return s.inner.Snapshot() +func (s *hookedStateDB) Snapshot() { + s.inner.Snapshot() } func (s *hookedStateDB) AddPreimage(hash common.Hash, bytes []byte) { diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index e20a332e78..a48cb63daa 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -563,14 +563,13 @@ func (test *snapshotTest) String() string { func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( - state, _ = New(types.EmptyRootHash, NewDatabaseForTesting()) - snapshotRevs = make([]int, len(test.snapshots)) - sindex = 0 - checkstates = make([]*StateDB, len(test.snapshots)) + state, _ = New(types.EmptyRootHash, NewDatabaseForTesting()) + sindex = 0 + checkstates = make([]*StateDB, len(test.snapshots)) ) for i, action := range test.actions { if len(test.snapshots) > sindex && i == test.snapshots[sindex] { - snapshotRevs[sindex] = state.Snapshot() + state.Snapshot() checkstates[sindex] = state.Copy() sindex++ } @@ -579,7 +578,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - state.RevertToSnapshot(snapshotRevs[sindex]) + state.RevertSnapshot() if err := test.checkEqual(state, checkstates[sindex]); err != nil { test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) return false @@ -734,13 +733,13 @@ func TestTouchDelete(t *testing.T) { root, _ := s.state.Commit(0, false) s.state, _ = New(root, s.state.db) - snapshot := s.state.Snapshot() + s.state.Snapshot() s.state.AddBalance(common.Address{}, new(uint256.Int), tracing.BalanceChangeUnspecified) if len(s.state.journal.dirtyAccounts()) != 1 { t.Fatal("expected one dirty state object") } - s.state.RevertToSnapshot(snapshot) + s.state.RevertSnapshot() if len(s.state.journal.dirtyAccounts()) != 0 { t.Fatal("expected no dirty state object") } @@ -985,9 +984,11 @@ func TestDeleteCreateRevert(t *testing.T) { state.SelfDestruct(addr) state.Finalise(true) - id := state.Snapshot() - state.SetBalance(addr, uint256.NewInt(2), tracing.BalanceChangeUnspecified) - state.RevertToSnapshot(id) + state.Snapshot() + { + state.SetBalance(addr, uint256.NewInt(2), tracing.BalanceChangeUnspecified) + } + state.RevertSnapshot() // Commit the entire state and make sure we don't crash and have the correct state root, _ = state.Commit(0, true) @@ -1126,25 +1127,15 @@ func TestStateDBAccessList(t *testing.T) { } } - var ids []int - push := func(id int) { - ids = append(ids, id) - } - pop := func() int { - id := ids[len(ids)-1] - ids = ids[:len(ids)-1] - return id - } - - push(state.journal.snapshot()) // journal id 0 + state.journal.snapshot() // journal id 0 state.AddAddressToAccessList(addr("aa")) // 1 - push(state.journal.snapshot()) // journal id 1 + state.journal.snapshot() // journal id 1 state.AddAddressToAccessList(addr("bb")) // 2 - push(state.journal.snapshot()) // journal id 2 + state.journal.snapshot() // journal id 2 state.AddSlotToAccessList(addr("bb"), slot("01")) // 3 - push(state.journal.snapshot()) // journal id 3 + state.journal.snapshot() // journal id 3 state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 - push(state.journal.snapshot()) // journal id 4 + state.journal.snapshot() // journal id 4 verifyAddrs("aa", "bb") verifySlots("bb", "01", "02") @@ -1158,11 +1149,11 @@ func TestStateDBAccessList(t *testing.T) { // some new ones state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 - push(state.journal.snapshot()) // journal id 5 + state.journal.snapshot() // journal id 5 state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 - push(state.journal.snapshot()) // journal id 6 + state.journal.snapshot() // journal id 6 state.AddAddressToAccessList(addr("cc")) // 7 - push(state.journal.snapshot()) // journal id 7 + state.journal.snapshot() // journal id 7 state.AddSlotToAccessList(addr("cc"), slot("01")) // 8 verifyAddrs("aa", "bb", "cc") @@ -1171,7 +1162,7 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("cc", "01") // now start rolling back changes - state.journal.revertToSnapshot(pop(), state) // revert to 6 + state.journal.revertSnapshot(state) // revert to 6 if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { t.Fatalf("slot present, expected missing") } @@ -1179,7 +1170,7 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("aa", "01") verifySlots("bb", "01", "02", "03") - state.journal.revertToSnapshot(pop(), state) // revert to 5 + state.journal.revertSnapshot(state) // revert to 5 if state.AddressInAccessList(addr("cc")) { t.Fatalf("addr present, expected missing") } @@ -1187,40 +1178,40 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("aa", "01") verifySlots("bb", "01", "02", "03") - state.journal.revertToSnapshot(pop(), state) // revert to 4 + state.journal.revertSnapshot(state) // revert to 4 if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01", "02", "03") - state.journal.revertToSnapshot(pop(), state) // revert to 3 + state.journal.revertSnapshot(state) // revert to 3 if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01", "02") - state.journal.revertToSnapshot(pop(), state) // revert to 2 + state.journal.revertSnapshot(state) // revert to 2 if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01") - state.journal.revertToSnapshot(pop(), state) // revert to 1 + state.journal.revertSnapshot(state) // revert to 1 if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") - state.journal.revertToSnapshot(pop(), state) // revert to 0 + state.journal.revertSnapshot(state) // revert to 0 if state.AddressInAccessList(addr("bb")) { t.Fatalf("addr present, expected missing") } verifyAddrs("aa") - state.journal.revertToSnapshot(0, state) + state.journal.revertSnapshot(state) if state.AddressInAccessList(addr("aa")) { t.Fatalf("addr present, expected missing") } @@ -1291,7 +1282,7 @@ func TestStateDBTransientStorage(t *testing.T) { key := common.Hash{0x01} value := common.Hash{0x02} addr := common.Address{} - revision := state.journal.snapshot() + state.journal.snapshot() state.SetTransientState(addr, key, value) // the retrieved value should equal what was set @@ -1301,7 +1292,7 @@ func TestStateDBTransientStorage(t *testing.T) { // revert the transient state being set and then check that the // value is now the empty hash - state.journal.revertToSnapshot(revision, state) + state.journal.revertSnapshot(state) if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got { t.Fatalf("transient storage mismatch: have %x, want %x", got, exp) } @@ -1395,22 +1386,24 @@ func TestStorageDirtiness(t *testing.T) { checkDirty(common.Hash{0x1}, common.Hash{}, false) // the storage change is valid, dirty marker is expected - snap := state.Snapshot() - state.SetState(addr, common.Hash{0x1}, common.Hash{0x1}) - checkDirty(common.Hash{0x1}, common.Hash{0x1}, true) - + state.Snapshot() + { + state.SetState(addr, common.Hash{0x1}, common.Hash{0x1}) + checkDirty(common.Hash{0x1}, common.Hash{0x1}, true) + } // the storage change is reverted, dirtiness should be revoked - state.RevertToSnapshot(snap) + state.RevertSnapshot() checkDirty(common.Hash{0x1}, common.Hash{}, false) // the storage is reset back to its original value, dirtiness should be revoked state.SetState(addr, common.Hash{0x1}, common.Hash{0x1}) - snap = state.Snapshot() - state.SetState(addr, common.Hash{0x1}, common.Hash{}) - checkDirty(common.Hash{0x1}, common.Hash{}, false) - + state.Snapshot() + { + state.SetState(addr, common.Hash{0x1}, common.Hash{}) + checkDirty(common.Hash{0x1}, common.Hash{}, false) + } // the storage change is reverted, dirty value should be set back - state.RevertToSnapshot(snap) + state.RevertSnapshot() checkDirty(common.Hash{0x1}, common.Hash{0x1}, true) } @@ -1455,10 +1448,12 @@ func TestStorageDirtiness2(t *testing.T) { checkDirty(common.Hash{0x1}, common.Hash{0xa}, false) // Enter new scope - snap := state.Snapshot() - state.SetState(addr, common.Hash{0x1}, common.Hash{0xb}) // SLOT(1) = 0xB - checkDirty(common.Hash{0x1}, common.Hash{0xb}, true) // Should be flagged dirty - state.RevertToSnapshot(snap) // Revert scope + state.Snapshot() + { + state.SetState(addr, common.Hash{0x1}, common.Hash{0xb}) // SLOT(1) = 0xB + checkDirty(common.Hash{0x1}, common.Hash{0xb}, true) // Should be flagged dirty + } + state.RevertSnapshot() // Revert scope // the storage change has been set back to original, dirtiness should be revoked checkDirty(common.Hash{0x1}, common.Hash{0x1}, false) diff --git a/core/vm/evm.go b/core/vm/evm.go index 0f16be4add..38ec2101f1 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -190,7 +190,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas if !value.IsZero() && !evm.Context.CanTransfer(evm.StateDB, caller.Address(), value) { return nil, gas, ErrInsufficientBalance } - snapshot := evm.StateDB.Snapshot() + evm.StateDB.Snapshot() p, isPrecompile := evm.precompile(addr) if !evm.StateDB.Exist(addr) { @@ -198,7 +198,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas // add proof of absence to witness wgas := evm.AccessEvents.AddAccount(addr, false) if gas < wgas { - evm.StateDB.RevertToSnapshot(snapshot) + evm.StateDB.RevertSnapshot() return nil, 0, ErrOutOfGas } gas -= wgas @@ -206,7 +206,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas if !isPrecompile && evm.chainRules.IsEIP158 && value.IsZero() { // Calling a non-existing account, don't do anything. - evm.StateDB.DiscardSnapshot(snapshot) + evm.StateDB.DiscardSnapshot() return nil, gas, nil } evm.StateDB.CreateAccount(addr) @@ -235,7 +235,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas // above we revert to the snapshot and consume any gas remaining. Additionally, // when we're in homestead this also counts for code storage gas errors. if err != nil { - evm.StateDB.RevertToSnapshot(snapshot) + evm.StateDB.RevertSnapshot() if err != ErrExecutionReverted { if evm.Config.Tracer != nil && evm.Config.Tracer.OnGasChange != nil { evm.Config.Tracer.OnGasChange(gas, 0, tracing.GasChangeCallFailedExecution) @@ -244,7 +244,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas gas = 0 } } else { - evm.StateDB.DiscardSnapshot(snapshot) + evm.StateDB.DiscardSnapshot() } return ret, gas, err } @@ -275,7 +275,7 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte, if !evm.Context.CanTransfer(evm.StateDB, caller.Address(), value) { return nil, gas, ErrInsufficientBalance } - var snapshot = evm.StateDB.Snapshot() + evm.StateDB.Snapshot() // It is allowed to call precompiles, even via delegatecall if p, isPrecompile := evm.precompile(addr); isPrecompile { @@ -290,7 +290,7 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte, gas = contract.Gas } if err != nil { - evm.StateDB.RevertToSnapshot(snapshot) + evm.StateDB.RevertSnapshot() if err != ErrExecutionReverted { if evm.Config.Tracer != nil && evm.Config.Tracer.OnGasChange != nil { evm.Config.Tracer.OnGasChange(gas, 0, tracing.GasChangeCallFailedExecution) @@ -299,7 +299,7 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte, gas = 0 } } else { - evm.StateDB.DiscardSnapshot(snapshot) + evm.StateDB.DiscardSnapshot() } return ret, gas, err } @@ -325,7 +325,7 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by if evm.depth > int(params.CallCreateDepth) { return nil, gas, ErrDepth } - var snapshot = evm.StateDB.Snapshot() + evm.StateDB.Snapshot() // It is allowed to call precompiles, even via delegatecall if p, isPrecompile := evm.precompile(addr); isPrecompile { @@ -339,7 +339,7 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by gas = contract.Gas } if err != nil { - evm.StateDB.RevertToSnapshot(snapshot) + evm.StateDB.RevertSnapshot() if err != ErrExecutionReverted { if evm.Config.Tracer != nil && evm.Config.Tracer.OnGasChange != nil { evm.Config.Tracer.OnGasChange(gas, 0, tracing.GasChangeCallFailedExecution) @@ -347,7 +347,7 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by gas = 0 } } else { - evm.StateDB.DiscardSnapshot(snapshot) + evm.StateDB.DiscardSnapshot() } return ret, gas, err } @@ -373,7 +373,7 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte // after all empty accounts were deleted, so this is not required. However, if we omit this, // then certain tests start failing; stRevertTest/RevertPrecompiledTouchExactOOG.json. // We could change this, but for now it's left for legacy reasons - var snapshot = evm.StateDB.Snapshot() + evm.StateDB.Snapshot() // We do an AddBalance of zero here, just in order to trigger a touch. // This doesn't matter on Mainnet, where all empties are gone at the time of Byzantium, @@ -399,7 +399,7 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte gas = contract.Gas } if err != nil { - evm.StateDB.RevertToSnapshot(snapshot) + evm.StateDB.RevertSnapshot() if err != ErrExecutionReverted { if evm.Config.Tracer != nil && evm.Config.Tracer.OnGasChange != nil { evm.Config.Tracer.OnGasChange(gas, 0, tracing.GasChangeCallFailedExecution) @@ -408,7 +408,7 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte gas = 0 } } else { - evm.StateDB.DiscardSnapshot(snapshot) + evm.StateDB.DiscardSnapshot() } return ret, gas, err } @@ -482,7 +482,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, // Create a new account on the state only if the object was not present. // It might be possible the contract code is deployed to a pre-existent // account with non-zero balance. - snapshot := evm.StateDB.Snapshot() + evm.StateDB.Snapshot() if !evm.StateDB.Exist(address) { evm.StateDB.CreateAccount(address) } @@ -516,12 +516,12 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, ret, err = evm.initNewContract(contract, address, value) if err != nil && (evm.chainRules.IsHomestead || err != ErrCodeStoreOutOfGas) { - evm.StateDB.RevertToSnapshot(snapshot) + evm.StateDB.RevertSnapshot() if err != ErrExecutionReverted { contract.UseGas(contract.Gas, evm.Config.Tracer, tracing.GasChangeCallFailedExecution) } } else { - evm.StateDB.DiscardSnapshot(snapshot) + evm.StateDB.DiscardSnapshot() } return ret, address, contract.Gas, err } diff --git a/core/vm/interface.go b/core/vm/interface.go index 4f5ada46f2..968dc8450b 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -90,15 +90,14 @@ type StateDB interface { Prepare(rules params.Rules, sender, coinbase common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) - // RevertToSnapshot reverts all state changes made since the given revision. - RevertToSnapshot(int) - - // DiscardSnapshot removes the snapshot with the given id; after calling this - // method, it is no longer possible to revert to that particular snapshot, the - // changes are considered part of the parent scope. - DiscardSnapshot(int) - // Snapshot returns an identifier for the current scope of the state. - Snapshot() int + // Snapshot starts a new journalled scope. + Snapshot() + // RevertSnapshot reverts all state changes made in the most recent journalled scope. + RevertSnapshot() + // DiscardSnapshot removes the ability to roll back the changes in the most + // recent journalled scope. After calling this method, the changes are considered + // part of the parent scope. + DiscardSnapshot() AddLog(*types.Log) AddPreimage(common.Hash, []byte) diff --git a/eth/tracers/internal/tracetest/calltrace_test.go b/eth/tracers/internal/tracetest/calltrace_test.go index 999ab211c0..20bd1f4721 100644 --- a/eth/tracers/internal/tracetest/calltrace_test.go +++ b/eth/tracers/internal/tracetest/calltrace_test.go @@ -219,7 +219,7 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) { evm := vm.NewEVM(context, state.StateDB, test.Genesis.Config, vm.Config{}) for i := 0; i < b.N; i++ { - snap := state.StateDB.Snapshot() + state.StateDB.Snapshot() tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), nil, test.Genesis.Config) if err != nil { b.Fatalf("failed to create call tracer: %v", err) @@ -238,7 +238,7 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) { if _, err = tracer.GetResult(); err != nil { b.Fatal(err) } - state.StateDB.RevertToSnapshot(snap) + state.StateDB.RevertSnapshot() } } diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index a72dbf6ee6..31b01ebf16 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -92,11 +92,11 @@ func BenchmarkTransactionTraceV2(b *testing.B) { tracer.OnTxStart(evm.GetVMContext(), tx, msg.From) evm.Config.Tracer = tracer - snap := state.StateDB.Snapshot() + state.StateDB.Snapshot() _, err := core.ApplyMessage(evm, msg, new(core.GasPool).AddGas(tx.Gas())) if err != nil { b.Fatal(err) } - state.StateDB.RevertToSnapshot(snap) + state.StateDB.RevertSnapshot() } } diff --git a/miner/worker.go b/miner/worker.go index b5aa080025..f51789da56 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -302,13 +302,11 @@ func (miner *Miner) commitBlobTransaction(env *environment, tx *types.Transactio // applyTransaction runs the transaction. If execution fails, state and gas pool are reverted. func (miner *Miner) applyTransaction(env *environment, tx *types.Transaction) (*types.Receipt, error) { - var ( - snap = env.state.Snapshot() - gp = env.gasPool.Gas() - ) + gp := env.gasPool.Gas() + env.state.Snapshot() receipt, err := core.ApplyTransaction(env.evm, env.gasPool, env.state, env.header, tx, &env.header.GasUsed) if err != nil { - env.state.RevertToSnapshot(snap) + env.state.RevertSnapshot() env.gasPool.SetGas(gp) } return receipt, err diff --git a/tests/state_test.go b/tests/state_test.go index 7b82b05e58..5e86ca081f 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -316,7 +316,7 @@ func runBenchmark(b *testing.B, t *StateTest) { ) b.ResetTimer() for n := 0; n < b.N; n++ { - snapshot := state.StateDB.Snapshot() + state.StateDB.Snapshot() state.StateDB.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList) b.StartTimer() start := time.Now() @@ -333,7 +333,7 @@ func runBenchmark(b *testing.B, t *StateTest) { refund += state.StateDB.GetRefund() gasUsed += msg.GasLimit - leftOverGas - state.StateDB.RevertToSnapshot(snapshot) + state.StateDB.RevertSnapshot() } if elapsed < 1 { elapsed = 1 diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 3dade82b0b..892eb09ca8 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -320,18 +320,18 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh tracer.OnTxStart(evm.GetVMContext(), nil, msg.From) } // Execute the message. - snapshot := st.StateDB.Snapshot() + st.StateDB.Snapshot() gaspool := new(core.GasPool) gaspool.AddGas(block.GasLimit()) vmRet, err := core.ApplyMessage(evm, msg, gaspool) if err != nil { - st.StateDB.RevertToSnapshot(snapshot) + st.StateDB.RevertSnapshot() if tracer := evm.Config.Tracer; tracer != nil && tracer.OnTxEnd != nil { evm.Config.Tracer.OnTxEnd(nil, err) } return st, common.Hash{}, 0, err } - st.StateDB.DiscardSnapshot(snapshot) + st.StateDB.DiscardSnapshot() // Add 0-value mining reward. This only makes a difference in the cases // where // - the coinbase self-destructed, or