2019-12-17 17:02:00 -06:00
|
|
|
package nftables
|
|
|
|
|
|
|
|
import (
|
2024-10-18 11:21:51 -05:00
|
|
|
"reflect"
|
2019-12-17 17:02:00 -06:00
|
|
|
"testing"
|
2024-10-18 11:21:51 -05:00
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/mdlayher/netlink"
|
2019-12-17 17:02:00 -06:00
|
|
|
)
|
|
|
|
|
2021-09-08 13:50:07 -05:00
|
|
|
// unknownNFTMagic is an nftMagic value that's unhandled by this
|
|
|
|
// library. We use two of them below.
|
|
|
|
const unknownNFTMagic uint32 = 1<<SetConcatTypeBits - 2
|
|
|
|
|
2019-12-17 17:02:00 -06:00
|
|
|
func genSetKeyType(types ...uint32) uint32 {
|
|
|
|
c := types[0]
|
|
|
|
for i := 1; i < len(types); i++ {
|
|
|
|
c = c<<SetConcatTypeBits | types[i]
|
|
|
|
}
|
|
|
|
return c
|
|
|
|
}
|
|
|
|
|
2023-04-02 03:11:12 -05:00
|
|
|
func TestParseSetDatatype(t *testing.T) {
|
2019-12-17 17:02:00 -06:00
|
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
nftMagicPacked uint32
|
|
|
|
pass bool
|
2023-04-02 03:11:12 -05:00
|
|
|
typeName string
|
|
|
|
typeBytes uint32
|
2019-12-17 17:02:00 -06:00
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "Single valid nftMagic",
|
2023-04-02 03:11:12 -05:00
|
|
|
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic),
|
2019-12-17 17:02:00 -06:00
|
|
|
pass: true,
|
2023-04-02 03:11:12 -05:00
|
|
|
typeName: "ipv4_addr",
|
|
|
|
typeBytes: 4,
|
2019-12-17 17:02:00 -06:00
|
|
|
},
|
|
|
|
{
|
2021-09-08 13:50:07 -05:00
|
|
|
name: "Single unknown nftMagic",
|
|
|
|
nftMagicPacked: genSetKeyType(unknownNFTMagic),
|
2019-12-17 17:02:00 -06:00
|
|
|
pass: false,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "Multiple valid nftMagic",
|
2023-04-02 03:11:12 -05:00
|
|
|
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic),
|
2019-12-17 17:02:00 -06:00
|
|
|
pass: true,
|
2023-04-02 03:11:12 -05:00
|
|
|
typeName: "ipv4_addr . inet_service",
|
|
|
|
typeBytes: 8,
|
2019-12-17 17:02:00 -06:00
|
|
|
},
|
|
|
|
{
|
2021-09-08 13:50:07 -05:00
|
|
|
name: "Multiple nftMagic with 1 unknown",
|
2023-04-02 03:11:12 -05:00
|
|
|
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic),
|
2019-12-17 17:02:00 -06:00
|
|
|
pass: false,
|
|
|
|
},
|
|
|
|
{
|
2021-09-08 13:50:07 -05:00
|
|
|
name: "Multiple nftMagic with 2 unknown",
|
2023-04-02 03:11:12 -05:00
|
|
|
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic, unknownNFTMagic+1),
|
2019-12-17 17:02:00 -06:00
|
|
|
pass: false,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
2023-04-02 03:11:12 -05:00
|
|
|
datatype, err := parseSetDatatype(tt.nftMagicPacked)
|
|
|
|
pass := err == nil
|
2019-12-17 17:02:00 -06:00
|
|
|
if pass && !tt.pass {
|
|
|
|
t.Fatalf("expected to fail but succeeded")
|
|
|
|
}
|
|
|
|
if !pass && tt.pass {
|
2023-04-02 03:11:12 -05:00
|
|
|
t.Fatalf("expected to succeed but failed: %s", err)
|
2019-12-17 17:02:00 -06:00
|
|
|
}
|
2023-04-02 03:11:12 -05:00
|
|
|
expected := SetDatatype{
|
|
|
|
Name: tt.typeName,
|
|
|
|
Bytes: tt.typeBytes,
|
|
|
|
nftMagic: tt.nftMagicPacked,
|
|
|
|
}
|
|
|
|
if pass && datatype != expected {
|
|
|
|
t.Fatalf("invalid datatype: expected %+v but got %+v", expected, datatype)
|
2019-12-17 17:02:00 -06:00
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2020-01-22 15:37:16 -06:00
|
|
|
|
|
|
|
func TestConcatSetType(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
types []SetDatatype
|
|
|
|
err error
|
|
|
|
concatName string
|
|
|
|
concatBytes uint32
|
|
|
|
concatMagic uint32
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "Concatenate six (too many) IPv4s",
|
|
|
|
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
|
|
|
|
err: ErrTooManyTypes,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "Concatenate five IPv4s",
|
|
|
|
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
|
|
|
|
err: nil,
|
|
|
|
concatName: "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr",
|
|
|
|
concatBytes: 20,
|
|
|
|
concatMagic: 0x071c71c7,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "Concatenate IPv6 and port",
|
|
|
|
types: []SetDatatype{TypeIP6Addr, TypeInetService},
|
|
|
|
err: nil,
|
|
|
|
concatName: "ipv6_addr . inet_service",
|
|
|
|
concatBytes: 20,
|
|
|
|
concatMagic: 0x0000020d,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "Concatenate protocol and port",
|
|
|
|
types: []SetDatatype{TypeInetProto, TypeInetService},
|
|
|
|
err: nil,
|
|
|
|
concatName: "inet_proto . inet_service",
|
|
|
|
concatBytes: 8,
|
|
|
|
concatMagic: 0x0000030d,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
concat, err := ConcatSetType(tt.types...)
|
|
|
|
if tt.err != err {
|
|
|
|
t.Errorf("ConcatSetType() returned an incorrect error: expected %v but got %v", tt.err, err)
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if tt.concatName != concat.Name {
|
|
|
|
t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name)
|
|
|
|
}
|
|
|
|
if tt.concatBytes != concat.Bytes {
|
|
|
|
t.Errorf("invalid concatinated number of bytes: expceted %d but got %d", tt.concatBytes, concat.Bytes)
|
|
|
|
}
|
|
|
|
if tt.concatMagic != concat.nftMagic {
|
|
|
|
t.Errorf("invalid concatinated magic: expceted %08x but got %08x", tt.concatMagic, concat.nftMagic)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2022-04-22 10:12:20 -05:00
|
|
|
|
|
|
|
func TestConcatSetTypeElements(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
types []SetDatatype
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "concat ip6 . inet_service",
|
|
|
|
types: []SetDatatype{TypeIP6Addr, TypeInetService},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "concat ip . inet_service . ip6",
|
|
|
|
types: []SetDatatype{TypeIPAddr, TypeInetService, TypeIP6Addr},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "concat inet_proto . inet_service",
|
|
|
|
types: []SetDatatype{TypeInetProto, TypeInetService},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "concat ip . ip . ip . ip",
|
|
|
|
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
concat, err := ConcatSetType(tt.types...)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
elements := ConcatSetTypeElements(concat)
|
|
|
|
if got, want := len(elements), len(tt.types); got != want {
|
|
|
|
t.Errorf("invalid number of elements: expected %d, got %d", got, want)
|
|
|
|
}
|
|
|
|
for i, v := range tt.types {
|
|
|
|
if got, want := elements[i].GetNFTMagic(), v.GetNFTMagic(); got != want {
|
|
|
|
t.Errorf("invalid element on position %d: expected %d, got %d", i, got, want)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2024-10-18 11:21:51 -05:00
|
|
|
|
|
|
|
func TestMarshalSet(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
tbl := &Table{
|
|
|
|
Name: "ipv4table",
|
|
|
|
Family: TableFamilyIPv4,
|
|
|
|
}
|
|
|
|
|
|
|
|
c, err := New(WithTestDial(
|
|
|
|
func(req []netlink.Message) ([]netlink.Message, error) {
|
|
|
|
return req, nil
|
|
|
|
}))
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
c.AddTable(tbl)
|
|
|
|
|
|
|
|
// Ensure the table is added.
|
|
|
|
const connMsgStart = 1
|
|
|
|
if len(c.messages) != connMsgStart {
|
|
|
|
t.Fatalf("AddSet() wrong start message count: %d, expected: %d", len(c.messages), connMsgStart)
|
|
|
|
}
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
set Set
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "Set without flags",
|
|
|
|
set: Set{
|
|
|
|
Name: "test-set",
|
|
|
|
ID: uint32(1),
|
|
|
|
Table: tbl,
|
|
|
|
KeyType: TypeIPAddr,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "Set with size, timeout, dynamic flag specified",
|
|
|
|
set: Set{
|
|
|
|
Name: "test-set",
|
|
|
|
ID: uint32(2),
|
|
|
|
HasTimeout: true,
|
|
|
|
Dynamic: true,
|
|
|
|
Size: 10,
|
|
|
|
Table: tbl,
|
|
|
|
KeyType: TypeIPAddr,
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for i, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
if err := c.AddSet(&tt.set, nil); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
connMsgSetIdx := connMsgStart + i
|
|
|
|
if len(c.messages) != connMsgSetIdx+1 {
|
|
|
|
t.Fatalf("AddSet() wrong message count: %d, expected: %d", len(c.messages), connMsgSetIdx+1)
|
|
|
|
}
|
|
|
|
msg := c.messages[connMsgSetIdx]
|
|
|
|
|
Set rule handle during flush
This change makes it possible to delete rules after inserting them,
without needing to query the rules first. Rules can be deleted both
before and after they are flushed. Additionally, this allows positioning
a new rule next to an existing rule, both before and after the existing
rule is flushed.
There are two ways to refer to a rule: Either by ID or by handle. The ID
is assigned by userspace, and is only valid within a transaction, so it
can only be used before the flush. The handle is assigned by the kernel
when the transaction is committed, and can thus only be used after the
flush. We thus need to set an ID on each newly created rule, and
retrieve the handle of the rule during the flush.
There was an existing mechanism to allocate IDs for sets, but this was
using a global counter without any synchronization to prevent data
races. I replaced this by a new mechanism which uses a connection-scoped
counter.
I implemented a new mechanism for retrieving replies in Flush, and
handling these replies by adding a callback to netlink messages. There
was some existing code to handle "overrun", which I deleted, because it
was nonsensical and just worked by accident. NLMSG_OVERRUN is in fact
not a flag, but a complete message type, so the (re&netlink.Overrun)
masking makes no sense. Even better, NLMSG_OVERRUN is never actually
used by Linux. What this code was actually doing was skipping over the
NFT_MSG_NEWRULE replies, and possibly a NFT_MSG_NEWGEN reply.
I had to update all existing tests which compared generated netlink
messages against a reference, by inserting the newly added ID attribute.
We also need to generate replies for the NFT_MSG_NEWRULE messages with a
handle added.
2025-02-20 13:12:30 -06:00
|
|
|
nset, err := setsFromMsg(netlink.Message{
|
|
|
|
Header: msg.Header,
|
|
|
|
Data: msg.Data,
|
|
|
|
})
|
2024-10-18 11:21:51 -05:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("setsFromMsg() error: %+v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Table pointer is set after flush, which is not implemented in the test.
|
|
|
|
tt.set.Table = nil
|
|
|
|
|
|
|
|
if !reflect.DeepEqual(&tt.set, nset) {
|
|
|
|
t.Fatalf("original %+v and recovered %+v Set structs are different", tt.set, nset)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|