diff --git a/conn.go b/conn.go index c6c85bb..67e0f72 100644 --- a/conn.go +++ b/conn.go @@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error { // Flush sends all buffered commands in a single batch to nftables. func (cc *Conn) Flush() error { + return cc.flush(0) +} + +// FlushWithGenID sends all buffered commands in a single batch to nftables +// along with the provided gen ID. If the ruleset has changed since the gen ID +// was retrieved, an ERESTART error will be returned. +func (cc *Conn) FlushWithGenID(genID uint32) error { + return cc.flush(genID) +} + +// flush sends all buffered commands in a single batch to nftables. If genID is +// non-zero, it will be included in the batch messages. +func (cc *Conn) flush(genID uint32) error { cc.mu.Lock() defer func() { cc.messages = nil @@ -259,7 +272,12 @@ func (cc *Conn) Flush() error { return err } - messages, err := conn.SendMessages(batch(cc.messages)) + batch, err := batch(cc.messages, genID) + if err != nil { + return err + } + + messages, err := conn.SendMessages(batch) if err != nil { return fmt.Errorf("SendMessages: %w", err) } @@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { return b } -func batch(messages []netlinkMessage) []netlink.Message { +// batch wraps the given messages in a batch begin and end message, and returns +// the resulting slice of netlink messages. If the genID is non-zero, it will be +// included in both batch messages. +func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) { batch := make([]netlink.Message, len(messages)+2) + + data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES) + + if genID > 0 { + attr, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)}, + }) + if err != nil { + return nil, err + } + data = append(data, attr...) + } + batch[0] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Flags: netlink.Request, }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + Data: data, } for i, msg := range messages { @@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message { Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + Data: data, } - return batch + return batch, nil } // allocateTransactionID allocates an identifier which is only valid in the diff --git a/gen.go b/gen.go index 0d4580d..0d88e15 100644 --- a/gen.go +++ b/gen.go @@ -8,15 +8,18 @@ import ( "golang.org/x/sys/unix" ) -type GenMsg struct { +type Gen struct { ID uint32 ProcPID uint32 ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN } +// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen. +type GenMsg = Gen + const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) -func genFromMsg(msg netlink.Message) (*GenMsg, error) { +func genFromMsg(msg netlink.Message) (*Gen, error) { if got, want := msg.Header.Type, genHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } @@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { } ad.ByteOrder = binary.BigEndian - msgOut := &GenMsg{} + msgOut := &Gen{} for ad.Next() { switch ad.Type() { case unix.NFTA_GEN_ID: @@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { case unix.NFTA_GEN_PROC_NAME: msgOut.ProcComm = ad.String() default: - return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes()) } } if err := ad.Err(); err != nil { @@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) { } return msgOut, nil } + +// GetGen retrieves the current nftables generation ID together with the name +// and ID of the process that last modified the ruleset. +// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen +func (cc *Conn) GetGen() (*Gen, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_GEN_ID}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(0, 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + if len(reply) == 0 { + return nil, fmt.Errorf("receiveAckAware: no reply") + } + return genFromMsg(reply[0]) +} diff --git a/monitor_test.go b/monitor_test.go index 5640e13..8735961 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) { return } - genMsg := event.GeneratedBy.Data.(*nftables.GenMsg) + genMsg := event.GeneratedBy.Data.(*nftables.Gen) fileName := filepath.Base(os.Args[0]) if genMsg.ProcComm != fileName { diff --git a/nftables_test.go b/nftables_test.go index fe12566..bc433b4 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7429,3 +7429,77 @@ func TestAutoBufferSize(t *testing.T) { t.Fatalf("failed to flush: %v", err) } } + +func TestGetGen(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + gen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Flush to increment the generation ID. + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + newGen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + if newGen.ID <= gen.ID { + t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID) + } + + if newGen.ProcComm != gen.ProcComm { + t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm) + } + + if newGen.ProcPID != gen.ProcPID { + t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID) + } +} + +func TestFlushWithGenID(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + gen, err := conn.GetGen() + if err != nil { + t.Fatalf("failed to get gen: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Flush to increment the generation ID. + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + conn.AddTable(&nftables.Table{ + Name: "test-table-2", + Family: nftables.TableFamilyIPv4, + }) + + err = conn.FlushWithGenID(gen.ID) + if err == nil || !errors.Is(err, syscall.ERESTART) { + t.Errorf("expected error to be ERESTART, got: %v", err) + } + + table, err := conn.ListTable("test-table-2") + if table != nil && !errors.Is(err, syscall.ENOENT) { + t.Errorf("expected table to not exist, got: %v", table) + } +}