introduced path struct
This commit is contained in:
parent
6a796e2c41
commit
af12b7cfc2
124
parse.go
124
parse.go
|
@ -13,10 +13,32 @@ import (
|
||||||
scalar "github.com/alexflint/go-scalar"
|
scalar "github.com/alexflint/go-scalar"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// path represents a sequence of steps to find the output location for an
|
||||||
|
// argument or subcommand in the final destination struct
|
||||||
|
type path struct {
|
||||||
|
root int // index of the destination struct
|
||||||
|
fields []string // sequence of struct field names to traverse
|
||||||
|
}
|
||||||
|
|
||||||
|
// String gets a string representation of the given path
|
||||||
|
func (p path) String() string {
|
||||||
|
return "args." + strings.Join(p.fields, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Child gets a new path representing a child of this path.
|
||||||
|
func (p path) Child(child string) path {
|
||||||
|
// copy the entire slice of fields to avoid possible slice overwrite
|
||||||
|
subfields := make([]string, len(p.fields)+1)
|
||||||
|
copy(subfields, append(p.fields, child))
|
||||||
|
return path{
|
||||||
|
root: p.root,
|
||||||
|
fields: subfields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// spec represents a command line option
|
// spec represents a command line option
|
||||||
type spec struct {
|
type spec struct {
|
||||||
root int
|
dest path
|
||||||
path []string // sequence of field names
|
|
||||||
typ reflect.Type
|
typ reflect.Type
|
||||||
long string
|
long string
|
||||||
short string
|
short string
|
||||||
|
@ -32,6 +54,7 @@ type spec struct {
|
||||||
// command represents a named subcommand, or the top-level command
|
// command represents a named subcommand, or the top-level command
|
||||||
type command struct {
|
type command struct {
|
||||||
name string
|
name string
|
||||||
|
dest path
|
||||||
specs []*spec
|
specs []*spec
|
||||||
subcommands []*command
|
subcommands []*command
|
||||||
}
|
}
|
||||||
|
@ -153,11 +176,12 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
|
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd, err := cmdFromStruct(name, t, nil, i)
|
cmd, err := cmdFromStruct(name, path{root: i}, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
|
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
|
||||||
|
p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
|
||||||
|
|
||||||
if dest, ok := dest.(Versioned); ok {
|
if dest, ok := dest.(Versioned); ok {
|
||||||
p.version = dest.Version()
|
p.version = dest.Version()
|
||||||
|
@ -170,20 +194,24 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) {
|
func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
|
||||||
// commands can only be created from pointers to structs
|
// commands can only be created from pointers to structs
|
||||||
if t.Kind() != reflect.Ptr {
|
if t.Kind() != reflect.Ptr {
|
||||||
return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a %s",
|
return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s",
|
||||||
strings.Join(path, "."), t.Kind())
|
dest, t.Kind())
|
||||||
}
|
}
|
||||||
|
|
||||||
t = t.Elem()
|
t = t.Elem()
|
||||||
if t.Kind() != reflect.Struct {
|
if t.Kind() != reflect.Struct {
|
||||||
return nil, fmt.Errorf("subcommands must be pointers to structs but args.%s is a pointer to %s",
|
return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s",
|
||||||
strings.Join(path, "."), t.Kind())
|
dest, t.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := command{
|
||||||
|
name: name,
|
||||||
|
dest: dest,
|
||||||
}
|
}
|
||||||
|
|
||||||
var cmd command
|
|
||||||
var errs []string
|
var errs []string
|
||||||
walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
|
walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
|
||||||
// Check for the ignore switch in the tag
|
// Check for the ignore switch in the tag
|
||||||
|
@ -198,12 +226,9 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
|
||||||
}
|
}
|
||||||
|
|
||||||
// duplicate the entire path to avoid slice overwrites
|
// duplicate the entire path to avoid slice overwrites
|
||||||
subpath := make([]string, len(path)+1)
|
subdest := dest.Child(field.Name)
|
||||||
copy(subpath, append(path, field.Name))
|
|
||||||
|
|
||||||
spec := spec{
|
spec := spec{
|
||||||
root: root,
|
dest: subdest,
|
||||||
path: subpath,
|
|
||||||
long: strings.ToLower(field.Name),
|
long: strings.ToLower(field.Name),
|
||||||
typ: field.Type,
|
typ: field.Type,
|
||||||
}
|
}
|
||||||
|
@ -213,19 +238,8 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
|
||||||
spec.help = help
|
spec.help = help
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether this field is supported. It's good to do this here rather than
|
|
||||||
// wait until ParseValue because it means that a program with invalid argument
|
|
||||||
// fields will always fail regardless of whether the arguments it received
|
|
||||||
// exercised those fields.
|
|
||||||
var parseable bool
|
|
||||||
parseable, spec.boolean, spec.multiple = canParse(field.Type)
|
|
||||||
if !parseable {
|
|
||||||
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
|
|
||||||
t.Name(), field.Name, field.Type.String()))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look at the tag
|
// Look at the tag
|
||||||
|
var isSubcommand bool // tracks whether this field is a subcommand
|
||||||
if tag != "" {
|
if tag != "" {
|
||||||
for _, key := range strings.Split(tag, ",") {
|
for _, key := range strings.Split(tag, ",") {
|
||||||
key = strings.TrimLeft(key, " ")
|
key = strings.TrimLeft(key, " ")
|
||||||
|
@ -269,20 +283,37 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
|
||||||
cmdname = strings.ToLower(field.Name)
|
cmdname = strings.ToLower(field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
subcmd, err := cmdFromStruct(cmdname, field.Type, subpath, root)
|
subcmd, err := cmdFromStruct(cmdname, subdest, field.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, err.Error())
|
errs = append(errs, err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.subcommands = append(cmd.subcommands, subcmd)
|
cmd.subcommands = append(cmd.subcommands, subcmd)
|
||||||
|
isSubcommand = true
|
||||||
|
fmt.Println("found a subcommand")
|
||||||
default:
|
default:
|
||||||
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
|
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cmd.specs = append(cmd.specs, &spec)
|
|
||||||
|
// Check whether this field is supported. It's good to do this here rather than
|
||||||
|
// wait until ParseValue because it means that a program with invalid argument
|
||||||
|
// fields will always fail regardless of whether the arguments it received
|
||||||
|
// exercised those fields.
|
||||||
|
if !isSubcommand {
|
||||||
|
cmd.specs = append(cmd.specs, &spec)
|
||||||
|
|
||||||
|
var parseable bool
|
||||||
|
parseable, spec.boolean, spec.multiple = canParse(field.Type)
|
||||||
|
if !parseable {
|
||||||
|
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
|
||||||
|
t.Name(), field.Name, field.Type.String()))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if this was an embedded field then we already returned true up above
|
// if this was an embedded field then we already returned true up above
|
||||||
return false
|
return false
|
||||||
|
@ -303,6 +334,8 @@ func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*comma
|
||||||
return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t)
|
return nil, fmt.Errorf("%T cannot have both subcommands and positional arguments", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Printf("parsed a command with %d subcommands\n", len(cmd.subcommands))
|
||||||
|
|
||||||
return &cmd, nil
|
return &cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -349,7 +382,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if err = setSlice(p.writable(spec), values, !spec.separate); err != nil {
|
if err = setSlice(p.writable(spec.dest), values, !spec.separate); err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"error processing environment variable %s with multiple values: %v",
|
"error processing environment variable %s with multiple values: %v",
|
||||||
spec.env,
|
spec.env,
|
||||||
|
@ -357,7 +390,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := scalar.ParseValue(p.writable(spec), value); err != nil {
|
if err := scalar.ParseValue(p.writable(spec.dest), value); err != nil {
|
||||||
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
|
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -400,6 +433,7 @@ func (p *Parser) process(args []string) error {
|
||||||
|
|
||||||
if !isFlag(arg) || allpositional {
|
if !isFlag(arg) || allpositional {
|
||||||
// each subcommand can have either subcommands or positionals, but not both
|
// each subcommand can have either subcommands or positionals, but not both
|
||||||
|
fmt.Printf("processing %q, with %d subcommands", arg, len(curCmd.subcommands))
|
||||||
if len(curCmd.subcommands) == 0 {
|
if len(curCmd.subcommands) == 0 {
|
||||||
positionals = append(positionals, arg)
|
positionals = append(positionals, arg)
|
||||||
continue
|
continue
|
||||||
|
@ -454,7 +488,7 @@ func (p *Parser) process(args []string) error {
|
||||||
} else {
|
} else {
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
err := setSlice(p.writable(spec), values, !spec.separate)
|
err := setSlice(p.writable(spec.dest), values, !spec.separate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", arg, err)
|
return fmt.Errorf("error processing %s: %v", arg, err)
|
||||||
}
|
}
|
||||||
|
@ -479,7 +513,7 @@ func (p *Parser) process(args []string) error {
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
err := scalar.ParseValue(p.writable(spec), value)
|
err := scalar.ParseValue(p.writable(spec.dest), value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", arg, err)
|
return fmt.Errorf("error processing %s: %v", arg, err)
|
||||||
}
|
}
|
||||||
|
@ -495,13 +529,13 @@ func (p *Parser) process(args []string) error {
|
||||||
}
|
}
|
||||||
wasPresent[spec] = true
|
wasPresent[spec] = true
|
||||||
if spec.multiple {
|
if spec.multiple {
|
||||||
err := setSlice(p.writable(spec), positionals, true)
|
err := setSlice(p.writable(spec.dest), positionals, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
||||||
}
|
}
|
||||||
positionals = nil
|
positionals = nil
|
||||||
} else {
|
} else {
|
||||||
err := scalar.ParseValue(p.writable(spec), positionals[0])
|
err := scalar.ParseValue(p.writable(spec.dest), positionals[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
||||||
}
|
}
|
||||||
|
@ -546,9 +580,9 @@ func isFlag(s string) bool {
|
||||||
|
|
||||||
// readable returns a reflect.Value corresponding to the current value for the
|
// readable returns a reflect.Value corresponding to the current value for the
|
||||||
// given
|
// given
|
||||||
func (p *Parser) readable(spec *spec) reflect.Value {
|
func (p *Parser) readable(dest path) reflect.Value {
|
||||||
v := p.roots[spec.root]
|
v := p.roots[dest.root]
|
||||||
for _, field := range spec.path {
|
for _, field := range dest.fields {
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Ptr {
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
return reflect.Value{}
|
return reflect.Value{}
|
||||||
|
@ -559,21 +593,21 @@ func (p *Parser) readable(spec *spec) reflect.Value {
|
||||||
v = v.FieldByName(field)
|
v = v.FieldByName(field)
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
// it is appropriate to panic here because this can only happen due to
|
// it is appropriate to panic here because this can only happen due to
|
||||||
// an internal bug in this library (since we construct spec.path ourselves
|
// an internal bug in this library (since we construct the path ourselves
|
||||||
// by reflecting on the same struct)
|
// by reflecting on the same struct)
|
||||||
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
|
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
|
||||||
spec.path, v.Type(), field))
|
dest.fields, v.Type(), field))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
// writable traverses the destination struct to find the destination to
|
// writable trav.patherses the destination struct to find the destination to
|
||||||
// which the value of the given spec should be written. It fills in null
|
// which the value of the given spec should be written. It fills in null
|
||||||
// structs with pointers to the zero value for that struct.
|
// structs with pointers to the zero value for that struct.
|
||||||
func (p *Parser) writable(spec *spec) reflect.Value {
|
func (p *Parser) writable(dest path) reflect.Value {
|
||||||
v := p.roots[spec.root]
|
v := p.roots[dest.root]
|
||||||
for _, field := range spec.path {
|
for _, field := range dest.fields {
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Ptr {
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
v.Set(reflect.New(v.Type().Elem()))
|
v.Set(reflect.New(v.Type().Elem()))
|
||||||
|
@ -584,10 +618,10 @@ func (p *Parser) writable(spec *spec) reflect.Value {
|
||||||
v = v.FieldByName(field)
|
v = v.FieldByName(field)
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
// it is appropriate to panic here because this can only happen due to
|
// it is appropriate to panic here because this can only happen due to
|
||||||
// an internal bug in this library (since we construct spec.path ourselves
|
// an internal bug in this library (since we construct the path ourselves
|
||||||
// by reflecting on the same struct)
|
// by reflecting on the same struct)
|
||||||
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
|
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
|
||||||
spec.path, v.Type(), field))
|
dest.fields, v.Type(), field))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
|
|
|
@ -4,12 +4,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This file contains tests for parse.go but I decided to put them here
|
// This file contains tests for parse.go but I decided to put them here
|
||||||
// since that file is getting large
|
// since that file is getting large
|
||||||
|
|
||||||
func TestSubcommandNotAStruct(t *testing.T) {
|
func TestSubcommandNotAPointer(t *testing.T) {
|
||||||
var args struct {
|
var args struct {
|
||||||
A string `arg:"subcommand"`
|
A string `arg:"subcommand"`
|
||||||
}
|
}
|
||||||
|
@ -17,6 +18,14 @@ func TestSubcommandNotAStruct(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSubcommandNotAPointerToStruct(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
A struct{} `arg:"subcommand"`
|
||||||
|
}
|
||||||
|
_, err := NewParser(Config{}, &args)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
|
func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
|
||||||
var args struct {
|
var args struct {
|
||||||
A string `arg:"positional"`
|
A string `arg:"positional"`
|
||||||
|
@ -25,3 +34,14 @@ func TestPositionalAndSubcommandNotAllowed(t *testing.T) {
|
||||||
_, err := NewParser(Config{}, &args)
|
_, err := NewParser(Config{}, &args)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMinimalSubcommand(t *testing.T) {
|
||||||
|
type listCmd struct {
|
||||||
|
}
|
||||||
|
var args struct {
|
||||||
|
List *listCmd `arg:"subcommand"`
|
||||||
|
}
|
||||||
|
err := parse("list", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, args.List)
|
||||||
|
}
|
||||||
|
|
6
usage.go
6
usage.go
|
@ -115,14 +115,12 @@ func (p *Parser) WriteHelp(w io.Writer) {
|
||||||
long: "help",
|
long: "help",
|
||||||
short: "h",
|
short: "h",
|
||||||
help: "display this help and exit",
|
help: "display this help and exit",
|
||||||
root: -1,
|
|
||||||
})
|
})
|
||||||
if p.version != "" {
|
if p.version != "" {
|
||||||
p.printOption(w, &spec{
|
p.printOption(w, &spec{
|
||||||
boolean: true,
|
boolean: true,
|
||||||
long: "version",
|
long: "version",
|
||||||
help: "display version and exit",
|
help: "display version and exit",
|
||||||
root: -1,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -143,8 +141,8 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
|
||||||
}
|
}
|
||||||
// If spec.dest is not the zero value then a default value has been added.
|
// If spec.dest is not the zero value then a default value has been added.
|
||||||
var v reflect.Value
|
var v reflect.Value
|
||||||
if spec.root >= 0 {
|
if len(spec.dest.fields) > 0 {
|
||||||
v = p.readable(spec)
|
v = p.readable(spec.dest)
|
||||||
}
|
}
|
||||||
if v.IsValid() {
|
if v.IsValid() {
|
||||||
z := reflect.Zero(v.Type())
|
z := reflect.Zero(v.Type())
|
||||||
|
|
Loading…
Reference in New Issue