function to create concatenated SetDatatypes (#93)
added function to create concatenated SetDatatypes
This commit is contained in:
parent
88b35b63a9
commit
327d5c62cd
46
set.go
46
set.go
|
@ -18,6 +18,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
|
||||||
|
@ -28,7 +29,10 @@ import (
|
||||||
|
|
||||||
// SetConcatTypeBits defines concatination bits, originally defined in
|
// SetConcatTypeBits defines concatination bits, originally defined in
|
||||||
// https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002
|
// https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002
|
||||||
const SetConcatTypeBits = 6
|
const (
|
||||||
|
SetConcatTypeBits = 6
|
||||||
|
SetConcatTypeMask = (1 << SetConcatTypeBits) - 1
|
||||||
|
)
|
||||||
|
|
||||||
var allocSetID uint32
|
var allocSetID uint32
|
||||||
|
|
||||||
|
@ -77,6 +81,44 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow.
|
||||||
|
var ErrTooManyTypes = errors.New("too many types to concat")
|
||||||
|
|
||||||
|
// MustConcatSetType does the same as ConcatSetType, but panics instead of an
|
||||||
|
// error. It simplifies safe initialization of global variables.
|
||||||
|
func MustConcatSetType(types ...SetDatatype) SetDatatype {
|
||||||
|
t, err := ConcatSetType(types...)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConcatSetType constructs a new SetDatatype which consists of a concatenation
|
||||||
|
// of the passed types. It returns ErrTooManyTypes, if nftMagic would overflow
|
||||||
|
// (more than 5 types).
|
||||||
|
func ConcatSetType(types ...SetDatatype) (SetDatatype, error) {
|
||||||
|
if len(types) > 32/SetConcatTypeBits {
|
||||||
|
return SetDatatype{}, ErrTooManyTypes
|
||||||
|
}
|
||||||
|
|
||||||
|
var magic, bytes uint32
|
||||||
|
names := make([]string, len(types))
|
||||||
|
for i, t := range types {
|
||||||
|
bytes += t.Bytes
|
||||||
|
// concatenated types pad the length to multiples of the register size (4 bytes)
|
||||||
|
// see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162
|
||||||
|
if t.Bytes%4 != 0 {
|
||||||
|
bytes += 4 - (t.Bytes % 4)
|
||||||
|
}
|
||||||
|
names[i] = t.Name
|
||||||
|
|
||||||
|
magic <<= SetConcatTypeBits
|
||||||
|
magic |= t.nftMagic & SetConcatTypeMask
|
||||||
|
}
|
||||||
|
return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Set represents an nftables set. Anonymous sets are only valid within the
|
// Set represents an nftables set. Anonymous sets are only valid within the
|
||||||
// context of a single batch.
|
// context of a single batch.
|
||||||
type Set struct {
|
type Set struct {
|
||||||
|
@ -469,7 +511,7 @@ func validateKeyType(bits uint32) ([]uint32, bool) {
|
||||||
found := false
|
found := false
|
||||||
valid := true
|
valid := true
|
||||||
for bits != 0 {
|
for bits != 0 {
|
||||||
unpackTypes = append(unpackTypes, bits&0x3f)
|
unpackTypes = append(unpackTypes, bits&SetConcatTypeMask)
|
||||||
bits = bits >> SetConcatTypeBits
|
bits = bits >> SetConcatTypeBits
|
||||||
}
|
}
|
||||||
for _, t := range unpackTypes {
|
for _, t := range unpackTypes {
|
||||||
|
|
64
set_test.go
64
set_test.go
|
@ -69,3 +69,67 @@ func TestValidateNFTMagic(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConcatSetType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
types []SetDatatype
|
||||||
|
err error
|
||||||
|
concatName string
|
||||||
|
concatBytes uint32
|
||||||
|
concatMagic uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Concatenate six (too many) IPv4s",
|
||||||
|
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
|
||||||
|
err: ErrTooManyTypes,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Concatenate five IPv4s",
|
||||||
|
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
|
||||||
|
err: nil,
|
||||||
|
concatName: "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr",
|
||||||
|
concatBytes: 20,
|
||||||
|
concatMagic: 0x071c71c7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Concatenate IPv6 and port",
|
||||||
|
types: []SetDatatype{TypeIP6Addr, TypeInetService},
|
||||||
|
err: nil,
|
||||||
|
concatName: "ipv6_addr . inet_service",
|
||||||
|
concatBytes: 20,
|
||||||
|
concatMagic: 0x0000020d,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Concatenate protocol and port",
|
||||||
|
types: []SetDatatype{TypeInetProto, TypeInetService},
|
||||||
|
err: nil,
|
||||||
|
concatName: "inet_proto . inet_service",
|
||||||
|
concatBytes: 8,
|
||||||
|
concatMagic: 0x0000030d,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
concat, err := ConcatSetType(tt.types...)
|
||||||
|
if tt.err != err {
|
||||||
|
t.Errorf("ConcatSetType() returned an incorrect error: expected %v but got %v", tt.err, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.concatName != concat.Name {
|
||||||
|
t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name)
|
||||||
|
}
|
||||||
|
if tt.concatBytes != concat.Bytes {
|
||||||
|
t.Errorf("invalid concatinated number of bytes: expceted %d but got %d", tt.concatBytes, concat.Bytes)
|
||||||
|
}
|
||||||
|
if tt.concatMagic != concat.nftMagic {
|
||||||
|
t.Errorf("invalid concatinated magic: expceted %08x but got %08x", tt.concatMagic, concat.nftMagic)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue