// Copyright 2023 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see package state import ( "bytes" "encoding/binary" "errors" "fmt" "maps" "math" "math/rand" "reflect" "strings" "testing" "testing/quick" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/triedb" "github.com/ethereum/go-ethereum/triedb/pathdb" "github.com/holiman/uint256" ) // A stateTest checks that the state changes are correctly captured. Instances // of this test with pseudorandom content are created by Generate. // // The test works as follows: // // A list of states are created by applying actions. The state changes between // each state instance are tracked and be verified. type stateTest struct { addrs []common.Address // all account addresses actions [][]testAction // modifications to the state, grouped by block chunk int // The number of actions per chunk err error // failure details are reported through this field } // newStateTestAction creates a random action that changes state. func newStateTestAction(addr common.Address, r *rand.Rand, index int) testAction { actions := []testAction{ { name: "SetBalance", fn: func(a testAction, s *StateDB) { s.SetBalance(addr, uint256.NewInt(uint64(a.args[0])), tracing.BalanceChangeUnspecified) }, args: make([]int64, 1), }, { name: "SetNonce", fn: func(a testAction, s *StateDB) { s.SetNonce(addr, uint64(a.args[0])) }, args: make([]int64, 1), }, { name: "SetStorage", fn: func(a testAction, s *StateDB) { var key, val common.Hash binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) s.SetState(addr, key, val) }, args: make([]int64, 2), }, { name: "SetCode", fn: func(a testAction, s *StateDB) { code := make([]byte, 16) binary.BigEndian.PutUint64(code, uint64(a.args[0])) binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) s.SetCode(addr, code) }, args: make([]int64, 2), }, { name: "CreateAccount", fn: func(a testAction, s *StateDB) { if !s.Exist(addr) { s.CreateAccount(addr) } }, }, { name: "Selfdestruct", fn: func(a testAction, s *StateDB) { s.SelfDestruct(addr) }, }, } var nonRandom = index != -1 if index == -1 { index = r.Intn(len(actions)) } action := actions[index] var names []string if !action.noAddr { names = append(names, addr.Hex()) } for i := range action.args { if nonRandom { action.args[i] = rand.Int63n(10000) + 1 // set balance to non-zero } else { action.args[i] = rand.Int63n(10000) } names = append(names, fmt.Sprint(action.args[i])) } action.name += " " + strings.Join(names, ", ") return action } // Generate returns a new snapshot test of the given size. All randomness is // derived from r. func (*stateTest) Generate(r *rand.Rand, size int) reflect.Value { addrs := make([]common.Address, 5) for i := range addrs { addrs[i][0] = byte(i) } actions := make([][]testAction, rand.Intn(5)+1) for i := 0; i < len(actions); i++ { actions[i] = make([]testAction, size) for j := range actions[i] { if j == 0 { // Always include a set balance action to make sure // the state changes are not empty. actions[i][j] = newStateTestAction(common.HexToAddress("0xdeadbeef"), r, 0) continue } actions[i][j] = newStateTestAction(addrs[r.Intn(len(addrs))], r, -1) } } chunk := int(math.Sqrt(float64(size))) if size > 0 && chunk == 0 { chunk = 1 } return reflect.ValueOf(&stateTest{ addrs: addrs, actions: actions, chunk: chunk, }) } func (test *stateTest) String() string { out := new(bytes.Buffer) for i, actions := range test.actions { fmt.Fprintf(out, "---- block %d ----\n", i) for j, action := range actions { if j%test.chunk == 0 { fmt.Fprintf(out, "---- transaction %d ----\n", j/test.chunk) } fmt.Fprintf(out, "%4d: %s\n", j%test.chunk, action.name) } } return out.String() } func (test *stateTest) run() bool { var ( roots []common.Hash accounts []map[common.Hash][]byte accountOrigin []map[common.Address][]byte storages []map[common.Hash]map[common.Hash][]byte storageOrigin []map[common.Address]map[common.Hash][]byte copyUpdate = func(update *stateUpdate) { accounts = append(accounts, maps.Clone(update.accounts)) accountOrigin = append(accountOrigin, maps.Clone(update.accountsOrigin)) storages = append(storages, maps.Clone(update.storages)) storageOrigin = append(storageOrigin, maps.Clone(update.storagesOrigin)) } disk = rawdb.NewMemoryDatabase() tdb = triedb.NewDatabase(disk, &triedb.Config{PathDB: pathdb.Defaults}) byzantium = rand.Intn(2) == 0 ) defer disk.Close() defer tdb.Close() var snaps *snapshot.Tree if rand.Intn(3) == 0 { snaps, _ = snapshot.New(snapshot.Config{ CacheSize: 1, Recovery: false, NoBuild: false, AsyncBuild: false, }, disk, tdb, types.EmptyRootHash) } for i, actions := range test.actions { root := types.EmptyRootHash if i != 0 { root = roots[len(roots)-1] } state, err := New(root, NewDatabase(tdb, snaps)) if err != nil { panic(err) } for i, action := range actions { if i%test.chunk == 0 && i != 0 { if byzantium { state.Finalise(true) // call finalise at the transaction boundary } else { state.IntermediateRoot(true) // call intermediateRoot at the transaction boundary } } action.fn(action, state) } if byzantium { state.Finalise(true) // call finalise at the transaction boundary } else { state.IntermediateRoot(true) // call intermediateRoot at the transaction boundary } ret, err := state.commitAndFlush(0, true) // call commit at the block boundary if err != nil { panic(err) } if ret.empty() { return true } copyUpdate(ret) roots = append(roots, ret.root) } for i := 0; i < len(test.actions); i++ { root := types.EmptyRootHash if i != 0 { root = roots[i-1] } test.err = test.verify(root, roots[i], tdb, accounts[i], accountOrigin[i], storages[i], storageOrigin[i]) if test.err != nil { return false } } return true } // verifyAccountCreation this function is called once the state diff says that // specific account was not present. A serial of checks will be performed to // ensure the state diff is correct, includes: // // - the account was indeed not present in trie // - the account is present in new trie, nil->nil is regarded as invalid // - the slots transition is correct func (test *stateTest) verifyAccountCreation(next common.Hash, db *triedb.Database, otr, ntr *trie.Trie, addr common.Address, account []byte, storages map[common.Hash][]byte, storagesOrigin map[common.Hash][]byte) error { // Verify account change addrHash := crypto.Keccak256Hash(addr.Bytes()) oBlob, err := otr.Get(addrHash.Bytes()) if err != nil { return err } nBlob, err := ntr.Get(addrHash.Bytes()) if err != nil { return err } if len(oBlob) != 0 { return fmt.Errorf("unexpected account in old trie, %x", addrHash) } if len(nBlob) == 0 { return fmt.Errorf("missing account in new trie, %x", addrHash) } full, err := types.FullAccountRLP(account) if err != nil { return err } if !bytes.Equal(nBlob, full) { return fmt.Errorf("unexpected account data, want: %v, got: %v", full, nBlob) } // Verify storage changes var nAcct types.StateAccount if err := rlp.DecodeBytes(nBlob, &nAcct); err != nil { return err } // Account has no slot, empty slot set is expected if nAcct.Root == types.EmptyRootHash { if len(storagesOrigin) != 0 { return fmt.Errorf("unexpected slot changes %x", addrHash) } if len(storages) != 0 { return fmt.Errorf("unexpected slot changes %x", addrHash) } return nil } // Account has slots, ensure all new slots are contained st, err := trie.New(trie.StorageTrieID(next, addrHash, nAcct.Root), db) if err != nil { return err } for key, val := range storagesOrigin { if _, exist := storages[key]; !exist { return errors.New("storage data is not found") } got, err := st.Get(key.Bytes()) if err != nil { return err } if !bytes.Equal(got, storages[key]) { return fmt.Errorf("unexpected storage data, want: %v, got: %v", storages[key], got) } st.Update(key.Bytes(), val) } if len(storagesOrigin) != len(storages) { return fmt.Errorf("extra storage found, want: %d, got: %d", len(storagesOrigin), len(storages)) } if st.Hash() != types.EmptyRootHash { return errors.New("invalid slot changes") } return nil } // verifyAccountUpdate this function is called once the state diff says that // specific account was present. A serial of checks will be performed to // ensure the state diff is correct, includes: // // - the account was indeed present in trie // - the account in old trie matches the provided value // - the slots transition is correct func (test *stateTest) verifyAccountUpdate(next common.Hash, db *triedb.Database, otr, ntr *trie.Trie, addr common.Address, account []byte, accountOrigin []byte, storages map[common.Hash][]byte, storageOrigin map[common.Hash][]byte) error { // Verify account change addrHash := crypto.Keccak256Hash(addr.Bytes()) oBlob, err := otr.Get(addrHash.Bytes()) if err != nil { return err } nBlob, err := ntr.Get(addrHash.Bytes()) if err != nil { return err } if len(oBlob) == 0 { return fmt.Errorf("missing account in old trie, %x", addrHash) } full, err := types.FullAccountRLP(accountOrigin) if err != nil { return err } if !bytes.Equal(full, oBlob) { return fmt.Errorf("account value is not matched, %x", addrHash) } if len(nBlob) == 0 { if len(account) != 0 { return errors.New("unexpected account data") } } else { full, _ = types.FullAccountRLP(account) if !bytes.Equal(full, nBlob) { return fmt.Errorf("unexpected account data, %x, want %v, got: %v", addrHash, full, nBlob) } } // Decode accounts var ( oAcct types.StateAccount nAcct types.StateAccount nRoot common.Hash ) if err := rlp.DecodeBytes(oBlob, &oAcct); err != nil { return err } if len(nBlob) == 0 { nRoot = types.EmptyRootHash } else { if err := rlp.DecodeBytes(nBlob, &nAcct); err != nil { return err } nRoot = nAcct.Root } // Verify storage st, err := trie.New(trie.StorageTrieID(next, addrHash, nRoot), db) if err != nil { return err } for key, val := range storageOrigin { if _, exist := storages[key]; !exist { return errors.New("storage data is not found") } got, err := st.Get(key.Bytes()) if err != nil { return err } if !bytes.Equal(got, storages[key]) { return fmt.Errorf("unexpected storage data, want: %v, got: %v", storages[key], got) } st.Update(key.Bytes(), val) } if len(storageOrigin) != len(storages) { return fmt.Errorf("extra storage found, want: %d, got: %d", len(storageOrigin), len(storages)) } if st.Hash() != oAcct.Root { return errors.New("invalid slot changes") } return nil } func (test *stateTest) verify(root common.Hash, next common.Hash, db *triedb.Database, accounts map[common.Hash][]byte, accountsOrigin map[common.Address][]byte, storages map[common.Hash]map[common.Hash][]byte, storagesOrigin map[common.Address]map[common.Hash][]byte) error { otr, err := trie.New(trie.StateTrieID(root), db) if err != nil { return err } ntr, err := trie.New(trie.StateTrieID(next), db) if err != nil { return err } for addr, accountOrigin := range accountsOrigin { var ( err error addrHash = crypto.Keccak256Hash(addr.Bytes()) ) if len(accountOrigin) == 0 { err = test.verifyAccountCreation(next, db, otr, ntr, addr, accounts[addrHash], storages[addrHash], storagesOrigin[addr]) } else { err = test.verifyAccountUpdate(next, db, otr, ntr, addr, accounts[addrHash], accountsOrigin[addr], storages[addrHash], storagesOrigin[addr]) } if err != nil { return err } } return nil } func TestStateChanges(t *testing.T) { config := &quick.Config{MaxCount: 1000} err := quick.Check((*stateTest).run, config) if cerr, ok := err.(*quick.CheckError); ok { test := cerr.In[0].(*stateTest) t.Errorf("%v:\n%s", test.err, test) } else if err != nil { t.Error(err) } }