Netlink conn retry logic
Fixes https://github.com/google/nftables/issues/175 | Adds receiveWithRetry in conn.go | Adds tests for "no error" message use cases
This commit is contained in:
parent
cbeb0fb1ec
commit
0d4369aacb
49
conn.go
49
conn.go
|
@ -15,15 +15,22 @@
|
|||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/mdlayher/netlink/nltest"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// errMsgNoError is returned in cases when netlink returns an error
|
||||
// message with a status 0, which means that there is no error
|
||||
// https://github.com/mdlayher/netlink/blob/9a593f9dc1a92ae1c7c45a9fc663b7dd1e111eac/message.go#L285-L289
|
||||
var errMsgNoError = errors.New("no error")
|
||||
|
||||
// A Conn represents a netlink connection of the nftables family.
|
||||
//
|
||||
// All methods return their input, so that variables can be defined from string
|
||||
|
@ -130,6 +137,48 @@ func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
|
|||
return nlconn, func() error { return nlconn.Close() }, nil
|
||||
}
|
||||
|
||||
func receiveWithRetry(nlconn *netlink.Conn) ([]netlink.Message, error) {
|
||||
if nlconn == nil {
|
||||
return nil, fmt.Errorf("netlink conn is not initialized")
|
||||
}
|
||||
|
||||
receive := func() ([]netlink.Message, error) {
|
||||
var err error
|
||||
reply, err := nlconn.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(reply) == 0 {
|
||||
return reply, nil
|
||||
}
|
||||
msg := reply[0]
|
||||
if msg.Header.Type == netlink.Error {
|
||||
if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 {
|
||||
return nil, fmt.Errorf("error delivered in message: %v", msg.Data)
|
||||
}
|
||||
return nil, errMsgNoError
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
msgs, err := receive()
|
||||
if err != nil {
|
||||
// when "overflowing" counter objects netlink will return multiple
|
||||
// "no error" messages (see TestCappedErrMsgOnObj in nftables_test.go)
|
||||
// the loop is here to filter them all out
|
||||
for err == errMsgNoError {
|
||||
msgs, err := receive()
|
||||
if err == nil {
|
||||
return msgs, nil
|
||||
}
|
||||
if err != errMsgNoError {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return msgs, err
|
||||
}
|
||||
|
||||
// CloseLasting closes the lasting netlink connection that has been opened using
|
||||
// AsLasting option when creating this connection. If either no lasting netlink
|
||||
// connection has been opened or the lasting connection is already in the
|
||||
|
|
145
nftables_test.go
145
nftables_test.go
|
@ -2457,6 +2457,151 @@ func TestCreateUseAnonymousSet(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCappedErrMsgOnObj(t *testing.T) {
|
||||
c, newNS := openSystemNFTConn(t)
|
||||
c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
|
||||
if err != nil {
|
||||
t.Fatalf("nftables.New() failed: %v", err)
|
||||
}
|
||||
defer cleanupSystemNFTConn(t, newNS)
|
||||
c.FlushRuleset()
|
||||
defer c.FlushRuleset()
|
||||
|
||||
filter := c.AddTable(&nftables.Table{
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Name: "filter",
|
||||
})
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Errorf("failed adding table: %v", err)
|
||||
}
|
||||
|
||||
const counterLen = 500
|
||||
counters := make([]nftables.Obj, counterLen)
|
||||
for i := 0; i < counterLen; i++ {
|
||||
counters[i] = &nftables.CounterObj{
|
||||
Table: filter,
|
||||
Name: fmt.Sprintf("ctr%d", i),
|
||||
Bytes: uint64(-1 * i),
|
||||
Packets: uint64(-1 * i),
|
||||
}
|
||||
c.AddObj(counters[i])
|
||||
}
|
||||
|
||||
if err := c.Flush(); err != nil {
|
||||
// this test leverages the fact that read buffer of the netlink socket
|
||||
// is too small to receive all ack messages sent by the kernel
|
||||
//
|
||||
// the Flush method tries to receive all acks after sending the messages
|
||||
// https://github.com/google/nftables/blob/cbeb0fb1eccf9ef582c20982c72e73107d1898a5/conn.go#L185-L193
|
||||
//
|
||||
// buffer peeking implemented by the netlink go library fails with a "no buffer space available"
|
||||
// https://github.com/mdlayher/netlink/blob/7fa043dcb6f27ed7e084dc90ddf5b6d4092478a9/conn_linux.go#L123-L127
|
||||
// this results in part of "non-received" acknowledgments to end up being received
|
||||
// in the following c.GetObj(counters[0]) call
|
||||
//
|
||||
// this can be fixed with extending the read buffer with conn.SetReadBuffer(size int)
|
||||
// https://github.com/mdlayher/netlink/blob/7fa043dcb6f27ed7e084dc90ddf5b6d4092478a9/conn_linux.go#L204-L206
|
||||
// in which case this test would still succeed
|
||||
if got, want := fmt.Sprint(err), "no buffer space available"; !strings.Contains(got, want) {
|
||||
t.Errorf("expected error %s, got %v", want, err)
|
||||
}
|
||||
}
|
||||
|
||||
// this GetObj call will receive some of the previously "non-received"
|
||||
// acknowledgments which should get dropped by the receiveWithRetry
|
||||
// function and successfully parse 500 counters
|
||||
objs, err := c.GetObj(counters[0])
|
||||
if err != nil {
|
||||
t.Errorf("failed getting objects: %v", err)
|
||||
}
|
||||
|
||||
if got, want := len(objs), len(counters); got != want {
|
||||
t.Errorf("object list length not equal: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if got, want := objs, counters; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("object list not equal: got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCappedErrMsgOnSets(t *testing.T) {
|
||||
c, newNS := openSystemNFTConn(t)
|
||||
c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
|
||||
if err != nil {
|
||||
t.Fatalf("nftables.New() failed: %v", err)
|
||||
}
|
||||
defer cleanupSystemNFTConn(t, newNS)
|
||||
c.FlushRuleset()
|
||||
defer c.FlushRuleset()
|
||||
|
||||
filter := c.AddTable(&nftables.Table{
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Name: "filter",
|
||||
})
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Errorf("failed adding table: %v", err)
|
||||
}
|
||||
tables, err := c.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
t.Errorf("failed to list IPv4 tables: %v", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == "filter" {
|
||||
filter = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ifSet := &nftables.Set{
|
||||
Table: filter,
|
||||
Name: "if_set",
|
||||
KeyType: nftables.TypeIFName,
|
||||
}
|
||||
if err := c.AddSet(ifSet, nil); err != nil {
|
||||
t.Errorf("c.AddSet(ifSet) failed: %v", err)
|
||||
}
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Errorf("failed adding set ifSet: %v", err)
|
||||
}
|
||||
ifSet, err = c.GetSetByName(filter, "if_set")
|
||||
if err != nil {
|
||||
t.Errorf("failed getting set by name: %v", err)
|
||||
}
|
||||
|
||||
elems, err := c.GetSetElements(ifSet)
|
||||
if err != nil {
|
||||
t.Errorf("failed getting set elements (ifSet): %v", err)
|
||||
}
|
||||
|
||||
if got, want := len(elems), 0; got != want {
|
||||
t.Errorf("first GetSetElements(ifSet) call len not equal: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
elements := []nftables.SetElement{
|
||||
{Key: []byte("012345678912345\x00")},
|
||||
}
|
||||
if err := c.SetAddElements(ifSet, elements); err != nil {
|
||||
t.Errorf("adding SetElements(ifSet) failed: %v", err)
|
||||
}
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Errorf("failed adding set elements ifSet: %v", err)
|
||||
}
|
||||
|
||||
elems, err = c.GetSetElements(ifSet)
|
||||
if err != nil {
|
||||
t.Fatalf("failed getting set elements (ifSet): %v", err)
|
||||
}
|
||||
|
||||
if got, want := len(elems), 1; got != want {
|
||||
t.Fatalf("second GetSetElements(ifSet) call len not equal: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if got, want := elems, elements; !reflect.DeepEqual(elems, elements) {
|
||||
t.Errorf("SetElements(ifSet) not equal: got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUseNamedSet(t *testing.T) {
|
||||
// Create a new network namespace to test these operations,
|
||||
// and tear down the namespace at test completion.
|
||||
|
|
2
obj.go
2
obj.go
|
@ -207,7 +207,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
|
|||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := conn.Receive()
|
||||
reply, err := receiveWithRetry(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Receive: %v", err)
|
||||
}
|
||||
|
|
2
rule.go
2
rule.go
|
@ -87,7 +87,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
|
|||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := conn.Receive()
|
||||
reply, err := receiveWithRetry(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Receive: %v", err)
|
||||
}
|
||||
|
|
6
set.go
6
set.go
|
@ -783,7 +783,7 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
|
|||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := conn.Receive()
|
||||
reply, err := receiveWithRetry(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Receive: %v", err)
|
||||
}
|
||||
|
@ -828,7 +828,7 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) {
|
|||
return nil, fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
||||
reply, err := conn.Receive()
|
||||
reply, err := receiveWithRetry(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Receive: %w", err)
|
||||
}
|
||||
|
@ -873,7 +873,7 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
|
|||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := conn.Receive()
|
||||
reply, err := receiveWithRetry(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Receive: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue