From 4b39769321221adef3b4ad79e984b408533acf0e Mon Sep 17 00:00:00 2001 From: psondej Date: Mon, 25 Nov 2024 20:10:10 +0100 Subject: [PATCH 1/2] fix: resolve deadlock in `Flush` function when handling ENOBUFS error --- conn.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 25d88e0..fef9c2a 100644 --- a/conn.go +++ b/conn.go @@ -19,6 +19,7 @@ import ( "fmt" "os" "sync" + "syscall" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -266,8 +267,8 @@ func (cc *Conn) Flush() error { // Fetch the requested acknowledgement for each message we sent. for _, msg := range cc.messages { if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { - if errors.Is(err, os.ErrPermission) { - // Kernel will only send one permission error to user space. + if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { + // Kernel will only send one error to user space. return err } errs = errors.Join(errs, err) From 198b2be13555e0ec868fab754d0c0eeaaa9eab15 Mon Sep 17 00:00:00 2001 From: psondej Date: Mon, 25 Nov 2024 20:11:05 +0100 Subject: [PATCH 2/2] Simulate deadlock issue using reduced read/write buffers to verify the fix and ensure no regressions --- nftables_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/nftables_test.go b/nftables_test.go index b241327..c537cdb 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -23,6 +23,7 @@ import ( "os" "reflect" "strings" + "syscall" "testing" "time" @@ -7666,3 +7667,82 @@ func TestNftablesCompat(t *testing.T) { t.Fatalf("compat policy should conflict and err should not be err") } } + +func TestNftablesDeadlock(t *testing.T) { + helperConn := func(t *testing.T, readBufSize, writeBufSize, wantRules int) (error, int) { + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) + conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.WithSockOptions(func(conn *netlink.Conn) error { + if err := conn.SetWriteBuffer(writeBufSize); err != nil { + return err + } + if err := conn.SetReadBuffer(readBufSize); err != nil { + return err + } + return nil + })) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test_deadlock", + Family: nftables.TableFamilyIPv4, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "filter", + Table: table, + }) + + for i := 0; i < wantRules; i++ { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + flushErr := conn.Flush() + + rules, err := conn.GetRules(table, chain) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + + return flushErr, len(rules) + } + + t.Run("recv", func(t *testing.T) { + sendRules := 2048 + wantRules := 2048 + + flushErr, rulesLen := helperConn(t, 1024, 1*1024*1024, sendRules) + if !errors.Is(flushErr, syscall.ENOBUFS) { + t.Errorf("conn.Flush() failed: %v", flushErr) + } + + if got, want := rulesLen, wantRules; got != want { + t.Fatalf("got rules %d, want rules %d", got, want) + } + }) + t.Run("send", func(t *testing.T) { + sendRules := 2048 + wantRules := 0 + + flushErr, rulesLen := helperConn(t, 1*1024*1024, 1024, sendRules) + if !errors.Is(flushErr, syscall.EMSGSIZE) { + t.Errorf("conn.Flush() failed: %v", flushErr) + } + + if got, want := rulesLen, wantRules; got != want { + t.Fatalf("got rules %d, want rules %d", got, want) + } + }) +}