added function to create concatenated SetDatatypes

This commit is contained in:
Leon Vack 2020-01-14 14:24:22 +01:00
parent 45c777dde0
commit dcf82b4d03
No known key found for this signature in database
GPG Key ID: B66DAB934BCECCB7
2 changed files with 95 additions and 2 deletions

32
set.go
View File

@ -18,6 +18,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"strings"
"github.com/google/nftables/expr"
@ -28,7 +29,10 @@ import (
// SetConcatTypeBits defines concatination bits, originally defined in
// https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002
const SetConcatTypeBits = 6
const (
SetConcatTypeBits = 6
SetConcatTypeMask = (1 << SetConcatTypeBits) - 1
)
var allocSetID uint32
@ -77,6 +81,30 @@ var (
}
)
// ConcatSetType constructs a new SetDatatype which consists of a concatenation of the passed types. It panics, if the
// nftMagic would overflow (more than 5 types)
func ConcatSetType(types ...SetDatatype) SetDatatype {
if len(types) > 32/SetConcatTypeBits {
panic("too many type to concat")
}
var magic, bytes uint32
names := make([]string, len(types))
for i, t := range types {
bytes += t.Bytes
// concatenated types pad the length to multiples of the register size (4 bytes)
// see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162
if t.Bytes%4 != 0 {
bytes += 4 - (t.Bytes % 4)
}
names[i] = t.Name
magic <<= SetConcatTypeBits
magic |= t.nftMagic & SetConcatTypeMask
}
return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}
}
// Set represents an nftables set. Anonymous sets are only valid within the
// context of a single batch.
type Set struct {
@ -469,7 +497,7 @@ func validateKeyType(bits uint32) ([]uint32, bool) {
found := false
valid := true
for bits != 0 {
unpackTypes = append(unpackTypes, bits&0x3f)
unpackTypes = append(unpackTypes, bits&SetConcatTypeMask)
bits = bits >> SetConcatTypeBits
}
for _, t := range unpackTypes {

View File

@ -69,3 +69,68 @@ func TestValidateNFTMagic(t *testing.T) {
})
}
}
func TestConcatSetType(t *testing.T) {
t.Parallel()
tests := []struct {
name string
types []SetDatatype
pass bool
concatName string
concatBytes uint32
concatMagic uint32
}{
{
name: "Concatenate six (too many) IPv4s",
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
pass: false,
},
{
name: "Concatenate five IPv4s",
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
pass: true,
concatName: "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr",
concatBytes: 20,
concatMagic: 0x071c71c7,
},
{
name: "Concatenate IPv6 and port",
types: []SetDatatype{TypeIP6Addr, TypeInetService},
pass: true,
concatName: "ipv6_addr . inet_service",
concatBytes: 20,
concatMagic: 0x0000020d,
},
{
name: "Concatenate protocol and port",
types: []SetDatatype{TypeInetProto, TypeInetService},
pass: true,
concatName: "inet_proto . inet_service",
concatBytes: 8,
concatMagic: 0x0000030d,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !tt.pass {
defer func() {
if recover() == nil {
t.Fatalf("ConcatSetType() should have paniced but did not")
}
}()
}
concat := ConcatSetType(tt.types...)
if tt.concatName != concat.Name {
t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name)
}
if tt.concatBytes != concat.Bytes {
t.Errorf("invalid concatinated number of bytes: expceted %d but got %d", tt.concatBytes, concat.Bytes)
}
if tt.concatMagic != concat.nftMagic {
t.Errorf("invalid concatinated magic: expceted %08x but got %08x", tt.concatMagic, concat.nftMagic)
}
})
}
}