refactor common test code into package nftest
Converting more test functions to use it (and then splitting out test functions into their own files) is left for a follow-up commit.
This commit is contained in:
parent
33143dee49
commit
2719b9add1
|
@ -0,0 +1,133 @@
|
|||
// Package nftest contains utility functions for nftables testing.
|
||||
package nftest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/mdlayher/netlink"
|
||||
)
|
||||
|
||||
// Recorder provides an nftables connection that does not send to the Linux
|
||||
// kernel but instead records netlink messages into the recorder. The recorded
|
||||
// requests can later be obtained using Requests and compared using Diff.
|
||||
type Recorder struct {
|
||||
requests []netlink.Message
|
||||
}
|
||||
|
||||
// Conn opens an nftables connection that records netlink messages into the
|
||||
// Recorder.
|
||||
func (r *Recorder) Conn() (*nftables.Conn, error) {
|
||||
return nftables.New(nftables.WithTestDial(
|
||||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||||
r.requests = append(r.requests, req...)
|
||||
|
||||
// TODO: generate and return acknowledgements
|
||||
for _, msg := range req {
|
||||
log.Printf("msg: %+v", msg)
|
||||
}
|
||||
return req, nil
|
||||
}))
|
||||
}
|
||||
|
||||
// Requests returns the recorded netlink messages (typically nftables requests).
|
||||
func (r *Recorder) Requests() []netlink.Message {
|
||||
return r.requests
|
||||
}
|
||||
|
||||
// NewRecorder returns a ready-to-use Recorder.
|
||||
func NewRecorder() *Recorder {
|
||||
return &Recorder{}
|
||||
}
|
||||
|
||||
// Diff returns the first difference between the specified netlink messages and
|
||||
// the expected netlink message payloads.
|
||||
func Diff(got []netlink.Message, want [][]byte) string {
|
||||
for idx, msg := range got {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("msg.MarshalBinary: %v", err)
|
||||
}
|
||||
if len(b) < 16 {
|
||||
continue
|
||||
}
|
||||
b = b[16:]
|
||||
if len(want) == 0 {
|
||||
return fmt.Sprintf("no want entry for message %d: %x", idx, b)
|
||||
}
|
||||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||||
return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
|
||||
// correspond to the provided want netlink message payloads
|
||||
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
|
||||
t.Helper()
|
||||
|
||||
rec := NewRecorder()
|
||||
|
||||
c, err := rec.Conn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c.FlushRuleset()
|
||||
|
||||
fillRuleset(c)
|
||||
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := Diff(rec.Requests(), want); diff != "" {
|
||||
t.Errorf("unexpected netlink messages: diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
|
||||
// users to make sense of large byte literals more easily.
|
||||
func nfdump(b []byte) string {
|
||||
var buf bytes.Buffer
|
||||
i := 0
|
||||
for ; i < len(b); i += 4 {
|
||||
// TODO: show printable characters as ASCII
|
||||
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
|
||||
b[i],
|
||||
b[i+1],
|
||||
b[i+2],
|
||||
b[i+3])
|
||||
}
|
||||
for ; i < len(b); i++ {
|
||||
fmt.Fprintf(&buf, "%02x ", b[i])
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// linediff returns a side-by-side diff of two nfdump() return values, flagging
|
||||
// lines which are not equal with an exclamation point prefix.
|
||||
func linediff(a, b string) string {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "got -- want\n")
|
||||
linesA := strings.Split(a, "\n")
|
||||
linesB := strings.Split(b, "\n")
|
||||
for idx, lineA := range linesA {
|
||||
if idx >= len(linesB) {
|
||||
break
|
||||
}
|
||||
lineB := linesB[idx]
|
||||
prefix := "! "
|
||||
if lineA == lineB {
|
||||
prefix = " "
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/nftables/internal/nftest"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/vishvananda/netns"
|
||||
"golang.org/x/sys/unix"
|
||||
|
@ -5057,34 +5058,8 @@ func TestNotrack(t *testing.T) {
|
|||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
|
||||
c, err := nftables.New(nftables.WithTestDial(
|
||||
func(req []netlink.Message) ([]netlink.Message, error) {
|
||||
for idx, msg := range req {
|
||||
b, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(b) < 16 {
|
||||
continue
|
||||
}
|
||||
b = b[16:]
|
||||
if len(want) == 0 {
|
||||
t.Errorf("no want entry for message %d: %x", idx, b)
|
||||
continue
|
||||
}
|
||||
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||||
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||||
}
|
||||
want = want[1:]
|
||||
}
|
||||
return req, nil
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c.FlushRuleset()
|
||||
|
||||
nftest.MatchRulesetBytes(t,
|
||||
func(c *nftables.Conn) {
|
||||
filter := c.AddTable(&nftables.Table{
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Name: "filter",
|
||||
|
@ -5106,10 +5081,8 @@ func TestNotrack(t *testing.T) {
|
|||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want)
|
||||
}
|
||||
|
||||
func TestQuota(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue