diff --git a/internal/nftest/nftest.go b/internal/nftest/nftest.go new file mode 100644 index 0000000..7651b4f --- /dev/null +++ b/internal/nftest/nftest.go @@ -0,0 +1,133 @@ +// Package nftest contains utility functions for nftables testing. +package nftest + +import ( + "bytes" + "fmt" + "log" + "strings" + "testing" + + "github.com/google/nftables" + "github.com/mdlayher/netlink" +) + +// Recorder provides an nftables connection that does not send to the Linux +// kernel but instead records netlink messages into the recorder. The recorded +// requests can later be obtained using Requests and compared using Diff. +type Recorder struct { + requests []netlink.Message +} + +// Conn opens an nftables connection that records netlink messages into the +// Recorder. +func (r *Recorder) Conn() (*nftables.Conn, error) { + return nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { + r.requests = append(r.requests, req...) + + // TODO: generate and return acknowledgements + for _, msg := range req { + log.Printf("msg: %+v", msg) + } + return req, nil + })) +} + +// Requests returns the recorded netlink messages (typically nftables requests). +func (r *Recorder) Requests() []netlink.Message { + return r.requests +} + +// NewRecorder returns a ready-to-use Recorder. +func NewRecorder() *Recorder { + return &Recorder{} +} + +// Diff returns the first difference between the specified netlink messages and +// the expected netlink message payloads. +func Diff(got []netlink.Message, want [][]byte) string { + for idx, msg := range got { + b, err := msg.MarshalBinary() + if err != nil { + return fmt.Sprintf("msg.MarshalBinary: %v", err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + return fmt.Sprintf("no want entry for message %d: %x", idx, b) + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return "" +} + +// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications +// correspond to the provided want netlink message payloads +func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) { + t.Helper() + + rec := NewRecorder() + + c, err := rec.Conn() + if err != nil { + t.Fatal(err) + } + + c.FlushRuleset() + + fillRuleset(c) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + if diff := Diff(rec.Requests(), want); diff != "" { + t.Errorf("unexpected netlink messages: diff: %s", diff) + } +} + +// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing +// users to make sense of large byte literals more easily. +func nfdump(b []byte) string { + var buf bytes.Buffer + i := 0 + for ; i < len(b); i += 4 { + // TODO: show printable characters as ASCII + fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", + b[i], + b[i+1], + b[i+2], + b[i+3]) + } + for ; i < len(b); i++ { + fmt.Fprintf(&buf, "%02x ", b[i]) + } + return buf.String() +} + +// linediff returns a side-by-side diff of two nfdump() return values, flagging +// lines which are not equal with an exclamation point prefix. +func linediff(a, b string) string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "got -- want\n") + linesA := strings.Split(a, "\n") + linesB := strings.Split(b, "\n") + for idx, lineA := range linesA { + if idx >= len(linesB) { + break + } + lineB := linesB[idx] + prefix := "! " + if lineA == lineB { + prefix = " " + } + fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB) + } + return buf.String() +} diff --git a/nftables_test.go b/nftables_test.go index a91c6fa..9609daf 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/google/nftables/internal/nftest" "github.com/mdlayher/netlink" "github.com/vishvananda/netns" "golang.org/x/sys/unix" @@ -5057,59 +5058,31 @@ func TestNotrack(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) - if err != nil { - t.Fatal(err) - } + nftest.MatchRulesetBytes(t, + func(c *nftables.Conn) { + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) - c.FlushRuleset() + prerouting := c.AddChain(&nftables.Chain{ + Name: "base-chain", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityFilter, + }) - filter := c.AddTable(&nftables.Table{ - Family: nftables.TableFamilyIPv4, - Name: "filter", - }) - - prerouting := c.AddChain(&nftables.Chain{ - Name: "base-chain", - Table: filter, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookPrerouting, - Priority: nftables.ChainPriorityFilter, - }) - - c.AddRule(&nftables.Rule{ - Table: filter, - Chain: prerouting, - Exprs: []expr.Any{ - // [ notrack ] - &expr.Notrack{}, + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: prerouting, + Exprs: []expr.Any{ + // [ notrack ] + &expr.Notrack{}, + }, + }) }, - }) - - if err := c.Flush(); err != nil { - t.Fatal(err) - } + want) } func TestQuota(t *testing.T) {