Fix getting concatenated data types for maps (#217)

This also implements parsing of concatenated data types.
This commit is contained in:
konradh 2023-04-02 10:11:12 +02:00 committed by GitHub
parent 2729c5a5ee
commit a93939a185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 103 deletions

171
set.go
View File

@ -117,53 +117,53 @@ var (
TypeTimeDay = SetDatatype{Name: "day", Bytes: 1, nftMagic: 45} TypeTimeDay = SetDatatype{Name: "day", Bytes: 1, nftMagic: 45}
TypeCGroupV2 = SetDatatype{Name: "cgroupsv2", Bytes: 8, nftMagic: 46} TypeCGroupV2 = SetDatatype{Name: "cgroupsv2", Bytes: 8, nftMagic: 46}
nftDatatypes = map[string]SetDatatype{ nftDatatypes = []SetDatatype{
TypeVerdict.Name: TypeVerdict, TypeVerdict,
TypeNFProto.Name: TypeNFProto, TypeNFProto,
TypeBitmask.Name: TypeBitmask, TypeBitmask,
TypeInteger.Name: TypeInteger, TypeInteger,
TypeString.Name: TypeString, TypeString,
TypeLLAddr.Name: TypeLLAddr, TypeLLAddr,
TypeIPAddr.Name: TypeIPAddr, TypeIPAddr,
TypeIP6Addr.Name: TypeIP6Addr, TypeIP6Addr,
TypeEtherAddr.Name: TypeEtherAddr, TypeEtherAddr,
TypeEtherType.Name: TypeEtherType, TypeEtherType,
TypeARPOp.Name: TypeARPOp, TypeARPOp,
TypeInetProto.Name: TypeInetProto, TypeInetProto,
TypeInetService.Name: TypeInetService, TypeInetService,
TypeICMPType.Name: TypeICMPType, TypeICMPType,
TypeTCPFlag.Name: TypeTCPFlag, TypeTCPFlag,
TypeDCCPPktType.Name: TypeDCCPPktType, TypeDCCPPktType,
TypeMHType.Name: TypeMHType, TypeMHType,
TypeTime.Name: TypeTime, TypeTime,
TypeMark.Name: TypeMark, TypeMark,
TypeIFIndex.Name: TypeIFIndex, TypeIFIndex,
TypeARPHRD.Name: TypeARPHRD, TypeARPHRD,
TypeRealm.Name: TypeRealm, TypeRealm,
TypeClassID.Name: TypeClassID, TypeClassID,
TypeUID.Name: TypeUID, TypeUID,
TypeGID.Name: TypeGID, TypeGID,
TypeCTState.Name: TypeCTState, TypeCTState,
TypeCTDir.Name: TypeCTDir, TypeCTDir,
TypeCTStatus.Name: TypeCTStatus, TypeCTStatus,
TypeICMP6Type.Name: TypeICMP6Type, TypeICMP6Type,
TypeCTLabel.Name: TypeCTLabel, TypeCTLabel,
TypePktType.Name: TypePktType, TypePktType,
TypeICMPCode.Name: TypeICMPCode, TypeICMPCode,
TypeICMPV6Code.Name: TypeICMPV6Code, TypeICMPV6Code,
TypeICMPXCode.Name: TypeICMPXCode, TypeICMPXCode,
TypeDevGroup.Name: TypeDevGroup, TypeDevGroup,
TypeDSCP.Name: TypeDSCP, TypeDSCP,
TypeECN.Name: TypeECN, TypeECN,
TypeFIBAddr.Name: TypeFIBAddr, TypeFIBAddr,
TypeBoolean.Name: TypeBoolean, TypeBoolean,
TypeCTEventBit.Name: TypeCTEventBit, TypeCTEventBit,
TypeIFName.Name: TypeIFName, TypeIFName,
TypeIGMPType.Name: TypeIGMPType, TypeIGMPType,
TypeTimeDate.Name: TypeTimeDate, TypeTimeDate,
TypeTimeHour.Name: TypeTimeHour, TypeTimeHour,
TypeTimeDay.Name: TypeTimeDay, TypeTimeDay,
TypeCGroupV2.Name: TypeCGroupV2, TypeCGroupV2,
} }
// ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c. // ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c.
@ -177,6 +177,19 @@ var (
sizeOfGIDT uint32 = 4 sizeOfGIDT uint32 = 4
) )
var nftDatatypesByName map[string]SetDatatype
var nftDatatypesByMagic map[uint32]SetDatatype
// Create maps for efficient datatype lookup.
func init() {
nftDatatypesByName = make(map[string]SetDatatype, len(nftDatatypes))
nftDatatypesByMagic = make(map[uint32]SetDatatype, len(nftDatatypes))
for _, dt := range nftDatatypes {
nftDatatypesByName[dt.Name] = dt
nftDatatypesByMagic[dt.nftMagic] = dt
}
}
// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow. // ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow.
var ErrTooManyTypes = errors.New("too many types to concat") var ErrTooManyTypes = errors.New("too many types to concat")
@ -221,7 +234,7 @@ func ConcatSetTypeElements(t SetDatatype) []SetDatatype {
names := strings.Split(t.Name, " . ") names := strings.Split(t.Name, " . ")
types := make([]SetDatatype, len(names)) types := make([]SetDatatype, len(names))
for i, n := range names { for i, n := range names {
types[i] = nftDatatypes[n] types[i] = nftDatatypesByName[n]
} }
return types return types
} }
@ -678,17 +691,11 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.Concatenation = (flags & NFT_SET_CONCAT) != 0 set.Concatenation = (flags & NFT_SET_CONCAT) != 0
case unix.NFTA_SET_KEY_TYPE: case unix.NFTA_SET_KEY_TYPE:
nftMagic := ad.Uint32() nftMagic := ad.Uint32()
if invalidMagic, ok := validateKeyType(nftMagic); !ok { dt, err := parseSetDatatype(nftMagic)
return nil, fmt.Errorf("could not determine key type %+v", invalidMagic) if err != nil {
} return nil, fmt.Errorf("could not determine data type: %w", err)
set.KeyType.nftMagic = nftMagic
for _, dt := range nftDatatypes {
// If this is a non-concatenated type, we can assign the descriptor.
if nftMagic == dt.nftMagic {
set.KeyType = dt
break
}
} }
set.KeyType = dt
case unix.NFTA_SET_DATA_TYPE: case unix.NFTA_SET_DATA_TYPE:
nftMagic := ad.Uint32() nftMagic := ad.Uint32()
// Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1 // Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1
@ -696,42 +703,34 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.KeyType = TypeVerdict set.KeyType = TypeVerdict
break break
} }
for _, dt := range nftDatatypes { dt, err := parseSetDatatype(nftMagic)
if nftMagic == dt.nftMagic { if err != nil {
set.DataType = dt return nil, fmt.Errorf("could not determine data type: %w", err)
break
}
}
if set.DataType.nftMagic == 0 {
return nil, fmt.Errorf("could not determine data type %x", nftMagic)
} }
set.DataType = dt
} }
} }
return &set, nil return &set, nil
} }
func validateKeyType(bits uint32) ([]uint32, bool) { func parseSetDatatype(magic uint32) (SetDatatype, error) {
var unpackTypes []uint32 types := make([]SetDatatype, 0, 32/SetConcatTypeBits)
var invalidTypes []uint32 for magic != 0 {
found := false t := magic & SetConcatTypeMask
valid := true magic = magic >> SetConcatTypeBits
for bits != 0 { dt, ok := nftDatatypesByMagic[t]
unpackTypes = append(unpackTypes, bits&SetConcatTypeMask) if !ok {
bits = bits >> SetConcatTypeBits return TypeInvalid, fmt.Errorf("could not determine data type %+v", dt)
}
for _, t := range unpackTypes {
for _, dt := range nftDatatypes {
if t == dt.nftMagic {
found = true
}
} }
if !found { // Because we start with the last type, we insert the later types at the front.
invalidTypes = append(invalidTypes, t) types = append([]SetDatatype{dt}, types...)
valid = false
}
found = false
} }
return invalidTypes, valid
dt, err := ConcatSetType(types...)
if err != nil {
return TypeInvalid, fmt.Errorf("could not create data type: %w", err)
}
return dt, nil
} }
var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)

View File

@ -1,7 +1,6 @@
package nftables package nftables
import ( import (
"reflect"
"testing" "testing"
) )
@ -17,58 +16,63 @@ func genSetKeyType(types ...uint32) uint32 {
return c return c
} }
func TestValidateNFTMagic(t *testing.T) { func TestParseSetDatatype(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
nftMagicPacked uint32 nftMagicPacked uint32
pass bool pass bool
invalid []uint32 typeName string
typeBytes uint32
}{ }{
{ {
name: "Single valid nftMagic", name: "Single valid nftMagic",
nftMagicPacked: genSetKeyType(7), nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic),
pass: true, pass: true,
invalid: nil, typeName: "ipv4_addr",
typeBytes: 4,
}, },
{ {
name: "Single unknown nftMagic", name: "Single unknown nftMagic",
nftMagicPacked: genSetKeyType(unknownNFTMagic), nftMagicPacked: genSetKeyType(unknownNFTMagic),
pass: false, pass: false,
invalid: []uint32{unknownNFTMagic},
}, },
{ {
name: "Multiple valid nftMagic", name: "Multiple valid nftMagic",
nftMagicPacked: genSetKeyType(7, 13), nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic),
pass: true, pass: true,
invalid: nil, typeName: "ipv4_addr . inet_service",
typeBytes: 8,
}, },
{ {
name: "Multiple nftMagic with 1 unknown", name: "Multiple nftMagic with 1 unknown",
nftMagicPacked: genSetKeyType(7, 13, unknownNFTMagic), nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic),
pass: false, pass: false,
invalid: []uint32{unknownNFTMagic},
}, },
{ {
name: "Multiple nftMagic with 2 unknown", name: "Multiple nftMagic with 2 unknown",
nftMagicPacked: genSetKeyType(7, 13, unknownNFTMagic, unknownNFTMagic+1), nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic, unknownNFTMagic+1),
pass: false, pass: false,
invalid: []uint32{unknownNFTMagic + 1, unknownNFTMagic},
// Invalid entries will appear in reverse order
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
invalid, pass := validateKeyType(tt.nftMagicPacked) datatype, err := parseSetDatatype(tt.nftMagicPacked)
pass := err == nil
if pass && !tt.pass { if pass && !tt.pass {
t.Fatalf("expected to fail but succeeded") t.Fatalf("expected to fail but succeeded")
} }
if !pass && tt.pass { if !pass && tt.pass {
t.Fatalf("expected to succeed but failed with invalid nftMagic: %+v", invalid) t.Fatalf("expected to succeed but failed: %s", err)
} }
if !reflect.DeepEqual(tt.invalid, invalid) { expected := SetDatatype{
t.Fatalf("Expected invalid: %+v but got: %+v", tt.invalid, invalid) Name: tt.typeName,
Bytes: tt.typeBytes,
nftMagic: tt.nftMagicPacked,
}
if pass && datatype != expected {
t.Fatalf("invalid datatype: expected %+v but got %+v", expected, datatype)
} }
}) })
} }