diff --git a/parse.go b/parse.go index 84a7ed1..37df734 100644 --- a/parse.go +++ b/parse.go @@ -377,7 +377,7 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { cmd.specs = append(cmd.specs, &spec) var parseable bool - parseable, spec.boolean, spec.multiple = canParse(field.Type) + //parseable, spec.boolean, spec.multiple = canParse(field.Type) if !parseable { errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())) @@ -728,7 +728,7 @@ func findSubcommand(cmds []*command, name string) *command { // isZero returns true if v contains the zero value for its type func isZero(v reflect.Value) bool { t := v.Type() - if t.Kind() == reflect.Slice { + if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { return v.IsNil() } if !t.Comparable() { diff --git a/parse_test.go b/parse_test.go index ce3068e..0decfc1 100644 --- a/parse_test.go +++ b/parse_test.go @@ -220,6 +220,16 @@ func TestLongFlag(t *testing.T) { assert.Equal(t, "xyz", args.Foo) } +func TestSliceOfBools(t *testing.T) { + var args struct { + B []bool + } + + err := parse("--b true false true", &args) + require.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, args.B) +} + func TestPlaceholder(t *testing.T) { var args struct { Input string `arg:"positional" placeholder:"SRC"` diff --git a/reflect.go b/reflect.go index f1e8e8d..c4fc5d9 100644 --- a/reflect.go +++ b/reflect.go @@ -2,6 +2,7 @@ package arg import ( "encoding" + "fmt" "reflect" "unicode" "unicode/utf8" @@ -11,42 +12,71 @@ import ( var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() -// canParse returns true if the type can be parsed from a string -func canParse(t reflect.Type) (parseable, boolean, multiple bool) { - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return +// kind is used to track the various kinds of options: +// - regular is an ordinary option that will be parsed from a single token +// - binary is an option that will be true if present but does not expect an explicit value +// - sequence is an option that accepts multiple values and will end up in a slice +// - mapping is an option that acccepts multiple key=value strings and will end up in a map +type kind int + +const ( + regular kind = iota + binary + sequence + mapping + unsupported +) + +func (k kind) String() string { + switch k { + case regular: + return "regular" + case binary: + return "binary" + case sequence: + return "sequence" + case mapping: + return "mapping" + case unsupported: + return "unsupported" + default: + return fmt.Sprintf("unknown(%d)", int(k)) + } +} + +// kindOf returns true if the type can be parsed from a string +func kindOf(t reflect.Type) (kind, error) { + if scalar.CanParse(t) { + if isBoolean(t) { + return binary, nil + } else { + return regular, nil + } } - // Look inside pointer types - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - // Look inside slice types - if t.Kind() == reflect.Slice { - multiple = true - t = t.Elem() - } - - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return - } - - // Look inside pointer types (again, in case of []*Type) + // look inside pointer types if t.Kind() == reflect.Ptr { t = t.Elem() } - parseable = scalar.CanParse(t) - boolean = isBoolean(t) - if parseable { - return + // look inside slice and map types + switch t.Kind() { + case reflect.Slice: + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into %v", t, t.Elem()) + } + return sequence, nil + case reflect.Map: + if !scalar.CanParse(t.Key()) { + return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the key type %v", t, t.Elem()) + } + if !scalar.CanParse(t.Elem()) { + return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the value type %v", t, t.Elem()) + } + return mapping, nil + default: + return unsupported, fmt.Errorf("cannot parse into %v", t) } - - return false, false, false } // isBoolean returns true if the type can be parsed from a single string diff --git a/reflect_test.go b/reflect_test.go index 07b459c..6a8af49 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -7,36 +7,51 @@ import ( "github.com/stretchr/testify/assert" ) -func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) { - p, b, m := canParse(typ) - assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p) - assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b) - assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m) +func assertKind(t *testing.T, typ reflect.Type, expected kind) { + actual, err := kindOf(typ) + assert.Equal(t, expected, actual, "expected %v to have kind %v but got %v", typ, expected, actual) + if expected == unsupported { + assert.Error(t, err) + } } -func TestCanParse(t *testing.T) { +func TestKindOf(t *testing.T) { var b bool var i int var s string var f float64 var bs []bool var is []int + var m map[string]int + var unsupported1 struct{} + var unsupported2 []struct{} + var unsupported3 map[string]struct{} - assertCanParse(t, reflect.TypeOf(b), true, true, false) - assertCanParse(t, reflect.TypeOf(i), true, false, false) - assertCanParse(t, reflect.TypeOf(s), true, false, false) - assertCanParse(t, reflect.TypeOf(f), true, false, false) + assertKind(t, reflect.TypeOf(b), binary) + assertKind(t, reflect.TypeOf(i), regular) + assertKind(t, reflect.TypeOf(s), regular) + assertKind(t, reflect.TypeOf(f), regular) - assertCanParse(t, reflect.TypeOf(&b), true, true, false) - assertCanParse(t, reflect.TypeOf(&s), true, false, false) - assertCanParse(t, reflect.TypeOf(&i), true, false, false) - assertCanParse(t, reflect.TypeOf(&f), true, false, false) + assertKind(t, reflect.TypeOf(&b), binary) + assertKind(t, reflect.TypeOf(&s), regular) + assertKind(t, reflect.TypeOf(&i), regular) + assertKind(t, reflect.TypeOf(&f), regular) - assertCanParse(t, reflect.TypeOf(bs), true, true, true) - assertCanParse(t, reflect.TypeOf(&bs), true, true, true) + assertKind(t, reflect.TypeOf(bs), sequence) + assertKind(t, reflect.TypeOf(is), sequence) - assertCanParse(t, reflect.TypeOf(is), true, false, true) - assertCanParse(t, reflect.TypeOf(&is), true, false, true) + assertKind(t, reflect.TypeOf(&bs), sequence) + assertKind(t, reflect.TypeOf(&is), sequence) + + assertKind(t, reflect.TypeOf(m), mapping) + assertKind(t, reflect.TypeOf(&m), mapping) + + assertKind(t, reflect.TypeOf(unsupported1), unsupported) + assertKind(t, reflect.TypeOf(&unsupported1), unsupported) + assertKind(t, reflect.TypeOf(unsupported2), unsupported) + assertKind(t, reflect.TypeOf(&unsupported2), unsupported) + assertKind(t, reflect.TypeOf(unsupported3), unsupported) + assertKind(t, reflect.TypeOf(&unsupported3), unsupported) } type implementsTextUnmarshaler struct{} @@ -46,12 +61,15 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error { } func TestCanParseTextUnmarshaler(t *testing.T) { - var u implementsTextUnmarshaler - var su []implementsTextUnmarshaler - assertCanParse(t, reflect.TypeOf(u), true, false, false) - assertCanParse(t, reflect.TypeOf(&u), true, false, false) - assertCanParse(t, reflect.TypeOf(su), true, false, true) - assertCanParse(t, reflect.TypeOf(&su), true, false, true) + var x implementsTextUnmarshaler + var s []implementsTextUnmarshaler + var m []implementsTextUnmarshaler + assertKind(t, reflect.TypeOf(x), regular) + assertKind(t, reflect.TypeOf(&x), regular) + assertKind(t, reflect.TypeOf(s), sequence) + assertKind(t, reflect.TypeOf(&s), sequence) + assertKind(t, reflect.TypeOf(m), mapping) + assertKind(t, reflect.TypeOf(&m), mapping) } func TestIsExported(t *testing.T) { diff --git a/sequence_test.go b/sequence_test.go index 4646811..446bc42 100644 --- a/sequence_test.go +++ b/sequence_test.go @@ -79,3 +79,11 @@ func TestSetMapTextUnmarshaller(t *testing.T) { assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) } + +func TestSetMapMalformed(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[string]string + entries := []string{"missing_equals_sign"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + assert.Error(t, err) +}