diff --git a/conn.go b/conn.go index fef9c2a..d4759b1 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ package nftables import ( "errors" "fmt" + "math" "os" "sync" "syscall" @@ -38,12 +39,14 @@ type Conn struct { TestDial nltest.Func // for testing only; passed to nltest.Dial NetNS int // fd referencing the network namespace netlink will interact with. - lasting bool // establish a lasting connection to be used across multiple netlink operations. - mu sync.Mutex // protects the following state - messages []netlink.Message - err error - nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. - sockOptions []SockOption + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state + messages []netlink.Message + err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + sockOptions []SockOption + lastID uint32 + allocatedIDs uint32 } // ConnOption is an option to change the behavior of the nftables Conn returned by Open. @@ -244,6 +247,7 @@ func (cc *Conn) Flush() error { cc.mu.Lock() defer func() { cc.messages = nil + cc.allocatedIDs = 0 cc.mu.Unlock() }() if len(cc.messages) == 0 { @@ -369,3 +373,20 @@ func batch(messages []netlink.Message) []netlink.Message { return batch } + +// allocateTransactionID allocates an identifier which is only valid in the +// current transaction. +func (cc *Conn) allocateTransactionID() uint32 { + if cc.allocatedIDs == math.MaxUint32 { + panic(fmt.Sprintf("trying to allocate more than %d IDs in a single nftables transaction", math.MaxUint32)) + } + // To make it more likely to catch when a transaction ID is erroneously used + // in a later transaction, cc.lastID is not reset after each transaction; + // instead it is only reset once it rolls over from math.MaxUint32 to 0. + cc.allocatedIDs++ + cc.lastID++ + if cc.lastID == 0 { + cc.lastID = 1 + } + return cc.lastID +} diff --git a/set.go b/set.go index cccdcba..431191e 100644 --- a/set.go +++ b/set.go @@ -46,8 +46,6 @@ const ( NFTA_SET_ELEM_EXPRESSIONS = 0x11 ) -var allocSetID uint32 - // SetDatatype represents a datatype declared by nft. type SetDatatype struct { Name string @@ -532,8 +530,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { } if s.ID == 0 { - allocSetID++ - s.ID = allocSetID + s.ID = cc.allocateTransactionID() if s.Anonymous { s.Name = "__set%d" if s.IsMap {