nftables/nftables_integration_test.go

699 lines
17 KiB
Go

package nftables_test
import (
"bytes"
"net"
"os"
"reflect"
"runtime"
"testing"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
)
func TestFlushTable(t *testing.T) {
if os.Getenv("TRAVIS") == "true" {
t.SkipNow()
}
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
nat := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "nat",
})
forward := c.AddChain(&nftables.Chain{
Table: filter,
Name: "forward",
})
input := c.AddChain(&nftables.Chain{
Table: filter,
Name: "input",
})
prerouting := c.AddChain(&nftables.Chain{
Table: nat,
Name: "prerouting",
})
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
// [ cmp eq reg 1 0x0000d204 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0x04, 0xd2},
},
// [ immediate reg 0 drop ]
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
})
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
// [ cmp eq reg 1 0x000010e1 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0xe1, 0x10},
},
// [ immediate reg 0 drop ]
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
})
c.AddRule(&nftables.Rule{
Table: filter,
Chain: input,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
// [ cmp eq reg 1 0x0000162e ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0x2e, 0x16},
},
// [ immediate reg 0 drop ]
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
})
c.AddRule(&nftables.Rule{
Table: nat,
Chain: prerouting,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2, // TODO
Len: 2, // TODO
},
// [ cmp eq reg 1 0x00001600 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0x00, 0x16},
},
// [ immediate reg 1 0x0000ae08 ]
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(2222),
},
// [ redir proto_min reg 1 ]
&expr.Redir{
RegisterProtoMin: 1,
},
},
})
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRule(filter, forward)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 2 {
t.Fatalf("len(rules) = %d, want 2", len(rules))
}
rules, err = c.GetRule(filter, input)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 1 {
t.Fatalf("len(rules) = %d, want 1", len(rules))
}
rules, err = c.GetRule(nat, prerouting)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 1 {
t.Fatalf("len(rules) = %d, want 1", len(rules))
}
c.FlushTable(filter)
if err := c.Flush(); err != nil {
t.Errorf("Second c.Flush() failed: %v", err)
}
rules, err = c.GetRule(filter, forward)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 0 {
t.Fatalf("len(rules) = %d, want 0", len(rules))
}
rules, err = c.GetRule(filter, input)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 0 {
t.Fatalf("len(rules) = %d, want 0", len(rules))
}
rules, err = c.GetRule(nat, prerouting)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 1 {
t.Fatalf("len(rules) = %d, want 1", len(rules))
}
}
func TestFlushChain(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
forward := c.AddChain(&nftables.Chain{
Table: filter,
Name: "forward",
})
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
// [ cmp eq reg 1 0x0000d204 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0x04, 0xd2},
},
// [ immediate reg 0 drop ]
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
})
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
// [ cmp eq reg 1 0x000010e1 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{0xe1, 0x10},
},
// [ immediate reg 0 drop ]
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
})
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRule(filter, forward)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 2 {
t.Fatalf("len(rules) = %d, want 2", len(rules))
}
c.FlushChain(forward)
if err := c.Flush(); err != nil {
t.Errorf("Second c.Flush() failed: %v", err)
}
rules, err = c.GetRule(filter, forward)
if err != nil {
t.Errorf("c.GetRule() failed: %v", err)
}
if len(rules) != 0 {
t.Fatalf("len(rules) = %d, want 0", len(rules))
}
}
func TestGetRuleLookupVerdictImmediate(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
forward := c.AddChain(&nftables.Chain{
Name: "forward",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
set := &nftables.Set{
Table: filter,
Name: "kek",
KeyType: nftables.TypeInetService,
}
if err := c.AddSet(set, nil); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
// [ meta load l4proto => reg 1 ]
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
// [ cmp eq reg 1 0x00000006 ]
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
// [ payload load 2b @ transport header + 2 => reg 1 ]
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
// [ lookup reg 1 set __set%d ]
&expr.Lookup{
SourceRegister: 1,
SetName: set.Name,
SetID: set.ID,
},
// [ immediate reg 0 drop ]
&expr.Verdict{
Kind: expr.VerdictAccept,
},
// [ immediate reg 2 kek ]
&expr.Immediate{
Register: 2,
Data: []byte("kek"),
},
},
})
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRule(
&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
},
&nftables.Chain{
Name: "forward",
},
)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 6; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
lookup, lookupOk := rules[0].Exprs[3].(*expr.Lookup)
if !lookupOk {
t.Fatalf("Exprs[3] is type %T, want *expr.Lookup", rules[0].Exprs[3])
}
if want := (&expr.Lookup{
SourceRegister: 1,
SetName: set.Name,
}); !reflect.DeepEqual(lookup, want) {
t.Errorf("lookup expr = %+v, wanted %+v", lookup, want)
}
verdict, verdictOk := rules[0].Exprs[4].(*expr.Verdict)
if !verdictOk {
t.Fatalf("Exprs[4] is type %T, want *expr.Verdict", rules[0].Exprs[4])
}
if want := (&expr.Verdict{
Kind: expr.VerdictAccept,
}); !reflect.DeepEqual(verdict, want) {
t.Errorf("verdict expr = %+v, wanted %+v", verdict, want)
}
imm, immOk := rules[0].Exprs[5].(*expr.Immediate)
if !immOk {
t.Fatalf("Exprs[4] is type %T, want *expr.Immediate", rules[0].Exprs[5])
}
if want := (&expr.Immediate{
Register: 2,
Data: []byte("kek"),
}); !reflect.DeepEqual(imm, want) {
t.Errorf("verdict expr = %+v, wanted %+v", imm, want)
}
}
func TestCreateUseNamedSet(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
portSet := &nftables.Set{
Table: filter,
Name: "kek",
KeyType: nftables.TypeInetService,
}
if err := c.AddSet(portSet, nil); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
ipSet := &nftables.Set{
Table: filter,
Name: "IPs_4_dayz",
KeyType: nftables.TypeIPAddr,
}
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil {
t.Errorf("c.AddSet(ipSet) failed: %v", err)
}
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
t.Errorf("c.SetVal(ipSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
sets, err := c.GetSets(filter)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(sets) != 2 {
t.Fatalf("len(sets) = %d, want 2", len(sets))
}
if sets[0].Name != "kek" {
t.Errorf("set[0].Name = %q, want kek", sets[0].Name)
}
if sets[1].Name != "IPs_4_dayz" {
t.Errorf("set[1].Name = %q, want IPs_4_dayz", sets[1].Name)
}
}
func TestCreateDeleteNamedSet(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
portSet := &nftables.Set{
Table: filter,
Name: "kek",
KeyType: nftables.TypeInetService,
}
if err := c.AddSet(portSet, nil); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
c.DelSet(portSet)
if err := c.Flush(); err != nil {
t.Errorf("Second c.Flush() failed: %v", err)
}
sets, err := c.GetSets(filter)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(sets) != 0 {
t.Fatalf("len(sets) = %d, want 0", len(sets))
}
}
func TestDeleteElementNamedSet(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
portSet := &nftables.Set{
Table: filter,
Name: "kek",
KeyType: nftables.TypeInetService,
}
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
c.SetDeleteElements(portSet, []nftables.SetElement{{Key: []byte{0, 23}}})
if err := c.Flush(); err != nil {
t.Errorf("Second c.Flush() failed: %v", err)
}
elems, err := c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(elems) != 1 {
t.Fatalf("len(elems) = %d, want 1", len(elems))
}
if !bytes.Equal(elems[0].Key, []byte{0, 22}) {
t.Errorf("elems[0].Key = %v, want 22", elems[0].Key)
}
}
func TestFlushNamedSet(t *testing.T) {
if os.Getenv("TRAVIS") == "true" {
t.SkipNow()
}
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
portSet := &nftables.Set{
Table: filter,
Name: "kek",
KeyType: nftables.TypeInetService,
}
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
c.FlushSet(portSet)
if err := c.Flush(); err != nil {
t.Errorf("Second c.Flush() failed: %v", err)
}
elems, err := c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(elems) != 0 {
t.Fatalf("len(elems) = %d, want 0", len(elems))
}
}
// openSystemNFTConn returns a netlink connection that tests against
// the running kernel in a separate network namespace.
// cleanupSystemNFTConn() must be called from a defer to cleanup
// created network namespace.
func openSystemNFTConn(t *testing.T) (*nftables.Conn, netns.NsHandle) {
t.Helper()
if !*enableSysTests {
t.SkipNow()
}
// We lock the goroutine into the current thread, as namespace operations
// such as those invoked by `netns.New()` are thread-local. This is undone
// in cleanupSystemNFTConn().
runtime.LockOSThread()
ns, err := netns.New()
if err != nil {
t.Fatalf("netns.New() failed: %v", err)
}
return &nftables.Conn{NetNS: int(ns)}, ns
}
func cleanupSystemNFTConn(t *testing.T, newNS netns.NsHandle) {
defer runtime.UnlockOSThread()
if err := newNS.Close(); err != nil {
t.Fatalf("newNS.Close() failed: %v", err)
}
}