all tests passing again

This commit is contained in:
Alex Flint 2019-04-14 19:50:17 -07:00
parent e86673b20a
commit e2dda40825
2 changed files with 208 additions and 115 deletions

158
parse.go
View File

@ -15,7 +15,9 @@ import (
// spec represents a command line option // spec represents a command line option
type spec struct { type spec struct {
dest reflect.Value root int
path []string // sequence of field names
typ reflect.Type
long string long string
short string short string
multiple bool multiple bool
@ -27,6 +29,13 @@ type spec struct {
boolean bool boolean bool
} }
// command represents a named subcommand, or the top-level command
type command struct {
name string
specs []*spec
subcommands []*command
}
// ErrHelp indicates that -h or --help were provided // ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user") var ErrHelp = errors.New("help requested by user")
@ -79,7 +88,8 @@ type Config struct {
// Parser represents a set of command line options with destination values // Parser represents a set of command line options with destination values
type Parser struct { type Parser struct {
specs []*spec cmd *command
roots []reflect.Value
config Config config Config
version string version string
description string description string
@ -102,41 +112,73 @@ type Described interface {
} }
// walkFields calls a function for each field of a struct, recursively expanding struct fields. // walkFields calls a function for each field of a struct, recursively expanding struct fields.
func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) { func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
field := t.Field(i) field := t.Field(i)
val := v.Field(i) expand := visit(field, t)
expand := visit(field, val, t)
if expand && field.Type.Kind() == reflect.Struct { if expand && field.Type.Kind() == reflect.Struct {
walkFields(val, visit) walkFields(field.Type, visit)
} }
} }
} }
// NewParser constructs a parser from a list of destination structs // NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) { func NewParser(config Config, dests ...interface{}) (*Parser, error) {
// first pick a name for the command for use in the usage text
var name string
switch {
case config.Program != "":
name = config.Program
case len(os.Args) > 0:
name = filepath.Base(os.Args[0])
default:
name = "program"
}
// construct a parser
p := Parser{ p := Parser{
cmd: &command{name: name},
config: config, config: config,
} }
// make a list of roots
for _, dest := range dests { for _, dest := range dests {
p.roots = append(p.roots, reflect.ValueOf(dest))
}
// process each of the destination values
for i, dest := range dests {
t := reflect.TypeOf(dest)
if t.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
}
t = t.Elem()
cmd, err := cmdFromStruct(name, t, nil, i)
if err != nil {
return nil, err
}
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
if dest, ok := dest.(Versioned); ok { if dest, ok := dest.(Versioned); ok {
p.version = dest.Version() p.version = dest.Version()
} }
if dest, ok := dest.(Described); ok { if dest, ok := dest.(Described); ok {
p.description = dest.Description() p.description = dest.Description()
} }
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type()))
}
v = v.Elem()
if v.Kind() != reflect.Struct {
panic(fmt.Sprintf("%T is not a struct pointer", dest))
} }
return &p, nil
}
func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) {
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("%v is not a struct pointer", t))
}
var cmd command
var errs []string var errs []string
walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool { walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
// Check for the ignore switch in the tag // Check for the ignore switch in the tag
tag := field.Tag.Get("arg") tag := field.Tag.Get("arg")
if tag == "-" { if tag == "-" {
@ -149,8 +191,10 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
} }
spec := spec{ spec := spec{
root: root,
path: append(path, field.Name),
long: strings.ToLower(field.Name), long: strings.ToLower(field.Name),
dest: val, typ: field.Type,
} }
help, exists := field.Tag.Lookup("help") help, exists := field.Tag.Lookup("help")
@ -207,13 +251,27 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
} else { } else {
spec.env = strings.ToUpper(field.Name) spec.env = strings.ToUpper(field.Name)
} }
case key == "subcommand":
// decide on a name for the subcommand
cmdname := value
if cmdname == "" {
cmdname = strings.ToLower(field.Name)
}
subcmd, err := cmdFromStruct(cmdname, field.Type, append(path, field.Name), root)
if err != nil {
errs = append(errs, err.Error())
return false
}
cmd.subcommands = append(cmd.subcommands, subcmd)
default: default:
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
return false return false
} }
} }
} }
p.specs = append(p.specs, &spec) cmd.specs = append(cmd.specs, &spec)
// if this was an embedded field then we already returned true up above // if this was an embedded field then we already returned true up above
return false return false
@ -222,14 +280,8 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
if len(errs) > 0 { if len(errs) > 0 {
return nil, errors.New(strings.Join(errs, "\n")) return nil, errors.New(strings.Join(errs, "\n"))
} }
}
if p.config.Program == "" { return &cmd, nil
p.config.Program = "program"
if len(os.Args) > 0 {
p.config.Program = filepath.Base(os.Args[0])
}
}
return &p, nil
} }
// Parse processes the given command line option, storing the results in the field // Parse processes the given command line option, storing the results in the field
@ -249,12 +301,12 @@ func (p *Parser) Parse(args []string) error {
} }
// Process all command line arguments // Process all command line arguments
return process(p.specs, args) return p.process(p.cmd.specs, args)
} }
// process goes through arguments one-by-one, parses them, and assigns the result to // process goes through arguments one-by-one, parses them, and assigns the result to
// the underlying struct field // the underlying struct field
func process(specs []*spec, args []string) error { func (p *Parser) process(specs []*spec, args []string) error {
// track the options we have seen // track the options we have seen
wasPresent := make(map[*spec]bool) wasPresent := make(map[*spec]bool)
@ -294,7 +346,7 @@ func process(specs []*spec, args []string) error {
err, err,
) )
} }
if err = setSlice(spec.dest, values, !spec.separate); err != nil { if err = setSlice(p.settable(spec), values, !spec.separate); err != nil {
return fmt.Errorf( return fmt.Errorf(
"error processing environment variable %s with multiple values: %v", "error processing environment variable %s with multiple values: %v",
spec.env, spec.env,
@ -302,7 +354,7 @@ func process(specs []*spec, args []string) error {
) )
} }
} else { } else {
if err := scalar.ParseValue(spec.dest, value); err != nil { if err := scalar.ParseValue(p.settable(spec), value); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
} }
} }
@ -355,7 +407,7 @@ func process(specs []*spec, args []string) error {
} else { } else {
values = append(values, value) values = append(values, value)
} }
err := setSlice(spec.dest, values, !spec.separate) err := setSlice(p.settable(spec), values, !spec.separate)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
@ -373,14 +425,14 @@ func process(specs []*spec, args []string) error {
if i+1 == len(args) { if i+1 == len(args) {
return fmt.Errorf("missing value for %s", arg) return fmt.Errorf("missing value for %s", arg)
} }
if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) { if !nextIsNumeric(spec.typ, args[i+1]) && isFlag(args[i+1]) {
return fmt.Errorf("missing value for %s", arg) return fmt.Errorf("missing value for %s", arg)
} }
value = args[i+1] value = args[i+1]
i++ i++
} }
err := scalar.ParseValue(spec.dest, value) err := scalar.ParseValue(p.settable(spec), value)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
@ -396,13 +448,13 @@ func process(specs []*spec, args []string) error {
} }
wasPresent[spec] = true wasPresent[spec] = true
if spec.multiple { if spec.multiple {
err := setSlice(spec.dest, positionals, true) err := setSlice(p.settable(spec), positionals, true)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err) return fmt.Errorf("error processing %s: %v", spec.long, err)
} }
positionals = nil positionals = nil
} else { } else {
err := scalar.ParseValue(spec.dest, positionals[0]) err := scalar.ParseValue(p.settable(spec), positionals[0])
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err) return fmt.Errorf("error processing %s: %v", spec.long, err)
} }
@ -445,6 +497,44 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
} }
func (p *Parser) get(spec *spec) reflect.Value {
v := p.roots[spec.root]
for _, field := range spec.path {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return reflect.Value{}
}
v = v.Elem()
}
v = v.FieldByName(field)
if !v.IsValid() {
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
spec.path, v.Type(), field))
}
}
return v
}
func (p *Parser) settable(spec *spec) reflect.Value {
v := p.roots[spec.root]
for _, field := range spec.path {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
v = v.FieldByName(field)
if !v.IsValid() {
panic(fmt.Errorf("error resolving path %v: %v has no field named %v",
spec.path, v.Type(), field))
}
}
return v
}
// parse a value as the appropriate type and store it in the struct // parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error { func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() { if !dest.CanSet() {

View File

@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) {
// WriteUsage writes usage information to the given writer // WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) { func (p *Parser) WriteUsage(w io.Writer) {
var positionals, options []*spec var positionals, options []*spec
for _, spec := range p.specs { for _, spec := range p.cmd.specs {
if spec.positional { if spec.positional {
positionals = append(positionals, spec) positionals = append(positionals, spec)
} else { } else {
@ -34,7 +34,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
fmt.Fprintln(w, p.version) fmt.Fprintln(w, p.version)
} }
fmt.Fprintf(w, "Usage: %s", p.config.Program) fmt.Fprintf(w, "Usage: %s", p.cmd.name)
// write the option component of the usage message // write the option component of the usage message
for _, spec := range options { for _, spec := range options {
@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
// WriteHelp writes the usage string followed by the full help string for each option // WriteHelp writes the usage string followed by the full help string for each option
func (p *Parser) WriteHelp(w io.Writer) { func (p *Parser) WriteHelp(w io.Writer) {
var positionals, options []*spec var positionals, options []*spec
for _, spec := range p.specs { for _, spec := range p.cmd.specs {
if spec.positional { if spec.positional {
positionals = append(positionals, spec) positionals = append(positionals, spec)
} else { } else {
@ -106,17 +106,17 @@ func (p *Parser) WriteHelp(w io.Writer) {
// write the list of options // write the list of options
fmt.Fprint(w, "\nOptions:\n") fmt.Fprint(w, "\nOptions:\n")
for _, spec := range options { for _, spec := range options {
printOption(w, spec) p.printOption(w, spec)
} }
// write the list of built in options // write the list of built in options
printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"}) p.printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit", root: -1})
if p.version != "" { if p.version != "" {
printOption(w, &spec{boolean: true, long: "version", help: "display version and exit"}) p.printOption(w, &spec{boolean: true, long: "version", help: "display version and exit", root: -1})
} }
} }
func printOption(w io.Writer, spec *spec) { func (p *Parser) printOption(w io.Writer, spec *spec) {
left := " " + synopsis(spec, "--"+spec.long) left := " " + synopsis(spec, "--"+spec.long)
if spec.short != "" { if spec.short != "" {
left += ", " + synopsis(spec, "-"+spec.short) left += ", " + synopsis(spec, "-"+spec.short)
@ -131,7 +131,10 @@ func printOption(w io.Writer, spec *spec) {
fmt.Fprint(w, spec.help) fmt.Fprint(w, spec.help)
} }
// If spec.dest is not the zero value then a default value has been added. // If spec.dest is not the zero value then a default value has been added.
v := spec.dest var v reflect.Value
if spec.root >= 0 {
v = p.get(spec)
}
if v.IsValid() { if v.IsValid() {
z := reflect.Zero(v.Type()) z := reflect.Zero(v.Type())
if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() { if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() {