diff --git a/chain.go b/chain.go index 9b77640..48ebadf 100644 --- a/chain.go +++ b/chain.go @@ -123,7 +123,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -144,7 +144,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -162,7 +162,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index 34cf0b9..533336d 100644 --- a/conn.go +++ b/conn.go @@ -35,14 +35,13 @@ type Entity interface { // // Commands are buffered. Flush sends all buffered commands in a single batch. type Conn struct { - TestDial nltest.Func // for testing only; passed to nltest.Dial - NetNS int // Network namespace netlink will interact with. sync.Mutex - put sync.Mutex - messages []netlink.Message - entities map[int]Entity - it int32 - err error + TestDial nltest.Func // for testing only; passed to nltest.Dial + NetNS int // Network namespace netlink will interact with. + entities map[int]Entity + messagesMu sync.Mutex + messages []netlink.Message + err error } // Flush sends all buffered commands in a single batch to nftables. @@ -69,9 +68,7 @@ func (cc *Conn) Flush() error { cc.endBatch(cc.messages) - _, err = conn.SendMessages(cc.messages) - - if err != nil { + if _, err = conn.SendMessages(cc.messages); err != nil { return fmt.Errorf("SendMessages: %w", err) } @@ -83,9 +80,12 @@ func (cc *Conn) Flush() error { // Trigger entities callback msg, err := cc.checkReceive(conn) + if err != nil { + return err + } + for msg { rmsg, err := conn.Receive() - if err != nil { return fmt.Errorf("Receive: %w", err) } @@ -93,18 +93,22 @@ func (cc *Conn) Flush() error { for _, msg := range rmsg { if e, ok := entitiesBySeq[msg.Header.Sequence]; ok { e.HandleResponse(msg) + } } msg, err = cc.checkReceive(conn) + if err != nil { + return err + } } return err } -// PutMessage store netlink message to sent after -func (cc *Conn) PutMessage(msg netlink.Message) int { - cc.put.Lock() - defer cc.put.Unlock() +// putMessage store netlink message to sent after +func (cc *Conn) putMessage(msg netlink.Message) int { + cc.messagesMu.Lock() + defer cc.messagesMu.Unlock() if cc.messages == nil { cc.messages = append(cc.messages, netlink.Message{ @@ -162,7 +166,7 @@ func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { func (cc *Conn) FlushRuleset() { cc.Lock() defer cc.Unlock() - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -205,8 +209,8 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { func (cc *Conn) endBatch(messages []netlink.Message) { - cc.put.Lock() - defer cc.put.Unlock() + cc.messagesMu.Lock() + defer cc.messagesMu.Unlock() cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ diff --git a/nftables_test.go b/nftables_test.go index c90d157..bc966d6 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -4008,7 +4008,6 @@ func TestIntegrationAddRule(t *testing.T) { c.Flush() execN := func(w int, n int) { - c := &nftables.Conn{NetNS: int(newNS)} for i := 0; i < n; i++ { diff --git a/obj.go b/obj.go index d3528f8..99d51e0 100644 --- a/obj.go +++ b/obj.go @@ -43,7 +43,7 @@ func (cc *Conn) AddObj(o Obj) Obj { return nil } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, diff --git a/rule.go b/rule.go index 9ca4168..d878b5e 100644 --- a/rule.go +++ b/rule.go @@ -130,7 +130,7 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), } - i := cc.PutMessage(m) + i := cc.putMessage(m) cc.PutEntity(i, r) return r @@ -152,7 +152,7 @@ func (cc *Conn) DelRule(r *Rule) error { })...) flags := netlink.Request | netlink.Acknowledge - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: flags, @@ -166,10 +166,11 @@ func (cc *Conn) DelRule(r *Rule) error { // HandleResponse retrieves Handle in netlink response func (r *Rule) HandleResponse(msg netlink.Message) { rule, err := ruleFromMsg(msg) - - if err == nil { - r.Handle = rule.Handle + if err != nil { + return } + + r.Handle = rule.Handle } func exprsFromMsg(b []byte) ([]expr.Any, error) { diff --git a/set.go b/set.go index 4fa283d..f45e0be 100644 --- a/set.go +++ b/set.go @@ -165,7 +165,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -327,7 +327,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -342,7 +342,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if err != nil { return err } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -362,7 +362,7 @@ func (cc *Conn) DelSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, @@ -383,7 +383,7 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -402,7 +402,7 @@ func (cc *Conn) FlushSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, diff --git a/table.go b/table.go index 9b47f1f..08c83f7 100644 --- a/table.go +++ b/table.go @@ -53,7 +53,7 @@ func (cc *Conn) DelTable(t *Table) { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge, @@ -71,7 +71,7 @@ func (cc *Conn) AddTable(t *Table) *Table { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -89,7 +89,7 @@ func (cc *Conn) FlushTable(t *Table) { data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge,