diff --git a/nftables_test.go b/nftables_test.go index df57a2d..568e010 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -561,6 +561,51 @@ func TestAddCounter(t *testing.T) { } } +func TestDelRule(t *testing.T) { + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft delete rule ipv4table ipv4chain-1 handle 9 + []byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x31\x00\x0c\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x09"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + }, + } + + c.DelRule(&nftables.Rule{ + Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4}, + Chain: &nftables.Chain{Name: "ipv4chain-1", Type: nftables.ChainTypeFilter}, + Handle: uint64(9), + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } +} + func TestTProxy(t *testing.T) { want := [][]byte{ // batch begin diff --git a/rule.go b/rule.go index 540d26f..297db4e 100644 --- a/rule.go +++ b/rule.go @@ -170,6 +170,31 @@ func (cc *Conn) AddRule(r *Rule) *Rule { return r } +// DelRule deletes the specified Rule, rule's handle cannot be 0 +func (cc *Conn) DelRule(r *Rule) error { + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")}, + }) + if r.Handle == 0 { + return fmt.Errorf("rule's handle cannot be 0") + } + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)}, + })...) + flags := netlink.Request | netlink.Acknowledge + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Flags: flags, + }, + Data: append(extraHeader(uint8(r.Table.Family), 0), data...), + }) + + return nil +} + func exprsFromMsg(b []byte) ([]expr.Any, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil {