134 lines
3.3 KiB
Go
134 lines
3.3 KiB
Go
|
// Package nftest contains utility functions for nftables testing.
|
||
|
package nftest
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"log"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/google/nftables"
|
||
|
"github.com/mdlayher/netlink"
|
||
|
)
|
||
|
|
||
|
// Recorder provides an nftables connection that does not send to the Linux
|
||
|
// kernel but instead records netlink messages into the recorder. The recorded
|
||
|
// requests can later be obtained using Requests and compared using Diff.
|
||
|
type Recorder struct {
|
||
|
requests []netlink.Message
|
||
|
}
|
||
|
|
||
|
// Conn opens an nftables connection that records netlink messages into the
|
||
|
// Recorder.
|
||
|
func (r *Recorder) Conn() (*nftables.Conn, error) {
|
||
|
return nftables.New(nftables.WithTestDial(
|
||
|
func(req []netlink.Message) ([]netlink.Message, error) {
|
||
|
r.requests = append(r.requests, req...)
|
||
|
|
||
|
// TODO: generate and return acknowledgements
|
||
|
for _, msg := range req {
|
||
|
log.Printf("msg: %+v", msg)
|
||
|
}
|
||
|
return req, nil
|
||
|
}))
|
||
|
}
|
||
|
|
||
|
// Requests returns the recorded netlink messages (typically nftables requests).
|
||
|
func (r *Recorder) Requests() []netlink.Message {
|
||
|
return r.requests
|
||
|
}
|
||
|
|
||
|
// NewRecorder returns a ready-to-use Recorder.
|
||
|
func NewRecorder() *Recorder {
|
||
|
return &Recorder{}
|
||
|
}
|
||
|
|
||
|
// Diff returns the first difference between the specified netlink messages and
|
||
|
// the expected netlink message payloads.
|
||
|
func Diff(got []netlink.Message, want [][]byte) string {
|
||
|
for idx, msg := range got {
|
||
|
b, err := msg.MarshalBinary()
|
||
|
if err != nil {
|
||
|
return fmt.Sprintf("msg.MarshalBinary: %v", err)
|
||
|
}
|
||
|
if len(b) < 16 {
|
||
|
continue
|
||
|
}
|
||
|
b = b[16:]
|
||
|
if len(want) == 0 {
|
||
|
return fmt.Sprintf("no want entry for message %d: %x", idx, b)
|
||
|
}
|
||
|
if got, want := b, want[0]; !bytes.Equal(got, want) {
|
||
|
return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
|
||
|
}
|
||
|
want = want[1:]
|
||
|
}
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
|
||
|
// correspond to the provided want netlink message payloads
|
||
|
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
|
||
|
t.Helper()
|
||
|
|
||
|
rec := NewRecorder()
|
||
|
|
||
|
c, err := rec.Conn()
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
c.FlushRuleset()
|
||
|
|
||
|
fillRuleset(c)
|
||
|
|
||
|
if err := c.Flush(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
if diff := Diff(rec.Requests(), want); diff != "" {
|
||
|
t.Errorf("unexpected netlink messages: diff: %s", diff)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
|
||
|
// users to make sense of large byte literals more easily.
|
||
|
func nfdump(b []byte) string {
|
||
|
var buf bytes.Buffer
|
||
|
i := 0
|
||
|
for ; i < len(b); i += 4 {
|
||
|
// TODO: show printable characters as ASCII
|
||
|
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
|
||
|
b[i],
|
||
|
b[i+1],
|
||
|
b[i+2],
|
||
|
b[i+3])
|
||
|
}
|
||
|
for ; i < len(b); i++ {
|
||
|
fmt.Fprintf(&buf, "%02x ", b[i])
|
||
|
}
|
||
|
return buf.String()
|
||
|
}
|
||
|
|
||
|
// linediff returns a side-by-side diff of two nfdump() return values, flagging
|
||
|
// lines which are not equal with an exclamation point prefix.
|
||
|
func linediff(a, b string) string {
|
||
|
var buf bytes.Buffer
|
||
|
fmt.Fprintf(&buf, "got -- want\n")
|
||
|
linesA := strings.Split(a, "\n")
|
||
|
linesB := strings.Split(b, "\n")
|
||
|
for idx, lineA := range linesA {
|
||
|
if idx >= len(linesB) {
|
||
|
break
|
||
|
}
|
||
|
lineB := linesB[idx]
|
||
|
prefix := "! "
|
||
|
if lineA == lineB {
|
||
|
prefix = " "
|
||
|
}
|
||
|
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
|
||
|
}
|
||
|
return buf.String()
|
||
|
}
|