[FIX]fix GetRule method https://github.com/google/nftables/issues/114 && add GetTable GetChain method

This commit is contained in:
RandolphCYG 2021-11-22 15:50:21 +08:00
parent 16a134723a
commit bc0ae9ae6f
4 changed files with 120 additions and 0 deletions

View File

@ -205,6 +205,41 @@ func (cc *Conn) ListChains() ([]*Chain, error) {
return chains, nil
}
// GetChain gets a chain by name
func (cc *Conn) GetChain(name string) (*Chain, error) {
conn, err := cc.dialNetlink()
if err != nil {
return nil, err
}
defer conn.Close()
msg := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN),
Flags: netlink.Request | netlink.Dump,
},
Data: extraHeader(uint8(unix.AF_UNSPEC), 0),
}
response, err := conn.Execute(msg)
if err != nil {
return nil, err
}
for _, m := range response {
c, err := chainFromMsg(m)
if err != nil {
return nil, err
}
if c.Name == name {
return c, nil
}
}
return nil, nil
}
func chainFromMsg(msg netlink.Message) (*Chain, error) {
chainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN)
if got, want := msg.Header.Type, chainHeaderType; got != want {

View File

@ -38,6 +38,54 @@ var (
enableSysTests = flag.Bool("run_system_tests", false, "Run tests that operate against the live kernel")
)
// Get table by net family and its name
func TestGetTable(t *testing.T) {
conn := nftables.Conn{} // start up a conn
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4)
fmt.Println(table.Name, table.Family)
table2, _ := conn.GetTable("filter", nftables.TableFamilyIPv4)
fmt.Println(table2.Name, table.Family)
}
// Get chain by chain's name
func TestGetChain(t *testing.T) {
conn := nftables.Conn{} // start up a conn
chain, _ := conn.GetChain("POSTROUTING") // get chain
fmt.Println(chain.Name)
}
// Get set and set's elements by table and set's name
func TestGetSet(t *testing.T) {
conn := nftables.Conn{} // start up a conn
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table
set, _ := conn.GetSetByName(table, "dest_addrs") // get set
fmt.Println(set.Name)
eles, _ := conn.GetSetElements(set)
fmt.Println(eles)
}
// Get rules by table and chain
func TestGetRules(t *testing.T) {
conn := nftables.Conn{} // start up a conn
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table
chain, _ := conn.GetChain("POSTROUTING") // get chain
rules, _ := conn.GetRule(table, chain) // get rules
for _, rule := range rules {
fmt.Println(rule.Table.Name, rule.Table.Family, rule.Chain.Name, rule.Handle)
// unpack exprs
//for _, expr := range rule.Exprs {
// fmt.Println(expr)
//}
}
}
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {

View File

@ -293,6 +293,7 @@ func ruleFromMsg(msg netlink.Message) (*Rule, error) {
switch ad.Type() {
case unix.NFTA_RULE_TABLE:
r.Table = &Table{Name: ad.String()}
r.Table.Family = TableFamily(msg.Data[0])
case unix.NFTA_RULE_CHAIN:
r.Chain = &Chain{Name: ad.String()}
case unix.NFTA_RULE_EXPRESSIONS:

View File

@ -45,6 +45,42 @@ type Table struct {
Family TableFamily
}
// GetTable gets a table by name and family
func (cc *Conn) GetTable(name string, family TableFamily) (*Table, error) {
conn, err := cc.dialNetlink()
if err != nil {
return nil, err
}
defer conn.Close()
msg := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE),
Flags: netlink.Request | netlink.Dump,
},
Data: extraHeader(uint8(unix.AF_UNSPEC), 0),
}
response, err := conn.Execute(msg)
if err != nil {
return nil, err
}
for _, m := range response {
t, err := tableFromMsg(m)
if err != nil {
return nil, err
}
if t.Name == name && t.Family == family {
return t, nil
}
}
return nil, nil
}
// DelTable deletes a specific table, along with all chains/rules it contains.
func (cc *Conn) DelTable(t *Table) {
cc.Lock()