268 lines
8.0 KiB
Go
268 lines
8.0 KiB
Go
// Copyright 2018 Google LLC. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package nftables
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/google/nftables/expr"
|
|
"github.com/mdlayher/netlink"
|
|
"github.com/mdlayher/netlink/nltest"
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
// A Conn represents a netlink connection of the nftables family.
|
|
//
|
|
// All methods return their input, so that variables can be defined from string
|
|
// literals when desired.
|
|
//
|
|
// Commands are buffered. Flush sends all buffered commands in a single batch.
|
|
type Conn struct {
|
|
TestDial nltest.Func // for testing only; passed to nltest.Dial
|
|
NetNS int // fd referencing the network namespace netlink will interact with.
|
|
|
|
lasting bool // establish a lasting connection to be used across multiple netlink operations.
|
|
mu sync.Mutex // protects the following state
|
|
messages []netlink.Message
|
|
err error
|
|
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
|
|
}
|
|
|
|
// ConnOption is an option to change the behavior of the nftables Conn returned by Open.
|
|
type ConnOption func(*Conn)
|
|
|
|
// New returns a netlink connection for querying and modifying nftables. Some
|
|
// aspects of the new netlink connection can be configured using the options
|
|
// WithNetNSFd, WithTestDial, and AsLasting.
|
|
//
|
|
// A lasting netlink connection should be closed by calling CloseLasting() to
|
|
// close the underlying lasting netlink connection, cancelling all pending
|
|
// operations using this connection.
|
|
func New(opts ...ConnOption) (*Conn, error) {
|
|
cc := &Conn{}
|
|
for _, opt := range opts {
|
|
opt(cc)
|
|
}
|
|
|
|
if !cc.lasting {
|
|
return cc, nil
|
|
}
|
|
|
|
nlconn, err := cc.dialNetlink()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cc.nlconn = nlconn
|
|
return cc, nil
|
|
}
|
|
|
|
// AsLasting creates the new netlink connection as a lasting connection that is
|
|
// reused across multiple netlink operations, instead of opening and closing the
|
|
// underlying netlink connection only for the duration of a single netlink
|
|
// operation.
|
|
func AsLasting() ConnOption {
|
|
return func(cc *Conn) {
|
|
// We cannot create the underlying connection yet, as we are called
|
|
// anywhere in the option processing chain and there might be later
|
|
// options still modifying connection behavior.
|
|
cc.lasting = true
|
|
}
|
|
}
|
|
|
|
// WithNetNSFd sets the network namespace to create a new netlink connection to:
|
|
// the fd must reference a network namespace.
|
|
func WithNetNSFd(fd int) ConnOption {
|
|
return func(cc *Conn) {
|
|
cc.NetNS = fd
|
|
}
|
|
}
|
|
|
|
// WithTestDial sets the specified nltest.Func when creating a new netlink
|
|
// connection.
|
|
func WithTestDial(f nltest.Func) ConnOption {
|
|
return func(cc *Conn) {
|
|
cc.TestDial = f
|
|
}
|
|
}
|
|
|
|
// netlinkCloser is returned by netlinkConn(UnderLock) and must be called after
|
|
// being done with the returned netlink connection in order to properly close
|
|
// this connection, if necessary.
|
|
type netlinkCloser func() error
|
|
|
|
// netlinkConn returns a netlink connection together with a netlinkCloser that
|
|
// later must be called by the caller when it doesn't need the returned netlink
|
|
// connection anymore. The netlinkCloser will close the netlink connection when
|
|
// necessary. If New has been told to create a lasting connection, then this
|
|
// lasting netlink connection will be returned, otherwise a new "transient"
|
|
// netlink connection will be opened and returned instead. netlinkConn must not
|
|
// be called while the Conn.mu lock is currently helt (this will cause a
|
|
// deadlock). Use netlinkConnUnderLock instead in such situations.
|
|
func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) {
|
|
cc.mu.Lock()
|
|
defer cc.mu.Unlock()
|
|
return cc.netlinkConnUnderLock()
|
|
}
|
|
|
|
// netlinkConnUnderLock works like netlinkConn but must be called while holding
|
|
// the Conn.mu lock.
|
|
func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
|
|
if cc.nlconn != nil {
|
|
return cc.nlconn, func() error { return nil }, nil
|
|
}
|
|
nlconn, err := cc.dialNetlink()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return nlconn, func() error { return nlconn.Close() }, nil
|
|
}
|
|
|
|
// CloseLasting closes the lasting netlink connection that has been opened using
|
|
// AsLasting option when creating this connection. If either no lasting netlink
|
|
// connection has been opened or the lasting connection is already in the
|
|
// process of closing or has been closed, CloseLasting will immediately return
|
|
// without any error.
|
|
//
|
|
// CloseLasting will terminate all pending netlink operations using the lasting
|
|
// connection.
|
|
//
|
|
// After closing a lasting connection, the connection will revert to using
|
|
// on-demand transient netlink connections when calling further netlink
|
|
// operations (such as GetTables).
|
|
func (cc *Conn) CloseLasting() error {
|
|
// Don't acquire the lock for the whole duration of the CloseLasting
|
|
// operation, but instead only so long as to make sure to only run the
|
|
// netlink socket close on the first time with a lasting netlink socket. As
|
|
// there is only the New() constructor, but no Open() method, it's
|
|
// impossible to reopen a lasting connection.
|
|
cc.mu.Lock()
|
|
nlconn := cc.nlconn
|
|
cc.nlconn = nil
|
|
cc.mu.Unlock()
|
|
if nlconn != nil {
|
|
return nlconn.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Flush sends all buffered commands in a single batch to nftables.
|
|
func (cc *Conn) Flush() error {
|
|
cc.mu.Lock()
|
|
defer func() {
|
|
cc.messages = nil
|
|
cc.mu.Unlock()
|
|
}()
|
|
if len(cc.messages) == 0 {
|
|
// Messages were already programmed, returning nil
|
|
return nil
|
|
}
|
|
if cc.err != nil {
|
|
return cc.err // serialization error
|
|
}
|
|
conn, closer, err := cc.netlinkConnUnderLock()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = closer() }()
|
|
|
|
if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
|
|
return fmt.Errorf("SendMessages: %w", err)
|
|
}
|
|
|
|
// Fetch the requested acknowledgement for each message we sent.
|
|
for _, msg := range cc.messages {
|
|
if msg.Header.Flags&netlink.Acknowledge == 0 {
|
|
continue // message did not request an acknowledgement
|
|
}
|
|
if _, err := conn.Receive(); err != nil {
|
|
return fmt.Errorf("conn.Receive: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FlushRuleset flushes the entire ruleset. See also
|
|
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
|
|
func (cc *Conn) FlushRuleset() {
|
|
cc.mu.Lock()
|
|
defer cc.mu.Unlock()
|
|
cc.messages = append(cc.messages, netlink.Message{
|
|
Header: netlink.Header{
|
|
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
|
|
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
|
|
},
|
|
Data: extraHeader(0, 0),
|
|
})
|
|
}
|
|
|
|
func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
|
|
if cc.TestDial != nil {
|
|
return nltest.Dial(cc.TestDial), nil
|
|
}
|
|
|
|
return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS})
|
|
}
|
|
|
|
func (cc *Conn) setErr(err error) {
|
|
if cc.err != nil {
|
|
return
|
|
}
|
|
cc.err = err
|
|
}
|
|
|
|
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
|
|
b, err := netlink.MarshalAttributes(attrs)
|
|
if err != nil {
|
|
cc.setErr(err)
|
|
return nil
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
|
b, err := expr.Marshal(fam, e)
|
|
if err != nil {
|
|
cc.setErr(err)
|
|
return nil
|
|
}
|
|
return b
|
|
}
|
|
|
|
func batch(messages []netlink.Message) []netlink.Message {
|
|
batch := []netlink.Message{
|
|
{
|
|
Header: netlink.Header{
|
|
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
|
|
Flags: netlink.Request,
|
|
},
|
|
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
|
},
|
|
}
|
|
|
|
batch = append(batch, messages...)
|
|
|
|
batch = append(batch, netlink.Message{
|
|
Header: netlink.Header{
|
|
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
|
|
Flags: netlink.Request,
|
|
},
|
|
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
|
})
|
|
|
|
return batch
|
|
}
|