diff --git a/parse_test.go b/parse_test.go index a915910..5714ebf 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1,6 +1,7 @@ package arg import ( + "net" "os" "strings" "testing" @@ -541,3 +542,40 @@ func TestSliceUnmarhsaler(t *testing.T) { assert.EqualValues(t, 5, (*args.Foo)[0]) assert.Equal(t, "xyz", args.Bar) } + +func TestIP(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestPtrToIP(t *testing.T) { + var args struct { + Host *net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestIPSlice(t *testing.T) { + var args struct { + Host []net.IP + } + err := parse("--host 192.168.0.1 127.0.0.1", &args) + require.NoError(t, err) + require.Len(t, args.Host, 2) + assert.Equal(t, "192.168.0.1", args.Host[0].String()) + assert.Equal(t, "127.0.0.1", args.Host[1].String()) +} + +func TestInvalidIPAddress(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} diff --git a/scalar.go b/scalar.go index ac56978..e79b002 100644 --- a/scalar.go +++ b/scalar.go @@ -93,18 +93,21 @@ func setScalar(v reflect.Value, s string) error { return err } v.Set(reflect.ValueOf(*addr)) + return nil case net.IP: ip := net.ParseIP(s) if ip == nil { return fmt.Errorf(`invalid IP address: "%s"`, s) } v.Set(reflect.ValueOf(ip)) + return nil case net.HardwareAddr: ip, err := net.ParseMAC(s) if err != nil { return err } v.Set(reflect.ValueOf(ip)) + return nil } // Switch on kind so that we can handle derived types