diff --git a/rpc/args.go b/rpc/args.go index d3449928bd..b23216c987 100644 --- a/rpc/args.go +++ b/rpc/args.go @@ -388,8 +388,12 @@ func (args *Sha3Args) UnmarshalJSON(b []byte) (err error) { if len(obj) < 1 { return NewInsufficientParamsError(len(obj), 1) } - args.Data = obj[0].(string) + argstr, ok := obj[0].(string) + if !ok { + return NewInvalidTypeError("data", "is not a string") + } + args.Data = argstr return nil } diff --git a/rpc/args_test.go b/rpc/args_test.go index fa46f65150..0c7360c53b 100644 --- a/rpc/args_test.go +++ b/rpc/args_test.go @@ -10,18 +10,6 @@ import ( "github.com/ethereum/go-ethereum/common" ) -func TestSha3(t *testing.T) { - input := `["0x68656c6c6f20776f726c64"]` - expected := "0x68656c6c6f20776f726c64" - - args := new(Sha3Args) - json.Unmarshal([]byte(input), &args) - - if args.Data != expected { - t.Error("got %s expected %s", input, expected) - } -} - func ExpectValidationError(err error) string { var str string switch err.(type) { @@ -74,6 +62,47 @@ func ExpectDecodeParamError(err error) string { return str } +func TestSha3(t *testing.T) { + input := `["0x68656c6c6f20776f726c64"]` + expected := "0x68656c6c6f20776f726c64" + + args := new(Sha3Args) + json.Unmarshal([]byte(input), &args) + + if args.Data != expected { + t.Error("got %s expected %s", input, expected) + } +} + +func TestSha3ArgsInvalid(t *testing.T) { + input := `{}` + + args := new(Sha3Args) + str := ExpectDecodeParamError(json.Unmarshal([]byte(input), &args)) + if len(str) > 0 { + t.Error(str) + } +} + +func TestSha3ArgsEmpty(t *testing.T) { + input := `[]` + + args := new(Sha3Args) + str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), &args)) + if len(str) > 0 { + t.Error(str) + } +} +func TestSha3ArgsDataInvalid(t *testing.T) { + input := `[4]` + + args := new(Sha3Args) + str := ExpectInvalidTypeError(json.Unmarshal([]byte(input), &args)) + if len(str) > 0 { + t.Error(str) + } +} + func TestGetBalanceArgs(t *testing.T) { input := `["0x407d73d8a49eeb85d32cf465507dd71d507100c1", "0x1f"]` expected := new(GetBalanceArgs) @@ -119,6 +148,8 @@ func TestGetBalanceArgsEmpty(t *testing.T) { args := new(GetBalanceArgs) str := ExpectInsufficientParamsError(json.Unmarshal([]byte(input), &args)) + if len(str) > 0 { + t.Error(str) } }