Netlink conn retry logic

Fixes https://github.com/google/nftables/issues/175 | Adds receiveWithRetry in conn.go | Adds tests for "no error" message use cases
This commit is contained in:
turekt 2022-09-21 18:55:38 +00:00
parent cbeb0fb1ec
commit 0d4369aacb
5 changed files with 199 additions and 5 deletions

49
conn.go
View File

@ -15,15 +15,22 @@
package nftables
import (
"errors"
"fmt"
"sync"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nltest"
"golang.org/x/sys/unix"
)
// errMsgNoError is returned in cases when netlink returns an error
// message with a status 0, which means that there is no error
// https://github.com/mdlayher/netlink/blob/9a593f9dc1a92ae1c7c45a9fc663b7dd1e111eac/message.go#L285-L289
var errMsgNoError = errors.New("no error")
// A Conn represents a netlink connection of the nftables family.
//
// All methods return their input, so that variables can be defined from string
@ -130,6 +137,48 @@ func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
return nlconn, func() error { return nlconn.Close() }, nil
}
func receiveWithRetry(nlconn *netlink.Conn) ([]netlink.Message, error) {
if nlconn == nil {
return nil, fmt.Errorf("netlink conn is not initialized")
}
receive := func() ([]netlink.Message, error) {
var err error
reply, err := nlconn.Receive()
if err != nil {
return nil, err
}
if len(reply) == 0 {
return reply, nil
}
msg := reply[0]
if msg.Header.Type == netlink.Error {
if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 {
return nil, fmt.Errorf("error delivered in message: %v", msg.Data)
}
return nil, errMsgNoError
}
return reply, nil
}
msgs, err := receive()
if err != nil {
// when "overflowing" counter objects netlink will return multiple
// "no error" messages (see TestCappedErrMsgOnObj in nftables_test.go)
// the loop is here to filter them all out
for err == errMsgNoError {
msgs, err := receive()
if err == nil {
return msgs, nil
}
if err != errMsgNoError {
return nil, err
}
}
}
return msgs, err
}
// CloseLasting closes the lasting netlink connection that has been opened using
// AsLasting option when creating this connection. If either no lasting netlink
// connection has been opened or the lasting connection is already in the

View File

@ -2457,6 +2457,151 @@ func TestCreateUseAnonymousSet(t *testing.T) {
}
}
func TestCappedErrMsgOnObj(t *testing.T) {
c, newNS := openSystemNFTConn(t)
c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
defer cleanupSystemNFTConn(t, newNS)
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
if err := c.Flush(); err != nil {
t.Errorf("failed adding table: %v", err)
}
const counterLen = 500
counters := make([]nftables.Obj, counterLen)
for i := 0; i < counterLen; i++ {
counters[i] = &nftables.CounterObj{
Table: filter,
Name: fmt.Sprintf("ctr%d", i),
Bytes: uint64(-1 * i),
Packets: uint64(-1 * i),
}
c.AddObj(counters[i])
}
if err := c.Flush(); err != nil {
// this test leverages the fact that read buffer of the netlink socket
// is too small to receive all ack messages sent by the kernel
//
// the Flush method tries to receive all acks after sending the messages
// https://github.com/google/nftables/blob/cbeb0fb1eccf9ef582c20982c72e73107d1898a5/conn.go#L185-L193
//
// buffer peeking implemented by the netlink go library fails with a "no buffer space available"
// https://github.com/mdlayher/netlink/blob/7fa043dcb6f27ed7e084dc90ddf5b6d4092478a9/conn_linux.go#L123-L127
// this results in part of "non-received" acknowledgments to end up being received
// in the following c.GetObj(counters[0]) call
//
// this can be fixed with extending the read buffer with conn.SetReadBuffer(size int)
// https://github.com/mdlayher/netlink/blob/7fa043dcb6f27ed7e084dc90ddf5b6d4092478a9/conn_linux.go#L204-L206
// in which case this test would still succeed
if got, want := fmt.Sprint(err), "no buffer space available"; !strings.Contains(got, want) {
t.Errorf("expected error %s, got %v", want, err)
}
}
// this GetObj call will receive some of the previously "non-received"
// acknowledgments which should get dropped by the receiveWithRetry
// function and successfully parse 500 counters
objs, err := c.GetObj(counters[0])
if err != nil {
t.Errorf("failed getting objects: %v", err)
}
if got, want := len(objs), len(counters); got != want {
t.Errorf("object list length not equal: got %d, want %d", got, want)
}
if got, want := objs, counters; !reflect.DeepEqual(got, want) {
t.Errorf("object list not equal: got %v, want %v", got, want)
}
}
func TestCappedErrMsgOnSets(t *testing.T) {
c, newNS := openSystemNFTConn(t)
c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
defer cleanupSystemNFTConn(t, newNS)
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
if err := c.Flush(); err != nil {
t.Errorf("failed adding table: %v", err)
}
tables, err := c.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
t.Errorf("failed to list IPv4 tables: %v", err)
}
for _, t := range tables {
if t.Name == "filter" {
filter = t
break
}
}
ifSet := &nftables.Set{
Table: filter,
Name: "if_set",
KeyType: nftables.TypeIFName,
}
if err := c.AddSet(ifSet, nil); err != nil {
t.Errorf("c.AddSet(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("failed adding set ifSet: %v", err)
}
ifSet, err = c.GetSetByName(filter, "if_set")
if err != nil {
t.Errorf("failed getting set by name: %v", err)
}
elems, err := c.GetSetElements(ifSet)
if err != nil {
t.Errorf("failed getting set elements (ifSet): %v", err)
}
if got, want := len(elems), 0; got != want {
t.Errorf("first GetSetElements(ifSet) call len not equal: got %d, want %d", got, want)
}
elements := []nftables.SetElement{
{Key: []byte("012345678912345\x00")},
}
if err := c.SetAddElements(ifSet, elements); err != nil {
t.Errorf("adding SetElements(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("failed adding set elements ifSet: %v", err)
}
elems, err = c.GetSetElements(ifSet)
if err != nil {
t.Fatalf("failed getting set elements (ifSet): %v", err)
}
if got, want := len(elems), 1; got != want {
t.Fatalf("second GetSetElements(ifSet) call len not equal: got %d, want %d", got, want)
}
if got, want := elems, elements; !reflect.DeepEqual(elems, elements) {
t.Errorf("SetElements(ifSet) not equal: got %v, want %v", got, want)
}
}
func TestCreateUseNamedSet(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.

2
obj.go
View File

@ -207,7 +207,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveWithRetry(conn)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}

View File

@ -87,7 +87,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveWithRetry(conn)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}

6
set.go
View File

@ -783,7 +783,7 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveWithRetry(conn)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}
@ -828,7 +828,7 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) {
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := conn.Receive()
reply, err := receiveWithRetry(conn)
if err != nil {
return nil, fmt.Errorf("Receive: %w", err)
}
@ -873,7 +873,7 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := conn.Receive()
reply, err := receiveWithRetry(conn)
if err != nil {
return nil, fmt.Errorf("Receive: %v", err)
}