signer/core: extended support for EIP-712 array types (#30620)

This change updates the EIP-712 implementation to resolve [#30619](https://github.com/ethereum/go-ethereum/issues/30619).

The test cases have been repurposed from the ethers.js [repository](https://github.com/ethers-io/ethers.js/blob/main/testcases/typed-data.json.gz), but have been updated to remove tests that don't have a valid domain separator; EIP-712 messages without a domain separator are not supported by geth.

---------

Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
Naveen 2024-11-09 01:04:17 +11:00 committed by Martin HS
parent f476702ea7
commit b362c37e3f
4 changed files with 6272 additions and 71 deletions

View File

@ -18,12 +18,18 @@ package apitypes
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt"
"math/big" "math/big"
"os"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestBytesPadding(t *testing.T) { func TestBytesPadding(t *testing.T) {
@ -244,45 +250,42 @@ func TestConvertAddressDataToSlice(t *testing.T) {
func TestTypedDataArrayValidate(t *testing.T) { func TestTypedDataArrayValidate(t *testing.T) {
t.Parallel() t.Parallel()
typedData := TypedData{ type testDataInput struct {
Types: Types{ Name string `json:"name"`
"BulkOrder": []Type{ Domain TypedDataDomain `json:"domain"`
// Should be able to accept fixed size arrays PrimaryType string `json:"primaryType"`
{Name: "tree", Type: "OrderComponents[2][2]"}, Types Types `json:"types"`
}, Message TypedDataMessage `json:"data"`
"OrderComponents": []Type{ Digest string `json:"digest"`
{Name: "offerer", Type: "address"},
{Name: "amount", Type: "uint8"},
},
"EIP712Domain": []Type{
{Name: "name", Type: "string"},
{Name: "version", Type: "string"},
{Name: "chainId", Type: "uint8"},
{Name: "verifyingContract", Type: "address"},
},
},
PrimaryType: "BulkOrder",
Domain: TypedDataDomain{
VerifyingContract: "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC",
},
Message: TypedDataMessage{},
} }
fc, err := os.ReadFile("./testdata/typed-data.json")
require.NoError(t, err, "error reading test data file")
if err := typedData.validate(); err != nil { var tests []testDataInput
t.Errorf("expected typed data to pass validation, got: %v", err) err = json.Unmarshal(fc, &tests)
} require.NoError(t, err, "error unmarshalling test data file contents")
// Should be able to accept dynamic arrays for _, tc := range tests {
typedData.Types["BulkOrder"][0].Type = "OrderComponents[]" t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
if err := typedData.validate(); err != nil { td := TypedData{
t.Errorf("expected typed data to pass validation, got: %v", err) Types: tc.Types,
} PrimaryType: tc.PrimaryType,
Domain: tc.Domain,
Message: tc.Message,
}
// Should be able to accept standard types domainSeparator, tErr := td.HashStruct("EIP712Domain", td.Domain.Map())
typedData.Types["BulkOrder"][0].Type = "OrderComponents" assert.NoError(t, tErr, "failed to hash domain separator: %v", tErr)
if err := typedData.validate(); err != nil { messageHash, tErr := td.HashStruct(td.PrimaryType, td.Message)
t.Errorf("expected typed data to pass validation, got: %v", err) assert.NoError(t, tErr, "failed to hash message: %v", tErr)
digest := crypto.Keccak256Hash([]byte(fmt.Sprintf("%s%s%s", "\x19\x01", string(domainSeparator), string(messageHash))))
assert.Equal(t, tc.Digest, digest.String(), "digest doesn't not match")
assert.NoError(t, td.validate(), "validation failed", tErr)
})
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -325,18 +325,17 @@ type Type struct {
Type string `json:"type"` Type string `json:"type"`
} }
// isArray returns true if the type is a fixed or variable sized array.
// This method may return false positives, in case the Type is not a valid
// expression, e.g. "fooo[[[[".
func (t *Type) isArray() bool { func (t *Type) isArray() bool {
return strings.HasSuffix(t.Type, "[]") return strings.IndexByte(t.Type, '[') > 0
} }
// typeName returns the canonical name of the type. If the type is 'Person[]', then // typeName returns the canonical name of the type. If the type is 'Person[]' or 'Person[2]', then
// this method returns 'Person' // this method returns 'Person'
func (t *Type) typeName() string { func (t *Type) typeName() string {
if strings.Contains(t.Type, "[") { return strings.Split(t.Type, "[")[0]
re := regexp.MustCompile(`\[\d*\]`)
return re.ReplaceAllString(t.Type, "")
}
return t.Type
} }
type Types map[string][]Type type Types map[string][]Type
@ -387,7 +386,7 @@ func (typedData *TypedData) HashStruct(primaryType string, data TypedDataMessage
// Dependencies returns an array of custom types ordered by their hierarchical reference tree // Dependencies returns an array of custom types ordered by their hierarchical reference tree
func (typedData *TypedData) Dependencies(primaryType string, found []string) []string { func (typedData *TypedData) Dependencies(primaryType string, found []string) []string {
primaryType = strings.TrimSuffix(primaryType, "[]") primaryType = strings.Split(primaryType, "[")[0]
if slices.Contains(found, primaryType) { if slices.Contains(found, primaryType) {
return found return found
@ -465,34 +464,11 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
encType := field.Type encType := field.Type
encValue := data[field.Name] encValue := data[field.Name]
if encType[len(encType)-1:] == "]" { if encType[len(encType)-1:] == "]" {
arrayValue, err := convertDataToSlice(encValue) encodedData, err := typedData.encodeArrayValue(encValue, encType, depth)
if err != nil { if err != nil {
return nil, dataMismatchError(encType, encValue) return nil, err
} }
buffer.Write(encodedData)
arrayBuffer := bytes.Buffer{}
parsedType := strings.Split(encType, "[")[0]
for _, item := range arrayValue {
if typedData.Types[parsedType] != nil {
mapValue, ok := item.(map[string]interface{})
if !ok {
return nil, dataMismatchError(parsedType, item)
}
encodedData, err := typedData.EncodeData(parsedType, mapValue, depth+1)
if err != nil {
return nil, err
}
arrayBuffer.Write(crypto.Keccak256(encodedData))
} else {
bytesValue, err := typedData.EncodePrimitiveValue(parsedType, item, depth)
if err != nil {
return nil, err
}
arrayBuffer.Write(bytesValue)
}
}
buffer.Write(crypto.Keccak256(arrayBuffer.Bytes()))
} else if typedData.Types[field.Type] != nil { } else if typedData.Types[field.Type] != nil {
mapValue, ok := encValue.(map[string]interface{}) mapValue, ok := encValue.(map[string]interface{})
if !ok { if !ok {
@ -514,6 +490,46 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
func (typedData *TypedData) encodeArrayValue(encValue interface{}, encType string, depth int) (hexutil.Bytes, error) {
arrayValue, err := convertDataToSlice(encValue)
if err != nil {
return nil, dataMismatchError(encType, encValue)
}
arrayBuffer := new(bytes.Buffer)
parsedType := strings.Split(encType, "[")[0]
for _, item := range arrayValue {
if reflect.TypeOf(item).Kind() == reflect.Slice ||
reflect.TypeOf(item).Kind() == reflect.Array {
encodedData, err := typedData.encodeArrayValue(item, parsedType, depth+1)
if err != nil {
return nil, err
}
arrayBuffer.Write(encodedData)
} else {
if typedData.Types[parsedType] != nil {
mapValue, ok := item.(map[string]interface{})
if !ok {
return nil, dataMismatchError(parsedType, item)
}
encodedData, err := typedData.EncodeData(parsedType, mapValue, depth+1)
if err != nil {
return nil, err
}
digest := crypto.Keccak256(encodedData)
arrayBuffer.Write(digest)
} else {
bytesValue, err := typedData.EncodePrimitiveValue(parsedType, item, depth)
if err != nil {
return nil, err
}
arrayBuffer.Write(bytesValue)
}
}
}
return crypto.Keccak256(arrayBuffer.Bytes()), nil
}
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes. // Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
func parseBytes(encType interface{}) ([]byte, bool) { func parseBytes(encType interface{}) ([]byte, bool) {
// Handle array types. // Handle array types.
@ -871,7 +887,8 @@ func init() {
// Checks if the primitive value is valid // Checks if the primitive value is valid
func isPrimitiveTypeValid(primitiveType string) bool { func isPrimitiveTypeValid(primitiveType string) bool {
_, ok := validPrimitiveTypes[primitiveType] input := strings.Split(primitiveType, "[")[0]
_, ok := validPrimitiveTypes[input]
return ok return ok
} }

View File

@ -31,8 +31,9 @@ func TestIsPrimitive(t *testing.T) {
t.Parallel() t.Parallel()
// Expected positives // Expected positives
for i, tc := range []string{ for i, tc := range []string{
"int24", "int24[]", "uint88", "uint88[]", "uint", "uint[]", "int256", "int256[]", "int24", "int24[]", "int[]", "int[2]", "uint88", "uint88[]", "uint", "uint[]", "uint[2]", "int256", "int256[]",
"uint96", "uint96[]", "int96", "int96[]", "bytes17[]", "bytes17", "uint96", "uint96[]", "int96", "int96[]", "bytes17[]", "bytes17", "address[2]", "bool[4]", "string[5]", "bytes[2]",
"bytes32", "bytes32[]", "bytes32[4]",
} { } {
if !isPrimitiveTypeValid(tc) { if !isPrimitiveTypeValid(tc) {
t.Errorf("test %d: expected '%v' to be a valid primitive", i, tc) t.Errorf("test %d: expected '%v' to be a valid primitive", i, tc)
@ -141,3 +142,94 @@ func TestBlobTxs(t *testing.T) {
} }
t.Logf("tx %v", string(data)) t.Logf("tx %v", string(data))
} }
func TestType_IsArray(t *testing.T) {
t.Parallel()
// Expected positives
for i, tc := range []Type{
{
Name: "type1",
Type: "int24[]",
},
{
Name: "type2",
Type: "int24[2]",
},
{
Name: "type3",
Type: "int24[2][2][2]",
},
} {
if !tc.isArray() {
t.Errorf("test %d: expected '%v' to be an array", i, tc)
}
}
// Expected negatives
for i, tc := range []Type{
{
Name: "type1",
Type: "int24",
},
{
Name: "type2",
Type: "uint88",
},
{
Name: "type3",
Type: "bytes32",
},
} {
if tc.isArray() {
t.Errorf("test %d: expected '%v' to not be an array", i, tc)
}
}
}
func TestType_TypeName(t *testing.T) {
t.Parallel()
for i, tc := range []struct {
Input Type
Expected string
}{
{
Input: Type{
Name: "type1",
Type: "int24[]",
},
Expected: "int24",
},
{
Input: Type{
Name: "type2",
Type: "int26[2][2][2]",
},
Expected: "int26",
},
{
Input: Type{
Name: "type3",
Type: "int24",
},
Expected: "int24",
},
{
Input: Type{
Name: "type4",
Type: "uint88",
},
Expected: "uint88",
},
{
Input: Type{
Name: "type5",
Type: "bytes32[2]",
},
Expected: "bytes32",
},
} {
if tc.Input.typeName() != tc.Expected {
t.Errorf("test %d: expected typeName value of '%v' but got '%v'", i, tc.Expected, tc.Input)
}
}
}