add OverwriteWithOptions, OverwriteWithCommandLine

This commit is contained in:
Alex Flint 2022-10-04 12:34:53 -07:00
parent 5f0c48f092
commit 2775f58376
5 changed files with 346 additions and 176 deletions

View File

@ -176,7 +176,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) {
return false
}
// duplicate the entire path to avoid slice overwrites
// create a new destination path for this field
subdest := dest.Child(field)
arg := Argument{
dest: subdest,

View File

@ -93,6 +93,250 @@ func (p *Parser) Parse(args, env []string) error {
return p.Validate()
}
// ProcessCommandLine scans arguments one-by-one, parses them and assigns
// the result to fields of the struct passed to NewParser. It returns
// an error if an argument is invalid or unknown, but not if a
// required argument is missing. To check that all required arguments
// are set, call Validate(). This function ignores the first element
// of args, which is assumed to be the program name itself. This function
// never overwrites arguments previously seen in a call to any Process*
// function.
func (p *Parser) ProcessCommandLine(args []string) error {
positionals, err := p.ProcessOptions(args)
if err != nil {
return err
}
return p.ProcessPositionals(positionals)
}
// OverwriteWithCommandLine is like ProcessCommandLine but it overwrites
// any previously seen values.
func (p *Parser) OverwriteWithCommandLine(args []string) error {
positionals, err := p.OverwriteWithOptions(args)
if err != nil {
return err
}
return p.OverwriteWithPositionals(positionals)
}
// ProcessOptions processes options but not positionals from the
// command line. Positionals are returned and can be passed to
// ProcessPositionals. This function ignores the first element of args,
// which is assumed to be the program name itself. Arguments seen
// in a previous call to any Process* or OverwriteWith* functions
// are ignored.
func (p *Parser) ProcessOptions(args []string) ([]string, error) {
return p.processOptions(args, false)
}
// OverwriteWithOptions is like ProcessOptions except previously seen
// arguments are overwritten
func (p *Parser) OverwriteWithOptions(args []string) ([]string, error) {
return p.processOptions(args, true)
}
func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) {
// union of args for the chain of subcommands encountered so far
p.leaf = p.cmd
// we will add to this list each time we expand a subcommand
p.accumulatedArgs = make([]*Argument, len(p.leaf.args))
copy(p.accumulatedArgs, p.leaf.args)
// process each string from the command line
var allpositional bool
var positionals []string
// must use explicit for loop, not range, because we manipulate i inside the loop
for i := 1; i < len(args); i++ {
token := args[i]
// the "--" token indicates that all further tokens should be treated as positionals
if token == "--" {
allpositional = true
continue
}
// check whether this is a positional argument
if !isFlag(token) || allpositional {
// each subcommand can have either subcommands or positionals, but not both
if len(p.leaf.subcommands) == 0 {
positionals = append(positionals, token)
continue
}
// if we have a subcommand then make sure it is valid for the current context
subcmd := findSubcommand(p.leaf.subcommands, token)
if subcmd == nil {
return nil, fmt.Errorf("invalid subcommand: %s", token)
}
// instantiate the field to point to a new struct
v := p.val(subcmd.dest)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
}
// add the new options to the set of allowed options
p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...)
p.leaf = subcmd
continue
}
// check for special --help and --version flags
switch token {
case "-h", "--help":
return nil, ErrHelp
case "--version":
return nil, ErrVersion
}
// check for an equals sign, as in "--foo=bar"
var value string
opt := strings.TrimLeft(token, "-")
if pos := strings.Index(opt, "="); pos != -1 {
value = opt[pos+1:]
opt = opt[:pos]
}
// look up the arg for this option (note that the "args" slice changes as
// we expand subcommands so it is better not to use a map)
arg := findOption(p.accumulatedArgs, opt)
if arg == nil {
return nil, fmt.Errorf("unknown argument %s", token)
}
// deal with the case of multiple values
if arg.cardinality == multiple {
// if arg.separate is true then just parse one value and append it
if arg.separate {
if value == "" {
if i+1 == len(args) {
return nil, fmt.Errorf("missing value for %s", token)
}
if isFlag(args[i+1]) {
return nil, fmt.Errorf("missing value for %s", token)
}
value = args[i+1]
i++
}
err := appendToSliceOrMap(p.val(arg.dest), value)
if err != nil {
return nil, fmt.Errorf("error processing %s: %v", token, err)
}
p.seen[arg] = true
continue
}
// if args.separate is not true then consume tokens until next --option
var values []string
if value == "" {
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
values = append(values, args[i+1])
i++
}
} else {
values = append(values, value)
}
// this is the first time we can check p.seen because we need to correctly
// increment i above, even when we then ignore the value
if p.seen[arg] && !overwrite {
continue
}
// store the values into the slice or map
err := setSliceOrMap(p.val(arg.dest), values, !arg.separate)
if err != nil {
return nil, fmt.Errorf("error processing %s: %v", token, err)
}
continue
}
// if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler
if arg.cardinality == zero && value == "" {
value = "true"
}
// if we have something like "--foo" then the value is the next argument
if value == "" {
if i+1 == len(args) {
return nil, fmt.Errorf("missing value for %s", token)
}
if isFlag(args[i+1]) {
return nil, fmt.Errorf("missing value for %s", token)
}
value = args[i+1]
i++
}
// this is the first time we can check p.seen because we need to correctly
// increment i above, even when we then ignore the value
if p.seen[arg] && !overwrite {
continue
}
err := scalar.ParseValue(p.val(arg.dest), value)
if err != nil {
return nil, fmt.Errorf("error processing %s: %v", token, err)
}
p.seen[arg] = true
}
return positionals, nil
}
// ProcessPositionals processes a list of positional arguments. If
// this list contains tokens that begin with a hyphen they will still be
// treated as positional arguments. Arguments seen in a previous call
// to any Process* or OverwriteWith* functions are ignored.
func (p *Parser) ProcessPositionals(positionals []string) error {
return p.processPositionals(positionals, false)
}
// OverwriteWithPositionals is like ProcessPositionals except previously
// seen arguments are overwritten.
func (p *Parser) OverwriteWithPositionals(positionals []string) error {
return p.processPositionals(positionals, true)
}
func (p *Parser) processPositionals(positionals []string, overwrite bool) error {
for _, arg := range p.accumulatedArgs {
if !arg.positional {
continue
}
if len(positionals) == 0 {
break
}
if arg.cardinality == multiple {
if !p.seen[arg] || overwrite {
err := setSliceOrMap(p.val(arg.dest), positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg.field.Name, err)
}
}
positionals = nil
} else {
if !p.seen[arg] || overwrite {
err := scalar.ParseValue(p.val(arg.dest), positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", arg.field.Name, err)
}
}
positionals = positionals[1:]
}
p.seen[arg] = true
}
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
return nil
}
// ProcessEnvironment processes environment variables from a list of strings
// of the form KEY=VALUE. You can pass in os.Environ(). It
// does not overwrite any fields with values already populated.
@ -167,180 +411,6 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error {
return nil
}
// ProcessCommandLine goes through arguments one-by-one, parses them,
// and assigns the result to the underlying struct field. It returns
// an error if an argument is invalid or an option is unknown, not if a
// required argument is missing. To check that all required arguments
// are set, call CheckRequired(). This function ignores the first element
// of args, which is assumed to be the program name itself.
func (p *Parser) ProcessCommandLine(args []string) error {
positionals, err := p.ProcessOptions(args)
if err != nil {
return err
}
return p.ProcessPositionals(positionals)
}
// ProcessOptions process command line arguments but does not process
// positional arguments. Instead, it returns positionals. These can then
// be passed to ProcessPositionals. This function ignores the first element
// of args, which is assumed to be the program name itself.
func (p *Parser) ProcessOptions(args []string) ([]string, error) {
// union of args for the chain of subcommands encountered so far
curCmd := p.cmd
p.leaf = curCmd
// we will add to this list each time we expand a subcommand
p.accumulatedArgs = make([]*Argument, len(curCmd.args))
copy(p.accumulatedArgs, curCmd.args)
// process each string from the command line
var allpositional bool
var positionals []string
// must use explicit for loop, not range, because we manipulate i inside the loop
for i := 1; i < len(args); i++ {
token := args[i]
if token == "--" {
allpositional = true
continue
}
if !isFlag(token) || allpositional {
// each subcommand can have either subcommands or positionals, but not both
if len(curCmd.subcommands) == 0 {
positionals = append(positionals, token)
continue
}
// if we have a subcommand then make sure it is valid for the current context
subcmd := findSubcommand(curCmd.subcommands, token)
if subcmd == nil {
return nil, fmt.Errorf("invalid subcommand: %s", token)
}
// instantiate the field to point to a new struct
v := p.val(subcmd.dest)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
}
// add the new options to the set of allowed options
p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...)
curCmd = subcmd
p.leaf = curCmd
continue
}
// check for special --help and --version flags
switch token {
case "-h", "--help":
return nil, ErrHelp
case "--version":
return nil, ErrVersion
}
// check for an equals sign, as in "--foo=bar"
var value string
opt := strings.TrimLeft(token, "-")
if pos := strings.Index(opt, "="); pos != -1 {
value = opt[pos+1:]
opt = opt[:pos]
}
// look up the arg for this option (note that the "args" slice changes as
// we expand subcommands so it is better not to use a map)
arg := findOption(p.accumulatedArgs, opt)
if arg == nil {
return nil, fmt.Errorf("unknown argument %s", token)
}
p.seen[arg] = true
// deal with the case of multiple values
if arg.cardinality == multiple {
var values []string
if value == "" {
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
values = append(values, args[i+1])
i++
if arg.separate {
break
}
}
} else {
values = append(values, value)
}
err := setSliceOrMap(p.val(arg.dest), values, !arg.separate)
if err != nil {
return nil, fmt.Errorf("error processing %s: %v", token, err)
}
continue
}
// if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler
if arg.cardinality == zero && value == "" {
value = "true"
}
// if we have something like "--foo" then the value is the next argument
if value == "" {
if i+1 == len(args) {
return nil, fmt.Errorf("missing value for %s", token)
}
if isFlag(args[i+1]) {
return nil, fmt.Errorf("missing value for %s", token)
}
value = args[i+1]
i++
}
p.seen[arg] = true
err := scalar.ParseValue(p.val(arg.dest), value)
if err != nil {
return nil, fmt.Errorf("error processing %s: %v", token, err)
}
}
return positionals, nil
}
// ProcessPositionals processes a list of positional arguments. It is assumed
// that options such as --abc and --abc=123 have already been removed. If
// this list contains tokens that begin with a hyphen they will still be
// treated as positional arguments.
func (p *Parser) ProcessPositionals(positionals []string) error {
for _, arg := range p.accumulatedArgs {
if !arg.positional {
continue
}
if len(positionals) == 0 {
break
}
p.seen[arg] = true
if arg.cardinality == multiple {
err := setSliceOrMap(p.val(arg.dest), positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg.field.Name, err)
}
positionals = nil
} else {
err := scalar.ParseValue(p.val(arg.dest), positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", arg.field.Name, err)
}
positionals = positionals[1:]
}
}
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
return nil
}
// ProcessDefaults assigns default values to all fields that have default values and
// are not already populated.
func (p *Parser) ProcessDefaults() error {

View File

@ -38,6 +38,7 @@ func TestString(t *testing.T) {
_, err := parse(&args, "--foo bar --ptr baz")
require.NoError(t, err)
assert.Equal(t, "bar", args.Foo)
require.NotNil(t, args.Ptr)
assert.Equal(t, "baz", *args.Ptr)
}

View File

@ -27,7 +27,7 @@ func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
case reflect.Map:
return setMap(dest, values, clear)
default:
return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t)
return fmt.Errorf("cannot insert multiple values into a %v", t)
}
}
@ -121,3 +121,98 @@ func setMap(dest reflect.Value, values []string, clear bool) error {
}
return nil
}
// appendSliceOrMap parses a string and appends it to an existing slice or map.
func appendToSliceOrMap(dest reflect.Value, value string) error {
if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
t := dest.Type()
if t.Kind() == reflect.Ptr {
dest = dest.Elem()
t = t.Elem()
}
switch t.Kind() {
case reflect.Slice:
return appendToSlice(dest, value)
case reflect.Map:
return appendToMap(dest, value)
default:
return fmt.Errorf("cannot insert multiple values into a %v", t)
}
}
// appendSlice parses a string and appends the result into a slice.
func appendToSlice(dest reflect.Value, s string) error {
var ptr bool
elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
ptr = true
elem = elem.Elem()
}
// parse the value and append
v := reflect.New(elem)
if err := scalar.ParseValue(v.Elem(), s); err != nil {
return err
}
if !ptr {
v = v.Elem()
}
dest.Set(reflect.Append(dest, v))
return nil
}
// appendToMap parses a name=value string and inserts it into a map.
// If clear is true then any values already in the map are removed.
func appendToMap(dest reflect.Value, s string) error {
// determine the key and value type
var keyIsPtr bool
keyType := dest.Type().Key()
if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) {
keyIsPtr = true
keyType = keyType.Elem()
}
var valIsPtr bool
valType := dest.Type().Elem()
if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) {
valIsPtr = true
valType = valType.Elem()
}
// allocate the map if it is not allocated
if dest.IsNil() {
dest.Set(reflect.MakeMap(dest.Type()))
}
// split at the first equals sign
pos := strings.Index(s, "=")
if pos == -1 {
return fmt.Errorf("cannot parse %q into a map, expected format key=value", s)
}
// parse the key
k := reflect.New(keyType)
if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil {
return err
}
if !keyIsPtr {
k = k.Elem()
}
// parse the value
v := reflect.New(valType)
if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil {
return err
}
if !valIsPtr {
v = v.Elem()
}
// add it to the map
dest.SetMapIndex(k, v)
return nil
}

View File

@ -150,3 +150,7 @@ func TestSetSliceOrMapErrors(t *testing.T) {
err = setSliceOrMap(dest, nil, false)
assert.Error(t, err)
}
// check that we can accumulate "separate" args across env, cmdline, map, and defaults
// check what happens if we have a required arg with a default value