From c578ee35d6bef766dc023c00cc30e409630f896c Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Wed, 8 Jan 2020 10:52:26 +0100 Subject: [PATCH] improve reliability --- chain.go | 6 ++--- conn.go | 62 +++++++++++++++++++++++++++++++++--------------- nftables_test.go | 8 +++---- obj.go | 2 +- rule.go | 11 ++++----- set.go | 12 +++++----- table.go | 6 ++--- 7 files changed, 64 insertions(+), 43 deletions(-) diff --git a/chain.go b/chain.go index 74caca5..9b77640 100644 --- a/chain.go +++ b/chain.go @@ -123,7 +123,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -144,7 +144,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -162,7 +162,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index 6011341..67c97a6 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ package nftables import ( "fmt" "sync" + "sync/atomic" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" @@ -39,7 +40,8 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message - entities map[int]Entity + entities map[int32]Entity + it int32 err error } @@ -49,6 +51,7 @@ func (cc *Conn) Flush() error { defer func() { cc.messages = nil cc.entities = nil + cc.it = 0 cc.Unlock() }() if len(cc.messages) == 0 { @@ -65,7 +68,9 @@ func (cc *Conn) Flush() error { defer conn.Close() - smsg, err := conn.SendMessages(batch(cc.messages)) + cc.endBatch(cc.messages) + + _, err = conn.SendMessages(cc.messages[:cc.it+1]) if err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -74,7 +79,7 @@ func (cc *Conn) Flush() error { // Retrieving of seq number associated to entities entitiesBySeq := make(map[uint32]Entity) for i, e := range cc.entities { - entitiesBySeq[smsg[i].Header.Sequence] = e + entitiesBySeq[cc.messages[i].Header.Sequence] = e } // Trigger entities callback @@ -97,6 +102,36 @@ func (cc *Conn) Flush() error { return err } +// PutMessage store netlink message to sent after +func (cc *Conn) PutMessage(msg netlink.Message) int32 { + if cc.messages == nil { + cc.messages = make([]netlink.Message, 128) + cc.messages = append(cc.messages, netlink.Message{}) + cc.messages[0] = netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + } + } + + i := atomic.AddInt32(&cc.it, 1) + + cc.messages = append(cc.messages, netlink.Message{}) + cc.messages[i] = msg + + return i +} + +// PutEntity store entity to relate to netlink response +func (cc *Conn) PutEntity(i int32, e Entity) { + if cc.entities == nil { + cc.entities = make(map[int32]Entity) + } + cc.entities[i] = e +} + func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { if cc.TestDial != nil { return false, nil @@ -130,7 +165,7 @@ func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { func (cc *Conn) FlushRuleset() { cc.Lock() defer cc.Unlock() - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -171,26 +206,15 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { return b } -func batch(messages []netlink.Message) []netlink.Message { - batch := []netlink.Message{ - { - Header: netlink.Header{ - Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), - Flags: netlink.Request, - }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }, - } +func (cc *Conn) endBatch(messages []netlink.Message) { - batch = append(batch, messages...) + i := atomic.AddInt32(&cc.it, 1) - batch = append(batch, netlink.Message{ + cc.messages[i] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }) - - return batch + } } diff --git a/nftables_test.go b/nftables_test.go index 8d0e5c4..c90d157 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -4033,13 +4033,13 @@ func TestIntegrationAddRule(t *testing.T) { t.Fatal(err) } + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + rulesGetted, _ := c.GetRule(filter, chain) for i, rg := range rulesGetted { - if r.Handle == 0 { - t.Fatalf("handle value is empty at %d", i) - } - if bytes.Equal(rg.UserData, r.UserData) && rg.Handle != r.Handle { t.Fatalf("mismatched handle at %d-%d, got: %d, want: %d", w, i, r.Handle, rg.Handle) } diff --git a/obj.go b/obj.go index f3627df..d3528f8 100644 --- a/obj.go +++ b/obj.go @@ -43,7 +43,7 @@ func (cc *Conn) AddObj(o Obj) Obj { return nil } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, diff --git a/rule.go b/rule.go index 2f3deae..9ca4168 100644 --- a/rule.go +++ b/rule.go @@ -122,19 +122,16 @@ func (cc *Conn) AddRule(r *Rule) *Rule { flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND } - cc.messages = append(cc.messages, netlink.Message{ + m := netlink.Message{ Header: netlink.Header{ Type: ruleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), - }) - - if cc.entities == nil { - cc.entities = make(map[int]Entity) } - cc.entities[len(cc.messages)] = r + i := cc.PutMessage(m) + cc.PutEntity(i, r) return r } @@ -155,7 +152,7 @@ func (cc *Conn) DelRule(r *Rule) error { })...) flags := netlink.Request | netlink.Acknowledge - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: flags, diff --git a/set.go b/set.go index 2b9ee7e..4fa283d 100644 --- a/set.go +++ b/set.go @@ -165,7 +165,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -327,7 +327,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -342,7 +342,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -362,7 +362,7 @@ func (cc *Conn) DelSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, @@ -383,7 +383,7 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -402,7 +402,7 @@ func (cc *Conn) FlushSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, diff --git a/table.go b/table.go index da0126a..9b47f1f 100644 --- a/table.go +++ b/table.go @@ -53,7 +53,7 @@ func (cc *Conn) DelTable(t *Table) { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge, @@ -71,7 +71,7 @@ func (cc *Conn) AddTable(t *Table) *Table { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -89,7 +89,7 @@ func (cc *Conn) FlushTable(t *Table) { data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge,