refactor validation
This commit is contained in:
parent
7b1d9ef23f
commit
b8678d4045
114
parse.go
114
parse.go
|
@ -24,7 +24,6 @@ type spec struct {
|
|||
separate bool
|
||||
help string
|
||||
env string
|
||||
wasPresent bool
|
||||
boolean bool
|
||||
}
|
||||
|
||||
|
@ -80,7 +79,7 @@ type Config struct {
|
|||
|
||||
// Parser represents a set of command line options with destination values
|
||||
type Parser struct {
|
||||
spec []*spec
|
||||
specs []*spec
|
||||
config Config
|
||||
version string
|
||||
description string
|
||||
|
@ -214,7 +213,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
|||
}
|
||||
}
|
||||
}
|
||||
p.spec = append(p.spec, &spec)
|
||||
p.specs = append(p.specs, &spec)
|
||||
|
||||
// if this was an embedded field then we already returned true up above
|
||||
return false
|
||||
|
@ -250,21 +249,18 @@ func (p *Parser) Parse(args []string) error {
|
|||
}
|
||||
|
||||
// Process all command line arguments
|
||||
err := process(p.spec, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate
|
||||
return validate(p.spec)
|
||||
return p.process(args)
|
||||
}
|
||||
|
||||
// process goes through arguments one-by-one, parses them, and assigns the result to
|
||||
// the underlying struct field
|
||||
func process(specs []*spec, args []string) error {
|
||||
func (p *Parser) process(args []string) error {
|
||||
// track the options we have seen
|
||||
wasPresent := make(map[*spec]bool)
|
||||
|
||||
// construct a map from --option to spec
|
||||
optionMap := make(map[string]*spec)
|
||||
for _, spec := range specs {
|
||||
for _, spec := range p.specs {
|
||||
if spec.positional {
|
||||
continue
|
||||
}
|
||||
|
@ -274,34 +270,43 @@ func process(specs []*spec, args []string) error {
|
|||
if spec.short != "" {
|
||||
optionMap[spec.short] = spec
|
||||
}
|
||||
if spec.env != "" {
|
||||
if value, found := os.LookupEnv(spec.env); found {
|
||||
if spec.multiple {
|
||||
// expect a CSV string in an environment
|
||||
// variable in the case of multiple values
|
||||
values, err := csv.NewReader(strings.NewReader(value)).Read()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"error reading a CSV string from environment variable %s with multiple values: %v",
|
||||
spec.env,
|
||||
err,
|
||||
)
|
||||
}
|
||||
if err = setSlice(spec.dest, values, !spec.separate); err != nil {
|
||||
return fmt.Errorf(
|
||||
"error processing environment variable %s with multiple values: %v",
|
||||
spec.env,
|
||||
err,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if err := scalar.ParseValue(spec.dest, value); err != nil {
|
||||
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
|
||||
}
|
||||
}
|
||||
spec.wasPresent = true
|
||||
}
|
||||
|
||||
// deal with environment vars
|
||||
for _, spec := range p.specs {
|
||||
if spec.env == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value, found := os.LookupEnv(spec.env)
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
if spec.multiple {
|
||||
// expect a CSV string in an environment
|
||||
// variable in the case of multiple values
|
||||
values, err := csv.NewReader(strings.NewReader(value)).Read()
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"error reading a CSV string from environment variable %s with multiple values: %v",
|
||||
spec.env,
|
||||
err,
|
||||
)
|
||||
}
|
||||
if err = setSlice(spec.dest, values, !spec.separate); err != nil {
|
||||
return fmt.Errorf(
|
||||
"error processing environment variable %s with multiple values: %v",
|
||||
spec.env,
|
||||
err,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if err := scalar.ParseValue(spec.dest, value); err != nil {
|
||||
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
|
||||
}
|
||||
}
|
||||
wasPresent[spec] = true
|
||||
}
|
||||
|
||||
// process each string from the command line
|
||||
|
@ -334,7 +339,7 @@ func process(specs []*spec, args []string) error {
|
|||
if !ok {
|
||||
return fmt.Errorf("unknown argument %s", arg)
|
||||
}
|
||||
spec.wasPresent = true
|
||||
wasPresent[spec] = true
|
||||
|
||||
// deal with the case of multiple values
|
||||
if spec.multiple {
|
||||
|
@ -382,20 +387,21 @@ func process(specs []*spec, args []string) error {
|
|||
}
|
||||
|
||||
// process positionals
|
||||
for _, spec := range specs {
|
||||
for _, spec := range p.specs {
|
||||
if !spec.positional {
|
||||
continue
|
||||
}
|
||||
if spec.required && len(positionals) == 0 {
|
||||
return fmt.Errorf("%s is required", spec.long)
|
||||
if len(positionals) == 0 {
|
||||
break
|
||||
}
|
||||
wasPresent[spec] = true
|
||||
if spec.multiple {
|
||||
err := setSlice(spec.dest, positionals, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
||||
}
|
||||
positionals = nil
|
||||
} else if len(positionals) > 0 {
|
||||
} else {
|
||||
err := scalar.ParseValue(spec.dest, positionals[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
||||
|
@ -406,6 +412,18 @@ func process(specs []*spec, args []string) error {
|
|||
if len(positionals) > 0 {
|
||||
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
|
||||
}
|
||||
|
||||
// finally check that all the required args were provided
|
||||
for _, spec := range p.specs {
|
||||
if spec.required && !wasPresent[spec] {
|
||||
name := spec.long
|
||||
if !spec.positional {
|
||||
name = "--" + spec.long
|
||||
}
|
||||
return fmt.Errorf("%s is required", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -427,16 +445,6 @@ func isFlag(s string) bool {
|
|||
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
|
||||
}
|
||||
|
||||
// validate an argument spec after arguments have been parse
|
||||
func validate(spec []*spec) error {
|
||||
for _, arg := range spec {
|
||||
if !arg.positional && arg.required && !arg.wasPresent {
|
||||
return fmt.Errorf("--%s is required", arg.long)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse a value as the appropriate type and store it in the struct
|
||||
func setSlice(dest reflect.Value, values []string, trunc bool) error {
|
||||
if !dest.CanSet() {
|
||||
|
|
|
@ -969,3 +969,18 @@ func TestSpacesAllowedInTags(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo)
|
||||
}
|
||||
|
||||
func TestReuseParser(t *testing.T) {
|
||||
var args struct {
|
||||
Foo string `arg:"required"`
|
||||
}
|
||||
|
||||
p, err := NewParser(Config{}, &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = p.Parse([]string{"--foo=abc"})
|
||||
assert.Equal(t, args.Foo, "abc")
|
||||
|
||||
err = p.Parse([]string{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
6
usage.go
6
usage.go
|
@ -1,12 +1,12 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"encoding"
|
||||
)
|
||||
|
||||
// the width of the left column
|
||||
|
@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) {
|
|||
// WriteUsage writes usage information to the given writer
|
||||
func (p *Parser) WriteUsage(w io.Writer) {
|
||||
var positionals, options []*spec
|
||||
for _, spec := range p.spec {
|
||||
for _, spec := range p.specs {
|
||||
if spec.positional {
|
||||
positionals = append(positionals, spec)
|
||||
} else {
|
||||
|
@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
|
|||
// WriteHelp writes the usage string followed by the full help string for each option
|
||||
func (p *Parser) WriteHelp(w io.Writer) {
|
||||
var positionals, options []*spec
|
||||
for _, spec := range p.spec {
|
||||
for _, spec := range p.specs {
|
||||
if spec.positional {
|
||||
positionals = append(positionals, spec)
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue