From 5af35a87425ffda680d2fb21b3b74b59c6ea761e Mon Sep 17 00:00:00 2001 From: nickgarlis Date: Wed, 26 Mar 2025 23:27:19 +0100 Subject: [PATCH] Add TestGetMessageSize --- nftables_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/nftables_test.go b/nftables_test.go index fd0fa1a..75f187a 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7396,3 +7396,77 @@ func TestSetElementComment(t *testing.T) { } } } + +func TestGetMessageSize(t *testing.T) { + // Use generous socket buffer sizes + writeBufSize := 1 * 1024 * 1024 + readBufSize := 1 * 1024 * 1024 + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) + conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.WithSockOptions(func(conn *netlink.Conn) error { + if err := conn.SetWriteBuffer(writeBufSize); err != nil { + return err + } + if err := conn.SetReadBuffer(readBufSize); err != nil { + return err + } + return nil + })) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + want := uint32(20) + size := conn.GetMessageSize() + if size != want { + t.Fatalf("got message size %d, want %d", size, want) + } + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "test-table", + }) + want += 44 // Table adds 44 bytes + size = conn.GetMessageSize() + if size != want { + t.Fatalf("got message size %d, want %d", size, want) + } + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-message", + Table: table, + }) + want += 56 // Chain adds 56 bytes + size = conn.GetMessageSize() + if size != want { + t.Fatalf("got message size %d, want %d", size, want) + } + + for i := 0; i < 2048; i++ { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + want += 116 // Rule adds 116 bytes + size = conn.GetMessageSize() + if size != want { + t.Fatalf("got message size %d, want %d", size, want) + } + + // Do not use over half of the buffer before flushing + if size > 32768/2 { + want = 0 + err := conn.Flush() + if err != nil { + t.Fatalf("failed to flush: %v", err) + } + } + } +}