signer/core/apitypes: support more input types for eip-712 encoding (#26074)
* apitypes: synchronize handling of types * signer/core/apitypes: improve array check * apitypes: add a test for big.Int -> int32 * signer/core/apitypes: Add a test for parsing addresses from [20]byte, []byte and string * signer/core/apitypes: add some testcases Co-authored-by: Felix Lange <fjl@twurst.com> Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
parent
a51188a163
commit
6d55908347
|
@ -21,6 +21,7 @@ import (
|
|||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/common/hexutil"
|
||||
)
|
||||
|
||||
|
@ -84,6 +85,55 @@ func TestBytesPadding(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
Input interface{}
|
||||
Output []byte // nil => error
|
||||
}{
|
||||
{
|
||||
Input: [20]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14},
|
||||
Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"),
|
||||
},
|
||||
{
|
||||
Input: "0x0102030405060708090A0B0C0D0E0F1011121314",
|
||||
Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"),
|
||||
},
|
||||
{
|
||||
Input: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14},
|
||||
Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"),
|
||||
},
|
||||
// Various error-cases:
|
||||
{Input: "0x000102030405060708090A0B0C0D0E0F1011121314"}, // too long string
|
||||
{Input: "0x01"}, // too short string
|
||||
{Input: ""},
|
||||
{Input: [32]byte{}}, // too long fixed-size array
|
||||
{Input: [21]byte{}}, // too long fixed-size array
|
||||
{Input: make([]byte, 19)}, // too short slice
|
||||
{Input: make([]byte, 21)}, // too long slice
|
||||
{Input: nil},
|
||||
}
|
||||
|
||||
d := TypedData{}
|
||||
for i, test := range tests {
|
||||
val, err := d.EncodePrimitiveValue("address", test.Input, 1)
|
||||
if test.Output == nil {
|
||||
if err == nil {
|
||||
t.Errorf("test %d: expected error, got no error (result %x)", i, val)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("test %d: expected no error, got %v", i, err)
|
||||
}
|
||||
if have, want := len(val), 32; have != want {
|
||||
t.Errorf("test %d: have len %d, want %d", i, have, want)
|
||||
}
|
||||
if !bytes.Equal(val, test.Output) {
|
||||
t.Errorf("test %d: want %x, have %x", i, test.Output, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBytes(t *testing.T) {
|
||||
for i, tt := range []struct {
|
||||
v interface{}
|
||||
|
@ -98,6 +148,9 @@ func TestParseBytes(t *testing.T) {
|
|||
{"not a hex string", nil},
|
||||
{15, nil},
|
||||
{nil, nil},
|
||||
{[2]byte{12, 34}, []byte{12, 34}},
|
||||
{[8]byte{12, 34, 56, 78, 90, 12, 34, 56}, []byte{12, 34, 56, 78, 90, 12, 34, 56}},
|
||||
{[16]byte{12, 34, 56, 78, 90, 12, 34, 56, 12, 34, 56, 78, 90, 12, 34, 56}, []byte{12, 34, 56, 78, 90, 12, 34, 56, 12, 34, 56, 78, 90, 12, 34, 56}},
|
||||
} {
|
||||
out, ok := parseBytes(tt.v)
|
||||
if tt.exp == nil {
|
||||
|
@ -123,6 +176,7 @@ func TestParseInteger(t *testing.T) {
|
|||
}{
|
||||
{"uint32", "-123", nil},
|
||||
{"int32", "-123", big.NewInt(-123)},
|
||||
{"int32", big.NewInt(-124), big.NewInt(-124)},
|
||||
{"uint32", "0xff", big.NewInt(0xff)},
|
||||
{"int8", "0xffff", nil},
|
||||
} {
|
||||
|
|
|
@ -418,6 +418,14 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
|
|||
|
||||
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
|
||||
func parseBytes(encType interface{}) ([]byte, bool) {
|
||||
// Handle array types.
|
||||
val := reflect.ValueOf(encType)
|
||||
if val.Kind() == reflect.Array && val.Type().Elem().Kind() == reflect.Uint8 {
|
||||
v := reflect.MakeSlice(reflect.TypeOf([]byte{}), val.Len(), val.Len())
|
||||
reflect.Copy(v, val)
|
||||
return v.Bytes(), true
|
||||
}
|
||||
|
||||
switch v := encType.(type) {
|
||||
case []byte:
|
||||
return v, true
|
||||
|
@ -458,6 +466,8 @@ func parseInteger(encType string, encValue interface{}) (*big.Int, error) {
|
|||
switch v := encValue.(type) {
|
||||
case *math.HexOrDecimal256:
|
||||
b = (*big.Int)(v)
|
||||
case *big.Int:
|
||||
b = v
|
||||
case string:
|
||||
var hexIntValue math.HexOrDecimal256
|
||||
if err := hexIntValue.UnmarshalText([]byte(v)); err != nil {
|
||||
|
@ -490,13 +500,23 @@ func parseInteger(encType string, encValue interface{}) (*big.Int, error) {
|
|||
func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interface{}, depth int) ([]byte, error) {
|
||||
switch encType {
|
||||
case "address":
|
||||
stringValue, ok := encValue.(string)
|
||||
if !ok || !common.IsHexAddress(stringValue) {
|
||||
return nil, dataMismatchError(encType, encValue)
|
||||
}
|
||||
retval := make([]byte, 32)
|
||||
copy(retval[12:], common.HexToAddress(stringValue).Bytes())
|
||||
return retval, nil
|
||||
switch val := encValue.(type) {
|
||||
case string:
|
||||
if common.IsHexAddress(val) {
|
||||
copy(retval[12:], common.HexToAddress(val).Bytes())
|
||||
return retval, nil
|
||||
}
|
||||
case []byte:
|
||||
if len(val) == 20 {
|
||||
copy(retval[12:], val)
|
||||
return retval, nil
|
||||
}
|
||||
case [20]byte:
|
||||
copy(retval[12:], val[:])
|
||||
return retval, nil
|
||||
}
|
||||
return nil, dataMismatchError(encType, encValue)
|
||||
case "bool":
|
||||
boolValue, ok := encValue.(bool)
|
||||
if !ok {
|
||||
|
|
Loading…
Reference in New Issue