add support for pointers and TextUnmarshaler

This commit is contained in:
Alex Flint 2016-01-23 19:40:15 -08:00
parent 64a4bab550
commit 865cc5a973
3 changed files with 150 additions and 29 deletions

View File

@ -1,6 +1,7 @@
package arg
import (
"encoding"
"errors"
"fmt"
"os"
@ -20,11 +21,15 @@ type spec struct {
env string
wasPresent bool
isBool bool
fieldName string // for generating helpful errors
}
// ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user")
// The TextUnmarshaler type in reflection form
var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// MustParse processes command line arguments and exits upon failure
func MustParse(dest ...interface{}) *Parser {
p, err := NewParser(dest...)
@ -80,31 +85,42 @@ func NewParser(dests ...interface{}) (*Parser, error) {
}
spec := spec{
long: strings.ToLower(field.Name),
dest: v.Field(i),
long: strings.ToLower(field.Name),
dest: v.Field(i),
fieldName: t.Name() + "." + field.Name,
}
// Get the scalar type for this field
scalarType := field.Type
if scalarType.Kind() == reflect.Slice {
spec.multiple = true
scalarType = scalarType.Elem()
// Check whether this field is supported. It's good to do this here rather than
// wait until setScalar because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it recieved happend
// to exercise those fields.
if !field.Type.Implements(textUnsmarshalerType) {
scalarType := field.Type
// Look inside pointer types
if scalarType.Kind() == reflect.Ptr {
scalarType = scalarType.Elem()
}
// Check for bool
if scalarType.Kind() == reflect.Bool {
spec.isBool = true
}
// Look inside slice types
if scalarType.Kind() == reflect.Slice {
spec.multiple = true
scalarType = scalarType.Elem()
}
// Look inside pointer types (again, in case of []*Type)
if scalarType.Kind() == reflect.Ptr {
scalarType = scalarType.Elem()
}
}
// Check for unsupported types
switch scalarType.Kind() {
case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
reflect.Map, reflect.Ptr, reflect.Struct,
reflect.Complex64, reflect.Complex128:
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
}
// Specify that it is a bool for usage
if scalarType.Kind() == reflect.Bool {
spec.isBool = true
// Check for unsupported types
switch scalarType.Kind() {
case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
reflect.Map, reflect.Ptr, reflect.Struct,
reflect.Complex64, reflect.Complex128:
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
}
}
// Look at the tag
@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error {
}
// if it's a flag and it has no value then set the value to true
if spec.dest.Kind() == reflect.Bool && value == "" {
// use isBool because this takes account of TextUnmarshaler
if spec.isBool && value == "" {
value = "true"
}

View File

@ -15,7 +15,11 @@ func parse(cmdline string, dest interface{}) error {
if err != nil {
return err
}
return p.Parse(strings.Split(cmdline, " "))
var parts []string
if len(cmdline) > 0 {
parts = strings.Split(cmdline, " ")
}
return p.Parse(parts)
}
func TestString(t *testing.T) {
@ -71,6 +75,25 @@ func TestInvalidDuration(t *testing.T) {
require.Error(t, err)
}
func TestIntPtr(t *testing.T) {
var args struct {
Foo *int
}
err := parse("--foo 123", &args)
require.NoError(t, err)
require.NotNil(t, args.Foo)
assert.Equal(t, 123, *args.Foo)
}
func TestIntPtrNotPresent(t *testing.T) {
var args struct {
Foo *int
}
err := parse("", &args)
require.NoError(t, err)
assert.Nil(t, args.Foo)
}
func TestMixed(t *testing.T) {
var args struct {
Foo string `arg:"-f"`
@ -359,6 +382,14 @@ func TestUnsupportedType(t *testing.T) {
}
func TestUnsupportedSliceElement(t *testing.T) {
var args struct {
Foo []interface{}
}
err := parse("--foo 3", &args)
assert.Error(t, err)
}
func TestUnsupportedSliceElementMissingValue(t *testing.T) {
var args struct {
Foo []interface{}
}
@ -452,3 +483,61 @@ func TestEnvironmentVariableRequired(t *testing.T) {
MustParse(&args)
assert.Equal(t, "bar", args.Foo)
}
type textUnmarshaler struct {
val int
}
func (f *textUnmarshaler) UnmarshalText(b []byte) error {
f.val = len(b)
return nil
}
func TestTextUnmarshaler(t *testing.T) {
// fields that implement TextUnmarshaler should be parsed using that interface
var args struct {
Foo *textUnmarshaler
}
err := parse("--foo abc", &args)
require.NoError(t, err)
assert.Equal(t, 3, args.Foo.val)
}
type boolUnmarshaler bool
func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
*p = len(b)%2 == 0
return nil
}
func TestBoolUnmarhsaler(t *testing.T) {
// test that a bool type that implements TextUnmarshaler is
// handled as a TextUnmarshaler not as a bool
var args struct {
Foo *boolUnmarshaler
}
err := parse("--foo ab", &args)
require.NoError(t, err)
assert.EqualValues(t, true, *args.Foo)
}
type sliceUnmarshaler []int
func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
*p = sliceUnmarshaler{len(b)}
return nil
}
func TestSliceUnmarhsaler(t *testing.T) {
// test that a slice type that implements TextUnmarshaler is
// handled as a TextUnmarshaler not as a slice
var args struct {
Foo *sliceUnmarshaler
Bar string `arg:"positional"`
}
err := parse("--foo abcde xyz", &args)
require.NoError(t, err)
require.Len(t, *args.Foo, 1)
assert.EqualValues(t, 5, (*args.Foo)[0])
assert.Equal(t, "xyz", args.Bar)
}

View File

@ -8,19 +8,33 @@ import (
"time"
)
var (
durationType = reflect.TypeOf(time.Duration(0))
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
)
// set a value from a string
func setScalar(v reflect.Value, s string) error {
if !v.CanSet() {
return fmt.Errorf("field is not exported")
}
// If we have a time.Duration then use time.ParseDuration
if v.Type() == durationType {
// If we have a nil pointer then allocate a new object
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
// Get the object as an interface
scalar := v.Interface()
// If it implements encoding.TextUnmarshaler then use that
if scalar, ok := scalar.(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s))
}
// If we have a pointer then dereference it
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
// Switch on concrete type
switch scalar.(type) {
case time.Duration:
x, err := time.ParseDuration(s)
if err != nil {
return err
@ -29,6 +43,7 @@ func setScalar(v reflect.Value, s string) error {
return nil
}
// Switch on kind so that we can handle derived types
switch v.Kind() {
case reflect.String:
v.SetString(s)
@ -57,7 +72,7 @@ func setScalar(v reflect.Value, s string) error {
}
v.SetFloat(x)
default:
return fmt.Errorf("not a scalar type: %s", v.Kind())
return fmt.Errorf("cannot parse argument into %s", v.Type().String())
}
return nil
}