diff --git a/core/state/statedb.go b/core/state/statedb.go index 754892493f..ef159a4c22 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -306,28 +306,20 @@ func (s *StateDB) Empty(addr common.Address) bool { // GetBalance retrieves the balance from the given address or 0 if object not found func (s *StateDB) GetBalance(addr common.Address) *uint256.Int { - bal := common.U2560 stateObject := s.getStateObject(addr) if stateObject != nil { - bal = stateObject.Balance() + return stateObject.Balance() } - if s.logger != nil && s.logger.OnBalanceRead != nil { - s.logger.OnBalanceRead(addr, bal.ToBig()) - } - return bal + return common.U2560 } // GetNonce retrieves the nonce from the given address or 0 if object not found func (s *StateDB) GetNonce(addr common.Address) uint64 { - var nonce uint64 stateObject := s.getStateObject(addr) if stateObject != nil { - nonce = stateObject.Nonce() + return stateObject.Nonce() } - if s.logger != nil && s.logger.OnNonceRead != nil { - s.logger.OnNonceRead(addr, nonce) - } - return nonce + return 0 } // GetStorageRoot retrieves the storage root from the given address or empty @@ -346,52 +338,36 @@ func (s *StateDB) TxIndex() int { } func (s *StateDB) GetCode(addr common.Address) []byte { - var code []byte stateObject := s.getStateObject(addr) if stateObject != nil { - code = stateObject.Code() + return stateObject.Code() } - if s.logger != nil && s.logger.OnCodeRead != nil { - s.logger.OnCodeRead(addr, code) - } - return code + return nil } func (s *StateDB) GetCodeSize(addr common.Address) int { - var size int stateObject := s.getStateObject(addr) if stateObject != nil { - size = stateObject.CodeSize() + return stateObject.CodeSize() } - if s.logger != nil && s.logger.OnCodeSizeRead != nil { - s.logger.OnCodeSizeRead(addr, size) - } - return size + return 0 } func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { - hash := common.Hash{} stateObject := s.getStateObject(addr) if stateObject != nil { - hash = common.BytesToHash(stateObject.CodeHash()) + return common.BytesToHash(stateObject.CodeHash()) } - if s.logger != nil && s.logger.OnCodeHashRead != nil { - s.logger.OnCodeHashRead(addr, hash) - } - return hash + return common.Hash{} } // GetState retrieves the value associated with the specific key. func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { - val := common.Hash{} stateObject := s.getStateObject(addr) if stateObject != nil { - val = stateObject.GetState(hash) + return stateObject.GetState(hash) } - if s.logger != nil && s.logger.OnStorageRead != nil { - s.logger.OnStorageRead(addr, hash, val) - } - return val + return common.Hash{} } // GetCommittedState retrieves the value associated with the specific key diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index 55b53ded40..e7ec0228f0 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -54,23 +54,43 @@ func (s *hookedStateDB) CreateContract(addr common.Address) { } func (s *hookedStateDB) GetBalance(addr common.Address) *uint256.Int { - return s.inner.GetBalance(addr) + bal := s.inner.GetBalance(addr) + if s.hooks.OnBalanceRead != nil { + s.hooks.OnBalanceRead(addr, bal.ToBig()) + } + return bal } func (s *hookedStateDB) GetNonce(addr common.Address) uint64 { - return s.inner.GetNonce(addr) + nonce := s.inner.GetNonce(addr) + if s.hooks.OnNonceRead != nil { + s.hooks.OnNonceRead(addr, nonce) + } + return nonce } func (s *hookedStateDB) GetCodeHash(addr common.Address) common.Hash { - return s.inner.GetCodeHash(addr) + codeHash := s.inner.GetCodeHash(addr) + if s.hooks.OnCodeHashRead != nil { + s.hooks.OnCodeHashRead(addr, codeHash) + } + return codeHash } func (s *hookedStateDB) GetCode(addr common.Address) []byte { - return s.inner.GetCode(addr) + code := s.inner.GetCode(addr) + if s.hooks.OnCodeRead != nil { + s.hooks.OnCodeRead(addr, code) + } + return code } func (s *hookedStateDB) GetCodeSize(addr common.Address) int { - return s.inner.GetCodeSize(addr) + size := s.inner.GetCodeSize(addr) + if s.hooks.OnCodeSizeRead != nil { + s.hooks.OnCodeSizeRead(addr, size) + } + return size } func (s *hookedStateDB) AddRefund(u uint64) { @@ -90,7 +110,11 @@ func (s *hookedStateDB) GetCommittedState(addr common.Address, hash common.Hash) } func (s *hookedStateDB) GetState(addr common.Address, hash common.Hash) common.Hash { - return s.inner.GetState(addr, hash) + val := s.inner.GetState(addr, hash) + if s.hooks.OnStorageRead != nil { + s.hooks.OnStorageRead(addr, hash, val) + } + return val } func (s *hookedStateDB) GetStorageRoot(addr common.Address) common.Hash { diff --git a/core/state/statedb_hooked_test.go b/core/state/statedb_hooked_test.go index 9abd76b02d..9fd3ebc95e 100644 --- a/core/state/statedb_hooked_test.go +++ b/core/state/statedb_hooked_test.go @@ -90,6 +90,12 @@ func TestHooks(t *testing.T) { "0xaa00000000000000000000000000000000000000.storage slot 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000000 ->0x0000000000000000000000000000000000000000000000000000000000000011", "0xaa00000000000000000000000000000000000000.storage slot 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000011 ->0x0000000000000000000000000000000000000000000000000000000000000022", "log 100", + "0xaa00000000000000000000000000000000000000.balance read: 50", + "0xaa00000000000000000000000000000000000000.nonce read: 1337", + "0xaa00000000000000000000000000000000000000.code read: [19 37]", + "0xaa00000000000000000000000000000000000000.storage read 0x0000000000000000000000000000000000000000000000000000000000000001: 0x0000000000000000000000000000000000000000000000000000000000000022", + "0xaa00000000000000000000000000000000000000.code size read: 2", + "0xaa00000000000000000000000000000000000000.code hash read: 0xa12ae05590de0c93a00bc7ac773c2fdb621e44f814985e72194f921c0050f728", } emitF := func(format string, a ...any) { result = append(result, fmt.Sprintf(format, a...)) @@ -110,6 +116,24 @@ func TestHooks(t *testing.T) { OnLog: func(log *types.Log) { emitF("log %v", log.TxIndex) }, + OnBalanceRead: func(addr common.Address, bal *big.Int) { + emitF("%v.balance read: %v", addr, bal) + }, + OnNonceRead: func(addr common.Address, nonce uint64) { + emitF("%v.nonce read: %v", addr, nonce) + }, + OnCodeRead: func(addr common.Address, code []byte) { + emitF("%v.code read: %v", addr, code) + }, + OnStorageRead: func(addr common.Address, slot common.Hash, value common.Hash) { + emitF("%v.storage read %v: %v", addr, slot, value) + }, + OnCodeSizeRead: func(addr common.Address, size int) { + emitF("%v.code size read: %v", addr, size) + }, + OnCodeHashRead: func(addr common.Address, hash common.Hash) { + emitF("%v.code hash read: %v", addr, hash) + }, }) sdb.AddBalance(common.Address{0xaa}, uint256.NewInt(100), tracing.BalanceChangeUnspecified) sdb.SubBalance(common.Address{0xaa}, uint256.NewInt(50), tracing.BalanceChangeTransfer) @@ -122,6 +146,12 @@ func TestHooks(t *testing.T) { sdb.AddLog(&types.Log{ Address: common.Address{0xbb}, }) + sdb.GetBalance(common.Address{0xaa}) + sdb.GetNonce(common.Address{0xaa}) + sdb.GetCode(common.Address{0xaa}) + sdb.GetState(common.Address{0xaa}, common.HexToHash("0x01")) + sdb.GetCodeSize(common.Address{0xaa}) + sdb.GetCodeHash(common.Address{0xaa}) for i, want := range wants { if have := result[i]; have != want { t.Fatalf("error event %d, have\n%v\nwant%v\n", i, have, want)