Refactor decoding set elements (#47)

This commit is contained in:
Serguei Bezverkhi 2019-08-13 16:19:49 -04:00 committed by Michael Stapelberg
parent 1435f3a62c
commit 579fe47a77
1 changed files with 59 additions and 12 deletions

71
set.go
View File

@ -76,6 +76,56 @@ type SetElement struct {
IntervalEnd bool
}
func (s *SetElement) decode() func(b []byte) error {
return func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return fmt.Errorf("failed to create nested attribute decoder: %v", err)
}
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_SET_ELEM_KEY:
s.Key, err = decodeElement(ad.Bytes())
if err != nil {
return err
}
case unix.NFTA_SET_ELEM_DATA:
s.Val, err = decodeElement(ad.Bytes())
if err != nil {
return err
}
case unix.NFTA_SET_ELEM_FLAGS:
flags := ad.Uint32()
s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0
}
}
return ad.Err()
}
}
func decodeElement(d []byte) ([]byte, error) {
ad, err := netlink.NewAttributeDecoder(d)
if err != nil {
return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err)
}
ad.ByteOrder = binary.BigEndian
var b []byte
for ad.Next() {
switch ad.Type() {
case unix.NFTA_SET_ELEM_KEY:
fallthrough
case unix.NFTA_SET_ELEM_DATA:
b = ad.Bytes()
}
}
if err := ad.Err(); err != nil {
return nil, err
}
return b, nil
}
// SetAddElements applies data points to an nftables set.
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
if s.Anonymous {
@ -105,7 +155,7 @@ func (s *Set) makeElemList(vals []SetElement) ([]netlink.Attribute, error) {
var flags uint32
if v.IntervalEnd {
flags |= unix.NFT_SET_ELEM_INTERVAL_END
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)})
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)})
}
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_SET_ELEM_KEY, Data: v.Key}})
@ -282,6 +332,8 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.Name = ad.String()
case unix.NFTA_SET_ID:
set.ID = binary.BigEndian.Uint32(ad.Bytes())
case unix.NFTA_SET_DATA_LEN:
set.DataLen = int(ad.Uint32())
case unix.NFTA_SET_FLAGS:
flags := ad.Uint32()
set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0
@ -318,26 +370,21 @@ func elementsFromMsg(msg netlink.Message) ([]SetElement, error) {
var elements []SetElement
for ad.Next() {
b := ad.Bytes()
if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS && len(b) > 8 {
ad, err := netlink.NewAttributeDecoder(b[8:])
if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return nil, err
}
ad.ByteOrder = binary.BigEndian
var elem SetElement
for ad.Next() {
var elem SetElement
switch ad.Type() {
case unix.NFTA_SET_ELEM_KEY:
elem.Key = ad.Bytes()
case unix.NFTA_SET_ELEM_DATA:
elem.Val = ad.Bytes()
case unix.NFTA_SET_ELEM_FLAGS:
flags := ad.Uint32()
elem.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0
case unix.NFTA_LIST_ELEM:
ad.Do(elem.decode())
}
elements = append(elements, elem)
}
elements = append(elements, elem)
}
}
return elements, nil