diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index e8e7f3a..a6cfe20 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -1,14 +1,14 @@ { "ImportPath": "github.com/alexflint/go-arg", "GoVersion": "go1.7", - "GodepVersion": "v79", + "GodepVersion": "v80", "Packages": [ "." ], "Deps": [ { "ImportPath": "github.com/alexflint/go-scalar", - "Rev": "45e5d6cd8605faef82fda2bacc59e734a0b6f1f0" + "Rev": "6ab8ad5e1c5b25ca2783fe83f493c3ab471407e2" }, { "ImportPath": "github.com/stretchr/testify/assert", diff --git a/README.md b/README.md index 8980ba1..8e68ca7 100644 --- a/README.md +++ b/README.md @@ -288,10 +288,10 @@ func (n *NameDotName) MarshalText() (text []byte, err error) { func main() { var args struct { - Name *NameDotName + Name NameDotName } // set default - args.Name = &NameDotName{"file", "txt"} + args.Name = NameDotName{"file", "txt"} arg.MustParse(&args) fmt.Printf("%#v\n", args.Name) } @@ -305,10 +305,10 @@ Options: --help, -h display this help and exit $ ./example -&main.NameDotName{Head:"file", Tail:"txt"} +main.NameDotName{Head:"file", Tail:"txt"} $ ./example --name=foo.bar -&main.NameDotName{Head:"foo", Tail:"bar"} +main.NameDotName{Head:"foo", Tail:"bar"} $ ./example --name=oops Usage: example [--name NAME] diff --git a/parse_test.go b/parse_test.go index 0bc97e3..b72563c 100644 --- a/parse_test.go +++ b/parse_test.go @@ -580,7 +580,7 @@ func TestEnvironmentVariableRequired(t *testing.T) { assert.Equal(t, "bar", args.Foo) } -func TestEnvironmentVariableSliceArgumentString(t *testing.T) { +func TestEnvironmentVariableSliceArgumentString(t *testing.T) { var args struct { Foo []string `arg:"env"` } @@ -589,7 +589,7 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) { assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) } -func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { +func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { var args struct { Foo []int `arg:"env"` } @@ -598,7 +598,7 @@ func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { assert.Equal(t, []int{1, 99}, args.Foo) } -func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { +func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { var args struct { Foo []float32 `arg:"env"` } @@ -607,7 +607,7 @@ func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { assert.Equal(t, []float32{1.1, 99.9}, args.Foo) } -func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { +func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { var args struct { Foo []bool `arg:"env"` } @@ -616,7 +616,7 @@ func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { assert.Equal(t, []bool{true, false, false, true}, args.Foo) } -func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { +func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { var args struct { Foo []int `arg:"env"` } @@ -625,7 +625,7 @@ func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { assert.Error(t, err) } -func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { +func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { var args struct { Foo []bool `arg:"env"` } @@ -644,6 +644,16 @@ func (f *textUnmarshaler) UnmarshalText(b []byte) error { } func TestTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +func TestPtrToTextUnmarshaler(t *testing.T) { // fields that implement TextUnmarshaler should be parsed using that interface var args struct { Foo *textUnmarshaler @@ -654,6 +664,19 @@ func TestTextUnmarshaler(t *testing.T) { } func TestRepeatedTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []textUnmarshaler + } + err := parse("--foo abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestRepeatedPtrToTextUnmarshaler(t *testing.T) { // fields that implement TextUnmarshaler should be parsed using that interface var args struct { Foo []*textUnmarshaler @@ -667,6 +690,19 @@ func TestRepeatedTextUnmarshaler(t *testing.T) { } func TestPositionalTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []textUnmarshaler `arg:"positional"` + } + err := parse("abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestPositionalPtrToTextUnmarshaler(t *testing.T) { // fields that implement TextUnmarshaler should be parsed using that interface var args struct { Foo []*textUnmarshaler `arg:"positional"` diff --git a/vendor/github.com/alexflint/go-scalar/scalar.go b/vendor/github.com/alexflint/go-scalar/scalar.go index 663f143..073392c 100644 --- a/vendor/github.com/alexflint/go-scalar/scalar.go +++ b/vendor/github.com/alexflint/go-scalar/scalar.go @@ -18,7 +18,6 @@ var ( textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() durationType = reflect.TypeOf(time.Duration(0)) mailAddressType = reflect.TypeOf(mail.Address{}) - ipType = reflect.TypeOf(net.IP{}) macType = reflect.TypeOf(net.HardwareAddr{}) ) @@ -47,6 +46,13 @@ func ParseValue(v reflect.Value, s string) error { if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok { return scalar.UnmarshalText([]byte(s)) } + // If it's a value instead of a pointer, check that we can unmarshal it + // via TextUnmarshaler as well + if v.CanAddr() { + if scalar, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok { + return scalar.UnmarshalText([]byte(s)) + } + } // If we have a pointer then dereference it if v.Kind() == reflect.Ptr { @@ -73,13 +79,6 @@ func ParseValue(v reflect.Value, s string) error { } v.Set(reflect.ValueOf(*addr)) return nil - case net.IP: - ip := net.ParseIP(s) - if ip == nil { - return fmt.Errorf(`invalid IP address: "%s"`, s) - } - v.Set(reflect.ValueOf(ip)) - return nil case net.HardwareAddr: ip, err := net.ParseMAC(s) if err != nil { @@ -126,7 +125,7 @@ func ParseValue(v reflect.Value, s string) error { // CanParse returns true if the type can be parsed from a string. func CanParse(t reflect.Type) bool { // If it implements encoding.TextUnmarshaler then use that - if t.Implements(textUnmarshalerType) { + if t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) { return true } @@ -137,7 +136,7 @@ func CanParse(t reflect.Type) bool { // Check for other special types switch t { - case durationType, mailAddressType, ipType, macType: + case durationType, mailAddressType, macType: return true }