refactor validation
This commit is contained in:
parent
7b1d9ef23f
commit
b8678d4045
72
parse.go
72
parse.go
|
@ -24,7 +24,6 @@ type spec struct {
|
||||||
separate bool
|
separate bool
|
||||||
help string
|
help string
|
||||||
env string
|
env string
|
||||||
wasPresent bool
|
|
||||||
boolean bool
|
boolean bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +79,7 @@ type Config struct {
|
||||||
|
|
||||||
// Parser represents a set of command line options with destination values
|
// Parser represents a set of command line options with destination values
|
||||||
type Parser struct {
|
type Parser struct {
|
||||||
spec []*spec
|
specs []*spec
|
||||||
config Config
|
config Config
|
||||||
version string
|
version string
|
||||||
description string
|
description string
|
||||||
|
@ -214,7 +213,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.spec = append(p.spec, &spec)
|
p.specs = append(p.specs, &spec)
|
||||||
|
|
||||||
// if this was an embedded field then we already returned true up above
|
// if this was an embedded field then we already returned true up above
|
||||||
return false
|
return false
|
||||||
|
@ -250,21 +249,18 @@ func (p *Parser) Parse(args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process all command line arguments
|
// Process all command line arguments
|
||||||
err := process(p.spec, args)
|
return p.process(args)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate
|
|
||||||
return validate(p.spec)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// process goes through arguments one-by-one, parses them, and assigns the result to
|
// process goes through arguments one-by-one, parses them, and assigns the result to
|
||||||
// the underlying struct field
|
// the underlying struct field
|
||||||
func process(specs []*spec, args []string) error {
|
func (p *Parser) process(args []string) error {
|
||||||
|
// track the options we have seen
|
||||||
|
wasPresent := make(map[*spec]bool)
|
||||||
|
|
||||||
// construct a map from --option to spec
|
// construct a map from --option to spec
|
||||||
optionMap := make(map[string]*spec)
|
optionMap := make(map[string]*spec)
|
||||||
for _, spec := range specs {
|
for _, spec := range p.specs {
|
||||||
if spec.positional {
|
if spec.positional {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -274,8 +270,19 @@ func process(specs []*spec, args []string) error {
|
||||||
if spec.short != "" {
|
if spec.short != "" {
|
||||||
optionMap[spec.short] = spec
|
optionMap[spec.short] = spec
|
||||||
}
|
}
|
||||||
if spec.env != "" {
|
}
|
||||||
if value, found := os.LookupEnv(spec.env); found {
|
|
||||||
|
// deal with environment vars
|
||||||
|
for _, spec := range p.specs {
|
||||||
|
if spec.env == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
value, found := os.LookupEnv(spec.env)
|
||||||
|
if !found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if spec.multiple {
|
if spec.multiple {
|
||||||
// expect a CSV string in an environment
|
// expect a CSV string in an environment
|
||||||
// variable in the case of multiple values
|
// variable in the case of multiple values
|
||||||
|
@ -299,9 +306,7 @@ func process(specs []*spec, args []string) error {
|
||||||
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
|
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
spec.wasPresent = true
|
wasPresent[spec] = true
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// process each string from the command line
|
// process each string from the command line
|
||||||
|
@ -334,7 +339,7 @@ func process(specs []*spec, args []string) error {
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("unknown argument %s", arg)
|
return fmt.Errorf("unknown argument %s", arg)
|
||||||
}
|
}
|
||||||
spec.wasPresent = true
|
wasPresent[spec] = true
|
||||||
|
|
||||||
// deal with the case of multiple values
|
// deal with the case of multiple values
|
||||||
if spec.multiple {
|
if spec.multiple {
|
||||||
|
@ -382,20 +387,21 @@ func process(specs []*spec, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// process positionals
|
// process positionals
|
||||||
for _, spec := range specs {
|
for _, spec := range p.specs {
|
||||||
if !spec.positional {
|
if !spec.positional {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if spec.required && len(positionals) == 0 {
|
if len(positionals) == 0 {
|
||||||
return fmt.Errorf("%s is required", spec.long)
|
break
|
||||||
}
|
}
|
||||||
|
wasPresent[spec] = true
|
||||||
if spec.multiple {
|
if spec.multiple {
|
||||||
err := setSlice(spec.dest, positionals, true)
|
err := setSlice(spec.dest, positionals, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
||||||
}
|
}
|
||||||
positionals = nil
|
positionals = nil
|
||||||
} else if len(positionals) > 0 {
|
} else {
|
||||||
err := scalar.ParseValue(spec.dest, positionals[0])
|
err := scalar.ParseValue(spec.dest, positionals[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
return fmt.Errorf("error processing %s: %v", spec.long, err)
|
||||||
|
@ -406,6 +412,18 @@ func process(specs []*spec, args []string) error {
|
||||||
if len(positionals) > 0 {
|
if len(positionals) > 0 {
|
||||||
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
|
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finally check that all the required args were provided
|
||||||
|
for _, spec := range p.specs {
|
||||||
|
if spec.required && !wasPresent[spec] {
|
||||||
|
name := spec.long
|
||||||
|
if !spec.positional {
|
||||||
|
name = "--" + spec.long
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s is required", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -427,16 +445,6 @@ func isFlag(s string) bool {
|
||||||
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
|
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate an argument spec after arguments have been parse
|
|
||||||
func validate(spec []*spec) error {
|
|
||||||
for _, arg := range spec {
|
|
||||||
if !arg.positional && arg.required && !arg.wasPresent {
|
|
||||||
return fmt.Errorf("--%s is required", arg.long)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse a value as the appropriate type and store it in the struct
|
// parse a value as the appropriate type and store it in the struct
|
||||||
func setSlice(dest reflect.Value, values []string, trunc bool) error {
|
func setSlice(dest reflect.Value, values []string, trunc bool) error {
|
||||||
if !dest.CanSet() {
|
if !dest.CanSet() {
|
||||||
|
|
|
@ -969,3 +969,18 @@ func TestSpacesAllowedInTags(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo)
|
assert.Equal(t, []string{"one", "two", "three", "four"}, args.Foo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReuseParser(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Foo string `arg:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := NewParser(Config{}, &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.Parse([]string{"--foo=abc"})
|
||||||
|
assert.Equal(t, args.Foo, "abc")
|
||||||
|
|
||||||
|
err = p.Parse([]string{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
6
usage.go
6
usage.go
|
@ -1,12 +1,12 @@
|
||||||
package arg
|
package arg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"encoding"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// the width of the left column
|
// the width of the left column
|
||||||
|
@ -22,7 +22,7 @@ func (p *Parser) Fail(msg string) {
|
||||||
// WriteUsage writes usage information to the given writer
|
// WriteUsage writes usage information to the given writer
|
||||||
func (p *Parser) WriteUsage(w io.Writer) {
|
func (p *Parser) WriteUsage(w io.Writer) {
|
||||||
var positionals, options []*spec
|
var positionals, options []*spec
|
||||||
for _, spec := range p.spec {
|
for _, spec := range p.specs {
|
||||||
if spec.positional {
|
if spec.positional {
|
||||||
positionals = append(positionals, spec)
|
positionals = append(positionals, spec)
|
||||||
} else {
|
} else {
|
||||||
|
@ -72,7 +72,7 @@ func (p *Parser) WriteUsage(w io.Writer) {
|
||||||
// WriteHelp writes the usage string followed by the full help string for each option
|
// WriteHelp writes the usage string followed by the full help string for each option
|
||||||
func (p *Parser) WriteHelp(w io.Writer) {
|
func (p *Parser) WriteHelp(w io.Writer) {
|
||||||
var positionals, options []*spec
|
var positionals, options []*spec
|
||||||
for _, spec := range p.spec {
|
for _, spec := range p.specs {
|
||||||
if spec.positional {
|
if spec.positional {
|
||||||
positionals = append(positionals, spec)
|
positionals = append(positionals, spec)
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue