core/state: refactor journal-fuzzer

core/state: work on fuzztest for journals/state
This commit is contained in:
Martin Holst Swende 2024-11-22 14:53:45 +01:00
parent 08b5bb9182
commit 474e2bee48
No known key found for this signature in database
GPG Key ID: 683B438C05A5DDF0
1 changed files with 247 additions and 170 deletions

View File

@ -18,8 +18,12 @@
package state package state
import ( import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt" "fmt"
"math/rand/v2" "io"
"slices"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -135,173 +139,246 @@ func testJournalRefunds(t *testing.T, j journal) {
} }
} }
func FuzzJournals(f *testing.F) { type fuzzReader struct {
input io.Reader
randByte := func() byte { exhausted bool
return byte(rand.Int()) }
}
randBool := func() bool { func (f *fuzzReader) byte() byte {
return rand.Int()%2 == 0 return f.bytes(1)[0]
} }
randAccount := func() *types.StateAccount {
return &types.StateAccount{ func (f *fuzzReader) bytes(n int) []byte {
Nonce: uint64(randByte()), r := make([]byte, n)
Balance: uint256.NewInt(uint64(randByte())), if _, err := f.input.Read(r); err != nil {
Root: types.EmptyRootHash, f.exhausted = true
CodeHash: types.EmptyCodeHash[:], }
} return r
} }
f.Fuzz(func(t *testing.T, operations []byte) { func newEmptyState() *StateDB {
var ( s, _ := New(types.EmptyRootHash, NewDatabaseForTesting())
statedb1, _ = New(types.EmptyRootHash, NewDatabaseForTesting()) return s
statedb2, _ = New(types.EmptyRootHash, NewDatabaseForTesting()) }
linear = newLinearJournal()
sparse = newSparseJournal() // fuzzJournals is pretty similar to `TestSnapshotRandom`/ `newTestAction` in
) // statedb_test.go. They both execute a sequence of state-actions, however, they
statedb1.journal = linear // test for different aspects.
statedb2.journal = sparse // This test compares two differing journal-implementations.
linear.snapshot() // The other test compares every point in time, whether it is identical when going
sparse.snapshot() // forward as when going backwards through the journal entries.
func fuzzJournals(t *testing.T, data []byte) {
for _, o := range operations { var (
switch o { reader = fuzzReader{input: bytes.NewReader(data)}
case 0: stateDbs = []*StateDB{
addr := randByte() newEmptyState(),
linear.accessListAddAccount(common.Address{addr}) newEmptyState(),
sparse.accessListAddAccount(common.Address{addr}) }
statedb1.accessList.AddAddress(common.Address{addr}) )
statedb2.accessList.AddAddress(common.Address{addr}) apply := func(action func(stateDbs *StateDB)) {
case 1: for _, sdb := range stateDbs {
addr := randByte() action(sdb)
slot := randByte() }
linear.accessListAddSlot(common.Address{addr}, common.Hash{slot}) }
sparse.accessListAddSlot(common.Address{addr}, common.Hash{slot}) stateDbs[0].journal = newLinearJournal()
statedb1.accessList.AddSlot(common.Address{addr}, common.Hash{slot}) stateDbs[1].journal = newSparseJournal()
statedb2.accessList.AddSlot(common.Address{addr}, common.Hash{slot})
case 2: for !reader.exhausted {
addr := randByte() op := reader.byte() % 18
account := randAccount() switch op {
destructed := randBool() case 0: // Add account to access lists
newContract := randBool() addr := common.BytesToAddress(reader.bytes(1))
linear.balanceChange(common.Address{addr}, account, destructed, newContract) t.Logf("Op %d: Add to access list %#x", op, addr)
sparse.balanceChange(common.Address{addr}, account, destructed, newContract) apply(func(sdb *StateDB) {
case 3: sdb.accessList.AddAddress(addr)
linear = linear.copy().(*linearJournal) })
sparse = sparse.copy().(*sparseJournal) case 1: // Add slot to access list
case 4: addr := common.BytesToAddress(reader.bytes(1))
addr := randByte() slot := common.BytesToHash(reader.bytes(1))
account := randAccount() t.Logf("Op %d: Add addr:slot to access list %#x : %#x", op, addr, slot)
linear.createContract(common.Address{addr}, account) apply(func(sdb *StateDB) {
sparse.createContract(common.Address{addr}, account) sdb.AddSlotToAccessList(addr, slot)
case 5: })
addr := randByte() case 2:
linear.createObject(common.Address{addr}) var (
sparse.createObject(common.Address{addr}) addr = common.BytesToAddress(reader.bytes(1))
case 6: value = uint64(reader.byte())
addr := randByte() )
account := randAccount() t.Logf("Op %d: Add balance %#x %d", op, addr, value)
linear.destruct(common.Address{addr}, account) apply(func(sdb *StateDB) {
sparse.destruct(common.Address{addr}, account) sdb.AddBalance(addr, uint256.NewInt(value), 0)
case 7: })
txHash := randByte() case 3:
linear.logChange(common.Hash{txHash}) t.Logf("Op %d: Copy journals[0]", op)
sparse.logChange(common.Hash{txHash}) stateDbs[0].journal = stateDbs[0].journal.copy()
case 8: case 4:
addr := randByte() t.Logf("Op %d: Copy journals[1]", op)
account := randAccount() stateDbs[1].journal = stateDbs[1].journal.copy()
destructed := randBool() case 5:
newContract := randBool() var (
linear.nonceChange(common.Address{addr}, account, destructed, newContract) addr = common.BytesToAddress(reader.bytes(1))
sparse.nonceChange(common.Address{addr}, account, destructed, newContract) code = reader.bytes(2)
case 9: )
refund := randByte() t.Logf("Op %d: (Create and) set code 0x%x", op, addr)
linear.refundChange(uint64(refund)) apply(func(s *StateDB) {
sparse.refundChange(uint64(refund)) if !s.Exist(addr) {
case 10: s.CreateAccount(addr)
addr := randByte() }
account := randAccount() contractHash := s.GetCodeHash(addr)
linear.setCode(common.Address{addr}, account) emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash
sparse.setCode(common.Address{addr}, account) storageRoot := s.GetStorageRoot(addr)
case 11: emptyStorage := storageRoot == (common.Hash{}) || storageRoot == types.EmptyRootHash
addr := randByte()
key := randByte() if obj := s.getStateObject(addr); obj != nil {
prev := randByte() if obj.selfDestructed {
origin := randByte() // If it's selfdestructed, we cannot create into it
linear.storageChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}, common.Hash{origin}) return
sparse.storageChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}, common.Hash{origin}) }
case 12: }
addr := randByte() if s.GetNonce(addr) == 0 && emptyCode && emptyStorage {
account := randAccount() s.CreateContract(addr)
destructed := randBool() // We also set some code here, to prevent the
newContract := randBool() // CreateContract action from being performed twice in a row,
linear.touchChange(common.Address{addr}, account, destructed, newContract) // which would cause a difference in state when unrolling
sparse.touchChange(common.Address{addr}, account, destructed, newContract) // the linearJournal. (CreateContact assumes created was false prior to
case 13: // invocation, and the linearJournal rollback sets it to false).
addr := randByte() s.SetCode(addr, code)
key := randByte() }
prev := randByte() })
linear.transientStateChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}) case 6:
sparse.transientStateChange(common.Address{addr}, common.Hash{key}, common.Hash{prev}) addr := common.BytesToAddress(reader.bytes(1))
case 14: t.Logf("Op %d: Create 0x%x", op, addr)
linear.reset() apply(func(sdb *StateDB) {
sparse.reset() if !sdb.Exist(addr) {
case 15: sdb.CreateAccount(addr)
linear.snapshot() }
sparse.snapshot() })
case 16: case 7:
linear.discardSnapshot() addr := common.BytesToAddress(reader.bytes(1))
sparse.discardSnapshot() t.Logf("Op %d: (Create and) destruct 0x%x", op, addr)
case 17: apply(func(s *StateDB) {
linear.revertSnapshot(statedb1) if !s.Exist(addr) {
sparse.revertSnapshot(statedb2) s.CreateAccount(addr)
case 18: }
accs1 := linear.dirtyAccounts() s.SelfDestruct(addr)
accs2 := linear.dirtyAccounts() })
if len(accs1) != len(accs2) { case 8:
panic(fmt.Sprintf("mismatched accounts: %v %v", accs1, accs2)) txHash := common.BytesToHash(reader.bytes(1))
t.Logf("Op %d: Add log %#x", op, txHash)
} apply(func(sdb *StateDB) {
for _, val := range accs1 { sdb.logs[txHash] = append(sdb.logs[txHash], new(types.Log))
found := false sdb.logSize++
for _, val2 := range accs2 { sdb.journal.logChange(txHash)
if val == val2 { })
if found { case 9:
panic(fmt.Sprintf("account found twice: %v %v account %v", accs1, accs2, val)) var (
} addr = common.BytesToAddress(reader.bytes(1))
found = true nonce = binary.BigEndian.Uint64(reader.bytes(8))
} )
} t.Logf("Op %d: Set nonce %#x %d", op, addr, nonce)
if !found { apply(func(sdb *StateDB) {
panic(fmt.Sprintf("missing account: %v %v account %v", accs1, accs2, val)) sdb.SetNonce(addr, nonce)
} })
} case 10:
} refund := uint64(reader.byte())
} t.Logf("Op %d: Set refund %d", op, refund)
// After all operations have been processed, verify equality apply(func(sdb *StateDB) {
accs1 := linear.dirtyAccounts() sdb.journal.refundChange(refund)
accs2 := linear.dirtyAccounts() })
for _, val := range accs1 { case 11:
found := false var (
for _, val2 := range accs2 { addr = common.BytesToAddress(reader.bytes(1))
if val == val2 { key = common.BytesToHash(reader.bytes(1))
if found { val = common.BytesToHash(reader.bytes(1))
panic(fmt.Sprintf("account found twice: %v %v account %v", accs1, accs2, val)) )
} t.Logf("Op %d: Set storage %#x [%#x]=%#x", op, addr, key, val)
found = true apply(func(sdb *StateDB) {
} sdb.SetState(addr, key, val)
} })
if !found { case 12:
panic(fmt.Sprintf("missing account: %v %v account %v", accs1, accs2, val)) var (
} addr = common.BytesToAddress(reader.bytes(1))
} )
h1, err1 := statedb1.Commit(0, false) t.Logf("Op %d: Zero-balance transfer (touch) %#x", op, addr)
h2, err2 := statedb2.Commit(0, false) apply(func(sdb *StateDB) {
if err1 != err2 { sdb.AddBalance(addr, uint256.NewInt(0), 0)
panic(fmt.Sprintf("mismatched errors: %v %v", err1, err2)) })
} case 13:
if h1 != h2 { var (
panic(fmt.Sprintf("mismatched roots: %v %v", h1, h2)) addr = common.BytesToAddress(reader.bytes(1))
} key = common.BytesToHash(reader.bytes(1))
}) value = common.BytesToHash(reader.bytes(1))
)
t.Logf("Op %d: Set t-storage %#x [%#x]=%#x", op, addr, key, value)
apply(func(sdb *StateDB) {
sdb.SetTransientState(addr, key, value)
})
case 14:
t.Logf("Op %d: Reset journal", op)
apply(func(sdb *StateDB) {
sdb.journal.reset()
})
case 15:
t.Logf("Op %d: Snapshot", op)
apply(func(sdb *StateDB) {
sdb.Snapshot()
})
case 16:
t.Logf("Op %d: Discard snapshot", op)
apply(func(sdb *StateDB) {
sdb.DiscardSnapshot()
})
case 17:
t.Logf("Op %d: Revert snapshot", op)
apply(func(sdb *StateDB) {
sdb.RevertSnapshot()
})
}
// Cross-check the dirty-sets
accs1 := stateDbs[0].journal.dirtyAccounts()
slices.SortFunc(accs1, func(a, b common.Address) int {
return bytes.Compare(a.Bytes(), b.Bytes())
})
accs2 := stateDbs[1].journal.dirtyAccounts()
slices.SortFunc(accs2, func(a, b common.Address) int {
return bytes.Compare(a.Bytes(), b.Bytes())
})
if !slices.Equal(accs1, accs2) {
t.Fatalf("mismatched dirty-sets:\n%v\n%v", accs1, accs2)
}
}
h1, err1 := stateDbs[0].Commit(0, false)
h2, err2 := stateDbs[1].Commit(0, false)
if err1 != err2 {
t.Fatalf("Mismatched errors: %v %v", err1, err2)
}
if h1 != h2 {
t.Fatalf("Mismatched roots: %v %v", h1, h2)
}
}
// FuzzJournals fuzzes the journals.
func FuzzJournals(f *testing.F) {
f.Fuzz(fuzzJournals)
}
// TestFuzzJournals runs 200 fuzz-tests
func TestFuzzJournals(t *testing.T) {
input := make([]byte, 200)
for i := 0; i < 200; i++ {
rand.Read(input)
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
t.Logf("input: %x", input)
fuzzJournals(t, input)
})
}
}
// TestFuzzJournalsSpecific can be used to test a specific input
func TestFuzzJournalsSpecific(t *testing.T) {
t.Skip("example")
input := common.FromHex("71d598d781f65eb7c047fed5d09b1e4e0c1ecad5c447a2149e7d1137fcb1b1d63f4ba6f761918a441a98eb61d69fe011cabfbce00d74bb78539ca9946a602e94d6eabc43c0924ba65ce3e171b476208059d81f33e81d90607e0b6e59d6016840b5c4e9b1a8e9798a5a40be909930658eea351d7a312dba0b1c7199c7e5f62a908a80f7faf29bc0108faae0cf0f497d0f4cd228b7600ef0d88532dfafa6349ea7782f28ad7426eeffc155282a9e58a606d25acd8a730dde61a6e5e887d1ba1fea813bb7f2c6caff25")
fuzzJournals(t, input)
} }