store both a default value and a string representation of that default value in the spec for each option

This commit is contained in:
Alex Flint 2022-10-29 14:47:13 -04:00
parent 197e226c77
commit 27c832b934
4 changed files with 102 additions and 74 deletions

130
parse.go
View File

@ -43,18 +43,19 @@ func (p path) Child(f reflect.StructField) path {
// spec represents a command line option // spec represents a command line option
type spec struct { type spec struct {
dest path dest path
field reflect.StructField // the struct field from which this option was created field reflect.StructField // the struct field from which this option was created
long string // the --long form for this option, or empty if none long string // the --long form for this option, or empty if none
short string // the -s short form for this option, or empty if none short string // the -s short form for this option, or empty if none
cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple) cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
required bool // if true, this option must be present on the command line required bool // if true, this option must be present on the command line
positional bool // if true, this option will be looked for in the positional flags positional bool // if true, this option will be looked for in the positional flags
separate bool // if true, each slice and map entry will have its own --flag separate bool // if true, each slice and map entry will have its own --flag
help string // the help text for this option help string // the help text for this option
env string // the name of the environment variable for this option, or empty for none env string // the name of the environment variable for this option, or empty for none
defaultVal string // default value for this option defaultValue reflect.Value // default value for this option
placeholder string // name of the data in help defaultString string // default value for this option, in string form to be displayed in help text
placeholder string // name of the data in help
} }
// command represents a named subcommand, or the top-level command // command represents a named subcommand, or the top-level command
@ -210,39 +211,30 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
// for backwards compatibility, add nonzero field values as defaults // for backwards compatibility, add nonzero field values as defaults
for _, spec := range cmd.specs { for _, spec := range cmd.specs {
// do not read default when UnmarshalText is implemented but not MarshalText
if isTextUnmarshaler(spec.field.Type) && !isTextMarshaler(spec.field.Type) {
continue
}
// do not process types that require multiple values
cardinality, _ := cardinalityOf(spec.field.Type)
if cardinality != one {
continue
}
// get the value // get the value
v := p.val(spec.dest) v := p.val(spec.dest)
if !v.IsValid() { if !v.IsValid() {
continue continue
} }
// if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore
if isZero(v) {
continue
}
// store as a default
spec.defaultValue = v
// we need a string to display in help text
// if MarshalText is implemented then use that // if MarshalText is implemented then use that
if m, ok := v.Interface().(encoding.TextMarshaler); ok { if m, ok := v.Interface().(encoding.TextMarshaler); ok {
if v.IsNil() {
continue
}
s, err := m.MarshalText() s, err := m.MarshalText()
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
} }
spec.defaultVal = string(s) spec.defaultString = string(s)
continue } else {
} spec.defaultString = fmt.Sprintf("%v", v)
// finally, use the value as a default if it is non-zero
if !isZero(v) {
spec.defaultVal = fmt.Sprintf("%v", v)
} }
} }
@ -311,11 +303,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.help = help spec.help = help
} }
defaultVal, hasDefault := field.Tag.Lookup("default")
if hasDefault {
spec.defaultVal = defaultVal
}
// Look at the tag // Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand var isSubcommand bool // tracks whether this field is a subcommand
for _, key := range strings.Split(tag, ",") { for _, key := range strings.Split(tag, ",") {
@ -342,11 +329,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
} }
spec.short = key[1:] spec.short = key[1:]
case key == "required": case key == "required":
if hasDefault {
errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
t.Name(), field.Name))
return false
}
spec.required = true spec.required = true
case key == "positional": case key == "positional":
spec.positional = true spec.positional = true
@ -395,27 +377,60 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.placeholder = strings.ToUpper(spec.field.Name) spec.placeholder = strings.ToUpper(spec.field.Name)
} }
// Check whether this field is supported. It's good to do this here rather than // if this is a subcommand then we've done everything we need to do
if isSubcommand {
return false
}
// check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument // wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received // fields will always fail regardless of whether the arguments it received
// exercised those fields. // exercised those fields.
if !isSubcommand { var err error
cmd.specs = append(cmd.specs, &spec) spec.cardinality, err = cardinalityOf(field.Type)
if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
}
var err error defaultString, hasDefault := field.Tag.Lookup("default")
spec.cardinality, err = cardinalityOf(field.Type) if hasDefault {
if err != nil { // we do not support default values for maps and slices
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported", if spec.cardinality == multiple {
t.Name(), field.Name, field.Type.String()))
return false
}
if spec.cardinality == multiple && hasDefault {
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields", errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
t.Name(), field.Name)) t.Name(), field.Name))
return false return false
} }
// a required field cannot also have a default value
if spec.required {
errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
t.Name(), field.Name))
return false
}
// parse the default value
spec.defaultString = defaultString
if field.Type.Kind() == reflect.Pointer {
// here we have a field of type *T and we create a new T, no need to dereference
// in order for the value to be settable
spec.defaultValue = reflect.New(field.Type.Elem())
} else {
// here we have a field of type T and we create a new T and then dereference it
// so that the resulting value is settable
spec.defaultValue = reflect.New(field.Type).Elem()
}
err := scalar.ParseValue(spec.defaultValue, defaultString)
if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: error processing default value: %v", t.Name(), field.Name, err))
return false
}
} }
// add the spec to the list of specs
cmd.specs = append(cmd.specs, &spec)
// if this was an embedded field then we already returned true up above // if this was an embedded field then we already returned true up above
return false return false
}) })
@ -682,11 +697,8 @@ func (p *Parser) process(args []string) error {
} }
return errors.New(msg) return errors.New(msg)
} }
if spec.defaultVal != "" { if spec.defaultValue.IsValid() {
err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) p.val(spec.dest).Set(spec.defaultValue)
if err != nil {
return fmt.Errorf("error processing default value for %s: %v", name, err)
}
} }
} }

View File

@ -1321,13 +1321,21 @@ func TestDefaultOptionValues(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 123, args.A) assert.Equal(t, 123, args.A)
assert.Equal(t, 123, *args.B) if assert.NotNil(t, args.B) {
assert.Equal(t, 123, *args.B)
}
assert.Equal(t, "xyz", args.C) assert.Equal(t, "xyz", args.C)
assert.Equal(t, "abc", *args.D) if assert.NotNil(t, args.D) {
assert.Equal(t, "abc", *args.D)
}
assert.Equal(t, 4.56, args.E) assert.Equal(t, 4.56, args.E)
assert.Equal(t, 1.23, *args.F) if assert.NotNil(t, args.F) {
assert.True(t, args.G) assert.Equal(t, 1.23, *args.F)
}
assert.True(t, args.G) assert.True(t, args.G)
if assert.NotNil(t, args.H) {
assert.True(t, *args.H)
}
} }
func TestDefaultUnparseable(t *testing.T) { func TestDefaultUnparseable(t *testing.T) {
@ -1336,7 +1344,7 @@ func TestDefaultUnparseable(t *testing.T) {
} }
err := parse("", &args) err := parse("", &args)
assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`) assert.EqualError(t, err, `.A: error processing default value: strconv.ParseInt: parsing "x": invalid syntax`)
} }
func TestDefaultPositionalValues(t *testing.T) { func TestDefaultPositionalValues(t *testing.T) {
@ -1355,13 +1363,21 @@ func TestDefaultPositionalValues(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 456, args.A) assert.Equal(t, 456, args.A)
assert.Equal(t, 789, *args.B) if assert.NotNil(t, args.B) {
assert.Equal(t, 789, *args.B)
}
assert.Equal(t, "abc", args.C) assert.Equal(t, "abc", args.C)
assert.Equal(t, "abc", *args.D) if assert.NotNil(t, args.D) {
assert.Equal(t, "abc", *args.D)
}
assert.Equal(t, 1.23, args.E) assert.Equal(t, 1.23, args.E)
assert.Equal(t, 1.23, *args.F) if assert.NotNil(t, args.F) {
assert.True(t, args.G) assert.Equal(t, 1.23, *args.F)
}
assert.True(t, args.G) assert.True(t, args.G)
if assert.NotNil(t, args.H) {
assert.True(t, *args.H)
}
} }
func TestDefaultValuesNotAllowedWithRequired(t *testing.T) { func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
@ -1375,7 +1391,7 @@ func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
var args struct { var args struct {
A []int `default:"123"` // required not allowed with default! A []int `default:"invalid"` // default values not allowed with slices
} }
err := parse("", &args) err := parse("", &args)

View File

@ -16,9 +16,9 @@ var (
) )
// cardinality tracks how many tokens are expected for a given spec // cardinality tracks how many tokens are expected for a given spec
// - zero is a boolean, which does to expect any value // - zero is a boolean, which does to expect any value
// - one is an ordinary option that will be parsed from a single token // - one is an ordinary option that will be parsed from a single token
// - multiple is a slice or map that can accept zero or more tokens // - multiple is a slice or map that can accept zero or more tokens
type cardinality int type cardinality int
const ( const (
@ -110,7 +110,7 @@ func isExported(field string) bool {
// isZero returns true if v contains the zero value for its type // isZero returns true if v contains the zero value for its type
func isZero(v reflect.Value) bool { func isZero(v reflect.Value) bool {
t := v.Type() t := v.Type()
if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { if t.Kind() == reflect.Pointer || t.Kind() == reflect.Slice || t.Kind() == reflect.Map || t.Kind() == reflect.Chan || t.Kind() == reflect.Interface {
return v.IsNil() return v.IsNil()
} }
if !t.Comparable() { if !t.Comparable() {

View File

@ -474,7 +474,7 @@ Options:
ShortOnly2 string `arg:"-b,--,required" help:"some help2"` ShortOnly2 string `arg:"-b,--,required" help:"some help2"`
} }
p, err := NewParser(Config{Program: "example"}, &args) p, err := NewParser(Config{Program: "example"}, &args)
assert.NoError(t, err) require.NoError(t, err)
var help bytes.Buffer var help bytes.Buffer
p.WriteHelp(&help) p.WriteHelp(&help)