Fix incorrect netlink acknowledgement handling (#194)

fixes https://github.com/google/nftables/issues/175
This commit is contained in:
turekt 2022-10-02 14:01:48 +00:00 committed by GitHub
parent 0aa65c0fdd
commit 535f5eb8da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 132 additions and 5 deletions

48
conn.go
View File

@ -15,9 +15,11 @@
package nftables
import (
"errors"
"fmt"
"sync"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nltest"
@ -130,6 +132,52 @@ func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
return nlconn, func() error { return nlconn.Close() }, nil
}
func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]netlink.Message, error) {
if nlconn == nil {
return nil, errors.New("netlink conn is not initialized")
}
// first receive will be the message that we expect
reply, err := nlconn.Receive()
if err != nil {
return nil, err
}
if (sentMsgFlags & netlink.Acknowledge) == 0 {
// we did not request an ack
return reply, nil
}
if (sentMsgFlags & netlink.Dump) == netlink.Dump {
// sent message has Dump flag set, there will be no acks
// https://github.com/torvalds/linux/blob/7e062cda7d90543ac8c7700fc7c5527d0c0f22ad/net/netlink/af_netlink.c#L2387-L2390
return reply, nil
}
// Dump flag is not set, we expect an ack
ack, err := nlconn.Receive()
if err != nil {
return nil, err
}
if len(ack) == 0 {
return nil, errors.New("received an empty ack")
}
msg := ack[0]
if msg.Header.Type != netlink.Error {
// acks should be delivered as NLMSG_ERROR
return nil, fmt.Errorf("expected header %v, but got %v", netlink.Error, msg.Header.Type)
}
if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 {
// if errno field is not set to 0 (success), this is an error
return nil, fmt.Errorf("error delivered in message: %v", msg.Data)
}
return reply, nil
}
// CloseLasting closes the lasting netlink connection that has been opened using
// AsLasting option when creating this connection. If either no lasting netlink
// connection has been opened or the lasting connection is already in the

View File

@ -1758,6 +1758,7 @@ func TestGetObjReset(t *testing.T) {
nil,
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x64, Type: 0xa12, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x2, 0x0, 0x0, 0x10, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x66, 0x77, 0x64, 0x65, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x1, 0x1c, 0x0, 0x4, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x61, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9, 0xc, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2}}},
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}},
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 36, Type: netlink.Error, Flags: 0x100, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0, 0, 0, 0, 88, 0, 0, 0, 12, 10, 5, 4, 143, 109, 199, 146, 236, 9, 0, 0}}},
}
c, err := nftables.New(nftables.WithTestDial(
@ -2457,6 +2458,84 @@ func TestCreateUseAnonymousSet(t *testing.T) {
}
}
func TestCappedErrMsgOnSets(t *testing.T) {
c, newNS := openSystemNFTConn(t)
c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
defer cleanupSystemNFTConn(t, newNS)
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
if err := c.Flush(); err != nil {
t.Errorf("failed adding table: %v", err)
}
tables, err := c.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
t.Errorf("failed to list IPv4 tables: %v", err)
}
for _, t := range tables {
if t.Name == "filter" {
filter = t
break
}
}
ifSet := &nftables.Set{
Table: filter,
Name: "if_set",
KeyType: nftables.TypeIFName,
}
if err := c.AddSet(ifSet, nil); err != nil {
t.Errorf("c.AddSet(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("failed adding set ifSet: %v", err)
}
ifSet, err = c.GetSetByName(filter, "if_set")
if err != nil {
t.Fatalf("failed getting set by name: %v", err)
}
elems, err := c.GetSetElements(ifSet)
if err != nil {
t.Errorf("failed getting set elements (ifSet): %v", err)
}
if got, want := len(elems), 0; got != want {
t.Errorf("first GetSetElements(ifSet) call len not equal: got %d, want %d", got, want)
}
elements := []nftables.SetElement{
{Key: []byte("012345678912345\x00")},
}
if err := c.SetAddElements(ifSet, elements); err != nil {
t.Errorf("adding SetElements(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("failed adding set elements ifSet: %v", err)
}
elems, err = c.GetSetElements(ifSet)
if err != nil {
t.Fatalf("failed getting set elements (ifSet): %v", err)
}
if got, want := len(elems), 1; got != want {
t.Fatalf("second GetSetElements(ifSet) call len not equal: got %d, want %d", got, want)
}
if got, want := elems, elements; !reflect.DeepEqual(elems, elements) {
t.Errorf("SetElements(ifSet) not equal: got %v, want %v", got, want)
}
}
func TestCreateUseNamedSet(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.

2
obj.go
View File

@ -207,7 +207,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}

View File

@ -87,7 +87,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}

6
set.go
View File

@ -783,7 +783,7 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}
@ -828,7 +828,7 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) {
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := conn.Receive()
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("Receive: %w", err)
}
@ -873,7 +873,7 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}