diff --git a/parse.go b/parse.go index 353b365..cb8be6d 100644 --- a/parse.go +++ b/parse.go @@ -13,10 +13,32 @@ import ( scalar "github.com/alexflint/go-scalar" ) +// path represents a sequence of steps to find the output location for an +// argument or subcommand in the final destination struct +type path struct { + root int // index of the destination struct + fields []string // sequence of struct field names to traverse +} + +// String gets a string representation of the given path +func (p path) String() string { + return "args." + strings.Join(p.fields, ".") +} + +// Child gets a new path representing a child of this path. +func (p path) Child(child string) path { + // copy the entire slice of fields to avoid possible slice overwrite + subfields := make([]string, len(p.fields)+1) + copy(subfields, append(p.fields, child)) + return path{ + root: p.root, + fields: subfields, + } +} + // spec represents a command line option type spec struct { - root int - path []string // sequence of field names + dest path typ reflect.Type long string short string @@ -32,6 +54,7 @@ type spec struct { // command represents a named subcommand, or the top-level command type command struct { name string + dest path specs []*spec subcommands []*command } @@ -153,11 +176,12 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) } - cmd, err := cmdFromStruct(name, t, nil, i) + cmd, err := cmdFromStruct(name, path{root: i}, t) if err != nil { return nil, err } p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) if dest, ok := dest.(Versioned); ok { p.version = dest.Version() @@ -170,20 +194,24 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { return &p, nil } -func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) { +func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // commands can only be created from pointers to structs if t.Kind() != reflect.Ptr { - return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a %s", - strings.Join(path, "."), t.Kind()) + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", + dest, t.Kind()) } t = t.Elem() if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a pointer to %s", - strings.Join(path, "."), t.Kind()) + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", + dest, t.Kind()) + } + + cmd := command{ + name: name, + dest: dest, } - var cmd command var errs []string walkFields(t, func(field reflect.StructField, t reflect.Type) bool { // Check for the ignore switch in the tag @@ -198,12 +226,9 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma } // duplicate the entire path to avoid slice overwrites - subpath := make([]string, len(path)+1) - copy(subpath, append(path, field.Name)) - + subdest := dest.Child(field.Name) spec := spec{ - root: root, - path: subpath, + dest: subdest, long: strings.ToLower(field.Name), typ: field.Type, } @@ -213,19 +238,8 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma spec.help = help } - // Check whether this field is supported. It's good to do this here rather than - // wait until ParseValue because it means that a program with invalid argument - // fields will always fail regardless of whether the arguments it received - // exercised those fields. - var parseable bool - parseable, spec.boolean, spec.multiple = canParse(field.Type) - if !parseable { - errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", - t.Name(), field.Name, field.Type.String())) - return false - } - // Look at the tag + var isSubcommand bool // tracks whether this field is a subcommand if tag != "" { for _, key := range strings.Split(tag, ",") { key = strings.TrimLeft(key, " ") @@ -269,20 +283,37 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma cmdname = strings.ToLower(field.Name) } - subcmd, err := cmdFromStruct(cmdname, field.Type, subpath, root) + subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) if err != nil { errs = append(errs, err.Error()) return false } cmd.subcommands = append(cmd.subcommands, subcmd) + isSubcommand = true + fmt.Println("found a subcommand") default: errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) return false } } } - cmd.specs = append(cmd.specs, &spec) + + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + if !isSubcommand { + cmd.specs = append(cmd.specs, &spec) + + var parseable bool + parseable, spec.boolean, spec.multiple = canParse(field.Type) + if !parseable { + errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", + t.Name(), field.Name, field.Type.String())) + return false + } + } // if this was an embedded field then we already returned true up above return false @@ -303,6 +334,8 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t) } + fmt.Printf("parsed a command with %d subcommands\n", len(cmd.subcommands)) + return &cmd, nil } @@ -349,7 +382,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error err, ) } - if err = setSlice(p.writable(spec), values, !spec.separate); err != nil { + if err = setSlice(p.writable(spec.dest), values, !spec.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", spec.env, @@ -357,7 +390,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error ) } } else { - if err := scalar.ParseValue(p.writable(spec), value); err != nil { + if err := scalar.ParseValue(p.writable(spec.dest), value); err != nil { return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) } } @@ -400,6 +433,7 @@ func (p *Parser) process(args []string) error { if !isFlag(arg) || allpositional { // each subcommand can have either subcommands or positionals, but not both + fmt.Printf("processing %q, with %d subcommands", arg, len(curCmd.subcommands)) if len(curCmd.subcommands) == 0 { positionals = append(positionals, arg) continue @@ -454,7 +488,7 @@ func (p *Parser) process(args []string) error { } else { values = append(values, value) } - err := setSlice(p.writable(spec), values, !spec.separate) + err := setSlice(p.writable(spec.dest), values, !spec.separate) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -479,7 +513,7 @@ func (p *Parser) process(args []string) error { i++ } - err := scalar.ParseValue(p.writable(spec), value) + err := scalar.ParseValue(p.writable(spec.dest), value) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -495,13 +529,13 @@ func (p *Parser) process(args []string) error { } wasPresent[spec] = true if spec.multiple { - err := setSlice(p.writable(spec), positionals, true) + err := setSlice(p.writable(spec.dest), positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = nil } else { - err := scalar.ParseValue(p.writable(spec), positionals[0]) + err := scalar.ParseValue(p.writable(spec.dest), positionals[0]) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } @@ -546,9 +580,9 @@ func isFlag(s string) bool { // readable returns a reflect.Value corresponding to the current value for the // given -func (p *Parser) readable(spec *spec) reflect.Value { - v := p.roots[spec.root] - for _, field := range spec.path { +func (p *Parser) readable(dest path) reflect.Value { + v := p.roots[dest.root] + for _, field := range dest.fields { if v.Kind() == reflect.Ptr { if v.IsNil() { return reflect.Value{} @@ -559,21 +593,21 @@ func (p *Parser) readable(spec *spec) reflect.Value { v = v.FieldByName(field) if !v.IsValid() { // it is appropriate to panic here because this can only happen due to - // an internal bug in this library (since we construct spec.path ourselves + // an internal bug in this library (since we construct the path ourselves // by reflecting on the same struct) panic(fmt.Errorf("error resolving path %v: %v has no field named %v", - spec.path, v.Type(), field)) + dest.fields, v.Type(), field)) } } return v } -// writable traverses the destination struct to find the destination to +// writable trav.patherses the destination struct to find the destination to // which the value of the given spec should be written. It fills in null // structs with pointers to the zero value for that struct. -func (p *Parser) writable(spec *spec) reflect.Value { - v := p.roots[spec.root] - for _, field := range spec.path { +func (p *Parser) writable(dest path) reflect.Value { + v := p.roots[dest.root] + for _, field := range dest.fields { if v.Kind() == reflect.Ptr { if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) @@ -584,10 +618,10 @@ func (p *Parser) writable(spec *spec) reflect.Value { v = v.FieldByName(field) if !v.IsValid() { // it is appropriate to panic here because this can only happen due to - // an internal bug in this library (since we construct spec.path ourselves + // an internal bug in this library (since we construct the path ourselves // by reflecting on the same struct) panic(fmt.Errorf("error resolving path %v: %v has no field named %v", - spec.path, v.Type(), field)) + dest.fields, v.Type(), field)) } } return v diff --git a/subcommand_test.go b/subcommand_test.go index d17c604..02c7b54 100644 --- a/subcommand_test.go +++ b/subcommand_test.go @@ -4,12 +4,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // This file contains tests for parse.go but I decided to put them here // since that file is getting large -func TestSubcommandNotAStruct(t *testing.T) { +func TestSubcommandNotAPointer(t *testing.T) { var args struct { A string `arg:"subcommand"` } @@ -17,6 +18,14 @@ func TestSubcommandNotAStruct(t *testing.T) { assert.Error(t, err) } +func TestSubcommandNotAPointerToStruct(t *testing.T) { + var args struct { + A struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + func TestPositionalAndSubcommandNotAllowed(t *testing.T) { var args struct { A string `arg:"positional"` @@ -25,3 +34,14 @@ func TestPositionalAndSubcommandNotAllowed(t *testing.T) { _, err := NewParser(Config{}, &args) assert.Error(t, err) } + +func TestMinimalSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + err := parse("list", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) +} diff --git a/usage.go b/usage.go index 833046f..d73da71 100644 --- a/usage.go +++ b/usage.go @@ -115,14 +115,12 @@ func (p *Parser) WriteHelp(w io.Writer) { long: "help", short: "h", help: "display this help and exit", - root: -1, }) if p.version != "" { p.printOption(w, &spec{ boolean: true, long: "version", help: "display version and exit", - root: -1, }) } } @@ -143,8 +141,8 @@ func (p *Parser) printOption(w io.Writer, spec *spec) { } // If spec.dest is not the zero value then a default value has been added. var v reflect.Value - if spec.root >= 0 { - v = p.readable(spec) + if len(spec.dest.fields) > 0 { + v = p.readable(spec.dest) } if v.IsValid() { z := reflect.Zero(v.Type())