Compare commits

...

5 Commits

Author SHA1 Message Date
nickgarlis 8e29fbaa6e Add note about different behavior when using WithSockOptions 2025-07-01 09:57:38 +02:00
nickgarlis f8c01f0bf3 Address review comments 2025-06-30 21:28:35 +02:00
nickgarlis 0a7196fb65 Prevent buffer enlargement in specific scenarios
We should not enlarge the socket buffers when:
 - We are using a test dial (there are no buffers to enlarge).
 - A connection has been initialized with socket options which means
   that the user could have specified fixed buffer sizes.
2025-06-30 21:07:25 +02:00
nickgarlis 7d83c94f64 Address review comments 2025-06-28 18:00:17 +02:00
nickgarlis dc9df31dfa Automatically set socket read & write buffer sizes
This is an attempt to port the logic that nftables uses to automatically
adjust the recvmsg & sndmsg buffer sizes. The implementation of setting
sndmsg size is not the same as nftables due to some limitations in the
underlying netlink & socket libraries.

- https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262
- https://git.netfilter.org/libmnl/tree/include/libmnl/libmnl.h?id=03da98bcd284d55212bc79e91dfb63da0ef7b937#n20
- https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391
- https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426
2025-04-05 13:59:46 +02:00
2 changed files with 144 additions and 1 deletions

112
conn.go
View File

@ -116,7 +116,9 @@ func WithTestDial(f nltest.Func) ConnOption {
} }
// WithSockOptions sets the specified socket options when creating a new netlink // WithSockOptions sets the specified socket options when creating a new netlink
// connection. // connection. Note that when using WithSockOptions, you are responsible for
// providing a large-enough read and write buffer, whereas normally, the
// nftables package automatically enlarges the buffers as needed.
func WithSockOptions(opts ...SockOption) ConnOption { func WithSockOptions(opts ...SockOption) ConnOption {
return func(cc *Conn) { return func(cc *Conn) {
cc.sockOptions = append(cc.sockOptions, opts...) cc.sockOptions = append(cc.sockOptions, opts...)
@ -250,6 +252,13 @@ func (cc *Conn) Flush() error {
} }
defer func() { _ = closer() }() defer func() { _ = closer() }()
if err := cc.enlargeWriteBuffer(conn); err != nil {
return err
}
if err := cc.enlargeReadBuffer(conn); err != nil {
return err
}
messages, err := conn.SendMessages(batch(cc.messages)) messages, err := conn.SendMessages(batch(cc.messages))
if err != nil { if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
@ -423,3 +432,104 @@ func (cc *Conn) allocateTransactionID() uint32 {
} }
return cc.lastID return cc.lastID
} }
// getMessageSize returns the total size of all messages in the buffer.
func (cc *Conn) getMessageSize() int {
var total int
for _, msg := range cc.messages {
total += len(msg.Data) + unix.NLMSG_HDRLEN
}
return total
}
// canEnlargeBuffers returns true if the connection can automatically enlarge
// the write and read buffers of the netlink connection.
func (cc *Conn) canEnlargeBuffers() bool {
// If there are sock options, we assume that the user has already set the
// buffers to a fixed size.
if len(cc.sockOptions) > 0 {
return false
}
if cc.TestDial != nil {
return false
}
return true
}
// enlargeWriteBuffer automatically sets the write buffer of the given
// connection to the accumulated message size. This is only done if the current
// write buffer is smaller than the message size.
//
// nftables actually handles this differently, it multiplies the number of
// iovec entries by 2MB. This is not possible to do here as our underlying
// netlink and socket libraries will only add a single iovec entry and
// won't expose the number of entries.
// https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262
//
// TODO: Update this function to mimic the behavior of nftables once our
// socket library supports multiple iovec entries.
func (cc *Conn) enlargeWriteBuffer(conn *netlink.Conn) error {
if !cc.canEnlargeBuffers() {
return nil
}
messageSize := cc.getMessageSize()
writeBuffer, err := conn.WriteBuffer()
if err != nil {
return err
}
if writeBuffer < messageSize {
return conn.SetWriteBuffer(messageSize)
}
return nil
}
// getDefaultEchoReadBuffer returns the minimum read buffer size for batches
// with echo messages.
//
// See https://git.netfilter.org/libmnl/tree/include/libmnl/libmnl.h?id=03da98bcd284d55212bc79e91dfb63da0ef7b937#n20
// and https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391
func (cc *Conn) getDefaultEchoReadBuffer() int {
pageSize := os.Getpagesize()
return max(pageSize, 8192) * 1024
}
// enlargeReadBuffer automatically sets the read buffer of the given connection
// to the required size. This is only done if the current read buffer is smaller
// than the required size.
//
// See https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426
func (cc *Conn) enlargeReadBuffer(conn *netlink.Conn) error {
if !cc.canEnlargeBuffers() {
return nil
}
var bufferSize int
// If there are any messages with the Echo flag, we initialize the buffer size
// to the default echo read buffer size.
for _, msg := range cc.messages {
if msg.Header.Flags&netlink.Echo == 0 {
bufferSize = cc.getDefaultEchoReadBuffer()
break
}
}
// Just like nftables, we allocate 1024 bytes for each message in the batch.
requiredSize := len(cc.messages) * 1024
if bufferSize < requiredSize {
bufferSize = requiredSize
}
currSize, err := conn.ReadBuffer()
if err != nil {
return err
}
if currSize < bufferSize {
return conn.SetReadBuffer(bufferSize)
}
return nil
}

View File

@ -7396,3 +7396,36 @@ func TestSetElementComment(t *testing.T) {
} }
} }
} }
func TestAutoBufferSize(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table := conn.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "test-table",
})
chain := conn.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
})
for range 4096 {
conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
}