From 535f5eb8da79067164607508180c53e95724ea19 Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Sun, 2 Oct 2022 14:01:48 +0000 Subject: [PATCH] Fix incorrect netlink acknowledgement handling (#194) fixes https://github.com/google/nftables/issues/175 --- conn.go | 48 +++++++++++++++++++++++++++++ nftables_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ obj.go | 2 +- rule.go | 2 +- set.go | 6 ++-- 5 files changed, 132 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 56544dd..711d7f6 100644 --- a/conn.go +++ b/conn.go @@ -15,9 +15,11 @@ package nftables import ( + "errors" "fmt" "sync" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nltest" @@ -130,6 +132,52 @@ func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) { return nlconn, func() error { return nlconn.Close() }, nil } +func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]netlink.Message, error) { + if nlconn == nil { + return nil, errors.New("netlink conn is not initialized") + } + + // first receive will be the message that we expect + reply, err := nlconn.Receive() + if err != nil { + return nil, err + } + + if (sentMsgFlags & netlink.Acknowledge) == 0 { + // we did not request an ack + return reply, nil + } + + if (sentMsgFlags & netlink.Dump) == netlink.Dump { + // sent message has Dump flag set, there will be no acks + // https://github.com/torvalds/linux/blob/7e062cda7d90543ac8c7700fc7c5527d0c0f22ad/net/netlink/af_netlink.c#L2387-L2390 + return reply, nil + } + + // Dump flag is not set, we expect an ack + ack, err := nlconn.Receive() + if err != nil { + return nil, err + } + + if len(ack) == 0 { + return nil, errors.New("received an empty ack") + } + + msg := ack[0] + if msg.Header.Type != netlink.Error { + // acks should be delivered as NLMSG_ERROR + return nil, fmt.Errorf("expected header %v, but got %v", netlink.Error, msg.Header.Type) + } + + if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 { + // if errno field is not set to 0 (success), this is an error + return nil, fmt.Errorf("error delivered in message: %v", msg.Data) + } + + return reply, nil +} + // CloseLasting closes the lasting netlink connection that has been opened using // AsLasting option when creating this connection. If either no lasting netlink // connection has been opened or the lasting connection is already in the diff --git a/nftables_test.go b/nftables_test.go index 74b593f..be6cda3 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1758,6 +1758,7 @@ func TestGetObjReset(t *testing.T) { nil, []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x64, Type: 0xa12, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x2, 0x0, 0x0, 0x10, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x66, 0x77, 0x64, 0x65, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x1, 0x1c, 0x0, 0x4, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x61, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9, 0xc, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2}}}, []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, + []netlink.Message{netlink.Message{Header: netlink.Header{Length: 36, Type: netlink.Error, Flags: 0x100, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0, 0, 0, 0, 88, 0, 0, 0, 12, 10, 5, 4, 143, 109, 199, 146, 236, 9, 0, 0}}}, } c, err := nftables.New(nftables.WithTestDial( @@ -2457,6 +2458,84 @@ func TestCreateUseAnonymousSet(t *testing.T) { } } +func TestCappedErrMsgOnSets(t *testing.T) { + c, newNS := openSystemNFTConn(t) + c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + defer cleanupSystemNFTConn(t, newNS) + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + if err := c.Flush(); err != nil { + t.Errorf("failed adding table: %v", err) + } + tables, err := c.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + t.Errorf("failed to list IPv4 tables: %v", err) + } + + for _, t := range tables { + if t.Name == "filter" { + filter = t + break + } + } + + ifSet := &nftables.Set{ + Table: filter, + Name: "if_set", + KeyType: nftables.TypeIFName, + } + if err := c.AddSet(ifSet, nil); err != nil { + t.Errorf("c.AddSet(ifSet) failed: %v", err) + } + if err := c.Flush(); err != nil { + t.Errorf("failed adding set ifSet: %v", err) + } + ifSet, err = c.GetSetByName(filter, "if_set") + if err != nil { + t.Fatalf("failed getting set by name: %v", err) + } + + elems, err := c.GetSetElements(ifSet) + if err != nil { + t.Errorf("failed getting set elements (ifSet): %v", err) + } + + if got, want := len(elems), 0; got != want { + t.Errorf("first GetSetElements(ifSet) call len not equal: got %d, want %d", got, want) + } + + elements := []nftables.SetElement{ + {Key: []byte("012345678912345\x00")}, + } + if err := c.SetAddElements(ifSet, elements); err != nil { + t.Errorf("adding SetElements(ifSet) failed: %v", err) + } + if err := c.Flush(); err != nil { + t.Errorf("failed adding set elements ifSet: %v", err) + } + + elems, err = c.GetSetElements(ifSet) + if err != nil { + t.Fatalf("failed getting set elements (ifSet): %v", err) + } + + if got, want := len(elems), 1; got != want { + t.Fatalf("second GetSetElements(ifSet) call len not equal: got %d, want %d", got, want) + } + + if got, want := elems, elements; !reflect.DeepEqual(elems, elements) { + t.Errorf("SetElements(ifSet) not equal: got %v, want %v", got, want) + } +} + func TestCreateUseNamedSet(t *testing.T) { // Create a new network namespace to test these operations, // and tear down the namespace at test completion. diff --git a/obj.go b/obj.go index 3fd01e2..08d43f4 100644 --- a/obj.go +++ b/obj.go @@ -207,7 +207,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { return nil, fmt.Errorf("SendMessages: %v", err) } - reply, err := conn.Receive() + reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { return nil, fmt.Errorf("Receive: %v", err) } diff --git a/rule.go b/rule.go index bc83b6c..95bfdff 100644 --- a/rule.go +++ b/rule.go @@ -87,7 +87,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { return nil, fmt.Errorf("SendMessages: %v", err) } - reply, err := conn.Receive() + reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { return nil, fmt.Errorf("Receive: %v", err) } diff --git a/set.go b/set.go index 907ea77..0240ed0 100644 --- a/set.go +++ b/set.go @@ -783,7 +783,7 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) { return nil, fmt.Errorf("SendMessages: %v", err) } - reply, err := conn.Receive() + reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { return nil, fmt.Errorf("Receive: %v", err) } @@ -828,7 +828,7 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { return nil, fmt.Errorf("SendMessages: %w", err) } - reply, err := conn.Receive() + reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { return nil, fmt.Errorf("Receive: %w", err) } @@ -873,7 +873,7 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { return nil, fmt.Errorf("SendMessages: %v", err) } - reply, err := conn.Receive() + reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { return nil, fmt.Errorf("Receive: %v", err) }