Merge pull request #39 from alexflint/embedded

add support for embedded structs
This commit is contained in:
Alex Flint 2016-10-11 09:09:17 +10:30 committed by GitHub
commit 7c77c70f85
3 changed files with 87 additions and 11 deletions

View File

@ -188,6 +188,34 @@ $ ./example --version
someprogram 4.3.0 someprogram 4.3.0
``` ```
### Embedded structs
The fields of embedded structs are treated just like regular fields:
```go
type DatabaseOptions struct {
Host string
Username string
Password string
}
type LogOptions struct {
LogFile string
Verbose bool
}
func main() {
var args struct {
DatabaseOptions
LogOptions
}
arg.MustParse(&args)
}
```
As usual, any field tagged with `arg:"-"` is ignored.
### Custom parsing ### Custom parsing
You can implement your own argument parser by implementing `encoding.TextUnmarshaler`: You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:

View File

@ -21,7 +21,6 @@ type spec struct {
env string env string
wasPresent bool wasPresent bool
boolean bool boolean bool
fieldName string // for generating helpful errors
} }
// ErrHelp indicates that -h or --help were provided // ErrHelp indicates that -h or --help were provided
@ -81,6 +80,19 @@ type Versioned interface {
Version() string Version() string
} }
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
val := v.Field(i)
expand := visit(field, val, t)
if expand && field.Type.Kind() == reflect.Struct {
walkFields(val, visit)
}
}
}
// NewParser constructs a parser from a list of destination structs // NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) { func NewParser(config Config, dests ...interface{}) (*Parser, error) {
p := Parser{ p := Parser{
@ -99,19 +111,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
panic(fmt.Sprintf("%T is not a struct pointer", dest)) panic(fmt.Sprintf("%T is not a struct pointer", dest))
} }
t := v.Type() var errs []string
for i := 0; i < t.NumField(); i++ { walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool {
// Check for the ignore switch in the tag // Check for the ignore switch in the tag
field := t.Field(i)
tag := field.Tag.Get("arg") tag := field.Tag.Get("arg")
if tag == "-" { if tag == "-" {
continue return false
}
// If this is an embedded struct then recurse into its fields
if field.Anonymous && field.Type.Kind() == reflect.Struct {
return true
} }
spec := spec{ spec := spec{
long: strings.ToLower(field.Name), long: strings.ToLower(field.Name),
dest: v.Field(i), dest: val,
fieldName: t.Name() + "." + field.Name,
} }
// Check whether this field is supported. It's good to do this here rather than // Check whether this field is supported. It's good to do this here rather than
@ -121,7 +136,9 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
var parseable bool var parseable bool
parseable, spec.boolean, spec.multiple = canParse(field.Type) parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable { if !parseable {
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String()) errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
} }
// Look at the tag // Look at the tag
@ -138,7 +155,9 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
spec.long = key[2:] spec.long = key[2:]
case strings.HasPrefix(key, "-"): case strings.HasPrefix(key, "-"):
if len(key) != 2 { if len(key) != 2 {
return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name) errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
t.Name(), field.Name))
return false
} }
spec.short = key[1:] spec.short = key[1:]
case key == "required": case key == "required":
@ -155,11 +174,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
spec.env = strings.ToUpper(field.Name) spec.env = strings.ToUpper(field.Name)
} }
default: default:
return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
return false
} }
} }
} }
p.spec = append(p.spec, &spec) p.spec = append(p.spec, &spec)
// if this was an embedded field then we already returned true up above
return false
})
if len(errs) > 0 {
return nil, errors.New(strings.Join(errs, "\n"))
} }
} }
if p.config.Program == "" { if p.config.Program == "" {

View File

@ -633,3 +633,24 @@ func TestInvalidMailAddr(t *testing.T) {
err := parse("--recipient xxx", &args) err := parse("--recipient xxx", &args)
assert.Error(t, err) assert.Error(t, err)
} }
type A struct {
X string
}
type B struct {
Y int
}
func TestEmbedded(t *testing.T) {
var args struct {
A
B
Z bool
}
err := parse("--x=hello --y=321 --z", &args)
require.NoError(t, err)
assert.Equal(t, "hello", args.X)
assert.Equal(t, 321, args.Y)
assert.Equal(t, true, args.Z)
}