Allow to use values (not pointers) with TextUnmarshaler

The patch makes sure that both values and pointer to values are checked
for custom TextUnmarshal implementation. This will allow to use go-arg
custom parsing as follows:

var args struct {
  Arg CustomType
}

instead of

var args struct {
  Arg *CustomType
}

Signed-off-by: Pavel Borzenkov <pavel.borzenkov@gmail.com>
This commit is contained in:
Pavel Borzenkov 2018-11-16 17:09:52 +03:00
parent e80c3b7ed2
commit 38f8eb7c6b
2 changed files with 20 additions and 1 deletions

View File

@ -47,6 +47,13 @@ func ParseValue(v reflect.Value, s string) error {
if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok { if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s)) return scalar.UnmarshalText([]byte(s))
} }
// If it's a value instead of a pointer, check that we can unmarshal it
// via TextUnmarshaler as well
if v.CanAddr() {
if scalar, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s))
}
}
// If we have a pointer then dereference it // If we have a pointer then dereference it
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
@ -126,7 +133,7 @@ func ParseValue(v reflect.Value, s string) error {
// CanParse returns true if the type can be parsed from a string. // CanParse returns true if the type can be parsed from a string.
func CanParse(t reflect.Type) bool { func CanParse(t reflect.Type) bool {
// If it implements encoding.TextUnmarshaler then use that // If it implements encoding.TextUnmarshaler then use that
if t.Implements(textUnmarshalerType) { if t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) {
return true return true
} }

View File

@ -10,6 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type textUnmarshaler struct {
val int
}
func (f *textUnmarshaler) UnmarshalText(b []byte) error {
f.val = len(b)
return nil
}
func assertParse(t *testing.T, expected interface{}, str string) { func assertParse(t *testing.T, expected interface{}, str string) {
v := reflect.New(reflect.TypeOf(expected)).Elem() v := reflect.New(reflect.TypeOf(expected)).Elem()
err := ParseValue(v, str) err := ParseValue(v, str)
@ -67,6 +76,9 @@ func TestParseValue(t *testing.T) {
// MAC addresses // MAC addresses
assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab") assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab")
// custom text unmarshaler
assertParse(t, textUnmarshaler{3}, "abc")
} }
func TestParse(t *testing.T) { func TestParse(t *testing.T) {