221 lines
4.9 KiB
Go
221 lines
4.9 KiB
Go
// Copyright 2018 Google LLC. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package nftables
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/google/nftables/expr"
|
|
"github.com/mdlayher/netlink"
|
|
"github.com/mdlayher/netlink/nltest"
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
type Entity interface {
|
|
HandleResponse(netlink.Message)
|
|
}
|
|
|
|
// A Conn represents a netlink connection of the nftables family.
|
|
//
|
|
// All methods return their input, so that variables can be defined from string
|
|
// literals when desired.
|
|
//
|
|
// Commands are buffered. Flush sends all buffered commands in a single batch.
|
|
type Conn struct {
|
|
TestDial nltest.Func // for testing only; passed to nltest.Dial
|
|
NetNS int // Network namespace netlink will interact with.
|
|
sync.Mutex
|
|
messages []netlink.Message
|
|
entities map[int32]Entity
|
|
it int32
|
|
err error
|
|
}
|
|
|
|
// Flush sends all buffered commands in a single batch to nftables.
|
|
func (cc *Conn) Flush() error {
|
|
cc.Lock()
|
|
defer func() {
|
|
cc.messages = nil
|
|
cc.entities = nil
|
|
cc.it = 0
|
|
cc.Unlock()
|
|
}()
|
|
if len(cc.messages) == 0 {
|
|
// Messages were already programmed, returning nil
|
|
return nil
|
|
}
|
|
if cc.err != nil {
|
|
return cc.err // serialization error
|
|
}
|
|
conn, err := cc.dialNetlink()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer conn.Close()
|
|
|
|
cc.endBatch(cc.messages)
|
|
|
|
_, err = conn.SendMessages(cc.messages[:cc.it+1])
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("SendMessages: %w", err)
|
|
}
|
|
|
|
// Retrieving of seq number associated to entities
|
|
entitiesBySeq := make(map[uint32]Entity)
|
|
for i, e := range cc.entities {
|
|
entitiesBySeq[cc.messages[i].Header.Sequence] = e
|
|
}
|
|
|
|
// Trigger entities callback
|
|
msg, err := cc.checkReceive(conn)
|
|
for msg {
|
|
rmsg, err := conn.Receive()
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("Receive: %w", err)
|
|
}
|
|
|
|
for _, msg := range rmsg {
|
|
if e, ok := entitiesBySeq[msg.Header.Sequence]; ok {
|
|
e.HandleResponse(msg)
|
|
}
|
|
}
|
|
msg, err = cc.checkReceive(conn)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// PutMessage store netlink message to sent after
|
|
func (cc *Conn) PutMessage(msg netlink.Message) int32 {
|
|
if cc.messages == nil {
|
|
cc.messages = make([]netlink.Message, 128)
|
|
cc.messages = append(cc.messages, netlink.Message{})
|
|
cc.messages[0] = netlink.Message{
|
|
Header: netlink.Header{
|
|
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
|
|
Flags: netlink.Request,
|
|
},
|
|
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
|
}
|
|
}
|
|
|
|
i := atomic.AddInt32(&cc.it, 1)
|
|
|
|
cc.messages = append(cc.messages, netlink.Message{})
|
|
cc.messages[i] = msg
|
|
|
|
return i
|
|
}
|
|
|
|
// PutEntity store entity to relate to netlink response
|
|
func (cc *Conn) PutEntity(i int32, e Entity) {
|
|
if cc.entities == nil {
|
|
cc.entities = make(map[int32]Entity)
|
|
}
|
|
cc.entities[i] = e
|
|
}
|
|
|
|
func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) {
|
|
if cc.TestDial != nil {
|
|
return false, nil
|
|
}
|
|
|
|
sc, err := c.SyscallConn()
|
|
|
|
if err != nil {
|
|
return false, fmt.Errorf("SyscallConn error: %w", err)
|
|
}
|
|
|
|
var n int
|
|
|
|
sc.Control(func(fd uintptr) {
|
|
var fdSet unix.FdSet
|
|
fdSet.Zero()
|
|
fdSet.Set(int(fd))
|
|
|
|
n, err = unix.Select(int(fd)+1, &fdSet, nil, nil, &unix.Timeval{})
|
|
})
|
|
|
|
if err == nil && n > 0 {
|
|
return true, nil
|
|
}
|
|
|
|
return false, err
|
|
}
|
|
|
|
// FlushRuleset flushes the entire ruleset. See also
|
|
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
|
|
func (cc *Conn) FlushRuleset() {
|
|
cc.Lock()
|
|
defer cc.Unlock()
|
|
cc.PutMessage(netlink.Message{
|
|
Header: netlink.Header{
|
|
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
|
|
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
|
|
},
|
|
Data: extraHeader(0, 0),
|
|
})
|
|
}
|
|
|
|
func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
|
|
if cc.TestDial != nil {
|
|
return nltest.Dial(cc.TestDial), nil
|
|
}
|
|
return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS})
|
|
}
|
|
|
|
func (cc *Conn) setErr(err error) {
|
|
if cc.err != nil {
|
|
return
|
|
}
|
|
cc.err = err
|
|
}
|
|
|
|
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
|
|
b, err := netlink.MarshalAttributes(attrs)
|
|
if err != nil {
|
|
cc.setErr(err)
|
|
return nil
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (cc *Conn) marshalExpr(e expr.Any) []byte {
|
|
b, err := expr.Marshal(e)
|
|
if err != nil {
|
|
cc.setErr(err)
|
|
return nil
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (cc *Conn) endBatch(messages []netlink.Message) {
|
|
|
|
i := atomic.AddInt32(&cc.it, 1)
|
|
|
|
cc.messages[i] = netlink.Message{
|
|
Header: netlink.Header{
|
|
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
|
|
Flags: netlink.Request,
|
|
},
|
|
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
|
}
|
|
}
|