diff --git a/conn.go b/conn.go index 90382b7..67e0f72 100644 --- a/conn.go +++ b/conn.go @@ -232,9 +232,20 @@ func (cc *Conn) CloseLasting() error { } // Flush sends all buffered commands in a single batch to nftables. -// If an optional gen ID is provided, it will be used in the batch begin message. -// If the gen ID is not matched by the kernel, it will return an ERESTART error. -func (cc *Conn) Flush(genID ...uint32) error { +func (cc *Conn) Flush() error { + return cc.flush(0) +} + +// FlushWithGenID sends all buffered commands in a single batch to nftables +// along with the provided gen ID. If the ruleset has changed since the gen ID +// was retrieved, an ERESTART error will be returned. +func (cc *Conn) FlushWithGenID(genID uint32) error { + return cc.flush(genID) +} + +// flush sends all buffered commands in a single batch to nftables. If genID is +// non-zero, it will be included in the batch messages. +func (cc *Conn) flush(genID uint32) error { cc.mu.Lock() defer func() { cc.messages = nil @@ -261,7 +272,7 @@ func (cc *Conn) Flush(genID ...uint32) error { return err } - batch, err := batch(cc.messages, genID...) + batch, err := batch(cc.messages, genID) if err != nil { return err } @@ -395,17 +406,17 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { return b } -// Batch wraps the given messages in a batch begin and end message, and returns -// the resulting slice of netlink messages. If a genID is provided, it is +// batch wraps the given messages in a batch begin and end message, and returns +// the resulting slice of netlink messages. If the genID is non-zero, it will be // included in both batch messages. -func batch(messages []netlinkMessage, genID ...uint32) ([]netlink.Message, error) { +func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) { batch := make([]netlink.Message, len(messages)+2) data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES) - if len(genID) > 0 && genID[0] > 0 { + if genID > 0 { attr, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID[0])}, + {Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)}, }) if err != nil { return nil, err diff --git a/gen.go b/gen.go index 52f1cd3..0d88e15 100644 --- a/gen.go +++ b/gen.go @@ -14,6 +14,7 @@ type Gen struct { ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN } +// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen. type GenMsg = Gen const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) diff --git a/nftables_test.go b/nftables_test.go index 1280f7c..bc433b4 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7493,7 +7493,7 @@ func TestFlushWithGenID(t *testing.T) { Family: nftables.TableFamilyIPv4, }) - err = conn.Flush(gen.ID) + err = conn.FlushWithGenID(gen.ID) if err == nil || !errors.Is(err, syscall.ERESTART) { t.Errorf("expected error to be ERESTART, got: %v", err) }