From 23b2b67fe299b63a072a3541f34d57757d0b8df0 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Thu, 9 Jun 2022 11:21:29 -0400 Subject: [PATCH] fix issue #184 --- go.mod | 8 ++++++- parse.go | 43 +++++++++++++++++++++++++-------- parse_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++ reflect.go | 19 ++++++++++++--- usage_test.go | 3 +-- 5 files changed, 123 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index 67ac880..0823012 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,10 @@ require ( github.com/stretchr/testify v1.7.0 ) -go 1.13 +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) + +go 1.18 diff --git a/parse.go b/parse.go index 7588dfb..28ed014 100644 --- a/parse.go +++ b/parse.go @@ -208,18 +208,41 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { return nil, err } - // add nonzero field values as defaults + // for backwards compatibility, add nonzero field values as defaults for _, spec := range cmd.specs { - if v := p.val(spec.dest); v.IsValid() && !isZero(v) { - if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { - str, err := defaultVal.MarshalText() - if err != nil { - return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) - } - spec.defaultVal = string(str) - } else { - spec.defaultVal = fmt.Sprintf("%v", v) + // do not read default when UnmarshalText is implemented but not MarshalText + if isTextUnmarshaler(spec.field.Type) && !isTextMarshaler(spec.field.Type) { + continue + } + + // do not process types that require multiple values + cardinality, _ := cardinalityOf(spec.field.Type) + if cardinality != one { + continue + } + + // get the value + v := p.val(spec.dest) + if !v.IsValid() { + continue + } + + // if MarshalText is implemented then use that + if m, ok := v.Interface().(encoding.TextMarshaler); ok { + if v.IsNil() { + continue } + s, err := m.MarshalText() + if err != nil { + return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) + } + spec.defaultVal = string(s) + continue + } + + // finally, use the value as a default if it is non-zero + if !isZero(v) { + spec.defaultVal = fmt.Sprintf("%v", v) } } diff --git a/parse_test.go b/parse_test.go index 2d0ef7a..0d58598 100644 --- a/parse_test.go +++ b/parse_test.go @@ -2,6 +2,7 @@ package arg import ( "bytes" + "encoding/json" "fmt" "net" "net/mail" @@ -1456,3 +1457,68 @@ func TestMustParsePrintsVersion(t *testing.T) { assert.Equal(t, 0, *exitCode) assert.Equal(t, "example 3.2.1\n", b.String()) } + +type jsonMap struct { + val map[string]string +} + +func (v *jsonMap) UnmarshalText(data []byte) error { + return json.Unmarshal(data, &v.val) +} + +func TestTextUnmarshallerEmpty(t *testing.T) { + // based on https://github.com/alexflint/go-arg/issues/184 + var args struct { + Config jsonMap `arg:"--config"` + } + + err := parse("", &args) + require.NoError(t, err) + assert.Empty(t, args.Config) +} + +func TestTextUnmarshallerEmptyPointer(t *testing.T) { + // a slight variant on https://github.com/alexflint/go-arg/issues/184 + var args struct { + Config *jsonMap `arg:"--config"` + } + + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.Config) +} + +// similar to the above but also implements MarshalText +type jsonMap2[T any] struct { + val T +} + +func (v *jsonMap2[T]) MarshalText(data []byte) error { + return json.Unmarshal(data, &v.val) +} + +func (v *jsonMap2[T]) UnmarshalText(data []byte) error { + return json.Unmarshal(data, &v.val) +} + +func TestTextMarshallerUnmarshallerEmpty(t *testing.T) { + // based on https://github.com/alexflint/go-arg/issues/184 + var args struct { + Config jsonMap2[map[string]string] `arg:"--config"` + } + + err := parse("", &args) + require.NoError(t, err) + assert.Empty(t, args.Config) +} + +func TestTextMarshallerUnmarshallerEmptyPointer(t *testing.T) { + // a slight variant on https://github.com/alexflint/go-arg/issues/184 + var args struct { + Config *jsonMap2[map[string]string] `arg:"--config"` + } + + err := parse("", &args) + require.NoError(t, err) + assert.Nil(t, args.Config) +} diff --git a/reflect.go b/reflect.go index cd80be7..b87db2a 100644 --- a/reflect.go +++ b/reflect.go @@ -10,7 +10,10 @@ import ( scalar "github.com/alexflint/go-scalar" ) -var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() +var ( + textMarshalerType = reflect.TypeOf([]encoding.TextMarshaler{}).Elem() + textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() +) // cardinality tracks how many tokens are expected for a given spec // - zero is a boolean, which does to expect any value @@ -74,10 +77,10 @@ func cardinalityOf(t reflect.Type) (cardinality, error) { } } -// isBoolean returns true if the type can be parsed from a single string +// isBoolean returns true if the type is a boolean or a pointer to a boolean func isBoolean(t reflect.Type) bool { switch { - case t.Implements(textUnmarshalerType): + case isTextUnmarshaler(t): return false case t.Kind() == reflect.Bool: return true @@ -88,6 +91,16 @@ func isBoolean(t reflect.Type) bool { } } +// isTextMarshaler returns true if the type or its pointer implements encoding.TextMarshaler +func isTextMarshaler(t reflect.Type) bool { + return t.Implements(textMarshalerType) || reflect.PtrTo(t).Implements(textMarshalerType) +} + +// isTextUnmarshaler returns true if the type or its pointer implements encoding.TextUnmarshaler +func isTextUnmarshaler(t reflect.Type) bool { + return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) +} + // isExported returns true if the struct field name is exported func isExported(field string) bool { r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8 diff --git a/usage_test.go b/usage_test.go index 1744536..0a7ddd8 100644 --- a/usage_test.go +++ b/usage_test.go @@ -50,7 +50,7 @@ Options: --optimize OPTIMIZE, -O OPTIMIZE optimization level --ids IDS Ids - --values VALUES Values [default: [3.14 42 256]] + --values VALUES Values --workers WORKERS, -w WORKERS number of workers to start [default: 10, env: WORKERS] --testenv TESTENV, -a TESTENV [env: TEST_ENV] @@ -74,7 +74,6 @@ Options: } args.Name = "Foo Bar" args.Value = 42 - args.Values = []float64{3.14, 42, 256} args.File = &NameDotName{"scratch", "txt"} p, err := NewParser(Config{Program: "example"}, &args) require.NoError(t, err)