Merge pull request #149 from alexflint/parse-into-map
Add support for parsing into a map
This commit is contained in:
commit
6a01a15f75
15
README.md
15
README.md
|
@ -191,6 +191,7 @@ var args struct {
|
||||||
Files []string `arg:"-f,separate"`
|
Files []string `arg:"-f,separate"`
|
||||||
Databases []string `arg:"positional"`
|
Databases []string `arg:"positional"`
|
||||||
}
|
}
|
||||||
|
arg.MustParse(&args)
|
||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -200,6 +201,20 @@ Files [file1 file2 file3]
|
||||||
Databases [db1 db2 db3]
|
Databases [db1 db2 db3]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Arguments with keys and values
|
||||||
|
```go
|
||||||
|
var args struct {
|
||||||
|
UserIDs map[string]int
|
||||||
|
}
|
||||||
|
arg.MustParse(&args)
|
||||||
|
fmt.Println(args.UserIDs)
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
./example --userids john=123 mary=456
|
||||||
|
map[john:123 mary:456]
|
||||||
|
```
|
||||||
|
|
||||||
### Custom validation
|
### Custom validation
|
||||||
```go
|
```go
|
||||||
var args struct {
|
var args struct {
|
||||||
|
|
|
@ -82,6 +82,19 @@ func Example_multipleValues() {
|
||||||
// output: Fetching the following IDs from localhost: [1 2 3]
|
// output: Fetching the following IDs from localhost: [1 2 3]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This example demonstrates arguments with keys and values
|
||||||
|
func Example_mappings() {
|
||||||
|
// The args you would pass in on the command line
|
||||||
|
os.Args = split("./example --userids john=123 mary=456")
|
||||||
|
|
||||||
|
var args struct {
|
||||||
|
UserIDs map[string]int
|
||||||
|
}
|
||||||
|
MustParse(&args)
|
||||||
|
fmt.Println(args.UserIDs)
|
||||||
|
// output: map[john:123 mary:456]
|
||||||
|
}
|
||||||
|
|
||||||
// This eample demonstrates multiple value arguments that can be mixed with
|
// This eample demonstrates multiple value arguments that can be mixed with
|
||||||
// other arguments.
|
// other arguments.
|
||||||
func Example_multipleMixed() {
|
func Example_multipleMixed() {
|
||||||
|
|
74
parse.go
74
parse.go
|
@ -50,15 +50,14 @@ type spec struct {
|
||||||
field reflect.StructField // the struct field from which this option was created
|
field reflect.StructField // the struct field from which this option was created
|
||||||
long string // the --long form for this option, or empty if none
|
long string // the --long form for this option, or empty if none
|
||||||
short string // the -s short form for this option, or empty if none
|
short string // the -s short form for this option, or empty if none
|
||||||
multiple bool
|
cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
|
||||||
required bool
|
required bool // if true, this option must be present on the command line
|
||||||
positional bool
|
positional bool // if true, this option will be looked for in the positional flags
|
||||||
separate bool
|
separate bool // if true, each slice and map entry will have its own --flag
|
||||||
help string
|
help string // the help text for this option
|
||||||
env string
|
env string // the name of the environment variable for this option, or empty for none
|
||||||
boolean bool
|
defaultVal string // default value for this option
|
||||||
defaultVal string // default value for this option
|
placeholder string // name of the data in help
|
||||||
placeholder string // name of the data in help
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// command represents a named subcommand, or the top-level command
|
// command represents a named subcommand, or the top-level command
|
||||||
|
@ -376,15 +375,15 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
|
||||||
if !isSubcommand {
|
if !isSubcommand {
|
||||||
cmd.specs = append(cmd.specs, &spec)
|
cmd.specs = append(cmd.specs, &spec)
|
||||||
|
|
||||||
var parseable bool
|
var err error
|
||||||
parseable, spec.boolean, spec.multiple = canParse(field.Type)
|
spec.cardinality, err = cardinalityOf(field.Type)
|
||||||
if !parseable {
|
if err != nil {
|
||||||
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
|
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
|
||||||
t.Name(), field.Name, field.Type.String()))
|
t.Name(), field.Name, field.Type.String()))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if spec.multiple && hasDefault {
|
if spec.cardinality == multiple && hasDefault {
|
||||||
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
|
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
|
||||||
t.Name(), field.Name))
|
t.Name(), field.Name))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -442,7 +441,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if spec.multiple {
|
if spec.cardinality == multiple {
|
||||||
// expect a CSV string in an environment
|
// expect a CSV string in an environment
|
||||||
// variable in the case of multiple values
|
// variable in the case of multiple values
|
||||||
values, err := csv.NewReader(strings.NewReader(value)).Read()
|
values, err := csv.NewReader(strings.NewReader(value)).Read()
|
||||||
|
@ -453,7 +452,7 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if err = setSlice(p.val(spec.dest), values, !spec.separate); err != nil {
|
if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"error processing environment variable %s with multiple values: %v",
|
"error processing environment variable %s with multiple values: %v",
|
||||||
spec.env,
|
spec.env,
|
||||||
|
@ -563,7 +562,7 @@ func (p *Parser) process(args []string) error {
|
||||||
wasPresent[spec] = true
|
wasPresent[spec] = true
|
||||||
|
|
||||||
// deal with the case of multiple values
|
// deal with the case of multiple values
|
||||||
if spec.multiple {
|
if spec.cardinality == multiple {
|
||||||
var values []string
|
var values []string
|
||||||
if value == "" {
|
if value == "" {
|
||||||
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
|
for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
|
||||||
|
@ -576,7 +575,7 @@ func (p *Parser) process(args []string) error {
|
||||||
} else {
|
} else {
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
err := setSlice(p.val(spec.dest), values, !spec.separate)
|
err := setSliceOrMap(p.val(spec.dest), values, !spec.separate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", arg, err)
|
return fmt.Errorf("error processing %s: %v", arg, err)
|
||||||
}
|
}
|
||||||
|
@ -585,7 +584,7 @@ func (p *Parser) process(args []string) error {
|
||||||
|
|
||||||
// if it's a flag and it has no value then set the value to true
|
// if it's a flag and it has no value then set the value to true
|
||||||
// use boolean because this takes account of TextUnmarshaler
|
// use boolean because this takes account of TextUnmarshaler
|
||||||
if spec.boolean && value == "" {
|
if spec.cardinality == zero && value == "" {
|
||||||
value = "true"
|
value = "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -616,8 +615,8 @@ func (p *Parser) process(args []string) error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
wasPresent[spec] = true
|
wasPresent[spec] = true
|
||||||
if spec.multiple {
|
if spec.cardinality == multiple {
|
||||||
err := setSlice(p.val(spec.dest), positionals, true)
|
err := setSliceOrMap(p.val(spec.dest), positionals, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
|
return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
|
||||||
}
|
}
|
||||||
|
@ -702,37 +701,6 @@ func (p *Parser) val(dest path) reflect.Value {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse a value as the appropriate type and store it in the struct
|
|
||||||
func setSlice(dest reflect.Value, values []string, trunc bool) error {
|
|
||||||
if !dest.CanSet() {
|
|
||||||
return fmt.Errorf("field is not writable")
|
|
||||||
}
|
|
||||||
|
|
||||||
var ptr bool
|
|
||||||
elem := dest.Type().Elem()
|
|
||||||
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
|
|
||||||
ptr = true
|
|
||||||
elem = elem.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Truncate the dest slice in case default values exist
|
|
||||||
if trunc && !dest.IsNil() {
|
|
||||||
dest.SetLen(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, s := range values {
|
|
||||||
v := reflect.New(elem)
|
|
||||||
if err := scalar.ParseValue(v.Elem(), s); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
dest.Set(reflect.Append(dest, v))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// findOption finds an option from its name, or returns null if no spec is found
|
// findOption finds an option from its name, or returns null if no spec is found
|
||||||
func findOption(specs []*spec, name string) *spec {
|
func findOption(specs []*spec, name string) *spec {
|
||||||
for _, spec := range specs {
|
for _, spec := range specs {
|
||||||
|
@ -759,7 +727,7 @@ func findSubcommand(cmds []*command, name string) *command {
|
||||||
// isZero returns true if v contains the zero value for its type
|
// isZero returns true if v contains the zero value for its type
|
||||||
func isZero(v reflect.Value) bool {
|
func isZero(v reflect.Value) bool {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
if t.Kind() == reflect.Slice {
|
if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
|
||||||
return v.IsNil()
|
return v.IsNil()
|
||||||
}
|
}
|
||||||
if !t.Comparable() {
|
if !t.Comparable() {
|
||||||
|
|
|
@ -220,6 +220,60 @@ func TestLongFlag(t *testing.T) {
|
||||||
assert.Equal(t, "xyz", args.Foo)
|
assert.Equal(t, "xyz", args.Foo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSlice(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Strings []string
|
||||||
|
}
|
||||||
|
err := parse("--strings a b c", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"a", "b", "c"}, args.Strings)
|
||||||
|
}
|
||||||
|
func TestSliceOfBools(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
B []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
err := parse("--b true false true", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []bool{true, false, true}, args.B)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Values map[string]int
|
||||||
|
}
|
||||||
|
err := parse("--values a=1 b=2 c=3", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, args.Values, 3)
|
||||||
|
assert.Equal(t, 1, args.Values["a"])
|
||||||
|
assert.Equal(t, 2, args.Values["b"])
|
||||||
|
assert.Equal(t, 3, args.Values["c"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapPositional(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Values map[string]int `arg:"positional"`
|
||||||
|
}
|
||||||
|
err := parse("a=1 b=2 c=3", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, args.Values, 3)
|
||||||
|
assert.Equal(t, 1, args.Values["a"])
|
||||||
|
assert.Equal(t, 2, args.Values["b"])
|
||||||
|
assert.Equal(t, 3, args.Values["c"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapWithSeparate(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Values map[string]int `arg:"separate"`
|
||||||
|
}
|
||||||
|
err := parse("--values a=1 --values b=2 --values c=3", &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, args.Values, 3)
|
||||||
|
assert.Equal(t, 1, args.Values["a"])
|
||||||
|
assert.Equal(t, 2, args.Values["b"])
|
||||||
|
assert.Equal(t, 3, args.Values["c"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestPlaceholder(t *testing.T) {
|
func TestPlaceholder(t *testing.T) {
|
||||||
var args struct {
|
var args struct {
|
||||||
Input string `arg:"positional" placeholder:"SRC"`
|
Input string `arg:"positional" placeholder:"SRC"`
|
||||||
|
@ -688,6 +742,17 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnvironmentVariableMap(t *testing.T) {
|
||||||
|
var args struct {
|
||||||
|
Foo map[int]string `arg:"env"`
|
||||||
|
}
|
||||||
|
setenv(t, "FOO", "1=one,99=ninetynine")
|
||||||
|
MustParse(&args)
|
||||||
|
assert.Len(t, args.Foo, 2)
|
||||||
|
assert.Equal(t, "one", args.Foo[1])
|
||||||
|
assert.Equal(t, "ninetynine", args.Foo[99])
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnvironmentVariableIgnored(t *testing.T) {
|
func TestEnvironmentVariableIgnored(t *testing.T) {
|
||||||
var args struct {
|
var args struct {
|
||||||
Foo string `arg:"env"`
|
Foo string `arg:"env"`
|
||||||
|
@ -1223,7 +1288,7 @@ func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err := parse("", &args)
|
err := parse("", &args)
|
||||||
assert.EqualError(t, err, ".A: default values are not supported for slice fields")
|
assert.EqualError(t, err, ".A: default values are not supported for slice or map fields")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnexportedFieldsSkipped(t *testing.T) {
|
func TestUnexportedFieldsSkipped(t *testing.T) {
|
||||||
|
|
84
reflect.go
84
reflect.go
|
@ -2,6 +2,7 @@ package arg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding"
|
"encoding"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -11,42 +12,67 @@ import (
|
||||||
|
|
||||||
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
|
||||||
|
|
||||||
// canParse returns true if the type can be parsed from a string
|
// cardinality tracks how many tokens are expected for a given spec
|
||||||
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
|
// - zero is a boolean, which does to expect any value
|
||||||
parseable = scalar.CanParse(t)
|
// - one is an ordinary option that will be parsed from a single token
|
||||||
boolean = isBoolean(t)
|
// - multiple is a slice or map that can accept zero or more tokens
|
||||||
if parseable {
|
type cardinality int
|
||||||
return
|
|
||||||
|
const (
|
||||||
|
zero cardinality = iota
|
||||||
|
one
|
||||||
|
multiple
|
||||||
|
unsupported
|
||||||
|
)
|
||||||
|
|
||||||
|
func (k cardinality) String() string {
|
||||||
|
switch k {
|
||||||
|
case zero:
|
||||||
|
return "zero"
|
||||||
|
case one:
|
||||||
|
return "one"
|
||||||
|
case multiple:
|
||||||
|
return "multiple"
|
||||||
|
case unsupported:
|
||||||
|
return "unsupported"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown(%d)", int(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cardinalityOf returns true if the type can be parsed from a string
|
||||||
|
func cardinalityOf(t reflect.Type) (cardinality, error) {
|
||||||
|
if scalar.CanParse(t) {
|
||||||
|
if isBoolean(t) {
|
||||||
|
return zero, nil
|
||||||
|
} else {
|
||||||
|
return one, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look inside pointer types
|
// look inside pointer types
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
// Look inside slice types
|
|
||||||
if t.Kind() == reflect.Slice {
|
|
||||||
multiple = true
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
parseable = scalar.CanParse(t)
|
|
||||||
boolean = isBoolean(t)
|
|
||||||
if parseable {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look inside pointer types (again, in case of []*Type)
|
|
||||||
if t.Kind() == reflect.Ptr {
|
if t.Kind() == reflect.Ptr {
|
||||||
t = t.Elem()
|
t = t.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
parseable = scalar.CanParse(t)
|
// look inside slice and map types
|
||||||
boolean = isBoolean(t)
|
switch t.Kind() {
|
||||||
if parseable {
|
case reflect.Slice:
|
||||||
return
|
if !scalar.CanParse(t.Elem()) {
|
||||||
|
return unsupported, fmt.Errorf("cannot parse into %v because %v not supported", t, t.Elem())
|
||||||
|
}
|
||||||
|
return multiple, nil
|
||||||
|
case reflect.Map:
|
||||||
|
if !scalar.CanParse(t.Key()) {
|
||||||
|
return unsupported, fmt.Errorf("cannot parse into %v because key type %v not supported", t, t.Elem())
|
||||||
|
}
|
||||||
|
if !scalar.CanParse(t.Elem()) {
|
||||||
|
return unsupported, fmt.Errorf("cannot parse into %v because value type %v not supported", t, t.Elem())
|
||||||
|
}
|
||||||
|
return multiple, nil
|
||||||
|
default:
|
||||||
|
return unsupported, fmt.Errorf("cannot parse into %v", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, false, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// isBoolean returns true if the type can be parsed from a single string
|
// isBoolean returns true if the type can be parsed from a single string
|
||||||
|
|
|
@ -7,36 +7,54 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertCanParse(t *testing.T, typ reflect.Type, parseable, boolean, multiple bool) {
|
func assertCardinality(t *testing.T, typ reflect.Type, expected cardinality) {
|
||||||
p, b, m := canParse(typ)
|
actual, err := cardinalityOf(typ)
|
||||||
assert.Equal(t, parseable, p, "expected %v to have parseable=%v but was %v", typ, parseable, p)
|
assert.Equal(t, expected, actual, "expected %v to have cardinality %v but got %v", typ, expected, actual)
|
||||||
assert.Equal(t, boolean, b, "expected %v to have boolean=%v but was %v", typ, boolean, b)
|
if expected == unsupported {
|
||||||
assert.Equal(t, multiple, m, "expected %v to have multiple=%v but was %v", typ, multiple, m)
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanParse(t *testing.T) {
|
func TestCardinalityOf(t *testing.T) {
|
||||||
var b bool
|
var b bool
|
||||||
var i int
|
var i int
|
||||||
var s string
|
var s string
|
||||||
var f float64
|
var f float64
|
||||||
var bs []bool
|
var bs []bool
|
||||||
var is []int
|
var is []int
|
||||||
|
var m map[string]int
|
||||||
|
var unsupported1 struct{}
|
||||||
|
var unsupported2 []struct{}
|
||||||
|
var unsupported3 map[string]struct{}
|
||||||
|
var unsupported4 map[struct{}]string
|
||||||
|
|
||||||
assertCanParse(t, reflect.TypeOf(b), true, true, false)
|
assertCardinality(t, reflect.TypeOf(b), zero)
|
||||||
assertCanParse(t, reflect.TypeOf(i), true, false, false)
|
assertCardinality(t, reflect.TypeOf(i), one)
|
||||||
assertCanParse(t, reflect.TypeOf(s), true, false, false)
|
assertCardinality(t, reflect.TypeOf(s), one)
|
||||||
assertCanParse(t, reflect.TypeOf(f), true, false, false)
|
assertCardinality(t, reflect.TypeOf(f), one)
|
||||||
|
|
||||||
assertCanParse(t, reflect.TypeOf(&b), true, true, false)
|
assertCardinality(t, reflect.TypeOf(&b), zero)
|
||||||
assertCanParse(t, reflect.TypeOf(&s), true, false, false)
|
assertCardinality(t, reflect.TypeOf(&s), one)
|
||||||
assertCanParse(t, reflect.TypeOf(&i), true, false, false)
|
assertCardinality(t, reflect.TypeOf(&i), one)
|
||||||
assertCanParse(t, reflect.TypeOf(&f), true, false, false)
|
assertCardinality(t, reflect.TypeOf(&f), one)
|
||||||
|
|
||||||
assertCanParse(t, reflect.TypeOf(bs), true, true, true)
|
assertCardinality(t, reflect.TypeOf(bs), multiple)
|
||||||
assertCanParse(t, reflect.TypeOf(&bs), true, true, true)
|
assertCardinality(t, reflect.TypeOf(is), multiple)
|
||||||
|
|
||||||
assertCanParse(t, reflect.TypeOf(is), true, false, true)
|
assertCardinality(t, reflect.TypeOf(&bs), multiple)
|
||||||
assertCanParse(t, reflect.TypeOf(&is), true, false, true)
|
assertCardinality(t, reflect.TypeOf(&is), multiple)
|
||||||
|
|
||||||
|
assertCardinality(t, reflect.TypeOf(m), multiple)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&m), multiple)
|
||||||
|
|
||||||
|
assertCardinality(t, reflect.TypeOf(unsupported1), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&unsupported1), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(unsupported2), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&unsupported2), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(unsupported3), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&unsupported3), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(unsupported4), unsupported)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&unsupported4), unsupported)
|
||||||
}
|
}
|
||||||
|
|
||||||
type implementsTextUnmarshaler struct{}
|
type implementsTextUnmarshaler struct{}
|
||||||
|
@ -45,13 +63,16 @@ func (*implementsTextUnmarshaler) UnmarshalText(text []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanParseTextUnmarshaler(t *testing.T) {
|
func TestCardinalityTextUnmarshaler(t *testing.T) {
|
||||||
var u implementsTextUnmarshaler
|
var x implementsTextUnmarshaler
|
||||||
var su []implementsTextUnmarshaler
|
var s []implementsTextUnmarshaler
|
||||||
assertCanParse(t, reflect.TypeOf(u), true, false, false)
|
var m []implementsTextUnmarshaler
|
||||||
assertCanParse(t, reflect.TypeOf(&u), true, false, false)
|
assertCardinality(t, reflect.TypeOf(x), one)
|
||||||
assertCanParse(t, reflect.TypeOf(su), true, false, true)
|
assertCardinality(t, reflect.TypeOf(&x), one)
|
||||||
assertCanParse(t, reflect.TypeOf(&su), true, false, true)
|
assertCardinality(t, reflect.TypeOf(s), multiple)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&s), multiple)
|
||||||
|
assertCardinality(t, reflect.TypeOf(m), multiple)
|
||||||
|
assertCardinality(t, reflect.TypeOf(&m), multiple)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsExported(t *testing.T) {
|
func TestIsExported(t *testing.T) {
|
||||||
|
@ -60,3 +81,11 @@ func TestIsExported(t *testing.T) {
|
||||||
assert.False(t, isExported(""))
|
assert.False(t, isExported(""))
|
||||||
assert.False(t, isExported(string([]byte{255})))
|
assert.False(t, isExported(string([]byte{255})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCardinalityString(t *testing.T) {
|
||||||
|
assert.Equal(t, "zero", zero.String())
|
||||||
|
assert.Equal(t, "one", one.String())
|
||||||
|
assert.Equal(t, "multiple", multiple.String())
|
||||||
|
assert.Equal(t, "unsupported", unsupported.String())
|
||||||
|
assert.Equal(t, "unknown(42)", cardinality(42).String())
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
package arg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
scalar "github.com/alexflint/go-scalar"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setSliceOrMap parses a sequence of strings into a slice or map. If clear is
|
||||||
|
// true then any values already in the slice or map are first removed.
|
||||||
|
func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
|
||||||
|
if !dest.CanSet() {
|
||||||
|
return fmt.Errorf("field is not writable")
|
||||||
|
}
|
||||||
|
|
||||||
|
t := dest.Type()
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
dest = dest.Elem()
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
return setSlice(dest, values, clear)
|
||||||
|
case reflect.Map:
|
||||||
|
return setMap(dest, values, clear)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("setSliceOrMap cannot insert values into a %v", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSlice parses a sequence of strings and inserts them into a slice. If clear
|
||||||
|
// is true then any values already in the slice are removed.
|
||||||
|
func setSlice(dest reflect.Value, values []string, clear bool) error {
|
||||||
|
var ptr bool
|
||||||
|
elem := dest.Type().Elem()
|
||||||
|
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
|
||||||
|
ptr = true
|
||||||
|
elem = elem.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear the slice in case default values exist
|
||||||
|
if clear && !dest.IsNil() {
|
||||||
|
dest.SetLen(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the values one-by-one
|
||||||
|
for _, s := range values {
|
||||||
|
v := reflect.New(elem)
|
||||||
|
if err := scalar.ParseValue(v.Elem(), s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
dest.Set(reflect.Append(dest, v))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setMap parses a sequence of name=value strings and inserts them into a map.
|
||||||
|
// If clear is true then any values already in the map are removed.
|
||||||
|
func setMap(dest reflect.Value, values []string, clear bool) error {
|
||||||
|
// determine the key and value type
|
||||||
|
var keyIsPtr bool
|
||||||
|
keyType := dest.Type().Key()
|
||||||
|
if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) {
|
||||||
|
keyIsPtr = true
|
||||||
|
keyType = keyType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var valIsPtr bool
|
||||||
|
valType := dest.Type().Elem()
|
||||||
|
if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) {
|
||||||
|
valIsPtr = true
|
||||||
|
valType = valType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear the slice in case default values exist
|
||||||
|
if clear && !dest.IsNil() {
|
||||||
|
for _, k := range dest.MapKeys() {
|
||||||
|
dest.SetMapIndex(k, reflect.Value{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocate the map if it is not allocated
|
||||||
|
if dest.IsNil() {
|
||||||
|
dest.Set(reflect.MakeMap(dest.Type()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the values one-by-one
|
||||||
|
for _, s := range values {
|
||||||
|
// split at the first equals sign
|
||||||
|
pos := strings.Index(s, "=")
|
||||||
|
if pos == -1 {
|
||||||
|
return fmt.Errorf("cannot parse %q into a map, expected format key=value", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the key
|
||||||
|
k := reflect.New(keyType)
|
||||||
|
if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !keyIsPtr {
|
||||||
|
k = k.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the value
|
||||||
|
v := reflect.New(valType)
|
||||||
|
if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !valIsPtr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// add it to the map
|
||||||
|
dest.SetMapIndex(k, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
package arg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetSliceWithoutClearing(t *testing.T) {
|
||||||
|
xs := []int{10}
|
||||||
|
entries := []string{"1", "2", "3"}
|
||||||
|
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []int{10, 1, 2, 3}, xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSliceAfterClearing(t *testing.T) {
|
||||||
|
xs := []int{100}
|
||||||
|
entries := []string{"1", "2", "3"}
|
||||||
|
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []int{1, 2, 3}, xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSliceInvalid(t *testing.T) {
|
||||||
|
xs := []int{100}
|
||||||
|
entries := []string{"invalid"}
|
||||||
|
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSlicePtr(t *testing.T) {
|
||||||
|
var xs []*int
|
||||||
|
entries := []string{"1", "2", "3"}
|
||||||
|
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, xs, 3)
|
||||||
|
assert.Equal(t, 1, *xs[0])
|
||||||
|
assert.Equal(t, 2, *xs[1])
|
||||||
|
assert.Equal(t, 3, *xs[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSliceTextUnmarshaller(t *testing.T) {
|
||||||
|
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||||
|
var xs []*textUnmarshaler
|
||||||
|
entries := []string{"a", "aa", "aaa"}
|
||||||
|
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, xs, 3)
|
||||||
|
assert.Equal(t, 1, xs[0].val)
|
||||||
|
assert.Equal(t, 2, xs[1].val)
|
||||||
|
assert.Equal(t, 3, xs[2].val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapWithoutClearing(t *testing.T) {
|
||||||
|
m := map[string]int{"foo": 10}
|
||||||
|
entries := []string{"a=1", "b=2"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m, 3)
|
||||||
|
assert.Equal(t, 1, m["a"])
|
||||||
|
assert.Equal(t, 2, m["b"])
|
||||||
|
assert.Equal(t, 10, m["foo"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapAfterClearing(t *testing.T) {
|
||||||
|
m := map[string]int{"foo": 10}
|
||||||
|
entries := []string{"a=1", "b=2"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m, 2)
|
||||||
|
assert.Equal(t, 1, m["a"])
|
||||||
|
assert.Equal(t, 2, m["b"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapWithKeyPointer(t *testing.T) {
|
||||||
|
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||||
|
var m map[*string]int
|
||||||
|
entries := []string{"abc=123"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapWithValuePointer(t *testing.T) {
|
||||||
|
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||||
|
var m map[string]*int
|
||||||
|
entries := []string{"abc=123"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m, 1)
|
||||||
|
assert.Equal(t, 123, *m["abc"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapTextUnmarshaller(t *testing.T) {
|
||||||
|
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||||
|
var m map[textUnmarshaler]*textUnmarshaler
|
||||||
|
entries := []string{"a=123", "aa=12", "aaa=1"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m, 3)
|
||||||
|
assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}])
|
||||||
|
assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}])
|
||||||
|
assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapInvalidKey(t *testing.T) {
|
||||||
|
var m map[int]int
|
||||||
|
entries := []string{"invalid=123"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapInvalidValue(t *testing.T) {
|
||||||
|
var m map[int]int
|
||||||
|
entries := []string{"123=invalid"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMapMalformed(t *testing.T) {
|
||||||
|
// textUnmarshaler is a struct that captures the length of the string passed to it
|
||||||
|
var m map[string]string
|
||||||
|
entries := []string{"missing_equals_sign"}
|
||||||
|
err := setMap(reflect.ValueOf(&m).Elem(), entries, true)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSliceOrMapErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
var dest reflect.Value
|
||||||
|
|
||||||
|
// converting a slice to a reflect.Value in this way will make it read only
|
||||||
|
var cannotSet []int
|
||||||
|
dest = reflect.ValueOf(cannotSet)
|
||||||
|
err = setSliceOrMap(dest, nil, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// check what happens when we pass in something that is not a slice or a map
|
||||||
|
var notSliceOrMap string
|
||||||
|
dest = reflect.ValueOf(¬SliceOrMap).Elem()
|
||||||
|
err = setSliceOrMap(dest, nil, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// check what happens when we pass in a pointer to something that is not a slice or a map
|
||||||
|
var stringPtr *string
|
||||||
|
dest = reflect.ValueOf(&stringPtr).Elem()
|
||||||
|
err = setSliceOrMap(dest, nil, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
18
usage.go
18
usage.go
|
@ -95,7 +95,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
|
||||||
for _, spec := range positionals {
|
for _, spec := range positionals {
|
||||||
// prefix with a space
|
// prefix with a space
|
||||||
fmt.Fprint(w, " ")
|
fmt.Fprint(w, " ")
|
||||||
if spec.multiple {
|
if spec.cardinality == multiple {
|
||||||
if !spec.required {
|
if !spec.required {
|
||||||
fmt.Fprint(w, "[")
|
fmt.Fprint(w, "[")
|
||||||
}
|
}
|
||||||
|
@ -213,16 +213,16 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
|
||||||
|
|
||||||
// write the list of built in options
|
// write the list of built in options
|
||||||
p.printOption(w, &spec{
|
p.printOption(w, &spec{
|
||||||
boolean: true,
|
cardinality: zero,
|
||||||
long: "help",
|
long: "help",
|
||||||
short: "h",
|
short: "h",
|
||||||
help: "display this help and exit",
|
help: "display this help and exit",
|
||||||
})
|
})
|
||||||
if p.version != "" {
|
if p.version != "" {
|
||||||
p.printOption(w, &spec{
|
p.printOption(w, &spec{
|
||||||
boolean: true,
|
cardinality: zero,
|
||||||
long: "version",
|
long: "version",
|
||||||
help: "display version and exit",
|
help: "display version and exit",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func synopsis(spec *spec, form string) string {
|
func synopsis(spec *spec, form string) string {
|
||||||
if spec.boolean {
|
if spec.cardinality == zero {
|
||||||
return form
|
return form
|
||||||
}
|
}
|
||||||
return form + " " + spec.placeholder
|
return form + " " + spec.placeholder
|
||||||
|
|
Loading…
Reference in New Issue