diff --git a/conn.go b/conn.go index 3768645..094c17a 100644 --- a/conn.go +++ b/conn.go @@ -35,6 +35,7 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message + rules []*Rule err error } @@ -43,6 +44,7 @@ func (cc *Conn) Flush() error { cc.Lock() defer func() { cc.messages = nil + cc.rules = nil cc.Unlock() }() if len(cc.messages) == 0 { @@ -63,8 +65,25 @@ func (cc *Conn) Flush() error { return fmt.Errorf("SendMessages: %w", err) } - if _, err := conn.Receive(); err != nil { - return fmt.Errorf("Receive: %w", err) + echoedRules := 0 + + for len(cc.rules) > echoedRules { + rmsg, err := conn.Receive() + + if err != nil { + return fmt.Errorf("Receive: %w", err) + } + + for _, msg := range rmsg { + if msg.Header.Type == ruleHeaderType { + rule, err := ruleFromMsg(msg) + if err == nil { + cc.rules[echoedRules].Handle = rule.Handle + echoedRules++ + } + } + } + } return nil diff --git a/nftables_test.go b/nftables_test.go index e3337ef..19b0f9a 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3980,3 +3980,77 @@ func TestStatelessNAT(t *testing.T) { t.Fatal(err) } } + +func TestHandleBack(t *testing.T) { + + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + prerouting := c.AddChain(&nftables.Chain{ + Name: "base-chain", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityFilter, + }) + + var rulesCreated []*nftables.Rule + + rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + Table: filter, + Chain: prerouting, + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + })) + + rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + Table: filter, + Chain: prerouting, + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + })) + + for i, r := range rulesCreated { + if r.Handle != 0 { + t.Fatalf("unexpected handle value at %d", i) + } + } + + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + rulesGetted, _ := c.GetRule(filter, prerouting) + + if len(rulesGetted) != len(rulesCreated) { + t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted), len(rulesCreated)) + } + + for i, r := range rulesGetted { + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + + if r.Handle != rulesCreated[i].Handle { + t.Fatalf("mismatched handle at %d", i) + } + } +} diff --git a/rule.go b/rule.go index 48d79d1..561da54 100644 --- a/rule.go +++ b/rule.go @@ -130,6 +130,8 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), }) + cc.rules = append(cc.rules, r) + return r }