diff --git a/parse.go b/parse.go index 3895ce9..6768699 100644 --- a/parse.go +++ b/parse.go @@ -1,7 +1,6 @@ package arg import ( - "encoding" "errors" "fmt" "os" @@ -20,16 +19,13 @@ type spec struct { help string env string wasPresent bool - isBool bool + boolean bool fieldName string // for generating helpful errors } // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") -// The TextUnmarshaler type in reflection form -var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() - // MustParse processes command line arguments and exits upon failure func MustParse(dest ...interface{}) *Parser { p, err := NewParser(dest...) @@ -94,33 +90,10 @@ func NewParser(dests ...interface{}) (*Parser, error) { // wait until setScalar because it means that a program with invalid argument // fields will always fail regardless of whether the arguments it recieved happend // to exercise those fields. - if !field.Type.Implements(textUnsmarshalerType) { - scalarType := field.Type - // Look inside pointer types - if scalarType.Kind() == reflect.Ptr { - scalarType = scalarType.Elem() - } - // Check for bool - if scalarType.Kind() == reflect.Bool { - spec.isBool = true - } - // Look inside slice types - if scalarType.Kind() == reflect.Slice { - spec.multiple = true - scalarType = scalarType.Elem() - } - // Look inside pointer types (again, in case of []*Type) - if scalarType.Kind() == reflect.Ptr { - scalarType = scalarType.Elem() - } - - // Check for unsupported types - switch scalarType.Kind() { - case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Struct, - reflect.Complex64, reflect.Complex128: - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) - } + var parseable bool + parseable, spec.boolean, spec.multiple = canParse(field.Type) + if !parseable { + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String()) } // Look at the tag @@ -264,8 +237,8 @@ func process(specs []*spec, args []string) error { } // if it's a flag and it has no value then set the value to true - // use isBool because this takes account of TextUnmarshaler - if spec.isBool && value == "" { + // use boolean because this takes account of TextUnmarshaler + if spec.boolean && value == "" { value = "true" } @@ -345,3 +318,38 @@ func setSlice(dest reflect.Value, values []string) error { } return nil } + +// canParse returns true if the type can be parsed from a string +func canParse(t reflect.Type) (parseable, boolean, multiple bool) { + parseable, boolean = isScalar(t) + if parseable { + return + } + + // Look inside pointer types + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + // Look inside slice types + if t.Kind() == reflect.Slice { + multiple = true + t = t.Elem() + } + + parseable, boolean = isScalar(t) + if parseable { + return + } + + // Look inside pointer types (again, in case of []*Type) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + parseable, boolean = isScalar(t) + if parseable { + return + } + + return false, false, false +} diff --git a/parse_test.go b/parse_test.go index a915910..e33fe76 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1,6 +1,8 @@ package arg import ( + "net" + "net/mail" "os" "strings" "testing" @@ -541,3 +543,74 @@ func TestSliceUnmarhsaler(t *testing.T) { assert.EqualValues(t, 5, (*args.Foo)[0]) assert.Equal(t, "xyz", args.Bar) } + +func TestIP(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestPtrToIP(t *testing.T) { + var args struct { + Host *net.IP + } + err := parse("--host 192.168.0.1", &args) + require.NoError(t, err) + assert.Equal(t, "192.168.0.1", args.Host.String()) +} + +func TestIPSlice(t *testing.T) { + var args struct { + Host []net.IP + } + err := parse("--host 192.168.0.1 127.0.0.1", &args) + require.NoError(t, err) + require.Len(t, args.Host, 2) + assert.Equal(t, "192.168.0.1", args.Host[0].String()) + assert.Equal(t, "127.0.0.1", args.Host[1].String()) +} + +func TestInvalidIPAddress(t *testing.T) { + var args struct { + Host net.IP + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMAC(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host 0123.4567.89ab", &args) + require.NoError(t, err) + assert.Equal(t, "01:23:45:67:89:ab", args.Host.String()) +} + +func TestInvalidMac(t *testing.T) { + var args struct { + Host net.HardwareAddr + } + err := parse("--host xxx", &args) + assert.Error(t, err) +} + +func TestMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient foo@example.com", &args) + require.NoError(t, err) + assert.Equal(t, "", args.Recipient.String()) +} + +func TestInvalidMailAddr(t *testing.T) { + var args struct { + Recipient mail.Address + } + err := parse("--recipient xxx", &args) + assert.Error(t, err) +} diff --git a/scalar.go b/scalar.go index 67b4540..e79b002 100644 --- a/scalar.go +++ b/scalar.go @@ -3,11 +3,57 @@ package arg import ( "encoding" "fmt" + "net" + "net/mail" "reflect" "strconv" "time" ) +// The reflected form of some special types +var ( + textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + durationType = reflect.TypeOf(time.Duration(0)) + mailAddressType = reflect.TypeOf(mail.Address{}) + ipType = reflect.TypeOf(net.IP{}) + macType = reflect.TypeOf(net.HardwareAddr{}) +) + +// isScalar returns true if the type can be parsed from a single string +func isScalar(t reflect.Type) (scalar, boolean bool) { + // If it implements encoding.TextUnmarshaler then use that + if t.Implements(textUnmarshalerType) { + // scalar=YES, boolean=NO + return true, false + } + + // If we have a pointer then dereference it + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // Check for other special types + switch t { + case durationType, mailAddressType, ipType, macType: + // scalar=YES, boolean=NO + return true, false + } + + // Fall back to checking the kind + switch t.Kind() { + case reflect.Bool: + // scalar=YES, boolean=YES + return true, true + case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + // scalar=YES, boolean=NO + return true, false + } + // scalar=NO, boolean=NO + return false, false +} + // set a value from a string func setScalar(v reflect.Value, s string) error { if !v.CanSet() { @@ -35,11 +81,32 @@ func setScalar(v reflect.Value, s string) error { // Switch on concrete type switch scalar.(type) { case time.Duration: - x, err := time.ParseDuration(s) + duration, err := time.ParseDuration(s) if err != nil { return err } - v.Set(reflect.ValueOf(x)) + v.Set(reflect.ValueOf(duration)) + return nil + case mail.Address: + addr, err := mail.ParseAddress(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(*addr)) + return nil + case net.IP: + ip := net.ParseIP(s) + if ip == nil { + return fmt.Errorf(`invalid IP address: "%s"`, s) + } + v.Set(reflect.ValueOf(ip)) + return nil + case net.HardwareAddr: + ip, err := net.ParseMAC(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(ip)) return nil } diff --git a/usage.go b/usage.go index 61f0ad6..eea6d29 100644 --- a/usage.go +++ b/usage.go @@ -97,7 +97,7 @@ func (p *Parser) WriteHelp(w io.Writer) { } // write the list of built in options - printOption(w, &spec{isBool: true, long: "help", short: "h", help: "display this help and exit"}) + printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"}) } func printOption(w io.Writer, spec *spec) { @@ -127,7 +127,7 @@ func printOption(w io.Writer, spec *spec) { } func synopsis(spec *spec, form string) string { - if spec.isBool { + if spec.boolean { return form } return form + " " + strings.ToUpper(spec.long)