Compare commits

...

3 Commits

Author SHA1 Message Date
Paul Greenberg e219c2036b
Merge dae73eaa9c into 5e242ec578 2024-04-19 04:18:26 -04:00
turekt 5e242ec578
List table or chain by name (#258)
Adds functionality to list table or chain by specifying its name
2024-04-14 11:19:27 +02:00
Paul Greenberg dae73eaa9c rule: add String() method
Before this commit: the printing of a rule results in
a pointer address.

After this commit: the printing of a rules results in
a human-readable text.

Resolves: #104

Signed-off-by: Paul Greenberg <greenpau@outlook.com>
2020-08-03 10:59:40 -04:00
7 changed files with 295 additions and 2 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
nftables.test

View File

@ -21,4 +21,12 @@ the data types/API will be identified as more functionality is added.
Contributions are very welcome! Contributions are very welcome!
### Testing Changes
Run the following commands to test your changes:
```bash
go test ./...
go test -c github.com/google/nftables
sudo ./nftables.test -test.v -run_system_tests
```

View File

@ -193,6 +193,43 @@ func (cc *Conn) ListChains() ([]*Chain, error) {
return cc.ListChainsOfTableFamily(TableFamilyUnspecified) return cc.ListChainsOfTableFamily(TableFamilyUnspecified)
} }
// ListChain returns a single chain configured in the specified table
func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()
attrs := []netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(table.Name + "\x00")},
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(chain + "\x00")},
}
msg := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN),
Flags: netlink.Request,
},
Data: append(extraHeader(uint8(table.Family), 0), cc.marshalAttr(attrs)...),
}
response, err := conn.Execute(msg)
if err != nil {
return nil, fmt.Errorf("conn.Execute failed: %v", err)
}
if got, want := len(response), 1; got != want {
return nil, fmt.Errorf("expected %d response message for chain, got %d", want, got)
}
ch, err := chainFromMsg(response[0])
if err != nil {
return nil, err
}
return ch, nil
}
// ListChainsOfTableFamily returns currently configured chains for the specified // ListChainsOfTableFamily returns currently configured chains for the specified
// family in the kernel. It lists all chains ins all tables if family is // family in the kernel. It lists all chains ins all tables if family is
// TableFamilyUnspecified. // TableFamilyUnspecified.

View File

@ -24,6 +24,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const (
NFT_DROP = 0
NFT_ACCEPT = 1
NFT_STOLEN = 2
NFT_QUEUE = 3
NFT_REPEAT = 4
NFT_STOP = 5
)
// This code assembles the verdict structure, as expected by the // This code assembles the verdict structure, as expected by the
// nftables netlink API. // nftables netlink API.
// For further information, consult: // For further information, consult:
@ -126,3 +135,37 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
} }
return ad.Err() return ad.Err()
} }
func (e *Verdict) String() string {
var v string
switch e.Kind {
case unix.NFT_RETURN:
v = "return" // -0x5
case unix.NFT_GOTO:
v = "goto" // -0x4
case unix.NFT_JUMP:
v = "jump" // NFT_JUMP = -0x3
case unix.NFT_BREAK:
v = "break" // NFT_BREAK = -0x2
case unix.NFT_CONTINUE:
v = "continue" // NFT_CONTINUE = -0x1
case NFT_DROP:
v = "drop"
case NFT_ACCEPT:
v = "accept"
case NFT_STOLEN:
v = "stolen"
case NFT_QUEUE:
v = "queue"
case NFT_REPEAT:
v = "repeat"
case NFT_STOP:
v = "stop"
default:
v = fmt.Sprintf("verdict %v", e.Kind)
}
if e.Chain != "" {
return v + " " + e.Chain
}
return v
}

View File

@ -221,12 +221,27 @@ func TestRuleOperations(t *testing.T) {
expr.VerdictDrop, expr.VerdictDrop,
} }
wantStrings := []string{
"queue",
"accept",
"queue",
"accept",
"drop",
"drop",
}
for i, r := range rules { for i, r := range rules {
rr, _ := r.Exprs[0].(*expr.Verdict) rr, _ := r.Exprs[0].(*expr.Verdict)
if rr.Kind != want[i] { if rr.Kind != want[i] {
t.Fatalf("bad verdict kind at %d", i) t.Fatalf("bad verdict kind at %d", i)
} }
if rr.String() != wantStrings[i] {
t.Fatalf("bad verdict string at %d: %s (received) vs. %s (expected)", i, rr.String(), wantStrings[i])
}
t.Logf("%s", rr)
} }
} }
@ -1746,6 +1761,160 @@ func TestListChains(t *testing.T) {
} }
} }
func TestListChainByName(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table := &nftables.Table{
Name: "chain_test",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)
c := &nftables.Chain{
Name: "filter",
Table: table,
}
conn.AddChain(c)
if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}
cr, err := conn.ListChain(tr, c.Name)
if err != nil {
t.Fatalf("conn.ListChain() failed: %v", err)
}
if got, want := cr.Name, c.Name; got != want {
t.Fatalf("got chain %s, want chain %s", got, want)
}
if got, want := cr.Table.Name, table.Name; got != want {
t.Fatalf("got chain table %s, want chain table %s", got, want)
}
}
func TestListChainByNameUsingLasting(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table := &nftables.Table{
Name: "chain_test_lasting",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)
c := &nftables.Chain{
Name: "filter_lasting",
Table: table,
}
conn.AddChain(c)
if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}
cr, err := conn.ListChain(tr, c.Name)
if err != nil {
t.Fatalf("conn.ListChain() failed: %v", err)
}
if got, want := cr.Name, c.Name; got != want {
t.Fatalf("got chain %s, want chain %s", got, want)
}
if got, want := cr.Table.Name, table.Name; got != want {
t.Fatalf("got chain table %s, want chain table %s", got, want)
}
}
func TestListTableByName(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table1 := &nftables.Table{
Name: "table_test",
Family: nftables.TableFamilyIPv4,
}
conn.AddTable(table1)
table2 := &nftables.Table{
Name: "table_test_inet",
Family: nftables.TableFamilyINet,
}
conn.AddTable(table2)
table3 := &nftables.Table{
Name: table1.Name,
Family: nftables.TableFamilyINet,
}
conn.AddTable(table3)
if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}
tr, err := conn.ListTable(table1.Name)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table1.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
// not specifying table family should return family ipv4
tr, err = conn.ListTable(table3.Name)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table1.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
if got, want := tr.Family, table1.Family; got != want {
t.Fatalf("got table family %v, want table family %v", got, want)
}
// specifying correct INet family
tr, err = conn.ListTableOfFamily(table3.Name, nftables.TableFamilyINet)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table3.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
if got, want := tr.Family, table3.Family; got != want {
t.Fatalf("got table family %v, want table family %v", got, want)
}
// not specifying correct family should return err since no table in ipv4
tr, err = conn.ListTable(table2.Name)
if err == nil {
t.Fatalf("conn.ListTable() should have failed")
}
// specifying correct INet family
tr, err = conn.ListTableOfFamily(table2.Name, nftables.TableFamilyINet)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table2.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
if got, want := tr.Family, table2.Family; got != want {
t.Fatalf("got table family %v, want table family %v", got, want)
}
}
func TestAddChain(t *testing.T) { func TestAddChain(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

3
nftables_test.sh Executable file
View File

@ -0,0 +1,3 @@
go test ./...
go test -c github.com/google/nftables
sudo ./nftables.test -test.v -run_system_tests

View File

@ -112,6 +112,25 @@ func (cc *Conn) FlushTable(t *Table) {
}) })
} }
// ListTable returns table found for the specified name. Searches for
// the table under IPv4 family. As per nft man page: "When no address
// family is specified, ip is used by default."
func (cc *Conn) ListTable(name string) (*Table, error) {
return cc.ListTableOfFamily(name, TableFamilyIPv4)
}
// ListTableOfFamily returns table found for the specified name and table family
func (cc *Conn) ListTableOfFamily(name string, family TableFamily) (*Table, error) {
t, err := cc.listTablesOfNameAndFamily(name, family)
if err != nil {
return nil, err
}
if got, want := len(t), 1; got != want {
return nil, fmt.Errorf("expected table count %d, got %d", want, got)
}
return t[0], nil
}
// ListTables returns currently configured tables in the kernel // ListTables returns currently configured tables in the kernel
func (cc *Conn) ListTables() ([]*Table, error) { func (cc *Conn) ListTables() ([]*Table, error) {
return cc.ListTablesOfFamily(TableFamilyUnspecified) return cc.ListTablesOfFamily(TableFamilyUnspecified)
@ -120,18 +139,31 @@ func (cc *Conn) ListTables() ([]*Table, error) {
// ListTablesOfFamily returns currently configured tables for the specified table family // ListTablesOfFamily returns currently configured tables for the specified table family
// in the kernel. It lists all tables if family is TableFamilyUnspecified. // in the kernel. It lists all tables if family is TableFamilyUnspecified.
func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) { func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) {
return cc.listTablesOfNameAndFamily("", family)
}
func (cc *Conn) listTablesOfNameAndFamily(name string, family TableFamily) ([]*Table, error) {
conn, closer, err := cc.netlinkConn() conn, closer, err := cc.netlinkConn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = closer() }() defer func() { _ = closer() }()
data := extraHeader(uint8(family), 0)
flags := netlink.Request | netlink.Dump
if name != "" {
data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(name + "\x00")},
})...)
flags = netlink.Request
}
msg := netlink.Message{ msg := netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE),
Flags: netlink.Request | netlink.Dump, Flags: flags,
}, },
Data: extraHeader(uint8(family), 0), Data: data,
} }
response, err := conn.Execute(msg) response, err := conn.Execute(msg)