diff --git a/set.go b/set.go index bb207d4..e16406f 100644 --- a/set.go +++ b/set.go @@ -76,6 +76,56 @@ type SetElement struct { IntervalEnd bool } +func (s *SetElement) decode() func(b []byte) error { + return func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return fmt.Errorf("failed to create nested attribute decoder: %v", err) + } + ad.ByteOrder = binary.BigEndian + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_ELEM_KEY: + s.Key, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case unix.NFTA_SET_ELEM_DATA: + s.Val, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case unix.NFTA_SET_ELEM_FLAGS: + flags := ad.Uint32() + s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 + } + } + return ad.Err() + } +} + +func decodeElement(d []byte) ([]byte, error) { + ad, err := netlink.NewAttributeDecoder(d) + if err != nil { + return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err) + } + ad.ByteOrder = binary.BigEndian + var b []byte + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_ELEM_KEY: + fallthrough + case unix.NFTA_SET_ELEM_DATA: + b = ad.Bytes() + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return b, nil +} + // SetAddElements applies data points to an nftables set. func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if s.Anonymous { @@ -105,7 +155,7 @@ func (s *Set) makeElemList(vals []SetElement) ([]netlink.Attribute, error) { var flags uint32 if v.IntervalEnd { flags |= unix.NFT_SET_ELEM_INTERVAL_END - item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)}) } encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_SET_ELEM_KEY, Data: v.Key}}) @@ -282,6 +332,8 @@ func setsFromMsg(msg netlink.Message) (*Set, error) { set.Name = ad.String() case unix.NFTA_SET_ID: set.ID = binary.BigEndian.Uint32(ad.Bytes()) + case unix.NFTA_SET_DATA_LEN: + set.DataLen = int(ad.Uint32()) case unix.NFTA_SET_FLAGS: flags := ad.Uint32() set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0 @@ -318,26 +370,21 @@ func elementsFromMsg(msg netlink.Message) ([]SetElement, error) { var elements []SetElement for ad.Next() { b := ad.Bytes() - if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS && len(b) > 8 { - ad, err := netlink.NewAttributeDecoder(b[8:]) + if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { + ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err } ad.ByteOrder = binary.BigEndian - var elem SetElement for ad.Next() { + var elem SetElement switch ad.Type() { - case unix.NFTA_SET_ELEM_KEY: - elem.Key = ad.Bytes() - case unix.NFTA_SET_ELEM_DATA: - elem.Val = ad.Bytes() - case unix.NFTA_SET_ELEM_FLAGS: - flags := ad.Uint32() - elem.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 + case unix.NFTA_LIST_ELEM: + ad.Do(elem.decode()) } + elements = append(elements, elem) } - elements = append(elements, elem) } } return elements, nil