diff --git a/conn.go b/conn.go index d974b80..c6c85bb 100644 --- a/conn.go +++ b/conn.go @@ -116,7 +116,9 @@ func WithTestDial(f nltest.Func) ConnOption { } // WithSockOptions sets the specified socket options when creating a new netlink -// connection. +// connection. Note that when using WithSockOptions, you are responsible for +// providing a large-enough read and write buffer, whereas normally, the +// nftables package automatically enlarges the buffers as needed. func WithSockOptions(opts ...SockOption) ConnOption { return func(cc *Conn) { cc.sockOptions = append(cc.sockOptions, opts...) @@ -250,6 +252,13 @@ func (cc *Conn) Flush() error { } defer func() { _ = closer() }() + if err := cc.enlargeWriteBuffer(conn); err != nil { + return err + } + if err := cc.enlargeReadBuffer(conn); err != nil { + return err + } + messages, err := conn.SendMessages(batch(cc.messages)) if err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -423,3 +432,104 @@ func (cc *Conn) allocateTransactionID() uint32 { } return cc.lastID } + +// getMessageSize returns the total size of all messages in the buffer. +func (cc *Conn) getMessageSize() int { + var total int + for _, msg := range cc.messages { + total += len(msg.Data) + unix.NLMSG_HDRLEN + } + return total +} + +// canEnlargeBuffers returns true if the connection can automatically enlarge +// the write and read buffers of the netlink connection. +func (cc *Conn) canEnlargeBuffers() bool { + // If there are sock options, we assume that the user has already set the + // buffers to a fixed size. + if len(cc.sockOptions) > 0 { + return false + } + + if cc.TestDial != nil { + return false + } + + return true +} + +// enlargeWriteBuffer automatically sets the write buffer of the given +// connection to the accumulated message size. This is only done if the current +// write buffer is smaller than the message size. +// +// nftables actually handles this differently, it multiplies the number of +// iovec entries by 2MB. This is not possible to do here as our underlying +// netlink and socket libraries will only add a single iovec entry and +// won't expose the number of entries. +// https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262 +// +// TODO: Update this function to mimic the behavior of nftables once our +// socket library supports multiple iovec entries. +func (cc *Conn) enlargeWriteBuffer(conn *netlink.Conn) error { + if !cc.canEnlargeBuffers() { + return nil + } + + messageSize := cc.getMessageSize() + writeBuffer, err := conn.WriteBuffer() + if err != nil { + return err + } + if writeBuffer < messageSize { + return conn.SetWriteBuffer(messageSize) + } + + return nil +} + +// getDefaultEchoReadBuffer returns the minimum read buffer size for batches +// with echo messages. +// +// See https://git.netfilter.org/libmnl/tree/include/libmnl/libmnl.h?id=03da98bcd284d55212bc79e91dfb63da0ef7b937#n20 +// and https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391 +func (cc *Conn) getDefaultEchoReadBuffer() int { + pageSize := os.Getpagesize() + return max(pageSize, 8192) * 1024 +} + +// enlargeReadBuffer automatically sets the read buffer of the given connection +// to the required size. This is only done if the current read buffer is smaller +// than the required size. +// +// See https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426 +func (cc *Conn) enlargeReadBuffer(conn *netlink.Conn) error { + if !cc.canEnlargeBuffers() { + return nil + } + + var bufferSize int + + // If there are any messages with the Echo flag, we initialize the buffer size + // to the default echo read buffer size. + for _, msg := range cc.messages { + if msg.Header.Flags&netlink.Echo == 0 { + bufferSize = cc.getDefaultEchoReadBuffer() + break + } + } + + // Just like nftables, we allocate 1024 bytes for each message in the batch. + requiredSize := len(cc.messages) * 1024 + if bufferSize < requiredSize { + bufferSize = requiredSize + } + + currSize, err := conn.ReadBuffer() + if err != nil { + return err + } + if currSize < bufferSize { + return conn.SetReadBuffer(bufferSize) + } + return nil +} diff --git a/go.mod b/go.mod index 543f5d0..7f1707f 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,14 @@ go 1.23.0 require ( github.com/google/go-cmp v0.6.0 - github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 + github.com/mdlayher/netlink v1.7.3-0.20250702063131-0f7746f74615 github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netns v0.0.4 golang.org/x/sys v0.31.0 ) require ( - github.com/mdlayher/socket v0.5.0 // indirect + github.com/mdlayher/socket v0.5.1 // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.6.0 // indirect ) diff --git a/go.sum b/go.sum index 0c8b25d..ec8b934 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,9 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= -github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= -github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= -github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= +github.com/mdlayher/netlink v1.7.3-0.20250702063131-0f7746f74615 h1:5T2ai+PpYFKe+tyNj/ZxePZGiYoG5xDOylT30nywJUU= +github.com/mdlayher/netlink v1.7.3-0.20250702063131-0f7746f74615/go.mod h1:ZlWrPUV9wyD64k5skWrIv4WDQmmiUbkNnCkBtEWYKwU= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= diff --git a/nftables_test.go b/nftables_test.go index fd0fa1a..fe12566 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7396,3 +7396,36 @@ func TestSetElementComment(t *testing.T) { } } } + +func TestAutoBufferSize(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "test-table", + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + for range 4096 { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } +}