add appendToSlice, appendToMap, appendToSliceOrMap

This commit is contained in:
Alex Flint 2022-10-04 12:48:04 -07:00
parent 2775f58376
commit 64288c5521
3 changed files with 47 additions and 53 deletions

View File

@ -247,7 +247,7 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error)
} }
// store the values into the slice or map // store the values into the slice or map
err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) err := setSliceOrMap(p.val(arg.dest), values)
if err != nil { if err != nil {
return nil, fmt.Errorf("error processing %s: %v", token, err) return nil, fmt.Errorf("error processing %s: %v", token, err)
} }
@ -312,7 +312,7 @@ func (p *Parser) processPositionals(positionals []string, overwrite bool) error
} }
if arg.cardinality == multiple { if arg.cardinality == multiple {
if !p.seen[arg] || overwrite { if !p.seen[arg] || overwrite {
err := setSliceOrMap(p.val(arg.dest), positionals, true) err := setSliceOrMap(p.val(arg.dest), positionals)
if err != nil { if err != nil {
return fmt.Errorf("error processing %s: %v", arg.field.Name, err) return fmt.Errorf("error processing %s: %v", arg.field.Name, err)
} }
@ -385,19 +385,20 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error {
if len(strings.TrimSpace(value)) > 0 { if len(strings.TrimSpace(value)) > 0 {
values, err = csv.NewReader(strings.NewReader(value)).Read() values, err = csv.NewReader(strings.NewReader(value)).Read()
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf("error reading a CSV string from environment variable %s : %v", arg.env, err)
"error reading a CSV string from environment variable %s with multiple values: %v",
arg.env,
err,
)
} }
} }
if err = setSliceOrMap(p.val(arg.dest), values, !arg.separate); err != nil {
return fmt.Errorf( if arg.separate {
"error processing environment variable %s with multiple values: %v", if err = setSliceOrMap(p.val(arg.dest), values); err != nil {
arg.env, return fmt.Errorf("error processing environment variable %s: %v", arg.env, err)
err, }
) } else {
for _, s := range values {
if err = appendToSliceOrMap(p.val(arg.dest), s); err != nil {
return fmt.Errorf("error processing environment variable %s: %v", arg.env, err)
}
}
} }
} else { } else {
if err := scalar.ParseValue(p.val(arg.dest), value); err != nil { if err := scalar.ParseValue(p.val(arg.dest), value); err != nil {

View File

@ -8,9 +8,9 @@ import (
scalar "github.com/alexflint/go-scalar" scalar "github.com/alexflint/go-scalar"
) )
// setSliceOrMap parses a sequence of strings into a slice or map. If clear is // setSliceOrMap parses a sequence of strings into a slice or map. The slice or
// true then any values already in the slice or map are first removed. // map is always cleared first.
func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { func setSliceOrMap(dest reflect.Value, values []string) error {
if !dest.CanSet() { if !dest.CanSet() {
return fmt.Errorf("field is not writable") return fmt.Errorf("field is not writable")
} }
@ -23,17 +23,17 @@ func setSliceOrMap(dest reflect.Value, values []string, clear bool) error {
switch t.Kind() { switch t.Kind() {
case reflect.Slice: case reflect.Slice:
return setSlice(dest, values, clear) return setSlice(dest, values)
case reflect.Map: case reflect.Map:
return setMap(dest, values, clear) return setMap(dest, values)
default: default:
return fmt.Errorf("cannot insert multiple values into a %v", t) return fmt.Errorf("cannot insert multiple values into a %v", t)
} }
} }
// setSlice parses a sequence of strings and inserts them into a slice. If clear // setSlice parses a sequence of strings and inserts them into a slice. The
// is true then any values already in the slice are removed. // slice is cleared first.
func setSlice(dest reflect.Value, values []string, clear bool) error { func setSlice(dest reflect.Value, values []string) error {
var ptr bool var ptr bool
elem := dest.Type().Elem() elem := dest.Type().Elem()
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
@ -42,7 +42,7 @@ func setSlice(dest reflect.Value, values []string, clear bool) error {
} }
// clear the slice in case default values exist // clear the slice in case default values exist
if clear && !dest.IsNil() { if !dest.IsNil() {
dest.SetLen(0) dest.SetLen(0)
} }
@ -61,8 +61,8 @@ func setSlice(dest reflect.Value, values []string, clear bool) error {
} }
// setMap parses a sequence of name=value strings and inserts them into a map. // setMap parses a sequence of name=value strings and inserts them into a map.
// If clear is true then any values already in the map are removed. // The map is always cleared first.
func setMap(dest reflect.Value, values []string, clear bool) error { func setMap(dest reflect.Value, values []string) error {
// determine the key and value type // determine the key and value type
var keyIsPtr bool var keyIsPtr bool
keyType := dest.Type().Key() keyType := dest.Type().Key()
@ -79,7 +79,7 @@ func setMap(dest reflect.Value, values []string, clear bool) error {
} }
// clear the slice in case default values exist // clear the slice in case default values exist
if clear && !dest.IsNil() { if !dest.IsNil() {
for _, k := range dest.MapKeys() { for _, k := range dest.MapKeys() {
dest.SetMapIndex(k, reflect.Value{}) dest.SetMapIndex(k, reflect.Value{})
} }

View File

@ -8,18 +8,17 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestSetSliceWithoutClearing(t *testing.T) { func TestAppendToSlice(t *testing.T) {
xs := []int{10} xs := []int{10}
entries := []string{"1", "2", "3"} err := appendToSlice(reflect.ValueOf(&xs).Elem(), "3")
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []int{10, 1, 2, 3}, xs) assert.Equal(t, []int{10, 3}, xs)
} }
func TestSetSliceAfterClearing(t *testing.T) { func TestSetSlice(t *testing.T) {
xs := []int{100} xs := []int{100}
entries := []string{"1", "2", "3"} entries := []string{"1", "2", "3"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) err := setSlice(reflect.ValueOf(&xs).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, xs) assert.Equal(t, []int{1, 2, 3}, xs)
} }
@ -27,14 +26,14 @@ func TestSetSliceAfterClearing(t *testing.T) {
func TestSetSliceInvalid(t *testing.T) { func TestSetSliceInvalid(t *testing.T) {
xs := []int{100} xs := []int{100}
entries := []string{"invalid"} entries := []string{"invalid"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) err := setSlice(reflect.ValueOf(&xs).Elem(), entries)
assert.Error(t, err) assert.Error(t, err)
} }
func TestSetSlicePtr(t *testing.T) { func TestSetSlicePtr(t *testing.T) {
var xs []*int var xs []*int
entries := []string{"1", "2", "3"} entries := []string{"1", "2", "3"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) err := setSlice(reflect.ValueOf(&xs).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, xs, 3) require.Len(t, xs, 3)
assert.Equal(t, 1, *xs[0]) assert.Equal(t, 1, *xs[0])
@ -46,7 +45,7 @@ func TestSetSliceTextUnmarshaller(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it // textUnmarshaler is a struct that captures the length of the string passed to it
var xs []*textUnmarshaler var xs []*textUnmarshaler
entries := []string{"a", "aa", "aaa"} entries := []string{"a", "aa", "aaa"}
err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) err := setSlice(reflect.ValueOf(&xs).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, xs, 3) require.Len(t, xs, 3)
assert.Equal(t, 1, xs[0].val) assert.Equal(t, 1, xs[0].val)
@ -54,21 +53,19 @@ func TestSetSliceTextUnmarshaller(t *testing.T) {
assert.Equal(t, 3, xs[2].val) assert.Equal(t, 3, xs[2].val)
} }
func TestSetMapWithoutClearing(t *testing.T) { func TestAppendToMap(t *testing.T) {
m := map[string]int{"foo": 10} m := map[string]int{"foo": 10}
entries := []string{"a=1", "b=2"} err := appendToMap(reflect.ValueOf(&m).Elem(), "a=1")
err := setMap(reflect.ValueOf(&m).Elem(), entries, false)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, m, 3) require.Len(t, m, 2)
assert.Equal(t, 1, m["a"]) assert.Equal(t, 1, m["a"])
assert.Equal(t, 2, m["b"])
assert.Equal(t, 10, m["foo"]) assert.Equal(t, 10, m["foo"])
} }
func TestSetMapAfterClearing(t *testing.T) { func TestSetMapAfterClearing(t *testing.T) {
m := map[string]int{"foo": 10} m := map[string]int{"foo": 10}
entries := []string{"a=1", "b=2"} entries := []string{"a=1", "b=2"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, m, 2) require.Len(t, m, 2)
assert.Equal(t, 1, m["a"]) assert.Equal(t, 1, m["a"])
@ -79,7 +76,7 @@ func TestSetMapWithKeyPointer(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it // textUnmarshaler is a struct that captures the length of the string passed to it
var m map[*string]int var m map[*string]int
entries := []string{"abc=123"} entries := []string{"abc=123"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, m, 1) require.Len(t, m, 1)
} }
@ -88,7 +85,7 @@ func TestSetMapWithValuePointer(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it // textUnmarshaler is a struct that captures the length of the string passed to it
var m map[string]*int var m map[string]*int
entries := []string{"abc=123"} entries := []string{"abc=123"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, m, 1) require.Len(t, m, 1)
assert.Equal(t, 123, *m["abc"]) assert.Equal(t, 123, *m["abc"])
@ -98,7 +95,7 @@ func TestSetMapTextUnmarshaller(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it // textUnmarshaler is a struct that captures the length of the string passed to it
var m map[textUnmarshaler]*textUnmarshaler var m map[textUnmarshaler]*textUnmarshaler
entries := []string{"a=123", "aa=12", "aaa=1"} entries := []string{"a=123", "aa=12", "aaa=1"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, m, 3) require.Len(t, m, 3)
assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}])
@ -109,14 +106,14 @@ func TestSetMapTextUnmarshaller(t *testing.T) {
func TestSetMapInvalidKey(t *testing.T) { func TestSetMapInvalidKey(t *testing.T) {
var m map[int]int var m map[int]int
entries := []string{"invalid=123"} entries := []string{"invalid=123"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
assert.Error(t, err) assert.Error(t, err)
} }
func TestSetMapInvalidValue(t *testing.T) { func TestSetMapInvalidValue(t *testing.T) {
var m map[int]int var m map[int]int
entries := []string{"123=invalid"} entries := []string{"123=invalid"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
assert.Error(t, err) assert.Error(t, err)
} }
@ -124,7 +121,7 @@ func TestSetMapMalformed(t *testing.T) {
// textUnmarshaler is a struct that captures the length of the string passed to it // textUnmarshaler is a struct that captures the length of the string passed to it
var m map[string]string var m map[string]string
entries := []string{"missing_equals_sign"} entries := []string{"missing_equals_sign"}
err := setMap(reflect.ValueOf(&m).Elem(), entries, true) err := setMap(reflect.ValueOf(&m).Elem(), entries)
assert.Error(t, err) assert.Error(t, err)
} }
@ -135,22 +132,18 @@ func TestSetSliceOrMapErrors(t *testing.T) {
// converting a slice to a reflect.Value in this way will make it read only // converting a slice to a reflect.Value in this way will make it read only
var cannotSet []int var cannotSet []int
dest = reflect.ValueOf(cannotSet) dest = reflect.ValueOf(cannotSet)
err = setSliceOrMap(dest, nil, false) err = setSliceOrMap(dest, nil)
assert.Error(t, err) assert.Error(t, err)
// check what happens when we pass in something that is not a slice or a map // check what happens when we pass in something that is not a slice or a map
var notSliceOrMap string var notSliceOrMap string
dest = reflect.ValueOf(&notSliceOrMap).Elem() dest = reflect.ValueOf(&notSliceOrMap).Elem()
err = setSliceOrMap(dest, nil, false) err = setSliceOrMap(dest, nil)
assert.Error(t, err) assert.Error(t, err)
// check what happens when we pass in a pointer to something that is not a slice or a map // check what happens when we pass in a pointer to something that is not a slice or a map
var stringPtr *string var stringPtr *string
dest = reflect.ValueOf(&stringPtr).Elem() dest = reflect.ValueOf(&stringPtr).Elem()
err = setSliceOrMap(dest, nil, false) err = setSliceOrMap(dest, nil)
assert.Error(t, err) assert.Error(t, err)
} }
// check that we can accumulate "separate" args across env, cmdline, map, and defaults
// check what happens if we have a required arg with a default value