fix: ensure Parser.val cannot be nil

Used when adding values to subcommands as
these are undefined pointers initially.

This refactoring is in preparation of
parsing group structs later.
This commit is contained in:
Sebastiaan Pasterkamp 2022-12-10 14:59:23 +01:00
parent 00c1c8e7cd
commit ca8dc31b84
3 changed files with 63 additions and 31 deletions

View File

@ -216,27 +216,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
// is the reason that this method for setting default values was deprecated)
for _, spec := range p.cmd.specs {
// get the value
v := p.val(spec.dest)
defaultString, defaultValue, err := p.defaultVal(spec.dest)
if err != nil {
return nil, err
}
// if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore
if isZero(v) {
if defaultString == "" {
continue
}
// store as a default
spec.defaultValue = v
// we need a string to display in help text
// if MarshalText is implemented then use that
if m, ok := v.Interface().(encoding.TextMarshaler); ok {
s, err := m.MarshalText()
if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
spec.defaultString = string(s)
} else {
spec.defaultString = fmt.Sprintf("%v", v)
}
spec.defaultString = defaultString
spec.defaultValue = defaultValue
}
if dest, ok := dest.(Versioned); ok {
@ -575,11 +567,8 @@ func (p *Parser) process(args []string) error {
return fmt.Errorf("invalid subcommand: %s", arg)
}
// instantiate the field to point to a new struct
v := p.val(subcmd.dest)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
}
// ensure the command struct exists (is not a nil pointer)
p.val(subcmd.dest)
// add the new options to the set of allowed options
specs = append(specs, subcmd.specs...)
@ -743,20 +732,57 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
// val returns a reflect.Value corresponding to the current value for the
// given path
func (p *Parser) val(dest path) reflect.Value {
// defaultVal returns the string representation of the value at dest if it is
// reachable without traversing nil pointers, but only if it does not represent
// the default value for the type.
func (p *Parser) defaultVal(dest path) (string, reflect.Value, error) {
v := p.roots[dest.root]
for _, field := range dest.fields {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return reflect.Value{}
return "", v, nil
}
v = v.Elem()
}
v = v.FieldByIndex(field.Index)
}
if !v.IsValid() || isZero(v) {
return "", v, nil
}
if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
str, err := defaultVal.MarshalText()
if err != nil {
return "", v, fmt.Errorf("%v: error marshaling default value to string: %w", dest, err)
}
return string(str), v, nil
}
return fmt.Sprintf("%v", v), v, nil
}
// val returns a reflect.Value corresponding to the current value for the
// given path initiating nil pointers in the path
func (p *Parser) val(dest path) reflect.Value {
v := p.roots[dest.root]
for _, field := range dest.fields {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
v = v.FieldByIndex(field.Index)
}
// Don't return a nil-pointer
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return v
}

View File

@ -930,6 +930,17 @@ func TestParserMustParse(t *testing.T) {
}
}
func TestNonPointerSubcommand(t *testing.T) {
var args struct {
Sub struct {
Foo string `arg:"env"`
} `arg:"subcommand"`
}
_, err := NewParser(Config{IgnoreEnv: true}, &args)
require.Error(t, err, "subcommands must be pointers to structs but args.Sub is a struct")
}
type textUnmarshaler struct {
val int
}

View File

@ -1,7 +1,6 @@
package arg
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@ -402,12 +401,8 @@ func TestValForNilStruct(t *testing.T) {
Sub *subcmd `arg:"subcommand"`
}
p, err := NewParser(Config{}, &cmd)
_, err := NewParser(Config{}, &cmd)
require.NoError(t, err)
typ := reflect.TypeOf(cmd)
subField, _ := typ.FieldByName("Sub")
v := p.val(path{fields: []reflect.StructField{subField, subField}})
assert.False(t, v.IsValid())
require.Nil(t, cmd.Sub)
}