diff --git a/parse.go b/parse.go index db15e5a..a82e377 100644 --- a/parse.go +++ b/parse.go @@ -1,6 +1,7 @@ package arg import ( + "encoding" "errors" "fmt" "os" @@ -445,9 +446,21 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) { return false, false, false } +var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + // isScalar returns true if the type can be parsed from a single string -func isScalar(t reflect.Type) (bool, bool) { - return scalar.CanParse(t), t.Kind() == reflect.Bool +func isScalar(t reflect.Type) (parseable, boolean bool) { + parseable = scalar.CanParse(t) + switch { + case t.Implements(textUnmarshalerType): + return parseable, false + case t.Kind() == reflect.Bool: + return parseable, true + case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool: + return parseable, true + default: + return parseable, false + } } // set a value from a string diff --git a/parse_test.go b/parse_test.go index 8779b6f..5e88700 100644 --- a/parse_test.go +++ b/parse_test.go @@ -33,46 +33,71 @@ func parse(cmdline string, dest interface{}) error { func TestString(t *testing.T) { var args struct { Foo string + Ptr *string } - err := parse("--foo bar", &args) + err := parse("--foo bar --ptr baz", &args) require.NoError(t, err) assert.Equal(t, "bar", args.Foo) + assert.Equal(t, "baz", *args.Ptr) +} + +func TestBool(t *testing.T) { + var args struct { + A bool + B bool + C *bool + D *bool + } + err := parse("--a --c", &args) + require.NoError(t, err) + assert.True(t, args.A) + assert.False(t, args.B) + assert.True(t, *args.C) + assert.Nil(t, args.D) } func TestInt(t *testing.T) { var args struct { Foo int + Ptr *int } - err := parse("--foo 7", &args) + err := parse("--foo 7 --ptr 8", &args) require.NoError(t, err) assert.EqualValues(t, 7, args.Foo) + assert.EqualValues(t, 8, *args.Ptr) } func TestUint(t *testing.T) { var args struct { Foo uint + Ptr *uint } - err := parse("--foo 7", &args) + err := parse("--foo 7 --ptr 8", &args) require.NoError(t, err) assert.EqualValues(t, 7, args.Foo) + assert.EqualValues(t, 8, *args.Ptr) } func TestFloat(t *testing.T) { var args struct { Foo float32 + Ptr *float32 } - err := parse("--foo 3.4", &args) + err := parse("--foo 3.4 --ptr 3.5", &args) require.NoError(t, err) assert.EqualValues(t, 3.4, args.Foo) + assert.EqualValues(t, 3.5, *args.Ptr) } func TestDuration(t *testing.T) { var args struct { Foo time.Duration + Ptr *time.Duration } - err := parse("--foo 3ms", &args) + err := parse("--foo 3ms --ptr 4ms", &args) require.NoError(t, err) assert.Equal(t, 3*time.Millisecond, args.Foo) + assert.Equal(t, 4*time.Millisecond, *args.Ptr) } func TestInvalidDuration(t *testing.T) {