core/state: work on fuzztest for journals/state

This commit is contained in:
Martin Holst Swende 2024-11-23 19:48:00 +01:00
parent d07dbb669e
commit fe618e71f2
No known key found for this signature in database
GPG Key ID: 683B438C05A5DDF0
1 changed files with 120 additions and 72 deletions

View File

@ -24,6 +24,7 @@ import (
"io"
"testing"
"crypto/rand"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/holiman/uint256"
@ -157,32 +158,34 @@ func (f *fuzzReader) bytes(n int) []byte {
return r
}
func newEmptyState() *StateDB {
s, _ := New(types.EmptyRootHash, NewDatabaseForTesting())
return s
}
func fuzzJournals(t *testing.T, data []byte) {
var (
reader = fuzzReader{input: bytes.NewReader(data)}
statedbA, _ = New(types.EmptyRootHash, NewDatabaseForTesting())
statedbB, _ = New(types.EmptyRootHash, NewDatabaseForTesting())
journalA = journal(newLinearJournal())
journalB = journal(newSparseJournal())
reader = fuzzReader{input: bytes.NewReader(data)}
journals = []journal{
journal(newLinearJournal()),
journal(newSparseJournal()),
}
stateDbs = []*StateDB{
newEmptyState(),
newEmptyState(),
}
)
apply := func(action func(j journal, sdb *StateDB)) {
action(journalA, statedbA)
action(journalB, statedbB)
}
randAccount := func() *types.StateAccount {
return &types.StateAccount{
Nonce: binary.BigEndian.Uint64(reader.bytes(8)),
Balance: uint256.NewInt(binary.BigEndian.Uint64(reader.bytes(8))),
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash[:],
for i := range journals {
action(journals[i], stateDbs[i])
}
}
crossCheck := func() {
accs1 := journalA.dirtyAccounts()
accs2 := journalA.dirtyAccounts()
if len(accs1) != len(accs2) {
panic(fmt.Sprintf("mismatched accounts: %v %v", accs1, accs2))
crossCheck := func() {
accs1 := journals[0].dirtyAccounts()
accs2 := journals[1].dirtyAccounts()
if len(accs1) != len(accs2) {
panic(fmt.Sprintf("mismatched dirty-sets:\n%v\n%v", accs1, accs2))
}
for _, val := range accs1 {
found := false
@ -199,138 +202,171 @@ func fuzzJournals(t *testing.T, data []byte) {
}
}
}
apply(func(j journal, sdb *StateDB) {
sdb.journal = j
j.snapshot()
})
for !reader.exhausted {
switch reader.byte() % 19 {
op := reader.byte() % 19
switch op {
case 0: // Add account to access lists
addr := common.BytesToAddress(reader.bytes(1))
t.Logf("Op %d: Add to access list %#x", op, addr)
apply(func(j journal, sdb *StateDB) {
j.accessListAddAccount(addr)
sdb.accessList.AddAddress(addr)
})
case 1: // Add slot to access list
addr := common.BytesToAddress(reader.bytes(1))
slot := common.BytesToHash(reader.bytes(1))
t.Logf("Op %d: Add addr:slot to access list %#x : %#x", op, addr, slot)
apply(func(j journal, sdb *StateDB) {
j.accessListAddSlot(addr, slot)
sdb.accessList.AddSlot(addr, slot)
sdb.AddSlotToAccessList(addr, slot)
})
case 2:
var (
addr = common.BytesToAddress(reader.bytes(1))
account = randAccount()
destructed = reader.bool()
newContract = reader.bool()
addr = common.BytesToAddress(reader.bytes(1))
value = uint64(reader.byte())
)
t.Logf("Op %d: Add balance %#x %d", op, addr, value)
apply(func(j journal, sdb *StateDB) {
j.balanceChange(addr, account, destructed, newContract)
sdb.AddBalance(addr, uint256.NewInt(value), 0)
})
case 3:
journalA = journalA.copy()
t.Logf("Op %d: copy journals[0]", op)
journals[0] = journals[0].copy()
stateDbs[0].journal = journals[0]
case 4:
journalB = journalB.copy()
t.Logf("Op %d: copy journals[1]", op)
journals[1] = journals[1].copy()
stateDbs[1].journal = journals[1]
case 5:
addr := common.BytesToAddress(reader.bytes(1))
account := randAccount()
apply(func(j journal, sdb *StateDB) {
j.createContract(addr, account)
var (
addr = common.BytesToAddress(reader.bytes(1))
code = reader.bytes(2)
)
t.Logf("Op %d: (Create and) set code 0x%x", op, addr)
apply(func(j journal, s *StateDB) {
if !s.Exist(addr) {
s.CreateAccount(addr)
}
contractHash := s.GetCodeHash(addr)
emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash
storageRoot := s.GetStorageRoot(addr)
emptyStorage := storageRoot == (common.Hash{}) || storageRoot == types.EmptyRootHash
if obj := s.getStateObject(addr); obj != nil {
if obj.selfDestructed {
// If it's selfdestructed, we cannot create into it
return
}
}
if s.GetNonce(addr) == 0 && emptyCode && emptyStorage {
s.CreateContract(addr)
// We also set some code here, to prevent the
// CreateContract action from being performed twice in a row,
// which would cause a difference in state when unrolling
// the linearJournal. (CreateContact assumes created was false prior to
// invocation, and the linearJournal rollback sets it to false).
s.SetCode(addr, code)
}
})
case 6:
addr := common.BytesToAddress(reader.bytes(1))
t.Logf("Op %d: Create 0x%x", op, addr)
apply(func(j journal, sdb *StateDB) {
j.createObject(addr)
if !sdb.Exist(addr) {
sdb.CreateAccount(addr)
}
})
case 7:
addr := common.BytesToAddress(reader.bytes(1))
account := randAccount()
apply(func(j journal, sdb *StateDB) {
j.destruct(addr, account)
t.Logf("Op %d: (Create and) destruct 0x%x", op, addr)
apply(func(j journal, s *StateDB) {
if !s.Exist(addr) {
s.CreateAccount(addr)
}
s.SelfDestruct(addr)
})
case 8:
txHash := common.BytesToHash(reader.bytes(1))
t.Logf("Op %d: Add log %#x", op, txHash)
apply(func(j journal, sdb *StateDB) {
sdb.logs[txHash] = append(sdb.logs[txHash], new(types.Log))
sdb.logSize++
j.logChange(txHash)
})
case 9:
var (
addr = common.BytesToAddress(reader.bytes(1))
account = randAccount()
destructed = reader.bool()
newContract = reader.bool()
addr = common.BytesToAddress(reader.bytes(1))
nonce = binary.BigEndian.Uint64(reader.bytes(8))
)
t.Logf("Op %d: Set nonce %#x %d", op, addr, nonce)
apply(func(j journal, sdb *StateDB) {
j.nonceChange(addr, account, destructed, newContract)
sdb.SetNonce(addr, nonce)
})
case 10:
refund := uint64(reader.byte())
t.Logf("Op %d: Set refund %d", op, refund)
apply(func(j journal, sdb *StateDB) {
j.refundChange(uint64(refund))
})
case 11:
var (
addr = common.BytesToAddress(reader.bytes(1))
account = randAccount()
)
apply(func(j journal, sdb *StateDB) {
j.setCode(addr, account)
})
case 12:
var (
addr = common.BytesToAddress(reader.bytes(1))
key = common.BytesToHash(reader.bytes(1))
prev = common.BytesToHash(reader.bytes(1))
origin = common.BytesToHash(reader.bytes(1))
addr = common.BytesToAddress(reader.bytes(1))
key = common.BytesToHash(reader.bytes(1))
val = common.BytesToHash(reader.bytes(1))
)
t.Logf("Op %d: Set storage %#x [%#x]=%#x", op, addr, key, val)
apply(func(j journal, sdb *StateDB) {
j.storageChange(addr, key, prev, origin)
sdb.SetState(addr, key, val)
})
case 13:
var (
addr = common.BytesToAddress(reader.bytes(1))
account = randAccount()
destructed = reader.bool()
newContract = reader.bool()
addr = common.BytesToAddress(reader.bytes(1))
)
t.Logf("Op %d: Zero-balance transfer (touch) %#x", op, addr)
apply(func(j journal, sdb *StateDB) {
j.touchChange(addr, account, destructed, newContract)
sdb.AddBalance(addr, uint256.NewInt(0), 0)
})
case 14:
addr := common.BytesToAddress(reader.bytes(1))
key := common.BytesToHash(reader.bytes(1))
prev := common.BytesToHash(reader.bytes(1))
var (
addr = common.BytesToAddress(reader.bytes(1))
key = common.BytesToHash(reader.bytes(1))
value = common.BytesToHash(reader.bytes(1))
)
t.Logf("Op %d: Set t-storage %#x [%#x]=%#x", op, addr, key, value)
apply(func(j journal, sdb *StateDB) {
j.transientStateChange(addr, key, prev)
sdb.SetTransientState(addr, key, value)
})
case 15:
t.Logf("Op %d: reset journal", op)
apply(func(j journal, sdb *StateDB) {
j.reset()
})
case 16:
t.Logf("Op %d: snapshot", op)
apply(func(j journal, sdb *StateDB) {
j.snapshot()
sdb.Snapshot()
})
case 17:
t.Logf("Op %d: discard snapshot", op)
apply(func(j journal, sdb *StateDB) {
j.discardSnapshot()
sdb.DiscardSnapshot()
})
case 18:
t.Logf("Op %d: revert snapshot", op)
apply(func(j journal, sdb *StateDB) {
j.revertSnapshot(statedbA)
sdb.RevertSnapshot()
})
}
crossCheck()
}
h1, err1 := statedbA.Commit(0, false)
h2, err2 := statedbB.Commit(0, false)
h1, err1 := stateDbs[0].Commit(0, false)
h2, err2 := stateDbs[1].Commit(0, false)
if err1 != err2 {
panic(fmt.Sprintf("mismatched errors: %v %v", err1, err2))
}
@ -342,3 +378,15 @@ func fuzzJournals(t *testing.T, data []byte) {
func FuzzJournals(f *testing.F) {
f.Fuzz(fuzzJournals)
}
func TestFuzzJournals(t *testing.T) {
input := make([]byte, 200)
for i := 0; i < 200; i++ {
rand.Read(input)
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
t.Logf("input: %x", input)
fuzzJournals(t, input)
})
}
}