Merge pull request #149 from alexflint/parse-into-map

Add support for parsing into a map
This commit is contained in:
Alex Flint 2021-04-19 19:27:31 -07:00 committed by GitHub
commit 6a01a15f75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 508 additions and 117 deletions

View File

@ -191,6 +191,7 @@ var args struct {
Files []string `arg:"-f,separate"` Files []string `arg:"-f,separate"`
Databases []string `arg:"positional"` Databases []string `arg:"positional"`
} }
arg.MustParse(&args)
``` ```
```shell ```shell
@ -200,6 +201,20 @@ Files [file1 file2 file3]
Databases [db1 db2 db3] Databases [db1 db2 db3]
``` ```
### Arguments with keys and values
```go
var args struct {
UserIDs map[string]int
}
arg.MustParse(&args)
fmt.Println(args.UserIDs)
```
```shell
./example --userids john=123 mary=456
map[john:123 mary:456]
```
### Custom validation ### Custom validation
```go ```go
var args struct { var args struct {

View File

@ -82,6 +82,19 @@ func Example_multipleValues() {
// output: Fetching the following IDs from localhost: [1 2 3] // output: Fetching the following IDs from localhost: [1 2 3]
} }
// This example demonstrates arguments with keys and values
func Example_mappings() {
// The args you would pass in on the command line
os.Args = split("./example --userids john=123 mary=456")
var args struct {
UserIDs map[string]int
}
MustParse(&args)
fmt.Println(args.UserIDs)
// output: map[john:123 mary:456]
}
// This eample demonstrates multiple value arguments that can be mixed with // This eample demonstrates multiple value arguments that can be mixed with
// other arguments. // other arguments.
func Example_multipleMixed() { func Example_multipleMixed() {

View File

@ -50,15 +50,14 @@ type spec struct {
field reflect.StructField // the struct field from which this option was created field reflect.StructField // the struct field from which this option was created
long string // the --long form for this option, or empty if none long string // the --long form for this option, or empty if none
short string // the -s short form for this option, or empty if none short string // the -s short form for this option, or empty if none
multiple bool cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
required bool required bool // if true, this option must be present on the command line
positional bool positional bool // if true, this option will be looked for in the positional flags
separate bool separate bool // if true, each slice and map entry will have its own --flag
help string help string // the help text for this option
env string env string // the name of the environment variable for this option, or empty for none
boolean bool defaultVal string // default value for this option
defaultVal string // default value for this option placeholder string // name of the data in help
placeholder string // name of the data in help
} }
// command represents a named subcommand, or the top-level command // command represents a named subcommand, or the top-level command
@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
if !isSubcommand { if !isSubcommand {
cmd.specs = append(cmd.specs, &spec) cmd.specs = append(cmd.specs, &spec)
var parseable bool var err error
parseable, spec.boolean, spec.multiple = canParse(field.Type) spec.cardinality, err = cardinalityOf(field.Type)
if !parseable { if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String())) t.Name(), field.Name, field.Type.String()))
return false return false
} }
if spec.multiple && hasDefault { if spec.cardinality == multiple && hasDefault {
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields", errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
t.Name(), field.Name)) t.Name(), field.Name))
return false return false
} }
@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
continue continue
} }
if spec.multiple { if spec.cardinality == multiple {
// expect a CSV string in an environment // expect a CSV string in an environment
// variable in the case of multiple values // variable in the case of multiple values
values, err := csv.NewReader(strings.NewReader(value)).Read() values, err := csv.NewReader(strings.NewReader(value)).Read()
@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
err, err,
) )
} }
if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil { if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil {
return fmt.Errorf( return fmt.Errorf(
"error processing environment variable %s with multiple values: %v", "error processing environment variable %s with multiple values: %v",
spec.env, spec.env,
@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error {
wasPresent[spec] = true wasPresent[spec] = true
// deal with the case of multiple values // deal with the case of multiple values
if spec.multiple { if spec.cardinality == multiple {
var values []string var values []string
if value == "" { if value == "" {
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error {
} else { } else {
values = append(values, value) values = append(values, value)
} }
err := setSlice(p.val(spec.dest), values, !spec.separate) err := setSliceOrMap(p.val(spec.dest), values, !spec.separate)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error {
// if it's a flag and it has no value then set the value to true // if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler // use boolean because this takes account of TextUnmarshaler
if spec.boolean && value == "" { if spec.cardinality == zero && value == "" {
value = "true" value = "true"
} }
@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error {
break break
} }
wasPresent[spec] = true wasPresent[spec] = true
if spec.multiple { if spec.cardinality == multiple {
err := setSlice(p.val(spec.dest), positionals, true) err := setSliceOrMap(p.val(spec.dest), positionals, true)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.field.Name, err) return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
} }
@ -702,37 +701,6 @@ func (p *Parser) val(dest path) reflect.Value {
return v return v
} }
// parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error {
if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
var ptr bool
elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
ptr = true
elem = elem.Elem()
}
// Truncate the dest slice in case default values exist
if trunc && !dest.IsNil() {
dest.SetLen(0)
}
for _, s := range values {
v := reflect.New(elem)
if err := scalar.ParseValue(v.Elem(), s); err != nil {
return err
}
if !ptr {
v = v.Elem()
}
dest.Set(reflect.Append(dest, v))
}
return nil
}
// findOption finds an option from its name, or returns null if no spec is found // findOption finds an option from its name, or returns null if no spec is found
func findOption(specs []*spec, name string) *spec { func findOption(specs []*spec, name string) *spec {
for _, spec := range specs { for _, spec := range specs {
@ -759,7 +727,7 @@ func findSubcommand(cmds []*command, name string) *command {
// isZero returns true if v contains the zero value for its type // isZero returns true if v contains the zero value for its type
func isZero(v reflect.Value) bool { func isZero(v reflect.Value) bool {
t := v.Type() t := v.Type()
if t.Kind() == reflect.Slice { if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
return v.IsNil() return v.IsNil()
} }
if !t.Comparable() { if !t.Comparable() {

View File

@ -220,6 +220,60 @@ func TestLongFlag(t *testing.T) {
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
} }
func TestSlice(t *testing.T) {
var args struct {
Strings []string
}
err := parse("--strings a b c", &args)
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, args.Strings)
}
func TestSliceOfBools(t *testing.T) {
var args struct {
B []bool
}
err := parse("--b true false true", &args)
require.NoError(t, err)
assert.Equal(t, []bool{true, false, true}, args.B)
}
func TestMap(t *testing.T) {
var args struct {
Values map[string]int
}
err := parse("--values a=1 b=2 c=3", &args)
require.NoError(t, err)
assert.Len(t, args.Values, 3)
assert.Equal(t, 1, args.Values["a"])
assert.Equal(t, 2, args.Values["b"])
assert.Equal(t, 3, args.Values["c"])
}
func TestMapPositional(t *testing.T) {
var args struct {
Values map[string]int `arg:"positional"`
}
err := parse("a=1 b=2 c=3", &args)
require.NoError(t, err)
assert.Len(t, args.Values, 3)
assert.Equal(t, 1, args.Values["a"])
assert.Equal(t, 2, args.Values["b"])
assert.Equal(t, 3, args.Values["c"])
}
func TestMapWithSeparate(t *testing.T) {
var args struct {
Values map[string]int `arg:"separate"`
}
err := parse("--values a=1 --values b=2 --values c=3", &args)
require.NoError(t, err)
assert.Len(t, args.Values, 3)
assert.Equal(t, 1, args.Values["a"])
assert.Equal(t, 2, args.Values["b"])
assert.Equal(t, 3, args.Values["c"])
}
func TestPlaceholder(t *testing.T) { func TestPlaceholder(t *testing.T) {
var args struct { var args struct {
Input string `arg:"positional" placeholder:"SRC"` Input string `arg:"positional" placeholder:"SRC"`
@ -688,6 +742,17 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
func TestEnvironmentVariableMap(t *testing.T) {
var args struct {
Foo map[int]string `arg:"env"`
}
setenv(t, "FOO", "1=one,99=ninetynine")
MustParse(&args)
assert.Len(t, args.Foo, 2)
assert.Equal(t, "one", args.Foo[1])
assert.Equal(t, "ninetynine", args.Foo[99])
}
func TestEnvironmentVariableIgnored(t *testing.T) { func TestEnvironmentVariableIgnored(t *testing.T) {
var args struct { var args struct {
Foo string `arg:"env"` Foo string `arg:"env"`
@ -1223,7 +1288,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
} }
err := parse("", &args) err := parse("", &args)
assert.EqualError(t, err, ".A: default values are not supported for slice fields") assert.EqualError(t, err, ".A: default values are not supported for slice or map fields")
} }
func TestUnexportedFieldsSkipped(t *testing.T) { func TestUnexportedFieldsSkipped(t *testing.T) {

View File

@ -2,6 +2,7 @@ package arg
import ( import (
"encoding" "encoding"
"fmt"
"reflect" "reflect"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
@ -11,42 +12,67 @@ import (
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// canParse returns true if the type can be parsed from a string // cardinality tracks how many tokens are expected for a given spec
func canParse(t reflect.Type) (parseable, boolean, multiple bool) { // - zero is a boolean, which does to expect any value
parseable = scalar.CanParse(t) // - one is an ordinary option that will be parsed from a single token
boolean = isBoolean(t) // - multiple is a slice or map that can accept zero or more tokens
if parseable { type cardinality int
return
const (
zero cardinality = iota
one
multiple
unsupported
)
func (k cardinality) String() string {
switch k {
case zero:
return "zero"
case one:
return "one"
case multiple:
return "multiple"
case unsupported:
return "unsupported"
default:
return fmt.Sprintf("unknown(%d)", int(k))
}
}
// cardinalityOf returns true if the type can be parsed from a string
func cardinalityOf(t reflect.Type) (cardinality, error) {
if scalar.CanParse(t) {
if isBoolean(t) {
return zero, nil
} else {
return one, nil
}
} }
// Look inside pointer types // look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Look inside slice types
if t.Kind() == reflect.Slice {
multiple = true
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
// Look inside pointer types (again, in case of []*Type)
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
t = t.Elem() t = t.Elem()
} }
parseable = scalar.CanParse(t) // look inside slice and map types
boolean = isBoolean(t) switch t.Kind() {
if parseable { case reflect.Slice:
return if !scalar.CanParse(t.Elem()) {
return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem())
}
return multiple, nil
case reflect.Map:
if !scalar.CanParse(t.Key()) {
return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem())
}
if !scalar.CanParse(t.Elem()) {
return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem())
}
return multiple, nil
default:
return unsupported, fmt.Errorf("cannot parse into %v", t)
} }
return false, false, false
} }
// isBoolean returns true if the type can be parsed from a single string // isBoolean returns true if the type can be parsed from a single string

View File

@ -7,36 +7,54 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) { func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) {
p, b, m := canParse(typ) actual, err := cardinalityOf(typ)
assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p) assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual)
assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b) if expected == unsupported {
assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m) assert.Error(t, err)
}
} }
func TestCanParse(t *testing.T) { func TestCardinalityOf(t *testing.T) {
var b bool var b bool
var i int var i int
var s string var s string
var f float64 var f float64
var bs []bool var bs []bool
var is []int var is []int
var m map[string]int
var unsupported1 struct{}
var unsupported2 []struct{}
var unsupported3 map[string]struct{}
var unsupported4 map[struct{}]string
assertCanParse(t, reflect.TypeOf(b), true, true, false) assertCardinality(t, reflect.TypeOf(b), zero)
assertCanParse(t, reflect.TypeOf(i), true, false, false) assertCardinality(t, reflect.TypeOf(i), one)
assertCanParse(t, reflect.TypeOf(s), true, false, false) assertCardinality(t, reflect.TypeOf(s), one)
assertCanParse(t, reflect.TypeOf(f), true, false, false) assertCardinality(t, reflect.TypeOf(f), one)
assertCanParse(t, reflect.TypeOf(&b), true, true, false) assertCardinality(t, reflect.TypeOf(&b), zero)
assertCanParse(t, reflect.TypeOf(&s), true, false, false) assertCardinality(t, reflect.TypeOf(&s), one)
assertCanParse(t, reflect.TypeOf(&i), true, false, false) assertCardinality(t, reflect.TypeOf(&i), one)
assertCanParse(t, reflect.TypeOf(&f), true, false, false) assertCardinality(t, reflect.TypeOf(&f), one)
assertCanParse(t, reflect.TypeOf(bs), true, true, true) assertCardinality(t, reflect.TypeOf(bs), multiple)
assertCanParse(t, reflect.TypeOf(&bs), true, true, true) assertCardinality(t, reflect.TypeOf(is), multiple)
assertCanParse(t, reflect.TypeOf(is), true, false, true) assertCardinality(t, reflect.TypeOf(&bs), multiple)
assertCanParse(t, reflect.TypeOf(&is), true, false, true) assertCardinality(t, reflect.TypeOf(&is), multiple)
assertCardinality(t, reflect.TypeOf(m), multiple)
assertCardinality(t, reflect.TypeOf(&m), multiple)
assertCardinality(t, reflect.TypeOf(unsupported1), unsupported)
assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported)
assertCardinality(t, reflect.TypeOf(unsupported2), unsupported)
assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported)
assertCardinality(t, reflect.TypeOf(unsupported3), unsupported)
assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported)
assertCardinality(t, reflect.TypeOf(unsupported4), unsupported)
assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported)
} }
type implementsTextUnmarshaler struct{} type implementsTextUnmarshaler struct{}
@ -45,13 +63,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
return nil return nil
} }
func TestCanParseTextUnmarshaler(t *testing.T) { func TestCardinalityTextUnmarshaler(t *testing.T) {
var u implementsTextUnmarshaler var x implementsTextUnmarshaler
var su []implementsTextUnmarshaler var s []implementsTextUnmarshaler
assertCanParse(t, reflect.TypeOf(u), true, false, false) var m []implementsTextUnmarshaler
assertCanParse(t, reflect.TypeOf(&u), true, false, false) assertCardinality(t, reflect.TypeOf(x), one)
assertCanParse(t, reflect.TypeOf(su), true, false, true) assertCardinality(t, reflect.TypeOf(&x), one)
assertCanParse(t, reflect.TypeOf(&su), true, false, true) assertCardinality(t, reflect.TypeOf(s), multiple)
assertCardinality(t, reflect.TypeOf(&s), multiple)
assertCardinality(t, reflect.TypeOf(m), multiple)
assertCardinality(t, reflect.TypeOf(&m), multiple)
} }
func TestIsExported(t *testing.T) { func TestIsExported(t *testing.T) {
@ -60,3 +81,11 @@ func TestIsExported(t *testing.T) {
assert.False(t, isExported("")) assert.False(t, isExported(""))
assert.False(t, isExported(string([]byte{255}))) assert.False(t, isExported(string([]byte{255})))
} }
func TestCardinalityString(t *testing.T) {
assert.Equal(t, "zero", zero.String())
assert.Equal(t, "one", one.String())
assert.Equal(t, "multiple", multiple.String())
assert.Equal(t, "unsupported", unsupported.String())
assert.Equal(t, "unknown(42)", cardinality(42).String())
}

123
sequence.go Normal file
View File

@ -0,0 +1,123 @@
package arg
import (
"fmt"
"reflect"
"strings"
scalar "github.com/alexflint/go-scalar"
)
// setSliceOrMap parses a sequence of strings into a slice or map. If clear is
// true then any values already in the slice or map are first removed.
func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
t := dest.Type()
if t.Kind() == reflect.Ptr {
dest = dest.Elem()
t = t.Elem()
}
switch t.Kind() {
case reflect.Slice:
return setSlice(dest, values, clear)
case reflect.Map:
return setMap(dest, values, clear)
default:
return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t)
}
}
// setSlice parses a sequence of strings and inserts them into a slice. If clear
// is true then any values already in the slice are removed.
func setSlice(dest reflect.Value, values []string, clear bool) error {
var ptr bool
elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
ptr = true
elem = elem.Elem()
}
// clear the slice in case default values exist
if clear && !dest.IsNil() {
dest.SetLen(0)
}
// parse the values one-by-one
for _, s := range values {
v := reflect.New(elem)
if err := scalar.ParseValue(v.Elem(), s); err != nil {
return err
}
if !ptr {
v = v.Elem()
}
dest.Set(reflect.Append(dest, v))
}
return nil
}
// setMap parses a sequence of name=value strings and inserts them into a map.
// If clear is true then any values already in the map are removed.
func setMap(dest reflect.Value, values []string, clear bool) error {
// determine the key and value type
var keyIsPtr bool
keyType := dest.Type().Key()
if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) {
keyIsPtr = true
keyType = keyType.Elem()
}
var valIsPtr bool
valType := dest.Type().Elem()
if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) {
valIsPtr = true
valType = valType.Elem()
}
// clear the slice in case default values exist
if clear && !dest.IsNil() {
for _, k := range dest.MapKeys() {
dest.SetMapIndex(k, reflect.Value{})
}
}
// allocate the map if it is not allocated
if dest.IsNil() {
dest.Set(reflect.MakeMap(dest.Type()))
}
// parse the values one-by-one
for _, s := range values {
// split at the first equals sign
pos := strings.Index(s, "=")
if pos == -1 {
return fmt.Errorf("cannot parse %q into a map, expected format key=value", s)
}
// parse the key
k := reflect.New(keyType)
if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil {
return err
}
if !keyIsPtr {
k = k.Elem()
}
// parse the value
v := reflect.New(valType)
if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil {
return err
}
if !valIsPtr {
v = v.Elem()
}
// add it to the map
dest.SetMapIndex(k, v)
}
return nil
}

152
sequence_test.go Normal file
View File

@ -0,0 +1,152 @@
package arg
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetSliceWithoutClearing(t *testing.T) {
xs := []int{10}
entries := []string{"1", "2", "3"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false)
require.NoError(t, err)
assert.Equal(t, []int{10, 1, 2, 3}, xs)
}
func TestSetSliceAfterClearing(t *testing.T) {
xs := []int{100}
entries := []string{"1", "2", "3"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, xs)
}
func TestSetSliceInvalid(t *testing.T) {
xs := []int{100}
entries := []string{"invalid"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
assert.Error(t, err)
}
func TestSetSlicePtr(t *testing.T) {
var xs []*int
entries := []string{"1", "2", "3"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
require.NoError(t, err)
require.Len(t, xs, 3)
assert.Equal(t, 1, *xs[0])
assert.Equal(t, 2, *xs[1])
assert.Equal(t, 3, *xs[2])
}
func TestSetSliceTextUnmarshaller(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it
var xs []*textUnmarshaler
entries := []string{"a", "aa", "aaa"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
require.NoError(t, err)
require.Len(t, xs, 3)
assert.Equal(t, 1, xs[0].val)
assert.Equal(t, 2, xs[1].val)
assert.Equal(t, 3, xs[2].val)
}
func TestSetMapWithoutClearing(t *testing.T) {
m := map[string]int{"foo": 10}
entries := []string{"a=1", "b=2"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, false)
require.NoError(t, err)
require.Len(t, m, 3)
assert.Equal(t, 1, m["a"])
assert.Equal(t, 2, m["b"])
assert.Equal(t, 10, m["foo"])
}
func TestSetMapAfterClearing(t *testing.T) {
m := map[string]int{"foo": 10}
entries := []string{"a=1", "b=2"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
require.NoError(t, err)
require.Len(t, m, 2)
assert.Equal(t, 1, m["a"])
assert.Equal(t, 2, m["b"])
}
func TestSetMapWithKeyPointer(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it
var m map[*string]int
entries := []string{"abc=123"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
require.NoError(t, err)
require.Len(t, m, 1)
}
func TestSetMapWithValuePointer(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it
var m map[string]*int
entries := []string{"abc=123"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
require.NoError(t, err)
require.Len(t, m, 1)
assert.Equal(t, 123, *m["abc"])
}
func TestSetMapTextUnmarshaller(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it
var m map[textUnmarshaler]*textUnmarshaler
entries := []string{"a=123", "aa=12", "aaa=1"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
require.NoError(t, err)
require.Len(t, m, 3)
assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}])
assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}])
assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}])
}
func TestSetMapInvalidKey(t *testing.T) {
var m map[int]int
entries := []string{"invalid=123"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
assert.Error(t, err)
}
func TestSetMapInvalidValue(t *testing.T) {
var m map[int]int
entries := []string{"123=invalid"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
assert.Error(t, err)
}
func TestSetMapMalformed(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it
var m map[string]string
entries := []string{"missing_equals_sign"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
assert.Error(t, err)
}
func TestSetSliceOrMapErrors(t *testing.T) {
var err error
var dest reflect.Value
// converting a slice to a reflect.Value in this way will make it read only
var cannotSet []int
dest = reflect.ValueOf(cannotSet)
err = setSliceOrMap(dest, nil, false)
assert.Error(t, err)
// check what happens when we pass in something that is not a slice or a map
var notSliceOrMap string
dest = reflect.ValueOf(&notSliceOrMap).Elem()
err = setSliceOrMap(dest, nil, false)
assert.Error(t, err)
// check what happens when we pass in a pointer to something that is not a slice or a map
var stringPtr *string
dest = reflect.ValueOf(&stringPtr).Elem()
err = setSliceOrMap(dest, nil, false)
assert.Error(t, err)
}

View File

@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
for _, spec := range positionals { for _, spec := range positionals {
// prefix with a space // prefix with a space
fmt.Fprint(w, " ") fmt.Fprint(w, " ")
if spec.multiple { if spec.cardinality == multiple {
if !spec.required { if !spec.required {
fmt.Fprint(w, "[") fmt.Fprint(w, "[")
} }
@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
// write the list of built in options // write the list of built in options
p.printOption(w, &spec{ p.printOption(w, &spec{
boolean: true, cardinality: zero,
long: "help", long: "help",
short: "h", short: "h",
help: "display this help and exit", help: "display this help and exit",
}) })
if p.version != "" { if p.version != "" {
p.printOption(w, &spec{ p.printOption(w, &spec{
boolean: true, cardinality: zero,
long: "version", long: "version",
help: "display version and exit", help: "display version and exit",
}) })
} }
@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
} }
func synopsis(spec *spec, form string) string { func synopsis(spec *spec, form string) string {
if spec.boolean { if spec.cardinality == zero {
return form return form
} }
return form + " " + spec.placeholder return form + " " + spec.placeholder