diff --git a/conn.go b/conn.go index 6a552d8..34cf0b9 100644 --- a/conn.go +++ b/conn.go @@ -17,7 +17,6 @@ package nftables import ( "fmt" "sync" - "sync/atomic" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" @@ -41,7 +40,7 @@ type Conn struct { sync.Mutex put sync.Mutex messages []netlink.Message - entities map[int32]Entity + entities map[int]Entity it int32 err error } @@ -52,7 +51,6 @@ func (cc *Conn) Flush() error { defer func() { cc.messages = nil cc.entities = nil - cc.it = 0 cc.Unlock() }() if len(cc.messages) == 0 { @@ -71,7 +69,7 @@ func (cc *Conn) Flush() error { cc.endBatch(cc.messages) - _, err = conn.SendMessages(cc.messages[:cc.it+1]) + _, err = conn.SendMessages(cc.messages) if err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -104,36 +102,29 @@ func (cc *Conn) Flush() error { } // PutMessage store netlink message to sent after -func (cc *Conn) PutMessage(msg netlink.Message) int32 { +func (cc *Conn) PutMessage(msg netlink.Message) int { cc.put.Lock() defer cc.put.Unlock() if cc.messages == nil { - cc.messages = make([]netlink.Message, 16) - cc.messages[0] = netlink.Message{ + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - } + }) } - i := atomic.AddInt32(&cc.it, 1) + cc.messages = append(cc.messages, msg) - if len(cc.messages) <= int(i) { - cc.messages = resize(cc.messages) - } - - cc.messages[i] = msg - - return i + return len(cc.messages) - 1 } // PutEntity store entity to relate to netlink response -func (cc *Conn) PutEntity(i int32, e Entity) { +func (cc *Conn) PutEntity(i int, e Entity) { if cc.entities == nil { - cc.entities = make(map[int32]Entity) + cc.entities = make(map[int]Entity) } cc.entities[i] = e } @@ -214,23 +205,14 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { func (cc *Conn) endBatch(messages []netlink.Message) { - i := atomic.AddInt32(&cc.it, 1) + cc.put.Lock() + defer cc.put.Unlock() - if len(cc.messages) <= int(i) { - cc.messages = resize(cc.messages) - } - - cc.messages[i] = netlink.Message{ + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - } -} - -func resize(messages []netlink.Message) []netlink.Message { - new := make([]netlink.Message, cap(messages)*2) - copy(new, messages) - return new + }) }