improve reliability

This commit is contained in:
Alexis PIRES 2019-12-27 23:18:57 +01:00
parent 8cb78f7432
commit 5a645a16e0
3 changed files with 73 additions and 20 deletions

20
conn.go
View File

@ -35,7 +35,7 @@ type Conn struct {
NetNS int // Network namespace netlink will interact with. NetNS int // Network namespace netlink will interact with.
sync.Mutex sync.Mutex
messages []netlink.Message messages []netlink.Message
rules []*Rule rules map[int]*Rule
err error err error
} }
@ -61,12 +61,20 @@ func (cc *Conn) Flush() error {
defer conn.Close() defer conn.Close()
if _, err := conn.SendMessages(batch(cc.messages)); err != nil { smsg, err := conn.SendMessages(batch(cc.messages))
if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
} }
echoedRules := 0 // Retrieving of seq number associated to rules
rulesBySeq := make(map[uint32]*Rule)
for i, rule := range cc.rules {
rulesBySeq[smsg[i].Header.Sequence] = rule
}
// Search handle in netlink messages based on requests seq
echoedRules := 0
for len(cc.rules) > echoedRules { for len(cc.rules) > echoedRules {
rmsg, err := conn.Receive() rmsg, err := conn.Receive()
@ -75,10 +83,10 @@ func (cc *Conn) Flush() error {
} }
for _, msg := range rmsg { for _, msg := range rmsg {
if msg.Header.Type == ruleHeaderType { if srule, ok := rulesBySeq[msg.Header.Sequence]; ok {
rule, err := ruleFromMsg(msg) rrule, err := ruleFromMsg(msg)
if err == nil { if err == nil {
cc.rules[echoedRules].Handle = rule.Handle srule.Handle = rrule.Handle
echoedRules++ echoedRules++
} }
} }

View File

@ -3996,19 +3996,28 @@ func TestHandleBack(t *testing.T) {
Name: "filter", Name: "filter",
}) })
prerouting := c.AddChain(&nftables.Chain{ chain1 := c.AddChain(&nftables.Chain{
Name: "base-chain", Name: "chain1",
Table: filter, Table: filter,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter, Priority: nftables.ChainPriorityFilter,
}) })
var rulesCreated []*nftables.Rule chain2 := c.AddChain(&nftables.Chain{
Name: "chain2",
rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{
Table: filter, Table: filter,
Chain: prerouting, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})
var rulesCreated1 []*nftables.Rule
var rulesCreated2 []*nftables.Rule
rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{
Table: filter,
Chain: chain1,
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Verdict{ &expr.Verdict{
// [ immediate reg 0 drop ] // [ immediate reg 0 drop ]
@ -4017,9 +4026,9 @@ func TestHandleBack(t *testing.T) {
}, },
})) }))
rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{
Table: filter, Table: filter,
Chain: prerouting, Chain: chain1,
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Verdict{ &expr.Verdict{
// [ immediate reg 0 drop ] // [ immediate reg 0 drop ]
@ -4028,7 +4037,24 @@ func TestHandleBack(t *testing.T) {
}, },
})) }))
for i, r := range rulesCreated { rulesCreated2 = append(rulesCreated2, c.AddRule(&nftables.Rule{
Table: filter,
Chain: chain2,
Exprs: []expr.Any{
&expr.Verdict{
// [ immediate reg 0 drop ]
Kind: expr.VerdictDrop,
},
},
}))
for i, r := range rulesCreated1 {
if r.Handle != 0 {
t.Fatalf("unexpected handle value at %d", i)
}
}
for i, r := range rulesCreated2 {
if r.Handle != 0 { if r.Handle != 0 {
t.Fatalf("unexpected handle value at %d", i) t.Fatalf("unexpected handle value at %d", i)
} }
@ -4038,18 +4064,33 @@ func TestHandleBack(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
rulesGetted, _ := c.GetRule(filter, prerouting) rulesGetted1, _ := c.GetRule(filter, chain1)
rulesGetted2, _ := c.GetRule(filter, chain2)
if len(rulesGetted) != len(rulesCreated) { if len(rulesGetted1) != len(rulesCreated1) {
t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted), len(rulesCreated)) t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted1), len(rulesCreated1))
} }
for i, r := range rulesGetted { if len(rulesGetted2) != len(rulesCreated2) {
t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted2), len(rulesCreated2))
}
for i, r := range rulesGetted1 {
if r.Handle == 0 { if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i) t.Fatalf("handle value is empty at %d", i)
} }
if r.Handle != rulesCreated[i].Handle { if r.Handle != rulesCreated1[i].Handle {
t.Fatalf("mismatched handle at %d", i)
}
}
for i, r := range rulesGetted2 {
if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i)
}
if r.Handle != rulesCreated2[i].Handle {
t.Fatalf("mismatched handle at %d", i) t.Fatalf("mismatched handle at %d", i)
} }
} }

View File

@ -130,7 +130,11 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
}) })
cc.rules = append(cc.rules, r) if cc.rules == nil {
cc.rules = make(map[int]*Rule)
}
cc.rules[len(cc.messages)] = r
return r return r
} }