cleaner way to put msg

This commit is contained in:
Alexis PIRES 2020-01-11 21:29:28 +01:00
parent a400c5deff
commit 0d3f3ffbed
1 changed files with 13 additions and 31 deletions

44
conn.go
View File

@ -17,7 +17,6 @@ package nftables
import ( import (
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
@ -41,7 +40,7 @@ type Conn struct {
sync.Mutex sync.Mutex
put sync.Mutex put sync.Mutex
messages []netlink.Message messages []netlink.Message
entities map[int32]Entity entities map[int]Entity
it int32 it int32
err error err error
} }
@ -52,7 +51,6 @@ func (cc *Conn) Flush() error {
defer func() { defer func() {
cc.messages = nil cc.messages = nil
cc.entities = nil cc.entities = nil
cc.it = 0
cc.Unlock() cc.Unlock()
}() }()
if len(cc.messages) == 0 { if len(cc.messages) == 0 {
@ -71,7 +69,7 @@ func (cc *Conn) Flush() error {
cc.endBatch(cc.messages) cc.endBatch(cc.messages)
_, err = conn.SendMessages(cc.messages[:cc.it+1]) _, err = conn.SendMessages(cc.messages)
if err != nil { if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
@ -104,36 +102,29 @@ func (cc *Conn) Flush() error {
} }
// PutMessage store netlink message to sent after // PutMessage store netlink message to sent after
func (cc *Conn) PutMessage(msg netlink.Message) int32 { func (cc *Conn) PutMessage(msg netlink.Message) int {
cc.put.Lock() cc.put.Lock()
defer cc.put.Unlock() defer cc.put.Unlock()
if cc.messages == nil { if cc.messages == nil {
cc.messages = make([]netlink.Message, 16) cc.messages = append(cc.messages, netlink.Message{
cc.messages[0] = netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
} })
} }
i := atomic.AddInt32(&cc.it, 1) cc.messages = append(cc.messages, msg)
if len(cc.messages) <= int(i) { return len(cc.messages) - 1
cc.messages = resize(cc.messages)
}
cc.messages[i] = msg
return i
} }
// PutEntity store entity to relate to netlink response // PutEntity store entity to relate to netlink response
func (cc *Conn) PutEntity(i int32, e Entity) { func (cc *Conn) PutEntity(i int, e Entity) {
if cc.entities == nil { if cc.entities == nil {
cc.entities = make(map[int32]Entity) cc.entities = make(map[int]Entity)
} }
cc.entities[i] = e cc.entities[i] = e
} }
@ -214,23 +205,14 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte {
func (cc *Conn) endBatch(messages []netlink.Message) { func (cc *Conn) endBatch(messages []netlink.Message) {
i := atomic.AddInt32(&cc.it, 1) cc.put.Lock()
defer cc.put.Unlock()
if len(cc.messages) <= int(i) { cc.messages = append(cc.messages, netlink.Message{
cc.messages = resize(cc.messages)
}
cc.messages[i] = netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
} })
}
func resize(messages []netlink.Message) []netlink.Message {
new := make([]netlink.Message, cap(messages)*2)
copy(new, messages)
return new
} }