add New constructor (with options functions, such as lasting connection)

* Close receiver for lasting netlink connections while defaulting to existing temporary netlink connection usage
* add unit test for New lasting connection, Close and correct default connection handling behavior
* refactor tests to use New constructor
* make Conn mutex un-exported (#159)

fixes issue #157
This commit is contained in:
TheDiveO 2022-05-09 13:25:29 +02:00 committed by GitHub
parent 85d0f3a0db
commit eeaebcf552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 440 additions and 172 deletions

View File

@ -96,8 +96,8 @@ type Chain struct {
// AddChain adds the specified Chain. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains
func (cc *Conn) AddChain(c *Chain) *Chain {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
@ -137,8 +137,8 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
// DelChain deletes the specified Chain. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Deleting_chains
func (cc *Conn) DelChain(c *Chain) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
@ -156,8 +156,8 @@ func (cc *Conn) DelChain(c *Chain) {
// FlushChain removes all rules within the specified Chain. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Flushing_chain
func (cc *Conn) FlushChain(c *Chain) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
@ -173,11 +173,11 @@ func (cc *Conn) FlushChain(c *Chain) {
// ListChains returns currently configured chains in the kernel
func (cc *Conn) ListChains() ([]*Chain, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
msg := netlink.Message{
Header: netlink.Header{

137
conn.go
View File

@ -32,18 +32,138 @@ import (
// Commands are buffered. Flush sends all buffered commands in a single batch.
type Conn struct {
TestDial nltest.Func // for testing only; passed to nltest.Dial
NetNS int // Network namespace netlink will interact with.
sync.Mutex
NetNS int // fd referencing the network namespace netlink will interact with.
lasting bool // establish a lasting connection to be used across multiple netlink operations.
mu sync.Mutex // protects the following state
messages []netlink.Message
err error
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
}
// ConnOption is an option to change the behavior of the nftables Conn returned by Open.
type ConnOption func(*Conn)
// New returns a netlink connection for querying and modifying nftables. Some
// aspects of the new netlink connection can be configured using the options
// WithNetNSFd, WithTestDial, and AsLasting.
//
// A lasting netlink connection should be closed by calling CloseLasting() to
// close the underlying lasting netlink connection, cancelling all pending
// operations using this connection.
func New(opts ...ConnOption) (*Conn, error) {
cc := &Conn{}
for _, opt := range opts {
opt(cc)
}
if !cc.lasting {
return cc, nil
}
nlconn, err := cc.dialNetlink()
if err != nil {
return nil, err
}
cc.nlconn = nlconn
return cc, nil
}
// AsLasting creates the new netlink connection as a lasting connection that is
// reused across multiple netlink operations, instead of opening and closing the
// underlying netlink connection only for the duration of a single netlink
// operation.
func AsLasting() ConnOption {
return func(cc *Conn) {
// We cannot create the underlying connection yet, as we are called
// anywhere in the option processing chain and there might be later
// options still modifying connection behavior.
cc.lasting = true
}
}
// WithNetNSFd sets the network namespace to create a new netlink connection to:
// the fd must reference a network namespace.
func WithNetNSFd(fd int) ConnOption {
return func(cc *Conn) {
cc.NetNS = fd
}
}
// WithTestDial sets the specified nltest.Func when creating a new netlink
// connection.
func WithTestDial(f nltest.Func) ConnOption {
return func(cc *Conn) {
cc.TestDial = f
}
}
// netlinkCloser is returned by netlinkConn(UnderLock) and must be called after
// being done with the returned netlink connection in order to properly close
// this connection, if necessary.
type netlinkCloser func() error
// netlinkConn returns a netlink connection together with a netlinkCloser that
// later must be called by the caller when it doesn't need the returned netlink
// connection anymore. The netlinkCloser will close the netlink connection when
// necessary. If New has been told to create a lasting connection, then this
// lasting netlink connection will be returned, otherwise a new "transient"
// netlink connection will be opened and returned instead. netlinkConn must not
// be called while the Conn.mu lock is currently helt (this will cause a
// deadlock). Use netlinkConnUnderLock instead in such situations.
func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.netlinkConnUnderLock()
}
// netlinkConnUnderLock works like netlinkConn but must be called while holding
// the Conn.mu lock.
func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
if cc.nlconn != nil {
return cc.nlconn, func() error { return nil }, nil
}
nlconn, err := cc.dialNetlink()
if err != nil {
return nil, nil, err
}
return nlconn, func() error { return nlconn.Close() }, nil
}
// CloseLasting closes the lasting netlink connection that has been opened using
// AsLasting option when creating this connection. If either no lasting netlink
// connection has been opened or the lasting connection is already in the
// process of closing or has been closed, CloseLasting will immediately return
// without any error.
//
// CloseLasting will terminate all pending netlink operations using the lasting
// connection.
//
// After closing a lasting connection, the connection will revert to using
// on-demand transient netlink connections when calling further netlink
// operations (such as GetTables).
func (cc *Conn) CloseLasting() error {
// Don't acquire the lock for the whole duration of the CloseLasting
// operation, but instead only so long as to make sure to only run the
// netlink socket close on the first time with a lasting netlink socket. As
// there is only the New() constructor, but no Open() method, it's
// impossible to reopen a lasting connection.
cc.mu.Lock()
nlconn := cc.nlconn
cc.nlconn = nil
cc.mu.Unlock()
if nlconn != nil {
return nlconn.Close()
}
return nil
}
// Flush sends all buffered commands in a single batch to nftables.
func (cc *Conn) Flush() error {
cc.Lock()
cc.mu.Lock()
defer func() {
cc.messages = nil
cc.Unlock()
cc.mu.Unlock()
}()
if len(cc.messages) == 0 {
// Messages were already programmed, returning nil
@ -52,12 +172,11 @@ func (cc *Conn) Flush() error {
if cc.err != nil {
return cc.err // serialization error
}
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConnUnderLock()
if err != nil {
return err
}
defer conn.Close()
defer func() { _ = closer() }()
if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
return fmt.Errorf("SendMessages: %w", err)
@ -73,8 +192,8 @@ func (cc *Conn) Flush() error {
// FlushRuleset flushes the entire ruleset. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
func (cc *Conn) FlushRuleset() {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
cc.messages = append(cc.messages, netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),

View File

@ -16,6 +16,7 @@ package nftables_test
import (
"bytes"
"errors"
"flag"
"fmt"
"net"
@ -102,7 +103,11 @@ func openSystemNFTConn(t *testing.T) (*nftables.Conn, netns.NsHandle) {
if err != nil {
t.Fatalf("netns.New() failed: %v", err)
}
return &nftables.Conn{NetNS: int(ns)}, ns
c, err := nftables.New(nftables.WithNetNSFd(int(ns)))
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
return c, ns
}
func cleanupSystemNFTConn(t *testing.T, newNS netns.NsHandle) {
@ -288,8 +293,8 @@ func TestConfigureNAT(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -309,7 +314,9 @@ func TestConfigureNAT(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -502,8 +509,8 @@ func TestConfigureNATSourceAddress(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -523,7 +530,9 @@ func TestConfigureNATSourceAddress(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -802,8 +811,8 @@ func TestGetRules(t *testing.T) {
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}},
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -825,7 +834,9 @@ func TestGetRules(t *testing.T) {
rep := reply[0]
reply = reply[1:]
return rep, nil
},
}))
if err != nil {
t.Fatal(err)
}
rules, err := c.GetRules(
@ -881,8 +892,8 @@ func TestAddCounter(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -902,7 +913,9 @@ func TestAddCounter(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddObj(&nftables.CounterObj{
@ -945,8 +958,8 @@ func TestDeleteCounter(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -966,7 +979,9 @@ func TestDeleteCounter(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddObj(&nftables.CounterObj{
@ -996,8 +1011,8 @@ func TestDelRule(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1017,7 +1032,9 @@ func TestDelRule(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.DelRule(&nftables.Rule{
@ -1041,8 +1058,8 @@ func TestLog(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1062,7 +1079,9 @@ func TestLog(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddRule(&nftables.Rule{
@ -1091,8 +1110,8 @@ func TestTProxy(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1112,7 +1131,9 @@ func TestTProxy(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddRule(&nftables.Rule{
@ -1154,8 +1175,8 @@ func TestCt(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1175,7 +1196,9 @@ func TestCt(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddRule(&nftables.Rule{
@ -1207,8 +1230,8 @@ func TestCtSet(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1228,7 +1251,9 @@ func TestCtSet(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddRule(&nftables.Rule{
@ -1265,8 +1290,8 @@ func TestCtStat(t *testing.T) {
// batch end
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1286,7 +1311,9 @@ func TestCtStat(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddRule(&nftables.Rule{
@ -1323,8 +1350,8 @@ func TestAddRuleWithPosition(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1344,7 +1371,9 @@ func TestAddRuleWithPosition(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.AddRule(&nftables.Rule{
@ -1387,6 +1416,73 @@ func TestAddRuleWithPosition(t *testing.T) {
}
}
func TestLastingConnection(t *testing.T) {
testdialerr := errors.New("test dial sentinel error")
dialCount := 0
c, err := nftables.New(
nftables.AsLasting(),
nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) {
dialCount++
return nil, testdialerr
}))
if err != nil {
t.Errorf("creating lasting netlink connection failed %v", err)
return
}
defer func() {
if err := c.CloseLasting(); err != nil {
t.Errorf("closing lasting netlink connection failed %v", err)
}
}()
_, err = c.ListTables()
if !errors.Is(err, testdialerr) {
t.Errorf("non-testdialerr error returned from TestDial %v", err)
return
}
if dialCount != 1 {
t.Errorf("internal test error with TestDial invocations %v", dialCount)
return
}
// While a lasting netlink connection is open, replacing TestDial must be
// ineffective as there is no need to dial again and activating a new
// TestDial function. The newly set TestDial function must be getting
// ignored.
c.TestDial = func(req []netlink.Message) ([]netlink.Message, error) {
dialCount--
return nil, errors.New("transient netlink connection error")
}
_, err = c.ListTables()
if !errors.Is(err, testdialerr) {
t.Errorf("non-testdialerr error returned from TestDial %v", err)
return
}
if dialCount != 2 {
t.Errorf("internal test error with TestDial invocations %v", dialCount)
return
}
for i := 0; i < 2; i++ {
err = c.CloseLasting()
if err != nil {
t.Errorf("closing lasting netlink connection failed in attempt no. %d: %v", i, err)
return
}
}
_, err = c.ListTables()
if errors.Is(err, testdialerr) {
t.Error("testdialerr error returned from TestDial when expecting different error")
return
}
if dialCount != 1 {
t.Errorf("internal test error with TestDial invocations %v", dialCount)
return
}
// fall into defer'ed second CloseLasting which must not cause any errors.
}
func TestListChains(t *testing.T) {
polDrop := nftables.ChainPolicyDrop
polAcpt := nftables.ChainPolicyAccept
@ -1432,8 +1528,8 @@ func TestListChains(t *testing.T) {
},
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
msgReply := make([]netlink.Message, len(reply))
for i, r := range reply {
nm := &netlink.Message{}
@ -1443,7 +1539,9 @@ func TestListChains(t *testing.T) {
msgReply[i] = *nm
}
return msgReply, nil
},
}))
if err != nil {
t.Fatal(err)
}
chains, err := c.ListChains()
@ -1522,8 +1620,8 @@ func TestAddChain(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1543,7 +1641,9 @@ func TestAddChain(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -1599,8 +1699,8 @@ func TestDelChain(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1620,7 +1720,9 @@ func TestDelChain(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
tt.chain.Table = &nftables.Table{
@ -1649,8 +1751,8 @@ func TestGetObjReset(t *testing.T) {
[]netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}},
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1672,7 +1774,9 @@ func TestGetObjReset(t *testing.T) {
rep := reply[0]
reply = reply[1:]
return rep, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
@ -1892,8 +1996,8 @@ func TestConfigureClamping(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -1913,7 +2017,9 @@ func TestConfigureClamping(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -2025,8 +2131,8 @@ func TestMatchPacketHeader(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -2046,7 +2152,9 @@ func TestMatchPacketHeader(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -2153,8 +2261,8 @@ func TestDropVerdict(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -2174,7 +2282,9 @@ func TestDropVerdict(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -2253,8 +2363,8 @@ func TestCreateUseAnonymousSet(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -2274,7 +2384,9 @@ func TestCreateUseAnonymousSet(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3221,8 +3333,8 @@ func TestConfigureNATRedirect(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3242,7 +3354,9 @@ func TestConfigureNATRedirect(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3326,8 +3440,8 @@ func TestConfigureJumpVerdict(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3347,7 +3461,9 @@ func TestConfigureJumpVerdict(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3432,8 +3548,8 @@ func TestConfigureReturnVerdict(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3453,7 +3569,9 @@ func TestConfigureReturnVerdict(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3517,8 +3635,8 @@ func TestConfigureRangePort(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3538,7 +3656,9 @@ func TestConfigureRangePort(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3615,8 +3735,8 @@ func TestConfigureRangeIPv4(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3636,7 +3756,9 @@ func TestConfigureRangeIPv4(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3705,8 +3827,8 @@ func TestConfigureRangeIPv6(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3726,7 +3848,9 @@ func TestConfigureRangeIPv6(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -3818,8 +3942,8 @@ func TestSet4(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -3839,7 +3963,9 @@ func TestSet4(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
tbl := &nftables.Table{
@ -4004,8 +4130,8 @@ func TestMasq(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4025,7 +4151,9 @@ func TestMasq(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -4135,8 +4263,8 @@ func TestReject(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4156,7 +4284,9 @@ func TestReject(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -4263,8 +4393,8 @@ func TestFib(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4284,7 +4414,9 @@ func TestFib(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -4367,8 +4499,8 @@ func TestNumgen(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4388,7 +4520,9 @@ func TestNumgen(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -4452,8 +4586,8 @@ func TestMap(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4473,7 +4607,9 @@ func TestMap(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -4570,8 +4706,8 @@ func TestVmap(t *testing.T) {
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4591,7 +4727,9 @@ func TestVmap(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
filter := c.AddTable(&nftables.Table{
@ -4630,8 +4768,8 @@ func TestJHash(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4651,7 +4789,9 @@ func TestJHash(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -4731,8 +4871,8 @@ func TestDup(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4752,7 +4892,9 @@ func TestDup(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -4831,8 +4973,8 @@ func TestDupWoDev(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4852,7 +4994,9 @@ func TestDupWoDev(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -4913,8 +5057,8 @@ func TestNotrack(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -4934,7 +5078,9 @@ func TestNotrack(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -4983,8 +5129,8 @@ func TestQuota(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -5004,7 +5150,9 @@ func TestQuota(t *testing.T) {
}
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
@ -5058,8 +5206,8 @@ func TestStatelessNAT(t *testing.T) {
[]byte("\x00\x00\x00\x0a"),
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
c, err := nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
@ -5079,7 +5227,9 @@ func TestStatelessNAT(t *testing.T) {
want = want[1:]
}
return req, nil
},
}))
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()

12
obj.go
View File

@ -41,8 +41,8 @@ func (cc *Conn) AddObject(o Obj) Obj {
// AddObj adds the specified Obj. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects
func (cc *Conn) AddObj(o Obj) Obj {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data, err := o.marshal(true)
if err != nil {
cc.setErr(err)
@ -61,8 +61,8 @@ func (cc *Conn) AddObj(o Obj) Obj {
// DeleteObject deletes the specified Obj
func (cc *Conn) DeleteObject(o Obj) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data, err := o.marshal(false)
if err != nil {
cc.setErr(err)
@ -174,11 +174,11 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
}
func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
var data []byte
var flags netlink.HeaderFlags

12
rule.go
View File

@ -60,11 +60,11 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
// GetRules returns the rules in the specified table and chain.
func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
@ -107,8 +107,8 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
// AddRule adds the specified Rule
func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
exprAttrs := make([]netlink.Attribute, len(r.Exprs))
for idx, expr := range r.Exprs {
exprAttrs[idx] = netlink.Attribute{
@ -190,8 +190,8 @@ func (cc *Conn) InsertRule(r *Rule) *Rule {
// DelRule deletes the specified Rule, rule's handle cannot be 0
func (cc *Conn) DelRule(r *Rule) error {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")},

32
set.go
View File

@ -318,8 +318,8 @@ func decodeElement(d []byte) ([]byte, error) {
// SetAddElements applies data points to an nftables set.
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
@ -431,8 +431,8 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
// AddSet adds the specified Set.
func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
// Based on nft implementation & linux source.
// Link: https://github.com/torvalds/linux/blob/49a57857aeea06ca831043acbb0fa5e0f50602fd/net/netfilter/nf_tables_api.c#L3395
// Another reference: https://git.netfilter.org/nftables/tree/src
@ -567,8 +567,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
// DelSet deletes a specific set, along with all elements it contains.
func (cc *Conn) DelSet(s *Set) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
@ -584,8 +584,8 @@ func (cc *Conn) DelSet(s *Set) {
// SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
@ -607,8 +607,8 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
// FlushSet deletes all data points from an nftables set.
func (cc *Conn) FlushSet(s *Set) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
@ -747,11 +747,11 @@ func elementsFromMsg(msg netlink.Message) ([]SetElement, error) {
// GetSets returns the sets in the specified table.
func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")},
@ -791,11 +791,11 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
// GetSetByName returns the set in the specified table if matching name is found.
func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")},
@ -836,11 +836,11 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) {
// GetSetElements returns the elements in the specified set.
func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},

View File

@ -47,8 +47,8 @@ type Table struct {
// DelTable deletes a specific table, along with all chains/rules it contains.
func (cc *Conn) DelTable(t *Table) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
@ -65,8 +65,8 @@ func (cc *Conn) DelTable(t *Table) {
// AddTable adds the specified Table. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables
func (cc *Conn) AddTable(t *Table) *Table {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
@ -84,8 +84,8 @@ func (cc *Conn) AddTable(t *Table) *Table {
// FlushTable removes all rules in all chains within the specified Table. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables#Flushing_tables
func (cc *Conn) FlushTable(t *Table) {
cc.Lock()
defer cc.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
})
@ -100,12 +100,11 @@ func (cc *Conn) FlushTable(t *Table) {
// ListTables returns currently configured tables in the kernel
func (cc *Conn) ListTables() ([]*Table, error) {
conn, err := cc.dialNetlink()
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer conn.Close()
defer func() { _ = closer() }()
msg := netlink.Message{
Header: netlink.Header{