diff --git a/parse.go b/parse.go index ce3892f..39eb52c 100644 --- a/parse.go +++ b/parse.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "reflect" - "strconv" "strings" ) @@ -329,42 +328,3 @@ func setSlice(dest reflect.Value, values []string) error { } return nil } - -// set a value from a string -func setScalar(v reflect.Value, s string) error { - if !v.CanSet() { - return fmt.Errorf("field is not exported") - } - - switch v.Kind() { - case reflect.String: - v.Set(reflect.ValueOf(s)) - case reflect.Bool: - x, err := strconv.ParseBool(s) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(s, 10, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(s, 10, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(s, v.Type().Bits()) - if err != nil { - return err - } - v.Set(reflect.ValueOf(x).Convert(v.Type())) - default: - return fmt.Errorf("not a scalar type: %s", v.Kind()) - } - return nil -} diff --git a/parse_test.go b/parse_test.go index f3e7350..c30809d 100644 --- a/parse_test.go +++ b/parse_test.go @@ -4,6 +4,7 @@ import ( "os" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,7 +18,7 @@ func parse(cmdline string, dest interface{}) error { return p.Parse(strings.Split(cmdline, " ")) } -func TestStringSingle(t *testing.T) { +func TestString(t *testing.T) { var args struct { Foo string } @@ -26,6 +27,50 @@ func TestStringSingle(t *testing.T) { assert.Equal(t, "bar", args.Foo) } +func TestInt(t *testing.T) { + var args struct { + Foo int + } + err := parse("--foo 7", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) +} + +func TestUint(t *testing.T) { + var args struct { + Foo uint + } + err := parse("--foo 7", &args) + require.NoError(t, err) + assert.EqualValues(t, 7, args.Foo) +} + +func TestFloat(t *testing.T) { + var args struct { + Foo float32 + } + err := parse("--foo 3.4", &args) + require.NoError(t, err) + assert.EqualValues(t, 3.4, args.Foo) +} + +func TestDuration(t *testing.T) { + var args struct { + Foo time.Duration + } + err := parse("--foo 3ms", &args) + require.NoError(t, err) + assert.Equal(t, 3*time.Millisecond, args.Foo) +} + +func TestInvalidDuration(t *testing.T) { + var args struct { + Foo time.Duration + } + err := parse("--foo xxx", &args) + require.Error(t, err) +} + func TestMixed(t *testing.T) { var args struct { Foo string `arg:"-f"` diff --git a/scalar.go b/scalar.go new file mode 100644 index 0000000..a3bafe4 --- /dev/null +++ b/scalar.go @@ -0,0 +1,63 @@ +package arg + +import ( + "encoding" + "fmt" + "reflect" + "strconv" + "time" +) + +var ( + durationType = reflect.TypeOf(time.Duration(0)) + textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() +) + +// set a value from a string +func setScalar(v reflect.Value, s string) error { + if !v.CanSet() { + return fmt.Errorf("field is not exported") + } + + // If we have a time.Duration then use time.ParseDuration + if v.Type() == durationType { + x, err := time.ParseDuration(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(x)) + return nil + } + + switch v.Kind() { + case reflect.String: + v.SetString(s) + case reflect.Bool: + x, err := strconv.ParseBool(s) + if err != nil { + return err + } + v.SetBool(x) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(s, 10, v.Type().Bits()) + if err != nil { + return err + } + v.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(s, 10, v.Type().Bits()) + if err != nil { + return err + } + v.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil { + return err + } + v.SetFloat(x) + default: + return fmt.Errorf("not a scalar type: %s", v.Kind()) + } + return nil +}