From e2dda40825e8f3671cb207f6cc2f6e319404d57f Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Sun, 14 Apr 2019 19:50:17 -0700 Subject: [PATCH] all tests passing again --- parse.go | 304 +++++++++++++++++++++++++++++++++++-------------------- usage.go | 19 ++-- 2 files changed, 208 insertions(+), 115 deletions(-) diff --git a/parse.go b/parse.go index 32fc619..f4978ec 100644 --- a/parse.go +++ b/parse.go @@ -15,7 +15,9 @@ import ( // spec represents a command line option type spec struct { - dest reflect.Value + root int + path []string // sequence of field names + typ reflect.Type long string short string multiple bool @@ -27,6 +29,13 @@ type spec struct { boolean bool } +// command represents a named subcommand, or the top-level command +type command struct { + name string + specs []*spec + subcommands []*command +} + // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") @@ -79,7 +88,8 @@ type Config struct { // Parser represents a set of command line options with destination values type Parser struct { - specs []*spec + cmd *command + roots []reflect.Value config Config version string description string @@ -102,134 +112,176 @@ type Described interface { } // walkFields calls a function for each field of a struct, recursively expanding struct fields. -func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) { - t := v.Type() +func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) { for i := 0; i < t.NumField(); i++ { field := t.Field(i) - val := v.Field(i) - expand := visit(field, val, t) + expand := visit(field, t) if expand && field.Type.Kind() == reflect.Struct { - walkFields(val, visit) + walkFields(field.Type, visit) } } } // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { + // first pick a name for the command for use in the usage text + var name string + switch { + case config.Program != "": + name = config.Program + case len(os.Args) > 0: + name = filepath.Base(os.Args[0]) + default: + name = "program" + } + + // construct a parser p := Parser{ + cmd: &command{name: name}, config: config, } + + // make a list of roots for _, dest := range dests { + p.roots = append(p.roots, reflect.ValueOf(dest)) + } + + // process each of the destination values + for i, dest := range dests { + t := reflect.TypeOf(dest) + if t.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t)) + } + t = t.Elem() + + cmd, err := cmdFromStruct(name, t, nil, i) + if err != nil { + return nil, err + } + p.cmd.specs = append(p.cmd.specs, cmd.specs...) + if dest, ok := dest.(Versioned); ok { p.version = dest.Version() } if dest, ok := dest.(Described); ok { p.description = dest.Description() } - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type())) - } - v = v.Elem() - if v.Kind() != reflect.Struct { - panic(fmt.Sprintf("%T is not a struct pointer", dest)) + } + + return &p, nil +} + +func cmdFromStruct(name string, t reflect.Type, path []string, root int) (*command, error) { + if t.Kind() != reflect.Struct { + panic(fmt.Sprintf("%v is not a struct pointer", t)) + } + + var cmd command + var errs []string + walkFields(t, func(field reflect.StructField, t reflect.Type) bool { + // Check for the ignore switch in the tag + tag := field.Tag.Get("arg") + if tag == "-" { + return false } - var errs []string - walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool { - // Check for the ignore switch in the tag - tag := field.Tag.Get("arg") - if tag == "-" { - return false - } + // If this is an embedded struct then recurse into its fields + if field.Anonymous && field.Type.Kind() == reflect.Struct { + return true + } - // If this is an embedded struct then recurse into its fields - if field.Anonymous && field.Type.Kind() == reflect.Struct { - return true - } + spec := spec{ + root: root, + path: append(path, field.Name), + long: strings.ToLower(field.Name), + typ: field.Type, + } - spec := spec{ - long: strings.ToLower(field.Name), - dest: val, - } + help, exists := field.Tag.Lookup("help") + if exists { + spec.help = help + } - help, exists := field.Tag.Lookup("help") - if exists { - spec.help = help - } + // Check whether this field is supported. It's good to do this here rather than + // wait until ParseValue because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it received + // exercised those fields. + var parseable bool + parseable, spec.boolean, spec.multiple = canParse(field.Type) + if !parseable { + errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", + t.Name(), field.Name, field.Type.String())) + return false + } - // Check whether this field is supported. It's good to do this here rather than - // wait until ParseValue because it means that a program with invalid argument - // fields will always fail regardless of whether the arguments it received - // exercised those fields. - var parseable bool - parseable, spec.boolean, spec.multiple = canParse(field.Type) - if !parseable { - errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", - t.Name(), field.Name, field.Type.String())) - return false - } + // Look at the tag + if tag != "" { + for _, key := range strings.Split(tag, ",") { + key = strings.TrimLeft(key, " ") + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } - // Look at the tag - if tag != "" { - for _, key := range strings.Split(tag, ",") { - key = strings.TrimLeft(key, " ") - var value string - if pos := strings.Index(key, ":"); pos != -1 { - value = key[pos+1:] - key = key[:pos] - } - - switch { - case strings.HasPrefix(key, "---"): - errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) - case strings.HasPrefix(key, "--"): - spec.long = key[2:] - case strings.HasPrefix(key, "-"): - if len(key) != 2 { - errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", - t.Name(), field.Name)) - return false - } - spec.short = key[1:] - case key == "required": - spec.required = true - case key == "positional": - spec.positional = true - case key == "separate": - spec.separate = true - case key == "help": // deprecated - spec.help = value - case key == "env": - // Use override name if provided - if value != "" { - spec.env = value - } else { - spec.env = strings.ToUpper(field.Name) - } - default: - errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + switch { + case strings.HasPrefix(key, "---"): + errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name)) + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", + t.Name(), field.Name)) return false } + spec.short = key[1:] + case key == "required": + spec.required = true + case key == "positional": + spec.positional = true + case key == "separate": + spec.separate = true + case key == "help": // deprecated + spec.help = value + case key == "env": + // Use override name if provided + if value != "" { + spec.env = value + } else { + spec.env = strings.ToUpper(field.Name) + } + case key == "subcommand": + // decide on a name for the subcommand + cmdname := value + if cmdname == "" { + cmdname = strings.ToLower(field.Name) + } + + subcmd, err := cmdFromStruct(cmdname, field.Type, append(path, field.Name), root) + if err != nil { + errs = append(errs, err.Error()) + return false + } + + cmd.subcommands = append(cmd.subcommands, subcmd) + default: + errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) + return false } } - p.specs = append(p.specs, &spec) - - // if this was an embedded field then we already returned true up above - return false - }) - - if len(errs) > 0 { - return nil, errors.New(strings.Join(errs, "\n")) } + cmd.specs = append(cmd.specs, &spec) + + // if this was an embedded field then we already returned true up above + return false + }) + + if len(errs) > 0 { + return nil, errors.New(strings.Join(errs, "\n")) } - if p.config.Program == "" { - p.config.Program = "program" - if len(os.Args) > 0 { - p.config.Program = filepath.Base(os.Args[0]) - } - } - return &p, nil + + return &cmd, nil } // Parse processes the given command line option, storing the results in the field @@ -249,12 +301,12 @@ func (p *Parser) Parse(args []string) error { } // Process all command line arguments - return process(p.specs, args) + return p.process(p.cmd.specs, args) } // process goes through arguments one-by-one, parses them, and assigns the result to // the underlying struct field -func process(specs []*spec, args []string) error { +func (p *Parser) process(specs []*spec, args []string) error { // track the options we have seen wasPresent := make(map[*spec]bool) @@ -294,7 +346,7 @@ func process(specs []*spec, args []string) error { err, ) } - if err = setSlice(spec.dest, values, !spec.separate); err != nil { + if err = setSlice(p.settable(spec), values, !spec.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", spec.env, @@ -302,7 +354,7 @@ func process(specs []*spec, args []string) error { ) } } else { - if err := scalar.ParseValue(spec.dest, value); err != nil { + if err := scalar.ParseValue(p.settable(spec), value); err != nil { return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) } } @@ -355,7 +407,7 @@ func process(specs []*spec, args []string) error { } else { values = append(values, value) } - err := setSlice(spec.dest, values, !spec.separate) + err := setSlice(p.settable(spec), values, !spec.separate) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -373,14 +425,14 @@ func process(specs []*spec, args []string) error { if i+1 == len(args) { return fmt.Errorf("missing value for %s", arg) } - if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) { + if !nextIsNumeric(spec.typ, args[i+1]) && isFlag(args[i+1]) { return fmt.Errorf("missing value for %s", arg) } value = args[i+1] i++ } - err := scalar.ParseValue(spec.dest, value) + err := scalar.ParseValue(p.settable(spec), value) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -396,13 +448,13 @@ func process(specs []*spec, args []string) error { } wasPresent[spec] = true if spec.multiple { - err := setSlice(spec.dest, positionals, true) + err := setSlice(p.settable(spec), positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = nil } else { - err := scalar.ParseValue(spec.dest, positionals[0]) + err := scalar.ParseValue(p.settable(spec), positionals[0]) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } @@ -445,6 +497,44 @@ func isFlag(s string) bool { return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" } +func (p *Parser) get(spec *spec) reflect.Value { + v := p.roots[spec.root] + for _, field := range spec.path { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + + v = v.FieldByName(field) + if !v.IsValid() { + panic(fmt.Errorf("error resolving path %v: %v has no field named %v", + spec.path, v.Type(), field)) + } + } + return v +} + +func (p *Parser) settable(spec *spec) reflect.Value { + v := p.roots[spec.root] + for _, field := range spec.path { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + v = v.FieldByName(field) + if !v.IsValid() { + panic(fmt.Errorf("error resolving path %v: %v has no field named %v", + spec.path, v.Type(), field)) + } + } + return v +} + // parse a value as the appropriate type and store it in the struct func setSlice(dest reflect.Value, values []string, trunc bool) error { if !dest.CanSet() { diff --git a/usage.go b/usage.go index cfac563..f9c1a76 100644 --- a/usage.go +++ b/usage.go @@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) { // WriteUsage writes usage information to the given writer func (p *Parser) WriteUsage(w io.Writer) { var positionals, options []*spec - for _, spec := range p.specs { + for _, spec := range p.cmd.specs { if spec.positional { positionals = append(positionals, spec) } else { @@ -34,7 +34,7 @@ func (p *Parser) WriteUsage(w io.Writer) { fmt.Fprintln(w, p.version) } - fmt.Fprintf(w, "Usage: %s", p.config.Program) + fmt.Fprintf(w, "Usage: %s", p.cmd.name) // write the option component of the usage message for _, spec := range options { @@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) { // WriteHelp writes the usage string followed by the full help string for each option func (p *Parser) WriteHelp(w io.Writer) { var positionals, options []*spec - for _, spec := range p.specs { + for _, spec := range p.cmd.specs { if spec.positional { positionals = append(positionals, spec) } else { @@ -106,17 +106,17 @@ func (p *Parser) WriteHelp(w io.Writer) { // write the list of options fmt.Fprint(w, "\nOptions:\n") for _, spec := range options { - printOption(w, spec) + p.printOption(w, spec) } // write the list of built in options - printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"}) + p.printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit", root: -1}) if p.version != "" { - printOption(w, &spec{boolean: true, long: "version", help: "display version and exit"}) + p.printOption(w, &spec{boolean: true, long: "version", help: "display version and exit", root: -1}) } } -func printOption(w io.Writer, spec *spec) { +func (p *Parser) printOption(w io.Writer, spec *spec) { left := " " + synopsis(spec, "--"+spec.long) if spec.short != "" { left += ", " + synopsis(spec, "-"+spec.short) @@ -131,7 +131,10 @@ func printOption(w io.Writer, spec *spec) { fmt.Fprint(w, spec.help) } // If spec.dest is not the zero value then a default value has been added. - v := spec.dest + var v reflect.Value + if spec.root >= 0 { + v = p.get(spec) + } if v.IsValid() { z := reflect.Zero(v.Type()) if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() {