Compare commits

...

5 Commits

Author SHA1 Message Date
Mikhail Sennikovsky 6bf1193815
Merge a0423c9897 into 4eb1370754 2025-03-11 02:15:38 +08:00
Jan Schär 4eb1370754
Split set elements into batches if needed (#303)
If the number of elements to be added to or removed from a set is large,
they may not all fit into one message, because the size field of a
netlink attribute is a uint16 and would overflow. To support this case,
the elements need to be split into multiple batches.
2025-03-03 12:55:40 +01:00
Jan Schär 385f80f4ef Use const instead of var where possible 2025-02-26 15:11:55 +01:00
Jan Schär 594585af33 Initialize registers in test
Recent kernels disallow reads from uninitialized registers, which breaks
this test.

See 14fb07130c
2025-02-26 15:11:55 +01:00
Mikhail Sennikovsky a0423c9897 Fix set verdict data type unmarshalling
Currently unmarshalling sets with "verdict" data type results in
the "verdict" type to be set as the key type, and the data type
remaining zero.

Properly set the verdict type to Set DataType field instead of
the KeyType.

Signed-off-by: Mikhail Sennikovsky <mikhail.sennikovskii@ionos.com>
2024-11-18 16:23:28 +01:00
6 changed files with 136 additions and 84 deletions

3
gen.go
View File

@ -3,6 +3,7 @@ package nftables
import (
"encoding/binary"
"fmt"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
@ -13,7 +14,7 @@ type GenMsg struct {
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
}
var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
func genFromMsg(msg netlink.Message) (*GenMsg, error) {
if got, want := msg.Header.Type, genHeaderType; got != want {

View File

@ -622,6 +622,14 @@ func TestMasqMarshalUnmarshal(t *testing.T) {
Table: filter,
Chain: postrouting,
Exprs: []expr.Any{
&expr.Immediate{
Register: min,
Data: binaryutil.BigEndian.PutUint16(4070),
},
&expr.Immediate{
Register: max,
Data: binaryutil.BigEndian.PutUint16(4090),
},
&expr.Masq{
ToPorts: true,
RegProtoMin: min,
@ -652,13 +660,13 @@ func TestMasqMarshalUnmarshal(t *testing.T) {
}
rule := rules[0]
if got, want := len(rule.Exprs), 1; got != want {
if got, want := len(rule.Exprs), 3; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
me, ok := rule.Exprs[0].(*expr.Masq)
me, ok := rule.Exprs[2].(*expr.Masq)
if !ok {
t.Fatalf("unexpected expression type: got %T, want *expr.Masq", rule.Exprs[0])
t.Fatalf("unexpected expression type: got %T, want *expr.Masq", rule.Exprs[2])
}
if got, want := me.ToPorts, true; got != want {
@ -3857,7 +3865,58 @@ func TestIP6SetAddElements(t *testing.T) {
t.Errorf("c.GetSetElements(portSet) failed: %v", err)
}
if len(elements) != 2 {
t.Fatalf("len(portSetElements) = %d, want 2", len(sets))
t.Fatalf("len(portSetElements) = %d, want 2", len(elements))
}
}
func TestSetElementBatching(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
portSet := &nftables.Set{
Table: filter,
Name: "ports",
KeyType: nftables.TypeInetService,
}
// The 5000 elements will need to be split into 3 batches to make each batch
// fit into a message.
elements := make([]nftables.SetElement, 5000)
for i := range elements {
elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i))
elements[i].Comment = "0123456789"
}
if err := c.AddSet(portSet, elements); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
gotElements, err := c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSetElements(portSet) failed: %v", err)
}
if len(gotElements) != len(elements) {
t.Errorf("len(gotElements) = %d, want %d", len(gotElements), len(elements))
}
gotNumbers := make([]bool, len(elements))
for _, element := range gotElements {
gotNumbers[binaryutil.BigEndian.Uint16(element.Key)] = true
}
for i := range gotNumbers {
if !gotNumbers[i] {
t.Errorf("Missing element %d", i)
break
}
}
}

2
obj.go
View File

@ -25,7 +25,7 @@ import (
"golang.org/x/sys/unix"
)
var (
const (
newObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ)
delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ)
)

View File

@ -25,7 +25,7 @@ import (
"golang.org/x/sys/unix"
)
var (
const (
newRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE)
delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE)
)

144
set.go
View File

@ -18,6 +18,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"math"
"strings"
"time"
@ -166,7 +167,9 @@ var (
TypeTimeDay,
TypeCGroupV2,
}
)
const (
// ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c.
ctLabelBitSize uint32 = 128
@ -377,24 +380,31 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
elements, err := s.makeElemList(vals, s.ID)
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
})
return nil
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
}
func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) {
// SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
}
// maxElemBatchSize is the maximum size in bytes of encoded set elements which
// are sent in one netlink message. The size field of a netlink attribute is a
// uint16, and 1024 bytes is more than enough for the per-message headers.
const maxElemBatchSize = math.MaxUint16 - 1024
func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error {
if len(vals) == 0 {
return nil
}
var elements []netlink.Attribute
batchSize := 0
var batches [][]netlink.Attribute
for i, v := range vals {
item := make([]netlink.Attribute, 0)
@ -406,14 +416,14 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
if err != nil {
return nil, fmt.Errorf("marshal key %d: %v", i, err)
return fmt.Errorf("marshal key %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if len(v.KeyEnd) > 0 {
encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
if err != nil {
return nil, fmt.Errorf("marshal key end %d: %v", i, err)
return fmt.Errorf("marshal key end %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
}
@ -433,7 +443,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
})
if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %v", i, err)
}
encodedVal = append(encodedVal, encodedKind...)
if len(v.VerdictData.Chain) != 0 {
@ -441,21 +451,21 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
{Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
})
if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %v", i, err)
}
encodedVal = append(encodedVal, encodedChain...)
}
encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
case len(v.Val) > 0:
// Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %v", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
@ -471,22 +481,42 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
encodedItem, err := netlink.MarshalAttributes(item)
if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %v", i, err)
}
itemSize := unix.NLA_HDRLEN + len(encodedItem)
if batchSize+itemSize > maxElemBatchSize {
batches = append(batches, elements)
elements = nil
batchSize = 0
}
elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem})
batchSize += itemSize
}
batches = append(batches, elements)
encodedElem, err := netlink.MarshalAttributes(elements)
if err != nil {
return nil, fmt.Errorf("marshal elements: %v", err)
for _, batch := range batches {
encodedElem, err := netlink.MarshalAttributes(batch)
if err != nil {
return fmt.Errorf("marshal elements: %v", err)
}
message := []netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
{Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(message)...),
})
}
return []netlink.Attribute{
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
{Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)},
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
}, nil
return nil
}
// AddSet adds the specified Set.
@ -662,22 +692,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
})
// Set the values of the set if initial values were provided.
if len(vals) > 0 {
hdrType := unix.NFT_MSG_NEWSETELEM
elements, err := s.makeElemList(vals, s.ID)
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
})
}
return nil
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
}
// DelSet deletes a specific set, along with all elements it contains.
@ -697,29 +712,6 @@ func (cc *Conn) DelSet(s *Set) {
})
}
// SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
elements, err := s.makeElemList(vals, s.ID)
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
},
Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...),
})
return nil
}
// FlushSet deletes all data points from an nftables set.
func (cc *Conn) FlushSet(s *Set) {
cc.mu.Lock()
@ -737,7 +729,7 @@ func (cc *Conn) FlushSet(s *Set) {
})
}
var (
const (
newSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET)
delSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET)
)
@ -784,7 +776,7 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
nftMagic := ad.Uint32()
// Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1
if nftMagic == 0xffffff00 {
set.KeyType = TypeVerdict
set.DataType = TypeVerdict
break
}
dt, err := parseSetDatatype(nftMagic)
@ -837,7 +829,7 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) {
return dt, nil
}
var (
const (
newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM)
)
@ -975,8 +967,8 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
})
if err != nil {
return nil, err

View File

@ -21,7 +21,7 @@ import (
"golang.org/x/sys/unix"
)
var (
const (
newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE)
delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE)
)