signer/core: extended support for EIP-712 array types (#30620)
This change updates the EIP-712 implementation to resolve [#30619](https://github.com/ethereum/go-ethereum/issues/30619). The test cases have been repurposed from the ethers.js [repository](https://github.com/ethers-io/ethers.js/blob/main/testcases/typed-data.json.gz), but have been updated to remove tests that don't have a valid domain separator; EIP-712 messages without a domain separator are not supported by geth. --------- Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
parent
f476702ea7
commit
b362c37e3f
|
@ -18,12 +18,18 @@ package apitypes
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/common/hexutil"
|
||||
"github.com/ethereum/go-ethereum/common/math"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBytesPadding(t *testing.T) {
|
||||
|
@ -244,45 +250,42 @@ func TestConvertAddressDataToSlice(t *testing.T) {
|
|||
func TestTypedDataArrayValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typedData := TypedData{
|
||||
Types: Types{
|
||||
"BulkOrder": []Type{
|
||||
// Should be able to accept fixed size arrays
|
||||
{Name: "tree", Type: "OrderComponents[2][2]"},
|
||||
},
|
||||
"OrderComponents": []Type{
|
||||
{Name: "offerer", Type: "address"},
|
||||
{Name: "amount", Type: "uint8"},
|
||||
},
|
||||
"EIP712Domain": []Type{
|
||||
{Name: "name", Type: "string"},
|
||||
{Name: "version", Type: "string"},
|
||||
{Name: "chainId", Type: "uint8"},
|
||||
{Name: "verifyingContract", Type: "address"},
|
||||
},
|
||||
},
|
||||
PrimaryType: "BulkOrder",
|
||||
Domain: TypedDataDomain{
|
||||
VerifyingContract: "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC",
|
||||
},
|
||||
Message: TypedDataMessage{},
|
||||
type testDataInput struct {
|
||||
Name string `json:"name"`
|
||||
Domain TypedDataDomain `json:"domain"`
|
||||
PrimaryType string `json:"primaryType"`
|
||||
Types Types `json:"types"`
|
||||
Message TypedDataMessage `json:"data"`
|
||||
Digest string `json:"digest"`
|
||||
}
|
||||
fc, err := os.ReadFile("./testdata/typed-data.json")
|
||||
require.NoError(t, err, "error reading test data file")
|
||||
|
||||
if err := typedData.validate(); err != nil {
|
||||
t.Errorf("expected typed data to pass validation, got: %v", err)
|
||||
}
|
||||
var tests []testDataInput
|
||||
err = json.Unmarshal(fc, &tests)
|
||||
require.NoError(t, err, "error unmarshalling test data file contents")
|
||||
|
||||
// Should be able to accept dynamic arrays
|
||||
typedData.Types["BulkOrder"][0].Type = "OrderComponents[]"
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if err := typedData.validate(); err != nil {
|
||||
t.Errorf("expected typed data to pass validation, got: %v", err)
|
||||
}
|
||||
td := TypedData{
|
||||
Types: tc.Types,
|
||||
PrimaryType: tc.PrimaryType,
|
||||
Domain: tc.Domain,
|
||||
Message: tc.Message,
|
||||
}
|
||||
|
||||
// Should be able to accept standard types
|
||||
typedData.Types["BulkOrder"][0].Type = "OrderComponents"
|
||||
domainSeparator, tErr := td.HashStruct("EIP712Domain", td.Domain.Map())
|
||||
assert.NoError(t, tErr, "failed to hash domain separator: %v", tErr)
|
||||
|
||||
if err := typedData.validate(); err != nil {
|
||||
t.Errorf("expected typed data to pass validation, got: %v", err)
|
||||
messageHash, tErr := td.HashStruct(td.PrimaryType, td.Message)
|
||||
assert.NoError(t, tErr, "failed to hash message: %v", tErr)
|
||||
|
||||
digest := crypto.Keccak256Hash([]byte(fmt.Sprintf("%s%s%s", "\x19\x01", string(domainSeparator), string(messageHash))))
|
||||
assert.Equal(t, tc.Digest, digest.String(), "digest doesn't not match")
|
||||
|
||||
assert.NoError(t, td.validate(), "validation failed", tErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -325,18 +325,17 @@ type Type struct {
|
|||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// isArray returns true if the type is a fixed or variable sized array.
|
||||
// This method may return false positives, in case the Type is not a valid
|
||||
// expression, e.g. "fooo[[[[".
|
||||
func (t *Type) isArray() bool {
|
||||
return strings.HasSuffix(t.Type, "[]")
|
||||
return strings.IndexByte(t.Type, '[') > 0
|
||||
}
|
||||
|
||||
// typeName returns the canonical name of the type. If the type is 'Person[]', then
|
||||
// typeName returns the canonical name of the type. If the type is 'Person[]' or 'Person[2]', then
|
||||
// this method returns 'Person'
|
||||
func (t *Type) typeName() string {
|
||||
if strings.Contains(t.Type, "[") {
|
||||
re := regexp.MustCompile(`\[\d*\]`)
|
||||
return re.ReplaceAllString(t.Type, "")
|
||||
}
|
||||
return t.Type
|
||||
return strings.Split(t.Type, "[")[0]
|
||||
}
|
||||
|
||||
type Types map[string][]Type
|
||||
|
@ -387,7 +386,7 @@ func (typedData *TypedData) HashStruct(primaryType string, data TypedDataMessage
|
|||
|
||||
// Dependencies returns an array of custom types ordered by their hierarchical reference tree
|
||||
func (typedData *TypedData) Dependencies(primaryType string, found []string) []string {
|
||||
primaryType = strings.TrimSuffix(primaryType, "[]")
|
||||
primaryType = strings.Split(primaryType, "[")[0]
|
||||
|
||||
if slices.Contains(found, primaryType) {
|
||||
return found
|
||||
|
@ -465,34 +464,11 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
|
|||
encType := field.Type
|
||||
encValue := data[field.Name]
|
||||
if encType[len(encType)-1:] == "]" {
|
||||
arrayValue, err := convertDataToSlice(encValue)
|
||||
encodedData, err := typedData.encodeArrayValue(encValue, encType, depth)
|
||||
if err != nil {
|
||||
return nil, dataMismatchError(encType, encValue)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrayBuffer := bytes.Buffer{}
|
||||
parsedType := strings.Split(encType, "[")[0]
|
||||
for _, item := range arrayValue {
|
||||
if typedData.Types[parsedType] != nil {
|
||||
mapValue, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, dataMismatchError(parsedType, item)
|
||||
}
|
||||
encodedData, err := typedData.EncodeData(parsedType, mapValue, depth+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arrayBuffer.Write(crypto.Keccak256(encodedData))
|
||||
} else {
|
||||
bytesValue, err := typedData.EncodePrimitiveValue(parsedType, item, depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arrayBuffer.Write(bytesValue)
|
||||
}
|
||||
}
|
||||
|
||||
buffer.Write(crypto.Keccak256(arrayBuffer.Bytes()))
|
||||
buffer.Write(encodedData)
|
||||
} else if typedData.Types[field.Type] != nil {
|
||||
mapValue, ok := encValue.(map[string]interface{})
|
||||
if !ok {
|
||||
|
@ -514,6 +490,46 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
|
|||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (typedData *TypedData) encodeArrayValue(encValue interface{}, encType string, depth int) (hexutil.Bytes, error) {
|
||||
arrayValue, err := convertDataToSlice(encValue)
|
||||
if err != nil {
|
||||
return nil, dataMismatchError(encType, encValue)
|
||||
}
|
||||
|
||||
arrayBuffer := new(bytes.Buffer)
|
||||
parsedType := strings.Split(encType, "[")[0]
|
||||
for _, item := range arrayValue {
|
||||
if reflect.TypeOf(item).Kind() == reflect.Slice ||
|
||||
reflect.TypeOf(item).Kind() == reflect.Array {
|
||||
encodedData, err := typedData.encodeArrayValue(item, parsedType, depth+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arrayBuffer.Write(encodedData)
|
||||
} else {
|
||||
if typedData.Types[parsedType] != nil {
|
||||
mapValue, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, dataMismatchError(parsedType, item)
|
||||
}
|
||||
encodedData, err := typedData.EncodeData(parsedType, mapValue, depth+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
digest := crypto.Keccak256(encodedData)
|
||||
arrayBuffer.Write(digest)
|
||||
} else {
|
||||
bytesValue, err := typedData.EncodePrimitiveValue(parsedType, item, depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arrayBuffer.Write(bytesValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
return crypto.Keccak256(arrayBuffer.Bytes()), nil
|
||||
}
|
||||
|
||||
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
|
||||
func parseBytes(encType interface{}) ([]byte, bool) {
|
||||
// Handle array types.
|
||||
|
@ -871,7 +887,8 @@ func init() {
|
|||
|
||||
// Checks if the primitive value is valid
|
||||
func isPrimitiveTypeValid(primitiveType string) bool {
|
||||
_, ok := validPrimitiveTypes[primitiveType]
|
||||
input := strings.Split(primitiveType, "[")[0]
|
||||
_, ok := validPrimitiveTypes[input]
|
||||
return ok
|
||||
}
|
||||
|
||||
|
|
|
@ -31,8 +31,9 @@ func TestIsPrimitive(t *testing.T) {
|
|||
t.Parallel()
|
||||
// Expected positives
|
||||
for i, tc := range []string{
|
||||
"int24", "int24[]", "uint88", "uint88[]", "uint", "uint[]", "int256", "int256[]",
|
||||
"uint96", "uint96[]", "int96", "int96[]", "bytes17[]", "bytes17",
|
||||
"int24", "int24[]", "int[]", "int[2]", "uint88", "uint88[]", "uint", "uint[]", "uint[2]", "int256", "int256[]",
|
||||
"uint96", "uint96[]", "int96", "int96[]", "bytes17[]", "bytes17", "address[2]", "bool[4]", "string[5]", "bytes[2]",
|
||||
"bytes32", "bytes32[]", "bytes32[4]",
|
||||
} {
|
||||
if !isPrimitiveTypeValid(tc) {
|
||||
t.Errorf("test %d: expected '%v' to be a valid primitive", i, tc)
|
||||
|
@ -141,3 +142,94 @@ func TestBlobTxs(t *testing.T) {
|
|||
}
|
||||
t.Logf("tx %v", string(data))
|
||||
}
|
||||
|
||||
func TestType_IsArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Expected positives
|
||||
for i, tc := range []Type{
|
||||
{
|
||||
Name: "type1",
|
||||
Type: "int24[]",
|
||||
},
|
||||
{
|
||||
Name: "type2",
|
||||
Type: "int24[2]",
|
||||
},
|
||||
{
|
||||
Name: "type3",
|
||||
Type: "int24[2][2][2]",
|
||||
},
|
||||
} {
|
||||
if !tc.isArray() {
|
||||
t.Errorf("test %d: expected '%v' to be an array", i, tc)
|
||||
}
|
||||
}
|
||||
// Expected negatives
|
||||
for i, tc := range []Type{
|
||||
{
|
||||
Name: "type1",
|
||||
Type: "int24",
|
||||
},
|
||||
{
|
||||
Name: "type2",
|
||||
Type: "uint88",
|
||||
},
|
||||
{
|
||||
Name: "type3",
|
||||
Type: "bytes32",
|
||||
},
|
||||
} {
|
||||
if tc.isArray() {
|
||||
t.Errorf("test %d: expected '%v' to not be an array", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestType_TypeName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tc := range []struct {
|
||||
Input Type
|
||||
Expected string
|
||||
}{
|
||||
{
|
||||
Input: Type{
|
||||
Name: "type1",
|
||||
Type: "int24[]",
|
||||
},
|
||||
Expected: "int24",
|
||||
},
|
||||
{
|
||||
Input: Type{
|
||||
Name: "type2",
|
||||
Type: "int26[2][2][2]",
|
||||
},
|
||||
Expected: "int26",
|
||||
},
|
||||
{
|
||||
Input: Type{
|
||||
Name: "type3",
|
||||
Type: "int24",
|
||||
},
|
||||
Expected: "int24",
|
||||
},
|
||||
{
|
||||
Input: Type{
|
||||
Name: "type4",
|
||||
Type: "uint88",
|
||||
},
|
||||
Expected: "uint88",
|
||||
},
|
||||
{
|
||||
Input: Type{
|
||||
Name: "type5",
|
||||
Type: "bytes32[2]",
|
||||
},
|
||||
Expected: "bytes32",
|
||||
},
|
||||
} {
|
||||
if tc.Input.typeName() != tc.Expected {
|
||||
t.Errorf("test %d: expected typeName value of '%v' but got '%v'", i, tc.Expected, tc.Input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue