This commit is contained in:
Alex Flint 2022-06-09 11:21:29 -04:00
parent f0f44b65d1
commit 23b2b67fe2
5 changed files with 123 additions and 16 deletions

8
go.mod
View File

@ -5,4 +5,10 @@ require (
github.com/stretchr/testify v1.7.0
)
go 1.13
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
go 1.18

View File

@ -208,18 +208,41 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
return nil, err
}
// add nonzero field values as defaults
// for backwards compatibility, add nonzero field values as defaults
for _, spec := range cmd.specs {
if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
str, err := defaultVal.MarshalText()
if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
spec.defaultVal = string(str)
} else {
spec.defaultVal = fmt.Sprintf("%v", v)
// do not read default when UnmarshalText is implemented but not MarshalText
if isTextUnmarshaler(spec.field.Type) && !isTextMarshaler(spec.field.Type) {
continue
}
// do not process types that require multiple values
cardinality, _ := cardinalityOf(spec.field.Type)
if cardinality != one {
continue
}
// get the value
v := p.val(spec.dest)
if !v.IsValid() {
continue
}
// if MarshalText is implemented then use that
if m, ok := v.Interface().(encoding.TextMarshaler); ok {
if v.IsNil() {
continue
}
s, err := m.MarshalText()
if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
spec.defaultVal = string(s)
continue
}
// finally, use the value as a default if it is non-zero
if !isZero(v) {
spec.defaultVal = fmt.Sprintf("%v", v)
}
}

View File

@ -2,6 +2,7 @@ package arg
import (
"bytes"
"encoding/json"
"fmt"
"net"
"net/mail"
@ -1456,3 +1457,68 @@ func TestMustParsePrintsVersion(t *testing.T) {
assert.Equal(t, 0, *exitCode)
assert.Equal(t, "example 3.2.1\n", b.String())
}
type jsonMap struct {
val map[string]string
}
func (v *jsonMap) UnmarshalText(data []byte) error {
return json.Unmarshal(data, &v.val)
}
func TestTextUnmarshallerEmpty(t *testing.T) {
// based on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config jsonMap `arg:"--config"`
}
err := parse("", &args)
require.NoError(t, err)
assert.Empty(t, args.Config)
}
func TestTextUnmarshallerEmptyPointer(t *testing.T) {
// a slight variant on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config *jsonMap `arg:"--config"`
}
err := parse("", &args)
require.NoError(t, err)
assert.Nil(t, args.Config)
}
// similar to the above but also implements MarshalText
type jsonMap2[T any] struct {
val T
}
func (v *jsonMap2[T]) MarshalText(data []byte) error {
return json.Unmarshal(data, &v.val)
}
func (v *jsonMap2[T]) UnmarshalText(data []byte) error {
return json.Unmarshal(data, &v.val)
}
func TestTextMarshallerUnmarshallerEmpty(t *testing.T) {
// based on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config jsonMap2[map[string]string] `arg:"--config"`
}
err := parse("", &args)
require.NoError(t, err)
assert.Empty(t, args.Config)
}
func TestTextMarshallerUnmarshallerEmptyPointer(t *testing.T) {
// a slight variant on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config *jsonMap2[map[string]string] `arg:"--config"`
}
err := parse("", &args)
require.NoError(t, err)
assert.Nil(t, args.Config)
}

View File

@ -10,7 +10,10 @@ import (
scalar "github.com/alexflint/go-scalar"
)
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
var (
textMarshalerType = reflect.TypeOf([]encoding.TextMarshaler{}).Elem()
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
)
// cardinality tracks how many tokens are expected for a given spec
// - zero is a boolean, which does to expect any value
@ -74,10 +77,10 @@ func cardinalityOf(t reflect.Type) (cardinality, error) {
}
}
// isBoolean returns true if the type can be parsed from a single string
// isBoolean returns true if the type is a boolean or a pointer to a boolean
func isBoolean(t reflect.Type) bool {
switch {
case t.Implements(textUnmarshalerType):
case isTextUnmarshaler(t):
return false
case t.Kind() == reflect.Bool:
return true
@ -88,6 +91,16 @@ func isBoolean(t reflect.Type) bool {
}
}
// isTextMarshaler returns true if the type or its pointer implements encoding.TextMarshaler
func isTextMarshaler(t reflect.Type) bool {
return t.Implements(textMarshalerType) || reflect.PtrTo(t).Implements(textMarshalerType)
}
// isTextUnmarshaler returns true if the type or its pointer implements encoding.TextUnmarshaler
func isTextUnmarshaler(t reflect.Type) bool {
return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType)
}
// isExported returns true if the struct field name is exported
func isExported(field string) bool {
r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8

View File

@ -50,7 +50,7 @@ Options:
--optimize OPTIMIZE, -O OPTIMIZE
optimization level
--ids IDS Ids
--values VALUES Values [default: [3.14 42 256]]
--values VALUES Values
--workers WORKERS, -w WORKERS
number of workers to start [default: 10, env: WORKERS]
--testenv TESTENV, -a TESTENV [env: TEST_ENV]
@ -74,7 +74,6 @@ Options:
}
args.Name = "Foo Bar"
args.Value = 42
args.Values = []float64{3.14, 42, 256}
args.File = &NameDotName{"scratch", "txt"}
p, err := NewParser(Config{Program: "example"}, &args)
require.NoError(t, err)