Compare commits

..

1 Commits

Author SHA1 Message Date
Antonio Ojea a7f1c6e8c3
Merge 1e48c1007e into d11ef81b6a 2025-03-18 10:04:17 +01:00
4 changed files with 868 additions and 133 deletions

66
conn.go
View File

@ -171,6 +171,24 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
return reply, nil return reply, nil
} }
if len(reply) != 0 {
last := reply[len(reply)-1]
for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type {
// we are not finished, the message is overrun
r, err := nlconn.Receive()
if err != nil {
return nil, err
}
reply = append(reply, r...)
last = reply[len(reply)-1]
}
if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 {
// we have already collected an ack
return reply, nil
}
}
// Now we expect an ack // Now we expect an ack
ack, err := nlconn.Receive() ack, err := nlconn.Receive()
if err != nil { if err != nil {
@ -178,7 +196,8 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
} }
if len(ack) == 0 { if len(ack) == 0 {
return nil, errors.New("received an empty ack") // received an empty ack?
return reply, nil
} }
msg := ack[0] msg := ack[0]
@ -244,49 +263,15 @@ func (cc *Conn) Flush() error {
} }
defer func() { _ = closer() }() defer func() { _ = closer() }()
messages, err := conn.SendMessages(batch(cc.messages)) if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
} }
var errs error var errs error
// Fetch replies. Each message with the Echo flag triggers a reply of the same
// type. Additionally, if the first message of the batch has the Echo flag, we
// get a reply of type NFT_MSG_NEWGEN, which we ignore.
replyIndex := 0
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
replyIndex++
}
replies, err := conn.Receive()
for err == nil && len(replies) != 0 {
reply := replies[0]
if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence {
// The next message is the acknowledgement for the first message in the
// batch; stop looking for replies.
break
} else if replyIndex < len(cc.messages) {
msg := messages[replyIndex+1]
if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type {
replyIndex++
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
replyIndex++
}
}
}
replies = replies[1:]
if len(replies) == 0 {
replies, err = conn.Receive()
}
}
// Fetch the requested acknowledgement for each message we sent. // Fetch the requested acknowledgement for each message we sent.
for i := range cc.messages { for _, msg := range cc.messages {
if i != 0 { if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil {
_, err = conn.Receive() if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) {
}
if err != nil {
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) {
// Kernel will only send one error to user space. // Kernel will only send one error to user space.
return err return err
} }
@ -297,9 +282,6 @@ func (cc *Conn) Flush() error {
if errs != nil { if errs != nil {
return fmt.Errorf("conn.Receive: %w", errs) return fmt.Errorf("conn.Receive: %w", errs)
} }
if replyIndex < len(cc.messages) {
return fmt.Errorf("missing reply for message %d in batch", replyIndex)
}
return nil return nil
} }

View File

@ -25,21 +25,10 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
func(req []netlink.Message) ([]netlink.Message, error) { func(req []netlink.Message) ([]netlink.Message, error) {
r.requests = append(r.requests, req...) r.requests = append(r.requests, req...)
replies := make([]netlink.Message, 0, len(req)) acks := make([]netlink.Message, 0, len(req))
// Generate replies.
for _, msg := range req {
if msg.Header.Flags&netlink.Echo != 0 {
data := append([]byte{}, msg.Data...)
replies = append(replies, netlink.Message{
Header: msg.Header,
Data: data,
})
}
}
// Generate acknowledgements.
for _, msg := range req { for _, msg := range req {
if msg.Header.Flags&netlink.Acknowledge != 0 { if msg.Header.Flags&netlink.Acknowledge != 0 {
replies = append(replies, netlink.Message{ acks = append(acks, netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Length: 4, Length: 4,
Type: netlink.Error, Type: netlink.Error,
@ -50,7 +39,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
}) })
} }
} }
return replies, nil return acks, nil
})) }))
} }

File diff suppressed because it is too large Load Diff

10
rule.go
View File

@ -94,7 +94,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
message := netlink.Message{ message := netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO,
}, },
Data: append(extraHeader(uint8(t.Family), 0), data...), Data: append(extraHeader(uint8(t.Family), 0), data...),
} }
@ -164,20 +164,20 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
msgData := []byte{} msgData := []byte{}
msgData = append(msgData, data...) msgData = append(msgData, data...)
var flags netlink.HeaderFlags
if r.UserData != nil { if r.UserData != nil {
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData},
})...) })...)
} }
var flags netlink.HeaderFlags
switch op { switch op {
case operationAdd: case operationAdd:
flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND
case operationInsert: case operationInsert:
flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO
case operationReplace: case operationReplace:
flags = netlink.Request | netlink.Acknowledge | netlink.Replace flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE
} }
if r.Position != 0 || (r.Flags&(1<<unix.NFTA_RULE_POSITION)) != 0 { if r.Position != 0 || (r.Flags&(1<<unix.NFTA_RULE_POSITION)) != 0 {