add support for pointers and TextUnmarshaler
This commit is contained in:
parent
64a4bab550
commit
865cc5a973
35
parse.go
35
parse.go
|
@ -1,6 +1,7 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -20,11 +21,15 @@ type spec struct {
|
|||
env string
|
||||
wasPresent bool
|
||||
isBool bool
|
||||
fieldName string // for generating helpful errors
|
||||
}
|
||||
|
||||
// ErrHelp indicates that -h or --help were provided
|
||||
var ErrHelp = errors.New("help requested by user")
|
||||
|
||||
// The TextUnmarshaler type in reflection form
|
||||
var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
||||
|
||||
// MustParse processes command line arguments and exits upon failure
|
||||
func MustParse(dest ...interface{}) *Parser {
|
||||
p, err := NewParser(dest...)
|
||||
|
@ -82,16 +87,31 @@ func NewParser(dests ...interface{}) (*Parser, error) {
|
|||
spec := spec{
|
||||
long: strings.ToLower(field.Name),
|
||||
dest: v.Field(i),
|
||||
fieldName: t.Name() + "." + field.Name,
|
||||
}
|
||||
|
||||
// Get the scalar type for this field
|
||||
// Check whether this field is supported. It's good to do this here rather than
|
||||
// wait until setScalar because it means that a program with invalid argument
|
||||
// fields will always fail regardless of whether the arguments it recieved happend
|
||||
// to exercise those fields.
|
||||
if !field.Type.Implements(textUnsmarshalerType) {
|
||||
scalarType := field.Type
|
||||
if scalarType.Kind() == reflect.Slice {
|
||||
spec.multiple = true
|
||||
scalarType = scalarType.Elem()
|
||||
// Look inside pointer types
|
||||
if scalarType.Kind() == reflect.Ptr {
|
||||
scalarType = scalarType.Elem()
|
||||
}
|
||||
// Check for bool
|
||||
if scalarType.Kind() == reflect.Bool {
|
||||
spec.isBool = true
|
||||
}
|
||||
// Look inside slice types
|
||||
if scalarType.Kind() == reflect.Slice {
|
||||
spec.multiple = true
|
||||
scalarType = scalarType.Elem()
|
||||
}
|
||||
// Look inside pointer types (again, in case of []*Type)
|
||||
if scalarType.Kind() == reflect.Ptr {
|
||||
scalarType = scalarType.Elem()
|
||||
}
|
||||
|
||||
// Check for unsupported types
|
||||
|
@ -101,10 +121,6 @@ func NewParser(dests ...interface{}) (*Parser, error) {
|
|||
reflect.Complex64, reflect.Complex128:
|
||||
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
|
||||
}
|
||||
|
||||
// Specify that it is a bool for usage
|
||||
if scalarType.Kind() == reflect.Bool {
|
||||
spec.isBool = true
|
||||
}
|
||||
|
||||
// Look at the tag
|
||||
|
@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error {
|
|||
}
|
||||
|
||||
// if it's a flag and it has no value then set the value to true
|
||||
if spec.dest.Kind() == reflect.Bool && value == "" {
|
||||
// use isBool because this takes account of TextUnmarshaler
|
||||
if spec.isBool && value == "" {
|
||||
value = "true"
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,11 @@ func parse(cmdline string, dest interface{}) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.Parse(strings.Split(cmdline, " "))
|
||||
var parts []string
|
||||
if len(cmdline) > 0 {
|
||||
parts = strings.Split(cmdline, " ")
|
||||
}
|
||||
return p.Parse(parts)
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
|
@ -71,6 +75,25 @@ func TestInvalidDuration(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIntPtr(t *testing.T) {
|
||||
var args struct {
|
||||
Foo *int
|
||||
}
|
||||
err := parse("--foo 123", &args)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, args.Foo)
|
||||
assert.Equal(t, 123, *args.Foo)
|
||||
}
|
||||
|
||||
func TestIntPtrNotPresent(t *testing.T) {
|
||||
var args struct {
|
||||
Foo *int
|
||||
}
|
||||
err := parse("", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, args.Foo)
|
||||
}
|
||||
|
||||
func TestMixed(t *testing.T) {
|
||||
var args struct {
|
||||
Foo string `arg:"-f"`
|
||||
|
@ -359,6 +382,14 @@ func TestUnsupportedType(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUnsupportedSliceElement(t *testing.T) {
|
||||
var args struct {
|
||||
Foo []interface{}
|
||||
}
|
||||
err := parse("--foo 3", &args)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUnsupportedSliceElementMissingValue(t *testing.T) {
|
||||
var args struct {
|
||||
Foo []interface{}
|
||||
}
|
||||
|
@ -452,3 +483,61 @@ func TestEnvironmentVariableRequired(t *testing.T) {
|
|||
MustParse(&args)
|
||||
assert.Equal(t, "bar", args.Foo)
|
||||
}
|
||||
|
||||
type textUnmarshaler struct {
|
||||
val int
|
||||
}
|
||||
|
||||
func (f *textUnmarshaler) UnmarshalText(b []byte) error {
|
||||
f.val = len(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTextUnmarshaler(t *testing.T) {
|
||||
// fields that implement TextUnmarshaler should be parsed using that interface
|
||||
var args struct {
|
||||
Foo *textUnmarshaler
|
||||
}
|
||||
err := parse("--foo abc", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, args.Foo.val)
|
||||
}
|
||||
|
||||
type boolUnmarshaler bool
|
||||
|
||||
func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
|
||||
*p = len(b)%2 == 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBoolUnmarhsaler(t *testing.T) {
|
||||
// test that a bool type that implements TextUnmarshaler is
|
||||
// handled as a TextUnmarshaler not as a bool
|
||||
var args struct {
|
||||
Foo *boolUnmarshaler
|
||||
}
|
||||
err := parse("--foo ab", &args)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, true, *args.Foo)
|
||||
}
|
||||
|
||||
type sliceUnmarshaler []int
|
||||
|
||||
func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
|
||||
*p = sliceUnmarshaler{len(b)}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSliceUnmarhsaler(t *testing.T) {
|
||||
// test that a slice type that implements TextUnmarshaler is
|
||||
// handled as a TextUnmarshaler not as a slice
|
||||
var args struct {
|
||||
Foo *sliceUnmarshaler
|
||||
Bar string `arg:"positional"`
|
||||
}
|
||||
err := parse("--foo abcde xyz", &args)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, *args.Foo, 1)
|
||||
assert.EqualValues(t, 5, (*args.Foo)[0])
|
||||
assert.Equal(t, "xyz", args.Bar)
|
||||
}
|
||||
|
|
31
scalar.go
31
scalar.go
|
@ -8,19 +8,33 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
durationType = reflect.TypeOf(time.Duration(0))
|
||||
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
||||
)
|
||||
|
||||
// set a value from a string
|
||||
func setScalar(v reflect.Value, s string) error {
|
||||
if !v.CanSet() {
|
||||
return fmt.Errorf("field is not exported")
|
||||
}
|
||||
|
||||
// If we have a time.Duration then use time.ParseDuration
|
||||
if v.Type() == durationType {
|
||||
// If we have a nil pointer then allocate a new object
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
|
||||
// Get the object as an interface
|
||||
scalar := v.Interface()
|
||||
|
||||
// If it implements encoding.TextUnmarshaler then use that
|
||||
if scalar, ok := scalar.(encoding.TextUnmarshaler); ok {
|
||||
return scalar.UnmarshalText([]byte(s))
|
||||
}
|
||||
|
||||
// If we have a pointer then dereference it
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
// Switch on concrete type
|
||||
switch scalar.(type) {
|
||||
case time.Duration:
|
||||
x, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -29,6 +43,7 @@ func setScalar(v reflect.Value, s string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Switch on kind so that we can handle derived types
|
||||
switch v.Kind() {
|
||||
case reflect.String:
|
||||
v.SetString(s)
|
||||
|
@ -57,7 +72,7 @@ func setScalar(v reflect.Value, s string) error {
|
|||
}
|
||||
v.SetFloat(x)
|
||||
default:
|
||||
return fmt.Errorf("not a scalar type: %s", v.Kind())
|
||||
return fmt.Errorf("cannot parse argument into %s", v.Type().String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue