diff --git a/parse.go b/parse.go index 74b9175..d9382c9 100644 --- a/parse.go +++ b/parse.go @@ -1,56 +1,16 @@ -package arguments +package arg import ( "fmt" - "log" "os" "reflect" "strconv" "strings" ) -// MustParse processes command line arguments and exits upon failure. -func MustParse(dest interface{}) { - err := Parse(dest) - if err != nil { - fmt.Println(err) - os.Exit(1) - } -} - -// Parse processes command line arguments and stores the result in args. -func Parse(dest interface{}) error { - return ParseFrom(dest, os.Args) -} - -// ParseFrom processes command line arguments and stores the result in args. -func ParseFrom(dest interface{}, args []string) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - panic(fmt.Sprintf("%s is not a pointer type", v.Type().Name())) - } - v = v.Elem() - - // Parse the spec - spec, err := extractSpec(v.Type()) - if err != nil { - return err - } - - // Process args - err = processArgs(v, spec, args) - if err != nil { - return err - } - - // Validate - return validate(spec) -} - -// spec represents information about an argument extracted from struct tags +// spec represents a command line option type spec struct { - field reflect.StructField - index int + dest reflect.Value long string short string multiple bool @@ -60,81 +20,131 @@ type spec struct { wasPresent bool } -// extractSpec gets specifications for each argument from the tags in a struct -func extractSpec(t reflect.Type) ([]*spec, error) { - if t.Kind() != reflect.Struct { - panic(fmt.Sprintf("%s is not a struct pointer", t.Name())) +// MustParse processes command line arguments and exits upon failure. +func MustParse(dest ...interface{}) { + err := Parse(dest...) + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +// Parse processes command line arguments and stores the result in args. +func Parse(dest ...interface{}) error { + return ParseFrom(os.Args[1:], dest...) +} + +// ParseFrom processes command line arguments and stores the result in args. +func ParseFrom(args []string, dest ...interface{}) error { + // Add the help option if one is not already defined + var internal struct { + Help bool `arg:"-h"` } + // Parse the spec + dest = append(dest, &internal) + spec, err := extractSpec(dest...) + if err != nil { + return err + } + + // Process args + err = processArgs(spec, args) + if err != nil { + return err + } + + // If -h or --help were specified then print help + if internal.Help { + writeUsage(os.Stdout, spec) + os.Exit(0) + } + + // Validate + return validate(spec) +} + +// extractSpec gets specifications for each argument from the tags in a struct +func extractSpec(dests ...interface{}) ([]*spec, error) { var specs []*spec - for i := 0; i < t.NumField(); i++ { - // Check for the ignore switch in the tag - field := t.Field(i) - tag := field.Tag.Get("arg") - if tag == "-" { - continue + for _, dest := range dests { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type())) + } + v = v.Elem() + if v.Kind() != reflect.Struct { + panic(fmt.Sprintf("%T is not a struct pointer", dest)) } - spec := spec{ - long: strings.ToLower(field.Name), - field: field, - index: i, - } + t := v.Type() + for i := 0; i < t.NumField(); i++ { + // Check for the ignore switch in the tag + field := t.Field(i) + tag := field.Tag.Get("arg") + if tag == "-" { + continue + } - // Get the scalar type for this field - scalarType := field.Type - log.Println(field.Name, field.Type, field.Type.Kind()) - if scalarType.Kind() == reflect.Slice { - spec.multiple = true - scalarType = scalarType.Elem() - if scalarType.Kind() == reflect.Ptr { + spec := spec{ + long: strings.ToLower(field.Name), + dest: v.Field(i), + } + + // Get the scalar type for this field + scalarType := field.Type + if scalarType.Kind() == reflect.Slice { + spec.multiple = true scalarType = scalarType.Elem() - } - } - - // Check for unsupported types - switch scalarType.Kind() { - case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Struct, - reflect.Complex64, reflect.Complex128: - return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) - } - - // Look at the tag - if tag != "" { - for _, key := range strings.Split(tag, ",") { - var value string - if pos := strings.Index(key, ":"); pos != -1 { - value = key[pos+1:] - key = key[:pos] + if scalarType.Kind() == reflect.Ptr { + scalarType = scalarType.Elem() } + } - switch { - case strings.HasPrefix(key, "--"): - spec.long = key[2:] - case strings.HasPrefix(key, "-"): - if len(key) != 2 { - return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name) + // Check for unsupported types + switch scalarType.Kind() { + case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Struct, + reflect.Complex64, reflect.Complex128: + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) + } + + // Look at the tag + if tag != "" { + for _, key := range strings.Split(tag, ",") { + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } + + switch { + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name) + } + spec.short = key[1:] + case key == "required": + spec.required = true + case key == "positional": + spec.positional = true + case key == "help": + spec.help = value + default: + return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) } - spec.short = key[1:] - case key == "required": - spec.required = true - case key == "positional": - spec.positional = true - case key == "help": - spec.help = value - default: - return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) } } + specs = append(specs, &spec) } - specs = append(specs, &spec) } return specs, nil } // processArgs processes arguments using a pre-constructed spec -func processArgs(dest reflect.Value, specs []*spec, args []string) error { +func processArgs(specs []*spec, args []string) error { // construct a map from --option to spec optionMap := make(map[string]*spec) for _, spec := range specs { @@ -192,7 +202,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { } else { values = append(values, value) } - err := setSlice(dest.Field(spec.index), values) + err := setSlice(spec.dest, values) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -200,7 +210,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { } // if it's a flag and it has no value then set the value to true - if spec.field.Type.Kind() == reflect.Bool && value == "" { + if spec.dest.Kind() == reflect.Bool && value == "" { value = "true" } @@ -213,7 +223,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { i++ } - err := setScalar(dest.Field(spec.index), value) + err := setScalar(spec.dest, value) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -221,22 +231,21 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { // process positionals for _, spec := range specs { - label := strings.ToLower(spec.field.Name) if spec.positional { if spec.multiple { - err := setSlice(dest.Field(spec.index), positionals) + err := setSlice(spec.dest, positionals) if err != nil { - return fmt.Errorf("error processing %s: %v", label, err) + return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = nil } else if len(positionals) > 0 { - err := setScalar(dest.Field(spec.index), positionals[0]) + err := setScalar(spec.dest, positionals[0]) if err != nil { - return fmt.Errorf("error processing %s: %v", label, err) + return fmt.Errorf("error processing %s: %v", spec.long, err) } positionals = positionals[1:] } else if spec.required { - return fmt.Errorf("%s is required", label) + return fmt.Errorf("%s is required", spec.long) } } } @@ -250,7 +259,7 @@ func processArgs(dest reflect.Value, specs []*spec, args []string) error { func validate(spec []*spec) error { for _, arg := range spec { if !arg.positional && arg.required && !arg.wasPresent { - return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name)) + return fmt.Errorf("--%s is required", arg.long) } } return nil diff --git a/parse_test.go b/parse_test.go index b8e3fa7..9ad5944 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1,4 +1,4 @@ -package arguments +package arg import ( "strings" @@ -16,7 +16,7 @@ func TestStringSingle(t *testing.T) { var args struct { Foo string } - err := ParseFrom(&args, split("--foo bar")) + err := ParseFrom(split("--foo bar"), &args) require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -30,7 +30,7 @@ func TestMixed(t *testing.T) { Spam float32 } args.Bar = 3 - err := ParseFrom(&args, split("123 -spam=1.2 -ham -f xyz")) + err := ParseFrom(split("123 -spam=1.2 -ham -f xyz"), &args) require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) assert.Equal(t, 3, args.Bar) @@ -43,7 +43,7 @@ func TestRequired(t *testing.T) { var args struct { Foo string `arg:"required"` } - err := ParseFrom(&args, nil) + err := ParseFrom(nil, &args) require.Error(t, err, "--foo is required") } @@ -52,15 +52,15 @@ func TestShortFlag(t *testing.T) { Foo string `arg:"-f"` } - err := ParseFrom(&args, split("-f xyz")) + err := ParseFrom(split("-f xyz"), &args) require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) - err = ParseFrom(&args, split("-foo xyz")) + err = ParseFrom(split("-foo xyz"), &args) require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) - err = ParseFrom(&args, split("--foo xyz")) + err = ParseFrom(split("--foo xyz"), &args) require.NoError(t, err) assert.Equal(t, "xyz", args.Foo) } @@ -71,7 +71,7 @@ func TestCaseSensitive(t *testing.T) { Upper bool `arg:"-V"` } - err := ParseFrom(&args, split("-v")) + err := ParseFrom(split("-v"), &args) require.NoError(t, err) assert.True(t, args.Lower) assert.False(t, args.Upper) @@ -83,7 +83,7 @@ func TestCaseSensitive2(t *testing.T) { Upper bool `arg:"-V"` } - err := ParseFrom(&args, split("-V")) + err := ParseFrom(split("-V"), &args) require.NoError(t, err) assert.False(t, args.Lower) assert.True(t, args.Upper) @@ -94,7 +94,7 @@ func TestPositional(t *testing.T) { Input string `arg:"positional"` Output string `arg:"positional"` } - err := ParseFrom(&args, split("foo")) + err := ParseFrom(split("foo"), &args) require.NoError(t, err) assert.Equal(t, "foo", args.Input) assert.Equal(t, "", args.Output) @@ -105,7 +105,7 @@ func TestRequiredPositional(t *testing.T) { Input string `arg:"positional"` Output string `arg:"positional,required"` } - err := ParseFrom(&args, split("foo")) + err := ParseFrom(split("foo"), &args) assert.Error(t, err) } @@ -114,7 +114,7 @@ func TestTooManyPositional(t *testing.T) { Input string `arg:"positional"` Output string `arg:"positional"` } - err := ParseFrom(&args, split("foo bar baz")) + err := ParseFrom(split("foo bar baz"), &args) assert.Error(t, err) } @@ -123,7 +123,7 @@ func TestMultiple(t *testing.T) { Foo []int Bar []string } - err := ParseFrom(&args, split("--foo 1 2 3 --bar x y z")) + err := ParseFrom(split("--foo 1 2 3 --bar x y z"), &args) require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, args.Foo) assert.Equal(t, []string{"x", "y", "z"}, args.Bar)