Merge pull request #30 from alexflint/scalar_pointers

add support for pointers and TextUnmarshaler
This commit is contained in:
Alex Flint 2016-01-23 21:03:39 -08:00
commit c0809e537f
4 changed files with 201 additions and 35 deletions

View File

@ -4,6 +4,10 @@
## Structured argument parsing for Go ## Structured argument parsing for Go
```shell
go get github.com/alexflint/go-arg
```
Declare the command line arguments your program accepts by defining a struct. Declare the command line arguments your program accepts by defining a struct.
```go ```go
@ -24,16 +28,16 @@ hello true
```go ```go
var args struct { var args struct {
Foo string `arg:"required"` ID int `arg:"required"`
Bar bool Timeout time.Duration
} }
arg.MustParse(&args) arg.MustParse(&args)
``` ```
```shell ```shell
$ ./example $ ./example
usage: example --foo FOO [--bar] usage: example --id ID [--timeout TIMEOUT]
error: --foo is required error: --id is required
``` ```
### Positional arguments ### Positional arguments
@ -161,10 +165,51 @@ usage: samples [--foo FOO] [--bar BAR]
error: you must provide one of --foo and --bar error: you must provide one of --foo and --bar
``` ```
### Installation ### Custom parsing
You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:
```go
package main
import (
"fmt"
"strings"
"github.com/alexflint/go-arg"
)
// Accepts command line arguments of the form "head.tail"
type NameDotName struct {
Head, Tail string
}
func (n *NameDotName) UnmarshalText(b []byte) error {
s := string(b)
pos := strings.Index(s, ".")
if pos == -1 {
return fmt.Errorf("missing period in %s", s)
}
n.Head = s[:pos]
n.Tail = s[pos+1:]
return nil
}
func main() {
var args struct {
Name *NameDotName
}
arg.MustParse(&args)
fmt.Printf("%#v\n", args.Name)
}
```
```shell ```shell
go get github.com/alexflint/go-arg $ ./example --name=foo.bar
&main.NameDotName{Head:"foo", Tail:"bar"}
$ ./example --name=oops
usage: example [--name NAME]
error: error processing --name: missing period in "oops"
``` ```
### Documentation ### Documentation

View File

@ -1,6 +1,7 @@
package arg package arg
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -20,11 +21,15 @@ type spec struct {
env string env string
wasPresent bool wasPresent bool
isBool bool isBool bool
fieldName string // for generating helpful errors
} }
// ErrHelp indicates that -h or --help were provided // ErrHelp indicates that -h or --help were provided
var ErrHelp = errors.New("help requested by user") var ErrHelp = errors.New("help requested by user")
// The TextUnmarshaler type in reflection form
var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// MustParse processes command line arguments and exits upon failure // MustParse processes command line arguments and exits upon failure
func MustParse(dest ...interface{}) *Parser { func MustParse(dest ...interface{}) *Parser {
p, err := NewParser(dest...) p, err := NewParser(dest...)
@ -80,31 +85,42 @@ func NewParser(dests ...interface{}) (*Parser, error) {
} }
spec := spec{ spec := spec{
long: strings.ToLower(field.Name), long: strings.ToLower(field.Name),
dest: v.Field(i), dest: v.Field(i),
fieldName: t.Name() + "." + field.Name,
} }
// Get the scalar type for this field // Check whether this field is supported. It's good to do this here rather than
scalarType := field.Type // wait until setScalar because it means that a program with invalid argument
if scalarType.Kind() == reflect.Slice { // fields will always fail regardless of whether the arguments it recieved happend
spec.multiple = true // to exercise those fields.
scalarType = scalarType.Elem() if !field.Type.Implements(textUnsmarshalerType) {
scalarType := field.Type
// Look inside pointer types
if scalarType.Kind() == reflect.Ptr {
scalarType = scalarType.Elem()
}
// Check for bool
if scalarType.Kind() == reflect.Bool {
spec.isBool = true
}
// Look inside slice types
if scalarType.Kind() == reflect.Slice {
spec.multiple = true
scalarType = scalarType.Elem()
}
// Look inside pointer types (again, in case of []*Type)
if scalarType.Kind() == reflect.Ptr { if scalarType.Kind() == reflect.Ptr {
scalarType = scalarType.Elem() scalarType = scalarType.Elem()
} }
}
// Check for unsupported types // Check for unsupported types
switch scalarType.Kind() { switch scalarType.Kind() {
case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
reflect.Map, reflect.Ptr, reflect.Struct, reflect.Map, reflect.Ptr, reflect.Struct,
reflect.Complex64, reflect.Complex128: reflect.Complex64, reflect.Complex128:
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
} }
// Specify that it is a bool for usage
if scalarType.Kind() == reflect.Bool {
spec.isBool = true
} }
// Look at the tag // Look at the tag
@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error {
} }
// if it's a flag and it has no value then set the value to true // if it's a flag and it has no value then set the value to true
if spec.dest.Kind() == reflect.Bool && value == "" { // use isBool because this takes account of TextUnmarshaler
if spec.isBool && value == "" {
value = "true" value = "true"
} }

View File

@ -15,7 +15,11 @@ func parse(cmdline string, dest interface{}) error {
if err != nil { if err != nil {
return err return err
} }
return p.Parse(strings.Split(cmdline, " ")) var parts []string
if len(cmdline) > 0 {
parts = strings.Split(cmdline, " ")
}
return p.Parse(parts)
} }
func TestString(t *testing.T) { func TestString(t *testing.T) {
@ -71,6 +75,25 @@ func TestInvalidDuration(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestIntPtr(t *testing.T) {
var args struct {
Foo *int
}
err := parse("--foo 123", &args)
require.NoError(t, err)
require.NotNil(t, args.Foo)
assert.Equal(t, 123, *args.Foo)
}
func TestIntPtrNotPresent(t *testing.T) {
var args struct {
Foo *int
}
err := parse("", &args)
require.NoError(t, err)
assert.Nil(t, args.Foo)
}
func TestMixed(t *testing.T) { func TestMixed(t *testing.T) {
var args struct { var args struct {
Foo string `arg:"-f"` Foo string `arg:"-f"`
@ -359,6 +382,14 @@ func TestUnsupportedType(t *testing.T) {
} }
func TestUnsupportedSliceElement(t *testing.T) { func TestUnsupportedSliceElement(t *testing.T) {
var args struct {
Foo []interface{}
}
err := parse("--foo 3", &args)
assert.Error(t, err)
}
func TestUnsupportedSliceElementMissingValue(t *testing.T) {
var args struct { var args struct {
Foo []interface{} Foo []interface{}
} }
@ -452,3 +483,61 @@ func TestEnvironmentVariableRequired(t *testing.T) {
MustParse(&args) MustParse(&args)
assert.Equal(t, "bar", args.Foo) assert.Equal(t, "bar", args.Foo)
} }
type textUnmarshaler struct {
val int
}
func (f *textUnmarshaler) UnmarshalText(b []byte) error {
f.val = len(b)
return nil
}
func TestTextUnmarshaler(t *testing.T) {
// fields that implement TextUnmarshaler should be parsed using that interface
var args struct {
Foo *textUnmarshaler
}
err := parse("--foo abc", &args)
require.NoError(t, err)
assert.Equal(t, 3, args.Foo.val)
}
type boolUnmarshaler bool
func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
*p = len(b)%2 == 0
return nil
}
func TestBoolUnmarhsaler(t *testing.T) {
// test that a bool type that implements TextUnmarshaler is
// handled as a TextUnmarshaler not as a bool
var args struct {
Foo *boolUnmarshaler
}
err := parse("--foo ab", &args)
require.NoError(t, err)
assert.EqualValues(t, true, *args.Foo)
}
type sliceUnmarshaler []int
func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
*p = sliceUnmarshaler{len(b)}
return nil
}
func TestSliceUnmarhsaler(t *testing.T) {
// test that a slice type that implements TextUnmarshaler is
// handled as a TextUnmarshaler not as a slice
var args struct {
Foo *sliceUnmarshaler
Bar string `arg:"positional"`
}
err := parse("--foo abcde xyz", &args)
require.NoError(t, err)
require.Len(t, *args.Foo, 1)
assert.EqualValues(t, 5, (*args.Foo)[0])
assert.Equal(t, "xyz", args.Bar)
}

View File

@ -8,19 +8,33 @@ import (
"time" "time"
) )
var (
durationType = reflect.TypeOf(time.Duration(0))
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
)
// set a value from a string // set a value from a string
func setScalar(v reflect.Value, s string) error { func setScalar(v reflect.Value, s string) error {
if !v.CanSet() { if !v.CanSet() {
return fmt.Errorf("field is not exported") return fmt.Errorf("field is not exported")
} }
// If we have a time.Duration then use time.ParseDuration // If we have a nil pointer then allocate a new object
if v.Type() == durationType { if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
// Get the object as an interface
scalar := v.Interface()
// If it implements encoding.TextUnmarshaler then use that
if scalar, ok := scalar.(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s))
}
// If we have a pointer then dereference it
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
// Switch on concrete type
switch scalar.(type) {
case time.Duration:
x, err := time.ParseDuration(s) x, err := time.ParseDuration(s)
if err != nil { if err != nil {
return err return err
@ -29,6 +43,7 @@ func setScalar(v reflect.Value, s string) error {
return nil return nil
} }
// Switch on kind so that we can handle derived types
switch v.Kind() { switch v.Kind() {
case reflect.String: case reflect.String:
v.SetString(s) v.SetString(s)
@ -57,7 +72,7 @@ func setScalar(v reflect.Value, s string) error {
} }
v.SetFloat(x) v.SetFloat(x)
default: default:
return fmt.Errorf("not a scalar type: %s", v.Kind()) return fmt.Errorf("cannot parse argument into %s", v.Type().String())
} }
return nil return nil
} }