736 lines
20 KiB
Go
736 lines
20 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/format"
|
|
"go/types"
|
|
"sort"
|
|
|
|
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
|
|
)
|
|
|
|
// buildContext keeps the data needed for make*Op.
|
|
type buildContext struct {
|
|
topType *types.Named // the type we're creating methods for
|
|
|
|
encoderIface *types.Interface
|
|
decoderIface *types.Interface
|
|
rawValueType *types.Named
|
|
|
|
typeToStructCache map[types.Type]*rlpstruct.Type
|
|
}
|
|
|
|
func newBuildContext(packageRLP *types.Package) *buildContext {
|
|
enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying()
|
|
dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying()
|
|
rawv := packageRLP.Scope().Lookup("RawValue").Type()
|
|
return &buildContext{
|
|
typeToStructCache: make(map[types.Type]*rlpstruct.Type),
|
|
encoderIface: enc.(*types.Interface),
|
|
decoderIface: dec.(*types.Interface),
|
|
rawValueType: rawv.(*types.Named),
|
|
}
|
|
}
|
|
|
|
func (bctx *buildContext) isEncoder(typ types.Type) bool {
|
|
return types.Implements(typ, bctx.encoderIface)
|
|
}
|
|
|
|
func (bctx *buildContext) isDecoder(typ types.Type) bool {
|
|
return types.Implements(typ, bctx.decoderIface)
|
|
}
|
|
|
|
// typeToStructType converts typ to rlpstruct.Type.
|
|
func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
|
|
if prev := bctx.typeToStructCache[typ]; prev != nil {
|
|
return prev // short-circuit for recursive types.
|
|
}
|
|
|
|
// Resolve named types to their underlying type, but keep the name.
|
|
name := types.TypeString(typ, nil)
|
|
for {
|
|
utype := typ.Underlying()
|
|
if utype == typ {
|
|
break
|
|
}
|
|
typ = utype
|
|
}
|
|
|
|
// Create the type and store it in cache.
|
|
t := &rlpstruct.Type{
|
|
Name: name,
|
|
Kind: typeReflectKind(typ),
|
|
IsEncoder: bctx.isEncoder(typ),
|
|
IsDecoder: bctx.isDecoder(typ),
|
|
}
|
|
bctx.typeToStructCache[typ] = t
|
|
|
|
// Assign element type.
|
|
switch typ.(type) {
|
|
case *types.Array, *types.Slice, *types.Pointer:
|
|
etype := typ.(interface{ Elem() types.Type }).Elem()
|
|
t.Elem = bctx.typeToStructType(etype)
|
|
}
|
|
return t
|
|
}
|
|
|
|
// genContext is passed to the gen* methods of op when generating
|
|
// the output code. It tracks packages to be imported by the output
|
|
// file and assigns unique names of temporary variables.
|
|
type genContext struct {
|
|
inPackage *types.Package
|
|
imports map[string]struct{}
|
|
tempCounter int
|
|
}
|
|
|
|
func newGenContext(inPackage *types.Package) *genContext {
|
|
return &genContext{
|
|
inPackage: inPackage,
|
|
imports: make(map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
func (ctx *genContext) temp() string {
|
|
v := fmt.Sprintf("_tmp%d", ctx.tempCounter)
|
|
ctx.tempCounter++
|
|
return v
|
|
}
|
|
|
|
func (ctx *genContext) resetTemp() {
|
|
ctx.tempCounter = 0
|
|
}
|
|
|
|
func (ctx *genContext) addImport(path string) {
|
|
if path == ctx.inPackage.Path() {
|
|
return // avoid importing the package that we're generating in.
|
|
}
|
|
// TODO: renaming?
|
|
ctx.imports[path] = struct{}{}
|
|
}
|
|
|
|
// importsList returns all packages that need to be imported.
|
|
func (ctx *genContext) importsList() []string {
|
|
imp := make([]string, 0, len(ctx.imports))
|
|
for k := range ctx.imports {
|
|
imp = append(imp, k)
|
|
}
|
|
sort.Strings(imp)
|
|
return imp
|
|
}
|
|
|
|
// qualify is the types.Qualifier used for printing types.
|
|
func (ctx *genContext) qualify(pkg *types.Package) string {
|
|
if pkg.Path() == ctx.inPackage.Path() {
|
|
return ""
|
|
}
|
|
ctx.addImport(pkg.Path())
|
|
// TODO: renaming?
|
|
return pkg.Name()
|
|
}
|
|
|
|
type op interface {
|
|
// genWrite creates the encoder. The generated code should write v,
|
|
// which is any Go expression, to the rlp.EncoderBuffer 'w'.
|
|
genWrite(ctx *genContext, v string) string
|
|
|
|
// genDecode creates the decoder. The generated code should read
|
|
// a value from the rlp.Stream 'dec' and store it to dst.
|
|
genDecode(ctx *genContext) (string, string)
|
|
}
|
|
|
|
// basicOp handles basic types bool, uint*, string.
|
|
type basicOp struct {
|
|
typ types.Type
|
|
writeMethod string // calle write the value
|
|
writeArgType types.Type // parameter type of writeMethod
|
|
decMethod string
|
|
decResultType types.Type // return type of decMethod
|
|
decUseBitSize bool // if true, result bit size is appended to decMethod
|
|
}
|
|
|
|
func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) {
|
|
op := basicOp{typ: typ}
|
|
kind := typ.Kind()
|
|
switch {
|
|
case kind == types.Bool:
|
|
op.writeMethod = "WriteBool"
|
|
op.writeArgType = types.Typ[types.Bool]
|
|
op.decMethod = "Bool"
|
|
op.decResultType = types.Typ[types.Bool]
|
|
case kind >= types.Uint8 && kind <= types.Uint64:
|
|
op.writeMethod = "WriteUint64"
|
|
op.writeArgType = types.Typ[types.Uint64]
|
|
op.decMethod = "Uint"
|
|
op.decResultType = typ
|
|
op.decUseBitSize = true
|
|
case kind == types.String:
|
|
op.writeMethod = "WriteString"
|
|
op.writeArgType = types.Typ[types.String]
|
|
op.decMethod = "String"
|
|
op.decResultType = types.Typ[types.String]
|
|
default:
|
|
return nil, fmt.Errorf("unhandled basic type: %v", typ)
|
|
}
|
|
return op, nil
|
|
}
|
|
|
|
func (*buildContext) makeByteSliceOp(typ *types.Slice) op {
|
|
if !isByte(typ.Elem()) {
|
|
panic("non-byte slice type in makeByteSliceOp")
|
|
}
|
|
bslice := types.NewSlice(types.Typ[types.Uint8])
|
|
return basicOp{
|
|
typ: typ,
|
|
writeMethod: "WriteBytes",
|
|
writeArgType: bslice,
|
|
decMethod: "Bytes",
|
|
decResultType: bslice,
|
|
}
|
|
}
|
|
|
|
func (bctx *buildContext) makeRawValueOp() op {
|
|
bslice := types.NewSlice(types.Typ[types.Uint8])
|
|
return basicOp{
|
|
typ: bctx.rawValueType,
|
|
writeMethod: "Write",
|
|
writeArgType: bslice,
|
|
decMethod: "Raw",
|
|
decResultType: bslice,
|
|
}
|
|
}
|
|
|
|
func (op basicOp) writeNeedsConversion() bool {
|
|
return !types.AssignableTo(op.typ, op.writeArgType)
|
|
}
|
|
|
|
func (op basicOp) decodeNeedsConversion() bool {
|
|
return !types.AssignableTo(op.decResultType, op.typ)
|
|
}
|
|
|
|
func (op basicOp) genWrite(ctx *genContext, v string) string {
|
|
if op.writeNeedsConversion() {
|
|
v = fmt.Sprintf("%s(%s)", op.writeArgType, v)
|
|
}
|
|
return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v)
|
|
}
|
|
|
|
func (op basicOp) genDecode(ctx *genContext) (string, string) {
|
|
var (
|
|
resultV = ctx.temp()
|
|
result = resultV
|
|
method = op.decMethod
|
|
)
|
|
if op.decUseBitSize {
|
|
// Note: For now, this only works for platform-independent integer
|
|
// sizes. makeBasicOp forbids the platform-dependent types.
|
|
var sizes types.StdSizes
|
|
method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8)
|
|
}
|
|
|
|
// Call the decoder method.
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method)
|
|
fmt.Fprintf(&b, "if err != nil { return err }\n")
|
|
if op.decodeNeedsConversion() {
|
|
conv := ctx.temp()
|
|
fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV)
|
|
result = conv
|
|
}
|
|
return result, b.String()
|
|
}
|
|
|
|
// byteArrayOp handles [...]byte.
|
|
type byteArrayOp struct {
|
|
typ types.Type
|
|
name types.Type // name != typ for named byte array types (e.g. common.Address)
|
|
}
|
|
|
|
func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp {
|
|
nt := types.Type(name)
|
|
if name == nil {
|
|
nt = typ
|
|
}
|
|
return byteArrayOp{typ, nt}
|
|
}
|
|
|
|
func (op byteArrayOp) genWrite(ctx *genContext, v string) string {
|
|
return fmt.Sprintf("w.WriteBytes(%s[:])\n", v)
|
|
}
|
|
|
|
func (op byteArrayOp) genDecode(ctx *genContext) (string, string) {
|
|
var resultV = ctx.temp()
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify))
|
|
fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV)
|
|
return resultV, b.String()
|
|
}
|
|
|
|
// bigIntNoPtrOp handles non-pointer big.Int.
|
|
// This exists because big.Int has it's own decoder operation on rlp.Stream,
|
|
// but the decode method returns *big.Int, so it needs to be dereferenced.
|
|
type bigIntOp struct {
|
|
pointer bool
|
|
}
|
|
|
|
func (op bigIntOp) genWrite(ctx *genContext, v string) string {
|
|
var b bytes.Buffer
|
|
|
|
fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v)
|
|
fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n")
|
|
fmt.Fprintf(&b, "}\n")
|
|
dst := v
|
|
if !op.pointer {
|
|
dst = "&" + v
|
|
}
|
|
fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst)
|
|
|
|
// Wrap with nil check.
|
|
if op.pointer {
|
|
code := b.String()
|
|
b.Reset()
|
|
fmt.Fprintf(&b, "if %s == nil {\n", v)
|
|
fmt.Fprintf(&b, " w.Write(rlp.EmptyString)")
|
|
fmt.Fprintf(&b, "} else {\n")
|
|
fmt.Fprint(&b, code)
|
|
fmt.Fprintf(&b, "}\n")
|
|
}
|
|
|
|
return b.String()
|
|
}
|
|
|
|
func (op bigIntOp) genDecode(ctx *genContext) (string, string) {
|
|
var resultV = ctx.temp()
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV)
|
|
fmt.Fprintf(&b, "if err != nil { return err }\n")
|
|
|
|
result := resultV
|
|
if !op.pointer {
|
|
result = "(*" + resultV + ")"
|
|
}
|
|
return result, b.String()
|
|
}
|
|
|
|
// encoderDecoderOp handles rlp.Encoder and rlp.Decoder.
|
|
// In order to be used with this, the type must implement both interfaces.
|
|
// This restriction may be lifted in the future by creating separate ops for
|
|
// encoding and decoding.
|
|
type encoderDecoderOp struct {
|
|
typ types.Type
|
|
}
|
|
|
|
func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string {
|
|
return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v)
|
|
}
|
|
|
|
func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) {
|
|
// DecodeRLP must have pointer receiver, and this is verified in makeOp.
|
|
etyp := op.typ.(*types.Pointer).Elem()
|
|
var resultV = ctx.temp()
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify))
|
|
fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV)
|
|
return resultV, b.String()
|
|
}
|
|
|
|
// ptrOp handles pointer types.
|
|
type ptrOp struct {
|
|
elemTyp types.Type
|
|
elem op
|
|
nilOK bool
|
|
nilValue rlpstruct.NilKind
|
|
}
|
|
|
|
func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) {
|
|
elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
op := ptrOp{elemTyp: elemTyp, elem: elemOp}
|
|
|
|
// Determine nil value.
|
|
if tags.NilOK {
|
|
op.nilOK = true
|
|
op.nilValue = tags.NilKind
|
|
} else {
|
|
styp := bctx.typeToStructType(elemTyp)
|
|
op.nilValue = styp.DefaultNilValue()
|
|
}
|
|
return op, nil
|
|
}
|
|
|
|
func (op ptrOp) genWrite(ctx *genContext, v string) string {
|
|
// Note: in writer functions, accesses to v are read-only, i.e. v is any Go
|
|
// expression. To make all accesses work through the pointer, we substitute
|
|
// v with (*v). This is required for most accesses including `v`, `call(v)`,
|
|
// and `v[index]` on slices.
|
|
//
|
|
// For `v.field` and `v[:]` on arrays, the dereference operation is not required.
|
|
var vv string
|
|
_, isStruct := op.elem.(structOp)
|
|
_, isByteArray := op.elem.(byteArrayOp)
|
|
if isStruct || isByteArray {
|
|
vv = v
|
|
} else {
|
|
vv = fmt.Sprintf("(*%s)", v)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "if %s == nil {\n", v)
|
|
fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue)
|
|
fmt.Fprintf(&b, "} else {\n")
|
|
fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv))
|
|
fmt.Fprintf(&b, "}\n")
|
|
return b.String()
|
|
}
|
|
|
|
func (op ptrOp) genDecode(ctx *genContext) (string, string) {
|
|
result, code := op.elem.genDecode(ctx)
|
|
if !op.nilOK {
|
|
// If nil pointers are not allowed, we can just decode the element.
|
|
return "&" + result, code
|
|
}
|
|
|
|
// nil is allowed, so check the kind and size first.
|
|
// If size is zero and kind matches the nilKind of the type,
|
|
// the value decodes as a nil pointer.
|
|
var (
|
|
resultV = ctx.temp()
|
|
kindV = ctx.temp()
|
|
sizeV = ctx.temp()
|
|
wantKind string
|
|
)
|
|
if op.nilValue == rlpstruct.NilKindList {
|
|
wantKind = "rlp.List"
|
|
} else {
|
|
wantKind = "rlp.String"
|
|
}
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify))
|
|
fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV)
|
|
fmt.Fprintf(&b, " return err\n")
|
|
fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind)
|
|
fmt.Fprint(&b, code)
|
|
fmt.Fprintf(&b, " %s = &%s\n", resultV, result)
|
|
fmt.Fprintf(&b, "}\n")
|
|
return resultV, b.String()
|
|
}
|
|
|
|
// structOp handles struct types.
|
|
type structOp struct {
|
|
named *types.Named
|
|
typ *types.Struct
|
|
fields []*structField
|
|
optionalFields []*structField
|
|
}
|
|
|
|
type structField struct {
|
|
name string
|
|
typ types.Type
|
|
elem op
|
|
}
|
|
|
|
func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) {
|
|
// Convert fields to []rlpstruct.Field.
|
|
var allStructFields []rlpstruct.Field
|
|
for i := 0; i < typ.NumFields(); i++ {
|
|
f := typ.Field(i)
|
|
allStructFields = append(allStructFields, rlpstruct.Field{
|
|
Name: f.Name(),
|
|
Exported: f.Exported(),
|
|
Index: i,
|
|
Tag: typ.Tag(i),
|
|
Type: *bctx.typeToStructType(f.Type()),
|
|
})
|
|
}
|
|
|
|
// Filter/validate fields.
|
|
fields, tags, err := rlpstruct.ProcessFields(allStructFields)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create field ops.
|
|
var op = structOp{named: named, typ: typ}
|
|
for i, field := range fields {
|
|
// Advanced struct tags are not supported yet.
|
|
tag := tags[i]
|
|
if err := checkUnsupportedTags(field.Name, tag); err != nil {
|
|
return nil, err
|
|
}
|
|
typ := typ.Field(field.Index).Type()
|
|
elem, err := bctx.makeOp(nil, typ, tags[i])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("field %s: %v", field.Name, err)
|
|
}
|
|
f := &structField{name: field.Name, typ: typ, elem: elem}
|
|
if tag.Optional {
|
|
op.optionalFields = append(op.optionalFields, f)
|
|
} else {
|
|
op.fields = append(op.fields, f)
|
|
}
|
|
}
|
|
return op, nil
|
|
}
|
|
|
|
func checkUnsupportedTags(field string, tag rlpstruct.Tags) error {
|
|
if tag.Tail {
|
|
return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (op structOp) genWrite(ctx *genContext, v string) string {
|
|
var b bytes.Buffer
|
|
var listMarker = ctx.temp()
|
|
fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
|
|
for _, field := range op.fields {
|
|
selector := v + "." + field.name
|
|
fmt.Fprint(&b, field.elem.genWrite(ctx, selector))
|
|
}
|
|
op.writeOptionalFields(&b, ctx, v)
|
|
fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
|
|
return b.String()
|
|
}
|
|
|
|
func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) {
|
|
if len(op.optionalFields) == 0 {
|
|
return
|
|
}
|
|
// First check zero-ness of all optional fields.
|
|
var zeroV = make([]string, len(op.optionalFields))
|
|
for i, field := range op.optionalFields {
|
|
selector := v + "." + field.name
|
|
zeroV[i] = ctx.temp()
|
|
fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify))
|
|
}
|
|
// Now write the fields.
|
|
for i, field := range op.optionalFields {
|
|
selector := v + "." + field.name
|
|
cond := ""
|
|
for j := i; j < len(op.optionalFields); j++ {
|
|
if j > i {
|
|
cond += " || "
|
|
}
|
|
cond += zeroV[j]
|
|
}
|
|
fmt.Fprintf(b, "if %s {\n", cond)
|
|
fmt.Fprint(b, field.elem.genWrite(ctx, selector))
|
|
fmt.Fprintf(b, "}\n")
|
|
}
|
|
}
|
|
|
|
func (op structOp) genDecode(ctx *genContext) (string, string) {
|
|
// Get the string representation of the type.
|
|
// Here, named types are handled separately because the output
|
|
// would contain a copy of the struct definition otherwise.
|
|
var typeName string
|
|
if op.named != nil {
|
|
typeName = types.TypeString(op.named, ctx.qualify)
|
|
} else {
|
|
typeName = types.TypeString(op.typ, ctx.qualify)
|
|
}
|
|
|
|
// Create struct object.
|
|
var resultV = ctx.temp()
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "var %s %s\n", resultV, typeName)
|
|
|
|
// Decode fields.
|
|
fmt.Fprintf(&b, "{\n")
|
|
fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
|
|
for _, field := range op.fields {
|
|
result, code := field.elem.genDecode(ctx)
|
|
fmt.Fprintf(&b, "// %s:\n", field.name)
|
|
fmt.Fprint(&b, code)
|
|
fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result)
|
|
}
|
|
op.decodeOptionalFields(&b, ctx, resultV)
|
|
fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
|
|
fmt.Fprintf(&b, "}\n")
|
|
return resultV, b.String()
|
|
}
|
|
|
|
func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) {
|
|
var suffix bytes.Buffer
|
|
for _, field := range op.optionalFields {
|
|
result, code := field.elem.genDecode(ctx)
|
|
fmt.Fprintf(b, "// %s:\n", field.name)
|
|
fmt.Fprintf(b, "if dec.MoreDataInList() {\n")
|
|
fmt.Fprint(b, code)
|
|
fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result)
|
|
fmt.Fprintf(&suffix, "}\n")
|
|
}
|
|
suffix.WriteTo(b)
|
|
}
|
|
|
|
// sliceOp handles slice types.
|
|
type sliceOp struct {
|
|
typ *types.Slice
|
|
elemOp op
|
|
}
|
|
|
|
func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) {
|
|
elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sliceOp{typ: typ, elemOp: elemOp}, nil
|
|
}
|
|
|
|
func (op sliceOp) genWrite(ctx *genContext, v string) string {
|
|
var (
|
|
listMarker = ctx.temp() // holds return value of w.List()
|
|
iterElemV = ctx.temp() // iteration variable
|
|
elemCode = op.elemOp.genWrite(ctx, iterElemV)
|
|
)
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
|
|
fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v)
|
|
fmt.Fprint(&b, elemCode)
|
|
fmt.Fprintf(&b, "}\n")
|
|
fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
|
|
return b.String()
|
|
}
|
|
|
|
func (op sliceOp) genDecode(ctx *genContext) (string, string) {
|
|
var sliceV = ctx.temp() // holds the output slice
|
|
elemResult, elemCode := op.elemOp.genDecode(ctx)
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify))
|
|
fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
|
|
fmt.Fprintf(&b, "for dec.MoreDataInList() {\n")
|
|
fmt.Fprintf(&b, " %s", elemCode)
|
|
fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult)
|
|
fmt.Fprintf(&b, "}\n")
|
|
fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
|
|
return sliceV, b.String()
|
|
}
|
|
|
|
func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) {
|
|
switch typ := typ.(type) {
|
|
case *types.Named:
|
|
if isBigInt(typ) {
|
|
return bigIntOp{}, nil
|
|
}
|
|
if typ == bctx.rawValueType {
|
|
return bctx.makeRawValueOp(), nil
|
|
}
|
|
if bctx.isDecoder(typ) {
|
|
return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ)
|
|
}
|
|
// TODO: same check for encoder?
|
|
return bctx.makeOp(typ, typ.Underlying(), tags)
|
|
case *types.Pointer:
|
|
if isBigInt(typ.Elem()) {
|
|
return bigIntOp{pointer: true}, nil
|
|
}
|
|
// Encoder/Decoder interfaces.
|
|
if bctx.isEncoder(typ) {
|
|
if bctx.isDecoder(typ) {
|
|
return encoderDecoderOp{typ}, nil
|
|
}
|
|
return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ)
|
|
}
|
|
if bctx.isDecoder(typ) {
|
|
return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ)
|
|
}
|
|
// Default pointer handling.
|
|
return bctx.makePtrOp(typ.Elem(), tags)
|
|
case *types.Basic:
|
|
return bctx.makeBasicOp(typ)
|
|
case *types.Struct:
|
|
return bctx.makeStructOp(name, typ)
|
|
case *types.Slice:
|
|
etyp := typ.Elem()
|
|
if isByte(etyp) && !bctx.isEncoder(etyp) {
|
|
return bctx.makeByteSliceOp(typ), nil
|
|
}
|
|
return bctx.makeSliceOp(typ)
|
|
case *types.Array:
|
|
etyp := typ.Elem()
|
|
if isByte(etyp) && !bctx.isEncoder(etyp) {
|
|
return bctx.makeByteArrayOp(name, typ), nil
|
|
}
|
|
return nil, fmt.Errorf("unhandled array type: %v", typ)
|
|
default:
|
|
return nil, fmt.Errorf("unhandled type: %v", typ)
|
|
}
|
|
}
|
|
|
|
// generateDecoder generates the DecodeRLP method on 'typ'.
|
|
func generateDecoder(ctx *genContext, typ string, op op) []byte {
|
|
ctx.resetTemp()
|
|
ctx.addImport(pathOfPackageRLP)
|
|
|
|
result, code := op.genDecode(ctx)
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ)
|
|
fmt.Fprint(&b, code)
|
|
fmt.Fprintf(&b, " *obj = %s\n", result)
|
|
fmt.Fprintf(&b, " return nil\n")
|
|
fmt.Fprintf(&b, "}\n")
|
|
return b.Bytes()
|
|
}
|
|
|
|
// generateEncoder generates the EncodeRLP method on 'typ'.
|
|
func generateEncoder(ctx *genContext, typ string, op op) []byte {
|
|
ctx.resetTemp()
|
|
ctx.addImport("io")
|
|
ctx.addImport(pathOfPackageRLP)
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
|
|
fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n")
|
|
fmt.Fprint(&b, op.genWrite(ctx, "obj"))
|
|
fmt.Fprintf(&b, " return w.Flush()\n")
|
|
fmt.Fprintf(&b, "}\n")
|
|
return b.Bytes()
|
|
}
|
|
|
|
func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) {
|
|
bctx.topType = typ
|
|
|
|
pkg := typ.Obj().Pkg()
|
|
op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var (
|
|
ctx = newGenContext(pkg)
|
|
encSource []byte
|
|
decSource []byte
|
|
)
|
|
if encoder {
|
|
encSource = generateEncoder(ctx, typ.Obj().Name(), op)
|
|
}
|
|
if decoder {
|
|
decSource = generateDecoder(ctx, typ.Obj().Name(), op)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
|
|
for _, imp := range ctx.importsList() {
|
|
fmt.Fprintf(&b, "import %q\n", imp)
|
|
}
|
|
if encoder {
|
|
fmt.Fprintln(&b)
|
|
b.Write(encSource)
|
|
}
|
|
if decoder {
|
|
fmt.Fprintln(&b)
|
|
b.Write(decSource)
|
|
}
|
|
|
|
source := b.Bytes()
|
|
// fmt.Println(string(source))
|
|
return format.Source(source)
|
|
}
|