fix: ensure Parser.val cannot be nil

Used when adding values to subcommands as
these are undefined pointers initially.

This refactoring is in preparation of
parsing group structs later.
This commit is contained in:
Sebastiaan Pasterkamp 2022-12-10 14:59:23 +01:00
parent 00c1c8e7cd
commit ca8dc31b84
3 changed files with 63 additions and 31 deletions

View File

@ -216,27 +216,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
// is the reason that this method for setting default values was deprecated) // is the reason that this method for setting default values was deprecated)
for _, spec := range p.cmd.specs { for _, spec := range p.cmd.specs {
// get the value // get the value
v := p.val(spec.dest) defaultString, defaultValue, err := p.defaultVal(spec.dest)
if err != nil {
return nil, err
}
// if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore // if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore
if isZero(v) { if defaultString == "" {
continue continue
} }
// store as a default // store as a default
spec.defaultValue = v spec.defaultString = defaultString
spec.defaultValue = defaultValue
// we need a string to display in help text
// if MarshalText is implemented then use that
if m, ok := v.Interface().(encoding.TextMarshaler); ok {
s, err := m.MarshalText()
if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
spec.defaultString = string(s)
} else {
spec.defaultString = fmt.Sprintf("%v", v)
}
} }
if dest, ok := dest.(Versioned); ok { if dest, ok := dest.(Versioned); ok {
@ -575,11 +567,8 @@ func (p *Parser) process(args []string) error {
return fmt.Errorf("invalid subcommand: %s", arg) return fmt.Errorf("invalid subcommand: %s", arg)
} }
// instantiate the field to point to a new struct // ensure the command struct exists (is not a nil pointer)
v := p.val(subcmd.dest) p.val(subcmd.dest)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
}
// add the new options to the set of allowed options // add the new options to the set of allowed options
specs = append(specs, subcmd.specs...) specs = append(specs, subcmd.specs...)
@ -743,20 +732,57 @@ func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
} }
// val returns a reflect.Value corresponding to the current value for the // defaultVal returns the string representation of the value at dest if it is
// given path // reachable without traversing nil pointers, but only if it does not represent
func (p *Parser) val(dest path) reflect.Value { // the default value for the type.
func (p *Parser) defaultVal(dest path) (string, reflect.Value, error) {
v := p.roots[dest.root] v := p.roots[dest.root]
for _, field := range dest.fields { for _, field := range dest.fields {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.IsNil() { if v.IsNil() {
return reflect.Value{} return "", v, nil
} }
v = v.Elem() v = v.Elem()
} }
v = v.FieldByIndex(field.Index) v = v.FieldByIndex(field.Index)
} }
if !v.IsValid() || isZero(v) {
return "", v, nil
}
if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
str, err := defaultVal.MarshalText()
if err != nil {
return "", v, fmt.Errorf("%v: error marshaling default value to string: %w", dest, err)
}
return string(str), v, nil
}
return fmt.Sprintf("%v", v), v, nil
}
// val returns a reflect.Value corresponding to the current value for the
// given path initiating nil pointers in the path
func (p *Parser) val(dest path) reflect.Value {
v := p.roots[dest.root]
for _, field := range dest.fields {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
v = v.FieldByIndex(field.Index)
}
// Don't return a nil-pointer
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return v return v
} }

View File

@ -930,6 +930,17 @@ func TestParserMustParse(t *testing.T) {
} }
} }
func TestNonPointerSubcommand(t *testing.T) {
var args struct {
Sub struct {
Foo string `arg:"env"`
} `arg:"subcommand"`
}
_, err := NewParser(Config{IgnoreEnv: true}, &args)
require.Error(t, err, "subcommands must be pointers to structs but args.Sub is a struct")
}
type textUnmarshaler struct { type textUnmarshaler struct {
val int val int
} }

View File

@ -1,7 +1,6 @@
package arg package arg
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -402,12 +401,8 @@ func TestValForNilStruct(t *testing.T) {
Sub *subcmd `arg:"subcommand"` Sub *subcmd `arg:"subcommand"`
} }
p, err := NewParser(Config{}, &cmd) _, err := NewParser(Config{}, &cmd)
require.NoError(t, err) require.NoError(t, err)
typ := reflect.TypeOf(cmd) require.Nil(t, cmd.Sub)
subField, _ := typ.FieldByName("Sub")
v := p.val(path{fields: []reflect.StructField{subField, subField}})
assert.False(t, v.IsValid())
} }