diff --git a/README.md b/README.md index da69469..48fa2f0 100644 --- a/README.md +++ b/README.md @@ -191,6 +191,7 @@ var args struct { Files []string `arg:"-f,separate"` Databases []string `arg:"positional"` } +arg.MustParse(&args) ``` ```shell @@ -200,6 +201,20 @@ Files [file1 file2 file3] Databases [db1 db2 db3] ``` +### Arguments with keys and values +```go +var args struct { + UserIDs map[string]int +} +arg.MustParse(&args) +fmt.Println(args.UserIDs) +``` + +```shell +./example --userids john=123 mary=456 +map[john:123 mary:456] +``` + ### Custom validation ```go var args struct { diff --git a/example_test.go b/example_test.go index 9091151..5645156 100644 --- a/example_test.go +++ b/example_test.go @@ -82,6 +82,19 @@ func Example_multipleValues() { // output: Fetching the following IDs from localhost: [1 2 3] } +// This example demonstrates arguments with keys and values +func Example_mappings() { + // The args you would pass in on the command line + os.Args = split("./example --userids john=123 mary=456") + + var args struct { + UserIDs map[string]int + } + MustParse(&args) + fmt.Println(args.UserIDs) + // output: map[john:123 mary:456] +} + // This eample demonstrates multiple value arguments that can be mixed with // other arguments. func Example_multipleMixed() { diff --git a/parse.go b/parse.go index b7d159d..d357d5c 100644 --- a/parse.go +++ b/parse.go @@ -50,15 +50,14 @@ type spec struct { field reflect.StructField // the struct field from which this option was created long string // the --long form for this option, or empty if none short string // the -s short form for this option, or empty if none - multiple bool - required bool - positional bool - separate bool - help string - env string - boolean bool - defaultVal string // default value for this option - placeholder string // name of the data in help + cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) + required bool // if true, this option must be present on the command line + positional bool // if true, this option will be looked for in the positional flags + separate bool // if true, each slice and map entry will have its own --flag + help string // the help text for this option + env string // the name of the environment variable for this option, or empty for none + defaultVal string // default value for this option + placeholder string // name of the data in help } // command represents a named subcommand, or the top-level command @@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { if !isSubcommand { cmd.specs = append(cmd.specs, &spec) - var parseable bool - parseable, spec.boolean, spec.multiple = canParse(field.Type) - if !parseable { + var err error + spec.cardinality, err = cardinalityOf(field.Type) + if err != nil { errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())) return false } - if spec.multiple && hasDefault { - errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields", + if spec.cardinality == multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", t.Name(), field.Name)) return false } @@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error continue } - if spec.multiple { + if spec.cardinality == multiple { // expect a CSV string in an environment // variable in the case of multiple values values, err := csv.NewReader(strings.NewReader(value)).Read() @@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error err, ) } - if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil { + if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", spec.env, @@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error { wasPresent[spec] = true // deal with the case of multiple values - if spec.multiple { + if spec.cardinality == multiple { var values []string if value == "" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { @@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error { } else { values = append(values, value) } - err := setSlice(p.val(spec.dest), values, !spec.separate) + err := setSliceOrMap(p.val(spec.dest), values, !spec.separate) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error { // if it's a flag and it has no value then set the value to true // use boolean because this takes account of TextUnmarshaler - if spec.boolean && value == "" { + if spec.cardinality == zero && value == "" { value = "true" } @@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error { break } wasPresent[spec] = true - if spec.multiple { - err := setSlice(p.val(spec.dest), positionals, true) + if spec.cardinality == multiple { + err := setSliceOrMap(p.val(spec.dest), positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.field.Name, err) } @@ -702,37 +701,6 @@ func (p *Parser) val(dest path) reflect.Value { return v } -// parse a value as the appropriate type and store it in the struct -func setSlice(dest reflect.Value, values []string, trunc bool) error { - if !dest.CanSet() { - return fmt.Errorf("field is not writable") - } - - var ptr bool - elem := dest.Type().Elem() - if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { - ptr = true - elem = elem.Elem() - } - - // Truncate the dest slice in case default values exist - if trunc && !dest.IsNil() { - dest.SetLen(0) - } - - for _, s := range values { - v := reflect.New(elem) - if err := scalar.ParseValue(v.Elem(), s); err != nil { - return err - } - if !ptr { - v = v.Elem() - } - dest.Set(reflect.Append(dest, v)) - } - return nil -} - // findOption finds an option from its name, or returns null if no spec is found func findOption(specs []*spec, name string) *spec { for _, spec := range specs { @@ -759,7 +727,7 @@ func findSubcommand(cmds []*command, name string) *command { // isZero returns true if v contains the zero value for its type func isZero(v reflect.Value) bool { t := v.Type() - if t.Kind() == reflect.Slice { + if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { return v.IsNil() } if !t.Comparable() { diff --git a/parse_test.go b/parse_test.go index ce3068e..d03cbfd 100644 --- a/parse_test.go +++ b/parse_test.go @@ -220,6 +220,60 @@ func TestLongFlag(t *testing.T) { assert.Equal(t, "xyz", args.Foo) } +func TestSlice(t *testing.T) { + var args struct { + Strings []string + } + err := parse("--strings a b c", &args) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, args.Strings) +} +func TestSliceOfBools(t *testing.T) { + var args struct { + B []bool + } + + err := parse("--b true false true", &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, args.B) +} + +func TestMap(t *testing.T) { + var args struct { + Values map[string]int + } + err := parse("--values a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestMapPositional(t *testing.T) { + var args struct { + Values map[string]int `arg:"positional"` + } + err := parse("a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + +func TestMapWithSeparate(t *testing.T) { + var args struct { + Values map[string]int `arg:"separate"` + } + err := parse("--values a=1 --values b=2 --values c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + func TestPlaceholder(t *testing.T) { var args struct { Input string `arg:"positional" placeholder:"SRC"` @@ -688,6 +742,17 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { assert.Error(t, err) } +func TestEnvironmentVariableMap(t *testing.T) { + var args struct { + Foo map[int]string `arg:"env"` + } + setenv(t, "FOO", "1=one,99=ninetynine") + MustParse(&args) + assert.Len(t, args.Foo, 2) + assert.Equal(t, "one", args.Foo[1]) + assert.Equal(t, "ninetynine", args.Foo[99]) +} + func TestEnvironmentVariableIgnored(t *testing.T) { var args struct { Foo string `arg:"env"` @@ -1223,7 +1288,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { } err := parse("", &args) - assert.EqualError(t, err, ".A: default values are not supported for slice fields") + assert.EqualError(t, err, ".A: default values are not supported for slice or map fields") } func TestUnexportedFieldsSkipped(t *testing.T) { diff --git a/reflect.go b/reflect.go index f1e8e8d..1806973 100644 --- a/reflect.go +++ b/reflect.go @@ -2,6 +2,7 @@ package arg import ( "encoding" + "fmt" "reflect" "unicode" "unicode/utf8" @@ -11,42 +12,67 @@ import ( var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() -// canParse returns true if the type can be parsed from a string -func canParse(t reflect.Type) (parseable, boolean, multiple bool) { - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return +// cardinality tracks how many tokens are expected for a given spec +// - zero is a boolean, which does to expect any value +// - one is an ordinary option that will be parsed from a single token +// - multiple is a slice or map that can accept zero or more tokens +type cardinality int + +const ( + zero cardinality = iota + one + multiple + unsupported +) + +func (k cardinality) String() string { + switch k { + case zero: + return "zero" + case one: + return "one" + case multiple: + return "multiple" + case unsupported: + return "unsupported" + default: + return fmt.Sprintf("unknown(%d)", int(k)) + } +} + +// cardinalityOf returns true if the type can be parsed from a string +func cardinalityOf(t reflect.Type) (cardinality, error) { + if scalar.CanParse(t) { + if isBoolean(t) { + return zero, nil + } else { + return one, nil + } } - // Look inside pointer types - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - // Look inside slice types - if t.Kind() == reflect.Slice { - multiple = true - t = t.Elem() - } - - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return - } - - // Look inside pointer types (again, in case of []*Type) + // look inside pointer types if t.Kind() == reflect.Ptr { t = t.Elem() } - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return + // look inside slice and map types + switch t.Kind() { + case reflect.Slice: + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem()) + } + return multiple, nil + case reflect.Map: + if !scalar.CanParse(t.Key()) { + return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem()) + } + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem()) + } + return multiple, nil + default: + return unsupported, fmt.Errorf("cannot parse into %v", t) } - - return false, false, false } // isBoolean returns true if the type can be parsed from a single string diff --git a/reflect_test.go b/reflect_test.go index 07b459c..8d65fd9 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -7,36 +7,54 @@ import ( "github.com/stretchr/testify/assert" ) -func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) { - p, b, m := canParse(typ) - assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p) - assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b) - assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m) +func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) { + actual, err := cardinalityOf(typ) + assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual) + if expected == unsupported { + assert.Error(t, err) + } } -func TestCanParse(t *testing.T) { +func TestCardinalityOf(t *testing.T) { var b bool var i int var s string var f float64 var bs []bool var is []int + var m map[string]int + var unsupported1 struct{} + var unsupported2 []struct{} + var unsupported3 map[string]struct{} + var unsupported4 map[struct{}]string - assertCanParse(t, reflect.TypeOf(b), true, true, false) - assertCanParse(t, reflect.TypeOf(i), true, false, false) - assertCanParse(t, reflect.TypeOf(s), true, false, false) - assertCanParse(t, reflect.TypeOf(f), true, false, false) + assertCardinality(t, reflect.TypeOf(b), zero) + assertCardinality(t, reflect.TypeOf(i), one) + assertCardinality(t, reflect.TypeOf(s), one) + assertCardinality(t, reflect.TypeOf(f), one) - assertCanParse(t, reflect.TypeOf(&b), true, true, false) - assertCanParse(t, reflect.TypeOf(&s), true, false, false) - assertCanParse(t, reflect.TypeOf(&i), true, false, false) - assertCanParse(t, reflect.TypeOf(&f), true, false, false) + assertCardinality(t, reflect.TypeOf(&b), zero) + assertCardinality(t, reflect.TypeOf(&s), one) + assertCardinality(t, reflect.TypeOf(&i), one) + assertCardinality(t, reflect.TypeOf(&f), one) - assertCanParse(t, reflect.TypeOf(bs), true, true, true) - assertCanParse(t, reflect.TypeOf(&bs), true, true, true) + assertCardinality(t, reflect.TypeOf(bs), multiple) + assertCardinality(t, reflect.TypeOf(is), multiple) - assertCanParse(t, reflect.TypeOf(is), true, false, true) - assertCanParse(t, reflect.TypeOf(&is), true, false, true) + assertCardinality(t, reflect.TypeOf(&bs), multiple) + assertCardinality(t, reflect.TypeOf(&is), multiple) + + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) + + assertCardinality(t, reflect.TypeOf(unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported4), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported) } type implementsTextUnmarshaler struct{} @@ -45,13 +63,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error { return nil } -func TestCanParseTextUnmarshaler(t *testing.T) { - var u implementsTextUnmarshaler - var su []implementsTextUnmarshaler - assertCanParse(t, reflect.TypeOf(u), true, false, false) - assertCanParse(t, reflect.TypeOf(&u), true, false, false) - assertCanParse(t, reflect.TypeOf(su), true, false, true) - assertCanParse(t, reflect.TypeOf(&su), true, false, true) +func TestCardinalityTextUnmarshaler(t *testing.T) { + var x implementsTextUnmarshaler + var s []implementsTextUnmarshaler + var m []implementsTextUnmarshaler + assertCardinality(t, reflect.TypeOf(x), one) + assertCardinality(t, reflect.TypeOf(&x), one) + assertCardinality(t, reflect.TypeOf(s), multiple) + assertCardinality(t, reflect.TypeOf(&s), multiple) + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) } func TestIsExported(t *testing.T) { @@ -60,3 +81,11 @@ func TestIsExported(t *testing.T) { assert.False(t, isExported("")) assert.False(t, isExported(string([]byte{255}))) } + +func TestCardinalityString(t *testing.T) { + assert.Equal(t, "zero", zero.String()) + assert.Equal(t, "one", one.String()) + assert.Equal(t, "multiple", multiple.String()) + assert.Equal(t, "unsupported", unsupported.String()) + assert.Equal(t, "unknown(42)", cardinality(42).String()) +} diff --git a/sequence.go b/sequence.go new file mode 100644 index 0000000..35a3614 --- /dev/null +++ b/sequence.go @@ -0,0 +1,123 @@ +package arg + +import ( + "fmt" + "reflect" + "strings" + + scalar "github.com/alexflint/go-scalar" +) + +// setSliceOrMap parses a sequence of strings into a slice or map. If clear is +// true then any values already in the slice or map are first removed. +func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + t := dest.Type() + if t.Kind() == reflect.Ptr { + dest = dest.Elem() + t = t.Elem() + } + + switch t.Kind() { + case reflect.Slice: + return setSlice(dest, values, clear) + case reflect.Map: + return setMap(dest, values, clear) + default: + return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t) + } +} + +// setSlice parses a sequence of strings and inserts them into a slice. If clear +// is true then any values already in the slice are removed. +func setSlice(dest reflect.Value, values []string, clear bool) error { + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { + ptr = true + elem = elem.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + dest.SetLen(0) + } + + // parse the values one-by-one + for _, s := range values { + v := reflect.New(elem) + if err := scalar.ParseValue(v.Elem(), s); err != nil { + return err + } + if !ptr { + v = v.Elem() + } + dest.Set(reflect.Append(dest, v)) + } + return nil +} + +// setMap parses a sequence of name=value strings and inserts them into a map. +// If clear is true then any values already in the map are removed. +func setMap(dest reflect.Value, values []string, clear bool) error { + // determine the key and value type + var keyIsPtr bool + keyType := dest.Type().Key() + if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) { + keyIsPtr = true + keyType = keyType.Elem() + } + + var valIsPtr bool + valType := dest.Type().Elem() + if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) { + valIsPtr = true + valType = valType.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + for _, k := range dest.MapKeys() { + dest.SetMapIndex(k, reflect.Value{}) + } + } + + // allocate the map if it is not allocated + if dest.IsNil() { + dest.Set(reflect.MakeMap(dest.Type())) + } + + // parse the values one-by-one + for _, s := range values { + // split at the first equals sign + pos := strings.Index(s, "=") + if pos == -1 { + return fmt.Errorf("cannot parse %q into a map, expected format key=value", s) + } + + // parse the key + k := reflect.New(keyType) + if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil { + return err + } + if !keyIsPtr { + k = k.Elem() + } + + // parse the value + v := reflect.New(valType) + if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil { + return err + } + if !valIsPtr { + v = v.Elem() + } + + // add it to the map + dest.SetMapIndex(k, v) + } + return nil +} diff --git a/sequence_test.go b/sequence_test.go new file mode 100644 index 0000000..fde3e3a --- /dev/null +++ b/sequence_test.go @@ -0,0 +1,152 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetSliceWithoutClearing(t *testing.T) { + xs := []int{10} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + require.NoError(t, err) + assert.Equal(t, []int{10, 1, 2, 3}, xs) +} + +func TestSetSliceAfterClearing(t *testing.T) { + xs := []int{100} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, xs) +} + +func TestSetSliceInvalid(t *testing.T) { + xs := []int{100} + entries := []string{"invalid"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetSlicePtr(t *testing.T) { + var xs []*int + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, *xs[0]) + assert.Equal(t, 2, *xs[1]) + assert.Equal(t, 3, *xs[2]) +} + +func TestSetSliceTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var xs []*textUnmarshaler + entries := []string{"a", "aa", "aaa"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, xs[0].val) + assert.Equal(t, 2, xs[1].val) + assert.Equal(t, 3, xs[2].val) +} + +func TestSetMapWithoutClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) + assert.Equal(t, 10, m["foo"]) +} + +func TestSetMapAfterClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 2) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) +} + +func TestSetMapWithKeyPointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[*string]int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) +} + +func TestSetMapWithValuePointer(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]*int + entries := []string{"abc=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 1) + assert.Equal(t, 123, *m["abc"]) +} + +func TestSetMapTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[textUnmarshaler]*textUnmarshaler + entries := []string{"a=123", "aa=12", "aaa=1"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) + assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) + assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) +} + +func TestSetMapInvalidKey(t *testing.T) { + var m map[int]int + entries := []string{"invalid=123"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapInvalidValue(t *testing.T) { + var m map[int]int + entries := []string{"123=invalid"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetMapMalformed(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]string + entries := []string{"missing_equals_sign"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +} + +func TestSetSliceOrMapErrors(t *testing.T) { + var err error + var dest reflect.Value + + // converting a slice to a reflect.Value in this way will make it read only + var cannotSet []int + dest = reflect.ValueOf(cannotSet) + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in something that is not a slice or a map + var notSliceOrMap string + dest = reflect.ValueOf(¬SliceOrMap).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) + + // check what happens when we pass in a pointer to something that is not a slice or a map + var stringPtr *string + dest = reflect.ValueOf(&stringPtr).Elem() + err = setSliceOrMap(dest, nil, false) + assert.Error(t, err) +} diff --git a/usage.go b/usage.go index cbbb021..231476b 100644 --- a/usage.go +++ b/usage.go @@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) { for _, spec := range positionals { // prefix with a space fmt.Fprint(w, " ") - if spec.multiple { + if spec.cardinality == multiple { if !spec.required { fmt.Fprint(w, "[") } @@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) { // write the list of built in options p.printOption(w, &spec{ - boolean: true, - long: "help", - short: "h", - help: "display this help and exit", + cardinality: zero, + long: "help", + short: "h", + help: "display this help and exit", }) if p.version != "" { p.printOption(w, &spec{ - boolean: true, - long: "version", - help: "display version and exit", + cardinality: zero, + long: "version", + help: "display version and exit", }) } @@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) { } func synopsis(spec *spec, form string) string { - if spec.boolean { + if spec.cardinality == zero { return form } return form + " " + spec.placeholder