From 9a2862f48b7868545da43839a36af73ff59dc9d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Sch=C3=A4r?= Date: Tue, 25 Mar 2025 17:03:44 +0100 Subject: [PATCH] Receive replies in Flush (#309) Commit 0d9bfa4d18da added code to handle "overrun", but the commit is very misleading. NLMSG_OVERRUN is in fact not a flag, but a complete message type, so the (re&netlink.Overrun) masking makes no sense. Even better, NLMSG_OVERRUN is never actually used by Linux. The actual bug which the commit was attempting to fix is that Flush was not receiving replies which the kernel sent for messages with the echo flag. This change reverts that commit and instead adds code in Flush to receive the replies. I updated tests which simulate the kernel to generate replies. --- conn.go | 66 ++- internal/nftest/nftest.go | 17 +- nftables_test.go | 908 +++----------------------------------- rule.go | 10 +- 4 files changed, 133 insertions(+), 868 deletions(-) diff --git a/conn.go b/conn.go index d4759b1..99baa2d 100644 --- a/conn.go +++ b/conn.go @@ -171,24 +171,6 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([] return reply, nil } - if len(reply) != 0 { - last := reply[len(reply)-1] - for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type { - // we are not finished, the message is overrun - r, err := nlconn.Receive() - if err != nil { - return nil, err - } - reply = append(reply, r...) - last = reply[len(reply)-1] - } - - if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 { - // we have already collected an ack - return reply, nil - } - } - // Now we expect an ack ack, err := nlconn.Receive() if err != nil { @@ -196,8 +178,7 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([] } if len(ack) == 0 { - // received an empty ack? - return reply, nil + return nil, errors.New("received an empty ack") } msg := ack[0] @@ -263,15 +244,49 @@ func (cc *Conn) Flush() error { } defer func() { _ = closer() }() - if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + messages, err := conn.SendMessages(batch(cc.messages)) + if err != nil { return fmt.Errorf("SendMessages: %w", err) } var errs error + + // Fetch replies. Each message with the Echo flag triggers a reply of the same + // type. Additionally, if the first message of the batch has the Echo flag, we + // get a reply of type NFT_MSG_NEWGEN, which we ignore. + replyIndex := 0 + for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { + replyIndex++ + } + replies, err := conn.Receive() + for err == nil && len(replies) != 0 { + reply := replies[0] + if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence { + // The next message is the acknowledgement for the first message in the + // batch; stop looking for replies. + break + } else if replyIndex < len(cc.messages) { + msg := messages[replyIndex+1] + if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type { + replyIndex++ + for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { + replyIndex++ + } + } + } + replies = replies[1:] + if len(replies) == 0 { + replies, err = conn.Receive() + } + } + // Fetch the requested acknowledgement for each message we sent. - for _, msg := range cc.messages { - if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { - if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { + for i := range cc.messages { + if i != 0 { + _, err = conn.Receive() + } + if err != nil { + if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) { // Kernel will only send one error to user space. return err } @@ -282,6 +297,9 @@ func (cc *Conn) Flush() error { if errs != nil { return fmt.Errorf("conn.Receive: %w", errs) } + if replyIndex < len(cc.messages) { + return fmt.Errorf("missing reply for message %d in batch", replyIndex) + } return nil } diff --git a/internal/nftest/nftest.go b/internal/nftest/nftest.go index 2709fa7..8d5b496 100644 --- a/internal/nftest/nftest.go +++ b/internal/nftest/nftest.go @@ -25,10 +25,21 @@ func (r *Recorder) Conn() (*nftables.Conn, error) { func(req []netlink.Message) ([]netlink.Message, error) { r.requests = append(r.requests, req...) - acks := make([]netlink.Message, 0, len(req)) + replies := make([]netlink.Message, 0, len(req)) + // Generate replies. + for _, msg := range req { + if msg.Header.Flags&netlink.Echo != 0 { + data := append([]byte{}, msg.Data...) + replies = append(replies, netlink.Message{ + Header: msg.Header, + Data: data, + }) + } + } + // Generate acknowledgements. for _, msg := range req { if msg.Header.Flags&netlink.Acknowledge != 0 { - acks = append(acks, netlink.Message{ + replies = append(replies, netlink.Message{ Header: netlink.Header{ Length: 4, Type: netlink.Error, @@ -39,7 +50,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) { }) } } - return acks, nil + return replies, nil })) } diff --git a/nftables_test.go b/nftables_test.go index 69855b6..6e9ecfe 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -79,6 +79,40 @@ func linediff(a, b string) string { return buf.String() } +func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption { + return nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) { + var replies []netlink.Message + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + + // Generate replies. + if msg.Header.Flags&netlink.Echo != 0 { + data := append([]byte{}, msg.Data...) + replies = append(replies, netlink.Message{ + Header: msg.Header, + Data: data, + }) + } + } + return replies, nil + }) +} + func ifname(n string) []byte { b := make([]byte, 16) copy(b, []byte(n+"\x00")) @@ -370,28 +404,7 @@ func TestConfigureNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -638,28 +651,7 @@ func TestConfigureNATSourceAddress(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1106,28 +1098,7 @@ func TestAddCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1172,28 +1143,7 @@ func TestDeleteCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1225,28 +1175,7 @@ func TestDelRule(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1272,28 +1201,7 @@ func TestLog(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1324,28 +1232,7 @@ func TestTProxy(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1389,28 +1276,7 @@ func TestTProxyWithAddrField(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1457,28 +1323,7 @@ func TestCt(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1518,28 +1363,7 @@ func TestSecMarkMarshaling(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - conn, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + conn, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1922,28 +1746,7 @@ func TestCtSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1982,28 +1785,7 @@ func TestCtStat(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -2042,28 +1824,7 @@ func TestAddRuleWithPosition(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -2472,28 +2233,7 @@ func TestAddChain(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -2551,28 +2291,7 @@ func TestDelChain(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -3284,28 +3003,7 @@ func TestConfigureClamping(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3419,28 +3117,7 @@ func TestMatchPacketHeader(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3549,28 +3226,7 @@ func TestDropVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3651,28 +3307,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5431,28 +5066,7 @@ func TestConfigureNATRedirect(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5538,28 +5152,7 @@ func TestConfigureJumpVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5646,28 +5239,7 @@ func TestConfigureReturnVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5733,28 +5305,7 @@ func TestConfigureRangePort(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5833,28 +5384,7 @@ func TestConfigureRangeIPv4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5925,28 +5455,7 @@ func TestConfigureRangeIPv6(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6039,28 +5548,7 @@ func TestSet4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6144,28 +5632,7 @@ func TestSetComment(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6289,28 +5756,7 @@ func TestMasq(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6422,28 +5868,7 @@ func TestReject(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6552,28 +5977,7 @@ func TestFib(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6713,28 +6117,7 @@ func TestNumgen(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6800,28 +6183,7 @@ func TestMap(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6920,28 +6282,7 @@ func TestVmap(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6982,28 +6323,7 @@ func TestJHash(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7084,28 +6404,7 @@ func TestDup(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7184,28 +6483,7 @@ func TestDupWoDev(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7311,28 +6589,7 @@ func TestQuota(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7388,28 +6645,7 @@ func TestStatelessNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } diff --git a/rule.go b/rule.go index 54146c1..68da81f 100644 --- a/rule.go +++ b/rule.go @@ -94,7 +94,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), - Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, }, Data: append(extraHeader(uint8(t.Family), 0), data...), } @@ -164,20 +164,20 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { msgData := []byte{} msgData = append(msgData, data...) - var flags netlink.HeaderFlags if r.UserData != nil { msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, })...) } + var flags netlink.HeaderFlags switch op { case operationAdd: - flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND + flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append case operationInsert: - flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO + flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo case operationReplace: - flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE + flags = netlink.Request | netlink.Acknowledge | netlink.Replace } if r.Position != 0 || (r.Flags&(1<