diff --git a/chain.go b/chain.go index 4f4c0a5..f1853cf 100644 --- a/chain.go +++ b/chain.go @@ -140,7 +140,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -161,7 +161,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -179,7 +179,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index d4759b1..d974b80 100644 --- a/conn.go +++ b/conn.go @@ -41,7 +41,7 @@ type Conn struct { lasting bool // establish a lasting connection to be used across multiple netlink operations. mu sync.Mutex // protects the following state - messages []netlink.Message + messages []netlinkMessage err error nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. sockOptions []SockOption @@ -49,6 +49,12 @@ type Conn struct { allocatedIDs uint32 } +type netlinkMessage struct { + Header netlink.Header + Data []byte + rule *Rule +} + // ConnOption is an option to change the behavior of the nftables Conn returned by Open. type ConnOption func(*Conn) @@ -171,24 +177,6 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([] return reply, nil } - if len(reply) != 0 { - last := reply[len(reply)-1] - for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type { - // we are not finished, the message is overrun - r, err := nlconn.Receive() - if err != nil { - return nil, err - } - reply = append(reply, r...) - last = reply[len(reply)-1] - } - - if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 { - // we have already collected an ack - return reply, nil - } - } - // Now we expect an ack ack, err := nlconn.Receive() if err != nil { @@ -196,8 +184,7 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([] } if len(ack) == 0 { - // received an empty ack? - return reply, nil + return nil, errors.New("received an empty ack") } msg := ack[0] @@ -263,15 +250,54 @@ func (cc *Conn) Flush() error { } defer func() { _ = closer() }() - if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + messages, err := conn.SendMessages(batch(cc.messages)) + if err != nil { return fmt.Errorf("SendMessages: %w", err) } var errs error + + // Fetch replies. Each message with the Echo flag triggers a reply of the same + // type. Additionally, if the first message of the batch has the Echo flag, we + // get a reply of type NFT_MSG_NEWGEN, which we ignore. + replyIndex := 0 + for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { + replyIndex++ + } + replies, err := conn.Receive() + for err == nil && len(replies) != 0 { + reply := replies[0] + if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence { + // The next message is the acknowledgement for the first message in the + // batch; stop looking for replies. + break + } else if replyIndex < len(cc.messages) { + msg := messages[replyIndex+1] + if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type { + // The only messages which set the echo flag are rule create messages. + err := cc.messages[replyIndex].rule.handleCreateReply(reply) + if err != nil { + errs = errors.Join(errs, err) + } + replyIndex++ + for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { + replyIndex++ + } + } + } + replies = replies[1:] + if len(replies) == 0 { + replies, err = conn.Receive() + } + } + // Fetch the requested acknowledgement for each message we sent. - for _, msg := range cc.messages { - if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { - if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { + for i := range cc.messages { + if i != 0 { + _, err = conn.Receive() + } + if err != nil { + if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) { // Kernel will only send one error to user space. return err } @@ -282,6 +308,9 @@ func (cc *Conn) Flush() error { if errs != nil { return fmt.Errorf("conn.Receive: %w", errs) } + if replyIndex < len(cc.messages) { + return fmt.Errorf("missing reply for message %d in batch", replyIndex) + } return nil } @@ -291,7 +320,7 @@ func (cc *Conn) Flush() error { func (cc *Conn) FlushRuleset() { cc.mu.Lock() defer cc.mu.Unlock() - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -350,26 +379,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { return b } -func batch(messages []netlink.Message) []netlink.Message { - batch := []netlink.Message{ - { - Header: netlink.Header{ - Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), - Flags: netlink.Request, - }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), +func batch(messages []netlinkMessage) []netlink.Message { + batch := make([]netlink.Message, len(messages)+2) + batch[0] = netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), } - batch = append(batch, messages...) + for i, msg := range messages { + batch[i+1] = netlink.Message{ + Header: msg.Header, + Data: msg.Data, + } + } - batch = append(batch, netlink.Message{ + batch[len(messages)+1] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }) + } return batch } diff --git a/flowtable.go b/flowtable.go index 93dbcb5..a35712f 100644 --- a/flowtable.go +++ b/flowtable.go @@ -142,7 +142,7 @@ func (cc *Conn) AddFlowtable(f *Flowtable) *Flowtable { {Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)}, })...) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -162,7 +162,7 @@ func (cc *Conn) DelFlowtable(f *Flowtable) { {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/internal/nftest/nftest.go b/internal/nftest/nftest.go index 2709fa7..509bac3 100644 --- a/internal/nftest/nftest.go +++ b/internal/nftest/nftest.go @@ -8,7 +8,9 @@ import ( "testing" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" ) // Recorder provides an nftables connection that does not send to the Linux @@ -21,14 +23,34 @@ type Recorder struct { // Conn opens an nftables connection that records netlink messages into the // Recorder. func (r *Recorder) Conn() (*nftables.Conn, error) { + nextHandle := uint64(1) return nftables.New(nftables.WithTestDial( func(req []netlink.Message) ([]netlink.Message, error) { r.requests = append(r.requests, req...) - acks := make([]netlink.Message, 0, len(req)) + replies := make([]netlink.Message, 0, len(req)) + // Generate replies. + for _, msg := range req { + if msg.Header.Flags&netlink.Echo != 0 { + data := append([]byte{}, msg.Data...) + switch msg.Header.Type { + case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE): + attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)}, + }) + nextHandle++ + data = append(data, attrs...) + } + replies = append(replies, netlink.Message{ + Header: msg.Header, + Data: data, + }) + } + } + // Generate acknowledgements. for _, msg := range req { if msg.Header.Flags&netlink.Acknowledge != 0 { - acks = append(acks, netlink.Message{ + replies = append(replies, netlink.Message{ Header: netlink.Header{ Length: 4, Type: netlink.Error, @@ -39,7 +61,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) { }) } } - return acks, nil + return replies, nil })) } diff --git a/nftables_test.go b/nftables_test.go index 69855b6..fd0fa1a 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -79,6 +79,49 @@ func linediff(a, b string) string { return buf.String() } +func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption { + nextHandle := uint64(1) + return nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) { + var replies []netlink.Message + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + + // Generate replies. + if msg.Header.Flags&netlink.Echo != 0 { + data := append([]byte{}, msg.Data...) + switch msg.Header.Type { + case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE): + attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)}, + }) + nextHandle++ + data = append(data, attrs...) + } + replies = append(replies, netlink.Message{ + Header: msg.Header, + Data: data, + }) + } + } + return replies, nil + }) +} + func ifname(n string) []byte { b := make([]byte, 16) copy(b, []byte(n+"\x00")) @@ -282,7 +325,7 @@ func TestRuleHandle(t *testing.T) { } for _, tt := range tests { - for _, afterFlush := range []bool{false} { + for _, afterFlush := range []bool{false, true} { flushName := map[bool]string{ false: "-before-flush", true: "-after-flush", @@ -370,28 +413,7 @@ func TestConfigureNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -638,28 +660,7 @@ func TestConfigureNATSourceAddress(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1106,28 +1107,7 @@ func TestAddCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1172,28 +1152,7 @@ func TestDeleteCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1225,28 +1184,7 @@ func TestDelRule(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1272,28 +1210,7 @@ func TestLog(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1324,28 +1241,7 @@ func TestTProxy(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1389,28 +1285,7 @@ func TestTProxyWithAddrField(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1457,28 +1332,7 @@ func TestCt(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1518,28 +1372,7 @@ func TestSecMarkMarshaling(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - conn, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + conn, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1922,28 +1755,7 @@ func TestCtSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -1982,28 +1794,7 @@ func TestCtStat(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -2042,28 +1833,7 @@ func TestAddRuleWithPosition(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -2472,28 +2242,7 @@ func TestAddChain(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -2551,28 +2300,7 @@ func TestDelChain(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -3284,28 +3012,7 @@ func TestConfigureClamping(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3419,28 +3126,7 @@ func TestMatchPacketHeader(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3549,28 +3235,7 @@ func TestDropVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -3651,28 +3316,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5431,28 +5075,7 @@ func TestConfigureNATRedirect(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5538,28 +5161,7 @@ func TestConfigureJumpVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5646,28 +5248,7 @@ func TestConfigureReturnVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5733,28 +5314,7 @@ func TestConfigureRangePort(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5833,28 +5393,7 @@ func TestConfigureRangeIPv4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -5925,28 +5464,7 @@ func TestConfigureRangeIPv6(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6039,28 +5557,7 @@ func TestSet4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6144,28 +5641,7 @@ func TestSetComment(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -6289,28 +5765,7 @@ func TestMasq(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6422,28 +5877,7 @@ func TestReject(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6552,28 +5986,7 @@ func TestFib(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6713,28 +6126,7 @@ func TestNumgen(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6800,28 +6192,7 @@ func TestMap(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6920,28 +6291,7 @@ func TestVmap(t *testing.T) { } for _, tt := range tests { - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(tt.want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, tt.want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, tt.want)) if err != nil { t.Fatal(err) } @@ -6982,28 +6332,7 @@ func TestJHash(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7084,28 +6413,7 @@ func TestDup(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7184,28 +6492,7 @@ func TestDupWoDev(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7311,28 +6598,7 @@ func TestQuota(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want[idx]) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - got := b - if !bytes.Equal(got, want[idx]) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want[idx]))) - } - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } @@ -7388,28 +6654,7 @@ func TestStatelessNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c, err := nftables.New(nftables.WithTestDial( - func(req []netlink.Message) ([]netlink.Message, error) { - for idx, msg := range req { - b, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } - if len(b) < 16 { - continue - } - b = b[16:] - if len(want) == 0 { - t.Errorf("no want entry for message %d: %x", idx, b) - continue - } - if got, want := b, want[0]; !bytes.Equal(got, want) { - t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) - } - want = want[1:] - } - return req, nil - })) + c, err := nftables.New(expectMessages(t, want)) if err != nil { t.Fatal(err) } diff --git a/obj.go b/obj.go index 634931b..65c4402 100644 --- a/obj.go +++ b/obj.go @@ -124,7 +124,7 @@ func (cc *Conn) AddObj(o Obj) Obj { attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -146,7 +146,7 @@ func (cc *Conn) DeleteObject(o Obj) { data := cc.marshalAttr(attrs) data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ), Flags: netlink.Request | netlink.Acknowledge, diff --git a/rule.go b/rule.go index 54146c1..65c6562 100644 --- a/rule.go +++ b/rule.go @@ -48,10 +48,13 @@ const ( type Rule struct { Table *Table Chain *Chain - // Handle identifies an existing Rule. + // Handle identifies an existing Rule. For a new Rule, this field is set + // during the Flush() in which the rule is committed. Make sure to not access + // this field concurrently with this Flush() to avoid data races. Handle uint64 // ID is an identifier for a new Rule, which is assigned by // AddRule/InsertRule, and only valid before the rule is committed by Flush(). + // The field is set to 0 during Flush(). ID uint32 // Position can be set to the Handle of another Rule to insert the new Rule // before (InsertRule) or after (AddRule) the existing rule. @@ -94,7 +97,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), - Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, }, Data: append(extraHeader(uint8(t.Family), 0), data...), } @@ -164,20 +167,23 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { msgData := []byte{} msgData = append(msgData, data...) - var flags netlink.HeaderFlags if r.UserData != nil { msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, })...) } + var flags netlink.HeaderFlags + var ruleRef *Rule switch op { case operationAdd: - flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND + flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append + ruleRef = r case operationInsert: - flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO + flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo + ruleRef = r case operationReplace: - flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE + flags = netlink.Request | netlink.Acknowledge | netlink.Replace } if r.Position != 0 || (r.Flags&(1<