diff --git a/README.md b/README.md index 4aa311e..bc761de 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,34 @@ $ ./example --version someprogram 4.3.0 ``` +### Embedded structs + +The fields of embedded structs are treated just like regular fields: + +```go + +type DatabaseOptions struct { + Host string + Username string + Password string +} + +type LogOptions struct { + LogFile string + Verbose bool +} + +func main() { + var args struct { + DatabaseOptions + LogOptions + } + arg.MustParse(&args) +} +``` + +As usual, any field tagged with `arg:"-"` is ignored. + ### Custom parsing You can implement your own argument parser by implementing `encoding.TextUnmarshaler`: diff --git a/parse.go b/parse.go index f5fdd7f..26b530a 100644 --- a/parse.go +++ b/parse.go @@ -21,7 +21,6 @@ type spec struct { env string wasPresent bool boolean bool - fieldName string // for generating helpful errors } // ErrHelp indicates that -h or --help were provided @@ -81,6 +80,19 @@ type Versioned interface { Version() string } +// walkFields calls a function for each field of a struct, recursively expanding struct fields. +func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) { + t := v.Type() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + val := v.Field(i) + expand := visit(field, val, t) + if expand && field.Type.Kind() == reflect.Struct { + walkFields(val, visit) + } + } +} + // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { p := Parser{ @@ -99,19 +111,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { panic(fmt.Sprintf("%T is not a struct pointer", dest)) } - t := v.Type() - for i := 0; i < t.NumField(); i++ { + var errs []string + walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool { // Check for the ignore switch in the tag - field := t.Field(i) tag := field.Tag.Get("arg") if tag == "-" { - continue + return false + } + + // If this is an embedded struct then recurse into its fields + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true } spec := spec{ - long: strings.ToLower(field.Name), - dest: v.Field(i), - fieldName: t.Name() + "." + field.Name, + long: strings.ToLower(field.Name), + dest: val, } // Check whether this field is supported. It's good to do this here rather than @@ -121,7 +136,9 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { var parseable bool parseable, spec.boolean, spec.multiple = canParse(field.Type) if !parseable { - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String()) + errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", + t.Name(), field.Name, field.Type.String())) + return false } // Look at the tag @@ -138,7 +155,9 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { spec.long = key[2:] case strings.HasPrefix(key, "-"): if len(key) != 2 { - return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name) + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) + return false } spec.short = key[1:] case key == "required": @@ -155,11 +174,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { spec.env = strings.ToUpper(field.Name) } default: - return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false } } } p.spec = append(p.spec, &spec) + + // if this was an embedded field then we already returned true up above + return false + }) + + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) } } if p.config.Program == "" { diff --git a/parse_test.go b/parse_test.go index ab0cfd7..dffebf4 100644 --- a/parse_test.go +++ b/parse_test.go @@ -633,3 +633,24 @@ func TestInvalidMailAddr(t *testing.T) { err := parse("--recipient xxx", &args) assert.Error(t, err) } + +type A struct { + X string +} + +type B struct { + Y int +} + +func TestEmbedded(t *testing.T) { + var args struct { + A + B + Z bool + } + err := parse("--x=hello --y=321 --z", &args) + require.NoError(t, err) + assert.Equal(t, "hello", args.X) + assert.Equal(t, 321, args.Y) + assert.Equal(t, true, args.Z) +}