Compare commits
1 Commits
3d33ebc054
...
c6a83a98ab
Author | SHA1 | Date |
---|---|---|
|
c6a83a98ab |
44
chain.go
44
chain.go
|
@ -37,7 +37,6 @@ var (
|
||||||
ChainHookOutput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_OUT)
|
ChainHookOutput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_OUT)
|
||||||
ChainHookPostrouting *ChainHook = ChainHookRef(unix.NF_INET_POST_ROUTING)
|
ChainHookPostrouting *ChainHook = ChainHookRef(unix.NF_INET_POST_ROUTING)
|
||||||
ChainHookIngress *ChainHook = ChainHookRef(unix.NF_NETDEV_INGRESS)
|
ChainHookIngress *ChainHook = ChainHookRef(unix.NF_NETDEV_INGRESS)
|
||||||
ChainHookEgress *ChainHook = ChainHookRef(unix.NF_NETDEV_EGRESS)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChainHookRef returns a pointer to a ChainHookRef value.
|
// ChainHookRef returns a pointer to a ChainHookRef value.
|
||||||
|
@ -102,7 +101,6 @@ type Chain struct {
|
||||||
Priority *ChainPriority
|
Priority *ChainPriority
|
||||||
Type ChainType
|
Type ChainType
|
||||||
Policy *ChainPolicy
|
Policy *ChainPolicy
|
||||||
Device string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddChain adds the specified Chain. See also
|
// AddChain adds the specified Chain. See also
|
||||||
|
@ -120,11 +118,6 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
|
||||||
{Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Hooknum))},
|
{Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Hooknum))},
|
||||||
{Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Priority))},
|
{Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Priority))},
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Device != "" {
|
|
||||||
hookAttr = append(hookAttr, netlink.Attribute{Type: unix.NFTA_HOOK_DEV, Data: []byte(c.Device + "\x00")})
|
|
||||||
}
|
|
||||||
|
|
||||||
data = append(data, cc.marshalAttr([]netlink.Attribute{
|
data = append(data, cc.marshalAttr([]netlink.Attribute{
|
||||||
{Type: unix.NLA_F_NESTED | unix.NFTA_CHAIN_HOOK, Data: cc.marshalAttr(hookAttr)},
|
{Type: unix.NLA_F_NESTED | unix.NFTA_CHAIN_HOOK, Data: cc.marshalAttr(hookAttr)},
|
||||||
})...)
|
})...)
|
||||||
|
@ -193,43 +186,6 @@ func (cc *Conn) ListChains() ([]*Chain, error) {
|
||||||
return cc.ListChainsOfTableFamily(TableFamilyUnspecified)
|
return cc.ListChainsOfTableFamily(TableFamilyUnspecified)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListChain returns a single chain configured in the specified table
|
|
||||||
func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) {
|
|
||||||
conn, closer, err := cc.netlinkConn()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() { _ = closer() }()
|
|
||||||
|
|
||||||
attrs := []netlink.Attribute{
|
|
||||||
{Type: unix.NFTA_TABLE_NAME, Data: []byte(table.Name + "\x00")},
|
|
||||||
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(chain + "\x00")},
|
|
||||||
}
|
|
||||||
msg := netlink.Message{
|
|
||||||
Header: netlink.Header{
|
|
||||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN),
|
|
||||||
Flags: netlink.Request,
|
|
||||||
},
|
|
||||||
Data: append(extraHeader(uint8(table.Family), 0), cc.marshalAttr(attrs)...),
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := conn.Execute(msg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("conn.Execute failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := len(response), 1; got != want {
|
|
||||||
return nil, fmt.Errorf("expected %d response message for chain, got %d", want, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
ch, err := chainFromMsg(response[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListChainsOfTableFamily returns currently configured chains for the specified
|
// ListChainsOfTableFamily returns currently configured chains for the specified
|
||||||
// family in the kernel. It lists all chains ins all tables if family is
|
// family in the kernel. It lists all chains ins all tables if family is
|
||||||
// TableFamilyUnspecified.
|
// TableFamilyUnspecified.
|
||||||
|
|
154
nftables_test.go
154
nftables_test.go
|
@ -1746,160 +1746,6 @@ func TestListChains(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListChainByName(t *testing.T) {
|
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
conn.FlushRuleset()
|
|
||||||
defer conn.FlushRuleset()
|
|
||||||
|
|
||||||
table := &nftables.Table{
|
|
||||||
Name: "chain_test",
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
}
|
|
||||||
tr := conn.AddTable(table)
|
|
||||||
|
|
||||||
c := &nftables.Chain{
|
|
||||||
Name: "filter",
|
|
||||||
Table: table,
|
|
||||||
}
|
|
||||||
conn.AddChain(c)
|
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Errorf("conn.Flush() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cr, err := conn.ListChain(tr, c.Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.ListChain() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := cr.Name, c.Name; got != want {
|
|
||||||
t.Fatalf("got chain %s, want chain %s", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := cr.Table.Name, table.Name; got != want {
|
|
||||||
t.Fatalf("got chain table %s, want chain table %s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListChainByNameUsingLasting(t *testing.T) {
|
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("nftables.New() failed: %v", err)
|
|
||||||
}
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
conn.FlushRuleset()
|
|
||||||
defer conn.FlushRuleset()
|
|
||||||
|
|
||||||
table := &nftables.Table{
|
|
||||||
Name: "chain_test_lasting",
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
}
|
|
||||||
tr := conn.AddTable(table)
|
|
||||||
|
|
||||||
c := &nftables.Chain{
|
|
||||||
Name: "filter_lasting",
|
|
||||||
Table: table,
|
|
||||||
}
|
|
||||||
conn.AddChain(c)
|
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Errorf("conn.Flush() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cr, err := conn.ListChain(tr, c.Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.ListChain() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := cr.Name, c.Name; got != want {
|
|
||||||
t.Fatalf("got chain %s, want chain %s", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := cr.Table.Name, table.Name; got != want {
|
|
||||||
t.Fatalf("got chain table %s, want chain table %s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListTableByName(t *testing.T) {
|
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
conn.FlushRuleset()
|
|
||||||
defer conn.FlushRuleset()
|
|
||||||
|
|
||||||
table1 := &nftables.Table{
|
|
||||||
Name: "table_test",
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
}
|
|
||||||
conn.AddTable(table1)
|
|
||||||
table2 := &nftables.Table{
|
|
||||||
Name: "table_test_inet",
|
|
||||||
Family: nftables.TableFamilyINet,
|
|
||||||
}
|
|
||||||
conn.AddTable(table2)
|
|
||||||
table3 := &nftables.Table{
|
|
||||||
Name: table1.Name,
|
|
||||||
Family: nftables.TableFamilyINet,
|
|
||||||
}
|
|
||||||
conn.AddTable(table3)
|
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Errorf("conn.Flush() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tr, err := conn.ListTable(table1.Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.ListTable() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := tr.Name, table1.Name; got != want {
|
|
||||||
t.Fatalf("got table %s, want table %s", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
// not specifying table family should return family ipv4
|
|
||||||
tr, err = conn.ListTable(table3.Name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.ListTable() failed: %v", err)
|
|
||||||
}
|
|
||||||
if got, want := tr.Name, table1.Name; got != want {
|
|
||||||
t.Fatalf("got table %s, want table %s", got, want)
|
|
||||||
}
|
|
||||||
if got, want := tr.Family, table1.Family; got != want {
|
|
||||||
t.Fatalf("got table family %v, want table family %v", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
// specifying correct INet family
|
|
||||||
tr, err = conn.ListTableOfFamily(table3.Name, nftables.TableFamilyINet)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.ListTable() failed: %v", err)
|
|
||||||
}
|
|
||||||
if got, want := tr.Name, table3.Name; got != want {
|
|
||||||
t.Fatalf("got table %s, want table %s", got, want)
|
|
||||||
}
|
|
||||||
if got, want := tr.Family, table3.Family; got != want {
|
|
||||||
t.Fatalf("got table family %v, want table family %v", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
// not specifying correct family should return err since no table in ipv4
|
|
||||||
tr, err = conn.ListTable(table2.Name)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("conn.ListTable() should have failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// specifying correct INet family
|
|
||||||
tr, err = conn.ListTableOfFamily(table2.Name, nftables.TableFamilyINet)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.ListTable() failed: %v", err)
|
|
||||||
}
|
|
||||||
if got, want := tr.Name, table2.Name; got != want {
|
|
||||||
t.Fatalf("got table %s, want table %s", got, want)
|
|
||||||
}
|
|
||||||
if got, want := tr.Family, table2.Family; got != want {
|
|
||||||
t.Fatalf("got table family %v, want table family %v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddChain(t *testing.T) {
|
func TestAddChain(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
36
table.go
36
table.go
|
@ -112,25 +112,6 @@ func (cc *Conn) FlushTable(t *Table) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListTable returns table found for the specified name. Searches for
|
|
||||||
// the table under IPv4 family. As per nft man page: "When no address
|
|
||||||
// family is specified, ip is used by default."
|
|
||||||
func (cc *Conn) ListTable(name string) (*Table, error) {
|
|
||||||
return cc.ListTableOfFamily(name, TableFamilyIPv4)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListTableOfFamily returns table found for the specified name and table family
|
|
||||||
func (cc *Conn) ListTableOfFamily(name string, family TableFamily) (*Table, error) {
|
|
||||||
t, err := cc.listTablesOfNameAndFamily(name, family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if got, want := len(t), 1; got != want {
|
|
||||||
return nil, fmt.Errorf("expected table count %d, got %d", want, got)
|
|
||||||
}
|
|
||||||
return t[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListTables returns currently configured tables in the kernel
|
// ListTables returns currently configured tables in the kernel
|
||||||
func (cc *Conn) ListTables() ([]*Table, error) {
|
func (cc *Conn) ListTables() ([]*Table, error) {
|
||||||
return cc.ListTablesOfFamily(TableFamilyUnspecified)
|
return cc.ListTablesOfFamily(TableFamilyUnspecified)
|
||||||
|
@ -139,31 +120,18 @@ func (cc *Conn) ListTables() ([]*Table, error) {
|
||||||
// ListTablesOfFamily returns currently configured tables for the specified table family
|
// ListTablesOfFamily returns currently configured tables for the specified table family
|
||||||
// in the kernel. It lists all tables if family is TableFamilyUnspecified.
|
// in the kernel. It lists all tables if family is TableFamilyUnspecified.
|
||||||
func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) {
|
func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) {
|
||||||
return cc.listTablesOfNameAndFamily("", family)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cc *Conn) listTablesOfNameAndFamily(name string, family TableFamily) ([]*Table, error) {
|
|
||||||
conn, closer, err := cc.netlinkConn()
|
conn, closer, err := cc.netlinkConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = closer() }()
|
defer func() { _ = closer() }()
|
||||||
|
|
||||||
data := extraHeader(uint8(family), 0)
|
|
||||||
flags := netlink.Request | netlink.Dump
|
|
||||||
if name != "" {
|
|
||||||
data = append(data, cc.marshalAttr([]netlink.Attribute{
|
|
||||||
{Type: unix.NFTA_TABLE_NAME, Data: []byte(name + "\x00")},
|
|
||||||
})...)
|
|
||||||
flags = netlink.Request
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := netlink.Message{
|
msg := netlink.Message{
|
||||||
Header: netlink.Header{
|
Header: netlink.Header{
|
||||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE),
|
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE),
|
||||||
Flags: flags,
|
Flags: netlink.Request | netlink.Dump,
|
||||||
},
|
},
|
||||||
Data: data,
|
Data: extraHeader(uint8(family), 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := conn.Execute(msg)
|
response, err := conn.Execute(msg)
|
||||||
|
|
Loading…
Reference in New Issue