improve reliability

This commit is contained in:
Alexis PIRES 2020-01-08 10:52:26 +01:00
parent 3aaad4cf4c
commit c578ee35d6
7 changed files with 64 additions and 43 deletions

View File

@ -123,7 +123,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
{Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")},
})...)
}
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -144,7 +144,7 @@ func (cc *Conn) DelChain(c *Chain) {
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN),
Flags: netlink.Request | netlink.Acknowledge,
@ -162,7 +162,7 @@ func (cc *Conn) FlushChain(c *Chain) {
{Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
Flags: netlink.Request | netlink.Acknowledge,

62
conn.go
View File

@ -17,6 +17,7 @@ package nftables
import (
"fmt"
"sync"
"sync/atomic"
"github.com/google/nftables/expr"
"github.com/mdlayher/netlink"
@ -39,7 +40,8 @@ type Conn struct {
NetNS int // Network namespace netlink will interact with.
sync.Mutex
messages []netlink.Message
entities map[int]Entity
entities map[int32]Entity
it int32
err error
}
@ -49,6 +51,7 @@ func (cc *Conn) Flush() error {
defer func() {
cc.messages = nil
cc.entities = nil
cc.it = 0
cc.Unlock()
}()
if len(cc.messages) == 0 {
@ -65,7 +68,9 @@ func (cc *Conn) Flush() error {
defer conn.Close()
smsg, err := conn.SendMessages(batch(cc.messages))
cc.endBatch(cc.messages)
_, err = conn.SendMessages(cc.messages[:cc.it+1])
if err != nil {
return fmt.Errorf("SendMessages: %w", err)
@ -74,7 +79,7 @@ func (cc *Conn) Flush() error {
// Retrieving of seq number associated to entities
entitiesBySeq := make(map[uint32]Entity)
for i, e := range cc.entities {
entitiesBySeq[smsg[i].Header.Sequence] = e
entitiesBySeq[cc.messages[i].Header.Sequence] = e
}
// Trigger entities callback
@ -97,6 +102,36 @@ func (cc *Conn) Flush() error {
return err
}
// PutMessage store netlink message to sent after
func (cc *Conn) PutMessage(msg netlink.Message) int32 {
if cc.messages == nil {
cc.messages = make([]netlink.Message, 128)
cc.messages = append(cc.messages, netlink.Message{})
cc.messages[0] = netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request,
},
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
}
}
i := atomic.AddInt32(&cc.it, 1)
cc.messages = append(cc.messages, netlink.Message{})
cc.messages[i] = msg
return i
}
// PutEntity store entity to relate to netlink response
func (cc *Conn) PutEntity(i int32, e Entity) {
if cc.entities == nil {
cc.entities = make(map[int32]Entity)
}
cc.entities[i] = e
}
func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) {
if cc.TestDial != nil {
return false, nil
@ -130,7 +165,7 @@ func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) {
func (cc *Conn) FlushRuleset() {
cc.Lock()
defer cc.Unlock()
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -171,26 +206,15 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte {
return b
}
func batch(messages []netlink.Message) []netlink.Message {
batch := []netlink.Message{
{
Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request,
},
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
},
}
func (cc *Conn) endBatch(messages []netlink.Message) {
batch = append(batch, messages...)
i := atomic.AddInt32(&cc.it, 1)
batch = append(batch, netlink.Message{
cc.messages[i] = netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request,
},
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
})
return batch
}
}

View File

@ -4033,13 +4033,13 @@ func TestIntegrationAddRule(t *testing.T) {
t.Fatal(err)
}
if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i)
}
rulesGetted, _ := c.GetRule(filter, chain)
for i, rg := range rulesGetted {
if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i)
}
if bytes.Equal(rg.UserData, r.UserData) && rg.Handle != r.Handle {
t.Fatalf("mismatched handle at %d-%d, got: %d, want: %d", w, i, r.Handle, rg.Handle)
}

2
obj.go
View File

@ -43,7 +43,7 @@ func (cc *Conn) AddObj(o Obj) Obj {
return nil
}
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,

11
rule.go
View File

@ -122,19 +122,16 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND
}
cc.messages = append(cc.messages, netlink.Message{
m := netlink.Message{
Header: netlink.Header{
Type: ruleHeaderType,
Flags: flags,
},
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
})
if cc.entities == nil {
cc.entities = make(map[int]Entity)
}
cc.entities[len(cc.messages)] = r
i := cc.PutMessage(m)
cc.PutEntity(i, r)
return r
}
@ -155,7 +152,7 @@ func (cc *Conn) DelRule(r *Rule) error {
})...)
flags := netlink.Request | netlink.Acknowledge
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
Flags: flags,

12
set.go
View File

@ -165,7 +165,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -327,7 +327,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")})
}
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -342,7 +342,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -362,7 +362,7 @@ func (cc *Conn) DelSet(s *Set) {
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET),
Flags: netlink.Request | netlink.Acknowledge,
@ -383,7 +383,7 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
if err != nil {
return err
}
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -402,7 +402,7 @@ func (cc *Conn) FlushSet(s *Set) {
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
Flags: netlink.Request | netlink.Acknowledge,

View File

@ -53,7 +53,7 @@ func (cc *Conn) DelTable(t *Table) {
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Flags: netlink.Request | netlink.Acknowledge,
@ -71,7 +71,7 @@ func (cc *Conn) AddTable(t *Table) *Table {
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -89,7 +89,7 @@ func (cc *Conn) FlushTable(t *Table) {
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
})
cc.messages = append(cc.messages, netlink.Message{
cc.PutMessage(netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
Flags: netlink.Request | netlink.Acknowledge,