From 1dfefdc43e8a9a06b532b5c29f876eb38f86a928 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Mon, 19 Apr 2021 12:10:53 -0700 Subject: [PATCH] factor setSlice into its own file, add setMap, and add tests for both --- parse.go | 31 -------------- sequence.go | 108 +++++++++++++++++++++++++++++++++++++++++++++++ sequence_test.go | 81 +++++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 31 deletions(-) create mode 100644 sequence.go create mode 100644 sequence_test.go diff --git a/parse.go b/parse.go index b7d159d..84a7ed1 100644 --- a/parse.go +++ b/parse.go @@ -702,37 +702,6 @@ func (p *Parser) val(dest path) reflect.Value { return v } -// parse a value as the appropriate type and store it in the struct -func setSlice(dest reflect.Value, values []string, trunc bool) error { - if !dest.CanSet() { - return fmt.Errorf("field is not writable") - } - - var ptr bool - elem := dest.Type().Elem() - if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { - ptr = true - elem = elem.Elem() - } - - // Truncate the dest slice in case default values exist - if trunc && !dest.IsNil() { - dest.SetLen(0) - } - - for _, s := range values { - v := reflect.New(elem) - if err := scalar.ParseValue(v.Elem(), s); err != nil { - return err - } - if !ptr { - v = v.Elem() - } - dest.Set(reflect.Append(dest, v)) - } - return nil -} - // findOption finds an option from its name, or returns null if no spec is found func findOption(specs []*spec, name string) *spec { for _, spec := range specs { diff --git a/sequence.go b/sequence.go new file mode 100644 index 0000000..8971341 --- /dev/null +++ b/sequence.go @@ -0,0 +1,108 @@ +package arg + +import ( + "fmt" + "reflect" + "strings" + + scalar "github.com/alexflint/go-scalar" +) + +// setSlice parses a sequence of strings and inserts them into a slice. If clear +// is true then any values already in the slice are removed. +func setSlice(dest reflect.Value, values []string, clear bool) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + var ptr bool + elem := dest.Type().Elem() + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { + ptr = true + elem = elem.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + dest.SetLen(0) + } + + // parse the values one-by-one + for _, s := range values { + v := reflect.New(elem) + if err := scalar.ParseValue(v.Elem(), s); err != nil { + return err + } + if !ptr { + v = v.Elem() + } + dest.Set(reflect.Append(dest, v)) + } + return nil +} + +// setMap parses a sequence of name=value strings and inserts them into a map. +// If clear is true then any values already in the map are removed. +func setMap(dest reflect.Value, values []string, clear bool) error { + if !dest.CanSet() { + return fmt.Errorf("field is not writable") + } + + // determine the key and value type + var keyIsPtr bool + keyType := dest.Type().Key() + if keyType.Kind() == reflect.Ptr && !keyType.Implements(textUnmarshalerType) { + keyIsPtr = true + keyType = keyType.Elem() + } + + var valIsPtr bool + valType := dest.Type().Elem() + if valType.Kind() == reflect.Ptr && !valType.Implements(textUnmarshalerType) { + valIsPtr = true + valType = valType.Elem() + } + + // clear the slice in case default values exist + if clear && !dest.IsNil() { + for _, k := range dest.MapKeys() { + dest.SetMapIndex(k, reflect.Value{}) + } + } + + // allocate the map if it is not allocated + if dest.IsNil() { + dest.Set(reflect.MakeMap(dest.Type())) + } + + // parse the values one-by-one + for _, s := range values { + // split at the first equals sign + pos := strings.Index(s, "=") + if pos == -1 { + return fmt.Errorf("cannot parse %q into a map, expected format key=value", s) + } + + // parse the key + k := reflect.New(keyType) + if err := scalar.ParseValue(k.Elem(), s[:pos]); err != nil { + return err + } + if !keyIsPtr { + k = k.Elem() + } + + // parse the value + v := reflect.New(valType) + if err := scalar.ParseValue(v.Elem(), s[pos+1:]); err != nil { + return err + } + if !valIsPtr { + v = v.Elem() + } + + // add it to the map + dest.SetMapIndex(k, v) + } + return nil +} diff --git a/sequence_test.go b/sequence_test.go new file mode 100644 index 0000000..4646811 --- /dev/null +++ b/sequence_test.go @@ -0,0 +1,81 @@ +package arg + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetSliceWithoutClearing(t *testing.T) { + xs := []int{10} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, false) + require.NoError(t, err) + assert.Equal(t, []int{10, 1, 2, 3}, xs) +} + +func TestSetSliceWithClear(t *testing.T) { + xs := []int{100} + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, xs) +} + +func TestSetSlicePtr(t *testing.T) { + var xs []*int + entries := []string{"1", "2", "3"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, *xs[0]) + assert.Equal(t, 2, *xs[1]) + assert.Equal(t, 3, *xs[2]) +} + +func TestSetSliceTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var xs []*textUnmarshaler + entries := []string{"a", "aa", "aaa"} + err := setSlice(reflect.ValueOf(&xs).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, xs, 3) + assert.Equal(t, 1, xs[0].val) + assert.Equal(t, 2, xs[1].val) + assert.Equal(t, 3, xs[2].val) +} + +func TestSetMapWithoutClearing(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, false) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) + assert.Equal(t, 10, m["foo"]) +} + +func TestSetMapWithClear(t *testing.T) { + m := map[string]int{"foo": 10} + entries := []string{"a=1", "b=2"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 2) + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) +} + +func TestSetMapTextUnmarshaller(t *testing.T) { + // textUnmarshaler is a struct that captures the length of the string passed to it + var m map[textUnmarshaler]*textUnmarshaler + entries := []string{"a=123", "aa=12", "aaa=1"} + err := setMap(reflect.ValueOf(&m).Elem(), entries, true) + require.NoError(t, err) + require.Len(t, m, 3) + assert.Equal(t, &textUnmarshaler{3}, m[textUnmarshaler{1}]) + assert.Equal(t, &textUnmarshaler{2}, m[textUnmarshaler{2}]) + assert.Equal(t, &textUnmarshaler{1}, m[textUnmarshaler{3}]) +}