Merge pull request #31 from alexflint/parse_ip_mac_and_email

Parse IP addresses, MAC addresses, and email addresses
This commit is contained in:
Alex Flint 2016-01-23 21:07:42 -08:00
commit 77dd0df006
4 changed files with 186 additions and 38 deletions

View File

@ -1,7 +1,6 @@
package arg package arg
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -20,16 +19,13 @@ type spec struct {
help string help string
env string env string
wasPresent bool wasPresent bool
isBool bool boolean bool
fieldName string // for generating helpful errors fieldName string // for generating helpful errors
} }
// ErrHelp indicates that -h or --help were provided // ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user") var ErrHelp = errors.New("help requested by user")
// The TextUnmarshaler type in reflection form
var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// MustParse processes command line arguments and exits upon failure // MustParse processes command line arguments and exits upon failure
func MustParse(dest ...interface{}) *Parser { func MustParse(dest ...interface{}) *Parser {
p, err := NewParser(dest...) p, err := NewParser(dest...)
@ -94,33 +90,10 @@ func NewParser(dests ...interface{}) (*Parser, error) {
// wait until setScalar because it means that a program with invalid argument // wait until setScalar because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it recieved happend // fields will always fail regardless of whether the arguments it recieved happend
// to exercise those fields. // to exercise those fields.
if !field.Type.Implements(textUnsmarshalerType) { var parseable bool
scalarType := field.Type parseable, spec.boolean, spec.multiple = canParse(field.Type)
// Look inside pointer types if !parseable {
if scalarType.Kind() == reflect.Ptr { return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())
scalarType = scalarType.Elem()
}
// Check for bool
if scalarType.Kind() == reflect.Bool {
spec.isBool = true
}
// Look inside slice types
if scalarType.Kind() == reflect.Slice {
spec.multiple = true
scalarType = scalarType.Elem()
}
// Look inside pointer types (again, in case of []*Type)
if scalarType.Kind() == reflect.Ptr {
scalarType = scalarType.Elem()
}
// Check for unsupported types
switch scalarType.Kind() {
case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
reflect.Map, reflect.Ptr, reflect.Struct,
reflect.Complex64, reflect.Complex128:
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
}
} }
// Look at the tag // Look at the tag
@ -264,8 +237,8 @@ func process(specs []*spec, args []string) error {
} }
// if it's a flag and it has no value then set the value to true // if it's a flag and it has no value then set the value to true
// use isBool because this takes account of TextUnmarshaler // use boolean because this takes account of TextUnmarshaler
if spec.isBool && value == "" { if spec.boolean && value == "" {
value = "true" value = "true"
} }
@ -345,3 +318,38 @@ func setSlice(dest reflect.Value, values []string) error {
} }
return nil return nil
} }
// canParse returns true if the type can be parsed from a string
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
parseable, boolean = isScalar(t)
if parseable {
return
}
// Look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Look inside slice types
if t.Kind() == reflect.Slice {
multiple = true
t = t.Elem()
}
parseable, boolean = isScalar(t)
if parseable {
return
}
// Look inside pointer types (again, in case of []*Type)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
parseable, boolean = isScalar(t)
if parseable {
return
}
return false, false, false
}

View File

@ -1,6 +1,8 @@
package arg package arg
import ( import (
"net"
"net/mail"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -541,3 +543,74 @@ func TestSliceUnmarhsaler(t *testing.T) {
assert.EqualValues(t, 5, (*args.Foo)[0]) assert.EqualValues(t, 5, (*args.Foo)[0])
assert.Equal(t, "xyz", args.Bar) assert.Equal(t, "xyz", args.Bar)
} }
func TestIP(t *testing.T) {
var args struct {
Host net.IP
}
err := parse("--host 192.168.0.1", &args)
require.NoError(t, err)
assert.Equal(t, "192.168.0.1", args.Host.String())
}
func TestPtrToIP(t *testing.T) {
var args struct {
Host *net.IP
}
err := parse("--host 192.168.0.1", &args)
require.NoError(t, err)
assert.Equal(t, "192.168.0.1", args.Host.String())
}
func TestIPSlice(t *testing.T) {
var args struct {
Host []net.IP
}
err := parse("--host 192.168.0.1 127.0.0.1", &args)
require.NoError(t, err)
require.Len(t, args.Host, 2)
assert.Equal(t, "192.168.0.1", args.Host[0].String())
assert.Equal(t, "127.0.0.1", args.Host[1].String())
}
func TestInvalidIPAddress(t *testing.T) {
var args struct {
Host net.IP
}
err := parse("--host xxx", &args)
assert.Error(t, err)
}
func TestMAC(t *testing.T) {
var args struct {
Host net.HardwareAddr
}
err := parse("--host 0123.4567.89ab", &args)
require.NoError(t, err)
assert.Equal(t, "01:23:45:67:89:ab", args.Host.String())
}
func TestInvalidMac(t *testing.T) {
var args struct {
Host net.HardwareAddr
}
err := parse("--host xxx", &args)
assert.Error(t, err)
}
func TestMailAddr(t *testing.T) {
var args struct {
Recipient mail.Address
}
err := parse("--recipient foo@example.com", &args)
require.NoError(t, err)
assert.Equal(t, "<foo@example.com>", args.Recipient.String())
}
func TestInvalidMailAddr(t *testing.T) {
var args struct {
Recipient mail.Address
}
err := parse("--recipient xxx", &args)
assert.Error(t, err)
}

View File

@ -3,11 +3,57 @@ package arg
import ( import (
"encoding" "encoding"
"fmt" "fmt"
"net"
"net/mail"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
) )
// The reflected form of some special types
var (
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
durationType = reflect.TypeOf(time.Duration(0))
mailAddressType = reflect.TypeOf(mail.Address{})
ipType = reflect.TypeOf(net.IP{})
macType = reflect.TypeOf(net.HardwareAddr{})
)
// isScalar returns true if the type can be parsed from a single string
func isScalar(t reflect.Type) (scalar, boolean bool) {
// If it implements encoding.TextUnmarshaler then use that
if t.Implements(textUnmarshalerType) {
// scalar=YES, boolean=NO
return true, false
}
// If we have a pointer then dereference it
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Check for other special types
switch t {
case durationType, mailAddressType, ipType, macType:
// scalar=YES, boolean=NO
return true, false
}
// Fall back to checking the kind
switch t.Kind() {
case reflect.Bool:
// scalar=YES, boolean=YES
return true, true
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
// scalar=YES, boolean=NO
return true, false
}
// scalar=NO, boolean=NO
return false, false
}
// set a value from a string // set a value from a string
func setScalar(v reflect.Value, s string) error { func setScalar(v reflect.Value, s string) error {
if !v.CanSet() { if !v.CanSet() {
@ -35,11 +81,32 @@ func setScalar(v reflect.Value, s string) error {
// Switch on concrete type // Switch on concrete type
switch scalar.(type) { switch scalar.(type) {
case time.Duration: case time.Duration:
x, err := time.ParseDuration(s) duration, err := time.ParseDuration(s)
if err != nil { if err != nil {
return err return err
} }
v.Set(reflect.ValueOf(x)) v.Set(reflect.ValueOf(duration))
return nil
case mail.Address:
addr, err := mail.ParseAddress(s)
if err != nil {
return err
}
v.Set(reflect.ValueOf(*addr))
return nil
case net.IP:
ip := net.ParseIP(s)
if ip == nil {
return fmt.Errorf(`invalid IP address: "%s"`, s)
}
v.Set(reflect.ValueOf(ip))
return nil
case net.HardwareAddr:
ip, err := net.ParseMAC(s)
if err != nil {
return err
}
v.Set(reflect.ValueOf(ip))
return nil return nil
} }

View File

@ -97,7 +97,7 @@ func (p *Parser) WriteHelp(w io.Writer) {
} }
// write the list of built in options // write the list of built in options
printOption(w, &spec{isBool: true, long: "help", short: "h", help: "display this help and exit"}) printOption(w, &spec{boolean: true, long: "help", short: "h", help: "display this help and exit"})
} }
func printOption(w io.Writer, spec *spec) { func printOption(w io.Writer, spec *spec) {
@ -127,7 +127,7 @@ func printOption(w io.Writer, spec *spec) {
} }
func synopsis(spec *spec, form string) string { func synopsis(spec *spec, form string) string {
if spec.isBool { if spec.boolean {
return form return form
} }
return form + " " + strings.ToUpper(spec.long) return form + " " + strings.ToUpper(spec.long)