From e2fedeb355b3299704cebbdd191cabed0b7579fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Sch=C3=A4r?= Date: Thu, 13 Mar 2025 10:38:46 +0100 Subject: [PATCH] Improve safety of ID allocation (#307) There was an existing mechanism to allocate IDs for sets, but this was using a global counter without any synchronization to prevent data races. I replaced this by a new mechanism which uses a connection-scoped counter, protected by the Conn.mu Mutex. This can then also be used in other places where IDs need to be allocated. As an additional safeguard, it will panic instead of allocating the same ID twice in a transaction. Most likely, your program will run out of memory before reaching this point. --- conn.go | 33 +++++++++++++++++++++++++++------ set.go | 5 +---- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index fef9c2a..d4759b1 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ package nftables import ( "errors" "fmt" + "math" "os" "sync" "syscall" @@ -38,12 +39,14 @@ type Conn struct { TestDial nltest.Func // for testing only; passed to nltest.Dial NetNS int // fd referencing the network namespace netlink will interact with. - lasting bool // establish a lasting connection to be used across multiple netlink operations. - mu sync.Mutex // protects the following state - messages []netlink.Message - err error - nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. - sockOptions []SockOption + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state + messages []netlink.Message + err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + sockOptions []SockOption + lastID uint32 + allocatedIDs uint32 } // ConnOption is an option to change the behavior of the nftables Conn returned by Open. @@ -244,6 +247,7 @@ func (cc *Conn) Flush() error { cc.mu.Lock() defer func() { cc.messages = nil + cc.allocatedIDs = 0 cc.mu.Unlock() }() if len(cc.messages) == 0 { @@ -369,3 +373,20 @@ func batch(messages []netlink.Message) []netlink.Message { return batch } + +// allocateTransactionID allocates an identifier which is only valid in the +// current transaction. +func (cc *Conn) allocateTransactionID() uint32 { + if cc.allocatedIDs == math.MaxUint32 { + panic(fmt.Sprintf("trying to allocate more than %d IDs in a single nftables transaction", math.MaxUint32)) + } + // To make it more likely to catch when a transaction ID is erroneously used + // in a later transaction, cc.lastID is not reset after each transaction; + // instead it is only reset once it rolls over from math.MaxUint32 to 0. + cc.allocatedIDs++ + cc.lastID++ + if cc.lastID == 0 { + cc.lastID = 1 + } + return cc.lastID +} diff --git a/set.go b/set.go index cccdcba..431191e 100644 --- a/set.go +++ b/set.go @@ -46,8 +46,6 @@ const ( NFTA_SET_ELEM_EXPRESSIONS = 0x11 ) -var allocSetID uint32 - // SetDatatype represents a datatype declared by nft. type SetDatatype struct { Name string @@ -532,8 +530,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { } if s.ID == 0 { - allocSetID++ - s.ID = allocSetID + s.ID = cc.allocateTransactionID() if s.Anonymous { s.Name = "__set%d" if s.IsMap {