Receive replies in Flush (#309)

Commit 0d9bfa4d18 added code to handle "overrun", but the commit is
very misleading. NLMSG_OVERRUN is in fact not a flag, but a complete
message type, so the (re&netlink.Overrun) masking makes no sense. Even
better, NLMSG_OVERRUN is never actually used by Linux.

The actual bug which the commit was attempting to fix is that Flush was
not receiving replies which the kernel sent for messages with the echo
flag. This change reverts that commit and instead adds code in Flush to
receive the replies.

I updated tests which simulate the kernel to generate replies.
This commit is contained in:
Jan Schär 2025-03-25 17:03:44 +01:00 committed by GitHub
parent d11ef81b6a
commit 9a2862f48b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 133 additions and 868 deletions

66
conn.go
View File

@ -171,24 +171,6 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
return reply, nil return reply, nil
} }
if len(reply) != 0 {
last := reply[len(reply)-1]
for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type {
// we are not finished, the message is overrun
r, err := nlconn.Receive()
if err != nil {
return nil, err
}
reply = append(reply, r...)
last = reply[len(reply)-1]
}
if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 {
// we have already collected an ack
return reply, nil
}
}
// Now we expect an ack // Now we expect an ack
ack, err := nlconn.Receive() ack, err := nlconn.Receive()
if err != nil { if err != nil {
@ -196,8 +178,7 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
} }
if len(ack) == 0 { if len(ack) == 0 {
// received an empty ack? return nil, errors.New("received an empty ack")
return reply, nil
} }
msg := ack[0] msg := ack[0]
@ -263,15 +244,49 @@ func (cc *Conn) Flush() error {
} }
defer func() { _ = closer() }() defer func() { _ = closer() }()
if _, err := conn.SendMessages(batch(cc.messages)); err != nil { messages, err := conn.SendMessages(batch(cc.messages))
if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
} }
var errs error var errs error
// Fetch replies. Each message with the Echo flag triggers a reply of the same
// type. Additionally, if the first message of the batch has the Echo flag, we
// get a reply of type NFT_MSG_NEWGEN, which we ignore.
replyIndex := 0
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
replyIndex++
}
replies, err := conn.Receive()
for err == nil && len(replies) != 0 {
reply := replies[0]
if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence {
// The next message is the acknowledgement for the first message in the
// batch; stop looking for replies.
break
} else if replyIndex < len(cc.messages) {
msg := messages[replyIndex+1]
if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type {
replyIndex++
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
replyIndex++
}
}
}
replies = replies[1:]
if len(replies) == 0 {
replies, err = conn.Receive()
}
}
// Fetch the requested acknowledgement for each message we sent. // Fetch the requested acknowledgement for each message we sent.
for _, msg := range cc.messages { for i := range cc.messages {
if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { if i != 0 {
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { _, err = conn.Receive()
}
if err != nil {
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) {
// Kernel will only send one error to user space. // Kernel will only send one error to user space.
return err return err
} }
@ -282,6 +297,9 @@ func (cc *Conn) Flush() error {
if errs != nil { if errs != nil {
return fmt.Errorf("conn.Receive: %w", errs) return fmt.Errorf("conn.Receive: %w", errs)
} }
if replyIndex < len(cc.messages) {
return fmt.Errorf("missing reply for message %d in batch", replyIndex)
}
return nil return nil
} }

View File

@ -25,10 +25,21 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
func(req []netlink.Message) ([]netlink.Message, error) { func(req []netlink.Message) ([]netlink.Message, error) {
r.requests = append(r.requests, req...) r.requests = append(r.requests, req...)
acks := make([]netlink.Message, 0, len(req)) replies := make([]netlink.Message, 0, len(req))
// Generate replies.
for _, msg := range req {
if msg.Header.Flags&netlink.Echo != 0 {
data := append([]byte{}, msg.Data...)
replies = append(replies, netlink.Message{
Header: msg.Header,
Data: data,
})
}
}
// Generate acknowledgements.
for _, msg := range req { for _, msg := range req {
if msg.Header.Flags&netlink.Acknowledge != 0 { if msg.Header.Flags&netlink.Acknowledge != 0 {
acks = append(acks, netlink.Message{ replies = append(replies, netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Length: 4, Length: 4,
Type: netlink.Error, Type: netlink.Error,
@ -39,7 +50,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
}) })
} }
} }
return acks, nil return replies, nil
})) }))
} }

File diff suppressed because it is too large Load Diff

10
rule.go
View File

@ -94,7 +94,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
message := netlink.Message{ message := netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, Flags: netlink.Request | netlink.Acknowledge | netlink.Dump,
}, },
Data: append(extraHeader(uint8(t.Family), 0), data...), Data: append(extraHeader(uint8(t.Family), 0), data...),
} }
@ -164,20 +164,20 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
msgData := []byte{} msgData := []byte{}
msgData = append(msgData, data...) msgData = append(msgData, data...)
var flags netlink.HeaderFlags
if r.UserData != nil { if r.UserData != nil {
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData},
})...) })...)
} }
var flags netlink.HeaderFlags
switch op { switch op {
case operationAdd: case operationAdd:
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append
case operationInsert: case operationInsert:
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo
case operationReplace: case operationReplace:
flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE flags = netlink.Request | netlink.Acknowledge | netlink.Replace
} }
if r.Position != 0 || (r.Flags&(1<<unix.NFTA_RULE_POSITION)) != 0 { if r.Position != 0 || (r.Flags&(1<<unix.NFTA_RULE_POSITION)) != 0 {