diff --git a/go.mod b/go.mod index 1a26321def..c467cef169 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,6 @@ require ( github.com/jackpal/go-nat-pmp v1.0.2 github.com/jedisct1/go-minisign v0.0.0-20230811132847-661be99b8267 github.com/karalabe/hid v1.0.1-0.20240306101548-573246063e52 - github.com/kilic/bls12-381 v0.1.0 github.com/kylelemons/godebug v1.1.0 github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 @@ -116,6 +115,7 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.4 // indirect github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/kilic/bls12-381 v0.1.0 // indirect github.com/klauspost/compress v1.16.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/kr/pretty v0.3.1 // indirect diff --git a/oss-fuzz.sh b/oss-fuzz.sh index 50491b9155..5e4aa1c253 100644 --- a/oss-fuzz.sh +++ b/oss-fuzz.sh @@ -160,6 +160,10 @@ compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG1Add fuzz_g1_add\ $repo/tests/fuzzers/bls12381/bls12381_test.go +compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ + FuzzCrossG1Mul fuzz_cross_g1_mul\ + $repo/tests/fuzzers/bls12381/bls12381_test.go + compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG1Mul fuzz_g1_mul\ $repo/tests/fuzzers/bls12381/bls12381_test.go @@ -172,6 +176,10 @@ compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG2Add fuzz_g2_add \ $repo/tests/fuzzers/bls12381/bls12381_test.go +compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ + FuzzCrossG2Mul fuzz_cross_g2_mul\ + $repo/tests/fuzzers/bls12381/bls12381_test.go + compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG2Mul fuzz_g2_mul\ $repo/tests/fuzzers/bls12381/bls12381_test.go @@ -204,6 +212,10 @@ compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzCrossG2Add fuzz_cross_g2_add \ $repo/tests/fuzzers/bls12381/bls12381_test.go +compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ + FuzzCrossG2MultiExp fuzz_cross_g2_multiexp \ + $repo/tests/fuzzers/bls12381/bls12381_test.go + compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzCrossPairing fuzz_cross_pairing\ $repo/tests/fuzzers/bls12381/bls12381_test.go diff --git a/tests/fuzzers/bls12381/bls12381_fuzz.go b/tests/fuzzers/bls12381/bls12381_fuzz.go index 74ea6f52a7..a3e0e9f72b 100644 --- a/tests/fuzzers/bls12381/bls12381_fuzz.go +++ b/tests/fuzzers/bls12381/bls12381_fuzz.go @@ -31,42 +31,33 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/ethereum/go-ethereum/common" - bls12381 "github.com/kilic/bls12-381" blst "github.com/supranational/blst/bindings/go" ) func fuzzG1SubgroupChecks(data []byte) int { input := bytes.NewReader(data) - kpG1, cpG1, blG1, err := getG1Points(input) + cpG1, blG1, err := getG1Points(input) if err != nil { return 0 } - inSubGroupKilic := bls12381.NewG1().InCorrectSubgroup(kpG1) inSubGroupGnark := cpG1.IsInSubGroup() inSubGroupBLST := blG1.InG1() - if inSubGroupKilic != inSubGroupGnark { - panic(fmt.Sprintf("differing subgroup check, kilic %v, gnark %v", inSubGroupKilic, inSubGroupGnark)) - } - if inSubGroupKilic != inSubGroupBLST { - panic(fmt.Sprintf("differing subgroup check, kilic %v, blst %v", inSubGroupKilic, inSubGroupBLST)) + if inSubGroupGnark != inSubGroupBLST { + panic(fmt.Sprintf("differing subgroup check, gnark %v, blst %v", inSubGroupGnark, inSubGroupBLST)) } return 1 } func fuzzG2SubgroupChecks(data []byte) int { input := bytes.NewReader(data) - kpG2, cpG2, blG2, err := getG2Points(input) + gpG2, blG2, err := getG2Points(input) if err != nil { return 0 } - inSubGroupKilic := bls12381.NewG2().InCorrectSubgroup(kpG2) - inSubGroupGnark := cpG2.IsInSubGroup() + inSubGroupGnark := gpG2.IsInSubGroup() inSubGroupBLST := blG2.InG2() - if inSubGroupKilic != inSubGroupGnark { - panic(fmt.Sprintf("differing subgroup check, kilic %v, gnark %v", inSubGroupKilic, inSubGroupGnark)) - } - if inSubGroupKilic != inSubGroupBLST { - panic(fmt.Sprintf("differing subgroup check, kilic %v, blst %v", inSubGroupKilic, inSubGroupBLST)) + if inSubGroupGnark != inSubGroupBLST { + panic(fmt.Sprintf("differing subgroup check, gnark %v, blst %v", inSubGroupGnark, inSubGroupBLST)) } return 1 } @@ -75,38 +66,28 @@ func fuzzCrossPairing(data []byte) int { input := bytes.NewReader(data) // get random G1 points - kpG1, cpG1, blG1, err := getG1Points(input) + cpG1, blG1, err := getG1Points(input) if err != nil { return 0 } // get random G2 points - kpG2, cpG2, blG2, err := getG2Points(input) + cpG2, blG2, err := getG2Points(input) if err != nil { return 0 } - // compute pairing using geth - engine := bls12381.NewEngine() - engine.AddPair(kpG1, kpG2) - kResult := engine.Result() - // compute pairing using gnark cResult, err := gnark.Pair([]gnark.G1Affine{*cpG1}, []gnark.G2Affine{*cpG2}) if err != nil { panic(fmt.Sprintf("gnark/bls12381 encountered error: %v", err)) } - // compare result - if !(bytes.Equal(cResult.Marshal(), bls12381.NewGT().ToBytes(kResult))) { - panic("pairing mismatch gnark / geth ") - } - // compute pairing using blst blstResult := blst.Fp12MillerLoop(blG2, blG1) blstResult.FinalExp() res := massageBLST(blstResult.ToBendian()) - if !(bytes.Equal(res, bls12381.NewGT().ToBytes(kResult))) { + if !(bytes.Equal(res, cResult.Marshal())) { panic("pairing mismatch blst / geth") } @@ -141,32 +122,22 @@ func fuzzCrossG1Add(data []byte) int { input := bytes.NewReader(data) // get random G1 points - kp1, cp1, bl1, err := getG1Points(input) + cp1, bl1, err := getG1Points(input) if err != nil { return 0 } // get random G1 points - kp2, cp2, bl2, err := getG1Points(input) + cp2, bl2, err := getG1Points(input) if err != nil { return 0 } - // compute kp = kp1 + kp2 - g1 := bls12381.NewG1() - kp := bls12381.PointG1{} - g1.Add(&kp, kp1, kp2) - // compute cp = cp1 + cp2 _cp1 := new(gnark.G1Jac).FromAffine(cp1) _cp2 := new(gnark.G1Jac).FromAffine(cp2) cp := new(gnark.G1Affine).FromJacobian(_cp1.AddAssign(_cp2)) - // compare result - if !(bytes.Equal(cp.Marshal(), g1.ToBytes(&kp))) { - panic("G1 point addition mismatch gnark / geth ") - } - bl3 := blst.P1AffinesAdd([]*blst.P1Affine{bl1, bl2}) if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) { panic("G1 point addition mismatch blst / geth ") @@ -179,34 +150,24 @@ func fuzzCrossG2Add(data []byte) int { input := bytes.NewReader(data) // get random G2 points - kp1, cp1, bl1, err := getG2Points(input) + gp1, bl1, err := getG2Points(input) if err != nil { return 0 } // get random G2 points - kp2, cp2, bl2, err := getG2Points(input) + gp2, bl2, err := getG2Points(input) if err != nil { return 0 } - // compute kp = kp1 + kp2 - g2 := bls12381.NewG2() - kp := bls12381.PointG2{} - g2.Add(&kp, kp1, kp2) - // compute cp = cp1 + cp2 - _cp1 := new(gnark.G2Jac).FromAffine(cp1) - _cp2 := new(gnark.G2Jac).FromAffine(cp2) - cp := new(gnark.G2Affine).FromJacobian(_cp1.AddAssign(_cp2)) - - // compare result - if !(bytes.Equal(cp.Marshal(), g2.ToBytes(&kp))) { - panic("G2 point addition mismatch gnark / geth ") - } + _gp1 := new(gnark.G2Jac).FromAffine(gp1) + _gp2 := new(gnark.G2Jac).FromAffine(gp2) + gp := new(gnark.G2Affine).FromJacobian(_gp1.AddAssign(_gp2)) bl3 := blst.P2AffinesAdd([]*blst.P2Affine{bl1, bl2}) - if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) { + if !(bytes.Equal(gp.Marshal(), bl3.Serialize())) { panic("G1 point addition mismatch blst / geth ") } @@ -216,10 +177,10 @@ func fuzzCrossG2Add(data []byte) int { func fuzzCrossG1MultiExp(data []byte) int { var ( input = bytes.NewReader(data) - gethScalars []*bls12381.Fr gnarkScalars []fr.Element - gethPoints []*bls12381.PointG1 gnarkPoints []gnark.G1Affine + blstScalars []*blst.Scalar + blstPoints []*blst.P1Affine ) // n random scalars (max 17) for i := 0; i < 17; i++ { @@ -229,50 +190,147 @@ func fuzzCrossG1MultiExp(data []byte) int { break } // get a random G1 point as basis - kp1, cp1, _, err := getG1Points(input) + cp1, bl1, err := getG1Points(input) if err != nil { break } - gethScalars = append(gethScalars, bls12381.NewFr().FromBytes(s.Bytes())) - var gnarkScalar = &fr.Element{} - gnarkScalar = gnarkScalar.SetBigInt(s) - gnarkScalars = append(gnarkScalars, *gnarkScalar) - gethPoints = append(gethPoints, new(bls12381.PointG1).Set(kp1)) + gnarkScalar := new(fr.Element).SetBigInt(s) + gnarkScalars = append(gnarkScalars, *gnarkScalar) gnarkPoints = append(gnarkPoints, *cp1) + + blstScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) + blstScalars = append(blstScalars, blstScalar) + blstPoints = append(blstPoints, bl1) } - if len(gethScalars) == 0 { + + if len(gnarkScalars) == 0 || len(gnarkScalars) != len(gnarkPoints) { return 0 } - // compute multi exponentiation - g1 := bls12381.NewG1() - kp := bls12381.PointG1{} - if _, err := g1.MultiExp(&kp, gethPoints, gethScalars); err != nil { - panic(fmt.Sprintf("G1 multi exponentiation errored (geth): %v", err)) - } - // note that geth/crypto/bls12381.MultiExp mutates the scalars slice (and sets all the scalars to zero) // gnark multi exp cp := new(gnark.G1Affine) cp.MultiExp(gnarkPoints, gnarkScalars, ecc.MultiExpConfig{}) - // compare result - gnarkRes := cp.Marshal() - gethRes := g1.ToBytes(&kp) - if !bytes.Equal(gnarkRes, gethRes) { - msg := fmt.Sprintf("G1 multi exponentiation mismatch gnark/geth.\ngnark: %x\ngeth: %x\ninput: %x\n ", - gnarkRes, gethRes, data) - panic(msg) + expectedGnark := multiExpG1Gnark(gnarkPoints, gnarkScalars) + if !bytes.Equal(cp.Marshal(), expectedGnark.Marshal()) { + panic("g1 multi exponentiation mismatch") } + // blst multi exp + expectedBlst := blst.P1AffinesMult(blstPoints, blstScalars, 256).ToAffine() + if !bytes.Equal(cp.Marshal(), expectedBlst.Serialize()) { + panic("g1 multi exponentiation mismatch, gnark/blst") + } return 1 } -func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1Affine, error) { +func fuzzCrossG1Mul(data []byte) int { + input := bytes.NewReader(data) + gp, blpAffine, err := getG1Points(input) + if err != nil { + return 0 + } + scalar, err := randomScalar(input, fp.Modulus()) + if err != nil { + return 0 + } + + blScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(scalar.Bytes(), 32)) + + blp := new(blst.P1) + blp.FromAffine(blpAffine) + + resBl := blp.Mult(blScalar) + resGeth := (new(gnark.G1Affine)).ScalarMultiplication(gp, scalar) + + if !bytes.Equal(resGeth.Marshal(), resBl.Serialize()) { + panic("bytes(blst.G1) != bytes(geth.G1)") + } + return 1 +} + +func fuzzCrossG2Mul(data []byte) int { + input := bytes.NewReader(data) + gp, blpAffine, err := getG2Points(input) + if err != nil { + return 0 + } + scalar, err := randomScalar(input, fp.Modulus()) + if err != nil { + return 0 + } + + blScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(scalar.Bytes(), 32)) + + blp := new(blst.P2) + blp.FromAffine(blpAffine) + + resBl := blp.Mult(blScalar) + resGeth := (new(gnark.G2Affine)).ScalarMultiplication(gp, scalar) + + if !bytes.Equal(resGeth.Marshal(), resBl.Serialize()) { + panic("bytes(blst.G1) != bytes(geth.G1)") + } + return 1 +} + +func fuzzCrossG2MultiExp(data []byte) int { + var ( + input = bytes.NewReader(data) + gnarkScalars []fr.Element + gnarkPoints []gnark.G2Affine + blstScalars []*blst.Scalar + blstPoints []*blst.P2Affine + ) + // n random scalars (max 17) + for i := 0; i < 17; i++ { + // note that geth/crypto/bls12381 works only with scalars <= 32bytes + s, err := randomScalar(input, fr.Modulus()) + if err != nil { + break + } + // get a random G1 point as basis + cp1, bl1, err := getG2Points(input) + if err != nil { + break + } + + gnarkScalar := new(fr.Element).SetBigInt(s) + gnarkScalars = append(gnarkScalars, *gnarkScalar) + gnarkPoints = append(gnarkPoints, *cp1) + + blstScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) + blstScalars = append(blstScalars, blstScalar) + blstPoints = append(blstPoints, bl1) + } + + if len(gnarkScalars) == 0 || len(gnarkScalars) != len(gnarkPoints) { + return 0 + } + + // gnark multi exp + cp := new(gnark.G2Affine) + cp.MultiExp(gnarkPoints, gnarkScalars, ecc.MultiExpConfig{}) + + expectedGnark := multiExpG2Gnark(gnarkPoints, gnarkScalars) + if !bytes.Equal(cp.Marshal(), expectedGnark.Marshal()) { + panic("g1 multi exponentiation mismatch") + } + + // blst multi exp + expectedBlst := blst.P2AffinesMult(blstPoints, blstScalars, 256).ToAffine() + if !bytes.Equal(cp.Marshal(), expectedBlst.Serialize()) { + panic("g1 multi exponentiation mismatch, gnark/blst") + } + return 1 +} + +func getG1Points(input io.Reader) (*gnark.G1Affine, *blst.P1Affine, error) { // sample a random scalar s, err := randomScalar(input, fp.Modulus()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } // compute a random point @@ -281,18 +339,6 @@ func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1A cp.ScalarMultiplication(&g1Gen, s) cpBytes := cp.Marshal() - // marshal gnark point -> geth point - g1 := bls12381.NewG1() - kp, err := g1.FromBytes(cpBytes) - if err != nil { - panic(fmt.Sprintf("Could not marshal gnark.G1 -> geth.G1: %v", err)) - } - - gnarkRes := g1.ToBytes(kp) - if !bytes.Equal(gnarkRes, cpBytes) { - panic(fmt.Sprintf("bytes(gnark.G1) != bytes(geth.G1)\ngnark.G1: %x\ngeth.G1: %x\n", gnarkRes, cpBytes)) - } - // marshal gnark point -> blst point scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) p1 := new(blst.P1Affine).From(scalar) @@ -301,43 +347,31 @@ func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1A panic(fmt.Sprintf("bytes(blst.G1) != bytes(geth.G1)\nblst.G1: %x\ngeth.G1: %x\n", blstRes, cpBytes)) } - return kp, cp, p1, nil + return cp, p1, nil } -func getG2Points(input io.Reader) (*bls12381.PointG2, *gnark.G2Affine, *blst.P2Affine, error) { +func getG2Points(input io.Reader) (*gnark.G2Affine, *blst.P2Affine, error) { // sample a random scalar s, err := randomScalar(input, fp.Modulus()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } // compute a random point - cp := new(gnark.G2Affine) + gp := new(gnark.G2Affine) _, _, _, g2Gen := gnark.Generators() - cp.ScalarMultiplication(&g2Gen, s) - cpBytes := cp.Marshal() - - // marshal gnark point -> geth point - g2 := bls12381.NewG2() - kp, err := g2.FromBytes(cpBytes) - if err != nil { - panic(fmt.Sprintf("Could not marshal gnark.G2 -> geth.G2: %v", err)) - } - - gnarkRes := g2.ToBytes(kp) - if !bytes.Equal(gnarkRes, cpBytes) { - panic(fmt.Sprintf("bytes(gnark.G2) != bytes(geth.G2)\ngnark.G2: %x\ngeth.G2: %x\n", gnarkRes, cpBytes)) - } + gp.ScalarMultiplication(&g2Gen, s) + cpBytes := gp.Marshal() // marshal gnark point -> blst point // Left pad the scalar to 32 bytes scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) p2 := new(blst.P2Affine).From(scalar) if !bytes.Equal(p2.Serialize(), cpBytes) { - panic("bytes(blst.G2) != bytes(geth.G2)") + panic("bytes(blst.G2) != bytes(bls12381.G2)") } - return kp, cp, p2, nil + return gp, p2, nil } func randomScalar(r io.Reader, max *big.Int) (k *big.Int, err error) { @@ -348,3 +382,29 @@ func randomScalar(r io.Reader, max *big.Int) (k *big.Int, err error) { } } } + +// multiExpG1Gnark is a naive implementation of G1 multi-exponentiation +func multiExpG1Gnark(gs []gnark.G1Affine, scalars []fr.Element) gnark.G1Affine { + res := gnark.G1Affine{} + for i := 0; i < len(gs); i++ { + tmp := new(gnark.G1Affine) + sb := scalars[i].Bytes() + scalarBytes := new(big.Int).SetBytes(sb[:]) + tmp.ScalarMultiplication(&gs[i], scalarBytes) + res.Add(&res, tmp) + } + return res +} + +// multiExpG1Gnark is a naive implementation of G1 multi-exponentiation +func multiExpG2Gnark(gs []gnark.G2Affine, scalars []fr.Element) gnark.G2Affine { + res := gnark.G2Affine{} + for i := 0; i < len(gs); i++ { + tmp := new(gnark.G2Affine) + sb := scalars[i].Bytes() + scalarBytes := new(big.Int).SetBytes(sb[:]) + tmp.ScalarMultiplication(&gs[i], scalarBytes) + res.Add(&res, tmp) + } + return res +} diff --git a/tests/fuzzers/bls12381/bls12381_test.go b/tests/fuzzers/bls12381/bls12381_test.go index fd782f7813..d4e5e20e04 100644 --- a/tests/fuzzers/bls12381/bls12381_test.go +++ b/tests/fuzzers/bls12381/bls12381_test.go @@ -27,6 +27,12 @@ func FuzzCrossPairing(f *testing.F) { }) } +func FuzzCrossG2MultiExp(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzzCrossG2MultiExp(data) + }) +} + func FuzzCrossG1Add(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { fuzzCrossG1Add(data) @@ -51,9 +57,9 @@ func FuzzG1Add(f *testing.F) { }) } -func FuzzG1Mul(f *testing.F) { +func FuzzCrossG1Mul(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { - fuzz(blsG1Mul, data) + fuzzCrossG1Mul(data) }) } @@ -69,9 +75,9 @@ func FuzzG2Add(f *testing.F) { }) } -func FuzzG2Mul(f *testing.F) { +func FuzzCrossG2Mul(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { - fuzz(blsG2Mul, data) + fuzzCrossG2Mul(data) }) } @@ -110,3 +116,15 @@ func FuzzG2SubgroupChecks(f *testing.F) { fuzzG2SubgroupChecks(data) }) } + +func FuzzG2Mul(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzz(blsG2Mul, data) + }) +} + +func FuzzG1Mul(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzz(blsG1Mul, data) + }) +}