From 487c56e8464c58e1cb9401b1590cddc5e5c96b2a Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Thu, 26 Dec 2019 12:51:35 +0100 Subject: [PATCH] test refactoring --- nftables_test.go | 785 ++++------------------------------------------- 1 file changed, 67 insertions(+), 718 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index e3337ef..91c1cd0 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nltest" "github.com/vishvananda/netns" "golang.org/x/sys/unix" ) @@ -112,6 +113,38 @@ func cleanupSystemNFTConn(t *testing.T, newNS netns.NsHandle) { } } +func CheckNLReq(t *testing.T, wantMsg [][]byte, replies [][]netlink.Message) nltest.Func { + return func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + return req, err + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(wantMsg) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, wantMsg[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + + wantMsg = wantMsg[1:] + } + + if len(replies) > 0 { + rep := replies[0] + replies = replies[1:] + return rep, nil + } else { + return req, nil + } + } +} + func TestConfigureNAT(t *testing.T) { // The want byte sequences come from stracing nft(8), e.g.: // strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat @@ -140,27 +173,7 @@ func TestConfigureNAT(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -354,27 +367,7 @@ func TestConfigureNATSourceAddress(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -437,29 +430,7 @@ func TestGetRule(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - rep := reply[0] - reply = reply[1:] - return rep, nil - }, + TestDial: CheckNLReq(t, want, reply), } rules, err := c.GetRule( @@ -516,27 +487,7 @@ func TestAddCounter(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.AddObj(&nftables.CounterObj{ @@ -573,27 +524,7 @@ func TestDelRule(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.DelRule(&nftables.Rule{ @@ -618,27 +549,7 @@ func TestLog(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.AddRule(&nftables.Rule{ @@ -668,27 +579,7 @@ func TestTProxy(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.AddRule(&nftables.Rule{ @@ -731,27 +622,7 @@ func TestCt(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.AddRule(&nftables.Rule{ @@ -784,27 +655,7 @@ func TestCtSet(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.AddRule(&nftables.Rule{ @@ -843,27 +694,7 @@ func TestAddRuleWithPosition(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.AddRule(&nftables.Rule{ @@ -951,27 +782,7 @@ func TestAddChain(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -1028,27 +839,7 @@ func TestDelChain(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } tt.chain.Table = &nftables.Table{ @@ -1078,29 +869,7 @@ func TestGetObjReset(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - rep := reply[0] - reply = reply[1:] - return rep, nil - }, + TestDial: CheckNLReq(t, want, reply), } filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} @@ -1158,27 +927,7 @@ func TestConfigureClamping(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -1291,27 +1040,7 @@ func TestDropVerdict(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -1391,27 +1120,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2148,27 +1857,7 @@ func TestConfigureNATRedirect(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2253,27 +1942,7 @@ func TestConfigureJumpVerdict(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2359,27 +2028,7 @@ func TestConfigureReturnVerdict(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2444,27 +2093,7 @@ func TestConfigureRangePort(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2542,27 +2171,7 @@ func TestConfigureRangeIPv4(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2632,27 +2241,7 @@ func TestConfigureRangeIPv6(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -2745,27 +2334,7 @@ func TestSet4(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } tbl := &nftables.Table{ @@ -2931,27 +2500,7 @@ func TestMasq(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -3062,27 +2611,7 @@ func TestReject(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -3190,27 +2719,7 @@ func TestFib(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -3294,27 +2803,7 @@ func TestNumgen(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -3379,27 +2868,7 @@ func TestMap(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -3497,27 +2966,7 @@ func TestVmap(t *testing.T) { for _, tt := range tests { c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - }, + TestDial: CheckNLReq(t, tt.want, nil), } filter := c.AddTable(&nftables.Table{ @@ -3557,27 +3006,7 @@ func TestJHash(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -3658,27 +3087,7 @@ func TestDup(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -3758,27 +3167,7 @@ func TestDupWoDev(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -3840,27 +3229,7 @@ func TestNotrack(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset() @@ -3914,27 +3283,7 @@ func TestStatelessNAT(t *testing.T) { } c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - }, + TestDial: CheckNLReq(t, want, nil), } c.FlushRuleset()