diff --git a/parse.go b/parse.go index 00e79b7..0175977 100644 --- a/parse.go +++ b/parse.go @@ -216,27 +216,19 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { // is the reason that this method for setting default values was deprecated) for _, spec := range p.cmd.specs { // get the value - v := p.val(spec.dest) + defaultString, defaultValue, err := p.defaultVal(spec.dest) + if err != nil { + return nil, err + } // if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore - if isZero(v) { + if defaultString == "" { continue } // store as a default - spec.defaultValue = v - - // we need a string to display in help text - // if MarshalText is implemented then use that - if m, ok := v.Interface().(encoding.TextMarshaler); ok { - s, err := m.MarshalText() - if err != nil { - return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err) - } - spec.defaultString = string(s) - } else { - spec.defaultString = fmt.Sprintf("%v", v) - } + spec.defaultString = defaultString + spec.defaultValue = defaultValue } if dest, ok := dest.(Versioned); ok { @@ -575,11 +567,8 @@ func (p *Parser) process(args []string) error { return fmt.Errorf("invalid subcommand: %s", arg) } - // instantiate the field to point to a new struct - v := p.val(subcmd.dest) - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers - } + // ensure the command struct exists (is not a nil pointer) + p.val(subcmd.dest) // add the new options to the set of allowed options specs = append(specs, subcmd.specs...) @@ -743,20 +732,57 @@ func isFlag(s string) bool { return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != "" } -// val returns a reflect.Value corresponding to the current value for the -// given path -func (p *Parser) val(dest path) reflect.Value { +// defaultVal returns the string representation of the value at dest if it is +// reachable without traversing nil pointers, but only if it does not represent +// the default value for the type. +func (p *Parser) defaultVal(dest path) (string, reflect.Value, error) { v := p.roots[dest.root] for _, field := range dest.fields { if v.Kind() == reflect.Ptr { if v.IsNil() { - return reflect.Value{} + return "", v, nil } v = v.Elem() } v = v.FieldByIndex(field.Index) } + + if !v.IsValid() || isZero(v) { + return "", v, nil + } + + if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok { + str, err := defaultVal.MarshalText() + if err != nil { + return "", v, fmt.Errorf("%v: error marshaling default value to string: %w", dest, err) + } + return string(str), v, nil + } + + return fmt.Sprintf("%v", v), v, nil +} + +// val returns a reflect.Value corresponding to the current value for the +// given path initiating nil pointers in the path +func (p *Parser) val(dest path) reflect.Value { + v := p.roots[dest.root] + for _, field := range dest.fields { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + v = v.FieldByIndex(field.Index) + } + + // Don't return a nil-pointer + if v.Kind() == reflect.Ptr && v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return v } diff --git a/parse_test.go b/parse_test.go index 3029760..2fa3d92 100644 --- a/parse_test.go +++ b/parse_test.go @@ -930,6 +930,17 @@ func TestParserMustParse(t *testing.T) { } } +func TestNonPointerSubcommand(t *testing.T) { + var args struct { + Sub struct { + Foo string `arg:"env"` + } `arg:"subcommand"` + } + + _, err := NewParser(Config{IgnoreEnv: true}, &args) + require.Error(t, err, "subcommands must be pointers to structs but args.Sub is a struct") +} + type textUnmarshaler struct { val int } diff --git a/subcommand_test.go b/subcommand_test.go index 2c61dd3..c3909e6 100644 --- a/subcommand_test.go +++ b/subcommand_test.go @@ -1,7 +1,6 @@ package arg import ( - "reflect" "testing" "github.com/stretchr/testify/assert" @@ -402,12 +401,8 @@ func TestValForNilStruct(t *testing.T) { Sub *subcmd `arg:"subcommand"` } - p, err := NewParser(Config{}, &cmd) + _, err := NewParser(Config{}, &cmd) require.NoError(t, err) - typ := reflect.TypeOf(cmd) - subField, _ := typ.FieldByName("Sub") - - v := p.val(path{fields: []reflect.StructField{subField, subField}}) - assert.False(t, v.IsValid()) + require.Nil(t, cmd.Sub) }