Merge pull request #82 from alexflint/subcommand-impl

Add support for subcommands
This commit is contained in:
Alex Flint 2019-08-06 16:58:46 -07:00 committed by GitHub
commit 8baf7040d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1222 additions and 235 deletions

View File

@ -1,7 +1,7 @@
language: go
go:
- "1.10"
- "1.11"
- "1.12"
- tip
env:
- GO111MODULE=on # will only be used in go 1.11

View File

@ -353,6 +353,59 @@ Options:
--help, -h display this help and exit
```
### Subcommands
*Introduced in `v1.1.0`*
Subcommands are commonly used in tools that wish to group multiple functions into a single program. An example is the `git` tool:
```shell
$ git checkout [arguments specific to checking out code]
$ git commit [arguments specific to committing]
$ git push [arguments specific to pushing]
```
The strings "checkout", "commit", and "push" are different from simple positional arguments because the options available to the user change depending on which subcommand they choose.
This can be implemented with `go-arg` as follows:
```go
type CheckoutCmd struct {
Branch string `arg:"positional"`
Track bool `arg:"-t"`
}
type CommitCmd struct {
All bool `arg:"-a"`
Message string `arg:"-m"`
}
type PushCmd struct {
Remote string `arg:"positional"`
Branch string `arg:"positional"`
SetUpstream bool `arg:"-u"`
}
var args struct {
Checkout *CheckoutCmd `arg:"subcommand:checkout"`
Commit *CommitCmd `arg:"subcommand:commit"`
Push *PushCmd `arg:"subcommand:push"`
Quiet bool `arg:"-q"` // this flag is global to all subcommands
}
arg.MustParse(&args)
switch {
case args.Checkout != nil:
fmt.Printf("checkout requested for branch %s\n", args.Checkout.Branch)
case args.Commit != nil:
fmt.Printf("commit requested with message \"%s\"\n", args.Commit.Message)
case args.Push != nil:
fmt.Printf("push requested from %s to %s\n", args.Push.Branch, args.Push.Remote)
}
```
Some additional rules apply when working with subcommands:
* The `subcommand` tag can only be used with fields that are pointers to structs
* Any struct that contains a subcommand must not contain any positionals
### API Documentation
https://godoc.org/github.com/alexflint/go-arg

View File

@ -104,7 +104,7 @@ func Example_multipleMixed() {
}
// This example shows the usage string generated by go-arg
func Example_usageString() {
func Example_helpText() {
// These are the args you would pass in on the command line
os.Args = split("./example --help")
@ -135,3 +135,167 @@ func Example_usageString() {
// optimization level
// --help, -h display this help and exit
}
// This example shows the usage string generated by go-arg when using subcommands
func Example_helpTextWithSubcommand() {
// These are the args you would pass in on the command line
os.Args = split("./example --help")
type getCmd struct {
Item string `arg:"positional" help:"item to fetch"`
}
type listCmd struct {
Format string `help:"output format"`
Limit int
}
var args struct {
Verbose bool
Get *getCmd `arg:"subcommand" help:"fetch an item and print it"`
List *listCmd `arg:"subcommand" help:"list available items"`
}
// This is only necessary when running inside golang's runnable example harness
osExit = func(int) {}
MustParse(&args)
// output:
// Usage: example [--verbose]
//
// Options:
// --verbose
// --help, -h display this help and exit
//
// Commands:
// get fetch an item and print it
// list list available items
}
// This example shows the usage string generated by go-arg when using subcommands
func Example_helpTextForSubcommand() {
// These are the args you would pass in on the command line
os.Args = split("./example get --help")
type getCmd struct {
Item string `arg:"positional" help:"item to fetch"`
}
type listCmd struct {
Format string `help:"output format"`
Limit int
}
var args struct {
Verbose bool
Get *getCmd `arg:"subcommand" help:"fetch an item and print it"`
List *listCmd `arg:"subcommand" help:"list available items"`
}
// This is only necessary when running inside golang's runnable example harness
osExit = func(int) {}
MustParse(&args)
// output:
// Usage: example get ITEM
//
// Positional arguments:
// ITEM item to fetch
//
// Options:
// --help, -h display this help and exit
}
// This example shows the error string generated by go-arg when an invalid option is provided
func Example_errorText() {
// These are the args you would pass in on the command line
os.Args = split("./example --optimize INVALID")
var args struct {
Input string `arg:"positional"`
Output []string `arg:"positional"`
Verbose bool `arg:"-v" help:"verbosity level"`
Dataset string `help:"dataset to use"`
Optimize int `arg:"-O,help:optimization level"`
}
// This is only necessary when running inside golang's runnable example harness
osExit = func(int) {}
stderr = os.Stdout
MustParse(&args)
// output:
// Usage: example [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] INPUT [OUTPUT [OUTPUT ...]]
// error: error processing --optimize: strconv.ParseInt: parsing "INVALID": invalid syntax
}
// This example shows the error string generated by go-arg when an invalid option is provided
func Example_errorTextForSubcommand() {
// These are the args you would pass in on the command line
os.Args = split("./example get --count INVALID")
type getCmd struct {
Count int
}
var args struct {
Get *getCmd `arg:"subcommand"`
}
// This is only necessary when running inside golang's runnable example harness
osExit = func(int) {}
stderr = os.Stdout
MustParse(&args)
// output:
// Usage: example get [--count COUNT]
// error: error processing --count: strconv.ParseInt: parsing "INVALID": invalid syntax
}
// This example demonstrates use of subcommands
func Example_subcommand() {
// These are the args you would pass in on the command line
os.Args = split("./example commit -a -m what-this-commit-is-about")
type CheckoutCmd struct {
Branch string `arg:"positional"`
Track bool `arg:"-t"`
}
type CommitCmd struct {
All bool `arg:"-a"`
Message string `arg:"-m"`
}
type PushCmd struct {
Remote string `arg:"positional"`
Branch string `arg:"positional"`
SetUpstream bool `arg:"-u"`
}
var args struct {
Checkout *CheckoutCmd `arg:"subcommand:checkout"`
Commit *CommitCmd `arg:"subcommand:commit"`
Push *PushCmd `arg:"subcommand:push"`
Quiet bool `arg:"-q"` // this flag is global to all subcommands
}
// This is only necessary when running inside golang's runnable example harness
osExit = func(int) {}
stderr = os.Stdout
MustParse(&args)
switch {
case args.Checkout != nil:
fmt.Printf("checkout requested for branch %s\n", args.Checkout.Branch)
case args.Commit != nil:
fmt.Printf("commit requested with message \"%s\"\n", args.Commit.Message)
case args.Push != nil:
fmt.Printf("push requested from %s to %s\n", args.Push.Branch, args.Push.Remote)
}
// output:
// commit requested with message "what-this-commit-is-about"
}

522
parse.go
View File

@ -1,7 +1,6 @@
package arg
import (
"encoding"
"encoding/csv"
"errors"
"fmt"
@ -16,9 +15,36 @@ import (
// to enable monkey-patching during tests
var osExit = os.Exit
// path represents a sequence of steps to find the output location for an
// argument or subcommand in the final destination struct
type path struct {
root int // index of the destination struct
fields []string // sequence of struct field names to traverse
}
// String gets a string representation of the given path
func (p path) String() string {
if len(p.fields) == 0 {
return "args"
}
return "args." + strings.Join(p.fields, ".")
}
// Child gets a new path representing a child of this path.
func (p path) Child(child string) path {
// copy the entire slice of fields to avoid possible slice overwrite
subfields := make([]string, len(p.fields)+1)
copy(subfields, append(p.fields, child))
return path{
root: p.root,
fields: subfields,
}
}
// spec represents a command line option
type spec struct {
dest reflect.Value
dest path
typ reflect.Type
long string
short string
multiple bool
@ -30,6 +56,16 @@ type spec struct {
boolean bool
}
// command represents a named subcommand, or the top-level command
type command struct {
name string
help string
dest path
specs []*spec
subcommands []*command
parent *command
}
// ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user")
@ -42,18 +78,19 @@ func MustParse(dest ...interface{}) *Parser {
if err != nil {
fmt.Println(err)
osExit(-1)
return nil // just in case osExit was monkey-patched
}
err = p.Parse(flags())
switch {
case err == ErrHelp:
p.WriteHelp(os.Stdout)
p.writeHelpForCommand(os.Stdout, p.lastCmd)
osExit(0)
case err == ErrVersion:
fmt.Println(p.version)
osExit(0)
case err != nil:
p.Fail(err.Error())
p.failWithCommand(err.Error(), p.lastCmd)
}
return p
@ -83,10 +120,14 @@ type Config struct {
// Parser represents a set of command line options with destination values
type Parser struct {
specs []*spec
cmd *command
roots []reflect.Value
config Config
version string
description string
// the following fields change curing processing of command line arguments
lastCmd *command
}
// Versioned is the interface that the destination struct should implement to
@ -106,66 +147,180 @@ type Described interface {
}
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) {
t := v.Type()
func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
val := v.Field(i)
expand := visit(field, val, t)
expand := visit(field, t)
if expand && field.Type.Kind() == reflect.Struct {
walkFields(val, visit)
walkFields(field.Type, visit)
}
}
}
// NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) {
// first pick a name for the command for use in the usage text
var name string
switch {
case config.Program != "":
name = config.Program
case len(os.Args) > 0:
name = filepath.Base(os.Args[0])
default:
name = "program"
}
// construct a parser
p := Parser{
cmd: &command{name: name},
config: config,
}
// make a list of roots
for _, dest := range dests {
p.roots = append(p.roots, reflect.ValueOf(dest))
}
// process each of the destination values
for i, dest := range dests {
t := reflect.TypeOf(dest)
if t.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
}
cmd, err := cmdFromStruct(name, path{root: i}, t)
if err != nil {
return nil, err
}
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
if dest, ok := dest.(Versioned); ok {
p.version = dest.Version()
}
if dest, ok := dest.(Described); ok {
p.description = dest.Description()
}
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type()))
}
v = v.Elem()
if v.Kind() != reflect.Struct {
panic(fmt.Sprintf("%T is not a struct pointer", dest))
}
return &p, nil
}
func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
// commands can only be created from pointers to structs
if t.Kind() != reflect.Ptr {
return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s",
dest, t.Kind())
}
t = t.Elem()
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s",
dest, t.Kind())
}
cmd := command{
name: name,
dest: dest,
}
var errs []string
walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
// Check for the ignore switch in the tag
tag := field.Tag.Get("arg")
if tag == "-" {
return false
}
var errs []string
walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool {
// Check for the ignore switch in the tag
tag := field.Tag.Get("arg")
if tag == "-" {
return false
}
// If this is an embedded struct then recurse into its fields
if field.Anonymous && field.Type.Kind() == reflect.Struct {
return true
}
// If this is an embedded struct then recurse into its fields
if field.Anonymous && field.Type.Kind() == reflect.Struct {
return true
}
// duplicate the entire path to avoid slice overwrites
subdest := dest.Child(field.Name)
spec := spec{
dest: subdest,
long: strings.ToLower(field.Name),
typ: field.Type,
}
spec := spec{
long: strings.ToLower(field.Name),
dest: val,
}
help, exists := field.Tag.Lookup("help")
if exists {
spec.help = help
}
help, exists := field.Tag.Lookup("help")
if exists {
spec.help = help
}
// Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
if tag != "" {
for _, key := range strings.Split(tag, ",") {
key = strings.TrimLeft(key, " ")
var value string
if pos := strings.Index(key, ":"); pos != -1 {
value = key[pos+1:]
key = key[:pos]
}
switch {
case strings.HasPrefix(key, "---"):
errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name))
case strings.HasPrefix(key, "--"):
spec.long = key[2:]
case strings.HasPrefix(key, "-"):
if len(key) != 2 {
errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
t.Name(), field.Name))
return false
}
spec.short = key[1:]
case key == "required":
spec.required = true
case key == "positional":
spec.positional = true
case key == "separate":
spec.separate = true
case key == "help": // deprecated
spec.help = value
case key == "env":
// Use override name if provided
if value != "" {
spec.env = value
} else {
spec.env = strings.ToUpper(field.Name)
}
case key == "subcommand":
// decide on a name for the subcommand
cmdname := value
if cmdname == "" {
cmdname = strings.ToLower(field.Name)
}
// parse the subcommand recursively
subcmd, err := cmdFromStruct(cmdname, subdest, field.Type)
if err != nil {
errs = append(errs, err.Error())
return false
}
subcmd.parent = &cmd
subcmd.help = field.Tag.Get("help")
cmd.subcommands = append(cmd.subcommands, subcmd)
isSubcommand = true
default:
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
return false
}
}
}
// Check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
if !isSubcommand {
cmd.specs = append(cmd.specs, &spec)
// Check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
var parseable bool
parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable {
@ -173,110 +328,50 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
t.Name(), field.Name, field.Type.String()))
return false
}
}
// Look at the tag
if tag != "" {
for _, key := range strings.Split(tag, ",") {
key = strings.TrimLeft(key, " ")
var value string
if pos := strings.Index(key, ":"); pos != -1 {
value = key[pos+1:]
key = key[:pos]
}
// if this was an embedded field then we already returned true up above
return false
})
switch {
case strings.HasPrefix(key, "---"):
errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name))
case strings.HasPrefix(key, "--"):
spec.long = key[2:]
case strings.HasPrefix(key, "-"):
if len(key) != 2 {
errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
t.Name(), field.Name))
return false
}
spec.short = key[1:]
case key == "required":
spec.required = true
case key == "positional":
spec.positional = true
case key == "separate":
spec.separate = true
case key == "help": // deprecated
spec.help = value
case key == "env":
// Use override name if provided
if value != "" {
spec.env = value
} else {
spec.env = strings.ToUpper(field.Name)
}
default:
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
return false
}
}
}
p.specs = append(p.specs, &spec)
if len(errs) > 0 {
return nil, errors.New(strings.Join(errs, "\n"))
}
// if this was an embedded field then we already returned true up above
return false
})
if len(errs) > 0 {
return nil, errors.New(strings.Join(errs, "\n"))
// check that we don't have both positionals and subcommands
var hasPositional bool
for _, spec := range cmd.specs {
if spec.positional {
hasPositional = true
}
}
if p.config.Program == "" {
p.config.Program = "program"
if len(os.Args) > 0 {
p.config.Program = filepath.Base(os.Args[0])
}
if hasPositional && len(cmd.subcommands) > 0 {
return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest)
}
return &p, nil
return &cmd, nil
}
// Parse processes the given command line option, storing the results in the field
// of the structs from which NewParser was constructed
func (p *Parser) Parse(args []string) error {
// If -h or --help were specified then print usage
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return ErrHelp
}
if arg == "--version" {
return ErrVersion
}
if arg == "--" {
break
err := p.process(args)
if err != nil {
// If -h or --help were specified then make sure help text supercedes other errors
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return ErrHelp
}
if arg == "--" {
break
}
}
}
// Process all command line arguments
return process(p.specs, args)
return err
}
// process goes through arguments one-by-one, parses them, and assigns the result to
// the underlying struct field
func process(specs []*spec, args []string) error {
// track the options we have seen
wasPresent := make(map[*spec]bool)
// construct a map from --option to spec
optionMap := make(map[string]*spec)
for _, spec := range specs {
if spec.positional {
continue
}
if spec.long != "" {
optionMap[spec.long] = spec
}
if spec.short != "" {
optionMap[spec.short] = spec
}
}
// deal with environment vars
// process environment vars for the given arguments
func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error {
for _, spec := range specs {
if spec.env == "" {
continue
@ -298,7 +393,7 @@ func process(specs []*spec, args []string) error {
err,
)
}
if err = setSlice(spec.dest, values, !spec.separate); err != nil {
if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil {
return fmt.Errorf(
"error processing environment variable %s with multiple values: %v",
spec.env,
@ -306,13 +401,36 @@ func process(specs []*spec, args []string) error {
)
}
} else {
if err := scalar.ParseValue(spec.dest, value); err != nil {
if err := scalar.ParseValue(p.val(spec.dest), value); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
}
}
wasPresent[spec] = true
}
return nil
}
// process goes through arguments one-by-one, parses them, and assigns the result to
// the underlying struct field
func (p *Parser) process(args []string) error {
// track the options we have seen
wasPresent := make(map[*spec]bool)
// union of specs for the chain of subcommands encountered so far
curCmd := p.cmd
p.lastCmd = curCmd
// make a copy of the specs because we will add to this list each time we expand a subcommand
specs := make([]*spec, len(curCmd.specs))
copy(specs, curCmd.specs)
// deal with environment vars
err := p.captureEnvVars(specs, wasPresent)
if err != nil {
return err
}
// process each string from the command line
var allpositional bool
var positionals []string
@ -326,10 +444,44 @@ func process(specs []*spec, args []string) error {
}
if !isFlag(arg) || allpositional {
positionals = append(positionals, arg)
// each subcommand can have either subcommands or positionals, but not both
if len(curCmd.subcommands) == 0 {
positionals = append(positionals, arg)
continue
}
// if we have a subcommand then make sure it is valid for the current context
subcmd := findSubcommand(curCmd.subcommands, arg)
if subcmd == nil {
return fmt.Errorf("invalid subcommand: %s", arg)
}
// instantiate the field to point to a new struct
v := p.val(subcmd.dest)
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
// add the new options to the set of allowed options
specs = append(specs, subcmd.specs...)
// capture environment vars for these new options
err := p.captureEnvVars(subcmd.specs, wasPresent)
if err != nil {
return err
}
curCmd = subcmd
p.lastCmd = curCmd
continue
}
// check for special --help and --version flags
switch arg {
case "-h", "--help":
return ErrHelp
case "--version":
return ErrVersion
}
// check for an equals sign, as in "--foo=bar"
var value string
opt := strings.TrimLeft(arg, "-")
@ -338,9 +490,10 @@ func process(specs []*spec, args []string) error {
opt = opt[:pos]
}
// lookup the spec for this option
spec, ok := optionMap[opt]
if !ok {
// lookup the spec for this option (note that the "specs" slice changes as
// we expand subcommands so it is better not to use a map)
spec := findOption(specs, opt)
if spec == nil {
return fmt.Errorf("unknown argument %s", arg)
}
wasPresent[spec] = true
@ -359,7 +512,7 @@ func process(specs []*spec, args []string) error {
} else {
values = append(values, value)
}
err := setSlice(spec.dest, values, !spec.separate)
err := setSlice(p.val(spec.dest), values, !spec.separate)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@ -377,14 +530,14 @@ func process(specs []*spec, args []string) error {
if i+1 == len(args) {
return fmt.Errorf("missing value for %s", arg)
}
if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) {
if !nextIsNumeric(spec.typ, args[i+1]) && isFlag(args[i+1]) {
return fmt.Errorf("missing value for %s", arg)
}
value = args[i+1]
i++
}
err := scalar.ParseValue(spec.dest, value)
err := scalar.ParseValue(p.val(spec.dest), value)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
@ -400,13 +553,13 @@ func process(specs []*spec, args []string) error {
}
wasPresent[spec] = true
if spec.multiple {
err := setSlice(spec.dest, positionals, true)
err := setSlice(p.val(spec.dest), positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
positionals = nil
} else {
err := scalar.ParseValue(spec.dest, positionals[0])
err := scalar.ParseValue(p.val(spec.dest), positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
@ -449,6 +602,30 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
// val returns a reflect.Value corresponding to the current value for the
// given path
func (p *Parser) val(dest path) reflect.Value {
v := p.roots[dest.root]
for _, field := range dest.fields {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return reflect.Value{}
}
v = v.Elem()
}
v = v.FieldByName(field)
if !v.IsValid() {
// it is appropriate to panic here because this can only happen due to
// an internal bug in this library (since we construct the path ourselves
// by reflecting on the same struct)
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
dest.fields, v.Type(), field))
}
}
return v
}
// parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() {
@ -480,56 +657,25 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error {
return nil
}
// canParse returns true if the type can be parsed from a string
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
// findOption finds an option from its name, or returns null if no spec is found
func findOption(specs []*spec, name string) *spec {
for _, spec := range specs {
if spec.positional {
continue
}
if spec.long == name || spec.short == name {
return spec
}
}
// Look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Look inside slice types
if t.Kind() == reflect.Slice {
multiple = true
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
// Look inside pointer types (again, in case of []*Type)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
return false, false, false
return nil
}
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// isBoolean returns true if the type can be parsed from a single string
func isBoolean(t reflect.Type) bool {
switch {
case t.Implements(textUnmarshalerType):
return false
case t.Kind() == reflect.Bool:
return true
case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
return true
default:
return false
// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found
func findSubcommand(cmds []*command, name string) *command {
for _, cmd := range cmds {
if cmd.name == name {
return cmd
}
}
return nil
}

View File

@ -19,15 +19,20 @@ func setenv(t *testing.T, name, val string) {
}
func parse(cmdline string, dest interface{}) error {
_, err := pparse(cmdline, dest)
return err
}
func pparse(cmdline string, dest interface{}) (*Parser, error) {
p, err := NewParser(Config{}, dest)
if err != nil {
return err
return nil, err
}
var parts []string
if len(cmdline) > 0 {
parts = strings.Split(cmdline, " ")
}
return p.Parse(parts)
return p, p.Parse(parts)
}
func TestString(t *testing.T) {
@ -371,7 +376,7 @@ func TestNonsenseKey(t *testing.T) {
assert.Error(t, err)
}
func TestMissingValue(t *testing.T) {
func TestMissingValueAtEnd(t *testing.T) {
var args struct {
Foo string
}
@ -379,6 +384,24 @@ func TestMissingValue(t *testing.T) {
assert.Error(t, err)
}
func TestMissingValueInMIddle(t *testing.T) {
var args struct {
Foo string
Bar string
}
err := parse("--foo --bar=abc", &args)
assert.Error(t, err)
}
func TestNegativeValue(t *testing.T) {
var args struct {
Foo int
}
err := parse("--foo -123", &args)
require.NoError(t, err)
assert.Equal(t, -123, args.Foo)
}
func TestInvalidInt(t *testing.T) {
var args struct {
Foo int
@ -462,11 +485,10 @@ func TestPanicOnNonPointer(t *testing.T) {
})
}
func TestPanicOnNonStruct(t *testing.T) {
func TestErrorOnNonStruct(t *testing.T) {
var args string
assert.Panics(t, func() {
_ = parse("", &args)
})
err := parse("", &args)
assert.Error(t, err)
}
func TestUnsupportedType(t *testing.T) {
@ -540,6 +562,15 @@ func TestEnvironmentVariable(t *testing.T) {
assert.Equal(t, "bar", args.Foo)
}
func TestEnvironmentVariableNotPresent(t *testing.T) {
var args struct {
NotPresent string `arg:"env"`
}
os.Args = []string{"example"}
MustParse(&args)
assert.Equal(t, "", args.NotPresent)
}
func TestEnvironmentVariableOverrideName(t *testing.T) {
var args struct {
Foo string `arg:"env:BAZ"`
@ -584,7 +615,7 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) {
var args struct {
Foo []string `arg:"env"`
}
setenv(t, "FOO", "bar,\"baz, qux\"")
setenv(t, "FOO", `bar,"baz, qux"`)
MustParse(&args)
assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo)
}
@ -846,6 +877,28 @@ func TestEmbedded(t *testing.T) {
assert.Equal(t, true, args.Z)
}
func TestEmbeddedPtr(t *testing.T) {
// embedded pointer fields are not supported so this should return an error
var args struct {
*A
}
err := parse("--x=hello", &args)
require.Error(t, err)
}
func TestEmbeddedPtrIgnored(t *testing.T) {
// embedded pointer fields are not normally supported but here
// we explicitly exclude it so the non-nil embedded structs
// should work as expected
var args struct {
*A `arg:"-"`
B
}
err := parse("--y=321", &args)
require.NoError(t, err)
assert.Equal(t, 321, args.Y)
}
func TestEmptyArgs(t *testing.T) {
origArgs := os.Args
@ -985,3 +1038,10 @@ func TestReuseParser(t *testing.T) {
err = p.Parse([]string{})
assert.Error(t, err)
}
func TestVersion(t *testing.T) {
var args struct{}
err := parse("--version", &args)
assert.Equal(t, ErrVersion, err)
}

62
reflect.go Normal file
View File

@ -0,0 +1,62 @@
package arg
import (
"encoding"
"reflect"
scalar "github.com/alexflint/go-scalar"
)
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// canParse returns true if the type can be parsed from a string
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
// Look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Look inside slice types
if t.Kind() == reflect.Slice {
multiple = true
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
// Look inside pointer types (again, in case of []*Type)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
return false, false, false
}
// isBoolean returns true if the type can be parsed from a single string
func isBoolean(t reflect.Type) bool {
switch {
case t.Implements(textUnmarshalerType):
return false
case t.Kind() == reflect.Bool:
return true
case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
return true
default:
return false
}
}

55
reflect_test.go Normal file
View File

@ -0,0 +1,55 @@
package arg
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) {
p, b, m := canParse(typ)
assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p)
assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b)
assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m)
}
func TestCanParse(t *testing.T) {
var b bool
var i int
var s string
var f float64
var bs []bool
var is []int
assertCanParse(t, reflect.TypeOf(b), true, true, false)
assertCanParse(t, reflect.TypeOf(i), true, false, false)
assertCanParse(t, reflect.TypeOf(s), true, false, false)
assertCanParse(t, reflect.TypeOf(f), true, false, false)
assertCanParse(t, reflect.TypeOf(&b), true, true, false)
assertCanParse(t, reflect.TypeOf(&s), true, false, false)
assertCanParse(t, reflect.TypeOf(&i), true, false, false)
assertCanParse(t, reflect.TypeOf(&f), true, false, false)
assertCanParse(t, reflect.TypeOf(bs), true, true, true)
assertCanParse(t, reflect.TypeOf(&bs), true, true, true)
assertCanParse(t, reflect.TypeOf(is), true, false, true)
assertCanParse(t, reflect.TypeOf(&is), true, false, true)
}
type implementsTextUnmarshaler struct{}
func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
return nil
}
func TestCanParseTextUnmarshaler(t *testing.T) {
var u implementsTextUnmarshaler
var su []implementsTextUnmarshaler
assertCanParse(t, reflect.TypeOf(u), true, false, false)
assertCanParse(t, reflect.TypeOf(&u), true, false, false)
assertCanParse(t, reflect.TypeOf(su), true, false, true)
assertCanParse(t, reflect.TypeOf(&su), true, false, true)
}

37
subcommand.go Normal file
View File

@ -0,0 +1,37 @@
package arg
// Subcommand returns the user struct for the subcommand selected by
// the command line arguments most recently processed by the parser.
// The return value is always a pointer to a struct. If no subcommand
// was specified then it returns the top-level arguments struct. If
// no command line arguments have been processed by this parser then it
// returns nil.
func (p *Parser) Subcommand() interface{} {
if p.lastCmd == nil || p.lastCmd.parent == nil {
return nil
}
return p.val(p.lastCmd.dest).Interface()
}
// SubcommandNames returns the sequence of subcommands specified by the
// user. If no subcommands were given then it returns an empty slice.
func (p *Parser) SubcommandNames() []string {
if p.lastCmd == nil {
return nil
}
// make a list of ancestor commands
var ancestors []string
cur := p.lastCmd
for cur.parent != nil { // we want to exclude the root
ancestors = append(ancestors, cur.name)
cur = cur.parent
}
// reverse the list
out := make([]string, len(ancestors))
for i := 0; i < len(ancestors); i++ {
out[i] = ancestors[len(ancestors)-i-1]
}
return out
}

355
subcommand_test.go Normal file
View File

@ -0,0 +1,355 @@
package arg
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// This file contains tests for parse.go but I decided to put them here
// since that file is getting large
func TestSubcommandNotAPointer(t *testing.T) {
var args struct {
A string `arg:"subcommand"`
}
_, err := NewParser(Config{}, &args)
assert.Error(t, err)
}
func TestSubcommandNotAPointerToStruct(t *testing.T) {
var args struct {
A struct{} `arg:"subcommand"`
}
_, err := NewParser(Config{}, &args)
assert.Error(t, err)
}
func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
var args struct {
A string `arg:"positional"`
B *struct{} `arg:"subcommand"`
}
_, err := NewParser(Config{}, &args)
assert.Error(t, err)
}
func TestMinimalSubcommand(t *testing.T) {
type listCmd struct {
}
var args struct {
List *listCmd `arg:"subcommand"`
}
p, err := pparse("list", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, args.List, p.Subcommand())
assert.Equal(t, []string{"list"}, p.SubcommandNames())
}
func TestNoSuchSubcommand(t *testing.T) {
type listCmd struct {
}
var args struct {
List *listCmd `arg:"subcommand"`
}
_, err := pparse("invalid", &args)
assert.Error(t, err)
}
func TestNamedSubcommand(t *testing.T) {
type listCmd struct {
}
var args struct {
List *listCmd `arg:"subcommand:ls"`
}
p, err := pparse("ls", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, args.List, p.Subcommand())
assert.Equal(t, []string{"ls"}, p.SubcommandNames())
}
func TestEmptySubcommand(t *testing.T) {
type listCmd struct {
}
var args struct {
List *listCmd `arg:"subcommand"`
}
p, err := pparse("", &args)
require.NoError(t, err)
assert.Nil(t, args.List)
assert.Nil(t, p.Subcommand())
assert.Empty(t, p.SubcommandNames())
}
func TestTwoSubcommands(t *testing.T) {
type getCmd struct {
}
type listCmd struct {
}
var args struct {
Get *getCmd `arg:"subcommand"`
List *listCmd `arg:"subcommand"`
}
p, err := pparse("list", &args)
require.NoError(t, err)
assert.Nil(t, args.Get)
assert.NotNil(t, args.List)
assert.Equal(t, args.List, p.Subcommand())
assert.Equal(t, []string{"list"}, p.SubcommandNames())
}
func TestSubcommandsWithOptions(t *testing.T) {
type getCmd struct {
Name string
}
type listCmd struct {
Limit int
}
type cmd struct {
Verbose bool
Get *getCmd `arg:"subcommand"`
List *listCmd `arg:"subcommand"`
}
{
var args cmd
err := parse("list", &args)
require.NoError(t, err)
assert.Nil(t, args.Get)
assert.NotNil(t, args.List)
}
{
var args cmd
err := parse("list --limit 3", &args)
require.NoError(t, err)
assert.Nil(t, args.Get)
assert.NotNil(t, args.List)
assert.Equal(t, args.List.Limit, 3)
}
{
var args cmd
err := parse("list --limit 3 --verbose", &args)
require.NoError(t, err)
assert.Nil(t, args.Get)
assert.NotNil(t, args.List)
assert.Equal(t, args.List.Limit, 3)
assert.True(t, args.Verbose)
}
{
var args cmd
err := parse("list --verbose --limit 3", &args)
require.NoError(t, err)
assert.Nil(t, args.Get)
assert.NotNil(t, args.List)
assert.Equal(t, args.List.Limit, 3)
assert.True(t, args.Verbose)
}
{
var args cmd
err := parse("--verbose list --limit 3", &args)
require.NoError(t, err)
assert.Nil(t, args.Get)
assert.NotNil(t, args.List)
assert.Equal(t, args.List.Limit, 3)
assert.True(t, args.Verbose)
}
{
var args cmd
err := parse("get", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Nil(t, args.List)
}
{
var args cmd
err := parse("get --name test", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Nil(t, args.List)
assert.Equal(t, args.Get.Name, "test")
}
}
func TestNestedSubcommands(t *testing.T) {
type child struct{}
type parent struct {
Child *child `arg:"subcommand"`
}
type grandparent struct {
Parent *parent `arg:"subcommand"`
}
type root struct {
Grandparent *grandparent `arg:"subcommand"`
}
{
var args root
p, err := pparse("grandparent parent child", &args)
require.NoError(t, err)
require.NotNil(t, args.Grandparent)
require.NotNil(t, args.Grandparent.Parent)
require.NotNil(t, args.Grandparent.Parent.Child)
assert.Equal(t, args.Grandparent.Parent.Child, p.Subcommand())
assert.Equal(t, []string{"grandparent", "parent", "child"}, p.SubcommandNames())
}
{
var args root
p, err := pparse("grandparent parent", &args)
require.NoError(t, err)
require.NotNil(t, args.Grandparent)
require.NotNil(t, args.Grandparent.Parent)
require.Nil(t, args.Grandparent.Parent.Child)
assert.Equal(t, args.Grandparent.Parent, p.Subcommand())
assert.Equal(t, []string{"grandparent", "parent"}, p.SubcommandNames())
}
{
var args root
p, err := pparse("grandparent", &args)
require.NoError(t, err)
require.NotNil(t, args.Grandparent)
require.Nil(t, args.Grandparent.Parent)
assert.Equal(t, args.Grandparent, p.Subcommand())
assert.Equal(t, []string{"grandparent"}, p.SubcommandNames())
}
{
var args root
p, err := pparse("", &args)
require.NoError(t, err)
require.Nil(t, args.Grandparent)
assert.Nil(t, p.Subcommand())
assert.Empty(t, p.SubcommandNames())
}
}
func TestSubcommandsWithPositionals(t *testing.T) {
type listCmd struct {
Pattern string `arg:"positional"`
}
type cmd struct {
Format string
List *listCmd `arg:"subcommand"`
}
{
var args cmd
err := parse("list", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, "", args.List.Pattern)
}
{
var args cmd
err := parse("list --format json", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, "", args.List.Pattern)
assert.Equal(t, "json", args.Format)
}
{
var args cmd
err := parse("list somepattern", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, "somepattern", args.List.Pattern)
}
{
var args cmd
err := parse("list somepattern --format json", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, "somepattern", args.List.Pattern)
assert.Equal(t, "json", args.Format)
}
{
var args cmd
err := parse("list --format json somepattern", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, "somepattern", args.List.Pattern)
assert.Equal(t, "json", args.Format)
}
{
var args cmd
err := parse("--format json list somepattern", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
assert.Equal(t, "somepattern", args.List.Pattern)
assert.Equal(t, "json", args.Format)
}
{
var args cmd
err := parse("--format json", &args)
require.NoError(t, err)
assert.Nil(t, args.List)
assert.Equal(t, "json", args.Format)
}
}
func TestSubcommandsWithMultiplePositionals(t *testing.T) {
type getCmd struct {
Items []string `arg:"positional"`
}
type cmd struct {
Limit int
Get *getCmd `arg:"subcommand"`
}
{
var args cmd
err := parse("get", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Empty(t, args.Get.Items)
}
{
var args cmd
err := parse("get --limit 5", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Empty(t, args.Get.Items)
assert.Equal(t, 5, args.Limit)
}
{
var args cmd
err := parse("get item1", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Equal(t, []string{"item1"}, args.Get.Items)
}
{
var args cmd
err := parse("get item1 item2 item3", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Equal(t, []string{"item1", "item2", "item3"}, args.Get.Items)
}
{
var args cmd
err := parse("get item1 --limit 5 item2", &args)
require.NoError(t, err)
assert.NotNil(t, args.Get)
assert.Equal(t, []string{"item1", "item2"}, args.Get.Items)
assert.Equal(t, 5, args.Limit)
}
}

129
usage.go
View File

@ -12,17 +12,30 @@ import (
// the width of the left column
const colWidth = 25
// to allow monkey patching in tests
var stderr = os.Stderr
// Fail prints usage information to stderr and exits with non-zero status
func (p *Parser) Fail(msg string) {
p.WriteUsage(os.Stderr)
fmt.Fprintln(os.Stderr, "error:", msg)
os.Exit(-1)
p.failWithCommand(msg, p.cmd)
}
// failWithCommand prints usage information for the given subcommand to stderr and exits with non-zero status
func (p *Parser) failWithCommand(msg string, cmd *command) {
p.writeUsageForCommand(stderr, cmd)
fmt.Fprintln(stderr, "error:", msg)
osExit(-1)
}
// WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) {
p.writeUsageForCommand(w, p.cmd)
}
// writeUsageForCommand writes usage information for the given subcommand
func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
var positionals, options []*spec
for _, spec := range p.specs {
for _, spec := range cmd.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {
@ -34,7 +47,19 @@ func (p *Parser) WriteUsage(w io.Writer) {
fmt.Fprintln(w, p.version)
}
fmt.Fprintf(w, "Usage: %s", p.config.Program)
// make a list of ancestor commands so that we print with full context
var ancestors []string
ancestor := cmd
for ancestor != nil {
ancestors = append(ancestors, ancestor.name)
ancestor = ancestor.parent
}
// print the beginning of the usage string
fmt.Fprint(w, "Usage:")
for i := len(ancestors) - 1; i >= 0; i-- {
fmt.Fprint(w, " "+ancestors[i])
}
// write the option component of the usage message
for _, spec := range options {
@ -69,10 +94,32 @@ func (p *Parser) WriteUsage(w io.Writer) {
fmt.Fprint(w, "\n")
}
func printTwoCols(w io.Writer, left, help string, defaultVal *string) {
lhs := " " + left
fmt.Fprint(w, lhs)
if help != "" {
if len(lhs)+2 < colWidth {
fmt.Fprint(w, strings.Repeat(" ", colWidth-len(lhs)))
} else {
fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth))
}
fmt.Fprint(w, help)
}
if defaultVal != nil {
fmt.Fprintf(w, " [default: %s]", *defaultVal)
}
fmt.Fprint(w, "\n")
}
// WriteHelp writes the usage string followed by the full help string for each option
func (p *Parser) WriteHelp(w io.Writer) {
p.writeHelpForCommand(w, p.cmd)
}
// writeHelp writes the usage string for the given subcommand
func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
var positionals, options []*spec
for _, spec := range p.specs {
for _, spec := range cmd.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {
@ -83,70 +130,74 @@ func (p *Parser) WriteHelp(w io.Writer) {
if p.description != "" {
fmt.Fprintln(w, p.description)
}
p.WriteUsage(w)
p.writeUsageForCommand(w, cmd)
// write the list of positionals
if len(positionals) > 0 {
fmt.Fprint(w, "\nPositional arguments:\n")
for _, spec := range positionals {
left := " " + strings.ToUpper(spec.long)
fmt.Fprint(w, left)
if spec.help != "" {
if len(left)+2 < colWidth {
fmt.Fprint(w, strings.Repeat(" ", colWidth-len(left)))
} else {
fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth))
}
fmt.Fprint(w, spec.help)
}
fmt.Fprint(w, "\n")
printTwoCols(w, strings.ToUpper(spec.long), spec.help, nil)
}
}
// write the list of options
fmt.Fprint(w, "\nOptions:\n")
for _, spec := range options {
printOption(w, spec)
p.printOption(w, spec)
}
// write the list of built in options
printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"})
p.printOption(w, &spec{
boolean: true,
long: "help",
short: "h",
help: "display this help and exit",
})
if p.version != "" {
printOption(w, &spec{boolean: true, long: "version", help: "display version and exit"})
p.printOption(w, &spec{
boolean: true,
long: "version",
help: "display version and exit",
})
}
// write the list of subcommands
if len(cmd.subcommands) > 0 {
fmt.Fprint(w, "\nCommands:\n")
for _, subcmd := range cmd.subcommands {
printTwoCols(w, subcmd.name, subcmd.help, nil)
}
}
}
func printOption(w io.Writer, spec *spec) {
left := " " + synopsis(spec, "--"+spec.long)
func (p *Parser) printOption(w io.Writer, spec *spec) {
left := synopsis(spec, "--"+spec.long)
if spec.short != "" {
left += ", " + synopsis(spec, "-"+spec.short)
}
fmt.Fprint(w, left)
if spec.help != "" {
if len(left)+2 < colWidth {
fmt.Fprint(w, strings.Repeat(" ", colWidth-len(left)))
} else {
fmt.Fprint(w, "\n"+strings.Repeat(" ", colWidth))
}
fmt.Fprint(w, spec.help)
}
// If spec.dest is not the zero value then a default value has been added.
v := spec.dest
var v reflect.Value
if len(spec.dest.fields) > 0 {
v = p.val(spec.dest)
}
var defaultVal *string
if v.IsValid() {
z := reflect.Zero(v.Type())
if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() {
if scalar, ok := v.Interface().(encoding.TextMarshaler); ok {
if value, err := scalar.MarshalText(); err != nil {
fmt.Fprintf(w, " [default: error: %v]", err)
defaultVal = ptrTo(fmt.Sprintf("error: %v", err))
} else {
fmt.Fprintf(w, " [default: %v]", string(value))
defaultVal = ptrTo(fmt.Sprintf("%v", string(value)))
}
} else {
fmt.Fprintf(w, " [default: %v]", v)
defaultVal = ptrTo(fmt.Sprintf("%v", v))
}
}
}
fmt.Fprint(w, "\n")
printTwoCols(w, left, spec.help, defaultVal)
}
func synopsis(spec *spec, form string) string {
@ -155,3 +206,7 @@ func synopsis(spec *spec, form string) string {
}
return form + " " + strings.ToUpper(spec.long)
}
func ptrTo(s string) *string {
return &s
}