diff --git a/v2/doc.go b/v2/doc.go new file mode 100644 index 0000000..3b0bafd --- /dev/null +++ b/v2/doc.go @@ -0,0 +1,39 @@ +// Package arg parses command line arguments using the fields from a struct. +// +// For example, +// +// var args struct { +// Iter int +// Debug bool +// } +// arg.MustParse(&args) +// +// defines two command line arguments, which can be set using any of +// +// ./example --iter=1 --debug // debug is a boolean flag so its value is set to true +// ./example -iter 1 // debug defaults to its zero value (false) +// ./example --debug=true // iter defaults to its zero value (zero) +// +// The fastest way to see how to use go-arg is to read the examples below. +// +// Fields can be bool, string, any float type, or any signed or unsigned integer type. +// They can also be slices of any of the above, or slices of pointers to any of the above. +// +// Tags can be specified using the `arg` and `help` tag names: +// +// var args struct { +// Input string `arg:"positional"` +// Log string `arg:"positional,required"` +// Debug bool `arg:"-d" help:"turn on debug mode"` +// RealMode bool `arg:"--real" +// Wr io.Writer `arg:"-"` +// } +// +// Any tag string that starts with a single hyphen is the short form for an argument +// (e.g. `./example -d`), and any tag string that starts with two hyphens is the long +// form for the argument (instead of the field name). +// +// Other valid tag strings are `positional` and `required`. +// +// Fields can be excluded from processing with `arg:"-"`. +package arg diff --git a/v2/example_test.go b/v2/example_test.go new file mode 100644 index 0000000..fd64777 --- /dev/null +++ b/v2/example_test.go @@ -0,0 +1,507 @@ +package arg + +import ( + "fmt" + "net" + "net/mail" + "net/url" + "os" + "strings" + "time" +) + +func split(s string) []string { + return strings.Split(s, " ") +} + +// This example demonstrates basic usage +func Example() { + // These are the args you would pass in on the command line + os.Args = split("./example --foo=hello --bar") + + var args struct { + Foo string + Bar bool + } + MustParse(&args) + fmt.Println(args.Foo, args.Bar) + // output: hello true +} + +// This example demonstrates arguments that have default values +func Example_defaultValues() { + // These are the args you would pass in on the command line + os.Args = split("./example") + + var args struct { + Foo string `default:"abc"` + } + MustParse(&args) + fmt.Println(args.Foo) + // output: abc +} + +// This example demonstrates arguments that are required +func Example_requiredArguments() { + // These are the args you would pass in on the command line + os.Args = split("./example --foo=abc --bar") + + var args struct { + Foo string `arg:"required"` + Bar bool + } + MustParse(&args) + fmt.Println(args.Foo, args.Bar) + // output: abc true +} + +// This example demonstrates positional arguments +func Example_positionalArguments() { + // These are the args you would pass in on the command line + os.Args = split("./example in out1 out2 out3") + + var args struct { + Input string `arg:"positional"` + Output []string `arg:"positional"` + } + MustParse(&args) + fmt.Println("In:", args.Input) + fmt.Println("Out:", args.Output) + // output: + // In: in + // Out: [out1 out2 out3] +} + +// This example demonstrates arguments that have multiple values +func Example_multipleValues() { + // The args you would pass in on the command line + os.Args = split("./example --database localhost --ids 1 2 3") + + var args struct { + Database string + IDs []int64 + } + MustParse(&args) + fmt.Printf("Fetching the following IDs from %s: %v", args.Database, args.IDs) + // output: Fetching the following IDs from localhost: [1 2 3] +} + +// This example demonstrates arguments with keys and values +func Example_mappings() { + // The args you would pass in on the command line + os.Args = split("./example --userids john=123 mary=456") + + var args struct { + UserIDs map[string]int + } + MustParse(&args) + fmt.Println(args.UserIDs) + // output: map[john:123 mary:456] +} + +type commaSeparated struct { + M map[string]string +} + +func (c *commaSeparated) UnmarshalText(b []byte) error { + c.M = make(map[string]string) + for _, part := range strings.Split(string(b), ",") { + pos := strings.Index(part, "=") + if pos == -1 { + return fmt.Errorf("error parsing %q, expected format key=value", part) + } + c.M[part[:pos]] = part[pos+1:] + } + return nil +} + +// This example demonstrates arguments with keys and values separated by commas +func Example_mappingWithCommas() { + // The args you would pass in on the command line + os.Args = split("./example --values one=two,three=four") + + var args struct { + Values commaSeparated + } + MustParse(&args) + fmt.Println(args.Values.M) + // output: map[one:two three:four] +} + +// This eample demonstrates multiple value arguments that can be mixed with +// other arguments. +func Example_multipleMixed() { + os.Args = split("./example -c cmd1 db1 -f file1 db2 -c cmd2 -f file2 -f file3 db3 -c cmd3") + var args struct { + Commands []string `arg:"-c,separate"` + Files []string `arg:"-f,separate"` + Databases []string `arg:"positional"` + } + MustParse(&args) + fmt.Println("Commands:", args.Commands) + fmt.Println("Files:", args.Files) + fmt.Println("Databases:", args.Databases) + + // output: + // Commands: [cmd1 cmd2 cmd3] + // Files: [file1 file2 file3] + // Databases: [db1 db2 db3] +} + +// This example shows the usage string generated by go-arg +func Example_helpText() { + // These are the args you would pass in on the command line + os.Args = split("./example --help") + + var args struct { + Input string `arg:"positional,required"` + Output []string `arg:"positional"` + Verbose bool `arg:"-v" help:"verbosity level"` + Dataset string `help:"dataset to use"` + Optimize int `arg:"-O,--optim" help:"optimization level"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + // Usage: example [--verbose] [--dataset DATASET] [--optim OPTIM] INPUT [OUTPUT [OUTPUT ...]] + // + // Positional arguments: + // INPUT + // OUTPUT + // + // Options: + // --verbose, -v verbosity level + // --dataset DATASET dataset to use + // --optim OPTIM, -O OPTIM + // optimization level + // --help, -h display this help and exit +} + +// This example shows the usage string generated by go-arg with customized placeholders +func Example_helpPlaceholder() { + // These are the args you would pass in on the command line + os.Args = split("./example --help") + + var args struct { + Input string `arg:"positional,required" placeholder:"SRC"` + Output []string `arg:"positional" placeholder:"DST"` + Optimize int `arg:"-O" help:"optimization level" placeholder:"LEVEL"` + MaxJobs int `arg:"-j" help:"maximum number of simultaneous jobs" placeholder:"N"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + + // Usage: example [--optimize LEVEL] [--maxjobs N] SRC [DST [DST ...]] + + // Positional arguments: + // SRC + // DST + + // Options: + // --optimize LEVEL, -O LEVEL + // optimization level + // --maxjobs N, -j N maximum number of simultaneous jobs + // --help, -h display this help and exit +} + +// This example shows the usage string generated by go-arg when using subcommands +func Example_helpTextWithSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example --help") + + type getCmd struct { + Item string `arg:"positional" help:"item to fetch"` + } + + type listCmd struct { + Format string `help:"output format"` + Limit int + } + + var args struct { + Verbose bool + Get *getCmd `arg:"subcommand" help:"fetch an item and print it"` + List *listCmd `arg:"subcommand" help:"list available items"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + // Usage: example [--verbose] [] + // + // Options: + // --verbose + // --help, -h display this help and exit + // + // Commands: + // get fetch an item and print it + // list list available items +} + +// This example shows the usage string generated by go-arg when using subcommands +func Example_helpTextWhenUsingSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example get --help") + + type getCmd struct { + Item string `arg:"positional,required" help:"item to fetch"` + } + + type listCmd struct { + Format string `help:"output format"` + Limit int + } + + var args struct { + Verbose bool + Get *getCmd `arg:"subcommand" help:"fetch an item and print it"` + List *listCmd `arg:"subcommand" help:"list available items"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + // Usage: example get ITEM + // + // Positional arguments: + // ITEM item to fetch + // + // Global options: + // --verbose + // --help, -h display this help and exit +} + +// This example shows how to print help for an explicit subcommand +func Example_writeHelpForSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example get --help") + + type getCmd struct { + Item string `arg:"positional" help:"item to fetch"` + } + + type listCmd struct { + Format string `help:"output format"` + Limit int + } + + var args struct { + Verbose bool + Get *getCmd `arg:"subcommand" help:"fetch an item and print it"` + List *listCmd `arg:"subcommand" help:"list available items"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + p, err := NewParser(Config{}, &args) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + err = p.WriteHelpForSubcommand(os.Stdout, "list") + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + // output: + // Usage: example list [--format FORMAT] [--limit LIMIT] + // + // Options: + // --format FORMAT output format + // --limit LIMIT + // + // Global options: + // --verbose + // --help, -h display this help and exit +} + +// This example shows how to print help for a subcommand that is nested several levels deep +func Example_writeHelpForSubcommandNested() { + // These are the args you would pass in on the command line + os.Args = split("./example get --help") + + type mostNestedCmd struct { + Item string + } + + type nestedCmd struct { + MostNested *mostNestedCmd `arg:"subcommand"` + } + + type topLevelCmd struct { + Nested *nestedCmd `arg:"subcommand"` + } + + var args struct { + TopLevel *topLevelCmd `arg:"subcommand"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + p, err := NewParser(Config{}, &args) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + err = p.WriteHelpForSubcommand(os.Stdout, "toplevel", "nested", "mostnested") + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + // output: + // Usage: example toplevel nested mostnested [--item ITEM] + // + // Options: + // --item ITEM + // --help, -h display this help and exit +} + +// This example shows the error string generated by go-arg when an invalid option is provided +func Example_errorText() { + // These are the args you would pass in on the command line + os.Args = split("./example --optimize INVALID") + + var args struct { + Input string `arg:"positional,required"` + Output []string `arg:"positional"` + Verbose bool `arg:"-v" help:"verbosity level"` + Dataset string `help:"dataset to use"` + Optimize int `arg:"-O,help:optimization level"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stderr = os.Stdout + + MustParse(&args) + + // output: + // Usage: example [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] INPUT [OUTPUT [OUTPUT ...]] + // error: error processing --optimize: strconv.ParseInt: parsing "INVALID": invalid syntax +} + +// This example shows the error string generated by go-arg when an invalid option is provided +func Example_errorTextForSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example get --count INVALID") + + type getCmd struct { + Count int + } + + var args struct { + Get *getCmd `arg:"subcommand"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stderr = os.Stdout + + MustParse(&args) + + // output: + // Usage: example get [--count COUNT] + // error: error processing --count: strconv.ParseInt: parsing "INVALID": invalid syntax +} + +// This example demonstrates use of subcommands +func Example_subcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example commit -a -m what-this-commit-is-about") + + type CheckoutCmd struct { + Branch string `arg:"positional"` + Track bool `arg:"-t"` + } + type CommitCmd struct { + All bool `arg:"-a"` + Message string `arg:"-m"` + } + type PushCmd struct { + Remote string `arg:"positional"` + Branch string `arg:"positional"` + SetUpstream bool `arg:"-u"` + } + var args struct { + Checkout *CheckoutCmd `arg:"subcommand:checkout"` + Commit *CommitCmd `arg:"subcommand:commit"` + Push *PushCmd `arg:"subcommand:push"` + Quiet bool `arg:"-q"` // this flag is global to all subcommands + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stderr = os.Stdout + + MustParse(&args) + + switch { + case args.Checkout != nil: + fmt.Printf("checkout requested for branch %s\n", args.Checkout.Branch) + case args.Commit != nil: + fmt.Printf("commit requested with message \"%s\"\n", args.Commit.Message) + case args.Push != nil: + fmt.Printf("push requested from %s to %s\n", args.Push.Branch, args.Push.Remote) + } + + // output: + // commit requested with message "what-this-commit-is-about" +} + +func Example_allSupportedTypes() { + // These are the args you would pass in on the command line + os.Args = []string{} + + var args struct { + Bool bool + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Float32 float32 + Float64 float64 + String string + Duration time.Duration + URL url.URL + Email mail.Address + MAC net.HardwareAddr + } + + // go-arg supports each of the types above, as well as pointers to any of + // the above and slices of any of the above. It also supports any types that + // implements encoding.TextUnmarshaler. + + MustParse(&args) + + // output: +} diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..7e575a8 --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,8 @@ +module github.com/alexflint/go-arg/v2 + +require ( + github.com/alexflint/go-scalar v1.2.0 + github.com/stretchr/testify v1.7.0 +) + +go 1.13 diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..385ca8f --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,15 @@ +github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= +github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/v2/parse.go b/v2/parse.go new file mode 100644 index 0000000..8e190f2 --- /dev/null +++ b/v2/parse.go @@ -0,0 +1,741 @@ +package arg + +import ( + "encoding" + "encoding/csv" + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + + scalar "github.com/alexflint/go-scalar" +) + +// path represents a sequence of steps to find the output location for an +// argument or subcommand in the final destination struct +type path struct { + fields []reflect.StructField // sequence of struct fields to traverse +} + +// String gets a string representation of the given path +func (p path) String() string { + s := "args" + for _, f := range p.fields { + s += "." + f.Name + } + return s +} + +// Child gets a new path representing a child of this path. +func (p path) Child(f reflect.StructField) path { + // copy the entire slice of fields to avoid possible slice overwrite + subfields := make([]reflect.StructField, len(p.fields)+1) + copy(subfields, p.fields) + subfields[len(subfields)-1] = f + return path{ + fields: subfields, + } +} + +// spec represents a command line option +type spec struct { + dest path + field reflect.StructField // the struct field from which this option was created + long string // the --long form for this option, or empty if none + short string // the -s short form for this option, or empty if none + cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) + required bool // if true, this option must be present on the command line + positional bool // if true, this option will be looked for in the positional flags + separate bool // if true, each slice and map entry will have its own --flag + help string // the help text for this option + env string // the name of the environment variable for this option, or empty for none + defaultVal string // default value for this option + placeholder string // name of the data in help +} + +// command represents a named subcommand, or the top-level command +type command struct { + name string + help string + dest path + specs []*spec + subcommands []*command + parent *command +} + +// ErrHelp indicates that -h or --help were provided +var ErrHelp = errors.New("help requested by user") + +// ErrVersion indicates that --version was provided +var ErrVersion = errors.New("version requested by user") + +// MustParse processes command line arguments and exits upon failure +func MustParse(dest interface{}) *Parser { + p, err := NewParser(Config{}, dest) + if err != nil { + fmt.Fprintln(stdout, err) + osExit(-1) + return nil // just in case osExit was monkey-patched + } + + err = p.Parse(flags()) + switch { + case err == ErrHelp: + p.writeHelpForSubcommand(stdout, p.lastCmd) + osExit(0) + case err == ErrVersion: + fmt.Fprintln(stdout, p.version) + osExit(0) + case err != nil: + p.failWithSubcommand(err.Error(), p.lastCmd) + } + + return p +} + +// Parse processes command line arguments and stores them in dest +func Parse(dest interface{}) error { + p, err := NewParser(Config{}, dest) + if err != nil { + return err + } + return p.Parse(flags()) +} + +// flags gets all command line arguments other than the first (program name) +func flags() []string { + if len(os.Args) == 0 { // os.Args could be empty + return nil + } + return os.Args[1:] +} + +// Config represents configuration options for an argument parser +type Config struct { + // Program is the name of the program used in the help text + Program string + + // IgnoreEnv instructs the library not to read environment variables + IgnoreEnv bool + + // IgnoreDefault instructs the library not to reset the variables to the + // default values, including pointers to sub commands + IgnoreDefault bool +} + +// Parser represents a set of command line options with destination values +type Parser struct { + cmd *command + root reflect.Value // destination struct to fill will values + config Config + version string + description string + epilogue string + + // the following field changes during processing of command line arguments + lastCmd *command +} + +// Versioned is the interface that the destination struct should implement to +// make a version string appear at the top of the help message. +type Versioned interface { + // Version returns the version string that will be printed on a line by itself + // at the top of the help message. + Version() string +} + +// Described is the interface that the destination struct should implement to +// make a description string appear at the top of the help message. +type Described interface { + // Description returns the string that will be printed on a line by itself + // at the top of the help message. + Description() string +} + +// Epilogued is the interface that the destination struct should implement to +// add an epilogue string at the bottom of the help message. +type Epilogued interface { + // Epilogue returns the string that will be printed on a line by itself + // at the end of the help message. + Epilogue() string +} + +// walkFields calls a function for each field of a struct, recursively expanding struct fields. +func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { + walkFieldsImpl(t, visit, nil) +} + +func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + field.Index = make([]int, len(path)+1) + copy(field.Index, append(path, i)) + expand := visit(field, t) + if expand && field.Type.Kind() == reflect.Struct { + var subpath []int + if field.Anonymous { + subpath = append(path, i) + } + walkFieldsImpl(field.Type, visit, subpath) + } + } +} + +// NewParser constructs a parser from a list of destination structs +func NewParser(config Config, dest interface{}) (*Parser, error) { + // first pick a name for the command for use in the usage text + var name string + switch { + case config.Program != "": + name = config.Program + case len(os.Args) > 0: + name = filepath.Base(os.Args[0]) + default: + name = "program" + } + + // construct a parser + p := Parser{ + cmd: &command{name: name}, + config: config, + } + + // make a list of roots + p.root = reflect.ValueOf(dest) + + // process each of the destination values + t := reflect.TypeOf(dest) + if t.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) + } + + cmd, err := cmdFromStruct(name, path{}, t) + if err != nil { + return nil, err + } + + // add nonzero field values as defaults + for _, spec := range cmd.specs { + if v := p.val(spec.dest); v.IsValid() && !isZero(v) { + if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { + str, err := defaultVal.MarshalText() + if err != nil { + return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) + } + spec.defaultVal = string(str) + } else { + spec.defaultVal = fmt.Sprintf("%v", v) + } + } + } + + p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) + + if dest, ok := dest.(Versioned); ok { + p.version = dest.Version() + } + if dest, ok := dest.(Described); ok { + p.description = dest.Description() + } + if dest, ok := dest.(Epilogued); ok { + p.epilogue = dest.Epilogue() + } + + return &p, nil +} + +func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { + // commands can only be created from pointers to structs + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", + dest, t.Kind()) + } + + t = t.Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", + dest, t.Kind()) + } + + cmd := command{ + name: name, + dest: dest, + } + + var errs []string + walkFields(t, func(field reflect.StructField, t reflect.Type) bool { + // check for the ignore switch in the tag + tag := field.Tag.Get("arg") + if tag == "-" { + return false + } + + // if this is an embedded struct then recurse into its fields, even if + // it is unexported, because exported fields on unexported embedded + // structs are still writable + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true + } + + // ignore any other unexported field + if !isExported(field.Name) { + return false + } + + // duplicate the entire path to avoid slice overwrites + subdest := dest.Child(field) + spec := spec{ + dest: subdest, + field: field, + long: strings.ToLower(field.Name), + } + + help, exists := field.Tag.Lookup("help") + if exists { + spec.help = help + } + + defaultVal, hasDefault := field.Tag.Lookup("default") + if hasDefault { + spec.defaultVal = defaultVal + } + + // Look at the tag + var isSubcommand bool // tracks whether this field is a subcommand + for _, key := range strings.Split(tag, ",") { + if key == "" { + continue + } + key = strings.TrimLeft(key, " ") + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } + + switch { + case strings.HasPrefix(key, "---"): + errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) + return false + } + spec.short = key[1:] + case key == "required": + if hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", + t.Name(), field.Name)) + return false + } + spec.required = true + case key == "positional": + spec.positional = true + case key == "separate": + spec.separate = true + case key == "help": // deprecated + spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } + case key == "subcommand": + // decide on a name for the subcommand + cmdname := value + if cmdname == "" { + cmdname = strings.ToLower(field.Name) + } + + // parse the subcommand recursively + subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) + if err != nil { + errs = append(errs, err.Error()) + return false + } + + subcmd.parent = &cmd + subcmd.help = field.Tag.Get("help") + + cmd.subcommands = append(cmd.subcommands, subcmd) + isSubcommand = true + default: + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false + } + } + + placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") + if hasPlaceholder { + spec.placeholder = placeholder + } else if spec.long != "" { + spec.placeholder = strings.ToUpper(spec.long) + } else { + spec.placeholder = strings.ToUpper(spec.field.Name) + } + + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + if !isSubcommand { + cmd.specs = append(cmd.specs, &spec) + + var err error + spec.cardinality, err = cardinalityOf(field.Type) + if err != nil { + errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", + t.Name(), field.Name, field.Type.String())) + return false + } + if spec.cardinality == multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", + t.Name(), field.Name)) + return false + } + } + + // if this was an embedded field then we already returned true up above + return false + }) + + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) + } + + // check that we don't have both positionals and subcommands + var hasPositional bool + for _, spec := range cmd.specs { + if spec.positional { + hasPositional = true + } + } + if hasPositional && len(cmd.subcommands) > 0 { + return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest) + } + + return &cmd, nil +} + +// Parse processes the given command line option, storing the results in the field +// of the structs from which NewParser was constructed +func (p *Parser) Parse(args []string) error { + err := p.process(args) + if err != nil { + // If -h or --help were specified then make sure help text supercedes other errors + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return ErrHelp + } + if arg == "--" { + break + } + } + } + return err +} + +// process environment vars for the given arguments +func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { + for _, spec := range specs { + if spec.env == "" { + continue + } + + value, found := os.LookupEnv(spec.env) + if !found { + continue + } + + if spec.cardinality == multiple { + // expect a CSV string in an environment + // variable in the case of multiple values + var values []string + var err error + if len(strings.TrimSpace(value)) > 0 { + values, err = csv.NewReader(strings.NewReader(value)).Read() + if err != nil { + return fmt.Errorf( + "error reading a CSV string from environment variable %s with multiple values: %v", + spec.env, + err, + ) + } + } + if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil { + return fmt.Errorf( + "error processing environment variable %s with multiple values: %v", + spec.env, + err, + ) + } + } else { + if err := scalar.ParseValue(p.val(spec.dest), value); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) + } + } + wasPresent[spec] = true + } + + return nil +} + +// process goes through arguments one-by-one, parses them, and assigns the result to +// the underlying struct field +func (p *Parser) process(args []string) error { + // track the options we have seen + wasPresent := make(map[*spec]bool) + + // union of specs for the chain of subcommands encountered so far + curCmd := p.cmd + p.lastCmd = curCmd + + // make a copy of the specs because we will add to this list each time we expand a subcommand + specs := make([]*spec, len(curCmd.specs)) + copy(specs, curCmd.specs) + + // deal with environment vars + if !p.config.IgnoreEnv { + err := p.captureEnvVars(specs, wasPresent) + if err != nil { + return err + } + } + + // process each string from the command line + var allpositional bool + var positionals []string + + // must use explicit for loop, not range, because we manipulate i inside the loop + for i := 0; i < len(args); i++ { + arg := args[i] + if arg == "--" { + allpositional = true + continue + } + + if !isFlag(arg) || allpositional { + // each subcommand can have either subcommands or positionals, but not both + if len(curCmd.subcommands) == 0 { + positionals = append(positionals, arg) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(curCmd.subcommands, arg) + if subcmd == nil { + return fmt.Errorf("invalid subcommand: %s", arg) + } + + // instantiate the field to point to a new struct + v := p.val(subcmd.dest) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers + } + + // add the new options to the set of allowed options + specs = append(specs, subcmd.specs...) + + // capture environment vars for these new options + if !p.config.IgnoreEnv { + err := p.captureEnvVars(subcmd.specs, wasPresent) + if err != nil { + return err + } + } + + curCmd = subcmd + p.lastCmd = curCmd + continue + } + + // check for special --help and --version flags + switch arg { + case "-h", "--help": + return ErrHelp + case "--version": + return ErrVersion + } + + // check for an equals sign, as in "--foo=bar" + var value string + opt := strings.TrimLeft(arg, "-") + if pos := strings.Index(opt, "="); pos != -1 { + value = opt[pos+1:] + opt = opt[:pos] + } + + // lookup the spec for this option (note that the "specs" slice changes as + // we expand subcommands so it is better not to use a map) + spec := findOption(specs, opt) + if spec == nil { + return fmt.Errorf("unknown argument %s", arg) + } + wasPresent[spec] = true + + // deal with the case of multiple values + if spec.cardinality == multiple { + var values []string + if value == "" { + for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { + values = append(values, args[i+1]) + i++ + if spec.separate { + break + } + } + } else { + values = append(values, value) + } + err := setSliceOrMap(p.val(spec.dest), values, !spec.separate) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg, err) + } + continue + } + + // if it's a flag and it has no value then set the value to true + // use boolean because this takes account of TextUnmarshaler + if spec.cardinality == zero && value == "" { + value = "true" + } + + // if we have something like "--foo" then the value is the next argument + if value == "" { + if i+1 == len(args) { + return fmt.Errorf("missing value for %s", arg) + } + if !nextIsNumeric(spec.field.Type, args[i+1]) && isFlag(args[i+1]) { + return fmt.Errorf("missing value for %s", arg) + } + value = args[i+1] + i++ + } + + err := scalar.ParseValue(p.val(spec.dest), value) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg, err) + } + } + + // process positionals + for _, spec := range specs { + if !spec.positional { + continue + } + if len(positionals) == 0 { + break + } + wasPresent[spec] = true + if spec.cardinality == multiple { + err := setSliceOrMap(p.val(spec.dest), positionals, true) + if err != nil { + return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + } + positionals = nil + } else { + err := scalar.ParseValue(p.val(spec.dest), positionals[0]) + if err != nil { + return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + } + positionals = positionals[1:] + } + } + if len(positionals) > 0 { + return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) + } + + // fill in defaults and check that all the required args were provided + for _, spec := range specs { + if wasPresent[spec] { + continue + } + + name := strings.ToLower(spec.field.Name) + if spec.long != "" && !spec.positional { + name = "--" + spec.long + } + + if spec.required { + msg := fmt.Sprintf("%s is required", name) + if spec.env != "" { + msg += " (or environment variable " + spec.env + ")" + } + return errors.New(msg) + } + if !p.config.IgnoreDefault && spec.defaultVal != "" { + err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + } + } + + return nil +} + +func nextIsNumeric(t reflect.Type, s string) bool { + switch t.Kind() { + case reflect.Ptr: + return nextIsNumeric(t.Elem(), s) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + v := reflect.New(t) + err := scalar.ParseValue(v, s) + return err == nil + default: + return false + } +} + +// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--" +func isFlag(s string) bool { + return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" +} + +// val returns a reflect.Value corresponding to the current value for the +// given path +func (p *Parser) val(dest path) reflect.Value { + v := p.root + for _, field := range dest.fields { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + + v = v.FieldByIndex(field.Index) + } + return v +} + +// findOption finds an option from its name, or returns null if no spec is found +func findOption(specs []*spec, name string) *spec { + for _, spec := range specs { + if spec.positional { + continue + } + if spec.long == name || spec.short == name { + return spec + } + } + return nil +} + +// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found +func findSubcommand(cmds []*command, name string) *command { + for _, cmd := range cmds { + if cmd.name == name { + return cmd + } + } + return nil +} diff --git a/v2/parse_test.go b/v2/parse_test.go new file mode 100644 index 0000000..4ea6bc4 --- /dev/null +++ b/v2/parse_test.go @@ -0,0 +1,1486 @@ +package arg + +import ( + "bytes" + "fmt" + "net" + "net/mail" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setenv(t *testing.T, name, val string) { + if err := os.Setenv(name, val); err != nil { + t.Error(err) + } +} + +func parse(cmdline string, dest interface{}) error { + _, err := pparse(cmdline, dest) + return err +} + +func pparse(cmdline string, dest interface{}) (*Parser, error) { + return parseWithEnv(cmdline, nil, dest) +} + +func parseWithEnv(cmdline string, env []string, dest interface{}) (*Parser, error) { + p, err := NewParser(Config{}, dest) + if err != nil { + return nil, err + } + + // split the command line + var parts []string + if len(cmdline) > 0 { + parts = strings.Split(cmdline, " ") + } + + // split the environment vars + for _, s := range env { + pos := strings.Index(s, "=") + if pos == -1 { + return nil, fmt.Errorf("missing equals sign in %q", s) + } + err := os.Setenv(s[:pos], s[pos+1:]) + if err != nil { + return nil, err + } + } + + // execute the parser + return p, p.Parse(parts) +} + +func TestString(t *testing.T) { + var args struct { + Foo string + Ptr *string + } + err := parse("--foo bar --ptr baz", &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) + assert.Equal(t, "baz", *args.Ptr) +} + +func TestBool(t *testing.T) { + var args struct { + A bool + B bool + C *bool + D *bool + } + err := parse("--a --c", &args) + require.NoError(t, err) + assert.True(t, args.A) + assert.False(t, args.B) + assert.True(t, *args.C) + assert.Nil(t, args.D) +} + +func TestInt(t *testing.T) { + var args struct { + Foo int + Ptr *int + } + err := parse("--foo 7 --ptr 8", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) + assert.EqualValues(t, 8, *args.Ptr) +} + +func TestHexOctBin(t *testing.T) { + var args struct { + Hex int + Oct int + Bin int + Underscored int + } + err := parse("--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456", &args) + require.NoError(t, err) + assert.EqualValues(t, 10, args.Hex) + assert.EqualValues(t, 8, args.Oct) + assert.EqualValues(t, 5, args.Bin) + assert.EqualValues(t, 123456, args.Underscored) +} + +func TestNegativeInt(t *testing.T) { + var args struct { + Foo int + } + err := parse("-foo -100", &args) + require.NoError(t, err) + assert.EqualValues(t, args.Foo, -100) +} + +func TestNegativeIntAndFloatAndTricks(t *testing.T) { + var args struct { + Foo int + Bar float64 + N int `arg:"--100"` + } + err := parse("-foo -100 -bar -60.14 -100 -100", &args) + require.NoError(t, err) + assert.EqualValues(t, args.Foo, -100) + assert.EqualValues(t, args.Bar, -60.14) + assert.EqualValues(t, args.N, -100) +} + +func TestUint(t *testing.T) { + var args struct { + Foo uint + Ptr *uint + } + err := parse("--foo 7 --ptr 8", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) + assert.EqualValues(t, 8, *args.Ptr) +} + +func TestFloat(t *testing.T) { + var args struct { + Foo float32 + Ptr *float32 + } + err := parse("--foo 3.4 --ptr 3.5", &args) + require.NoError(t, err) + assert.EqualValues(t, 3.4, args.Foo) + assert.EqualValues(t, 3.5, *args.Ptr) +} + +func TestDuration(t *testing.T) { + var args struct { + Foo time.Duration + Ptr *time.Duration + } + err := parse("--foo 3ms --ptr 4ms", &args) + require.NoError(t, err) + assert.Equal(t, 3*time.Millisecond, args.Foo) + assert.Equal(t, 4*time.Millisecond, *args.Ptr) +} + +func TestInvalidDuration(t *testing.T) { + var args struct { + Foo time.Duration + } + err := parse("--foo xxx", &args) + require.Error(t, err) +} + +func TestIntPtr(t *testing.T) { + var args struct { + Foo *int + } + err := parse("--foo 123", &args) + require.NoError(t, err) + require.NotNil(t, args.Foo) + assert.Equal(t, 123, *args.Foo) +} + +func TestIntPtrNotPresent(t *testing.T) { + var args struct { + Foo *int + } + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.Foo) +} + +func TestMixed(t *testing.T) { + var args struct { + Foo string `arg:"-f"` + Bar int + Baz uint `arg:"positional"` + Ham bool + Spam float32 + } + args.Bar = 3 + err := parse("123 -spam=1.2 -ham -f xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + assert.Equal(t, 3, args.Bar) + assert.Equal(t, uint(123), args.Baz) + assert.Equal(t, true, args.Ham) + assert.EqualValues(t, 1.2, args.Spam) +} + +func TestRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + err := parse("", &args) + require.Error(t, err, "--foo is required") +} + +func TestRequiredWithEnv(t *testing.T) { + var args struct { + Foo string `arg:"required,env:FOO"` + } + err := parse("", &args) + require.Error(t, err, "--foo is required (or environment variable FOO)") +} + +func TestShortFlag(t *testing.T) { + var args struct { + Foo string `arg:"-f"` + } + + err := parse("-f xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = parse("-foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = parse("--foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestInvalidShortFlag(t *testing.T) { + var args struct { + Foo string `arg:"-foo"` + } + err := parse("", &args) + assert.Error(t, err) +} + +func TestLongFlag(t *testing.T) { + var args struct { + Foo string `arg:"--abc"` + } + + err := parse("-abc xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = parse("--abc xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestSlice(t *testing.T) { + var args struct { + Strings []string + } + err := parse("--strings a b c", &args) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, args.Strings) +} +func TestSliceOfBools(t *testing.T) { + var args struct { + B []bool + } + + err := parse("--b true false true", &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, args.B) +} + +func TestMap(t *testing.T) { + var args struct { + Values map[string]int + } + err := parse("--values a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestMapPositional(t *testing.T) { + var args struct { + Values map[string]int `arg:"positional"` + } + err := parse("a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestMapWithSeparate(t *testing.T) { + var args struct { + Values map[string]int `arg:"separate"` + } + err := parse("--values a=1 --values b=2 --values c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestPlaceholder(t *testing.T) { + var args struct { + Input string `arg:"positional" placeholder:"SRC"` + Output []string `arg:"positional" placeholder:"DST"` + Optimize int `arg:"-O" placeholder:"LEVEL"` + MaxJobs int `arg:"-j" placeholder:"N"` + } + err := parse("-O 5 --maxjobs 2 src dest1 dest2", &args) + assert.NoError(t, err) +} + +func TestNoLongName(t *testing.T) { + var args struct { + ShortOnly string `arg:"-s,--"` + EnvOnly string `arg:"--,env"` + } + setenv(t, "ENVONLY", "TestVal") + err := parse("-s TestVal2", &args) + assert.NoError(t, err) + assert.Equal(t, "TestVal", args.EnvOnly) + assert.Equal(t, "TestVal2", args.ShortOnly) +} + +func TestCaseSensitive(t *testing.T) { + var args struct { + Lower bool `arg:"-v"` + Upper bool `arg:"-V"` + } + + err := parse("-v", &args) + require.NoError(t, err) + assert.True(t, args.Lower) + assert.False(t, args.Upper) +} + +func TestCaseSensitive2(t *testing.T) { + var args struct { + Lower bool `arg:"-v"` + Upper bool `arg:"-V"` + } + + err := parse("-V", &args) + require.NoError(t, err) + assert.False(t, args.Lower) + assert.True(t, args.Upper) +} + +func TestPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional"` + } + err := parse("foo", &args) + require.NoError(t, err) + assert.Equal(t, "foo", args.Input) + assert.Equal(t, "", args.Output) +} + +func TestPositionalPointer(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output []*string `arg:"positional"` + } + err := parse("foo bar baz", &args) + require.NoError(t, err) + assert.Equal(t, "foo", args.Input) + bar := "bar" + baz := "baz" + assert.Equal(t, []*string{&bar, &baz}, args.Output) +} + +func TestRequiredPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional,required"` + } + err := parse("foo", &args) + assert.Error(t, err) +} + +func TestRequiredPositionalMultiple(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Multiple []string `arg:"positional,required"` + } + err := parse("foo", &args) + assert.Error(t, err) +} + +func TestTooManyPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional"` + } + err := parse("foo bar baz", &args) + assert.Error(t, err) +} + +func TestMultiple(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + err := parse("--foo 1 2 3 --bar x y z", &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x", "y", "z"}, args.Bar) +} + +func TestMultiplePositionals(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Multiple []string `arg:"positional,required"` + } + err := parse("foo a b c", &args) + assert.NoError(t, err) + assert.Equal(t, "foo", args.Input) + assert.Equal(t, []string{"a", "b", "c"}, args.Multiple) +} + +func TestMultipleWithEq(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + err := parse("--foo 1 2 3 --bar=x", &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x"}, args.Bar) +} + +func TestMultipleWithDefault(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + args.Foo = []int{42} + args.Bar = []string{"foo"} + err := parse("--foo 1 2 3 --bar x y z", &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x", "y", "z"}, args.Bar) +} + +func TestExemptField(t *testing.T) { + var args struct { + Foo string + Bar interface{} `arg:"-"` + } + err := parse("--foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestUnknownField(t *testing.T) { + var args struct { + Foo string + } + err := parse("--bar xyz", &args) + assert.Error(t, err) +} + +func TestMissingRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + X []string `arg:"positional"` + } + err := parse("x", &args) + assert.Error(t, err) +} + +func TestNonsenseKey(t *testing.T) { + var args struct { + X []string `arg:"positional, nonsense"` + } + err := parse("x", &args) + assert.Error(t, err) +} + +func TestMissingValueAtEnd(t *testing.T) { + var args struct { + Foo string + } + err := parse("--foo", &args) + assert.Error(t, err) +} + +func TestMissingValueInMiddle(t *testing.T) { + var args struct { + Foo string + Bar string + } + err := parse("--foo --bar=abc", &args) + assert.Error(t, err) +} + +func TestNegativeValue(t *testing.T) { + var args struct { + Foo int + } + err := parse("--foo -123", &args) + require.NoError(t, err) + assert.Equal(t, -123, args.Foo) +} + +func TestInvalidInt(t *testing.T) { + var args struct { + Foo int + } + err := parse("--foo=xyz", &args) + assert.Error(t, err) +} + +func TestInvalidUint(t *testing.T) { + var args struct { + Foo uint + } + err := parse("--foo=xyz", &args) + assert.Error(t, err) +} + +func TestInvalidFloat(t *testing.T) { + var args struct { + Foo float64 + } + err := parse("--foo xyz", &args) + require.Error(t, err) +} + +func TestInvalidBool(t *testing.T) { + var args struct { + Foo bool + } + err := parse("--foo=xyz", &args) + require.Error(t, err) +} + +func TestInvalidIntSlice(t *testing.T) { + var args struct { + Foo []int + } + err := parse("--foo 1 2 xyz", &args) + require.Error(t, err) +} + +func TestInvalidPositional(t *testing.T) { + var args struct { + Foo int `arg:"positional"` + } + err := parse("xyz", &args) + require.Error(t, err) +} + +func TestInvalidPositionalSlice(t *testing.T) { + var args struct { + Foo []int `arg:"positional"` + } + err := parse("1 2 xyz", &args) + require.Error(t, err) +} + +func TestNoMoreOptions(t *testing.T) { + var args struct { + Foo string + Bar []string `arg:"positional"` + } + err := parse("abc -- --foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "", args.Foo) + assert.Equal(t, []string{"abc", "--foo", "xyz"}, args.Bar) +} + +func TestNoMoreOptionsBeforeHelp(t *testing.T) { + var args struct { + Foo int + } + err := parse("not_an_integer -- --help", &args) + assert.NotEqual(t, ErrHelp, err) +} + +func TestHelpFlag(t *testing.T) { + var args struct { + Foo string + Bar interface{} `arg:"-"` + } + err := parse("--help", &args) + assert.Equal(t, ErrHelp, err) +} + +func TestPanicOnNonPointer(t *testing.T) { + var args struct{} + assert.Panics(t, func() { + _ = parse("", args) + }) +} + +func TestErrorOnNonStruct(t *testing.T) { + var args string + err := parse("", &args) + assert.Error(t, err) +} + +func TestUnsupportedType(t *testing.T) { + var args struct { + Foo interface{} + } + err := parse("--foo", &args) + assert.Error(t, err) +} + +func TestUnsupportedSliceElement(t *testing.T) { + var args struct { + Foo []interface{} + } + err := parse("--foo 3", &args) + assert.Error(t, err) +} + +func TestUnsupportedSliceElementMissingValue(t *testing.T) { + var args struct { + Foo []interface{} + } + err := parse("--foo", &args) + assert.Error(t, err) +} + +func TestUnknownTag(t *testing.T) { + var args struct { + Foo string `arg:"this_is_not_valid"` + } + err := parse("--foo xyz", &args) + assert.Error(t, err) +} + +func TestParse(t *testing.T) { + var args struct { + Foo string + } + os.Args = []string{"example", "--foo", "bar"} + err := Parse(&args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestParseError(t *testing.T) { + var args struct { + Foo string `arg:"this_is_not_valid"` + } + os.Args = []string{"example", "--bar"} + err := Parse(&args) + assert.Error(t, err) +} + +func TestMustParse(t *testing.T) { + var args struct { + Foo string + } + os.Args = []string{"example", "--foo", "bar"} + parser := MustParse(&args) + assert.Equal(t, "bar", args.Foo) + assert.NotNil(t, parser) +} + +func TestEnvironmentVariable(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableNotPresent(t *testing.T) { + var args struct { + NotPresent string `arg:"env"` + } + _, err := parseWithEnv("", nil, &args) + require.NoError(t, err) + assert.Equal(t, "", args.NotPresent) +} + +func TestEnvironmentVariableOverrideName(t *testing.T) { + var args struct { + Foo string `arg:"env:BAZ"` + } + _, err := parseWithEnv("", []string{"BAZ=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableOverrideArgument(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + _, err := parseWithEnv("--foo zzz", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "zzz", args.Foo) +} + +func TestEnvironmentVariableError(t *testing.T) { + var args struct { + Foo int `arg:"env"` + } + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + assert.Error(t, err) +} + +func TestEnvironmentVariableRequired(t *testing.T) { + var args struct { + Foo string `arg:"env,required"` + } + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableSliceArgumentString(t *testing.T) { + var args struct { + Foo []string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=bar,"baz, qux"`}, &args) + require.NoError(t, err) + assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) +} + +func TestEnvironmentVariableSliceEmpty(t *testing.T) { + var args struct { + Foo []string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=`}, &args) + require.NoError(t, err) + assert.Len(t, args.Foo, 0) +} + +func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { + var args struct { + Foo []int `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1,99`}, &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 99}, args.Foo) +} + +func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { + var args struct { + Foo []float32 `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1.1,99.9`}, &args) + require.NoError(t, err) + assert.Equal(t, []float32{1.1, 99.9}, args.Foo) +} + +func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { + var args struct { + Foo []bool `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=true,false,0,1`}, &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, false, true}, args.Foo) +} + +func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { + var args struct { + Foo []int `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1,99\"`}, &args) + assert.Error(t, err) +} + +func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { + var args struct { + Foo []bool `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=one,two`}, &args) + assert.Error(t, err) +} + +func TestEnvironmentVariableMap(t *testing.T) { + var args struct { + Foo map[int]string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1=one,99=ninetynine`}, &args) + require.NoError(t, err) + assert.Len(t, args.Foo, 2) + assert.Equal(t, "one", args.Foo[1]) + assert.Equal(t, "ninetynine", args.Foo[99]) +} + +func TestEnvironmentVariableEmptyMap(t *testing.T) { + var args struct { + Foo map[int]string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=`}, &args) + require.NoError(t, err) + assert.Len(t, args.Foo, 0) +} + +func TestEnvironmentVariableIgnored(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + setenv(t, "FOO", "abc") + + p, err := NewParser(Config{IgnoreEnv: true}, &args) + require.NoError(t, err) + + err = p.Parse(nil) + assert.NoError(t, err) + assert.Equal(t, "", args.Foo) +} + +func TestDefaultValuesIgnored(t *testing.T) { + var args struct { + Foo string `default:"bad"` + } + + p, err := NewParser(Config{IgnoreDefault: true}, &args) + require.NoError(t, err) + + err = p.Parse(nil) + assert.NoError(t, err) + assert.Equal(t, "", args.Foo) +} + +func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) { + var args struct { + Sub *struct { + Foo string `arg:"env"` + } `arg:"subcommand"` + } + setenv(t, "FOO", "abc") + + p, err := NewParser(Config{IgnoreEnv: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub"}) + assert.NoError(t, err) + assert.Equal(t, "", args.Sub.Foo) +} + +type textUnmarshaler struct { + val int +} + +func (f *textUnmarshaler) UnmarshalText(b []byte) error { + f.val = len(b) + return nil +} + +func TestTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +func TestPtrToTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo *textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +func TestRepeatedTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []textUnmarshaler + } + err := parse("--foo abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestRepeatedPtrToTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []*textUnmarshaler + } + err := parse("--foo abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestPositionalTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []textUnmarshaler `arg:"positional"` + } + err := parse("abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestPositionalPtrToTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []*textUnmarshaler `arg:"positional"` + } + err := parse("abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +type boolUnmarshaler bool + +func (p *boolUnmarshaler) UnmarshalText(b []byte) error { + *p = len(b)%2 == 0 + return nil +} + +func TestBoolUnmarhsaler(t *testing.T) { + // test that a bool type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a bool + var args struct { + Foo *boolUnmarshaler + } + err := parse("--foo ab", &args) + require.NoError(t, err) + assert.EqualValues(t, true, *args.Foo) +} + +type sliceUnmarshaler []int + +func (p *sliceUnmarshaler) UnmarshalText(b []byte) error { + *p = sliceUnmarshaler{len(b)} + return nil +} + +func TestSliceUnmarhsaler(t *testing.T) { + // test that a slice type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a slice + var args struct { + Foo *sliceUnmarshaler + Bar string `arg:"positional"` + } + err := parse("--foo abcde xyz", &args) + require.NoError(t, err) + require.Len(t, *args.Foo, 1) + assert.EqualValues(t, 5, (*args.Foo)[0]) + assert.Equal(t, "xyz", args.Bar) +} + +func TestIP(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestPtrToIP(t *testing.T) { + var args struct { + Host *net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestURL(t *testing.T) { + var args struct { + URL url.URL + } + err := parse("--url https://example.com/get?item=xyz", &args) + require.NoError(t, err) + assert.Equal(t, "https://example.com/get?item=xyz", args.URL.String()) +} + +func TestPtrToURL(t *testing.T) { + var args struct { + URL *url.URL + } + err := parse("--url http://example.com/#xyz", &args) + require.NoError(t, err) + assert.Equal(t, "http://example.com/#xyz", args.URL.String()) +} + +func TestIPSlice(t *testing.T) { + var args struct { + Host []net.IP + } + err := parse("--host 192.168.0.1 127.0.0.1", &args) + require.NoError(t, err) + require.Len(t, args.Host, 2) + assert.Equal(t, "192.168.0.1", args.Host[0].String()) + assert.Equal(t, "127.0.0.1", args.Host[1].String()) +} + +func TestInvalidIPAddress(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMAC(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host 0123.4567.89ab", &args) + require.NoError(t, err) + assert.Equal(t, "01:23:45:67:89:ab", args.Host.String()) +} + +func TestInvalidMac(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient foo@example.com", &args) + require.NoError(t, err) + assert.Equal(t, "", args.Recipient.String()) +} + +func TestInvalidMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient xxx", &args) + assert.Error(t, err) +} + +type A struct { + X string +} + +type B struct { + Y int +} + +func TestEmbedded(t *testing.T) { + var args struct { + A + B + Z bool + } + err := parse("--x=hello --y=321 --z", &args) + require.NoError(t, err) + assert.Equal(t, "hello", args.X) + assert.Equal(t, 321, args.Y) + assert.Equal(t, true, args.Z) +} + +func TestEmbeddedPtr(t *testing.T) { + // embedded pointer fields are not supported so this should return an error + var args struct { + *A + } + err := parse("--x=hello", &args) + require.Error(t, err) +} + +func TestEmbeddedPtrIgnored(t *testing.T) { + // embedded pointer fields are not normally supported but here + // we explicitly exclude it so the non-nil embedded structs + // should work as expected + var args struct { + *A `arg:"-"` + B + } + err := parse("--y=321", &args) + require.NoError(t, err) + assert.Equal(t, 321, args.Y) +} + +func TestEmbeddedWithDuplicateField(t *testing.T) { + // see https://github.com/alexflint/go-arg/issues/100 + type T struct { + A string `arg:"--cat"` + } + type U struct { + A string `arg:"--dog"` + } + var args struct { + T + U + } + + err := parse("--cat=cat --dog=dog", &args) + require.NoError(t, err) + assert.Equal(t, "cat", args.T.A) + assert.Equal(t, "dog", args.U.A) +} + +func TestEmbeddedWithDuplicateField2(t *testing.T) { + // see https://github.com/alexflint/go-arg/issues/100 + type T struct { + A string + } + type U struct { + A string + } + var args struct { + T + U + } + + err := parse("--a=xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.T.A) + assert.Equal(t, "", args.U.A) +} + +func TestUnexportedEmbedded(t *testing.T) { + type embeddedArgs struct { + Foo string + } + var args struct { + embeddedArgs + } + err := parse("--foo bar", &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestIgnoredEmbedded(t *testing.T) { + type embeddedArgs struct { + Foo string + } + var args struct { + embeddedArgs `arg:"-"` + } + err := parse("--foo bar", &args) + require.Error(t, err) +} + +func TestEmptyArgs(t *testing.T) { + origArgs := os.Args + + // test what happens if somehow os.Args is empty + os.Args = nil + var args struct { + Foo string + } + MustParse(&args) + + // put the original arguments back + os.Args = origArgs +} + +func TestTooManyHyphens(t *testing.T) { + var args struct { + TooManyHyphens string `arg:"---x"` + } + err := parse("--foo -", &args) + assert.Error(t, err) +} + +func TestHyphenAsOption(t *testing.T) { + var args struct { + Foo string + } + err := parse("--foo -", &args) + require.NoError(t, err) + assert.Equal(t, "-", args.Foo) +} + +func TestHyphenAsPositional(t *testing.T) { + var args struct { + Foo string `arg:"positional"` + } + err := parse("-", &args) + require.NoError(t, err) + assert.Equal(t, "-", args.Foo) +} + +func TestHyphenInMultiOption(t *testing.T) { + var args struct { + Foo []string + Bar int + } + err := parse("--foo --- x - y --bar 3", &args) + require.NoError(t, err) + assert.Equal(t, []string{"---", "x", "-", "y"}, args.Foo) + assert.Equal(t, 3, args.Bar) +} + +func TestHyphenInMultiPositional(t *testing.T) { + var args struct { + Foo []string `arg:"positional"` + } + err := parse("--- x - y", &args) + require.NoError(t, err) + assert.Equal(t, []string{"---", "x", "-", "y"}, args.Foo) +} + +func TestSeparate(t *testing.T) { + for _, val := range []string{"-f one", "-f=one", "--foo one", "--foo=one"} { + var args struct { + Foo []string `arg:"--foo,-f,separate"` + } + + err := parse(val, &args) + require.NoError(t, err) + assert.Equal(t, []string{"one"}, args.Foo) + } +} + +func TestSeparateWithDefault(t *testing.T) { + args := struct { + Foo []string `arg:"--foo,-f,separate"` + }{ + Foo: []string{"default"}, + } + + err := parse("-f one -f=two", &args) + require.NoError(t, err) + assert.Equal(t, []string{"default", "one", "two"}, args.Foo) +} + +func TestSeparateWithPositional(t *testing.T) { + var args struct { + Foo []string `arg:"--foo,-f,separate"` + Bar string `arg:"positional"` + Moo string `arg:"positional"` + } + + err := parse("zzz --foo one -f=two --foo=three -f four aaa", &args) + require.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) + assert.Equal(t, "zzz", args.Bar) + assert.Equal(t, "aaa", args.Moo) +} + +func TestSeparatePositionalInterweaved(t *testing.T) { + var args struct { + Foo []string `arg:"--foo,-f,separate"` + Bar []string `arg:"--bar,-b,separate"` + Pre string `arg:"positional"` + Post []string `arg:"positional"` + } + + err := parse("zzz -f foo1 -b=bar1 --foo=foo2 -b bar2 post1 -b bar3 post2 post3", &args) + require.NoError(t, err) + assert.Equal(t, []string{"foo1", "foo2"}, args.Foo) + assert.Equal(t, []string{"bar1", "bar2", "bar3"}, args.Bar) + assert.Equal(t, "zzz", args.Pre) + assert.Equal(t, []string{"post1", "post2", "post3"}, args.Post) +} + +func TestSpacesAllowedInTags(t *testing.T) { + var args struct { + Foo []string `arg:"--foo, -f, separate, required, help:quite nice really"` + } + + err := parse("--foo one -f=two --foo=three -f four", &args) + require.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) +} + +func TestReuseParser(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"--foo=abc"}) + require.NoError(t, err) + assert.Equal(t, args.Foo, "abc") + + err = p.Parse([]string{}) + assert.Error(t, err) +} + +func TestVersion(t *testing.T) { + var args struct{} + err := parse("--version", &args) + assert.Equal(t, ErrVersion, err) + +} + +func TestMultipleTerminates(t *testing.T) { + var args struct { + X []string + Y string `arg:"positional"` + } + + err := parse("--x a b -- c", &args) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, args.X) + assert.Equal(t, "c", args.Y) +} + +func TestDefaultOptionValues(t *testing.T) { + var args struct { + A int `default:"123"` + B *int `default:"123"` + C string `default:"abc"` + D *string `default:"abc"` + E float64 `default:"1.23"` + F *float64 `default:"1.23"` + G bool `default:"true"` + H *bool `default:"true"` + } + + err := parse("--c=xyz --e=4.56", &args) + require.NoError(t, err) + + assert.Equal(t, 123, args.A) + assert.Equal(t, 123, *args.B) + assert.Equal(t, "xyz", args.C) + assert.Equal(t, "abc", *args.D) + assert.Equal(t, 4.56, args.E) + assert.Equal(t, 1.23, *args.F) + assert.True(t, args.G) + assert.True(t, args.G) +} + +func TestDefaultUnparseable(t *testing.T) { + var args struct { + A int `default:"x"` + } + + err := parse("", &args) + assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`) +} + +func TestDefaultPositionalValues(t *testing.T) { + var args struct { + A int `arg:"positional" default:"123"` + B *int `arg:"positional" default:"123"` + C string `arg:"positional" default:"abc"` + D *string `arg:"positional" default:"abc"` + E float64 `arg:"positional" default:"1.23"` + F *float64 `arg:"positional" default:"1.23"` + G bool `arg:"positional" default:"true"` + H *bool `arg:"positional" default:"true"` + } + + err := parse("456 789", &args) + require.NoError(t, err) + + assert.Equal(t, 456, args.A) + assert.Equal(t, 789, *args.B) + assert.Equal(t, "abc", args.C) + assert.Equal(t, "abc", *args.D) + assert.Equal(t, 1.23, args.E) + assert.Equal(t, 1.23, *args.F) + assert.True(t, args.G) + assert.True(t, args.G) +} + +func TestDefaultValuesNotAllowedWithRequired(t *testing.T) { + var args struct { + A int `arg:"required" default:"123"` // required not allowed with default! + } + + err := parse("", &args) + assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified") +} + +func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { + var args struct { + A []int `default:"123"` // required not allowed with default! + } + + err := parse("", &args) + assert.EqualError(t, err, ".A: default values are not supported for slice or map fields") +} + +func TestUnexportedFieldsSkipped(t *testing.T) { + var args struct { + unexported struct{} + } + + _, err := NewParser(Config{}, &args) + require.NoError(t, err) +} + +func TestMustParseInvalidParser(t *testing.T) { + originalExit := osExit + originalStdout := stdout + defer func() { + osExit = originalExit + stdout = originalStdout + }() + + var exitCode int + osExit = func(code int) { exitCode = code } + stdout = &bytes.Buffer{} + + var args struct { + CannotParse struct{} + } + parser := MustParse(&args) + assert.Nil(t, parser) + assert.Equal(t, -1, exitCode) +} + +func TestMustParsePrintsHelp(t *testing.T) { + originalExit := osExit + originalStdout := stdout + originalArgs := os.Args + defer func() { + osExit = originalExit + stdout = originalStdout + os.Args = originalArgs + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + os.Args = []string{"someprogram", "--help"} + stdout = &bytes.Buffer{} + + var args struct{} + parser := MustParse(&args) + assert.NotNil(t, parser) + require.NotNil(t, exitCode) + assert.Equal(t, 0, *exitCode) +} + +func TestMustParsePrintsVersion(t *testing.T) { + originalExit := osExit + originalStdout := stdout + originalArgs := os.Args + defer func() { + osExit = originalExit + stdout = originalStdout + os.Args = originalArgs + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + os.Args = []string{"someprogram", "--version"} + + var b bytes.Buffer + stdout = &b + + var args versioned + parser := MustParse(&args) + require.NotNil(t, parser) + require.NotNil(t, exitCode) + assert.Equal(t, 0, *exitCode) + assert.Equal(t, "example 3.2.1\n", b.String()) +} diff --git a/v2/reflect.go b/v2/reflect.go new file mode 100644 index 0000000..cd80be7 --- /dev/null +++ b/v2/reflect.go @@ -0,0 +1,107 @@ +package arg + +import ( + "encoding" + "fmt" + "reflect" + "unicode" + "unicode/utf8" + + scalar "github.com/alexflint/go-scalar" +) + +var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + +// cardinality tracks how many tokens are expected for a given spec +// - zero is a boolean, which does to expect any value +// - one is an ordinary option that will be parsed from a single token +// - multiple is a slice or map that can accept zero or more tokens +type cardinality int + +const ( + zero cardinality = iota + one + multiple + unsupported +) + +func (k cardinality) String() string { + switch k { + case zero: + return "zero" + case one: + return "one" + case multiple: + return "multiple" + case unsupported: + return "unsupported" + default: + return fmt.Sprintf("unknown(%d)", int(k)) + } +} + +// cardinalityOf returns true if the type can be parsed from a string +func cardinalityOf(t reflect.Type) (cardinality, error) { + if scalar.CanParse(t) { + if isBoolean(t) { + return zero, nil + } + return one, nil + } + + // look inside pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // look inside slice and map types + switch t.Kind() { + case reflect.Slice: + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem()) + } + return multiple, nil + case reflect.Map: + if !scalar.CanParse(t.Key()) { + return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem()) + } + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem()) + } + return multiple, nil + default: + return unsupported, fmt.Errorf("cannot parse into %v", t) + } +} + +// isBoolean returns true if the type can be parsed from a single string +func isBoolean(t reflect.Type) bool { + switch { + case t.Implements(textUnmarshalerType): + return false + case t.Kind() == reflect.Bool: + return true + case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool: + return true + default: + return false + } +} + +// isExported returns true if the struct field name is exported +func isExported(field string) bool { + r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8 + return unicode.IsLetter(r) && unicode.IsUpper(r) +} + +// isZero returns true if v contains the zero value for its type +func isZero(v reflect.Value) bool { + t := v.Type() + if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { + return v.IsNil() + } + if !t.Comparable() { + return false + } + return v.Interface() == reflect.Zero(t).Interface() +} diff --git a/v2/reflect_test.go b/v2/reflect_test.go new file mode 100644 index 0000000..10909b3 --- /dev/null +++ b/v2/reflect_test.go @@ -0,0 +1,112 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) { + actual, err := cardinalityOf(typ) + assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual) + if expected == unsupported { + assert.Error(t, err) + } +} + +func TestCardinalityOf(t *testing.T) { + var b bool + var i int + var s string + var f float64 + var bs []bool + var is []int + var m map[string]int + var unsupported1 struct{} + var unsupported2 []struct{} + var unsupported3 map[string]struct{} + var unsupported4 map[struct{}]string + + assertCardinality(t, reflect.TypeOf(b), zero) + assertCardinality(t, reflect.TypeOf(i), one) + assertCardinality(t, reflect.TypeOf(s), one) + assertCardinality(t, reflect.TypeOf(f), one) + + assertCardinality(t, reflect.TypeOf(&b), zero) + assertCardinality(t, reflect.TypeOf(&s), one) + assertCardinality(t, reflect.TypeOf(&i), one) + assertCardinality(t, reflect.TypeOf(&f), one) + + assertCardinality(t, reflect.TypeOf(bs), multiple) + assertCardinality(t, reflect.TypeOf(is), multiple) + + assertCardinality(t, reflect.TypeOf(&bs), multiple) + assertCardinality(t, reflect.TypeOf(&is), multiple) + + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) + + assertCardinality(t, reflect.TypeOf(unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported4), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported) +} + +type implementsTextUnmarshaler struct{} + +func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error { + return nil +} + +func TestCardinalityTextUnmarshaler(t *testing.T) { + var x implementsTextUnmarshaler + var s []implementsTextUnmarshaler + var m []implementsTextUnmarshaler + assertCardinality(t, reflect.TypeOf(x), one) + assertCardinality(t, reflect.TypeOf(&x), one) + assertCardinality(t, reflect.TypeOf(s), multiple) + assertCardinality(t, reflect.TypeOf(&s), multiple) + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) +} + +func TestIsExported(t *testing.T) { + assert.True(t, isExported("Exported")) + assert.False(t, isExported("notExported")) + assert.False(t, isExported("")) + assert.False(t, isExported(string([]byte{255}))) +} + +func TestCardinalityString(t *testing.T) { + assert.Equal(t, "zero", zero.String()) + assert.Equal(t, "one", one.String()) + assert.Equal(t, "multiple", multiple.String()) + assert.Equal(t, "unsupported", unsupported.String()) + assert.Equal(t, "unknown(42)", cardinality(42).String()) +} + +func TestIsZero(t *testing.T) { + var zero int + var notZero = 3 + var nilSlice []int + var nonNilSlice = []int{1, 2, 3} + var nilMap map[string]string + var nonNilMap = map[string]string{"foo": "bar"} + var uncomparable = func() {} + + assert.True(t, isZero(reflect.ValueOf(zero))) + assert.False(t, isZero(reflect.ValueOf(notZero))) + + assert.True(t, isZero(reflect.ValueOf(nilSlice))) + assert.False(t, isZero(reflect.ValueOf(nonNilSlice))) + + assert.True(t, isZero(reflect.ValueOf(nilMap))) + assert.False(t, isZero(reflect.ValueOf(nonNilMap))) + + assert.False(t, isZero(reflect.ValueOf(uncomparable))) +} diff --git a/v2/sequence.go b/v2/sequence.go new file mode 100644 index 0000000..35a3614 --- /dev/null +++ b/v2/sequence.go @@ -0,0 +1,123 @@ +package arg + +import ( + "fmt" + "reflect" + "strings" + + scalar "github.com/alexflint/go-scalar" +) + +// setSliceOrMap parses a sequence of strings into a slice or map. If clear is +// true then any values already in the slice or map are first removed. +func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + t := dest.Type() + if t.Kind() == reflect.Ptr { + dest = dest.Elem() + t = t.Elem() + } + + switch t.Kind() { + case reflect.Slice: + return setSlice(dest, values, clear) + case reflect.Map: + return setMap(dest, values, clear) + default: + return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t) + } +} + +// setSlice parses a sequence of strings and inserts them into a slice. If clear +// is true then any values already in the slice are removed. +func setSlice(dest reflect.Value, values []string, clear bool) error { + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { + ptr = true + elem = elem.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + dest.SetLen(0) + } + + // parse the values one-by-one + for _, s := range values { + v := reflect.New(elem) + if err := scalar.ParseValue(v.Elem(), s); err != nil { + return err + } + if !ptr { + v = v.Elem() + } + dest.Set(reflect.Append(dest, v)) + } + return nil +} + +// setMap parses a sequence of name=value strings and inserts them into a map. +// If clear is true then any values already in the map are removed. +func setMap(dest reflect.Value, values []string, clear bool) error { + // determine the key and value type + var keyIsPtr bool + keyType := dest.Type().Key() + if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) { + keyIsPtr = true + keyType = keyType.Elem() + } + + var valIsPtr bool + valType := dest.Type().Elem() + if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) { + valIsPtr = true + valType = valType.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + for _, k := range dest.MapKeys() { + dest.SetMapIndex(k, reflect.Value{}) + } + } + + // allocate the map if it is not allocated + if dest.IsNil() { + dest.Set(reflect.MakeMap(dest.Type())) + } + + // parse the values one-by-one + for _, s := range values { + // split at the first equals sign + pos := strings.Index(s, "=") + if pos == -1 { + return fmt.Errorf("cannot parse %q into a map, expected format key=value", s) + } + + // parse the key + k := reflect.New(keyType) + if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil { + return err + } + if !keyIsPtr { + k = k.Elem() + } + + // parse the value + v := reflect.New(valType) + if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil { + return err + } + if !valIsPtr { + v = v.Elem() + } + + // add it to the map + dest.SetMapIndex(k, v) + } + return nil +} diff --git a/v2/sequence_test.go b/v2/sequence_test.go new file mode 100644 index 0000000..fde3e3a --- /dev/null +++ b/v2/sequence_test.go @@ -0,0 +1,152 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetSliceWithoutClearing(t *testing.T) { + xs := []int{10} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + require.NoError(t, err) + assert.Equal(t, []int{10, 1, 2, 3}, xs) +} + +func TestSetSliceAfterClearing(t *testing.T) { + xs := []int{100} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, xs) +} + +func TestSetSliceInvalid(t *testing.T) { + xs := []int{100} + entries := []string{"invalid"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetSlicePtr(t *testing.T) { + var xs []*int + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, *xs[0]) + assert.Equal(t, 2, *xs[1]) + assert.Equal(t, 3, *xs[2]) +} + +func TestSetSliceTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var xs []*textUnmarshaler + entries := []string{"a", "aa", "aaa"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, xs[0].val) + assert.Equal(t, 2, xs[1].val) + assert.Equal(t, 3, xs[2].val) +} + +func TestSetMapWithoutClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) + assert.Equal(t, 10, m["foo"]) +} + +func TestSetMapAfterClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 2) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) +} + +func TestSetMapWithKeyPointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[*string]int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) +} + +func TestSetMapWithValuePointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]*int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) + assert.Equal(t, 123, *m["abc"]) +} + +func TestSetMapTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[textUnmarshaler]*textUnmarshaler + entries := []string{"a=123", "aa=12", "aaa=1"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) + assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) + assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) +} + +func TestSetMapInvalidKey(t *testing.T) { + var m map[int]int + entries := []string{"invalid=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapInvalidValue(t *testing.T) { + var m map[int]int + entries := []string{"123=invalid"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapMalformed(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]string + entries := []string{"missing_equals_sign"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetSliceOrMapErrors(t *testing.T) { + var err error + var dest reflect.Value + + // converting a slice to a reflect.Value in this way will make it read only + var cannotSet []int + dest = reflect.ValueOf(cannotSet) + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in something that is not a slice or a map + var notSliceOrMap string + dest = reflect.ValueOf(¬SliceOrMap).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in a pointer to something that is not a slice or a map + var stringPtr *string + dest = reflect.ValueOf(&stringPtr).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) +} diff --git a/v2/subcommand.go b/v2/subcommand.go new file mode 100644 index 0000000..dff732c --- /dev/null +++ b/v2/subcommand.go @@ -0,0 +1,37 @@ +package arg + +// Subcommand returns the user struct for the subcommand selected by +// the command line arguments most recently processed by the parser. +// The return value is always a pointer to a struct. If no subcommand +// was specified then it returns the top-level arguments struct. If +// no command line arguments have been processed by this parser then it +// returns nil. +func (p *Parser) Subcommand() interface{} { + if p.lastCmd == nil || p.lastCmd.parent == nil { + return nil + } + return p.val(p.lastCmd.dest).Interface() +} + +// SubcommandNames returns the sequence of subcommands specified by the +// user. If no subcommands were given then it returns an empty slice. +func (p *Parser) SubcommandNames() []string { + if p.lastCmd == nil { + return nil + } + + // make a list of ancestor commands + var ancestors []string + cur := p.lastCmd + for cur.parent != nil { // we want to exclude the root + ancestors = append(ancestors, cur.name) + cur = cur.parent + } + + // reverse the list + out := make([]string, len(ancestors)) + for i := 0; i < len(ancestors); i++ { + out[i] = ancestors[len(ancestors)-i-1] + } + return out +} diff --git a/v2/subcommand_test.go b/v2/subcommand_test.go new file mode 100644 index 0000000..2c61dd3 --- /dev/null +++ b/v2/subcommand_test.go @@ -0,0 +1,413 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// This file contains tests for parse.go but I decided to put them here +// since that file is getting large + +func TestSubcommandNotAPointer(t *testing.T) { + var args struct { + A string `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestSubcommandNotAPointerToStruct(t *testing.T) { + var args struct { + A struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestPositionalAndSubcommandNotAllowed(t *testing.T) { + var args struct { + A string `arg:"positional"` + B *struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestMinimalSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := pparse("list", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, args.List, p.Subcommand()) + assert.Equal(t, []string{"list"}, p.SubcommandNames()) +} + +func TestSubcommandNamesBeforeParsing(t *testing.T) { + type listCmd struct{} + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + assert.Nil(t, p.Subcommand()) + assert.Nil(t, p.SubcommandNames()) +} + +func TestNoSuchSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + _, err := pparse("invalid", &args) + assert.Error(t, err) +} + +func TestNamedSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand:ls"` + } + p, err := pparse("ls", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, args.List, p.Subcommand()) + assert.Equal(t, []string{"ls"}, p.SubcommandNames()) +} + +func TestEmptySubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := pparse("", &args) + require.NoError(t, err) + assert.Nil(t, args.List) + assert.Nil(t, p.Subcommand()) + assert.Empty(t, p.SubcommandNames()) +} + +func TestTwoSubcommands(t *testing.T) { + type getCmd struct { + } + type listCmd struct { + } + var args struct { + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + p, err := pparse("list", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List, p.Subcommand()) + assert.Equal(t, []string{"list"}, p.SubcommandNames()) +} + +func TestSubcommandsWithOptions(t *testing.T) { + type getCmd struct { + Name string + } + type listCmd struct { + Limit int + } + type cmd struct { + Verbose bool + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("list", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + } + + { + var args cmd + err := parse("list --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + } + + { + var args cmd + err := parse("list --limit 3 --verbose", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("list --verbose --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("--verbose list --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("get", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Nil(t, args.List) + } + + { + var args cmd + err := parse("get --name test", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Nil(t, args.List) + assert.Equal(t, args.Get.Name, "test") + } +} + +func TestSubcommandsWithEnvVars(t *testing.T) { + type getCmd struct { + Name string `arg:"env"` + } + type listCmd struct { + Limit int `arg:"env"` + } + type cmd struct { + Verbose bool + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + setenv(t, "LIMIT", "123") + err := parse("list", &args) + require.NoError(t, err) + require.NotNil(t, args.List) + assert.Equal(t, 123, args.List.Limit) + } + + { + var args cmd + setenv(t, "LIMIT", "not_an_integer") + err := parse("list", &args) + assert.Error(t, err) + } +} + +func TestNestedSubcommands(t *testing.T) { + type child struct{} + type parent struct { + Child *child `arg:"subcommand"` + } + type grandparent struct { + Parent *parent `arg:"subcommand"` + } + type root struct { + Grandparent *grandparent `arg:"subcommand"` + } + + { + var args root + p, err := pparse("grandparent parent child", &args) + require.NoError(t, err) + require.NotNil(t, args.Grandparent) + require.NotNil(t, args.Grandparent.Parent) + require.NotNil(t, args.Grandparent.Parent.Child) + assert.Equal(t, args.Grandparent.Parent.Child, p.Subcommand()) + assert.Equal(t, []string{"grandparent", "parent", "child"}, p.SubcommandNames()) + } + + { + var args root + p, err := pparse("grandparent parent", &args) + require.NoError(t, err) + require.NotNil(t, args.Grandparent) + require.NotNil(t, args.Grandparent.Parent) + require.Nil(t, args.Grandparent.Parent.Child) + assert.Equal(t, args.Grandparent.Parent, p.Subcommand()) + assert.Equal(t, []string{"grandparent", "parent"}, p.SubcommandNames()) + } + + { + var args root + p, err := pparse("grandparent", &args) + require.NoError(t, err) + require.NotNil(t, args.Grandparent) + require.Nil(t, args.Grandparent.Parent) + assert.Equal(t, args.Grandparent, p.Subcommand()) + assert.Equal(t, []string{"grandparent"}, p.SubcommandNames()) + } + + { + var args root + p, err := pparse("", &args) + require.NoError(t, err) + require.Nil(t, args.Grandparent) + assert.Nil(t, p.Subcommand()) + assert.Empty(t, p.SubcommandNames()) + } +} + +func TestSubcommandsWithPositionals(t *testing.T) { + type listCmd struct { + Pattern string `arg:"positional"` + } + type cmd struct { + Format string + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("list", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "", args.List.Pattern) + } + + { + var args cmd + err := parse("list --format json", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("list somepattern", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + } + + { + var args cmd + err := parse("list somepattern --format json", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("list --format json somepattern", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("--format json list somepattern", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("--format json", &args) + require.NoError(t, err) + assert.Nil(t, args.List) + assert.Equal(t, "json", args.Format) + } +} +func TestSubcommandsWithMultiplePositionals(t *testing.T) { + type getCmd struct { + Items []string `arg:"positional"` + } + type cmd struct { + Limit int + Get *getCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("get", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Empty(t, args.Get.Items) + } + + { + var args cmd + err := parse("get --limit 5", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Empty(t, args.Get.Items) + assert.Equal(t, 5, args.Limit) + } + + { + var args cmd + err := parse("get item1", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Equal(t, []string{"item1"}, args.Get.Items) + } + + { + var args cmd + err := parse("get item1 item2 item3", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Equal(t, []string{"item1", "item2", "item3"}, args.Get.Items) + } + + { + var args cmd + err := parse("get item1 --limit 5 item2", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Equal(t, []string{"item1", "item2"}, args.Get.Items) + assert.Equal(t, 5, args.Limit) + } +} + +func TestValForNilStruct(t *testing.T) { + type subcmd struct{} + var cmd struct { + Sub *subcmd `arg:"subcommand"` + } + + p, err := NewParser(Config{}, &cmd) + require.NoError(t, err) + + typ := reflect.TypeOf(cmd) + subField, _ := typ.FieldByName("Sub") + + v := p.val(path{fields: []reflect.StructField{subField, subField}}) + assert.False(t, v.IsValid()) +} diff --git a/v2/usage.go b/v2/usage.go new file mode 100644 index 0000000..7ba06cc --- /dev/null +++ b/v2/usage.go @@ -0,0 +1,339 @@ +package arg + +import ( + "fmt" + "io" + "os" + "strings" +) + +// the width of the left column +const colWidth = 25 + +// to allow monkey patching in tests +var ( + stdout io.Writer = os.Stdout + stderr io.Writer = os.Stderr + osExit = os.Exit +) + +// Fail prints usage information to stderr and exits with non-zero status +func (p *Parser) Fail(msg string) { + p.failWithSubcommand(msg, p.cmd) +} + +// FailSubcommand prints usage information for a specified subcommand to stderr, +// then exits with non-zero status. To write usage information for a top-level +// subcommand, provide just the name of that subcommand. To write usage +// information for a subcommand that is nested under another subcommand, provide +// a sequence of subcommand names starting with the top-level subcommand and so +// on down the tree. +func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { + cmd, err := p.lookupCommand(subcommand...) + if err != nil { + return err + } + p.failWithSubcommand(msg, cmd) + return nil +} + +// failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status +func (p *Parser) failWithSubcommand(msg string, cmd *command) { + p.writeUsageForSubcommand(stderr, cmd) + fmt.Fprintln(stderr, "error:", msg) + osExit(-1) +} + +// WriteUsage writes usage information to the given writer +func (p *Parser) WriteUsage(w io.Writer) { + cmd := p.cmd + if p.lastCmd != nil { + cmd = p.lastCmd + } + p.writeUsageForSubcommand(w, cmd) +} + +// WriteUsageForSubcommand writes the usage information for a specified +// subcommand. To write usage information for a top-level subcommand, provide +// just the name of that subcommand. To write usage information for a subcommand +// that is nested under another subcommand, provide a sequence of subcommand +// names starting with the top-level subcommand and so on down the tree. +func (p *Parser) WriteUsageForSubcommand(w io.Writer, subcommand ...string) error { + cmd, err := p.lookupCommand(subcommand...) + if err != nil { + return err + } + p.writeUsageForSubcommand(w, cmd) + return nil +} + +// writeUsageForSubcommand writes usage information for the given subcommand +func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { + var positionals, longOptions, shortOptions []*spec + for _, spec := range cmd.specs { + switch { + case spec.positional: + positionals = append(positionals, spec) + case spec.long != "": + longOptions = append(longOptions, spec) + case spec.short != "": + shortOptions = append(shortOptions, spec) + } + } + + if p.version != "" { + fmt.Fprintln(w, p.version) + } + + // make a list of ancestor commands so that we print with full context + var ancestors []string + ancestor := cmd + for ancestor != nil { + ancestors = append(ancestors, ancestor.name) + ancestor = ancestor.parent + } + + // print the beginning of the usage string + fmt.Fprint(w, "Usage:") + for i := len(ancestors) - 1; i >= 0; i-- { + fmt.Fprint(w, " "+ancestors[i]) + } + + // write the option component of the usage message + for _, spec := range shortOptions { + // prefix with a space + fmt.Fprint(w, " ") + if !spec.required { + fmt.Fprint(w, "[") + } + fmt.Fprint(w, synopsis(spec, "-"+spec.short)) + if !spec.required { + fmt.Fprint(w, "]") + } + } + + for _, spec := range longOptions { + // prefix with a space + fmt.Fprint(w, " ") + if !spec.required { + fmt.Fprint(w, "[") + } + fmt.Fprint(w, synopsis(spec, "--"+spec.long)) + if !spec.required { + fmt.Fprint(w, "]") + } + } + + // When we parse positionals, we check that: + // 1. required positionals come before non-required positionals + // 2. there is at most one multiple-value positional + // 3. if there is a multiple-value positional then it comes after all other positionals + // Here we merely print the usage string, so we do not explicitly re-enforce those rules + + // write the positionals in following form: + // REQUIRED1 REQUIRED2 + // REQUIRED1 REQUIRED2 [OPTIONAL1 [OPTIONAL2]] + // REQUIRED1 REQUIRED2 REPEATED [REPEATED ...] + // REQUIRED1 REQUIRED2 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]] + // REQUIRED1 REQUIRED2 [OPTIONAL1 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]]] + var closeBrackets int + for _, spec := range positionals { + fmt.Fprint(w, " ") + if !spec.required { + fmt.Fprint(w, "[") + closeBrackets += 1 + } + if spec.cardinality == multiple { + fmt.Fprintf(w, "%s [%s ...]", spec.placeholder, spec.placeholder) + } else { + fmt.Fprint(w, spec.placeholder) + } + } + fmt.Fprint(w, strings.Repeat("]", closeBrackets)) + + // if the program supports subcommands, give a hint to the user about their existence + if len(cmd.subcommands) > 0 { + fmt.Fprint(w, " []") + } + + fmt.Fprint(w, "\n") +} + +func printTwoCols(w io.Writer, left, help string, defaultVal string, envVal string) { + lhs := " " + left + fmt.Fprint(w, lhs) + if help != "" { + if len(lhs)+2 < colWidth { + fmt.Fprint(w, strings.Repeat(" ", colWidth-len(lhs))) + } else { + fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth)) + } + fmt.Fprint(w, help) + } + + bracketsContent := []string{} + + if defaultVal != "" { + bracketsContent = append(bracketsContent, + fmt.Sprintf("default: %s", defaultVal), + ) + } + + if envVal != "" { + bracketsContent = append(bracketsContent, + fmt.Sprintf("env: %s", envVal), + ) + } + + if len(bracketsContent) > 0 { + fmt.Fprintf(w, " [%s]", strings.Join(bracketsContent, ", ")) + } + fmt.Fprint(w, "\n") +} + +// WriteHelp writes the usage string followed by the full help string for each option +func (p *Parser) WriteHelp(w io.Writer) { + cmd := p.cmd + if p.lastCmd != nil { + cmd = p.lastCmd + } + p.writeHelpForSubcommand(w, cmd) +} + +// WriteHelpForSubcommand writes the usage string followed by the full help +// string for a specified subcommand. To write help for a top-level subcommand, +// provide just the name of that subcommand. To write help for a subcommand that +// is nested under another subcommand, provide a sequence of subcommand names +// starting with the top-level subcommand and so on down the tree. +func (p *Parser) WriteHelpForSubcommand(w io.Writer, subcommand ...string) error { + cmd, err := p.lookupCommand(subcommand...) + if err != nil { + return err + } + p.writeHelpForSubcommand(w, cmd) + return nil +} + +// writeHelp writes the usage string for the given subcommand +func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { + var positionals, longOptions, shortOptions []*spec + for _, spec := range cmd.specs { + switch { + case spec.positional: + positionals = append(positionals, spec) + case spec.long != "": + longOptions = append(longOptions, spec) + case spec.short != "": + shortOptions = append(shortOptions, spec) + } + } + + if p.description != "" { + fmt.Fprintln(w, p.description) + } + p.writeUsageForSubcommand(w, cmd) + + // write the list of positionals + if len(positionals) > 0 { + fmt.Fprint(w, "\nPositional arguments:\n") + for _, spec := range positionals { + printTwoCols(w, spec.placeholder, spec.help, "", "") + } + } + + // write the list of options with the short-only ones first to match the usage string + if len(shortOptions)+len(longOptions) > 0 || cmd.parent == nil { + fmt.Fprint(w, "\nOptions:\n") + for _, spec := range shortOptions { + p.printOption(w, spec) + } + for _, spec := range longOptions { + p.printOption(w, spec) + } + } + + // obtain a flattened list of options from all ancestors + var globals []*spec + ancestor := cmd.parent + for ancestor != nil { + globals = append(globals, ancestor.specs...) + ancestor = ancestor.parent + } + + // write the list of global options + if len(globals) > 0 { + fmt.Fprint(w, "\nGlobal options:\n") + for _, spec := range globals { + p.printOption(w, spec) + } + } + + // write the list of built in options + p.printOption(w, &spec{ + cardinality: zero, + long: "help", + short: "h", + help: "display this help and exit", + }) + if p.version != "" { + p.printOption(w, &spec{ + cardinality: zero, + long: "version", + help: "display version and exit", + }) + } + + // write the list of subcommands + if len(cmd.subcommands) > 0 { + fmt.Fprint(w, "\nCommands:\n") + for _, subcmd := range cmd.subcommands { + printTwoCols(w, subcmd.name, subcmd.help, "", "") + } + } + + if p.epilogue != "" { + fmt.Fprintln(w, "\n"+p.epilogue) + } +} + +func (p *Parser) printOption(w io.Writer, spec *spec) { + ways := make([]string, 0, 2) + if spec.long != "" { + ways = append(ways, synopsis(spec, "--"+spec.long)) + } + if spec.short != "" { + ways = append(ways, synopsis(spec, "-"+spec.short)) + } + if len(ways) > 0 { + printTwoCols(w, strings.Join(ways, ", "), spec.help, spec.defaultVal, spec.env) + } +} + +// lookupCommand finds a subcommand based on a sequence of subcommand names. The +// first string should be a top-level subcommand, the next should be a child +// subcommand of that subcommand, and so on. If no strings are given then the +// root command is returned. If no such subcommand exists then an error is +// returned. +func (p *Parser) lookupCommand(path ...string) (*command, error) { + cmd := p.cmd + for _, name := range path { + var found *command + for _, child := range cmd.subcommands { + if child.name == name { + found = child + } + } + if found == nil { + return nil, fmt.Errorf("%q is not a subcommand of %s", name, cmd.name) + } + cmd = found + } + return cmd, nil +} + +func synopsis(spec *spec, form string) string { + if spec.cardinality == zero { + return form + } + return form + " " + spec.placeholder +} diff --git a/v2/usage_test.go b/v2/usage_test.go new file mode 100644 index 0000000..fd67fc8 --- /dev/null +++ b/v2/usage_test.go @@ -0,0 +1,635 @@ +package arg + +import ( + "bytes" + "errors" + "fmt" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type NameDotName struct { + Head, Tail string +} + +func (n *NameDotName) UnmarshalText(b []byte) error { + s := string(b) + pos := strings.Index(s, ".") + if pos == -1 { + return fmt.Errorf("missing period in %s", s) + } + n.Head = s[:pos] + n.Tail = s[pos+1:] + return nil +} + +func (n *NameDotName) MarshalText() (text []byte, err error) { + text = []byte(fmt.Sprintf("%s.%s", n.Head, n.Tail)) + return +} + +func TestWriteUsage(t *testing.T) { + expectedUsage := "Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]]" + + expectedHelp := ` +Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]] + +Positional arguments: + INPUT + OUTPUT list of outputs + +Options: + --name NAME name to use [default: Foo Bar] + --value VALUE secret value [default: 42] + --verbose, -v verbosity level + --dataset DATASET dataset to use + --optimize OPTIMIZE, -O OPTIMIZE + optimization level + --ids IDS Ids + --values VALUES Values [default: [3.14 42 256]] + --workers WORKERS, -w WORKERS + number of workers to start [default: 10, env: WORKERS] + --testenv TESTENV, -a TESTENV [env: TEST_ENV] + --file FILE, -f FILE File with mandatory extension [default: scratch.txt] + --help, -h display this help and exit +` + + var args struct { + Input string `arg:"positional,required"` + Output []string `arg:"positional" help:"list of outputs"` + Name string `help:"name to use"` + Value int `help:"secret value"` + Verbose bool `arg:"-v" help:"verbosity level"` + Dataset string `help:"dataset to use"` + Optimize int `arg:"-O" help:"optimization level"` + Ids []int64 `help:"Ids"` + Values []float64 `help:"Values"` + Workers int `arg:"-w,env:WORKERS" help:"number of workers to start" default:"10"` + TestEnv string `arg:"-a,env:TEST_ENV"` + File *NameDotName `arg:"-f" help:"File with mandatory extension"` + } + args.Name = "Foo Bar" + args.Value = 42 + args.Values = []float64{3.14, 42, 256} + args.File = &NameDotName{"scratch", "txt"} + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + os.Args[0] = "example" + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type MyEnum int + +func (n *MyEnum) UnmarshalText(b []byte) error { + return nil +} + +func (n *MyEnum) MarshalText() ([]byte, error) { + return nil, errors.New("There was a problem") +} + +func TestUsageWithDefaults(t *testing.T) { + expectedUsage := "Usage: example [--label LABEL] [--content CONTENT]" + + expectedHelp := ` +Usage: example [--label LABEL] [--content CONTENT] + +Options: + --label LABEL [default: cat] + --content CONTENT [default: dog] + --help, -h display this help and exit +` + var args struct { + Label string + Content string `default:"dog"` + } + args.Label = "cat" + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + args.Label = "should_ignore_this" + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageCannotMarshalToString(t *testing.T) { + var args struct { + Name *MyEnum + } + v := MyEnum(42) + args.Name = &v + _, err := NewParser(Config{Program: "example"}, &args) + assert.EqualError(t, err, `args.Name: error marshaling default value to string: There was a problem`) +} + +func TestUsageLongPositionalWithHelp_legacyForm(t *testing.T) { + expectedUsage := "Usage: example [VERYLONGPOSITIONALWITHHELP]" + + expectedHelp := ` +Usage: example [VERYLONGPOSITIONALWITHHELP] + +Positional arguments: + VERYLONGPOSITIONALWITHHELP + this positional argument is very long but cannot include commas + +Options: + --help, -h display this help and exit +` + var args struct { + VeryLongPositionalWithHelp string `arg:"positional,help:this positional argument is very long but cannot include commas"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageLongPositionalWithHelp_newForm(t *testing.T) { + expectedUsage := "Usage: example [VERYLONGPOSITIONALWITHHELP]" + + expectedHelp := ` +Usage: example [VERYLONGPOSITIONALWITHHELP] + +Positional arguments: + VERYLONGPOSITIONALWITHHELP + this positional argument is very long, and includes: commas, colons etc + +Options: + --help, -h display this help and exit +` + var args struct { + VeryLongPositionalWithHelp string `arg:"positional" help:"this positional argument is very long, and includes: commas, colons etc"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageWithProgramName(t *testing.T) { + expectedUsage := "Usage: myprogram" + + expectedHelp := ` +Usage: myprogram + +Options: + --help, -h display this help and exit +` + config := Config{ + Program: "myprogram", + } + p, err := NewParser(config, &struct{}{}) + require.NoError(t, err) + + os.Args[0] = "example" + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type versioned struct{} + +// Version returns the version for this program +func (versioned) Version() string { + return "example 3.2.1" +} + +func TestUsageWithVersion(t *testing.T) { + expectedUsage := "example 3.2.1\nUsage: example" + + expectedHelp := ` +example 3.2.1 +Usage: example + +Options: + --help, -h display this help and exit + --version display version and exit +` + os.Args[0] = "example" + p, err := NewParser(Config{}, &versioned{}) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type described struct{} + +// Described returns the description for this program +func (described) Description() string { + return "this program does this and that" +} + +func TestUsageWithDescription(t *testing.T) { + expectedUsage := "Usage: example" + + expectedHelp := ` +this program does this and that +Usage: example + +Options: + --help, -h display this help and exit +` + os.Args[0] = "example" + p, err := NewParser(Config{}, &described{}) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type epilogued struct{} + +// Epilogued returns the epilogue for this program +func (epilogued) Epilogue() string { + return "For more information visit github.com/alexflint/go-arg" +} + +func TestUsageWithEpilogue(t *testing.T) { + expectedUsage := "Usage: example" + + expectedHelp := ` +Usage: example + +Options: + --help, -h display this help and exit + +For more information visit github.com/alexflint/go-arg +` + os.Args[0] = "example" + p, err := NewParser(Config{}, &epilogued{}) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageForRequiredPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageForMixedPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2 [OPTIONAL1 [OPTIONAL2]]\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + Optional1 string `arg:"positional"` + Optional2 string `arg:"positional"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageForRepeatedPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2 REPEATED [REPEATED ...]\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + Repeated []string `arg:"positional,required"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageForMixedAndRepeatedPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2 [OPTIONAL1 [OPTIONAL2 [REPEATED [REPEATED ...]]]]\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + Optional1 string `arg:"positional"` + Optional2 string `arg:"positional"` + Repeated []string `arg:"positional"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestRequiredMultiplePositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...]\n" + + expectedHelp := ` +Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...] + +Positional arguments: + REQUIREDMULTIPLE required multiple positional + +Options: + --help, -h display this help and exit +` + var args struct { + RequiredMultiple []string `arg:"positional,required" help:"required multiple positional"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageWithNestedSubcommands(t *testing.T) { + expectedUsage := "Usage: example child nested [--enable] OUTPUT" + + expectedHelp := ` +Usage: example child nested [--enable] OUTPUT + +Positional arguments: + OUTPUT + +Options: + --enable + +Global options: + --values VALUES Values + --verbose, -v verbosity level + --help, -h display this help and exit +` + + var args struct { + Verbose bool `arg:"-v" help:"verbosity level"` + Child *struct { + Values []float64 `help:"Values"` + Nested *struct { + Enable bool + Output string `arg:"positional,required"` + } `arg:"subcommand:nested"` + } `arg:"subcommand:child"` + } + + os.Args[0] = "example" + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + _ = p.Parse([]string{"child", "nested", "value"}) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var help2 bytes.Buffer + p.WriteHelpForSubcommand(&help2, "child", "nested") + assert.Equal(t, expectedHelp[1:], help2.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) + + var usage2 bytes.Buffer + p.WriteUsageForSubcommand(&usage2, "child", "nested") + assert.Equal(t, expectedUsage, strings.TrimSpace(usage2.String())) +} + +func TestNonexistentSubcommand(t *testing.T) { + var args struct { + sub *struct{} `arg:"subcommand"` + } + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + var b bytes.Buffer + + err = p.WriteUsageForSubcommand(&b, "does_not_exist") + assert.Error(t, err) + + err = p.WriteHelpForSubcommand(&b, "does_not_exist") + assert.Error(t, err) + + err = p.FailSubcommand("something went wrong", "does_not_exist") + assert.Error(t, err) + + err = p.WriteUsageForSubcommand(&b, "sub", "does_not_exist") + assert.Error(t, err) + + err = p.WriteHelpForSubcommand(&b, "sub", "does_not_exist") + assert.Error(t, err) + + err = p.FailSubcommand("something went wrong", "sub", "does_not_exist") + assert.Error(t, err) +} + +func TestUsageWithoutLongNames(t *testing.T) { + expectedUsage := "Usage: example [-a PLACEHOLDER] -b SHORTONLY2" + + expectedHelp := ` +Usage: example [-a PLACEHOLDER] -b SHORTONLY2 + +Options: + -a PLACEHOLDER some help [default: some val] + -b SHORTONLY2 some help2 + --help, -h display this help and exit +` + var args struct { + ShortOnly string `arg:"-a,--" help:"some help" default:"some val" placeholder:"PLACEHOLDER"` + ShortOnly2 string `arg:"-b,--,required" help:"some help2"` + } + p, err := NewParser(Config{Program: "example"}, &args) + assert.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageWithShortFirst(t *testing.T) { + expectedUsage := "Usage: example [-c CAT] [--dog DOG]" + + expectedHelp := ` +Usage: example [-c CAT] [--dog DOG] + +Options: + -c CAT + --dog DOG + --help, -h display this help and exit +` + var args struct { + Dog string + Cat string `arg:"-c,--"` + } + p, err := NewParser(Config{Program: "example"}, &args) + assert.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageWithEnvOptions(t *testing.T) { + expectedUsage := "Usage: example [-s SHORT]" + + expectedHelp := ` +Usage: example [-s SHORT] + +Options: + -s SHORT [env: SHORT] + --help, -h display this help and exit +` + var args struct { + Short string `arg:"--,-s,env"` + EnvOnly string `arg:"--,env"` + EnvOnlyOverriden string `arg:"--,env:CUSTOM"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + assert.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestFail(t *testing.T) { + originalStderr := stderr + originalExit := osExit + defer func() { + stderr = originalStderr + osExit = originalExit + }() + + var b bytes.Buffer + stderr = &b + + var exitCode int + osExit = func(code int) { exitCode = code } + + expectedStdout := ` +Usage: example [--foo FOO] +error: something went wrong +` + + var args struct { + Foo int + } + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + p.Fail("something went wrong") + + assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, -1, exitCode) +} + +func TestFailSubcommand(t *testing.T) { + originalStderr := stderr + originalExit := osExit + defer func() { + stderr = originalStderr + osExit = originalExit + }() + + var b bytes.Buffer + stderr = &b + + var exitCode int + osExit = func(code int) { exitCode = code } + + expectedStdout := ` +Usage: example sub +error: something went wrong +` + + var args struct { + Sub *struct{} `arg:"subcommand"` + } + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + err = p.FailSubcommand("something went wrong", "sub") + require.NoError(t, err) + + assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, -1, exitCode) +}