Fix getting concatenated data types for maps (#217)

This also implements parsing of concatenated data types.
This commit is contained in:
konradh 2023-04-02 10:11:12 +02:00 committed by GitHub
parent 2729c5a5ee
commit a93939a185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 103 deletions

171
set.go
View File

@ -117,53 +117,53 @@ var (
TypeTimeDay = SetDatatype{Name: "day", Bytes: 1, nftMagic: 45}
TypeCGroupV2 = SetDatatype{Name: "cgroupsv2", Bytes: 8, nftMagic: 46}
nftDatatypes = map[string]SetDatatype{
TypeVerdict.Name: TypeVerdict,
TypeNFProto.Name: TypeNFProto,
TypeBitmask.Name: TypeBitmask,
TypeInteger.Name: TypeInteger,
TypeString.Name: TypeString,
TypeLLAddr.Name: TypeLLAddr,
TypeIPAddr.Name: TypeIPAddr,
TypeIP6Addr.Name: TypeIP6Addr,
TypeEtherAddr.Name: TypeEtherAddr,
TypeEtherType.Name: TypeEtherType,
TypeARPOp.Name: TypeARPOp,
TypeInetProto.Name: TypeInetProto,
TypeInetService.Name: TypeInetService,
TypeICMPType.Name: TypeICMPType,
TypeTCPFlag.Name: TypeTCPFlag,
TypeDCCPPktType.Name: TypeDCCPPktType,
TypeMHType.Name: TypeMHType,
TypeTime.Name: TypeTime,
TypeMark.Name: TypeMark,
TypeIFIndex.Name: TypeIFIndex,
TypeARPHRD.Name: TypeARPHRD,
TypeRealm.Name: TypeRealm,
TypeClassID.Name: TypeClassID,
TypeUID.Name: TypeUID,
TypeGID.Name: TypeGID,
TypeCTState.Name: TypeCTState,
TypeCTDir.Name: TypeCTDir,
TypeCTStatus.Name: TypeCTStatus,
TypeICMP6Type.Name: TypeICMP6Type,
TypeCTLabel.Name: TypeCTLabel,
TypePktType.Name: TypePktType,
TypeICMPCode.Name: TypeICMPCode,
TypeICMPV6Code.Name: TypeICMPV6Code,
TypeICMPXCode.Name: TypeICMPXCode,
TypeDevGroup.Name: TypeDevGroup,
TypeDSCP.Name: TypeDSCP,
TypeECN.Name: TypeECN,
TypeFIBAddr.Name: TypeFIBAddr,
TypeBoolean.Name: TypeBoolean,
TypeCTEventBit.Name: TypeCTEventBit,
TypeIFName.Name: TypeIFName,
TypeIGMPType.Name: TypeIGMPType,
TypeTimeDate.Name: TypeTimeDate,
TypeTimeHour.Name: TypeTimeHour,
TypeTimeDay.Name: TypeTimeDay,
TypeCGroupV2.Name: TypeCGroupV2,
nftDatatypes = []SetDatatype{
TypeVerdict,
TypeNFProto,
TypeBitmask,
TypeInteger,
TypeString,
TypeLLAddr,
TypeIPAddr,
TypeIP6Addr,
TypeEtherAddr,
TypeEtherType,
TypeARPOp,
TypeInetProto,
TypeInetService,
TypeICMPType,
TypeTCPFlag,
TypeDCCPPktType,
TypeMHType,
TypeTime,
TypeMark,
TypeIFIndex,
TypeARPHRD,
TypeRealm,
TypeClassID,
TypeUID,
TypeGID,
TypeCTState,
TypeCTDir,
TypeCTStatus,
TypeICMP6Type,
TypeCTLabel,
TypePktType,
TypeICMPCode,
TypeICMPV6Code,
TypeICMPXCode,
TypeDevGroup,
TypeDSCP,
TypeECN,
TypeFIBAddr,
TypeBoolean,
TypeCTEventBit,
TypeIFName,
TypeIGMPType,
TypeTimeDate,
TypeTimeHour,
TypeTimeDay,
TypeCGroupV2,
}
// ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c.
@ -177,6 +177,19 @@ var (
sizeOfGIDT uint32 = 4
)
var nftDatatypesByName map[string]SetDatatype
var nftDatatypesByMagic map[uint32]SetDatatype
// Create maps for efficient datatype lookup.
func init() {
nftDatatypesByName = make(map[string]SetDatatype, len(nftDatatypes))
nftDatatypesByMagic = make(map[uint32]SetDatatype, len(nftDatatypes))
for _, dt := range nftDatatypes {
nftDatatypesByName[dt.Name] = dt
nftDatatypesByMagic[dt.nftMagic] = dt
}
}
// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow.
var ErrTooManyTypes = errors.New("too many types to concat")
@ -221,7 +234,7 @@ func ConcatSetTypeElements(t SetDatatype) []SetDatatype {
names := strings.Split(t.Name, " . ")
types := make([]SetDatatype, len(names))
for i, n := range names {
types[i] = nftDatatypes[n]
types[i] = nftDatatypesByName[n]
}
return types
}
@ -678,17 +691,11 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.Concatenation = (flags & NFT_SET_CONCAT) != 0
case unix.NFTA_SET_KEY_TYPE:
nftMagic := ad.Uint32()
if invalidMagic, ok := validateKeyType(nftMagic); !ok {
return nil, fmt.Errorf("could not determine key type %+v", invalidMagic)
}
set.KeyType.nftMagic = nftMagic
for _, dt := range nftDatatypes {
// If this is a non-concatenated type, we can assign the descriptor.
if nftMagic == dt.nftMagic {
set.KeyType = dt
break
}
dt, err := parseSetDatatype(nftMagic)
if err != nil {
return nil, fmt.Errorf("could not determine data type: %w", err)
}
set.KeyType = dt
case unix.NFTA_SET_DATA_TYPE:
nftMagic := ad.Uint32()
// Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1
@ -696,42 +703,34 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.KeyType = TypeVerdict
break
}
for _, dt := range nftDatatypes {
if nftMagic == dt.nftMagic {
set.DataType = dt
break
}
}
if set.DataType.nftMagic == 0 {
return nil, fmt.Errorf("could not determine data type %x", nftMagic)
dt, err := parseSetDatatype(nftMagic)
if err != nil {
return nil, fmt.Errorf("could not determine data type: %w", err)
}
set.DataType = dt
}
}
return &set, nil
}
func validateKeyType(bits uint32) ([]uint32, bool) {
var unpackTypes []uint32
var invalidTypes []uint32
found := false
valid := true
for bits != 0 {
unpackTypes = append(unpackTypes, bits&SetConcatTypeMask)
bits = bits >> SetConcatTypeBits
}
for _, t := range unpackTypes {
for _, dt := range nftDatatypes {
if t == dt.nftMagic {
found = true
}
func parseSetDatatype(magic uint32) (SetDatatype, error) {
types := make([]SetDatatype, 0, 32/SetConcatTypeBits)
for magic != 0 {
t := magic & SetConcatTypeMask
magic = magic >> SetConcatTypeBits
dt, ok := nftDatatypesByMagic[t]
if !ok {
return TypeInvalid, fmt.Errorf("could not determine data type %+v", dt)
}
if !found {
invalidTypes = append(invalidTypes, t)
valid = false
}
found = false
// Because we start with the last type, we insert the later types at the front.
types = append([]SetDatatype{dt}, types...)
}
return invalidTypes, valid
dt, err := ConcatSetType(types...)
if err != nil {
return TypeInvalid, fmt.Errorf("could not create data type: %w", err)
}
return dt, nil
}
var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)

View File

@ -1,7 +1,6 @@
package nftables
import (
"reflect"
"testing"
)
@ -17,58 +16,63 @@ func genSetKeyType(types ...uint32) uint32 {
return c
}
func TestValidateNFTMagic(t *testing.T) {
func TestParseSetDatatype(t *testing.T) {
t.Parallel()
tests := []struct {
name string
nftMagicPacked uint32
pass bool
invalid []uint32
typeName string
typeBytes uint32
}{
{
name: "Single valid nftMagic",
nftMagicPacked: genSetKeyType(7),
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic),
pass: true,
invalid: nil,
typeName: "ipv4_addr",
typeBytes: 4,
},
{
name: "Single unknown nftMagic",
nftMagicPacked: genSetKeyType(unknownNFTMagic),
pass: false,
invalid: []uint32{unknownNFTMagic},
},
{
name: "Multiple valid nftMagic",
nftMagicPacked: genSetKeyType(7, 13),
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic),
pass: true,
invalid: nil,
typeName: "ipv4_addr . inet_service",
typeBytes: 8,
},
{
name: "Multiple nftMagic with 1 unknown",
nftMagicPacked: genSetKeyType(7, 13, unknownNFTMagic),
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic),
pass: false,
invalid: []uint32{unknownNFTMagic},
},
{
name: "Multiple nftMagic with 2 unknown",
nftMagicPacked: genSetKeyType(7, 13, unknownNFTMagic, unknownNFTMagic+1),
nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic, unknownNFTMagic+1),
pass: false,
invalid: []uint32{unknownNFTMagic + 1, unknownNFTMagic},
// Invalid entries will appear in reverse order
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invalid, pass := validateKeyType(tt.nftMagicPacked)
datatype, err := parseSetDatatype(tt.nftMagicPacked)
pass := err == nil
if pass && !tt.pass {
t.Fatalf("expected to fail but succeeded")
}
if !pass && tt.pass {
t.Fatalf("expected to succeed but failed with invalid nftMagic: %+v", invalid)
t.Fatalf("expected to succeed but failed: %s", err)
}
if !reflect.DeepEqual(tt.invalid, invalid) {
t.Fatalf("Expected invalid: %+v but got: %+v", tt.invalid, invalid)
expected := SetDatatype{
Name: tt.typeName,
Bytes: tt.typeBytes,
nftMagic: tt.nftMagicPacked,
}
if pass && datatype != expected {
t.Fatalf("invalid datatype: expected %+v but got %+v", expected, datatype)
}
})
}