From 327d5c62cd3e128ff82ce2ddb78358ffdfa39d92 Mon Sep 17 00:00:00 2001
From: Leon Vack <LogicalOverflow@users.noreply.github.com>
Date: Wed, 22 Jan 2020 22:37:16 +0100
Subject: [PATCH] function to create concatenated SetDatatypes (#93)

added function to create concatenated SetDatatypes
---
 set.go      | 46 ++++++++++++++++++++++++++++++++++++--
 set_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 108 insertions(+), 2 deletions(-)

diff --git a/set.go b/set.go
index e9d51cb..563cacb 100644
--- a/set.go
+++ b/set.go
@@ -18,6 +18,7 @@ import (
 	"encoding/binary"
 	"errors"
 	"fmt"
+	"strings"
 
 	"github.com/google/nftables/expr"
 
@@ -28,7 +29,10 @@ import (
 
 // SetConcatTypeBits defines concatination bits, originally defined in
 // https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002
-const SetConcatTypeBits = 6
+const (
+	SetConcatTypeBits = 6
+	SetConcatTypeMask = (1 << SetConcatTypeBits) - 1
+)
 
 var allocSetID uint32
 
@@ -77,6 +81,44 @@ var (
 	}
 )
 
+// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow.
+var ErrTooManyTypes = errors.New("too many types to concat")
+
+// MustConcatSetType does the same as ConcatSetType, but panics instead of an
+// error. It simplifies safe initialization of global variables.
+func MustConcatSetType(types ...SetDatatype) SetDatatype {
+	t, err := ConcatSetType(types...)
+	if err != nil {
+		panic(err)
+	}
+	return t
+}
+
+// ConcatSetType constructs a new SetDatatype which consists of a concatenation
+// of the passed types. It returns ErrTooManyTypes, if nftMagic would overflow
+// (more than 5 types).
+func ConcatSetType(types ...SetDatatype) (SetDatatype, error) {
+	if len(types) > 32/SetConcatTypeBits {
+		return SetDatatype{}, ErrTooManyTypes
+	}
+
+	var magic, bytes uint32
+	names := make([]string, len(types))
+	for i, t := range types {
+		bytes += t.Bytes
+		// concatenated types pad the length to multiples of the register size (4 bytes)
+		// see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162
+		if t.Bytes%4 != 0 {
+			bytes += 4 - (t.Bytes % 4)
+		}
+		names[i] = t.Name
+
+		magic <<= SetConcatTypeBits
+		magic |= t.nftMagic & SetConcatTypeMask
+	}
+	return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil
+}
+
 // Set represents an nftables set. Anonymous sets are only valid within the
 // context of a single batch.
 type Set struct {
@@ -469,7 +511,7 @@ func validateKeyType(bits uint32) ([]uint32, bool) {
 	found := false
 	valid := true
 	for bits != 0 {
-		unpackTypes = append(unpackTypes, bits&0x3f)
+		unpackTypes = append(unpackTypes, bits&SetConcatTypeMask)
 		bits = bits >> SetConcatTypeBits
 	}
 	for _, t := range unpackTypes {
diff --git a/set_test.go b/set_test.go
index 88a55fb..5417326 100644
--- a/set_test.go
+++ b/set_test.go
@@ -69,3 +69,67 @@ func TestValidateNFTMagic(t *testing.T) {
 		})
 	}
 }
+
+func TestConcatSetType(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name        string
+		types       []SetDatatype
+		err         error
+		concatName  string
+		concatBytes uint32
+		concatMagic uint32
+	}{
+		{
+			name:  "Concatenate six (too many) IPv4s",
+			types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
+			err:   ErrTooManyTypes,
+		},
+		{
+			name:        "Concatenate five IPv4s",
+			types:       []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
+			err:         nil,
+			concatName:  "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr",
+			concatBytes: 20,
+			concatMagic: 0x071c71c7,
+		},
+		{
+			name:        "Concatenate IPv6 and port",
+			types:       []SetDatatype{TypeIP6Addr, TypeInetService},
+			err:         nil,
+			concatName:  "ipv6_addr . inet_service",
+			concatBytes: 20,
+			concatMagic: 0x0000020d,
+		},
+		{
+			name:        "Concatenate protocol and port",
+			types:       []SetDatatype{TypeInetProto, TypeInetService},
+			err:         nil,
+			concatName:  "inet_proto . inet_service",
+			concatBytes: 8,
+			concatMagic: 0x0000030d,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			concat, err := ConcatSetType(tt.types...)
+			if tt.err != err {
+				t.Errorf("ConcatSetType() returned an incorrect error: expected %v but got %v", tt.err, err)
+			}
+			if err != nil {
+				return
+			}
+			if tt.concatName != concat.Name {
+				t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name)
+			}
+			if tt.concatBytes != concat.Bytes {
+				t.Errorf("invalid concatinated number of bytes: expceted %d but got %d", tt.concatBytes, concat.Bytes)
+			}
+			if tt.concatMagic != concat.nftMagic {
+				t.Errorf("invalid concatinated magic: expceted %08x but got %08x", tt.concatMagic, concat.nftMagic)
+			}
+		})
+	}
+}