diff --git a/set.go b/set.go index 8020de8..10e15f7 100644 --- a/set.go +++ b/set.go @@ -117,53 +117,53 @@ var ( TypeTimeDay = SetDatatype{Name: "day", Bytes: 1, nftMagic: 45} TypeCGroupV2 = SetDatatype{Name: "cgroupsv2", Bytes: 8, nftMagic: 46} - nftDatatypes = map[string]SetDatatype{ - TypeVerdict.Name: TypeVerdict, - TypeNFProto.Name: TypeNFProto, - TypeBitmask.Name: TypeBitmask, - TypeInteger.Name: TypeInteger, - TypeString.Name: TypeString, - TypeLLAddr.Name: TypeLLAddr, - TypeIPAddr.Name: TypeIPAddr, - TypeIP6Addr.Name: TypeIP6Addr, - TypeEtherAddr.Name: TypeEtherAddr, - TypeEtherType.Name: TypeEtherType, - TypeARPOp.Name: TypeARPOp, - TypeInetProto.Name: TypeInetProto, - TypeInetService.Name: TypeInetService, - TypeICMPType.Name: TypeICMPType, - TypeTCPFlag.Name: TypeTCPFlag, - TypeDCCPPktType.Name: TypeDCCPPktType, - TypeMHType.Name: TypeMHType, - TypeTime.Name: TypeTime, - TypeMark.Name: TypeMark, - TypeIFIndex.Name: TypeIFIndex, - TypeARPHRD.Name: TypeARPHRD, - TypeRealm.Name: TypeRealm, - TypeClassID.Name: TypeClassID, - TypeUID.Name: TypeUID, - TypeGID.Name: TypeGID, - TypeCTState.Name: TypeCTState, - TypeCTDir.Name: TypeCTDir, - TypeCTStatus.Name: TypeCTStatus, - TypeICMP6Type.Name: TypeICMP6Type, - TypeCTLabel.Name: TypeCTLabel, - TypePktType.Name: TypePktType, - TypeICMPCode.Name: TypeICMPCode, - TypeICMPV6Code.Name: TypeICMPV6Code, - TypeICMPXCode.Name: TypeICMPXCode, - TypeDevGroup.Name: TypeDevGroup, - TypeDSCP.Name: TypeDSCP, - TypeECN.Name: TypeECN, - TypeFIBAddr.Name: TypeFIBAddr, - TypeBoolean.Name: TypeBoolean, - TypeCTEventBit.Name: TypeCTEventBit, - TypeIFName.Name: TypeIFName, - TypeIGMPType.Name: TypeIGMPType, - TypeTimeDate.Name: TypeTimeDate, - TypeTimeHour.Name: TypeTimeHour, - TypeTimeDay.Name: TypeTimeDay, - TypeCGroupV2.Name: TypeCGroupV2, + nftDatatypes = []SetDatatype{ + TypeVerdict, + TypeNFProto, + TypeBitmask, + TypeInteger, + TypeString, + TypeLLAddr, + TypeIPAddr, + TypeIP6Addr, + TypeEtherAddr, + TypeEtherType, + TypeARPOp, + TypeInetProto, + TypeInetService, + TypeICMPType, + TypeTCPFlag, + TypeDCCPPktType, + TypeMHType, + TypeTime, + TypeMark, + TypeIFIndex, + TypeARPHRD, + TypeRealm, + TypeClassID, + TypeUID, + TypeGID, + TypeCTState, + TypeCTDir, + TypeCTStatus, + TypeICMP6Type, + TypeCTLabel, + TypePktType, + TypeICMPCode, + TypeICMPV6Code, + TypeICMPXCode, + TypeDevGroup, + TypeDSCP, + TypeECN, + TypeFIBAddr, + TypeBoolean, + TypeCTEventBit, + TypeIFName, + TypeIGMPType, + TypeTimeDate, + TypeTimeHour, + TypeTimeDay, + TypeCGroupV2, } // ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c. @@ -177,6 +177,19 @@ var ( sizeOfGIDT uint32 = 4 ) +var nftDatatypesByName map[string]SetDatatype +var nftDatatypesByMagic map[uint32]SetDatatype + +// Create maps for efficient datatype lookup. +func init() { + nftDatatypesByName = make(map[string]SetDatatype, len(nftDatatypes)) + nftDatatypesByMagic = make(map[uint32]SetDatatype, len(nftDatatypes)) + for _, dt := range nftDatatypes { + nftDatatypesByName[dt.Name] = dt + nftDatatypesByMagic[dt.nftMagic] = dt + } +} + // ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow. var ErrTooManyTypes = errors.New("too many types to concat") @@ -221,7 +234,7 @@ func ConcatSetTypeElements(t SetDatatype) []SetDatatype { names := strings.Split(t.Name, " . ") types := make([]SetDatatype, len(names)) for i, n := range names { - types[i] = nftDatatypes[n] + types[i] = nftDatatypesByName[n] } return types } @@ -678,17 +691,11 @@ func setsFromMsg(msg netlink.Message) (*Set, error) { set.Concatenation = (flags & NFT_SET_CONCAT) != 0 case unix.NFTA_SET_KEY_TYPE: nftMagic := ad.Uint32() - if invalidMagic, ok := validateKeyType(nftMagic); !ok { - return nil, fmt.Errorf("could not determine key type %+v", invalidMagic) - } - set.KeyType.nftMagic = nftMagic - for _, dt := range nftDatatypes { - // If this is a non-concatenated type, we can assign the descriptor. - if nftMagic == dt.nftMagic { - set.KeyType = dt - break - } + dt, err := parseSetDatatype(nftMagic) + if err != nil { + return nil, fmt.Errorf("could not determine data type: %w", err) } + set.KeyType = dt case unix.NFTA_SET_DATA_TYPE: nftMagic := ad.Uint32() // Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1 @@ -696,42 +703,34 @@ func setsFromMsg(msg netlink.Message) (*Set, error) { set.KeyType = TypeVerdict break } - for _, dt := range nftDatatypes { - if nftMagic == dt.nftMagic { - set.DataType = dt - break - } - } - if set.DataType.nftMagic == 0 { - return nil, fmt.Errorf("could not determine data type %x", nftMagic) + dt, err := parseSetDatatype(nftMagic) + if err != nil { + return nil, fmt.Errorf("could not determine data type: %w", err) } + set.DataType = dt } } return &set, nil } -func validateKeyType(bits uint32) ([]uint32, bool) { - var unpackTypes []uint32 - var invalidTypes []uint32 - found := false - valid := true - for bits != 0 { - unpackTypes = append(unpackTypes, bits&SetConcatTypeMask) - bits = bits >> SetConcatTypeBits - } - for _, t := range unpackTypes { - for _, dt := range nftDatatypes { - if t == dt.nftMagic { - found = true - } +func parseSetDatatype(magic uint32) (SetDatatype, error) { + types := make([]SetDatatype, 0, 32/SetConcatTypeBits) + for magic != 0 { + t := magic & SetConcatTypeMask + magic = magic >> SetConcatTypeBits + dt, ok := nftDatatypesByMagic[t] + if !ok { + return TypeInvalid, fmt.Errorf("could not determine data type %+v", dt) } - if !found { - invalidTypes = append(invalidTypes, t) - valid = false - } - found = false + // Because we start with the last type, we insert the later types at the front. + types = append([]SetDatatype{dt}, types...) } - return invalidTypes, valid + + dt, err := ConcatSetType(types...) + if err != nil { + return TypeInvalid, fmt.Errorf("could not create data type: %w", err) + } + return dt, nil } var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) diff --git a/set_test.go b/set_test.go index db9092a..dda0a56 100644 --- a/set_test.go +++ b/set_test.go @@ -1,7 +1,6 @@ package nftables import ( - "reflect" "testing" ) @@ -17,58 +16,63 @@ func genSetKeyType(types ...uint32) uint32 { return c } -func TestValidateNFTMagic(t *testing.T) { +func TestParseSetDatatype(t *testing.T) { t.Parallel() tests := []struct { name string nftMagicPacked uint32 pass bool - invalid []uint32 + typeName string + typeBytes uint32 }{ { name: "Single valid nftMagic", - nftMagicPacked: genSetKeyType(7), + nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic), pass: true, - invalid: nil, + typeName: "ipv4_addr", + typeBytes: 4, }, { name: "Single unknown nftMagic", nftMagicPacked: genSetKeyType(unknownNFTMagic), pass: false, - invalid: []uint32{unknownNFTMagic}, }, { name: "Multiple valid nftMagic", - nftMagicPacked: genSetKeyType(7, 13), + nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic), pass: true, - invalid: nil, + typeName: "ipv4_addr . inet_service", + typeBytes: 8, }, { name: "Multiple nftMagic with 1 unknown", - nftMagicPacked: genSetKeyType(7, 13, unknownNFTMagic), + nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic), pass: false, - invalid: []uint32{unknownNFTMagic}, }, { name: "Multiple nftMagic with 2 unknown", - nftMagicPacked: genSetKeyType(7, 13, unknownNFTMagic, unknownNFTMagic+1), + nftMagicPacked: genSetKeyType(TypeIPAddr.nftMagic, TypeInetService.nftMagic, unknownNFTMagic, unknownNFTMagic+1), pass: false, - invalid: []uint32{unknownNFTMagic + 1, unknownNFTMagic}, - // Invalid entries will appear in reverse order }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - invalid, pass := validateKeyType(tt.nftMagicPacked) + datatype, err := parseSetDatatype(tt.nftMagicPacked) + pass := err == nil if pass && !tt.pass { t.Fatalf("expected to fail but succeeded") } if !pass && tt.pass { - t.Fatalf("expected to succeed but failed with invalid nftMagic: %+v", invalid) + t.Fatalf("expected to succeed but failed: %s", err) } - if !reflect.DeepEqual(tt.invalid, invalid) { - t.Fatalf("Expected invalid: %+v but got: %+v", tt.invalid, invalid) + expected := SetDatatype{ + Name: tt.typeName, + Bytes: tt.typeBytes, + nftMagic: tt.nftMagicPacked, + } + if pass && datatype != expected { + t.Fatalf("invalid datatype: expected %+v but got %+v", expected, datatype) } }) }