diff --git a/v2/parse.go b/v2/parse.go index 251ddb6..dcdd353 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -247,7 +247,7 @@ func (p *Parser) processOptions(args []string, overwrite bool) ([]string, error) } // store the values into the slice or map - err := setSliceOrMap(p.val(arg.dest), values, !arg.separate) + err := setSliceOrMap(p.val(arg.dest), values) if err != nil { return nil, fmt.Errorf("error processing %s: %v", token, err) } @@ -312,7 +312,7 @@ func (p *Parser) processPositionals(positionals []string, overwrite bool) error } if arg.cardinality == multiple { if !p.seen[arg] || overwrite { - err := setSliceOrMap(p.val(arg.dest), positionals, true) + err := setSliceOrMap(p.val(arg.dest), positionals) if err != nil { return fmt.Errorf("error processing %s: %v", arg.field.Name, err) } @@ -385,19 +385,20 @@ func (p *Parser) processEnvironment(environ []string, overwrite bool) error { if len(strings.TrimSpace(value)) > 0 { values, err = csv.NewReader(strings.NewReader(value)).Read() if err != nil { - return fmt.Errorf( - "error reading a CSV string from environment variable %s with multiple values: %v", - arg.env, - err, - ) + return fmt.Errorf("error reading a CSV string from environment variable %s : %v", arg.env, err) } } - if err = setSliceOrMap(p.val(arg.dest), values, !arg.separate); err != nil { - return fmt.Errorf( - "error processing environment variable %s with multiple values: %v", - arg.env, - err, - ) + + if arg.separate { + if err = setSliceOrMap(p.val(arg.dest), values); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) + } + } else { + for _, s := range values { + if err = appendToSliceOrMap(p.val(arg.dest), s); err != nil { + return fmt.Errorf("error processing environment variable %s: %v", arg.env, err) + } + } } } else { if err := scalar.ParseValue(p.val(arg.dest), value); err != nil { diff --git a/v2/sequence.go b/v2/sequence.go index f0fff46..566c8d2 100644 --- a/v2/sequence.go +++ b/v2/sequence.go @@ -8,9 +8,9 @@ import ( scalar "github.com/alexflint/go-scalar" ) -// setSliceOrMap parses a sequence of strings into a slice or map. If clear is -// true then any values already in the slice or map are first removed. -func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { +// setSliceOrMap parses a sequence of strings into a slice or map. The slice or +// map is always cleared first. +func setSliceOrMap(dest reflect.Value, values []string) error { if !dest.CanSet() { return fmt.Errorf("field is not writable") } @@ -23,17 +23,17 @@ func setSliceOrMap(dest reflect.Value, values []string, clear bool) error { switch t.Kind() { case reflect.Slice: - return setSlice(dest, values, clear) + return setSlice(dest, values) case reflect.Map: - return setMap(dest, values, clear) + return setMap(dest, values) default: return fmt.Errorf("cannot insert multiple values into a %v", t) } } -// setSlice parses a sequence of strings and inserts them into a slice. If clear -// is true then any values already in the slice are removed. -func setSlice(dest reflect.Value, values []string, clear bool) error { +// setSlice parses a sequence of strings and inserts them into a slice. The +// slice is cleared first. +func setSlice(dest reflect.Value, values []string) error { var ptr bool elem := dest.Type().Elem() if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { @@ -42,7 +42,7 @@ func setSlice(dest reflect.Value, values []string, clear bool) error { } // clear the slice in case default values exist - if clear && !dest.IsNil() { + if !dest.IsNil() { dest.SetLen(0) } @@ -61,8 +61,8 @@ func setSlice(dest reflect.Value, values []string, clear bool) error { } // setMap parses a sequence of name=value strings and inserts them into a map. -// If clear is true then any values already in the map are removed. -func setMap(dest reflect.Value, values []string, clear bool) error { +// The map is always cleared first. +func setMap(dest reflect.Value, values []string) error { // determine the key and value type var keyIsPtr bool keyType := dest.Type().Key() @@ -79,7 +79,7 @@ func setMap(dest reflect.Value, values []string, clear bool) error { } // clear the slice in case default values exist - if clear && !dest.IsNil() { + if !dest.IsNil() { for _, k := range dest.MapKeys() { dest.SetMapIndex(k, reflect.Value{}) } diff --git a/v2/sequence_test.go b/v2/sequence_test.go index 6383949..519cdec 100644 --- a/v2/sequence_test.go +++ b/v2/sequence_test.go @@ -8,18 +8,17 @@ import ( "github.com/stretchr/testify/require" ) -func TestSetSliceWithoutClearing(t *testing.T) { +func TestAppendToSlice(t *testing.T) { xs := []int{10} - entries := []string{"1", "2", "3"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + err := appendToSlice(reflect.ValueOf(&xs).Elem(), "3") require.NoError(t, err) - assert.Equal(t, []int{10, 1, 2, 3}, xs) + assert.Equal(t, []int{10, 3}, xs) } -func TestSetSliceAfterClearing(t *testing.T) { +func TestSetSlice(t *testing.T) { xs := []int{100} entries := []string{"1", "2", "3"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, xs) } @@ -27,14 +26,14 @@ func TestSetSliceAfterClearing(t *testing.T) { func TestSetSliceInvalid(t *testing.T) { xs := []int{100} entries := []string{"invalid"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) assert.Error(t, err) } func TestSetSlicePtr(t *testing.T) { var xs []*int entries := []string{"1", "2", "3"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) require.NoError(t, err) require.Len(t, xs, 3) assert.Equal(t, 1, *xs[0]) @@ -46,7 +45,7 @@ func TestSetSliceTextUnmarshaller(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var xs []*textUnmarshaler entries := []string{"a", "aa", "aaa"} - err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + err := setSlice(reflect.ValueOf(&xs).Elem(), entries) require.NoError(t, err) require.Len(t, xs, 3) assert.Equal(t, 1, xs[0].val) @@ -54,21 +53,19 @@ func TestSetSliceTextUnmarshaller(t *testing.T) { assert.Equal(t, 3, xs[2].val) } -func TestSetMapWithoutClearing(t *testing.T) { +func TestAppendToMap(t *testing.T) { m := map[string]int{"foo": 10} - entries := []string{"a=1", "b=2"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + err := appendToMap(reflect.ValueOf(&m).Elem(), "a=1") require.NoError(t, err) - require.Len(t, m, 3) + require.Len(t, m, 2) assert.Equal(t, 1, m["a"]) - assert.Equal(t, 2, m["b"]) assert.Equal(t, 10, m["foo"]) } func TestSetMapAfterClearing(t *testing.T) { m := map[string]int{"foo": 10} entries := []string{"a=1", "b=2"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 2) assert.Equal(t, 1, m["a"]) @@ -79,7 +76,7 @@ func TestSetMapWithKeyPointer(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[*string]int entries := []string{"abc=123"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 1) } @@ -88,7 +85,7 @@ func TestSetMapWithValuePointer(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[string]*int entries := []string{"abc=123"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 1) assert.Equal(t, 123, *m["abc"]) @@ -98,7 +95,7 @@ func TestSetMapTextUnmarshaller(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[textUnmarshaler]*textUnmarshaler entries := []string{"a=123", "aa=12", "aaa=1"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) require.NoError(t, err) require.Len(t, m, 3) assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) @@ -109,14 +106,14 @@ func TestSetMapTextUnmarshaller(t *testing.T) { func TestSetMapInvalidKey(t *testing.T) { var m map[int]int entries := []string{"invalid=123"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) assert.Error(t, err) } func TestSetMapInvalidValue(t *testing.T) { var m map[int]int entries := []string{"123=invalid"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) assert.Error(t, err) } @@ -124,7 +121,7 @@ func TestSetMapMalformed(t *testing.T) { // textUnmarshaler is a struct that captures the length of the string passed to it var m map[string]string entries := []string{"missing_equals_sign"} - err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + err := setMap(reflect.ValueOf(&m).Elem(), entries) assert.Error(t, err) } @@ -135,22 +132,18 @@ func TestSetSliceOrMapErrors(t *testing.T) { // converting a slice to a reflect.Value in this way will make it read only var cannotSet []int dest = reflect.ValueOf(cannotSet) - err = setSliceOrMap(dest, nil, false) + err = setSliceOrMap(dest, nil) assert.Error(t, err) // check what happens when we pass in something that is not a slice or a map var notSliceOrMap string dest = reflect.ValueOf(¬SliceOrMap).Elem() - err = setSliceOrMap(dest, nil, false) + err = setSliceOrMap(dest, nil) assert.Error(t, err) // check what happens when we pass in a pointer to something that is not a slice or a map var stringPtr *string dest = reflect.ValueOf(&stringPtr).Elem() - err = setSliceOrMap(dest, nil, false) + err = setSliceOrMap(dest, nil) assert.Error(t, err) } - -// check that we can accumulate "separate" args across env, cmdline, map, and defaults - -// check what happens if we have a required arg with a default value