introduced path struct

This commit is contained in:
Alex Flint 2019-04-30 13:30:23 -07:00
parent 6a796e2c41
commit af12b7cfc2
3 changed files with 102 additions and 50 deletions

124
parse.go
View File

@ -13,10 +13,32 @@ import (
scalar "github.com/alexflint/go-scalar" scalar "github.com/alexflint/go-scalar"
) )
// path represents a sequence of steps to find the output location for an
// argument or subcommand in the final destination struct
type path struct {
root int // index of the destination struct
fields []string // sequence of struct field names to traverse
}
// String gets a string representation of the given path
func (p path) String() string {
return "args." + strings.Join(p.fields, ".")
}
// Child gets a new path representing a child of this path.
func (p path) Child(child string) path {
// copy the entire slice of fields to avoid possible slice overwrite
subfields := make([]string, len(p.fields)+1)
copy(subfields, append(p.fields, child))
return path{
root: p.root,
fields: subfields,
}
}
// spec represents a command line option // spec represents a command line option
type spec struct { type spec struct {
root int dest path
path []string // sequence of field names
typ reflect.Type typ reflect.Type
long string long string
short string short string
@ -32,6 +54,7 @@ type spec struct {
// command represents a named subcommand, or the top-level command // command represents a named subcommand, or the top-level command
type command struct { type command struct {
name string name string
dest path
specs []*spec specs []*spec
subcommands []*command subcommands []*command
} }
@ -153,11 +176,12 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
} }
cmd, err := cmdFromStruct(name, t, nil, i) cmd, err := cmdFromStruct(name, path{root: i}, t)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.cmd.specs = append(p.cmd.specs, cmd.specs...) p.cmd.specs = append(p.cmd.specs, cmd.specs...)
p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
if dest, ok := dest.(Versioned); ok { if dest, ok := dest.(Versioned); ok {
p.version = dest.Version() p.version = dest.Version()
@ -170,20 +194,24 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
return &p, nil return &p, nil
} }
func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) { func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
// commands can only be created from pointers to structs // commands can only be created from pointers to structs
if t.Kind() != reflect.Ptr { if t.Kind() != reflect.Ptr {
return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a %s", return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s",
strings.Join(path, "."), t.Kind()) dest, t.Kind())
} }
t = t.Elem() t = t.Elem()
if t.Kind() != reflect.Struct { if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a pointer to %s", return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s",
strings.Join(path, "."), t.Kind()) dest, t.Kind())
}
cmd := command{
name: name,
dest: dest,
} }
var cmd command
var errs []string var errs []string
walkFields(t, func(field reflect.StructField, t reflect.Type) bool { walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
// Check for the ignore switch in the tag // Check for the ignore switch in the tag
@ -198,12 +226,9 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
} }
// duplicate the entire path to avoid slice overwrites // duplicate the entire path to avoid slice overwrites
subpath := make([]string, len(path)+1) subdest := dest.Child(field.Name)
copy(subpath, append(path, field.Name))
spec := spec{ spec := spec{
root: root, dest: subdest,
path: subpath,
long: strings.ToLower(field.Name), long: strings.ToLower(field.Name),
typ: field.Type, typ: field.Type,
} }
@ -213,19 +238,8 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
spec.help = help spec.help = help
} }
// Check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
var parseable bool
parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
}
// Look at the tag // Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
if tag != "" { if tag != "" {
for _, key := range strings.Split(tag, ",") { for _, key := range strings.Split(tag, ",") {
key = strings.TrimLeft(key, " ") key = strings.TrimLeft(key, " ")
@ -269,20 +283,37 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
cmdname = strings.ToLower(field.Name) cmdname = strings.ToLower(field.Name)
} }
subcmd, err := cmdFromStruct(cmdname, field.Type, subpath, root) subcmd, err := cmdFromStruct(cmdname, subdest, field.Type)
if err != nil { if err != nil {
errs = append(errs, err.Error()) errs = append(errs, err.Error())
return false return false
} }
cmd.subcommands = append(cmd.subcommands, subcmd) cmd.subcommands = append(cmd.subcommands, subcmd)
isSubcommand = true
fmt.Println("found a subcommand")
default: default:
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
return false return false
} }
} }
} }
cmd.specs = append(cmd.specs, &spec)
// Check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
if !isSubcommand {
cmd.specs = append(cmd.specs, &spec)
var parseable bool
parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
}
}
// if this was an embedded field then we already returned true up above // if this was an embedded field then we already returned true up above
return false return false
@ -303,6 +334,8 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t) return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t)
} }
fmt.Printf("parsed a command with %d subcommands\n", len(cmd.subcommands))
return &cmd, nil return &cmd, nil
} }
@ -349,7 +382,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
err, err,
) )
} }
if err = setSlice(p.writable(spec), values, !spec.separate); err != nil { if err = setSlice(p.writable(spec.dest), values, !spec.separate); err != nil {
return fmt.Errorf( return fmt.Errorf(
"error processing environment variable %s with multiple values: %v", "error processing environment variable %s with multiple values: %v",
spec.env, spec.env,
@ -357,7 +390,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
) )
} }
} else { } else {
if err := scalar.ParseValue(p.writable(spec), value); err != nil { if err := scalar.ParseValue(p.writable(spec.dest), value); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
} }
} }
@ -400,6 +433,7 @@ func (p *Parser) process(args []string) error {
if !isFlag(arg) || allpositional { if !isFlag(arg) || allpositional {
// each subcommand can have either subcommands or positionals, but not both // each subcommand can have either subcommands or positionals, but not both
fmt.Printf("processing %q, with %d subcommands", arg, len(curCmd.subcommands))
if len(curCmd.subcommands) == 0 { if len(curCmd.subcommands) == 0 {
positionals = append(positionals, arg) positionals = append(positionals, arg)
continue continue
@ -454,7 +488,7 @@ func (p *Parser) process(args []string) error {
} else { } else {
values = append(values, value) values = append(values, value)
} }
err := setSlice(p.writable(spec), values, !spec.separate) err := setSlice(p.writable(spec.dest), values, !spec.separate)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
@ -479,7 +513,7 @@ func (p *Parser) process(args []string) error {
i++ i++
} }
err := scalar.ParseValue(p.writable(spec), value) err := scalar.ParseValue(p.writable(spec.dest), value)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
@ -495,13 +529,13 @@ func (p *Parser) process(args []string) error {
} }
wasPresent[spec] = true wasPresent[spec] = true
if spec.multiple { if spec.multiple {
err := setSlice(p.writable(spec), positionals, true) err := setSlice(p.writable(spec.dest), positionals, true)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err) return fmt.Errorf("error processing %s: %v", spec.long, err)
} }
positionals = nil positionals = nil
} else { } else {
err := scalar.ParseValue(p.writable(spec), positionals[0]) err := scalar.ParseValue(p.writable(spec.dest), positionals[0])
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err) return fmt.Errorf("error processing %s: %v", spec.long, err)
} }
@ -546,9 +580,9 @@ func isFlag(s string) bool {
// readable returns a reflect.Value corresponding to the current value for the // readable returns a reflect.Value corresponding to the current value for the
// given // given
func (p *Parser) readable(spec *spec) reflect.Value { func (p *Parser) readable(dest path) reflect.Value {
v := p.roots[spec.root] v := p.roots[dest.root]
for _, field := range spec.path { for _, field := range dest.fields {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.IsNil() { if v.IsNil() {
return reflect.Value{} return reflect.Value{}
@ -559,21 +593,21 @@ func (p *Parser) readable(spec *spec) reflect.Value {
v = v.FieldByName(field) v = v.FieldByName(field)
if !v.IsValid() { if !v.IsValid() {
// it is appropriate to panic here because this can only happen due to // it is appropriate to panic here because this can only happen due to
// an internal bug in this library (since we construct spec.path ourselves // an internal bug in this library (since we construct the path ourselves
// by reflecting on the same struct) // by reflecting on the same struct)
panic(fmt.Errorf("error resolving path %v: %v has no field named %v", panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
spec.path, v.Type(), field)) dest.fields, v.Type(), field))
} }
} }
return v return v
} }
// writable traverses the destination struct to find the destination to // writable trav.patherses the destination struct to find the destination to
// which the value of the given spec should be written. It fills in null // which the value of the given spec should be written. It fills in null
// structs with pointers to the zero value for that struct. // structs with pointers to the zero value for that struct.
func (p *Parser) writable(spec *spec) reflect.Value { func (p *Parser) writable(dest path) reflect.Value {
v := p.roots[spec.root] v := p.roots[dest.root]
for _, field := range spec.path { for _, field := range dest.fields {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
@ -584,10 +618,10 @@ func (p *Parser) writable(spec *spec) reflect.Value {
v = v.FieldByName(field) v = v.FieldByName(field)
if !v.IsValid() { if !v.IsValid() {
// it is appropriate to panic here because this can only happen due to // it is appropriate to panic here because this can only happen due to
// an internal bug in this library (since we construct spec.path ourselves // an internal bug in this library (since we construct the path ourselves
// by reflecting on the same struct) // by reflecting on the same struct)
panic(fmt.Errorf("error resolving path %v: %v has no field named %v", panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
spec.path, v.Type(), field)) dest.fields, v.Type(), field))
} }
} }
return v return v

View File

@ -4,12 +4,13 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// This file contains tests for parse.go but I decided to put them here // This file contains tests for parse.go but I decided to put them here
// since that file is getting large // since that file is getting large
func TestSubcommandNotAStruct(t *testing.T) { func TestSubcommandNotAPointer(t *testing.T) {
var args struct { var args struct {
A string `arg:"subcommand"` A string `arg:"subcommand"`
} }
@ -17,6 +18,14 @@ func TestSubcommandNotAStruct(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
func TestSubcommandNotAPointerToStruct(t *testing.T) {
var args struct {
A struct{} `arg:"subcommand"`
}
_, err := NewParser(Config{}, &args)
assert.Error(t, err)
}
func TestPositionalAndSubcommandNotAllowed(t *testing.T) { func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
var args struct { var args struct {
A string `arg:"positional"` A string `arg:"positional"`
@ -25,3 +34,14 @@ func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
_, err := NewParser(Config{}, &args) _, err := NewParser(Config{}, &args)
assert.Error(t, err) assert.Error(t, err)
} }
func TestMinimalSubcommand(t *testing.T) {
type listCmd struct {
}
var args struct {
List *listCmd `arg:"subcommand"`
}
err := parse("list", &args)
require.NoError(t, err)
assert.NotNil(t, args.List)
}

View File

@ -115,14 +115,12 @@ func (p *Parser) WriteHelp(w io.Writer) {
long: "help", long: "help",
short: "h", short: "h",
help: "display this help and exit", help: "display this help and exit",
root: -1,
}) })
if p.version != "" { if p.version != "" {
p.printOption(w, &spec{ p.printOption(w, &spec{
boolean: true, boolean: true,
long: "version", long: "version",
help: "display version and exit", help: "display version and exit",
root: -1,
}) })
} }
} }
@ -143,8 +141,8 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
} }
// If spec.dest is not the zero value then a default value has been added. // If spec.dest is not the zero value then a default value has been added.
var v reflect.Value var v reflect.Value
if spec.root >= 0 { if len(spec.dest.fields) > 0 {
v = p.readable(spec) v = p.readable(spec.dest)
} }
if v.IsValid() { if v.IsValid() {
z := reflect.Zero(v.Type()) z := reflect.Zero(v.Type())