From eeaebcf552951a01af8c359bec279367468aca28 Mon Sep 17 00:00:00 2001 From: TheDiveO <6920158+thediveo@users.noreply.github.com> Date: Mon, 9 May 2022 13:25:29 +0200 Subject: [PATCH] add New constructor (with options functions, such as lasting connection) * Close receiver for lasting netlink connections while defaulting to existing temporary netlink connection usage * add unit test for New lasting connection, Close and correct default connection handling behavior * refactor tests to use New constructor * make Conn mutex un-exported (#159) fixes issue #157 --- chain.go | 16 +- conn.go | 137 +++++++++++++++-- nftables_test.go | 386 ++++++++++++++++++++++++++++++++--------------- obj.go | 12 +- rule.go | 12 +- set.go | 32 ++-- table.go | 17 +-- 7 files changed, 440 insertions(+), 172 deletions(-) diff --git a/chain.go b/chain.go index bcc35de..f2bef40 100644 --- a/chain.go +++ b/chain.go @@ -96,8 +96,8 @@ type Chain struct { // AddChain adds the specified Chain. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains func (cc *Conn) AddChain(c *Chain) *Chain { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, @@ -137,8 +137,8 @@ func (cc *Conn) AddChain(c *Chain) *Chain { // DelChain deletes the specified Chain. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Deleting_chains func (cc *Conn) DelChain(c *Chain) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, @@ -156,8 +156,8 @@ func (cc *Conn) DelChain(c *Chain) { // FlushChain removes all rules within the specified Chain. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Flushing_chain func (cc *Conn) FlushChain(c *Chain) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, @@ -173,11 +173,11 @@ func (cc *Conn) FlushChain(c *Chain) { // ListChains returns currently configured chains in the kernel func (cc *Conn) ListChains() ([]*Chain, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - defer conn.Close() + defer func() { _ = closer() }() msg := netlink.Message{ Header: netlink.Header{ diff --git a/conn.go b/conn.go index d7659d7..2b20ea4 100644 --- a/conn.go +++ b/conn.go @@ -32,18 +32,138 @@ import ( // Commands are buffered. Flush sends all buffered commands in a single batch. type Conn struct { TestDial nltest.Func // for testing only; passed to nltest.Dial - NetNS int // Network namespace netlink will interact with. - sync.Mutex + NetNS int // fd referencing the network namespace netlink will interact with. + + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state messages []netlink.Message err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. +} + +// ConnOption is an option to change the behavior of the nftables Conn returned by Open. +type ConnOption func(*Conn) + +// New returns a netlink connection for querying and modifying nftables. Some +// aspects of the new netlink connection can be configured using the options +// WithNetNSFd, WithTestDial, and AsLasting. +// +// A lasting netlink connection should be closed by calling CloseLasting() to +// close the underlying lasting netlink connection, cancelling all pending +// operations using this connection. +func New(opts ...ConnOption) (*Conn, error) { + cc := &Conn{} + for _, opt := range opts { + opt(cc) + } + + if !cc.lasting { + return cc, nil + } + + nlconn, err := cc.dialNetlink() + if err != nil { + return nil, err + } + cc.nlconn = nlconn + return cc, nil +} + +// AsLasting creates the new netlink connection as a lasting connection that is +// reused across multiple netlink operations, instead of opening and closing the +// underlying netlink connection only for the duration of a single netlink +// operation. +func AsLasting() ConnOption { + return func(cc *Conn) { + // We cannot create the underlying connection yet, as we are called + // anywhere in the option processing chain and there might be later + // options still modifying connection behavior. + cc.lasting = true + } +} + +// WithNetNSFd sets the network namespace to create a new netlink connection to: +// the fd must reference a network namespace. +func WithNetNSFd(fd int) ConnOption { + return func(cc *Conn) { + cc.NetNS = fd + } +} + +// WithTestDial sets the specified nltest.Func when creating a new netlink +// connection. +func WithTestDial(f nltest.Func) ConnOption { + return func(cc *Conn) { + cc.TestDial = f + } +} + +// netlinkCloser is returned by netlinkConn(UnderLock) and must be called after +// being done with the returned netlink connection in order to properly close +// this connection, if necessary. +type netlinkCloser func() error + +// netlinkConn returns a netlink connection together with a netlinkCloser that +// later must be called by the caller when it doesn't need the returned netlink +// connection anymore. The netlinkCloser will close the netlink connection when +// necessary. If New has been told to create a lasting connection, then this +// lasting netlink connection will be returned, otherwise a new "transient" +// netlink connection will be opened and returned instead. netlinkConn must not +// be called while the Conn.mu lock is currently helt (this will cause a +// deadlock). Use netlinkConnUnderLock instead in such situations. +func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.netlinkConnUnderLock() +} + +// netlinkConnUnderLock works like netlinkConn but must be called while holding +// the Conn.mu lock. +func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) { + if cc.nlconn != nil { + return cc.nlconn, func() error { return nil }, nil + } + nlconn, err := cc.dialNetlink() + if err != nil { + return nil, nil, err + } + return nlconn, func() error { return nlconn.Close() }, nil +} + +// CloseLasting closes the lasting netlink connection that has been opened using +// AsLasting option when creating this connection. If either no lasting netlink +// connection has been opened or the lasting connection is already in the +// process of closing or has been closed, CloseLasting will immediately return +// without any error. +// +// CloseLasting will terminate all pending netlink operations using the lasting +// connection. +// +// After closing a lasting connection, the connection will revert to using +// on-demand transient netlink connections when calling further netlink +// operations (such as GetTables). +func (cc *Conn) CloseLasting() error { + // Don't acquire the lock for the whole duration of the CloseLasting + // operation, but instead only so long as to make sure to only run the + // netlink socket close on the first time with a lasting netlink socket. As + // there is only the New() constructor, but no Open() method, it's + // impossible to reopen a lasting connection. + cc.mu.Lock() + nlconn := cc.nlconn + cc.nlconn = nil + cc.mu.Unlock() + if nlconn != nil { + return nlconn.Close() + } + return nil } // Flush sends all buffered commands in a single batch to nftables. func (cc *Conn) Flush() error { - cc.Lock() + cc.mu.Lock() defer func() { cc.messages = nil - cc.Unlock() + cc.mu.Unlock() }() if len(cc.messages) == 0 { // Messages were already programmed, returning nil @@ -52,12 +172,11 @@ func (cc *Conn) Flush() error { if cc.err != nil { return cc.err // serialization error } - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConnUnderLock() if err != nil { return err } - - defer conn.Close() + defer func() { _ = closer() }() if _, err := conn.SendMessages(batch(cc.messages)); err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -73,8 +192,8 @@ func (cc *Conn) Flush() error { // FlushRuleset flushes the entire ruleset. See also // https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level func (cc *Conn) FlushRuleset() { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), diff --git a/nftables_test.go b/nftables_test.go index 7ee23de..a91c6fa 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -16,6 +16,7 @@ package nftables_test import ( "bytes" + "errors" "flag" "fmt" "net" @@ -102,7 +103,11 @@ func openSystemNFTConn(t *testing.T) (*nftables.Conn, netns.NsHandle) { if err != nil { t.Fatalf("netns.New() failed: %v", err) } - return &nftables.Conn{NetNS: int(ns)}, ns + c, err := nftables.New(nftables.WithNetNSFd(int(ns))) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + return c, ns } func cleanupSystemNFTConn(t *testing.T, newNS netns.NsHandle) { @@ -288,8 +293,8 @@ func TestConfigureNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -309,7 +314,9 @@ func TestConfigureNAT(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -502,8 +509,8 @@ func TestConfigureNATSourceAddress(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -523,7 +530,9 @@ func TestConfigureNATSourceAddress(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -802,8 +811,8 @@ func TestGetRules(t *testing.T) { []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -825,7 +834,9 @@ func TestGetRules(t *testing.T) { rep := reply[0] reply = reply[1:] return rep, nil - }, + })) + if err != nil { + t.Fatal(err) } rules, err := c.GetRules( @@ -881,8 +892,8 @@ func TestAddCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -902,7 +913,9 @@ func TestAddCounter(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddObj(&nftables.CounterObj{ @@ -945,8 +958,8 @@ func TestDeleteCounter(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -966,7 +979,9 @@ func TestDeleteCounter(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddObj(&nftables.CounterObj{ @@ -996,8 +1011,8 @@ func TestDelRule(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1017,7 +1032,9 @@ func TestDelRule(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.DelRule(&nftables.Rule{ @@ -1041,8 +1058,8 @@ func TestLog(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1062,7 +1079,9 @@ func TestLog(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddRule(&nftables.Rule{ @@ -1091,8 +1110,8 @@ func TestTProxy(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1112,7 +1131,9 @@ func TestTProxy(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddRule(&nftables.Rule{ @@ -1154,8 +1175,8 @@ func TestCt(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1175,7 +1196,9 @@ func TestCt(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddRule(&nftables.Rule{ @@ -1207,8 +1230,8 @@ func TestCtSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1228,7 +1251,9 @@ func TestCtSet(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddRule(&nftables.Rule{ @@ -1265,8 +1290,8 @@ func TestCtStat(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1286,7 +1311,9 @@ func TestCtStat(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddRule(&nftables.Rule{ @@ -1323,8 +1350,8 @@ func TestAddRuleWithPosition(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1344,7 +1371,9 @@ func TestAddRuleWithPosition(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.AddRule(&nftables.Rule{ @@ -1387,6 +1416,73 @@ func TestAddRuleWithPosition(t *testing.T) { } } +func TestLastingConnection(t *testing.T) { + testdialerr := errors.New("test dial sentinel error") + dialCount := 0 + c, err := nftables.New( + nftables.AsLasting(), + nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) { + dialCount++ + return nil, testdialerr + })) + if err != nil { + t.Errorf("creating lasting netlink connection failed %v", err) + return + } + defer func() { + if err := c.CloseLasting(); err != nil { + t.Errorf("closing lasting netlink connection failed %v", err) + } + }() + + _, err = c.ListTables() + if !errors.Is(err, testdialerr) { + t.Errorf("non-testdialerr error returned from TestDial %v", err) + return + } + if dialCount != 1 { + t.Errorf("internal test error with TestDial invocations %v", dialCount) + return + } + + // While a lasting netlink connection is open, replacing TestDial must be + // ineffective as there is no need to dial again and activating a new + // TestDial function. The newly set TestDial function must be getting + // ignored. + c.TestDial = func(req []netlink.Message) ([]netlink.Message, error) { + dialCount-- + return nil, errors.New("transient netlink connection error") + } + _, err = c.ListTables() + if !errors.Is(err, testdialerr) { + t.Errorf("non-testdialerr error returned from TestDial %v", err) + return + } + if dialCount != 2 { + t.Errorf("internal test error with TestDial invocations %v", dialCount) + return + } + + for i := 0; i < 2; i++ { + err = c.CloseLasting() + if err != nil { + t.Errorf("closing lasting netlink connection failed in attempt no. %d: %v", i, err) + return + } + } + _, err = c.ListTables() + if errors.Is(err, testdialerr) { + t.Error("testdialerr error returned from TestDial when expecting different error") + return + } + if dialCount != 1 { + t.Errorf("internal test error with TestDial invocations %v", dialCount) + return + } + + // fall into defer'ed second CloseLasting which must not cause any errors. +} + func TestListChains(t *testing.T) { polDrop := nftables.ChainPolicyDrop polAcpt := nftables.ChainPolicyAccept @@ -1432,8 +1528,8 @@ func TestListChains(t *testing.T) { }, } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { msgReply := make([]netlink.Message, len(reply)) for i, r := range reply { nm := &netlink.Message{} @@ -1443,7 +1539,9 @@ func TestListChains(t *testing.T) { msgReply[i] = *nm } return msgReply, nil - }, + })) + if err != nil { + t.Fatal(err) } chains, err := c.ListChains() @@ -1522,8 +1620,8 @@ func TestAddChain(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1543,7 +1641,9 @@ func TestAddChain(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -1599,8 +1699,8 @@ func TestDelChain(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1620,7 +1720,9 @@ func TestDelChain(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } tt.chain.Table = &nftables.Table{ @@ -1649,8 +1751,8 @@ func TestGetObjReset(t *testing.T) { []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1672,7 +1774,9 @@ func TestGetObjReset(t *testing.T) { rep := reply[0] reply = reply[1:] return rep, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} @@ -1892,8 +1996,8 @@ func TestConfigureClamping(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -1913,7 +2017,9 @@ func TestConfigureClamping(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -2025,8 +2131,8 @@ func TestMatchPacketHeader(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -2046,7 +2152,9 @@ func TestMatchPacketHeader(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -2153,8 +2261,8 @@ func TestDropVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -2174,7 +2282,9 @@ func TestDropVerdict(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -2253,8 +2363,8 @@ func TestCreateUseAnonymousSet(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -2274,7 +2384,9 @@ func TestCreateUseAnonymousSet(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3221,8 +3333,8 @@ func TestConfigureNATRedirect(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3242,7 +3354,9 @@ func TestConfigureNATRedirect(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3326,8 +3440,8 @@ func TestConfigureJumpVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3347,7 +3461,9 @@ func TestConfigureJumpVerdict(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3432,8 +3548,8 @@ func TestConfigureReturnVerdict(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3453,7 +3569,9 @@ func TestConfigureReturnVerdict(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3517,8 +3635,8 @@ func TestConfigureRangePort(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3538,7 +3656,9 @@ func TestConfigureRangePort(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3615,8 +3735,8 @@ func TestConfigureRangeIPv4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3636,7 +3756,9 @@ func TestConfigureRangeIPv4(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3705,8 +3827,8 @@ func TestConfigureRangeIPv6(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3726,7 +3848,9 @@ func TestConfigureRangeIPv6(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -3818,8 +3942,8 @@ func TestSet4(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -3839,7 +3963,9 @@ func TestSet4(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } tbl := &nftables.Table{ @@ -4004,8 +4130,8 @@ func TestMasq(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4025,7 +4151,9 @@ func TestMasq(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -4135,8 +4263,8 @@ func TestReject(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4156,7 +4284,9 @@ func TestReject(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -4263,8 +4393,8 @@ func TestFib(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4284,7 +4414,9 @@ func TestFib(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -4367,8 +4499,8 @@ func TestNumgen(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4388,7 +4520,9 @@ func TestNumgen(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -4452,8 +4586,8 @@ func TestMap(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4473,7 +4607,9 @@ func TestMap(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -4570,8 +4706,8 @@ func TestVmap(t *testing.T) { } for _, tt := range tests { - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4591,7 +4727,9 @@ func TestVmap(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } filter := c.AddTable(&nftables.Table{ @@ -4630,8 +4768,8 @@ func TestJHash(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4651,7 +4789,9 @@ func TestJHash(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -4731,8 +4871,8 @@ func TestDup(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4752,7 +4892,9 @@ func TestDup(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -4831,8 +4973,8 @@ func TestDupWoDev(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4852,7 +4994,9 @@ func TestDupWoDev(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -4913,8 +5057,8 @@ func TestNotrack(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -4934,7 +5078,9 @@ func TestNotrack(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -4983,8 +5129,8 @@ func TestQuota(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -5004,7 +5150,9 @@ func TestQuota(t *testing.T) { } } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() @@ -5058,8 +5206,8 @@ func TestStatelessNAT(t *testing.T) { []byte("\x00\x00\x00\x0a"), } - c := &nftables.Conn{ - TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { b, err := msg.MarshalBinary() if err != nil { @@ -5079,7 +5227,9 @@ func TestStatelessNAT(t *testing.T) { want = want[1:] } return req, nil - }, + })) + if err != nil { + t.Fatal(err) } c.FlushRuleset() diff --git a/obj.go b/obj.go index a88aed4..3fd01e2 100644 --- a/obj.go +++ b/obj.go @@ -41,8 +41,8 @@ func (cc *Conn) AddObject(o Obj) Obj { // AddObj adds the specified Obj. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects func (cc *Conn) AddObj(o Obj) Obj { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data, err := o.marshal(true) if err != nil { cc.setErr(err) @@ -61,8 +61,8 @@ func (cc *Conn) AddObj(o Obj) Obj { // DeleteObject deletes the specified Obj func (cc *Conn) DeleteObject(o Obj) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data, err := o.marshal(false) if err != nil { cc.setErr(err) @@ -174,11 +174,11 @@ func objFromMsg(msg netlink.Message) (Obj, error) { } func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - defer conn.Close() + defer func() { _ = closer() }() var data []byte var flags netlink.HeaderFlags diff --git a/rule.go b/rule.go index be48792..f2057ee 100644 --- a/rule.go +++ b/rule.go @@ -60,11 +60,11 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { // GetRules returns the rules in the specified table and chain. func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - defer conn.Close() + defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, @@ -107,8 +107,8 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { // AddRule adds the specified Rule func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() exprAttrs := make([]netlink.Attribute, len(r.Exprs)) for idx, expr := range r.Exprs { exprAttrs[idx] = netlink.Attribute{ @@ -190,8 +190,8 @@ func (cc *Conn) InsertRule(r *Rule) *Rule { // DelRule deletes the specified Rule, rule's handle cannot be 0 func (cc *Conn) DelRule(r *Rule) error { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")}, diff --git a/set.go b/set.go index 58ac250..1f1f777 100644 --- a/set.go +++ b/set.go @@ -318,8 +318,8 @@ func decodeElement(d []byte) ([]byte, error) { // SetAddElements applies data points to an nftables set. func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() if s.Anonymous { return errors.New("anonymous sets cannot be updated") } @@ -431,8 +431,8 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e // AddSet adds the specified Set. func (cc *Conn) AddSet(s *Set, vals []SetElement) error { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() // Based on nft implementation & linux source. // Link: https://github.com/torvalds/linux/blob/49a57857aeea06ca831043acbb0fa5e0f50602fd/net/netfilter/nf_tables_api.c#L3395 // Another reference: https://git.netfilter.org/nftables/tree/src @@ -567,8 +567,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { // DelSet deletes a specific set, along with all elements it contains. func (cc *Conn) DelSet(s *Set) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, @@ -584,8 +584,8 @@ func (cc *Conn) DelSet(s *Set) { // SetDeleteElements deletes data points from an nftables set. func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() if s.Anonymous { return errors.New("anonymous sets cannot be updated") } @@ -607,8 +607,8 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { // FlushSet deletes all data points from an nftables set. func (cc *Conn) FlushSet(s *Set) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, @@ -747,11 +747,11 @@ func elementsFromMsg(msg netlink.Message) ([]SetElement, error) { // GetSets returns the sets in the specified table. func (cc *Conn) GetSets(t *Table) ([]*Set, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - defer conn.Close() + defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, @@ -791,11 +791,11 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) { // GetSetByName returns the set in the specified table if matching name is found. func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - defer conn.Close() + defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, @@ -836,11 +836,11 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { // GetSetElements returns the elements in the specified set. func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - defer conn.Close() + defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, diff --git a/table.go b/table.go index da0126a..d7ff330 100644 --- a/table.go +++ b/table.go @@ -47,8 +47,8 @@ type Table struct { // DelTable deletes a specific table, along with all chains/rules it contains. func (cc *Conn) DelTable(t *Table) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, @@ -65,8 +65,8 @@ func (cc *Conn) DelTable(t *Table) { // AddTable adds the specified Table. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables func (cc *Conn) AddTable(t *Table) *Table { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, @@ -84,8 +84,8 @@ func (cc *Conn) AddTable(t *Table) *Table { // FlushTable removes all rules in all chains within the specified Table. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables#Flushing_tables func (cc *Conn) FlushTable(t *Table) { - cc.Lock() - defer cc.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) @@ -100,12 +100,11 @@ func (cc *Conn) FlushTable(t *Table) { // ListTables returns currently configured tables in the kernel func (cc *Conn) ListTables() ([]*Table, error) { - conn, err := cc.dialNetlink() + conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } - - defer conn.Close() + defer func() { _ = closer() }() msg := netlink.Message{ Header: netlink.Header{