refactor validation

This commit is contained in:
Alex Flint 2019-04-14 18:00:40 -07:00
parent 7b1d9ef23f
commit b8678d4045
3 changed files with 79 additions and 56 deletions

View File

@ -24,7 +24,6 @@ type spec struct {
separate bool separate bool
help string help string
env string env string
wasPresent bool
boolean bool boolean bool
} }
@ -80,7 +79,7 @@ type Config struct {
// Parser represents a set of command line options with destination values // Parser represents a set of command line options with destination values
type Parser struct { type Parser struct {
spec []*spec specs []*spec
config Config config Config
version string version string
description string description string
@ -214,7 +213,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
} }
} }
} }
p.spec = append(p.spec, &spec) p.specs = append(p.specs, &spec)
// if this was an embedded field then we already returned true up above // if this was an embedded field then we already returned true up above
return false return false
@ -250,21 +249,18 @@ func (p *Parser) Parse(args []string) error {
} }
// Process all command line arguments // Process all command line arguments
err := process(p.spec, args) return p.process(args)
if err != nil {
return err
}
// Validate
return validate(p.spec)
} }
// process goes through arguments one-by-one, parses them, and assigns the result to // process goes through arguments one-by-one, parses them, and assigns the result to
// the underlying struct field // the underlying struct field
func process(specs []*spec, args []string) error { func (p *Parser) process(args []string) error {
// track the options we have seen
wasPresent := make(map[*spec]bool)
// construct a map from --option to spec // construct a map from --option to spec
optionMap := make(map[string]*spec) optionMap := make(map[string]*spec)
for _, spec := range specs { for _, spec := range p.specs {
if spec.positional { if spec.positional {
continue continue
} }
@ -274,8 +270,19 @@ func process(specs []*spec, args []string) error {
if spec.short != "" { if spec.short != "" {
optionMap[spec.short] = spec optionMap[spec.short] = spec
} }
if spec.env != "" { }
if value, found := os.LookupEnv(spec.env); found {
// deal with environment vars
for _, spec := range p.specs {
if spec.env == "" {
continue
}
value, found := os.LookupEnv(spec.env)
if !found {
continue
}
if spec.multiple { if spec.multiple {
// expect a CSV string in an environment // expect a CSV string in an environment
// variable in the case of multiple values // variable in the case of multiple values
@ -299,9 +306,7 @@ func process(specs []*spec, args []string) error {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
} }
} }
spec.wasPresent = true wasPresent[spec] = true
}
}
} }
// process each string from the command line // process each string from the command line
@ -334,7 +339,7 @@ func process(specs []*spec, args []string) error {
if !ok { if !ok {
return fmt.Errorf("unknown argument %s", arg) return fmt.Errorf("unknown argument %s", arg)
} }
spec.wasPresent = true wasPresent[spec] = true
// deal with the case of multiple values // deal with the case of multiple values
if spec.multiple { if spec.multiple {
@ -382,20 +387,21 @@ func process(specs []*spec, args []string) error {
} }
// process positionals // process positionals
for _, spec := range specs { for _, spec := range p.specs {
if !spec.positional { if !spec.positional {
continue continue
} }
if spec.required && len(positionals) == 0 { if len(positionals) == 0 {
return fmt.Errorf("%s is required", spec.long) break
} }
wasPresent[spec] = true
if spec.multiple { if spec.multiple {
err := setSlice(spec.dest, positionals, true) err := setSlice(spec.dest, positionals, true)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err) return fmt.Errorf("error processing %s: %v", spec.long, err)
} }
positionals = nil positionals = nil
} else if len(positionals) > 0 { } else {
err := scalar.ParseValue(spec.dest, positionals[0]) err := scalar.ParseValue(spec.dest, positionals[0])
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err) return fmt.Errorf("error processing %s: %v", spec.long, err)
@ -406,6 +412,18 @@ func process(specs []*spec, args []string) error {
if len(positionals) > 0 { if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
} }
// finally check that all the required args were provided
for _, spec := range p.specs {
if spec.required && !wasPresent[spec] {
name := spec.long
if !spec.positional {
name = "--" + spec.long
}
return fmt.Errorf("%s is required", name)
}
}
return nil return nil
} }
@ -427,16 +445,6 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
} }
// validate an argument spec after arguments have been parse
func validate(spec []*spec) error {
for _, arg := range spec {
if !arg.positional && arg.required && !arg.wasPresent {
return fmt.Errorf("--%s is required", arg.long)
}
}
return nil
}
// parse a value as the appropriate type and store it in the struct // parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error { func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() { if !dest.CanSet() {

View File

@ -969,3 +969,18 @@ func TestSpacesAllowedInTags(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo)
} }
func TestReuseParser(t *testing.T) {
var args struct {
Foo string `arg:"required"`
}
p, err := NewParser(Config{}, &args)
require.NoError(t, err)
err = p.Parse([]string{"--foo=abc"})
assert.Equal(t, args.Foo, "abc")
err = p.Parse([]string{})
assert.Error(t, err)
}

View File

@ -1,12 +1,12 @@
package arg package arg
import ( import (
"encoding"
"fmt" "fmt"
"io" "io"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"encoding"
) )
// the width of the left column // the width of the left column
@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) {
// WriteUsage writes usage information to the given writer // WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) { func (p *Parser) WriteUsage(w io.Writer) {
var positionals, options []*spec var positionals, options []*spec
for _, spec := range p.spec { for _, spec := range p.specs {
if spec.positional { if spec.positional {
positionals = append(positionals, spec) positionals = append(positionals, spec)
} else { } else {
@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
// WriteHelp writes the usage string followed by the full help string for each option // WriteHelp writes the usage string followed by the full help string for each option
func (p *Parser) WriteHelp(w io.Writer) { func (p *Parser) WriteHelp(w io.Writer) {
var positionals, options []*spec var positionals, options []*spec
for _, spec := range p.spec { for _, spec := range p.specs {
if spec.positional { if spec.positional {
positionals = append(positionals, spec) positionals = append(positionals, spec)
} else { } else {