Compare commits

..

3 Commits

Author SHA1 Message Date
Paul Greenberg 96d06a061a
Merge dae73eaa9c into 1148f1a84f 2025-09-02 14:21:22 +02:00
Nikita Vorontsov 1148f1a84f
add DestroyTable and SetDestroyElements (#322)
These methods are like their DeleteTable and SetDeleteElements counterparts, but they do not return an error if the specified table/set does not exist.
2025-09-02 14:08:18 +02:00
Nick Garlis 3efc75f481
Add GetGen method to retrieve current generation ID (#325)
Add GetGen method to retrieve current generation ID

nftables uses generation IDs (gen IDs) for optimistic concurrency
control. This commit adds a GetGen method to expose current gen ID so
that users can retrieve it explicitly.

Typical usage:
  1. Call GetGen to retrieve current gen ID.
  2. Read the the current state.
  3. Send the batch along with the gen ID by calling Flush.

If the state changes before the flush, the kernel will reject the
batch, preventing stale writes.

- https://wiki.nftables.org/wiki-nftables/index.php/Portal:DeveloperDocs/nftables_internals#Batched_handlers
- https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
- 3957a57201/net/netfilter/nfnetlink.c (L424)
2025-09-02 14:05:05 +02:00
6 changed files with 249 additions and 14 deletions

44
conn.go
View File

@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error {
// Flush sends all buffered commands in a single batch to nftables. // Flush sends all buffered commands in a single batch to nftables.
func (cc *Conn) Flush() error { func (cc *Conn) Flush() error {
return cc.flush(0)
}
// FlushWithGenID sends all buffered commands in a single batch to nftables
// along with the provided gen ID. If the ruleset has changed since the gen ID
// was retrieved, an ERESTART error will be returned.
func (cc *Conn) FlushWithGenID(genID uint32) error {
return cc.flush(genID)
}
// flush sends all buffered commands in a single batch to nftables. If genID is
// non-zero, it will be included in the batch messages.
func (cc *Conn) flush(genID uint32) error {
cc.mu.Lock() cc.mu.Lock()
defer func() { defer func() {
cc.messages = nil cc.messages = nil
@ -259,7 +272,12 @@ func (cc *Conn) Flush() error {
return err return err
} }
messages, err := conn.SendMessages(batch(cc.messages)) batch, err := batch(cc.messages, genID)
if err != nil {
return err
}
messages, err := conn.SendMessages(batch)
if err != nil { if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
} }
@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
return b return b
} }
func batch(messages []netlinkMessage) []netlink.Message { // batch wraps the given messages in a batch begin and end message, and returns
// the resulting slice of netlink messages. If the genID is non-zero, it will be
// included in both batch messages.
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
batch := make([]netlink.Message, len(messages)+2) batch := make([]netlink.Message, len(messages)+2)
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
if genID > 0 {
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
})
if err != nil {
return nil, err
}
data = append(data, attr...)
}
batch[0] = netlink.Message{ batch[0] = netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), Data: data,
} }
for i, msg := range messages { for i, msg := range messages {
@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message {
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), Data: data,
} }
return batch return batch, nil
} }
// allocateTransactionID allocates an identifier which is only valid in the // allocateTransactionID allocates an identifier which is only valid in the

50
gen.go
View File

@ -8,15 +8,18 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type GenMsg struct { type Gen struct {
ID uint32 ID uint32
ProcPID uint32 ProcPID uint32
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
} }
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
type GenMsg = Gen
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
func genFromMsg(msg netlink.Message) (*GenMsg, error) { func genFromMsg(msg netlink.Message) (*Gen, error) {
if got, want := msg.Header.Type, genHeaderType; got != want { if got, want := msg.Header.Type, genHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
} }
@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
} }
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
msgOut := &GenMsg{} msgOut := &Gen{}
for ad.Next() { for ad.Next() {
switch ad.Type() { switch ad.Type() {
case unix.NFTA_GEN_ID: case unix.NFTA_GEN_ID:
@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
case unix.NFTA_GEN_PROC_NAME: case unix.NFTA_GEN_PROC_NAME:
msgOut.ProcComm = ad.String() msgOut.ProcComm = ad.String()
default: default:
return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes())
} }
} }
if err := ad.Err(); err != nil { if err := ad.Err(); err != nil {
@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
} }
return msgOut, nil return msgOut, nil
} }
// GetGen retrieves the current nftables generation ID together with the name
// and ID of the process that last modified the ruleset.
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
func (cc *Conn) GetGen() (*Gen, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_GEN_ID},
})
if err != nil {
return nil, err
}
message := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN),
Flags: netlink.Request | netlink.Acknowledge,
},
Data: append(extraHeader(0, 0), data...),
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
}
if len(reply) == 0 {
return nil, fmt.Errorf("receiveAckAware: no reply")
}
return genFromMsg(reply[0])
}

View File

@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) {
return return
} }
genMsg := event.GeneratedBy.Data.(*nftables.GenMsg) genMsg := event.GeneratedBy.Data.(*nftables.Gen)
fileName := filepath.Base(os.Args[0]) fileName := filepath.Base(os.Args[0])
if genMsg.ProcComm != fileName { if genMsg.ProcComm != fileName {

View File

@ -128,6 +128,47 @@ func ifname(n string) []byte {
return b return b
} }
func TestTableCreateDestroy(t *testing.T) {
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer c.FlushRuleset()
filter := &nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
}
c.DestroyTable(filter)
c.AddTable(filter)
err := c.Flush()
if err != nil {
t.Fatalf("on Flush: %q", err.Error())
}
lookupMyTable := func() bool {
ts, err := c.ListTables()
if err != nil {
t.Fatalf("on ListTables: %q", err.Error())
}
return slices.ContainsFunc(ts, func(t *nftables.Table) bool {
return t.Name == filter.Name && t.Family == filter.Family
})
}
if !lookupMyTable() {
t.Fatal("AddTable doesn't create my table!")
}
c.DestroyTable(filter)
if err = c.Flush(); err != nil {
t.Fatalf("on Flush: %q", err.Error())
}
if lookupMyTable() {
t.Fatal("DestroyTable doesn't delete my table!")
}
c.DestroyTable(filter) // just for test that 'destroy' ignore error 'not found'
}
func TestRuleOperations(t *testing.T) { func TestRuleOperations(t *testing.T) {
// Create a new network namespace to test these operations, // Create a new network namespace to test these operations,
// and tear down the namespace at test completion. // and tear down the namespace at test completion.
@ -3792,7 +3833,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}, {Key: []byte{0, 24}}}); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err) t.Errorf("c.AddSet(portSet) failed: %v", err)
} }
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
@ -3809,6 +3850,22 @@ func TestDeleteElementNamedSet(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("c.GetSets() failed: %v", err) t.Errorf("c.GetSets() failed: %v", err)
} }
if len(elems) != 2 {
t.Fatalf("len(elems) = %d, want 2", len(elems))
}
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 99}}})
if err := c.Flush(); err != nil {
t.Errorf("Third c.Flush() failed: %v", err)
}
elems, err = c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(elems) != 1 { if len(elems) != 1 {
t.Fatalf("len(elems) = %d, want 1", len(elems)) t.Fatalf("len(elems) = %d, want 1", len(elems))
} }
@ -7444,3 +7501,77 @@ func TestAutoBufferSize(t *testing.T) {
t.Fatalf("failed to flush: %v", err) t.Fatalf("failed to flush: %v", err)
} }
} }
func TestGetGen(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
gen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Flush to increment the generation ID.
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
newGen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
if newGen.ID <= gen.ID {
t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID)
}
if newGen.ProcComm != gen.ProcComm {
t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm)
}
if newGen.ProcPID != gen.ProcPID {
t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID)
}
}
func TestFlushWithGenID(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
gen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Flush to increment the generation ID.
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table-2",
Family: nftables.TableFamilyIPv4,
})
err = conn.FlushWithGenID(gen.ID)
if err == nil || !errors.Is(err, syscall.ERESTART) {
t.Errorf("expected error to be ERESTART, got: %v", err)
}
table, err := conn.ListTable("test-table-2")
if table != nil && !errors.Is(err, syscall.ENOENT) {
t.Errorf("expected table to not exist, got: %v", table)
}
}

14
set.go
View File

@ -44,6 +44,9 @@ const (
NFTA_SET_ELEM_KEY_END = 10 NFTA_SET_ELEM_KEY_END = 10
// https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n429 // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n429
NFTA_SET_ELEM_EXPRESSIONS = 0x11 NFTA_SET_ELEM_EXPRESSIONS = 0x11
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYSETELEM const yet.
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
NFT_MSG_DESTROYSETELEM = 0x1e
) )
// SetDatatype represents a datatype declared by nft. // SetDatatype represents a datatype declared by nft.
@ -391,6 +394,16 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
} }
// SetDestroyElements like SetDeleteElements, but not an error if setelement doesn't exists
func (cc *Conn) SetDestroyElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
return cc.appendElemList(s, vals, NFT_MSG_DESTROYSETELEM)
}
// maxElemBatchSize is the maximum size in bytes of encoded set elements which // maxElemBatchSize is the maximum size in bytes of encoded set elements which
// are sent in one netlink message. The size field of a netlink attribute is a // are sent in one netlink message. The size field of a netlink attribute is a
// uint16, and 1024 bytes is more than enough for the per-message headers. // uint16, and 1024 bytes is more than enough for the per-message headers.
@ -828,6 +841,7 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) {
const ( const (
newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM) delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM)
destroyElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DESTROYSETELEM)
) )
func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) {

View File

@ -24,6 +24,10 @@ import (
const ( const (
newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE)
delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE)
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYTABLE const yet.
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
destroyTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | 0x1a)
) )
// TableFamily specifies the address family for this table. // TableFamily specifies the address family for this table.
@ -51,15 +55,25 @@ type Table struct {
// DelTable deletes a specific table, along with all chains/rules it contains. // DelTable deletes a specific table, along with all chains/rules it contains.
func (cc *Conn) DelTable(t *Table) { func (cc *Conn) DelTable(t *Table) {
cc.delTable(t, delTableHeaderType)
}
// DestroyTable is like DelTable, but not an error if table doesn't exists
func (cc *Conn) DestroyTable(t *Table) {
cc.delTable(t, destroyTableHeaderType)
}
func (cc *Conn) delTable(t *Table, hdrType netlink.HeaderType) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{ data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
}) })
cc.messages = append(cc.messages, netlinkMessage{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Type: hdrType,
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,
}, },
Data: append(extraHeader(uint8(t.Family), 0), data...), Data: append(extraHeader(uint8(t.Family), 0), data...),