From 2c9afab25476d488e7419a06ed4828b53e8f7b8d Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Fri, 19 Apr 2024 09:08:05 +0200 Subject: [PATCH] core/state: separate journal-implementation behind interface, implement createaccount --- core/state/journal.go | 88 +++++++++++++++++------------- core/state/journal_api.go | 69 ++++++++++++++++++++++++ core/state/statedb.go | 9 ++-- core/state/statedb_hooked.go | 2 +- core/state/statedb_test.go | 102 +++++++++++++++++++++-------------- 5 files changed, 186 insertions(+), 84 deletions(-) create mode 100644 core/state/journal_api.go diff --git a/core/state/journal.go b/core/state/journal.go index 13332dbd79..860d742776 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -32,33 +32,36 @@ type revision struct { journalIndex int } -// journalEntry is a modification entry in the state change journal that can be +// journalEntry is a modification entry in the state change linear journal that can be // reverted on demand. type journalEntry interface { - // revert undoes the changes introduced by this journal entry. + // revert undoes the changes introduced by this entry. revert(*StateDB) - // dirtied returns the Ethereum address modified by this journal entry. + // dirtied returns the Ethereum address modified by this entry. dirtied() *common.Address - // copy returns a deep-copied journal entry. + // copy returns a deep-copied entry. copy() journalEntry } -// journal contains the list of state modifications applied since the last state +// linearJournal contains the list of state modifications applied since the last state // commit. These are tracked to be able to be reverted in the case of an execution // exception or request for reversal. -type journal struct { - entries []journalEntry // Current changes tracked by the journal +type linearJournal struct { + entries []journalEntry // Current changes tracked by the linearJournal dirties map[common.Address]int // Dirty accounts and the number of changes validRevisions []revision nextRevisionId int } -// newJournal creates a new initialized journal. -func newJournal() *journal { - return &journal{ +// compile-time interface check +var _ journal = (*linearJournal)(nil) + +// newLinearJournal creates a new initialized linearJournal. +func newLinearJournal() *linearJournal { + return &linearJournal{ dirties: make(map[common.Address]int), } } @@ -66,15 +69,24 @@ func newJournal() *journal { // reset clears the journal, after this operation the journal can be used anew. // It is semantically similar to calling 'newJournal', but the underlying slices // can be reused. -func (j *journal) reset() { +func (j *linearJournal) reset() { j.entries = j.entries[:0] j.validRevisions = j.validRevisions[:0] clear(j.dirties) j.nextRevisionId = 0 } +func (j linearJournal) dirtyAccounts() []common.Address { + dirty := make([]common.Address, 0, len(j.dirties)) + // flatten into list + for addr := range j.dirties { + dirty = append(dirty, addr) + } + return dirty +} + // snapshot returns an identifier for the current revision of the state. -func (j *journal) snapshot() int { +func (j *linearJournal) snapshot() int { id := j.nextRevisionId j.nextRevisionId++ j.validRevisions = append(j.validRevisions, revision{id, j.length()}) @@ -82,23 +94,23 @@ func (j *journal) snapshot() int { } // revertToSnapshot reverts all state changes made since the given revision. -func (j *journal) revertToSnapshot(revid int, s *StateDB) { +func (j *linearJournal) revertToSnapshot(revid int, s *StateDB) { // Find the snapshot in the stack of valid snapshots. idx := sort.Search(len(j.validRevisions), func(i int) bool { return j.validRevisions[i].id >= revid }) if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid { - panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + panic(fmt.Errorf("revision id %v cannot be reverted (valid revisions: %d)", revid, len(j.validRevisions))) } snapshot := j.validRevisions[idx].journalIndex - // Replay the journal to undo changes and remove invalidated snapshots + // Replay the linearJournal to undo changes and remove invalidated snapshots j.revert(s, snapshot) j.validRevisions = j.validRevisions[:idx] } -// append inserts a new modification entry to the end of the change journal. -func (j *journal) append(entry journalEntry) { +// append inserts a new modification entry to the end of the change linearJournal. +func (j *linearJournal) append(entry journalEntry) { j.entries = append(j.entries, entry) if addr := entry.dirtied(); addr != nil { j.dirties[*addr]++ @@ -107,7 +119,7 @@ func (j *journal) append(entry journalEntry) { // revert undoes a batch of journalled modifications along with any reverted // dirty handling too. -func (j *journal) revert(statedb *StateDB, snapshot int) { +func (j *linearJournal) revert(statedb *StateDB, snapshot int) { for i := len(j.entries) - 1; i >= snapshot; i-- { // Undo the changes made by the operation j.entries[i].revert(statedb) @@ -125,22 +137,22 @@ func (j *journal) revert(statedb *StateDB, snapshot int) { // dirty explicitly sets an address to dirty, even if the change entries would // otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD // precompile consensus exception. -func (j *journal) dirty(addr common.Address) { +func (j *linearJournal) dirty(addr common.Address) { j.dirties[addr]++ } -// length returns the current number of entries in the journal. -func (j *journal) length() int { +// length returns the current number of entries in the linearJournal. +func (j *linearJournal) length() int { return len(j.entries) } // copy returns a deep-copied journal. -func (j *journal) copy() *journal { +func (j *linearJournal) copy() journal { entries := make([]journalEntry, 0, j.length()) for i := 0; i < j.length(); i++ { entries = append(entries, j.entries[i].copy()) } - return &journal{ + return &linearJournal{ entries: entries, dirties: maps.Clone(j.dirties), validRevisions: slices.Clone(j.validRevisions), @@ -148,23 +160,23 @@ func (j *journal) copy() *journal { } } -func (j *journal) logChange(txHash common.Hash) { +func (j *linearJournal) logChange(txHash common.Hash) { j.append(addLogChange{txhash: txHash}) } -func (j *journal) createObject(addr common.Address) { +func (j *linearJournal) createObject(addr common.Address) { j.append(createObjectChange{account: addr}) } -func (j *journal) createContract(addr common.Address) { +func (j *linearJournal) createContract(addr common.Address) { j.append(createContractChange{account: addr}) } -func (j *journal) destruct(addr common.Address) { +func (j *linearJournal) destruct(addr common.Address) { j.append(selfDestructChange{account: addr}) } -func (j *journal) storageChange(addr common.Address, key, prev, origin common.Hash) { +func (j *linearJournal) storageChange(addr common.Address, key, prev, origin common.Hash) { j.append(storageChange{ account: addr, key: key, @@ -173,7 +185,7 @@ func (j *journal) storageChange(addr common.Address, key, prev, origin common.Ha }) } -func (j *journal) transientStateChange(addr common.Address, key, prev common.Hash) { +func (j *linearJournal) transientStateChange(addr common.Address, key, prev common.Hash) { j.append(transientStorageChange{ account: addr, key: key, @@ -181,32 +193,32 @@ func (j *journal) transientStateChange(addr common.Address, key, prev common.Has }) } -func (j *journal) refundChange(previous uint64) { +func (j *linearJournal) refundChange(previous uint64) { j.append(refundChange{prev: previous}) } -func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) { +func (j *linearJournal) balanceChange(addr common.Address, previous *uint256.Int) { j.append(balanceChange{ account: addr, prev: previous.Clone(), }) } -func (j *journal) setCode(address common.Address, prevCode []byte) { +func (j *linearJournal) setCode(address common.Address, prevCode []byte) { j.append(codeChange{ account: address, prevCode: prevCode, }) } -func (j *journal) nonceChange(address common.Address, prev uint64) { +func (j *linearJournal) nonceChange(address common.Address, prev uint64) { j.append(nonceChange{ account: address, prev: prev, }) } -func (j *journal) touchChange(address common.Address) { +func (j *linearJournal) touchChange(address common.Address) { j.append(touchChange{ account: address, }) @@ -217,11 +229,11 @@ func (j *journal) touchChange(address common.Address) { } } -func (j *journal) accessListAddAccount(addr common.Address) { +func (j *linearJournal) accessListAddAccount(addr common.Address) { j.append(accessListAddAccountChange{addr}) } -func (j *journal) accessListAddSlot(addr common.Address, slot common.Hash) { +func (j *linearJournal) accessListAddSlot(addr common.Address, slot common.Hash) { j.append(accessListAddSlotChange{ address: addr, slot: slot, @@ -234,7 +246,7 @@ type ( account common.Address } // createContractChange represents an account becoming a contract-account. - // This event happens prior to executing initcode. The journal-event simply + // This event happens prior to executing initcode. The linearJournal-event simply // manages the created-flag, in order to allow same-tx destruction. createContractChange struct { account common.Address @@ -464,7 +476,7 @@ func (ch addLogChange) copy() journalEntry { func (ch accessListAddAccountChange) revert(s *StateDB) { /* One important invariant here, is that whenever a (addr, slot) is added, if the - addr is not already present, the add causes two journal entries: + addr is not already present, the add causes two linearJournal entries: - one for the address, - one for the (address,slot) Therefore, when unrolling the change, we can always blindly delete the diff --git a/core/state/journal_api.go b/core/state/journal_api.go new file mode 100644 index 0000000000..96f0167500 --- /dev/null +++ b/core/state/journal_api.go @@ -0,0 +1,69 @@ +package state + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" +) + +type journal interface { + + // snapshot returns an identifier for the current revision of the state. + snapshot() int + + // revertToSnapshot reverts all state changes made since the given revision. + revertToSnapshot(revid int, s *StateDB) + + // reset clears the journal so it can be reused. + reset() + + // dirtyAccounts returns a list of all accounts modified in this journal + dirtyAccounts() []common.Address + + // accessListAddAccount journals the adding of addr to the access list + accessListAddAccount(addr common.Address) + + // accessListAddSlot journals the adding of addr/slot to the access list + accessListAddSlot(addr common.Address, slot common.Hash) + + // logChange journals the adding of a log related to the txHash + logChange(txHash common.Hash) + + // createObject journals the event of a new account created in the trie. + createObject(addr common.Address) + + // createContract journals the creation of a new contract at addr. + // OBS: This method must not be applied twice, it assumes that the pre-state + // (i.e the rollback-state) is non-created. + createContract(addr common.Address) + + // destruct journals the destruction of an account in the trie. + // OBS: This method must not be applied twice -- it always assumes that the + // pre-state (i.e the rollback-state) is non-destructed. + destruct(addr common.Address) + + // storageChange journals a change in the storage data related to addr. + // It records the key and previous value of the slot. + storageChange(addr common.Address, key, prev, origin common.Hash) + + // transientStateChange journals a change in the t-storage data related to addr. + // It records the key and previous value of the slot. + transientStateChange(addr common.Address, key, prev common.Hash) + + // refundChange journals that the refund has been changed, recording the previous value. + refundChange(previous uint64) + + // balanceChange journals tha the balance of addr has been changed, recording the previous value + balanceChange(addr common.Address, previous *uint256.Int) + + // JournalSetCode journals that the code of addr has been set. + setCode(addr common.Address, prev []byte) + + // nonceChange journals that the nonce of addr was changed, recording the previous value. + nonceChange(addr common.Address, prev uint64) + + // touchChange journals that the account at addr was touched during execution. + touchChange(addr common.Address) + + // copy returns a deep-copied journal. + copy() journal +} diff --git a/core/state/statedb.go b/core/state/statedb.go index d279ccfdfe..9b03a8cf7b 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -133,7 +133,7 @@ type StateDB struct { // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. - journal *journal + journal journal // State witness if cross validation is needed witness *stateless.Witness @@ -177,7 +177,7 @@ func New(root common.Hash, db Database) (*StateDB, error) { mutations: make(map[common.Address]*mutation), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), - journal: newJournal(), + journal: newLinearJournal(), accessList: newAccessList(), transientStorage: newTransientStorage(), } @@ -727,8 +727,9 @@ func (s *StateDB) GetRefund() uint64 { // the journal as well as the refunds. Finalise, however, will not push any updates // into the tries just yet. Only IntermediateRoot or Commit will do that. func (s *StateDB) Finalise(deleteEmptyObjects bool) { - addressesToPrefetch := make([]common.Address, 0, len(s.journal.dirties)) - for addr := range s.journal.dirties { + dirties := s.journal.dirtyAccounts() + addressesToPrefetch := make([]common.Address, 0, len(dirties)) + for _, addr := range dirties { obj, exist := s.stateObjects[addr] if !exist { // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2 diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index 31bdd06b46..4dafc991cc 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -259,7 +259,7 @@ func (s *hookedStateDB) Finalise(deleteEmptyObjects bool) { if s.hooks.OnBalanceChange == nil { return } - for addr := range s.inner.journal.dirties { + for _, addr := range s.inner.journal.dirtyAccounts() { obj := s.inner.stateObjects[addr] if obj != nil && obj.selfDestructed { // If ether was sent to account post-selfdestruct it is burnt. diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 37141e90b0..7467b50f8b 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -228,7 +228,7 @@ func TestCopy(t *testing.T) { } // TestCopyWithDirtyJournal tests if Copy can correct create a equal copied -// stateDB with dirty journal present. +// stateDB with dirty linearJournal present. func TestCopyWithDirtyJournal(t *testing.T) { db := NewDatabaseForTesting() orig, _ := New(types.EmptyRootHash, db) @@ -410,8 +410,8 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { // We also set some code here, to prevent the // CreateContract action from being performed twice in a row, // which would cause a difference in state when unrolling - // the journal. (CreateContact assumes created was false prior to - // invocation, and the journal rollback sets it to false). + // the linearJournal. (CreateContact assumes created was false prior to + // invocation, and the linearJournal rollback sets it to false). s.SetCode(addr, []byte{1}) } }, @@ -673,22 +673,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) } - if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) { - getKeys := func(dirty map[common.Address]int) string { - var keys []common.Address - out := new(strings.Builder) - for key := range dirty { - keys = append(keys, key) + { // Check the dirty-accounts + have := state.journal.dirtyAccounts() + want := checkstate.journal.dirtyAccounts() + slices.SortFunc(have, common.Address.Cmp) + slices.SortFunc(want, common.Address.Cmp) + if !slices.Equal(have, want) { + getKeys := func(keys []common.Address) string { + out := new(strings.Builder) + for i, key := range keys { + fmt.Fprintf(out, " %d. %v\n", i, key) + } + return out.String() } - slices.SortFunc(keys, common.Address.Cmp) - for i, key := range keys { - fmt.Fprintf(out, " %d. %v\n", i, key) - } - return out.String() + haveK := getKeys(state.journal.dirtyAccounts()) + wantK := getKeys(checkstate.journal.dirtyAccounts()) + return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", haveK, wantK) } - have := getKeys(state.journal.dirties) - want := getKeys(checkstate.journal.dirties) - return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want) } return nil } @@ -702,11 +703,11 @@ func TestTouchDelete(t *testing.T) { snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(uint256.Int), tracing.BalanceChangeUnspecified) - if len(s.state.journal.dirties) != 1 { + if len(s.state.journal.dirtyAccounts()) != 1 { t.Fatal("expected one dirty state object") } s.state.RevertToSnapshot(snapshot) - if len(s.state.journal.dirties) != 0 { + if len(s.state.journal.dirtyAccounts()) != 0 { t.Fatal("expected no dirty state object") } } @@ -1091,32 +1092,51 @@ func TestStateDBAccessList(t *testing.T) { } } + var ids []int + push := func(id int) { + ids = append(ids, id) + } + pop := func() int { + id := ids[len(ids)-1] + ids = ids[:len(ids)-1] + return id + } + + push(state.journal.snapshot()) // journal id 0 state.AddAddressToAccessList(addr("aa")) // 1 - state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3 + push(state.journal.snapshot()) // journal id 1 + state.AddAddressToAccessList(addr("bb")) // 2 + push(state.journal.snapshot()) // journal id 2 + state.AddSlotToAccessList(addr("bb"), slot("01")) // 3 + push(state.journal.snapshot()) // journal id 3 state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 + push(state.journal.snapshot()) // journal id 4 verifyAddrs("aa", "bb") verifySlots("bb", "01", "02") // Make a copy stateCopy1 := state.Copy() - if exp, got := 4, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + if exp, got := 4, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } - // same again, should cause no journal entries + // same again, should cause no linearJournal entries state.AddSlotToAccessList(addr("bb"), slot("01")) state.AddSlotToAccessList(addr("bb"), slot("02")) state.AddAddressToAccessList(addr("aa")) - if exp, got := 4, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + if exp, got := 4, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } // some new ones state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 + push(state.journal.snapshot()) // journal id 5 state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 - state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8 - state.AddAddressToAccessList(addr("cc")) - if exp, got := 8, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + push(state.journal.snapshot()) // journal id 6 + state.AddAddressToAccessList(addr("cc")) // 7 + push(state.journal.snapshot()) // journal id 7 + state.AddSlotToAccessList(addr("cc"), slot("01")) // 8 + if exp, got := 8, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } verifyAddrs("aa", "bb", "cc") @@ -1125,7 +1145,7 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("cc", "01") // now start rolling back changes - state.journal.revert(state, 7) + state.journal.revertToSnapshot(pop(), state) // revert to 6 if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { t.Fatalf("slot present, expected missing") } @@ -1133,7 +1153,7 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("aa", "01") verifySlots("bb", "01", "02", "03") - state.journal.revert(state, 6) + state.journal.revertToSnapshot(pop(), state) // revert to 5 if state.AddressInAccessList(addr("cc")) { t.Fatalf("addr present, expected missing") } @@ -1141,40 +1161,40 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("aa", "01") verifySlots("bb", "01", "02", "03") - state.journal.revert(state, 5) + state.journal.revertToSnapshot(pop(), state) // revert to 4 if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01", "02", "03") - state.journal.revert(state, 4) + state.journal.revertToSnapshot(pop(), state) // revert to 3 if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01", "02") - state.journal.revert(state, 3) + state.journal.revertToSnapshot(pop(), state) // revert to 2 if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01") - state.journal.revert(state, 2) + state.journal.revertToSnapshot(pop(), state) // revert to 1 if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") - state.journal.revert(state, 1) + state.journal.revertToSnapshot(pop(), state) // revert to 0 if state.AddressInAccessList(addr("bb")) { t.Fatalf("addr present, expected missing") } verifyAddrs("aa") - state.journal.revert(state, 0) + state.journal.revertToSnapshot(0, state) if state.AddressInAccessList(addr("aa")) { t.Fatalf("addr present, expected missing") } @@ -1245,10 +1265,10 @@ func TestStateDBTransientStorage(t *testing.T) { key := common.Hash{0x01} value := common.Hash{0x02} addr := common.Address{} - + revision := state.journal.snapshot() state.SetTransientState(addr, key, value) - if exp, got := 1, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + if exp, got := 1, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } // the retrieved value should equal what was set if got := state.GetTransientState(addr, key); got != value { @@ -1257,7 +1277,7 @@ func TestStateDBTransientStorage(t *testing.T) { // revert the transient state being set and then check that the // value is now the empty hash - state.journal.revert(state, 0) + state.journal.revertToSnapshot(revision, state) if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got { t.Fatalf("transient storage mismatch: have %x, want %x", got, exp) }