Compare commits
1 Commits
d078f784f8
...
a7f1c6e8c3
Author | SHA1 | Date |
---|---|---|
|
a7f1c6e8c3 |
66
conn.go
66
conn.go
|
@ -171,6 +171,24 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
|
||||||
return reply, nil
|
return reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(reply) != 0 {
|
||||||
|
last := reply[len(reply)-1]
|
||||||
|
for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type {
|
||||||
|
// we are not finished, the message is overrun
|
||||||
|
r, err := nlconn.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
reply = append(reply, r...)
|
||||||
|
last = reply[len(reply)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 {
|
||||||
|
// we have already collected an ack
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Now we expect an ack
|
// Now we expect an ack
|
||||||
ack, err := nlconn.Receive()
|
ack, err := nlconn.Receive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -178,7 +196,8 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ack) == 0 {
|
if len(ack) == 0 {
|
||||||
return nil, errors.New("received an empty ack")
|
// received an empty ack?
|
||||||
|
return reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := ack[0]
|
msg := ack[0]
|
||||||
|
@ -244,49 +263,15 @@ func (cc *Conn) Flush() error {
|
||||||
}
|
}
|
||||||
defer func() { _ = closer() }()
|
defer func() { _ = closer() }()
|
||||||
|
|
||||||
messages, err := conn.SendMessages(batch(cc.messages))
|
if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("SendMessages: %w", err)
|
return fmt.Errorf("SendMessages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errs error
|
var errs error
|
||||||
|
|
||||||
// Fetch replies. Each message with the Echo flag triggers a reply of the same
|
|
||||||
// type. Additionally, if the first message of the batch has the Echo flag, we
|
|
||||||
// get a reply of type NFT_MSG_NEWGEN, which we ignore.
|
|
||||||
replyIndex := 0
|
|
||||||
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
|
|
||||||
replyIndex++
|
|
||||||
}
|
|
||||||
replies, err := conn.Receive()
|
|
||||||
for err == nil && len(replies) != 0 {
|
|
||||||
reply := replies[0]
|
|
||||||
if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence {
|
|
||||||
// The next message is the acknowledgement for the first message in the
|
|
||||||
// batch; stop looking for replies.
|
|
||||||
break
|
|
||||||
} else if replyIndex < len(cc.messages) {
|
|
||||||
msg := messages[replyIndex+1]
|
|
||||||
if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type {
|
|
||||||
replyIndex++
|
|
||||||
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
|
|
||||||
replyIndex++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
replies = replies[1:]
|
|
||||||
if len(replies) == 0 {
|
|
||||||
replies, err = conn.Receive()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch the requested acknowledgement for each message we sent.
|
// Fetch the requested acknowledgement for each message we sent.
|
||||||
for i := range cc.messages {
|
for _, msg := range cc.messages {
|
||||||
if i != 0 {
|
if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil {
|
||||||
_, err = conn.Receive()
|
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) {
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) {
|
|
||||||
// Kernel will only send one error to user space.
|
// Kernel will only send one error to user space.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -297,9 +282,6 @@ func (cc *Conn) Flush() error {
|
||||||
if errs != nil {
|
if errs != nil {
|
||||||
return fmt.Errorf("conn.Receive: %w", errs)
|
return fmt.Errorf("conn.Receive: %w", errs)
|
||||||
}
|
}
|
||||||
if replyIndex < len(cc.messages) {
|
|
||||||
return fmt.Errorf("missing reply for message %d in batch", replyIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,21 +25,10 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
|
||||||
func(req []netlink.Message) ([]netlink.Message, error) {
|
func(req []netlink.Message) ([]netlink.Message, error) {
|
||||||
r.requests = append(r.requests, req...)
|
r.requests = append(r.requests, req...)
|
||||||
|
|
||||||
replies := make([]netlink.Message, 0, len(req))
|
acks := make([]netlink.Message, 0, len(req))
|
||||||
// Generate replies.
|
|
||||||
for _, msg := range req {
|
|
||||||
if msg.Header.Flags&netlink.Echo != 0 {
|
|
||||||
data := append([]byte{}, msg.Data...)
|
|
||||||
replies = append(replies, netlink.Message{
|
|
||||||
Header: msg.Header,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Generate acknowledgements.
|
|
||||||
for _, msg := range req {
|
for _, msg := range req {
|
||||||
if msg.Header.Flags&netlink.Acknowledge != 0 {
|
if msg.Header.Flags&netlink.Acknowledge != 0 {
|
||||||
replies = append(replies, netlink.Message{
|
acks = append(acks, netlink.Message{
|
||||||
Header: netlink.Header{
|
Header: netlink.Header{
|
||||||
Length: 4,
|
Length: 4,
|
||||||
Type: netlink.Error,
|
Type: netlink.Error,
|
||||||
|
@ -50,7 +39,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return replies, nil
|
return acks, nil
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
908
nftables_test.go
908
nftables_test.go
File diff suppressed because it is too large
Load Diff
10
rule.go
10
rule.go
|
@ -94,7 +94,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
|
||||||
message := netlink.Message{
|
message := netlink.Message{
|
||||||
Header: netlink.Header{
|
Header: netlink.Header{
|
||||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
|
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
|
||||||
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump,
|
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO,
|
||||||
},
|
},
|
||||||
Data: append(extraHeader(uint8(t.Family), 0), data...),
|
Data: append(extraHeader(uint8(t.Family), 0), data...),
|
||||||
}
|
}
|
||||||
|
@ -164,20 +164,20 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
|
||||||
msgData := []byte{}
|
msgData := []byte{}
|
||||||
|
|
||||||
msgData = append(msgData, data...)
|
msgData = append(msgData, data...)
|
||||||
|
var flags netlink.HeaderFlags
|
||||||
if r.UserData != nil {
|
if r.UserData != nil {
|
||||||
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
|
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
|
||||||
{Type: unix.NFTA_RULE_USERDATA, Data: r.UserData},
|
{Type: unix.NFTA_RULE_USERDATA, Data: r.UserData},
|
||||||
})...)
|
})...)
|
||||||
}
|
}
|
||||||
|
|
||||||
var flags netlink.HeaderFlags
|
|
||||||
switch op {
|
switch op {
|
||||||
case operationAdd:
|
case operationAdd:
|
||||||
flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append
|
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND
|
||||||
case operationInsert:
|
case operationInsert:
|
||||||
flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo
|
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO
|
||||||
case operationReplace:
|
case operationReplace:
|
||||||
flags = netlink.Request | netlink.Acknowledge | netlink.Replace
|
flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Position != 0 || (r.Flags&(1<<unix.NFTA_RULE_POSITION)) != 0 {
|
if r.Position != 0 || (r.Flags&(1<<unix.NFTA_RULE_POSITION)) != 0 {
|
||||||
|
|
Loading…
Reference in New Issue