signer/core: extended support for EIP-712 array types (#30620)
This change updates the EIP-712 implementation to resolve [#30619](https://github.com/ethereum/go-ethereum/issues/30619). The test cases have been repurposed from the ethers.js [repository](https://github.com/ethers-io/ethers.js/blob/main/testcases/typed-data.json.gz), but have been updated to remove tests that don't have a valid domain separator; EIP-712 messages without a domain separator are not supported by geth. --------- Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
parent
f476702ea7
commit
b362c37e3f
|
@ -18,12 +18,18 @@ package apitypes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/common"
|
"github.com/ethereum/go-ethereum/common"
|
||||||
"github.com/ethereum/go-ethereum/common/hexutil"
|
"github.com/ethereum/go-ethereum/common/hexutil"
|
||||||
"github.com/ethereum/go-ethereum/common/math"
|
"github.com/ethereum/go-ethereum/common/math"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBytesPadding(t *testing.T) {
|
func TestBytesPadding(t *testing.T) {
|
||||||
|
@ -244,45 +250,42 @@ func TestConvertAddressDataToSlice(t *testing.T) {
|
||||||
func TestTypedDataArrayValidate(t *testing.T) {
|
func TestTypedDataArrayValidate(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
typedData := TypedData{
|
type testDataInput struct {
|
||||||
Types: Types{
|
Name string `json:"name"`
|
||||||
"BulkOrder": []Type{
|
Domain TypedDataDomain `json:"domain"`
|
||||||
// Should be able to accept fixed size arrays
|
PrimaryType string `json:"primaryType"`
|
||||||
{Name: "tree", Type: "OrderComponents[2][2]"},
|
Types Types `json:"types"`
|
||||||
},
|
Message TypedDataMessage `json:"data"`
|
||||||
"OrderComponents": []Type{
|
Digest string `json:"digest"`
|
||||||
{Name: "offerer", Type: "address"},
|
|
||||||
{Name: "amount", Type: "uint8"},
|
|
||||||
},
|
|
||||||
"EIP712Domain": []Type{
|
|
||||||
{Name: "name", Type: "string"},
|
|
||||||
{Name: "version", Type: "string"},
|
|
||||||
{Name: "chainId", Type: "uint8"},
|
|
||||||
{Name: "verifyingContract", Type: "address"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PrimaryType: "BulkOrder",
|
|
||||||
Domain: TypedDataDomain{
|
|
||||||
VerifyingContract: "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC",
|
|
||||||
},
|
|
||||||
Message: TypedDataMessage{},
|
|
||||||
}
|
}
|
||||||
|
fc, err := os.ReadFile("./testdata/typed-data.json")
|
||||||
|
require.NoError(t, err, "error reading test data file")
|
||||||
|
|
||||||
if err := typedData.validate(); err != nil {
|
var tests []testDataInput
|
||||||
t.Errorf("expected typed data to pass validation, got: %v", err)
|
err = json.Unmarshal(fc, &tests)
|
||||||
}
|
require.NoError(t, err, "error unmarshalling test data file contents")
|
||||||
|
|
||||||
// Should be able to accept dynamic arrays
|
for _, tc := range tests {
|
||||||
typedData.Types["BulkOrder"][0].Type = "OrderComponents[]"
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
if err := typedData.validate(); err != nil {
|
td := TypedData{
|
||||||
t.Errorf("expected typed data to pass validation, got: %v", err)
|
Types: tc.Types,
|
||||||
}
|
PrimaryType: tc.PrimaryType,
|
||||||
|
Domain: tc.Domain,
|
||||||
|
Message: tc.Message,
|
||||||
|
}
|
||||||
|
|
||||||
// Should be able to accept standard types
|
domainSeparator, tErr := td.HashStruct("EIP712Domain", td.Domain.Map())
|
||||||
typedData.Types["BulkOrder"][0].Type = "OrderComponents"
|
assert.NoError(t, tErr, "failed to hash domain separator: %v", tErr)
|
||||||
|
|
||||||
if err := typedData.validate(); err != nil {
|
messageHash, tErr := td.HashStruct(td.PrimaryType, td.Message)
|
||||||
t.Errorf("expected typed data to pass validation, got: %v", err)
|
assert.NoError(t, tErr, "failed to hash message: %v", tErr)
|
||||||
|
|
||||||
|
digest := crypto.Keccak256Hash([]byte(fmt.Sprintf("%s%s%s", "\x19\x01", string(domainSeparator), string(messageHash))))
|
||||||
|
assert.Equal(t, tc.Digest, digest.String(), "digest doesn't not match")
|
||||||
|
|
||||||
|
assert.NoError(t, td.validate(), "validation failed", tErr)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -325,18 +325,17 @@ type Type struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isArray returns true if the type is a fixed or variable sized array.
|
||||||
|
// This method may return false positives, in case the Type is not a valid
|
||||||
|
// expression, e.g. "fooo[[[[".
|
||||||
func (t *Type) isArray() bool {
|
func (t *Type) isArray() bool {
|
||||||
return strings.HasSuffix(t.Type, "[]")
|
return strings.IndexByte(t.Type, '[') > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// typeName returns the canonical name of the type. If the type is 'Person[]', then
|
// typeName returns the canonical name of the type. If the type is 'Person[]' or 'Person[2]', then
|
||||||
// this method returns 'Person'
|
// this method returns 'Person'
|
||||||
func (t *Type) typeName() string {
|
func (t *Type) typeName() string {
|
||||||
if strings.Contains(t.Type, "[") {
|
return strings.Split(t.Type, "[")[0]
|
||||||
re := regexp.MustCompile(`\[\d*\]`)
|
|
||||||
return re.ReplaceAllString(t.Type, "")
|
|
||||||
}
|
|
||||||
return t.Type
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Types map[string][]Type
|
type Types map[string][]Type
|
||||||
|
@ -387,7 +386,7 @@ func (typedData *TypedData) HashStruct(primaryType string, data TypedDataMessage
|
||||||
|
|
||||||
// Dependencies returns an array of custom types ordered by their hierarchical reference tree
|
// Dependencies returns an array of custom types ordered by their hierarchical reference tree
|
||||||
func (typedData *TypedData) Dependencies(primaryType string, found []string) []string {
|
func (typedData *TypedData) Dependencies(primaryType string, found []string) []string {
|
||||||
primaryType = strings.TrimSuffix(primaryType, "[]")
|
primaryType = strings.Split(primaryType, "[")[0]
|
||||||
|
|
||||||
if slices.Contains(found, primaryType) {
|
if slices.Contains(found, primaryType) {
|
||||||
return found
|
return found
|
||||||
|
@ -465,34 +464,11 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
|
||||||
encType := field.Type
|
encType := field.Type
|
||||||
encValue := data[field.Name]
|
encValue := data[field.Name]
|
||||||
if encType[len(encType)-1:] == "]" {
|
if encType[len(encType)-1:] == "]" {
|
||||||
arrayValue, err := convertDataToSlice(encValue)
|
encodedData, err := typedData.encodeArrayValue(encValue, encType, depth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, dataMismatchError(encType, encValue)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
buffer.Write(encodedData)
|
||||||
arrayBuffer := bytes.Buffer{}
|
|
||||||
parsedType := strings.Split(encType, "[")[0]
|
|
||||||
for _, item := range arrayValue {
|
|
||||||
if typedData.Types[parsedType] != nil {
|
|
||||||
mapValue, ok := item.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, dataMismatchError(parsedType, item)
|
|
||||||
}
|
|
||||||
encodedData, err := typedData.EncodeData(parsedType, mapValue, depth+1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
arrayBuffer.Write(crypto.Keccak256(encodedData))
|
|
||||||
} else {
|
|
||||||
bytesValue, err := typedData.EncodePrimitiveValue(parsedType, item, depth)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
arrayBuffer.Write(bytesValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer.Write(crypto.Keccak256(arrayBuffer.Bytes()))
|
|
||||||
} else if typedData.Types[field.Type] != nil {
|
} else if typedData.Types[field.Type] != nil {
|
||||||
mapValue, ok := encValue.(map[string]interface{})
|
mapValue, ok := encValue.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -514,6 +490,46 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
|
||||||
return buffer.Bytes(), nil
|
return buffer.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (typedData *TypedData) encodeArrayValue(encValue interface{}, encType string, depth int) (hexutil.Bytes, error) {
|
||||||
|
arrayValue, err := convertDataToSlice(encValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, dataMismatchError(encType, encValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayBuffer := new(bytes.Buffer)
|
||||||
|
parsedType := strings.Split(encType, "[")[0]
|
||||||
|
for _, item := range arrayValue {
|
||||||
|
if reflect.TypeOf(item).Kind() == reflect.Slice ||
|
||||||
|
reflect.TypeOf(item).Kind() == reflect.Array {
|
||||||
|
encodedData, err := typedData.encodeArrayValue(item, parsedType, depth+1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
arrayBuffer.Write(encodedData)
|
||||||
|
} else {
|
||||||
|
if typedData.Types[parsedType] != nil {
|
||||||
|
mapValue, ok := item.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, dataMismatchError(parsedType, item)
|
||||||
|
}
|
||||||
|
encodedData, err := typedData.EncodeData(parsedType, mapValue, depth+1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
digest := crypto.Keccak256(encodedData)
|
||||||
|
arrayBuffer.Write(digest)
|
||||||
|
} else {
|
||||||
|
bytesValue, err := typedData.EncodePrimitiveValue(parsedType, item, depth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
arrayBuffer.Write(bytesValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return crypto.Keccak256(arrayBuffer.Bytes()), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
|
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
|
||||||
func parseBytes(encType interface{}) ([]byte, bool) {
|
func parseBytes(encType interface{}) ([]byte, bool) {
|
||||||
// Handle array types.
|
// Handle array types.
|
||||||
|
@ -871,7 +887,8 @@ func init() {
|
||||||
|
|
||||||
// Checks if the primitive value is valid
|
// Checks if the primitive value is valid
|
||||||
func isPrimitiveTypeValid(primitiveType string) bool {
|
func isPrimitiveTypeValid(primitiveType string) bool {
|
||||||
_, ok := validPrimitiveTypes[primitiveType]
|
input := strings.Split(primitiveType, "[")[0]
|
||||||
|
_, ok := validPrimitiveTypes[input]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,9 @@ func TestIsPrimitive(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
// Expected positives
|
// Expected positives
|
||||||
for i, tc := range []string{
|
for i, tc := range []string{
|
||||||
"int24", "int24[]", "uint88", "uint88[]", "uint", "uint[]", "int256", "int256[]",
|
"int24", "int24[]", "int[]", "int[2]", "uint88", "uint88[]", "uint", "uint[]", "uint[2]", "int256", "int256[]",
|
||||||
"uint96", "uint96[]", "int96", "int96[]", "bytes17[]", "bytes17",
|
"uint96", "uint96[]", "int96", "int96[]", "bytes17[]", "bytes17", "address[2]", "bool[4]", "string[5]", "bytes[2]",
|
||||||
|
"bytes32", "bytes32[]", "bytes32[4]",
|
||||||
} {
|
} {
|
||||||
if !isPrimitiveTypeValid(tc) {
|
if !isPrimitiveTypeValid(tc) {
|
||||||
t.Errorf("test %d: expected '%v' to be a valid primitive", i, tc)
|
t.Errorf("test %d: expected '%v' to be a valid primitive", i, tc)
|
||||||
|
@ -141,3 +142,94 @@ func TestBlobTxs(t *testing.T) {
|
||||||
}
|
}
|
||||||
t.Logf("tx %v", string(data))
|
t.Logf("tx %v", string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestType_IsArray(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Expected positives
|
||||||
|
for i, tc := range []Type{
|
||||||
|
{
|
||||||
|
Name: "type1",
|
||||||
|
Type: "int24[]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "type2",
|
||||||
|
Type: "int24[2]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "type3",
|
||||||
|
Type: "int24[2][2][2]",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
if !tc.isArray() {
|
||||||
|
t.Errorf("test %d: expected '%v' to be an array", i, tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Expected negatives
|
||||||
|
for i, tc := range []Type{
|
||||||
|
{
|
||||||
|
Name: "type1",
|
||||||
|
Type: "int24",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "type2",
|
||||||
|
Type: "uint88",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "type3",
|
||||||
|
Type: "bytes32",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
if tc.isArray() {
|
||||||
|
t.Errorf("test %d: expected '%v' to not be an array", i, tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestType_TypeName(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, tc := range []struct {
|
||||||
|
Input Type
|
||||||
|
Expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Input: Type{
|
||||||
|
Name: "type1",
|
||||||
|
Type: "int24[]",
|
||||||
|
},
|
||||||
|
Expected: "int24",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Input: Type{
|
||||||
|
Name: "type2",
|
||||||
|
Type: "int26[2][2][2]",
|
||||||
|
},
|
||||||
|
Expected: "int26",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Input: Type{
|
||||||
|
Name: "type3",
|
||||||
|
Type: "int24",
|
||||||
|
},
|
||||||
|
Expected: "int24",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Input: Type{
|
||||||
|
Name: "type4",
|
||||||
|
Type: "uint88",
|
||||||
|
},
|
||||||
|
Expected: "uint88",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Input: Type{
|
||||||
|
Name: "type5",
|
||||||
|
Type: "bytes32[2]",
|
||||||
|
},
|
||||||
|
Expected: "bytes32",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
if tc.Input.typeName() != tc.Expected {
|
||||||
|
t.Errorf("test %d: expected typeName value of '%v' but got '%v'", i, tc.Expected, tc.Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue