positional arguments working
This commit is contained in:
parent
408290f7c2
commit
8397a40f4c
78
parse.go
78
parse.go
|
@ -2,6 +2,7 @@ package arguments
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
@ -82,6 +83,7 @@ func extractSpec(t reflect.Type) ([]*spec, error) {
|
|||
|
||||
// Get the scalar type for this field
|
||||
scalarType := field.Type
|
||||
log.Println(field.Name, field.Type, field.Type.Kind())
|
||||
if scalarType.Kind() == reflect.Slice {
|
||||
spec.multiple = true
|
||||
scalarType = scalarType.Elem()
|
||||
|
@ -133,14 +135,17 @@ func extractSpec(t reflect.Type) ([]*spec, error) {
|
|||
|
||||
// processArgs processes arguments using a pre-constructed spec
|
||||
func processArgs(dest reflect.Value, specs []*spec, args []string) error {
|
||||
// construct a map from arg name to spec
|
||||
specByName := make(map[string]*spec)
|
||||
// construct a map from --option to spec
|
||||
optionMap := make(map[string]*spec)
|
||||
for _, spec := range specs {
|
||||
if spec.positional {
|
||||
continue
|
||||
}
|
||||
if spec.long != "" {
|
||||
specByName[spec.long] = spec
|
||||
optionMap[spec.long] = spec
|
||||
}
|
||||
if spec.short != "" {
|
||||
specByName[spec.short] = spec
|
||||
optionMap[spec.short] = spec
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,7 +175,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
|
|||
}
|
||||
|
||||
// lookup the spec for this option
|
||||
spec, ok := specByName[opt]
|
||||
spec, ok := optionMap[opt]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown argument %s", arg)
|
||||
}
|
||||
|
@ -180,13 +185,17 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
|
|||
if spec.multiple {
|
||||
var values []string
|
||||
if value == "" {
|
||||
for i++; i < len(args) && !strings.HasPrefix(args[i], "-"); i++ {
|
||||
values = append(values, args[i])
|
||||
for i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") {
|
||||
values = append(values, args[i+1])
|
||||
i++
|
||||
}
|
||||
} else {
|
||||
values = append(values, value)
|
||||
}
|
||||
setSlice(dest, spec, values)
|
||||
err := setSlice(dest.Field(spec.index), values)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", arg, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -209,13 +218,38 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error {
|
|||
return fmt.Errorf("error processing %s: %v", arg, err)
|
||||
}
|
||||
}
|
||||
|
||||
// process positionals
|
||||
for _, spec := range specs {
|
||||
label := strings.ToLower(spec.field.Name)
|
||||
if spec.positional {
|
||||
if spec.multiple {
|
||||
err := setSlice(dest.Field(spec.index), positionals)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", label, err)
|
||||
}
|
||||
positionals = nil
|
||||
} else if len(positionals) > 0 {
|
||||
err := setScalar(dest.Field(spec.index), positionals[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", label, err)
|
||||
}
|
||||
positionals = positionals[1:]
|
||||
} else if spec.required {
|
||||
return fmt.Errorf("%s is required", label)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(positionals) > 0 {
|
||||
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate an argument spec after arguments have been parse
|
||||
func validate(spec []*spec) error {
|
||||
for _, arg := range spec {
|
||||
if arg.required && !arg.wasPresent {
|
||||
if !arg.positional && arg.required && !arg.wasPresent {
|
||||
return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name))
|
||||
}
|
||||
}
|
||||
|
@ -223,15 +257,35 @@ func validate(spec []*spec) error {
|
|||
}
|
||||
|
||||
// parse a value as the apropriate type and store it in the struct
|
||||
func setSlice(dest reflect.Value, spec *spec, values []string) error {
|
||||
// TODO
|
||||
func setSlice(dest reflect.Value, values []string) error {
|
||||
if !dest.CanSet() {
|
||||
return fmt.Errorf("field is not writable")
|
||||
}
|
||||
|
||||
var ptr bool
|
||||
elem := dest.Type().Elem()
|
||||
if elem.Kind() == reflect.Ptr {
|
||||
ptr = true
|
||||
elem = elem.Elem()
|
||||
}
|
||||
|
||||
for _, s := range values {
|
||||
v := reflect.New(elem)
|
||||
if err := setScalar(v.Elem(), s); err != nil {
|
||||
return err
|
||||
}
|
||||
if ptr {
|
||||
v = v.Addr()
|
||||
}
|
||||
dest.Set(reflect.Append(dest, v.Elem()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// set a value from a string
|
||||
func setScalar(v reflect.Value, s string) error {
|
||||
if !v.CanSet() {
|
||||
return fmt.Errorf("field is not writable")
|
||||
return fmt.Errorf("field is not exported")
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
|
|
|
@ -25,14 +25,16 @@ func TestMixed(t *testing.T) {
|
|||
var args struct {
|
||||
Foo string `arg:"-f"`
|
||||
Bar int
|
||||
Baz uint `arg:"positional"`
|
||||
Ham bool
|
||||
Spam float32
|
||||
}
|
||||
args.Bar = 3
|
||||
err := ParseFrom(&args, split("-spam=1.2 -ham -f xyz"))
|
||||
err := ParseFrom(&args, split("123 -spam=1.2 -ham -f xyz"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "xyz", args.Foo)
|
||||
assert.Equal(t, 3, args.Bar)
|
||||
assert.Equal(t, uint(123), args.Baz)
|
||||
assert.Equal(t, true, args.Ham)
|
||||
assert.Equal(t, 1.2, args.Spam)
|
||||
}
|
||||
|
@ -86,3 +88,43 @@ func TestCaseSensitive2(t *testing.T) {
|
|||
assert.False(t, args.Lower)
|
||||
assert.True(t, args.Upper)
|
||||
}
|
||||
|
||||
func TestPositional(t *testing.T) {
|
||||
var args struct {
|
||||
Input string `arg:"positional"`
|
||||
Output string `arg:"positional"`
|
||||
}
|
||||
err := ParseFrom(&args, split("foo"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "foo", args.Input)
|
||||
assert.Equal(t, "", args.Output)
|
||||
}
|
||||
|
||||
func TestRequiredPositional(t *testing.T) {
|
||||
var args struct {
|
||||
Input string `arg:"positional"`
|
||||
Output string `arg:"positional,required"`
|
||||
}
|
||||
err := ParseFrom(&args, split("foo"))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTooManyPositional(t *testing.T) {
|
||||
var args struct {
|
||||
Input string `arg:"positional"`
|
||||
Output string `arg:"positional"`
|
||||
}
|
||||
err := ParseFrom(&args, split("foo bar baz"))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMultiple(t *testing.T) {
|
||||
var args struct {
|
||||
Foo []int
|
||||
Bar []string
|
||||
}
|
||||
err := ParseFrom(&args, split("--foo 1 2 3 --bar x y z"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []int{1, 2, 3}, args.Foo)
|
||||
assert.Equal(t, []string{"x", "y", "z"}, args.Bar)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue