go-arg/v2/parse.go

562 lines
15 KiB
Go

package arg
import (
"encoding/csv"
"errors"
"fmt"
"os"
"reflect"
"strings"
scalar "github.com/alexflint/go-scalar"
)
// ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user")
// ErrVersion indicates that --version was provided
var ErrVersion = errors.New("version requested by user")
// MustParse processes command line arguments and exits upon failure
func MustParse(dest interface{}) *Parser {
p, err := NewParser(dest)
if err != nil {
fmt.Fprintln(stdout, err)
osExit(-1)
return nil // just in case osExit was monkey-patched
}
err = p.Parse(os.Args, os.Environ())
switch {
case err == ErrHelp:
p.writeHelpForSubcommand(stdout, p.leaf)
osExit(0)
case err == ErrVersion:
fmt.Fprintln(stdout, p.version)
osExit(0)
case err != nil:
p.failWithSubcommand(err.Error(), p.leaf)
}
return p
}
// Parse processes command line arguments and stores them in dest
func Parse(dest interface{}, options ...ParserOption) error {
p, err := NewParser(dest, options...)
if err != nil {
return err
}
return p.Parse(os.Args, os.Environ())
}
// Parse processes the given command line option, storing the results in the field
// of the structs from which NewParser was constructed
func (p *Parser) Parse(args, env []string) error {
p.seen = make(map[*Argument]bool)
// If -h or --help were specified then make sure help text supercedes other errors
var help bool
for _, arg := range args {
if arg == "-h" || arg == "--help" {
help = true
}
if arg == "--" {
break
}
}
err := p.ProcessCommandLine(args)
if err != nil {
if help {
return ErrHelp
}
return err
}
err = p.ProcessEnvironment(env)
if err != nil {
if help {
return ErrHelp
}
return err
}
err = p.ProcessDefaults()
if err != nil {
if help {
return ErrHelp
}
return err
}
return p.Validate()
}
// ProcessCommandLine scans arguments one-by-one, parses them and assigns
// the result to fields of the struct passed to NewParser. It returns
// an error if an argument is invalid or unknown, but not if a
// required argument is missing. To check that all required arguments
// are set, call Validate(). This function ignores the first element
// of args, which is assumed to be the program name itself. This function
// never overwrites arguments previously seen in a call to any Process*
// function.
func (p *Parser) ProcessCommandLine(args []string) error {
if len(args) == 0 {
return nil
}
positionals, err := p.ProcessOptions(args[1:])
if err != nil {
return err
}
return p.ProcessPositionals(positionals)
}
// OverwriteWithCommandLine is like ProcessCommandLine but it overwrites
// any previously seen values.
func (p *Parser) OverwriteWithCommandLine(args []string) error {
if len(args) == 0 {
return nil
}
positionals, err := p.OverwriteWithOptions(args[1:])
if err != nil {
return err
}
return p.OverwriteWithPositionals(positionals)
}
// ProcessOptions processes options but not positionals from the
// command line. Positionals are returned and can be passed to
// ProcessPositionals. Arguments seen in a previous call to any
// Process* or OverwriteWith* functions are ignored.
func (p *Parser) ProcessOptions(args []string) ([]string, error) {
return p.processOptions(args, false)
}
// OverwriteWithOptions is like ProcessOptions except previously seen
// arguments are overwritten
func (p *Parser) OverwriteWithOptions(args []string) ([]string, error) {
return p.processOptions(args, true)
}
func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) {
// note that p.cmd.args has already been copied into p.accessible in NewParser
// union of args for the chain of subcommands encountered so far
p.leaf = p.cmd
var allpositional bool
var positionals []string
// must use explicit for loop, not range, because we manipulate i inside the loop
for i := 0; i < len(args); i++ {
token := args[i]
// the "--" token indicates that all further tokens should be treated as positionals
if token == "--" {
allpositional = true
continue
}
// check whether this is a positional argument
if !isFlag(token) || allpositional {
// each subcommand can have either subcommands or positionals, but not both
if len(p.leaf.subcommands) == 0 {
positionals = append(positionals, token)
continue
}
// if we have a subcommand then make sure it is valid for the current context
subcmd := findSubcommand(p.leaf.subcommands, token)
if subcmd == nil {
return nil, fmt.Errorf("invalid subcommand: %s", token)
}
// instantiate the field to point to a new struct
v := p.val(subcmd.dest)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
}
// add the new options to the set of allowed options
p.accessible = append(p.accessible, subcmd.args...)
p.leaf = subcmd
continue
}
// check for special --help and --version flags
switch token {
case "-h", "--help":
return nil, ErrHelp
case "--version":
return nil, ErrVersion
}
// check for an equals sign, as in "--foo=bar"
var value string
opt := strings.TrimLeft(token, "-")
if pos := strings.Index(opt, "="); pos != -1 {
value = opt[pos+1:]
opt = opt[:pos]
}
// look up the arg for this option (note that the "args" slice changes as
// we expand subcommands so it is better not to use a map)
arg := findOption(p.accessible, opt)
if arg == nil {
return nil, fmt.Errorf("unknown argument %s", token)
}
// for the case of multiple values, consume tokens until next --option
if arg.cardinality == multiple && !arg.separate {
var values []string
if value == "" {
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
values = append(values, args[i+1])
i++
}
} else {
values = append(values, value)
}
if err := p.processSequence(arg, values, overwrite); err != nil {
return nil, err
}
continue
}
// if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler
if arg.cardinality == zero && value == "" {
value = "true"
}
// if we have something like "--foo" then the value is the next argument
if value == "" {
if i+1 == len(args) {
return nil, fmt.Errorf("missing value for %s", token)
}
if isFlag(args[i+1]) {
return nil, fmt.Errorf("missing value for %s", token)
}
value = args[i+1]
i++
}
// send the value to the argument
if err := p.processScalar(arg, value, overwrite); err != nil {
return nil, err
}
}
return positionals, nil
}
// ProcessPositionals processes a list of positional arguments. If
// this list contains tokens that begin with a hyphen they will still be
// treated as positional arguments. Arguments seen in a previous call
// to any Process* or OverwriteWith* functions are ignored.
func (p *Parser) ProcessPositionals(positionals []string) error {
return p.processPositionals(positionals, false)
}
// OverwriteWithPositionals is like ProcessPositionals except previously
// seen arguments are overwritten.
func (p *Parser) OverwriteWithPositionals(positionals []string) error {
return p.processPositionals(positionals, true)
}
func (p *Parser) processPositionals(positionals []string, overwrite bool) error {
for _, arg := range p.accessible {
if !arg.positional {
continue
}
if len(positionals) == 0 {
break
}
if arg.cardinality == multiple {
if err := p.processSequence(arg, positionals, overwrite); err != nil {
return err
}
positionals = nil
break
}
if err := p.processScalar(arg, positionals[0], overwrite); err != nil {
return err
}
positionals = positionals[1:]
}
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
return nil
}
// ProcessEnvironment processes environment variables from a list of strings
// of the form KEY=VALUE. You can pass in os.Environ(). It
// does not overwrite any fields with values already populated.
func (p *Parser) ProcessEnvironment(environ []string) error {
return p.processEnvironment(environ, false)
}
// OverwriteWithEnvironment processes environment variables from a list
// of strings of the form "KEY=VALUE". Any existing values are overwritten.
func (p *Parser) OverwriteWithEnvironment(environ []string) error {
return p.processEnvironment(environ, true)
}
// ProcessEnvironment processes environment variables from a list of strings
// of the form KEY=VALUE. You can pass in os.Environ(). It
// overwrites already-populated fields only if overwrite is true.
func (p *Parser) processEnvironment(environ []string, overwrite bool) error {
// parse the list of KEY=VAL strings in environ
env := make(map[string]string)
for _, s := range environ {
if i := strings.Index(s, "="); i >= 0 {
env[s[:i]] = s[i+1:]
}
}
// process arguments one-by-one
for _, arg := range p.accessible {
if arg.env == "" {
continue
}
value, found := env[arg.env]
if !found {
continue
}
if arg.cardinality == multiple && !arg.separate {
// expect a CSV string in an environment
// variable in the case of multiple values
var values []string
if len(strings.TrimSpace(value)) > 0 {
var err error
values, err = csv.NewReader(strings.NewReader(value)).Read()
if err != nil {
return fmt.Errorf("error parsing CSV string from environment variable %s: %v", arg.env, err)
}
}
if err := p.processSequence(arg, values, overwrite); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", arg.env, err)
}
continue
}
if err := p.processScalar(arg, value, overwrite); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", arg.env, err)
}
}
return nil
}
// ProcessDefaults assigns default values to all fields that have default values and
// are not already populated.
func (p *Parser) ProcessDefaults() error {
return p.processDefaults(false)
}
// OverwriteWithDefaults assigns default values to all fields that have default values,
// overwriting any previous value
func (p *Parser) OverwriteWithDefaults() error {
return p.processDefaults(true)
}
// processDefaults assigns default values to all fields in all expanded subcommands.
// If overwrite is true then it overwrites existing values.
func (p *Parser) processDefaults(overwrite bool) error {
for _, arg := range p.accessible {
if arg.defaultVal == "" {
continue
}
if err := p.processScalar(arg, arg.defaultVal, overwrite); err != nil {
return err
}
}
return nil
}
// processScalar parses a single argument, inserts it into the struct,
// and marks the argument as "seen" (unless the argument has been seen
// before and overwrite=false, in which case the value is ignored)
func (p *Parser) processScalar(arg *Argument, value string, overwrite bool) error {
if p.seen[arg] && !overwrite && !arg.separate {
return nil
}
name := strings.ToLower(arg.field.Name)
if arg.long != "" && !arg.positional {
name = "--" + arg.long
}
if arg.cardinality == multiple {
err := appendToSliceOrMap(p.val(arg.dest), value)
if err != nil {
return fmt.Errorf("error processing default value for %s: %v", name, err)
}
} else {
err := scalar.ParseValue(p.val(arg.dest), value)
if err != nil {
return fmt.Errorf("error processing default value for %s: %v", name, err)
}
}
p.seen[arg] = true
return nil
}
// processSequence parses a sequence argument, inserts it into the struct,
// and marks the argument as "seen" (unless the argument has been seen
// before and overwrite=false, in which case the value is ignored)
func (p *Parser) processSequence(arg *Argument, values []string, overwrite bool) error {
if p.seen[arg] && !overwrite && !arg.separate {
return nil
}
name := strings.ToLower(arg.field.Name)
if arg.long != "" && !arg.positional {
name = "--" + arg.long
}
if arg.cardinality != multiple {
panic(fmt.Sprintf("processSequence called for argument %s which has cardinality %v", arg.field.Name, arg.cardinality))
}
err := setSliceOrMap(p.val(arg.dest), values)
if err != nil {
return fmt.Errorf("error processing default value for %s: %v", name, err)
}
p.seen[arg] = true
return nil
}
// Missing returns a list of required arguments that were not provided
func (p *Parser) Missing() []*Argument {
var missing []*Argument
for _, arg := range p.accessible {
if arg.required && !p.seen[arg] {
missing = append(missing, arg)
}
}
return missing
}
// Validate returns an error if any required arguments were missing
func (p *Parser) Validate() error {
if missing := p.Missing(); len(missing) > 0 {
name := strings.ToLower(missing[0].field.Name)
if missing[0].long != "" && !missing[0].positional {
name = "--" + missing[0].long
}
if missing[0].env == "" {
return fmt.Errorf("%s is required", name)
}
return fmt.Errorf("%s is required (or environment variable %s)", name, missing[0].env)
}
return nil
}
// val returns a reflect.Value corresponding to the current value for the
// given path
func (p *Parser) val(dest path) reflect.Value {
v := p.root
for _, field := range dest.fields {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return reflect.Value{}
}
v = v.Elem()
}
v = v.FieldByIndex(field.Index)
}
return v
}
// path represents a sequence of steps to find the output location for an
// argument or subcommand in the final destination struct
type path struct {
fields []reflect.StructField // sequence of struct fields to traverse
}
// String gets a string representation of the given path
func (p path) String() string {
s := "args"
for _, f := range p.fields {
s += "." + f.Name
}
return s
}
// Child gets a new path representing a child of this path.
func (p path) Child(f reflect.StructField) path {
// copy the entire slice of fields to avoid possible slice overwrite
subfields := make([]reflect.StructField, len(p.fields)+1)
copy(subfields, p.fields)
subfields[len(subfields)-1] = f
return path{
fields: subfields,
}
}
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) {
walkFieldsImpl(t, visit, nil)
}
func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
field.Index = make([]int, len(path)+1)
copy(field.Index, append(path, i))
expand := visit(field, t)
if expand && field.Type.Kind() == reflect.Struct {
var subpath []int
if field.Anonymous {
subpath = append(path, i)
}
walkFieldsImpl(field.Type, visit, subpath)
}
}
}
// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--"
func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
// findOption finds an option from its name, or returns nil if no arg is found
func findOption(args []*Argument, name string) *Argument {
for _, arg := range args {
if arg.positional {
continue
}
if arg.long == name || arg.short == name {
return arg
}
}
return nil
}
// findSubcommand finds a subcommand using its name, or returns nil if no subcommand is found
func findSubcommand(cmds []*Command, name string) *Command {
for _, cmd := range cmds {
if cmd.name == name {
return cmd
}
}
return nil
}