change "kind" to "cardinality", add support for maps to parser

This commit is contained in:
Alex Flint 2021-04-19 13:21:04 -07:00
parent 23b96d7aac
commit 9949860eb3
6 changed files with 120 additions and 90 deletions

View File

@ -50,13 +50,12 @@ type spec struct {
field reflect.StructField // the struct field from which this option was created field reflect.StructField // the struct field from which this option was created
long string // the --long form for this option, or empty if none long string // the --long form for this option, or empty if none
short string // the -s short form for this option, or empty if none short string // the -s short form for this option, or empty if none
multiple bool cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
required bool required bool // if true, this option must be present on the command line
positional bool positional bool // if true, this option will be looked for in the positional flags
separate bool separate bool // if true,
help string help string
env string env string
boolean bool
defaultVal string // default value for this option defaultVal string // default value for this option
placeholder string // name of the data in help placeholder string // name of the data in help
} }
@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
if !isSubcommand { if !isSubcommand {
cmd.specs = append(cmd.specs, &spec) cmd.specs = append(cmd.specs, &spec)
var parseable bool var err error
//parseable, spec.boolean, spec.multiple = canParse(field.Type) spec.cardinality, err = cardinalityOf(field.Type)
if !parseable { if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String())) t.Name(), field.Name, field.Type.String()))
return false return false
} }
if spec.multiple && hasDefault { if spec.cardinality == multiple && hasDefault {
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields", errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
t.Name(), field.Name)) t.Name(), field.Name))
return false return false
} }
@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
continue continue
} }
if spec.multiple { if spec.cardinality == multiple {
// expect a CSV string in an environment // expect a CSV string in an environment
// variable in the case of multiple values // variable in the case of multiple values
values, err := csv.NewReader(strings.NewReader(value)).Read() values, err := csv.NewReader(strings.NewReader(value)).Read()
@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
err, err,
) )
} }
if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil { if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil {
return fmt.Errorf( return fmt.Errorf(
"error processing environment variable %s with multiple values: %v", "error processing environment variable %s with multiple values: %v",
spec.env, spec.env,
@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error {
wasPresent[spec] = true wasPresent[spec] = true
// deal with the case of multiple values // deal with the case of multiple values
if spec.multiple { if spec.cardinality == multiple {
var values []string var values []string
if value == "" { if value == "" {
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" { for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error {
} else { } else {
values = append(values, value) values = append(values, value)
} }
err := setSlice(p.val(spec.dest), values, !spec.separate) err := setSliceOrMap(p.val(spec.dest), values, !spec.separate)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err) return fmt.Errorf("error processing %s: %v", arg, err)
} }
@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error {
// if it's a flag and it has no value then set the value to true // if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler // use boolean because this takes account of TextUnmarshaler
if spec.boolean && value == "" { if spec.cardinality == zero && value == "" {
value = "true" value = "true"
} }
@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error {
break break
} }
wasPresent[spec] = true wasPresent[spec] = true
if spec.multiple { if spec.cardinality == multiple {
err := setSlice(p.val(spec.dest), positionals, true) err := setSliceOrMap(p.val(spec.dest), positionals, true)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", spec.field.Name, err) return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
} }

View File

@ -220,6 +220,14 @@ func TestLongFlag(t *testing.T) {
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
} }
func TestSlice(t *testing.T) {
var args struct {
Strings []string
}
err := parse("--strings a b c", &args)
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, args.Strings)
}
func TestSliceOfBools(t *testing.T) { func TestSliceOfBools(t *testing.T) {
var args struct { var args struct {
B []bool B []bool
@ -230,6 +238,18 @@ func TestSliceOfBools(t *testing.T) {
assert.Equal(t, []bool{true, false, true}, args.B) assert.Equal(t, []bool{true, false, true}, args.B)
} }
func TestMap(t *testing.T) {
var args struct {
Values map[string]int
}
err := parse("--values a=1 b=2 c=3", &args)
require.NoError(t, err)
assert.Len(t, args.Values, 3)
assert.Equal(t, 1, args.Values["a"])
assert.Equal(t, 2, args.Values["b"])
assert.Equal(t, 3, args.Values["c"])
}
func TestPlaceholder(t *testing.T) { func TestPlaceholder(t *testing.T) {
var args struct { var args struct {
Input string `arg:"positional" placeholder:"SRC"` Input string `arg:"positional" placeholder:"SRC"`
@ -1233,7 +1253,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
} }
err := parse("", &args) err := parse("", &args)
assert.EqualError(t, err, ".A: default values are not supported for slice fields") assert.EqualError(t, err, ".A: default values are not supported for slice or map fields")
} }
func TestUnexportedFieldsSkipped(t *testing.T) { func TestUnexportedFieldsSkipped(t *testing.T) {

View File

@ -12,31 +12,27 @@ import (
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// kind is used to track the various kinds of options: // cardinality tracks how many tokens are expected for a given spec
// - regular is an ordinary option that will be parsed from a single token // - zero is a boolean, which does to expect any value
// - binary is an option that will be true if present but does not expect an explicit value // - one is an ordinary option that will be parsed from a single token
// - sequence is an option that accepts multiple values and will end up in a slice // - multiple is a slice or map that can accept zero or more tokens
// - mapping is an option that acccepts multiple key=value strings and will end up in a map type cardinality int
type kind int
const ( const (
regular kind = iota zero cardinality = iota
binary one
sequence multiple
mapping
unsupported unsupported
) )
func (k kind) String() string { func (k cardinality) String() string {
switch k { switch k {
case regular: case zero:
return "regular" return "zero"
case binary: case one:
return "binary" return "one"
case sequence: case multiple:
return "sequence" return "multiple"
case mapping:
return "mapping"
case unsupported: case unsupported:
return "unsupported" return "unsupported"
default: default:
@ -44,13 +40,13 @@ func (k kind) String() string {
} }
} }
// kindOf returns true if the type can be parsed from a string // cardinalityOf returns true if the type can be parsed from a string
func kindOf(t reflect.Type) (kind, error) { func cardinalityOf(t reflect.Type) (cardinality, error) {
if scalar.CanParse(t) { if scalar.CanParse(t) {
if isBoolean(t) { if isBoolean(t) {
return binary, nil return zero, nil
} else { } else {
return regular, nil return one, nil
} }
} }
@ -65,7 +61,7 @@ func kindOf(t reflect.Type) (kind, error) {
if !scalar.CanParse(t.Elem()) { if !scalar.CanParse(t.Elem()) {
return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into %v", t, t.Elem()) return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into %v", t, t.Elem())
} }
return sequence, nil return multiple, nil
case reflect.Map: case reflect.Map:
if !scalar.CanParse(t.Key()) { if !scalar.CanParse(t.Key()) {
return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the key type %v", t, t.Elem()) return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the key type %v", t, t.Elem())
@ -73,7 +69,7 @@ func kindOf(t reflect.Type) (kind, error) {
if !scalar.CanParse(t.Elem()) { if !scalar.CanParse(t.Elem()) {
return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the value type %v", t, t.Elem()) return unsupported, fmt.Errorf("cannot parse into %v because we cannot parse into the value type %v", t, t.Elem())
} }
return mapping, nil return multiple, nil
default: default:
return unsupported, fmt.Errorf("cannot parse into %v", t) return unsupported, fmt.Errorf("cannot parse into %v", t)
} }

View File

@ -7,15 +7,15 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func assertKind(t *testing.T, typ reflect.Type, expected kind) { func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) {
actual, err := kindOf(typ) actual, err := cardinalityOf(typ)
assert.Equal(t, expected, actual, "expected %v to have kind %v but got %v", typ, expected, actual) assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual)
if expected == unsupported { if expected == unsupported {
assert.Error(t, err) assert.Error(t, err)
} }
} }
func TestKindOf(t *testing.T) { func TestCardinalityOf(t *testing.T) {
var b bool var b bool
var i int var i int
var s string var s string
@ -27,31 +27,31 @@ func TestKindOf(t *testing.T) {
var unsupported2 []struct{} var unsupported2 []struct{}
var unsupported3 map[string]struct{} var unsupported3 map[string]struct{}
assertKind(t, reflect.TypeOf(b), binary) assertCardinality(t, reflect.TypeOf(b), zero)
assertKind(t, reflect.TypeOf(i), regular) assertCardinality(t, reflect.TypeOf(i), one)
assertKind(t, reflect.TypeOf(s), regular) assertCardinality(t, reflect.TypeOf(s), one)
assertKind(t, reflect.TypeOf(f), regular) assertCardinality(t, reflect.TypeOf(f), one)
assertKind(t, reflect.TypeOf(&b), binary) assertCardinality(t, reflect.TypeOf(&b), zero)
assertKind(t, reflect.TypeOf(&s), regular) assertCardinality(t, reflect.TypeOf(&s), one)
assertKind(t, reflect.TypeOf(&i), regular) assertCardinality(t, reflect.TypeOf(&i), one)
assertKind(t, reflect.TypeOf(&f), regular) assertCardinality(t, reflect.TypeOf(&f), one)
assertKind(t, reflect.TypeOf(bs), sequence) assertCardinality(t, reflect.TypeOf(bs), multiple)
assertKind(t, reflect.TypeOf(is), sequence) assertCardinality(t, reflect.TypeOf(is), multiple)
assertKind(t, reflect.TypeOf(&bs), sequence) assertCardinality(t, reflect.TypeOf(&bs), multiple)
assertKind(t, reflect.TypeOf(&is), sequence) assertCardinality(t, reflect.TypeOf(&is), multiple)
assertKind(t, reflect.TypeOf(m), mapping) assertCardinality(t, reflect.TypeOf(m), multiple)
assertKind(t, reflect.TypeOf(&m), mapping) assertCardinality(t, reflect.TypeOf(&m), multiple)
assertKind(t, reflect.TypeOf(unsupported1), unsupported) assertCardinality(t, reflect.TypeOf(unsupported1), unsupported)
assertKind(t, reflect.TypeOf(&unsupported1), unsupported) assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported)
assertKind(t, reflect.TypeOf(unsupported2), unsupported) assertCardinality(t, reflect.TypeOf(unsupported2), unsupported)
assertKind(t, reflect.TypeOf(&unsupported2), unsupported) assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported)
assertKind(t, reflect.TypeOf(unsupported3), unsupported) assertCardinality(t, reflect.TypeOf(unsupported3), unsupported)
assertKind(t, reflect.TypeOf(&unsupported3), unsupported) assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported)
} }
type implementsTextUnmarshaler struct{} type implementsTextUnmarshaler struct{}
@ -60,16 +60,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
return nil return nil
} }
func TestCanParseTextUnmarshaler(t *testing.T) { func TestCardinalityTextUnmarshaler(t *testing.T) {
var x implementsTextUnmarshaler var x implementsTextUnmarshaler
var s []implementsTextUnmarshaler var s []implementsTextUnmarshaler
var m []implementsTextUnmarshaler var m []implementsTextUnmarshaler
assertKind(t, reflect.TypeOf(x), regular) assertCardinality(t, reflect.TypeOf(x), one)
assertKind(t, reflect.TypeOf(&x), regular) assertCardinality(t, reflect.TypeOf(&x), one)
assertKind(t, reflect.TypeOf(s), sequence) assertCardinality(t, reflect.TypeOf(s), multiple)
assertKind(t, reflect.TypeOf(&s), sequence) assertCardinality(t, reflect.TypeOf(&s), multiple)
assertKind(t, reflect.TypeOf(m), mapping) assertCardinality(t, reflect.TypeOf(m), multiple)
assertKind(t, reflect.TypeOf(&m), mapping) assertCardinality(t, reflect.TypeOf(&m), multiple)
} }
func TestIsExported(t *testing.T) { func TestIsExported(t *testing.T) {

View File

@ -8,13 +8,32 @@ import (
scalar "github.com/alexflint/go-scalar" scalar "github.com/alexflint/go-scalar"
) )
// setSlice parses a sequence of strings and inserts them into a slice. If clear // setSliceOrMap parses a sequence of strings into a slice or map. If clear is
// is true then any values already in the slice are removed. // true then any values already in the slice or map are first removed.
func setSlice(dest reflect.Value, values []string, clear bool) error { func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
if !dest.CanSet() { if !dest.CanSet() {
return fmt.Errorf("field is not writable") return fmt.Errorf("field is not writable")
} }
t := dest.Type()
if t.Kind() == reflect.Ptr {
dest = dest.Elem()
t = t.Elem()
}
switch t.Kind() {
case reflect.Slice:
return setSlice(dest, values, clear)
case reflect.Map:
return setMap(dest, values, clear)
default:
return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t)
}
}
// setSlice parses a sequence of strings and inserts them into a slice. If clear
// is true then any values already in the slice are removed.
func setSlice(dest reflect.Value, values []string, clear bool) error {
var ptr bool var ptr bool
elem := dest.Type().Elem() elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
@ -44,10 +63,6 @@ func setSlice(dest reflect.Value, values []string, clear bool) error {
// setMap parses a sequence of name=value strings and inserts them into a map. // setMap parses a sequence of name=value strings and inserts them into a map.
// If clear is true then any values already in the map are removed. // If clear is true then any values already in the map are removed.
func setMap(dest reflect.Value, values []string, clear bool) error { func setMap(dest reflect.Value, values []string, clear bool) error {
if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
// determine the key and value type // determine the key and value type
var keyIsPtr bool var keyIsPtr bool
keyType := dest.Type().Key() keyType := dest.Type().Key()

View File

@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
for _, spec := range positionals { for _, spec := range positionals {
// prefix with a space // prefix with a space
fmt.Fprint(w, " ") fmt.Fprint(w, " ")
if spec.multiple { if spec.cardinality == multiple {
if !spec.required { if !spec.required {
fmt.Fprint(w, "[") fmt.Fprint(w, "[")
} }
@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
// write the list of built in options // write the list of built in options
p.printOption(w, &spec{ p.printOption(w, &spec{
boolean: true, cardinality: zero,
long: "help", long: "help",
short: "h", short: "h",
help: "display this help and exit", help: "display this help and exit",
}) })
if p.version != "" { if p.version != "" {
p.printOption(w, &spec{ p.printOption(w, &spec{
boolean: true, cardinality: zero,
long: "version", long: "version",
help: "display version and exit", help: "display version and exit",
}) })
} }
@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
} }
func synopsis(spec *spec, form string) string { func synopsis(spec *spec, form string) string {
if spec.boolean { if spec.cardinality == zero {
return form return form
} }
return form + " " + spec.placeholder return form + " " + spec.placeholder