diff --git a/set.go b/set.go index e9d51cb..563cacb 100644 --- a/set.go +++ b/set.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "errors" "fmt" + "strings" "github.com/google/nftables/expr" @@ -28,7 +29,10 @@ import ( // SetConcatTypeBits defines concatination bits, originally defined in // https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002 -const SetConcatTypeBits = 6 +const ( + SetConcatTypeBits = 6 + SetConcatTypeMask = (1 << SetConcatTypeBits) - 1 +) var allocSetID uint32 @@ -77,6 +81,44 @@ var ( } ) +// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow. +var ErrTooManyTypes = errors.New("too many types to concat") + +// MustConcatSetType does the same as ConcatSetType, but panics instead of an +// error. It simplifies safe initialization of global variables. +func MustConcatSetType(types ...SetDatatype) SetDatatype { + t, err := ConcatSetType(types...) + if err != nil { + panic(err) + } + return t +} + +// ConcatSetType constructs a new SetDatatype which consists of a concatenation +// of the passed types. It returns ErrTooManyTypes, if nftMagic would overflow +// (more than 5 types). +func ConcatSetType(types ...SetDatatype) (SetDatatype, error) { + if len(types) > 32/SetConcatTypeBits { + return SetDatatype{}, ErrTooManyTypes + } + + var magic, bytes uint32 + names := make([]string, len(types)) + for i, t := range types { + bytes += t.Bytes + // concatenated types pad the length to multiples of the register size (4 bytes) + // see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162 + if t.Bytes%4 != 0 { + bytes += 4 - (t.Bytes % 4) + } + names[i] = t.Name + + magic <<= SetConcatTypeBits + magic |= t.nftMagic & SetConcatTypeMask + } + return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil +} + // Set represents an nftables set. Anonymous sets are only valid within the // context of a single batch. type Set struct { @@ -469,7 +511,7 @@ func validateKeyType(bits uint32) ([]uint32, bool) { found := false valid := true for bits != 0 { - unpackTypes = append(unpackTypes, bits&0x3f) + unpackTypes = append(unpackTypes, bits&SetConcatTypeMask) bits = bits >> SetConcatTypeBits } for _, t := range unpackTypes { diff --git a/set_test.go b/set_test.go index 88a55fb..5417326 100644 --- a/set_test.go +++ b/set_test.go @@ -69,3 +69,67 @@ func TestValidateNFTMagic(t *testing.T) { }) } } + +func TestConcatSetType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + types []SetDatatype + err error + concatName string + concatBytes uint32 + concatMagic uint32 + }{ + { + name: "Concatenate six (too many) IPv4s", + types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr}, + err: ErrTooManyTypes, + }, + { + name: "Concatenate five IPv4s", + types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr}, + err: nil, + concatName: "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr", + concatBytes: 20, + concatMagic: 0x071c71c7, + }, + { + name: "Concatenate IPv6 and port", + types: []SetDatatype{TypeIP6Addr, TypeInetService}, + err: nil, + concatName: "ipv6_addr . inet_service", + concatBytes: 20, + concatMagic: 0x0000020d, + }, + { + name: "Concatenate protocol and port", + types: []SetDatatype{TypeInetProto, TypeInetService}, + err: nil, + concatName: "inet_proto . inet_service", + concatBytes: 8, + concatMagic: 0x0000030d, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + concat, err := ConcatSetType(tt.types...) + if tt.err != err { + t.Errorf("ConcatSetType() returned an incorrect error: expected %v but got %v", tt.err, err) + } + if err != nil { + return + } + if tt.concatName != concat.Name { + t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name) + } + if tt.concatBytes != concat.Bytes { + t.Errorf("invalid concatinated number of bytes: expceted %d but got %d", tt.concatBytes, concat.Bytes) + } + if tt.concatMagic != concat.nftMagic { + t.Errorf("invalid concatinated magic: expceted %08x but got %08x", tt.concatMagic, concat.nftMagic) + } + }) + } +}