positional arguments working

This commit is contained in:
Alex Flint 2015-10-31 17:05:14 -07:00
parent 408290f7c2
commit 8397a40f4c
2 changed files with 109 additions and 13 deletions

View File

@ -2,6 +2,7 @@ package arguments
import (
"fmt"
"log"
"os"
"reflect"
"strconv"
@ -82,6 +83,7 @@ func extractSpec(t reflect.Type) ([]*spec, error) {
// Get the scalar type for this field
scalarType := field.Type
log.Println(field.Name, field.Type, field.Type.Kind())
if scalarType.Kind() == reflect.Slice {
spec.multiple = true
scalarType = scalarType.Elem()
@ -133,14 +135,17 @@ func extractSpec(t reflect.Type) ([]*spec, error) {
// processArgs processes arguments using a pre-constructed spec
func processArgs(dest reflect.Value, specs []*spec, args []string) error {
// construct a map from arg name to spec
specByName := make(map[string]*spec)
// construct a map from --option to spec
optionMap := make(map[string]*spec)
for _, spec := range specs {
if spec.positional {
continue
}
if spec.long != "" {
specByName[spec.long] = spec
optionMap[spec.long] = spec
}
if spec.short != "" {
specByName[spec.short] = spec
optionMap[spec.short] = spec
}
}
@ -170,7 +175,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
}
// lookup the spec for this option
spec, ok := specByName[opt]
spec, ok := optionMap[opt]
if !ok {
return fmt.Errorf("unknown argument %s", arg)
}
@ -180,13 +185,17 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
if spec.multiple {
var values []string
if value == "" {
for i++; i < len(args) && !strings.HasPrefix(args[i], "-"); i++ {
values = append(values, args[i])
for i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") {
values = append(values, args[i+1])
i++
}
} else {
values = append(values, value)
}
setSlice(dest, spec, values)
err := setSlice(dest.Field(spec.index), values)
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
continue
}
@ -209,13 +218,38 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
return fmt.Errorf("error processing %s: %v", arg, err)
}
}
// process positionals
for _, spec := range specs {
label := strings.ToLower(spec.field.Name)
if spec.positional {
if spec.multiple {
err := setSlice(dest.Field(spec.index), positionals)
if err != nil {
return fmt.Errorf("error processing %s: %v", label, err)
}
positionals = nil
} else if len(positionals) > 0 {
err := setScalar(dest.Field(spec.index), positionals[0])
if err != nil {
return fmt.Errorf("error processing %s: %v", label, err)
}
positionals = positionals[1:]
} else if spec.required {
return fmt.Errorf("%s is required", label)
}
}
}
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
return nil
}
// validate an argument spec after arguments have been parse
func validate(spec []*spec) error {
for _, arg := range spec {
if arg.required && !arg.wasPresent {
if !arg.positional && arg.required && !arg.wasPresent {
return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name))
}
}
@ -223,15 +257,35 @@ func validate(spec []*spec) error {
}
// parse a value as the apropriate type and store it in the struct
func setSlice(dest reflect.Value, spec *spec, values []string) error {
// TODO
func setSlice(dest reflect.Value, values []string) error {
if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
var ptr bool
elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr {
ptr = true
elem = elem.Elem()
}
for _, s := range values {
v := reflect.New(elem)
if err := setScalar(v.Elem(), s); err != nil {
return err
}
if ptr {
v = v.Addr()
}
dest.Set(reflect.Append(dest, v.Elem()))
}
return nil
}
// set a value from a string
func setScalar(v reflect.Value, s string) error {
if !v.CanSet() {
return fmt.Errorf("field is not writable")
return fmt.Errorf("field is not exported")
}
switch v.Kind() {

View File

@ -25,14 +25,16 @@ func TestMixed(t *testing.T) {
var args struct {
Foo string `arg:"-f"`
Bar int
Baz uint `arg:"positional"`
Ham bool
Spam float32
}
args.Bar = 3
err := ParseFrom(&args, split("-spam=1.2 -ham -f xyz"))
err := ParseFrom(&args, split("123 -spam=1.2 -ham -f xyz"))
require.NoError(t, err)
assert.Equal(t, "xyz", args.Foo)
assert.Equal(t, 3, args.Bar)
assert.Equal(t, uint(123), args.Baz)
assert.Equal(t, true, args.Ham)
assert.Equal(t, 1.2, args.Spam)
}
@ -86,3 +88,43 @@ func TestCaseSensitive2(t *testing.T) {
assert.False(t, args.Lower)
assert.True(t, args.Upper)
}
func TestPositional(t *testing.T) {
var args struct {
Input string `arg:"positional"`
Output string `arg:"positional"`
}
err := ParseFrom(&args, split("foo"))
require.NoError(t, err)
assert.Equal(t, "foo", args.Input)
assert.Equal(t, "", args.Output)
}
func TestRequiredPositional(t *testing.T) {
var args struct {
Input string `arg:"positional"`
Output string `arg:"positional,required"`
}
err := ParseFrom(&args, split("foo"))
assert.Error(t, err)
}
func TestTooManyPositional(t *testing.T) {
var args struct {
Input string `arg:"positional"`
Output string `arg:"positional"`
}
err := ParseFrom(&args, split("foo bar baz"))
assert.Error(t, err)
}
func TestMultiple(t *testing.T) {
var args struct {
Foo []int
Bar []string
}
err := ParseFrom(&args, split("--foo 1 2 3 --bar x y z"))
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, args.Foo)
assert.Equal(t, []string{"x", "y", "z"}, args.Bar)
}