From dc9df31dfaf32bd96579763ed05e7ceaba37d651 Mon Sep 17 00:00:00 2001 From: nickgarlis Date: Sat, 5 Apr 2025 13:59:41 +0200 Subject: [PATCH] Automatically set socket read & write buffer sizes This is an attempt to port the logic that nftables uses to automatically adjust the recvmsg & sndmsg buffer sizes. The implementation of setting sndmsg size is not the same as nftables due to some limitations in the underlying netlink & socket libraries. - https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262 - https://git.netfilter.org/libmnl/tree/include/libmnl/libmnl.h?id=03da98bcd284d55212bc79e91dfb63da0ef7b937#n20 - https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391 - https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426 --- conn.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++ nftables_test.go | 34 ++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/conn.go b/conn.go index d974b80..27ad93d 100644 --- a/conn.go +++ b/conn.go @@ -250,6 +250,15 @@ func (cc *Conn) Flush() error { } defer func() { _ = closer() }() + err = cc.setWriteBuffer(conn) + if err != nil { + return err + } + err = cc.setReadBuffer(conn) + if err != nil { + return err + } + messages, err := conn.SendMessages(batch(cc.messages)) if err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -423,3 +432,84 @@ func (cc *Conn) allocateTransactionID() uint32 { } return cc.lastID } + +// getMessageSize returns the total size of all messages in the buffer. +func (cc *Conn) getMessageSize() int { + var total int + for _, msg := range cc.messages { + total += len(msg.Data) + 16 // 16 bytes for the header + } + return total +} + +// setWriteBuffer automatically sets the write buffer of the given connection to +// the accumulated message size. This is only done if the current write buffer +// is smaller than the message size. +// +// nftables actually handles this differently, it multiplies the number of +// iovec entries by 2MB. This is not possible to do here as our underlying +// netlink and socket libraries will only add a single iovec entry and +// won't expose the number of entries. +// https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262 +// +// TODO: Update this function to mimic the behavior of nftables once those +// limitations are no longer present. +func (cc *Conn) setWriteBuffer(conn *netlink.Conn) error { + messageSize := cc.getMessageSize() + writeBuffer, err := conn.WriteBuffer() + if err != nil { + return err + } + if writeBuffer < messageSize { + return conn.SetWriteBuffer(messageSize) + } + + return nil +} + +// getDefaultEchoReadBuffer returns the minimum read buffer size for batches +// with echo messages. +// +// See https://git.netfilter.org/libmnl/tree/include/libmnl/libmnl.h?id=03da98bcd284d55212bc79e91dfb63da0ef7b937#n20 +// and https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391 +func (cc *Conn) getDefaultEchoReadBuffer() int { + pageSize := os.Getpagesize() + if pageSize < 8192 { + return pageSize * 1024 + } + + return 8192 * 1024 +} + +// setReadBuffer automatically sets the read buffer of the given connection +// to the required size. This is only done if the current read buffer is smaller +// than the required size. +// +// See https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426 +func (cc *Conn) setReadBuffer(conn *netlink.Conn) error { + var bufferSize int + + // If there are any messages with the Echo flag, we initialize the buffer size + // to the default echo read buffer size. + for _, msg := range cc.messages { + if msg.Header.Flags&netlink.Echo == 0 { + bufferSize = cc.getDefaultEchoReadBuffer() + break + } + } + + // Just like nftables, we allocate 1024 bytes for each message in the batch. + requiredSize := len(cc.messages) * 1024 + if bufferSize < requiredSize { + bufferSize = requiredSize + } + + currSize, err := conn.ReadBuffer() + if err != nil { + return err + } + if currSize < bufferSize { + return conn.SetReadBuffer(bufferSize) + } + return nil +} diff --git a/nftables_test.go b/nftables_test.go index fd0fa1a..263c4a3 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7396,3 +7396,37 @@ func TestSetElementComment(t *testing.T) { } } } + +func TestAutoBufferSize(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "test-table", + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + for i := 0; i < 4096; i++ { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + err := conn.Flush() + if err != nil { + t.Fatalf("failed to flush: %v", err) + } +}