From 9949860eb3d60d374df3a47ebc0a22ca55bba399 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Mon, 19 Apr 2021 13:21:04 -0700 Subject: [PATCH] change "kind" to "cardinality", add support for maps to parser --- parse.go | 33 +++++++++++++------------- parse_test.go | 22 +++++++++++++++++- reflect.go | 46 +++++++++++++++++------------------- reflect_test.go | 62 ++++++++++++++++++++++++------------------------- sequence.go | 29 +++++++++++++++++------ usage.go | 18 +++++++------- 6 files changed, 120 insertions(+), 90 deletions(-) diff --git a/parse.go b/parse.go index 37df734..e05d1c3 100644 --- a/parse.go +++ b/parse.go @@ -50,13 +50,12 @@ type spec struct { field reflect.StructField // the struct field from which this option was created long string // the --long form for this option, or empty if none short string // the -s short form for this option, or empty if none - multiple bool - required bool - positional bool - separate bool + cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) + required bool // if true, this option must be present on the command line + positional bool // if true, this option will be looked for in the positional flags + separate bool // if true, help string env string - boolean bool defaultVal string // default value for this option placeholder string // name of the data in help } @@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { if !isSubcommand { cmd.specs = append(cmd.specs, &spec) - var parseable bool - //parseable, spec.boolean, spec.multiple = canParse(field.Type) - if !parseable { + var err error + spec.cardinality, err = cardinalityOf(field.Type) + if err != nil { errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", t.Name(), field.Name, field.Type.String())) return false } - if spec.multiple && hasDefault { - errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields", + if spec.cardinality == multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", t.Name(), field.Name)) return false } @@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error continue } - if spec.multiple { + if spec.cardinality == multiple { // expect a CSV string in an environment // variable in the case of multiple values values, err := csv.NewReader(strings.NewReader(value)).Read() @@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error err, ) } - if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil { + if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil { return fmt.Errorf( "error processing environment variable %s with multiple values: %v", spec.env, @@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error { wasPresent[spec] = true // deal with the case of multiple values - if spec.multiple { + if spec.cardinality == multiple { var values []string if value == "" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { @@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error { } else { values = append(values, value) } - err := setSlice(p.val(spec.dest), values, !spec.separate) + err := setSliceOrMap(p.val(spec.dest), values, !spec.separate) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error { // if it's a flag and it has no value then set the value to true // use boolean because this takes account of TextUnmarshaler - if spec.boolean && value == "" { + if spec.cardinality == zero && value == "" { value = "true" } @@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error { break } wasPresent[spec] = true - if spec.multiple { - err := setSlice(p.val(spec.dest), positionals, true) + if spec.cardinality == multiple { + err := setSliceOrMap(p.val(spec.dest), positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.field.Name, err) } diff --git a/parse_test.go b/parse_test.go index 0decfc1..6ee3541 100644 --- a/parse_test.go +++ b/parse_test.go @@ -220,6 +220,14 @@ func TestLongFlag(t *testing.T) { assert.Equal(t, "xyz", args.Foo) } +func TestSlice(t *testing.T) { + var args struct { + Strings []string + } + err := parse("--strings a b c", &args) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, args.Strings) +} func TestSliceOfBools(t *testing.T) { var args struct { B []bool @@ -230,6 +238,18 @@ func TestSliceOfBools(t *testing.T) { assert.Equal(t, []bool{true, false, true}, args.B) } +func TestMap(t *testing.T) { + var args struct { + Values map[string]int + } + err := parse("--values a=1 b=2 c=3", &args) + require.NoError(t, err) + assert.Len(t, args.Values, 3) + assert.Equal(t, 1, args.Values["a"]) + assert.Equal(t, 2, args.Values["b"]) + assert.Equal(t, 3, args.Values["c"]) +} + func TestPlaceholder(t *testing.T) { var args struct { Input string `arg:"positional" placeholder:"SRC"` @@ -1233,7 +1253,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { } err := parse("", &args) - assert.EqualError(t, err, ".A: default values are not supported for slice fields") + assert.EqualError(t, err, ".A: default values are not supported for slice or map fields") } func TestUnexportedFieldsSkipped(t *testing.T) { diff --git a/reflect.go b/reflect.go index c4fc5d9..be202dc 100644 --- a/reflect.go +++ b/reflect.go @@ -12,31 +12,27 @@ import ( var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() -// kind is used to track the various kinds of options: -// - regular is an ordinary option that will be parsed from a single token -// - binary is an option that will be true if present but does not expect an explicit value -// - sequence is an option that accepts multiple values and will end up in a slice -// - mapping is an option that acccepts multiple key=value strings and will end up in a map -type kind int +// cardinality tracks how many tokens are expected for a given spec +// - zero is a boolean, which does to expect any value +// - one is an ordinary option that will be parsed from a single token +// - multiple is a slice or map that can accept zero or more tokens +type cardinality int const ( - regular kind = iota - binary - sequence - mapping + zero cardinality = iota + one + multiple unsupported ) -func (k kind) String() string { +func (k cardinality) String() string { switch k { - case regular: - return "regular" - case binary: - return "binary" - case sequence: - return "sequence" - case mapping: - return "mapping" + case zero: + return "zero" + case one: + return "one" + case multiple: + return "multiple" case unsupported: return "unsupported" default: @@ -44,13 +40,13 @@ func (k kind) String() string { } } -// kindOf returns true if the type can be parsed from a string -func kindOf(t reflect.Type) (kind, error) { +// cardinalityOf returns true if the type can be parsed from a string +func cardinalityOf(t reflect.Type) (cardinality, error) { if scalar.CanParse(t) { if isBoolean(t) { - return binary, nil + return zero, nil } else { - return regular, nil + return one, nil } } @@ -65,7 +61,7 @@ func kindOf(t reflect.Type) (kind, error) { if !scalar.CanParse(t.Elem()) { return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into %v", t, t.Elem()) } - return sequence, nil + return multiple, nil case reflect.Map: if !scalar.CanParse(t.Key()) { return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the key type %v", t, t.Elem()) @@ -73,7 +69,7 @@ func kindOf(t reflect.Type) (kind, error) { if !scalar.CanParse(t.Elem()) { return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the value type %v", t, t.Elem()) } - return mapping, nil + return multiple, nil default: return unsupported, fmt.Errorf("cannot parse into %v", t) } diff --git a/reflect_test.go b/reflect_test.go index 6a8af49..d7a5492 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -7,15 +7,15 @@ import ( "github.com/stretchr/testify/assert" ) -func assertKind(t *testing.T, typ reflect.Type, expected kind) { - actual, err := kindOf(typ) - assert.Equal(t, expected, actual, "expected %v to have kind %v but got %v", typ, expected, actual) +func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) { + actual, err := cardinalityOf(typ) + assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual) if expected == unsupported { assert.Error(t, err) } } -func TestKindOf(t *testing.T) { +func TestCardinalityOf(t *testing.T) { var b bool var i int var s string @@ -27,31 +27,31 @@ func TestKindOf(t *testing.T) { var unsupported2 []struct{} var unsupported3 map[string]struct{} - assertKind(t, reflect.TypeOf(b), binary) - assertKind(t, reflect.TypeOf(i), regular) - assertKind(t, reflect.TypeOf(s), regular) - assertKind(t, reflect.TypeOf(f), regular) + assertCardinality(t, reflect.TypeOf(b), zero) + assertCardinality(t, reflect.TypeOf(i), one) + assertCardinality(t, reflect.TypeOf(s), one) + assertCardinality(t, reflect.TypeOf(f), one) - assertKind(t, reflect.TypeOf(&b), binary) - assertKind(t, reflect.TypeOf(&s), regular) - assertKind(t, reflect.TypeOf(&i), regular) - assertKind(t, reflect.TypeOf(&f), regular) + assertCardinality(t, reflect.TypeOf(&b), zero) + assertCardinality(t, reflect.TypeOf(&s), one) + assertCardinality(t, reflect.TypeOf(&i), one) + assertCardinality(t, reflect.TypeOf(&f), one) - assertKind(t, reflect.TypeOf(bs), sequence) - assertKind(t, reflect.TypeOf(is), sequence) + assertCardinality(t, reflect.TypeOf(bs), multiple) + assertCardinality(t, reflect.TypeOf(is), multiple) - assertKind(t, reflect.TypeOf(&bs), sequence) - assertKind(t, reflect.TypeOf(&is), sequence) + assertCardinality(t, reflect.TypeOf(&bs), multiple) + assertCardinality(t, reflect.TypeOf(&is), multiple) - assertKind(t, reflect.TypeOf(m), mapping) - assertKind(t, reflect.TypeOf(&m), mapping) + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) - assertKind(t, reflect.TypeOf(unsupported1), unsupported) - assertKind(t, reflect.TypeOf(&unsupported1), unsupported) - assertKind(t, reflect.TypeOf(unsupported2), unsupported) - assertKind(t, reflect.TypeOf(&unsupported2), unsupported) - assertKind(t, reflect.TypeOf(unsupported3), unsupported) - assertKind(t, reflect.TypeOf(&unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported) + assertCardinality(t, reflect.TypeOf(unsupported3), unsupported) + assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported) } type implementsTextUnmarshaler struct{} @@ -60,16 +60,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error { return nil } -func TestCanParseTextUnmarshaler(t *testing.T) { +func TestCardinalityTextUnmarshaler(t *testing.T) { var x implementsTextUnmarshaler var s []implementsTextUnmarshaler var m []implementsTextUnmarshaler - assertKind(t, reflect.TypeOf(x), regular) - assertKind(t, reflect.TypeOf(&x), regular) - assertKind(t, reflect.TypeOf(s), sequence) - assertKind(t, reflect.TypeOf(&s), sequence) - assertKind(t, reflect.TypeOf(m), mapping) - assertKind(t, reflect.TypeOf(&m), mapping) + assertCardinality(t, reflect.TypeOf(x), one) + assertCardinality(t, reflect.TypeOf(&x), one) + assertCardinality(t, reflect.TypeOf(s), multiple) + assertCardinality(t, reflect.TypeOf(&s), multiple) + assertCardinality(t, reflect.TypeOf(m), multiple) + assertCardinality(t, reflect.TypeOf(&m), multiple) } func TestIsExported(t *testing.T) { diff --git a/sequence.go b/sequence.go index 8971341..35a3614 100644 --- a/sequence.go +++ b/sequence.go @@ -8,13 +8,32 @@ import ( scalar "github.com/alexflint/go-scalar" ) -// setSlice parses a sequence of strings and inserts them into a slice. If clear -// is true then any values already in the slice are removed. -func setSlice(dest reflect.Value, values []string, clear bool) error { +// setSliceOrMap parses a sequence of strings into a slice or map. If clear is +// true then any values already in the slice or map are first removed. +func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { if !dest.CanSet() { return fmt.Errorf("field is not writable") } + t := dest.Type() + if t.Kind() == reflect.Ptr { + dest = dest.Elem() + t = t.Elem() + } + + switch t.Kind() { + case reflect.Slice: + return setSlice(dest, values, clear) + case reflect.Map: + return setMap(dest, values, clear) + default: + return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t) + } +} + +// setSlice parses a sequence of strings and inserts them into a slice. If clear +// is true then any values already in the slice are removed. +func setSlice(dest reflect.Value, values []string, clear bool) error { var ptr bool elem := dest.Type().Elem() if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { @@ -44,10 +63,6 @@ func setSlice(dest reflect.Value, values []string, clear bool) error { // setMap parses a sequence of name=value strings and inserts them into a map. // If clear is true then any values already in the map are removed. func setMap(dest reflect.Value, values []string, clear bool) error { - if !dest.CanSet() { - return fmt.Errorf("field is not writable") - } - // determine the key and value type var keyIsPtr bool keyType := dest.Type().Key() diff --git a/usage.go b/usage.go index cbbb021..231476b 100644 --- a/usage.go +++ b/usage.go @@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) { for _, spec := range positionals { // prefix with a space fmt.Fprint(w, " ") - if spec.multiple { + if spec.cardinality == multiple { if !spec.required { fmt.Fprint(w, "[") } @@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) { // write the list of built in options p.printOption(w, &spec{ - boolean: true, - long: "help", - short: "h", - help: "display this help and exit", + cardinality: zero, + long: "help", + short: "h", + help: "display this help and exit", }) if p.version != "" { p.printOption(w, &spec{ - boolean: true, - long: "version", - help: "display version and exit", + cardinality: zero, + long: "version", + help: "display version and exit", }) } @@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) { } func synopsis(spec *spec, form string) string { - if spec.boolean { + if spec.cardinality == zero { return form } return form + " " + spec.placeholder