nftables/internal/nftest/nftest.go

134 lines
3.3 KiB
Go

// Package nftest contains utility functions for nftables testing.
package nftest
import (
"bytes"
"fmt"
"log"
"strings"
"testing"
"github.com/google/nftables"
"github.com/mdlayher/netlink"
)
// Recorder provides an nftables connection that does not send to the Linux
// kernel but instead records netlink messages into the recorder. The recorded
// requests can later be obtained using Requests and compared using Diff.
type Recorder struct {
requests []netlink.Message
}
// Conn opens an nftables connection that records netlink messages into the
// Recorder.
func (r *Recorder) Conn() (*nftables.Conn, error) {
return nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
r.requests = append(r.requests, req...)
// TODO: generate and return acknowledgements
for _, msg := range req {
log.Printf("msg: %+v", msg)
}
return req, nil
}))
}
// Requests returns the recorded netlink messages (typically nftables requests).
func (r *Recorder) Requests() []netlink.Message {
return r.requests
}
// NewRecorder returns a ready-to-use Recorder.
func NewRecorder() *Recorder {
return &Recorder{}
}
// Diff returns the first difference between the specified netlink messages and
// the expected netlink message payloads.
func Diff(got []netlink.Message, want [][]byte) string {
for idx, msg := range got {
b, err := msg.MarshalBinary()
if err != nil {
return fmt.Sprintf("msg.MarshalBinary: %v", err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
return fmt.Sprintf("no want entry for message %d: %x", idx, b)
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
return ""
}
// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
// correspond to the provided want netlink message payloads
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
t.Helper()
rec := NewRecorder()
c, err := rec.Conn()
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
fillRuleset(c)
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if diff := Diff(rec.Requests(), want); diff != "" {
t.Errorf("unexpected netlink messages: diff: %s", diff)
}
}
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {
var buf bytes.Buffer
i := 0
for ; i < len(b); i += 4 {
// TODO: show printable characters as ASCII
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
b[i],
b[i+1],
b[i+2],
b[i+3])
}
for ; i < len(b); i++ {
fmt.Fprintf(&buf, "%02x ", b[i])
}
return buf.String()
}
// linediff returns a side-by-side diff of two nfdump() return values, flagging
// lines which are not equal with an exclamation point prefix.
func linediff(a, b string) string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "got -- want\n")
linesA := strings.Split(a, "\n")
linesB := strings.Split(b, "\n")
for idx, lineA := range linesA {
if idx >= len(linesB) {
break
}
lineB := linesB[idx]
prefix := "! "
if lineA == lineB {
prefix = " "
}
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
}
return buf.String()
}