Compare commits

..

7 Commits

Author SHA1 Message Date
Jan Schär c8100b04e6 Deprecate Rule.Flags field
The functionality added in a46119e5 never worked: If you set
NFTA_RULE_POSITION to 0, the kernel will just complain that a rule with
this handle does not exist. This removes the broken functionality,
leaving the field deprecated.

The right way to insert a rule at the beginning of a chain is to use
InsertRule and leave Position unset.

https://github.com/google/nftables/issues/126 mentions that the nft
command allows referring to rules by index. But here is a quote from the
nft manpage:

> The add and insert commands support an optional location specifier,
> which is either a handle or the index (starting at zero) of an
> existing rule. Internally, rule locations are always identified by
> handle and the translation from index happens in userspace.

In other words, identifiying rules by index is a feature of nft and is
not part of the kernel interface.
2025-03-26 08:53:15 +00:00
Jan Schär 207a46354c
Set rule handle during flush (#299)
This change makes it possible to delete rules after inserting them,
without needing to query the rules first. Additionally, this allows
positioning a new rule next to an existing rule.

There are two ways to refer to a rule: Either by ID or by handle. The ID
is assigned by userspace, and is only valid within a transaction, so it
can only be used before the flush. The handle is assigned by the kernel
when the transaction is committed, and can thus only be used after the
flush. We thus need to set an ID on each newly created rule, and
retrieve the handle of the rule during the flush.

I extended the message struct with a pointer to the Rule which the
message creates. This allows calling the reply handler callback which
sets the handle.

I updated tests to add a handle to generated replies for the
NFT_MSG_NEWRULE messages.
2025-03-26 09:24:33 +01:00
Jan Schär 9a2862f48b
Receive replies in Flush (#309)
Commit 0d9bfa4d18 added code to handle "overrun", but the commit is
very misleading. NLMSG_OVERRUN is in fact not a flag, but a complete
message type, so the (re&netlink.Overrun) masking makes no sense. Even
better, NLMSG_OVERRUN is never actually used by Linux.

The actual bug which the commit was attempting to fix is that Flush was
not receiving replies which the kernel sent for messages with the echo
flag. This change reverts that commit and instead adds code in Flush to
receive the replies.

I updated tests which simulate the kernel to generate replies.
2025-03-25 17:03:44 +01:00
Jan Schär d11ef81b6a
Add ID to rule (#308)
The ID allows referring to a rule before it is committed, as
demonstrated in the newly added test.

I had to update all existing tests which compared generated netlink
messages against a reference, by inserting the newly added ID attribute.
2025-03-18 09:44:35 +01:00
Jan Schär e2fedeb355
Improve safety of ID allocation (#307)
There was an existing mechanism to allocate IDs for sets, but this was
using a global counter without any synchronization to prevent data
races. I replaced this by a new mechanism which uses a connection-scoped
counter, protected by the Conn.mu Mutex. This can then also be used in
other places where IDs need to be allocated.

As an additional safeguard, it will panic instead of allocating the same
ID twice in a transaction. Most likely, your program will run out of
memory before reaching this point.
2025-03-13 10:38:46 +01:00
Michael Stapelberg a24f918d08 go.{mod,sum}: update to latest x/ packages 2025-03-13 09:42:41 +01:00
Michael Stapelberg 3163cd89a9 go.mod: bump language version to go1.23
Our dependencies like golang.org/x/net use go1.23 (the oldest still-supported
version, latest is go1.24), so it is time for us to upgrade, too.
2025-03-13 09:41:52 +01:00
12 changed files with 459 additions and 966 deletions

View File

@ -140,7 +140,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
{Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")},
})...) })...)
} }
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create, Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -161,7 +161,7 @@ func (cc *Conn) DelChain(c *Chain) {
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,
@ -179,7 +179,7 @@ func (cc *Conn) FlushChain(c *Chain) {
{Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,

138
conn.go
View File

@ -17,6 +17,7 @@ package nftables
import ( import (
"errors" "errors"
"fmt" "fmt"
"math"
"os" "os"
"sync" "sync"
"syscall" "syscall"
@ -38,12 +39,20 @@ type Conn struct {
TestDial nltest.Func // for testing only; passed to nltest.Dial TestDial nltest.Func // for testing only; passed to nltest.Dial
NetNS int // fd referencing the network namespace netlink will interact with. NetNS int // fd referencing the network namespace netlink will interact with.
lasting bool // establish a lasting connection to be used across multiple netlink operations. lasting bool // establish a lasting connection to be used across multiple netlink operations.
mu sync.Mutex // protects the following state mu sync.Mutex // protects the following state
messages []netlink.Message messages []netlinkMessage
err error err error
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
sockOptions []SockOption sockOptions []SockOption
lastID uint32
allocatedIDs uint32
}
type netlinkMessage struct {
Header netlink.Header
Data []byte
rule *Rule
} }
// ConnOption is an option to change the behavior of the nftables Conn returned by Open. // ConnOption is an option to change the behavior of the nftables Conn returned by Open.
@ -168,24 +177,6 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
return reply, nil return reply, nil
} }
if len(reply) != 0 {
last := reply[len(reply)-1]
for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type {
// we are not finished, the message is overrun
r, err := nlconn.Receive()
if err != nil {
return nil, err
}
reply = append(reply, r...)
last = reply[len(reply)-1]
}
if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 {
// we have already collected an ack
return reply, nil
}
}
// Now we expect an ack // Now we expect an ack
ack, err := nlconn.Receive() ack, err := nlconn.Receive()
if err != nil { if err != nil {
@ -193,8 +184,7 @@ func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]
} }
if len(ack) == 0 { if len(ack) == 0 {
// received an empty ack? return nil, errors.New("received an empty ack")
return reply, nil
} }
msg := ack[0] msg := ack[0]
@ -244,6 +234,7 @@ func (cc *Conn) Flush() error {
cc.mu.Lock() cc.mu.Lock()
defer func() { defer func() {
cc.messages = nil cc.messages = nil
cc.allocatedIDs = 0
cc.mu.Unlock() cc.mu.Unlock()
}() }()
if len(cc.messages) == 0 { if len(cc.messages) == 0 {
@ -259,15 +250,54 @@ func (cc *Conn) Flush() error {
} }
defer func() { _ = closer() }() defer func() { _ = closer() }()
if _, err := conn.SendMessages(batch(cc.messages)); err != nil { messages, err := conn.SendMessages(batch(cc.messages))
if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
} }
var errs error var errs error
// Fetch replies. Each message with the Echo flag triggers a reply of the same
// type. Additionally, if the first message of the batch has the Echo flag, we
// get a reply of type NFT_MSG_NEWGEN, which we ignore.
replyIndex := 0
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
replyIndex++
}
replies, err := conn.Receive()
for err == nil && len(replies) != 0 {
reply := replies[0]
if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence {
// The next message is the acknowledgement for the first message in the
// batch; stop looking for replies.
break
} else if replyIndex < len(cc.messages) {
msg := messages[replyIndex+1]
if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type {
// The only messages which set the echo flag are rule create messages.
err := cc.messages[replyIndex].rule.handleCreateReply(reply)
if err != nil {
errs = errors.Join(errs, err)
}
replyIndex++
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
replyIndex++
}
}
}
replies = replies[1:]
if len(replies) == 0 {
replies, err = conn.Receive()
}
}
// Fetch the requested acknowledgement for each message we sent. // Fetch the requested acknowledgement for each message we sent.
for _, msg := range cc.messages { for i := range cc.messages {
if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { if i != 0 {
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { _, err = conn.Receive()
}
if err != nil {
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) {
// Kernel will only send one error to user space. // Kernel will only send one error to user space.
return err return err
} }
@ -278,6 +308,9 @@ func (cc *Conn) Flush() error {
if errs != nil { if errs != nil {
return fmt.Errorf("conn.Receive: %w", errs) return fmt.Errorf("conn.Receive: %w", errs)
} }
if replyIndex < len(cc.messages) {
return fmt.Errorf("missing reply for message %d in batch", replyIndex)
}
return nil return nil
} }
@ -287,7 +320,7 @@ func (cc *Conn) Flush() error {
func (cc *Conn) FlushRuleset() { func (cc *Conn) FlushRuleset() {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create, Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -346,26 +379,47 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
return b return b
} }
func batch(messages []netlink.Message) []netlink.Message { func batch(messages []netlinkMessage) []netlink.Message {
batch := []netlink.Message{ batch := make([]netlink.Message, len(messages)+2)
{ batch[0] = netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request, Flags: netlink.Request,
},
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
}, },
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
} }
batch = append(batch, messages...) for i, msg := range messages {
batch[i+1] = netlink.Message{
Header: msg.Header,
Data: msg.Data,
}
}
batch = append(batch, netlink.Message{ batch[len(messages)+1] = netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
}) }
return batch return batch
} }
// allocateTransactionID allocates an identifier which is only valid in the
// current transaction.
func (cc *Conn) allocateTransactionID() uint32 {
if cc.allocatedIDs == math.MaxUint32 {
panic(fmt.Sprintf("trying to allocate more than %d IDs in a single nftables transaction", math.MaxUint32))
}
// To make it more likely to catch when a transaction ID is erroneously used
// in a later transaction, cc.lastID is not reset after each transaction;
// instead it is only reset once it rolls over from math.MaxUint32 to 0.
cc.allocatedIDs++
cc.lastID++
if cc.lastID == 0 {
cc.lastID = 1
}
return cc.lastID
}

View File

@ -142,7 +142,7 @@ func (cc *Conn) AddFlowtable(f *Flowtable) *Flowtable {
{Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)}, {Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)},
})...) })...)
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create, Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -162,7 +162,7 @@ func (cc *Conn) DelFlowtable(f *Flowtable) {
{Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,

6
go.mod
View File

@ -1,17 +1,17 @@
module github.com/google/nftables module github.com/google/nftables
go 1.21 go 1.23.0
require ( require (
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.0
github.com/vishvananda/netns v0.0.4 github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.28.0 golang.org/x/sys v0.31.0
) )
require ( require (
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
golang.org/x/net v0.33.0 // indirect golang.org/x/net v0.37.0 // indirect
golang.org/x/sync v0.6.0 // indirect golang.org/x/sync v0.6.0 // indirect
) )

8
go.sum
View File

@ -8,11 +8,11 @@ github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQ
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=

View File

@ -8,7 +8,9 @@ import (
"testing" "testing"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
) )
// Recorder provides an nftables connection that does not send to the Linux // Recorder provides an nftables connection that does not send to the Linux
@ -21,14 +23,34 @@ type Recorder struct {
// Conn opens an nftables connection that records netlink messages into the // Conn opens an nftables connection that records netlink messages into the
// Recorder. // Recorder.
func (r *Recorder) Conn() (*nftables.Conn, error) { func (r *Recorder) Conn() (*nftables.Conn, error) {
nextHandle := uint64(1)
return nftables.New(nftables.WithTestDial( return nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) { func(req []netlink.Message) ([]netlink.Message, error) {
r.requests = append(r.requests, req...) r.requests = append(r.requests, req...)
acks := make([]netlink.Message, 0, len(req)) replies := make([]netlink.Message, 0, len(req))
// Generate replies.
for _, msg := range req {
if msg.Header.Flags&netlink.Echo != 0 {
data := append([]byte{}, msg.Data...)
switch msg.Header.Type {
case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE):
attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)},
})
nextHandle++
data = append(data, attrs...)
}
replies = append(replies, netlink.Message{
Header: msg.Header,
Data: data,
})
}
}
// Generate acknowledgements.
for _, msg := range req { for _, msg := range req {
if msg.Header.Flags&netlink.Acknowledge != 0 { if msg.Header.Flags&netlink.Acknowledge != 0 {
acks = append(acks, netlink.Message{ replies = append(replies, netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Length: 4, Length: 4,
Type: netlink.Error, Type: netlink.Error,
@ -39,7 +61,7 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
}) })
} }
} }
return acks, nil return replies, nil
})) }))
} }

File diff suppressed because it is too large Load Diff

4
obj.go
View File

@ -124,7 +124,7 @@ func (cc *Conn) AddObj(o Obj) Obj {
attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data})
} }
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create, Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -146,7 +146,7 @@ func (cc *Conn) DeleteObject(o Obj) {
data := cc.marshalAttr(attrs) data := cc.marshalAttr(attrs)
data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...)
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,

96
rule.go
View File

@ -30,6 +30,10 @@ const (
delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE) delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE)
) )
// This constant is missing at unix.NFTA_RULE_POSITION_ID.
// TODO: Add the constant in unix and then remove it here.
const nfta_rule_position_id = 0xa
type ruleOperation uint32 type ruleOperation uint32
// Possible PayloadOperationType values. // Possible PayloadOperationType values.
@ -42,10 +46,22 @@ const (
// A Rule does something with a packet. See also // A Rule does something with a packet. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management // https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management
type Rule struct { type Rule struct {
Table *Table Table *Table
Chain *Chain Chain *Chain
// Handle identifies an existing Rule. For a new Rule, this field is set
// during the Flush() in which the rule is committed. Make sure to not access
// this field concurrently with this Flush() to avoid data races.
Handle uint64
// ID is an identifier for a new Rule, which is assigned by
// AddRule/InsertRule, and only valid before the rule is committed by Flush().
// The field is set to 0 during Flush().
ID uint32
// Position can be set to the Handle of another Rule to insert the new Rule
// before (InsertRule) or after (AddRule) the existing rule.
Position uint64 Position uint64
Handle uint64 // PositionID can be set to the ID of another Rule, same as Position, for when
// the existing rule is not yet committed.
PositionID uint32
// Deprecated: The feature for which this field was added never worked. // Deprecated: The feature for which this field was added never worked.
// The field may be removed in a later version. // The field may be removed in a later version.
Flags uint32 Flags uint32
@ -79,7 +95,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
message := netlink.Message{ message := netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, Flags: netlink.Request | netlink.Acknowledge | netlink.Dump,
}, },
Data: append(extraHeader(uint8(t.Family), 0), data...), Data: append(extraHeader(uint8(t.Family), 0), data...),
} }
@ -104,7 +120,6 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
return rules, nil return rules, nil
} }
// AddRule adds the specified Rule
func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
@ -125,6 +140,11 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
data = append(data, cc.marshalAttr([]netlink.Attribute{ data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)}, {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)},
})...) })...)
} else {
r.ID = cc.allocateTransactionID()
data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)},
})...)
} }
data = append(data, cc.marshalAttr([]netlink.Attribute{ data = append(data, cc.marshalAttr([]netlink.Attribute{
@ -145,43 +165,77 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
msgData := []byte{} msgData := []byte{}
msgData = append(msgData, data...) msgData = append(msgData, data...)
var flags netlink.HeaderFlags
if r.UserData != nil { if r.UserData != nil {
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData},
})...) })...)
} }
var flags netlink.HeaderFlags
var ruleRef *Rule
switch op { switch op {
case operationAdd: case operationAdd:
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append
ruleRef = r
case operationInsert: case operationInsert:
flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo
ruleRef = r
case operationReplace: case operationReplace:
flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE flags = netlink.Request | netlink.Acknowledge | netlink.Replace
} }
if r.Position != 0 { if r.Position != 0 {
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_POSITION, Data: binaryutil.BigEndian.PutUint64(r.Position)}, {Type: unix.NFTA_RULE_POSITION, Data: binaryutil.BigEndian.PutUint64(r.Position)},
})...) })...)
} else if r.PositionID != 0 {
msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{
{Type: nfta_rule_position_id, Data: binaryutil.BigEndian.PutUint32(r.PositionID)},
})...)
} }
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: newRuleHeaderType, Type: newRuleHeaderType,
Flags: flags, Flags: flags,
}, },
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
rule: ruleRef,
}) })
return r return r
} }
func (r *Rule) handleCreateReply(reply netlink.Message) error {
ad, err := netlink.NewAttributeDecoder(reply.Data[4:])
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
var handle uint64
for ad.Next() {
switch ad.Type() {
case unix.NFTA_RULE_HANDLE:
handle = ad.Uint64()
}
}
if ad.Err() != nil {
return ad.Err()
}
if handle == 0 {
return fmt.Errorf("missing rule handle in create reply")
}
r.Handle = handle
r.ID = 0
return nil
}
func (cc *Conn) ReplaceRule(r *Rule) *Rule { func (cc *Conn) ReplaceRule(r *Rule) *Rule {
return cc.newRule(r, operationReplace) return cc.newRule(r, operationReplace)
} }
// AddRule inserts the specified Rule after the existing Rule referenced by
// Position/PositionID if set, otherwise at the end of the chain.
func (cc *Conn) AddRule(r *Rule) *Rule { func (cc *Conn) AddRule(r *Rule) *Rule {
if r.Handle != 0 { if r.Handle != 0 {
return cc.newRule(r, operationReplace) return cc.newRule(r, operationReplace)
@ -190,6 +244,8 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
return cc.newRule(r, operationAdd) return cc.newRule(r, operationAdd)
} }
// InsertRule inserts the specified Rule before the existing Rule referenced by
// Position/PositionID if set, otherwise at the beginning of the chain.
func (cc *Conn) InsertRule(r *Rule) *Rule { func (cc *Conn) InsertRule(r *Rule) *Rule {
if r.Handle != 0 { if r.Handle != 0 {
return cc.newRule(r, operationReplace) return cc.newRule(r, operationReplace)
@ -198,7 +254,8 @@ func (cc *Conn) InsertRule(r *Rule) *Rule {
return cc.newRule(r, operationInsert) return cc.newRule(r, operationInsert)
} }
// DelRule deletes the specified Rule, rule's handle cannot be 0 // DelRule deletes the specified Rule. Either the Handle or ID of the
// rule must be set.
func (cc *Conn) DelRule(r *Rule) error { func (cc *Conn) DelRule(r *Rule) error {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
@ -206,15 +263,20 @@ func (cc *Conn) DelRule(r *Rule) error {
{Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")},
}) })
if r.Handle == 0 { if r.Handle != 0 {
return fmt.Errorf("rule's handle cannot be 0") data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)},
})...)
} else if r.ID != 0 {
data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)},
})...)
} else {
return fmt.Errorf("rule must have a handle or ID")
} }
data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)},
})...)
flags := netlink.Request | netlink.Acknowledge flags := netlink.Request | netlink.Acknowledge
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: delRuleHeaderType, Type: delRuleHeaderType,
Flags: flags, Flags: flags,

13
set.go
View File

@ -46,8 +46,6 @@ const (
NFTA_SET_ELEM_EXPRESSIONS = 0x11 NFTA_SET_ELEM_EXPRESSIONS = 0x11
) )
var allocSetID uint32
// SetDatatype represents a datatype declared by nft. // SetDatatype represents a datatype declared by nft.
type SetDatatype struct { type SetDatatype struct {
Name string Name string
@ -508,7 +506,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
{Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
} }
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create, Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -532,8 +530,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
} }
if s.ID == 0 { if s.ID == 0 {
allocSetID++ s.ID = cc.allocateTransactionID()
s.ID = allocSetID
if s.Anonymous { if s.Anonymous {
s.Name = "__set%d" s.Name = "__set%d"
if s.IsMap { if s.IsMap {
@ -683,7 +680,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data})
} }
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET),
Flags: netlink.Request | netlink.Acknowledge | netlink.Create, Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@ -703,7 +700,7 @@ func (cc *Conn) DelSet(s *Set) {
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,
@ -720,7 +717,7 @@ func (cc *Conn) FlushSet(s *Set) {
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,

View File

@ -254,7 +254,10 @@ func TestMarshalSet(t *testing.T) {
} }
msg := c.messages[connMsgSetIdx] msg := c.messages[connMsgSetIdx]
nset, err := setsFromMsg(msg) nset, err := setsFromMsg(netlink.Message{
Header: msg.Header,
Data: msg.Data,
})
if err != nil { if err != nil {
t.Fatalf("setsFromMsg() error: %+v", err) t.Fatalf("setsFromMsg() error: %+v", err)
} }

View File

@ -57,7 +57,7 @@ func (cc *Conn) DelTable(t *Table) {
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,
@ -73,7 +73,7 @@ func (cc *Conn) addTable(t *Table, flag netlink.HeaderFlags) *Table {
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE),
Flags: netlink.Request | netlink.Acknowledge | flag, Flags: netlink.Request | netlink.Acknowledge | flag,
@ -103,7 +103,7 @@ func (cc *Conn) FlushTable(t *Table) {
data := cc.marshalAttr([]netlink.Attribute{ data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
}) })
cc.messages = append(cc.messages, netlink.Message{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,