signer: EIP 712, parse `bytes` and `bytesX` as hex strings + correct padding (#21307)
* Handle hex strings for bytesX types * Add tests for parseBytes * Improve tests * Return nil bytes if error is non-nil * Right-pad instead of left-pad bytes * More tests
This commit is contained in:
parent
c0c01612e9
commit
90dedea40f
|
@ -481,6 +481,24 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
|
|||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
|
||||
func parseBytes(encType interface{}) ([]byte, bool) {
|
||||
switch v := encType.(type) {
|
||||
case []byte:
|
||||
return v, true
|
||||
case hexutil.Bytes:
|
||||
return []byte(v), true
|
||||
case string:
|
||||
bytes, err := hexutil.Decode(v)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return bytes, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseInteger(encType string, encValue interface{}) (*big.Int, error) {
|
||||
var (
|
||||
length int
|
||||
|
@ -560,7 +578,7 @@ func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interf
|
|||
}
|
||||
return crypto.Keccak256([]byte(strVal)), nil
|
||||
case "bytes":
|
||||
bytesValue, ok := encValue.([]byte)
|
||||
bytesValue, ok := parseBytes(encValue)
|
||||
if !ok {
|
||||
return nil, dataMismatchError(encType, encValue)
|
||||
}
|
||||
|
@ -575,10 +593,13 @@ func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interf
|
|||
if length < 0 || length > 32 {
|
||||
return nil, fmt.Errorf("invalid size on bytes: %d", length)
|
||||
}
|
||||
if byteValue, ok := encValue.(hexutil.Bytes); !ok {
|
||||
if byteValue, ok := parseBytes(encValue); !ok || len(byteValue) != length {
|
||||
return nil, dataMismatchError(encType, encValue)
|
||||
} else {
|
||||
return math.PaddedBigBytes(new(big.Int).SetBytes(byteValue), 32), nil
|
||||
// Right-pad the bits
|
||||
dst := make([]byte, 32)
|
||||
copy(dst, byteValue)
|
||||
return dst, nil
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(encType, "int") || strings.HasPrefix(encType, "uint") {
|
||||
|
|
|
@ -17,10 +17,104 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/hexutil"
|
||||
)
|
||||
|
||||
func TestBytesPadding(t *testing.T) {
|
||||
tests := []struct {
|
||||
Type string
|
||||
Input []byte
|
||||
Output []byte // nil => error
|
||||
}{
|
||||
{
|
||||
// Fail on wrong length
|
||||
Type: "bytes20",
|
||||
Input: []byte{},
|
||||
Output: nil,
|
||||
},
|
||||
{
|
||||
Type: "bytes1",
|
||||
Input: []byte{1},
|
||||
Output: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
Type: "bytes1",
|
||||
Input: []byte{1, 2},
|
||||
Output: nil,
|
||||
},
|
||||
{
|
||||
Type: "bytes7",
|
||||
Input: []byte{1, 2, 3, 4, 5, 6, 7},
|
||||
Output: []byte{1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
Type: "bytes32",
|
||||
Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
|
||||
Output: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
|
||||
},
|
||||
{
|
||||
Type: "bytes32",
|
||||
Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33},
|
||||
Output: nil,
|
||||
},
|
||||
}
|
||||
|
||||
d := TypedData{}
|
||||
for i, test := range tests {
|
||||
val, err := d.EncodePrimitiveValue(test.Type, test.Input, 1)
|
||||
if test.Output == nil {
|
||||
if err == nil {
|
||||
t.Errorf("test %d: expected error, got no error (result %x)", i, val)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("test %d: expected no error, got %v", i, err)
|
||||
}
|
||||
if len(val) != 32 {
|
||||
t.Errorf("test %d: expected len 32, got %d", i, len(val))
|
||||
}
|
||||
if !bytes.Equal(val, test.Output) {
|
||||
t.Errorf("test %d: expected %x, got %x", i, test.Output, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBytes(t *testing.T) {
|
||||
for i, tt := range []struct {
|
||||
v interface{}
|
||||
exp []byte
|
||||
}{
|
||||
{"0x", []byte{}},
|
||||
{"0x1234", []byte{0x12, 0x34}},
|
||||
{[]byte{12, 34}, []byte{12, 34}},
|
||||
{hexutil.Bytes([]byte{12, 34}), []byte{12, 34}},
|
||||
{"1234", nil}, // not a proper hex-string
|
||||
{"0x01233", nil}, // nibbles should be rejected
|
||||
{"not a hex string", nil},
|
||||
{15, nil},
|
||||
{nil, nil},
|
||||
} {
|
||||
out, ok := parseBytes(tt.v)
|
||||
if tt.exp == nil {
|
||||
if ok || out != nil {
|
||||
t.Errorf("test %d: expected !ok, got ok = %v with out = %x", i, ok, out)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("test %d: expected ok got !ok", i)
|
||||
}
|
||||
if !bytes.Equal(out, tt.exp) {
|
||||
t.Errorf("test %d: expected %x got %x", i, tt.exp, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInteger(t *testing.T) {
|
||||
for i, tt := range []struct {
|
||||
t string
|
||||
|
|
Loading…
Reference in New Issue