From 3efc75f4816c254dd84b15ec70ac7823c45ab213 Mon Sep 17 00:00:00 2001 From: Nick Garlis Date: Tue, 2 Sep 2025 14:05:05 +0200 Subject: [PATCH] Add GetGen method to retrieve current generation ID (#325) Add GetGen method to retrieve current generation ID nftables uses generation IDs (gen IDs) for optimistic concurrency control. This commit adds a GetGen method to expose current gen ID so that users can retrieve it explicitly. Typical usage: 1. Call GetGen to retrieve current gen ID. 2. Read the the current state. 3. Send the batch along with the gen ID by calling Flush. If the state changes before the flush, the kernel will reject the batch, preventing stale writes. - https://wiki.nftables.org/wiki-nftables/index.php/Portal:DeveloperDocs/nftables_internals#Batched_handlers - https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen - https://github.com/torvalds/linux/blob/3957a5720157264dcc41415fbec7c51c4000fc2d/net/netfilter/nfnetlink.c#L424 --- conn.go | 44 ++++++++++++++++++++++++---- gen.go | 50 +++++++++++++++++++++++++++++--- monitor_test.go | 2 +- nftables_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 160 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index c6c85bb..67e0f72 100644 --- a/conn.go +++ b/conn.go @@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error { // Flush sends all buffered commands in a single batch to nftables. func (cc *Conn) Flush() error { + return cc.flush(0) +} + +// FlushWithGenID sends all buffered commands in a single batch to nftables +// along with the provided gen ID. If the ruleset has changed since the gen ID +// was retrieved, an ERESTART error will be returned. +func (cc *Conn) FlushWithGenID(genID uint32) error { + return cc.flush(genID) +} + +// flush sends all buffered commands in a single batch to nftables. If genID is +// non-zero, it will be included in the batch messages. +func (cc *Conn) flush(genID uint32) error { cc.mu.Lock() defer func() { cc.messages = nil @@ -259,7 +272,12 @@ func (cc *Conn) Flush() error { return err } - messages, err := conn.SendMessages(batch(cc.messages)) + batch, err := batch(cc.messages, genID) + if err != nil { + return err + } + + messages, err := conn.SendMessages(batch) if err != nil { return fmt.Errorf("SendMessages: %w", err) } @@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { return b } -func batch(messages []netlinkMessage) []netlink.Message { +// batch wraps the given messages in a batch begin and end message, and returns +// the resulting slice of netlink messages. If the genID is non-zero, it will be +// included in both batch messages. +func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) { batch := make([]netlink.Message, len(messages)+2) + + data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES) + + if genID > 0 { + attr, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)}, + }) + if err != nil { + return nil, err + } + data = append(data, attr...) + } + batch[0] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Flags: netlink.Request, }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + Data: data, } for i, msg := range messages { @@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message { Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + Data: data, } - return batch + return batch, nil } // allocateTransactionID allocates an identifier which is only valid in the diff --git a/gen.go b/gen.go index 0d4580d..0d88e15 100644 --- a/gen.go +++ b/gen.go @@ -8,15 +8,18 @@ import ( "golang.org/x/sys/unix" ) -type GenMsg struct { +type Gen struct { ID uint32 ProcPID uint32 ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN } +// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen. +type GenMsg = Gen + const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) -func genFromMsg(msg netlink.Message) (*GenMsg, error) { +func genFromMsg(msg netlink.Message) (*Gen, error) { if got, want := msg.Header.Type, genHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } @@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { } ad.ByteOrder = binary.BigEndian - msgOut := &GenMsg{} + msgOut := &Gen{} for ad.Next() { switch ad.Type() { case unix.NFTA_GEN_ID: @@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { case unix.NFTA_GEN_PROC_NAME: msgOut.ProcComm = ad.String() default: - return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes()) } } if err := ad.Err(); err != nil { @@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { } return msgOut, nil } + +// GetGen retrieves the current nftables generation ID together with the name +// and ID of the process that last modified the ruleset. +// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen +func (cc *Conn) GetGen() (*Gen, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_GEN_ID}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(0, 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + if len(reply) == 0 { + return nil, fmt.Errorf("receiveAckAware: no reply") + } + return genFromMsg(reply[0]) +} diff --git a/monitor_test.go b/monitor_test.go index 5640e13..8735961 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) { return } - genMsg := event.GeneratedBy.Data.(*nftables.GenMsg) + genMsg := event.GeneratedBy.Data.(*nftables.Gen) fileName := filepath.Base(os.Args[0]) if genMsg.ProcComm != fileName { diff --git a/nftables_test.go b/nftables_test.go index fe12566..bc433b4 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7429,3 +7429,77 @@ func TestAutoBufferSize(t *testing.T) { t.Fatalf("failed to flush: %v", err) } } + +func TestGetGen(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + gen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Flush to increment the generation ID. + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + newGen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + if newGen.ID <= gen.ID { + t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID) + } + + if newGen.ProcComm != gen.ProcComm { + t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm) + } + + if newGen.ProcPID != gen.ProcPID { + t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID) + } +} + +func TestFlushWithGenID(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + gen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Flush to increment the generation ID. + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table-2", + Family: nftables.TableFamilyIPv4, + }) + + err = conn.FlushWithGenID(gen.ID) + if err == nil || !errors.Is(err, syscall.ERESTART) { + t.Errorf("expected error to be ERESTART, got: %v", err) + } + + table, err := conn.ListTable("test-table-2") + if table != nil && !errors.Is(err, syscall.ENOENT) { + t.Errorf("expected table to not exist, got: %v", table) + } +}