From b8678d404568d6df96cf390eab226a2ebc04d208 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Sun, 14 Apr 2019 18:00:40 -0700 Subject: [PATCH] refactor validation --- parse.go | 114 +++++++++++++++++++++++++++----------------------- parse_test.go | 15 +++++++ usage.go | 6 +-- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/parse.go b/parse.go index c4afda2..e1d1b29 100644 --- a/parse.go +++ b/parse.go @@ -24,7 +24,6 @@ type spec struct { separate bool help string env string - wasPresent bool boolean bool } @@ -80,7 +79,7 @@ type Config struct { // Parser represents a set of command line options with destination values type Parser struct { - spec []*spec + specs []*spec config Config version string description string @@ -214,7 +213,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { } } } - p.spec = append(p.spec, &spec) + p.specs = append(p.specs, &spec) // if this was an embedded field then we already returned true up above return false @@ -250,21 +249,18 @@ func (p *Parser) Parse(args []string) error { } // Process all command line arguments - err := process(p.spec, args) - if err != nil { - return err - } - - // Validate - return validate(p.spec) + return p.process(args) } // process goes through arguments one-by-one, parses them, and assigns the result to // the underlying struct field -func process(specs []*spec, args []string) error { +func (p *Parser) process(args []string) error { + // track the options we have seen + wasPresent := make(map[*spec]bool) + // construct a map from --option to spec optionMap := make(map[string]*spec) - for _, spec := range specs { + for _, spec := range p.specs { if spec.positional { continue } @@ -274,34 +270,43 @@ func process(specs []*spec, args []string) error { if spec.short != "" { optionMap[spec.short] = spec } - if spec.env != "" { - if value, found := os.LookupEnv(spec.env); found { - if spec.multiple { - // expect a CSV string in an environment - // variable in the case of multiple values - values, err := csv.NewReader(strings.NewReader(value)).Read() - if err != nil { - return fmt.Errorf( - "error reading a CSV string from environment variable %s with multiple values: %v", - spec.env, - err, - ) - } - if err = setSlice(spec.dest, values, !spec.separate); err != nil { - return fmt.Errorf( - "error processing environment variable %s with multiple values: %v", - spec.env, - err, - ) - } - } else { - if err := scalar.ParseValue(spec.dest, value); err != nil { - return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) - } - } - spec.wasPresent = true + } + + // deal with environment vars + for _, spec := range p.specs { + if spec.env == "" { + continue + } + + value, found := os.LookupEnv(spec.env) + if !found { + continue + } + + if spec.multiple { + // expect a CSV string in an environment + // variable in the case of multiple values + values, err := csv.NewReader(strings.NewReader(value)).Read() + if err != nil { + return fmt.Errorf( + "error reading a CSV string from environment variable %s with multiple values: %v", + spec.env, + err, + ) + } + if err = setSlice(spec.dest, values, !spec.separate); err != nil { + return fmt.Errorf( + "error processing environment variable %s with multiple values: %v", + spec.env, + err, + ) + } + } else { + if err := scalar.ParseValue(spec.dest, value); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) } } + wasPresent[spec] = true } // process each string from the command line @@ -334,7 +339,7 @@ func process(specs []*spec, args []string) error { if !ok { return fmt.Errorf("unknown argument %s", arg) } - spec.wasPresent = true + wasPresent[spec] = true // deal with the case of multiple values if spec.multiple { @@ -382,20 +387,21 @@ func process(specs []*spec, args []string) error { } // process positionals - for _, spec := range specs { + for _, spec := range p.specs { if !spec.positional { continue } - if spec.required && len(positionals) == 0 { - return fmt.Errorf("%s is required", spec.long) + if len(positionals) == 0 { + break } + wasPresent[spec] = true if spec.multiple { err := setSlice(spec.dest, positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = nil - } else if len(positionals) > 0 { + } else { err := scalar.ParseValue(spec.dest, positionals[0]) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) @@ -406,6 +412,18 @@ func process(specs []*spec, args []string) error { if len(positionals) > 0 { return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) } + + // finally check that all the required args were provided + for _, spec := range p.specs { + if spec.required && !wasPresent[spec] { + name := spec.long + if !spec.positional { + name = "--" + spec.long + } + return fmt.Errorf("%s is required", name) + } + } + return nil } @@ -427,16 +445,6 @@ func isFlag(s string) bool { return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" } -// validate an argument spec after arguments have been parse -func validate(spec []*spec) error { - for _, arg := range spec { - if !arg.positional && arg.required && !arg.wasPresent { - return fmt.Errorf("--%s is required", arg.long) - } - } - return nil -} - // parse a value as the appropriate type and store it in the struct func setSlice(dest reflect.Value, values []string, trunc bool) error { if !dest.CanSet() { diff --git a/parse_test.go b/parse_test.go index 2e438aa..81cd2c3 100644 --- a/parse_test.go +++ b/parse_test.go @@ -969,3 +969,18 @@ func TestSpacesAllowedInTags(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo) } + +func TestReuseParser(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"--foo=abc"}) + assert.Equal(t, args.Foo, "abc") + + err = p.Parse([]string{}) + assert.Error(t, err) +} diff --git a/usage.go b/usage.go index 656ee9a..cfac563 100644 --- a/usage.go +++ b/usage.go @@ -1,12 +1,12 @@ package arg import ( + "encoding" "fmt" "io" "os" "reflect" "strings" - "encoding" ) // the width of the left column @@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) { // WriteUsage writes usage information to the given writer func (p *Parser) WriteUsage(w io.Writer) { var positionals, options []*spec - for _, spec := range p.spec { + for _, spec := range p.specs { if spec.positional { positionals = append(positionals, spec) } else { @@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) { // WriteHelp writes the usage string followed by the full help string for each option func (p *Parser) WriteHelp(w io.Writer) { var positionals, options []*spec - for _, spec := range p.spec { + for _, spec := range p.specs { if spec.positional { positionals = append(positionals, spec) } else {