diff --git a/nftables_test.go b/nftables_test.go index fe12566..867e529 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -128,6 +128,48 @@ func ifname(n string) []byte { return b } +func TestTableCreateDestroy(t *testing.T) { + c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer c.FlushRuleset() + + filter := &nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + } + c.DelTable(filter, true) + c.AddTable(filter) + err := c.Flush() + if err != nil { + t.Fatalf("on Flush: %q", err.Error()) + } + + LookupMyTable := func() bool { + ts, err := c.ListTables() + if err != nil { + t.Fatalf("on ListTables: %q", err.Error()) + } + return slices.ContainsFunc(ts, func(t *nftables.Table) bool { + return t.Name == filter.Name && t.Family == filter.Family + }) + } + if !LookupMyTable() { + t.Fatal("AddTable doesn't create my table!") + } + + c.DelTable(filter) + err = c.Flush() + if err != nil { + t.Fatalf("on Flush: %q", err.Error()) + } + + if LookupMyTable() { + t.Fatal("DelTable doesn't delete my table!") + } + + c.DelTable(filter, true) // just for test that 'force' ignore error 'not found' +} + func TestRuleOperations(t *testing.T) { // Create a new network namespace to test these operations, // and tear down the namespace at test completion. diff --git a/table.go b/table.go index 3686b7a..43390cd 100644 --- a/table.go +++ b/table.go @@ -16,6 +16,7 @@ package nftables import ( "fmt" + "slices" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -24,6 +25,10 @@ import ( const ( newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) + + // FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYTABLE const yet. + // See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h + destroyTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | 0x1a) ) // TableFamily specifies the address family for this table. @@ -50,16 +55,23 @@ type Table struct { } // DelTable deletes a specific table, along with all chains/rules it contains. -func (cc *Conn) DelTable(t *Table) { +func (cc *Conn) DelTable(t *Table, force ...bool) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) + + var hdrType netlink.HeaderType + if slices.Contains(force, true) { + hdrType = destroyTableHeaderType + } else { + hdrType = delTableHeaderType + } cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), + Type: hdrType, Flags: netlink.Request | netlink.Acknowledge, }, Data: append(extraHeader(uint8(t.Family), 0), data...),