Merge pull request #2 from pborzenkov/text-unmarshaler-value

Allow to use values (not pointers) with TextUnmarshaler
This commit is contained in:
Alex Flint 2018-11-19 12:48:00 -08:00 committed by GitHub
commit 6ab8ad5e1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 10 deletions

View File

@ -18,7 +18,6 @@ var (
textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
durationType = reflect.TypeOf(time.Duration(0)) durationType = reflect.TypeOf(time.Duration(0))
mailAddressType = reflect.TypeOf(mail.Address{}) mailAddressType = reflect.TypeOf(mail.Address{})
ipType = reflect.TypeOf(net.IP{})
macType = reflect.TypeOf(net.HardwareAddr{}) macType = reflect.TypeOf(net.HardwareAddr{})
) )
@ -47,6 +46,13 @@ func ParseValue(v reflect.Value, s string) error {
if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok { if scalar, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s)) return scalar.UnmarshalText([]byte(s))
} }
// If it's a value instead of a pointer, check that we can unmarshal it
// via TextUnmarshaler as well
if v.CanAddr() {
if scalar, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
return scalar.UnmarshalText([]byte(s))
}
}
// If we have a pointer then dereference it // If we have a pointer then dereference it
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
@ -73,13 +79,6 @@ func ParseValue(v reflect.Value, s string) error {
} }
v.Set(reflect.ValueOf(*addr)) v.Set(reflect.ValueOf(*addr))
return nil return nil
case net.IP:
ip := net.ParseIP(s)
if ip == nil {
return fmt.Errorf(`invalid IP address: "%s"`, s)
}
v.Set(reflect.ValueOf(ip))
return nil
case net.HardwareAddr: case net.HardwareAddr:
ip, err := net.ParseMAC(s) ip, err := net.ParseMAC(s)
if err != nil { if err != nil {
@ -126,7 +125,7 @@ func ParseValue(v reflect.Value, s string) error {
// CanParse returns true if the type can be parsed from a string. // CanParse returns true if the type can be parsed from a string.
func CanParse(t reflect.Type) bool { func CanParse(t reflect.Type) bool {
// If it implements encoding.TextUnmarshaler then use that // If it implements encoding.TextUnmarshaler then use that
if t.Implements(textUnmarshalerType) { if t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType) {
return true return true
} }
@ -137,7 +136,7 @@ func CanParse(t reflect.Type) bool {
// Check for other special types // Check for other special types
switch t { switch t {
case durationType, mailAddressType, ipType, macType: case durationType, mailAddressType, macType:
return true return true
} }

View File

@ -10,6 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type textUnmarshaler struct {
val int
}
func (f *textUnmarshaler) UnmarshalText(b []byte) error {
f.val = len(b)
return nil
}
func assertParse(t *testing.T, expected interface{}, str string) { func assertParse(t *testing.T, expected interface{}, str string) {
v := reflect.New(reflect.TypeOf(expected)).Elem() v := reflect.New(reflect.TypeOf(expected)).Elem()
err := ParseValue(v, str) err := ParseValue(v, str)
@ -67,6 +76,9 @@ func TestParseValue(t *testing.T) {
// MAC addresses // MAC addresses
assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab") assertParse(t, net.HardwareAddr("\x01\x23\x45\x67\x89\xab"), "01:23:45:67:89:ab")
// custom text unmarshaler
assertParse(t, textUnmarshaler{3}, "abc")
} }
func TestParse(t *testing.T) { func TestParse(t *testing.T) {