From 319e79247e2dac5e006488954afed51832a112db Mon Sep 17 00:00:00 2001 From: nickgarlis Date: Sat, 28 Jun 2025 17:56:55 +0200 Subject: [PATCH] Address review comments --- conn.go | 30 ++++++++++++------------------ nftables_test.go | 5 ++--- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/conn.go b/conn.go index 27ad93d..73eeeec 100644 --- a/conn.go +++ b/conn.go @@ -250,12 +250,10 @@ func (cc *Conn) Flush() error { } defer func() { _ = closer() }() - err = cc.setWriteBuffer(conn) - if err != nil { + if err = cc.enlargeWriteBuffer(conn); err != nil { return err } - err = cc.setReadBuffer(conn) - if err != nil { + if err = cc.enlargeReadBuffer(conn); err != nil { return err } @@ -437,14 +435,14 @@ func (cc *Conn) allocateTransactionID() uint32 { func (cc *Conn) getMessageSize() int { var total int for _, msg := range cc.messages { - total += len(msg.Data) + 16 // 16 bytes for the header + total += len(msg.Data) + unix.NLMSG_HDRLEN } return total } -// setWriteBuffer automatically sets the write buffer of the given connection to -// the accumulated message size. This is only done if the current write buffer -// is smaller than the message size. +// enlargeWriteBuffer automatically sets the write buffer of the given +// connection to the accumulated message size. This is only done if the current +// write buffer is smaller than the message size. // // nftables actually handles this differently, it multiplies the number of // iovec entries by 2MB. This is not possible to do here as our underlying @@ -452,9 +450,9 @@ func (cc *Conn) getMessageSize() int { // won't expose the number of entries. // https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262 // -// TODO: Update this function to mimic the behavior of nftables once those -// limitations are no longer present. -func (cc *Conn) setWriteBuffer(conn *netlink.Conn) error { +// TODO: Update this function to mimic the behavior of nftables once our +// socket library supports multiple iovec entries. +func (cc *Conn) enlargeWriteBuffer(conn *netlink.Conn) error { messageSize := cc.getMessageSize() writeBuffer, err := conn.WriteBuffer() if err != nil { @@ -474,19 +472,15 @@ func (cc *Conn) setWriteBuffer(conn *netlink.Conn) error { // and https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391 func (cc *Conn) getDefaultEchoReadBuffer() int { pageSize := os.Getpagesize() - if pageSize < 8192 { - return pageSize * 1024 - } - - return 8192 * 1024 + return max(pageSize, 8192) * 1024 } -// setReadBuffer automatically sets the read buffer of the given connection +// enlargeReadBuffer automatically sets the read buffer of the given connection // to the required size. This is only done if the current read buffer is smaller // than the required size. // // See https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426 -func (cc *Conn) setReadBuffer(conn *netlink.Conn) error { +func (cc *Conn) enlargeReadBuffer(conn *netlink.Conn) error { var bufferSize int // If there are any messages with the Echo flag, we initialize the buffer size diff --git a/nftables_test.go b/nftables_test.go index 263c4a3..fe12566 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7413,7 +7413,7 @@ func TestAutoBufferSize(t *testing.T) { Table: table, }) - for i := 0; i < 4096; i++ { + for range 4096 { conn.AddRule(&nftables.Rule{ Table: table, Chain: chain, @@ -7425,8 +7425,7 @@ func TestAutoBufferSize(t *testing.T) { }) } - err := conn.Flush() - if err != nil { + if err := conn.Flush(); err != nil { t.Fatalf("failed to flush: %v", err) } }