Compare commits

..

4 Commits

Author SHA1 Message Date
Mikhail Sennikovsky 6bf1193815
Merge a0423c9897 into 4eb1370754 2025-03-11 02:15:38 +08:00
Jan Schär 4eb1370754
Split set elements into batches if needed (#303)
If the number of elements to be added to or removed from a set is large,
they may not all fit into one message, because the size field of a
netlink attribute is a uint16 and would overflow. To support this case,
the elements need to be split into multiple batches.
2025-03-03 12:55:40 +01:00
Jan Schär 385f80f4ef Use const instead of var where possible 2025-02-26 15:11:55 +01:00
Jan Schär 594585af33 Initialize registers in test
Recent kernels disallow reads from uninitialized registers, which breaks
this test.

See 14fb07130c
2025-02-26 15:11:55 +01:00
6 changed files with 135 additions and 83 deletions

3
gen.go
View File

@ -3,6 +3,7 @@ package nftables
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -13,7 +14,7 @@ type GenMsg struct {
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
} }
var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
func genFromMsg(msg netlink.Message) (*GenMsg, error) { func genFromMsg(msg netlink.Message) (*GenMsg, error) {
if got, want := msg.Header.Type, genHeaderType; got != want { if got, want := msg.Header.Type, genHeaderType; got != want {

View File

@ -622,6 +622,14 @@ func TestMasqMarshalUnmarshal(t *testing.T) {
Table: filter, Table: filter,
Chain: postrouting, Chain: postrouting,
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Immediate{
Register: min,
Data: binaryutil.BigEndian.PutUint16(4070),
},
&expr.Immediate{
Register: max,
Data: binaryutil.BigEndian.PutUint16(4090),
},
&expr.Masq{ &expr.Masq{
ToPorts: true, ToPorts: true,
RegProtoMin: min, RegProtoMin: min,
@ -652,13 +660,13 @@ func TestMasqMarshalUnmarshal(t *testing.T) {
} }
rule := rules[0] rule := rules[0]
if got, want := len(rule.Exprs), 1; got != want { if got, want := len(rule.Exprs), 3; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
} }
me, ok := rule.Exprs[0].(*expr.Masq) me, ok := rule.Exprs[2].(*expr.Masq)
if !ok { if !ok {
t.Fatalf("unexpected expression type: got %T, want *expr.Masq", rule.Exprs[0]) t.Fatalf("unexpected expression type: got %T, want *expr.Masq", rule.Exprs[2])
} }
if got, want := me.ToPorts, true; got != want { if got, want := me.ToPorts, true; got != want {
@ -3857,7 +3865,58 @@ func TestIP6SetAddElements(t *testing.T) {
t.Errorf("c.GetSetElements(portSet) failed: %v", err) t.Errorf("c.GetSetElements(portSet) failed: %v", err)
} }
if len(elements) != 2 { if len(elements) != 2 {
t.Fatalf("len(portSetElements) = %d, want 2", len(sets)) t.Fatalf("len(portSetElements) = %d, want 2", len(elements))
}
}
func TestSetElementBatching(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
portSet := &nftables.Set{
Table: filter,
Name: "ports",
KeyType: nftables.TypeInetService,
}
// The 5000 elements will need to be split into 3 batches to make each batch
// fit into a message.
elements := make([]nftables.SetElement, 5000)
for i := range elements {
elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i))
elements[i].Comment = "0123456789"
}
if err := c.AddSet(portSet, elements); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
gotElements, err := c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSetElements(portSet) failed: %v", err)
}
if len(gotElements) != len(elements) {
t.Errorf("len(gotElements) = %d, want %d", len(gotElements), len(elements))
}
gotNumbers := make([]bool, len(elements))
for _, element := range gotElements {
gotNumbers[binaryutil.BigEndian.Uint16(element.Key)] = true
}
for i := range gotNumbers {
if !gotNumbers[i] {
t.Errorf("Missing element %d", i)
break
}
} }
} }

2
obj.go
View File

@ -25,7 +25,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var ( const (
newObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) newObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ)
delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ) delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ)
) )

View File

@ -25,7 +25,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var ( const (
newRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) newRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE)
delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE) delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE)
) )

142
set.go
View File

@ -18,6 +18,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"math"
"strings" "strings"
"time" "time"
@ -166,7 +167,9 @@ var (
TypeTimeDay, TypeTimeDay,
TypeCGroupV2, TypeCGroupV2,
} }
)
const (
// ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c. // ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c.
ctLabelBitSize uint32 = 128 ctLabelBitSize uint32 = 128
@ -377,24 +380,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
if s.Anonymous { if s.Anonymous {
return errors.New("anonymous sets cannot be updated") return errors.New("anonymous sets cannot be updated")
} }
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
elements, err := s.makeElemList(vals, s.ID)
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
})
return nil
} }
func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { // SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
}
// maxElemBatchSize is the maximum size in bytes of encoded set elements which
// are sent in one netlink message. The size field of a netlink attribute is a
// uint16, and 1024 bytes is more than enough for the per-message headers.
const maxElemBatchSize = math.MaxUint16 - 1024
func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error {
if len(vals) == 0 {
return nil
}
var elements []netlink.Attribute var elements []netlink.Attribute
batchSize := 0
var batches [][]netlink.Attribute
for i, v := range vals { for i, v := range vals {
item := make([]netlink.Attribute, 0) item := make([]netlink.Attribute, 0)
@ -406,14 +416,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal key %d: %v", i, err) return fmt.Errorf("marshal key %d: %v", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if len(v.KeyEnd) > 0 { if len(v.KeyEnd) > 0 {
encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal key end %d: %v", i, err) return fmt.Errorf("marshal key end %d: %v", i, err)
} }
item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
} }
@ -433,7 +443,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %v", i, err)
} }
encodedVal = append(encodedVal, encodedKind...) encodedVal = append(encodedVal, encodedKind...)
if len(v.VerdictData.Chain) != 0 { if len(v.VerdictData.Chain) != 0 {
@ -441,21 +451,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
{Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %v", i, err)
} }
encodedVal = append(encodedVal, encodedChain...) encodedVal = append(encodedVal, encodedChain...)
} }
encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %v", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
case len(v.Val) > 0: case len(v.Val) > 0:
// Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %v", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
@ -471,22 +481,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
encodedItem, err := netlink.MarshalAttributes(item) encodedItem, err := netlink.MarshalAttributes(item)
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %v", i, err)
}
itemSize := unix.NLA_HDRLEN + len(encodedItem)
if batchSize+itemSize > maxElemBatchSize {
batches = append(batches, elements)
elements = nil
batchSize = 0
} }
elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem})
batchSize += itemSize
} }
batches = append(batches, elements)
encodedElem, err := netlink.MarshalAttributes(elements) for _, batch := range batches {
if err != nil { encodedElem, err := netlink.MarshalAttributes(batch)
return nil, fmt.Errorf("marshal elements: %v", err) if err != nil {
return fmt.Errorf("marshal elements: %v", err)
}
message := []netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
{Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...),
})
} }
return nil
return []netlink.Attribute{
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
{Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)},
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
}, nil
} }
// AddSet adds the specified Set. // AddSet adds the specified Set.
@ -662,22 +692,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
}) })
// Set the values of the set if initial values were provided. // Set the values of the set if initial values were provided.
if len(vals) > 0 { return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
hdrType := unix.NFT_MSG_NEWSETELEM
elements, err := s.makeElemList(vals, s.ID)
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
})
}
return nil
} }
// DelSet deletes a specific set, along with all elements it contains. // DelSet deletes a specific set, along with all elements it contains.
@ -697,29 +712,6 @@ func (cc *Conn) DelSet(s *Set) {
}) })
} }
// SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
elements, err := s.makeElemList(vals, s.ID)
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
})
return nil
}
// FlushSet deletes all data points from an nftables set. // FlushSet deletes all data points from an nftables set.
func (cc *Conn) FlushSet(s *Set) { func (cc *Conn) FlushSet(s *Set) {
cc.mu.Lock() cc.mu.Lock()
@ -737,7 +729,7 @@ func (cc *Conn) FlushSet(s *Set) {
}) })
} }
var ( const (
newSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) newSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET)
delSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET) delSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET)
) )
@ -837,7 +829,7 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) {
return dt, nil return dt, nil
} }
var ( const (
newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM) delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM)
) )
@ -975,8 +967,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
defer func() { _ = closer() }() defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, {Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -21,7 +21,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var ( const (
newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE)
delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE)
) )