[FIX]fix GetRule method https://github.com/google/nftables/issues/114 && add GetTable and GetChain method to get specific data by name && add test files
This commit is contained in:
parent
16a134723a
commit
031c75209e
|
@ -0,0 +1 @@
|
|||
.idea/
|
|
@ -1,6 +1,7 @@
|
|||
[](https://github.com/google/nftables/actions/workflows/push.yml)
|
||||
[](https://godoc.org/github.com/google/nftables)
|
||||
|
||||
## 1. Introduction
|
||||
**This is not the correct repository for issues with the Linux nftables
|
||||
project!** This repository contains a third-party Go package to programmatically
|
||||
interact with nftables. Find the official nftables website at
|
||||
|
@ -11,14 +12,70 @@ implemented in pure Go, i.e. does not wrap libnftnl.
|
|||
|
||||
This is not an official Google product.
|
||||
|
||||
## Breaking changes
|
||||
## 2. Breaking changes
|
||||
|
||||
This package is in very early stages, and only contains enough data types and
|
||||
functions to install very basic nftables rules. It is likely that mistakes with
|
||||
the data types/API will be identified as more functionality is added.
|
||||
|
||||
## Contributions
|
||||
## 3. Contributions
|
||||
|
||||
Contributions are very welcome!
|
||||
|
||||
|
||||
## 4. Examples
|
||||
|
||||
### 1. Get common data types of Nftables
|
||||
|
||||
#### 1.1. Get table by net family and its name
|
||||
|
||||
```go
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
|
||||
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4)
|
||||
fmt.Println(table.Name)
|
||||
```
|
||||
#### 1.2. Get chain by chain's name
|
||||
|
||||
```go
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
chain, _ := conn.GetChain("POSTROUTING") // get chain
|
||||
fmt.Println(chain.Name)
|
||||
```
|
||||
|
||||
#### 1.3. Get set and set's elements by table and set's name
|
||||
|
||||
```go
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
|
||||
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table
|
||||
|
||||
|
||||
set, _ := conn.GetSetByName(table, "dest_addrs") // get set
|
||||
fmt.Println(set.Name)
|
||||
|
||||
eles, _ := conn.GetSetElements(set)
|
||||
fmt.Println(eles)
|
||||
```
|
||||
|
||||
#### 1.4. Get rules by table and chain
|
||||
|
||||
```go
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
|
||||
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table
|
||||
chain, _ := conn.GetChain("POSTROUTING") // get chain
|
||||
|
||||
rules, _ := conn.GetRule(table, chain) // get rules
|
||||
for _, rule := range rules {
|
||||
fmt.Println(rule.Table.Name, rule.Table.Family, rule.Chain.Name, rule.Handle)
|
||||
// unpack exprs
|
||||
for _, expr := range rule.Exprs {
|
||||
fmt.Println(expr)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Insert common data types of Nftables
|
||||
|
||||
**wait for update**
|
|
@ -205,6 +205,20 @@ func (cc *Conn) ListChains() ([]*Chain, error) {
|
|||
return chains, nil
|
||||
}
|
||||
|
||||
// GetChain gets a chain by name
|
||||
func (cc *Conn) GetChain(name string) (*Chain, error) {
|
||||
chains, err := cc.ListChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, chain := range chains {
|
||||
if chain.Name == name {
|
||||
return chain, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func chainFromMsg(msg netlink.Message) (*Chain, error) {
|
||||
chainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN)
|
||||
if got, want := msg.Header.Type, chainHeaderType; got != want {
|
||||
|
|
|
@ -84,8 +84,11 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if r.Table.Name == t.Name && r.Table.Family == t.Family && r.Chain.Name == c.Name {
|
||||
rules = append(rules, r)
|
||||
}
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
@ -293,6 +296,7 @@ func ruleFromMsg(msg netlink.Message) (*Rule, error) {
|
|||
switch ad.Type() {
|
||||
case unix.NFTA_RULE_TABLE:
|
||||
r.Table = &Table{Name: ad.String()}
|
||||
r.Table.Family = TableFamily(msg.Data[0])
|
||||
case unix.NFTA_RULE_CHAIN:
|
||||
r.Chain = &Chain{Name: ad.String()}
|
||||
case unix.NFTA_RULE_EXPRESSIONS:
|
||||
|
|
|
@ -45,6 +45,23 @@ type Table struct {
|
|||
Family TableFamily
|
||||
}
|
||||
|
||||
// GetTable gets a table by name and family
|
||||
func (cc *Conn) GetTable(name string, family TableFamily) (*Table, error) {
|
||||
cc.Lock()
|
||||
defer cc.Unlock()
|
||||
|
||||
tables, err := cc.ListTables()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, table := range tables {
|
||||
if table.Name == name && table.Family == family {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// DelTable deletes a specific table, along with all chains/rules it contains.
|
||||
func (cc *Conn) DelTable(t *Table) {
|
||||
cc.Lock()
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/nftables"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Get table by net family and its name
|
||||
func TestGetTable(t *testing.T) {
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
|
||||
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4)
|
||||
fmt.Println(table.Name)
|
||||
}
|
||||
|
||||
// Get chain by chain's name
|
||||
func TestGetChain(t *testing.T) {
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
chain, _ := conn.GetChain("POSTROUTING") // get chain
|
||||
fmt.Println(chain.Name)
|
||||
}
|
||||
|
||||
// Get set and set's elements by table and set's name
|
||||
func TestGetSet(t *testing.T) {
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
|
||||
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table
|
||||
|
||||
set, _ := conn.GetSetByName(table, "dest_addrs") // get set
|
||||
fmt.Println(set.Name)
|
||||
|
||||
eles, _ := conn.GetSetElements(set)
|
||||
fmt.Println(eles)
|
||||
}
|
||||
|
||||
// Get rules by table and chain
|
||||
func TestGetRules(t *testing.T) {
|
||||
conn := nftables.Conn{} // start up a conn
|
||||
|
||||
table, _ := conn.GetTable("nat", nftables.TableFamilyIPv4) // get table
|
||||
chain, _ := conn.GetChain("POSTROUTING") // get chain
|
||||
|
||||
rules, _ := conn.GetRule(table, chain) // get rules
|
||||
for _, rule := range rules {
|
||||
fmt.Println(rule.Table.Name, rule.Table.Family, rule.Chain.Name, rule.Handle)
|
||||
// unpack exprs
|
||||
for _, expr := range rule.Exprs {
|
||||
fmt.Println(expr)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue