7945 lines
240 KiB
Go
7945 lines
240 KiB
Go
// Copyright 2018 Google LLC. All Rights Reserved.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package nftables_test
|
||
|
||
import (
|
||
"bytes"
|
||
"errors"
|
||
"flag"
|
||
"fmt"
|
||
"net"
|
||
"os"
|
||
"reflect"
|
||
"strings"
|
||
"syscall"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/google/nftables"
|
||
"github.com/google/nftables/binaryutil"
|
||
"github.com/google/nftables/expr"
|
||
"github.com/google/nftables/internal/nftest"
|
||
"github.com/google/nftables/xt"
|
||
"github.com/mdlayher/netlink"
|
||
"golang.org/x/sys/unix"
|
||
)
|
||
|
||
var enableSysTests = flag.Bool("run_system_tests", false, "Run tests that operate against the live kernel")
|
||
|
||
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
|
||
// users to make sense of large byte literals more easily.
|
||
func nfdump(b []byte) string {
|
||
var buf bytes.Buffer
|
||
i := 0
|
||
for ; i < len(b); i += 4 {
|
||
// TODO: show printable characters as ASCII
|
||
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
|
||
b[i],
|
||
b[i+1],
|
||
b[i+2],
|
||
b[i+3])
|
||
}
|
||
for ; i < len(b); i++ {
|
||
fmt.Fprintf(&buf, "%02x ", b[i])
|
||
}
|
||
return buf.String()
|
||
}
|
||
|
||
// linediff returns a side-by-side diff of two nfdump() return values, flagging
|
||
// lines which are not equal with an exclamation point prefix.
|
||
func linediff(a, b string) string {
|
||
var buf bytes.Buffer
|
||
fmt.Fprintf(&buf, "got -- want\n")
|
||
linesA := strings.Split(a, "\n")
|
||
linesB := strings.Split(b, "\n")
|
||
for idx, lineA := range linesA {
|
||
if idx >= len(linesB) {
|
||
break
|
||
}
|
||
lineB := linesB[idx]
|
||
prefix := "! "
|
||
if lineA == lineB {
|
||
prefix = " "
|
||
}
|
||
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
|
||
}
|
||
return buf.String()
|
||
}
|
||
|
||
func ifname(n string) []byte {
|
||
b := make([]byte, 16)
|
||
copy(b, []byte(n+"\x00"))
|
||
return b
|
||
}
|
||
|
||
func TestRuleOperations(t *testing.T) {
|
||
// Create a new network namespace to test these operations,
|
||
// and tear down the namespace at test completion.
|
||
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
// Clear all rules at the beginning + end of the test.
|
||
c.FlushRuleset()
|
||
defer c.FlushRuleset()
|
||
|
||
filter := c.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
})
|
||
|
||
prerouting := c.AddChain(&nftables.Chain{
|
||
Name: "base-chain",
|
||
Table: filter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookPrerouting,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 drop ]
|
||
Kind: expr.VerdictDrop,
|
||
},
|
||
},
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 drop ]
|
||
Kind: expr.VerdictDrop,
|
||
},
|
||
},
|
||
})
|
||
|
||
c.InsertRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 accept ]
|
||
Kind: expr.VerdictAccept,
|
||
},
|
||
},
|
||
})
|
||
|
||
c.InsertRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 queue ]
|
||
Kind: expr.VerdictQueue,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
rules, _ := c.GetRules(filter, prerouting)
|
||
|
||
want := []expr.VerdictKind{
|
||
expr.VerdictQueue,
|
||
expr.VerdictAccept,
|
||
expr.VerdictDrop,
|
||
expr.VerdictDrop,
|
||
}
|
||
|
||
for i, r := range rules {
|
||
rr, _ := r.Exprs[0].(*expr.Verdict)
|
||
|
||
if rr.Kind != want[i] {
|
||
t.Fatalf("bad verdict kind at %d", i)
|
||
}
|
||
}
|
||
|
||
c.ReplaceRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Handle: rules[2].Handle,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 accept ]
|
||
Kind: expr.VerdictAccept,
|
||
},
|
||
},
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Position: rules[2].Handle,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 drop ]
|
||
Kind: expr.VerdictDrop,
|
||
},
|
||
},
|
||
})
|
||
|
||
c.InsertRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: prerouting,
|
||
Position: rules[2].Handle,
|
||
Exprs: []expr.Any{
|
||
&expr.Verdict{
|
||
// [ immediate reg 0 queue ]
|
||
Kind: expr.VerdictQueue,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
rules, _ = c.GetRules(filter, prerouting)
|
||
|
||
want = []expr.VerdictKind{
|
||
expr.VerdictQueue,
|
||
expr.VerdictAccept,
|
||
expr.VerdictQueue,
|
||
expr.VerdictAccept,
|
||
expr.VerdictDrop,
|
||
expr.VerdictDrop,
|
||
}
|
||
|
||
for i, r := range rules {
|
||
rr, _ := r.Exprs[0].(*expr.Verdict)
|
||
|
||
if rr.Kind != want[i] {
|
||
t.Fatalf("bad verdict kind at %d", i)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestConfigureNAT(t *testing.T) {
|
||
// The want byte sequences come from stracing nft(8), e.g.:
|
||
// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat
|
||
//
|
||
// The nft(8) command sequence was taken from:
|
||
// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft flush ruleset
|
||
[]byte("\x00\x00\x00\x00"),
|
||
// nft add table ip nat
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||
// nft add chain nat prerouting '{' type nat hook prerouting priority 0 \; '}'
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x0f\x00\x03\x00\x70\x72\x65\x72\x6f\x75\x74\x69\x6e\x67\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x07\x00\x6e\x61\x74\x00"),
|
||
// nft add chain nat postrouting '{' type nat hook postrouting priority 100 \; '}'
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x10\x00\x03\x00\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"),
|
||
// nft add rule nat postrouting oifname uplink0 masquerade
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x10\x00\x02\x00\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x00\x74\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x07\x08\x00\x01\x00\x00\x00\x00\x01\x38\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x18\x00\x03\x80\x14\x00\x01\x00\x75\x70\x6c\x69\x6e\x6b\x30\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"),
|
||
// nft add rule nat prerouting iif uplink0 tcp dport 4070 dnat 192.168.23.2:4080
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x0f\x00\x02\x00\x70\x72\x65\x72\x6f\x75\x74\x69\x6e\x67\x00\x00\x98\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x38\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x18\x00\x03\x80\x14\x00\x01\x00\x75\x70\x6c\x69\x6e\x6b\x30\x00\x00\x00\x00\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x10\x08\x00\x01\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x06\x00\x00\x00\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x02\x08\x00\x04\x00\x00\x00\x00\x02\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x06\x00\x01\x00\x0f\xe6\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x0c\x00\x02\x80\x08\x00\x01\x00\xc0\xa8\x17\x02\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x02\x0c\x00\x02\x80\x06\x00\x01\x00\x0f\xf0\x00\x00\x30\x00\x01\x80\x08\x00\x01\x00\x6e\x61\x74\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x01\x08\x00\x05\x00\x00\x00\x00\x02"),
|
||
// nft add rule nat prerouting iifname uplink0 udp dport 4070-4090 dnat 192.168.23.2:4070-4090
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x0f\x00\x02\x00\x70\x72\x65\x72\x6f\x75\x74\x69\x6e\x67\x00\x00\xf8\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x38\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x18\x00\x03\x80\x14\x00\x01\x00\x75\x70\x6c\x69\x6e\x6b\x30\x00\x00\x00\x00\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x10\x08\x00\x01\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x11\x00\x00\x00\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x02\x08\x00\x04\x00\x00\x00\x00\x02\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x05\x0c\x00\x03\x80\x06\x00\x01\x00\x0f\xe6\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x03\x0c\x00\x03\x80\x06\x00\x01\x00\x0f\xfa\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x0c\x00\x02\x80\x08\x00\x01\x00\xc0\xa8\x17\x02\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x02\x0c\x00\x02\x80\x06\x00\x01\x00\x0f\xe6\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x03\x0c\x00\x02\x80\x06\x00\x01\x00\x0f\xfa\x00\x00\x38\x00\x01\x80\x08\x00\x01\x00\x6e\x61\x74\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x01\x08\x00\x05\x00\x00\x00\x00\x02\x08\x00\x06\x00\x00\x00\x00\x03"),
|
||
// nft add rule nat prerouting ip daddr 10.0.0.0/24 dnat prefix to 20.0.0.0/24
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x0f\x00\x02\x00\x70\x72\x65\x72\x6f\x75\x74\x69\x6e\x67\x00\x00\x38\x01\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x10\x08\x00\x04\x00\x00\x00\x00\x04\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\xff\xff\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x0a\x00\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x0c\x00\x02\x80\x08\x00\x01\x00\x14\x00\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x02\x0c\x00\x02\x80\x08\x00\x01\x00\x14\x00\x00\xff\x38\x00\x01\x80\x08\x00\x01\x00\x6e\x61\x74\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x01\x08\x00\x04\x00\x00\x00\x00\x02\x08\x00\x07\x00\x00\x00\x00\x40"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.FlushRuleset()
|
||
|
||
nat := c.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "nat",
|
||
})
|
||
|
||
prerouting := c.AddChain(&nftables.Chain{
|
||
Name: "prerouting",
|
||
Hooknum: nftables.ChainHookPrerouting,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
Table: nat,
|
||
Type: nftables.ChainTypeNAT,
|
||
})
|
||
|
||
postrouting := c.AddChain(&nftables.Chain{
|
||
Name: "postrouting",
|
||
Hooknum: nftables.ChainHookPostrouting,
|
||
Priority: nftables.ChainPriorityNATSource,
|
||
Table: nat,
|
||
Type: nftables.ChainTypeNAT,
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: nat,
|
||
Chain: postrouting,
|
||
Exprs: []expr.Any{
|
||
// meta load oifname => reg 1
|
||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||
// cmp eq reg 1 0x696c7075 0x00306b6e 0x00000000 0x00000000
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: ifname("uplink0"),
|
||
},
|
||
// masq
|
||
&expr.Masq{},
|
||
},
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: nat,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
// [ meta load iifname => reg 1 ]
|
||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||
// [ cmp eq reg 1 0x696c7075 0x00306b6e 0x00000000 0x00000000 ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: ifname("uplink0"),
|
||
},
|
||
|
||
// [ meta load l4proto => reg 1 ]
|
||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||
// [ cmp eq reg 1 0x00000006 ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: []byte{unix.IPPROTO_TCP},
|
||
},
|
||
|
||
// [ payload load 2b @ transport header + 2 => reg 1 ]
|
||
&expr.Payload{
|
||
DestRegister: 1,
|
||
Base: expr.PayloadBaseTransportHeader,
|
||
Offset: 2, // TODO
|
||
Len: 2, // TODO
|
||
},
|
||
// [ cmp eq reg 1 0x0000e60f ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: binaryutil.BigEndian.PutUint16(4070),
|
||
},
|
||
|
||
// [ immediate reg 1 0x0217a8c0 ]
|
||
&expr.Immediate{
|
||
Register: 1,
|
||
Data: net.ParseIP("192.168.23.2").To4(),
|
||
},
|
||
// [ immediate reg 2 0x0000f00f ]
|
||
&expr.Immediate{
|
||
Register: 2,
|
||
Data: binaryutil.BigEndian.PutUint16(4080),
|
||
},
|
||
// [ nat dnat ip addr_min reg 1 addr_max reg 0 proto_min reg 2 proto_max reg 0 ]
|
||
&expr.NAT{
|
||
Type: expr.NATTypeDestNAT,
|
||
Family: unix.NFPROTO_IPV4,
|
||
RegAddrMin: 1,
|
||
RegProtoMin: 2,
|
||
},
|
||
},
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: nat,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
// [ meta load iifname => reg 1 ]
|
||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||
// [ cmp eq reg 1 0x696c7075 0x00306b6e 0x00000000 0x00000000 ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: ifname("uplink0"),
|
||
},
|
||
|
||
// [ meta load l4proto => reg 1 ]
|
||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||
// [ cmp eq reg 1 0x00000006 ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: []byte{unix.IPPROTO_UDP},
|
||
},
|
||
|
||
// [ payload load 2b @ transport header + 2 => reg 1 ]
|
||
&expr.Payload{
|
||
DestRegister: 1,
|
||
Base: expr.PayloadBaseTransportHeader,
|
||
Offset: 2, // TODO
|
||
Len: 2, // TODO
|
||
},
|
||
// [ cmp gte reg 1 0x0000e60f ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpGte,
|
||
Register: 1,
|
||
Data: binaryutil.BigEndian.PutUint16(4070),
|
||
},
|
||
// [ cmp lte reg 1 0x0000fa0f ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpLte,
|
||
Register: 1,
|
||
Data: binaryutil.BigEndian.PutUint16(4090),
|
||
},
|
||
|
||
// [ immediate reg 1 0x0217a8c0 ]
|
||
&expr.Immediate{
|
||
Register: 1,
|
||
Data: net.ParseIP("192.168.23.2").To4(),
|
||
},
|
||
// [ immediate reg 2 0x0000f00f ]
|
||
&expr.Immediate{
|
||
Register: 2,
|
||
Data: binaryutil.BigEndian.PutUint16(4070),
|
||
},
|
||
// [ immediate reg 3 0x0000fa0f ]
|
||
&expr.Immediate{
|
||
Register: 3,
|
||
Data: binaryutil.BigEndian.PutUint16(4090),
|
||
},
|
||
// [ nat dnat ip addr_min reg 1 addr_max reg 0 proto_min reg 2 proto_max reg 3 ]
|
||
&expr.NAT{
|
||
Type: expr.NATTypeDestNAT,
|
||
Family: unix.NFPROTO_IPV4,
|
||
RegAddrMin: 1,
|
||
RegProtoMin: 2,
|
||
RegProtoMax: 3,
|
||
},
|
||
},
|
||
})
|
||
|
||
dstipmatch, dstcidrmatch, err := net.ParseCIDR("10.0.0.0/24")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
dnatfirstip, dnatlastip, err := nftables.NetFirstAndLastIP("20.0.0.0/24")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: nat,
|
||
Chain: prerouting,
|
||
Exprs: []expr.Any{
|
||
&expr.Payload{
|
||
DestRegister: 1,
|
||
Base: expr.PayloadBaseNetworkHeader,
|
||
Offset: 16, // destination addr offset
|
||
Len: 4,
|
||
},
|
||
&expr.Bitwise{
|
||
SourceRegister: 1,
|
||
DestRegister: 1,
|
||
Len: 4,
|
||
// By specifying Xor to 0x0,0x0,0x0,0x0 and Mask to the CIDR mask,
|
||
// the rule will match the CIDR of the IP (e.g in this case 10.0.0.0/24).
|
||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||
Mask: dstcidrmatch.Mask,
|
||
},
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: dstipmatch.To4(),
|
||
},
|
||
&expr.Immediate{
|
||
Register: 1,
|
||
Data: dnatfirstip,
|
||
},
|
||
&expr.Immediate{
|
||
Register: 2,
|
||
Data: dnatlastip,
|
||
},
|
||
&expr.NAT{
|
||
Type: expr.NATTypeDestNAT,
|
||
RegAddrMin: 1,
|
||
RegAddrMax: 2,
|
||
Prefix: true,
|
||
Family: uint32(nftables.TableFamilyIPv4),
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestConfigureNATSourceAddress(t *testing.T) {
|
||
// The want byte sequences come from stracing nft(8), e.g.:
|
||
// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat
|
||
//
|
||
// The nft(8) command sequence was taken from:
|
||
// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft flush ruleset
|
||
[]byte("\x00\x00\x00\x00"),
|
||
// nft add table ip nat
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||
// nft add chain nat postrouting '{' type nat hook postrouting priority 100 \; '}'
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x10\x00\x03\x00\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"),
|
||
// nft add rule nat postrouting ip saddr 192.168.69.2 masquerade
|
||
[]byte("\x02\x00\x00\x00\x08\x00\x01\x00\x6e\x61\x74\x00\x10\x00\x02\x00\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x00\x78\x00\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x0c\x08\x00\x04\x00\x00\x00\x00\x04\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\xc0\xa8\x45\x02\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.FlushRuleset()
|
||
|
||
nat := c.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "nat",
|
||
})
|
||
|
||
postrouting := c.AddChain(&nftables.Chain{
|
||
Name: "postrouting",
|
||
Hooknum: nftables.ChainHookPostrouting,
|
||
Priority: nftables.ChainPriorityNATSource,
|
||
Table: nat,
|
||
Type: nftables.ChainTypeNAT,
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: nat,
|
||
Chain: postrouting,
|
||
Exprs: []expr.Any{
|
||
// payload load 4b @ network header + 12 => reg 1
|
||
&expr.Payload{
|
||
DestRegister: 1,
|
||
Base: expr.PayloadBaseNetworkHeader,
|
||
Offset: 12,
|
||
Len: 4,
|
||
},
|
||
// cmp eq reg 1 0x0245a8c0
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: net.ParseIP("192.168.69.2").To4(),
|
||
},
|
||
|
||
// masq
|
||
&expr.Masq{},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestMasqMarshalUnmarshal(t *testing.T) {
|
||
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
|
||
c.FlushRuleset()
|
||
defer c.FlushRuleset()
|
||
|
||
filter := c.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyINet,
|
||
Name: "filter",
|
||
})
|
||
postrouting := c.AddChain(&nftables.Chain{
|
||
Name: "postrouting",
|
||
Table: filter,
|
||
Type: nftables.ChainTypeNAT,
|
||
Hooknum: nftables.ChainHookPostrouting,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
})
|
||
|
||
min := uint32(1)
|
||
max := uint32(3)
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: postrouting,
|
||
Exprs: []expr.Any{
|
||
&expr.Masq{
|
||
ToPorts: true,
|
||
RegProtoMin: min,
|
||
RegProtoMax: max,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatalf("c.Flush() failed: %v", err)
|
||
}
|
||
|
||
rules, err := c.GetRules(
|
||
&nftables.Table{
|
||
Family: nftables.TableFamilyINet,
|
||
Name: "filter",
|
||
},
|
||
&nftables.Chain{
|
||
Name: "postrouting",
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("c.GetRules() failed: %v", err)
|
||
}
|
||
|
||
if got, want := len(rules), 1; got != want {
|
||
t.Fatalf("unexpected rule count: got %d, want %d", got, want)
|
||
}
|
||
|
||
rule := rules[0]
|
||
if got, want := len(rule.Exprs), 1; got != want {
|
||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
||
}
|
||
|
||
me, ok := rule.Exprs[0].(*expr.Masq)
|
||
if !ok {
|
||
t.Fatalf("unexpected expression type: got %T, want *expr.Masq", rule.Exprs[0])
|
||
}
|
||
|
||
if got, want := me.ToPorts, true; got != want {
|
||
t.Errorf("unexpected masq random flag: got %v, want %v", got, want)
|
||
}
|
||
|
||
if got, want := me.RegProtoMin, min; got != want {
|
||
t.Errorf("unexpected reg proto min: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := me.RegProtoMax, max; got != want {
|
||
t.Errorf("unexpected reg proto max: got %d, want %d", got, want)
|
||
}
|
||
}
|
||
|
||
func TestExprLogOptions(t *testing.T) {
|
||
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
|
||
c.FlushRuleset()
|
||
defer c.FlushRuleset()
|
||
|
||
filter := c.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
})
|
||
input := c.AddChain(&nftables.Chain{
|
||
Name: "input",
|
||
Table: filter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookInput,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
})
|
||
forward := c.AddChain(&nftables.Chain{
|
||
Name: "forward",
|
||
Table: filter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookInput,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
})
|
||
|
||
keyGQ := uint32((1 << unix.NFTA_LOG_GROUP) | (1 << unix.NFTA_LOG_QTHRESHOLD) | (1 << unix.NFTA_LOG_SNAPLEN))
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: input,
|
||
Exprs: []expr.Any{
|
||
&expr.Log{
|
||
Key: keyGQ,
|
||
QThreshold: uint16(20),
|
||
Group: uint16(1),
|
||
Snaplen: uint32(132),
|
||
},
|
||
},
|
||
})
|
||
|
||
keyPL := uint32((1 << unix.NFTA_LOG_PREFIX) | (1 << unix.NFTA_LOG_LEVEL) | (1 << unix.NFTA_LOG_FLAGS))
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: forward,
|
||
Exprs: []expr.Any{
|
||
&expr.Log{
|
||
Key: keyPL,
|
||
Data: []byte("LOG FORWARD"),
|
||
Level: expr.LogLevelDebug,
|
||
Flags: expr.LogFlagsTCPOpt | expr.LogFlagsIPOpt,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Errorf("c.Flush() failed: %v", err)
|
||
}
|
||
|
||
rules, err := c.GetRules(
|
||
&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
},
|
||
&nftables.Chain{
|
||
Name: "input",
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if got, want := len(rules), 1; got != want {
|
||
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
|
||
}
|
||
|
||
rule := rules[0]
|
||
if got, want := len(rule.Exprs), 1; got != want {
|
||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
||
}
|
||
|
||
le, ok := rule.Exprs[0].(*expr.Log)
|
||
if !ok {
|
||
t.Fatalf("unexpected expression type: got %T, want *expr.Log", rule.Exprs[0])
|
||
}
|
||
|
||
if got, want := le.Key, keyGQ; got != want {
|
||
t.Fatalf("unexpected log key: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := le.Group, uint16(1); got != want {
|
||
t.Fatalf("unexpected group: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := le.QThreshold, uint16(20); got != want {
|
||
t.Fatalf("unexpected queue-threshold: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := le.Snaplen, uint32(132); got != want {
|
||
t.Fatalf("unexpected snaplen: got %d, want %d", got, want)
|
||
}
|
||
|
||
rules, err = c.GetRules(
|
||
&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
},
|
||
&nftables.Chain{
|
||
Name: "forward",
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if got, want := len(rules), 1; got != want {
|
||
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
|
||
}
|
||
|
||
rule = rules[0]
|
||
if got, want := len(rule.Exprs), 1; got != want {
|
||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
||
}
|
||
|
||
le, ok = rule.Exprs[0].(*expr.Log)
|
||
if !ok {
|
||
t.Fatalf("unexpected expression type: got %T, want *expr.Log", rule.Exprs[0])
|
||
}
|
||
|
||
if got, want := le.Key, keyPL; got != want {
|
||
t.Fatalf("unexpected log key: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := string(le.Data), "LOG FORWARD"; got != want {
|
||
t.Fatalf("unexpected prefix data: got %s, want %s", got, want)
|
||
}
|
||
|
||
if got, want := le.Level, expr.LogLevelDebug; got != want {
|
||
t.Fatalf("unexpected log level: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := le.Flags, expr.LogFlagsTCPOpt|expr.LogFlagsIPOpt; got != want {
|
||
t.Fatalf("unexpected log flags: got %d, want %d", got, want)
|
||
}
|
||
}
|
||
|
||
func TestExprLogPrefix(t *testing.T) {
|
||
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
|
||
c.FlushRuleset()
|
||
defer c.FlushRuleset()
|
||
|
||
filter := c.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
})
|
||
input := c.AddChain(&nftables.Chain{
|
||
Name: "input",
|
||
Table: filter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookInput,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: filter,
|
||
Chain: input,
|
||
Exprs: []expr.Any{
|
||
&expr.Log{
|
||
Key: 1 << unix.NFTA_LOG_PREFIX,
|
||
Data: []byte("LOG INPUT"),
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Errorf("c.Flush() failed: %v", err)
|
||
}
|
||
|
||
rules, err := c.GetRules(
|
||
&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
},
|
||
&nftables.Chain{
|
||
Name: "input",
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if got, want := len(rules), 1; got != want {
|
||
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
|
||
}
|
||
if got, want := len(rules[0].Exprs), 1; got != want {
|
||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
||
}
|
||
|
||
logExpr, ok := rules[0].Exprs[0].(*expr.Log)
|
||
if !ok {
|
||
t.Fatalf("Exprs[0] is type %T, want *expr.Log", rules[0].Exprs[0])
|
||
}
|
||
|
||
// nftables defaults to warn log level when no level is specified and group is not defined
|
||
// see https://wiki.nftables.org/wiki-nftables/index.php/Logging_traffic
|
||
if got, want := logExpr.Key, uint32((1<<unix.NFTA_LOG_PREFIX)|(1<<unix.NFTA_LOG_LEVEL)); got != want {
|
||
t.Fatalf("unexpected *expr.Log key: got %d, want %d", got, want)
|
||
}
|
||
if got, want := string(logExpr.Data), "LOG INPUT"; got != want {
|
||
t.Fatalf("unexpected *expr.Log data: got %s, want %s", got, want)
|
||
}
|
||
if got, want := logExpr.Level, expr.LogLevelWarning; got != want {
|
||
t.Fatalf("unexpected *expr.Log level: got %d, want %d", got, want)
|
||
}
|
||
}
|
||
|
||
func TestGetRules(t *testing.T) {
|
||
// The want byte sequences come from stracing nft(8), e.g.:
|
||
// strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward
|
||
|
||
want := [][]byte{
|
||
{0x2, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x0, 0x0, 0x0},
|
||
}
|
||
|
||
// The reply messages come from adding log.Printf("msgs: %#v", msgs) to
|
||
// (*github.com/mdlayher/netlink/Conn).receive
|
||
reply := [][]netlink.Message{
|
||
nil,
|
||
{{Header: netlink.Header{Length: 0x68, Type: 0xa06, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x2, 0x0, 0x0, 0xc, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x0, 0x4, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6d, 0x92, 0x20, 0x20, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x48, 0xd9}}},
|
||
{{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}},
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
rep := reply[0]
|
||
reply = reply[1:]
|
||
return rep, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
rules, err := c.GetRules(
|
||
&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
},
|
||
&nftables.Chain{
|
||
Name: "input",
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if got, want := len(rules), 1; got != want {
|
||
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
|
||
}
|
||
|
||
rule := rules[0]
|
||
if got, want := len(rule.Exprs), 1; got != want {
|
||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
||
}
|
||
|
||
ce, ok := rule.Exprs[0].(*expr.Counter)
|
||
if !ok {
|
||
t.Fatalf("unexpected expression type: got %T, want *expr.Counter", rule.Exprs[0])
|
||
}
|
||
|
||
if got, want := ce.Packets, uint64(674009); got != want {
|
||
t.Errorf("unexpected number of packets: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := ce.Bytes, uint64(1838293024); got != want {
|
||
t.Errorf("unexpected number of bytes: got %d, want %d", got, want)
|
||
}
|
||
}
|
||
|
||
func TestAddCounter(t *testing.T) {
|
||
// The want byte sequences come from stracing nft(8), e.g.:
|
||
// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat
|
||
//
|
||
// The nft(8) command sequence was taken from:
|
||
// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add counter ip filter fwded
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0a\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x01\x1c\x00\x04\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
|
||
// nft add rule ip filter forward counter name fwded
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x2c\x00\x04\x80\x28\x00\x01\x80\x0b\x00\x01\x00\x6f\x62\x6a\x72\x65\x66\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x09\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddObj(&nftables.CounterObj{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Name: "fwded",
|
||
Bytes: 0,
|
||
Packets: 0,
|
||
})
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{Name: "forward", Type: nftables.ChainTypeFilter},
|
||
Exprs: []expr.Any{
|
||
&expr.Objref{
|
||
Type: 1,
|
||
Name: "fwded",
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestDeleteCounter(t *testing.T) {
|
||
// The want byte sequences come from stracing nft(8), e.g.:
|
||
// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat
|
||
//
|
||
// The nft(8) command sequence was taken from:
|
||
// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add counter ip filter fwded
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0a\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x01\x1c\x00\x04\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
|
||
// nft delete counter ip filter fwded
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0a\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x01\x04\x00\x04\x80"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddObj(&nftables.CounterObj{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Name: "fwded",
|
||
Bytes: 0,
|
||
Packets: 0,
|
||
})
|
||
|
||
c.DeleteObject(&nftables.CounterObj{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Name: "fwded",
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestDelRule(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft delete rule ipv4table ipv4chain-1 handle 9
|
||
[]byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x31\x00\x0c\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x09"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.DelRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{Name: "ipv4chain-1", Type: nftables.ChainTypeFilter},
|
||
Handle: uint64(9),
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestLog(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add rule ipv4table ipv4chain-1 log prefix nftables
|
||
[]byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x31\x00\x24\x00\x04\x80\x20\x00\x01\x80\x08\x00\x01\x00\x6c\x6f\x67\x00\x14\x00\x02\x80\x0d\x00\x02\x00\x6e\x66\x74\x61\x62\x6c\x65\x73\x00\x00\x00\x00"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{Name: "ipv4chain-1", Type: nftables.ChainTypeFilter},
|
||
Exprs: []expr.Any{
|
||
&expr.Log{
|
||
Key: 1 << unix.NFTA_LOG_PREFIX,
|
||
Data: []byte("nftables"),
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestTProxy(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add rule filter divert ip protocol tcp tproxy to :50080
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0b\x00\x02\x00\x64\x69\x76\x65\x72\x74\x00\x00\xb4\x00\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x09\x08\x00\x04\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x06\x00\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x0c\x00\x02\x80\x06\x00\x01\x00\xc3\xa0\x00\x00\x24\x00\x01\x80\x0b\x00\x01\x00\x74\x70\x72\x6f\x78\x79\x00\x00\x14\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x01"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{
|
||
Name: "divert",
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookPrerouting,
|
||
Priority: nftables.ChainPriorityRef(-150),
|
||
},
|
||
Exprs: []expr.Any{
|
||
// [ payload load 1b @ network header + 9 => reg 1 ]
|
||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1},
|
||
// [ cmp eq reg 1 0x00000006 ]
|
||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
|
||
// [ immediate reg 1 0x0000a0c3 ]
|
||
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(50080)},
|
||
// [ tproxy ip port reg 1 ]
|
||
&expr.TProxy{
|
||
Family: byte(nftables.TableFamilyIPv4),
|
||
TableFamily: byte(nftables.TableFamilyIPv4),
|
||
RegPort: 1,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestTProxyWithAddrField(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add rule filter divert ip protocol tcp tproxy to 10.10.72.1:50080
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0b\x00\x02\x00\x64\x69\x76\x65\x72\x74\x00\x00\xe8\x00\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x09\x08\x00\x04\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x06\x00\x00\x00\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x0c\x00\x02\x80\x08\x00\x01\x00\x0a\x0a\x48\x01\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x02\x0c\x00\x02\x80\x06\x00\x01\x00\xc3\xa0\x00\x00\x2c\x00\x01\x80\x0b\x00\x01\x00\x74\x70\x72\x6f\x78\x79\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x01"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{
|
||
Name: "divert",
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookPrerouting,
|
||
Priority: nftables.ChainPriorityRef(-150),
|
||
},
|
||
Exprs: []expr.Any{
|
||
// [ payload load 1b @ network header + 9 => reg 1 ]
|
||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1},
|
||
// [ cmp eq reg 1 0x00000006 ]
|
||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
|
||
// [ immediate reg 1 0x01480a0a ]
|
||
&expr.Immediate{Register: 1, Data: []byte("\x0a\x0a\x48\x01")},
|
||
// [ immediate reg 2 0x0000a0c3 ]
|
||
&expr.Immediate{Register: 2, Data: binaryutil.BigEndian.PutUint16(50080)},
|
||
// [ tproxy ip addr reg 1 port reg 2 ]
|
||
&expr.TProxy{
|
||
Family: byte(nftables.TableFamilyIPv4),
|
||
TableFamily: byte(nftables.TableFamilyIPv4),
|
||
RegAddr: 1,
|
||
RegPort: 2,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestCt(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// sudo nft add rule ipv4table ipv4chain-5 ct mark 123 counter
|
||
[]byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x35\x00\x24\x00\x04\x80\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{
|
||
Name: "ipv4chain-5",
|
||
},
|
||
Exprs: []expr.Any{
|
||
// [ ct load mark => reg 1 ]
|
||
&expr.Ct{
|
||
Key: unix.NFT_CT_MARK,
|
||
Register: 1,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestSecMarkMarshaling(t *testing.T) {
|
||
// Testing marshaling since secmark requires live selinux tag otherwise
|
||
// errors with conn.Receive: netlink receive: no such file or directory.
|
||
// More information available here:
|
||
// https://git.netfilter.org/nftables/tree/files/examples/secmark.nft?id=26d9cbefb10e6bc3765df7e9e7a4fc3b951a80f3#n6
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// sudo nft add table inet filter
|
||
[]byte("\x01\x00\x00\x00\x0b\x00\x01\x00filter\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||
// sudo nft add secmark inet filter sshtag '{ ctx "system_u:object_r:ssh_server_packet_t:s0" }'
|
||
[]byte("\x01\x00\x00\x00\x0b\x00\x01\x00filter\x00\x00\x0b\x00\x02\x00sshtag\x00\x00\x08\x00\x03\x00\x00\x00\x00\x080\x00\x04\x80,\x00\x01\x00system_u:object_r:ssh_server_packet_t:s0"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
conn, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
table := conn.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyINet,
|
||
Name: "filter",
|
||
})
|
||
|
||
sec := &nftables.NamedObj{
|
||
Table: table,
|
||
Name: "sshtag",
|
||
Type: nftables.ObjTypeSecMark,
|
||
Obj: &expr.SecMark{
|
||
Ctx: "system_u:object_r:ssh_server_packet_t:s0",
|
||
},
|
||
}
|
||
conn.AddObj(sec)
|
||
|
||
if err := conn.Flush(); err != nil {
|
||
t.Fatal(err.Error())
|
||
}
|
||
}
|
||
|
||
func TestSynProxyObject(t *testing.T) {
|
||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table := conn.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyINet,
|
||
Name: "filter",
|
||
})
|
||
|
||
syn1 := &nftables.NamedObj{
|
||
Table: table,
|
||
Name: "https-synproxy",
|
||
Type: nftables.ObjTypeSynProxy,
|
||
Obj: &expr.SynProxy{
|
||
Mss: 1,
|
||
Wscale: 2,
|
||
Timestamp: true,
|
||
SackPerm: true,
|
||
// set for equals test below
|
||
MssValueSet: true,
|
||
WscaleValueSet: true,
|
||
},
|
||
}
|
||
syn2 := &nftables.NamedObj{
|
||
Table: table,
|
||
Name: "https-synproxy-empty",
|
||
Type: nftables.ObjTypeSynProxy,
|
||
Obj: &expr.SynProxy{},
|
||
}
|
||
syn3 := &nftables.NamedObj{
|
||
Table: table,
|
||
Name: "https-synproxy-zero",
|
||
Type: nftables.ObjTypeSynProxy,
|
||
Obj: &expr.SynProxy{
|
||
Mss: 0,
|
||
Wscale: 0,
|
||
MssValueSet: true,
|
||
WscaleValueSet: true,
|
||
},
|
||
}
|
||
conn.AddObj(syn1)
|
||
conn.AddObj(syn2)
|
||
conn.AddObj(syn3)
|
||
if err := conn.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
objs, err := conn.GetNamedObjects(table)
|
||
if err != nil {
|
||
t.Errorf("c.GetObjects(table) failed: %v", err)
|
||
}
|
||
if got, want := len(objs), 3; got != want {
|
||
t.Fatalf("received %d objects, expected %d", got, want)
|
||
}
|
||
|
||
synObjs := []*nftables.NamedObj{syn1, syn2, syn3}
|
||
for i := 0; i < len(objs); i++ {
|
||
obj := objs[i].(*nftables.NamedObj)
|
||
syn := synObjs[i]
|
||
if got, want := obj.Name, syn.Name; got != want {
|
||
t.Errorf("object %d names are not equal: got %s, want %s", i, got, want)
|
||
}
|
||
if got, want := obj.Type, syn.Type; got != want {
|
||
t.Errorf("object %d types are not equal: got %v, want %v", i, got, want)
|
||
}
|
||
if got, want := obj.Table.Name, syn.Table.Name; got != want {
|
||
t.Errorf("object %d tables are not equal: got %s, want %s", i, got, want)
|
||
}
|
||
sp1 := obj.Obj.(*expr.SynProxy)
|
||
sp2 := syn.Obj.(*expr.SynProxy)
|
||
if got, want := sp1.Mss, sp2.Mss; got != want {
|
||
t.Errorf("object %d mss' are not equal: got %d, want %d", i, got, want)
|
||
}
|
||
if got, want := sp1.Wscale, sp2.Wscale; got != want {
|
||
t.Errorf("object %d wscales are not equal: got %d, want %d", i, got, want)
|
||
}
|
||
if got, want := sp1.Timestamp, sp2.Timestamp; got != want {
|
||
t.Errorf("object %d timestamp flags are not equal: got %v, want %v", i, got, want)
|
||
}
|
||
if got, want := sp1.SackPerm, sp2.SackPerm; got != want {
|
||
t.Errorf("object %d sack-perm flags are not equal: got %v, want %v", i, got, want)
|
||
}
|
||
if got, want := sp1.MssValueSet, sp2.MssValueSet; got != want {
|
||
t.Errorf("object %d MssValueSet flags are not equal: got %v, want %v", i, got, want)
|
||
}
|
||
if got, want := sp1.WscaleValueSet, sp2.WscaleValueSet; got != want {
|
||
t.Errorf("object %d WscaleValueSet flags are not equal: got %v, want %v", i, got, want)
|
||
}
|
||
if got, want := sp1.Ecn, sp2.Ecn; got != want {
|
||
t.Errorf("object %d Ecn flags are not equal: got %v, want %v", i, got, want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestCtTimeout(t *testing.T) {
|
||
t.Parallel()
|
||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table := conn.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
})
|
||
|
||
tests := [...]struct {
|
||
Name string
|
||
Input expr.CtTimeout
|
||
Expect expr.CtTimeout
|
||
}{
|
||
{
|
||
Name: "timeout-blank-tcp-policy",
|
||
Input: expr.CtTimeout{L4Proto: unix.IPPROTO_TCP},
|
||
Expect: expr.CtTimeout{
|
||
L4Proto: unix.IPPROTO_TCP,
|
||
L3Proto: unix.NFPROTO_UNSPEC,
|
||
Policy: expr.CtStateTCPTimeoutDefaults,
|
||
},
|
||
},
|
||
{
|
||
Name: "timeout-blank-udp-policy",
|
||
Input: expr.CtTimeout{L4Proto: unix.IPPROTO_UDP},
|
||
Expect: expr.CtTimeout{
|
||
L4Proto: unix.IPPROTO_UDP,
|
||
L3Proto: unix.NFPROTO_UNSPEC,
|
||
Policy: expr.CtStateUDPTimeoutDefaults,
|
||
},
|
||
},
|
||
{
|
||
Name: "timeout-partial-tcp-policy",
|
||
Input: expr.CtTimeout{
|
||
L4Proto: unix.IPPROTO_TCP,
|
||
L3Proto: unix.NFPROTO_IPV4,
|
||
Policy: expr.CtStatePolicyTimeout{
|
||
expr.CtStateTCPSYNSENT: 100,
|
||
expr.CtStateTCPESTABLISHED: 5,
|
||
expr.CtStateTCPCLOSEWAIT: 9,
|
||
},
|
||
},
|
||
Expect: expr.CtTimeout{
|
||
L4Proto: unix.IPPROTO_TCP,
|
||
L3Proto: unix.NFPROTO_IPV4,
|
||
Policy: expr.CtStatePolicyTimeout{
|
||
expr.CtStateTCPSYNSENT: 100,
|
||
expr.CtStateTCPSYNRECV: 60,
|
||
expr.CtStateTCPESTABLISHED: 5,
|
||
expr.CtStateTCPFINWAIT: 120,
|
||
expr.CtStateTCPCLOSEWAIT: 9,
|
||
expr.CtStateTCPLASTACK: 30,
|
||
expr.CtStateTCPTIMEWAIT: 120,
|
||
expr.CtStateTCPCLOSE: 10,
|
||
expr.CtStateTCPSYNSENT2: 120,
|
||
expr.CtStateTCPRETRANS: 300,
|
||
expr.CtStateTCPUNACK: 300,
|
||
},
|
||
},
|
||
},
|
||
{
|
||
Name: "timeout-complete-udp-policy",
|
||
Input: expr.CtTimeout{
|
||
L4Proto: unix.IPPROTO_UDP,
|
||
L3Proto: unix.NFPROTO_IPV6,
|
||
Policy: expr.CtStatePolicyTimeout{
|
||
expr.CtStateUDPUNREPLIED: 500,
|
||
expr.CtStateUDPREPLIED: 10000,
|
||
},
|
||
},
|
||
Expect: expr.CtTimeout{
|
||
L4Proto: unix.IPPROTO_UDP,
|
||
L3Proto: unix.NFPROTO_IPV6,
|
||
Policy: expr.CtStatePolicyTimeout{
|
||
expr.CtStateUDPUNREPLIED: 500,
|
||
expr.CtStateUDPREPLIED: 10000,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.Name, func(t *testing.T) {
|
||
ctt1 := conn.AddObj(&nftables.NamedObj{
|
||
Table: table,
|
||
Name: tt.Name,
|
||
Type: nftables.ObjTypeCtTimeout,
|
||
Obj: &tt.Input,
|
||
})
|
||
|
||
if err := conn.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
obj, err := conn.GetObject(ctt1)
|
||
if err != nil {
|
||
t.Errorf("c.GetObject(ctt1) failed: %v failed", err)
|
||
}
|
||
|
||
ctt2, ok := obj.(*nftables.NamedObj)
|
||
if !ok {
|
||
t.Fatalf("unexpected type: got %T, want *nftables.NamedObj", ctt2)
|
||
}
|
||
|
||
o1 := ctt2.Obj.(*expr.CtTimeout)
|
||
o2 := &tt.Expect
|
||
if got, want := o1.L3Proto, o2.L3Proto; got != want {
|
||
t.Fatalf("unexpected l3proto: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := o1.L4Proto, o2.L4Proto; got != want {
|
||
t.Fatalf("unexpected l4proto: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := o1.Policy, o2.Policy; !reflect.DeepEqual(got, want) {
|
||
t.Fatalf("unexpected policy: got %v, want %v", got, want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestCtExpect(t *testing.T) {
|
||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table := conn.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
})
|
||
|
||
cte := &nftables.NamedObj{
|
||
Table: table,
|
||
Name: "expect",
|
||
Type: nftables.ObjTypeCtExpect,
|
||
Obj: &expr.CtExpect{
|
||
L3Proto: unix.NFPROTO_IPV4,
|
||
L4Proto: unix.IPPROTO_TCP,
|
||
DPort: 53,
|
||
Timeout: 20,
|
||
Size: 100,
|
||
},
|
||
}
|
||
|
||
conn.AddObj(cte)
|
||
if err := conn.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
objs, err := conn.GetNamedObjects(table)
|
||
if err != nil {
|
||
t.Errorf("c.GetObjects(table) failed: %v", err)
|
||
}
|
||
|
||
if got, want := len(objs), 1; got != want {
|
||
t.Fatalf("received %d objects, expected %d", got, want)
|
||
}
|
||
|
||
obj := objs[0].(*nftables.NamedObj)
|
||
if got, want := obj.Name, cte.Name; got != want {
|
||
t.Errorf("object names are not equal: got %s, want %s", got, want)
|
||
}
|
||
if got, want := obj.Type, cte.Type; got != want {
|
||
t.Errorf("object types are not equal: got %v, want %v", got, want)
|
||
}
|
||
if got, want := obj.Table.Name, cte.Table.Name; got != want {
|
||
t.Errorf("object tables are not equal: got %s, want %s", got, want)
|
||
}
|
||
|
||
ce1 := obj.Obj.(*expr.CtExpect)
|
||
ce2 := cte.Obj.(*expr.CtExpect)
|
||
if got, want := ce1.L3Proto, ce2.L3Proto; got != want {
|
||
t.Errorf("object l3proto not equal: got %d, want %d", got, want)
|
||
}
|
||
if got, want := ce1.L4Proto, ce2.L4Proto; got != want {
|
||
t.Errorf("object l4proto not equal: got %d, want %d", got, want)
|
||
}
|
||
if got, want := ce1.DPort, ce2.DPort; got != want {
|
||
t.Errorf("object dport not equal: got %d, want %d", got, want)
|
||
}
|
||
if got, want := ce1.Size, ce2.Size; got != want {
|
||
t.Errorf("object Size not equal: got %d, want %d", got, want)
|
||
}
|
||
if got, want := ce1.Timeout, ce2.Timeout; got != want {
|
||
t.Errorf("object timeout not equal: got %d, want %d", got, want)
|
||
}
|
||
}
|
||
|
||
func TestCtHelper(t *testing.T) {
|
||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table := conn.AddTable(&nftables.Table{
|
||
Family: nftables.TableFamilyIPv4,
|
||
Name: "filter",
|
||
})
|
||
|
||
cthelp1 := conn.AddObj(&nftables.NamedObj{
|
||
Table: table,
|
||
Name: "ftp-standard",
|
||
Type: nftables.ObjTypeCtHelper,
|
||
Obj: &expr.CtHelper{
|
||
Name: "ftp",
|
||
L4Proto: unix.IPPROTO_TCP,
|
||
L3Proto: unix.NFPROTO_IPV4,
|
||
},
|
||
})
|
||
|
||
if err := conn.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
obj1, err := conn.GetObject(cthelp1)
|
||
if err != nil {
|
||
t.Errorf("c.GetObject(cthelp1) failed: %v failed", err)
|
||
}
|
||
|
||
helper, ok := obj1.(*nftables.NamedObj)
|
||
if !ok {
|
||
t.Fatalf("unexpected type: got %T, want *nftables.NamedObj", obj1)
|
||
}
|
||
|
||
if got, want := helper.Name, "ftp-standard"; got != want {
|
||
t.Fatalf("unexpected counter name: got %s, want %s", got, want)
|
||
}
|
||
|
||
if _, err = conn.ResetObject(cthelp1); err != nil {
|
||
t.Errorf("c.ResetObjects(cthelp1) failed: %v failed", err)
|
||
}
|
||
|
||
obj1, err = conn.GetObject(cthelp1)
|
||
if err != nil {
|
||
t.Errorf("c.GetObject(cthelp1) failed: %v failed", err)
|
||
}
|
||
|
||
help := obj1.(*nftables.NamedObj).Obj.(*expr.CtHelper)
|
||
if got, want := help.L4Proto, uint8(unix.IPPROTO_TCP); got != want {
|
||
t.Errorf("unexpected l4proto number: got %d, want %d", got, want)
|
||
}
|
||
|
||
if got, want := help.L3Proto, uint16(unix.NFPROTO_IPV4); got != want {
|
||
t.Errorf("unexpected l3proto number: got %d, want %d", got, want)
|
||
}
|
||
}
|
||
|
||
func TestCtSet(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// sudo nft add rule filter forward ct mark set 1
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x50\x00\x04\x80\x2c\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x0c\x00\x02\x80\x08\x00\x01\x00\x01\x00\x00\x00\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x04\x00\x00\x00\x00\x01"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{
|
||
Name: "forward",
|
||
},
|
||
Exprs: []expr.Any{
|
||
// [ immediate reg 1 0x00000001 ]
|
||
&expr.Immediate{
|
||
Register: 1,
|
||
Data: binaryutil.NativeEndian.PutUint32(1),
|
||
},
|
||
// [ ct set mark with reg 1 ]
|
||
&expr.Ct{
|
||
Key: expr.CtKeyMARK,
|
||
Register: 1,
|
||
SourceRegister: true,
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestCtStat(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// ct state established,related accept
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0b\x00\x02\x00\x6f\x75\x74\x70\x75\x74\x00\x00\xc4\x00\x04\x80\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x06\x00\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{
|
||
Name: "output",
|
||
},
|
||
Exprs: []expr.Any{
|
||
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
|
||
&expr.Bitwise{
|
||
SourceRegister: 1,
|
||
DestRegister: 1,
|
||
Len: 4,
|
||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||
},
|
||
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: []byte{0, 0, 0, 0}},
|
||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestAddRuleWithPosition(t *testing.T) {
|
||
want := [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add rule ip ipv4table ipv4chain-1 position 2 ip version 6
|
||
[]byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x31\x00\xa8\x00\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x01\x0c\x00\x04\x80\x05\x00\x01\x00\xf0\x00\x00\x00\x0c\x00\x05\x80\x05\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x60\x00\x00\x00\x0c\x00\x06\x00\x00\x00\x00\x00\x00\x00\x00\x02"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
for idx, msg := range req {
|
||
b, err := msg.MarshalBinary()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(b) < 16 {
|
||
continue
|
||
}
|
||
b = b[16:]
|
||
if len(want) == 0 {
|
||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||
continue
|
||
}
|
||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
}
|
||
want = want[1:]
|
||
}
|
||
return req, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c.AddRule(&nftables.Rule{
|
||
Position: 2,
|
||
Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4},
|
||
Chain: &nftables.Chain{
|
||
Name: "ipv4chain-1",
|
||
Type: nftables.ChainTypeFilter,
|
||
Hooknum: nftables.ChainHookPrerouting,
|
||
Priority: nftables.ChainPriorityRef(0),
|
||
},
|
||
|
||
Exprs: []expr.Any{
|
||
// [ payload load 1b @ network header + 0 => reg 1 ]
|
||
&expr.Payload{
|
||
DestRegister: 1,
|
||
Base: expr.PayloadBaseNetworkHeader,
|
||
Offset: 0, // Offset for a transport protocol header
|
||
Len: 1, // 1 bytes for port
|
||
},
|
||
// [ bitwise reg 1 = (reg=1 & 0x000000f0 ) ^ 0x00000000 ]
|
||
&expr.Bitwise{
|
||
SourceRegister: 1,
|
||
DestRegister: 1,
|
||
Len: 1,
|
||
Mask: []byte{0xf0},
|
||
Xor: []byte{0x0},
|
||
},
|
||
// [ cmp eq reg 1 0x00000060 ]
|
||
&expr.Cmp{
|
||
Op: expr.CmpOpEq,
|
||
Register: 1,
|
||
Data: []byte{(0x6 << 4)},
|
||
},
|
||
},
|
||
})
|
||
|
||
if err := c.Flush(); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
func TestLastingConnection(t *testing.T) {
|
||
testdialerr := errors.New("test dial sentinel error")
|
||
dialCount := 0
|
||
c, err := nftables.New(
|
||
nftables.AsLasting(),
|
||
nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) {
|
||
dialCount++
|
||
return nil, testdialerr
|
||
}))
|
||
if err != nil {
|
||
t.Errorf("creating lasting netlink connection failed %v", err)
|
||
return
|
||
}
|
||
defer func() {
|
||
if err := c.CloseLasting(); err != nil {
|
||
t.Errorf("closing lasting netlink connection failed %v", err)
|
||
}
|
||
}()
|
||
|
||
_, err = c.ListTables()
|
||
if !errors.Is(err, testdialerr) {
|
||
t.Errorf("non-testdialerr error returned from TestDial %v", err)
|
||
return
|
||
}
|
||
if dialCount != 1 {
|
||
t.Errorf("internal test error with TestDial invocations %v", dialCount)
|
||
return
|
||
}
|
||
|
||
// While a lasting netlink connection is open, replacing TestDial must be
|
||
// ineffective as there is no need to dial again and activating a new
|
||
// TestDial function. The newly set TestDial function must be getting
|
||
// ignored.
|
||
c.TestDial = func(req []netlink.Message) ([]netlink.Message, error) {
|
||
dialCount--
|
||
return nil, errors.New("transient netlink connection error")
|
||
}
|
||
_, err = c.ListTables()
|
||
if !errors.Is(err, testdialerr) {
|
||
t.Errorf("non-testdialerr error returned from TestDial %v", err)
|
||
return
|
||
}
|
||
if dialCount != 2 {
|
||
t.Errorf("internal test error with TestDial invocations %v", dialCount)
|
||
return
|
||
}
|
||
|
||
for i := 0; i < 2; i++ {
|
||
err = c.CloseLasting()
|
||
if err != nil {
|
||
t.Errorf("closing lasting netlink connection failed in attempt no. %d: %v", i, err)
|
||
return
|
||
}
|
||
}
|
||
_, err = c.ListTables()
|
||
if errors.Is(err, testdialerr) {
|
||
t.Error("testdialerr error returned from TestDial when expecting different error")
|
||
return
|
||
}
|
||
if dialCount != 1 {
|
||
t.Errorf("internal test error with TestDial invocations %v", dialCount)
|
||
return
|
||
}
|
||
|
||
// fall into defer'ed second CloseLasting which must not cause any errors.
|
||
}
|
||
|
||
func TestListChains(t *testing.T) {
|
||
polDrop := nftables.ChainPolicyDrop
|
||
polAcpt := nftables.ChainPolicyAccept
|
||
reply := [][]byte{
|
||
// chain input { type filter hook input priority filter; policy accept; }
|
||
[]byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x00\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x01\x0a\x00\x03\x00\x69\x6e\x70\x75\x74\x00\x00\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x01\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"),
|
||
// chain forward { type filter hook forward priority filter; policy drop; }
|
||
[]byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x01\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x02\x0c\x00\x03\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"),
|
||
// chain output { type filter hook output priority filter; policy accept; }
|
||
[]byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x02\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x03\x0b\x00\x03\x00\x6f\x75\x74\x70\x75\x74\x00\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x03\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x01\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"),
|
||
// chain undef { counter packets 56235 bytes 175436495 return }
|
||
[]byte("\x40\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x03\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x04\x0a\x00\x03\x00\x75\x6e\x64\x65\x66\x00\x00\x00\x08\x00\x06\x00\x00\x00\x00\x01"),
|
||
[]byte("\x14\x00\x00\x00\x03\x00\x02\x00\x00\x00\x00\x04\xb8\x76\x02\x00\x00\x00\x00\x00"),
|
||
}
|
||
|
||
want := []*nftables.Chain{
|
||
{
|
||
Name: "input",
|
||
Hooknum: nftables.ChainHookInput,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Policy: &polAcpt,
|
||
},
|
||
{
|
||
Name: "forward",
|
||
Hooknum: nftables.ChainHookForward,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Policy: &polDrop,
|
||
},
|
||
{
|
||
Name: "output",
|
||
Hooknum: nftables.ChainHookOutput,
|
||
Priority: nftables.ChainPriorityFilter,
|
||
Type: nftables.ChainTypeFilter,
|
||
Policy: &polAcpt,
|
||
},
|
||
{
|
||
Name: "undef",
|
||
Hooknum: nil,
|
||
Priority: nil,
|
||
Policy: nil,
|
||
},
|
||
}
|
||
|
||
c, err := nftables.New(nftables.WithTestDial(
|
||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
msgReply := make([]netlink.Message, len(reply))
|
||
for i, r := range reply {
|
||
nm := &netlink.Message{}
|
||
nm.UnmarshalBinary(r)
|
||
nm.Header.Sequence = req[0].Header.Sequence
|
||
nm.Header.PID = req[0].Header.PID
|
||
msgReply[i] = *nm
|
||
}
|
||
return msgReply, nil
|
||
}))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
chains, err := c.ListChains()
|
||
if err != nil {
|
||
t.Errorf("error returned from TestDial %v", err)
|
||
return
|
||
}
|
||
|
||
if len(chains) != len(want) {
|
||
t.Errorf("number of chains %d != number of want %d", len(chains), len(want))
|
||
return
|
||
}
|
||
|
||
validate := func(got interface{}, want interface{}, name string, index int) {
|
||
if got != want {
|
||
t.Errorf("chain %d: chain %s mismatch, got %v want %v", index, name, got, want)
|
||
}
|
||
}
|
||
|
||
for i, chain := range chains {
|
||
validate(chain.Name, want[i].Name, "name", i)
|
||
if want[i].Hooknum != nil && chain.Hooknum != nil {
|
||
validate(*chain.Hooknum, *want[i].Hooknum, "hooknum value", i)
|
||
} else {
|
||
validate(chain.Hooknum, want[i].Hooknum, "hooknum pointer", i)
|
||
}
|
||
if want[i].Priority != nil && chain.Priority != nil {
|
||
validate(*chain.Priority, *want[i].Priority, "priority value", i)
|
||
} else {
|
||
validate(chain.Priority, want[i].Priority, "priority pointer", i)
|
||
}
|
||
validate(chain.Type, want[i].Type, "type", i)
|
||
|
||
if want[i].Policy != nil && chain.Policy != nil {
|
||
validate(*chain.Policy, *want[i].Policy, "policy value", i)
|
||
} else {
|
||
validate(chain.Policy, want[i].Policy, "policy pointer", i)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestListChainByName(t *testing.T) {
|
||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table := &nftables.Table{
|
||
Name: "chain_test",
|
||
Family: nftables.TableFamilyIPv4,
|
||
}
|
||
tr := conn.AddTable(table)
|
||
|
||
c := &nftables.Chain{
|
||
Name: "filter",
|
||
Table: table,
|
||
}
|
||
conn.AddChain(c)
|
||
|
||
if err := conn.Flush(); err != nil {
|
||
t.Errorf("conn.Flush() failed: %v", err)
|
||
}
|
||
|
||
cr, err := conn.ListChain(tr, c.Name)
|
||
if err != nil {
|
||
t.Fatalf("conn.ListChain() failed: %v", err)
|
||
}
|
||
|
||
if got, want := cr.Name, c.Name; got != want {
|
||
t.Fatalf("got chain %s, want chain %s", got, want)
|
||
}
|
||
|
||
if got, want := cr.Table.Name, table.Name; got != want {
|
||
t.Fatalf("got chain table %s, want chain table %s", got, want)
|
||
}
|
||
}
|
||
|
||
func TestListChainByNameUsingLasting(t *testing.T) {
|
||
_, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
|
||
if err != nil {
|
||
t.Fatalf("nftables.New() failed: %v", err)
|
||
}
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table := &nftables.Table{
|
||
Name: "chain_test_lasting",
|
||
Family: nftables.TableFamilyIPv4,
|
||
}
|
||
tr := conn.AddTable(table)
|
||
|
||
c := &nftables.Chain{
|
||
Name: "filter_lasting",
|
||
Table: table,
|
||
}
|
||
conn.AddChain(c)
|
||
|
||
if err := conn.Flush(); err != nil {
|
||
t.Errorf("conn.Flush() failed: %v", err)
|
||
}
|
||
|
||
cr, err := conn.ListChain(tr, c.Name)
|
||
if err != nil {
|
||
t.Fatalf("conn.ListChain() failed: %v", err)
|
||
}
|
||
|
||
if got, want := cr.Name, c.Name; got != want {
|
||
t.Fatalf("got chain %s, want chain %s", got, want)
|
||
}
|
||
|
||
if got, want := cr.Table.Name, table.Name; got != want {
|
||
t.Fatalf("got chain table %s, want chain table %s", got, want)
|
||
}
|
||
}
|
||
|
||
func TestListTableByName(t *testing.T) {
|
||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||
defer nftest.CleanupSystemConn(t, newNS)
|
||
conn.FlushRuleset()
|
||
defer conn.FlushRuleset()
|
||
|
||
table1 := &nftables.Table{
|
||
Name: "table_test",
|
||
Family: nftables.TableFamilyIPv4,
|
||
}
|
||
conn.AddTable(table1)
|
||
table2 := &nftables.Table{
|
||
Name: "table_test_inet",
|
||
Family: nftables.TableFamilyINet,
|
||
}
|
||
conn.AddTable(table2)
|
||
table3 := &nftables.Table{
|
||
Name: table1.Name,
|
||
Family: nftables.TableFamilyINet,
|
||
}
|
||
conn.AddTable(table3)
|
||
|
||
if err := conn.Flush(); err != nil {
|
||
t.Errorf("conn.Flush() failed: %v", err)
|
||
}
|
||
|
||
tr, err := conn.ListTable(table1.Name)
|
||
if err != nil {
|
||
t.Fatalf("conn.ListTable() failed: %v", err)
|
||
}
|
||
|
||
if got, want := tr.Name, table1.Name; got != want {
|
||
t.Fatalf("got table %s, want table %s", got, want)
|
||
}
|
||
|
||
// not specifying table family should return family ipv4
|
||
tr, err = conn.ListTable(table3.Name)
|
||
if err != nil {
|
||
t.Fatalf("conn.ListTable() failed: %v", err)
|
||
}
|
||
if got, want := tr.Name, table1.Name; got != want {
|
||
t.Fatalf("got table %s, want table %s", got, want)
|
||
}
|
||
if got, want := tr.Family, table1.Family; got != want {
|
||
t.Fatalf("got table family %v, want table family %v", got, want)
|
||
}
|
||
|
||
// specifying correct INet family
|
||
tr, err = conn.ListTableOfFamily(table3.Name, nftables.TableFamilyINet)
|
||
if err != nil {
|
||
t.Fatalf("conn.ListTable() failed: %v", err)
|
||
}
|
||
if got, want := tr.Name, table3.Name; got != want {
|
||
t.Fatalf("got table %s, want table %s", got, want)
|
||
}
|
||
if got, want := tr.Family, table3.Family; got != want {
|
||
t.Fatalf("got table family %v, want table family %v", got, want)
|
||
}
|
||
|
||
// not specifying correct family should return err since no table in ipv4
|
||
if _, err = conn.ListTable(table2.Name); err == nil {
|
||
t.Fatalf("conn.ListTable() should have failed")
|
||
}
|
||
|
||
// specifying correct INet family
|
||
tr, err = conn.ListTableOfFamily(table2.Name, nftables.TableFamilyINet)
|
||
if err != nil {
|
||
t.Fatalf("conn.ListTable() failed: %v", err)
|
||
}
|
||
if got, want := tr.Name, table2.Name; got != want {
|
||
t.Fatalf("got table %s, want table %s", got, want)
|
||
}
|
||
if got, want := tr.Family, table2.Family; got != want {
|
||
t.Fatalf("got table family %v, want table family %v", got, want)
|
||
}
|
||
}
|
||
|
||
func TestAddChain(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
chain *nftables.Chain
|
||
want [][]byte
|
||
}{
|
||
{
|
||
name: "Base chain",
|
||
chain: &nftables.Chain{
|
||
Name: "base-chain",
|
||
Hooknum: nftables.ChainHookPrerouting,
|
||
Priority: nftables.ChainPriorityRef(0),
|
||
Type: nftables.ChainTypeFilter,
|
||
},
|
||
want: [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add table ip filter
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||
// nft add chain ip filter base-chain { type filter hook prerouting priority 0 \; }
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0f\x00\x03\x00\x62\x61\x73\x65\x2d\x63\x68\x61\x69\x6e\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"),
|
||
// batch end
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
},
|
||
},
|
||
{
|
||
name: "Regular chain",
|
||
chain: &nftables.Chain{
|
||
Name: "regular-chain",
|
||
},
|
||
want: [][]byte{
|
||
// batch begin
|
||
[]byte("\x00\x00\x00\x0a"),
|
||
// nft add table ip filter
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||
// nft add chain ip filter regular-chain
|
||
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x12\x00\x03\x00\x72\x65\x67\x75\x6c\x61\x72\x2d\x63\x68\x61\x69\x6e\x00\x00\x00"),
|
||
// batch end
|
||
|