diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index dd3aa1b0b5..1fa06e9801 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -41,6 +41,7 @@ type Config struct { DisableJit bool // "disable" so it's enabled by default Debug bool + State *state.StateDB GetHashFn func(n uint64) common.Hash } @@ -94,12 +95,14 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { vm.EnableJit = !cfg.DisableJit vm.Debug = cfg.Debug + if cfg.State == nil { + db, _ := ethdb.NewMemDatabase() + cfg.State, _ = state.New(common.Hash{}, db) + } var ( - db, _ = ethdb.NewMemDatabase() - statedb, _ = state.New(common.Hash{}, db) - vmenv = NewEnv(cfg, statedb) - sender = statedb.CreateAccount(cfg.Origin) - receiver = statedb.CreateAccount(common.StringToAddress("contract")) + vmenv = NewEnv(cfg, cfg.State) + sender = cfg.State.CreateAccount(cfg.Origin) + receiver = cfg.State.CreateAccount(common.StringToAddress("contract")) ) // set the receiver's (the executing contract) code for execution. receiver.SetCode(code) @@ -117,5 +120,43 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { if cfg.Debug { vm.StdErrFormat(vmenv.StructLogs()) } - return ret, statedb, err + return ret, cfg.State, err +} + +// Call executes the code given by the contract's address. It will return the +// EVM's return value or an error if it failed. +// +// Call, unlike Execute, requires a config and also requires the State field to +// be set. +func Call(address common.Address, input []byte, cfg *Config) ([]byte, error) { + setDefaults(cfg) + + // defer the call to setting back the original values + defer func(debug, forceJit, enableJit bool) { + vm.Debug = debug + vm.ForceJit = forceJit + vm.EnableJit = enableJit + }(vm.Debug, vm.ForceJit, vm.EnableJit) + + vm.ForceJit = !cfg.DisableJit + vm.EnableJit = !cfg.DisableJit + vm.Debug = cfg.Debug + + vmenv := NewEnv(cfg, cfg.State) + + sender := cfg.State.GetOrNewStateObject(cfg.Origin) + // Call the code with the given configuration. + ret, err := vmenv.Call( + sender, + address, + input, + cfg.GasLimit, + cfg.GasPrice, + cfg.Value, + ) + + if cfg.Debug { + vm.StdErrFormat(vmenv.StructLogs()) + } + return ret, err } diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 773a0163e4..e5183052fa 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -17,12 +17,15 @@ package runtime import ( + "math/big" "strings" "testing" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" ) func TestDefaults(t *testing.T) { @@ -71,6 +74,49 @@ func TestEnvironment(t *testing.T) { }, nil, nil) } +func TestExecute(t *testing.T) { + ret, _, err := Execute([]byte{ + byte(vm.PUSH1), 10, + byte(vm.PUSH1), 0, + byte(vm.MSTORE), + byte(vm.PUSH1), 32, + byte(vm.PUSH1), 0, + byte(vm.RETURN), + }, nil, nil) + if err != nil { + t.Fatal("didn't expect error", err) + } + + num := common.BytesToBig(ret) + if num.Cmp(big.NewInt(10)) != 0 { + t.Error("Expected 10, got", num) + } +} + +func TestCall(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + state, _ := state.New(common.Hash{}, db) + address := common.HexToAddress("0x0a") + state.SetCode(address, []byte{ + byte(vm.PUSH1), 10, + byte(vm.PUSH1), 0, + byte(vm.MSTORE), + byte(vm.PUSH1), 32, + byte(vm.PUSH1), 0, + byte(vm.RETURN), + }) + + ret, err := Call(address, nil, &Config{State: state}) + if err != nil { + t.Fatal("didn't expect error", err) + } + + num := common.BytesToBig(ret) + if num.Cmp(big.NewInt(10)) != 0 { + t.Error("Expected 10, got", num) + } +} + func TestRestoreDefaults(t *testing.T) { Execute(nil, nil, &Config{Debug: true}) if vm.ForceJit { diff --git a/core/vm_env.go b/core/vm_env.go index c8b50debc6..1c787e9824 100644 --- a/core/vm_env.go +++ b/core/vm_env.go @@ -25,6 +25,21 @@ import ( "github.com/ethereum/go-ethereum/core/vm" ) +// GetHashFn returns a function for which the VM env can query block hashes thru +// up to the limit defined by the Yellow Paper and uses the given block chain +// to query for information. +func GetHashFn(ref common.Hash, chain *BlockChain) func(n uint64) common.Hash { + return func(n uint64) common.Hash { + for block := chain.GetBlock(ref); block != nil; block = chain.GetBlock(block.ParentHash()) { + if block.NumberU64() == n { + return block.Hash() + } + } + + return common.Hash{} + } +} + type VMEnv struct { state *state.StateDB header *types.Header @@ -32,17 +47,20 @@ type VMEnv struct { depth int chain *BlockChain typ vm.Type + + getHashFn func(uint64) common.Hash // structured logging logs []vm.StructLog } func NewEnv(state *state.StateDB, chain *BlockChain, msg Message, header *types.Header) *VMEnv { return &VMEnv{ - chain: chain, - state: state, - header: header, - msg: msg, - typ: vm.StdVmTy, + chain: chain, + state: state, + header: header, + msg: msg, + typ: vm.StdVmTy, + getHashFn: GetHashFn(header.ParentHash, chain), } } @@ -59,13 +77,7 @@ func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) VmType() vm.Type { return self.typ } func (self *VMEnv) SetVmType(t vm.Type) { self.typ = t } func (self *VMEnv) GetHash(n uint64) common.Hash { - for block := self.chain.GetBlock(self.header.ParentHash); block != nil; block = self.chain.GetBlock(block.ParentHash()) { - if block.NumberU64() == n { - return block.Hash() - } - } - - return common.Hash{} + return self.getHashFn(n) } func (self *VMEnv) AddLog(log *vm.Log) {