diff --git a/v2/construct.go b/v2/construct.go index bed64eb..bd2800e 100644 --- a/v2/construct.go +++ b/v2/construct.go @@ -176,7 +176,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { return false } - // duplicate the entire path to avoid slice overwrites + // create a new destination path for this field subdest := dest.Child(field) arg := Argument{ dest: subdest, diff --git a/v2/parse.go b/v2/parse.go index 1a29e35..251ddb6 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -93,6 +93,250 @@ func (p *Parser) Parse(args, env []string) error { return p.Validate() } +// ProcessCommandLine scans arguments one-by-one, parses them and assigns +// the result to fields of the struct passed to NewParser. It returns +// an error if an argument is invalid or unknown, but not if a +// required argument is missing. To check that all required arguments +// are set, call Validate(). This function ignores the first element +// of args, which is assumed to be the program name itself. This function +// never overwrites arguments previously seen in a call to any Process* +// function. +func (p *Parser) ProcessCommandLine(args []string) error { + positionals, err := p.ProcessOptions(args) + if err != nil { + return err + } + return p.ProcessPositionals(positionals) +} + +// OverwriteWithCommandLine is like ProcessCommandLine but it overwrites +// any previously seen values. +func (p *Parser) OverwriteWithCommandLine(args []string) error { + positionals, err := p.OverwriteWithOptions(args) + if err != nil { + return err + } + return p.OverwriteWithPositionals(positionals) +} + +// ProcessOptions processes options but not positionals from the +// command line. Positionals are returned and can be passed to +// ProcessPositionals. This function ignores the first element of args, +// which is assumed to be the program name itself. Arguments seen +// in a previous call to any Process* or OverwriteWith* functions +// are ignored. +func (p *Parser) ProcessOptions(args []string) ([]string, error) { + return p.processOptions(args, false) +} + +// OverwriteWithOptions is like ProcessOptions except previously seen +// arguments are overwritten +func (p *Parser) OverwriteWithOptions(args []string) ([]string, error) { + return p.processOptions(args, true) +} + +func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) { + // union of args for the chain of subcommands encountered so far + p.leaf = p.cmd + + // we will add to this list each time we expand a subcommand + p.accumulatedArgs = make([]*Argument, len(p.leaf.args)) + copy(p.accumulatedArgs, p.leaf.args) + + // process each string from the command line + var allpositional bool + var positionals []string + + // must use explicit for loop, not range, because we manipulate i inside the loop + for i := 1; i < len(args); i++ { + token := args[i] + + // the "--" token indicates that all further tokens should be treated as positionals + if token == "--" { + allpositional = true + continue + } + + // check whether this is a positional argument + if !isFlag(token) || allpositional { + // each subcommand can have either subcommands or positionals, but not both + if len(p.leaf.subcommands) == 0 { + positionals = append(positionals, token) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(p.leaf.subcommands, token) + if subcmd == nil { + return nil, fmt.Errorf("invalid subcommand: %s", token) + } + + // instantiate the field to point to a new struct + v := p.val(subcmd.dest) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers + } + + // add the new options to the set of allowed options + p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) + p.leaf = subcmd + continue + } + + // check for special --help and --version flags + switch token { + case "-h", "--help": + return nil, ErrHelp + case "--version": + return nil, ErrVersion + } + + // check for an equals sign, as in "--foo=bar" + var value string + opt := strings.TrimLeft(token, "-") + if pos := strings.Index(opt, "="); pos != -1 { + value = opt[pos+1:] + opt = opt[:pos] + } + + // look up the arg for this option (note that the "args" slice changes as + // we expand subcommands so it is better not to use a map) + arg := findOption(p.accumulatedArgs, opt) + if arg == nil { + return nil, fmt.Errorf("unknown argument %s", token) + } + + // deal with the case of multiple values + if arg.cardinality == multiple { + // if arg.separate is true then just parse one value and append it + if arg.separate { + if value == "" { + if i+1 == len(args) { + return nil, fmt.Errorf("missing value for %s", token) + } + if isFlag(args[i+1]) { + return nil, fmt.Errorf("missing value for %s", token) + } + value = args[i+1] + i++ + } + + err := appendToSliceOrMap(p.val(arg.dest), value) + if err != nil { + return nil, fmt.Errorf("error processing %s: %v", token, err) + } + p.seen[arg] = true + continue + } + + // if args.separate is not true then consume tokens until next --option + var values []string + if value == "" { + for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { + values = append(values, args[i+1]) + i++ + } + } else { + values = append(values, value) + } + + // this is the first time we can check p.seen because we need to correctly + // increment i above, even when we then ignore the value + if p.seen[arg] && !overwrite { + continue + } + + // store the values into the slice or map + err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) + if err != nil { + return nil, fmt.Errorf("error processing %s: %v", token, err) + } + continue + } + + // if it's a flag and it has no value then set the value to true + // use boolean because this takes account of TextUnmarshaler + if arg.cardinality == zero && value == "" { + value = "true" + } + + // if we have something like "--foo" then the value is the next argument + if value == "" { + if i+1 == len(args) { + return nil, fmt.Errorf("missing value for %s", token) + } + if isFlag(args[i+1]) { + return nil, fmt.Errorf("missing value for %s", token) + } + value = args[i+1] + i++ + } + + // this is the first time we can check p.seen because we need to correctly + // increment i above, even when we then ignore the value + if p.seen[arg] && !overwrite { + continue + } + + err := scalar.ParseValue(p.val(arg.dest), value) + if err != nil { + return nil, fmt.Errorf("error processing %s: %v", token, err) + } + p.seen[arg] = true + } + + return positionals, nil +} + +// ProcessPositionals processes a list of positional arguments. If +// this list contains tokens that begin with a hyphen they will still be +// treated as positional arguments. Arguments seen in a previous call +// to any Process* or OverwriteWith* functions are ignored. +func (p *Parser) ProcessPositionals(positionals []string) error { + return p.processPositionals(positionals, false) +} + +// OverwriteWithPositionals is like ProcessPositionals except previously +// seen arguments are overwritten. +func (p *Parser) OverwriteWithPositionals(positionals []string) error { + return p.processPositionals(positionals, true) +} + +func (p *Parser) processPositionals(positionals []string, overwrite bool) error { + for _, arg := range p.accumulatedArgs { + if !arg.positional { + continue + } + if len(positionals) == 0 { + break + } + if arg.cardinality == multiple { + if !p.seen[arg] || overwrite { + err := setSliceOrMap(p.val(arg.dest), positionals, true) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) + } + } + positionals = nil + } else { + if !p.seen[arg] || overwrite { + err := scalar.ParseValue(p.val(arg.dest), positionals[0]) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) + } + } + positionals = positionals[1:] + } + p.seen[arg] = true + } + + if len(positionals) > 0 { + return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) + } + + return nil +} + // ProcessEnvironment processes environment variables from a list of strings // of the form KEY=VALUE. You can pass in os.Environ(). It // does not overwrite any fields with values already populated. @@ -167,180 +411,6 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { return nil } -// ProcessCommandLine goes through arguments one-by-one, parses them, -// and assigns the result to the underlying struct field. It returns -// an error if an argument is invalid or an option is unknown, not if a -// required argument is missing. To check that all required arguments -// are set, call CheckRequired(). This function ignores the first element -// of args, which is assumed to be the program name itself. -func (p *Parser) ProcessCommandLine(args []string) error { - positionals, err := p.ProcessOptions(args) - if err != nil { - return err - } - return p.ProcessPositionals(positionals) -} - -// ProcessOptions process command line arguments but does not process -// positional arguments. Instead, it returns positionals. These can then -// be passed to ProcessPositionals. This function ignores the first element -// of args, which is assumed to be the program name itself. -func (p *Parser) ProcessOptions(args []string) ([]string, error) { - // union of args for the chain of subcommands encountered so far - curCmd := p.cmd - p.leaf = curCmd - - // we will add to this list each time we expand a subcommand - p.accumulatedArgs = make([]*Argument, len(curCmd.args)) - copy(p.accumulatedArgs, curCmd.args) - - // process each string from the command line - var allpositional bool - var positionals []string - - // must use explicit for loop, not range, because we manipulate i inside the loop - for i := 1; i < len(args); i++ { - token := args[i] - if token == "--" { - allpositional = true - continue - } - - if !isFlag(token) || allpositional { - // each subcommand can have either subcommands or positionals, but not both - if len(curCmd.subcommands) == 0 { - positionals = append(positionals, token) - continue - } - - // if we have a subcommand then make sure it is valid for the current context - subcmd := findSubcommand(curCmd.subcommands, token) - if subcmd == nil { - return nil, fmt.Errorf("invalid subcommand: %s", token) - } - - // instantiate the field to point to a new struct - v := p.val(subcmd.dest) - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers - } - - // add the new options to the set of allowed options - p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) - - curCmd = subcmd - p.leaf = curCmd - continue - } - - // check for special --help and --version flags - switch token { - case "-h", "--help": - return nil, ErrHelp - case "--version": - return nil, ErrVersion - } - - // check for an equals sign, as in "--foo=bar" - var value string - opt := strings.TrimLeft(token, "-") - if pos := strings.Index(opt, "="); pos != -1 { - value = opt[pos+1:] - opt = opt[:pos] - } - - // look up the arg for this option (note that the "args" slice changes as - // we expand subcommands so it is better not to use a map) - arg := findOption(p.accumulatedArgs, opt) - if arg == nil { - return nil, fmt.Errorf("unknown argument %s", token) - } - p.seen[arg] = true - - // deal with the case of multiple values - if arg.cardinality == multiple { - var values []string - if value == "" { - for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { - values = append(values, args[i+1]) - i++ - if arg.separate { - break - } - } - } else { - values = append(values, value) - } - err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - continue - } - - // if it's a flag and it has no value then set the value to true - // use boolean because this takes account of TextUnmarshaler - if arg.cardinality == zero && value == "" { - value = "true" - } - - // if we have something like "--foo" then the value is the next argument - if value == "" { - if i+1 == len(args) { - return nil, fmt.Errorf("missing value for %s", token) - } - if isFlag(args[i+1]) { - return nil, fmt.Errorf("missing value for %s", token) - } - value = args[i+1] - i++ - } - - p.seen[arg] = true - err := scalar.ParseValue(p.val(arg.dest), value) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - } - - return positionals, nil -} - -// ProcessPositionals processes a list of positional arguments. It is assumed -// that options such as --abc and --abc=123 have already been removed. If -// this list contains tokens that begin with a hyphen they will still be -// treated as positional arguments. -func (p *Parser) ProcessPositionals(positionals []string) error { - for _, arg := range p.accumulatedArgs { - if !arg.positional { - continue - } - if len(positionals) == 0 { - break - } - p.seen[arg] = true - if arg.cardinality == multiple { - err := setSliceOrMap(p.val(arg.dest), positionals, true) - if err != nil { - return fmt.Errorf("error processing %s: %v", arg.field.Name, err) - } - positionals = nil - } else { - err := scalar.ParseValue(p.val(arg.dest), positionals[0]) - if err != nil { - return fmt.Errorf("error processing %s: %v", arg.field.Name, err) - } - positionals = positionals[1:] - } - } - - if len(positionals) > 0 { - return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) - } - - return nil -} - // ProcessDefaults assigns default values to all fields that have default values and // are not already populated. func (p *Parser) ProcessDefaults() error { diff --git a/v2/parse_test.go b/v2/parse_test.go index ee46006..72efa79 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -38,6 +38,7 @@ func TestString(t *testing.T) { _, err := parse(&args, "--foo bar --ptr baz") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) + require.NotNil(t, args.Ptr) assert.Equal(t, "baz", *args.Ptr) } diff --git a/v2/sequence.go b/v2/sequence.go index 35a3614..f0fff46 100644 --- a/v2/sequence.go +++ b/v2/sequence.go @@ -27,7 +27,7 @@ func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { case reflect.Map: return setMap(dest, values, clear) default: - return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t) + return fmt.Errorf("cannot insert multiple values into a %v", t) } } @@ -121,3 +121,98 @@ func setMap(dest reflect.Value, values []string, clear bool) error { } return nil } + +// appendSliceOrMap parses a string and appends it to an existing slice or map. +func appendToSliceOrMap(dest reflect.Value, value string) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + t := dest.Type() + if t.Kind() == reflect.Ptr { + dest = dest.Elem() + t = t.Elem() + } + + switch t.Kind() { + case reflect.Slice: + return appendToSlice(dest, value) + case reflect.Map: + return appendToMap(dest, value) + default: + return fmt.Errorf("cannot insert multiple values into a %v", t) + } +} + +// appendSlice parses a string and appends the result into a slice. +func appendToSlice(dest reflect.Value, s string) error { + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { + ptr = true + elem = elem.Elem() + } + + // parse the value and append + v := reflect.New(elem) + if err := scalar.ParseValue(v.Elem(), s); err != nil { + return err + } + if !ptr { + v = v.Elem() + } + dest.Set(reflect.Append(dest, v)) + return nil +} + +// appendToMap parses a name=value string and inserts it into a map. +// If clear is true then any values already in the map are removed. +func appendToMap(dest reflect.Value, s string) error { + // determine the key and value type + var keyIsPtr bool + keyType := dest.Type().Key() + if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) { + keyIsPtr = true + keyType = keyType.Elem() + } + + var valIsPtr bool + valType := dest.Type().Elem() + if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) { + valIsPtr = true + valType = valType.Elem() + } + + // allocate the map if it is not allocated + if dest.IsNil() { + dest.Set(reflect.MakeMap(dest.Type())) + } + + // split at the first equals sign + pos := strings.Index(s, "=") + if pos == -1 { + return fmt.Errorf("cannot parse %q into a map, expected format key=value", s) + } + + // parse the key + k := reflect.New(keyType) + if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil { + return err + } + if !keyIsPtr { + k = k.Elem() + } + + // parse the value + v := reflect.New(valType) + if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil { + return err + } + if !valIsPtr { + v = v.Elem() + } + + // add it to the map + dest.SetMapIndex(k, v) + return nil +} diff --git a/v2/sequence_test.go b/v2/sequence_test.go index fde3e3a..6383949 100644 --- a/v2/sequence_test.go +++ b/v2/sequence_test.go @@ -150,3 +150,7 @@ func TestSetSliceOrMapErrors(t *testing.T) { err = setSliceOrMap(dest, nil, false) assert.Error(t, err) } + +// check that we can accumulate "separate" args across env, cmdline, map, and defaults + +// check what happens if we have a required arg with a default value