refactor common test code into package nftest

Converting more test functions to use it (and then splitting out test
functions into their own files) is left for a follow-up commit.
This commit is contained in:
Michael Stapelberg 2022-06-11 23:10:56 +02:00
parent 33143dee49
commit 2719b9add1
2 changed files with 156 additions and 50 deletions

133
internal/nftest/nftest.go Normal file
View File

@ -0,0 +1,133 @@
// Package nftest contains utility functions for nftables testing.
package nftest
import (
"bytes"
"fmt"
"log"
"strings"
"testing"
"github.com/google/nftables"
"github.com/mdlayher/netlink"
)
// Recorder provides an nftables connection that does not send to the Linux
// kernel but instead records netlink messages into the recorder. The recorded
// requests can later be obtained using Requests and compared using Diff.
type Recorder struct {
requests []netlink.Message
}
// Conn opens an nftables connection that records netlink messages into the
// Recorder.
func (r *Recorder) Conn() (*nftables.Conn, error) {
return nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
r.requests = append(r.requests, req...)
// TODO: generate and return acknowledgements
for _, msg := range req {
log.Printf("msg: %+v", msg)
}
return req, nil
}))
}
// Requests returns the recorded netlink messages (typically nftables requests).
func (r *Recorder) Requests() []netlink.Message {
return r.requests
}
// NewRecorder returns a ready-to-use Recorder.
func NewRecorder() *Recorder {
return &Recorder{}
}
// Diff returns the first difference between the specified netlink messages and
// the expected netlink message payloads.
func Diff(got []netlink.Message, want [][]byte) string {
for idx, msg := range got {
b, err := msg.MarshalBinary()
if err != nil {
return fmt.Sprintf("msg.MarshalBinary: %v", err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
return fmt.Sprintf("no want entry for message %d: %x", idx, b)
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
return ""
}
// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
// correspond to the provided want netlink message payloads
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
t.Helper()
rec := NewRecorder()
c, err := rec.Conn()
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
fillRuleset(c)
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if diff := Diff(rec.Requests(), want); diff != "" {
t.Errorf("unexpected netlink messages: diff: %s", diff)
}
}
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {
var buf bytes.Buffer
i := 0
for ; i < len(b); i += 4 {
// TODO: show printable characters as ASCII
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
b[i],
b[i+1],
b[i+2],
b[i+3])
}
for ; i < len(b); i++ {
fmt.Fprintf(&buf, "%02x ", b[i])
}
return buf.String()
}
// linediff returns a side-by-side diff of two nfdump() return values, flagging
// lines which are not equal with an exclamation point prefix.
func linediff(a, b string) string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "got -- want\n")
linesA := strings.Split(a, "\n")
linesB := strings.Split(b, "\n")
for idx, lineA := range linesA {
if idx >= len(linesB) {
break
}
lineB := linesB[idx]
prefix := "! "
if lineA == lineB {
prefix = " "
}
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
}
return buf.String()
}

View File

@ -30,6 +30,7 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/internal/nftest"
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
@ -5057,59 +5058,31 @@ func TestNotrack(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
return req, nil
}))
if err != nil {
t.Fatal(err)
}
nftest.MatchRulesetBytes(t,
func(c *nftables.Conn) {
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
c.FlushRuleset()
prerouting := c.AddChain(&nftables.Chain{
Name: "base-chain",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
prerouting := c.AddChain(&nftables.Chain{
Name: "base-chain",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})
c.AddRule(&nftables.Rule{
Table: filter,
Chain: prerouting,
Exprs: []expr.Any{
// [ notrack ]
&expr.Notrack{},
c.AddRule(&nftables.Rule{
Table: filter,
Chain: prerouting,
Exprs: []expr.Any{
// [ notrack ]
&expr.Notrack{},
},
})
},
})
if err := c.Flush(); err != nil {
t.Fatal(err)
}
want)
}
func TestQuota(t *testing.T) {