From 9857ffe35c89cd56481d60f722f1c5f5f0abb7db Mon Sep 17 00:00:00 2001 From: nickgarlis Date: Fri, 22 Aug 2025 23:28:11 +0200 Subject: [PATCH] Add GetGen method to retrieve current generation ID nftables uses generation IDs (gen IDs) for optimistic concurrency control. This commit adds a GetGen method to expose current gen ID so that users can retrieve it explicitly. Typical usage: 1. Call GetGen to retrieve current gen ID. 2. Read the the current state. 3. Send the batch along with the gen ID by calling Flush. If the state changes before the flush, the kernel will reject the batch, preventing stale writes. - https://wiki.nftables.org/wiki-nftables/index.php/Portal:DeveloperDocs/nftables_internals#Batched_handlers - https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen - https://github.com/torvalds/linux/blob/3957a5720157264dcc41415fbec7c51c4000fc2d/net/netfilter/nfnetlink.c#L424 --- conn.go | 35 +++++++++++++++++++---- gen.go | 49 +++++++++++++++++++++++++++++--- monitor_test.go | 2 +- nftables_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 149 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index c6c85bb..90382b7 100644 --- a/conn.go +++ b/conn.go @@ -232,7 +232,9 @@ func (cc *Conn) CloseLasting() error { } // Flush sends all buffered commands in a single batch to nftables. -func (cc *Conn) Flush() error { +// If an optional gen ID is provided, it will be used in the batch begin message. +// If the gen ID is not matched by the kernel, it will return an ERESTART error. +func (cc *Conn) Flush(genID ...uint32) error { cc.mu.Lock() defer func() { cc.messages = nil @@ -259,7 +261,12 @@ func (cc *Conn) Flush() error { return err } - messages, err := conn.SendMessages(batch(cc.messages)) + batch, err := batch(cc.messages, genID...) + if err != nil { + return err + } + + messages, err := conn.SendMessages(batch) if err != nil { return fmt.Errorf("SendMessages: %w", err) } @@ -388,14 +395,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { return b } -func batch(messages []netlinkMessage) []netlink.Message { +// Batch wraps the given messages in a batch begin and end message, and returns +// the resulting slice of netlink messages. If a genID is provided, it is +// included in both batch messages. +func batch(messages []netlinkMessage, genID ...uint32) ([]netlink.Message, error) { batch := make([]netlink.Message, len(messages)+2) + + data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES) + + if len(genID) > 0 && genID[0] > 0 { + attr, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID[0])}, + }) + if err != nil { + return nil, err + } + data = append(data, attr...) + } + batch[0] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Flags: netlink.Request, }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + Data: data, } for i, msg := range messages { @@ -410,10 +433,10 @@ func batch(messages []netlinkMessage) []netlink.Message { Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + Data: data, } - return batch + return batch, nil } // allocateTransactionID allocates an identifier which is only valid in the diff --git a/gen.go b/gen.go index 0d4580d..52f1cd3 100644 --- a/gen.go +++ b/gen.go @@ -8,15 +8,17 @@ import ( "golang.org/x/sys/unix" ) -type GenMsg struct { +type Gen struct { ID uint32 ProcPID uint32 ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN } +type GenMsg = Gen + const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) -func genFromMsg(msg netlink.Message) (*GenMsg, error) { +func genFromMsg(msg netlink.Message) (*Gen, error) { if got, want := msg.Header.Type, genHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } @@ -26,7 +28,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { } ad.ByteOrder = binary.BigEndian - msgOut := &GenMsg{} + msgOut := &Gen{} for ad.Next() { switch ad.Type() { case unix.NFTA_GEN_ID: @@ -36,7 +38,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { case unix.NFTA_GEN_PROC_NAME: msgOut.ProcComm = ad.String() default: - return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes()) } } if err := ad.Err(); err != nil { @@ -44,3 +46,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { } return msgOut, nil } + +// GetGen retrieves the current nftables generation ID together with the name +// and ID of the process that last modified the ruleset. +// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen +func (cc *Conn) GetGen() (*Gen, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_GEN_ID}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(0, 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + if len(reply) == 0 { + return nil, fmt.Errorf("receiveAckAware: no reply") + } + return genFromMsg(reply[0]) +} diff --git a/monitor_test.go b/monitor_test.go index 5640e13..8735961 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) { return } - genMsg := event.GeneratedBy.Data.(*nftables.GenMsg) + genMsg := event.GeneratedBy.Data.(*nftables.Gen) fileName := filepath.Base(os.Args[0]) if genMsg.ProcComm != fileName { diff --git a/nftables_test.go b/nftables_test.go index fe12566..1280f7c 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7429,3 +7429,77 @@ func TestAutoBufferSize(t *testing.T) { t.Fatalf("failed to flush: %v", err) } } + +func TestGetGen(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + gen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Flush to increment the generation ID. + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + newGen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + if newGen.ID <= gen.ID { + t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID) + } + + if newGen.ProcComm != gen.ProcComm { + t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm) + } + + if newGen.ProcPID != gen.ProcPID { + t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID) + } +} + +func TestFlushWithGenID(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + gen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Flush to increment the generation ID. + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table-2", + Family: nftables.TableFamilyIPv4, + }) + + err = conn.Flush(gen.ID) + if err == nil || !errors.Is(err, syscall.ERESTART) { + t.Errorf("expected error to be ERESTART, got: %v", err) + } + + table, err := conn.ListTable("test-table-2") + if table != nil && !errors.Is(err, syscall.ENOENT) { + t.Errorf("expected table to not exist, got: %v", table) + } +}