diff --git a/README.md b/README.md index 3d1d12f..28ff388 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,10 @@ ## Structured argument parsing for Go +```shell +go get github.com/alexflint/go-arg +``` + Declare the command line arguments your program accepts by defining a struct. ```go @@ -24,16 +28,16 @@ hello true ```go var args struct { - Foo string `arg:"required"` - Bar bool + ID int `arg:"required"` + Timeout time.Duration } arg.MustParse(&args) ``` ```shell $ ./example -usage: example --foo FOO [--bar] -error: --foo is required +usage: example --id ID [--timeout TIMEOUT] +error: --id is required ``` ### Positional arguments @@ -161,10 +165,51 @@ usage: samples [--foo FOO] [--bar BAR] error: you must provide one of --foo and --bar ``` -### Installation +### Custom parsing +You can implement your own argument parser by implementing `encoding.TextUnmarshaler`: + +```go +package main + +import ( + "fmt" + "strings" + + "github.com/alexflint/go-arg" +) + +// Accepts command line arguments of the form "head.tail" +type NameDotName struct { + Head, Tail string +} + +func (n *NameDotName) UnmarshalText(b []byte) error { + s := string(b) + pos := strings.Index(s, ".") + if pos == -1 { + return fmt.Errorf("missing period in %s", s) + } + n.Head = s[:pos] + n.Tail = s[pos+1:] + return nil +} + +func main() { + var args struct { + Name *NameDotName + } + arg.MustParse(&args) + fmt.Printf("%#v\n", args.Name) +} +``` ```shell -go get github.com/alexflint/go-arg +$ ./example --name=foo.bar +&main.NameDotName{Head:"foo", Tail:"bar"} + +$ ./example --name=oops +usage: example [--name NAME] +error: error processing --name: missing period in "oops" ``` ### Documentation diff --git a/parse.go b/parse.go index 39eb52c..3895ce9 100644 --- a/parse.go +++ b/parse.go @@ -1,6 +1,7 @@ package arg import ( + "encoding" "errors" "fmt" "os" @@ -20,11 +21,15 @@ type spec struct { env string wasPresent bool isBool bool + fieldName string // for generating helpful errors } // ErrHelp indicates that -h or --help were provided var ErrHelp = errors.New("help requested by user") +// The TextUnmarshaler type in reflection form +var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() + // MustParse processes command line arguments and exits upon failure func MustParse(dest ...interface{}) *Parser { p, err := NewParser(dest...) @@ -80,31 +85,42 @@ func NewParser(dests ...interface{}) (*Parser, error) { } spec := spec{ - long: strings.ToLower(field.Name), - dest: v.Field(i), + long: strings.ToLower(field.Name), + dest: v.Field(i), + fieldName: t.Name() + "." + field.Name, } - // Get the scalar type for this field - scalarType := field.Type - if scalarType.Kind() == reflect.Slice { - spec.multiple = true - scalarType = scalarType.Elem() + // Check whether this field is supported. It's good to do this here rather than + // wait until setScalar because it means that a program with invalid argument + // fields will always fail regardless of whether the arguments it recieved happend + // to exercise those fields. + if !field.Type.Implements(textUnsmarshalerType) { + scalarType := field.Type + // Look inside pointer types + if scalarType.Kind() == reflect.Ptr { + scalarType = scalarType.Elem() + } + // Check for bool + if scalarType.Kind() == reflect.Bool { + spec.isBool = true + } + // Look inside slice types + if scalarType.Kind() == reflect.Slice { + spec.multiple = true + scalarType = scalarType.Elem() + } + // Look inside pointer types (again, in case of []*Type) if scalarType.Kind() == reflect.Ptr { scalarType = scalarType.Elem() } - } - // Check for unsupported types - switch scalarType.Kind() { - case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Struct, - reflect.Complex64, reflect.Complex128: - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) - } - - // Specify that it is a bool for usage - if scalarType.Kind() == reflect.Bool { - spec.isBool = true + // Check for unsupported types + switch scalarType.Kind() { + case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Struct, + reflect.Complex64, reflect.Complex128: + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) + } } // Look at the tag @@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error { } // if it's a flag and it has no value then set the value to true - if spec.dest.Kind() == reflect.Bool && value == "" { + // use isBool because this takes account of TextUnmarshaler + if spec.isBool && value == "" { value = "true" } diff --git a/parse_test.go b/parse_test.go index c30809d..a915910 100644 --- a/parse_test.go +++ b/parse_test.go @@ -15,7 +15,11 @@ func parse(cmdline string, dest interface{}) error { if err != nil { return err } - return p.Parse(strings.Split(cmdline, " ")) + var parts []string + if len(cmdline) > 0 { + parts = strings.Split(cmdline, " ") + } + return p.Parse(parts) } func TestString(t *testing.T) { @@ -71,6 +75,25 @@ func TestInvalidDuration(t *testing.T) { require.Error(t, err) } +func TestIntPtr(t *testing.T) { + var args struct { + Foo *int + } + err := parse("--foo 123", &args) + require.NoError(t, err) + require.NotNil(t, args.Foo) + assert.Equal(t, 123, *args.Foo) +} + +func TestIntPtrNotPresent(t *testing.T) { + var args struct { + Foo *int + } + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.Foo) +} + func TestMixed(t *testing.T) { var args struct { Foo string `arg:"-f"` @@ -359,6 +382,14 @@ func TestUnsupportedType(t *testing.T) { } func TestUnsupportedSliceElement(t *testing.T) { + var args struct { + Foo []interface{} + } + err := parse("--foo 3", &args) + assert.Error(t, err) +} + +func TestUnsupportedSliceElementMissingValue(t *testing.T) { var args struct { Foo []interface{} } @@ -452,3 +483,61 @@ func TestEnvironmentVariableRequired(t *testing.T) { MustParse(&args) assert.Equal(t, "bar", args.Foo) } + +type textUnmarshaler struct { + val int +} + +func (f *textUnmarshaler) UnmarshalText(b []byte) error { + f.val = len(b) + return nil +} + +func TestTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo *textUnmarshaler + } + err := parse("--foo abc", &args) + require.NoError(t, err) + assert.Equal(t, 3, args.Foo.val) +} + +type boolUnmarshaler bool + +func (p *boolUnmarshaler) UnmarshalText(b []byte) error { + *p = len(b)%2 == 0 + return nil +} + +func TestBoolUnmarhsaler(t *testing.T) { + // test that a bool type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a bool + var args struct { + Foo *boolUnmarshaler + } + err := parse("--foo ab", &args) + require.NoError(t, err) + assert.EqualValues(t, true, *args.Foo) +} + +type sliceUnmarshaler []int + +func (p *sliceUnmarshaler) UnmarshalText(b []byte) error { + *p = sliceUnmarshaler{len(b)} + return nil +} + +func TestSliceUnmarhsaler(t *testing.T) { + // test that a slice type that implements TextUnmarshaler is + // handled as a TextUnmarshaler not as a slice + var args struct { + Foo *sliceUnmarshaler + Bar string `arg:"positional"` + } + err := parse("--foo abcde xyz", &args) + require.NoError(t, err) + require.Len(t, *args.Foo, 1) + assert.EqualValues(t, 5, (*args.Foo)[0]) + assert.Equal(t, "xyz", args.Bar) +} diff --git a/scalar.go b/scalar.go index a3bafe4..67b4540 100644 --- a/scalar.go +++ b/scalar.go @@ -8,19 +8,33 @@ import ( "time" ) -var ( - durationType = reflect.TypeOf(time.Duration(0)) - textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() -) - // set a value from a string func setScalar(v reflect.Value, s string) error { if !v.CanSet() { return fmt.Errorf("field is not exported") } - // If we have a time.Duration then use time.ParseDuration - if v.Type() == durationType { + // If we have a nil pointer then allocate a new object + if v.Kind() == reflect.Ptr && v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + // Get the object as an interface + scalar := v.Interface() + + // If it implements encoding.TextUnmarshaler then use that + if scalar, ok := scalar.(encoding.TextUnmarshaler); ok { + return scalar.UnmarshalText([]byte(s)) + } + + // If we have a pointer then dereference it + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + // Switch on concrete type + switch scalar.(type) { + case time.Duration: x, err := time.ParseDuration(s) if err != nil { return err @@ -29,6 +43,7 @@ func setScalar(v reflect.Value, s string) error { return nil } + // Switch on kind so that we can handle derived types switch v.Kind() { case reflect.String: v.SetString(s) @@ -57,7 +72,7 @@ func setScalar(v reflect.Value, s string) error { } v.SetFloat(x) default: - return fmt.Errorf("not a scalar type: %s", v.Kind()) + return fmt.Errorf("cannot parse argument into %s", v.Type().String()) } return nil }