Merge pull request #149 from alexflint/parse-into-map
Add support for parsing into a map
This commit is contained in:
commit
6a01a15f75
15
README.md
15
README.md
|
@ -191,6 +191,7 @@ var args struct {
|
|||
Files []string `arg:"-f,separate"`
|
||||
Databases []string `arg:"positional"`
|
||||
}
|
||||
arg.MustParse(&args)
|
||||
```
|
||||
|
||||
```shell
|
||||
|
@ -200,6 +201,20 @@ Files [file1 file2 file3]
|
|||
Databases [db1 db2 db3]
|
||||
```
|
||||
|
||||
### Arguments with keys and values
|
||||
```go
|
||||
var args struct {
|
||||
UserIDs map[string]int
|
||||
}
|
||||
arg.MustParse(&args)
|
||||
fmt.Println(args.UserIDs)
|
||||
```
|
||||
|
||||
```shell
|
||||
./example --userids john=123 mary=456
|
||||
map[john:123 mary:456]
|
||||
```
|
||||
|
||||
### Custom validation
|
||||
```go
|
||||
var args struct {
|
||||
|
|
|
@ -82,6 +82,19 @@ func Example_multipleValues() {
|
|||
// output: Fetching the following IDs from localhost: [1 2 3]
|
||||
}
|
||||
|
||||
// This example demonstrates arguments with keys and values
|
||||
func Example_mappings() {
|
||||
// The args you would pass in on the command line
|
||||
os.Args = split("./example --userids john=123 mary=456")
|
||||
|
||||
var args struct {
|
||||
UserIDs map[string]int
|
||||
}
|
||||
MustParse(&args)
|
||||
fmt.Println(args.UserIDs)
|
||||
// output: map[john:123 mary:456]
|
||||
}
|
||||
|
||||
// This eample demonstrates multiple value arguments that can be mixed with
|
||||
// other arguments.
|
||||
func Example_multipleMixed() {
|
||||
|
|
74
parse.go
74
parse.go
|
@ -50,15 +50,14 @@ type spec struct {
|
|||
field reflect.StructField // the struct field from which this option was created
|
||||
long string // the --long form for this option, or empty if none
|
||||
short string // the -s short form for this option, or empty if none
|
||||
multiple bool
|
||||
required bool
|
||||
positional bool
|
||||
separate bool
|
||||
help string
|
||||
env string
|
||||
boolean bool
|
||||
defaultVal string // default value for this option
|
||||
placeholder string // name of the data in help
|
||||
cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
|
||||
required bool // if true, this option must be present on the command line
|
||||
positional bool // if true, this option will be looked for in the positional flags
|
||||
separate bool // if true, each slice and map entry will have its own --flag
|
||||
help string // the help text for this option
|
||||
env string // the name of the environment variable for this option, or empty for none
|
||||
defaultVal string // default value for this option
|
||||
placeholder string // name of the data in help
|
||||
}
|
||||
|
||||
// command represents a named subcommand, or the top-level command
|
||||
|
@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
|
|||
if !isSubcommand {
|
||||
cmd.specs = append(cmd.specs, &spec)
|
||||
|
||||
var parseable bool
|
||||
parseable, spec.boolean, spec.multiple = canParse(field.Type)
|
||||
if !parseable {
|
||||
var err error
|
||||
spec.cardinality, err = cardinalityOf(field.Type)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
|
||||
t.Name(), field.Name, field.Type.String()))
|
||||
return false
|
||||
}
|
||||
if spec.multiple && hasDefault {
|
||||
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
|
||||
if spec.cardinality == multiple && hasDefault {
|
||||
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
|
||||
t.Name(), field.Name))
|
||||
return false
|
||||
}
|
||||
|
@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
|
|||
continue
|
||||
}
|
||||
|
||||
if spec.multiple {
|
||||
if spec.cardinality == multiple {
|
||||
// expect a CSV string in an environment
|
||||
// variable in the case of multiple values
|
||||
values, err := csv.NewReader(strings.NewReader(value)).Read()
|
||||
|
@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
|
|||
err,
|
||||
)
|
||||
}
|
||||
if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil {
|
||||
if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil {
|
||||
return fmt.Errorf(
|
||||
"error processing environment variable %s with multiple values: %v",
|
||||
spec.env,
|
||||
|
@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error {
|
|||
wasPresent[spec] = true
|
||||
|
||||
// deal with the case of multiple values
|
||||
if spec.multiple {
|
||||
if spec.cardinality == multiple {
|
||||
var values []string
|
||||
if value == "" {
|
||||
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
|
||||
|
@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error {
|
|||
} else {
|
||||
values = append(values, value)
|
||||
}
|
||||
err := setSlice(p.val(spec.dest), values, !spec.separate)
|
||||
err := setSliceOrMap(p.val(spec.dest), values, !spec.separate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", arg, err)
|
||||
}
|
||||
|
@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error {
|
|||
|
||||
// if it's a flag and it has no value then set the value to true
|
||||
// use boolean because this takes account of TextUnmarshaler
|
||||
if spec.boolean && value == "" {
|
||||
if spec.cardinality == zero && value == "" {
|
||||
value = "true"
|
||||
}
|
||||
|
||||
|
@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error {
|
|||
break
|
||||
}
|
||||
wasPresent[spec] = true
|
||||
if spec.multiple {
|
||||
err := setSlice(p.val(spec.dest), positionals, true)
|
||||
if spec.cardinality == multiple {
|
||||
err := setSliceOrMap(p.val(spec.dest), positionals, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
|
||||
}
|
||||
|
@ -702,37 +701,6 @@ func (p *Parser) val(dest path) reflect.Value {
|
|||
return v
|
||||
}
|
||||
|
||||
// parse a value as the appropriate type and store it in the struct
|
||||
func setSlice(dest reflect.Value, values []string, trunc bool) error {
|
||||
if !dest.CanSet() {
|
||||
return fmt.Errorf("field is not writable")
|
||||
}
|
||||
|
||||
var ptr bool
|
||||
elem := dest.Type().Elem()
|
||||
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
|
||||
ptr = true
|
||||
elem = elem.Elem()
|
||||
}
|
||||
|
||||
// Truncate the dest slice in case default values exist
|
||||
if trunc && !dest.IsNil() {
|
||||
dest.SetLen(0)
|
||||
}
|
||||
|
||||
for _, s := range values {
|
||||
v := reflect.New(elem)
|
||||
if err := scalar.ParseValue(v.Elem(), s); err != nil {
|
||||
return err
|
||||
}
|
||||
if !ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
dest.Set(reflect.Append(dest, v))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findOption finds an option from its name, or returns null if no spec is found
|
||||
func findOption(specs []*spec, name string) *spec {
|
||||
for _, spec := range specs {
|
||||
|
@ -759,7 +727,7 @@ func findSubcommand(cmds []*command, name string) *command {
|
|||
// isZero returns true if v contains the zero value for its type
|
||||
func isZero(v reflect.Value) bool {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Slice {
|
||||
if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
|
||||
return v.IsNil()
|
||||
}
|
||||
if !t.Comparable() {
|
||||
|
|
|
@ -220,6 +220,60 @@ func TestLongFlag(t *testing.T) {
|
|||
assert.Equal(t, "xyz", args.Foo)
|
||||
}
|
||||
|
||||
func TestSlice(t *testing.T) {
|
||||
var args struct {
|
||||
Strings []string
|
||||
}
|
||||
err := parse("--strings a b c", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"a", "b", "c"}, args.Strings)
|
||||
}
|
||||
func TestSliceOfBools(t *testing.T) {
|
||||
var args struct {
|
||||
B []bool
|
||||
}
|
||||
|
||||
err := parse("--b true false true", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []bool{true, false, true}, args.B)
|
||||
}
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
var args struct {
|
||||
Values map[string]int
|
||||
}
|
||||
err := parse("--values a=1 b=2 c=3", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, args.Values, 3)
|
||||
assert.Equal(t, 1, args.Values["a"])
|
||||
assert.Equal(t, 2, args.Values["b"])
|
||||
assert.Equal(t, 3, args.Values["c"])
|
||||
}
|
||||
|
||||
func TestMapPositional(t *testing.T) {
|
||||
var args struct {
|
||||
Values map[string]int `arg:"positional"`
|
||||
}
|
||||
err := parse("a=1 b=2 c=3", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, args.Values, 3)
|
||||
assert.Equal(t, 1, args.Values["a"])
|
||||
assert.Equal(t, 2, args.Values["b"])
|
||||
assert.Equal(t, 3, args.Values["c"])
|
||||
}
|
||||
|
||||
func TestMapWithSeparate(t *testing.T) {
|
||||
var args struct {
|
||||
Values map[string]int `arg:"separate"`
|
||||
}
|
||||
err := parse("--values a=1 --values b=2 --values c=3", &args)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, args.Values, 3)
|
||||
assert.Equal(t, 1, args.Values["a"])
|
||||
assert.Equal(t, 2, args.Values["b"])
|
||||
assert.Equal(t, 3, args.Values["c"])
|
||||
}
|
||||
|
||||
func TestPlaceholder(t *testing.T) {
|
||||
var args struct {
|
||||
Input string `arg:"positional" placeholder:"SRC"`
|
||||
|
@ -688,6 +742,17 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEnvironmentVariableMap(t *testing.T) {
|
||||
var args struct {
|
||||
Foo map[int]string `arg:"env"`
|
||||
}
|
||||
setenv(t, "FOO", "1=one,99=ninetynine")
|
||||
MustParse(&args)
|
||||
assert.Len(t, args.Foo, 2)
|
||||
assert.Equal(t, "one", args.Foo[1])
|
||||
assert.Equal(t, "ninetynine", args.Foo[99])
|
||||
}
|
||||
|
||||
func TestEnvironmentVariableIgnored(t *testing.T) {
|
||||
var args struct {
|
||||
Foo string `arg:"env"`
|
||||
|
@ -1223,7 +1288,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
|
|||
}
|
||||
|
||||
err := parse("", &args)
|
||||
assert.EqualError(t, err, ".A: default values are not supported for slice fields")
|
||||
assert.EqualError(t, err, ".A: default values are not supported for slice or map fields")
|
||||
}
|
||||
|
||||
func TestUnexportedFieldsSkipped(t *testing.T) {
|
||||
|
|
84
reflect.go
84
reflect.go
|
@ -2,6 +2,7 @@ package arg
|
|||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
@ -11,42 +12,67 @@ import (
|
|||
|
||||
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
||||
|
||||
// canParse returns true if the type can be parsed from a string
|
||||
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
|
||||
parseable = scalar.CanParse(t)
|
||||
boolean = isBoolean(t)
|
||||
if parseable {
|
||||
return
|
||||
// cardinality tracks how many tokens are expected for a given spec
|
||||
// - zero is a boolean, which does to expect any value
|
||||
// - one is an ordinary option that will be parsed from a single token
|
||||
// - multiple is a slice or map that can accept zero or more tokens
|
||||
type cardinality int
|
||||
|
||||
const (
|
||||
zero cardinality = iota
|
||||
one
|
||||
multiple
|
||||
unsupported
|
||||
)
|
||||
|
||||
func (k cardinality) String() string {
|
||||
switch k {
|
||||
case zero:
|
||||
return "zero"
|
||||
case one:
|
||||
return "one"
|
||||
case multiple:
|
||||
return "multiple"
|
||||
case unsupported:
|
||||
return "unsupported"
|
||||
default:
|
||||
return fmt.Sprintf("unknown(%d)", int(k))
|
||||
}
|
||||
}
|
||||
|
||||
// cardinalityOf returns true if the type can be parsed from a string
|
||||
func cardinalityOf(t reflect.Type) (cardinality, error) {
|
||||
if scalar.CanParse(t) {
|
||||
if isBoolean(t) {
|
||||
return zero, nil
|
||||
} else {
|
||||
return one, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Look inside pointer types
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
// Look inside slice types
|
||||
if t.Kind() == reflect.Slice {
|
||||
multiple = true
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
parseable = scalar.CanParse(t)
|
||||
boolean = isBoolean(t)
|
||||
if parseable {
|
||||
return
|
||||
}
|
||||
|
||||
// Look inside pointer types (again, in case of []*Type)
|
||||
// look inside pointer types
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
parseable = scalar.CanParse(t)
|
||||
boolean = isBoolean(t)
|
||||
if parseable {
|
||||
return
|
||||
// look inside slice and map types
|
||||
switch t.Kind() {
|
||||
case reflect.Slice:
|
||||
if !scalar.CanParse(t.Elem()) {
|
||||
return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem())
|
||||
}
|
||||
return multiple, nil
|
||||
case reflect.Map:
|
||||
if !scalar.CanParse(t.Key()) {
|
||||
return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem())
|
||||
}
|
||||
if !scalar.CanParse(t.Elem()) {
|
||||
return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem())
|
||||
}
|
||||
return multiple, nil
|
||||
default:
|
||||
return unsupported, fmt.Errorf("cannot parse into %v", t)
|
||||
}
|
||||
|
||||
return false, false, false
|
||||
}
|
||||
|
||||
// isBoolean returns true if the type can be parsed from a single string
|
||||
|
|
|
@ -7,36 +7,54 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) {
|
||||
p, b, m := canParse(typ)
|
||||
assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p)
|
||||
assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b)
|
||||
assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m)
|
||||
func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) {
|
||||
actual, err := cardinalityOf(typ)
|
||||
assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual)
|
||||
if expected == unsupported {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanParse(t *testing.T) {
|
||||
func TestCardinalityOf(t *testing.T) {
|
||||
var b bool
|
||||
var i int
|
||||
var s string
|
||||
var f float64
|
||||
var bs []bool
|
||||
var is []int
|
||||
var m map[string]int
|
||||
var unsupported1 struct{}
|
||||
var unsupported2 []struct{}
|
||||
var unsupported3 map[string]struct{}
|
||||
var unsupported4 map[struct{}]string
|
||||
|
||||
assertCanParse(t, reflect.TypeOf(b), true, true, false)
|
||||
assertCanParse(t, reflect.TypeOf(i), true, false, false)
|
||||
assertCanParse(t, reflect.TypeOf(s), true, false, false)
|
||||
assertCanParse(t, reflect.TypeOf(f), true, false, false)
|
||||
assertCardinality(t, reflect.TypeOf(b), zero)
|
||||
assertCardinality(t, reflect.TypeOf(i), one)
|
||||
assertCardinality(t, reflect.TypeOf(s), one)
|
||||
assertCardinality(t, reflect.TypeOf(f), one)
|
||||
|
||||
assertCanParse(t, reflect.TypeOf(&b), true, true, false)
|
||||
assertCanParse(t, reflect.TypeOf(&s), true, false, false)
|
||||
assertCanParse(t, reflect.TypeOf(&i), true, false, false)
|
||||
assertCanParse(t, reflect.TypeOf(&f), true, false, false)
|
||||
assertCardinality(t, reflect.TypeOf(&b), zero)
|
||||
assertCardinality(t, reflect.TypeOf(&s), one)
|
||||
assertCardinality(t, reflect.TypeOf(&i), one)
|
||||
assertCardinality(t, reflect.TypeOf(&f), one)
|
||||
|
||||
assertCanParse(t, reflect.TypeOf(bs), true, true, true)
|
||||
assertCanParse(t, reflect.TypeOf(&bs), true, true, true)
|
||||
assertCardinality(t, reflect.TypeOf(bs), multiple)
|
||||
assertCardinality(t, reflect.TypeOf(is), multiple)
|
||||
|
||||
assertCanParse(t, reflect.TypeOf(is), true, false, true)
|
||||
assertCanParse(t, reflect.TypeOf(&is), true, false, true)
|
||||
assertCardinality(t, reflect.TypeOf(&bs), multiple)
|
||||
assertCardinality(t, reflect.TypeOf(&is), multiple)
|
||||
|
||||
assertCardinality(t, reflect.TypeOf(m), multiple)
|
||||
assertCardinality(t, reflect.TypeOf(&m), multiple)
|
||||
|
||||
assertCardinality(t, reflect.TypeOf(unsupported1), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(unsupported2), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(unsupported3), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(unsupported4), unsupported)
|
||||
assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported)
|
||||
}
|
||||
|
||||
type implementsTextUnmarshaler struct{}
|
||||
|
@ -45,13 +63,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func TestCanParseTextUnmarshaler(t *testing.T) {
|
||||
var u implementsTextUnmarshaler
|
||||
var su []implementsTextUnmarshaler
|
||||
assertCanParse(t, reflect.TypeOf(u), true, false, false)
|
||||
assertCanParse(t, reflect.TypeOf(&u), true, false, false)
|
||||
assertCanParse(t, reflect.TypeOf(su), true, false, true)
|
||||
assertCanParse(t, reflect.TypeOf(&su), true, false, true)
|
||||
func TestCardinalityTextUnmarshaler(t *testing.T) {
|
||||
var x implementsTextUnmarshaler
|
||||
var s []implementsTextUnmarshaler
|
||||
var m []implementsTextUnmarshaler
|
||||
assertCardinality(t, reflect.TypeOf(x), one)
|
||||
assertCardinality(t, reflect.TypeOf(&x), one)
|
||||
assertCardinality(t, reflect.TypeOf(s), multiple)
|
||||
assertCardinality(t, reflect.TypeOf(&s), multiple)
|
||||
assertCardinality(t, reflect.TypeOf(m), multiple)
|
||||
assertCardinality(t, reflect.TypeOf(&m), multiple)
|
||||
}
|
||||
|
||||
func TestIsExported(t *testing.T) {
|
||||
|
@ -60,3 +81,11 @@ func TestIsExported(t *testing.T) {
|
|||
assert.False(t, isExported(""))
|
||||
assert.False(t, isExported(string([]byte{255})))
|
||||
}
|
||||
|
||||
func TestCardinalityString(t *testing.T) {
|
||||
assert.Equal(t, "zero", zero.String())
|
||||
assert.Equal(t, "one", one.String())
|
||||
assert.Equal(t, "multiple", multiple.String())
|
||||
assert.Equal(t, "unsupported", unsupported.String())
|
||||
assert.Equal(t, "unknown(42)", cardinality(42).String())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
scalar "github.com/alexflint/go-scalar"
|
||||
)
|
||||
|
||||
// setSliceOrMap parses a sequence of strings into a slice or map. If clear is
|
||||
// true then any values already in the slice or map are first removed.
|
||||
func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
|
||||
if !dest.CanSet() {
|
||||
return fmt.Errorf("field is not writable")
|
||||
}
|
||||
|
||||
t := dest.Type()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
dest = dest.Elem()
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Slice:
|
||||
return setSlice(dest, values, clear)
|
||||
case reflect.Map:
|
||||
return setMap(dest, values, clear)
|
||||
default:
|
||||
return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t)
|
||||
}
|
||||
}
|
||||
|
||||
// setSlice parses a sequence of strings and inserts them into a slice. If clear
|
||||
// is true then any values already in the slice are removed.
|
||||
func setSlice(dest reflect.Value, values []string, clear bool) error {
|
||||
var ptr bool
|
||||
elem := dest.Type().Elem()
|
||||
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
|
||||
ptr = true
|
||||
elem = elem.Elem()
|
||||
}
|
||||
|
||||
// clear the slice in case default values exist
|
||||
if clear && !dest.IsNil() {
|
||||
dest.SetLen(0)
|
||||
}
|
||||
|
||||
// parse the values one-by-one
|
||||
for _, s := range values {
|
||||
v := reflect.New(elem)
|
||||
if err := scalar.ParseValue(v.Elem(), s); err != nil {
|
||||
return err
|
||||
}
|
||||
if !ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
dest.Set(reflect.Append(dest, v))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setMap parses a sequence of name=value strings and inserts them into a map.
|
||||
// If clear is true then any values already in the map are removed.
|
||||
func setMap(dest reflect.Value, values []string, clear bool) error {
|
||||
// determine the key and value type
|
||||
var keyIsPtr bool
|
||||
keyType := dest.Type().Key()
|
||||
if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) {
|
||||
keyIsPtr = true
|
||||
keyType = keyType.Elem()
|
||||
}
|
||||
|
||||
var valIsPtr bool
|
||||
valType := dest.Type().Elem()
|
||||
if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) {
|
||||
valIsPtr = true
|
||||
valType = valType.Elem()
|
||||
}
|
||||
|
||||
// clear the slice in case default values exist
|
||||
if clear && !dest.IsNil() {
|
||||
for _, k := range dest.MapKeys() {
|
||||
dest.SetMapIndex(k, reflect.Value{})
|
||||
}
|
||||
}
|
||||
|
||||
// allocate the map if it is not allocated
|
||||
if dest.IsNil() {
|
||||
dest.Set(reflect.MakeMap(dest.Type()))
|
||||
}
|
||||
|
||||
// parse the values one-by-one
|
||||
for _, s := range values {
|
||||
// split at the first equals sign
|
||||
pos := strings.Index(s, "=")
|
||||
if pos == -1 {
|
||||
return fmt.Errorf("cannot parse %q into a map, expected format key=value", s)
|
||||
}
|
||||
|
||||
// parse the key
|
||||
k := reflect.New(keyType)
|
||||
if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil {
|
||||
return err
|
||||
}
|
||||
if !keyIsPtr {
|
||||
k = k.Elem()
|
||||
}
|
||||
|
||||
// parse the value
|
||||
v := reflect.New(valType)
|
||||
if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if !valIsPtr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
// add it to the map
|
||||
dest.SetMapIndex(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetSliceWithoutClearing(t *testing.T) {
|
||||
xs := []int{10}
|
||||
entries := []string{"1", "2", "3"}
|
||||
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []int{10, 1, 2, 3}, xs)
|
||||
}
|
||||
|
||||
func TestSetSliceAfterClearing(t *testing.T) {
|
||||
xs := []int{100}
|
||||
entries := []string{"1", "2", "3"}
|
||||
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []int{1, 2, 3}, xs)
|
||||
}
|
||||
|
||||
func TestSetSliceInvalid(t *testing.T) {
|
||||
xs := []int{100}
|
||||
entries := []string{"invalid"}
|
||||
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSetSlicePtr(t *testing.T) {
|
||||
var xs []*int
|
||||
entries := []string{"1", "2", "3"}
|
||||
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, xs, 3)
|
||||
assert.Equal(t, 1, *xs[0])
|
||||
assert.Equal(t, 2, *xs[1])
|
||||
assert.Equal(t, 3, *xs[2])
|
||||
}
|
||||
|
||||
func TestSetSliceTextUnmarshaller(t *testing.T) {
|
||||
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||
var xs []*textUnmarshaler
|
||||
entries := []string{"a", "aa", "aaa"}
|
||||
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, xs, 3)
|
||||
assert.Equal(t, 1, xs[0].val)
|
||||
assert.Equal(t, 2, xs[1].val)
|
||||
assert.Equal(t, 3, xs[2].val)
|
||||
}
|
||||
|
||||
func TestSetMapWithoutClearing(t *testing.T) {
|
||||
m := map[string]int{"foo": 10}
|
||||
entries := []string{"a=1", "b=2"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, false)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 3)
|
||||
assert.Equal(t, 1, m["a"])
|
||||
assert.Equal(t, 2, m["b"])
|
||||
assert.Equal(t, 10, m["foo"])
|
||||
}
|
||||
|
||||
func TestSetMapAfterClearing(t *testing.T) {
|
||||
m := map[string]int{"foo": 10}
|
||||
entries := []string{"a=1", "b=2"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 2)
|
||||
assert.Equal(t, 1, m["a"])
|
||||
assert.Equal(t, 2, m["b"])
|
||||
}
|
||||
|
||||
func TestSetMapWithKeyPointer(t *testing.T) {
|
||||
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||
var m map[*string]int
|
||||
entries := []string{"abc=123"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 1)
|
||||
}
|
||||
|
||||
func TestSetMapWithValuePointer(t *testing.T) {
|
||||
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||
var m map[string]*int
|
||||
entries := []string{"abc=123"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 1)
|
||||
assert.Equal(t, 123, *m["abc"])
|
||||
}
|
||||
|
||||
func TestSetMapTextUnmarshaller(t *testing.T) {
|
||||
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||
var m map[textUnmarshaler]*textUnmarshaler
|
||||
entries := []string{"a=123", "aa=12", "aaa=1"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 3)
|
||||
assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}])
|
||||
assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}])
|
||||
assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}])
|
||||
}
|
||||
|
||||
func TestSetMapInvalidKey(t *testing.T) {
|
||||
var m map[int]int
|
||||
entries := []string{"invalid=123"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSetMapInvalidValue(t *testing.T) {
|
||||
var m map[int]int
|
||||
entries := []string{"123=invalid"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSetMapMalformed(t *testing.T) {
|
||||
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||
var m map[string]string
|
||||
entries := []string{"missing_equals_sign"}
|
||||
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSetSliceOrMapErrors(t *testing.T) {
|
||||
var err error
|
||||
var dest reflect.Value
|
||||
|
||||
// converting a slice to a reflect.Value in this way will make it read only
|
||||
var cannotSet []int
|
||||
dest = reflect.ValueOf(cannotSet)
|
||||
err = setSliceOrMap(dest, nil, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
// check what happens when we pass in something that is not a slice or a map
|
||||
var notSliceOrMap string
|
||||
dest = reflect.ValueOf(¬SliceOrMap).Elem()
|
||||
err = setSliceOrMap(dest, nil, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
// check what happens when we pass in a pointer to something that is not a slice or a map
|
||||
var stringPtr *string
|
||||
dest = reflect.ValueOf(&stringPtr).Elem()
|
||||
err = setSliceOrMap(dest, nil, false)
|
||||
assert.Error(t, err)
|
||||
}
|
18
usage.go
18
usage.go
|
@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
|
|||
for _, spec := range positionals {
|
||||
// prefix with a space
|
||||
fmt.Fprint(w, " ")
|
||||
if spec.multiple {
|
||||
if spec.cardinality == multiple {
|
||||
if !spec.required {
|
||||
fmt.Fprint(w, "[")
|
||||
}
|
||||
|
@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
|
|||
|
||||
// write the list of built in options
|
||||
p.printOption(w, &spec{
|
||||
boolean: true,
|
||||
long: "help",
|
||||
short: "h",
|
||||
help: "display this help and exit",
|
||||
cardinality: zero,
|
||||
long: "help",
|
||||
short: "h",
|
||||
help: "display this help and exit",
|
||||
})
|
||||
if p.version != "" {
|
||||
p.printOption(w, &spec{
|
||||
boolean: true,
|
||||
long: "version",
|
||||
help: "display version and exit",
|
||||
cardinality: zero,
|
||||
long: "version",
|
||||
help: "display version and exit",
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
|
|||
}
|
||||
|
||||
func synopsis(spec *spec, form string) string {
|
||||
if spec.boolean {
|
||||
if spec.cardinality == zero {
|
||||
return form
|
||||
}
|
||||
return form + " " + spec.placeholder
|
||||
|
|
Loading…
Reference in New Issue