From 4e977796af5ef0863a674ef468c5036dcca20623 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 30 Apr 2019 12:54:28 -0700 Subject: [PATCH] add recursive expansion of subcommands --- parse.go | 130 ++++++++++++++++++++++++++++++++++++++------------ parse_test.go | 7 ++- 2 files changed, 103 insertions(+), 34 deletions(-) diff --git a/parse.go b/parse.go index b5b76b8..353b365 100644 --- a/parse.go +++ b/parse.go @@ -152,7 +152,6 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { if t.Kind() != reflect.Ptr { panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) } - t = t.Elem() cmd, err := cmdFromStruct(name, t, nil, i) if err != nil { @@ -172,8 +171,16 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { } func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) { + // commands can only be created from pointers to structs + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a %s", + strings.Join(path, "."), t.Kind()) + } + + t = t.Elem() if t.Kind() != reflect.Struct { - panic(fmt.Sprintf("%v is not a struct pointer", t)) + return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a pointer to %s", + strings.Join(path, "."), t.Kind()) } var cmd command @@ -190,9 +197,13 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma return true } + // duplicate the entire path to avoid slice overwrites + subpath := make([]string, len(path)+1) + copy(subpath, append(path, field.Name)) + spec := spec{ root: root, - path: append(path, field.Name), + path: subpath, long: strings.ToLower(field.Name), typ: field.Type, } @@ -258,7 +269,7 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma cmdname = strings.ToLower(field.Name) } - subcmd, err := cmdFromStruct(cmdname, field.Type, append(path, field.Name), root) + subcmd, err := cmdFromStruct(cmdname, field.Type, subpath, root) if err != nil { errs = append(errs, err.Error()) return false @@ -281,6 +292,17 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma return nil, errors.New(strings.Join(errs, "\n")) } + // check that we don't have both positionals and subcommands + var hasPositional bool + for _, spec := range cmd.specs { + if spec.positional { + hasPositional = true + } + } + if hasPositional && len(cmd.subcommands) > 0 { + return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t) + } + return &cmd, nil } @@ -301,30 +323,11 @@ func (p *Parser) Parse(args []string) error { } // Process all command line arguments - return p.process(p.cmd.specs, args) + return p.process(args) } -// process goes through arguments one-by-one, parses them, and assigns the result to -// the underlying struct field -func (p *Parser) process(specs []*spec, args []string) error { - // track the options we have seen - wasPresent := make(map[*spec]bool) - - // construct a map from --option to spec - optionMap := make(map[string]*spec) - for _, spec := range specs { - if spec.positional { - continue - } - if spec.long != "" { - optionMap[spec.long] = spec - } - if spec.short != "" { - optionMap[spec.short] = spec - } - } - - // deal with environment vars +// process environment vars for the given arguments +func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { for _, spec := range specs { if spec.env == "" { continue @@ -361,6 +364,28 @@ func (p *Parser) process(specs []*spec, args []string) error { wasPresent[spec] = true } + return nil +} + +// process goes through arguments one-by-one, parses them, and assigns the result to +// the underlying struct field +func (p *Parser) process(args []string) error { + // track the options we have seen + wasPresent := make(map[*spec]bool) + + // union of specs for the chain of subcommands encountered so far + curCmd := p.cmd + + // make a copy of the specs because we will add to this list each time we expand a subcommand + specs := make([]*spec, len(curCmd.specs)) + copy(specs, curCmd.specs) + + // deal with environment vars + err := p.captureEnvVars(specs, wasPresent) + if err != nil { + return err + } + // process each string from the command line var allpositional bool var positionals []string @@ -374,7 +399,28 @@ func (p *Parser) process(specs []*spec, args []string) error { } if !isFlag(arg) || allpositional { - positionals = append(positionals, arg) + // each subcommand can have either subcommands or positionals, but not both + if len(curCmd.subcommands) == 0 { + positionals = append(positionals, arg) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(curCmd.subcommands, arg) + if subcmd == nil { + return fmt.Errorf("invalid subcommand: %s", arg) + } + + // add the new options to the set of allowed options + specs = append(specs, subcmd.specs...) + + // capture environment vars for these new options + err := p.captureEnvVars(subcmd.specs, wasPresent) + if err != nil { + return err + } + + curCmd = subcmd continue } @@ -386,9 +432,10 @@ func (p *Parser) process(specs []*spec, args []string) error { opt = opt[:pos] } - // lookup the spec for this option - spec, ok := optionMap[opt] - if !ok { + // lookup the spec for this option (note that the "specs" slice changes as + // we expand subcommands so it is better not to use a map) + spec := findOption(specs, opt) + if spec == nil { return fmt.Errorf("unknown argument %s", arg) } wasPresent[spec] = true @@ -630,3 +677,26 @@ func isBoolean(t reflect.Type) bool { return false } } + +// findOption finds an option from its name, or returns null if no spec is found +func findOption(specs []*spec, name string) *spec { + for _, spec := range specs { + if spec.positional { + continue + } + if spec.long == name || spec.short == name { + return spec + } + } + return nil +} + +// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found +func findSubcommand(cmds []*command, name string) *command { + for _, cmd := range cmds { + if cmd.name == name { + return cmd + } + } + return nil +} diff --git a/parse_test.go b/parse_test.go index 9aad2e3..94cf21a 100644 --- a/parse_test.go +++ b/parse_test.go @@ -462,11 +462,10 @@ func TestPanicOnNonPointer(t *testing.T) { }) } -func TestPanicOnNonStruct(t *testing.T) { +func TestErrorOnNonStruct(t *testing.T) { var args string - assert.Panics(t, func() { - _ = parse("", &args) - }) + err := parse("", &args) + assert.Error(t, err) } func TestUnsupportedType(t *testing.T) {