From bc0ae9ae6fdff70e0f901b36608e6e144137a22e Mon Sep 17 00:00:00 2001 From: RandolphCYG Date: Mon, 22 Nov 2021 15:50:21 +0800 Subject: [PATCH] [FIX]fix GetRule method https://github.com/google/nftables/issues/114 && add GetTable GetChain method --- chain.go | 35 +++++++++++++++++++++++++++++++++++ nftables_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ rule.go | 1 + table.go | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+) diff --git a/chain.go b/chain.go index 74caca5..c01161e 100644 --- a/chain.go +++ b/chain.go @@ -205,6 +205,41 @@ func (cc *Conn) ListChains() ([]*Chain, error) { return chains, nil } +// GetChain gets a chain by name +func (cc *Conn) GetChain(name string) (*Chain, error) { + conn, err := cc.dialNetlink() + if err != nil { + return nil, err + } + defer conn.Close() + + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), + Flags: netlink.Request | netlink.Dump, + }, + Data: extraHeader(uint8(unix.AF_UNSPEC), 0), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, err + } + + for _, m := range response { + c, err := chainFromMsg(m) + if err != nil { + return nil, err + } + if c.Name == name { + return c, nil + } + + } + + return nil, nil +} + func chainFromMsg(msg netlink.Message) (*Chain, error) { chainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) if got, want := msg.Header.Type, chainHeaderType; got != want { diff --git a/nftables_test.go b/nftables_test.go index 5986022..e1ba6d5 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -38,6 +38,54 @@ var ( enableSysTests = flag.Bool("run_system_tests", false, "Run tests that operate against the live kernel") ) +// Get table by net family and its name +func TestGetTable(t *testing.T) { + conn := nftables.Conn{} // start up a conn + + table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) + fmt.Println(table.Name, table.Family) + + table2, _ := conn.GetTable("filter", nftables.TableFamilyIPv4) + fmt.Println(table2.Name, table.Family) +} + +// Get chain by chain's name +func TestGetChain(t *testing.T) { + conn := nftables.Conn{} // start up a conn + chain, _ := conn.GetChain("POSTROUTING") // get chain + fmt.Println(chain.Name) +} + +// Get set and set's elements by table and set's name +func TestGetSet(t *testing.T) { + conn := nftables.Conn{} // start up a conn + + table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table + + set, _ := conn.GetSetByName(table, "dest_addrs") // get set + fmt.Println(set.Name) + + eles, _ := conn.GetSetElements(set) + fmt.Println(eles) +} + +// Get rules by table and chain +func TestGetRules(t *testing.T) { + conn := nftables.Conn{} // start up a conn + + table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table + chain, _ := conn.GetChain("POSTROUTING") // get chain + + rules, _ := conn.GetRule(table, chain) // get rules + for _, rule := range rules { + fmt.Println(rule.Table.Name, rule.Table.Family, rule.Chain.Name, rule.Handle) + // unpack exprs + //for _, expr := range rule.Exprs { + // fmt.Println(expr) + //} + } +} + // nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing // users to make sense of large byte literals more easily. func nfdump(b []byte) string { diff --git a/rule.go b/rule.go index ec4ce1f..3eef99f 100644 --- a/rule.go +++ b/rule.go @@ -293,6 +293,7 @@ func ruleFromMsg(msg netlink.Message) (*Rule, error) { switch ad.Type() { case unix.NFTA_RULE_TABLE: r.Table = &Table{Name: ad.String()} + r.Table.Family = TableFamily(msg.Data[0]) case unix.NFTA_RULE_CHAIN: r.Chain = &Chain{Name: ad.String()} case unix.NFTA_RULE_EXPRESSIONS: diff --git a/table.go b/table.go index da0126a..9f406de 100644 --- a/table.go +++ b/table.go @@ -45,6 +45,42 @@ type Table struct { Family TableFamily } +// GetTable gets a table by name and family +func (cc *Conn) GetTable(name string, family TableFamily) (*Table, error) { + conn, err := cc.dialNetlink() + if err != nil { + return nil, err + } + + defer conn.Close() + + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE), + Flags: netlink.Request | netlink.Dump, + }, + Data: extraHeader(uint8(unix.AF_UNSPEC), 0), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, err + } + + for _, m := range response { + t, err := tableFromMsg(m) + if err != nil { + return nil, err + } + + if t.Name == name && t.Family == family { + return t, nil + } + } + + return nil, nil +} + // DelTable deletes a specific table, along with all chains/rules it contains. func (cc *Conn) DelTable(t *Table) { cc.Lock()