refactor validation

This commit is contained in:
Alex Flint 2019-04-14 18:00:40 -07:00
parent 7b1d9ef23f
commit b8678d4045
3 changed files with 79 additions and 56 deletions

View File

@ -24,7 +24,6 @@ type spec struct {
separate bool
help string
env string
wasPresent bool
boolean bool
}
@ -80,7 +79,7 @@ type Config struct {
// Parser represents a set of command line options with destination values
type Parser struct {
spec []*spec
specs []*spec
config Config
version string
description string
@ -214,7 +213,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
}
}
}
p.spec = append(p.spec, &spec)
p.specs = append(p.specs, &spec)
// if this was an embedded field then we already returned true up above
return false
@ -250,21 +249,18 @@ func (p *Parser) Parse(args []string) error {
}
// Process all command line arguments
err := process(p.spec, args)
if err != nil {
return err
}
// Validate
return validate(p.spec)
return p.process(args)
}
// process goes through arguments one-by-one, parses them, and assigns the result to
// the underlying struct field
func process(specs []*spec, args []string) error {
func (p *Parser) process(args []string) error {
// track the options we have seen
wasPresent := make(map[*spec]bool)
// construct a map from --option to spec
optionMap := make(map[string]*spec)
for _, spec := range specs {
for _, spec := range p.specs {
if spec.positional {
continue
}
@ -274,8 +270,19 @@ func process(specs []*spec, args []string) error {
if spec.short != "" {
optionMap[spec.short] = spec
}
if spec.env != "" {
if value, found := os.LookupEnv(spec.env); found {
}
// deal with environment vars
for _, spec := range p.specs {
if spec.env == "" {
continue
}
value, found := os.LookupEnv(spec.env)
if !found {
continue
}
if spec.multiple {
// expect a CSV string in an environment
// variable in the case of multiple values
@ -299,9 +306,7 @@ func process(specs []*spec, args []string) error {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
}
}
spec.wasPresent = true
}
}
wasPresent[spec] = true
}
// process each string from the command line
@ -334,7 +339,7 @@ func process(specs []*spec, args []string) error {
if !ok {
return fmt.Errorf("unknown argument %s", arg)
}
spec.wasPresent = true
wasPresent[spec] = true
// deal with the case of multiple values
if spec.multiple {
@ -382,20 +387,21 @@ func process(specs []*spec, args []string) error {
}
// process positionals
for _, spec := range specs {
for _, spec := range p.specs {
if !spec.positional {
continue
}
if spec.required && len(positionals) == 0 {
return fmt.Errorf("%s is required", spec.long)
if len(positionals) == 0 {
break
}
wasPresent[spec] = true
if spec.multiple {
err := setSlice(spec.dest, positionals, true)
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
}
positionals = nil
} else if len(positionals) > 0 {
} else {
err := scalar.ParseValue(spec.dest, positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", spec.long, err)
@ -406,6 +412,18 @@ func process(specs []*spec, args []string) error {
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
// finally check that all the required args were provided
for _, spec := range p.specs {
if spec.required && !wasPresent[spec] {
name := spec.long
if !spec.positional {
name = "--" + spec.long
}
return fmt.Errorf("%s is required", name)
}
}
return nil
}
@ -427,16 +445,6 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
// validate an argument spec after arguments have been parse
func validate(spec []*spec) error {
for _, arg := range spec {
if !arg.positional && arg.required && !arg.wasPresent {
return fmt.Errorf("--%s is required", arg.long)
}
}
return nil
}
// parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() {

View File

@ -969,3 +969,18 @@ func TestSpacesAllowedInTags(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo)
}
func TestReuseParser(t *testing.T) {
var args struct {
Foo string `arg:"required"`
}
p, err := NewParser(Config{}, &args)
require.NoError(t, err)
err = p.Parse([]string{"--foo=abc"})
assert.Equal(t, args.Foo, "abc")
err = p.Parse([]string{})
assert.Error(t, err)
}

View File

@ -1,12 +1,12 @@
package arg
import (
"encoding"
"fmt"
"io"
"os"
"reflect"
"strings"
"encoding"
)
// the width of the left column
@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) {
// WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) {
var positionals, options []*spec
for _, spec := range p.spec {
for _, spec := range p.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {
@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
// WriteHelp writes the usage string followed by the full help string for each option
func (p *Parser) WriteHelp(w io.Writer) {
var positionals, options []*spec
for _, spec := range p.spec {
for _, spec := range p.specs {
if spec.positional {
positionals = append(positionals, spec)
} else {