From 8397a40f4cafd39c553df848854e022d33149fa5 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Sat, 31 Oct 2015 17:05:14 -0700 Subject: [PATCH] positional arguments working --- parse.go | 78 +++++++++++++++++++++++++++++++++++++++++++-------- parse_test.go | 44 ++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 13 deletions(-) diff --git a/parse.go b/parse.go index b58b51e..74b9175 100644 --- a/parse.go +++ b/parse.go @@ -2,6 +2,7 @@ package arguments import ( "fmt" + "log" "os" "reflect" "strconv" @@ -82,6 +83,7 @@ func extractSpec(t reflect.Type) ([]*spec, error) { // Get the scalar type for this field scalarType := field.Type + log.Println(field.Name, field.Type, field.Type.Kind()) if scalarType.Kind() == reflect.Slice { spec.multiple = true scalarType = scalarType.Elem() @@ -133,14 +135,17 @@ func extractSpec(t reflect.Type) ([]*spec, error) { // processArgs processes arguments using a pre-constructed spec func processArgs(dest reflect.Value, specs []*spec, args []string) error { - // construct a map from arg name to spec - specByName := make(map[string]*spec) + // construct a map from --option to spec + optionMap := make(map[string]*spec) for _, spec := range specs { + if spec.positional { + continue + } if spec.long != "" { - specByName[spec.long] = spec + optionMap[spec.long] = spec } if spec.short != "" { - specByName[spec.short] = spec + optionMap[spec.short] = spec } } @@ -170,7 +175,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { } // lookup the spec for this option - spec, ok := specByName[opt] + spec, ok := optionMap[opt] if !ok { return fmt.Errorf("unknown argument %s", arg) } @@ -180,13 +185,17 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { if spec.multiple { var values []string if value == "" { - for i++; i < len(args) && !strings.HasPrefix(args[i], "-"); i++ { - values = append(values, args[i]) + for i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + values = append(values, args[i+1]) + i++ } } else { values = append(values, value) } - setSlice(dest, spec, values) + err := setSlice(dest.Field(spec.index), values) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg, err) + } continue } @@ -209,13 +218,38 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { return fmt.Errorf("error processing %s: %v", arg, err) } } + + // process positionals + for _, spec := range specs { + label := strings.ToLower(spec.field.Name) + if spec.positional { + if spec.multiple { + err := setSlice(dest.Field(spec.index), positionals) + if err != nil { + return fmt.Errorf("error processing %s: %v", label, err) + } + positionals = nil + } else if len(positionals) > 0 { + err := setScalar(dest.Field(spec.index), positionals[0]) + if err != nil { + return fmt.Errorf("error processing %s: %v", label, err) + } + positionals = positionals[1:] + } else if spec.required { + return fmt.Errorf("%s is required", label) + } + } + } + if len(positionals) > 0 { + return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) + } return nil } // validate an argument spec after arguments have been parse func validate(spec []*spec) error { for _, arg := range spec { - if arg.required && !arg.wasPresent { + if !arg.positional && arg.required && !arg.wasPresent { return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name)) } } @@ -223,15 +257,35 @@ func validate(spec []*spec) error { } // parse a value as the apropriate type and store it in the struct -func setSlice(dest reflect.Value, spec *spec, values []string) error { - // TODO +func setSlice(dest reflect.Value, values []string) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr { + ptr = true + elem = elem.Elem() + } + + for _, s := range values { + v := reflect.New(elem) + if err := setScalar(v.Elem(), s); err != nil { + return err + } + if ptr { + v = v.Addr() + } + dest.Set(reflect.Append(dest, v.Elem())) + } return nil } // set a value from a string func setScalar(v reflect.Value, s string) error { if !v.CanSet() { - return fmt.Errorf("field is not writable") + return fmt.Errorf("field is not exported") } switch v.Kind() { diff --git a/parse_test.go b/parse_test.go index 4864ebc..b8e3fa7 100644 --- a/parse_test.go +++ b/parse_test.go @@ -25,14 +25,16 @@ func TestMixed(t *testing.T) { var args struct { Foo string `arg:"-f"` Bar int + Baz uint `arg:"positional"` Ham bool Spam float32 } args.Bar = 3 - err := ParseFrom(&args, split("-spam=1.2 -ham -f xyz")) + err := ParseFrom(&args, split("123 -spam=1.2 -ham -f xyz")) require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) assert.Equal(t, 3, args.Bar) + assert.Equal(t, uint(123), args.Baz) assert.Equal(t, true, args.Ham) assert.Equal(t, 1.2, args.Spam) } @@ -86,3 +88,43 @@ func TestCaseSensitive2(t *testing.T) { assert.False(t, args.Lower) assert.True(t, args.Upper) } + +func TestPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional"` + } + err := ParseFrom(&args, split("foo")) + require.NoError(t, err) + assert.Equal(t, "foo", args.Input) + assert.Equal(t, "", args.Output) +} + +func TestRequiredPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional,required"` + } + err := ParseFrom(&args, split("foo")) + assert.Error(t, err) +} + +func TestTooManyPositional(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Output string `arg:"positional"` + } + err := ParseFrom(&args, split("foo bar baz")) + assert.Error(t, err) +} + +func TestMultiple(t *testing.T) { + var args struct { + Foo []int + Bar []string + } + err := ParseFrom(&args, split("--foo 1 2 3 --bar x y z")) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, args.Foo) + assert.Equal(t, []string{"x", "y", "z"}, args.Bar) +}