From 2e6284635afce830433d47d9ed97a43fa841990a Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 08:56:31 -0700 Subject: [PATCH 01/19] drop support for multiple destination structs --- v2/doc.go | 39 ++ v2/example_test.go | 507 ++++++++++++++ v2/go.mod | 8 + v2/go.sum | 15 + v2/parse.go | 741 ++++++++++++++++++++ v2/parse_test.go | 1486 +++++++++++++++++++++++++++++++++++++++++ v2/reflect.go | 107 +++ v2/reflect_test.go | 112 ++++ v2/sequence.go | 123 ++++ v2/sequence_test.go | 152 +++++ v2/subcommand.go | 37 + v2/subcommand_test.go | 413 ++++++++++++ v2/usage.go | 339 ++++++++++ v2/usage_test.go | 635 ++++++++++++++++++ 14 files changed, 4714 insertions(+) create mode 100644 v2/doc.go create mode 100644 v2/example_test.go create mode 100644 v2/go.mod create mode 100644 v2/go.sum create mode 100644 v2/parse.go create mode 100644 v2/parse_test.go create mode 100644 v2/reflect.go create mode 100644 v2/reflect_test.go create mode 100644 v2/sequence.go create mode 100644 v2/sequence_test.go create mode 100644 v2/subcommand.go create mode 100644 v2/subcommand_test.go create mode 100644 v2/usage.go create mode 100644 v2/usage_test.go diff --git a/v2/doc.go b/v2/doc.go new file mode 100644 index 0000000..3b0bafd --- /dev/null +++ b/v2/doc.go @@ -0,0 +1,39 @@ +// Package arg parses command line arguments using the fields from a struct. +// +// For example, +// +// var args struct { +// Iter int +// Debug bool +// } +// arg.MustParse(&args) +// +// defines two command line arguments, which can be set using any of +// +// ./example --iter=1 --debug // debug is a boolean flag so its value is set to true +// ./example -iter 1 // debug defaults to its zero value (false) +// ./example --debug=true // iter defaults to its zero value (zero) +// +// The fastest way to see how to use go-arg is to read the examples below. +// +// Fields can be bool, string, any float type, or any signed or unsigned integer type. +// They can also be slices of any of the above, or slices of pointers to any of the above. +// +// Tags can be specified using the `arg` and `help` tag names: +// +// var args struct { +// Input string `arg:"positional"` +// Log string `arg:"positional,required"` +// Debug bool `arg:"-d" help:"turn on debug mode"` +// RealMode bool `arg:"--real" +// Wr io.Writer `arg:"-"` +// } +// +// Any tag string that starts with a single hyphen is the short form for an argument +// (e.g. `./example -d`), and any tag string that starts with two hyphens is the long +// form for the argument (instead of the field name). +// +// Other valid tag strings are `positional` and `required`. +// +// Fields can be excluded from processing with `arg:"-"`. +package arg diff --git a/v2/example_test.go b/v2/example_test.go new file mode 100644 index 0000000..fd64777 --- /dev/null +++ b/v2/example_test.go @@ -0,0 +1,507 @@ +package arg + +import ( + "fmt" + "net" + "net/mail" + "net/url" + "os" + "strings" + "time" +) + +func split(s string) []string { + return strings.Split(s, " ") +} + +// This example demonstrates basic usage +func Example() { + // These are the args you would pass in on the command line + os.Args = split("./example --foo=hello --bar") + + var args struct { + Foo string + Bar bool + } + MustParse(&args) + fmt.Println(args.Foo, args.Bar) + // output: hello true +} + +// This example demonstrates arguments that have default values +func Example_defaultValues() { + // These are the args you would pass in on the command line + os.Args = split("./example") + + var args struct { + Foo string `default:"abc"` + } + MustParse(&args) + fmt.Println(args.Foo) + // output: abc +} + +// This example demonstrates arguments that are required +func Example_requiredArguments() { + // These are the args you would pass in on the command line + os.Args = split("./example --foo=abc --bar") + + var args struct { + Foo string `arg:"required"` + Bar bool + } + MustParse(&args) + fmt.Println(args.Foo, args.Bar) + // output: abc true +} + +// This example demonstrates positional arguments +func Example_positionalArguments() { + // These are the args you would pass in on the command line + os.Args = split("./example in out1 out2 out3") + + var args struct { + Input string `arg:"positional"` + Output []string `arg:"positional"` + } + MustParse(&args) + fmt.Println("In:", args.Input) + fmt.Println("Out:", args.Output) + // output: + // In: in + // Out: [out1 out2 out3] +} + +// This example demonstrates arguments that have multiple values +func Example_multipleValues() { + // The args you would pass in on the command line + os.Args = split("./example --database localhost --ids 1 2 3") + + var args struct { + Database string + IDs []int64 + } + MustParse(&args) + fmt.Printf("Fetching the following IDs from %s: %v", args.Database, args.IDs) + // output: Fetching the following IDs from localhost: [1 2 3] +} + +// This example demonstrates arguments with keys and values +func Example_mappings() { + // The args you would pass in on the command line + os.Args = split("./example --userids john=123 mary=456") + + var args struct { + UserIDs map[string]int + } + MustParse(&args) + fmt.Println(args.UserIDs) + // output: map[john:123 mary:456] +} + +type commaSeparated struct { + M map[string]string +} + +func (c *commaSeparated) UnmarshalText(b []byte) error { + c.M = make(map[string]string) + for _, part := range strings.Split(string(b), ",") { + pos := strings.Index(part, "=") + if pos == -1 { + return fmt.Errorf("error parsing %q, expected format key=value", part) + } + c.M[part[:pos]] = part[pos+1:] + } + return nil +} + +// This example demonstrates arguments with keys and values separated by commas +func Example_mappingWithCommas() { + // The args you would pass in on the command line + os.Args = split("./example --values one=two,three=four") + + var args struct { + Values commaSeparated + } + MustParse(&args) + fmt.Println(args.Values.M) + // output: map[one:two three:four] +} + +// This eample demonstrates multiple value arguments that can be mixed with +// other arguments. +func Example_multipleMixed() { + os.Args = split("./example -c cmd1 db1 -f file1 db2 -c cmd2 -f file2 -f file3 db3 -c cmd3") + var args struct { + Commands []string `arg:"-c,separate"` + Files []string `arg:"-f,separate"` + Databases []string `arg:"positional"` + } + MustParse(&args) + fmt.Println("Commands:", args.Commands) + fmt.Println("Files:", args.Files) + fmt.Println("Databases:", args.Databases) + + // output: + // Commands: [cmd1 cmd2 cmd3] + // Files: [file1 file2 file3] + // Databases: [db1 db2 db3] +} + +// This example shows the usage string generated by go-arg +func Example_helpText() { + // These are the args you would pass in on the command line + os.Args = split("./example --help") + + var args struct { + Input string `arg:"positional,required"` + Output []string `arg:"positional"` + Verbose bool `arg:"-v" help:"verbosity level"` + Dataset string `help:"dataset to use"` + Optimize int `arg:"-O,--optim" help:"optimization level"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + // Usage: example [--verbose] [--dataset DATASET] [--optim OPTIM] INPUT [OUTPUT [OUTPUT ...]] + // + // Positional arguments: + // INPUT + // OUTPUT + // + // Options: + // --verbose, -v verbosity level + // --dataset DATASET dataset to use + // --optim OPTIM, -O OPTIM + // optimization level + // --help, -h display this help and exit +} + +// This example shows the usage string generated by go-arg with customized placeholders +func Example_helpPlaceholder() { + // These are the args you would pass in on the command line + os.Args = split("./example --help") + + var args struct { + Input string `arg:"positional,required" placeholder:"SRC"` + Output []string `arg:"positional" placeholder:"DST"` + Optimize int `arg:"-O" help:"optimization level" placeholder:"LEVEL"` + MaxJobs int `arg:"-j" help:"maximum number of simultaneous jobs" placeholder:"N"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + + // Usage: example [--optimize LEVEL] [--maxjobs N] SRC [DST [DST ...]] + + // Positional arguments: + // SRC + // DST + + // Options: + // --optimize LEVEL, -O LEVEL + // optimization level + // --maxjobs N, -j N maximum number of simultaneous jobs + // --help, -h display this help and exit +} + +// This example shows the usage string generated by go-arg when using subcommands +func Example_helpTextWithSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example --help") + + type getCmd struct { + Item string `arg:"positional" help:"item to fetch"` + } + + type listCmd struct { + Format string `help:"output format"` + Limit int + } + + var args struct { + Verbose bool + Get *getCmd `arg:"subcommand" help:"fetch an item and print it"` + List *listCmd `arg:"subcommand" help:"list available items"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + // Usage: example [--verbose] [] + // + // Options: + // --verbose + // --help, -h display this help and exit + // + // Commands: + // get fetch an item and print it + // list list available items +} + +// This example shows the usage string generated by go-arg when using subcommands +func Example_helpTextWhenUsingSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example get --help") + + type getCmd struct { + Item string `arg:"positional,required" help:"item to fetch"` + } + + type listCmd struct { + Format string `help:"output format"` + Limit int + } + + var args struct { + Verbose bool + Get *getCmd `arg:"subcommand" help:"fetch an item and print it"` + List *listCmd `arg:"subcommand" help:"list available items"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + MustParse(&args) + + // output: + // Usage: example get ITEM + // + // Positional arguments: + // ITEM item to fetch + // + // Global options: + // --verbose + // --help, -h display this help and exit +} + +// This example shows how to print help for an explicit subcommand +func Example_writeHelpForSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example get --help") + + type getCmd struct { + Item string `arg:"positional" help:"item to fetch"` + } + + type listCmd struct { + Format string `help:"output format"` + Limit int + } + + var args struct { + Verbose bool + Get *getCmd `arg:"subcommand" help:"fetch an item and print it"` + List *listCmd `arg:"subcommand" help:"list available items"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + p, err := NewParser(Config{}, &args) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + err = p.WriteHelpForSubcommand(os.Stdout, "list") + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + // output: + // Usage: example list [--format FORMAT] [--limit LIMIT] + // + // Options: + // --format FORMAT output format + // --limit LIMIT + // + // Global options: + // --verbose + // --help, -h display this help and exit +} + +// This example shows how to print help for a subcommand that is nested several levels deep +func Example_writeHelpForSubcommandNested() { + // These are the args you would pass in on the command line + os.Args = split("./example get --help") + + type mostNestedCmd struct { + Item string + } + + type nestedCmd struct { + MostNested *mostNestedCmd `arg:"subcommand"` + } + + type topLevelCmd struct { + Nested *nestedCmd `arg:"subcommand"` + } + + var args struct { + TopLevel *topLevelCmd `arg:"subcommand"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stdout = os.Stdout + + p, err := NewParser(Config{}, &args) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + err = p.WriteHelpForSubcommand(os.Stdout, "toplevel", "nested", "mostnested") + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + // output: + // Usage: example toplevel nested mostnested [--item ITEM] + // + // Options: + // --item ITEM + // --help, -h display this help and exit +} + +// This example shows the error string generated by go-arg when an invalid option is provided +func Example_errorText() { + // These are the args you would pass in on the command line + os.Args = split("./example --optimize INVALID") + + var args struct { + Input string `arg:"positional,required"` + Output []string `arg:"positional"` + Verbose bool `arg:"-v" help:"verbosity level"` + Dataset string `help:"dataset to use"` + Optimize int `arg:"-O,help:optimization level"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stderr = os.Stdout + + MustParse(&args) + + // output: + // Usage: example [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] INPUT [OUTPUT [OUTPUT ...]] + // error: error processing --optimize: strconv.ParseInt: parsing "INVALID": invalid syntax +} + +// This example shows the error string generated by go-arg when an invalid option is provided +func Example_errorTextForSubcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example get --count INVALID") + + type getCmd struct { + Count int + } + + var args struct { + Get *getCmd `arg:"subcommand"` + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stderr = os.Stdout + + MustParse(&args) + + // output: + // Usage: example get [--count COUNT] + // error: error processing --count: strconv.ParseInt: parsing "INVALID": invalid syntax +} + +// This example demonstrates use of subcommands +func Example_subcommand() { + // These are the args you would pass in on the command line + os.Args = split("./example commit -a -m what-this-commit-is-about") + + type CheckoutCmd struct { + Branch string `arg:"positional"` + Track bool `arg:"-t"` + } + type CommitCmd struct { + All bool `arg:"-a"` + Message string `arg:"-m"` + } + type PushCmd struct { + Remote string `arg:"positional"` + Branch string `arg:"positional"` + SetUpstream bool `arg:"-u"` + } + var args struct { + Checkout *CheckoutCmd `arg:"subcommand:checkout"` + Commit *CommitCmd `arg:"subcommand:commit"` + Push *PushCmd `arg:"subcommand:push"` + Quiet bool `arg:"-q"` // this flag is global to all subcommands + } + + // This is only necessary when running inside golang's runnable example harness + osExit = func(int) {} + stderr = os.Stdout + + MustParse(&args) + + switch { + case args.Checkout != nil: + fmt.Printf("checkout requested for branch %s\n", args.Checkout.Branch) + case args.Commit != nil: + fmt.Printf("commit requested with message \"%s\"\n", args.Commit.Message) + case args.Push != nil: + fmt.Printf("push requested from %s to %s\n", args.Push.Branch, args.Push.Remote) + } + + // output: + // commit requested with message "what-this-commit-is-about" +} + +func Example_allSupportedTypes() { + // These are the args you would pass in on the command line + os.Args = []string{} + + var args struct { + Bool bool + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Float32 float32 + Float64 float64 + String string + Duration time.Duration + URL url.URL + Email mail.Address + MAC net.HardwareAddr + } + + // go-arg supports each of the types above, as well as pointers to any of + // the above and slices of any of the above. It also supports any types that + // implements encoding.TextUnmarshaler. + + MustParse(&args) + + // output: +} diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..7e575a8 --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,8 @@ +module github.com/alexflint/go-arg/v2 + +require ( + github.com/alexflint/go-scalar v1.2.0 + github.com/stretchr/testify v1.7.0 +) + +go 1.13 diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..385ca8f --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,15 @@ +github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= +github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/v2/parse.go b/v2/parse.go new file mode 100644 index 0000000..8e190f2 --- /dev/null +++ b/v2/parse.go @@ -0,0 +1,741 @@ +package arg + +import ( + "encoding" + "encoding/csv" + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + + scalar "github.com/alexflint/go-scalar" +) + +// path represents a sequence of steps to find the output location for an +// argument or subcommand in the final destination struct +type path struct { + fields []reflect.StructField // sequence of struct fields to traverse +} + +// String gets a string representation of the given path +func (p path) String() string { + s := "args" + for _, f := range p.fields { + s += "." + f.Name + } + return s +} + +// Child gets a new path representing a child of this path. +func (p path) Child(f reflect.StructField) path { + // copy the entire slice of fields to avoid possible slice overwrite + subfields := make([]reflect.StructField, len(p.fields)+1) + copy(subfields, p.fields) + subfields[len(subfields)-1] = f + return path{ + fields: subfields, + } +} + +// spec represents a command line option +type spec struct { + dest path + field reflect.StructField // the struct field from which this option was created + long string // the --long form for this option, or empty if none + short string // the -s short form for this option, or empty if none + cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) + required bool // if true, this option must be present on the command line + positional bool // if true, this option will be looked for in the positional flags + separate bool // if true, each slice and map entry will have its own --flag + help string // the help text for this option + env string // the name of the environment variable for this option, or empty for none + defaultVal string // default value for this option + placeholder string // name of the data in help +} + +// command represents a named subcommand, or the top-level command +type command struct { + name string + help string + dest path + specs []*spec + subcommands []*command + parent *command +} + +// ErrHelp indicates that -h or --help were provided +var ErrHelp = errors.New("help requested by user") + +// ErrVersion indicates that --version was provided +var ErrVersion = errors.New("version requested by user") + +// MustParse processes command line arguments and exits upon failure +func MustParse(dest interface{}) *Parser { + p, err := NewParser(Config{}, dest) + if err != nil { + fmt.Fprintln(stdout, err) + osExit(-1) + return nil // just in case osExit was monkey-patched + } + + err = p.Parse(flags()) + switch { + case err == ErrHelp: + p.writeHelpForSubcommand(stdout, p.lastCmd) + osExit(0) + case err == ErrVersion: + fmt.Fprintln(stdout, p.version) + osExit(0) + case err != nil: + p.failWithSubcommand(err.Error(), p.lastCmd) + } + + return p +} + +// Parse processes command line arguments and stores them in dest +func Parse(dest interface{}) error { + p, err := NewParser(Config{}, dest) + if err != nil { + return err + } + return p.Parse(flags()) +} + +// flags gets all command line arguments other than the first (program name) +func flags() []string { + if len(os.Args) == 0 { // os.Args could be empty + return nil + } + return os.Args[1:] +} + +// Config represents configuration options for an argument parser +type Config struct { + // Program is the name of the program used in the help text + Program string + + // IgnoreEnv instructs the library not to read environment variables + IgnoreEnv bool + + // IgnoreDefault instructs the library not to reset the variables to the + // default values, including pointers to sub commands + IgnoreDefault bool +} + +// Parser represents a set of command line options with destination values +type Parser struct { + cmd *command + root reflect.Value // destination struct to fill will values + config Config + version string + description string + epilogue string + + // the following field changes during processing of command line arguments + lastCmd *command +} + +// Versioned is the interface that the destination struct should implement to +// make a version string appear at the top of the help message. +type Versioned interface { + // Version returns the version string that will be printed on a line by itself + // at the top of the help message. + Version() string +} + +// Described is the interface that the destination struct should implement to +// make a description string appear at the top of the help message. +type Described interface { + // Description returns the string that will be printed on a line by itself + // at the top of the help message. + Description() string +} + +// Epilogued is the interface that the destination struct should implement to +// add an epilogue string at the bottom of the help message. +type Epilogued interface { + // Epilogue returns the string that will be printed on a line by itself + // at the end of the help message. + Epilogue() string +} + +// walkFields calls a function for each field of a struct, recursively expanding struct fields. +func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { + walkFieldsImpl(t, visit, nil) +} + +func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + field.Index = make([]int, len(path)+1) + copy(field.Index, append(path, i)) + expand := visit(field, t) + if expand && field.Type.Kind() == reflect.Struct { + var subpath []int + if field.Anonymous { + subpath = append(path, i) + } + walkFieldsImpl(field.Type, visit, subpath) + } + } +} + +// NewParser constructs a parser from a list of destination structs +func NewParser(config Config, dest interface{}) (*Parser, error) { + // first pick a name for the command for use in the usage text + var name string + switch { + case config.Program != "": + name = config.Program + case len(os.Args) > 0: + name = filepath.Base(os.Args[0]) + default: + name = "program" + } + + // construct a parser + p := Parser{ + cmd: &command{name: name}, + config: config, + } + + // make a list of roots + p.root = reflect.ValueOf(dest) + + // process each of the destination values + t := reflect.TypeOf(dest) + if t.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) + } + + cmd, err := cmdFromStruct(name, path{}, t) + if err != nil { + return nil, err + } + + // add nonzero field values as defaults + for _, spec := range cmd.specs { + if v := p.val(spec.dest); v.IsValid() && !isZero(v) { + if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { + str, err := defaultVal.MarshalText() + if err != nil { + return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) + } + spec.defaultVal = string(str) + } else { + spec.defaultVal = fmt.Sprintf("%v", v) + } + } + } + + p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) + + if dest, ok := dest.(Versioned); ok { + p.version = dest.Version() + } + if dest, ok := dest.(Described); ok { + p.description = dest.Description() + } + if dest, ok := dest.(Epilogued); ok { + p.epilogue = dest.Epilogue() + } + + return &p, nil +} + +func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { + // commands can only be created from pointers to structs + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", + dest, t.Kind()) + } + + t = t.Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", + dest, t.Kind()) + } + + cmd := command{ + name: name, + dest: dest, + } + + var errs []string + walkFields(t, func(field reflect.StructField, t reflect.Type) bool { + // check for the ignore switch in the tag + tag := field.Tag.Get("arg") + if tag == "-" { + return false + } + + // if this is an embedded struct then recurse into its fields, even if + // it is unexported, because exported fields on unexported embedded + // structs are still writable + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true + } + + // ignore any other unexported field + if !isExported(field.Name) { + return false + } + + // duplicate the entire path to avoid slice overwrites + subdest := dest.Child(field) + spec := spec{ + dest: subdest, + field: field, + long: strings.ToLower(field.Name), + } + + help, exists := field.Tag.Lookup("help") + if exists { + spec.help = help + } + + defaultVal, hasDefault := field.Tag.Lookup("default") + if hasDefault { + spec.defaultVal = defaultVal + } + + // Look at the tag + var isSubcommand bool // tracks whether this field is a subcommand + for _, key := range strings.Split(tag, ",") { + if key == "" { + continue + } + key = strings.TrimLeft(key, " ") + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } + + switch { + case strings.HasPrefix(key, "---"): + errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) + return false + } + spec.short = key[1:] + case key == "required": + if hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", + t.Name(), field.Name)) + return false + } + spec.required = true + case key == "positional": + spec.positional = true + case key == "separate": + spec.separate = true + case key == "help": // deprecated + spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } + case key == "subcommand": + // decide on a name for the subcommand + cmdname := value + if cmdname == "" { + cmdname = strings.ToLower(field.Name) + } + + // parse the subcommand recursively + subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) + if err != nil { + errs = append(errs, err.Error()) + return false + } + + subcmd.parent = &cmd + subcmd.help = field.Tag.Get("help") + + cmd.subcommands = append(cmd.subcommands, subcmd) + isSubcommand = true + default: + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false + } + } + + placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") + if hasPlaceholder { + spec.placeholder = placeholder + } else if spec.long != "" { + spec.placeholder = strings.ToUpper(spec.long) + } else { + spec.placeholder = strings.ToUpper(spec.field.Name) + } + + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + if !isSubcommand { + cmd.specs = append(cmd.specs, &spec) + + var err error + spec.cardinality, err = cardinalityOf(field.Type) + if err != nil { + errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", + t.Name(), field.Name, field.Type.String())) + return false + } + if spec.cardinality == multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", + t.Name(), field.Name)) + return false + } + } + + // if this was an embedded field then we already returned true up above + return false + }) + + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) + } + + // check that we don't have both positionals and subcommands + var hasPositional bool + for _, spec := range cmd.specs { + if spec.positional { + hasPositional = true + } + } + if hasPositional && len(cmd.subcommands) > 0 { + return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest) + } + + return &cmd, nil +} + +// Parse processes the given command line option, storing the results in the field +// of the structs from which NewParser was constructed +func (p *Parser) Parse(args []string) error { + err := p.process(args) + if err != nil { + // If -h or --help were specified then make sure help text supercedes other errors + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return ErrHelp + } + if arg == "--" { + break + } + } + } + return err +} + +// process environment vars for the given arguments +func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { + for _, spec := range specs { + if spec.env == "" { + continue + } + + value, found := os.LookupEnv(spec.env) + if !found { + continue + } + + if spec.cardinality == multiple { + // expect a CSV string in an environment + // variable in the case of multiple values + var values []string + var err error + if len(strings.TrimSpace(value)) > 0 { + values, err = csv.NewReader(strings.NewReader(value)).Read() + if err != nil { + return fmt.Errorf( + "error reading a CSV string from environment variable %s with multiple values: %v", + spec.env, + err, + ) + } + } + if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil { + return fmt.Errorf( + "error processing environment variable %s with multiple values: %v", + spec.env, + err, + ) + } + } else { + if err := scalar.ParseValue(p.val(spec.dest), value); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) + } + } + wasPresent[spec] = true + } + + return nil +} + +// process goes through arguments one-by-one, parses them, and assigns the result to +// the underlying struct field +func (p *Parser) process(args []string) error { + // track the options we have seen + wasPresent := make(map[*spec]bool) + + // union of specs for the chain of subcommands encountered so far + curCmd := p.cmd + p.lastCmd = curCmd + + // make a copy of the specs because we will add to this list each time we expand a subcommand + specs := make([]*spec, len(curCmd.specs)) + copy(specs, curCmd.specs) + + // deal with environment vars + if !p.config.IgnoreEnv { + err := p.captureEnvVars(specs, wasPresent) + if err != nil { + return err + } + } + + // process each string from the command line + var allpositional bool + var positionals []string + + // must use explicit for loop, not range, because we manipulate i inside the loop + for i := 0; i < len(args); i++ { + arg := args[i] + if arg == "--" { + allpositional = true + continue + } + + if !isFlag(arg) || allpositional { + // each subcommand can have either subcommands or positionals, but not both + if len(curCmd.subcommands) == 0 { + positionals = append(positionals, arg) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(curCmd.subcommands, arg) + if subcmd == nil { + return fmt.Errorf("invalid subcommand: %s", arg) + } + + // instantiate the field to point to a new struct + v := p.val(subcmd.dest) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers + } + + // add the new options to the set of allowed options + specs = append(specs, subcmd.specs...) + + // capture environment vars for these new options + if !p.config.IgnoreEnv { + err := p.captureEnvVars(subcmd.specs, wasPresent) + if err != nil { + return err + } + } + + curCmd = subcmd + p.lastCmd = curCmd + continue + } + + // check for special --help and --version flags + switch arg { + case "-h", "--help": + return ErrHelp + case "--version": + return ErrVersion + } + + // check for an equals sign, as in "--foo=bar" + var value string + opt := strings.TrimLeft(arg, "-") + if pos := strings.Index(opt, "="); pos != -1 { + value = opt[pos+1:] + opt = opt[:pos] + } + + // lookup the spec for this option (note that the "specs" slice changes as + // we expand subcommands so it is better not to use a map) + spec := findOption(specs, opt) + if spec == nil { + return fmt.Errorf("unknown argument %s", arg) + } + wasPresent[spec] = true + + // deal with the case of multiple values + if spec.cardinality == multiple { + var values []string + if value == "" { + for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { + values = append(values, args[i+1]) + i++ + if spec.separate { + break + } + } + } else { + values = append(values, value) + } + err := setSliceOrMap(p.val(spec.dest), values, !spec.separate) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg, err) + } + continue + } + + // if it's a flag and it has no value then set the value to true + // use boolean because this takes account of TextUnmarshaler + if spec.cardinality == zero && value == "" { + value = "true" + } + + // if we have something like "--foo" then the value is the next argument + if value == "" { + if i+1 == len(args) { + return fmt.Errorf("missing value for %s", arg) + } + if !nextIsNumeric(spec.field.Type, args[i+1]) && isFlag(args[i+1]) { + return fmt.Errorf("missing value for %s", arg) + } + value = args[i+1] + i++ + } + + err := scalar.ParseValue(p.val(spec.dest), value) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg, err) + } + } + + // process positionals + for _, spec := range specs { + if !spec.positional { + continue + } + if len(positionals) == 0 { + break + } + wasPresent[spec] = true + if spec.cardinality == multiple { + err := setSliceOrMap(p.val(spec.dest), positionals, true) + if err != nil { + return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + } + positionals = nil + } else { + err := scalar.ParseValue(p.val(spec.dest), positionals[0]) + if err != nil { + return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + } + positionals = positionals[1:] + } + } + if len(positionals) > 0 { + return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) + } + + // fill in defaults and check that all the required args were provided + for _, spec := range specs { + if wasPresent[spec] { + continue + } + + name := strings.ToLower(spec.field.Name) + if spec.long != "" && !spec.positional { + name = "--" + spec.long + } + + if spec.required { + msg := fmt.Sprintf("%s is required", name) + if spec.env != "" { + msg += " (or environment variable " + spec.env + ")" + } + return errors.New(msg) + } + if !p.config.IgnoreDefault && spec.defaultVal != "" { + err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + } + } + + return nil +} + +func nextIsNumeric(t reflect.Type, s string) bool { + switch t.Kind() { + case reflect.Ptr: + return nextIsNumeric(t.Elem(), s) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + v := reflect.New(t) + err := scalar.ParseValue(v, s) + return err == nil + default: + return false + } +} + +// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--" +func isFlag(s string) bool { + return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" +} + +// val returns a reflect.Value corresponding to the current value for the +// given path +func (p *Parser) val(dest path) reflect.Value { + v := p.root + for _, field := range dest.fields { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + + v = v.FieldByIndex(field.Index) + } + return v +} + +// findOption finds an option from its name, or returns null if no spec is found +func findOption(specs []*spec, name string) *spec { + for _, spec := range specs { + if spec.positional { + continue + } + if spec.long == name || spec.short == name { + return spec + } + } + return nil +} + +// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found +func findSubcommand(cmds []*command, name string) *command { + for _, cmd := range cmds { + if cmd.name == name { + return cmd + } + } + return nil +} diff --git a/v2/parse_test.go b/v2/parse_test.go new file mode 100644 index 0000000..4ea6bc4 --- /dev/null +++ b/v2/parse_test.go @@ -0,0 +1,1486 @@ +package arg + +import ( + "bytes" + "fmt" + "net" + "net/mail" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setenv(t *testing.T, name, val string) { + if err := os.Setenv(name, val); err != nil { + t.Error(err) + } +} + +func parse(cmdline string, dest interface{}) error { + _, err := pparse(cmdline, dest) + return err +} + +func pparse(cmdline string, dest interface{}) (*Parser, error) { + return parseWithEnv(cmdline, nil, dest) +} + +func parseWithEnv(cmdline string, env []string, dest interface{}) (*Parser, error) { + p, err := NewParser(Config{}, dest) + if err != nil { + return nil, err + } + + // split the command line + var parts []string + if len(cmdline) > 0 { + parts = strings.Split(cmdline, " ") + } + + // split the environment vars + for _, s := range env { + pos := strings.Index(s, "=") + if pos == -1 { + return nil, fmt.Errorf("missing equals sign in %q", s) + } + err := os.Setenv(s[:pos], s[pos+1:]) + if err != nil { + return nil, err + } + } + + // execute the parser + return p, p.Parse(parts) +} + +func TestString(t *testing.T) { + var args struct { + Foo string + Ptr *string + } + err := parse("--foo bar --ptr baz", &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) + assert.Equal(t, "baz", *args.Ptr) +} + +func TestBool(t *testing.T) { + var args struct { + A bool + B bool + C *bool + D *bool + } + err := parse("--a --c", &args) + require.NoError(t, err) + assert.True(t, args.A) + assert.False(t, args.B) + assert.True(t, *args.C) + assert.Nil(t, args.D) +} + +func TestInt(t *testing.T) { + var args struct { + Foo int + Ptr *int + } + err := parse("--foo 7 --ptr 8", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) + assert.EqualValues(t, 8, *args.Ptr) +} + +func TestHexOctBin(t *testing.T) { + var args struct { + Hex int + Oct int + Bin int + Underscored int + } + err := parse("--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456", &args) + require.NoError(t, err) + assert.EqualValues(t, 10, args.Hex) + assert.EqualValues(t, 8, args.Oct) + assert.EqualValues(t, 5, args.Bin) + assert.EqualValues(t, 123456, args.Underscored) +} + +func TestNegativeInt(t *testing.T) { + var args struct { + Foo int + } + err := parse("-foo -100", &args) + require.NoError(t, err) + assert.EqualValues(t, args.Foo, -100) +} + +func TestNegativeIntAndFloatAndTricks(t *testing.T) { + var args struct { + Foo int + Bar float64 + N int `arg:"--100"` + } + err := parse("-foo -100 -bar -60.14 -100 -100", &args) + require.NoError(t, err) + assert.EqualValues(t, args.Foo, -100) + assert.EqualValues(t, args.Bar, -60.14) + assert.EqualValues(t, args.N, -100) +} + +func TestUint(t *testing.T) { + var args struct { + Foo uint + Ptr *uint + } + err := parse("--foo 7 --ptr 8", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) + assert.EqualValues(t, 8, *args.Ptr) +} + +func TestFloat(t *testing.T) { + var args struct { + Foo float32 + Ptr *float32 + } + err := parse("--foo 3.4 --ptr 3.5", &args) + require.NoError(t, err) + assert.EqualValues(t, 3.4, args.Foo) + assert.EqualValues(t, 3.5, *args.Ptr) +} + +func TestDuration(t *testing.T) { + var args struct { + Foo time.Duration + Ptr *time.Duration + } + err := parse("--foo 3ms --ptr 4ms", &args) + require.NoError(t, err) + assert.Equal(t, 3*time.Millisecond, args.Foo) + assert.Equal(t, 4*time.Millisecond, *args.Ptr) +} + +func TestInvalidDuration(t *testing.T) { + var args struct { + Foo time.Duration + } + err := parse("--foo xxx", &args) + require.Error(t, err) +} + +func TestIntPtr(t *testing.T) { + var args struct { + Foo *int + } + err := parse("--foo 123", &args) + require.NoError(t, err) + require.NotNil(t, args.Foo) + assert.Equal(t, 123, *args.Foo) +} + +func TestIntPtrNotPresent(t *testing.T) { + var args struct { + Foo *int + } + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.Foo) +} + +func TestMixed(t *testing.T) { + var args struct { + Foo string `arg:"-f"` + Bar int + Baz uint `arg:"positional"` + Ham bool + Spam float32 + } + args.Bar = 3 + err := parse("123 -spam=1.2 -ham -f xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + assert.Equal(t, 3, args.Bar) + assert.Equal(t, uint(123), args.Baz) + assert.Equal(t, true, args.Ham) + assert.EqualValues(t, 1.2, args.Spam) +} + +func TestRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + err := parse("", &args) + require.Error(t, err, "--foo is required") +} + +func TestRequiredWithEnv(t *testing.T) { + var args struct { + Foo string `arg:"required,env:FOO"` + } + err := parse("", &args) + require.Error(t, err, "--foo is required (or environment variable FOO)") +} + +func TestShortFlag(t *testing.T) { + var args struct { + Foo string `arg:"-f"` + } + + err := parse("-f xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = parse("-foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = parse("--foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestInvalidShortFlag(t *testing.T) { + var args struct { + Foo string `arg:"-foo"` + } + err := parse("", &args) + assert.Error(t, err) +} + +func TestLongFlag(t *testing.T) { + var args struct { + Foo string `arg:"--abc"` + } + + err := parse("-abc xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = parse("--abc xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestSlice(t *testing.T) { + var args struct { + Strings []string + } + err := parse("--strings a b c", &args) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, args.Strings) +} +func TestSliceOfBools(t *testing.T) { + var args struct { + B []bool + } + + err := parse("--b true false true", &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, args.B) +} + +func TestMap(t *testing.T) { + var args struct { + Values map[string]int + } + err := parse("--values a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestMapPositional(t *testing.T) { + var args struct { + Values map[string]int `arg:"positional"` + } + err := parse("a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestMapWithSeparate(t *testing.T) { + var args struct { + Values map[string]int `arg:"separate"` + } + err := parse("--values a=1 --values b=2 --values c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestPlaceholder(t *testing.T) { + var args struct { + Input string `arg:"positional" placeholder:"SRC"` + Output []string `arg:"positional" placeholder:"DST"` + Optimize int `arg:"-O" placeholder:"LEVEL"` + MaxJobs int `arg:"-j" placeholder:"N"` + } + err := parse("-O 5 --maxjobs 2 src dest1 dest2", &args) + assert.NoError(t, err) +} + +func TestNoLongName(t *testing.T) { + var args struct { + ShortOnly string `arg:"-s,--"` + EnvOnly string `arg:"--,env"` + } + setenv(t, "ENVONLY", "TestVal") + err := parse("-s TestVal2", &args) + assert.NoError(t, err) + assert.Equal(t, "TestVal", args.EnvOnly) + assert.Equal(t, "TestVal2", args.ShortOnly) +} + +func TestCaseSensitive(t *testing.T) { + var args struct { + Lower bool `arg:"-v"` + Upper bool `arg:"-V"` + } + + err := parse("-v", &args) + require.NoError(t, err) + assert.True(t, args.Lower) + assert.False(t, args.Upper) +} + +func TestCaseSensitive2(t *testing.T) { + var args struct { + Lower bool `arg:"-v"` + Upper bool `arg:"-V"` + } + + err := parse("-V", &args) + require.NoError(t, err) + assert.False(t, args.Lower) + assert.True(t, args.Upper) +} + +func TestPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional"` + } + err := parse("foo", &args) + require.NoError(t, err) + assert.Equal(t, "foo", args.Input) + assert.Equal(t, "", args.Output) +} + +func TestPositionalPointer(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output []*string `arg:"positional"` + } + err := parse("foo bar baz", &args) + require.NoError(t, err) + assert.Equal(t, "foo", args.Input) + bar := "bar" + baz := "baz" + assert.Equal(t, []*string{&bar, &baz}, args.Output) +} + +func TestRequiredPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional,required"` + } + err := parse("foo", &args) + assert.Error(t, err) +} + +func TestRequiredPositionalMultiple(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Multiple []string `arg:"positional,required"` + } + err := parse("foo", &args) + assert.Error(t, err) +} + +func TestTooManyPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional"` + } + err := parse("foo bar baz", &args) + assert.Error(t, err) +} + +func TestMultiple(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + err := parse("--foo 1 2 3 --bar x y z", &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x", "y", "z"}, args.Bar) +} + +func TestMultiplePositionals(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Multiple []string `arg:"positional,required"` + } + err := parse("foo a b c", &args) + assert.NoError(t, err) + assert.Equal(t, "foo", args.Input) + assert.Equal(t, []string{"a", "b", "c"}, args.Multiple) +} + +func TestMultipleWithEq(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + err := parse("--foo 1 2 3 --bar=x", &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x"}, args.Bar) +} + +func TestMultipleWithDefault(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + args.Foo = []int{42} + args.Bar = []string{"foo"} + err := parse("--foo 1 2 3 --bar x y z", &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x", "y", "z"}, args.Bar) +} + +func TestExemptField(t *testing.T) { + var args struct { + Foo string + Bar interface{} `arg:"-"` + } + err := parse("--foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestUnknownField(t *testing.T) { + var args struct { + Foo string + } + err := parse("--bar xyz", &args) + assert.Error(t, err) +} + +func TestMissingRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + X []string `arg:"positional"` + } + err := parse("x", &args) + assert.Error(t, err) +} + +func TestNonsenseKey(t *testing.T) { + var args struct { + X []string `arg:"positional, nonsense"` + } + err := parse("x", &args) + assert.Error(t, err) +} + +func TestMissingValueAtEnd(t *testing.T) { + var args struct { + Foo string + } + err := parse("--foo", &args) + assert.Error(t, err) +} + +func TestMissingValueInMiddle(t *testing.T) { + var args struct { + Foo string + Bar string + } + err := parse("--foo --bar=abc", &args) + assert.Error(t, err) +} + +func TestNegativeValue(t *testing.T) { + var args struct { + Foo int + } + err := parse("--foo -123", &args) + require.NoError(t, err) + assert.Equal(t, -123, args.Foo) +} + +func TestInvalidInt(t *testing.T) { + var args struct { + Foo int + } + err := parse("--foo=xyz", &args) + assert.Error(t, err) +} + +func TestInvalidUint(t *testing.T) { + var args struct { + Foo uint + } + err := parse("--foo=xyz", &args) + assert.Error(t, err) +} + +func TestInvalidFloat(t *testing.T) { + var args struct { + Foo float64 + } + err := parse("--foo xyz", &args) + require.Error(t, err) +} + +func TestInvalidBool(t *testing.T) { + var args struct { + Foo bool + } + err := parse("--foo=xyz", &args) + require.Error(t, err) +} + +func TestInvalidIntSlice(t *testing.T) { + var args struct { + Foo []int + } + err := parse("--foo 1 2 xyz", &args) + require.Error(t, err) +} + +func TestInvalidPositional(t *testing.T) { + var args struct { + Foo int `arg:"positional"` + } + err := parse("xyz", &args) + require.Error(t, err) +} + +func TestInvalidPositionalSlice(t *testing.T) { + var args struct { + Foo []int `arg:"positional"` + } + err := parse("1 2 xyz", &args) + require.Error(t, err) +} + +func TestNoMoreOptions(t *testing.T) { + var args struct { + Foo string + Bar []string `arg:"positional"` + } + err := parse("abc -- --foo xyz", &args) + require.NoError(t, err) + assert.Equal(t, "", args.Foo) + assert.Equal(t, []string{"abc", "--foo", "xyz"}, args.Bar) +} + +func TestNoMoreOptionsBeforeHelp(t *testing.T) { + var args struct { + Foo int + } + err := parse("not_an_integer -- --help", &args) + assert.NotEqual(t, ErrHelp, err) +} + +func TestHelpFlag(t *testing.T) { + var args struct { + Foo string + Bar interface{} `arg:"-"` + } + err := parse("--help", &args) + assert.Equal(t, ErrHelp, err) +} + +func TestPanicOnNonPointer(t *testing.T) { + var args struct{} + assert.Panics(t, func() { + _ = parse("", args) + }) +} + +func TestErrorOnNonStruct(t *testing.T) { + var args string + err := parse("", &args) + assert.Error(t, err) +} + +func TestUnsupportedType(t *testing.T) { + var args struct { + Foo interface{} + } + err := parse("--foo", &args) + assert.Error(t, err) +} + +func TestUnsupportedSliceElement(t *testing.T) { + var args struct { + Foo []interface{} + } + err := parse("--foo 3", &args) + assert.Error(t, err) +} + +func TestUnsupportedSliceElementMissingValue(t *testing.T) { + var args struct { + Foo []interface{} + } + err := parse("--foo", &args) + assert.Error(t, err) +} + +func TestUnknownTag(t *testing.T) { + var args struct { + Foo string `arg:"this_is_not_valid"` + } + err := parse("--foo xyz", &args) + assert.Error(t, err) +} + +func TestParse(t *testing.T) { + var args struct { + Foo string + } + os.Args = []string{"example", "--foo", "bar"} + err := Parse(&args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestParseError(t *testing.T) { + var args struct { + Foo string `arg:"this_is_not_valid"` + } + os.Args = []string{"example", "--bar"} + err := Parse(&args) + assert.Error(t, err) +} + +func TestMustParse(t *testing.T) { + var args struct { + Foo string + } + os.Args = []string{"example", "--foo", "bar"} + parser := MustParse(&args) + assert.Equal(t, "bar", args.Foo) + assert.NotNil(t, parser) +} + +func TestEnvironmentVariable(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableNotPresent(t *testing.T) { + var args struct { + NotPresent string `arg:"env"` + } + _, err := parseWithEnv("", nil, &args) + require.NoError(t, err) + assert.Equal(t, "", args.NotPresent) +} + +func TestEnvironmentVariableOverrideName(t *testing.T) { + var args struct { + Foo string `arg:"env:BAZ"` + } + _, err := parseWithEnv("", []string{"BAZ=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableOverrideArgument(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + _, err := parseWithEnv("--foo zzz", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "zzz", args.Foo) +} + +func TestEnvironmentVariableError(t *testing.T) { + var args struct { + Foo int `arg:"env"` + } + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + assert.Error(t, err) +} + +func TestEnvironmentVariableRequired(t *testing.T) { + var args struct { + Foo string `arg:"env,required"` + } + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableSliceArgumentString(t *testing.T) { + var args struct { + Foo []string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=bar,"baz, qux"`}, &args) + require.NoError(t, err) + assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) +} + +func TestEnvironmentVariableSliceEmpty(t *testing.T) { + var args struct { + Foo []string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=`}, &args) + require.NoError(t, err) + assert.Len(t, args.Foo, 0) +} + +func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { + var args struct { + Foo []int `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1,99`}, &args) + require.NoError(t, err) + assert.Equal(t, []int{1, 99}, args.Foo) +} + +func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { + var args struct { + Foo []float32 `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1.1,99.9`}, &args) + require.NoError(t, err) + assert.Equal(t, []float32{1.1, 99.9}, args.Foo) +} + +func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { + var args struct { + Foo []bool `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=true,false,0,1`}, &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, false, true}, args.Foo) +} + +func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { + var args struct { + Foo []int `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1,99\"`}, &args) + assert.Error(t, err) +} + +func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { + var args struct { + Foo []bool `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=one,two`}, &args) + assert.Error(t, err) +} + +func TestEnvironmentVariableMap(t *testing.T) { + var args struct { + Foo map[int]string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=1=one,99=ninetynine`}, &args) + require.NoError(t, err) + assert.Len(t, args.Foo, 2) + assert.Equal(t, "one", args.Foo[1]) + assert.Equal(t, "ninetynine", args.Foo[99]) +} + +func TestEnvironmentVariableEmptyMap(t *testing.T) { + var args struct { + Foo map[int]string `arg:"env"` + } + _, err := parseWithEnv("", []string{`FOO=`}, &args) + require.NoError(t, err) + assert.Len(t, args.Foo, 0) +} + +func TestEnvironmentVariableIgnored(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + setenv(t, "FOO", "abc") + + p, err := NewParser(Config{IgnoreEnv: true}, &args) + require.NoError(t, err) + + err = p.Parse(nil) + assert.NoError(t, err) + assert.Equal(t, "", args.Foo) +} + +func TestDefaultValuesIgnored(t *testing.T) { + var args struct { + Foo string `default:"bad"` + } + + p, err := NewParser(Config{IgnoreDefault: true}, &args) + require.NoError(t, err) + + err = p.Parse(nil) + assert.NoError(t, err) + assert.Equal(t, "", args.Foo) +} + +func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) { + var args struct { + Sub *struct { + Foo string `arg:"env"` + } `arg:"subcommand"` + } + setenv(t, "FOO", "abc") + + p, err := NewParser(Config{IgnoreEnv: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub"}) + assert.NoError(t, err) + assert.Equal(t, "", args.Sub.Foo) +} + +type textUnmarshaler struct { + val int +} + +func (f *textUnmarshaler) UnmarshalText(b []byte) error { + f.val = len(b) + return nil +} + +func TestTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +func TestPtrToTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo *textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +func TestRepeatedTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []textUnmarshaler + } + err := parse("--foo abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestRepeatedPtrToTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []*textUnmarshaler + } + err := parse("--foo abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestPositionalTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []textUnmarshaler `arg:"positional"` + } + err := parse("abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestPositionalPtrToTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []*textUnmarshaler `arg:"positional"` + } + err := parse("abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +type boolUnmarshaler bool + +func (p *boolUnmarshaler) UnmarshalText(b []byte) error { + *p = len(b)%2 == 0 + return nil +} + +func TestBoolUnmarhsaler(t *testing.T) { + // test that a bool type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a bool + var args struct { + Foo *boolUnmarshaler + } + err := parse("--foo ab", &args) + require.NoError(t, err) + assert.EqualValues(t, true, *args.Foo) +} + +type sliceUnmarshaler []int + +func (p *sliceUnmarshaler) UnmarshalText(b []byte) error { + *p = sliceUnmarshaler{len(b)} + return nil +} + +func TestSliceUnmarhsaler(t *testing.T) { + // test that a slice type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a slice + var args struct { + Foo *sliceUnmarshaler + Bar string `arg:"positional"` + } + err := parse("--foo abcde xyz", &args) + require.NoError(t, err) + require.Len(t, *args.Foo, 1) + assert.EqualValues(t, 5, (*args.Foo)[0]) + assert.Equal(t, "xyz", args.Bar) +} + +func TestIP(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestPtrToIP(t *testing.T) { + var args struct { + Host *net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestURL(t *testing.T) { + var args struct { + URL url.URL + } + err := parse("--url https://example.com/get?item=xyz", &args) + require.NoError(t, err) + assert.Equal(t, "https://example.com/get?item=xyz", args.URL.String()) +} + +func TestPtrToURL(t *testing.T) { + var args struct { + URL *url.URL + } + err := parse("--url http://example.com/#xyz", &args) + require.NoError(t, err) + assert.Equal(t, "http://example.com/#xyz", args.URL.String()) +} + +func TestIPSlice(t *testing.T) { + var args struct { + Host []net.IP + } + err := parse("--host 192.168.0.1 127.0.0.1", &args) + require.NoError(t, err) + require.Len(t, args.Host, 2) + assert.Equal(t, "192.168.0.1", args.Host[0].String()) + assert.Equal(t, "127.0.0.1", args.Host[1].String()) +} + +func TestInvalidIPAddress(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMAC(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host 0123.4567.89ab", &args) + require.NoError(t, err) + assert.Equal(t, "01:23:45:67:89:ab", args.Host.String()) +} + +func TestInvalidMac(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient foo@example.com", &args) + require.NoError(t, err) + assert.Equal(t, "", args.Recipient.String()) +} + +func TestInvalidMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient xxx", &args) + assert.Error(t, err) +} + +type A struct { + X string +} + +type B struct { + Y int +} + +func TestEmbedded(t *testing.T) { + var args struct { + A + B + Z bool + } + err := parse("--x=hello --y=321 --z", &args) + require.NoError(t, err) + assert.Equal(t, "hello", args.X) + assert.Equal(t, 321, args.Y) + assert.Equal(t, true, args.Z) +} + +func TestEmbeddedPtr(t *testing.T) { + // embedded pointer fields are not supported so this should return an error + var args struct { + *A + } + err := parse("--x=hello", &args) + require.Error(t, err) +} + +func TestEmbeddedPtrIgnored(t *testing.T) { + // embedded pointer fields are not normally supported but here + // we explicitly exclude it so the non-nil embedded structs + // should work as expected + var args struct { + *A `arg:"-"` + B + } + err := parse("--y=321", &args) + require.NoError(t, err) + assert.Equal(t, 321, args.Y) +} + +func TestEmbeddedWithDuplicateField(t *testing.T) { + // see https://github.com/alexflint/go-arg/issues/100 + type T struct { + A string `arg:"--cat"` + } + type U struct { + A string `arg:"--dog"` + } + var args struct { + T + U + } + + err := parse("--cat=cat --dog=dog", &args) + require.NoError(t, err) + assert.Equal(t, "cat", args.T.A) + assert.Equal(t, "dog", args.U.A) +} + +func TestEmbeddedWithDuplicateField2(t *testing.T) { + // see https://github.com/alexflint/go-arg/issues/100 + type T struct { + A string + } + type U struct { + A string + } + var args struct { + T + U + } + + err := parse("--a=xyz", &args) + require.NoError(t, err) + assert.Equal(t, "xyz", args.T.A) + assert.Equal(t, "", args.U.A) +} + +func TestUnexportedEmbedded(t *testing.T) { + type embeddedArgs struct { + Foo string + } + var args struct { + embeddedArgs + } + err := parse("--foo bar", &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestIgnoredEmbedded(t *testing.T) { + type embeddedArgs struct { + Foo string + } + var args struct { + embeddedArgs `arg:"-"` + } + err := parse("--foo bar", &args) + require.Error(t, err) +} + +func TestEmptyArgs(t *testing.T) { + origArgs := os.Args + + // test what happens if somehow os.Args is empty + os.Args = nil + var args struct { + Foo string + } + MustParse(&args) + + // put the original arguments back + os.Args = origArgs +} + +func TestTooManyHyphens(t *testing.T) { + var args struct { + TooManyHyphens string `arg:"---x"` + } + err := parse("--foo -", &args) + assert.Error(t, err) +} + +func TestHyphenAsOption(t *testing.T) { + var args struct { + Foo string + } + err := parse("--foo -", &args) + require.NoError(t, err) + assert.Equal(t, "-", args.Foo) +} + +func TestHyphenAsPositional(t *testing.T) { + var args struct { + Foo string `arg:"positional"` + } + err := parse("-", &args) + require.NoError(t, err) + assert.Equal(t, "-", args.Foo) +} + +func TestHyphenInMultiOption(t *testing.T) { + var args struct { + Foo []string + Bar int + } + err := parse("--foo --- x - y --bar 3", &args) + require.NoError(t, err) + assert.Equal(t, []string{"---", "x", "-", "y"}, args.Foo) + assert.Equal(t, 3, args.Bar) +} + +func TestHyphenInMultiPositional(t *testing.T) { + var args struct { + Foo []string `arg:"positional"` + } + err := parse("--- x - y", &args) + require.NoError(t, err) + assert.Equal(t, []string{"---", "x", "-", "y"}, args.Foo) +} + +func TestSeparate(t *testing.T) { + for _, val := range []string{"-f one", "-f=one", "--foo one", "--foo=one"} { + var args struct { + Foo []string `arg:"--foo,-f,separate"` + } + + err := parse(val, &args) + require.NoError(t, err) + assert.Equal(t, []string{"one"}, args.Foo) + } +} + +func TestSeparateWithDefault(t *testing.T) { + args := struct { + Foo []string `arg:"--foo,-f,separate"` + }{ + Foo: []string{"default"}, + } + + err := parse("-f one -f=two", &args) + require.NoError(t, err) + assert.Equal(t, []string{"default", "one", "two"}, args.Foo) +} + +func TestSeparateWithPositional(t *testing.T) { + var args struct { + Foo []string `arg:"--foo,-f,separate"` + Bar string `arg:"positional"` + Moo string `arg:"positional"` + } + + err := parse("zzz --foo one -f=two --foo=three -f four aaa", &args) + require.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) + assert.Equal(t, "zzz", args.Bar) + assert.Equal(t, "aaa", args.Moo) +} + +func TestSeparatePositionalInterweaved(t *testing.T) { + var args struct { + Foo []string `arg:"--foo,-f,separate"` + Bar []string `arg:"--bar,-b,separate"` + Pre string `arg:"positional"` + Post []string `arg:"positional"` + } + + err := parse("zzz -f foo1 -b=bar1 --foo=foo2 -b bar2 post1 -b bar3 post2 post3", &args) + require.NoError(t, err) + assert.Equal(t, []string{"foo1", "foo2"}, args.Foo) + assert.Equal(t, []string{"bar1", "bar2", "bar3"}, args.Bar) + assert.Equal(t, "zzz", args.Pre) + assert.Equal(t, []string{"post1", "post2", "post3"}, args.Post) +} + +func TestSpacesAllowedInTags(t *testing.T) { + var args struct { + Foo []string `arg:"--foo, -f, separate, required, help:quite nice really"` + } + + err := parse("--foo one -f=two --foo=three -f four", &args) + require.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) +} + +func TestReuseParser(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"--foo=abc"}) + require.NoError(t, err) + assert.Equal(t, args.Foo, "abc") + + err = p.Parse([]string{}) + assert.Error(t, err) +} + +func TestVersion(t *testing.T) { + var args struct{} + err := parse("--version", &args) + assert.Equal(t, ErrVersion, err) + +} + +func TestMultipleTerminates(t *testing.T) { + var args struct { + X []string + Y string `arg:"positional"` + } + + err := parse("--x a b -- c", &args) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, args.X) + assert.Equal(t, "c", args.Y) +} + +func TestDefaultOptionValues(t *testing.T) { + var args struct { + A int `default:"123"` + B *int `default:"123"` + C string `default:"abc"` + D *string `default:"abc"` + E float64 `default:"1.23"` + F *float64 `default:"1.23"` + G bool `default:"true"` + H *bool `default:"true"` + } + + err := parse("--c=xyz --e=4.56", &args) + require.NoError(t, err) + + assert.Equal(t, 123, args.A) + assert.Equal(t, 123, *args.B) + assert.Equal(t, "xyz", args.C) + assert.Equal(t, "abc", *args.D) + assert.Equal(t, 4.56, args.E) + assert.Equal(t, 1.23, *args.F) + assert.True(t, args.G) + assert.True(t, args.G) +} + +func TestDefaultUnparseable(t *testing.T) { + var args struct { + A int `default:"x"` + } + + err := parse("", &args) + assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`) +} + +func TestDefaultPositionalValues(t *testing.T) { + var args struct { + A int `arg:"positional" default:"123"` + B *int `arg:"positional" default:"123"` + C string `arg:"positional" default:"abc"` + D *string `arg:"positional" default:"abc"` + E float64 `arg:"positional" default:"1.23"` + F *float64 `arg:"positional" default:"1.23"` + G bool `arg:"positional" default:"true"` + H *bool `arg:"positional" default:"true"` + } + + err := parse("456 789", &args) + require.NoError(t, err) + + assert.Equal(t, 456, args.A) + assert.Equal(t, 789, *args.B) + assert.Equal(t, "abc", args.C) + assert.Equal(t, "abc", *args.D) + assert.Equal(t, 1.23, args.E) + assert.Equal(t, 1.23, *args.F) + assert.True(t, args.G) + assert.True(t, args.G) +} + +func TestDefaultValuesNotAllowedWithRequired(t *testing.T) { + var args struct { + A int `arg:"required" default:"123"` // required not allowed with default! + } + + err := parse("", &args) + assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified") +} + +func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { + var args struct { + A []int `default:"123"` // required not allowed with default! + } + + err := parse("", &args) + assert.EqualError(t, err, ".A: default values are not supported for slice or map fields") +} + +func TestUnexportedFieldsSkipped(t *testing.T) { + var args struct { + unexported struct{} + } + + _, err := NewParser(Config{}, &args) + require.NoError(t, err) +} + +func TestMustParseInvalidParser(t *testing.T) { + originalExit := osExit + originalStdout := stdout + defer func() { + osExit = originalExit + stdout = originalStdout + }() + + var exitCode int + osExit = func(code int) { exitCode = code } + stdout = &bytes.Buffer{} + + var args struct { + CannotParse struct{} + } + parser := MustParse(&args) + assert.Nil(t, parser) + assert.Equal(t, -1, exitCode) +} + +func TestMustParsePrintsHelp(t *testing.T) { + originalExit := osExit + originalStdout := stdout + originalArgs := os.Args + defer func() { + osExit = originalExit + stdout = originalStdout + os.Args = originalArgs + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + os.Args = []string{"someprogram", "--help"} + stdout = &bytes.Buffer{} + + var args struct{} + parser := MustParse(&args) + assert.NotNil(t, parser) + require.NotNil(t, exitCode) + assert.Equal(t, 0, *exitCode) +} + +func TestMustParsePrintsVersion(t *testing.T) { + originalExit := osExit + originalStdout := stdout + originalArgs := os.Args + defer func() { + osExit = originalExit + stdout = originalStdout + os.Args = originalArgs + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + os.Args = []string{"someprogram", "--version"} + + var b bytes.Buffer + stdout = &b + + var args versioned + parser := MustParse(&args) + require.NotNil(t, parser) + require.NotNil(t, exitCode) + assert.Equal(t, 0, *exitCode) + assert.Equal(t, "example 3.2.1\n", b.String()) +} diff --git a/v2/reflect.go b/v2/reflect.go new file mode 100644 index 0000000..cd80be7 --- /dev/null +++ b/v2/reflect.go @@ -0,0 +1,107 @@ +package arg + +import ( + "encoding" + "fmt" + "reflect" + "unicode" + "unicode/utf8" + + scalar "github.com/alexflint/go-scalar" +) + +var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + +// cardinality tracks how many tokens are expected for a given spec +// - zero is a boolean, which does to expect any value +// - one is an ordinary option that will be parsed from a single token +// - multiple is a slice or map that can accept zero or more tokens +type cardinality int + +const ( + zero cardinality = iota + one + multiple + unsupported +) + +func (k cardinality) String() string { + switch k { + case zero: + return "zero" + case one: + return "one" + case multiple: + return "multiple" + case unsupported: + return "unsupported" + default: + return fmt.Sprintf("unknown(%d)", int(k)) + } +} + +// cardinalityOf returns true if the type can be parsed from a string +func cardinalityOf(t reflect.Type) (cardinality, error) { + if scalar.CanParse(t) { + if isBoolean(t) { + return zero, nil + } + return one, nil + } + + // look inside pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // look inside slice and map types + switch t.Kind() { + case reflect.Slice: + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem()) + } + return multiple, nil + case reflect.Map: + if !scalar.CanParse(t.Key()) { + return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem()) + } + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem()) + } + return multiple, nil + default: + return unsupported, fmt.Errorf("cannot parse into %v", t) + } +} + +// isBoolean returns true if the type can be parsed from a single string +func isBoolean(t reflect.Type) bool { + switch { + case t.Implements(textUnmarshalerType): + return false + case t.Kind() == reflect.Bool: + return true + case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool: + return true + default: + return false + } +} + +// isExported returns true if the struct field name is exported +func isExported(field string) bool { + r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8 + return unicode.IsLetter(r) && unicode.IsUpper(r) +} + +// isZero returns true if v contains the zero value for its type +func isZero(v reflect.Value) bool { + t := v.Type() + if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { + return v.IsNil() + } + if !t.Comparable() { + return false + } + return v.Interface() == reflect.Zero(t).Interface() +} diff --git a/v2/reflect_test.go b/v2/reflect_test.go new file mode 100644 index 0000000..10909b3 --- /dev/null +++ b/v2/reflect_test.go @@ -0,0 +1,112 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) { + actual, err := cardinalityOf(typ) + assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual) + if expected == unsupported { + assert.Error(t, err) + } +} + +func TestCardinalityOf(t *testing.T) { + var b bool + var i int + var s string + var f float64 + var bs []bool + var is []int + var m map[string]int + var unsupported1 struct{} + var unsupported2 []struct{} + var unsupported3 map[string]struct{} + var unsupported4 map[struct{}]string + + assertCardinality(t, reflect.TypeOf(b), zero) + assertCardinality(t, reflect.TypeOf(i), one) + assertCardinality(t, reflect.TypeOf(s), one) + assertCardinality(t, reflect.TypeOf(f), one) + + assertCardinality(t, reflect.TypeOf(&b), zero) + assertCardinality(t, reflect.TypeOf(&s), one) + assertCardinality(t, reflect.TypeOf(&i), one) + assertCardinality(t, reflect.TypeOf(&f), one) + + assertCardinality(t, reflect.TypeOf(bs), multiple) + assertCardinality(t, reflect.TypeOf(is), multiple) + + assertCardinality(t, reflect.TypeOf(&bs), multiple) + assertCardinality(t, reflect.TypeOf(&is), multiple) + + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) + + assertCardinality(t, reflect.TypeOf(unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported4), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported) +} + +type implementsTextUnmarshaler struct{} + +func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error { + return nil +} + +func TestCardinalityTextUnmarshaler(t *testing.T) { + var x implementsTextUnmarshaler + var s []implementsTextUnmarshaler + var m []implementsTextUnmarshaler + assertCardinality(t, reflect.TypeOf(x), one) + assertCardinality(t, reflect.TypeOf(&x), one) + assertCardinality(t, reflect.TypeOf(s), multiple) + assertCardinality(t, reflect.TypeOf(&s), multiple) + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) +} + +func TestIsExported(t *testing.T) { + assert.True(t, isExported("Exported")) + assert.False(t, isExported("notExported")) + assert.False(t, isExported("")) + assert.False(t, isExported(string([]byte{255}))) +} + +func TestCardinalityString(t *testing.T) { + assert.Equal(t, "zero", zero.String()) + assert.Equal(t, "one", one.String()) + assert.Equal(t, "multiple", multiple.String()) + assert.Equal(t, "unsupported", unsupported.String()) + assert.Equal(t, "unknown(42)", cardinality(42).String()) +} + +func TestIsZero(t *testing.T) { + var zero int + var notZero = 3 + var nilSlice []int + var nonNilSlice = []int{1, 2, 3} + var nilMap map[string]string + var nonNilMap = map[string]string{"foo": "bar"} + var uncomparable = func() {} + + assert.True(t, isZero(reflect.ValueOf(zero))) + assert.False(t, isZero(reflect.ValueOf(notZero))) + + assert.True(t, isZero(reflect.ValueOf(nilSlice))) + assert.False(t, isZero(reflect.ValueOf(nonNilSlice))) + + assert.True(t, isZero(reflect.ValueOf(nilMap))) + assert.False(t, isZero(reflect.ValueOf(nonNilMap))) + + assert.False(t, isZero(reflect.ValueOf(uncomparable))) +} diff --git a/v2/sequence.go b/v2/sequence.go new file mode 100644 index 0000000..35a3614 --- /dev/null +++ b/v2/sequence.go @@ -0,0 +1,123 @@ +package arg + +import ( + "fmt" + "reflect" + "strings" + + scalar "github.com/alexflint/go-scalar" +) + +// setSliceOrMap parses a sequence of strings into a slice or map. If clear is +// true then any values already in the slice or map are first removed. +func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + t := dest.Type() + if t.Kind() == reflect.Ptr { + dest = dest.Elem() + t = t.Elem() + } + + switch t.Kind() { + case reflect.Slice: + return setSlice(dest, values, clear) + case reflect.Map: + return setMap(dest, values, clear) + default: + return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t) + } +} + +// setSlice parses a sequence of strings and inserts them into a slice. If clear +// is true then any values already in the slice are removed. +func setSlice(dest reflect.Value, values []string, clear bool) error { + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { + ptr = true + elem = elem.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + dest.SetLen(0) + } + + // parse the values one-by-one + for _, s := range values { + v := reflect.New(elem) + if err := scalar.ParseValue(v.Elem(), s); err != nil { + return err + } + if !ptr { + v = v.Elem() + } + dest.Set(reflect.Append(dest, v)) + } + return nil +} + +// setMap parses a sequence of name=value strings and inserts them into a map. +// If clear is true then any values already in the map are removed. +func setMap(dest reflect.Value, values []string, clear bool) error { + // determine the key and value type + var keyIsPtr bool + keyType := dest.Type().Key() + if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) { + keyIsPtr = true + keyType = keyType.Elem() + } + + var valIsPtr bool + valType := dest.Type().Elem() + if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) { + valIsPtr = true + valType = valType.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + for _, k := range dest.MapKeys() { + dest.SetMapIndex(k, reflect.Value{}) + } + } + + // allocate the map if it is not allocated + if dest.IsNil() { + dest.Set(reflect.MakeMap(dest.Type())) + } + + // parse the values one-by-one + for _, s := range values { + // split at the first equals sign + pos := strings.Index(s, "=") + if pos == -1 { + return fmt.Errorf("cannot parse %q into a map, expected format key=value", s) + } + + // parse the key + k := reflect.New(keyType) + if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil { + return err + } + if !keyIsPtr { + k = k.Elem() + } + + // parse the value + v := reflect.New(valType) + if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil { + return err + } + if !valIsPtr { + v = v.Elem() + } + + // add it to the map + dest.SetMapIndex(k, v) + } + return nil +} diff --git a/v2/sequence_test.go b/v2/sequence_test.go new file mode 100644 index 0000000..fde3e3a --- /dev/null +++ b/v2/sequence_test.go @@ -0,0 +1,152 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetSliceWithoutClearing(t *testing.T) { + xs := []int{10} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + require.NoError(t, err) + assert.Equal(t, []int{10, 1, 2, 3}, xs) +} + +func TestSetSliceAfterClearing(t *testing.T) { + xs := []int{100} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, xs) +} + +func TestSetSliceInvalid(t *testing.T) { + xs := []int{100} + entries := []string{"invalid"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetSlicePtr(t *testing.T) { + var xs []*int + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, *xs[0]) + assert.Equal(t, 2, *xs[1]) + assert.Equal(t, 3, *xs[2]) +} + +func TestSetSliceTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var xs []*textUnmarshaler + entries := []string{"a", "aa", "aaa"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, xs[0].val) + assert.Equal(t, 2, xs[1].val) + assert.Equal(t, 3, xs[2].val) +} + +func TestSetMapWithoutClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) + assert.Equal(t, 10, m["foo"]) +} + +func TestSetMapAfterClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 2) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) +} + +func TestSetMapWithKeyPointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[*string]int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) +} + +func TestSetMapWithValuePointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]*int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) + assert.Equal(t, 123, *m["abc"]) +} + +func TestSetMapTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[textUnmarshaler]*textUnmarshaler + entries := []string{"a=123", "aa=12", "aaa=1"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) + assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) + assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) +} + +func TestSetMapInvalidKey(t *testing.T) { + var m map[int]int + entries := []string{"invalid=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapInvalidValue(t *testing.T) { + var m map[int]int + entries := []string{"123=invalid"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapMalformed(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]string + entries := []string{"missing_equals_sign"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetSliceOrMapErrors(t *testing.T) { + var err error + var dest reflect.Value + + // converting a slice to a reflect.Value in this way will make it read only + var cannotSet []int + dest = reflect.ValueOf(cannotSet) + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in something that is not a slice or a map + var notSliceOrMap string + dest = reflect.ValueOf(¬SliceOrMap).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in a pointer to something that is not a slice or a map + var stringPtr *string + dest = reflect.ValueOf(&stringPtr).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) +} diff --git a/v2/subcommand.go b/v2/subcommand.go new file mode 100644 index 0000000..dff732c --- /dev/null +++ b/v2/subcommand.go @@ -0,0 +1,37 @@ +package arg + +// Subcommand returns the user struct for the subcommand selected by +// the command line arguments most recently processed by the parser. +// The return value is always a pointer to a struct. If no subcommand +// was specified then it returns the top-level arguments struct. If +// no command line arguments have been processed by this parser then it +// returns nil. +func (p *Parser) Subcommand() interface{} { + if p.lastCmd == nil || p.lastCmd.parent == nil { + return nil + } + return p.val(p.lastCmd.dest).Interface() +} + +// SubcommandNames returns the sequence of subcommands specified by the +// user. If no subcommands were given then it returns an empty slice. +func (p *Parser) SubcommandNames() []string { + if p.lastCmd == nil { + return nil + } + + // make a list of ancestor commands + var ancestors []string + cur := p.lastCmd + for cur.parent != nil { // we want to exclude the root + ancestors = append(ancestors, cur.name) + cur = cur.parent + } + + // reverse the list + out := make([]string, len(ancestors)) + for i := 0; i < len(ancestors); i++ { + out[i] = ancestors[len(ancestors)-i-1] + } + return out +} diff --git a/v2/subcommand_test.go b/v2/subcommand_test.go new file mode 100644 index 0000000..2c61dd3 --- /dev/null +++ b/v2/subcommand_test.go @@ -0,0 +1,413 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// This file contains tests for parse.go but I decided to put them here +// since that file is getting large + +func TestSubcommandNotAPointer(t *testing.T) { + var args struct { + A string `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestSubcommandNotAPointerToStruct(t *testing.T) { + var args struct { + A struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestPositionalAndSubcommandNotAllowed(t *testing.T) { + var args struct { + A string `arg:"positional"` + B *struct{} `arg:"subcommand"` + } + _, err := NewParser(Config{}, &args) + assert.Error(t, err) +} + +func TestMinimalSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := pparse("list", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, args.List, p.Subcommand()) + assert.Equal(t, []string{"list"}, p.SubcommandNames()) +} + +func TestSubcommandNamesBeforeParsing(t *testing.T) { + type listCmd struct{} + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + assert.Nil(t, p.Subcommand()) + assert.Nil(t, p.SubcommandNames()) +} + +func TestNoSuchSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + _, err := pparse("invalid", &args) + assert.Error(t, err) +} + +func TestNamedSubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand:ls"` + } + p, err := pparse("ls", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, args.List, p.Subcommand()) + assert.Equal(t, []string{"ls"}, p.SubcommandNames()) +} + +func TestEmptySubcommand(t *testing.T) { + type listCmd struct { + } + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := pparse("", &args) + require.NoError(t, err) + assert.Nil(t, args.List) + assert.Nil(t, p.Subcommand()) + assert.Empty(t, p.SubcommandNames()) +} + +func TestTwoSubcommands(t *testing.T) { + type getCmd struct { + } + type listCmd struct { + } + var args struct { + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + p, err := pparse("list", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List, p.Subcommand()) + assert.Equal(t, []string{"list"}, p.SubcommandNames()) +} + +func TestSubcommandsWithOptions(t *testing.T) { + type getCmd struct { + Name string + } + type listCmd struct { + Limit int + } + type cmd struct { + Verbose bool + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("list", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + } + + { + var args cmd + err := parse("list --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + } + + { + var args cmd + err := parse("list --limit 3 --verbose", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("list --verbose --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("--verbose list --limit 3", &args) + require.NoError(t, err) + assert.Nil(t, args.Get) + assert.NotNil(t, args.List) + assert.Equal(t, args.List.Limit, 3) + assert.True(t, args.Verbose) + } + + { + var args cmd + err := parse("get", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Nil(t, args.List) + } + + { + var args cmd + err := parse("get --name test", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Nil(t, args.List) + assert.Equal(t, args.Get.Name, "test") + } +} + +func TestSubcommandsWithEnvVars(t *testing.T) { + type getCmd struct { + Name string `arg:"env"` + } + type listCmd struct { + Limit int `arg:"env"` + } + type cmd struct { + Verbose bool + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + setenv(t, "LIMIT", "123") + err := parse("list", &args) + require.NoError(t, err) + require.NotNil(t, args.List) + assert.Equal(t, 123, args.List.Limit) + } + + { + var args cmd + setenv(t, "LIMIT", "not_an_integer") + err := parse("list", &args) + assert.Error(t, err) + } +} + +func TestNestedSubcommands(t *testing.T) { + type child struct{} + type parent struct { + Child *child `arg:"subcommand"` + } + type grandparent struct { + Parent *parent `arg:"subcommand"` + } + type root struct { + Grandparent *grandparent `arg:"subcommand"` + } + + { + var args root + p, err := pparse("grandparent parent child", &args) + require.NoError(t, err) + require.NotNil(t, args.Grandparent) + require.NotNil(t, args.Grandparent.Parent) + require.NotNil(t, args.Grandparent.Parent.Child) + assert.Equal(t, args.Grandparent.Parent.Child, p.Subcommand()) + assert.Equal(t, []string{"grandparent", "parent", "child"}, p.SubcommandNames()) + } + + { + var args root + p, err := pparse("grandparent parent", &args) + require.NoError(t, err) + require.NotNil(t, args.Grandparent) + require.NotNil(t, args.Grandparent.Parent) + require.Nil(t, args.Grandparent.Parent.Child) + assert.Equal(t, args.Grandparent.Parent, p.Subcommand()) + assert.Equal(t, []string{"grandparent", "parent"}, p.SubcommandNames()) + } + + { + var args root + p, err := pparse("grandparent", &args) + require.NoError(t, err) + require.NotNil(t, args.Grandparent) + require.Nil(t, args.Grandparent.Parent) + assert.Equal(t, args.Grandparent, p.Subcommand()) + assert.Equal(t, []string{"grandparent"}, p.SubcommandNames()) + } + + { + var args root + p, err := pparse("", &args) + require.NoError(t, err) + require.Nil(t, args.Grandparent) + assert.Nil(t, p.Subcommand()) + assert.Empty(t, p.SubcommandNames()) + } +} + +func TestSubcommandsWithPositionals(t *testing.T) { + type listCmd struct { + Pattern string `arg:"positional"` + } + type cmd struct { + Format string + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("list", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "", args.List.Pattern) + } + + { + var args cmd + err := parse("list --format json", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("list somepattern", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + } + + { + var args cmd + err := parse("list somepattern --format json", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("list --format json somepattern", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("--format json list somepattern", &args) + require.NoError(t, err) + assert.NotNil(t, args.List) + assert.Equal(t, "somepattern", args.List.Pattern) + assert.Equal(t, "json", args.Format) + } + + { + var args cmd + err := parse("--format json", &args) + require.NoError(t, err) + assert.Nil(t, args.List) + assert.Equal(t, "json", args.Format) + } +} +func TestSubcommandsWithMultiplePositionals(t *testing.T) { + type getCmd struct { + Items []string `arg:"positional"` + } + type cmd struct { + Limit int + Get *getCmd `arg:"subcommand"` + } + + { + var args cmd + err := parse("get", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Empty(t, args.Get.Items) + } + + { + var args cmd + err := parse("get --limit 5", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Empty(t, args.Get.Items) + assert.Equal(t, 5, args.Limit) + } + + { + var args cmd + err := parse("get item1", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Equal(t, []string{"item1"}, args.Get.Items) + } + + { + var args cmd + err := parse("get item1 item2 item3", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Equal(t, []string{"item1", "item2", "item3"}, args.Get.Items) + } + + { + var args cmd + err := parse("get item1 --limit 5 item2", &args) + require.NoError(t, err) + assert.NotNil(t, args.Get) + assert.Equal(t, []string{"item1", "item2"}, args.Get.Items) + assert.Equal(t, 5, args.Limit) + } +} + +func TestValForNilStruct(t *testing.T) { + type subcmd struct{} + var cmd struct { + Sub *subcmd `arg:"subcommand"` + } + + p, err := NewParser(Config{}, &cmd) + require.NoError(t, err) + + typ := reflect.TypeOf(cmd) + subField, _ := typ.FieldByName("Sub") + + v := p.val(path{fields: []reflect.StructField{subField, subField}}) + assert.False(t, v.IsValid()) +} diff --git a/v2/usage.go b/v2/usage.go new file mode 100644 index 0000000..7ba06cc --- /dev/null +++ b/v2/usage.go @@ -0,0 +1,339 @@ +package arg + +import ( + "fmt" + "io" + "os" + "strings" +) + +// the width of the left column +const colWidth = 25 + +// to allow monkey patching in tests +var ( + stdout io.Writer = os.Stdout + stderr io.Writer = os.Stderr + osExit = os.Exit +) + +// Fail prints usage information to stderr and exits with non-zero status +func (p *Parser) Fail(msg string) { + p.failWithSubcommand(msg, p.cmd) +} + +// FailSubcommand prints usage information for a specified subcommand to stderr, +// then exits with non-zero status. To write usage information for a top-level +// subcommand, provide just the name of that subcommand. To write usage +// information for a subcommand that is nested under another subcommand, provide +// a sequence of subcommand names starting with the top-level subcommand and so +// on down the tree. +func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { + cmd, err := p.lookupCommand(subcommand...) + if err != nil { + return err + } + p.failWithSubcommand(msg, cmd) + return nil +} + +// failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status +func (p *Parser) failWithSubcommand(msg string, cmd *command) { + p.writeUsageForSubcommand(stderr, cmd) + fmt.Fprintln(stderr, "error:", msg) + osExit(-1) +} + +// WriteUsage writes usage information to the given writer +func (p *Parser) WriteUsage(w io.Writer) { + cmd := p.cmd + if p.lastCmd != nil { + cmd = p.lastCmd + } + p.writeUsageForSubcommand(w, cmd) +} + +// WriteUsageForSubcommand writes the usage information for a specified +// subcommand. To write usage information for a top-level subcommand, provide +// just the name of that subcommand. To write usage information for a subcommand +// that is nested under another subcommand, provide a sequence of subcommand +// names starting with the top-level subcommand and so on down the tree. +func (p *Parser) WriteUsageForSubcommand(w io.Writer, subcommand ...string) error { + cmd, err := p.lookupCommand(subcommand...) + if err != nil { + return err + } + p.writeUsageForSubcommand(w, cmd) + return nil +} + +// writeUsageForSubcommand writes usage information for the given subcommand +func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { + var positionals, longOptions, shortOptions []*spec + for _, spec := range cmd.specs { + switch { + case spec.positional: + positionals = append(positionals, spec) + case spec.long != "": + longOptions = append(longOptions, spec) + case spec.short != "": + shortOptions = append(shortOptions, spec) + } + } + + if p.version != "" { + fmt.Fprintln(w, p.version) + } + + // make a list of ancestor commands so that we print with full context + var ancestors []string + ancestor := cmd + for ancestor != nil { + ancestors = append(ancestors, ancestor.name) + ancestor = ancestor.parent + } + + // print the beginning of the usage string + fmt.Fprint(w, "Usage:") + for i := len(ancestors) - 1; i >= 0; i-- { + fmt.Fprint(w, " "+ancestors[i]) + } + + // write the option component of the usage message + for _, spec := range shortOptions { + // prefix with a space + fmt.Fprint(w, " ") + if !spec.required { + fmt.Fprint(w, "[") + } + fmt.Fprint(w, synopsis(spec, "-"+spec.short)) + if !spec.required { + fmt.Fprint(w, "]") + } + } + + for _, spec := range longOptions { + // prefix with a space + fmt.Fprint(w, " ") + if !spec.required { + fmt.Fprint(w, "[") + } + fmt.Fprint(w, synopsis(spec, "--"+spec.long)) + if !spec.required { + fmt.Fprint(w, "]") + } + } + + // When we parse positionals, we check that: + // 1. required positionals come before non-required positionals + // 2. there is at most one multiple-value positional + // 3. if there is a multiple-value positional then it comes after all other positionals + // Here we merely print the usage string, so we do not explicitly re-enforce those rules + + // write the positionals in following form: + // REQUIRED1 REQUIRED2 + // REQUIRED1 REQUIRED2 [OPTIONAL1 [OPTIONAL2]] + // REQUIRED1 REQUIRED2 REPEATED [REPEATED ...] + // REQUIRED1 REQUIRED2 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]] + // REQUIRED1 REQUIRED2 [OPTIONAL1 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]]] + var closeBrackets int + for _, spec := range positionals { + fmt.Fprint(w, " ") + if !spec.required { + fmt.Fprint(w, "[") + closeBrackets += 1 + } + if spec.cardinality == multiple { + fmt.Fprintf(w, "%s [%s ...]", spec.placeholder, spec.placeholder) + } else { + fmt.Fprint(w, spec.placeholder) + } + } + fmt.Fprint(w, strings.Repeat("]", closeBrackets)) + + // if the program supports subcommands, give a hint to the user about their existence + if len(cmd.subcommands) > 0 { + fmt.Fprint(w, " []") + } + + fmt.Fprint(w, "\n") +} + +func printTwoCols(w io.Writer, left, help string, defaultVal string, envVal string) { + lhs := " " + left + fmt.Fprint(w, lhs) + if help != "" { + if len(lhs)+2 < colWidth { + fmt.Fprint(w, strings.Repeat(" ", colWidth-len(lhs))) + } else { + fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth)) + } + fmt.Fprint(w, help) + } + + bracketsContent := []string{} + + if defaultVal != "" { + bracketsContent = append(bracketsContent, + fmt.Sprintf("default: %s", defaultVal), + ) + } + + if envVal != "" { + bracketsContent = append(bracketsContent, + fmt.Sprintf("env: %s", envVal), + ) + } + + if len(bracketsContent) > 0 { + fmt.Fprintf(w, " [%s]", strings.Join(bracketsContent, ", ")) + } + fmt.Fprint(w, "\n") +} + +// WriteHelp writes the usage string followed by the full help string for each option +func (p *Parser) WriteHelp(w io.Writer) { + cmd := p.cmd + if p.lastCmd != nil { + cmd = p.lastCmd + } + p.writeHelpForSubcommand(w, cmd) +} + +// WriteHelpForSubcommand writes the usage string followed by the full help +// string for a specified subcommand. To write help for a top-level subcommand, +// provide just the name of that subcommand. To write help for a subcommand that +// is nested under another subcommand, provide a sequence of subcommand names +// starting with the top-level subcommand and so on down the tree. +func (p *Parser) WriteHelpForSubcommand(w io.Writer, subcommand ...string) error { + cmd, err := p.lookupCommand(subcommand...) + if err != nil { + return err + } + p.writeHelpForSubcommand(w, cmd) + return nil +} + +// writeHelp writes the usage string for the given subcommand +func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { + var positionals, longOptions, shortOptions []*spec + for _, spec := range cmd.specs { + switch { + case spec.positional: + positionals = append(positionals, spec) + case spec.long != "": + longOptions = append(longOptions, spec) + case spec.short != "": + shortOptions = append(shortOptions, spec) + } + } + + if p.description != "" { + fmt.Fprintln(w, p.description) + } + p.writeUsageForSubcommand(w, cmd) + + // write the list of positionals + if len(positionals) > 0 { + fmt.Fprint(w, "\nPositional arguments:\n") + for _, spec := range positionals { + printTwoCols(w, spec.placeholder, spec.help, "", "") + } + } + + // write the list of options with the short-only ones first to match the usage string + if len(shortOptions)+len(longOptions) > 0 || cmd.parent == nil { + fmt.Fprint(w, "\nOptions:\n") + for _, spec := range shortOptions { + p.printOption(w, spec) + } + for _, spec := range longOptions { + p.printOption(w, spec) + } + } + + // obtain a flattened list of options from all ancestors + var globals []*spec + ancestor := cmd.parent + for ancestor != nil { + globals = append(globals, ancestor.specs...) + ancestor = ancestor.parent + } + + // write the list of global options + if len(globals) > 0 { + fmt.Fprint(w, "\nGlobal options:\n") + for _, spec := range globals { + p.printOption(w, spec) + } + } + + // write the list of built in options + p.printOption(w, &spec{ + cardinality: zero, + long: "help", + short: "h", + help: "display this help and exit", + }) + if p.version != "" { + p.printOption(w, &spec{ + cardinality: zero, + long: "version", + help: "display version and exit", + }) + } + + // write the list of subcommands + if len(cmd.subcommands) > 0 { + fmt.Fprint(w, "\nCommands:\n") + for _, subcmd := range cmd.subcommands { + printTwoCols(w, subcmd.name, subcmd.help, "", "") + } + } + + if p.epilogue != "" { + fmt.Fprintln(w, "\n"+p.epilogue) + } +} + +func (p *Parser) printOption(w io.Writer, spec *spec) { + ways := make([]string, 0, 2) + if spec.long != "" { + ways = append(ways, synopsis(spec, "--"+spec.long)) + } + if spec.short != "" { + ways = append(ways, synopsis(spec, "-"+spec.short)) + } + if len(ways) > 0 { + printTwoCols(w, strings.Join(ways, ", "), spec.help, spec.defaultVal, spec.env) + } +} + +// lookupCommand finds a subcommand based on a sequence of subcommand names. The +// first string should be a top-level subcommand, the next should be a child +// subcommand of that subcommand, and so on. If no strings are given then the +// root command is returned. If no such subcommand exists then an error is +// returned. +func (p *Parser) lookupCommand(path ...string) (*command, error) { + cmd := p.cmd + for _, name := range path { + var found *command + for _, child := range cmd.subcommands { + if child.name == name { + found = child + } + } + if found == nil { + return nil, fmt.Errorf("%q is not a subcommand of %s", name, cmd.name) + } + cmd = found + } + return cmd, nil +} + +func synopsis(spec *spec, form string) string { + if spec.cardinality == zero { + return form + } + return form + " " + spec.placeholder +} diff --git a/v2/usage_test.go b/v2/usage_test.go new file mode 100644 index 0000000..fd67fc8 --- /dev/null +++ b/v2/usage_test.go @@ -0,0 +1,635 @@ +package arg + +import ( + "bytes" + "errors" + "fmt" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type NameDotName struct { + Head, Tail string +} + +func (n *NameDotName) UnmarshalText(b []byte) error { + s := string(b) + pos := strings.Index(s, ".") + if pos == -1 { + return fmt.Errorf("missing period in %s", s) + } + n.Head = s[:pos] + n.Tail = s[pos+1:] + return nil +} + +func (n *NameDotName) MarshalText() (text []byte, err error) { + text = []byte(fmt.Sprintf("%s.%s", n.Head, n.Tail)) + return +} + +func TestWriteUsage(t *testing.T) { + expectedUsage := "Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]]" + + expectedHelp := ` +Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]] + +Positional arguments: + INPUT + OUTPUT list of outputs + +Options: + --name NAME name to use [default: Foo Bar] + --value VALUE secret value [default: 42] + --verbose, -v verbosity level + --dataset DATASET dataset to use + --optimize OPTIMIZE, -O OPTIMIZE + optimization level + --ids IDS Ids + --values VALUES Values [default: [3.14 42 256]] + --workers WORKERS, -w WORKERS + number of workers to start [default: 10, env: WORKERS] + --testenv TESTENV, -a TESTENV [env: TEST_ENV] + --file FILE, -f FILE File with mandatory extension [default: scratch.txt] + --help, -h display this help and exit +` + + var args struct { + Input string `arg:"positional,required"` + Output []string `arg:"positional" help:"list of outputs"` + Name string `help:"name to use"` + Value int `help:"secret value"` + Verbose bool `arg:"-v" help:"verbosity level"` + Dataset string `help:"dataset to use"` + Optimize int `arg:"-O" help:"optimization level"` + Ids []int64 `help:"Ids"` + Values []float64 `help:"Values"` + Workers int `arg:"-w,env:WORKERS" help:"number of workers to start" default:"10"` + TestEnv string `arg:"-a,env:TEST_ENV"` + File *NameDotName `arg:"-f" help:"File with mandatory extension"` + } + args.Name = "Foo Bar" + args.Value = 42 + args.Values = []float64{3.14, 42, 256} + args.File = &NameDotName{"scratch", "txt"} + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + os.Args[0] = "example" + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type MyEnum int + +func (n *MyEnum) UnmarshalText(b []byte) error { + return nil +} + +func (n *MyEnum) MarshalText() ([]byte, error) { + return nil, errors.New("There was a problem") +} + +func TestUsageWithDefaults(t *testing.T) { + expectedUsage := "Usage: example [--label LABEL] [--content CONTENT]" + + expectedHelp := ` +Usage: example [--label LABEL] [--content CONTENT] + +Options: + --label LABEL [default: cat] + --content CONTENT [default: dog] + --help, -h display this help and exit +` + var args struct { + Label string + Content string `default:"dog"` + } + args.Label = "cat" + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + args.Label = "should_ignore_this" + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageCannotMarshalToString(t *testing.T) { + var args struct { + Name *MyEnum + } + v := MyEnum(42) + args.Name = &v + _, err := NewParser(Config{Program: "example"}, &args) + assert.EqualError(t, err, `args.Name: error marshaling default value to string: There was a problem`) +} + +func TestUsageLongPositionalWithHelp_legacyForm(t *testing.T) { + expectedUsage := "Usage: example [VERYLONGPOSITIONALWITHHELP]" + + expectedHelp := ` +Usage: example [VERYLONGPOSITIONALWITHHELP] + +Positional arguments: + VERYLONGPOSITIONALWITHHELP + this positional argument is very long but cannot include commas + +Options: + --help, -h display this help and exit +` + var args struct { + VeryLongPositionalWithHelp string `arg:"positional,help:this positional argument is very long but cannot include commas"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageLongPositionalWithHelp_newForm(t *testing.T) { + expectedUsage := "Usage: example [VERYLONGPOSITIONALWITHHELP]" + + expectedHelp := ` +Usage: example [VERYLONGPOSITIONALWITHHELP] + +Positional arguments: + VERYLONGPOSITIONALWITHHELP + this positional argument is very long, and includes: commas, colons etc + +Options: + --help, -h display this help and exit +` + var args struct { + VeryLongPositionalWithHelp string `arg:"positional" help:"this positional argument is very long, and includes: commas, colons etc"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageWithProgramName(t *testing.T) { + expectedUsage := "Usage: myprogram" + + expectedHelp := ` +Usage: myprogram + +Options: + --help, -h display this help and exit +` + config := Config{ + Program: "myprogram", + } + p, err := NewParser(config, &struct{}{}) + require.NoError(t, err) + + os.Args[0] = "example" + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type versioned struct{} + +// Version returns the version for this program +func (versioned) Version() string { + return "example 3.2.1" +} + +func TestUsageWithVersion(t *testing.T) { + expectedUsage := "example 3.2.1\nUsage: example" + + expectedHelp := ` +example 3.2.1 +Usage: example + +Options: + --help, -h display this help and exit + --version display version and exit +` + os.Args[0] = "example" + p, err := NewParser(Config{}, &versioned{}) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type described struct{} + +// Described returns the description for this program +func (described) Description() string { + return "this program does this and that" +} + +func TestUsageWithDescription(t *testing.T) { + expectedUsage := "Usage: example" + + expectedHelp := ` +this program does this and that +Usage: example + +Options: + --help, -h display this help and exit +` + os.Args[0] = "example" + p, err := NewParser(Config{}, &described{}) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +type epilogued struct{} + +// Epilogued returns the epilogue for this program +func (epilogued) Epilogue() string { + return "For more information visit github.com/alexflint/go-arg" +} + +func TestUsageWithEpilogue(t *testing.T) { + expectedUsage := "Usage: example" + + expectedHelp := ` +Usage: example + +Options: + --help, -h display this help and exit + +For more information visit github.com/alexflint/go-arg +` + os.Args[0] = "example" + p, err := NewParser(Config{}, &epilogued{}) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageForRequiredPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageForMixedPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2 [OPTIONAL1 [OPTIONAL2]]\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + Optional1 string `arg:"positional"` + Optional2 string `arg:"positional"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageForRepeatedPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2 REPEATED [REPEATED ...]\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + Repeated []string `arg:"positional,required"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageForMixedAndRepeatedPositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIRED1 REQUIRED2 [OPTIONAL1 [OPTIONAL2 [REPEATED [REPEATED ...]]]]\n" + var args struct { + Required1 string `arg:"positional,required"` + Required2 string `arg:"positional,required"` + Optional1 string `arg:"positional"` + Optional2 string `arg:"positional"` + Repeated []string `arg:"positional"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestRequiredMultiplePositionals(t *testing.T) { + expectedUsage := "Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...]\n" + + expectedHelp := ` +Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...] + +Positional arguments: + REQUIREDMULTIPLE required multiple positional + +Options: + --help, -h display this help and exit +` + var args struct { + RequiredMultiple []string `arg:"positional,required" help:"required multiple positional"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, usage.String()) +} + +func TestUsageWithNestedSubcommands(t *testing.T) { + expectedUsage := "Usage: example child nested [--enable] OUTPUT" + + expectedHelp := ` +Usage: example child nested [--enable] OUTPUT + +Positional arguments: + OUTPUT + +Options: + --enable + +Global options: + --values VALUES Values + --verbose, -v verbosity level + --help, -h display this help and exit +` + + var args struct { + Verbose bool `arg:"-v" help:"verbosity level"` + Child *struct { + Values []float64 `help:"Values"` + Nested *struct { + Enable bool + Output string `arg:"positional,required"` + } `arg:"subcommand:nested"` + } `arg:"subcommand:child"` + } + + os.Args[0] = "example" + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + _ = p.Parse([]string{"child", "nested", "value"}) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var help2 bytes.Buffer + p.WriteHelpForSubcommand(&help2, "child", "nested") + assert.Equal(t, expectedHelp[1:], help2.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) + + var usage2 bytes.Buffer + p.WriteUsageForSubcommand(&usage2, "child", "nested") + assert.Equal(t, expectedUsage, strings.TrimSpace(usage2.String())) +} + +func TestNonexistentSubcommand(t *testing.T) { + var args struct { + sub *struct{} `arg:"subcommand"` + } + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + var b bytes.Buffer + + err = p.WriteUsageForSubcommand(&b, "does_not_exist") + assert.Error(t, err) + + err = p.WriteHelpForSubcommand(&b, "does_not_exist") + assert.Error(t, err) + + err = p.FailSubcommand("something went wrong", "does_not_exist") + assert.Error(t, err) + + err = p.WriteUsageForSubcommand(&b, "sub", "does_not_exist") + assert.Error(t, err) + + err = p.WriteHelpForSubcommand(&b, "sub", "does_not_exist") + assert.Error(t, err) + + err = p.FailSubcommand("something went wrong", "sub", "does_not_exist") + assert.Error(t, err) +} + +func TestUsageWithoutLongNames(t *testing.T) { + expectedUsage := "Usage: example [-a PLACEHOLDER] -b SHORTONLY2" + + expectedHelp := ` +Usage: example [-a PLACEHOLDER] -b SHORTONLY2 + +Options: + -a PLACEHOLDER some help [default: some val] + -b SHORTONLY2 some help2 + --help, -h display this help and exit +` + var args struct { + ShortOnly string `arg:"-a,--" help:"some help" default:"some val" placeholder:"PLACEHOLDER"` + ShortOnly2 string `arg:"-b,--,required" help:"some help2"` + } + p, err := NewParser(Config{Program: "example"}, &args) + assert.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageWithShortFirst(t *testing.T) { + expectedUsage := "Usage: example [-c CAT] [--dog DOG]" + + expectedHelp := ` +Usage: example [-c CAT] [--dog DOG] + +Options: + -c CAT + --dog DOG + --help, -h display this help and exit +` + var args struct { + Dog string + Cat string `arg:"-c,--"` + } + p, err := NewParser(Config{Program: "example"}, &args) + assert.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestUsageWithEnvOptions(t *testing.T) { + expectedUsage := "Usage: example [-s SHORT]" + + expectedHelp := ` +Usage: example [-s SHORT] + +Options: + -s SHORT [env: SHORT] + --help, -h display this help and exit +` + var args struct { + Short string `arg:"--,-s,env"` + EnvOnly string `arg:"--,env"` + EnvOnlyOverriden string `arg:"--,env:CUSTOM"` + } + + p, err := NewParser(Config{Program: "example"}, &args) + assert.NoError(t, err) + + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestFail(t *testing.T) { + originalStderr := stderr + originalExit := osExit + defer func() { + stderr = originalStderr + osExit = originalExit + }() + + var b bytes.Buffer + stderr = &b + + var exitCode int + osExit = func(code int) { exitCode = code } + + expectedStdout := ` +Usage: example [--foo FOO] +error: something went wrong +` + + var args struct { + Foo int + } + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + p.Fail("something went wrong") + + assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, -1, exitCode) +} + +func TestFailSubcommand(t *testing.T) { + originalStderr := stderr + originalExit := osExit + defer func() { + stderr = originalStderr + osExit = originalExit + }() + + var b bytes.Buffer + stderr = &b + + var exitCode int + osExit = func(code int) { exitCode = code } + + expectedStdout := ` +Usage: example sub +error: something went wrong +` + + var args struct { + Sub *struct{} `arg:"subcommand"` + } + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + + err = p.FailSubcommand("something went wrong", "sub") + require.NoError(t, err) + + assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, -1, exitCode) +} From 09d28e1195519df88f2606137f227aac6186ed09 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:00:42 -0700 Subject: [PATCH 02/19] split the parsing logic into ProcessEnvironment, ProcessCommandLine, ProcessOptions, ProcessPositions, ProcessDefaults --- go.sum | 2 - v2/parse.go | 470 ++++++++++++++++++++++++------------------ v2/parse_test.go | 173 +++++++--------- v2/subcommand.go | 8 +- v2/subcommand_test.go | 6 +- v2/usage.go | 130 ++++++------ v2/usage_test.go | 10 +- 7 files changed, 420 insertions(+), 379 deletions(-) diff --git a/go.sum b/go.sum index 5b536f9..385ca8f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/alexflint/go-scalar v1.1.0 h1:aaAouLLzI9TChcPXotr6gUhq+Scr8rl0P9P4PnltbhM= -github.com/alexflint/go-scalar v1.1.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/v2/parse.go b/v2/parse.go index 8e190f2..ce02bd4 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -39,8 +39,8 @@ func (p path) Child(f reflect.StructField) path { } } -// spec represents a command line option -type spec struct { +// Arg represents a command line argument +type Argument struct { dest path field reflect.StructField // the struct field from which this option was created long string // the --long form for this option, or empty if none @@ -55,14 +55,14 @@ type spec struct { placeholder string // name of the data in help } -// command represents a named subcommand, or the top-level command -type command struct { +// Command represents a named subcommand, or the top-level command +type Command struct { name string help string dest path - specs []*spec - subcommands []*command - parent *command + args []*Argument + subcommands []*Command + parent *Command } // ErrHelp indicates that -h or --help were provided @@ -80,16 +80,16 @@ func MustParse(dest interface{}) *Parser { return nil // just in case osExit was monkey-patched } - err = p.Parse(flags()) + err = p.Parse(os.Args, os.Environ()) switch { case err == ErrHelp: - p.writeHelpForSubcommand(stdout, p.lastCmd) + p.writeHelpForSubcommand(stdout, p.leaf) osExit(0) case err == ErrVersion: fmt.Fprintln(stdout, p.version) osExit(0) case err != nil: - p.failWithSubcommand(err.Error(), p.lastCmd) + p.failWithSubcommand(err.Error(), p.leaf) } return p @@ -101,41 +101,28 @@ func Parse(dest interface{}) error { if err != nil { return err } - return p.Parse(flags()) -} - -// flags gets all command line arguments other than the first (program name) -func flags() []string { - if len(os.Args) == 0 { // os.Args could be empty - return nil - } - return os.Args[1:] + return p.Parse(os.Args, os.Environ()) } // Config represents configuration options for an argument parser type Config struct { // Program is the name of the program used in the help text Program string - - // IgnoreEnv instructs the library not to read environment variables - IgnoreEnv bool - - // IgnoreDefault instructs the library not to reset the variables to the - // default values, including pointers to sub commands - IgnoreDefault bool } // Parser represents a set of command line options with destination values type Parser struct { - cmd *command - root reflect.Value // destination struct to fill will values - config Config - version string - description string - epilogue string + cmd *Command // the top-level command + root reflect.Value // destination struct to fill will values + config Config // configuration passed to NewParser + version string // version from the argument struct + prologue string // prologue for help text (from the argument struct) + epilogue string // epilogue for help text (from the argument struct) - // the following field changes during processing of command line arguments - lastCmd *command + // the following fields are updated during processing of command line arguments + leaf *Command // the subcommand we processed last + accumulatedArgs []*Argument // concatenation of the leaf subcommand's arguments plus all ancestors' arguments + seen map[*Argument]bool // the arguments we encountered while processing command line arguments } // Versioned is the interface that the destination struct should implement to @@ -198,8 +185,9 @@ func NewParser(config Config, dest interface{}) (*Parser, error) { // construct a parser p := Parser{ - cmd: &command{name: name}, + cmd: &Command{name: name}, config: config, + seen: make(map[*Argument]bool), } // make a list of roots @@ -217,28 +205,28 @@ func NewParser(config Config, dest interface{}) (*Parser, error) { } // add nonzero field values as defaults - for _, spec := range cmd.specs { - if v := p.val(spec.dest); v.IsValid() && !isZero(v) { + for _, arg := range cmd.args { + if v := p.val(arg.dest); v.IsValid() && !isZero(v) { if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { str, err := defaultVal.MarshalText() if err != nil { - return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) + return nil, fmt.Errorf("%v: error marshaling default value to string: %v", arg.dest, err) } - spec.defaultVal = string(str) + arg.defaultVal = string(str) } else { - spec.defaultVal = fmt.Sprintf("%v", v) + arg.defaultVal = fmt.Sprintf("%v", v) } } } - p.cmd.specs = append(p.cmd.specs, cmd.specs...) + p.cmd.args = append(p.cmd.args, cmd.args...) p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) if dest, ok := dest.(Versioned); ok { p.version = dest.Version() } if dest, ok := dest.(Described); ok { - p.description = dest.Description() + p.prologue = dest.Description() } if dest, ok := dest.(Epilogued); ok { p.epilogue = dest.Epilogue() @@ -247,7 +235,7 @@ func NewParser(config Config, dest interface{}) (*Parser, error) { return &p, nil } -func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { +func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { // commands can only be created from pointers to structs if t.Kind() != reflect.Ptr { return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", @@ -260,7 +248,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { dest, t.Kind()) } - cmd := command{ + cmd := Command{ name: name, dest: dest, } @@ -287,7 +275,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // duplicate the entire path to avoid slice overwrites subdest := dest.Child(field) - spec := spec{ + arg := Argument{ dest: subdest, field: field, long: strings.ToLower(field.Name), @@ -295,12 +283,12 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { help, exists := field.Tag.Lookup("help") if exists { - spec.help = help + arg.help = help } defaultVal, hasDefault := field.Tag.Lookup("default") if hasDefault { - spec.defaultVal = defaultVal + arg.defaultVal = defaultVal } // Look at the tag @@ -320,33 +308,33 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { case strings.HasPrefix(key, "---"): errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) case strings.HasPrefix(key, "--"): - spec.long = key[2:] + arg.long = key[2:] case strings.HasPrefix(key, "-"): if len(key) != 2 { errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", t.Name(), field.Name)) return false } - spec.short = key[1:] + arg.short = key[1:] case key == "required": if hasDefault { errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", t.Name(), field.Name)) return false } - spec.required = true + arg.required = true case key == "positional": - spec.positional = true + arg.positional = true case key == "separate": - spec.separate = true + arg.separate = true case key == "help": // deprecated - spec.help = value + arg.help = value case key == "env": // Use override name if provided if value != "" { - spec.env = value + arg.env = value } else { - spec.env = strings.ToUpper(field.Name) + arg.env = strings.ToUpper(field.Name) } case key == "subcommand": // decide on a name for the subcommand @@ -375,11 +363,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") if hasPlaceholder { - spec.placeholder = placeholder - } else if spec.long != "" { - spec.placeholder = strings.ToUpper(spec.long) + arg.placeholder = placeholder + } else if arg.long != "" { + arg.placeholder = strings.ToUpper(arg.long) } else { - spec.placeholder = strings.ToUpper(spec.field.Name) + arg.placeholder = strings.ToUpper(arg.field.Name) } // Check whether this field is supported. It's good to do this here rather than @@ -387,16 +375,16 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // fields will always fail regardless of whether the arguments it received // exercised those fields. if !isSubcommand { - cmd.specs = append(cmd.specs, &spec) + cmd.args = append(cmd.args, &arg) var err error - spec.cardinality, err = cardinalityOf(field.Type) + arg.cardinality, err = cardinalityOf(field.Type) if err != nil { errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())) return false } - if spec.cardinality == multiple && hasDefault { + if arg.cardinality == multiple && hasDefault { errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", t.Name(), field.Name)) return false @@ -413,8 +401,8 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // check that we don't have both positionals and subcommands var hasPositional bool - for _, spec := range cmd.specs { - if spec.positional { + for _, arg := range cmd.args { + if arg.positional { hasPositional = true } } @@ -427,35 +415,88 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { // Parse processes the given command line option, storing the results in the field // of the structs from which NewParser was constructed -func (p *Parser) Parse(args []string) error { - err := p.process(args) - if err != nil { - // If -h or --help were specified then make sure help text supercedes other errors - for _, arg := range args { - if arg == "-h" || arg == "--help" { - return ErrHelp - } - if arg == "--" { - break - } +func (p *Parser) Parse(args, env []string) error { + p.seen = make(map[*Argument]bool) + + // If -h or --help were specified then make sure help text supercedes other errors + var help bool + for _, arg := range args { + if arg == "-h" || arg == "--help" { + help = true + } + if arg == "--" { + break } } - return err + + err := p.ProcessCommandLine(args) + if err != nil { + if help { + return ErrHelp + } + return err + } + + err = p.ProcessEnvironment(env) + if err != nil { + if help { + return ErrHelp + } + return err + } + + err = p.ProcessDefaults() + if err != nil { + if help { + return ErrHelp + } + return err + } + + return p.Validate() } -// process environment vars for the given arguments -func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { - for _, spec := range specs { - if spec.env == "" { +// ProcessEnvironment processes environment variables from a list of strings +// of the form KEY=VALUE. You can pass in os.Environ(). It +// does not overwrite any fields with values already populated. +func (p *Parser) ProcessEnvironment(environ []string) error { + return p.processEnvironment(environ, false) +} + +// OverwriteWithEnvironment processes environment variables from a list +// of strings of the form "KEY=VALUE". Any existing values are overwritten. +func (p *Parser) OverwriteWithEnvironment(environ []string) error { + return p.processEnvironment(environ, true) +} + +// ProcessEnvironment processes environment variables from a list of strings +// of the form KEY=VALUE. You can pass in os.Environ(). It +// overwrites already-populated fields only if overwrite is true. +func (p *Parser) processEnvironment(environ []string, overwrite bool) error { + // parse the list of KEY=VAL strings in environ + env := make(map[string]string) + for _, s := range environ { + if i := strings.Index(s, "="); i >= 0 { + env[s[:i]] = s[i+1:] + } + } + + // process arguments one-by-one + for _, arg := range p.accumulatedArgs { + if p.seen[arg] && !overwrite { continue } - value, found := os.LookupEnv(spec.env) + if arg.env == "" { + continue + } + + value, found := env[arg.env] if !found { continue } - if spec.cardinality == multiple { + if arg.cardinality == multiple { // expect a CSV string in an environment // variable in the case of multiple values var values []string @@ -465,74 +506,80 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error if err != nil { return fmt.Errorf( "error reading a CSV string from environment variable %s with multiple values: %v", - spec.env, + arg.env, err, ) } } - if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil { + if err = setSliceOrMap(p.val(arg.dest), values, !arg.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", - spec.env, + arg.env, err, ) } } else { - if err := scalar.ParseValue(p.val(spec.dest), value); err != nil { - return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) + if err := scalar.ParseValue(p.val(arg.dest), value); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) } } - wasPresent[spec] = true + + p.seen[arg] = true } return nil } -// process goes through arguments one-by-one, parses them, and assigns the result to -// the underlying struct field -func (p *Parser) process(args []string) error { - // track the options we have seen - wasPresent := make(map[*spec]bool) - - // union of specs for the chain of subcommands encountered so far - curCmd := p.cmd - p.lastCmd = curCmd - - // make a copy of the specs because we will add to this list each time we expand a subcommand - specs := make([]*spec, len(curCmd.specs)) - copy(specs, curCmd.specs) - - // deal with environment vars - if !p.config.IgnoreEnv { - err := p.captureEnvVars(specs, wasPresent) - if err != nil { - return err - } +// ProcessCommandLine goes through arguments one-by-one, parses them, +// and assigns the result to the underlying struct field. It returns +// an error if an argument is invalid or an option is unknown, not if a +// required argument is missing. To check that all required arguments +// are set, call CheckRequired(). This function ignores the first element +// of args, which is assumed to be the program name itself. +func (p *Parser) ProcessCommandLine(args []string) error { + positionals, err := p.ProcessOptions(args) + if err != nil { + return err } + return p.ProcessPositionals(positionals) +} + +// ProcessOptions process command line arguments but does not process +// positional arguments. Instead, it returns positionals. These can then +// be passed to ProcessPositionals. This function ignores the first element +// of args, which is assumed to be the program name itself. +func (p *Parser) ProcessOptions(args []string) ([]string, error) { + // union of args for the chain of subcommands encountered so far + curCmd := p.cmd + p.leaf = curCmd + + // we will add to this list each time we expand a subcommand + p.accumulatedArgs = make([]*Argument, len(curCmd.args)) + copy(p.accumulatedArgs, curCmd.args) // process each string from the command line var allpositional bool var positionals []string // must use explicit for loop, not range, because we manipulate i inside the loop - for i := 0; i < len(args); i++ { - arg := args[i] - if arg == "--" { + for i := 1; i < len(args); i++ { + token := args[i] + if token == "--" { allpositional = true continue } - if !isFlag(arg) || allpositional { + if !isFlag(token) || allpositional { // each subcommand can have either subcommands or positionals, but not both if len(curCmd.subcommands) == 0 { - positionals = append(positionals, arg) + positionals = append(positionals, token) continue } // if we have a subcommand then make sure it is valid for the current context - subcmd := findSubcommand(curCmd.subcommands, arg) + subcmd := findSubcommand(curCmd.subcommands, token) if subcmd == nil { - return fmt.Errorf("invalid subcommand: %s", arg) + return nil, fmt.Errorf("invalid subcommand: %s", token) } // instantiate the field to point to a new struct @@ -542,157 +589,186 @@ func (p *Parser) process(args []string) error { } // add the new options to the set of allowed options - specs = append(specs, subcmd.specs...) - - // capture environment vars for these new options - if !p.config.IgnoreEnv { - err := p.captureEnvVars(subcmd.specs, wasPresent) - if err != nil { - return err - } - } + p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) curCmd = subcmd - p.lastCmd = curCmd + p.leaf = curCmd continue } // check for special --help and --version flags - switch arg { + switch token { case "-h", "--help": - return ErrHelp + return nil, ErrHelp case "--version": - return ErrVersion + return nil, ErrVersion } // check for an equals sign, as in "--foo=bar" var value string - opt := strings.TrimLeft(arg, "-") + opt := strings.TrimLeft(token, "-") if pos := strings.Index(opt, "="); pos != -1 { value = opt[pos+1:] opt = opt[:pos] } - // lookup the spec for this option (note that the "specs" slice changes as + // look up the arg for this option (note that the "args" slice changes as // we expand subcommands so it is better not to use a map) - spec := findOption(specs, opt) - if spec == nil { - return fmt.Errorf("unknown argument %s", arg) + arg := findOption(p.accumulatedArgs, opt) + if arg == nil { + return nil, fmt.Errorf("unknown argument %s", token) } - wasPresent[spec] = true + p.seen[arg] = true // deal with the case of multiple values - if spec.cardinality == multiple { + if arg.cardinality == multiple { var values []string if value == "" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { values = append(values, args[i+1]) i++ - if spec.separate { + if arg.separate { break } } } else { values = append(values, value) } - err := setSliceOrMap(p.val(spec.dest), values, !spec.separate) + err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) if err != nil { - return fmt.Errorf("error processing %s: %v", arg, err) + return nil, fmt.Errorf("error processing %s: %v", token, err) } continue } // if it's a flag and it has no value then set the value to true // use boolean because this takes account of TextUnmarshaler - if spec.cardinality == zero && value == "" { + if arg.cardinality == zero && value == "" { value = "true" } // if we have something like "--foo" then the value is the next argument if value == "" { if i+1 == len(args) { - return fmt.Errorf("missing value for %s", arg) + return nil, fmt.Errorf("missing value for %s", token) } - if !nextIsNumeric(spec.field.Type, args[i+1]) && isFlag(args[i+1]) { - return fmt.Errorf("missing value for %s", arg) + if isFlag(args[i+1]) { + return nil, fmt.Errorf("missing value for %s", token) } value = args[i+1] i++ } - err := scalar.ParseValue(p.val(spec.dest), value) + p.seen[arg] = true + err := scalar.ParseValue(p.val(arg.dest), value) if err != nil { - return fmt.Errorf("error processing %s: %v", arg, err) + return nil, fmt.Errorf("error processing %s: %v", token, err) } } - // process positionals - for _, spec := range specs { - if !spec.positional { + return positionals, nil +} + +// ProcessPositionals processes a list of positional arguments. It is assumed +// that options such as --abc and --abc=123 have already been removed. If +// this list contains tokens that begin with a hyphen they will still be +// treated as positional arguments. +func (p *Parser) ProcessPositionals(positionals []string) error { + for _, arg := range p.accumulatedArgs { + if !arg.positional { continue } if len(positionals) == 0 { break } - wasPresent[spec] = true - if spec.cardinality == multiple { - err := setSliceOrMap(p.val(spec.dest), positionals, true) + p.seen[arg] = true + if arg.cardinality == multiple { + err := setSliceOrMap(p.val(arg.dest), positionals, true) if err != nil { - return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } positionals = nil } else { - err := scalar.ParseValue(p.val(spec.dest), positionals[0]) + err := scalar.ParseValue(p.val(arg.dest), positionals[0]) if err != nil { - return fmt.Errorf("error processing %s: %v", spec.field.Name, err) + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } positionals = positionals[1:] } } + if len(positionals) > 0 { return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) } - // fill in defaults and check that all the required args were provided - for _, spec := range specs { - if wasPresent[spec] { - continue - } - - name := strings.ToLower(spec.field.Name) - if spec.long != "" && !spec.positional { - name = "--" + spec.long - } - - if spec.required { - msg := fmt.Sprintf("%s is required", name) - if spec.env != "" { - msg += " (or environment variable " + spec.env + ")" - } - return errors.New(msg) - } - if !p.config.IgnoreDefault && spec.defaultVal != "" { - err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) - if err != nil { - return fmt.Errorf("error processing default value for %s: %v", name, err) - } - } - } - return nil } -func nextIsNumeric(t reflect.Type, s string) bool { - switch t.Kind() { - case reflect.Ptr: - return nextIsNumeric(t.Elem(), s) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - v := reflect.New(t) - err := scalar.ParseValue(v, s) - return err == nil - default: - return false +// ProcessDefaults assigns default values to all fields that have default values and +// are not already populated. +func (p *Parser) ProcessDefaults() error { + return p.processDefaults(false) +} + +// OverwriteWithDefaults assigns default values to all fields that have default values, +// overwriting any previous value +func (p *Parser) OverwriteWithDefaults() error { + return p.processDefaults(true) +} + +// processDefaults assigns default values to all fields in all expanded subcommands. +// If overwrite is true then it overwrites existing values. +func (p *Parser) processDefaults(overwrite bool) error { + for _, arg := range p.accumulatedArgs { + if p.seen[arg] && !overwrite { + continue + } + + if arg.defaultVal == "" { + continue + } + + name := strings.ToLower(arg.field.Name) + if arg.long != "" && !arg.positional { + name = "--" + arg.long + } + + err := scalar.ParseValue(p.val(arg.dest), arg.defaultVal) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + p.seen[arg] = true } + + return nil +} + +// Missing returns a list of required arguments that were not provided +func (p *Parser) Missing() []*Argument { + var missing []*Argument + for _, arg := range p.accumulatedArgs { + if arg.required && !p.seen[arg] { + missing = append(missing, arg) + } + } + return missing +} + +// Validate returns an error if any required arguments were missing +func (p *Parser) Validate() error { + if missing := p.Missing(); len(missing) > 0 { + name := strings.ToLower(missing[0].field.Name) + if missing[0].long != "" && !missing[0].positional { + name = "--" + missing[0].long + } + + if missing[0].env == "" { + return fmt.Errorf("%s is required", name) + } + return fmt.Errorf("%s is required (or environment variable %s)", name, missing[0].env) + } + + return nil } // isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--" @@ -717,21 +793,21 @@ func (p *Parser) val(dest path) reflect.Value { return v } -// findOption finds an option from its name, or returns null if no spec is found -func findOption(specs []*spec, name string) *spec { - for _, spec := range specs { - if spec.positional { +// findOption finds an option from its name, or returns nil if no arg is found +func findOption(args []*Argument, name string) *Argument { + for _, arg := range args { + if arg.positional { continue } - if spec.long == name || spec.short == name { - return spec + if arg.long == name || arg.short == name { + return arg } } return nil } -// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found -func findSubcommand(cmds []*command, name string) *command { +// findSubcommand finds a subcommand using its name, or returns nil if no subcommand is found +func findSubcommand(cmds []*Command, name string) *Command { for _, cmd := range cmds { if cmd.name == name { return cmd diff --git a/v2/parse_test.go b/v2/parse_test.go index 4ea6bc4..b9d6948 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -2,7 +2,6 @@ package arg import ( "bytes" - "fmt" "net" "net/mail" "net/url" @@ -15,47 +14,29 @@ import ( "github.com/stretchr/testify/require" ) -func setenv(t *testing.T, name, val string) { - if err := os.Setenv(name, val); err != nil { - t.Error(err) - } -} - func parse(cmdline string, dest interface{}) error { _, err := pparse(cmdline, dest) return err } func pparse(cmdline string, dest interface{}) (*Parser, error) { - return parseWithEnv(cmdline, nil, dest) + return parseWithEnv(dest, cmdline) } -func parseWithEnv(cmdline string, env []string, dest interface{}) (*Parser, error) { +func parseWithEnv(dest interface{}, cmdline string, env ...string) (*Parser, error) { p, err := NewParser(Config{}, dest) if err != nil { return nil, err } // split the command line - var parts []string + tokens := []string{"program"} // first token is the program name if len(cmdline) > 0 { - parts = strings.Split(cmdline, " ") - } - - // split the environment vars - for _, s := range env { - pos := strings.Index(s, "=") - if pos == -1 { - return nil, fmt.Errorf("missing equals sign in %q", s) - } - err := os.Setenv(s[:pos], s[pos+1:]) - if err != nil { - return nil, err - } + tokens = append(tokens, strings.Split(cmdline, " ")...) } // execute the parser - return p, p.Parse(parts) + return p, p.Parse(tokens, env) } func TestString(t *testing.T) { @@ -97,9 +78,9 @@ func TestInt(t *testing.T) { func TestHexOctBin(t *testing.T) { var args struct { - Hex int - Oct int - Bin int + Hex int + Oct int + Bin int Underscored int } err := parse("--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456", &args) @@ -114,22 +95,18 @@ func TestNegativeInt(t *testing.T) { var args struct { Foo int } - err := parse("-foo -100", &args) + err := parse("-foo=-100", &args) require.NoError(t, err) assert.EqualValues(t, args.Foo, -100) } -func TestNegativeIntAndFloatAndTricks(t *testing.T) { +func TestNumericOptionName(t *testing.T) { var args struct { - Foo int - Bar float64 - N int `arg:"--100"` + N int `arg:"--100"` } - err := parse("-foo -100 -bar -60.14 -100 -100", &args) + err := parse("-100 6", &args) require.NoError(t, err) - assert.EqualValues(t, args.Foo, -100) - assert.EqualValues(t, args.Bar, -60.14) - assert.EqualValues(t, args.N, -100) + assert.EqualValues(t, args.N, 6) } func TestUint(t *testing.T) { @@ -211,6 +188,14 @@ func TestMixed(t *testing.T) { } func TestRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + err := parse("--foo=abc", &args) + require.NoError(t, err) +} + +func TestMissingRequired(t *testing.T) { var args struct { Foo string `arg:"required"` } @@ -218,7 +203,7 @@ func TestRequired(t *testing.T) { require.Error(t, err, "--foo is required") } -func TestRequiredWithEnv(t *testing.T) { +func TestMissingRequiredWithEnv(t *testing.T) { var args struct { Foo string `arg:"required,env:FOO"` } @@ -336,8 +321,7 @@ func TestNoLongName(t *testing.T) { ShortOnly string `arg:"-s,--"` EnvOnly string `arg:"--,env"` } - setenv(t, "ENVONLY", "TestVal") - err := parse("-s TestVal2", &args) + _, err := parseWithEnv(&args, "-s TestVal2", "ENVONLY=TestVal") assert.NoError(t, err) assert.Equal(t, "TestVal", args.EnvOnly) assert.Equal(t, "TestVal2", args.ShortOnly) @@ -482,15 +466,6 @@ func TestUnknownField(t *testing.T) { assert.Error(t, err) } -func TestMissingRequired(t *testing.T) { - var args struct { - Foo string `arg:"required"` - X []string `arg:"positional"` - } - err := parse("x", &args) - assert.Error(t, err) -} - func TestNonsenseKey(t *testing.T) { var args struct { X []string `arg:"positional, nonsense"` @@ -520,7 +495,7 @@ func TestNegativeValue(t *testing.T) { var args struct { Foo int } - err := parse("--foo -123", &args) + err := parse("--foo=-123", &args) require.NoError(t, err) assert.Equal(t, -123, args.Foo) } @@ -687,7 +662,7 @@ func TestEnvironmentVariable(t *testing.T) { var args struct { Foo string `arg:"env"` } - _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "", "FOO=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -696,7 +671,7 @@ func TestEnvironmentVariableNotPresent(t *testing.T) { var args struct { NotPresent string `arg:"env"` } - _, err := parseWithEnv("", nil, &args) + _, err := parseWithEnv(&args, "", "") require.NoError(t, err) assert.Equal(t, "", args.NotPresent) } @@ -705,16 +680,16 @@ func TestEnvironmentVariableOverrideName(t *testing.T) { var args struct { Foo string `arg:"env:BAZ"` } - _, err := parseWithEnv("", []string{"BAZ=bar"}, &args) + _, err := parseWithEnv(&args, "", "BAZ=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } -func TestEnvironmentVariableOverrideArgument(t *testing.T) { +func TestCommandLineSupercedesEnv(t *testing.T) { var args struct { Foo string `arg:"env"` } - _, err := parseWithEnv("--foo zzz", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "--foo zzz", "FOO=bar") require.NoError(t, err) assert.Equal(t, "zzz", args.Foo) } @@ -723,7 +698,7 @@ func TestEnvironmentVariableError(t *testing.T) { var args struct { Foo int `arg:"env"` } - _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "", "FOO=bar") assert.Error(t, err) } @@ -731,7 +706,7 @@ func TestEnvironmentVariableRequired(t *testing.T) { var args struct { Foo string `arg:"env,required"` } - _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + _, err := parseWithEnv(&args, "", "FOO=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -740,7 +715,7 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) { var args struct { Foo []string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=bar,"baz, qux"`}, &args) + _, err := parseWithEnv(&args, "", `FOO=bar,"baz, qux"`) require.NoError(t, err) assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) } @@ -749,7 +724,7 @@ func TestEnvironmentVariableSliceEmpty(t *testing.T) { var args struct { Foo []string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=`}, &args) + _, err := parseWithEnv(&args, "", `FOO=`) require.NoError(t, err) assert.Len(t, args.Foo, 0) } @@ -758,7 +733,7 @@ func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { var args struct { Foo []int `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1,99`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1,99`) require.NoError(t, err) assert.Equal(t, []int{1, 99}, args.Foo) } @@ -767,7 +742,7 @@ func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { var args struct { Foo []float32 `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1.1,99.9`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1.1,99.9`) require.NoError(t, err) assert.Equal(t, []float32{1.1, 99.9}, args.Foo) } @@ -776,7 +751,7 @@ func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=true,false,0,1`}, &args) + _, err := parseWithEnv(&args, "", `FOO=true,false,0,1`) require.NoError(t, err) assert.Equal(t, []bool{true, false, false, true}, args.Foo) } @@ -785,7 +760,7 @@ func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { var args struct { Foo []int `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1,99\"`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1,99\"`) assert.Error(t, err) } @@ -793,7 +768,7 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=one,two`}, &args) + _, err := parseWithEnv(&args, "", `FOO=one,two`) assert.Error(t, err) } @@ -801,7 +776,7 @@ func TestEnvironmentVariableMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=1=one,99=ninetynine`}, &args) + _, err := parseWithEnv(&args, "", `FOO=1=one,99=ninetynine`) require.NoError(t, err) assert.Len(t, args.Foo, 2) assert.Equal(t, "one", args.Foo[1]) @@ -812,51 +787,61 @@ func TestEnvironmentVariableEmptyMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - _, err := parseWithEnv("", []string{`FOO=`}, &args) + _, err := parseWithEnv(&args, "", `FOO=`) require.NoError(t, err) assert.Len(t, args.Foo, 0) } -func TestEnvironmentVariableIgnored(t *testing.T) { - var args struct { - Foo string `arg:"env"` - } - setenv(t, "FOO", "abc") +// func TestEnvironmentVariableIgnored(t *testing.T) { +// var args struct { +// Foo string `arg:"env"` +// } +// setenv(t, "FOO", "abc") - p, err := NewParser(Config{IgnoreEnv: true}, &args) - require.NoError(t, err) +// p, err := NewParser(Config{IgnoreEnv: true}, &args) +// require.NoError(t, err) - err = p.Parse(nil) - assert.NoError(t, err) - assert.Equal(t, "", args.Foo) -} +// err = p.Parse(nil) +// assert.NoError(t, err) +// assert.Equal(t, "", args.Foo) +// } -func TestDefaultValuesIgnored(t *testing.T) { - var args struct { - Foo string `default:"bad"` - } +// func TestDefaultValuesIgnored(t *testing.T) { +// var args struct { +// Foo string `default:"bad"` +// } - p, err := NewParser(Config{IgnoreDefault: true}, &args) - require.NoError(t, err) +// p, err := NewParser(Config{IgnoreDefault: true}, &args) +// require.NoError(t, err) - err = p.Parse(nil) - assert.NoError(t, err) - assert.Equal(t, "", args.Foo) -} +// err = p.Parse(nil) +// assert.NoError(t, err) +// assert.Equal(t, "", args.Foo) +// } -func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) { +func TestEnvironmentVariableInSubcommand(t *testing.T) { var args struct { Sub *struct { - Foo string `arg:"env"` + Foo string `arg:"env:FOO"` } `arg:"subcommand"` } - setenv(t, "FOO", "abc") - p, err := NewParser(Config{IgnoreEnv: true}, &args) + _, err := parseWithEnv(&args, "sub", "FOO=abc") require.NoError(t, err) + require.NotNil(t, args.Sub) + assert.Equal(t, "abc", args.Sub.Foo) +} - err = p.Parse([]string{"sub"}) - assert.NoError(t, err) +func TestEnvironmentVariableInSubcommandEmpty(t *testing.T) { + var args struct { + Sub *struct { + Foo string `arg:"env:FOO"` + } `arg:"subcommand"` + } + + _, err := parseWithEnv(&args, "sub") + require.NoError(t, err) + require.NotNil(t, args.Sub) assert.Equal(t, "", args.Sub.Foo) } @@ -1305,11 +1290,11 @@ func TestReuseParser(t *testing.T) { p, err := NewParser(Config{}, &args) require.NoError(t, err) - err = p.Parse([]string{"--foo=abc"}) + err = p.Parse([]string{"program", "--foo=abc"}, nil) require.NoError(t, err) assert.Equal(t, args.Foo, "abc") - err = p.Parse([]string{}) + err = p.Parse([]string{}, nil) assert.Error(t, err) } diff --git a/v2/subcommand.go b/v2/subcommand.go index dff732c..03fe399 100644 --- a/v2/subcommand.go +++ b/v2/subcommand.go @@ -7,22 +7,22 @@ package arg // no command line arguments have been processed by this parser then it // returns nil. func (p *Parser) Subcommand() interface{} { - if p.lastCmd == nil || p.lastCmd.parent == nil { + if p.leaf == nil || p.leaf.parent == nil { return nil } - return p.val(p.lastCmd.dest).Interface() + return p.val(p.leaf.dest).Interface() } // SubcommandNames returns the sequence of subcommands specified by the // user. If no subcommands were given then it returns an empty slice. func (p *Parser) SubcommandNames() []string { - if p.lastCmd == nil { + if p.leaf == nil { return nil } // make a list of ancestor commands var ancestors []string - cur := p.lastCmd + cur := p.leaf for cur.parent != nil { // we want to exclude the root ancestors = append(ancestors, cur.name) cur = cur.parent diff --git a/v2/subcommand_test.go b/v2/subcommand_test.go index 2c61dd3..9f7c8c5 100644 --- a/v2/subcommand_test.go +++ b/v2/subcommand_test.go @@ -206,8 +206,7 @@ func TestSubcommandsWithEnvVars(t *testing.T) { { var args cmd - setenv(t, "LIMIT", "123") - err := parse("list", &args) + _, err := parseWithEnv(&args, "list", "LIMIT=123") require.NoError(t, err) require.NotNil(t, args.List) assert.Equal(t, 123, args.List.Limit) @@ -215,8 +214,7 @@ func TestSubcommandsWithEnvVars(t *testing.T) { { var args cmd - setenv(t, "LIMIT", "not_an_integer") - err := parse("list", &args) + _, err := parseWithEnv(&args, "list", "LIMIT=not_an_integer") assert.Error(t, err) } } diff --git a/v2/usage.go b/v2/usage.go index 7ba06cc..bdc118d 100644 --- a/v2/usage.go +++ b/v2/usage.go @@ -38,19 +38,15 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { } // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status -func (p *Parser) failWithSubcommand(msg string, cmd *command) { +func (p *Parser) failWithSubcommand(msg string, cmd *Command) { p.writeUsageForSubcommand(stderr, cmd) fmt.Fprintln(stderr, "error:", msg) osExit(-1) } -// WriteUsage writes usage information to the given writer +// WriteUsage writes usage information for the top-level command func (p *Parser) WriteUsage(w io.Writer) { - cmd := p.cmd - if p.lastCmd != nil { - cmd = p.lastCmd - } - p.writeUsageForSubcommand(w, cmd) + p.writeUsageForSubcommand(w, p.cmd) } // WriteUsageForSubcommand writes the usage information for a specified @@ -68,16 +64,16 @@ func (p *Parser) WriteUsageForSubcommand(w io.Writer, subcommand ...string) erro } // writeUsageForSubcommand writes usage information for the given subcommand -func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { - var positionals, longOptions, shortOptions []*spec - for _, spec := range cmd.specs { +func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *Command) { + var positionals, longOptions, shortOptions []*Argument + for _, arg := range cmd.args { switch { - case spec.positional: - positionals = append(positionals, spec) - case spec.long != "": - longOptions = append(longOptions, spec) - case spec.short != "": - shortOptions = append(shortOptions, spec) + case arg.positional: + positionals = append(positionals, arg) + case arg.long != "": + longOptions = append(longOptions, arg) + case arg.short != "": + shortOptions = append(shortOptions, arg) } } @@ -100,26 +96,26 @@ func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { } // write the option component of the usage message - for _, spec := range shortOptions { + for _, arg := range shortOptions { // prefix with a space fmt.Fprint(w, " ") - if !spec.required { + if !arg.required { fmt.Fprint(w, "[") } - fmt.Fprint(w, synopsis(spec, "-"+spec.short)) - if !spec.required { + fmt.Fprint(w, synopsis(arg, "-"+arg.short)) + if !arg.required { fmt.Fprint(w, "]") } } - for _, spec := range longOptions { + for _, arg := range longOptions { // prefix with a space fmt.Fprint(w, " ") - if !spec.required { + if !arg.required { fmt.Fprint(w, "[") } - fmt.Fprint(w, synopsis(spec, "--"+spec.long)) - if !spec.required { + fmt.Fprint(w, synopsis(arg, "--"+arg.long)) + if !arg.required { fmt.Fprint(w, "]") } } @@ -137,16 +133,16 @@ func (p *Parser) writeUsageForSubcommand(w io.Writer, cmd *command) { // REQUIRED1 REQUIRED2 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]] // REQUIRED1 REQUIRED2 [OPTIONAL1 [REPEATEDOPTIONAL [REPEATEDOPTIONAL ...]]] var closeBrackets int - for _, spec := range positionals { + for _, arg := range positionals { fmt.Fprint(w, " ") - if !spec.required { + if !arg.required { fmt.Fprint(w, "[") closeBrackets += 1 } - if spec.cardinality == multiple { - fmt.Fprintf(w, "%s [%s ...]", spec.placeholder, spec.placeholder) + if arg.cardinality == multiple { + fmt.Fprintf(w, "%s [%s ...]", arg.placeholder, arg.placeholder) } else { - fmt.Fprint(w, spec.placeholder) + fmt.Fprint(w, arg.placeholder) } } fmt.Fprint(w, strings.Repeat("]", closeBrackets)) @@ -191,13 +187,9 @@ func printTwoCols(w io.Writer, left, help string, defaultVal string, envVal stri fmt.Fprint(w, "\n") } -// WriteHelp writes the usage string followed by the full help string for each option +// WriteHelp writes the usage string for the top-level command func (p *Parser) WriteHelp(w io.Writer) { - cmd := p.cmd - if p.lastCmd != nil { - cmd = p.lastCmd - } - p.writeHelpForSubcommand(w, cmd) + p.writeHelpForSubcommand(w, p.cmd) } // WriteHelpForSubcommand writes the usage string followed by the full help @@ -215,68 +207,68 @@ func (p *Parser) WriteHelpForSubcommand(w io.Writer, subcommand ...string) error } // writeHelp writes the usage string for the given subcommand -func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { - var positionals, longOptions, shortOptions []*spec - for _, spec := range cmd.specs { +func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *Command) { + var positionals, longOptions, shortOptions []*Argument + for _, arg := range cmd.args { switch { - case spec.positional: - positionals = append(positionals, spec) - case spec.long != "": - longOptions = append(longOptions, spec) - case spec.short != "": - shortOptions = append(shortOptions, spec) + case arg.positional: + positionals = append(positionals, arg) + case arg.long != "": + longOptions = append(longOptions, arg) + case arg.short != "": + shortOptions = append(shortOptions, arg) } } - if p.description != "" { - fmt.Fprintln(w, p.description) + if p.prologue != "" { + fmt.Fprintln(w, p.prologue) } p.writeUsageForSubcommand(w, cmd) // write the list of positionals if len(positionals) > 0 { fmt.Fprint(w, "\nPositional arguments:\n") - for _, spec := range positionals { - printTwoCols(w, spec.placeholder, spec.help, "", "") + for _, arg := range positionals { + printTwoCols(w, arg.placeholder, arg.help, "", "") } } // write the list of options with the short-only ones first to match the usage string if len(shortOptions)+len(longOptions) > 0 || cmd.parent == nil { fmt.Fprint(w, "\nOptions:\n") - for _, spec := range shortOptions { - p.printOption(w, spec) + for _, arg := range shortOptions { + p.printOption(w, arg) } - for _, spec := range longOptions { - p.printOption(w, spec) + for _, arg := range longOptions { + p.printOption(w, arg) } } // obtain a flattened list of options from all ancestors - var globals []*spec + var globals []*Argument ancestor := cmd.parent for ancestor != nil { - globals = append(globals, ancestor.specs...) + globals = append(globals, ancestor.args...) ancestor = ancestor.parent } // write the list of global options if len(globals) > 0 { fmt.Fprint(w, "\nGlobal options:\n") - for _, spec := range globals { - p.printOption(w, spec) + for _, arg := range globals { + p.printOption(w, arg) } } // write the list of built in options - p.printOption(w, &spec{ + p.printOption(w, &Argument{ cardinality: zero, long: "help", short: "h", help: "display this help and exit", }) if p.version != "" { - p.printOption(w, &spec{ + p.printOption(w, &Argument{ cardinality: zero, long: "version", help: "display version and exit", @@ -296,16 +288,16 @@ func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { } } -func (p *Parser) printOption(w io.Writer, spec *spec) { +func (p *Parser) printOption(w io.Writer, arg *Argument) { ways := make([]string, 0, 2) - if spec.long != "" { - ways = append(ways, synopsis(spec, "--"+spec.long)) + if arg.long != "" { + ways = append(ways, synopsis(arg, "--"+arg.long)) } - if spec.short != "" { - ways = append(ways, synopsis(spec, "-"+spec.short)) + if arg.short != "" { + ways = append(ways, synopsis(arg, "-"+arg.short)) } if len(ways) > 0 { - printTwoCols(w, strings.Join(ways, ", "), spec.help, spec.defaultVal, spec.env) + printTwoCols(w, strings.Join(ways, ", "), arg.help, arg.defaultVal, arg.env) } } @@ -314,10 +306,10 @@ func (p *Parser) printOption(w io.Writer, spec *spec) { // subcommand of that subcommand, and so on. If no strings are given then the // root command is returned. If no such subcommand exists then an error is // returned. -func (p *Parser) lookupCommand(path ...string) (*command, error) { +func (p *Parser) lookupCommand(path ...string) (*Command, error) { cmd := p.cmd for _, name := range path { - var found *command + var found *Command for _, child := range cmd.subcommands { if child.name == name { found = child @@ -331,9 +323,9 @@ func (p *Parser) lookupCommand(path ...string) (*command, error) { return cmd, nil } -func synopsis(spec *spec, form string) string { - if spec.cardinality == zero { +func synopsis(arg *Argument, form string) string { + if arg.cardinality == zero { return form } - return form + " " + spec.placeholder + return form + " " + arg.placeholder } diff --git a/v2/usage_test.go b/v2/usage_test.go index fd67fc8..b306506 100644 --- a/v2/usage_test.go +++ b/v2/usage_test.go @@ -443,20 +443,12 @@ Global options: p, err := NewParser(Config{}, &args) require.NoError(t, err) - _ = p.Parse([]string{"child", "nested", "value"}) - - var help bytes.Buffer - p.WriteHelp(&help) - assert.Equal(t, expectedHelp[1:], help.String()) + _ = p.Parse([]string{"child", "nested", "value"}, nil) var help2 bytes.Buffer p.WriteHelpForSubcommand(&help2, "child", "nested") assert.Equal(t, expectedHelp[1:], help2.String()) - var usage bytes.Buffer - p.WriteUsage(&usage) - assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) - var usage2 bytes.Buffer p.WriteUsageForSubcommand(&usage2, "child", "nested") assert.Equal(t, expectedUsage, strings.TrimSpace(usage2.String())) From 22f214d7eda0eaffb3dcc67106f405da4ff10293 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:02:46 -0700 Subject: [PATCH 03/19] added test that library does not directly access environment variables from OS --- v2/parse_test.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/v2/parse_test.go b/v2/parse_test.go index b9d6948..d323847 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -792,19 +792,18 @@ func TestEnvironmentVariableEmptyMap(t *testing.T) { assert.Len(t, args.Foo, 0) } -// func TestEnvironmentVariableIgnored(t *testing.T) { -// var args struct { -// Foo string `arg:"env"` -// } -// setenv(t, "FOO", "abc") +func TestEnvironmentVariableIgnored(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } -// p, err := NewParser(Config{IgnoreEnv: true}, &args) -// require.NoError(t, err) + // the library should never read env vars direct from os + os.Setenv("FOO", "123") -// err = p.Parse(nil) -// assert.NoError(t, err) -// assert.Equal(t, "", args.Foo) -// } + _, err := parseWithEnv(&args, "") + require.NoError(t, err) + assert.Equal(t, "", args.Foo) +} // func TestDefaultValuesIgnored(t *testing.T) { // var args struct { From a1e2b672eacc00152b7047ea971474283b60bf1e Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:06:19 -0700 Subject: [PATCH 04/19] add a test to check that default values can be ignored if needed --- v2/parse_test.go | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/v2/parse_test.go b/v2/parse_test.go index d323847..148bd09 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -805,18 +805,25 @@ func TestEnvironmentVariableIgnored(t *testing.T) { assert.Equal(t, "", args.Foo) } -// func TestDefaultValuesIgnored(t *testing.T) { -// var args struct { -// Foo string `default:"bad"` -// } +func TestDefaultValuesIgnored(t *testing.T) { + var args struct { + Foo string `default:"bad"` + } -// p, err := NewParser(Config{IgnoreDefault: true}, &args) -// require.NoError(t, err) + // just checking that default values are not automatically applied + // in ProcessCommandLine or ProcessEnvironment -// err = p.Parse(nil) -// assert.NoError(t, err) -// assert.Equal(t, "", args.Foo) -// } + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + err = p.ProcessCommandLine(nil) + assert.NoError(t, err) + + err = p.ProcessEnvironment(nil) + assert.NoError(t, err) + + assert.Equal(t, "", args.Foo) +} func TestEnvironmentVariableInSubcommand(t *testing.T) { var args struct { From 4aea7830230f37e75ee4a1edceca6346213d3a7b Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:28:34 -0700 Subject: [PATCH 05/19] changed NewParser to take options at the end rather than config at the front --- v2/example_test.go | 4 +- v2/parse.go | 94 ++++++++++++++++++++----------------------- v2/parse_test.go | 8 ++-- v2/subcommand_test.go | 10 ++--- v2/usage_test.go | 70 ++++++++++++-------------------- 5 files changed, 79 insertions(+), 107 deletions(-) diff --git a/v2/example_test.go b/v2/example_test.go index fd64777..e769d60 100644 --- a/v2/example_test.go +++ b/v2/example_test.go @@ -314,7 +314,7 @@ func Example_writeHelpForSubcommand() { osExit = func(int) {} stdout = os.Stdout - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args, WithProgramName("example")) if err != nil { fmt.Println(err) os.Exit(1) @@ -363,7 +363,7 @@ func Example_writeHelpForSubcommandNested() { osExit = func(int) {} stdout = os.Stdout - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args, WithProgramName("example")) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/v2/parse.go b/v2/parse.go index ce02bd4..f5bcab4 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -1,7 +1,6 @@ package arg import ( - "encoding" "encoding/csv" "errors" "fmt" @@ -73,7 +72,7 @@ var ErrVersion = errors.New("version requested by user") // MustParse processes command line arguments and exits upon failure func MustParse(dest interface{}) *Parser { - p, err := NewParser(Config{}, dest) + p, err := NewParser(dest) if err != nil { fmt.Fprintln(stdout, err) osExit(-1) @@ -96,25 +95,18 @@ func MustParse(dest interface{}) *Parser { } // Parse processes command line arguments and stores them in dest -func Parse(dest interface{}) error { - p, err := NewParser(Config{}, dest) +func Parse(dest interface{}, options ...ParserOption) error { + p, err := NewParser(dest, options...) if err != nil { return err } return p.Parse(os.Args, os.Environ()) } -// Config represents configuration options for an argument parser -type Config struct { - // Program is the name of the program used in the help text - Program string -} - // Parser represents a set of command line options with destination values type Parser struct { cmd *Command // the top-level command root reflect.Value // destination struct to fill will values - config Config // configuration passed to NewParser version string // version from the argument struct prologue string // prologue for help text (from the argument struct) epilogue string // epilogue for help text (from the argument struct) @@ -170,58 +162,58 @@ func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner } } +// the ParserOption interface matches options for the parser constructor +type ParserOption interface { + parserOption() +} + +type programNameParserOption struct { + s string +} + +func (programNameParserOption) parserOption() {} + +// WithProgramName overrides the name of the program as displayed in help test +func WithProgramName(name string) ParserOption { + return programNameParserOption{s: name} +} + // NewParser constructs a parser from a list of destination structs -func NewParser(config Config, dest interface{}) (*Parser, error) { - // first pick a name for the command for use in the usage text - var name string - switch { - case config.Program != "": - name = config.Program - case len(os.Args) > 0: - name = filepath.Base(os.Args[0]) - default: - name = "program" - } - - // construct a parser - p := Parser{ - cmd: &Command{name: name}, - config: config, - seen: make(map[*Argument]bool), - } - - // make a list of roots - p.root = reflect.ValueOf(dest) - - // process each of the destination values +func NewParser(dest interface{}, options ...ParserOption) (*Parser, error) { + // check the destination type t := reflect.TypeOf(dest) if t.Kind() != reflect.Ptr { panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) } - cmd, err := cmdFromStruct(name, path{}, t) + // pick a program name for help text and usage output + program := "program" + if len(os.Args) > 0 { + program = filepath.Base(os.Args[0]) + } + + // apply the options + for _, opt := range options { + switch opt := opt.(type) { + case programNameParserOption: + program = opt.s + } + } + + // build the root command from the struct + cmd, err := cmdFromStruct(program, path{}, t) if err != nil { return nil, err } - // add nonzero field values as defaults - for _, arg := range cmd.args { - if v := p.val(arg.dest); v.IsValid() && !isZero(v) { - if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { - str, err := defaultVal.MarshalText() - if err != nil { - return nil, fmt.Errorf("%v: error marshaling default value to string: %v", arg.dest, err) - } - arg.defaultVal = string(str) - } else { - arg.defaultVal = fmt.Sprintf("%v", v) - } - } + // construct the parser + p := Parser{ + seen: make(map[*Argument]bool), + root: reflect.ValueOf(dest), + cmd: cmd, } - p.cmd.args = append(p.cmd.args, cmd.args...) - p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...) - + // check for version, prologue, and epilogue if dest, ok := dest.(Versioned); ok { p.version = dest.Version() } diff --git a/v2/parse_test.go b/v2/parse_test.go index 148bd09..042712c 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -24,7 +24,7 @@ func pparse(cmdline string, dest interface{}) (*Parser, error) { } func parseWithEnv(dest interface{}, cmdline string, env ...string) (*Parser, error) { - p, err := NewParser(Config{}, dest) + p, err := NewParser(dest) if err != nil { return nil, err } @@ -813,7 +813,7 @@ func TestDefaultValuesIgnored(t *testing.T) { // just checking that default values are not automatically applied // in ProcessCommandLine or ProcessEnvironment - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args) require.NoError(t, err) err = p.ProcessCommandLine(nil) @@ -1293,7 +1293,7 @@ func TestReuseParser(t *testing.T) { Foo string `arg:"required"` } - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args) require.NoError(t, err) err = p.Parse([]string{"program", "--foo=abc"}, nil) @@ -1405,7 +1405,7 @@ func TestUnexportedFieldsSkipped(t *testing.T) { unexported struct{} } - _, err := NewParser(Config{}, &args) + _, err := NewParser(&args) require.NoError(t, err) } diff --git a/v2/subcommand_test.go b/v2/subcommand_test.go index 9f7c8c5..31dc2dd 100644 --- a/v2/subcommand_test.go +++ b/v2/subcommand_test.go @@ -15,7 +15,7 @@ func TestSubcommandNotAPointer(t *testing.T) { var args struct { A string `arg:"subcommand"` } - _, err := NewParser(Config{}, &args) + _, err := NewParser(&args) assert.Error(t, err) } @@ -23,7 +23,7 @@ func TestSubcommandNotAPointerToStruct(t *testing.T) { var args struct { A struct{} `arg:"subcommand"` } - _, err := NewParser(Config{}, &args) + _, err := NewParser(&args) assert.Error(t, err) } @@ -32,7 +32,7 @@ func TestPositionalAndSubcommandNotAllowed(t *testing.T) { A string `arg:"positional"` B *struct{} `arg:"subcommand"` } - _, err := NewParser(Config{}, &args) + _, err := NewParser(&args) assert.Error(t, err) } @@ -54,7 +54,7 @@ func TestSubcommandNamesBeforeParsing(t *testing.T) { var args struct { List *listCmd `arg:"subcommand"` } - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args) require.NoError(t, err) assert.Nil(t, p.Subcommand()) assert.Nil(t, p.SubcommandNames()) @@ -400,7 +400,7 @@ func TestValForNilStruct(t *testing.T) { Sub *subcmd `arg:"subcommand"` } - p, err := NewParser(Config{}, &cmd) + p, err := NewParser(&cmd) require.NoError(t, err) typ := reflect.TypeOf(cmd) diff --git a/v2/usage_test.go b/v2/usage_test.go index b306506..7a5e11d 100644 --- a/v2/usage_test.go +++ b/v2/usage_test.go @@ -50,19 +50,19 @@ Options: --optimize OPTIMIZE, -O OPTIMIZE optimization level --ids IDS Ids - --values VALUES Values [default: [3.14 42 256]] + --values VALUES Values --workers WORKERS, -w WORKERS number of workers to start [default: 10, env: WORKERS] --testenv TESTENV, -a TESTENV [env: TEST_ENV] - --file FILE, -f FILE File with mandatory extension [default: scratch.txt] + --file FILE, -f FILE File with mandatory extension --help, -h display this help and exit ` var args struct { Input string `arg:"positional,required"` Output []string `arg:"positional" help:"list of outputs"` - Name string `help:"name to use"` - Value int `help:"secret value"` + Name string `help:"name to use" default:"Foo Bar"` + Value int `help:"secret value" default:"42"` Verbose bool `arg:"-v" help:"verbosity level"` Dataset string `help:"dataset to use"` Optimize int `arg:"-O" help:"optimization level"` @@ -72,11 +72,7 @@ Options: TestEnv string `arg:"-a,env:TEST_ENV"` File *NameDotName `arg:"-f" help:"File with mandatory extension"` } - args.Name = "Foo Bar" - args.Value = 42 - args.Values = []float64{3.14, 42, 256} - args.File = &NameDotName{"scratch", "txt"} - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) os.Args[0] = "example" @@ -112,11 +108,10 @@ Options: --help, -h display this help and exit ` var args struct { - Label string + Label string `default:"cat"` Content string `default:"dog"` } - args.Label = "cat" - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) args.Label = "should_ignore_this" @@ -130,16 +125,6 @@ Options: assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } -func TestUsageCannotMarshalToString(t *testing.T) { - var args struct { - Name *MyEnum - } - v := MyEnum(42) - args.Name = &v - _, err := NewParser(Config{Program: "example"}, &args) - assert.EqualError(t, err, `args.Name: error marshaling default value to string: There was a problem`) -} - func TestUsageLongPositionalWithHelp_legacyForm(t *testing.T) { expectedUsage := "Usage: example [VERYLONGPOSITIONALWITHHELP]" @@ -157,7 +142,7 @@ Options: VeryLongPositionalWithHelp string `arg:"positional,help:this positional argument is very long but cannot include commas"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var help bytes.Buffer @@ -186,7 +171,7 @@ Options: VeryLongPositionalWithHelp string `arg:"positional" help:"this positional argument is very long, and includes: commas, colons etc"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var help bytes.Buffer @@ -207,10 +192,7 @@ Usage: myprogram Options: --help, -h display this help and exit ` - config := Config{ - Program: "myprogram", - } - p, err := NewParser(config, &struct{}{}) + p, err := NewParser(&struct{}{}, WithProgramName("myprogram")) require.NoError(t, err) os.Args[0] = "example" @@ -242,8 +224,7 @@ Options: --help, -h display this help and exit --version display version and exit ` - os.Args[0] = "example" - p, err := NewParser(Config{}, &versioned{}) + p, err := NewParser(&versioned{}, WithProgramName("example")) require.NoError(t, err) var help bytes.Buffer @@ -272,8 +253,7 @@ Usage: example Options: --help, -h display this help and exit ` - os.Args[0] = "example" - p, err := NewParser(Config{}, &described{}) + p, err := NewParser(&described{}, WithProgramName("example")) require.NoError(t, err) var help bytes.Buffer @@ -304,7 +284,7 @@ Options: For more information visit github.com/alexflint/go-arg ` os.Args[0] = "example" - p, err := NewParser(Config{}, &epilogued{}) + p, err := NewParser(&epilogued{}, WithProgramName("example")) require.NoError(t, err) var help bytes.Buffer @@ -323,7 +303,7 @@ func TestUsageForRequiredPositionals(t *testing.T) { Required2 string `arg:"positional,required"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var usage bytes.Buffer @@ -340,7 +320,7 @@ func TestUsageForMixedPositionals(t *testing.T) { Optional2 string `arg:"positional"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var usage bytes.Buffer @@ -356,7 +336,7 @@ func TestUsageForRepeatedPositionals(t *testing.T) { Repeated []string `arg:"positional,required"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var usage bytes.Buffer @@ -374,7 +354,7 @@ func TestUsageForMixedAndRepeatedPositionals(t *testing.T) { Repeated []string `arg:"positional"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var usage bytes.Buffer @@ -398,7 +378,7 @@ Options: RequiredMultiple []string `arg:"positional,required" help:"required multiple positional"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) var help bytes.Buffer @@ -440,7 +420,7 @@ Global options: } os.Args[0] = "example" - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args) require.NoError(t, err) _ = p.Parse([]string{"child", "nested", "value"}, nil) @@ -458,7 +438,7 @@ func TestNonexistentSubcommand(t *testing.T) { var args struct { sub *struct{} `arg:"subcommand"` } - p, err := NewParser(Config{}, &args) + p, err := NewParser(&args) require.NoError(t, err) var b bytes.Buffer @@ -497,7 +477,7 @@ Options: ShortOnly string `arg:"-a,--" help:"some help" default:"some val" placeholder:"PLACEHOLDER"` ShortOnly2 string `arg:"-b,--,required" help:"some help2"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) assert.NoError(t, err) var help bytes.Buffer @@ -524,7 +504,7 @@ Options: Dog string Cat string `arg:"-c,--"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) assert.NoError(t, err) var help bytes.Buffer @@ -552,7 +532,7 @@ Options: EnvOnlyOverriden string `arg:"--,env:CUSTOM"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) assert.NoError(t, err) var help bytes.Buffer @@ -586,7 +566,7 @@ error: something went wrong var args struct { Foo int } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) p.Fail("something went wrong") @@ -616,7 +596,7 @@ error: something went wrong var args struct { Sub *struct{} `arg:"subcommand"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(&args, WithProgramName("example")) require.NoError(t, err) err = p.FailSubcommand("something went wrong", "sub") From 5ca19cd72d03c3216896c0dc14033e706bd5dae7 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:39:58 -0700 Subject: [PATCH 06/19] cleaned up the test helpers parse, pparse, and parseWithEnv: now all are just using "parse" --- v2/parse_test.go | 270 ++++++++++++++++++++---------------------- v2/subcommand_test.go | 60 +++++----- 2 files changed, 160 insertions(+), 170 deletions(-) diff --git a/v2/parse_test.go b/v2/parse_test.go index 042712c..c65ded8 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -14,16 +14,7 @@ import ( "github.com/stretchr/testify/require" ) -func parse(cmdline string, dest interface{}) error { - _, err := pparse(cmdline, dest) - return err -} - -func pparse(cmdline string, dest interface{}) (*Parser, error) { - return parseWithEnv(dest, cmdline) -} - -func parseWithEnv(dest interface{}, cmdline string, env ...string) (*Parser, error) { +func parse(dest interface{}, cmdline string, env ...string) (*Parser, error) { p, err := NewParser(dest) if err != nil { return nil, err @@ -44,7 +35,7 @@ func TestString(t *testing.T) { Foo string Ptr *string } - err := parse("--foo bar --ptr baz", &args) + _, err := parse(&args, "--foo bar --ptr baz") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) assert.Equal(t, "baz", *args.Ptr) @@ -57,7 +48,7 @@ func TestBool(t *testing.T) { C *bool D *bool } - err := parse("--a --c", &args) + _, err := parse(&args, "--a --c") require.NoError(t, err) assert.True(t, args.A) assert.False(t, args.B) @@ -70,7 +61,7 @@ func TestInt(t *testing.T) { Foo int Ptr *int } - err := parse("--foo 7 --ptr 8", &args) + _, err := parse(&args, "--foo 7 --ptr 8") require.NoError(t, err) assert.EqualValues(t, 7, args.Foo) assert.EqualValues(t, 8, *args.Ptr) @@ -83,7 +74,7 @@ func TestHexOctBin(t *testing.T) { Bin int Underscored int } - err := parse("--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456", &args) + _, err := parse(&args, "--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456") require.NoError(t, err) assert.EqualValues(t, 10, args.Hex) assert.EqualValues(t, 8, args.Oct) @@ -95,7 +86,7 @@ func TestNegativeInt(t *testing.T) { var args struct { Foo int } - err := parse("-foo=-100", &args) + _, err := parse(&args, "-foo=-100") require.NoError(t, err) assert.EqualValues(t, args.Foo, -100) } @@ -104,7 +95,7 @@ func TestNumericOptionName(t *testing.T) { var args struct { N int `arg:"--100"` } - err := parse("-100 6", &args) + _, err := parse(&args, "-100 6") require.NoError(t, err) assert.EqualValues(t, args.N, 6) } @@ -114,7 +105,7 @@ func TestUint(t *testing.T) { Foo uint Ptr *uint } - err := parse("--foo 7 --ptr 8", &args) + _, err := parse(&args, "--foo 7 --ptr 8") require.NoError(t, err) assert.EqualValues(t, 7, args.Foo) assert.EqualValues(t, 8, *args.Ptr) @@ -125,7 +116,7 @@ func TestFloat(t *testing.T) { Foo float32 Ptr *float32 } - err := parse("--foo 3.4 --ptr 3.5", &args) + _, err := parse(&args, "--foo 3.4 --ptr 3.5") require.NoError(t, err) assert.EqualValues(t, 3.4, args.Foo) assert.EqualValues(t, 3.5, *args.Ptr) @@ -136,7 +127,7 @@ func TestDuration(t *testing.T) { Foo time.Duration Ptr *time.Duration } - err := parse("--foo 3ms --ptr 4ms", &args) + _, err := parse(&args, "--foo 3ms --ptr 4ms") require.NoError(t, err) assert.Equal(t, 3*time.Millisecond, args.Foo) assert.Equal(t, 4*time.Millisecond, *args.Ptr) @@ -146,7 +137,7 @@ func TestInvalidDuration(t *testing.T) { var args struct { Foo time.Duration } - err := parse("--foo xxx", &args) + _, err := parse(&args, "--foo xxx") require.Error(t, err) } @@ -154,7 +145,7 @@ func TestIntPtr(t *testing.T) { var args struct { Foo *int } - err := parse("--foo 123", &args) + _, err := parse(&args, "--foo 123") require.NoError(t, err) require.NotNil(t, args.Foo) assert.Equal(t, 123, *args.Foo) @@ -164,7 +155,7 @@ func TestIntPtrNotPresent(t *testing.T) { var args struct { Foo *int } - err := parse("", &args) + _, err := parse(&args, "") require.NoError(t, err) assert.Nil(t, args.Foo) } @@ -178,7 +169,7 @@ func TestMixed(t *testing.T) { Spam float32 } args.Bar = 3 - err := parse("123 -spam=1.2 -ham -f xyz", &args) + _, err := parse(&args, "123 -spam=1.2 -ham -f xyz") require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) assert.Equal(t, 3, args.Bar) @@ -191,7 +182,7 @@ func TestRequired(t *testing.T) { var args struct { Foo string `arg:"required"` } - err := parse("--foo=abc", &args) + _, err := parse(&args, "--foo=abc") require.NoError(t, err) } @@ -199,7 +190,7 @@ func TestMissingRequired(t *testing.T) { var args struct { Foo string `arg:"required"` } - err := parse("", &args) + _, err := parse(&args, "") require.Error(t, err, "--foo is required") } @@ -207,7 +198,7 @@ func TestMissingRequiredWithEnv(t *testing.T) { var args struct { Foo string `arg:"required,env:FOO"` } - err := parse("", &args) + _, err := parse(&args, "") require.Error(t, err, "--foo is required (or environment variable FOO)") } @@ -216,24 +207,24 @@ func TestShortFlag(t *testing.T) { Foo string `arg:"-f"` } - err := parse("-f xyz", &args) + _, err := parse(&args, "-f a") require.NoError(t, err) - assert.Equal(t, "xyz", args.Foo) + assert.Equal(t, "a", args.Foo) - err = parse("-foo xyz", &args) + _, err = parse(&args, "-foo b") require.NoError(t, err) - assert.Equal(t, "xyz", args.Foo) + assert.Equal(t, "b", args.Foo) - err = parse("--foo xyz", &args) + _, err = parse(&args, "--foo c") require.NoError(t, err) - assert.Equal(t, "xyz", args.Foo) + assert.Equal(t, "c", args.Foo) } func TestInvalidShortFlag(t *testing.T) { var args struct { Foo string `arg:"-foo"` } - err := parse("", &args) + _, err := parse(&args, "") assert.Error(t, err) } @@ -242,11 +233,11 @@ func TestLongFlag(t *testing.T) { Foo string `arg:"--abc"` } - err := parse("-abc xyz", &args) + _, err := parse(&args, "-abc xyz") require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) - err = parse("--abc xyz", &args) + _, err = parse(&args, "--abc xyz") require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) } @@ -255,7 +246,7 @@ func TestSlice(t *testing.T) { var args struct { Strings []string } - err := parse("--strings a b c", &args) + _, err := parse(&args, "--strings a b c") require.NoError(t, err) assert.Equal(t, []string{"a", "b", "c"}, args.Strings) } @@ -264,7 +255,7 @@ func TestSliceOfBools(t *testing.T) { B []bool } - err := parse("--b true false true", &args) + _, err := parse(&args, "--b true false true") require.NoError(t, err) assert.Equal(t, []bool{true, false, true}, args.B) } @@ -273,7 +264,7 @@ func TestMap(t *testing.T) { var args struct { Values map[string]int } - err := parse("--values a=1 b=2 c=3", &args) + _, err := parse(&args, "--values a=1 b=2 c=3") require.NoError(t, err) assert.Len(t, args.Values, 3) assert.Equal(t, 1, args.Values["a"]) @@ -285,7 +276,7 @@ func TestMapPositional(t *testing.T) { var args struct { Values map[string]int `arg:"positional"` } - err := parse("a=1 b=2 c=3", &args) + _, err := parse(&args, "a=1 b=2 c=3") require.NoError(t, err) assert.Len(t, args.Values, 3) assert.Equal(t, 1, args.Values["a"]) @@ -297,7 +288,7 @@ func TestMapWithSeparate(t *testing.T) { var args struct { Values map[string]int `arg:"separate"` } - err := parse("--values a=1 --values b=2 --values c=3", &args) + _, err := parse(&args, "--values a=1 --values b=2 --values c=3") require.NoError(t, err) assert.Len(t, args.Values, 3) assert.Equal(t, 1, args.Values["a"]) @@ -312,7 +303,7 @@ func TestPlaceholder(t *testing.T) { Optimize int `arg:"-O" placeholder:"LEVEL"` MaxJobs int `arg:"-j" placeholder:"N"` } - err := parse("-O 5 --maxjobs 2 src dest1 dest2", &args) + _, err := parse(&args, "-O 5 --maxjobs 2 src dest1 dest2") assert.NoError(t, err) } @@ -321,7 +312,7 @@ func TestNoLongName(t *testing.T) { ShortOnly string `arg:"-s,--"` EnvOnly string `arg:"--,env"` } - _, err := parseWithEnv(&args, "-s TestVal2", "ENVONLY=TestVal") + _, err := parse(&args, "-s TestVal2", "ENVONLY=TestVal") assert.NoError(t, err) assert.Equal(t, "TestVal", args.EnvOnly) assert.Equal(t, "TestVal2", args.ShortOnly) @@ -333,7 +324,7 @@ func TestCaseSensitive(t *testing.T) { Upper bool `arg:"-V"` } - err := parse("-v", &args) + _, err := parse(&args, "-v") require.NoError(t, err) assert.True(t, args.Lower) assert.False(t, args.Upper) @@ -345,7 +336,7 @@ func TestCaseSensitive2(t *testing.T) { Upper bool `arg:"-V"` } - err := parse("-V", &args) + _, err := parse(&args, "-V") require.NoError(t, err) assert.False(t, args.Lower) assert.True(t, args.Upper) @@ -356,7 +347,7 @@ func TestPositional(t *testing.T) { Input string `arg:"positional"` Output string `arg:"positional"` } - err := parse("foo", &args) + _, err := parse(&args, "foo") require.NoError(t, err) assert.Equal(t, "foo", args.Input) assert.Equal(t, "", args.Output) @@ -367,7 +358,7 @@ func TestPositionalPointer(t *testing.T) { Input string `arg:"positional"` Output []*string `arg:"positional"` } - err := parse("foo bar baz", &args) + _, err := parse(&args, "foo bar baz") require.NoError(t, err) assert.Equal(t, "foo", args.Input) bar := "bar" @@ -380,7 +371,7 @@ func TestRequiredPositional(t *testing.T) { Input string `arg:"positional"` Output string `arg:"positional,required"` } - err := parse("foo", &args) + _, err := parse(&args, "foo") assert.Error(t, err) } @@ -389,7 +380,7 @@ func TestRequiredPositionalMultiple(t *testing.T) { Input string `arg:"positional"` Multiple []string `arg:"positional,required"` } - err := parse("foo", &args) + _, err := parse(&args, "foo") assert.Error(t, err) } @@ -398,7 +389,7 @@ func TestTooManyPositional(t *testing.T) { Input string `arg:"positional"` Output string `arg:"positional"` } - err := parse("foo bar baz", &args) + _, err := parse(&args, "foo bar baz") assert.Error(t, err) } @@ -407,7 +398,7 @@ func TestMultiple(t *testing.T) { Foo []int Bar []string } - err := parse("--foo 1 2 3 --bar x y z", &args) + _, err := parse(&args, "--foo 1 2 3 --bar x y z") require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, args.Foo) assert.Equal(t, []string{"x", "y", "z"}, args.Bar) @@ -418,7 +409,7 @@ func TestMultiplePositionals(t *testing.T) { Input string `arg:"positional"` Multiple []string `arg:"positional,required"` } - err := parse("foo a b c", &args) + _, err := parse(&args, "foo a b c") assert.NoError(t, err) assert.Equal(t, "foo", args.Input) assert.Equal(t, []string{"a", "b", "c"}, args.Multiple) @@ -429,7 +420,7 @@ func TestMultipleWithEq(t *testing.T) { Foo []int Bar []string } - err := parse("--foo 1 2 3 --bar=x", &args) + _, err := parse(&args, "--foo 1 2 3 --bar=x") require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, args.Foo) assert.Equal(t, []string{"x"}, args.Bar) @@ -442,7 +433,7 @@ func TestMultipleWithDefault(t *testing.T) { } args.Foo = []int{42} args.Bar = []string{"foo"} - err := parse("--foo 1 2 3 --bar x y z", &args) + _, err := parse(&args, "--foo 1 2 3 --bar x y z") require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, args.Foo) assert.Equal(t, []string{"x", "y", "z"}, args.Bar) @@ -453,7 +444,7 @@ func TestExemptField(t *testing.T) { Foo string Bar interface{} `arg:"-"` } - err := parse("--foo xyz", &args) + _, err := parse(&args, "--foo xyz") require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) } @@ -462,7 +453,7 @@ func TestUnknownField(t *testing.T) { var args struct { Foo string } - err := parse("--bar xyz", &args) + _, err := parse(&args, "--bar xyz") assert.Error(t, err) } @@ -470,7 +461,7 @@ func TestNonsenseKey(t *testing.T) { var args struct { X []string `arg:"positional, nonsense"` } - err := parse("x", &args) + _, err := parse(&args, "x") assert.Error(t, err) } @@ -478,7 +469,7 @@ func TestMissingValueAtEnd(t *testing.T) { var args struct { Foo string } - err := parse("--foo", &args) + _, err := parse(&args, "--foo") assert.Error(t, err) } @@ -487,7 +478,7 @@ func TestMissingValueInMiddle(t *testing.T) { Foo string Bar string } - err := parse("--foo --bar=abc", &args) + _, err := parse(&args, "--foo --bar=abc") assert.Error(t, err) } @@ -495,7 +486,7 @@ func TestNegativeValue(t *testing.T) { var args struct { Foo int } - err := parse("--foo=-123", &args) + _, err := parse(&args, "--foo=-123") require.NoError(t, err) assert.Equal(t, -123, args.Foo) } @@ -504,7 +495,7 @@ func TestInvalidInt(t *testing.T) { var args struct { Foo int } - err := parse("--foo=xyz", &args) + _, err := parse(&args, "--foo=xyz") assert.Error(t, err) } @@ -512,7 +503,7 @@ func TestInvalidUint(t *testing.T) { var args struct { Foo uint } - err := parse("--foo=xyz", &args) + _, err := parse(&args, "--foo=xyz") assert.Error(t, err) } @@ -520,7 +511,7 @@ func TestInvalidFloat(t *testing.T) { var args struct { Foo float64 } - err := parse("--foo xyz", &args) + _, err := parse(&args, "--foo xyz") require.Error(t, err) } @@ -528,7 +519,7 @@ func TestInvalidBool(t *testing.T) { var args struct { Foo bool } - err := parse("--foo=xyz", &args) + _, err := parse(&args, "--foo=xyz") require.Error(t, err) } @@ -536,7 +527,7 @@ func TestInvalidIntSlice(t *testing.T) { var args struct { Foo []int } - err := parse("--foo 1 2 xyz", &args) + _, err := parse(&args, "--foo 1 2 xyz") require.Error(t, err) } @@ -544,7 +535,7 @@ func TestInvalidPositional(t *testing.T) { var args struct { Foo int `arg:"positional"` } - err := parse("xyz", &args) + _, err := parse(&args, "xyz") require.Error(t, err) } @@ -552,7 +543,7 @@ func TestInvalidPositionalSlice(t *testing.T) { var args struct { Foo []int `arg:"positional"` } - err := parse("1 2 xyz", &args) + _, err := parse(&args, "1 2 xyz") require.Error(t, err) } @@ -561,7 +552,7 @@ func TestNoMoreOptions(t *testing.T) { Foo string Bar []string `arg:"positional"` } - err := parse("abc -- --foo xyz", &args) + _, err := parse(&args, "abc -- --foo xyz") require.NoError(t, err) assert.Equal(t, "", args.Foo) assert.Equal(t, []string{"abc", "--foo", "xyz"}, args.Bar) @@ -571,7 +562,7 @@ func TestNoMoreOptionsBeforeHelp(t *testing.T) { var args struct { Foo int } - err := parse("not_an_integer -- --help", &args) + _, err := parse(&args, "not_an_integer -- --help") assert.NotEqual(t, ErrHelp, err) } @@ -580,20 +571,20 @@ func TestHelpFlag(t *testing.T) { Foo string Bar interface{} `arg:"-"` } - err := parse("--help", &args) + _, err := parse(&args, "--help") assert.Equal(t, ErrHelp, err) } func TestPanicOnNonPointer(t *testing.T) { var args struct{} assert.Panics(t, func() { - _ = parse("", args) + _, _ = parse(args, "") }) } func TestErrorOnNonStruct(t *testing.T) { var args string - err := parse("", &args) + _, err := parse(&args, "") assert.Error(t, err) } @@ -601,7 +592,7 @@ func TestUnsupportedType(t *testing.T) { var args struct { Foo interface{} } - err := parse("--foo", &args) + _, err := parse(&args, "--foo") assert.Error(t, err) } @@ -609,7 +600,7 @@ func TestUnsupportedSliceElement(t *testing.T) { var args struct { Foo []interface{} } - err := parse("--foo 3", &args) + _, err := parse(&args, "--foo 3") assert.Error(t, err) } @@ -617,7 +608,7 @@ func TestUnsupportedSliceElementMissingValue(t *testing.T) { var args struct { Foo []interface{} } - err := parse("--foo", &args) + _, err := parse(&args, "--foo") assert.Error(t, err) } @@ -625,7 +616,7 @@ func TestUnknownTag(t *testing.T) { var args struct { Foo string `arg:"this_is_not_valid"` } - err := parse("--foo xyz", &args) + _, err := parse(&args, "--foo xyz") assert.Error(t, err) } @@ -633,8 +624,7 @@ func TestParse(t *testing.T) { var args struct { Foo string } - os.Args = []string{"example", "--foo", "bar"} - err := Parse(&args) + _, err := parse(&args, "--foo bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -643,8 +633,7 @@ func TestParseError(t *testing.T) { var args struct { Foo string `arg:"this_is_not_valid"` } - os.Args = []string{"example", "--bar"} - err := Parse(&args) + _, err := NewParser(&args) assert.Error(t, err) } @@ -662,7 +651,7 @@ func TestEnvironmentVariable(t *testing.T) { var args struct { Foo string `arg:"env"` } - _, err := parseWithEnv(&args, "", "FOO=bar") + _, err := parse(&args, "", "FOO=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -671,7 +660,7 @@ func TestEnvironmentVariableNotPresent(t *testing.T) { var args struct { NotPresent string `arg:"env"` } - _, err := parseWithEnv(&args, "", "") + _, err := parse(&args, "", "") require.NoError(t, err) assert.Equal(t, "", args.NotPresent) } @@ -680,7 +669,7 @@ func TestEnvironmentVariableOverrideName(t *testing.T) { var args struct { Foo string `arg:"env:BAZ"` } - _, err := parseWithEnv(&args, "", "BAZ=bar") + _, err := parse(&args, "", "BAZ=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -689,7 +678,7 @@ func TestCommandLineSupercedesEnv(t *testing.T) { var args struct { Foo string `arg:"env"` } - _, err := parseWithEnv(&args, "--foo zzz", "FOO=bar") + _, err := parse(&args, "--foo zzz", "FOO=bar") require.NoError(t, err) assert.Equal(t, "zzz", args.Foo) } @@ -698,7 +687,7 @@ func TestEnvironmentVariableError(t *testing.T) { var args struct { Foo int `arg:"env"` } - _, err := parseWithEnv(&args, "", "FOO=bar") + _, err := parse(&args, "", "FOO=bar") assert.Error(t, err) } @@ -706,7 +695,7 @@ func TestEnvironmentVariableRequired(t *testing.T) { var args struct { Foo string `arg:"env,required"` } - _, err := parseWithEnv(&args, "", "FOO=bar") + _, err := parse(&args, "", "FOO=bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -715,7 +704,7 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) { var args struct { Foo []string `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=bar,"baz, qux"`) + _, err := parse(&args, "", `FOO=bar,"baz, qux"`) require.NoError(t, err) assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) } @@ -724,7 +713,7 @@ func TestEnvironmentVariableSliceEmpty(t *testing.T) { var args struct { Foo []string `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=`) + _, err := parse(&args, "", `FOO=`) require.NoError(t, err) assert.Len(t, args.Foo, 0) } @@ -733,7 +722,7 @@ func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { var args struct { Foo []int `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=1,99`) + _, err := parse(&args, "", `FOO=1,99`) require.NoError(t, err) assert.Equal(t, []int{1, 99}, args.Foo) } @@ -742,7 +731,7 @@ func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { var args struct { Foo []float32 `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=1.1,99.9`) + _, err := parse(&args, "", `FOO=1.1,99.9`) require.NoError(t, err) assert.Equal(t, []float32{1.1, 99.9}, args.Foo) } @@ -751,7 +740,7 @@ func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=true,false,0,1`) + _, err := parse(&args, "", `FOO=true,false,0,1`) require.NoError(t, err) assert.Equal(t, []bool{true, false, false, true}, args.Foo) } @@ -760,7 +749,7 @@ func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { var args struct { Foo []int `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=1,99\"`) + _, err := parse(&args, "", `FOO=1,99\"`) assert.Error(t, err) } @@ -768,7 +757,7 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=one,two`) + _, err := parse(&args, "", `FOO=one,two`) assert.Error(t, err) } @@ -776,7 +765,7 @@ func TestEnvironmentVariableMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=1=one,99=ninetynine`) + _, err := parse(&args, "", `FOO=1=one,99=ninetynine`) require.NoError(t, err) assert.Len(t, args.Foo, 2) assert.Equal(t, "one", args.Foo[1]) @@ -787,7 +776,7 @@ func TestEnvironmentVariableEmptyMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - _, err := parseWithEnv(&args, "", `FOO=`) + _, err := parse(&args, "", `FOO=`) require.NoError(t, err) assert.Len(t, args.Foo, 0) } @@ -800,7 +789,7 @@ func TestEnvironmentVariableIgnored(t *testing.T) { // the library should never read env vars direct from os os.Setenv("FOO", "123") - _, err := parseWithEnv(&args, "") + _, err := parse(&args, "") require.NoError(t, err) assert.Equal(t, "", args.Foo) } @@ -832,7 +821,7 @@ func TestEnvironmentVariableInSubcommand(t *testing.T) { } `arg:"subcommand"` } - _, err := parseWithEnv(&args, "sub", "FOO=abc") + _, err := parse(&args, "sub", "FOO=abc") require.NoError(t, err) require.NotNil(t, args.Sub) assert.Equal(t, "abc", args.Sub.Foo) @@ -845,7 +834,7 @@ func TestEnvironmentVariableInSubcommandEmpty(t *testing.T) { } `arg:"subcommand"` } - _, err := parseWithEnv(&args, "sub") + _, err := parse(&args, "sub") require.NoError(t, err) require.NotNil(t, args.Sub) assert.Equal(t, "", args.Sub.Foo) @@ -865,7 +854,7 @@ func TestTextUnmarshaler(t *testing.T) { var args struct { Foo textUnmarshaler } - err := parse("--foo abc", &args) + _, err := parse(&args, "--foo abc") require.NoError(t, err) assert.Equal(t, 3, args.Foo.val) } @@ -875,7 +864,7 @@ func TestPtrToTextUnmarshaler(t *testing.T) { var args struct { Foo *textUnmarshaler } - err := parse("--foo abc", &args) + _, err := parse(&args, "--foo abc") require.NoError(t, err) assert.Equal(t, 3, args.Foo.val) } @@ -885,7 +874,7 @@ func TestRepeatedTextUnmarshaler(t *testing.T) { var args struct { Foo []textUnmarshaler } - err := parse("--foo abc d ef", &args) + _, err := parse(&args, "--foo abc d ef") require.NoError(t, err) require.Len(t, args.Foo, 3) assert.Equal(t, 3, args.Foo[0].val) @@ -898,7 +887,7 @@ func TestRepeatedPtrToTextUnmarshaler(t *testing.T) { var args struct { Foo []*textUnmarshaler } - err := parse("--foo abc d ef", &args) + _, err := parse(&args, "--foo abc d ef") require.NoError(t, err) require.Len(t, args.Foo, 3) assert.Equal(t, 3, args.Foo[0].val) @@ -911,7 +900,7 @@ func TestPositionalTextUnmarshaler(t *testing.T) { var args struct { Foo []textUnmarshaler `arg:"positional"` } - err := parse("abc d ef", &args) + _, err := parse(&args, "abc d ef") require.NoError(t, err) require.Len(t, args.Foo, 3) assert.Equal(t, 3, args.Foo[0].val) @@ -924,7 +913,7 @@ func TestPositionalPtrToTextUnmarshaler(t *testing.T) { var args struct { Foo []*textUnmarshaler `arg:"positional"` } - err := parse("abc d ef", &args) + _, err := parse(&args, "abc d ef") require.NoError(t, err) require.Len(t, args.Foo, 3) assert.Equal(t, 3, args.Foo[0].val) @@ -945,7 +934,7 @@ func TestBoolUnmarhsaler(t *testing.T) { var args struct { Foo *boolUnmarshaler } - err := parse("--foo ab", &args) + _, err := parse(&args, "--foo ab") require.NoError(t, err) assert.EqualValues(t, true, *args.Foo) } @@ -964,7 +953,7 @@ func TestSliceUnmarhsaler(t *testing.T) { Foo *sliceUnmarshaler Bar string `arg:"positional"` } - err := parse("--foo abcde xyz", &args) + _, err := parse(&args, "--foo abcde xyz") require.NoError(t, err) require.Len(t, *args.Foo, 1) assert.EqualValues(t, 5, (*args.Foo)[0]) @@ -975,7 +964,7 @@ func TestIP(t *testing.T) { var args struct { Host net.IP } - err := parse("--host 192.168.0.1", &args) + _, err := parse(&args, "--host 192.168.0.1") require.NoError(t, err) assert.Equal(t, "192.168.0.1", args.Host.String()) } @@ -984,7 +973,7 @@ func TestPtrToIP(t *testing.T) { var args struct { Host *net.IP } - err := parse("--host 192.168.0.1", &args) + _, err := parse(&args, "--host 192.168.0.1") require.NoError(t, err) assert.Equal(t, "192.168.0.1", args.Host.String()) } @@ -993,7 +982,7 @@ func TestURL(t *testing.T) { var args struct { URL url.URL } - err := parse("--url https://example.com/get?item=xyz", &args) + _, err := parse(&args, "--url https://example.com/get?item=xyz") require.NoError(t, err) assert.Equal(t, "https://example.com/get?item=xyz", args.URL.String()) } @@ -1002,7 +991,7 @@ func TestPtrToURL(t *testing.T) { var args struct { URL *url.URL } - err := parse("--url http://example.com/#xyz", &args) + _, err := parse(&args, "--url http://example.com/#xyz") require.NoError(t, err) assert.Equal(t, "http://example.com/#xyz", args.URL.String()) } @@ -1011,7 +1000,7 @@ func TestIPSlice(t *testing.T) { var args struct { Host []net.IP } - err := parse("--host 192.168.0.1 127.0.0.1", &args) + _, err := parse(&args, "--host 192.168.0.1 127.0.0.1") require.NoError(t, err) require.Len(t, args.Host, 2) assert.Equal(t, "192.168.0.1", args.Host[0].String()) @@ -1022,7 +1011,7 @@ func TestInvalidIPAddress(t *testing.T) { var args struct { Host net.IP } - err := parse("--host xxx", &args) + _, err := parse(&args, "--host xxx") assert.Error(t, err) } @@ -1030,7 +1019,7 @@ func TestMAC(t *testing.T) { var args struct { Host net.HardwareAddr } - err := parse("--host 0123.4567.89ab", &args) + _, err := parse(&args, "--host 0123.4567.89ab") require.NoError(t, err) assert.Equal(t, "01:23:45:67:89:ab", args.Host.String()) } @@ -1039,7 +1028,7 @@ func TestInvalidMac(t *testing.T) { var args struct { Host net.HardwareAddr } - err := parse("--host xxx", &args) + _, err := parse(&args, "--host xxx") assert.Error(t, err) } @@ -1047,7 +1036,7 @@ func TestMailAddr(t *testing.T) { var args struct { Recipient mail.Address } - err := parse("--recipient foo@example.com", &args) + _, err := parse(&args, "--recipient foo@example.com") require.NoError(t, err) assert.Equal(t, "", args.Recipient.String()) } @@ -1056,7 +1045,7 @@ func TestInvalidMailAddr(t *testing.T) { var args struct { Recipient mail.Address } - err := parse("--recipient xxx", &args) + _, err := parse(&args, "--recipient xxx") assert.Error(t, err) } @@ -1074,7 +1063,7 @@ func TestEmbedded(t *testing.T) { B Z bool } - err := parse("--x=hello --y=321 --z", &args) + _, err := parse(&args, "--x=hello --y=321 --z") require.NoError(t, err) assert.Equal(t, "hello", args.X) assert.Equal(t, 321, args.Y) @@ -1086,7 +1075,7 @@ func TestEmbeddedPtr(t *testing.T) { var args struct { *A } - err := parse("--x=hello", &args) + _, err := parse(&args, "--x=hello") require.Error(t, err) } @@ -1098,7 +1087,7 @@ func TestEmbeddedPtrIgnored(t *testing.T) { *A `arg:"-"` B } - err := parse("--y=321", &args) + _, err := parse(&args, "--y=321") require.NoError(t, err) assert.Equal(t, 321, args.Y) } @@ -1116,7 +1105,7 @@ func TestEmbeddedWithDuplicateField(t *testing.T) { U } - err := parse("--cat=cat --dog=dog", &args) + _, err := parse(&args, "--cat=cat --dog=dog") require.NoError(t, err) assert.Equal(t, "cat", args.T.A) assert.Equal(t, "dog", args.U.A) @@ -1135,7 +1124,7 @@ func TestEmbeddedWithDuplicateField2(t *testing.T) { U } - err := parse("--a=xyz", &args) + _, err := parse(&args, "--a=xyz") require.NoError(t, err) assert.Equal(t, "xyz", args.T.A) assert.Equal(t, "", args.U.A) @@ -1148,7 +1137,7 @@ func TestUnexportedEmbedded(t *testing.T) { var args struct { embeddedArgs } - err := parse("--foo bar", &args) + _, err := parse(&args, "--foo bar") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -1160,7 +1149,7 @@ func TestIgnoredEmbedded(t *testing.T) { var args struct { embeddedArgs `arg:"-"` } - err := parse("--foo bar", &args) + _, err := parse(&args, "--foo bar") require.Error(t, err) } @@ -1172,7 +1161,8 @@ func TestEmptyArgs(t *testing.T) { var args struct { Foo string } - MustParse(&args) + err := Parse(&args) + require.NoError(t, err) // put the original arguments back os.Args = origArgs @@ -1182,7 +1172,7 @@ func TestTooManyHyphens(t *testing.T) { var args struct { TooManyHyphens string `arg:"---x"` } - err := parse("--foo -", &args) + _, err := parse(&args, "--foo -") assert.Error(t, err) } @@ -1190,7 +1180,7 @@ func TestHyphenAsOption(t *testing.T) { var args struct { Foo string } - err := parse("--foo -", &args) + _, err := parse(&args, "--foo -") require.NoError(t, err) assert.Equal(t, "-", args.Foo) } @@ -1199,7 +1189,7 @@ func TestHyphenAsPositional(t *testing.T) { var args struct { Foo string `arg:"positional"` } - err := parse("-", &args) + _, err := parse(&args, "-") require.NoError(t, err) assert.Equal(t, "-", args.Foo) } @@ -1209,7 +1199,7 @@ func TestHyphenInMultiOption(t *testing.T) { Foo []string Bar int } - err := parse("--foo --- x - y --bar 3", &args) + _, err := parse(&args, "--foo --- x - y --bar 3") require.NoError(t, err) assert.Equal(t, []string{"---", "x", "-", "y"}, args.Foo) assert.Equal(t, 3, args.Bar) @@ -1219,7 +1209,7 @@ func TestHyphenInMultiPositional(t *testing.T) { var args struct { Foo []string `arg:"positional"` } - err := parse("--- x - y", &args) + _, err := parse(&args, "--- x - y") require.NoError(t, err) assert.Equal(t, []string{"---", "x", "-", "y"}, args.Foo) } @@ -1230,7 +1220,7 @@ func TestSeparate(t *testing.T) { Foo []string `arg:"--foo,-f,separate"` } - err := parse(val, &args) + _, err := parse(&args, val) require.NoError(t, err) assert.Equal(t, []string{"one"}, args.Foo) } @@ -1243,7 +1233,7 @@ func TestSeparateWithDefault(t *testing.T) { Foo: []string{"default"}, } - err := parse("-f one -f=two", &args) + _, err := parse(&args, "-f one -f=two") require.NoError(t, err) assert.Equal(t, []string{"default", "one", "two"}, args.Foo) } @@ -1255,7 +1245,7 @@ func TestSeparateWithPositional(t *testing.T) { Moo string `arg:"positional"` } - err := parse("zzz --foo one -f=two --foo=three -f four aaa", &args) + _, err := parse(&args, "zzz --foo one -f=two --foo=three -f four aaa") require.NoError(t, err) assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) assert.Equal(t, "zzz", args.Bar) @@ -1270,7 +1260,7 @@ func TestSeparatePositionalInterweaved(t *testing.T) { Post []string `arg:"positional"` } - err := parse("zzz -f foo1 -b=bar1 --foo=foo2 -b bar2 post1 -b bar3 post2 post3", &args) + _, err := parse(&args, "zzz -f foo1 -b=bar1 --foo=foo2 -b bar2 post1 -b bar3 post2 post3") require.NoError(t, err) assert.Equal(t, []string{"foo1", "foo2"}, args.Foo) assert.Equal(t, []string{"bar1", "bar2", "bar3"}, args.Bar) @@ -1283,7 +1273,7 @@ func TestSpacesAllowedInTags(t *testing.T) { Foo []string `arg:"--foo, -f, separate, required, help:quite nice really"` } - err := parse("--foo one -f=two --foo=three -f four", &args) + _, err := parse(&args, "--foo one -f=two --foo=three -f four") require.NoError(t, err) assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) } @@ -1306,7 +1296,7 @@ func TestReuseParser(t *testing.T) { func TestVersion(t *testing.T) { var args struct{} - err := parse("--version", &args) + _, err := parse(&args, "--version") assert.Equal(t, ErrVersion, err) } @@ -1317,7 +1307,7 @@ func TestMultipleTerminates(t *testing.T) { Y string `arg:"positional"` } - err := parse("--x a b -- c", &args) + _, err := parse(&args, "--x a b -- c") require.NoError(t, err) assert.Equal(t, []string{"a", "b"}, args.X) assert.Equal(t, "c", args.Y) @@ -1335,7 +1325,7 @@ func TestDefaultOptionValues(t *testing.T) { H *bool `default:"true"` } - err := parse("--c=xyz --e=4.56", &args) + _, err := parse(&args, "--c=xyz --e=4.56") require.NoError(t, err) assert.Equal(t, 123, args.A) @@ -1353,7 +1343,7 @@ func TestDefaultUnparseable(t *testing.T) { A int `default:"x"` } - err := parse("", &args) + _, err := parse(&args, "") assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`) } @@ -1369,7 +1359,7 @@ func TestDefaultPositionalValues(t *testing.T) { H *bool `arg:"positional" default:"true"` } - err := parse("456 789", &args) + _, err := parse(&args, "456 789") require.NoError(t, err) assert.Equal(t, 456, args.A) @@ -1387,7 +1377,7 @@ func TestDefaultValuesNotAllowedWithRequired(t *testing.T) { A int `arg:"required" default:"123"` // required not allowed with default! } - err := parse("", &args) + _, err := parse(&args, "") assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified") } @@ -1396,7 +1386,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { A []int `default:"123"` // required not allowed with default! } - err := parse("", &args) + _, err := parse(&args, "") assert.EqualError(t, err, ".A: default values are not supported for slice or map fields") } diff --git a/v2/subcommand_test.go b/v2/subcommand_test.go index 31dc2dd..b9f6f2f 100644 --- a/v2/subcommand_test.go +++ b/v2/subcommand_test.go @@ -42,7 +42,7 @@ func TestMinimalSubcommand(t *testing.T) { var args struct { List *listCmd `arg:"subcommand"` } - p, err := pparse("list", &args) + p, err := parse(&args, "list") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, args.List, p.Subcommand()) @@ -66,7 +66,7 @@ func TestNoSuchSubcommand(t *testing.T) { var args struct { List *listCmd `arg:"subcommand"` } - _, err := pparse("invalid", &args) + _, err := parse(&args, "invalid") assert.Error(t, err) } @@ -76,7 +76,7 @@ func TestNamedSubcommand(t *testing.T) { var args struct { List *listCmd `arg:"subcommand:ls"` } - p, err := pparse("ls", &args) + p, err := parse(&args, "ls") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, args.List, p.Subcommand()) @@ -89,7 +89,7 @@ func TestEmptySubcommand(t *testing.T) { var args struct { List *listCmd `arg:"subcommand"` } - p, err := pparse("", &args) + p, err := parse(&args, "") require.NoError(t, err) assert.Nil(t, args.List) assert.Nil(t, p.Subcommand()) @@ -105,7 +105,7 @@ func TestTwoSubcommands(t *testing.T) { Get *getCmd `arg:"subcommand"` List *listCmd `arg:"subcommand"` } - p, err := pparse("list", &args) + p, err := parse(&args, "list") require.NoError(t, err) assert.Nil(t, args.Get) assert.NotNil(t, args.List) @@ -128,7 +128,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("list", &args) + _, err := parse(&args, "list") require.NoError(t, err) assert.Nil(t, args.Get) assert.NotNil(t, args.List) @@ -136,7 +136,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("list --limit 3", &args) + _, err := parse(&args, "list --limit 3") require.NoError(t, err) assert.Nil(t, args.Get) assert.NotNil(t, args.List) @@ -145,7 +145,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("list --limit 3 --verbose", &args) + _, err := parse(&args, "list --limit 3 --verbose") require.NoError(t, err) assert.Nil(t, args.Get) assert.NotNil(t, args.List) @@ -155,7 +155,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("list --verbose --limit 3", &args) + _, err := parse(&args, "list --verbose --limit 3") require.NoError(t, err) assert.Nil(t, args.Get) assert.NotNil(t, args.List) @@ -165,7 +165,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("--verbose list --limit 3", &args) + _, err := parse(&args, "--verbose list --limit 3") require.NoError(t, err) assert.Nil(t, args.Get) assert.NotNil(t, args.List) @@ -175,7 +175,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("get", &args) + _, err := parse(&args, "get") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Nil(t, args.List) @@ -183,7 +183,7 @@ func TestSubcommandsWithOptions(t *testing.T) { { var args cmd - err := parse("get --name test", &args) + _, err := parse(&args, "get --name test") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Nil(t, args.List) @@ -206,7 +206,7 @@ func TestSubcommandsWithEnvVars(t *testing.T) { { var args cmd - _, err := parseWithEnv(&args, "list", "LIMIT=123") + _, err := parse(&args, "list", "LIMIT=123") require.NoError(t, err) require.NotNil(t, args.List) assert.Equal(t, 123, args.List.Limit) @@ -214,7 +214,7 @@ func TestSubcommandsWithEnvVars(t *testing.T) { { var args cmd - _, err := parseWithEnv(&args, "list", "LIMIT=not_an_integer") + _, err := parse(&args, "list", "LIMIT=not_an_integer") assert.Error(t, err) } } @@ -233,7 +233,7 @@ func TestNestedSubcommands(t *testing.T) { { var args root - p, err := pparse("grandparent parent child", &args) + p, err := parse(&args, "grandparent parent child") require.NoError(t, err) require.NotNil(t, args.Grandparent) require.NotNil(t, args.Grandparent.Parent) @@ -244,7 +244,7 @@ func TestNestedSubcommands(t *testing.T) { { var args root - p, err := pparse("grandparent parent", &args) + p, err := parse(&args, "grandparent parent") require.NoError(t, err) require.NotNil(t, args.Grandparent) require.NotNil(t, args.Grandparent.Parent) @@ -255,7 +255,7 @@ func TestNestedSubcommands(t *testing.T) { { var args root - p, err := pparse("grandparent", &args) + p, err := parse(&args, "grandparent") require.NoError(t, err) require.NotNil(t, args.Grandparent) require.Nil(t, args.Grandparent.Parent) @@ -265,7 +265,7 @@ func TestNestedSubcommands(t *testing.T) { { var args root - p, err := pparse("", &args) + p, err := parse(&args, "") require.NoError(t, err) require.Nil(t, args.Grandparent) assert.Nil(t, p.Subcommand()) @@ -284,7 +284,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("list", &args) + _, err := parse(&args, "list") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, "", args.List.Pattern) @@ -292,7 +292,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("list --format json", &args) + _, err := parse(&args, "list --format json") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, "", args.List.Pattern) @@ -301,7 +301,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("list somepattern", &args) + _, err := parse(&args, "list somepattern") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, "somepattern", args.List.Pattern) @@ -309,7 +309,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("list somepattern --format json", &args) + _, err := parse(&args, "list somepattern --format json") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, "somepattern", args.List.Pattern) @@ -318,7 +318,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("list --format json somepattern", &args) + _, err := parse(&args, "list --format json somepattern") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, "somepattern", args.List.Pattern) @@ -327,7 +327,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("--format json list somepattern", &args) + _, err := parse(&args, "--format json list somepattern") require.NoError(t, err) assert.NotNil(t, args.List) assert.Equal(t, "somepattern", args.List.Pattern) @@ -336,7 +336,7 @@ func TestSubcommandsWithPositionals(t *testing.T) { { var args cmd - err := parse("--format json", &args) + _, err := parse(&args, "--format json") require.NoError(t, err) assert.Nil(t, args.List) assert.Equal(t, "json", args.Format) @@ -353,7 +353,7 @@ func TestSubcommandsWithMultiplePositionals(t *testing.T) { { var args cmd - err := parse("get", &args) + _, err := parse(&args, "get") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Empty(t, args.Get.Items) @@ -361,7 +361,7 @@ func TestSubcommandsWithMultiplePositionals(t *testing.T) { { var args cmd - err := parse("get --limit 5", &args) + _, err := parse(&args, "get --limit 5") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Empty(t, args.Get.Items) @@ -370,7 +370,7 @@ func TestSubcommandsWithMultiplePositionals(t *testing.T) { { var args cmd - err := parse("get item1", &args) + _, err := parse(&args, "get item1") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Equal(t, []string{"item1"}, args.Get.Items) @@ -378,7 +378,7 @@ func TestSubcommandsWithMultiplePositionals(t *testing.T) { { var args cmd - err := parse("get item1 item2 item3", &args) + _, err := parse(&args, "get item1 item2 item3") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Equal(t, []string{"item1", "item2", "item3"}, args.Get.Items) @@ -386,7 +386,7 @@ func TestSubcommandsWithMultiplePositionals(t *testing.T) { { var args cmd - err := parse("get item1 --limit 5 item2", &args) + _, err := parse(&args, "get item1 --limit 5 item2") require.NoError(t, err) assert.NotNil(t, args.Get) assert.Equal(t, []string{"item1", "item2"}, args.Get.Items) From 5f0c48f092161d61d13a2ebeda134c51c2d5b2d9 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 11:54:22 -0700 Subject: [PATCH 07/19] move construction logic out of parse.go into construct.go --- v2/construct.go | 317 +++++++++++++++++++++++++++++++++ v2/construct_test.go | 25 +++ v2/parse.go | 412 ++++++------------------------------------- v2/parse_test.go | 27 +-- 4 files changed, 399 insertions(+), 382 deletions(-) create mode 100644 v2/construct.go create mode 100644 v2/construct_test.go diff --git a/v2/construct.go b/v2/construct.go new file mode 100644 index 0000000..bed64eb --- /dev/null +++ b/v2/construct.go @@ -0,0 +1,317 @@ +package arg + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" +) + +// Argument represents a command line argument +type Argument struct { + dest path + field reflect.StructField // the struct field from which this option was created + long string // the --long form for this option, or empty if none + short string // the -s short form for this option, or empty if none + cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) + required bool // if true, this option must be present on the command line + positional bool // if true, this option will be looked for in the positional flags + separate bool // if true, each slice and map entry will have its own --flag + help string // the help text for this option + env string // the name of the environment variable for this option, or empty for none + defaultVal string // default value for this option + placeholder string // name of the data in help +} + +// Command represents a named subcommand, or the top-level command +type Command struct { + name string + help string + dest path + args []*Argument + subcommands []*Command + parent *Command +} + +// Parser represents a set of command line options with destination values +type Parser struct { + cmd *Command // the top-level command + root reflect.Value // destination struct to fill will values + version string // version from the argument struct + prologue string // prologue for help text (from the argument struct) + epilogue string // epilogue for help text (from the argument struct) + + // the following fields are updated during processing of command line arguments + leaf *Command // the subcommand we processed last + accumulatedArgs []*Argument // concatenation of the leaf subcommand's arguments plus all ancestors' arguments + seen map[*Argument]bool // the arguments we encountered while processing command line arguments +} + +// Versioned is the interface that the destination struct should implement to +// make a version string appear at the top of the help message. +type Versioned interface { + // Version returns the version string that will be printed on a line by itself + // at the top of the help message. + Version() string +} + +// Described is the interface that the destination struct should implement to +// make a description string appear at the top of the help message. +type Described interface { + // Description returns the string that will be printed on a line by itself + // at the top of the help message. + Description() string +} + +// Epilogued is the interface that the destination struct should implement to +// add an epilogue string at the bottom of the help message. +type Epilogued interface { + // Epilogue returns the string that will be printed on a line by itself + // at the end of the help message. + Epilogue() string +} + +// the ParserOption interface matches options for the parser constructor +type ParserOption interface { + parserOption() +} + +type programNameParserOption struct { + s string +} + +func (programNameParserOption) parserOption() {} + +// WithProgramName overrides the name of the program as displayed in help test +func WithProgramName(name string) ParserOption { + return programNameParserOption{s: name} +} + +// NewParser constructs a parser from a list of destination structs +func NewParser(dest interface{}, options ...ParserOption) (*Parser, error) { + // check the destination type + t := reflect.TypeOf(dest) + if t.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) + } + + // pick a program name for help text and usage output + program := "program" + if len(os.Args) > 0 { + program = filepath.Base(os.Args[0]) + } + + // apply the options + for _, opt := range options { + switch opt := opt.(type) { + case programNameParserOption: + program = opt.s + } + } + + // build the root command from the struct + cmd, err := cmdFromStruct(program, path{}, t) + if err != nil { + return nil, err + } + + // construct the parser + p := Parser{ + seen: make(map[*Argument]bool), + root: reflect.ValueOf(dest), + cmd: cmd, + } + + // check for version, prologue, and epilogue + if dest, ok := dest.(Versioned); ok { + p.version = dest.Version() + } + if dest, ok := dest.(Described); ok { + p.prologue = dest.Description() + } + if dest, ok := dest.(Epilogued); ok { + p.epilogue = dest.Epilogue() + } + + return &p, nil +} + +func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { + // commands can only be created from pointers to structs + if t.Kind() != reflect.Ptr { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", + dest, t.Kind()) + } + + t = t.Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", + dest, t.Kind()) + } + + cmd := Command{ + name: name, + dest: dest, + } + + var errs []string + walkFields(t, func(field reflect.StructField, t reflect.Type) bool { + // check for the ignore switch in the tag + tag := field.Tag.Get("arg") + if tag == "-" { + return false + } + + // if this is an embedded struct then recurse into its fields, even if + // it is unexported, because exported fields on unexported embedded + // structs are still writable + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true + } + + // ignore any other unexported field + if !isExported(field.Name) { + return false + } + + // duplicate the entire path to avoid slice overwrites + subdest := dest.Child(field) + arg := Argument{ + dest: subdest, + field: field, + long: strings.ToLower(field.Name), + } + + help, exists := field.Tag.Lookup("help") + if exists { + arg.help = help + } + + defaultVal, hasDefault := field.Tag.Lookup("default") + if hasDefault { + arg.defaultVal = defaultVal + } + + // Look at the tag + var isSubcommand bool // tracks whether this field is a subcommand + for _, key := range strings.Split(tag, ",") { + if key == "" { + continue + } + key = strings.TrimLeft(key, " ") + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } + + switch { + case strings.HasPrefix(key, "---"): + errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) + case strings.HasPrefix(key, "--"): + arg.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) + return false + } + arg.short = key[1:] + case key == "required": + if hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", + t.Name(), field.Name)) + return false + } + arg.required = true + case key == "positional": + arg.positional = true + case key == "separate": + arg.separate = true + case key == "help": // deprecated + arg.help = value + case key == "env": + // Use override name if provided + if value != "" { + arg.env = value + } else { + arg.env = strings.ToUpper(field.Name) + } + case key == "subcommand": + // decide on a name for the subcommand + cmdname := value + if cmdname == "" { + cmdname = strings.ToLower(field.Name) + } + + // parse the subcommand recursively + subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) + if err != nil { + errs = append(errs, err.Error()) + return false + } + + subcmd.parent = &cmd + subcmd.help = field.Tag.Get("help") + + cmd.subcommands = append(cmd.subcommands, subcmd) + isSubcommand = true + default: + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false + } + } + + placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") + if hasPlaceholder { + arg.placeholder = placeholder + } else if arg.long != "" { + arg.placeholder = strings.ToUpper(arg.long) + } else { + arg.placeholder = strings.ToUpper(arg.field.Name) + } + + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + if !isSubcommand { + cmd.args = append(cmd.args, &arg) + + var err error + arg.cardinality, err = cardinalityOf(field.Type) + if err != nil { + errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", + t.Name(), field.Name, field.Type.String())) + return false + } + if arg.cardinality == multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", + t.Name(), field.Name)) + return false + } + } + + // if this was an embedded field then we already returned true up above + return false + }) + + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) + } + + // check that we don't have both positionals and subcommands + var hasPositional bool + for _, arg := range cmd.args { + if arg.positional { + hasPositional = true + } + } + if hasPositional && len(cmd.subcommands) > 0 { + return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest) + } + + return &cmd, nil +} diff --git a/v2/construct_test.go b/v2/construct_test.go new file mode 100644 index 0000000..66b2f80 --- /dev/null +++ b/v2/construct_test.go @@ -0,0 +1,25 @@ +package arg + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInvalidTag(t *testing.T) { + var args struct { + Foo string `arg:"this_is_not_valid"` + } + _, err := NewParser(&args) + assert.Error(t, err) +} + +func TestUnexportedFieldsSkipped(t *testing.T) { + var args struct { + unexported struct{} + } + + _, err := NewParser(&args) + require.NoError(t, err) +} diff --git a/v2/parse.go b/v2/parse.go index f5bcab4..1a29e35 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -5,65 +5,12 @@ import ( "errors" "fmt" "os" - "path/filepath" "reflect" "strings" scalar "github.com/alexflint/go-scalar" ) -// path represents a sequence of steps to find the output location for an -// argument or subcommand in the final destination struct -type path struct { - fields []reflect.StructField // sequence of struct fields to traverse -} - -// String gets a string representation of the given path -func (p path) String() string { - s := "args" - for _, f := range p.fields { - s += "." + f.Name - } - return s -} - -// Child gets a new path representing a child of this path. -func (p path) Child(f reflect.StructField) path { - // copy the entire slice of fields to avoid possible slice overwrite - subfields := make([]reflect.StructField, len(p.fields)+1) - copy(subfields, p.fields) - subfields[len(subfields)-1] = f - return path{ - fields: subfields, - } -} - -// Arg represents a command line argument -type Argument struct { - dest path - field reflect.StructField // the struct field from which this option was created - long string // the --long form for this option, or empty if none - short string // the -s short form for this option, or empty if none - cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) - required bool // if true, this option must be present on the command line - positional bool // if true, this option will be looked for in the positional flags - separate bool // if true, each slice and map entry will have its own --flag - help string // the help text for this option - env string // the name of the environment variable for this option, or empty for none - defaultVal string // default value for this option - placeholder string // name of the data in help -} - -// Command represents a named subcommand, or the top-level command -type Command struct { - name string - help string - dest path - args []*Argument - subcommands []*Command - parent *Command -} - // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") @@ -103,308 +50,6 @@ func Parse(dest interface{}, options ...ParserOption) error { return p.Parse(os.Args, os.Environ()) } -// Parser represents a set of command line options with destination values -type Parser struct { - cmd *Command // the top-level command - root reflect.Value // destination struct to fill will values - version string // version from the argument struct - prologue string // prologue for help text (from the argument struct) - epilogue string // epilogue for help text (from the argument struct) - - // the following fields are updated during processing of command line arguments - leaf *Command // the subcommand we processed last - accumulatedArgs []*Argument // concatenation of the leaf subcommand's arguments plus all ancestors' arguments - seen map[*Argument]bool // the arguments we encountered while processing command line arguments -} - -// Versioned is the interface that the destination struct should implement to -// make a version string appear at the top of the help message. -type Versioned interface { - // Version returns the version string that will be printed on a line by itself - // at the top of the help message. - Version() string -} - -// Described is the interface that the destination struct should implement to -// make a description string appear at the top of the help message. -type Described interface { - // Description returns the string that will be printed on a line by itself - // at the top of the help message. - Description() string -} - -// Epilogued is the interface that the destination struct should implement to -// add an epilogue string at the bottom of the help message. -type Epilogued interface { - // Epilogue returns the string that will be printed on a line by itself - // at the end of the help message. - Epilogue() string -} - -// walkFields calls a function for each field of a struct, recursively expanding struct fields. -func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { - walkFieldsImpl(t, visit, nil) -} - -func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) { - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - field.Index = make([]int, len(path)+1) - copy(field.Index, append(path, i)) - expand := visit(field, t) - if expand && field.Type.Kind() == reflect.Struct { - var subpath []int - if field.Anonymous { - subpath = append(path, i) - } - walkFieldsImpl(field.Type, visit, subpath) - } - } -} - -// the ParserOption interface matches options for the parser constructor -type ParserOption interface { - parserOption() -} - -type programNameParserOption struct { - s string -} - -func (programNameParserOption) parserOption() {} - -// WithProgramName overrides the name of the program as displayed in help test -func WithProgramName(name string) ParserOption { - return programNameParserOption{s: name} -} - -// NewParser constructs a parser from a list of destination structs -func NewParser(dest interface{}, options ...ParserOption) (*Parser, error) { - // check the destination type - t := reflect.TypeOf(dest) - if t.Kind() != reflect.Ptr { - panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) - } - - // pick a program name for help text and usage output - program := "program" - if len(os.Args) > 0 { - program = filepath.Base(os.Args[0]) - } - - // apply the options - for _, opt := range options { - switch opt := opt.(type) { - case programNameParserOption: - program = opt.s - } - } - - // build the root command from the struct - cmd, err := cmdFromStruct(program, path{}, t) - if err != nil { - return nil, err - } - - // construct the parser - p := Parser{ - seen: make(map[*Argument]bool), - root: reflect.ValueOf(dest), - cmd: cmd, - } - - // check for version, prologue, and epilogue - if dest, ok := dest.(Versioned); ok { - p.version = dest.Version() - } - if dest, ok := dest.(Described); ok { - p.prologue = dest.Description() - } - if dest, ok := dest.(Epilogued); ok { - p.epilogue = dest.Epilogue() - } - - return &p, nil -} - -func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { - // commands can only be created from pointers to structs - if t.Kind() != reflect.Ptr { - return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s", - dest, t.Kind()) - } - - t = t.Elem() - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s", - dest, t.Kind()) - } - - cmd := Command{ - name: name, - dest: dest, - } - - var errs []string - walkFields(t, func(field reflect.StructField, t reflect.Type) bool { - // check for the ignore switch in the tag - tag := field.Tag.Get("arg") - if tag == "-" { - return false - } - - // if this is an embedded struct then recurse into its fields, even if - // it is unexported, because exported fields on unexported embedded - // structs are still writable - if field.Anonymous && field.Type.Kind() == reflect.Struct { - return true - } - - // ignore any other unexported field - if !isExported(field.Name) { - return false - } - - // duplicate the entire path to avoid slice overwrites - subdest := dest.Child(field) - arg := Argument{ - dest: subdest, - field: field, - long: strings.ToLower(field.Name), - } - - help, exists := field.Tag.Lookup("help") - if exists { - arg.help = help - } - - defaultVal, hasDefault := field.Tag.Lookup("default") - if hasDefault { - arg.defaultVal = defaultVal - } - - // Look at the tag - var isSubcommand bool // tracks whether this field is a subcommand - for _, key := range strings.Split(tag, ",") { - if key == "" { - continue - } - key = strings.TrimLeft(key, " ") - var value string - if pos := strings.Index(key, ":"); pos != -1 { - value = key[pos+1:] - key = key[:pos] - } - - switch { - case strings.HasPrefix(key, "---"): - errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) - case strings.HasPrefix(key, "--"): - arg.long = key[2:] - case strings.HasPrefix(key, "-"): - if len(key) != 2 { - errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", - t.Name(), field.Name)) - return false - } - arg.short = key[1:] - case key == "required": - if hasDefault { - errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", - t.Name(), field.Name)) - return false - } - arg.required = true - case key == "positional": - arg.positional = true - case key == "separate": - arg.separate = true - case key == "help": // deprecated - arg.help = value - case key == "env": - // Use override name if provided - if value != "" { - arg.env = value - } else { - arg.env = strings.ToUpper(field.Name) - } - case key == "subcommand": - // decide on a name for the subcommand - cmdname := value - if cmdname == "" { - cmdname = strings.ToLower(field.Name) - } - - // parse the subcommand recursively - subcmd, err := cmdFromStruct(cmdname, subdest, field.Type) - if err != nil { - errs = append(errs, err.Error()) - return false - } - - subcmd.parent = &cmd - subcmd.help = field.Tag.Get("help") - - cmd.subcommands = append(cmd.subcommands, subcmd) - isSubcommand = true - default: - errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) - return false - } - } - - placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") - if hasPlaceholder { - arg.placeholder = placeholder - } else if arg.long != "" { - arg.placeholder = strings.ToUpper(arg.long) - } else { - arg.placeholder = strings.ToUpper(arg.field.Name) - } - - // Check whether this field is supported. It's good to do this here rather than - // wait until ParseValue because it means that a program with invalid argument - // fields will always fail regardless of whether the arguments it received - // exercised those fields. - if !isSubcommand { - cmd.args = append(cmd.args, &arg) - - var err error - arg.cardinality, err = cardinalityOf(field.Type) - if err != nil { - errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", - t.Name(), field.Name, field.Type.String())) - return false - } - if arg.cardinality == multiple && hasDefault { - errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", - t.Name(), field.Name)) - return false - } - } - - // if this was an embedded field then we already returned true up above - return false - }) - - if len(errs) > 0 { - return nil, errors.New(strings.Join(errs, "\n")) - } - - // check that we don't have both positionals and subcommands - var hasPositional bool - for _, arg := range cmd.args { - if arg.positional { - hasPositional = true - } - } - if hasPositional && len(cmd.subcommands) > 0 { - return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest) - } - - return &cmd, nil -} - // Parse processes the given command line option, storing the results in the field // of the structs from which NewParser was constructed func (p *Parser) Parse(args, env []string) error { @@ -763,11 +408,6 @@ func (p *Parser) Validate() error { return nil } -// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--" -func isFlag(s string) bool { - return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" -} - // val returns a reflect.Value corresponding to the current value for the // given path func (p *Parser) val(dest path) reflect.Value { @@ -785,6 +425,58 @@ func (p *Parser) val(dest path) reflect.Value { return v } +// path represents a sequence of steps to find the output location for an +// argument or subcommand in the final destination struct +type path struct { + fields []reflect.StructField // sequence of struct fields to traverse +} + +// String gets a string representation of the given path +func (p path) String() string { + s := "args" + for _, f := range p.fields { + s += "." + f.Name + } + return s +} + +// Child gets a new path representing a child of this path. +func (p path) Child(f reflect.StructField) path { + // copy the entire slice of fields to avoid possible slice overwrite + subfields := make([]reflect.StructField, len(p.fields)+1) + copy(subfields, p.fields) + subfields[len(subfields)-1] = f + return path{ + fields: subfields, + } +} + +// walkFields calls a function for each field of a struct, recursively expanding struct fields. +func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { + walkFieldsImpl(t, visit, nil) +} + +func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + field.Index = make([]int, len(path)+1) + copy(field.Index, append(path, i)) + expand := visit(field, t) + if expand && field.Type.Kind() == reflect.Struct { + var subpath []int + if field.Anonymous { + subpath = append(path, i) + } + walkFieldsImpl(field.Type, visit, subpath) + } + } +} + +// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--" +func isFlag(s string) bool { + return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" +} + // findOption finds an option from its name, or returns nil if no arg is found func findOption(args []*Argument, name string) *Argument { for _, arg := range args { diff --git a/v2/parse_test.go b/v2/parse_test.go index c65ded8..ee46006 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -629,14 +629,6 @@ func TestParse(t *testing.T) { assert.Equal(t, "bar", args.Foo) } -func TestParseError(t *testing.T) { - var args struct { - Foo string `arg:"this_is_not_valid"` - } - _, err := NewParser(&args) - assert.Error(t, err) -} - func TestMustParse(t *testing.T) { var args struct { Foo string @@ -795,13 +787,13 @@ func TestEnvironmentVariableIgnored(t *testing.T) { } func TestDefaultValuesIgnored(t *testing.T) { - var args struct { - Foo string `default:"bad"` - } - - // just checking that default values are not automatically applied + // check that default values are not automatically applied // in ProcessCommandLine or ProcessEnvironment + var args struct { + Foo string `default:"hello"` + } + p, err := NewParser(&args) require.NoError(t, err) @@ -1390,15 +1382,6 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { assert.EqualError(t, err, ".A: default values are not supported for slice or map fields") } -func TestUnexportedFieldsSkipped(t *testing.T) { - var args struct { - unexported struct{} - } - - _, err := NewParser(&args) - require.NoError(t, err) -} - func TestMustParseInvalidParser(t *testing.T) { originalExit := osExit originalStdout := stdout From 2775f58376287528adf36cdcf551a72093a0c40c Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 12:34:53 -0700 Subject: [PATCH 08/19] add OverwriteWithOptions, OverwriteWithCommandLine --- v2/construct.go | 2 +- v2/parse.go | 418 ++++++++++++++++++++++++++------------------ v2/parse_test.go | 1 + v2/sequence.go | 97 +++++++++- v2/sequence_test.go | 4 + 5 files changed, 346 insertions(+), 176 deletions(-) diff --git a/v2/construct.go b/v2/construct.go index bed64eb..bd2800e 100644 --- a/v2/construct.go +++ b/v2/construct.go @@ -176,7 +176,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { return false } - // duplicate the entire path to avoid slice overwrites + // create a new destination path for this field subdest := dest.Child(field) arg := Argument{ dest: subdest, diff --git a/v2/parse.go b/v2/parse.go index 1a29e35..251ddb6 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -93,6 +93,250 @@ func (p *Parser) Parse(args, env []string) error { return p.Validate() } +// ProcessCommandLine scans arguments one-by-one, parses them and assigns +// the result to fields of the struct passed to NewParser. It returns +// an error if an argument is invalid or unknown, but not if a +// required argument is missing. To check that all required arguments +// are set, call Validate(). This function ignores the first element +// of args, which is assumed to be the program name itself. This function +// never overwrites arguments previously seen in a call to any Process* +// function. +func (p *Parser) ProcessCommandLine(args []string) error { + positionals, err := p.ProcessOptions(args) + if err != nil { + return err + } + return p.ProcessPositionals(positionals) +} + +// OverwriteWithCommandLine is like ProcessCommandLine but it overwrites +// any previously seen values. +func (p *Parser) OverwriteWithCommandLine(args []string) error { + positionals, err := p.OverwriteWithOptions(args) + if err != nil { + return err + } + return p.OverwriteWithPositionals(positionals) +} + +// ProcessOptions processes options but not positionals from the +// command line. Positionals are returned and can be passed to +// ProcessPositionals. This function ignores the first element of args, +// which is assumed to be the program name itself. Arguments seen +// in a previous call to any Process* or OverwriteWith* functions +// are ignored. +func (p *Parser) ProcessOptions(args []string) ([]string, error) { + return p.processOptions(args, false) +} + +// OverwriteWithOptions is like ProcessOptions except previously seen +// arguments are overwritten +func (p *Parser) OverwriteWithOptions(args []string) ([]string, error) { + return p.processOptions(args, true) +} + +func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) { + // union of args for the chain of subcommands encountered so far + p.leaf = p.cmd + + // we will add to this list each time we expand a subcommand + p.accumulatedArgs = make([]*Argument, len(p.leaf.args)) + copy(p.accumulatedArgs, p.leaf.args) + + // process each string from the command line + var allpositional bool + var positionals []string + + // must use explicit for loop, not range, because we manipulate i inside the loop + for i := 1; i < len(args); i++ { + token := args[i] + + // the "--" token indicates that all further tokens should be treated as positionals + if token == "--" { + allpositional = true + continue + } + + // check whether this is a positional argument + if !isFlag(token) || allpositional { + // each subcommand can have either subcommands or positionals, but not both + if len(p.leaf.subcommands) == 0 { + positionals = append(positionals, token) + continue + } + + // if we have a subcommand then make sure it is valid for the current context + subcmd := findSubcommand(p.leaf.subcommands, token) + if subcmd == nil { + return nil, fmt.Errorf("invalid subcommand: %s", token) + } + + // instantiate the field to point to a new struct + v := p.val(subcmd.dest) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers + } + + // add the new options to the set of allowed options + p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) + p.leaf = subcmd + continue + } + + // check for special --help and --version flags + switch token { + case "-h", "--help": + return nil, ErrHelp + case "--version": + return nil, ErrVersion + } + + // check for an equals sign, as in "--foo=bar" + var value string + opt := strings.TrimLeft(token, "-") + if pos := strings.Index(opt, "="); pos != -1 { + value = opt[pos+1:] + opt = opt[:pos] + } + + // look up the arg for this option (note that the "args" slice changes as + // we expand subcommands so it is better not to use a map) + arg := findOption(p.accumulatedArgs, opt) + if arg == nil { + return nil, fmt.Errorf("unknown argument %s", token) + } + + // deal with the case of multiple values + if arg.cardinality == multiple { + // if arg.separate is true then just parse one value and append it + if arg.separate { + if value == "" { + if i+1 == len(args) { + return nil, fmt.Errorf("missing value for %s", token) + } + if isFlag(args[i+1]) { + return nil, fmt.Errorf("missing value for %s", token) + } + value = args[i+1] + i++ + } + + err := appendToSliceOrMap(p.val(arg.dest), value) + if err != nil { + return nil, fmt.Errorf("error processing %s: %v", token, err) + } + p.seen[arg] = true + continue + } + + // if args.separate is not true then consume tokens until next --option + var values []string + if value == "" { + for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { + values = append(values, args[i+1]) + i++ + } + } else { + values = append(values, value) + } + + // this is the first time we can check p.seen because we need to correctly + // increment i above, even when we then ignore the value + if p.seen[arg] && !overwrite { + continue + } + + // store the values into the slice or map + err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) + if err != nil { + return nil, fmt.Errorf("error processing %s: %v", token, err) + } + continue + } + + // if it's a flag and it has no value then set the value to true + // use boolean because this takes account of TextUnmarshaler + if arg.cardinality == zero && value == "" { + value = "true" + } + + // if we have something like "--foo" then the value is the next argument + if value == "" { + if i+1 == len(args) { + return nil, fmt.Errorf("missing value for %s", token) + } + if isFlag(args[i+1]) { + return nil, fmt.Errorf("missing value for %s", token) + } + value = args[i+1] + i++ + } + + // this is the first time we can check p.seen because we need to correctly + // increment i above, even when we then ignore the value + if p.seen[arg] && !overwrite { + continue + } + + err := scalar.ParseValue(p.val(arg.dest), value) + if err != nil { + return nil, fmt.Errorf("error processing %s: %v", token, err) + } + p.seen[arg] = true + } + + return positionals, nil +} + +// ProcessPositionals processes a list of positional arguments. If +// this list contains tokens that begin with a hyphen they will still be +// treated as positional arguments. Arguments seen in a previous call +// to any Process* or OverwriteWith* functions are ignored. +func (p *Parser) ProcessPositionals(positionals []string) error { + return p.processPositionals(positionals, false) +} + +// OverwriteWithPositionals is like ProcessPositionals except previously +// seen arguments are overwritten. +func (p *Parser) OverwriteWithPositionals(positionals []string) error { + return p.processPositionals(positionals, true) +} + +func (p *Parser) processPositionals(positionals []string, overwrite bool) error { + for _, arg := range p.accumulatedArgs { + if !arg.positional { + continue + } + if len(positionals) == 0 { + break + } + if arg.cardinality == multiple { + if !p.seen[arg] || overwrite { + err := setSliceOrMap(p.val(arg.dest), positionals, true) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) + } + } + positionals = nil + } else { + if !p.seen[arg] || overwrite { + err := scalar.ParseValue(p.val(arg.dest), positionals[0]) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg.field.Name, err) + } + } + positionals = positionals[1:] + } + p.seen[arg] = true + } + + if len(positionals) > 0 { + return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) + } + + return nil +} + // ProcessEnvironment processes environment variables from a list of strings // of the form KEY=VALUE. You can pass in os.Environ(). It // does not overwrite any fields with values already populated. @@ -167,180 +411,6 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { return nil } -// ProcessCommandLine goes through arguments one-by-one, parses them, -// and assigns the result to the underlying struct field. It returns -// an error if an argument is invalid or an option is unknown, not if a -// required argument is missing. To check that all required arguments -// are set, call CheckRequired(). This function ignores the first element -// of args, which is assumed to be the program name itself. -func (p *Parser) ProcessCommandLine(args []string) error { - positionals, err := p.ProcessOptions(args) - if err != nil { - return err - } - return p.ProcessPositionals(positionals) -} - -// ProcessOptions process command line arguments but does not process -// positional arguments. Instead, it returns positionals. These can then -// be passed to ProcessPositionals. This function ignores the first element -// of args, which is assumed to be the program name itself. -func (p *Parser) ProcessOptions(args []string) ([]string, error) { - // union of args for the chain of subcommands encountered so far - curCmd := p.cmd - p.leaf = curCmd - - // we will add to this list each time we expand a subcommand - p.accumulatedArgs = make([]*Argument, len(curCmd.args)) - copy(p.accumulatedArgs, curCmd.args) - - // process each string from the command line - var allpositional bool - var positionals []string - - // must use explicit for loop, not range, because we manipulate i inside the loop - for i := 1; i < len(args); i++ { - token := args[i] - if token == "--" { - allpositional = true - continue - } - - if !isFlag(token) || allpositional { - // each subcommand can have either subcommands or positionals, but not both - if len(curCmd.subcommands) == 0 { - positionals = append(positionals, token) - continue - } - - // if we have a subcommand then make sure it is valid for the current context - subcmd := findSubcommand(curCmd.subcommands, token) - if subcmd == nil { - return nil, fmt.Errorf("invalid subcommand: %s", token) - } - - // instantiate the field to point to a new struct - v := p.val(subcmd.dest) - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers - } - - // add the new options to the set of allowed options - p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) - - curCmd = subcmd - p.leaf = curCmd - continue - } - - // check for special --help and --version flags - switch token { - case "-h", "--help": - return nil, ErrHelp - case "--version": - return nil, ErrVersion - } - - // check for an equals sign, as in "--foo=bar" - var value string - opt := strings.TrimLeft(token, "-") - if pos := strings.Index(opt, "="); pos != -1 { - value = opt[pos+1:] - opt = opt[:pos] - } - - // look up the arg for this option (note that the "args" slice changes as - // we expand subcommands so it is better not to use a map) - arg := findOption(p.accumulatedArgs, opt) - if arg == nil { - return nil, fmt.Errorf("unknown argument %s", token) - } - p.seen[arg] = true - - // deal with the case of multiple values - if arg.cardinality == multiple { - var values []string - if value == "" { - for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { - values = append(values, args[i+1]) - i++ - if arg.separate { - break - } - } - } else { - values = append(values, value) - } - err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - continue - } - - // if it's a flag and it has no value then set the value to true - // use boolean because this takes account of TextUnmarshaler - if arg.cardinality == zero && value == "" { - value = "true" - } - - // if we have something like "--foo" then the value is the next argument - if value == "" { - if i+1 == len(args) { - return nil, fmt.Errorf("missing value for %s", token) - } - if isFlag(args[i+1]) { - return nil, fmt.Errorf("missing value for %s", token) - } - value = args[i+1] - i++ - } - - p.seen[arg] = true - err := scalar.ParseValue(p.val(arg.dest), value) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - } - - return positionals, nil -} - -// ProcessPositionals processes a list of positional arguments. It is assumed -// that options such as --abc and --abc=123 have already been removed. If -// this list contains tokens that begin with a hyphen they will still be -// treated as positional arguments. -func (p *Parser) ProcessPositionals(positionals []string) error { - for _, arg := range p.accumulatedArgs { - if !arg.positional { - continue - } - if len(positionals) == 0 { - break - } - p.seen[arg] = true - if arg.cardinality == multiple { - err := setSliceOrMap(p.val(arg.dest), positionals, true) - if err != nil { - return fmt.Errorf("error processing %s: %v", arg.field.Name, err) - } - positionals = nil - } else { - err := scalar.ParseValue(p.val(arg.dest), positionals[0]) - if err != nil { - return fmt.Errorf("error processing %s: %v", arg.field.Name, err) - } - positionals = positionals[1:] - } - } - - if len(positionals) > 0 { - return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) - } - - return nil -} - // ProcessDefaults assigns default values to all fields that have default values and // are not already populated. func (p *Parser) ProcessDefaults() error { diff --git a/v2/parse_test.go b/v2/parse_test.go index ee46006..72efa79 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -38,6 +38,7 @@ func TestString(t *testing.T) { _, err := parse(&args, "--foo bar --ptr baz") require.NoError(t, err) assert.Equal(t, "bar", args.Foo) + require.NotNil(t, args.Ptr) assert.Equal(t, "baz", *args.Ptr) } diff --git a/v2/sequence.go b/v2/sequence.go index 35a3614..f0fff46 100644 --- a/v2/sequence.go +++ b/v2/sequence.go @@ -27,7 +27,7 @@ func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { case reflect.Map: return setMap(dest, values, clear) default: - return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t) + return fmt.Errorf("cannot insert multiple values into a %v", t) } } @@ -121,3 +121,98 @@ func setMap(dest reflect.Value, values []string, clear bool) error { } return nil } + +// appendSliceOrMap parses a string and appends it to an existing slice or map. +func appendToSliceOrMap(dest reflect.Value, value string) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + t := dest.Type() + if t.Kind() == reflect.Ptr { + dest = dest.Elem() + t = t.Elem() + } + + switch t.Kind() { + case reflect.Slice: + return appendToSlice(dest, value) + case reflect.Map: + return appendToMap(dest, value) + default: + return fmt.Errorf("cannot insert multiple values into a %v", t) + } +} + +// appendSlice parses a string and appends the result into a slice. +func appendToSlice(dest reflect.Value, s string) error { + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { + ptr = true + elem = elem.Elem() + } + + // parse the value and append + v := reflect.New(elem) + if err := scalar.ParseValue(v.Elem(), s); err != nil { + return err + } + if !ptr { + v = v.Elem() + } + dest.Set(reflect.Append(dest, v)) + return nil +} + +// appendToMap parses a name=value string and inserts it into a map. +// If clear is true then any values already in the map are removed. +func appendToMap(dest reflect.Value, s string) error { + // determine the key and value type + var keyIsPtr bool + keyType := dest.Type().Key() + if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) { + keyIsPtr = true + keyType = keyType.Elem() + } + + var valIsPtr bool + valType := dest.Type().Elem() + if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) { + valIsPtr = true + valType = valType.Elem() + } + + // allocate the map if it is not allocated + if dest.IsNil() { + dest.Set(reflect.MakeMap(dest.Type())) + } + + // split at the first equals sign + pos := strings.Index(s, "=") + if pos == -1 { + return fmt.Errorf("cannot parse %q into a map, expected format key=value", s) + } + + // parse the key + k := reflect.New(keyType) + if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil { + return err + } + if !keyIsPtr { + k = k.Elem() + } + + // parse the value + v := reflect.New(valType) + if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil { + return err + } + if !valIsPtr { + v = v.Elem() + } + + // add it to the map + dest.SetMapIndex(k, v) + return nil +} diff --git a/v2/sequence_test.go b/v2/sequence_test.go index fde3e3a..6383949 100644 --- a/v2/sequence_test.go +++ b/v2/sequence_test.go @@ -150,3 +150,7 @@ func TestSetSliceOrMapErrors(t *testing.T) { err = setSliceOrMap(dest, nil, false) assert.Error(t, err) } + +// check that we can accumulate "separate" args across env, cmdline, map, and defaults + +// check what happens if we have a required arg with a default value From 64288c5521c31228a7f06de14261c24b66626d8e Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 12:48:04 -0700 Subject: [PATCH 09/19] add appendToSlice, appendToMap, appendToSliceOrMap --- v2/parse.go | 27 +++++++++++++------------ v2/sequence.go | 24 +++++++++++----------- v2/sequence_test.go | 49 +++++++++++++++++++-------------------------- 3 files changed, 47 insertions(+), 53 deletions(-) diff --git a/v2/parse.go b/v2/parse.go index 251ddb6..dcdd353 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -247,7 +247,7 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) } // store the values into the slice or map - err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) + err := setSliceOrMap(p.val(arg.dest), values) if err != nil { return nil, fmt.Errorf("error processing %s: %v", token, err) } @@ -312,7 +312,7 @@ func (p *Parser) processPositionals(positionals []string, overwrite bool) error } if arg.cardinality == multiple { if !p.seen[arg] || overwrite { - err := setSliceOrMap(p.val(arg.dest), positionals, true) + err := setSliceOrMap(p.val(arg.dest), positionals) if err != nil { return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } @@ -385,19 +385,20 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { if len(strings.TrimSpace(value)) > 0 { values, err = csv.NewReader(strings.NewReader(value)).Read() if err != nil { - return fmt.Errorf( - "error reading a CSV string from environment variable %s with multiple values: %v", - arg.env, - err, - ) + return fmt.Errorf("error reading a CSV string from environment variable %s : %v", arg.env, err) } } - if err = setSliceOrMap(p.val(arg.dest), values, !arg.separate); err != nil { - return fmt.Errorf( - "error processing environment variable %s with multiple values: %v", - arg.env, - err, - ) + + if arg.separate { + if err = setSliceOrMap(p.val(arg.dest), values); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) + } + } else { + for _, s := range values { + if err = appendToSliceOrMap(p.val(arg.dest), s); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) + } + } } } else { if err := scalar.ParseValue(p.val(arg.dest), value); err != nil { diff --git a/v2/sequence.go b/v2/sequence.go index f0fff46..566c8d2 100644 --- a/v2/sequence.go +++ b/v2/sequence.go @@ -8,9 +8,9 @@ import ( scalar "github.com/alexflint/go-scalar" ) -// setSliceOrMap parses a sequence of strings into a slice or map. If clear is -// true then any values already in the slice or map are first removed. -func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { +// setSliceOrMap parses a sequence of strings into a slice or map. The slice or +// map is always cleared first. +func setSliceOrMap(dest reflect.Value, values []string) error { if !dest.CanSet() { return fmt.Errorf("field is not writable") } @@ -23,17 +23,17 @@ func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { switch t.Kind() { case reflect.Slice: - return setSlice(dest, values, clear) + return setSlice(dest, values) case reflect.Map: - return setMap(dest, values, clear) + return setMap(dest, values) default: return fmt.Errorf("cannot insert multiple values into a %v", t) } } -// setSlice parses a sequence of strings and inserts them into a slice. If clear -// is true then any values already in the slice are removed. -func setSlice(dest reflect.Value, values []string, clear bool) error { +// setSlice parses a sequence of strings and inserts them into a slice. The +// slice is cleared first. +func setSlice(dest reflect.Value, values []string) error { var ptr bool elem := dest.Type().Elem() if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { @@ -42,7 +42,7 @@ func setSlice(dest reflect.Value, values []string, clear bool) error { } // clear the slice in case default values exist - if clear && !dest.IsNil() { + if !dest.IsNil() { dest.SetLen(0) } @@ -61,8 +61,8 @@ func setSlice(dest reflect.Value, values []string, clear bool) error { } // setMap parses a sequence of name=value strings and inserts them into a map. -// If clear is true then any values already in the map are removed. -func setMap(dest reflect.Value, values []string, clear bool) error { +// The map is always cleared first. +func setMap(dest reflect.Value, values []string) error { // determine the key and value type var keyIsPtr bool keyType := dest.Type().Key() @@ -79,7 +79,7 @@ func setMap(dest reflect.Value, values []string, clear bool) error { } // clear the slice in case default values exist - if clear && !dest.IsNil() { + if !dest.IsNil() { for _, k := range dest.MapKeys() { dest.SetMapIndex(k, reflect.Value{}) } diff --git a/v2/sequence_test.go b/v2/sequence_test.go index 6383949..519cdec 100644 --- a/v2/sequence_test.go +++ b/v2/sequence_test.go @@ -8,18 +8,17 @@ import ( "github.com/stretchr/testify/require" ) -func TestSetSliceWithoutClearing(t *testing.T) { +func TestAppendToSlice(t *testing.T) { xs := []int{10} - entries := []string{"1", "2", "3"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + err := appendToSlice(reflect.ValueOf(&xs).Elem(), "3") require.NoError(t, err) - assert.Equal(t, []int{10, 1, 2, 3}, xs) + assert.Equal(t, []int{10, 3}, xs) } -func TestSetSliceAfterClearing(t *testing.T) { +func TestSetSlice(t *testing.T) { xs := []int{100} entries := []string{"1", "2", "3"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, xs) } @@ -27,14 +26,14 @@ func TestSetSliceAfterClearing(t *testing.T) { func TestSetSliceInvalid(t *testing.T) { xs := []int{100} entries := []string{"invalid"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) assert.Error(t, err) } func TestSetSlicePtr(t *testing.T) { var xs []*int entries := []string{"1", "2", "3"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) require.NoError(t, err) require.Len(t, xs, 3) assert.Equal(t, 1, *xs[0]) @@ -46,7 +45,7 @@ func TestSetSliceTextUnmarshaller(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var xs []*textUnmarshaler entries := []string{"a", "aa", "aaa"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) require.NoError(t, err) require.Len(t, xs, 3) assert.Equal(t, 1, xs[0].val) @@ -54,21 +53,19 @@ func TestSetSliceTextUnmarshaller(t *testing.T) { assert.Equal(t, 3, xs[2].val) } -func TestSetMapWithoutClearing(t *testing.T) { +func TestAppendToMap(t *testing.T) { m := map[string]int{"foo": 10} - entries := []string{"a=1", "b=2"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + err := appendToMap(reflect.ValueOf(&m).Elem(), "a=1") require.NoError(t, err) - require.Len(t, m, 3) + require.Len(t, m, 2) assert.Equal(t, 1, m["a"]) - assert.Equal(t, 2, m["b"]) assert.Equal(t, 10, m["foo"]) } func TestSetMapAfterClearing(t *testing.T) { m := map[string]int{"foo": 10} entries := []string{"a=1", "b=2"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 2) assert.Equal(t, 1, m["a"]) @@ -79,7 +76,7 @@ func TestSetMapWithKeyPointer(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[*string]int entries := []string{"abc=123"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 1) } @@ -88,7 +85,7 @@ func TestSetMapWithValuePointer(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[string]*int entries := []string{"abc=123"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 1) assert.Equal(t, 123, *m["abc"]) @@ -98,7 +95,7 @@ func TestSetMapTextUnmarshaller(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[textUnmarshaler]*textUnmarshaler entries := []string{"a=123", "aa=12", "aaa=1"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 3) assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) @@ -109,14 +106,14 @@ func TestSetMapTextUnmarshaller(t *testing.T) { func TestSetMapInvalidKey(t *testing.T) { var m map[int]int entries := []string{"invalid=123"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) assert.Error(t, err) } func TestSetMapInvalidValue(t *testing.T) { var m map[int]int entries := []string{"123=invalid"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) assert.Error(t, err) } @@ -124,7 +121,7 @@ func TestSetMapMalformed(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[string]string entries := []string{"missing_equals_sign"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) assert.Error(t, err) } @@ -135,22 +132,18 @@ func TestSetSliceOrMapErrors(t *testing.T) { // converting a slice to a reflect.Value in this way will make it read only var cannotSet []int dest = reflect.ValueOf(cannotSet) - err = setSliceOrMap(dest, nil, false) + err = setSliceOrMap(dest, nil) assert.Error(t, err) // check what happens when we pass in something that is not a slice or a map var notSliceOrMap string dest = reflect.ValueOf(¬SliceOrMap).Elem() - err = setSliceOrMap(dest, nil, false) + err = setSliceOrMap(dest, nil) assert.Error(t, err) // check what happens when we pass in a pointer to something that is not a slice or a map var stringPtr *string dest = reflect.ValueOf(&stringPtr).Elem() - err = setSliceOrMap(dest, nil, false) + err = setSliceOrMap(dest, nil) assert.Error(t, err) } - -// check that we can accumulate "separate" args across env, cmdline, map, and defaults - -// check what happens if we have a required arg with a default value From b365ec078197b11cb297f49acc09e0f94d2b10cf Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 13:12:38 -0700 Subject: [PATCH 10/19] add processSingle and make it responsible for checking whether an argument has been seen before --- v2/example_test.go | 4 +- v2/parse.go | 118 +++++++++++++++++++-------------------------- 2 files changed, 51 insertions(+), 71 deletions(-) diff --git a/v2/example_test.go b/v2/example_test.go index e769d60..7a62373 100644 --- a/v2/example_test.go +++ b/v2/example_test.go @@ -404,7 +404,7 @@ func Example_errorText() { // output: // Usage: example [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] INPUT [OUTPUT [OUTPUT ...]] - // error: error processing --optimize: strconv.ParseInt: parsing "INVALID": invalid syntax + // error: error processing default value for --optimize: strconv.ParseInt: parsing "INVALID": invalid syntax } // This example shows the error string generated by go-arg when an invalid option is provided @@ -428,7 +428,7 @@ func Example_errorTextForSubcommand() { // output: // Usage: example get [--count COUNT] - // error: error processing --count: strconv.ParseInt: parsing "INVALID": invalid syntax + // error: error processing default value for --count: strconv.ParseInt: parsing "INVALID": invalid syntax } // This example demonstrates use of subcommands diff --git a/v2/parse.go b/v2/parse.go index dcdd353..cc17d4f 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -207,29 +207,8 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) } // deal with the case of multiple values - if arg.cardinality == multiple { - // if arg.separate is true then just parse one value and append it - if arg.separate { - if value == "" { - if i+1 == len(args) { - return nil, fmt.Errorf("missing value for %s", token) - } - if isFlag(args[i+1]) { - return nil, fmt.Errorf("missing value for %s", token) - } - value = args[i+1] - i++ - } - - err := appendToSliceOrMap(p.val(arg.dest), value) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - p.seen[arg] = true - continue - } - - // if args.separate is not true then consume tokens until next --option + if arg.cardinality == multiple && !arg.separate { + // consume tokens until next --option var values []string if value == "" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { @@ -240,6 +219,8 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) values = append(values, value) } + // TODO: call p.processVector + // this is the first time we can check p.seen because we need to correctly // increment i above, even when we then ignore the value if p.seen[arg] && !overwrite { @@ -251,6 +232,8 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) if err != nil { return nil, fmt.Errorf("error processing %s: %v", token, err) } + p.seen[arg] = true + continue } @@ -272,17 +255,10 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) i++ } - // this is the first time we can check p.seen because we need to correctly - // increment i above, even when we then ignore the value - if p.seen[arg] && !overwrite { - continue + // send the value to the argument + if err := p.processSingle(arg, value, overwrite); err != nil { + return nil, err } - - err := scalar.ParseValue(p.val(arg.dest), value) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - p.seen[arg] = true } return positionals, nil @@ -311,23 +287,21 @@ func (p *Parser) processPositionals(positionals []string, overwrite bool) error break } if arg.cardinality == multiple { + // TODO: call p.processMultiple if !p.seen[arg] || overwrite { err := setSliceOrMap(p.val(arg.dest), positionals) if err != nil { return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } + p.seen[arg] = true } positionals = nil } else { - if !p.seen[arg] || overwrite { - err := scalar.ParseValue(p.val(arg.dest), positionals[0]) - if err != nil { - return fmt.Errorf("error processing %s: %v", arg.field.Name, err) - } + if err := p.processSingle(arg, positionals[0], overwrite); err != nil { + return err } positionals = positionals[1:] } - p.seen[arg] = true } if len(positionals) > 0 { @@ -364,10 +338,6 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { // process arguments one-by-one for _, arg := range p.accumulatedArgs { - if p.seen[arg] && !overwrite { - continue - } - if arg.env == "" { continue } @@ -377,7 +347,7 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { continue } - if arg.cardinality == multiple { + if arg.cardinality == multiple && !arg.separate { // expect a CSV string in an environment // variable in the case of multiple values var values []string @@ -385,28 +355,20 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { if len(strings.TrimSpace(value)) > 0 { values, err = csv.NewReader(strings.NewReader(value)).Read() if err != nil { - return fmt.Errorf("error reading a CSV string from environment variable %s : %v", arg.env, err) + return fmt.Errorf("error parsing CSV string from environment variable %s: %v", arg.env, err) } } - if arg.separate { - if err = setSliceOrMap(p.val(arg.dest), values); err != nil { - return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) - } - } else { - for _, s := range values { - if err = appendToSliceOrMap(p.val(arg.dest), s); err != nil { - return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) - } - } - } - } else { - if err := scalar.ParseValue(p.val(arg.dest), value); err != nil { + // TODO: call p.processMultiple, respect "overwrite" + if err = setSliceOrMap(p.val(arg.dest), values); err != nil { return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) } + continue } - p.seen[arg] = true + if err := p.processSingle(arg, value, overwrite); err != nil { + return err + } } return nil @@ -428,26 +390,44 @@ func (p *Parser) OverwriteWithDefaults() error { // If overwrite is true then it overwrites existing values. func (p *Parser) processDefaults(overwrite bool) error { for _, arg := range p.accumulatedArgs { - if p.seen[arg] && !overwrite { - continue - } - if arg.defaultVal == "" { continue } - - name := strings.ToLower(arg.field.Name) - if arg.long != "" && !arg.positional { - name = "--" + arg.long + if err := p.processSingle(arg, arg.defaultVal, overwrite); err != nil { + return err } + } - err := scalar.ParseValue(p.val(arg.dest), arg.defaultVal) + return nil +} + +// processSingle parses a single argument, inserts it into the struct, +// and marks the argument as "seen" for the sake of required arguments +// and overwrite semantics. If the argument has been seen before and +// overwrite=false then the value is ignored. +func (p *Parser) processSingle(arg *Argument, value string, overwrite bool) error { + if p.seen[arg] && !overwrite && !arg.separate { + return nil + } + + name := strings.ToLower(arg.field.Name) + if arg.long != "" && !arg.positional { + name = "--" + arg.long + } + + if arg.cardinality == multiple { + err := appendToSliceOrMap(p.val(arg.dest), value) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + } else { + err := scalar.ParseValue(p.val(arg.dest), value) if err != nil { return fmt.Errorf("error processing default value for %s: %v", name, err) } - p.seen[arg] = true } + p.seen[arg] = true return nil } From 1cc263f9f213c3384246a86c19c0067083bad352 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 13:23:57 -0700 Subject: [PATCH 11/19] add processSequence and make it responsible for respecting "overwrite" --- v2/parse.go | 86 +++++++++++++++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/v2/parse.go b/v2/parse.go index cc17d4f..b21dbf9 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -206,9 +206,8 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) return nil, fmt.Errorf("unknown argument %s", token) } - // deal with the case of multiple values + // for the case of multiple values, consume tokens until next --option if arg.cardinality == multiple && !arg.separate { - // consume tokens until next --option var values []string if value == "" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { @@ -219,21 +218,9 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) values = append(values, value) } - // TODO: call p.processVector - - // this is the first time we can check p.seen because we need to correctly - // increment i above, even when we then ignore the value - if p.seen[arg] && !overwrite { - continue + if err := p.processSequence(arg, values, overwrite); err != nil { + return nil, err } - - // store the values into the slice or map - err := setSliceOrMap(p.val(arg.dest), values) - if err != nil { - return nil, fmt.Errorf("error processing %s: %v", token, err) - } - p.seen[arg] = true - continue } @@ -256,7 +243,7 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) } // send the value to the argument - if err := p.processSingle(arg, value, overwrite); err != nil { + if err := p.processScalar(arg, value, overwrite); err != nil { return nil, err } } @@ -287,21 +274,17 @@ func (p *Parser) processPositionals(positionals []string, overwrite bool) error break } if arg.cardinality == multiple { - // TODO: call p.processMultiple - if !p.seen[arg] || overwrite { - err := setSliceOrMap(p.val(arg.dest), positionals) - if err != nil { - return fmt.Errorf("error processing %s: %v", arg.field.Name, err) - } - p.seen[arg] = true - } - positionals = nil - } else { - if err := p.processSingle(arg, positionals[0], overwrite); err != nil { + if err := p.processSequence(arg, positionals, overwrite); err != nil { return err } - positionals = positionals[1:] + positionals = nil + break } + + if err := p.processScalar(arg, positionals[0], overwrite); err != nil { + return err + } + positionals = positionals[1:] } if len(positionals) > 0 { @@ -351,8 +334,8 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { // expect a CSV string in an environment // variable in the case of multiple values var values []string - var err error if len(strings.TrimSpace(value)) > 0 { + var err error values, err = csv.NewReader(strings.NewReader(value)).Read() if err != nil { return fmt.Errorf("error parsing CSV string from environment variable %s: %v", arg.env, err) @@ -360,14 +343,14 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { } // TODO: call p.processMultiple, respect "overwrite" - if err = setSliceOrMap(p.val(arg.dest), values); err != nil { + if err := p.processSequence(arg, values, overwrite); err != nil { return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) } continue } - if err := p.processSingle(arg, value, overwrite); err != nil { - return err + if err := p.processScalar(arg, value, overwrite); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) } } @@ -393,7 +376,7 @@ func (p *Parser) processDefaults(overwrite bool) error { if arg.defaultVal == "" { continue } - if err := p.processSingle(arg, arg.defaultVal, overwrite); err != nil { + if err := p.processScalar(arg, arg.defaultVal, overwrite); err != nil { return err } } @@ -401,11 +384,10 @@ func (p *Parser) processDefaults(overwrite bool) error { return nil } -// processSingle parses a single argument, inserts it into the struct, -// and marks the argument as "seen" for the sake of required arguments -// and overwrite semantics. If the argument has been seen before and -// overwrite=false then the value is ignored. -func (p *Parser) processSingle(arg *Argument, value string, overwrite bool) error { +// processScalar parses a single argument, inserts it into the struct, +// and marks the argument as "seen" (unless the argument has been seen +// before and overwrite=false, in which case the value is ignored) +func (p *Parser) processScalar(arg *Argument, value string, overwrite bool) error { if p.seen[arg] && !overwrite && !arg.separate { return nil } @@ -431,6 +413,32 @@ func (p *Parser) processSingle(arg *Argument, value string, overwrite bool) erro return nil } +// processSequence parses a sequence argument, inserts it into the struct, +// and marks the argument as "seen" (unless the argument has been seen +// before and overwrite=false, in which case the value is ignored) +func (p *Parser) processSequence(arg *Argument, values []string, overwrite bool) error { + if p.seen[arg] && !overwrite && !arg.separate { + return nil + } + + name := strings.ToLower(arg.field.Name) + if arg.long != "" && !arg.positional { + name = "--" + arg.long + } + + if arg.cardinality != multiple { + panic(fmt.Sprintf("processSequence called for argument %s which has cardinality %v", arg.field.Name, arg.cardinality)) + } + + err := setSliceOrMap(p.val(arg.dest), values) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + + p.seen[arg] = true + return nil +} + // Missing returns a list of required arguments that were not provided func (p *Parser) Missing() []*Argument { var missing []*Argument From 84b7154efcc5b9d056a33f9c5e8096d107a43a6b Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 13:25:01 -0700 Subject: [PATCH 12/19] add TestSliceWithEqualsSign --- v2/parse_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/v2/parse_test.go b/v2/parse_test.go index 72efa79..7fb3c2e 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -251,6 +251,16 @@ func TestSlice(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"a", "b", "c"}, args.Strings) } + +func TestSliceWithEqualsSign(t *testing.T) { + var args struct { + Strings []string + } + _, err := parse(&args, "--strings=test") + require.NoError(t, err) + assert.Equal(t, []string{"test"}, args.Strings) +} + func TestSliceOfBools(t *testing.T) { var args struct { B []bool From 0769dd58393c44a93e9c1f4b42a7249d1cc598cb Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 13:51:51 -0700 Subject: [PATCH 13/19] add tests for new Process* and OverwriteWith* functions --- v2/precedence_test.go | 271 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 v2/precedence_test.go diff --git a/v2/precedence_test.go b/v2/precedence_test.go new file mode 100644 index 0000000..84646e4 --- /dev/null +++ b/v2/precedence_test.go @@ -0,0 +1,271 @@ +package arg + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// this file contains tests related to the precedence rules for: +// ProcessCommandLine +// ProcessOptions +// ProcessPositionals +// ProcessEnvironment +// ProcessMap +// ProcessSingle +// ProcessSequence +// OverwriteWithCommandLine +// OverwriteWithOptions +// OverwriteWithPositionals +// OverwriteWithEnvironment +// OverwriteWithMap +// +// The Process* functions should not overwrite fields that have +// been previously populated, whereas the OverwriteWith* functions +// should overwrite fields that have been previously populated. + +// check that we can accumulate "separate" args across env, cmdline, map, and defaults + +// check what happens if we have a required arg with a default value + +// add more tests for combinations of separate and cardinality + +// check what happens if we call ProcessCommandLine multiple times with different subcommands + +func TestProcessOptions(t *testing.T) { + var args struct { + Arg string + } + + p, err := NewParser(&args) + require.NoError(t, err) + + _, err = p.ProcessOptions([]string{"program", "--arg=hello"}) + require.NoError(t, err) + assert.Equal(t, "hello", args.Arg) +} + +func TestProcessOptionsDoesNotOverwrite(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=123"}) + require.NoError(t, err) + + _, err = p.ProcessOptions([]string{"--arg=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "123", args.Arg) +} + +func TestOverwriteWithOptions(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=123"}) + require.NoError(t, err) + + _, err = p.OverwriteWithOptions([]string{"--arg=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} + +func TestProcessPositionals(t *testing.T) { + var args struct { + Arg string `arg:"positional"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessPositionals([]string{"hello"}) + require.NoError(t, err) + assert.Equal(t, "hello", args.Arg) +} + +func TestProcessPositionalsDoesNotOverwrite(t *testing.T) { + var args struct { + Arg string `arg:"env,positional"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=123"}) + require.NoError(t, err) + + err = p.ProcessPositionals([]string{"hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "123", args.Arg) +} + +func TestOverwriteWithPositionals(t *testing.T) { + var args struct { + Arg string `arg:"env,positional"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=123"}) + require.NoError(t, err) + + err = p.OverwriteWithPositionals([]string{"hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} + +func TestProcessCommandLine(t *testing.T) { + var args struct { + Arg string + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessCommandLine([]string{"program", "--arg=hello"}) + require.NoError(t, err) + assert.Equal(t, "hello", args.Arg) +} + +func TestProcessCommandLineDoesNotOverwrite(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=123"}) + require.NoError(t, err) + + err = p.ProcessCommandLine([]string{"program", "--arg=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "123", args.Arg) +} + +func TestOverwriteWithCommandLine(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=123"}) + require.NoError(t, err) + + err = p.OverwriteWithCommandLine([]string{"program", "--arg=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} + +func TestProcessEnvironment(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} + +func TestProcessEnvironmentDoesNotOverwrite(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + _, err = p.ProcessOptions([]string{"--arg=123"}) + require.NoError(t, err) + + err = p.ProcessEnvironment([]string{"ARG=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "123", args.Arg) +} + +func TestOverwriteWithEnvironment(t *testing.T) { + var args struct { + Arg string `arg:"env"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + _, err = p.ProcessOptions([]string{"--arg=123"}) + require.NoError(t, err) + + err = p.OverwriteWithEnvironment([]string{"ARG=hello"}) + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} + +func TestProcessDefaults(t *testing.T) { + var args struct { + Arg string `default:"hello"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + err = p.ProcessDefaults() + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} + +func TestProcessDefaultsDoesNotOverwrite(t *testing.T) { + var args struct { + Arg string `default:"hello"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + _, err = p.ProcessOptions([]string{"--arg=123"}) + require.NoError(t, err) + + err = p.ProcessDefaults() + require.NoError(t, err) + + assert.EqualValues(t, "123", args.Arg) +} + +func TestOverwriteWithDefaults(t *testing.T) { + var args struct { + Arg string `default:"hello"` + } + + p, err := NewParser(&args) + require.NoError(t, err) + + _, err = p.ProcessOptions([]string{"--arg=123"}) + require.NoError(t, err) + + err = p.OverwriteWithDefaults() + require.NoError(t, err) + + assert.EqualValues(t, "hello", args.Arg) +} From 55d90253290aa07aa248a9e628f34cd6c4791722 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Tue, 4 Oct 2022 13:54:53 -0700 Subject: [PATCH 14/19] rename "accumulatedArgs" -> "accessible" --- v2/construct.go | 10 +++++++--- v2/parse.go | 41 +++++++++++++++++++++-------------------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/v2/construct.go b/v2/construct.go index bd2800e..64bc28d 100644 --- a/v2/construct.go +++ b/v2/construct.go @@ -44,9 +44,9 @@ type Parser struct { epilogue string // epilogue for help text (from the argument struct) // the following fields are updated during processing of command line arguments - leaf *Command // the subcommand we processed last - accumulatedArgs []*Argument // concatenation of the leaf subcommand's arguments plus all ancestors' arguments - seen map[*Argument]bool // the arguments we encountered while processing command line arguments + leaf *Command // the subcommand we processed last + accessible []*Argument // concatenation of the leaf subcommand's arguments plus all ancestors' arguments + seen map[*Argument]bool // the arguments we encountered while processing command line arguments } // Versioned is the interface that the destination struct should implement to @@ -123,6 +123,10 @@ func NewParser(dest interface{}, options ...ParserOption) (*Parser, error) { root: reflect.ValueOf(dest), cmd: cmd, } + // copy the args for the root command into "accessible", which will + // grow each time we open up a subcommand + p.accessible = make([]*Argument, len(p.cmd.args)) + copy(p.accessible, p.cmd.args) // check for version, prologue, and epilogue if dest, ok := dest.(Versioned); ok { diff --git a/v2/parse.go b/v2/parse.go index b21dbf9..2b05171 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -102,7 +102,11 @@ func (p *Parser) Parse(args, env []string) error { // never overwrites arguments previously seen in a call to any Process* // function. func (p *Parser) ProcessCommandLine(args []string) error { - positionals, err := p.ProcessOptions(args) + if len(args) == 0 { + return nil + } + + positionals, err := p.ProcessOptions(args[1:]) if err != nil { return err } @@ -112,7 +116,11 @@ func (p *Parser) ProcessCommandLine(args []string) error { // OverwriteWithCommandLine is like ProcessCommandLine but it overwrites // any previously seen values. func (p *Parser) OverwriteWithCommandLine(args []string) error { - positionals, err := p.OverwriteWithOptions(args) + if len(args) == 0 { + return nil + } + + positionals, err := p.OverwriteWithOptions(args[1:]) if err != nil { return err } @@ -121,10 +129,8 @@ func (p *Parser) OverwriteWithCommandLine(args []string) error { // ProcessOptions processes options but not positionals from the // command line. Positionals are returned and can be passed to -// ProcessPositionals. This function ignores the first element of args, -// which is assumed to be the program name itself. Arguments seen -// in a previous call to any Process* or OverwriteWith* functions -// are ignored. +// ProcessPositionals. Arguments seen in a previous call to any +// Process* or OverwriteWith* functions are ignored. func (p *Parser) ProcessOptions(args []string) ([]string, error) { return p.processOptions(args, false) } @@ -136,19 +142,15 @@ func (p *Parser) OverwriteWithOptions(args []string) ([]string, error) { } func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) { + // note that p.cmd.args has already been copied into p.accessible in NewParser + // union of args for the chain of subcommands encountered so far p.leaf = p.cmd - - // we will add to this list each time we expand a subcommand - p.accumulatedArgs = make([]*Argument, len(p.leaf.args)) - copy(p.accumulatedArgs, p.leaf.args) - - // process each string from the command line var allpositional bool var positionals []string // must use explicit for loop, not range, because we manipulate i inside the loop - for i := 1; i < len(args); i++ { + for i := 0; i < len(args); i++ { token := args[i] // the "--" token indicates that all further tokens should be treated as positionals @@ -178,7 +180,7 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) } // add the new options to the set of allowed options - p.accumulatedArgs = append(p.accumulatedArgs, subcmd.args...) + p.accessible = append(p.accessible, subcmd.args...) p.leaf = subcmd continue } @@ -201,7 +203,7 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) // look up the arg for this option (note that the "args" slice changes as // we expand subcommands so it is better not to use a map) - arg := findOption(p.accumulatedArgs, opt) + arg := findOption(p.accessible, opt) if arg == nil { return nil, fmt.Errorf("unknown argument %s", token) } @@ -266,7 +268,7 @@ func (p *Parser) OverwriteWithPositionals(positionals []string) error { } func (p *Parser) processPositionals(positionals []string, overwrite bool) error { - for _, arg := range p.accumulatedArgs { + for _, arg := range p.accessible { if !arg.positional { continue } @@ -320,7 +322,7 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { } // process arguments one-by-one - for _, arg := range p.accumulatedArgs { + for _, arg := range p.accessible { if arg.env == "" { continue } @@ -342,7 +344,6 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { } } - // TODO: call p.processMultiple, respect "overwrite" if err := p.processSequence(arg, values, overwrite); err != nil { return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) } @@ -372,7 +373,7 @@ func (p *Parser) OverwriteWithDefaults() error { // processDefaults assigns default values to all fields in all expanded subcommands. // If overwrite is true then it overwrites existing values. func (p *Parser) processDefaults(overwrite bool) error { - for _, arg := range p.accumulatedArgs { + for _, arg := range p.accessible { if arg.defaultVal == "" { continue } @@ -442,7 +443,7 @@ func (p *Parser) processSequence(arg *Argument, values []string, overwrite bool) // Missing returns a list of required arguments that were not provided func (p *Parser) Missing() []*Argument { var missing []*Argument - for _, arg := range p.accumulatedArgs { + for _, arg := range p.accessible { if arg.required && !p.seen[arg] { missing = append(missing, arg) } From 60a0117880dc7db4c397dbf2e893a8cebe165ef2 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Fri, 7 Oct 2022 12:51:27 -0700 Subject: [PATCH 15/19] update readme for v2 (still has some TODOs) --- README.md | 325 ++++++++++++++++++++++++++---------------------------- 1 file changed, 159 insertions(+), 166 deletions(-) diff --git a/README.md b/README.md index f105b17..887eb9e 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ Declare command line arguments for your program by defining a struct. +```go +import "github.com/go-arg/v2 +``` + +TODO + ```go var args struct { Foo string @@ -33,7 +39,7 @@ hello true ### Installation ```shell -go get github.com/alexflint/go-arg +go get github.com/alexflint/go-arg/v2 ``` ### Required arguments @@ -90,36 +96,6 @@ $ WORKERS=4 ./example --workers=6 Workers: 6 ``` -You can also override the name of the environment variable: - -```go -var args struct { - Workers int `arg:"env:NUM_WORKERS"` -} -arg.MustParse(&args) -fmt.Println("Workers:", args.Workers) -``` - -``` -$ NUM_WORKERS=4 ./example -Workers: 4 -``` - -You can provide multiple values using the CSV (RFC 4180) format: - -```go -var args struct { - Workers []int `arg:"env"` -} -arg.MustParse(&args) -fmt.Println("Workers:", args.Workers) -``` - -``` -$ WORKERS='1,99' ./example -Workers: [1 99] -``` - ### Usage strings ```go var args struct { @@ -158,47 +134,23 @@ var args struct { arg.MustParse(&args) ``` -### Default values (before v1.2) +### Overriding the name of an environment variable ```go var args struct { - Foo string - Bar bool -} -arg.Foo = "abc" -arg.MustParse(&args) -``` - -### Combining command line options, environment variables, and default values - -You can combine command line arguments, environment variables, and default values. Command line arguments take precedence over environment variables, which take precedence over default values. This means that we check whether a certain option was provided on the command line, then if not, we check for an environment variable (only if an `env` tag was provided), then if none is found, we check for a `default` tag containing a default value. - -```go -var args struct { - Test string `arg:"-t,env:TEST" default:"something"` + Workers int `arg:"env:NUM_WORKERS"` } arg.MustParse(&args) +fmt.Println("Workers:", args.Workers) ``` -#### Ignoring environment variables and/or default values - -The values in an existing structure can be kept in-tact by ignoring environment -variables and/or default values. - -```go -var args struct { - Test string `arg:"-t,env:TEST" default:"something"` -} - -p, err := arg.NewParser(arg.Config{ - IgnoreEnv: true, - IgnoreDefault: true, -}, &args) - -err = p.Parse(os.Args) +``` +$ NUM_WORKERS=4 ./example +Workers: 4 ``` ### Arguments with multiple values + ```go var args struct { Database string @@ -213,23 +165,6 @@ fmt.Printf("Fetching the following IDs from %s: %q", args.Database, args.IDs) Fetching the following IDs from foo: [1 2 3] ``` -### Arguments that can be specified multiple times, mixed with positionals -```go -var args struct { - Commands []string `arg:"-c,separate"` - Files []string `arg:"-f,separate"` - Databases []string `arg:"positional"` -} -arg.MustParse(&args) -``` - -```shell -./example -c cmd1 db1 -f file1 db2 -c cmd2 -f file2 -f file3 db3 -c cmd3 -Commands: [cmd1 cmd2 cmd3] -Files [file1 file2 file3] -Databases [db1 db2 db3] -``` - ### Arguments with keys and values ```go var args struct { @@ -266,7 +201,7 @@ error: you must provide either --foo or --bar ```go type args struct { - ... + // ... } func (args) Version() string { @@ -353,7 +288,7 @@ The following types may be used as arguments: - maps using any of the above as keys and values - any type that implements `encoding.TextUnmarshaler` -### Custom parsing +### Custom parsing Implement `encoding.TextUnmarshaler` to define your own parsing logic. @@ -391,73 +326,121 @@ Usage: example [--name NAME] error: error processing --name: missing period in "oops" ``` -### Custom parsing with default values +### Slice-valued environment variables -Implement `encoding.TextMarshaler` to define your own default value strings: - -```go -// Accepts command line arguments of the form "head.tail" -type NameDotName struct { - Head, Tail string -} - -func (n *NameDotName) UnmarshalText(b []byte) error { - // same as previous example -} - -// this is only needed if you want to display a default value in the usage string -func (n *NameDotName) MarshalText() ([]byte, error) { - return []byte(fmt.Sprintf("%s.%s", n.Head, n.Tail)), nil -} - -func main() { - var args struct { - Name NameDotName `default:"file.txt"` - } - arg.MustParse(&args) - fmt.Printf("%#v\n", args.Name) -} -``` -```shell -$ ./example --help -Usage: test [--name NAME] - -Options: - --name NAME [default: file.txt] - --help, -h display this help and exit - -$ ./example -main.NameDotName{Head:"file", Tail:"txt"} -``` - -### Custom placeholders - -*Introduced in version 1.3.0* - -Use the `placeholder` tag to control which placeholder text is used in the usage text. +You can provide multiple values using the CSV (RFC 4180) format: ```go var args struct { - Input string `arg:"positional" placeholder:"SRC"` - Output []string `arg:"positional" placeholder:"DST"` - Optimize int `arg:"-O" help:"optimization level" placeholder:"LEVEL"` - MaxJobs int `arg:"-j" help:"maximum number of simultaneous jobs" placeholder:"N"` + Workers []int `arg:"env"` +} +arg.MustParse(&args) +fmt.Println("Workers:", args.Workers) +``` + +``` +$ WORKERS='1,99' ./example +Workers: [1 99] +``` + +### Parsing command line tokens and environment variables from a slice + +You can override the command line tokens and environment variables processed by go-arg: + +```go +var args struct { + Samsara int + Nirvana float64 `arg:"env:NIRVANA"` +} +p, err := arg.NewParser(&args) +if err != nil { + log.Fatal(err) +} +cmdline := []string{"./thisprogram", "--samsara=123"} +environ := []string{"NIRVANA=45.6"} +err = p.Parse(cmdline, environ) +if err != nil { + log.Fatal(err) +} +``` +``` +./example +SAMSARA: 123 +NIRVANA: 45.6 +``` + +### Configuration files + +TODO + +### Combining command line options, environment variables, and default values + +By default, command line arguments take precedence over environment variables, which take precedence over default values. This means that we check whether a certain option was provided on the command line, then if not, we check for an environment variable (only if an `env` tag was provided), then, if none is found, we check for a `default` tag. + +```go +var args struct { + Test string `arg:"-t,env:TEST" default:"something"` } arg.MustParse(&args) ``` + +### Changing precedence of command line options, environment variables, and default values + +You can use the low-level functions `Process*` and `OverwriteWith*` to control which things override which other things. Here is an example in which environment variables take precedence over command line options, which is the opposite of the default behavior: + +```go +var args struct { + Test string `arg:"env:TEST"` +} + +p, err := arg.NewParser(&args) +if err != nil { + log.Fatal(err) +} + +err = p.ParseCommandLine(os.Args) +if err != nil { + p.Fail(err.Error()) +} + +err = p.OverwriteWithEnvironment(os.Environ()) +if err != nil { + p.Fail(err.Error()) +} + +err = p.Validate() +if err != nil { + p.Fail(err.Error()) +} + +fmt.Printf("test=%s\n", args.Test) +``` +``` +TEST=value_from_env ./example --test=value_from_option +test=value_from_env +``` + +### Ignoring environment variables + +TODO + +### Ignoring default values + +TODO + +### Arguments that can be specified multiple times +```go +var args struct { + Commands []string `arg:"-c,separate"` + Files []string `arg:"-f,separate"` +} +arg.MustParse(&args) +``` + ```shell -$ ./example -h -Usage: example [--optimize LEVEL] [--maxjobs N] SRC [DST [DST ...]] - -Positional arguments: - SRC - DST - -Options: - --optimize LEVEL, -O LEVEL - optimization level - --maxjobs N, -j N maximum number of simultaneous jobs - --help, -h display this help and exit +./example -c cmd1 -f file1 -c cmd2 -f file2 -f file3 -c cmd3 +Commands: [cmd1 cmd2 cmd3] +Files [file1 file2 file3] ``` ### Description strings @@ -521,18 +504,14 @@ For more information visit github.com/alexflint/go-arg ### Subcommands -*Introduced in version 1.1.0* - -Subcommands are commonly used in tools that wish to group multiple functions into a single program. An example is the `git` tool: +Subcommands are commonly used in tools that group multiple functions into a single program. An example is the `git` tool: ```shell $ git checkout [arguments specific to checking out code] -$ git commit [arguments specific to committing] -$ git push [arguments specific to pushing] +$ git commit [arguments specific to committing code] +$ git push [arguments specific to pushing code] ``` -The strings "checkout", "commit", and "push" are different from simple positional arguments because the options available to the user change depending on which subcommand they choose. - -This can be implemented with `go-arg` as follows: +This can be implemented with `go-arg` with the `arg:"subcommand"` tag: ```go type CheckoutCmd struct { @@ -567,14 +546,9 @@ case args.Push != nil: } ``` -Some additional rules apply when working with subcommands: -* The `subcommand` tag can only be used with fields that are pointers to structs -* Any struct that contains a subcommand must not contain any positionals +Note that the `subcommand` tag can only be used with fields that are pointers to structs, and that any struct that contains subcommands cannot also contain positionals. -This package allows to have a program that accepts subcommands, but also does something else -when no subcommands are specified. -If on the other hand you want the program to terminate when no subcommands are specified, -the recommended way is: +### Terminating when no subcommands are specified ```go p := arg.MustParse(&args) @@ -583,20 +557,39 @@ if p.Subcommand() == nil { } ``` +### Customizing placeholder strings + +Use the `placeholder` tag to control which placeholder text is used in the usage text. + +```go +var args struct { + Input string `arg:"positional" placeholder:"SRC"` + Output []string `arg:"positional" placeholder:"DST"` + Optimize int `arg:"-O" placeholder:"LEVEL"` + MaxJobs int `arg:"-j" placeholder:"N"` +} +arg.MustParse(&args) +``` +```shell +$ ./example -h +Usage: example [--optimize LEVEL] [--maxjobs N] SRC [DST [DST ...]] + +Positional arguments: + SRC + DST + +Options: + --optimize LEVEL, -O LEVEL + --maxjobs N, -j N + --help, -h display this help and exit +``` + ### API Documentation https://godoc.org/github.com/alexflint/go-arg -### Rationale +### Migrating from v1.x -There are many command line argument parsing libraries for Go, including one in the standard library, so why build another? +Migrating IgnoreEnv to passing a nil environ -The `flag` library that ships in the standard library seems awkward to me. Positional arguments must preceed options, so `./prog x --foo=1` does what you expect but `./prog --foo=1 x` does not. It also does not allow arguments to have both long (`--foo`) and short (`-f`) forms. - -Many third-party argument parsing libraries are great for writing sophisticated command line interfaces, but feel to me like overkill for a simple script with a few flags. - -The idea behind `go-arg` is that Go already has an excellent way to describe data structures using structs, so there is no need to develop additional levels of abstraction. Instead of one API to specify which arguments your program accepts, and then another API to get the values of those arguments, `go-arg` replaces both with a single struct. - -### Backward compatibility notes - -Earlier versions of this library required the help text to be part of the `arg` tag. This is still supported but is now deprecated. Instead, you should use a separate `help` tag, described above, which removes most of the limits on the text you can write. In particular, you will need to use the new `help` tag if your help text includes any commas. +Migrating from IgnoreDefault to calling ProcessCommandLine \ No newline at end of file From 47ff44303fbe55cf9c98030670bff65667639d3c Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Fri, 7 Oct 2022 12:51:55 -0700 Subject: [PATCH 16/19] drop support for help tag inside arg tag --- v2/construct.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/v2/construct.go b/v2/construct.go index 64bc28d..fe3ad4b 100644 --- a/v2/construct.go +++ b/v2/construct.go @@ -234,8 +234,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*Command, error) { arg.positional = true case key == "separate": arg.separate = true - case key == "help": // deprecated - arg.help = value case key == "env": // Use override name if provided if value != "" { From 2ffe24630bbdbd070d324b36ca415afb47a4a90f Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Fri, 7 Oct 2022 14:14:01 -0700 Subject: [PATCH 17/19] add mdtest command to generate and run tests from a markdown file --- mdtest/example1.go.tpl | 11 +++ mdtest/example2.go.tpl | 9 +++ mdtest/mdtest.go | 179 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+) create mode 100644 mdtest/example1.go.tpl create mode 100644 mdtest/example2.go.tpl create mode 100644 mdtest/mdtest.go diff --git a/mdtest/example1.go.tpl b/mdtest/example1.go.tpl new file mode 100644 index 0000000..a6b12c6 --- /dev/null +++ b/mdtest/example1.go.tpl @@ -0,0 +1,11 @@ +package main + +import ( + "github.com/alexflint/go-arg/v2" + {{if contains .Code "fmt."}}"fmt"{{end}} + {{if contains .Code "strings."}}"strings"{{end}} +) + +func main() { + {{.Code}} +} diff --git a/mdtest/example2.go.tpl b/mdtest/example2.go.tpl new file mode 100644 index 0000000..5cbdd84 --- /dev/null +++ b/mdtest/example2.go.tpl @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/alexflint/go-arg/v2" + {{if contains .Code "fmt."}}"fmt"{{end}} + {{if contains .Code "strings."}}"strings"{{end}} +) + +{{.Code}} diff --git a/mdtest/mdtest.go b/mdtest/mdtest.go new file mode 100644 index 0000000..ed22146 --- /dev/null +++ b/mdtest/mdtest.go @@ -0,0 +1,179 @@ +// mdtest executes code blocks in markdown and checks that they run as expected +package main + +import ( + "bytes" + "context" + _ "embed" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "text/template" + "time" + + "github.com/alexflint/go-arg/v2" +) + +// var pattern = "```go(.*)```\\s*```\\s*\\$(.*)\\n(.*)```" +var pattern = "(?s)```go([^`]*?)```\\s*```([^`]*?)```" //go(.*)```\\s*```\\s*\\$(.*)\\n(.*)```" + +var re = regexp.MustCompile(pattern) + +var funcs = map[string]any{ + "contains": strings.Contains, +} + +//go:embed example1.go.tpl +var templateSource1 string + +//go:embed example2.go.tpl +var templateSource2 string + +var t1 = template.Must(template.New("example1.go").Funcs(funcs).Parse(templateSource1)) +var t2 = template.Must(template.New("example2.go").Funcs(funcs).Parse(templateSource2)) + +type payload struct { + Code string +} + +func runCode(ctx context.Context, code []byte, cmd string) ([]byte, error) { + dir, err := os.MkdirTemp("", "") + if err != nil { + return nil, fmt.Errorf("error creating temp dir to build and run code: %w", err) + } + + fmt.Println(dir) + fmt.Println(strings.Repeat("-", 80)) + + srcpath := filepath.Join(dir, "src.go") + binpath := filepath.Join(dir, "example") + + // If the code contains a main function then use t2, otherwise use t1 + t := t1 + if strings.Contains(string(code), "func main") { + t = t2 + } + + var b bytes.Buffer + err = t.Execute(&b, payload{Code: string(code)}) + if err != nil { + return nil, fmt.Errorf("error executing template for source file: %w", err) + } + + fmt.Println(b.String()) + fmt.Println(strings.Repeat("-", 80)) + + err = os.WriteFile(srcpath, b.Bytes(), os.ModePerm) + if err != nil { + return nil, fmt.Errorf("error writing temporary source file: %w", err) + } + + compiler, err := exec.LookPath("go") + if err != nil { + return nil, fmt.Errorf("could not find path to go compiler: %w", err) + } + + buildCmd := exec.CommandContext(ctx, compiler, "build", "-o", binpath, srcpath) + out, err := buildCmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("error building source: %w. Compiler said:\n%s", err, string(out)) + } + + // replace "./example" with full path to compiled program + var env, args []string + var found bool + for _, part := range strings.Split(cmd, " ") { + if found { + args = append(args, part) + } else if part == "./example" { + found = true + } else { + env = append(env, part) + } + } + + runCmd := exec.CommandContext(ctx, binpath, args...) + runCmd.Env = env + output, err := runCmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("error runing example: %w. Program said:\n%s", err, string(output)) + } + + // Clean up the temp dir + if err := os.RemoveAll(dir); err != nil { + return nil, fmt.Errorf("error deleting temp dir: %w", err) + } + + return output, nil +} + +func Main() error { + ctx := context.Background() + + var args struct { + Input string `arg:"positional,required"` + } + arg.MustParse(&args) + + buf, err := os.ReadFile(args.Input) + if err != nil { + return err + } + + fmt.Println(strings.Repeat("=", 80)) + + matches := re.FindAllSubmatchIndex(buf, -1) + for k, match := range matches { + codebegin, codeend := match[2], match[3] + code := buf[codebegin:codeend] + + shellbegin, shellend := match[4], match[5] + shell := buf[shellbegin:shellend] + + lines := strings.Split(string(shell), "\n") + for i := 0; i < len(lines); i++ { + if strings.HasPrefix(lines[i], "$") && strings.Contains(lines[i], "./example") { + cmd := strings.TrimSpace(strings.TrimPrefix(lines[i], "$")) + + var output []string + i++ + for i < len(lines) && !strings.HasPrefix(lines[i], "$") { + output = append(output, lines[i]) + i++ + } + + expected := strings.TrimSpace(strings.Join(output, "\n")) + + fmt.Println(string(code)) + fmt.Println(strings.Repeat("-", 80)) + fmt.Println(string(cmd)) + fmt.Println(strings.Repeat("-", 80)) + fmt.Println(string(expected)) + fmt.Println(strings.Repeat("-", 80)) + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + actual, err := runCode(ctx, code, cmd) + if err != nil { + return fmt.Errorf("error running example %d: %w\nCode was:\n%s", k, err, string(code)) + } + + fmt.Println(string(actual)) + fmt.Println(strings.Repeat("=", 80)) + } + } + } + fmt.Printf("found %d matches\n", len(matches)) + return nil +} + +func main() { + if err := Main(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} From f2539d7ad233a95ec1e9680b34c1f70d1f5ff236 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Sat, 29 Oct 2022 12:28:46 -0400 Subject: [PATCH 18/19] add go.work -- maybe remove before merge? --- go.work | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 go.work diff --git a/go.work b/go.work new file mode 100644 index 0000000..3544a52 --- /dev/null +++ b/go.work @@ -0,0 +1,2 @@ +use . +use ./v2 From c046f49e125dfe0596bf48c6323a5f60a5371e1f Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Sat, 29 Oct 2022 15:28:22 -0400 Subject: [PATCH 19/19] drop go.work and add it to .gitignore --- .gitignore | 1 + go.work | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) delete mode 100644 go.work diff --git a/.gitignore b/.gitignore index daf913b..ab343f1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ _testmain.go *.exe *.test *.prof +go.work diff --git a/go.work b/go.work deleted file mode 100644 index 3544a52..0000000 --- a/go.work +++ /dev/null @@ -1,2 +0,0 @@ -use . -use ./v2