deal with booleans correctly

This commit is contained in:
Alex Flint 2017-02-15 18:37:19 -08:00
parent 38c51f4cab
commit 44a8b85d82
2 changed files with 45 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package arg package arg
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -445,9 +446,21 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
return false, false, false return false, false, false
} }
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// isScalar returns true if the type can be parsed from a single string // isScalar returns true if the type can be parsed from a single string
func isScalar(t reflect.Type) (bool, bool) { func isScalar(t reflect.Type) (parseable, boolean bool) {
return scalar.CanParse(t), t.Kind() == reflect.Bool parseable = scalar.CanParse(t)
switch {
case t.Implements(textUnmarshalerType):
return parseable, false
case t.Kind() == reflect.Bool:
return parseable, true
case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
return parseable, true
default:
return parseable, false
}
} }
// set a value from a string // set a value from a string

View File

@ -33,46 +33,71 @@ func parse(cmdline string, dest interface{}) error {
func TestString(t *testing.T) { func TestString(t *testing.T) {
var args struct { var args struct {
Foo string Foo string
Ptr *string
} }
err := parse("--foo bar", &args) err := parse("--foo bar --ptr baz", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "bar", args.Foo) assert.Equal(t, "bar", args.Foo)
assert.Equal(t, "baz", *args.Ptr)
}
func TestBool(t *testing.T) {
var args struct {
A bool
B bool
C *bool
D *bool
}
err := parse("--a --c", &args)
require.NoError(t, err)
assert.True(t, args.A)
assert.False(t, args.B)
assert.True(t, *args.C)
assert.Nil(t, args.D)
} }
func TestInt(t *testing.T) { func TestInt(t *testing.T) {
var args struct { var args struct {
Foo int Foo int
Ptr *int
} }
err := parse("--foo 7", &args) err := parse("--foo 7 --ptr 8", &args)
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, 7, args.Foo) assert.EqualValues(t, 7, args.Foo)
assert.EqualValues(t, 8, *args.Ptr)
} }
func TestUint(t *testing.T) { func TestUint(t *testing.T) {
var args struct { var args struct {
Foo uint Foo uint
Ptr *uint
} }
err := parse("--foo 7", &args) err := parse("--foo 7 --ptr 8", &args)
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, 7, args.Foo) assert.EqualValues(t, 7, args.Foo)
assert.EqualValues(t, 8, *args.Ptr)
} }
func TestFloat(t *testing.T) { func TestFloat(t *testing.T) {
var args struct { var args struct {
Foo float32 Foo float32
Ptr *float32
} }
err := parse("--foo 3.4", &args) err := parse("--foo 3.4 --ptr 3.5", &args)
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, 3.4, args.Foo) assert.EqualValues(t, 3.4, args.Foo)
assert.EqualValues(t, 3.5, *args.Ptr)
} }
func TestDuration(t *testing.T) { func TestDuration(t *testing.T) {
var args struct { var args struct {
Foo time.Duration Foo time.Duration
Ptr *time.Duration
} }
err := parse("--foo 3ms", &args) err := parse("--foo 3ms --ptr 4ms", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 3*time.Millisecond, args.Foo) assert.Equal(t, 3*time.Millisecond, args.Foo)
assert.Equal(t, 4*time.Millisecond, *args.Ptr)
} }
func TestInvalidDuration(t *testing.T) { func TestInvalidDuration(t *testing.T) {