positional arguments working

This commit is contained in:
Alex Flint 2015-10-31 17:05:14 -07:00
parent 408290f7c2
commit 8397a40f4c
2 changed files with 109 additions and 13 deletions

View File

@ -2,6 +2,7 @@ package arguments
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@ -82,6 +83,7 @@ func extractSpec(t reflect.Type) ([]*spec, error) {
// Get the scalar type for this field // Get the scalar type for this field
scalarType := field.Type scalarType := field.Type
log.Println(field.Name, field.Type, field.Type.Kind())
if scalarType.Kind() == reflect.Slice { if scalarType.Kind() == reflect.Slice {
spec.multiple = true spec.multiple = true
scalarType = scalarType.Elem() scalarType = scalarType.Elem()
@ -133,14 +135,17 @@ func extractSpec(t reflect.Type) ([]*spec, error) {
// processArgs processes arguments using a pre-constructed spec // processArgs processes arguments using a pre-constructed spec
func processArgs(dest reflect.Value, specs []*spec, args []string) error { func processArgs(dest reflect.Value, specs []*spec, args []string) error {
// construct a map from arg name to spec // construct a map from --option to spec
specByName := make(map[string]*spec) optionMap := make(map[string]*spec)
for _, spec := range specs { for _, spec := range specs {
if spec.positional {
continue
}
if spec.long != "" { if spec.long != "" {
specByName[spec.long] = spec optionMap[spec.long] = spec
} }
if spec.short != "" { if spec.short != "" {
specByName[spec.short] = spec optionMap[spec.short] = spec
} }
} }
@ -170,7 +175,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
} }
// lookup the spec for this option // lookup the spec for this option
spec, ok := specByName[opt] spec, ok := optionMap[opt]
if !ok { if !ok {
return fmt.Errorf("unknown argument %s", arg) return fmt.Errorf("unknown argument %s", arg)
} }
@ -180,13 +185,17 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
if spec.multiple { if spec.multiple {
var values []string var values []string
if value == "" { if value == "" {
for i++; i < len(args) && !strings.HasPrefix(args[i], "-"); i++ { for i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") {
values = append(values, args[i]) values = append(values, args[i+1])
i++
} }
} else { } else {
values = append(values, value) values = append(values, value)
} }
setSlice(dest, spec, values) err := setSlice(dest.Field(spec.index), values)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
continue continue
} }
@ -209,13 +218,38 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
} }
// process positionals
for _, spec := range specs {
label := strings.ToLower(spec.field.Name)
if spec.positional {
if spec.multiple {
err := setSlice(dest.Field(spec.index), positionals)
if err != nil {
return fmt.Errorf("error processing %s: %v", label, err)
}
positionals = nil
} else if len(positionals) > 0 {
err := setScalar(dest.Field(spec.index), positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", label, err)
}
positionals = positionals[1:]
} else if spec.required {
return fmt.Errorf("%s is required", label)
}
}
}
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
return nil return nil
} }
// validate an argument spec after arguments have been parse // validate an argument spec after arguments have been parse
func validate(spec []*spec) error { func validate(spec []*spec) error {
for _, arg := range spec { for _, arg := range spec {
if arg.required && !arg.wasPresent { if !arg.positional && arg.required && !arg.wasPresent {
return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name)) return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name))
} }
} }
@ -223,15 +257,35 @@ func validate(spec []*spec) error {
} }
// parse a value as the apropriate type and store it in the struct // parse a value as the apropriate type and store it in the struct
func setSlice(dest reflect.Value, spec *spec, values []string) error { func setSlice(dest reflect.Value, values []string) error {
// TODO if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
var ptr bool
elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr {
ptr = true
elem = elem.Elem()
}
for _, s := range values {
v := reflect.New(elem)
if err := setScalar(v.Elem(), s); err != nil {
return err
}
if ptr {
v = v.Addr()
}
dest.Set(reflect.Append(dest, v.Elem()))
}
return nil return nil
} }
// set a value from a string // set a value from a string
func setScalar(v reflect.Value, s string) error { func setScalar(v reflect.Value, s string) error {
if !v.CanSet() { if !v.CanSet() {
return fmt.Errorf("field is not writable") return fmt.Errorf("field is not exported")
} }
switch v.Kind() { switch v.Kind() {

View File

@ -25,14 +25,16 @@ func TestMixed(t *testing.T) {
var args struct { var args struct {
Foo string `arg:"-f"` Foo string `arg:"-f"`
Bar int Bar int
Baz uint `arg:"positional"`
Ham bool Ham bool
Spam float32 Spam float32
} }
args.Bar = 3 args.Bar = 3
err := ParseFrom(&args, split("-spam=1.2 -ham -f xyz")) err := ParseFrom(&args, split("123 -spam=1.2 -ham -f xyz"))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
assert.Equal(t, 3, args.Bar) assert.Equal(t, 3, args.Bar)
assert.Equal(t, uint(123), args.Baz)
assert.Equal(t, true, args.Ham) assert.Equal(t, true, args.Ham)
assert.Equal(t, 1.2, args.Spam) assert.Equal(t, 1.2, args.Spam)
} }
@ -86,3 +88,43 @@ func TestCaseSensitive2(t *testing.T) {
assert.False(t, args.Lower) assert.False(t, args.Lower)
assert.True(t, args.Upper) assert.True(t, args.Upper)
} }
func TestPositional(t *testing.T) {
var args struct {
Input string `arg:"positional"`
Output string `arg:"positional"`
}
err := ParseFrom(&args, split("foo"))
require.NoError(t, err)
assert.Equal(t, "foo", args.Input)
assert.Equal(t, "", args.Output)
}
func TestRequiredPositional(t *testing.T) {
var args struct {
Input string `arg:"positional"`
Output string `arg:"positional,required"`
}
err := ParseFrom(&args, split("foo"))
assert.Error(t, err)
}
func TestTooManyPositional(t *testing.T) {
var args struct {
Input string `arg:"positional"`
Output string `arg:"positional"`
}
err := ParseFrom(&args, split("foo bar baz"))
assert.Error(t, err)
}
func TestMultiple(t *testing.T) {
var args struct {
Foo []int
Bar []string
}
err := ParseFrom(&args, split("--foo 1 2 3 --bar x y z"))
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, args.Foo)
assert.Equal(t, []string{"x", "y", "z"}, args.Bar)
}