From 408290f7c2a968a0de255813e125a9ebb0a9dda6 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Sat, 31 Oct 2015 16:15:24 -0700 Subject: [PATCH] basic first version working --- parse.go | 268 ++++++++++++++++++++++++++++++++++++++++++++++++++ parse_test.go | 88 +++++++++++++++++ 2 files changed, 356 insertions(+) create mode 100644 parse.go create mode 100644 parse_test.go diff --git a/parse.go b/parse.go new file mode 100644 index 0000000..b58b51e --- /dev/null +++ b/parse.go @@ -0,0 +1,268 @@ +package arguments + +import ( + "fmt" + "os" + "reflect" + "strconv" + "strings" +) + +// MustParse processes command line arguments and exits upon failure. +func MustParse(dest interface{}) { + err := Parse(dest) + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +// Parse processes command line arguments and stores the result in args. +func Parse(dest interface{}) error { + return ParseFrom(dest, os.Args) +} + +// ParseFrom processes command line arguments and stores the result in args. +func ParseFrom(dest interface{}, args []string) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + panic(fmt.Sprintf("%s is not a pointer type", v.Type().Name())) + } + v = v.Elem() + + // Parse the spec + spec, err := extractSpec(v.Type()) + if err != nil { + return err + } + + // Process args + err = processArgs(v, spec, args) + if err != nil { + return err + } + + // Validate + return validate(spec) +} + +// spec represents information about an argument extracted from struct tags +type spec struct { + field reflect.StructField + index int + long string + short string + multiple bool + required bool + positional bool + help string + wasPresent bool +} + +// extractSpec gets specifications for each argument from the tags in a struct +func extractSpec(t reflect.Type) ([]*spec, error) { + if t.Kind() != reflect.Struct { + panic(fmt.Sprintf("%s is not a struct pointer", t.Name())) + } + + var specs []*spec + for i := 0; i < t.NumField(); i++ { + // Check for the ignore switch in the tag + field := t.Field(i) + tag := field.Tag.Get("arg") + if tag == "-" { + continue + } + + spec := spec{ + long: strings.ToLower(field.Name), + field: field, + index: i, + } + + // Get the scalar type for this field + scalarType := field.Type + if scalarType.Kind() == reflect.Slice { + spec.multiple = true + scalarType = scalarType.Elem() + if scalarType.Kind() == reflect.Ptr { + scalarType = scalarType.Elem() + } + } + + // Check for unsupported types + switch scalarType.Kind() { + case reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Struct, + reflect.Complex64, reflect.Complex128: + return nil, fmt.Errorf("%s.%s: %s fields are not supported", t.Name(), field.Name, scalarType.Kind()) + } + + // Look at the tag + if tag != "" { + for _, key := range strings.Split(tag, ",") { + var value string + if pos := strings.Index(key, ":"); pos != -1 { + value = key[pos+1:] + key = key[:pos] + } + + switch { + case strings.HasPrefix(key, "--"): + spec.long = key[2:] + case strings.HasPrefix(key, "-"): + if len(key) != 2 { + return nil, fmt.Errorf("%s.%s: short arguments must be one character only", t.Name(), field.Name) + } + spec.short = key[1:] + case key == "required": + spec.required = true + case key == "positional": + spec.positional = true + case key == "help": + spec.help = value + default: + return nil, fmt.Errorf("unrecognized tag '%s' on field %s", key, tag) + } + } + } + specs = append(specs, &spec) + } + return specs, nil +} + +// processArgs processes arguments using a pre-constructed spec +func processArgs(dest reflect.Value, specs []*spec, args []string) error { + // construct a map from arg name to spec + specByName := make(map[string]*spec) + for _, spec := range specs { + if spec.long != "" { + specByName[spec.long] = spec + } + if spec.short != "" { + specByName[spec.short] = spec + } + } + + // process each string from the command line + var allpositional bool + var positionals []string + + // must use explicit for loop, not range, because we manipulate i inside the loop + for i := 0; i < len(args); i++ { + arg := args[i] + if arg == "--" { + allpositional = true + continue + } + + if !strings.HasPrefix(arg, "-") || allpositional { + positionals = append(positionals, arg) + continue + } + + // check for an equals sign, as in "--foo=bar" + var value string + opt := strings.TrimLeft(arg, "-") + if pos := strings.Index(opt, "="); pos != -1 { + value = opt[pos+1:] + opt = opt[:pos] + } + + // lookup the spec for this option + spec, ok := specByName[opt] + if !ok { + return fmt.Errorf("unknown argument %s", arg) + } + spec.wasPresent = true + + // deal with the case of multiple values + if spec.multiple { + var values []string + if value == "" { + for i++; i < len(args) && !strings.HasPrefix(args[i], "-"); i++ { + values = append(values, args[i]) + } + } else { + values = append(values, value) + } + setSlice(dest, spec, values) + continue + } + + // if it's a flag and it has no value then set the value to true + if spec.field.Type.Kind() == reflect.Bool && value == "" { + value = "true" + } + + // if we have something like "--foo" then the value is the next argument + if value == "" { + if i+1 == len(args) || strings.HasPrefix(args[i+1], "-") { + return fmt.Errorf("missing value for %s", arg) + } + value = args[i+1] + i++ + } + + err := setScalar(dest.Field(spec.index), value) + if err != nil { + return fmt.Errorf("error processing %s: %v", arg, err) + } + } + return nil +} + +// validate an argument spec after arguments have been parse +func validate(spec []*spec) error { + for _, arg := range spec { + if arg.required && !arg.wasPresent { + return fmt.Errorf("--%s is required", strings.ToLower(arg.field.Name)) + } + } + return nil +} + +// parse a value as the apropriate type and store it in the struct +func setSlice(dest reflect.Value, spec *spec, values []string) error { + // TODO + return nil +} + +// set a value from a string +func setScalar(v reflect.Value, s string) error { + if !v.CanSet() { + return fmt.Errorf("field is not writable") + } + + switch v.Kind() { + case reflect.String: + v.Set(reflect.ValueOf(s)) + case reflect.Bool: + x, err := strconv.ParseBool(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(x)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(s, 10, v.Type().Bits()) + if err != nil { + return err + } + v.Set(reflect.ValueOf(x).Convert(v.Type())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(s, 10, v.Type().Bits()) + if err != nil { + return err + } + v.Set(reflect.ValueOf(x).Convert(v.Type())) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil { + return err + } + v.Set(reflect.ValueOf(x).Convert(v.Type())) + default: + return fmt.Errorf("not a scalar type: %s", v.Kind()) + } + return nil +} diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..4864ebc --- /dev/null +++ b/parse_test.go @@ -0,0 +1,88 @@ +package arguments + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func split(s string) []string { + return strings.Split(s, " ") +} + +func TestStringSingle(t *testing.T) { + var args struct { + Foo string + } + err := ParseFrom(&args, split("--foo bar")) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestMixed(t *testing.T) { + var args struct { + Foo string `arg:"-f"` + Bar int + Ham bool + Spam float32 + } + args.Bar = 3 + err := ParseFrom(&args, split("-spam=1.2 -ham -f xyz")) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + assert.Equal(t, 3, args.Bar) + assert.Equal(t, true, args.Ham) + assert.Equal(t, 1.2, args.Spam) +} + +func TestRequired(t *testing.T) { + var args struct { + Foo string `arg:"required"` + } + err := ParseFrom(&args, nil) + require.Error(t, err, "--foo is required") +} + +func TestShortFlag(t *testing.T) { + var args struct { + Foo string `arg:"-f"` + } + + err := ParseFrom(&args, split("-f xyz")) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = ParseFrom(&args, split("-foo xyz")) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) + + err = ParseFrom(&args, split("--foo xyz")) + require.NoError(t, err) + assert.Equal(t, "xyz", args.Foo) +} + +func TestCaseSensitive(t *testing.T) { + var args struct { + Lower bool `arg:"-v"` + Upper bool `arg:"-V"` + } + + err := ParseFrom(&args, split("-v")) + require.NoError(t, err) + assert.True(t, args.Lower) + assert.False(t, args.Upper) +} + +func TestCaseSensitive2(t *testing.T) { + var args struct { + Lower bool `arg:"-v"` + Upper bool `arg:"-V"` + } + + err := ParseFrom(&args, split("-V")) + require.NoError(t, err) + assert.False(t, args.Lower) + assert.True(t, args.Upper) +}