From 76ed01e300f2916ef6c921d84092ffc38fc0fa45 Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Fri, 22 Apr 2022 15:12:20 +0000 Subject: [PATCH] Support for concat set intervals (#155) Fixes https://github.com/google/nftables/issues/154 Added support for intervals in concat sets Added missing constants, Concatenation flag and KeyEnd field to Set type with marshaling support Added ConcatSetTypeElements function to derive base types from concatenated types Changed nftDatatypes list to map Added tests --- nftables_test.go | 70 +++++++++++++++++++ set.go | 171 +++++++++++++++++++++++++++++++++-------------- set_test.go | 44 ++++++++++++ 3 files changed, 233 insertions(+), 52 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index 18d98c4..ec87ca4 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -2571,6 +2571,76 @@ func TestFlushNamedSet(t *testing.T) { } } +func TestSetElementsInterval(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv6, + Name: "filter", + }) + portSet := &nftables.Set{ + Table: filter, + Name: "ports", + KeyType: nftables.MustConcatSetType(nftables.TypeIP6Addr, nftables.TypeInetService, nftables.TypeIP6Addr), + Interval: true, + Concatenation: true, + } + if err := c.AddSet(portSet, nil); err != nil { + t.Errorf("c.AddSet(portSet) failed: %v", err) + } + + // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } + keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} + // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } + keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} + // elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } + if err := c.SetAddElements(portSet, []nftables.SetElement{ + {Key: keyBytes, KeyEnd: keyEndBytes}, + }); err != nil { + t.Errorf("c.SetVal(portSet) failed: %v", err) + } + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + sets, err := c.GetSets(filter) + if err != nil { + t.Errorf("c.GetSets() failed: %v", err) + } + if len(sets) != 1 { + t.Fatalf("len(sets) = %d, want 1", len(sets)) + } + + elements, err := c.GetSetElements(sets[0]) + if err != nil { + t.Errorf("c.GetSetElements(portSet) failed: %v", err) + } + if len(elements) != 1 { + t.Fatalf("len(portSetElements) = %d, want 1", len(sets)) + } + + element := elements[0] + if len(element.Key) == 0 { + t.Fatal("len(portSetElements.Key) = 0") + } + if len(element.KeyEnd) == 0 { + t.Fatal("len(portSetElements.KeyEnd) = 0") + } + if !bytes.Equal(element.Key, keyBytes) { + t.Fatal("element.Key != keyBytes") + } + if !bytes.Equal(element.KeyEnd, keyEndBytes) { + t.Fatal("element.KeyEnd != keyEndBytes") + } +} + func TestFlushChain(t *testing.T) { // Create a new network namespace to test these operations, // and tear down the namespace at test completion. diff --git a/set.go b/set.go index ca956c8..58ac250 100644 --- a/set.go +++ b/set.go @@ -33,6 +33,13 @@ import ( const ( SetConcatTypeBits = 6 SetConcatTypeMask = (1 << SetConcatTypeBits) - 1 + // below consts added because not found in go unix package + // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306 + NFT_SET_CONCAT = 0x80 + // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n330 + NFTA_SET_DESC_CONCAT = 2 + // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n428 + NFTA_SET_ELEM_KEY_END = 10 ) var allocSetID uint32 @@ -108,53 +115,53 @@ var ( TypeTimeDay = SetDatatype{Name: "day", Bytes: 1, nftMagic: 45} TypeCGroupV2 = SetDatatype{Name: "cgroupsv2", Bytes: 8, nftMagic: 46} - nftDatatypes = []SetDatatype{ - TypeVerdict, - TypeNFProto, - TypeBitmask, - TypeInteger, - TypeString, - TypeLLAddr, - TypeIPAddr, - TypeIP6Addr, - TypeEtherAddr, - TypeEtherType, - TypeARPOp, - TypeInetProto, - TypeInetService, - TypeICMPType, - TypeTCPFlag, - TypeDCCPPktType, - TypeMHType, - TypeTime, - TypeMark, - TypeIFIndex, - TypeARPHRD, - TypeRealm, - TypeClassID, - TypeUID, - TypeGID, - TypeCTState, - TypeCTDir, - TypeCTStatus, - TypeICMP6Type, - TypeCTLabel, - TypePktType, - TypeICMPCode, - TypeICMPV6Code, - TypeICMPXCode, - TypeDevGroup, - TypeDSCP, - TypeECN, - TypeFIBAddr, - TypeBoolean, - TypeCTEventBit, - TypeIFName, - TypeIGMPType, - TypeTimeDate, - TypeTimeHour, - TypeTimeDay, - TypeCGroupV2, + nftDatatypes = map[string]SetDatatype{ + TypeVerdict.Name: TypeVerdict, + TypeNFProto.Name: TypeNFProto, + TypeBitmask.Name: TypeBitmask, + TypeInteger.Name: TypeInteger, + TypeString.Name: TypeString, + TypeLLAddr.Name: TypeLLAddr, + TypeIPAddr.Name: TypeIPAddr, + TypeIP6Addr.Name: TypeIP6Addr, + TypeEtherAddr.Name: TypeEtherAddr, + TypeEtherType.Name: TypeEtherType, + TypeARPOp.Name: TypeARPOp, + TypeInetProto.Name: TypeInetProto, + TypeInetService.Name: TypeInetService, + TypeICMPType.Name: TypeICMPType, + TypeTCPFlag.Name: TypeTCPFlag, + TypeDCCPPktType.Name: TypeDCCPPktType, + TypeMHType.Name: TypeMHType, + TypeTime.Name: TypeTime, + TypeMark.Name: TypeMark, + TypeIFIndex.Name: TypeIFIndex, + TypeARPHRD.Name: TypeARPHRD, + TypeRealm.Name: TypeRealm, + TypeClassID.Name: TypeClassID, + TypeUID.Name: TypeUID, + TypeGID.Name: TypeGID, + TypeCTState.Name: TypeCTState, + TypeCTDir.Name: TypeCTDir, + TypeCTStatus.Name: TypeCTStatus, + TypeICMP6Type.Name: TypeICMP6Type, + TypeCTLabel.Name: TypeCTLabel, + TypePktType.Name: TypePktType, + TypeICMPCode.Name: TypeICMPCode, + TypeICMPV6Code.Name: TypeICMPV6Code, + TypeICMPXCode.Name: TypeICMPXCode, + TypeDevGroup.Name: TypeDevGroup, + TypeDSCP.Name: TypeDSCP, + TypeECN.Name: TypeECN, + TypeFIBAddr.Name: TypeFIBAddr, + TypeBoolean.Name: TypeBoolean, + TypeCTEventBit.Name: TypeCTEventBit, + TypeIFName.Name: TypeIFName, + TypeIGMPType.Name: TypeIGMPType, + TypeTimeDate.Name: TypeTimeDate, + TypeTimeHour.Name: TypeTimeHour, + TypeTimeDay.Name: TypeTimeDay, + TypeCGroupV2.Name: TypeCGroupV2, } // ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c. @@ -206,6 +213,17 @@ func ConcatSetType(types ...SetDatatype) (SetDatatype, error) { return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil } +// ConcatSetTypeElements uses the ConcatSetType name to calculate and return +// a list of base types which were used to construct the concatenated type +func ConcatSetTypeElements(t SetDatatype) []SetDatatype { + names := strings.Split(t.Name, " . ") + types := make([]SetDatatype, len(names)) + for i, n := range names { + types[i] = nftDatatypes[n] + } + return types +} + // Set represents an nftables set. Anonymous sets are only valid within the // context of a single batch. type Set struct { @@ -217,15 +235,21 @@ type Set struct { Interval bool IsMap bool HasTimeout bool - Timeout time.Duration - KeyType SetDatatype - DataType SetDatatype + // Indicates that the set contains a concatenation + // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306 + Concatenation bool + Timeout time.Duration + KeyType SetDatatype + DataType SetDatatype } // SetElement represents a data point within a set. type SetElement struct { - Key []byte - Val []byte + Key []byte + Val []byte + // Field used for definition of ending interval value in concatenated types + // https://git.netfilter.org/libnftnl/tree/include/set_elem.h?id=e2514c0eff4da7e8e0aabd410f7b7d0b7564c880#n11 + KeyEnd []byte IntervalEnd bool // To support vmap, a caller must be able to pass Verdict type of data. // If IsMap is true and VerdictData is not nil, then Val of SetElement will be ignored @@ -250,6 +274,11 @@ func (s *SetElement) decode() func(b []byte) error { if err != nil { return err } + case NFTA_SET_ELEM_KEY_END: + s.KeyEnd, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } case unix.NFTA_SET_ELEM_DATA: s.Val, err = decodeElement(ad.Bytes()) if err != nil { @@ -325,7 +354,15 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e if err != nil { return nil, fmt.Errorf("marshal key %d: %v", i, err) } + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) + if len(v.KeyEnd) > 0 { + encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) + if err != nil { + return nil, fmt.Errorf("marshal key end %d: %v", i, err) + } + item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) + } if s.HasTimeout && v.Timeout != 0 { // Set has Timeout flag set, which means an individual element can specify its own timeout. item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(v.Timeout.Milliseconds()))}) @@ -431,6 +468,9 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if s.HasTimeout { flags |= unix.NFT_SET_TIMEOUT } + if s.Concatenation { + flags |= NFT_SET_CONCAT + } tableInfo := []netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, @@ -465,7 +505,33 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) } + if s.Concatenation { + // Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset + var concatDefinition []byte + elements := ConcatSetTypeElements(s.KeyType) + for i, v := range elements { + // Marshal base type size value + valData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, + }) + if err != nil { + return fmt.Errorf("fail to marshal element key size %d: %v", i, err) + } + // Marshal base type size description + descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, + }) + concatDefinition = append(concatDefinition, descSize...) + } + // Marshal all base type descriptions into concatenation size description + concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) + if err != nil { + return fmt.Errorf("fail to marshal concat definition %v", err) + } + // Marshal concat size description as set description + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: concatBytes}) + } if s.Anonymous || s.Constant || s.Interval { tableInfo = append(tableInfo, // Semantically useless - kept for binary compatability with nft @@ -585,6 +651,7 @@ func setsFromMsg(msg netlink.Message) (*Set, error) { set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0 set.IsMap = (flags & unix.NFT_SET_MAP) != 0 set.HasTimeout = (flags & unix.NFT_SET_TIMEOUT) != 0 + set.Concatenation = (flags & NFT_SET_CONCAT) != 0 case unix.NFTA_SET_KEY_TYPE: nftMagic := ad.Uint32() if invalidMagic, ok := validateKeyType(nftMagic); !ok { diff --git a/set_test.go b/set_test.go index 9e07834..db9092a 100644 --- a/set_test.go +++ b/set_test.go @@ -137,3 +137,47 @@ func TestConcatSetType(t *testing.T) { }) } } + +func TestConcatSetTypeElements(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + types []SetDatatype + }{ + { + name: "concat ip6 . inet_service", + types: []SetDatatype{TypeIP6Addr, TypeInetService}, + }, + { + name: "concat ip . inet_service . ip6", + types: []SetDatatype{TypeIPAddr, TypeInetService, TypeIP6Addr}, + }, + { + name: "concat inet_proto . inet_service", + types: []SetDatatype{TypeInetProto, TypeInetService}, + }, + { + name: "concat ip . ip . ip . ip", + types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + concat, err := ConcatSetType(tt.types...) + if err != nil { + return + } + elements := ConcatSetTypeElements(concat) + if got, want := len(elements), len(tt.types); got != want { + t.Errorf("invalid number of elements: expected %d, got %d", got, want) + } + for i, v := range tt.types { + if got, want := elements[i].GetNFTMagic(), v.GetNFTMagic(); got != want { + t.Errorf("invalid element on position %d: expected %d, got %d", i, got, want) + } + } + }) + } +}