fix: ensure Parser.val cannot be nil
Used when adding values to subcommands as these are undefined pointers initially. This refactoring is in preparation of parsing group structs later.
This commit is contained in:
parent
00c1c8e7cd
commit
ca8dc31b84
74
parse.go
74
parse.go
|
@ -216,27 +216,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
|||
// is the reason that this method for setting default values was deprecated)
|
||||
for _, spec := range p.cmd.specs {
|
||||
// get the value
|
||||
v := p.val(spec.dest)
|
||||
defaultString, defaultValue, err := p.defaultVal(spec.dest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore
|
||||
if isZero(v) {
|
||||
if defaultString == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// store as a default
|
||||
spec.defaultValue = v
|
||||
|
||||
// we need a string to display in help text
|
||||
// if MarshalText is implemented then use that
|
||||
if m, ok := v.Interface().(encoding.TextMarshaler); ok {
|
||||
s, err := m.MarshalText()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
|
||||
}
|
||||
spec.defaultString = string(s)
|
||||
} else {
|
||||
spec.defaultString = fmt.Sprintf("%v", v)
|
||||
}
|
||||
spec.defaultString = defaultString
|
||||
spec.defaultValue = defaultValue
|
||||
}
|
||||
|
||||
if dest, ok := dest.(Versioned); ok {
|
||||
|
@ -575,11 +567,8 @@ func (p *Parser) process(args []string) error {
|
|||
return fmt.Errorf("invalid subcommand: %s", arg)
|
||||
}
|
||||
|
||||
// instantiate the field to point to a new struct
|
||||
v := p.val(subcmd.dest)
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
|
||||
}
|
||||
// ensure the command struct exists (is not a nil pointer)
|
||||
p.val(subcmd.dest)
|
||||
|
||||
// add the new options to the set of allowed options
|
||||
specs = append(specs, subcmd.specs...)
|
||||
|
@ -743,20 +732,57 @@ func isFlag(s string) bool {
|
|||
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
|
||||
}
|
||||
|
||||
// val returns a reflect.Value corresponding to the current value for the
|
||||
// given path
|
||||
func (p *Parser) val(dest path) reflect.Value {
|
||||
// defaultVal returns the string representation of the value at dest if it is
|
||||
// reachable without traversing nil pointers, but only if it does not represent
|
||||
// the default value for the type.
|
||||
func (p *Parser) defaultVal(dest path) (string, reflect.Value, error) {
|
||||
v := p.roots[dest.root]
|
||||
for _, field := range dest.fields {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return reflect.Value{}
|
||||
return "", v, nil
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
v = v.FieldByIndex(field.Index)
|
||||
}
|
||||
|
||||
if !v.IsValid() || isZero(v) {
|
||||
return "", v, nil
|
||||
}
|
||||
|
||||
if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
|
||||
str, err := defaultVal.MarshalText()
|
||||
if err != nil {
|
||||
return "", v, fmt.Errorf("%v: error marshaling default value to string: %w", dest, err)
|
||||
}
|
||||
return string(str), v, nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", v), v, nil
|
||||
}
|
||||
|
||||
// val returns a reflect.Value corresponding to the current value for the
|
||||
// given path initiating nil pointers in the path
|
||||
func (p *Parser) val(dest path) reflect.Value {
|
||||
v := p.roots[dest.root]
|
||||
for _, field := range dest.fields {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
v = v.FieldByIndex(field.Index)
|
||||
}
|
||||
|
||||
// Don't return a nil-pointer
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
|
|
|
@ -930,6 +930,17 @@ func TestParserMustParse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNonPointerSubcommand(t *testing.T) {
|
||||
var args struct {
|
||||
Sub struct {
|
||||
Foo string `arg:"env"`
|
||||
} `arg:"subcommand"`
|
||||
}
|
||||
|
||||
_, err := NewParser(Config{IgnoreEnv: true}, &args)
|
||||
require.Error(t, err, "subcommands must be pointers to structs but args.Sub is a struct")
|
||||
}
|
||||
|
||||
type textUnmarshaler struct {
|
||||
val int
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -402,12 +401,8 @@ func TestValForNilStruct(t *testing.T) {
|
|||
Sub *subcmd `arg:"subcommand"`
|
||||
}
|
||||
|
||||
p, err := NewParser(Config{}, &cmd)
|
||||
_, err := NewParser(Config{}, &cmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
typ := reflect.TypeOf(cmd)
|
||||
subField, _ := typ.FieldByName("Sub")
|
||||
|
||||
v := p.val(path{fields: []reflect.StructField{subField, subField}})
|
||||
assert.False(t, v.IsValid())
|
||||
require.Nil(t, cmd.Sub)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue