diff --git a/conn.go b/conn.go index d4759b1..99baa2d 100644 --- a/conn.go +++ b/conn.go @@ -171,24 +171,6 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([] return reply, nil } - if len(reply) != 0 { - last := reply[len(reply)-1] - for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type { - // we are not finished, the message is overrun - r, err := nlconn.Receive() - if err != nil { - return nil, err - } - reply = append(reply, r...) - last = reply[len(reply)-1] - } - - if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 { - // we have already collected an ack - return reply, nil - } - } - // Now we expect an ack ack, err := nlconn.Receive() if err != nil { @@ -196,8 +178,7 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([] } if len(ack) == 0 { - // received an empty ack? - return reply, nil + return nil, errors.New("received an empty ack") } msg := ack[0] @@ -263,15 +244,49 @@ func (cc *Conn) Flush() error { } defer func() { _ = closer() }() - if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + messages, err := conn.SendMessages(batch(cc.messages)) + if err != nil { return fmt.Errorf("SendMessages: %w", err) } var errs error + + // Fetch replies. Each message with the Echo flag triggers a reply of the same + // type. Additionally, if the first message of the batch has the Echo flag, we + // get a reply of type NFT_MSG_NEWGEN, which we ignore. + replyIndex := 0 + for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { + replyIndex++ + } + replies, err := conn.Receive() + for err == nil && len(replies) != 0 { + reply := replies[0] + if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence { + // The next message is the acknowledgement for the first message in the + // batch; stop looking for replies. + break + } else if replyIndex < len(cc.messages) { + msg := messages[replyIndex+1] + if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type { + replyIndex++ + for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { + replyIndex++ + } + } + } + replies = replies[1:] + if len(replies) == 0 { + replies, err = conn.Receive() + } + } + // Fetch the requested acknowledgement for each message we sent. - for _, msg := range cc.messages { - if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { - if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { + for i := range cc.messages { + if i != 0 { + _, err = conn.Receive() + } + if err != nil { + if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) { // Kernel will only send one error to user space. return err } @@ -282,6 +297,9 @@ func (cc *Conn) Flush() error { if errs != nil { return fmt.Errorf("conn.Receive: %w", errs) } + if replyIndex < len(cc.messages) { + return fmt.Errorf("missing reply for message %d in batch", replyIndex) + } return nil } diff --git a/internal/nftest/nftest.go b/internal/nftest/nftest.go index 2709fa7..8d5b496 100644 --- a/internal/nftest/nftest.go +++ b/internal/nftest/nftest.go @@ -25,10 +25,21 @@ func (r *Recorder) Conn() (*nftables.Conn, error) { func(req []netlink.Message) ([]netlink.Message, error) { r.requests = append(r.requests, req...) - acks := make([]netlink.Message, 0, len(req)) + replies := make([]netlink.Message, 0, len(req)) + // Generate replies. + for _, msg := range req { + if msg.Header.Flags&netlink.Echo != 0 { + data := append([]byte{}, msg.Data...) + replies = append(replies, netlink.Message{ + Header: msg.Header, + Data: data, + }) + } + } + // Generate acknowledgements. for _, msg := range req { if msg.Header.Flags&netlink.Acknowledge != 0 { - acks = append(acks, netlink.Message{ + replies = append(replies, netlink.Message{ Header: netlink.Header{ Length: 4, Type: netlink.Error, @@ -39,7 +50,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) { }) } } - return acks, nil + return replies, nil })) } diff --git a/nftables_test.go b/nftables_test.go index 69855b6..6e9ecfe 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -79,6 +79,40 @@ func linediff(a, b string) string { return buf.String() } +func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption { + return nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) { + var replies []netlink.Message + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + + // Generate replies. + if msg.Header.Flags&netlink.Echo != 0 { + data := append([]byte{}, msg.Data...) + replies = append(replies, netlink.Message{ + Header: msg.Header, + Data: data, + }) + } + } + return replies, nil + }) +} + func ifname(n string) []byte { b := make([]byte, 16) copy(b, []byte(n+"\x00")) @@ -370,28 +404,7 @@ func TestConfigureNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -638,28 +651,7 @@ func TestConfigureNATSourceAddress(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1106,28 +1098,7 @@ func TestAddCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1172,28 +1143,7 @@ func TestDeleteCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1225,28 +1175,7 @@ func TestDelRule(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1272,28 +1201,7 @@ func TestLog(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1324,28 +1232,7 @@ func TestTProxy(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1389,28 +1276,7 @@ func TestTProxyWithAddrField(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1457,28 +1323,7 @@ func TestCt(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1518,28 +1363,7 @@ func TestSecMarkMarshaling(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - conn, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + conn, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1922,28 +1746,7 @@ func TestCtSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1982,28 +1785,7 @@ func TestCtStat(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -2042,28 +1824,7 @@ func TestAddRuleWithPosition(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -2472,28 +2233,7 @@ func TestAddChain(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -2551,28 +2291,7 @@ func TestDelChain(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -3284,28 +3003,7 @@ func TestConfigureClamping(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3419,28 +3117,7 @@ func TestMatchPacketHeader(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3549,28 +3226,7 @@ func TestDropVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3651,28 +3307,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5431,28 +5066,7 @@ func TestConfigureNATRedirect(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5538,28 +5152,7 @@ func TestConfigureJumpVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5646,28 +5239,7 @@ func TestConfigureReturnVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5733,28 +5305,7 @@ func TestConfigureRangePort(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5833,28 +5384,7 @@ func TestConfigureRangeIPv4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5925,28 +5455,7 @@ func TestConfigureRangeIPv6(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6039,28 +5548,7 @@ func TestSet4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6144,28 +5632,7 @@ func TestSetComment(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6289,28 +5756,7 @@ func TestMasq(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6422,28 +5868,7 @@ func TestReject(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6552,28 +5977,7 @@ func TestFib(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6713,28 +6117,7 @@ func TestNumgen(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6800,28 +6183,7 @@ func TestMap(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6920,28 +6282,7 @@ func TestVmap(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6982,28 +6323,7 @@ func TestJHash(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7084,28 +6404,7 @@ func TestDup(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7184,28 +6483,7 @@ func TestDupWoDev(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7311,28 +6589,7 @@ func TestQuota(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7388,28 +6645,7 @@ func TestStatelessNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } diff --git a/rule.go b/rule.go index 54146c1..68da81f 100644 --- a/rule.go +++ b/rule.go @@ -94,7 +94,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), - Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, }, Data: append(extraHeader(uint8(t.Family), 0), data...), } @@ -164,20 +164,20 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { msgData := []byte{} msgData = append(msgData, data...) - var flags netlink.HeaderFlags if r.UserData != nil { msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, })...) } + var flags netlink.HeaderFlags switch op { case operationAdd: - flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND + flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append case operationInsert: - flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO + flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo case operationReplace: - flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE + flags = netlink.Request | netlink.Acknowledge | netlink.Replace } if r.Position != 0 || (r.Flags&(1<