From 4eb137075435b9a50fb7b84e6280f3165a7dd9fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Sch=C3=A4r?= Date: Mon, 3 Mar 2025 12:55:40 +0100 Subject: [PATCH] Split set elements into batches if needed (#303) If the number of elements to be added to or removed from a set is large, they may not all fit into one message, because the size field of a netlink attribute is a uint16 and would overflow. To support this case, the elements need to be split into multiple batches. --- nftables_test.go | 53 +++++++++++++++++- set.go | 136 ++++++++++++++++++++++------------------------- 2 files changed, 115 insertions(+), 74 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index d1603c1..0cd2d59 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3865,7 +3865,58 @@ func TestIP6SetAddElements(t *testing.T) { t.Errorf("c.GetSetElements(portSet) failed: %v", err) } if len(elements) != 2 { - t.Fatalf("len(portSetElements) = %d, want 2", len(sets)) + t.Fatalf("len(portSetElements) = %d, want 2", len(elements)) + } +} + +func TestSetElementBatching(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + portSet := &nftables.Set{ + Table: filter, + Name: "ports", + KeyType: nftables.TypeInetService, + } + // The 5000 elements will need to be split into 3 batches to make each batch + // fit into a message. + elements := make([]nftables.SetElement, 5000) + for i := range elements { + elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i)) + elements[i].Comment = "0123456789" + } + if err := c.AddSet(portSet, elements); err != nil { + t.Errorf("c.AddSet(portSet) failed: %v", err) + } + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + gotElements, err := c.GetSetElements(portSet) + if err != nil { + t.Errorf("c.GetSetElements(portSet) failed: %v", err) + } + if len(gotElements) != len(elements) { + t.Errorf("len(gotElements) = %d, want %d", len(gotElements), len(elements)) + } + gotNumbers := make([]bool, len(elements)) + for _, element := range gotElements { + gotNumbers[binaryutil.BigEndian.Uint16(element.Key)] = true + } + for i := range gotNumbers { + if !gotNumbers[i] { + t.Errorf("Missing element %d", i) + break + } } } diff --git a/set.go b/set.go index 446132f..cccdcba 100644 --- a/set.go +++ b/set.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "errors" "fmt" + "math" "strings" "time" @@ -379,24 +380,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if s.Anonymous { return errors.New("anonymous sets cannot be updated") } - - elements, err := s.makeElemList(vals, s.ID) - if err != nil { - return err - } - cc.messages = append(cc.messages, netlink.Message{ - Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, - }, - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), - }) - - return nil + return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) } -func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { +// SetDeleteElements deletes data points from an nftables set. +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + if s.Anonymous { + return errors.New("anonymous sets cannot be updated") + } + return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) +} + +// maxElemBatchSize is the maximum size in bytes of encoded set elements which +// are sent in one netlink message. The size field of a netlink attribute is a +// uint16, and 1024 bytes is more than enough for the per-message headers. +const maxElemBatchSize = math.MaxUint16 - 1024 + +func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error { + if len(vals) == 0 { + return nil + } var elements []netlink.Attribute + batchSize := 0 + var batches [][]netlink.Attribute for i, v := range vals { item := make([]netlink.Attribute, 0) @@ -408,14 +416,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) if err != nil { - return nil, fmt.Errorf("marshal key %d: %v", i, err) + return fmt.Errorf("marshal key %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) if len(v.KeyEnd) > 0 { encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) if err != nil { - return nil, fmt.Errorf("marshal key end %d: %v", i, err) + return fmt.Errorf("marshal key end %d: %v", i, err) } item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) } @@ -435,7 +443,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, }) if err != nil { - return nil, fmt.Errorf("marshal item %d: %v", i, err) + return fmt.Errorf("marshal item %d: %v", i, err) } encodedVal = append(encodedVal, encodedKind...) if len(v.VerdictData.Chain) != 0 { @@ -443,21 +451,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, }) if err != nil { - return nil, fmt.Errorf("marshal item %d: %v", i, err) + return fmt.Errorf("marshal item %d: %v", i, err) } encodedVal = append(encodedVal, encodedChain...) } encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) if err != nil { - return nil, fmt.Errorf("marshal item %d: %v", i, err) + return fmt.Errorf("marshal item %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) case len(v.Val) > 0: // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) if err != nil { - return nil, fmt.Errorf("marshal item %d: %v", i, err) + return fmt.Errorf("marshal item %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) @@ -473,22 +481,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e encodedItem, err := netlink.MarshalAttributes(item) if err != nil { - return nil, fmt.Errorf("marshal item %d: %v", i, err) + return fmt.Errorf("marshal item %d: %v", i, err) + } + + itemSize := unix.NLA_HDRLEN + len(encodedItem) + if batchSize+itemSize > maxElemBatchSize { + batches = append(batches, elements) + elements = nil + batchSize = 0 } elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) + batchSize += itemSize } + batches = append(batches, elements) - encodedElem, err := netlink.MarshalAttributes(elements) - if err != nil { - return nil, fmt.Errorf("marshal elements: %v", err) + for _, batch := range batches { + encodedElem, err := netlink.MarshalAttributes(batch) + if err != nil { + return fmt.Errorf("marshal elements: %v", err) + } + + message := []netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")}, + {Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, + {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, + } + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...), + }) } - - return []netlink.Attribute{ - {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, - {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)}, - {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, - {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, - }, nil + return nil } // AddSet adds the specified Set. @@ -664,22 +692,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { }) // Set the values of the set if initial values were provided. - if len(vals) > 0 { - hdrType := unix.NFT_MSG_NEWSETELEM - elements, err := s.makeElemList(vals, s.ID) - if err != nil { - return err - } - cc.messages = append(cc.messages, netlink.Message{ - Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, - }, - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), - }) - } - - return nil + return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) } // DelSet deletes a specific set, along with all elements it contains. @@ -699,29 +712,6 @@ func (cc *Conn) DelSet(s *Set) { }) } -// SetDeleteElements deletes data points from an nftables set. -func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { - cc.mu.Lock() - defer cc.mu.Unlock() - if s.Anonymous { - return errors.New("anonymous sets cannot be updated") - } - - elements, err := s.makeElemList(vals, s.ID) - if err != nil { - return err - } - cc.messages = append(cc.messages, netlink.Message{ - Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, - }, - Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), - }) - - return nil -} - // FlushSet deletes all data points from an nftables set. func (cc *Conn) FlushSet(s *Set) { cc.mu.Lock() @@ -977,8 +967,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, - {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")}, }) if err != nil { return nil, err