Merge pull request #39 from alexflint/embedded
add support for embedded structs
This commit is contained in:
commit
7c77c70f85
28
README.md
28
README.md
|
@ -188,6 +188,34 @@ $ ./example --version
|
||||||
someprogram 4.3.0
|
someprogram 4.3.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Embedded structs
|
||||||
|
|
||||||
|
The fields of embedded structs are treated just like regular fields:
|
||||||
|
|
||||||
|
```go
|
||||||
|
|
||||||
|
type DatabaseOptions struct {
|
||||||
|
Host string
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogOptions struct {
|
||||||
|
LogFile string
|
||||||
|
Verbose bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var args struct {
|
||||||
|
DatabaseOptions
|
||||||
|
LogOptions
|
||||||
|
}
|
||||||
|
arg.MustParse(&args)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
As usual, any field tagged with `arg:"-"` is ignored.
|
||||||
|
|
||||||
### Custom parsing
|
### Custom parsing
|
||||||
|
|
||||||
You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:
|
You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:
|
||||||
|
|
49
parse.go
49
parse.go
|
@ -21,7 +21,6 @@ type spec struct {
|
||||||
env string
|
env string
|
||||||
wasPresent bool
|
wasPresent bool
|
||||||
boolean bool
|
boolean bool
|
||||||
fieldName string // for generating helpful errors
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrHelp indicates that -h or --help were provided
|
// ErrHelp indicates that -h or --help were provided
|
||||||
|
@ -81,6 +80,19 @@ type Versioned interface {
|
||||||
Version() string
|
Version() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
|
||||||
|
func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) {
|
||||||
|
t := v.Type()
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
val := v.Field(i)
|
||||||
|
expand := visit(field, val, t)
|
||||||
|
if expand && field.Type.Kind() == reflect.Struct {
|
||||||
|
walkFields(val, visit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewParser constructs a parser from a list of destination structs
|
// NewParser constructs a parser from a list of destination structs
|
||||||
func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
p := Parser{
|
p := Parser{
|
||||||
|
@ -99,19 +111,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
panic(fmt.Sprintf("%T is not a struct pointer", dest))
|
panic(fmt.Sprintf("%T is not a struct pointer", dest))
|
||||||
}
|
}
|
||||||
|
|
||||||
t := v.Type()
|
var errs []string
|
||||||
for i := 0; i < t.NumField(); i++ {
|
walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool {
|
||||||
// Check for the ignore switch in the tag
|
// Check for the ignore switch in the tag
|
||||||
field := t.Field(i)
|
|
||||||
tag := field.Tag.Get("arg")
|
tag := field.Tag.Get("arg")
|
||||||
if tag == "-" {
|
if tag == "-" {
|
||||||
continue
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is an embedded struct then recurse into its fields
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
spec := spec{
|
spec := spec{
|
||||||
long: strings.ToLower(field.Name),
|
long: strings.ToLower(field.Name),
|
||||||
dest: v.Field(i),
|
dest: val,
|
||||||
fieldName: t.Name() + "." + field.Name,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether this field is supported. It's good to do this here rather than
|
// Check whether this field is supported. It's good to do this here rather than
|
||||||
|
@ -121,7 +136,9 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
var parseable bool
|
var parseable bool
|
||||||
parseable, spec.boolean, spec.multiple = canParse(field.Type)
|
parseable, spec.boolean, spec.multiple = canParse(field.Type)
|
||||||
if !parseable {
|
if !parseable {
|
||||||
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())
|
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
|
||||||
|
t.Name(), field.Name, field.Type.String()))
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look at the tag
|
// Look at the tag
|
||||||
|
@ -138,7 +155,9 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
spec.long = key[2:]
|
spec.long = key[2:]
|
||||||
case strings.HasPrefix(key, "-"):
|
case strings.HasPrefix(key, "-"):
|
||||||
if len(key) != 2 {
|
if len(key) != 2 {
|
||||||
return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name)
|
errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
|
||||||
|
t.Name(), field.Name))
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
spec.short = key[1:]
|
spec.short = key[1:]
|
||||||
case key == "required":
|
case key == "required":
|
||||||
|
@ -155,11 +174,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
spec.env = strings.ToUpper(field.Name)
|
spec.env = strings.ToUpper(field.Name)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag)
|
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.spec = append(p.spec, &spec)
|
p.spec = append(p.spec, &spec)
|
||||||
|
|
||||||
|
// if this was an embedded field then we already returned true up above
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return nil, errors.New(strings.Join(errs, "\n"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if p.config.Program == "" {
|
if p.config.Program == "" {
|
||||||
|
|
|
@ -633,3 +633,24 @@ func TestInvalidMailAddr(t *testing.T) {
|
||||||
err := parse("--recipient xxx", &args)
|
err := parse("--recipient xxx", &args)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type A struct {
|
||||||
|
X string
|
||||||
|
}
|
||||||
|
|
||||||
|
type B struct {
|
||||||
|
Y int
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbedded(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
A
|
||||||
|
B
|
||||||
|
Z bool
|
||||||
|
}
|
||||||
|
err := parse("--x=hello --y=321 --z", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hello", args.X)
|
||||||
|
assert.Equal(t, 321, args.Y)
|
||||||
|
assert.Equal(t, true, args.Z)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue