562 lines
15 KiB
Go
562 lines
15 KiB
Go
package arg
|
|
|
|
import (
|
|
"encoding/csv"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
|
|
scalar "github.com/alexflint/go-scalar"
|
|
)
|
|
|
|
// ErrHelp indicates that -h or --help were provided
|
|
var ErrHelp = errors.New("help requested by user")
|
|
|
|
// ErrVersion indicates that --version was provided
|
|
var ErrVersion = errors.New("version requested by user")
|
|
|
|
// MustParse processes command line arguments and exits upon failure
|
|
func MustParse(dest interface{}) *Parser {
|
|
p, err := NewParser(dest)
|
|
if err != nil {
|
|
fmt.Fprintln(stdout, err)
|
|
osExit(-1)
|
|
return nil // just in case osExit was monkey-patched
|
|
}
|
|
|
|
err = p.Parse(os.Args, os.Environ())
|
|
switch {
|
|
case err == ErrHelp:
|
|
p.writeHelpForSubcommand(stdout, p.leaf)
|
|
osExit(0)
|
|
case err == ErrVersion:
|
|
fmt.Fprintln(stdout, p.version)
|
|
osExit(0)
|
|
case err != nil:
|
|
p.failWithSubcommand(err.Error(), p.leaf)
|
|
}
|
|
|
|
return p
|
|
}
|
|
|
|
// Parse processes command line arguments and stores them in dest
|
|
func Parse(dest interface{}, options ...ParserOption) error {
|
|
p, err := NewParser(dest, options...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return p.Parse(os.Args, os.Environ())
|
|
}
|
|
|
|
// Parse processes the given command line option, storing the results in the field
|
|
// of the structs from which NewParser was constructed
|
|
func (p *Parser) Parse(args, env []string) error {
|
|
p.seen = make(map[*Argument]bool)
|
|
|
|
// If -h or --help were specified then make sure help text supercedes other errors
|
|
var help bool
|
|
for _, arg := range args {
|
|
if arg == "-h" || arg == "--help" {
|
|
help = true
|
|
}
|
|
if arg == "--" {
|
|
break
|
|
}
|
|
}
|
|
|
|
err := p.ProcessCommandLine(args)
|
|
if err != nil {
|
|
if help {
|
|
return ErrHelp
|
|
}
|
|
return err
|
|
}
|
|
|
|
err = p.ProcessEnvironment(env)
|
|
if err != nil {
|
|
if help {
|
|
return ErrHelp
|
|
}
|
|
return err
|
|
}
|
|
|
|
err = p.ProcessDefaults()
|
|
if err != nil {
|
|
if help {
|
|
return ErrHelp
|
|
}
|
|
return err
|
|
}
|
|
|
|
return p.Validate()
|
|
}
|
|
|
|
// ProcessCommandLine scans arguments one-by-one, parses them and assigns
|
|
// the result to fields of the struct passed to NewParser. It returns
|
|
// an error if an argument is invalid or unknown, but not if a
|
|
// required argument is missing. To check that all required arguments
|
|
// are set, call Validate(). This function ignores the first element
|
|
// of args, which is assumed to be the program name itself. This function
|
|
// never overwrites arguments previously seen in a call to any Process*
|
|
// function.
|
|
func (p *Parser) ProcessCommandLine(args []string) error {
|
|
if len(args) == 0 {
|
|
return nil
|
|
}
|
|
|
|
positionals, err := p.ProcessOptions(args[1:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return p.ProcessPositionals(positionals)
|
|
}
|
|
|
|
// OverwriteWithCommandLine is like ProcessCommandLine but it overwrites
|
|
// any previously seen values.
|
|
func (p *Parser) OverwriteWithCommandLine(args []string) error {
|
|
if len(args) == 0 {
|
|
return nil
|
|
}
|
|
|
|
positionals, err := p.OverwriteWithOptions(args[1:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return p.OverwriteWithPositionals(positionals)
|
|
}
|
|
|
|
// ProcessOptions processes options but not positionals from the
|
|
// command line. Positionals are returned and can be passed to
|
|
// ProcessPositionals. Arguments seen in a previous call to any
|
|
// Process* or OverwriteWith* functions are ignored.
|
|
func (p *Parser) ProcessOptions(args []string) ([]string, error) {
|
|
return p.processOptions(args, false)
|
|
}
|
|
|
|
// OverwriteWithOptions is like ProcessOptions except previously seen
|
|
// arguments are overwritten
|
|
func (p *Parser) OverwriteWithOptions(args []string) ([]string, error) {
|
|
return p.processOptions(args, true)
|
|
}
|
|
|
|
func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) {
|
|
// note that p.cmd.args has already been copied into p.accessible in NewParser
|
|
|
|
// union of args for the chain of subcommands encountered so far
|
|
p.leaf = p.cmd
|
|
var allpositional bool
|
|
var positionals []string
|
|
|
|
// must use explicit for loop, not range, because we manipulate i inside the loop
|
|
for i := 0; i < len(args); i++ {
|
|
token := args[i]
|
|
|
|
// the "--" token indicates that all further tokens should be treated as positionals
|
|
if token == "--" {
|
|
allpositional = true
|
|
continue
|
|
}
|
|
|
|
// check whether this is a positional argument
|
|
if !isFlag(token) || allpositional {
|
|
// each subcommand can have either subcommands or positionals, but not both
|
|
if len(p.leaf.subcommands) == 0 {
|
|
positionals = append(positionals, token)
|
|
continue
|
|
}
|
|
|
|
// if we have a subcommand then make sure it is valid for the current context
|
|
subcmd := findSubcommand(p.leaf.subcommands, token)
|
|
if subcmd == nil {
|
|
return nil, fmt.Errorf("invalid subcommand: %s", token)
|
|
}
|
|
|
|
// instantiate the field to point to a new struct
|
|
v := p.val(subcmd.dest)
|
|
if v.IsNil() {
|
|
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
|
|
}
|
|
|
|
// add the new options to the set of allowed options
|
|
p.accessible = append(p.accessible, subcmd.args...)
|
|
p.leaf = subcmd
|
|
continue
|
|
}
|
|
|
|
// check for special --help and --version flags
|
|
switch token {
|
|
case "-h", "--help":
|
|
return nil, ErrHelp
|
|
case "--version":
|
|
return nil, ErrVersion
|
|
}
|
|
|
|
// check for an equals sign, as in "--foo=bar"
|
|
var value string
|
|
opt := strings.TrimLeft(token, "-")
|
|
if pos := strings.Index(opt, "="); pos != -1 {
|
|
value = opt[pos+1:]
|
|
opt = opt[:pos]
|
|
}
|
|
|
|
// look up the arg for this option (note that the "args" slice changes as
|
|
// we expand subcommands so it is better not to use a map)
|
|
arg := findOption(p.accessible, opt)
|
|
if arg == nil {
|
|
return nil, fmt.Errorf("unknown argument %s", token)
|
|
}
|
|
|
|
// for the case of multiple values, consume tokens until next --option
|
|
if arg.cardinality == multiple && !arg.separate {
|
|
var values []string
|
|
if value == "" {
|
|
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
|
|
values = append(values, args[i+1])
|
|
i++
|
|
}
|
|
} else {
|
|
values = append(values, value)
|
|
}
|
|
|
|
if err := p.processSequence(arg, values, overwrite); err != nil {
|
|
return nil, err
|
|
}
|
|
continue
|
|
}
|
|
|
|
// if it's a flag and it has no value then set the value to true
|
|
// use boolean because this takes account of TextUnmarshaler
|
|
if arg.cardinality == zero && value == "" {
|
|
value = "true"
|
|
}
|
|
|
|
// if we have something like "--foo" then the value is the next argument
|
|
if value == "" {
|
|
if i+1 == len(args) {
|
|
return nil, fmt.Errorf("missing value for %s", token)
|
|
}
|
|
if isFlag(args[i+1]) {
|
|
return nil, fmt.Errorf("missing value for %s", token)
|
|
}
|
|
value = args[i+1]
|
|
i++
|
|
}
|
|
|
|
// send the value to the argument
|
|
if err := p.processScalar(arg, value, overwrite); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return positionals, nil
|
|
}
|
|
|
|
// ProcessPositionals processes a list of positional arguments. If
|
|
// this list contains tokens that begin with a hyphen they will still be
|
|
// treated as positional arguments. Arguments seen in a previous call
|
|
// to any Process* or OverwriteWith* functions are ignored.
|
|
func (p *Parser) ProcessPositionals(positionals []string) error {
|
|
return p.processPositionals(positionals, false)
|
|
}
|
|
|
|
// OverwriteWithPositionals is like ProcessPositionals except previously
|
|
// seen arguments are overwritten.
|
|
func (p *Parser) OverwriteWithPositionals(positionals []string) error {
|
|
return p.processPositionals(positionals, true)
|
|
}
|
|
|
|
func (p *Parser) processPositionals(positionals []string, overwrite bool) error {
|
|
for _, arg := range p.accessible {
|
|
if !arg.positional {
|
|
continue
|
|
}
|
|
if len(positionals) == 0 {
|
|
break
|
|
}
|
|
if arg.cardinality == multiple {
|
|
if err := p.processSequence(arg, positionals, overwrite); err != nil {
|
|
return err
|
|
}
|
|
positionals = nil
|
|
break
|
|
}
|
|
|
|
if err := p.processScalar(arg, positionals[0], overwrite); err != nil {
|
|
return err
|
|
}
|
|
positionals = positionals[1:]
|
|
}
|
|
|
|
if len(positionals) > 0 {
|
|
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ProcessEnvironment processes environment variables from a list of strings
|
|
// of the form KEY=VALUE. You can pass in os.Environ(). It
|
|
// does not overwrite any fields with values already populated.
|
|
func (p *Parser) ProcessEnvironment(environ []string) error {
|
|
return p.processEnvironment(environ, false)
|
|
}
|
|
|
|
// OverwriteWithEnvironment processes environment variables from a list
|
|
// of strings of the form "KEY=VALUE". Any existing values are overwritten.
|
|
func (p *Parser) OverwriteWithEnvironment(environ []string) error {
|
|
return p.processEnvironment(environ, true)
|
|
}
|
|
|
|
// ProcessEnvironment processes environment variables from a list of strings
|
|
// of the form KEY=VALUE. You can pass in os.Environ(). It
|
|
// overwrites already-populated fields only if overwrite is true.
|
|
func (p *Parser) processEnvironment(environ []string, overwrite bool) error {
|
|
// parse the list of KEY=VAL strings in environ
|
|
env := make(map[string]string)
|
|
for _, s := range environ {
|
|
if i := strings.Index(s, "="); i >= 0 {
|
|
env[s[:i]] = s[i+1:]
|
|
}
|
|
}
|
|
|
|
// process arguments one-by-one
|
|
for _, arg := range p.accessible {
|
|
if arg.env == "" {
|
|
continue
|
|
}
|
|
|
|
value, found := env[arg.env]
|
|
if !found {
|
|
continue
|
|
}
|
|
|
|
if arg.cardinality == multiple && !arg.separate {
|
|
// expect a CSV string in an environment
|
|
// variable in the case of multiple values
|
|
var values []string
|
|
if len(strings.TrimSpace(value)) > 0 {
|
|
var err error
|
|
values, err = csv.NewReader(strings.NewReader(value)).Read()
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing CSV string from environment variable %s: %v", arg.env, err)
|
|
}
|
|
}
|
|
|
|
if err := p.processSequence(arg, values, overwrite); err != nil {
|
|
return fmt.Errorf("error processing environment variable %s: %v", arg.env, err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if err := p.processScalar(arg, value, overwrite); err != nil {
|
|
return fmt.Errorf("error processing environment variable %s: %v", arg.env, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ProcessDefaults assigns default values to all fields that have default values and
|
|
// are not already populated.
|
|
func (p *Parser) ProcessDefaults() error {
|
|
return p.processDefaults(false)
|
|
}
|
|
|
|
// OverwriteWithDefaults assigns default values to all fields that have default values,
|
|
// overwriting any previous value
|
|
func (p *Parser) OverwriteWithDefaults() error {
|
|
return p.processDefaults(true)
|
|
}
|
|
|
|
// processDefaults assigns default values to all fields in all expanded subcommands.
|
|
// If overwrite is true then it overwrites existing values.
|
|
func (p *Parser) processDefaults(overwrite bool) error {
|
|
for _, arg := range p.accessible {
|
|
if arg.defaultVal == "" {
|
|
continue
|
|
}
|
|
if err := p.processScalar(arg, arg.defaultVal, overwrite); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// processScalar parses a single argument, inserts it into the struct,
|
|
// and marks the argument as "seen" (unless the argument has been seen
|
|
// before and overwrite=false, in which case the value is ignored)
|
|
func (p *Parser) processScalar(arg *Argument, value string, overwrite bool) error {
|
|
if p.seen[arg] && !overwrite && !arg.separate {
|
|
return nil
|
|
}
|
|
|
|
name := strings.ToLower(arg.field.Name)
|
|
if arg.long != "" && !arg.positional {
|
|
name = "--" + arg.long
|
|
}
|
|
|
|
if arg.cardinality == multiple {
|
|
err := appendToSliceOrMap(p.val(arg.dest), value)
|
|
if err != nil {
|
|
return fmt.Errorf("error processing default value for %s: %v", name, err)
|
|
}
|
|
} else {
|
|
err := scalar.ParseValue(p.val(arg.dest), value)
|
|
if err != nil {
|
|
return fmt.Errorf("error processing default value for %s: %v", name, err)
|
|
}
|
|
}
|
|
|
|
p.seen[arg] = true
|
|
return nil
|
|
}
|
|
|
|
// processSequence parses a sequence argument, inserts it into the struct,
|
|
// and marks the argument as "seen" (unless the argument has been seen
|
|
// before and overwrite=false, in which case the value is ignored)
|
|
func (p *Parser) processSequence(arg *Argument, values []string, overwrite bool) error {
|
|
if p.seen[arg] && !overwrite && !arg.separate {
|
|
return nil
|
|
}
|
|
|
|
name := strings.ToLower(arg.field.Name)
|
|
if arg.long != "" && !arg.positional {
|
|
name = "--" + arg.long
|
|
}
|
|
|
|
if arg.cardinality != multiple {
|
|
panic(fmt.Sprintf("processSequence called for argument %s which has cardinality %v", arg.field.Name, arg.cardinality))
|
|
}
|
|
|
|
err := setSliceOrMap(p.val(arg.dest), values)
|
|
if err != nil {
|
|
return fmt.Errorf("error processing default value for %s: %v", name, err)
|
|
}
|
|
|
|
p.seen[arg] = true
|
|
return nil
|
|
}
|
|
|
|
// Missing returns a list of required arguments that were not provided
|
|
func (p *Parser) Missing() []*Argument {
|
|
var missing []*Argument
|
|
for _, arg := range p.accessible {
|
|
if arg.required && !p.seen[arg] {
|
|
missing = append(missing, arg)
|
|
}
|
|
}
|
|
return missing
|
|
}
|
|
|
|
// Validate returns an error if any required arguments were missing
|
|
func (p *Parser) Validate() error {
|
|
if missing := p.Missing(); len(missing) > 0 {
|
|
name := strings.ToLower(missing[0].field.Name)
|
|
if missing[0].long != "" && !missing[0].positional {
|
|
name = "--" + missing[0].long
|
|
}
|
|
|
|
if missing[0].env == "" {
|
|
return fmt.Errorf("%s is required", name)
|
|
}
|
|
return fmt.Errorf("%s is required (or environment variable %s)", name, missing[0].env)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// val returns a reflect.Value corresponding to the current value for the
|
|
// given path
|
|
func (p *Parser) val(dest path) reflect.Value {
|
|
v := p.root
|
|
for _, field := range dest.fields {
|
|
if v.Kind() == reflect.Ptr {
|
|
if v.IsNil() {
|
|
return reflect.Value{}
|
|
}
|
|
v = v.Elem()
|
|
}
|
|
|
|
v = v.FieldByIndex(field.Index)
|
|
}
|
|
return v
|
|
}
|
|
|
|
// path represents a sequence of steps to find the output location for an
|
|
// argument or subcommand in the final destination struct
|
|
type path struct {
|
|
fields []reflect.StructField // sequence of struct fields to traverse
|
|
}
|
|
|
|
// String gets a string representation of the given path
|
|
func (p path) String() string {
|
|
s := "args"
|
|
for _, f := range p.fields {
|
|
s += "." + f.Name
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Child gets a new path representing a child of this path.
|
|
func (p path) Child(f reflect.StructField) path {
|
|
// copy the entire slice of fields to avoid possible slice overwrite
|
|
subfields := make([]reflect.StructField, len(p.fields)+1)
|
|
copy(subfields, p.fields)
|
|
subfields[len(subfields)-1] = f
|
|
return path{
|
|
fields: subfields,
|
|
}
|
|
}
|
|
|
|
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
|
|
func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) {
|
|
walkFieldsImpl(t, visit, nil)
|
|
}
|
|
|
|
func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) {
|
|
for i := 0; i < t.NumField(); i++ {
|
|
field := t.Field(i)
|
|
field.Index = make([]int, len(path)+1)
|
|
copy(field.Index, append(path, i))
|
|
expand := visit(field, t)
|
|
if expand && field.Type.Kind() == reflect.Struct {
|
|
var subpath []int
|
|
if field.Anonymous {
|
|
subpath = append(path, i)
|
|
}
|
|
walkFieldsImpl(field.Type, visit, subpath)
|
|
}
|
|
}
|
|
}
|
|
|
|
// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--"
|
|
func isFlag(s string) bool {
|
|
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
|
|
}
|
|
|
|
// findOption finds an option from its name, or returns nil if no arg is found
|
|
func findOption(args []*Argument, name string) *Argument {
|
|
for _, arg := range args {
|
|
if arg.positional {
|
|
continue
|
|
}
|
|
if arg.long == name || arg.short == name {
|
|
return arg
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// findSubcommand finds a subcommand using its name, or returns nil if no subcommand is found
|
|
func findSubcommand(cmds []*Command, name string) *Command {
|
|
for _, cmd := range cmds {
|
|
if cmd.name == name {
|
|
return cmd
|
|
}
|
|
}
|
|
return nil
|
|
}
|