mv read hooks to statedb_hooked

This commit is contained in:
Sina Mahmoodi 2024-10-24 06:09:44 +02:00
parent fbd1d19cdb
commit 0f005af66a
3 changed files with 72 additions and 42 deletions

View File

@ -306,28 +306,20 @@ func (s *StateDB) Empty(addr common.Address) bool {
// GetBalance retrieves the balance from the given address or 0 if object not found // GetBalance retrieves the balance from the given address or 0 if object not found
func (s *StateDB) GetBalance(addr common.Address) *uint256.Int { func (s *StateDB) GetBalance(addr common.Address) *uint256.Int {
bal := common.U2560
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
bal = stateObject.Balance() return stateObject.Balance()
} }
if s.logger != nil && s.logger.OnBalanceRead != nil { return common.U2560
s.logger.OnBalanceRead(addr, bal.ToBig())
}
return bal
} }
// GetNonce retrieves the nonce from the given address or 0 if object not found // GetNonce retrieves the nonce from the given address or 0 if object not found
func (s *StateDB) GetNonce(addr common.Address) uint64 { func (s *StateDB) GetNonce(addr common.Address) uint64 {
var nonce uint64
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
nonce = stateObject.Nonce() return stateObject.Nonce()
} }
if s.logger != nil && s.logger.OnNonceRead != nil { return 0
s.logger.OnNonceRead(addr, nonce)
}
return nonce
} }
// GetStorageRoot retrieves the storage root from the given address or empty // GetStorageRoot retrieves the storage root from the given address or empty
@ -346,52 +338,36 @@ func (s *StateDB) TxIndex() int {
} }
func (s *StateDB) GetCode(addr common.Address) []byte { func (s *StateDB) GetCode(addr common.Address) []byte {
var code []byte
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
code = stateObject.Code() return stateObject.Code()
} }
if s.logger != nil && s.logger.OnCodeRead != nil { return nil
s.logger.OnCodeRead(addr, code)
}
return code
} }
func (s *StateDB) GetCodeSize(addr common.Address) int { func (s *StateDB) GetCodeSize(addr common.Address) int {
var size int
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
size = stateObject.CodeSize() return stateObject.CodeSize()
} }
if s.logger != nil && s.logger.OnCodeSizeRead != nil { return 0
s.logger.OnCodeSizeRead(addr, size)
}
return size
} }
func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { func (s *StateDB) GetCodeHash(addr common.Address) common.Hash {
hash := common.Hash{}
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
hash = common.BytesToHash(stateObject.CodeHash()) return common.BytesToHash(stateObject.CodeHash())
} }
if s.logger != nil && s.logger.OnCodeHashRead != nil { return common.Hash{}
s.logger.OnCodeHashRead(addr, hash)
}
return hash
} }
// GetState retrieves the value associated with the specific key. // GetState retrieves the value associated with the specific key.
func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
val := common.Hash{}
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
val = stateObject.GetState(hash) return stateObject.GetState(hash)
} }
if s.logger != nil && s.logger.OnStorageRead != nil { return common.Hash{}
s.logger.OnStorageRead(addr, hash, val)
}
return val
} }
// GetCommittedState retrieves the value associated with the specific key // GetCommittedState retrieves the value associated with the specific key

View File

@ -54,23 +54,43 @@ func (s *hookedStateDB) CreateContract(addr common.Address) {
} }
func (s *hookedStateDB) GetBalance(addr common.Address) *uint256.Int { func (s *hookedStateDB) GetBalance(addr common.Address) *uint256.Int {
return s.inner.GetBalance(addr) bal := s.inner.GetBalance(addr)
if s.hooks.OnBalanceRead != nil {
s.hooks.OnBalanceRead(addr, bal.ToBig())
}
return bal
} }
func (s *hookedStateDB) GetNonce(addr common.Address) uint64 { func (s *hookedStateDB) GetNonce(addr common.Address) uint64 {
return s.inner.GetNonce(addr) nonce := s.inner.GetNonce(addr)
if s.hooks.OnNonceRead != nil {
s.hooks.OnNonceRead(addr, nonce)
}
return nonce
} }
func (s *hookedStateDB) GetCodeHash(addr common.Address) common.Hash { func (s *hookedStateDB) GetCodeHash(addr common.Address) common.Hash {
return s.inner.GetCodeHash(addr) codeHash := s.inner.GetCodeHash(addr)
if s.hooks.OnCodeHashRead != nil {
s.hooks.OnCodeHashRead(addr, codeHash)
}
return codeHash
} }
func (s *hookedStateDB) GetCode(addr common.Address) []byte { func (s *hookedStateDB) GetCode(addr common.Address) []byte {
return s.inner.GetCode(addr) code := s.inner.GetCode(addr)
if s.hooks.OnCodeRead != nil {
s.hooks.OnCodeRead(addr, code)
}
return code
} }
func (s *hookedStateDB) GetCodeSize(addr common.Address) int { func (s *hookedStateDB) GetCodeSize(addr common.Address) int {
return s.inner.GetCodeSize(addr) size := s.inner.GetCodeSize(addr)
if s.hooks.OnCodeSizeRead != nil {
s.hooks.OnCodeSizeRead(addr, size)
}
return size
} }
func (s *hookedStateDB) AddRefund(u uint64) { func (s *hookedStateDB) AddRefund(u uint64) {
@ -90,7 +110,11 @@ func (s *hookedStateDB) GetCommittedState(addr common.Address, hash common.Hash)
} }
func (s *hookedStateDB) GetState(addr common.Address, hash common.Hash) common.Hash { func (s *hookedStateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
return s.inner.GetState(addr, hash) val := s.inner.GetState(addr, hash)
if s.hooks.OnStorageRead != nil {
s.hooks.OnStorageRead(addr, hash, val)
}
return val
} }
func (s *hookedStateDB) GetStorageRoot(addr common.Address) common.Hash { func (s *hookedStateDB) GetStorageRoot(addr common.Address) common.Hash {

View File

@ -90,6 +90,12 @@ func TestHooks(t *testing.T) {
"0xaa00000000000000000000000000000000000000.storage slot 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000000 ->0x0000000000000000000000000000000000000000000000000000000000000011", "0xaa00000000000000000000000000000000000000.storage slot 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000000 ->0x0000000000000000000000000000000000000000000000000000000000000011",
"0xaa00000000000000000000000000000000000000.storage slot 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000011 ->0x0000000000000000000000000000000000000000000000000000000000000022", "0xaa00000000000000000000000000000000000000.storage slot 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000011 ->0x0000000000000000000000000000000000000000000000000000000000000022",
"log 100", "log 100",
"0xaa00000000000000000000000000000000000000.balance read: 50",
"0xaa00000000000000000000000000000000000000.nonce read: 1337",
"0xaa00000000000000000000000000000000000000.code read: [19 37]",
"0xaa00000000000000000000000000000000000000.storage read 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000022",
"0xaa00000000000000000000000000000000000000.code size read: 2",
"0xaa00000000000000000000000000000000000000.code hash read: 0xa12ae05590de0c93a00bc7ac773c2fdb621e44f814985e72194f921c0050f728",
} }
emitF := func(format string, a ...any) { emitF := func(format string, a ...any) {
result = append(result, fmt.Sprintf(format, a...)) result = append(result, fmt.Sprintf(format, a...))
@ -110,6 +116,24 @@ func TestHooks(t *testing.T) {
OnLog: func(log *types.Log) { OnLog: func(log *types.Log) {
emitF("log %v", log.TxIndex) emitF("log %v", log.TxIndex)
}, },
OnBalanceRead: func(addr common.Address, bal *big.Int) {
emitF("%v.balance read: %v", addr, bal)
},
OnNonceRead: func(addr common.Address, nonce uint64) {
emitF("%v.nonce read: %v", addr, nonce)
},
OnCodeRead: func(addr common.Address, code []byte) {
emitF("%v.code read: %v", addr, code)
},
OnStorageRead: func(addr common.Address, slot common.Hash, value common.Hash) {
emitF("%v.storage read %v: %v", addr, slot, value)
},
OnCodeSizeRead: func(addr common.Address, size int) {
emitF("%v.code size read: %v", addr, size)
},
OnCodeHashRead: func(addr common.Address, hash common.Hash) {
emitF("%v.code hash read: %v", addr, hash)
},
}) })
sdb.AddBalance(common.Address{0xaa}, uint256.NewInt(100), tracing.BalanceChangeUnspecified) sdb.AddBalance(common.Address{0xaa}, uint256.NewInt(100), tracing.BalanceChangeUnspecified)
sdb.SubBalance(common.Address{0xaa}, uint256.NewInt(50), tracing.BalanceChangeTransfer) sdb.SubBalance(common.Address{0xaa}, uint256.NewInt(50), tracing.BalanceChangeTransfer)
@ -122,6 +146,12 @@ func TestHooks(t *testing.T) {
sdb.AddLog(&types.Log{ sdb.AddLog(&types.Log{
Address: common.Address{0xbb}, Address: common.Address{0xbb},
}) })
sdb.GetBalance(common.Address{0xaa})
sdb.GetNonce(common.Address{0xaa})
sdb.GetCode(common.Address{0xaa})
sdb.GetState(common.Address{0xaa}, common.HexToHash("0x01"))
sdb.GetCodeSize(common.Address{0xaa})
sdb.GetCodeHash(common.Address{0xaa})
for i, want := range wants { for i, want := range wants {
if have := result[i]; have != want { if have := result[i]; have != want {
t.Fatalf("error event %d, have\n%v\nwant%v\n", i, have, want) t.Fatalf("error event %d, have\n%v\nwant%v\n", i, have, want)