From 5a645a16e0b978ab100de1b7b7f589431ccaff06 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Fri, 27 Dec 2019 23:18:57 +0100 Subject: [PATCH] improve reliability --- conn.go | 20 ++++++++++----- nftables_test.go | 67 ++++++++++++++++++++++++++++++++++++++---------- rule.go | 6 ++++- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index 094c17a..e7ba741 100644 --- a/conn.go +++ b/conn.go @@ -35,7 +35,7 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message - rules []*Rule + rules map[int]*Rule err error } @@ -61,12 +61,20 @@ func (cc *Conn) Flush() error { defer conn.Close() - if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + smsg, err := conn.SendMessages(batch(cc.messages)) + + if err != nil { return fmt.Errorf("SendMessages: %w", err) } - echoedRules := 0 + // Retrieving of seq number associated to rules + rulesBySeq := make(map[uint32]*Rule) + for i, rule := range cc.rules { + rulesBySeq[smsg[i].Header.Sequence] = rule + } + // Search handle in netlink messages based on requests seq + echoedRules := 0 for len(cc.rules) > echoedRules { rmsg, err := conn.Receive() @@ -75,10 +83,10 @@ func (cc *Conn) Flush() error { } for _, msg := range rmsg { - if msg.Header.Type == ruleHeaderType { - rule, err := ruleFromMsg(msg) + if srule, ok := rulesBySeq[msg.Header.Sequence]; ok { + rrule, err := ruleFromMsg(msg) if err == nil { - cc.rules[echoedRules].Handle = rule.Handle + srule.Handle = rrule.Handle echoedRules++ } } diff --git a/nftables_test.go b/nftables_test.go index 19b0f9a..02795b2 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3996,19 +3996,28 @@ func TestHandleBack(t *testing.T) { Name: "filter", }) - prerouting := c.AddChain(&nftables.Chain{ - Name: "base-chain", + chain1 := c.AddChain(&nftables.Chain{ + Name: "chain1", Table: filter, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityFilter, }) - var rulesCreated []*nftables.Rule + chain2 := c.AddChain(&nftables.Chain{ + Name: "chain2", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityFilter, + }) - rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + var rulesCreated1 []*nftables.Rule + var rulesCreated2 []*nftables.Rule + + rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{ Table: filter, - Chain: prerouting, + Chain: chain1, Exprs: []expr.Any{ &expr.Verdict{ // [ immediate reg 0 drop ] @@ -4017,9 +4026,9 @@ func TestHandleBack(t *testing.T) { }, })) - rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{ Table: filter, - Chain: prerouting, + Chain: chain1, Exprs: []expr.Any{ &expr.Verdict{ // [ immediate reg 0 drop ] @@ -4028,7 +4037,24 @@ func TestHandleBack(t *testing.T) { }, })) - for i, r := range rulesCreated { + rulesCreated2 = append(rulesCreated2, c.AddRule(&nftables.Rule{ + Table: filter, + Chain: chain2, + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + })) + + for i, r := range rulesCreated1 { + if r.Handle != 0 { + t.Fatalf("unexpected handle value at %d", i) + } + } + + for i, r := range rulesCreated2 { if r.Handle != 0 { t.Fatalf("unexpected handle value at %d", i) } @@ -4038,18 +4064,33 @@ func TestHandleBack(t *testing.T) { t.Fatal(err) } - rulesGetted, _ := c.GetRule(filter, prerouting) + rulesGetted1, _ := c.GetRule(filter, chain1) + rulesGetted2, _ := c.GetRule(filter, chain2) - if len(rulesGetted) != len(rulesCreated) { - t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted), len(rulesCreated)) + if len(rulesGetted1) != len(rulesCreated1) { + t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted1), len(rulesCreated1)) } - for i, r := range rulesGetted { + if len(rulesGetted2) != len(rulesCreated2) { + t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted2), len(rulesCreated2)) + } + + for i, r := range rulesGetted1 { if r.Handle == 0 { t.Fatalf("handle value is empty at %d", i) } - if r.Handle != rulesCreated[i].Handle { + if r.Handle != rulesCreated1[i].Handle { + t.Fatalf("mismatched handle at %d", i) + } + } + + for i, r := range rulesGetted2 { + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + + if r.Handle != rulesCreated2[i].Handle { t.Fatalf("mismatched handle at %d", i) } } diff --git a/rule.go b/rule.go index 561da54..6a403eb 100644 --- a/rule.go +++ b/rule.go @@ -130,7 +130,11 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), }) - cc.rules = append(cc.rules, r) + if cc.rules == nil { + cc.rules = make(map[int]*Rule) + } + + cc.rules[len(cc.messages)] = r return r }