From 5e242ec5780646a4bf8b60d4651e25fff33a081c Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Sun, 14 Apr 2024 09:19:27 +0000 Subject: [PATCH] List table or chain by name (#258) Adds functionality to list table or chain by specifying its name --- chain.go | 37 ++++++++++++ nftables_test.go | 154 +++++++++++++++++++++++++++++++++++++++++++++++ table.go | 36 ++++++++++- 3 files changed, 225 insertions(+), 2 deletions(-) diff --git a/chain.go b/chain.go index 8d797be..4f4c0a5 100644 --- a/chain.go +++ b/chain.go @@ -193,6 +193,43 @@ func (cc *Conn) ListChains() ([]*Chain, error) { return cc.ListChainsOfTableFamily(TableFamilyUnspecified) } +// ListChain returns a single chain configured in the specified table +func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + attrs := []netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(table.Name + "\x00")}, + {Type: unix.NFTA_CHAIN_NAME, Data: []byte(chain + "\x00")}, + } + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), + Flags: netlink.Request, + }, + Data: append(extraHeader(uint8(table.Family), 0), cc.marshalAttr(attrs)...), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, fmt.Errorf("conn.Execute failed: %v", err) + } + + if got, want := len(response), 1; got != want { + return nil, fmt.Errorf("expected %d response message for chain, got %d", want, got) + } + + ch, err := chainFromMsg(response[0]) + if err != nil { + return nil, err + } + + return ch, nil +} + // ListChainsOfTableFamily returns currently configured chains for the specified // family in the kernel. It lists all chains ins all tables if family is // TableFamilyUnspecified. diff --git a/nftables_test.go b/nftables_test.go index f972188..be8b83b 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1746,6 +1746,160 @@ func TestListChains(t *testing.T) { } } +func TestListChainByName(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "chain_test", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + cr, err := conn.ListChain(tr, c.Name) + if err != nil { + t.Fatalf("conn.ListChain() failed: %v", err) + } + + if got, want := cr.Name, c.Name; got != want { + t.Fatalf("got chain %s, want chain %s", got, want) + } + + if got, want := cr.Table.Name, table.Name; got != want { + t.Fatalf("got chain table %s, want chain table %s", got, want) + } +} + +func TestListChainByNameUsingLasting(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "chain_test_lasting", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter_lasting", + Table: table, + } + conn.AddChain(c) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + cr, err := conn.ListChain(tr, c.Name) + if err != nil { + t.Fatalf("conn.ListChain() failed: %v", err) + } + + if got, want := cr.Name, c.Name; got != want { + t.Fatalf("got chain %s, want chain %s", got, want) + } + + if got, want := cr.Table.Name, table.Name; got != want { + t.Fatalf("got chain table %s, want chain table %s", got, want) + } +} + +func TestListTableByName(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table1 := &nftables.Table{ + Name: "table_test", + Family: nftables.TableFamilyIPv4, + } + conn.AddTable(table1) + table2 := &nftables.Table{ + Name: "table_test_inet", + Family: nftables.TableFamilyINet, + } + conn.AddTable(table2) + table3 := &nftables.Table{ + Name: table1.Name, + Family: nftables.TableFamilyINet, + } + conn.AddTable(table3) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + tr, err := conn.ListTable(table1.Name) + if err != nil { + t.Fatalf("conn.ListTable() failed: %v", err) + } + + if got, want := tr.Name, table1.Name; got != want { + t.Fatalf("got table %s, want table %s", got, want) + } + + // not specifying table family should return family ipv4 + tr, err = conn.ListTable(table3.Name) + if err != nil { + t.Fatalf("conn.ListTable() failed: %v", err) + } + if got, want := tr.Name, table1.Name; got != want { + t.Fatalf("got table %s, want table %s", got, want) + } + if got, want := tr.Family, table1.Family; got != want { + t.Fatalf("got table family %v, want table family %v", got, want) + } + + // specifying correct INet family + tr, err = conn.ListTableOfFamily(table3.Name, nftables.TableFamilyINet) + if err != nil { + t.Fatalf("conn.ListTable() failed: %v", err) + } + if got, want := tr.Name, table3.Name; got != want { + t.Fatalf("got table %s, want table %s", got, want) + } + if got, want := tr.Family, table3.Family; got != want { + t.Fatalf("got table family %v, want table family %v", got, want) + } + + // not specifying correct family should return err since no table in ipv4 + tr, err = conn.ListTable(table2.Name) + if err == nil { + t.Fatalf("conn.ListTable() should have failed") + } + + // specifying correct INet family + tr, err = conn.ListTableOfFamily(table2.Name, nftables.TableFamilyINet) + if err != nil { + t.Fatalf("conn.ListTable() failed: %v", err) + } + if got, want := tr.Name, table2.Name; got != want { + t.Fatalf("got table %s, want table %s", got, want) + } + if got, want := tr.Family, table2.Family; got != want { + t.Fatalf("got table family %v, want table family %v", got, want) + } +} + func TestAddChain(t *testing.T) { tests := []struct { name string diff --git a/table.go b/table.go index ff3b592..c391b7b 100644 --- a/table.go +++ b/table.go @@ -112,6 +112,25 @@ func (cc *Conn) FlushTable(t *Table) { }) } +// ListTable returns table found for the specified name. Searches for +// the table under IPv4 family. As per nft man page: "When no address +// family is specified, ip is used by default." +func (cc *Conn) ListTable(name string) (*Table, error) { + return cc.ListTableOfFamily(name, TableFamilyIPv4) +} + +// ListTableOfFamily returns table found for the specified name and table family +func (cc *Conn) ListTableOfFamily(name string, family TableFamily) (*Table, error) { + t, err := cc.listTablesOfNameAndFamily(name, family) + if err != nil { + return nil, err + } + if got, want := len(t), 1; got != want { + return nil, fmt.Errorf("expected table count %d, got %d", want, got) + } + return t[0], nil +} + // ListTables returns currently configured tables in the kernel func (cc *Conn) ListTables() ([]*Table, error) { return cc.ListTablesOfFamily(TableFamilyUnspecified) @@ -120,18 +139,31 @@ func (cc *Conn) ListTables() ([]*Table, error) { // ListTablesOfFamily returns currently configured tables for the specified table family // in the kernel. It lists all tables if family is TableFamilyUnspecified. func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) { + return cc.listTablesOfNameAndFamily("", family) +} + +func (cc *Conn) listTablesOfNameAndFamily(name string, family TableFamily) ([]*Table, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() + data := extraHeader(uint8(family), 0) + flags := netlink.Request | netlink.Dump + if name != "" { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(name + "\x00")}, + })...) + flags = netlink.Request + } + msg := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE), - Flags: netlink.Request | netlink.Dump, + Flags: flags, }, - Data: extraHeader(uint8(family), 0), + Data: data, } response, err := conn.Execute(msg)