refactor common test code into package nftest
Converting more test functions to use it (and then splitting out test functions into their own files) is left for a follow-up commit.
This commit is contained in:
parent
33143dee49
commit
2719b9add1
|
@ -0,0 +1,133 @@
|
||||||
|
// Package nftest contains utility functions for nftables testing.
|
||||||
|
package nftest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/mdlayher/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recorder provides an nftables connection that does not send to the Linux
|
||||||
|
// kernel but instead records netlink messages into the recorder. The recorded
|
||||||
|
// requests can later be obtained using Requests and compared using Diff.
|
||||||
|
type Recorder struct {
|
||||||
|
requests []netlink.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn opens an nftables connection that records netlink messages into the
|
||||||
|
// Recorder.
|
||||||
|
func (r *Recorder) Conn() (*nftables.Conn, error) {
|
||||||
|
return nftables.New(nftables.WithTestDial(
|
||||||
|
func(req []netlink.Message) ([]netlink.Message, error) {
|
||||||
|
r.requests = append(r.requests, req...)
|
||||||
|
|
||||||
|
// TODO: generate and return acknowledgements
|
||||||
|
for _, msg := range req {
|
||||||
|
log.Printf("msg: %+v", msg)
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Requests returns the recorded netlink messages (typically nftables requests).
|
||||||
|
func (r *Recorder) Requests() []netlink.Message {
|
||||||
|
return r.requests
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecorder returns a ready-to-use Recorder.
|
||||||
|
func NewRecorder() *Recorder {
|
||||||
|
return &Recorder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diff returns the first difference between the specified netlink messages and
|
||||||
|
// the expected netlink message payloads.
|
||||||
|
func Diff(got []netlink.Message, want [][]byte) string {
|
||||||
|
for idx, msg := range got {
|
||||||
|
b, err := msg.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("msg.MarshalBinary: %v", err)
|
||||||
|
}
|
||||||
|
if len(b) < 16 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b = b[16:]
|
||||||
|
if len(want) == 0 {
|
||||||
|
return fmt.Sprintf("no want entry for message %d: %x", idx, b)
|
||||||
|
}
|
||||||
|
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||||||
|
return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||||||
|
}
|
||||||
|
want = want[1:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
|
||||||
|
// correspond to the provided want netlink message payloads
|
||||||
|
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
rec := NewRecorder()
|
||||||
|
|
||||||
|
c, err := rec.Conn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.FlushRuleset()
|
||||||
|
|
||||||
|
fillRuleset(c)
|
||||||
|
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := Diff(rec.Requests(), want); diff != "" {
|
||||||
|
t.Errorf("unexpected netlink messages: diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
|
||||||
|
// users to make sense of large byte literals more easily.
|
||||||
|
func nfdump(b []byte) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
i := 0
|
||||||
|
for ; i < len(b); i += 4 {
|
||||||
|
// TODO: show printable characters as ASCII
|
||||||
|
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
|
||||||
|
b[i],
|
||||||
|
b[i+1],
|
||||||
|
b[i+2],
|
||||||
|
b[i+3])
|
||||||
|
}
|
||||||
|
for ; i < len(b); i++ {
|
||||||
|
fmt.Fprintf(&buf, "%02x ", b[i])
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// linediff returns a side-by-side diff of two nfdump() return values, flagging
|
||||||
|
// lines which are not equal with an exclamation point prefix.
|
||||||
|
func linediff(a, b string) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fmt.Fprintf(&buf, "got -- want\n")
|
||||||
|
linesA := strings.Split(a, "\n")
|
||||||
|
linesB := strings.Split(b, "\n")
|
||||||
|
for idx, lineA := range linesA {
|
||||||
|
if idx >= len(linesB) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lineB := linesB[idx]
|
||||||
|
prefix := "! "
|
||||||
|
if lineA == lineB {
|
||||||
|
prefix = " "
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
|
@ -30,6 +30,7 @@ import (
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/google/nftables/internal/nftest"
|
||||||
"github.com/mdlayher/netlink"
|
"github.com/mdlayher/netlink"
|
||||||
"github.com/vishvananda/netns"
|
"github.com/vishvananda/netns"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
@ -5057,34 +5058,8 @@ func TestNotrack(t *testing.T) {
|
||||||
[]byte("\x00\x00\x00\x0a"),
|
[]byte("\x00\x00\x00\x0a"),
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := nftables.New(nftables.WithTestDial(
|
nftest.MatchRulesetBytes(t,
|
||||||
func(req []netlink.Message) ([]netlink.Message, error) {
|
func(c *nftables.Conn) {
|
||||||
for idx, msg := range req {
|
|
||||||
b, err := msg.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(b) < 16 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b = b[16:]
|
|
||||||
if len(want) == 0 {
|
|
||||||
t.Errorf("no want entry for message %d: %x", idx, b)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
|
||||||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
|
||||||
}
|
|
||||||
want = want[1:]
|
|
||||||
}
|
|
||||||
return req, nil
|
|
||||||
}))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.FlushRuleset()
|
|
||||||
|
|
||||||
filter := c.AddTable(&nftables.Table{
|
filter := c.AddTable(&nftables.Table{
|
||||||
Family: nftables.TableFamilyIPv4,
|
Family: nftables.TableFamilyIPv4,
|
||||||
Name: "filter",
|
Name: "filter",
|
||||||
|
@ -5106,10 +5081,8 @@ func TestNotrack(t *testing.T) {
|
||||||
&expr.Notrack{},
|
&expr.Notrack{},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
},
|
||||||
if err := c.Flush(); err != nil {
|
want)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQuota(t *testing.T) {
|
func TestQuota(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue