Merge pull request #30 from alexflint/scalar_pointers
add support for pointers and TextUnmarshaler
This commit is contained in:
commit
c0809e537f
57
README.md
57
README.md
|
@ -4,6 +4,10 @@
|
||||||
|
|
||||||
## Structured argument parsing for Go
|
## Structured argument parsing for Go
|
||||||
|
|
||||||
|
```shell
|
||||||
|
go get github.com/alexflint/go-arg
|
||||||
|
```
|
||||||
|
|
||||||
Declare the command line arguments your program accepts by defining a struct.
|
Declare the command line arguments your program accepts by defining a struct.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
@ -24,16 +28,16 @@ hello true
|
||||||
|
|
||||||
```go
|
```go
|
||||||
var args struct {
|
var args struct {
|
||||||
Foo string `arg:"required"`
|
ID int `arg:"required"`
|
||||||
Bar bool
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
arg.MustParse(&args)
|
arg.MustParse(&args)
|
||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
$ ./example
|
$ ./example
|
||||||
usage: example --foo FOO [--bar]
|
usage: example --id ID [--timeout TIMEOUT]
|
||||||
error: --foo is required
|
error: --id is required
|
||||||
```
|
```
|
||||||
|
|
||||||
### Positional arguments
|
### Positional arguments
|
||||||
|
@ -161,10 +165,51 @@ usage: samples [--foo FOO] [--bar BAR]
|
||||||
error: you must provide one of --foo and --bar
|
error: you must provide one of --foo and --bar
|
||||||
```
|
```
|
||||||
|
|
||||||
### Installation
|
### Custom parsing
|
||||||
|
|
||||||
|
You can implement your own argument parser by implementing `encoding.TextUnmarshaler`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/alexflint/go-arg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Accepts command line arguments of the form "head.tail"
|
||||||
|
type NameDotName struct {
|
||||||
|
Head, Tail string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NameDotName) UnmarshalText(b []byte) error {
|
||||||
|
s := string(b)
|
||||||
|
pos := strings.Index(s, ".")
|
||||||
|
if pos == -1 {
|
||||||
|
return fmt.Errorf("missing period in %s", s)
|
||||||
|
}
|
||||||
|
n.Head = s[:pos]
|
||||||
|
n.Tail = s[pos+1:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var args struct {
|
||||||
|
Name *NameDotName
|
||||||
|
}
|
||||||
|
arg.MustParse(&args)
|
||||||
|
fmt.Printf("%#v\n", args.Name)
|
||||||
|
}
|
||||||
|
```
|
||||||
```shell
|
```shell
|
||||||
go get github.com/alexflint/go-arg
|
$ ./example --name=foo.bar
|
||||||
|
&main.NameDotName{Head:"foo", Tail:"bar"}
|
||||||
|
|
||||||
|
$ ./example --name=oops
|
||||||
|
usage: example [--name NAME]
|
||||||
|
error: error processing --name: missing period in "oops"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
57
parse.go
57
parse.go
|
@ -1,6 +1,7 @@
|
||||||
package arg
|
package arg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -20,11 +21,15 @@ type spec struct {
|
||||||
env string
|
env string
|
||||||
wasPresent bool
|
wasPresent bool
|
||||||
isBool bool
|
isBool bool
|
||||||
|
fieldName string // for generating helpful errors
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrHelp indicates that -h or --help were provided
|
// ErrHelp indicates that -h or --help were provided
|
||||||
var ErrHelp = errors.New("help requested by user")
|
var ErrHelp = errors.New("help requested by user")
|
||||||
|
|
||||||
|
// The TextUnmarshaler type in reflection form
|
||||||
|
var textUnsmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
||||||
|
|
||||||
// MustParse processes command line arguments and exits upon failure
|
// MustParse processes command line arguments and exits upon failure
|
||||||
func MustParse(dest ...interface{}) *Parser {
|
func MustParse(dest ...interface{}) *Parser {
|
||||||
p, err := NewParser(dest...)
|
p, err := NewParser(dest...)
|
||||||
|
@ -80,31 +85,42 @@ func NewParser(dests ...interface{}) (*Parser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
spec := spec{
|
spec := spec{
|
||||||
long: strings.ToLower(field.Name),
|
long: strings.ToLower(field.Name),
|
||||||
dest: v.Field(i),
|
dest: v.Field(i),
|
||||||
|
fieldName: t.Name() + "." + field.Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the scalar type for this field
|
// Check whether this field is supported. It's good to do this here rather than
|
||||||
scalarType := field.Type
|
// wait until setScalar because it means that a program with invalid argument
|
||||||
if scalarType.Kind() == reflect.Slice {
|
// fields will always fail regardless of whether the arguments it recieved happend
|
||||||
spec.multiple = true
|
// to exercise those fields.
|
||||||
scalarType = scalarType.Elem()
|
if !field.Type.Implements(textUnsmarshalerType) {
|
||||||
|
scalarType := field.Type
|
||||||
|
// Look inside pointer types
|
||||||
|
if scalarType.Kind() == reflect.Ptr {
|
||||||
|
scalarType = scalarType.Elem()
|
||||||
|
}
|
||||||
|
// Check for bool
|
||||||
|
if scalarType.Kind() == reflect.Bool {
|
||||||
|
spec.isBool = true
|
||||||
|
}
|
||||||
|
// Look inside slice types
|
||||||
|
if scalarType.Kind() == reflect.Slice {
|
||||||
|
spec.multiple = true
|
||||||
|
scalarType = scalarType.Elem()
|
||||||
|
}
|
||||||
|
// Look inside pointer types (again, in case of []*Type)
|
||||||
if scalarType.Kind() == reflect.Ptr {
|
if scalarType.Kind() == reflect.Ptr {
|
||||||
scalarType = scalarType.Elem()
|
scalarType = scalarType.Elem()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Check for unsupported types
|
// Check for unsupported types
|
||||||
switch scalarType.Kind() {
|
switch scalarType.Kind() {
|
||||||
case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
|
case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
|
||||||
reflect.Map, reflect.Ptr, reflect.Struct,
|
reflect.Map, reflect.Ptr, reflect.Struct,
|
||||||
reflect.Complex64, reflect.Complex128:
|
reflect.Complex64, reflect.Complex128:
|
||||||
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
|
return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Specify that it is a bool for usage
|
|
||||||
if scalarType.Kind() == reflect.Bool {
|
|
||||||
spec.isBool = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look at the tag
|
// Look at the tag
|
||||||
|
@ -248,7 +264,8 @@ func process(specs []*spec, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it's a flag and it has no value then set the value to true
|
// if it's a flag and it has no value then set the value to true
|
||||||
if spec.dest.Kind() == reflect.Bool && value == "" {
|
// use isBool because this takes account of TextUnmarshaler
|
||||||
|
if spec.isBool && value == "" {
|
||||||
value = "true"
|
value = "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,11 @@ func parse(cmdline string, dest interface{}) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return p.Parse(strings.Split(cmdline, " "))
|
var parts []string
|
||||||
|
if len(cmdline) > 0 {
|
||||||
|
parts = strings.Split(cmdline, " ")
|
||||||
|
}
|
||||||
|
return p.Parse(parts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestString(t *testing.T) {
|
func TestString(t *testing.T) {
|
||||||
|
@ -71,6 +75,25 @@ func TestInvalidDuration(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIntPtr(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Foo *int
|
||||||
|
}
|
||||||
|
err := parse("--foo 123", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, args.Foo)
|
||||||
|
assert.Equal(t, 123, *args.Foo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntPtrNotPresent(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Foo *int
|
||||||
|
}
|
||||||
|
err := parse("", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, args.Foo)
|
||||||
|
}
|
||||||
|
|
||||||
func TestMixed(t *testing.T) {
|
func TestMixed(t *testing.T) {
|
||||||
var args struct {
|
var args struct {
|
||||||
Foo string `arg:"-f"`
|
Foo string `arg:"-f"`
|
||||||
|
@ -359,6 +382,14 @@ func TestUnsupportedType(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnsupportedSliceElement(t *testing.T) {
|
func TestUnsupportedSliceElement(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Foo []interface{}
|
||||||
|
}
|
||||||
|
err := parse("--foo 3", &args)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsupportedSliceElementMissingValue(t *testing.T) {
|
||||||
var args struct {
|
var args struct {
|
||||||
Foo []interface{}
|
Foo []interface{}
|
||||||
}
|
}
|
||||||
|
@ -452,3 +483,61 @@ func TestEnvironmentVariableRequired(t *testing.T) {
|
||||||
MustParse(&args)
|
MustParse(&args)
|
||||||
assert.Equal(t, "bar", args.Foo)
|
assert.Equal(t, "bar", args.Foo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type textUnmarshaler struct {
|
||||||
|
val int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *textUnmarshaler) UnmarshalText(b []byte) error {
|
||||||
|
f.val = len(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextUnmarshaler(t *testing.T) {
|
||||||
|
// fields that implement TextUnmarshaler should be parsed using that interface
|
||||||
|
var args struct {
|
||||||
|
Foo *textUnmarshaler
|
||||||
|
}
|
||||||
|
err := parse("--foo abc", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, args.Foo.val)
|
||||||
|
}
|
||||||
|
|
||||||
|
type boolUnmarshaler bool
|
||||||
|
|
||||||
|
func (p *boolUnmarshaler) UnmarshalText(b []byte) error {
|
||||||
|
*p = len(b)%2 == 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBoolUnmarhsaler(t *testing.T) {
|
||||||
|
// test that a bool type that implements TextUnmarshaler is
|
||||||
|
// handled as a TextUnmarshaler not as a bool
|
||||||
|
var args struct {
|
||||||
|
Foo *boolUnmarshaler
|
||||||
|
}
|
||||||
|
err := parse("--foo ab", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.EqualValues(t, true, *args.Foo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sliceUnmarshaler []int
|
||||||
|
|
||||||
|
func (p *sliceUnmarshaler) UnmarshalText(b []byte) error {
|
||||||
|
*p = sliceUnmarshaler{len(b)}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceUnmarhsaler(t *testing.T) {
|
||||||
|
// test that a slice type that implements TextUnmarshaler is
|
||||||
|
// handled as a TextUnmarshaler not as a slice
|
||||||
|
var args struct {
|
||||||
|
Foo *sliceUnmarshaler
|
||||||
|
Bar string `arg:"positional"`
|
||||||
|
}
|
||||||
|
err := parse("--foo abcde xyz", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, *args.Foo, 1)
|
||||||
|
assert.EqualValues(t, 5, (*args.Foo)[0])
|
||||||
|
assert.Equal(t, "xyz", args.Bar)
|
||||||
|
}
|
||||||
|
|
31
scalar.go
31
scalar.go
|
@ -8,19 +8,33 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
durationType = reflect.TypeOf(time.Duration(0))
|
|
||||||
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
|
||||||
)
|
|
||||||
|
|
||||||
// set a value from a string
|
// set a value from a string
|
||||||
func setScalar(v reflect.Value, s string) error {
|
func setScalar(v reflect.Value, s string) error {
|
||||||
if !v.CanSet() {
|
if !v.CanSet() {
|
||||||
return fmt.Errorf("field is not exported")
|
return fmt.Errorf("field is not exported")
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have a time.Duration then use time.ParseDuration
|
// If we have a nil pointer then allocate a new object
|
||||||
if v.Type() == durationType {
|
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||||
|
v.Set(reflect.New(v.Type().Elem()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the object as an interface
|
||||||
|
scalar := v.Interface()
|
||||||
|
|
||||||
|
// If it implements encoding.TextUnmarshaler then use that
|
||||||
|
if scalar, ok := scalar.(encoding.TextUnmarshaler); ok {
|
||||||
|
return scalar.UnmarshalText([]byte(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a pointer then dereference it
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch on concrete type
|
||||||
|
switch scalar.(type) {
|
||||||
|
case time.Duration:
|
||||||
x, err := time.ParseDuration(s)
|
x, err := time.ParseDuration(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -29,6 +43,7 @@ func setScalar(v reflect.Value, s string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Switch on kind so that we can handle derived types
|
||||||
switch v.Kind() {
|
switch v.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
v.SetString(s)
|
v.SetString(s)
|
||||||
|
@ -57,7 +72,7 @@ func setScalar(v reflect.Value, s string) error {
|
||||||
}
|
}
|
||||||
v.SetFloat(x)
|
v.SetFloat(x)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("not a scalar type: %s", v.Kind())
|
return fmt.Errorf("cannot parse argument into %s", v.Type().String())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue