From 09d28e1195519df88f2606137f227aac6186ed09 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:00:42 -0700 Subject: [PATCH] split the parsing logic into ProcessEnvironment, ProcessCommandLine, ProcessOptions, ProcessPositions, ProcessDefaults --- go.sum | 2 - v2/parse.go | 470 ++++++++++++++++++++++++------------------ v2/parse_test.go | 173 +++++++--------- v2/subcommand.go | 8 +- v2/subcommand_test.go | 6 +- v2/usage.go | 130 ++++++------ v2/usage_test.go | 10 +- 7 files changed, 420 insertions(+), 379 deletions(-) diff --git a/go.sum b/go.sum index 5b536f9..385ca8f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/alexflint/go-scalar v1.1.0 h1:aaAouLLzI9TChcPXotr6gUhq+Scr8rl0P9P4PnltbhM= -github.com/alexflint/go-scalar v1.1.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/v2/parse.go b/v2/parse.go index 8e190f2..ce02bd4 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -39,8 +39,8 @@ func (p path) Child(f reflect.StructField) path { } } -// spec represents a command line option -type spec struct { +// Arg represents a command line argument +type Argument struct { dest path field reflect.StructField // the struct field from which this option was created long string // the --long form for this option, or empty if none @@ -55,14 +55,14 @@ type spec struct { placeholder string // name of the data in help } -// command represents a named subcommand, or the top-level command -type command struct { +// Command represents a named subcommand, or the top-level command +type Command struct { name string help string dest path - specs []*spec - subcommands []*command - parent *command + args []*Argument + subcommands []*Command + parent *Command } // ErrHelp indicates that -h or --help were provided @@ -80,16 +80,16 @@ func MustParse(dest interface{}) *Parser { return nil // just in case osExit was monkey-patched } - err = p.Parse(flags()) + err = p.Parse(os.Args, os.Environ()) switch { case err == ErrHelp: - p.writeHelpForSubcommand(stdout, p.lastCmd) + p.writeHelpForSubcommand(stdout, p.leaf) osExit(0) case err == ErrVersion: fmt.Fprintln(stdout, p.version) osExit(0) case err != nil: - p.failWithSubcommand(err.Error(), p.lastCmd) + p.failWithSubcommand(err.Error(), p.leaf) } return p @@ -101,41 +101,28 @@ func Parse(dest interface{}) error { if err != nil { return err } - return p.Parse(flags()) -} - -// flags gets all command line arguments other than the first (program name) -func flags() []string { - if len(os.Args) == 0 { // os.Args could be empty - return nil - } - return os.Args[1:] + return p.Parse(os.Args, os.Environ()) } // Config represents configuration options for an argument parser type Config struct { // Program is the name of the program used in the help text Program string - - // IgnoreEnv instructs the library not to read environment variables - IgnoreEnv bool - - // IgnoreDefault instructs the library not to reset the variables to the - // default values, including pointers to sub commands - IgnoreDefault bool } // Parser represents a set of command line options with destination values type Parser struct { - cmd *command - root reflect.Value // destination struct to fill will values - config Config - version string - description string - epilogue string + cmd *Command // the top-level command + root reflect.Value // destination struct to fill will values + config Config // configuration passed to NewParser + version string // version from the argument struct + prologue string // prologue for help text (from the argument struct) + epilogue string // epilogue for help text (from the argument struct) - // the following field changes during processing of command line arguments - lastCmd *command + // the following fields are updated during processing of command line arguments + leaf *Command // the subcommand we processed last + accumulatedArgs []*Argument // concatenation of the leaf subcommand's arguments plus all ancestors' arguments + seen map[*Argument]bool // the arguments we encountered while processing command line arguments } // Versioned is the interface that the destination struct should implement to @@ -198,8 +185,9 @@ func NewParser(config Config, dest interface{}) (*Parser, error) { // construct a parser p := Parser{ - cmd: &command{name: name}, + cmd: &Command{name: name}, config: config, + seen: make(map[*Argument]bool), } // make a list of roots @@ -217,28 +205,28 @@ func NewParser(config Config, dest interface{}) (*Parser, error) { } // add nonzero field values as defaults - for _, spec := range cmd.specs { - if v := p.val(spec.dest); v.IsValid() && !isZero(v) { + for _, arg := range cmd.args { + if v := p.val(arg.dest); v.IsValid() && !isZero(v) { if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { str, err := defaultVal.MarshalText() if err != nil { - return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) + return nil, fmt.Errorf("%v: error marshaling default value to string: %v", arg.dest, err) } - spec.defaultVal = string(str) + arg.defaultVal = string(str) } else { - spec.defaultVal = fmt.Sprintf("%v", v) + arg.defaultVal = fmt.Sprintf("%v", v) } } } - p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.args = append(p.cmd.args, cmd.args...) p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) if dest, ok := dest.(Versioned); ok { p.version = dest.Version() } if dest, ok := dest.(Described); ok { - p.description = dest.Description() + p.prologue = dest.Description() } if dest, ok := dest.(Epilogued); ok { p.epilogue = dest.Epilogue() @@ -247,7 +235,7 @@ func NewParser(config Config, dest interface{}) (*Parser, error) { return &p, nil } -func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { +func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { // commands can only be created from pointers to structs if t.Kind() != reflect.Ptr { return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", @@ -260,7 +248,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { dest, t.Kind()) } - cmd := command{ + cmd := Command{ name: name, dest: dest, } @@ -287,7 +275,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // duplicate the entire path to avoid slice overwrites subdest := dest.Child(field) - spec := spec{ + arg := Argument{ dest: subdest, field: field, long: strings.ToLower(field.Name), @@ -295,12 +283,12 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { help, exists := field.Tag.Lookup("help") if exists { - spec.help = help + arg.help = help } defaultVal, hasDefault := field.Tag.Lookup("default") if hasDefault { - spec.defaultVal = defaultVal + arg.defaultVal = defaultVal } // Look at the tag @@ -320,33 +308,33 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { case strings.HasPrefix(key, "---"): errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) case strings.HasPrefix(key, "--"): - spec.long = key[2:] + arg.long = key[2:] case strings.HasPrefix(key, "-"): if len(key) != 2 { errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", t.Name(), field.Name)) return false } - spec.short = key[1:] + arg.short = key[1:] case key == "required": if hasDefault { errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", t.Name(), field.Name)) return false } - spec.required = true + arg.required = true case key == "positional": - spec.positional = true + arg.positional = true case key == "separate": - spec.separate = true + arg.separate = true case key == "help": // deprecated - spec.help = value + arg.help = value case key == "env": // Use override name if provided if value != "" { - spec.env = value + arg.env = value } else { - spec.env = strings.ToUpper(field.Name) + arg.env = strings.ToUpper(field.Name) } case key == "subcommand": // decide on a name for the subcommand @@ -375,11 +363,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") if hasPlaceholder { - spec.placeholder = placeholder - } else if spec.long != "" { - spec.placeholder = strings.ToUpper(spec.long) + arg.placeholder = placeholder + } else if arg.long != "" { + arg.placeholder = strings.ToUpper(arg.long) } else { - spec.placeholder = strings.ToUpper(spec.field.Name) + arg.placeholder = strings.ToUpper(arg.field.Name) } // Check whether this field is supported. It's good to do this here rather than @@ -387,16 +375,16 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // fields will always fail regardless of whether the arguments it received // exercised those fields. if !isSubcommand { - cmd.specs = append(cmd.specs, &spec) + cmd.args = append(cmd.args, &arg) var err error - spec.cardinality, err = cardinalityOf(field.Type) + arg.cardinality, err = cardinalityOf(field.Type) if err != nil { errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())) return false } - if spec.cardinality == multiple && hasDefault { + if arg.cardinality == multiple && hasDefault { errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", t.Name(), field.Name)) return false @@ -413,8 +401,8 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // check that we don't have both positionals and subcommands var hasPositional bool - for _, spec := range cmd.specs { - if spec.positional { + for _, arg := range cmd.args { + if arg.positional { hasPositional = true } } @@ -427,35 +415,88 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // Parse processes the given command line option, storing the results in the field // of the structs from which NewParser was constructed -func (p *Parser) Parse(args []string) error { - err := p.process(args) - if err != nil { - // If -h or --help were specified then make sure help text supercedes other errors - for _, arg := range args { - if arg == "-h" || arg == "--help" { - return ErrHelp - } - if arg == "--" { - break - } +func (p *Parser) Parse(args, env []string) error { + p.seen = make(map[*Argument]bool) + + // If -h or --help were specified then make sure help text supercedes other errors + var help bool + for _, arg := range args { + if arg == "-h" || arg == "--help" { + help = true + } + if arg == "--" { + break } } - return err + + err := p.ProcessCommandLine(args) + if err != nil { + if help { + return ErrHelp + } + return err + } + + err = p.ProcessEnvironment(env) + if err != nil { + if help { + return ErrHelp + } + return err + } + + err = p.ProcessDefaults() + if err != nil { + if help { + return ErrHelp + } + return err + } + + return p.Validate() } -// process environment vars for the given arguments -func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { - for _, spec := range specs { - if spec.env == "" { +// ProcessEnvironment processes environment variables from a list of strings +// of the form KEY=VALUE. You can pass in os.Environ(). It +// does not overwrite any fields with values already populated. +func (p *Parser) ProcessEnvironment(environ []string) error { + return p.processEnvironment(environ, false) +} + +// OverwriteWithEnvironment processes environment variables from a list +// of strings of the form "KEY=VALUE". Any existing values are overwritten. +func (p *Parser) OverwriteWithEnvironment(environ []string) error { + return p.processEnvironment(environ, true) +} + +// ProcessEnvironment processes environment variables from a list of strings +// of the form KEY=VALUE. You can pass in os.Environ(). It +// overwrites already-populated fields only if overwrite is true. +func (p *Parser) processEnvironment(environ []string, overwrite bool) error { + // parse the list of KEY=VAL strings in environ + env := make(map[string]string) + for _, s := range environ { + if i := strings.Index(s, "="); i >= 0 { + env[s[:i]] = s[i+1:] + } + } + + // process arguments one-by-one + for _, arg := range p.accumulatedArgs { + if p.seen[arg] && !overwrite { continue } - value, found := os.LookupEnv(spec.env) + if arg.env == "" { + continue + } + + value, found := env[arg.env] if !found { continue } - if spec.cardinality == multiple { + if arg.cardinality == multiple { // expect a CSV string in an environment // variable in the case of multiple values var values []string @@ -465,74 +506,80 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error if err != nil { return fmt.Errorf( "error reading a CSV string from environment variable %s with multiple values: %v", - spec.env, + arg.env, err, ) } } - if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil { + if err = setSliceOrMap(p.val(arg.dest), values, !arg.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", - spec.env, + arg.env, err, ) } } else { - if err := scalar.ParseValue(p.val(spec.dest), value); err != nil { - return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) + if err := scalar.ParseValue(p.val(arg.dest), value); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) } } - wasPresent[spec] = true + + p.seen[arg] = true } return nil } -// process goes through arguments one-by-one, parses them, and assigns the result to -// the underlying struct field -func (p *Parser) process(args []string) error { - // track the options we have seen - wasPresent := make(map[*spec]bool) - - // union of specs for the chain of subcommands encountered so far - curCmd := p.cmd - p.lastCmd = curCmd - - // make a copy of the specs because we will add to this list each time we expand a subcommand - specs := make([]*spec, len(curCmd.specs)) - copy(specs, curCmd.specs) - - // deal with environment vars - if !p.config.IgnoreEnv { - err := p.captureEnvVars(specs, wasPresent) - if err != nil { - return err - } +// ProcessCommandLine goes through arguments one-by-one, parses them, +// and assigns the result to the underlying struct field. It returns +// an error if an argument is invalid or an option is unknown, not if a +// required argument is missing. To check that all required arguments +// are set, call CheckRequired(). This function ignores the first element +// of args, which is assumed to be the program name itself. +func (p *Parser) ProcessCommandLine(args []string) error { + positionals, err := p.ProcessOptions(args) + if err != nil { + return err } + return p.ProcessPositionals(positionals) +} + +// ProcessOptions process command line arguments but does not process +// positional arguments. Instead, it returns positionals. These can then +// be passed to ProcessPositionals. This function ignores the first element +// of args, which is assumed to be the program name itself. +func (p *Parser) ProcessOptions(args []string) ([]string, error) { + // union of args for the chain of subcommands encountered so far + curCmd := p.cmd + p.leaf = curCmd + + // we will add to this list each time we expand a subcommand + p.accumulatedArgs = make([]*Argument, len(curCmd.args)) + copy(p.accumulatedArgs, curCmd.args) // process each string from the command line var allpositional bool var positionals []string // must use explicit for loop, not range, because we manipulate i inside the loop - for i := 0; i < len(args); i++ { - arg := args[i] - if arg == "--" { + for i := 1; i < len(args); i++ { + token := args[i] + if token == "--" { allpositional = true continue } - if !isFlag(arg) || allpositional { + if !isFlag(token) || allpositional { // each subcommand can have either subcommands or positionals, but not both if len(curCmd.subcommands) == 0 { - positionals = append(positionals, arg) + positionals = append(positionals, token) continue } // if we have a subcommand then make sure it is valid for the current context - subcmd := findSubcommand(curCmd.subcommands, arg) + subcmd := findSubcommand(curCmd.subcommands, token) if subcmd == nil { - return fmt.Errorf("invalid subcommand: %s", arg) + return nil, fmt.Errorf("invalid subcommand: %s", token) } // instantiate the field to point to a new struct @@ -542,157 +589,186 @@ func (p *Parser) process(args []string) error { } // add the new options to the set of allowed options - specs = append(specs, subcmd.specs...) - - // capture environment vars for these new options - if !p.config.IgnoreEnv { - err := p.captureEnvVars(subcmd.specs, wasPresent) - if err != nil { - return err - } - } + p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) curCmd = subcmd - p.lastCmd = curCmd + p.leaf = curCmd continue } // check for special --help and --version flags - switch arg { + switch token { case "-h", "--help": - return ErrHelp + return nil, ErrHelp case "--version": - return ErrVersion + return nil, ErrVersion } // check for an equals sign, as in "--foo=bar" var value string - opt := strings.TrimLeft(arg, "-") + opt := strings.TrimLeft(token, "-") if pos := strings.Index(opt, "="); pos != -1 { value = opt[pos+1:] opt = opt[:pos] } - // lookup the spec for this option (note that the "specs" slice changes as + // look up the arg for this option (note that the "args" slice changes as // we expand subcommands so it is better not to use a map) - spec := findOption(specs, opt) - if spec == nil { - return fmt.Errorf("unknown argument %s", arg) + arg := findOption(p.accumulatedArgs, opt) + if arg == nil { + return nil, fmt.Errorf("unknown argument %s", token) } - wasPresent[spec] = true + p.seen[arg] = true // deal with the case of multiple values - if spec.cardinality == multiple { + if arg.cardinality == multiple { var values []string if value == "" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { values = append(values, args[i+1]) i++ - if spec.separate { + if arg.separate { break } } } else { values = append(values, value) } - err := setSliceOrMap(p.val(spec.dest), values, !spec.separate) + err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) if err != nil { - return fmt.Errorf("error processing %s: %v", arg, err) + return nil, fmt.Errorf("error processing %s: %v", token, err) } continue } // if it's a flag and it has no value then set the value to true // use boolean because this takes account of TextUnmarshaler - if spec.cardinality == zero && value == "" { + if arg.cardinality == zero && value == "" { value = "true" } // if we have something like "--foo" then the value is the next argument if value == "" { if i+1 == len(args) { - return fmt.Errorf("missing value for %s", arg) + return nil, fmt.Errorf("missing value for %s", token) } - if !nextIsNumeric(spec.field.Type, args[i+1]) && isFlag(args[i+1]) { - return fmt.Errorf("missing value for %s", arg) + if isFlag(args[i+1]) { + return nil, fmt.Errorf("missing value for %s", token) } value = args[i+1] i++ } - err := scalar.ParseValue(p.val(spec.dest), value) + p.seen[arg] = true + err := scalar.ParseValue(p.val(arg.dest), value) if err != nil { - return fmt.Errorf("error processing %s: %v", arg, err) + return nil, fmt.Errorf("error processing %s: %v", token, err) } } - // process positionals - for _, spec := range specs { - if !spec.positional { + return positionals, nil +} + +// ProcessPositionals processes a list of positional arguments. It is assumed +// that options such as --abc and --abc=123 have already been removed. If +// this list contains tokens that begin with a hyphen they will still be +// treated as positional arguments. +func (p *Parser) ProcessPositionals(positionals []string) error { + for _, arg := range p.accumulatedArgs { + if !arg.positional { continue } if len(positionals) == 0 { break } - wasPresent[spec] = true - if spec.cardinality == multiple { - err := setSliceOrMap(p.val(spec.dest), positionals, true) + p.seen[arg] = true + if arg.cardinality == multiple { + err := setSliceOrMap(p.val(arg.dest), positionals, true) if err != nil { - return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } positionals = nil } else { - err := scalar.ParseValue(p.val(spec.dest), positionals[0]) + err := scalar.ParseValue(p.val(arg.dest), positionals[0]) if err != nil { - return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } positionals = positionals[1:] } } + if len(positionals) > 0 { return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) } - // fill in defaults and check that all the required args were provided - for _, spec := range specs { - if wasPresent[spec] { - continue - } - - name := strings.ToLower(spec.field.Name) - if spec.long != "" && !spec.positional { - name = "--" + spec.long - } - - if spec.required { - msg := fmt.Sprintf("%s is required", name) - if spec.env != "" { - msg += " (or environment variable " + spec.env + ")" - } - return errors.New(msg) - } - if !p.config.IgnoreDefault && spec.defaultVal != "" { - err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) - if err != nil { - return fmt.Errorf("error processing default value for %s: %v", name, err) - } - } - } - return nil } -func nextIsNumeric(t reflect.Type, s string) bool { - switch t.Kind() { - case reflect.Ptr: - return nextIsNumeric(t.Elem(), s) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - v := reflect.New(t) - err := scalar.ParseValue(v, s) - return err == nil - default: - return false +// ProcessDefaults assigns default values to all fields that have default values and +// are not already populated. +func (p *Parser) ProcessDefaults() error { + return p.processDefaults(false) +} + +// OverwriteWithDefaults assigns default values to all fields that have default values, +// overwriting any previous value +func (p *Parser) OverwriteWithDefaults() error { + return p.processDefaults(true) +} + +// processDefaults assigns default values to all fields in all expanded subcommands. +// If overwrite is true then it overwrites existing values. +func (p *Parser) processDefaults(overwrite bool) error { + for _, arg := range p.accumulatedArgs { + if p.seen[arg] && !overwrite { + continue + } + + if arg.defaultVal == "" { + continue + } + + name := strings.ToLower(arg.field.Name) + if arg.long != "" && !arg.positional { + name = "--" + arg.long + } + + err := scalar.ParseValue(p.val(arg.dest), arg.defaultVal) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + p.seen[arg] = true } + + return nil +} + +// Missing returns a list of required arguments that were not provided +func (p *Parser) Missing() []*Argument { + var missing []*Argument + for _, arg := range p.accumulatedArgs { + if arg.required && !p.seen[arg] { + missing = append(missing, arg) + } + } + return missing +} + +// Validate returns an error if any required arguments were missing +func (p *Parser) Validate() error { + if missing := p.Missing(); len(missing) > 0 { + name := strings.ToLower(missing[0].field.Name) + if missing[0].long != "" && !missing[0].positional { + name = "--" + missing[0].long + } + + if missing[0].env == "" { + return fmt.Errorf("%s is required", name) + } + return fmt.Errorf("%s is required (or environment variable %s)", name, missing[0].env) + } + + return nil } // isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--" @@ -717,21 +793,21 @@ func (p *Parser) val(dest path) reflect.Value { return v } -// findOption finds an option from its name, or returns null if no spec is found -func findOption(specs []*spec, name string) *spec { - for _, spec := range specs { - if spec.positional { +// findOption finds an option from its name, or returns nil if no arg is found +func findOption(args []*Argument, name string) *Argument { + for _, arg := range args { + if arg.positional { continue } - if spec.long == name || spec.short == name { - return spec + if arg.long == name || arg.short == name { + return arg } } return nil } -// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found -func findSubcommand(cmds []*command, name string) *command { +// findSubcommand finds a subcommand using its name, or returns nil if no subcommand is found +func findSubcommand(cmds []*Command, name string) *Command { for _, cmd := range cmds { if cmd.name == name { return cmd diff --git a/v2/parse_test.go b/v2/parse_test.go index 4ea6bc4..b9d6948 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -2,7 +2,6 @@ package arg import ( "bytes" - "fmt" "net" "net/mail" "net/url" @@ -15,47 +14,29 @@ import ( "github.com/stretchr/testify/require" ) -func setenv(t *testing.T, name, val string) { - if err := os.Setenv(name, val); err != nil { - t.Error(err) - } -} - func parse(cmdline string, dest interface{}) error { _, err := pparse(cmdline, dest) return err } func pparse(cmdline string, dest interface{}) (*Parser, error) { - return parseWithEnv(cmdline, nil, dest) + return parseWithEnv(dest, cmdline) } -func parseWithEnv(cmdline string, env []string, dest interface{}) (*Parser, error) { +func parseWithEnv(dest interface{}, cmdline string, env ...string) (*Parser, error) { p, err := NewParser(Config{}, dest) if err != nil { return nil, err } // split the command line - var parts []string + tokens := []string{"program"} // first token is the program name if len(cmdline) > 0 { - parts = strings.Split(cmdline, " ") - } - - // split the environment vars - for _, s := range env { - pos := strings.Index(s, "=") - if pos == -1 { - return nil, fmt.Errorf("missing equals sign in %q", s) - } - err := os.Setenv(s[:pos], s[pos+1:]) - if err != nil { - return nil, err - } + tokens = append(tokens, strings.Split(cmdline, " ")...) } // execute the parser - return p, p.Parse(parts) + return p, p.Parse(tokens, env) } func TestString(t *testing.T) { @@ -97,9 +78,9 @@ func TestInt(t *testing.T) { func TestHexOctBin(t *testing.T) { var args struct { - Hex int - Oct int - Bin int + Hex int + Oct int + Bin int Underscored int } err := parse("--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456", &args) @@ -114,22 +95,18 @@ func TestNegativeInt(t *testing.T) { var args struct { Foo int } - err := parse("-foo -100", &args) + err := parse("-foo=-100", &args) require.NoError(t, err) assert.EqualValues(t, args.Foo, -100) } -func TestNegativeIntAndFloatAndTricks(t *testing.T) { +func TestNumericOptionName(t *testing.T) { var args struct { - Foo int - Bar float64 - N int `arg:"--100"` + N int `arg:"--100"` } - err := parse("-foo -100 -bar -60.14 -100 -100", &args) + err := parse("-100 6", &args) require.NoError(t, err) - assert.EqualValues(t, args.Foo, -100) - assert.EqualValues(t, args.Bar, -60.14) - assert.EqualValues(t, args.N, -100) + assert.EqualValues(t, args.N, 6) } func TestUint(t *testing.T) { @@ -211,6 +188,14 @@ func TestMixed(t *testing.T) { } func TestRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + err := parse("--foo=abc", &args) + require.NoError(t, err) +} + +func TestMissingRequired(t *testing.T) { var args struct { Foo string `arg:"required"` } @@ -218,7 +203,7 @@ func TestRequired(t *testing.T) { require.Error(t, err, "--foo is required") } -func TestRequiredWithEnv(t *testing.T) { +func TestMissingRequiredWithEnv(t *testing.T) { var args struct { Foo string `arg:"required,env:FOO"` } @@ -336,8 +321,7 @@ func TestNoLongName(t *testing.T) { ShortOnly string `arg:"-s,--"` EnvOnly string `arg:"--,env"` } - setenv(t, "ENVONLY", "TestVal") - err := parse("-s TestVal2", &args) + _, err := parseWithEnv(&args, "-s TestVal2", "ENVONLY=TestVal") assert.NoError(t, err) assert.Equal(t, "TestVal", args.EnvOnly) assert.Equal(t, "TestVal2", args.ShortOnly) @@ -482,15 +466,6 @@ func TestUnknownField(t *testing.T) { assert.Error(t, err) } -func TestMissingRequired(t *testing.T) { - var args struct { - Foo string `arg:"required"` - X []string `arg:"positional"` - } - err := parse("x", &args) - assert.Error(t, err) -} - func TestNonsenseKey(t *testing.T) { var args struct { X []string `arg:"positional, nonsense"` @@ -520,7 +495,7 @@ func TestNegativeValue(t *testing.T) { var args struct { Foo int } - err := parse("--foo -123", &args) + err := parse("--foo=-123", &args) require.NoError(t, err) assert.Equal(t, -123, args.Foo) } @@ -687,7 +662,7 @@ func TestEnvironmentVariable(t *testing.T) { var args struct { Foo string `arg:"env"` } - _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "", "FOO=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -696,7 +671,7 @@ func TestEnvironmentVariableNotPresent(t *testing.T) { var args struct { NotPresent string `arg:"env"` } - _, err := parseWithEnv("", nil, &args) + _, err := parseWithEnv(&args, "", "") require.NoError(t, err) assert.Equal(t, "", args.NotPresent) } @@ -705,16 +680,16 @@ func TestEnvironmentVariableOverrideName(t *testing.T) { var args struct { Foo string `arg:"env:BAZ"` } - _, err := parseWithEnv("", []string{"BAZ=bar"}, &args) + _, err := parseWithEnv(&args, "", "BAZ=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } -func TestEnvironmentVariableOverrideArgument(t *testing.T) { +func TestCommandLineSupercedesEnv(t *testing.T) { var args struct { Foo string `arg:"env"` } - _, err := parseWithEnv("--foo zzz", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "--foo zzz", "FOO=bar") require.NoError(t, err) assert.Equal(t, "zzz", args.Foo) } @@ -723,7 +698,7 @@ func TestEnvironmentVariableError(t *testing.T) { var args struct { Foo int `arg:"env"` } - _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "", "FOO=bar") assert.Error(t, err) } @@ -731,7 +706,7 @@ func TestEnvironmentVariableRequired(t *testing.T) { var args struct { Foo string `arg:"env,required"` } - _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "", "FOO=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -740,7 +715,7 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) { var args struct { Foo []string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=bar,"baz, qux"`}, &args) + _, err := parseWithEnv(&args, "", `FOO=bar,"baz, qux"`) require.NoError(t, err) assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) } @@ -749,7 +724,7 @@ func TestEnvironmentVariableSliceEmpty(t *testing.T) { var args struct { Foo []string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=`}, &args) + _, err := parseWithEnv(&args, "", `FOO=`) require.NoError(t, err) assert.Len(t, args.Foo, 0) } @@ -758,7 +733,7 @@ func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { var args struct { Foo []int `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1,99`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1,99`) require.NoError(t, err) assert.Equal(t, []int{1, 99}, args.Foo) } @@ -767,7 +742,7 @@ func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { var args struct { Foo []float32 `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1.1,99.9`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1.1,99.9`) require.NoError(t, err) assert.Equal(t, []float32{1.1, 99.9}, args.Foo) } @@ -776,7 +751,7 @@ func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=true,false,0,1`}, &args) + _, err := parseWithEnv(&args, "", `FOO=true,false,0,1`) require.NoError(t, err) assert.Equal(t, []bool{true, false, false, true}, args.Foo) } @@ -785,7 +760,7 @@ func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { var args struct { Foo []int `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1,99\"`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1,99\"`) assert.Error(t, err) } @@ -793,7 +768,7 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=one,two`}, &args) + _, err := parseWithEnv(&args, "", `FOO=one,two`) assert.Error(t, err) } @@ -801,7 +776,7 @@ func TestEnvironmentVariableMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1=one,99=ninetynine`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1=one,99=ninetynine`) require.NoError(t, err) assert.Len(t, args.Foo, 2) assert.Equal(t, "one", args.Foo[1]) @@ -812,51 +787,61 @@ func TestEnvironmentVariableEmptyMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=`}, &args) + _, err := parseWithEnv(&args, "", `FOO=`) require.NoError(t, err) assert.Len(t, args.Foo, 0) } -func TestEnvironmentVariableIgnored(t *testing.T) { - var args struct { - Foo string `arg:"env"` - } - setenv(t, "FOO", "abc") +// func TestEnvironmentVariableIgnored(t *testing.T) { +// var args struct { +// Foo string `arg:"env"` +// } +// setenv(t, "FOO", "abc") - p, err := NewParser(Config{IgnoreEnv: true}, &args) - require.NoError(t, err) +// p, err := NewParser(Config{IgnoreEnv: true}, &args) +// require.NoError(t, err) - err = p.Parse(nil) - assert.NoError(t, err) - assert.Equal(t, "", args.Foo) -} +// err = p.Parse(nil) +// assert.NoError(t, err) +// assert.Equal(t, "", args.Foo) +// } -func TestDefaultValuesIgnored(t *testing.T) { - var args struct { - Foo string `default:"bad"` - } +// func TestDefaultValuesIgnored(t *testing.T) { +// var args struct { +// Foo string `default:"bad"` +// } - p, err := NewParser(Config{IgnoreDefault: true}, &args) - require.NoError(t, err) +// p, err := NewParser(Config{IgnoreDefault: true}, &args) +// require.NoError(t, err) - err = p.Parse(nil) - assert.NoError(t, err) - assert.Equal(t, "", args.Foo) -} +// err = p.Parse(nil) +// assert.NoError(t, err) +// assert.Equal(t, "", args.Foo) +// } -func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) { +func TestEnvironmentVariableInSubcommand(t *testing.T) { var args struct { Sub *struct { - Foo string `arg:"env"` + Foo string `arg:"env:FOO"` } `arg:"subcommand"` } - setenv(t, "FOO", "abc") - p, err := NewParser(Config{IgnoreEnv: true}, &args) + _, err := parseWithEnv(&args, "sub", "FOO=abc") require.NoError(t, err) + require.NotNil(t, args.Sub) + assert.Equal(t, "abc", args.Sub.Foo) +} - err = p.Parse([]string{"sub"}) - assert.NoError(t, err) +func TestEnvironmentVariableInSubcommandEmpty(t *testing.T) { + var args struct { + Sub *struct { + Foo string `arg:"env:FOO"` + } `arg:"subcommand"` + } + + _, err := parseWithEnv(&args, "sub") + require.NoError(t, err) + require.NotNil(t, args.Sub) assert.Equal(t, "", args.Sub.Foo) } @@ -1305,11 +1290,11 @@ func TestReuseParser(t *testing.T) { p, err := NewParser(Config{}, &args) require.NoError(t, err) - err = p.Parse([]string{"--foo=abc"}) + err = p.Parse([]string{"program", "--foo=abc"}, nil) require.NoError(t, err) assert.Equal(t, args.Foo, "abc") - err = p.Parse([]string{}) + err = p.Parse([]string{}, nil) assert.Error(t, err) } diff --git a/v2/subcommand.go b/v2/subcommand.go index dff732c..03fe399 100644 --- a/v2/subcommand.go +++ b/v2/subcommand.go @@ -7,22 +7,22 @@ package arg // no command line arguments have been processed by this parser then it // returns nil. func (p *Parser) Subcommand() interface{} { - if p.lastCmd == nil || p.lastCmd.parent == nil { + if p.leaf == nil || p.leaf.parent == nil { return nil } - return p.val(p.lastCmd.dest).Interface() + return p.val(p.leaf.dest).Interface() } // SubcommandNames returns the sequence of subcommands specified by the // user. If no subcommands were given then it returns an empty slice. func (p *Parser) SubcommandNames() []string { - if p.lastCmd == nil { + if p.leaf == nil { return nil } // make a list of ancestor commands var ancestors []string - cur := p.lastCmd + cur := p.leaf for cur.parent != nil { // we want to exclude the root ancestors = append(ancestors, cur.name) cur = cur.parent diff --git a/v2/subcommand_test.go b/v2/subcommand_test.go index 2c61dd3..9f7c8c5 100644 --- a/v2/subcommand_test.go +++ b/v2/subcommand_test.go @@ -206,8 +206,7 @@ func TestSubcommandsWithEnvVars(t *testing.T) { { var args cmd - setenv(t, "LIMIT", "123") - err := parse("list", &args) + _, err := parseWithEnv(&args, "list", "LIMIT=123") require.NoError(t, err) require.NotNil(t, args.List) assert.Equal(t, 123, args.List.Limit) @@ -215,8 +214,7 @@ func TestSubcommandsWithEnvVars(t *testing.T) { { var args cmd - setenv(t, "LIMIT", "not_an_integer") - err := parse("list", &args) + _, err := parseWithEnv(&args, "list", "LIMIT=not_an_integer") assert.Error(t, err) } } diff --git a/v2/usage.go b/v2/usage.go index 7ba06cc..bdc118d 100644 --- a/v2/usage.go +++ b/v2/usage.go @@ -38,19 +38,15 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { } // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status -func (p *Parser) failWithSubcommand(msg string, cmd *command) { +func (p *Parser) failWithSubcommand(msg string, cmd *Command) { p.writeUsageForSubcommand(stderr, cmd) fmt.Fprintln(stderr, "error:", msg) osExit(-1) } -// WriteUsage writes usage information to the given writer +// WriteUsage writes usage information for the top-level command func (p *Parser) WriteUsage(w io.Writer) { - cmd := p.cmd - if p.lastCmd != nil { - cmd = p.lastCmd - } - p.writeUsageForSubcommand(w, cmd) + p.writeUsageForSubcommand(w, p.cmd) } // WriteUsageForSubcommand writes the usage information for a specified @@ -68,16 +64,16 @@ func (p *Parser) WriteUsageForSubcommand(w io.Writer, subcommand ...string) erro } // writeUsageForSubcommand writes usage information for the given subcommand -func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { - var positionals, longOptions, shortOptions []*spec - for _, spec := range cmd.specs { +func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *Command) { + var positionals, longOptions, shortOptions []*Argument + for _, arg := range cmd.args { switch { - case spec.positional: - positionals = append(positionals, spec) - case spec.long != "": - longOptions = append(longOptions, spec) - case spec.short != "": - shortOptions = append(shortOptions, spec) + case arg.positional: + positionals = append(positionals, arg) + case arg.long != "": + longOptions = append(longOptions, arg) + case arg.short != "": + shortOptions = append(shortOptions, arg) } } @@ -100,26 +96,26 @@ func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { } // write the option component of the usage message - for _, spec := range shortOptions { + for _, arg := range shortOptions { // prefix with a space fmt.Fprint(w, " ") - if !spec.required { + if !arg.required { fmt.Fprint(w, "[") } - fmt.Fprint(w, synopsis(spec, "-"+spec.short)) - if !spec.required { + fmt.Fprint(w, synopsis(arg, "-"+arg.short)) + if !arg.required { fmt.Fprint(w, "]") } } - for _, spec := range longOptions { + for _, arg := range longOptions { // prefix with a space fmt.Fprint(w, " ") - if !spec.required { + if !arg.required { fmt.Fprint(w, "[") } - fmt.Fprint(w, synopsis(spec, "--"+spec.long)) - if !spec.required { + fmt.Fprint(w, synopsis(arg, "--"+arg.long)) + if !arg.required { fmt.Fprint(w, "]") } } @@ -137,16 +133,16 @@ func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { // REQUIRED1 REQUIRED2 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]] // REQUIRED1 REQUIRED2 [OPTIONAL1 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]]] var closeBrackets int - for _, spec := range positionals { + for _, arg := range positionals { fmt.Fprint(w, " ") - if !spec.required { + if !arg.required { fmt.Fprint(w, "[") closeBrackets += 1 } - if spec.cardinality == multiple { - fmt.Fprintf(w, "%s [%s ...]", spec.placeholder, spec.placeholder) + if arg.cardinality == multiple { + fmt.Fprintf(w, "%s [%s ...]", arg.placeholder, arg.placeholder) } else { - fmt.Fprint(w, spec.placeholder) + fmt.Fprint(w, arg.placeholder) } } fmt.Fprint(w, strings.Repeat("]", closeBrackets)) @@ -191,13 +187,9 @@ func printTwoCols(w io.Writer, left, help string, defaultVal string, envVal stri fmt.Fprint(w, "\n") } -// WriteHelp writes the usage string followed by the full help string for each option +// WriteHelp writes the usage string for the top-level command func (p *Parser) WriteHelp(w io.Writer) { - cmd := p.cmd - if p.lastCmd != nil { - cmd = p.lastCmd - } - p.writeHelpForSubcommand(w, cmd) + p.writeHelpForSubcommand(w, p.cmd) } // WriteHelpForSubcommand writes the usage string followed by the full help @@ -215,68 +207,68 @@ func (p *Parser) WriteHelpForSubcommand(w io.Writer, subcommand ...string) error } // writeHelp writes the usage string for the given subcommand -func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { - var positionals, longOptions, shortOptions []*spec - for _, spec := range cmd.specs { +func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *Command) { + var positionals, longOptions, shortOptions []*Argument + for _, arg := range cmd.args { switch { - case spec.positional: - positionals = append(positionals, spec) - case spec.long != "": - longOptions = append(longOptions, spec) - case spec.short != "": - shortOptions = append(shortOptions, spec) + case arg.positional: + positionals = append(positionals, arg) + case arg.long != "": + longOptions = append(longOptions, arg) + case arg.short != "": + shortOptions = append(shortOptions, arg) } } - if p.description != "" { - fmt.Fprintln(w, p.description) + if p.prologue != "" { + fmt.Fprintln(w, p.prologue) } p.writeUsageForSubcommand(w, cmd) // write the list of positionals if len(positionals) > 0 { fmt.Fprint(w, "\nPositional arguments:\n") - for _, spec := range positionals { - printTwoCols(w, spec.placeholder, spec.help, "", "") + for _, arg := range positionals { + printTwoCols(w, arg.placeholder, arg.help, "", "") } } // write the list of options with the short-only ones first to match the usage string if len(shortOptions)+len(longOptions) > 0 || cmd.parent == nil { fmt.Fprint(w, "\nOptions:\n") - for _, spec := range shortOptions { - p.printOption(w, spec) + for _, arg := range shortOptions { + p.printOption(w, arg) } - for _, spec := range longOptions { - p.printOption(w, spec) + for _, arg := range longOptions { + p.printOption(w, arg) } } // obtain a flattened list of options from all ancestors - var globals []*spec + var globals []*Argument ancestor := cmd.parent for ancestor != nil { - globals = append(globals, ancestor.specs...) + globals = append(globals, ancestor.args...) ancestor = ancestor.parent } // write the list of global options if len(globals) > 0 { fmt.Fprint(w, "\nGlobal options:\n") - for _, spec := range globals { - p.printOption(w, spec) + for _, arg := range globals { + p.printOption(w, arg) } } // write the list of built in options - p.printOption(w, &spec{ + p.printOption(w, &Argument{ cardinality: zero, long: "help", short: "h", help: "display this help and exit", }) if p.version != "" { - p.printOption(w, &spec{ + p.printOption(w, &Argument{ cardinality: zero, long: "version", help: "display version and exit", @@ -296,16 +288,16 @@ func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { } } -func (p *Parser) printOption(w io.Writer, spec *spec) { +func (p *Parser) printOption(w io.Writer, arg *Argument) { ways := make([]string, 0, 2) - if spec.long != "" { - ways = append(ways, synopsis(spec, "--"+spec.long)) + if arg.long != "" { + ways = append(ways, synopsis(arg, "--"+arg.long)) } - if spec.short != "" { - ways = append(ways, synopsis(spec, "-"+spec.short)) + if arg.short != "" { + ways = append(ways, synopsis(arg, "-"+arg.short)) } if len(ways) > 0 { - printTwoCols(w, strings.Join(ways, ", "), spec.help, spec.defaultVal, spec.env) + printTwoCols(w, strings.Join(ways, ", "), arg.help, arg.defaultVal, arg.env) } } @@ -314,10 +306,10 @@ func (p *Parser) printOption(w io.Writer, spec *spec) { // subcommand of that subcommand, and so on. If no strings are given then the // root command is returned. If no such subcommand exists then an error is // returned. -func (p *Parser) lookupCommand(path ...string) (*command, error) { +func (p *Parser) lookupCommand(path ...string) (*Command, error) { cmd := p.cmd for _, name := range path { - var found *command + var found *Command for _, child := range cmd.subcommands { if child.name == name { found = child @@ -331,9 +323,9 @@ func (p *Parser) lookupCommand(path ...string) (*command, error) { return cmd, nil } -func synopsis(spec *spec, form string) string { - if spec.cardinality == zero { +func synopsis(arg *Argument, form string) string { + if arg.cardinality == zero { return form } - return form + " " + spec.placeholder + return form + " " + arg.placeholder } diff --git a/v2/usage_test.go b/v2/usage_test.go index fd67fc8..b306506 100644 --- a/v2/usage_test.go +++ b/v2/usage_test.go @@ -443,20 +443,12 @@ Global options: p, err := NewParser(Config{}, &args) require.NoError(t, err) - _ = p.Parse([]string{"child", "nested", "value"}) - - var help bytes.Buffer - p.WriteHelp(&help) - assert.Equal(t, expectedHelp[1:], help.String()) + _ = p.Parse([]string{"child", "nested", "value"}, nil) var help2 bytes.Buffer p.WriteHelpForSubcommand(&help2, "child", "nested") assert.Equal(t, expectedHelp[1:], help2.String()) - var usage bytes.Buffer - p.WriteUsage(&usage) - assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) - var usage2 bytes.Buffer p.WriteUsageForSubcommand(&usage2, "child", "nested") assert.Equal(t, expectedUsage, strings.TrimSpace(usage2.String()))