Add GetGen method to retrieve current generation ID (#325)
Add GetGen method to retrieve current generation ID
nftables uses generation IDs (gen IDs) for optimistic concurrency
control. This commit adds a GetGen method to expose current gen ID so
that users can retrieve it explicitly.
Typical usage:
1. Call GetGen to retrieve current gen ID.
2. Read the the current state.
3. Send the batch along with the gen ID by calling Flush.
If the state changes before the flush, the kernel will reject the
batch, preventing stale writes.
- https://wiki.nftables.org/wiki-nftables/index.php/Portal:DeveloperDocs/nftables_internals#Batched_handlers
- https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
- 3957a57201/net/netfilter/nfnetlink.c (L424)
This commit is contained in:
parent
508bb1ffd4
commit
3efc75f481
44
conn.go
44
conn.go
|
@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error {
|
|||
|
||||
// Flush sends all buffered commands in a single batch to nftables.
|
||||
func (cc *Conn) Flush() error {
|
||||
return cc.flush(0)
|
||||
}
|
||||
|
||||
// FlushWithGenID sends all buffered commands in a single batch to nftables
|
||||
// along with the provided gen ID. If the ruleset has changed since the gen ID
|
||||
// was retrieved, an ERESTART error will be returned.
|
||||
func (cc *Conn) FlushWithGenID(genID uint32) error {
|
||||
return cc.flush(genID)
|
||||
}
|
||||
|
||||
// flush sends all buffered commands in a single batch to nftables. If genID is
|
||||
// non-zero, it will be included in the batch messages.
|
||||
func (cc *Conn) flush(genID uint32) error {
|
||||
cc.mu.Lock()
|
||||
defer func() {
|
||||
cc.messages = nil
|
||||
|
@ -259,7 +272,12 @@ func (cc *Conn) Flush() error {
|
|||
return err
|
||||
}
|
||||
|
||||
messages, err := conn.SendMessages(batch(cc.messages))
|
||||
batch, err := batch(cc.messages, genID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messages, err := conn.SendMessages(batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
|||
return b
|
||||
}
|
||||
|
||||
func batch(messages []netlinkMessage) []netlink.Message {
|
||||
// batch wraps the given messages in a batch begin and end message, and returns
|
||||
// the resulting slice of netlink messages. If the genID is non-zero, it will be
|
||||
// included in both batch messages.
|
||||
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
|
||||
batch := make([]netlink.Message, len(messages)+2)
|
||||
|
||||
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
|
||||
|
||||
if genID > 0 {
|
||||
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data = append(data, attr...)
|
||||
}
|
||||
|
||||
batch[0] = netlink.Message{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
|
||||
Flags: netlink.Request,
|
||||
},
|
||||
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
|
@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message {
|
|||
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
|
||||
Flags: netlink.Request,
|
||||
},
|
||||
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return batch
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
// allocateTransactionID allocates an identifier which is only valid in the
|
||||
|
|
50
gen.go
50
gen.go
|
@ -8,15 +8,18 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type GenMsg struct {
|
||||
type Gen struct {
|
||||
ID uint32
|
||||
ProcPID uint32
|
||||
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
|
||||
}
|
||||
|
||||
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
|
||||
type GenMsg = Gen
|
||||
|
||||
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
|
||||
|
||||
func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
||||
func genFromMsg(msg netlink.Message) (*Gen, error) {
|
||||
if got, want := msg.Header.Type, genHeaderType; got != want {
|
||||
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
|
||||
}
|
||||
|
@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
|||
}
|
||||
ad.ByteOrder = binary.BigEndian
|
||||
|
||||
msgOut := &GenMsg{}
|
||||
msgOut := &Gen{}
|
||||
for ad.Next() {
|
||||
switch ad.Type() {
|
||||
case unix.NFTA_GEN_ID:
|
||||
|
@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
|||
case unix.NFTA_GEN_PROC_NAME:
|
||||
msgOut.ProcComm = ad.String()
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes())
|
||||
return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes())
|
||||
}
|
||||
}
|
||||
if err := ad.Err(); err != nil {
|
||||
|
@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
|||
}
|
||||
return msgOut, nil
|
||||
}
|
||||
|
||||
// GetGen retrieves the current nftables generation ID together with the name
|
||||
// and ID of the process that last modified the ruleset.
|
||||
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
|
||||
func (cc *Conn) GetGen() (*Gen, error) {
|
||||
conn, closer, err := cc.netlinkConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = closer() }()
|
||||
|
||||
data, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
{Type: unix.NFTA_GEN_ID},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := netlink.Message{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN),
|
||||
Flags: netlink.Request | netlink.Acknowledge,
|
||||
},
|
||||
Data: append(extraHeader(0, 0), data...),
|
||||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
}
|
||||
if len(reply) == 0 {
|
||||
return nil, fmt.Errorf("receiveAckAware: no reply")
|
||||
}
|
||||
return genFromMsg(reply[0])
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
genMsg := event.GeneratedBy.Data.(*nftables.GenMsg)
|
||||
genMsg := event.GeneratedBy.Data.(*nftables.Gen)
|
||||
fileName := filepath.Base(os.Args[0])
|
||||
|
||||
if genMsg.ProcComm != fileName {
|
||||
|
|
|
@ -7429,3 +7429,77 @@ func TestAutoBufferSize(t *testing.T) {
|
|||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGen(t *testing.T) {
|
||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||
defer nftest.CleanupSystemConn(t, newNS)
|
||||
defer conn.FlushRuleset()
|
||||
|
||||
gen, err := conn.GetGen()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get gen: %v", err)
|
||||
}
|
||||
|
||||
conn.AddTable(&nftables.Table{
|
||||
Name: "test-table",
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
|
||||
// Flush to increment the generation ID.
|
||||
if err := conn.Flush(); err != nil {
|
||||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
|
||||
newGen, err := conn.GetGen()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get gen: %v", err)
|
||||
}
|
||||
|
||||
if newGen.ID <= gen.ID {
|
||||
t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID)
|
||||
}
|
||||
|
||||
if newGen.ProcComm != gen.ProcComm {
|
||||
t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm)
|
||||
}
|
||||
|
||||
if newGen.ProcPID != gen.ProcPID {
|
||||
t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushWithGenID(t *testing.T) {
|
||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||
defer nftest.CleanupSystemConn(t, newNS)
|
||||
defer conn.FlushRuleset()
|
||||
|
||||
gen, err := conn.GetGen()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get gen: %v", err)
|
||||
}
|
||||
|
||||
conn.AddTable(&nftables.Table{
|
||||
Name: "test-table",
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
|
||||
// Flush to increment the generation ID.
|
||||
if err := conn.Flush(); err != nil {
|
||||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
|
||||
conn.AddTable(&nftables.Table{
|
||||
Name: "test-table-2",
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
|
||||
err = conn.FlushWithGenID(gen.ID)
|
||||
if err == nil || !errors.Is(err, syscall.ERESTART) {
|
||||
t.Errorf("expected error to be ERESTART, got: %v", err)
|
||||
}
|
||||
|
||||
table, err := conn.ListTable("test-table-2")
|
||||
if table != nil && !errors.Is(err, syscall.ENOENT) {
|
||||
t.Errorf("expected table to not exist, got: %v", table)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue