// Copyright 2018 Google LLC. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nftables import ( "encoding/binary" "errors" "fmt" "strings" "time" "github.com/google/nftables/expr" "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) // SetConcatTypeBits defines concatination bits, originally defined in // https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002 const ( SetConcatTypeBits = 6 SetConcatTypeMask = (1 << SetConcatTypeBits) - 1 // below consts added because not found in go unix package // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306 NFT_SET_CONCAT = 0x80 // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n330 NFTA_SET_DESC_CONCAT = 2 // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n428 NFTA_SET_ELEM_KEY_END = 10 ) var allocSetID uint32 // SetDatatype represents a datatype declared by nft. type SetDatatype struct { Name string Bytes uint32 // nftMagic represents the magic value that nft uses for // certain types (ie: IP addresses). We populate SET_KEY_TYPE // identically, so `nft list ...` commands produce correct output. nftMagic uint32 } // GetNFTMagic returns a custom datatype based on user's parameters func (s *SetDatatype) GetNFTMagic() uint32 { return s.nftMagic } // SetNFTMagic returns a custom datatype based on user's parameters func (s *SetDatatype) SetNFTMagic(nftMagic uint32) { s.nftMagic = nftMagic } // NFT datatypes. See: https://git.netfilter.org/nftables/tree/include/datatype.h var ( TypeInvalid = SetDatatype{Name: "invalid", nftMagic: 0} TypeVerdict = SetDatatype{Name: "verdict", Bytes: 0, nftMagic: 1} TypeNFProto = SetDatatype{Name: "nf_proto", Bytes: 1, nftMagic: 2} TypeBitmask = SetDatatype{Name: "bitmask", Bytes: 0, nftMagic: 3} TypeInteger = SetDatatype{Name: "integer", Bytes: 4, nftMagic: 4} TypeString = SetDatatype{Name: "string", Bytes: 0, nftMagic: 5} TypeLLAddr = SetDatatype{Name: "ll_addr", Bytes: 0, nftMagic: 6} TypeIPAddr = SetDatatype{Name: "ipv4_addr", Bytes: 4, nftMagic: 7} TypeIP6Addr = SetDatatype{Name: "ipv6_addr", Bytes: 16, nftMagic: 8} TypeEtherAddr = SetDatatype{Name: "ether_addr", Bytes: 6, nftMagic: 9} TypeEtherType = SetDatatype{Name: "ether_type", Bytes: 2, nftMagic: 10} TypeARPOp = SetDatatype{Name: "arp_op", Bytes: 2, nftMagic: 11} TypeInetProto = SetDatatype{Name: "inet_proto", Bytes: 1, nftMagic: 12} TypeInetService = SetDatatype{Name: "inet_service", Bytes: 2, nftMagic: 13} TypeICMPType = SetDatatype{Name: "icmp_type", Bytes: 1, nftMagic: 14} TypeTCPFlag = SetDatatype{Name: "tcp_flag", Bytes: 1, nftMagic: 15} TypeDCCPPktType = SetDatatype{Name: "dccp_pkttype", Bytes: 1, nftMagic: 16} TypeMHType = SetDatatype{Name: "mh_type", Bytes: 1, nftMagic: 17} TypeTime = SetDatatype{Name: "time", Bytes: 8, nftMagic: 18} TypeMark = SetDatatype{Name: "mark", Bytes: 4, nftMagic: 19} TypeIFIndex = SetDatatype{Name: "iface_index", Bytes: 4, nftMagic: 20} TypeARPHRD = SetDatatype{Name: "iface_type", Bytes: 2, nftMagic: 21} TypeRealm = SetDatatype{Name: "realm", Bytes: 4, nftMagic: 22} TypeClassID = SetDatatype{Name: "classid", Bytes: 4, nftMagic: 23} TypeUID = SetDatatype{Name: "uid", Bytes: sizeOfUIDT, nftMagic: 24} TypeGID = SetDatatype{Name: "gid", Bytes: sizeOfGIDT, nftMagic: 25} TypeCTState = SetDatatype{Name: "ct_state", Bytes: 4, nftMagic: 26} TypeCTDir = SetDatatype{Name: "ct_dir", Bytes: 1, nftMagic: 27} TypeCTStatus = SetDatatype{Name: "ct_status", Bytes: 4, nftMagic: 28} TypeICMP6Type = SetDatatype{Name: "icmpv6_type", Bytes: 1, nftMagic: 29} TypeCTLabel = SetDatatype{Name: "ct_label", Bytes: ctLabelBitSize / 8, nftMagic: 30} TypePktType = SetDatatype{Name: "pkt_type", Bytes: 1, nftMagic: 31} TypeICMPCode = SetDatatype{Name: "icmp_code", Bytes: 1, nftMagic: 32} TypeICMPV6Code = SetDatatype{Name: "icmpv6_code", Bytes: 1, nftMagic: 33} TypeICMPXCode = SetDatatype{Name: "icmpx_code", Bytes: 1, nftMagic: 34} TypeDevGroup = SetDatatype{Name: "devgroup", Bytes: 4, nftMagic: 35} TypeDSCP = SetDatatype{Name: "dscp", Bytes: 1, nftMagic: 36} TypeECN = SetDatatype{Name: "ecn", Bytes: 1, nftMagic: 37} TypeFIBAddr = SetDatatype{Name: "fib_addrtype", Bytes: 4, nftMagic: 38} TypeBoolean = SetDatatype{Name: "boolean", Bytes: 1, nftMagic: 39} TypeCTEventBit = SetDatatype{Name: "ct_event", Bytes: 4, nftMagic: 40} TypeIFName = SetDatatype{Name: "ifname", Bytes: ifNameSize, nftMagic: 41} TypeIGMPType = SetDatatype{Name: "igmp_type", Bytes: 1, nftMagic: 42} TypeTimeDate = SetDatatype{Name: "time", Bytes: 8, nftMagic: 43} TypeTimeHour = SetDatatype{Name: "hour", Bytes: 8, nftMagic: 44} TypeTimeDay = SetDatatype{Name: "day", Bytes: 1, nftMagic: 45} TypeCGroupV2 = SetDatatype{Name: "cgroupsv2", Bytes: 8, nftMagic: 46} nftDatatypes = map[string]SetDatatype{ TypeVerdict.Name: TypeVerdict, TypeNFProto.Name: TypeNFProto, TypeBitmask.Name: TypeBitmask, TypeInteger.Name: TypeInteger, TypeString.Name: TypeString, TypeLLAddr.Name: TypeLLAddr, TypeIPAddr.Name: TypeIPAddr, TypeIP6Addr.Name: TypeIP6Addr, TypeEtherAddr.Name: TypeEtherAddr, TypeEtherType.Name: TypeEtherType, TypeARPOp.Name: TypeARPOp, TypeInetProto.Name: TypeInetProto, TypeInetService.Name: TypeInetService, TypeICMPType.Name: TypeICMPType, TypeTCPFlag.Name: TypeTCPFlag, TypeDCCPPktType.Name: TypeDCCPPktType, TypeMHType.Name: TypeMHType, TypeTime.Name: TypeTime, TypeMark.Name: TypeMark, TypeIFIndex.Name: TypeIFIndex, TypeARPHRD.Name: TypeARPHRD, TypeRealm.Name: TypeRealm, TypeClassID.Name: TypeClassID, TypeUID.Name: TypeUID, TypeGID.Name: TypeGID, TypeCTState.Name: TypeCTState, TypeCTDir.Name: TypeCTDir, TypeCTStatus.Name: TypeCTStatus, TypeICMP6Type.Name: TypeICMP6Type, TypeCTLabel.Name: TypeCTLabel, TypePktType.Name: TypePktType, TypeICMPCode.Name: TypeICMPCode, TypeICMPV6Code.Name: TypeICMPV6Code, TypeICMPXCode.Name: TypeICMPXCode, TypeDevGroup.Name: TypeDevGroup, TypeDSCP.Name: TypeDSCP, TypeECN.Name: TypeECN, TypeFIBAddr.Name: TypeFIBAddr, TypeBoolean.Name: TypeBoolean, TypeCTEventBit.Name: TypeCTEventBit, TypeIFName.Name: TypeIFName, TypeIGMPType.Name: TypeIGMPType, TypeTimeDate.Name: TypeTimeDate, TypeTimeHour.Name: TypeTimeHour, TypeTimeDay.Name: TypeTimeDay, TypeCGroupV2.Name: TypeCGroupV2, } // ctLabelBitSize is defined in https://git.netfilter.org/nftables/tree/src/ct.c. ctLabelBitSize uint32 = 128 // ifNameSize is called IFNAMSIZ in linux/if.h. ifNameSize uint32 = 16 // bits/typesizes.h sizeOfUIDT uint32 = 4 sizeOfGIDT uint32 = 4 ) // ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow. var ErrTooManyTypes = errors.New("too many types to concat") // MustConcatSetType does the same as ConcatSetType, but panics instead of an // error. It simplifies safe initialization of global variables. func MustConcatSetType(types ...SetDatatype) SetDatatype { t, err := ConcatSetType(types...) if err != nil { panic(err) } return t } // ConcatSetType constructs a new SetDatatype which consists of a concatenation // of the passed types. It returns ErrTooManyTypes, if nftMagic would overflow // (more than 5 types). func ConcatSetType(types ...SetDatatype) (SetDatatype, error) { if len(types) > 32/SetConcatTypeBits { return SetDatatype{}, ErrTooManyTypes } var magic, bytes uint32 names := make([]string, len(types)) for i, t := range types { bytes += t.Bytes // concatenated types pad the length to multiples of the register size (4 bytes) // see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162 if t.Bytes%4 != 0 { bytes += 4 - (t.Bytes % 4) } names[i] = t.Name magic <<= SetConcatTypeBits magic |= t.nftMagic & SetConcatTypeMask } return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil } // ConcatSetTypeElements uses the ConcatSetType name to calculate and return // a list of base types which were used to construct the concatenated type func ConcatSetTypeElements(t SetDatatype) []SetDatatype { names := strings.Split(t.Name, " . ") types := make([]SetDatatype, len(names)) for i, n := range names { types[i] = nftDatatypes[n] } return types } // Set represents an nftables set. Anonymous sets are only valid within the // context of a single batch. type Set struct { Table *Table ID uint32 Name string Anonymous bool Constant bool Interval bool IsMap bool HasTimeout bool // Indicates that the set contains a concatenation // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306 Concatenation bool Timeout time.Duration KeyType SetDatatype DataType SetDatatype } // SetElement represents a data point within a set. type SetElement struct { Key []byte Val []byte // Field used for definition of ending interval value in concatenated types // https://git.netfilter.org/libnftnl/tree/include/set_elem.h?id=e2514c0eff4da7e8e0aabd410f7b7d0b7564c880#n11 KeyEnd []byte IntervalEnd bool // To support vmap, a caller must be able to pass Verdict type of data. // If IsMap is true and VerdictData is not nil, then Val of SetElement will be ignored // and VerdictData will be wrapped into Attribute data. VerdictData *expr.Verdict // To support aging of set elements Timeout time.Duration } func (s *SetElement) decode() func(b []byte) error { return func(b []byte) error { ad, err := netlink.NewAttributeDecoder(b) if err != nil { return fmt.Errorf("failed to create nested attribute decoder: %v", err) } ad.ByteOrder = binary.BigEndian for ad.Next() { switch ad.Type() { case unix.NFTA_SET_ELEM_KEY: s.Key, err = decodeElement(ad.Bytes()) if err != nil { return err } case NFTA_SET_ELEM_KEY_END: s.KeyEnd, err = decodeElement(ad.Bytes()) if err != nil { return err } case unix.NFTA_SET_ELEM_DATA: s.Val, err = decodeElement(ad.Bytes()) if err != nil { return err } case unix.NFTA_SET_ELEM_FLAGS: flags := ad.Uint32() s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 case unix.NFTA_SET_ELEM_TIMEOUT: s.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64())) } } return ad.Err() } } func decodeElement(d []byte) ([]byte, error) { ad, err := netlink.NewAttributeDecoder(d) if err != nil { return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err) } ad.ByteOrder = binary.BigEndian var b []byte for ad.Next() { switch ad.Type() { case unix.NFTA_SET_ELEM_KEY: fallthrough case unix.NFTA_SET_ELEM_DATA: b = ad.Bytes() } } if err := ad.Err(); err != nil { return nil, err } return b, nil } // SetAddElements applies data points to an nftables set. func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { cc.mu.Lock() defer cc.mu.Unlock() if s.Anonymous { return errors.New("anonymous sets cannot be updated") } elements, err := s.makeElemList(vals, s.ID) if err != nil { return err } cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), }) return nil } func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { var elements []netlink.Attribute for i, v := range vals { item := make([]netlink.Attribute, 0) var flags uint32 if v.IntervalEnd { flags |= unix.NFT_SET_ELEM_INTERVAL_END item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)}) } encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) if err != nil { return nil, fmt.Errorf("marshal key %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) if len(v.KeyEnd) > 0 { encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) if err != nil { return nil, fmt.Errorf("marshal key end %d: %v", i, err) } item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) } if s.HasTimeout && v.Timeout != 0 { // Set has Timeout flag set, which means an individual element can specify its own timeout. item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(v.Timeout.Milliseconds()))}) } // The following switch statement deal with 3 different types of elements. // 1. v is an element of vmap // 2. v is an element of a regular map // 3. v is an element of a regular set (default) switch { case v.VerdictData != nil: // Since VerdictData is not nil, v is vmap element, need to add to the attributes encodedVal := []byte{} encodedKind, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, }) if err != nil { return nil, fmt.Errorf("marshal item %d: %v", i, err) } encodedVal = append(encodedVal, encodedKind...) if len(v.VerdictData.Chain) != 0 { encodedChain, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, }) if err != nil { return nil, fmt.Errorf("marshal item %d: %v", i, err) } encodedVal = append(encodedVal, encodedChain...) } encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) if err != nil { return nil, fmt.Errorf("marshal item %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) case len(v.Val) > 0: // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) if err != nil { return nil, fmt.Errorf("marshal item %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) default: // If niether of previous cases matche, it means 'e' is an element of a regular Set, no need to add to the attributes } encodedItem, err := netlink.MarshalAttributes(item) if err != nil { return nil, fmt.Errorf("marshal item %d: %v", i, err) } elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) } encodedElem, err := netlink.MarshalAttributes(elements) if err != nil { return nil, fmt.Errorf("marshal elements: %v", err) } return []netlink.Attribute{ {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)}, {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, }, nil } // AddSet adds the specified Set. func (cc *Conn) AddSet(s *Set, vals []SetElement) error { cc.mu.Lock() defer cc.mu.Unlock() // Based on nft implementation & linux source. // Link: https://github.com/torvalds/linux/blob/49a57857aeea06ca831043acbb0fa5e0f50602fd/net/netfilter/nf_tables_api.c#L3395 // Another reference: https://git.netfilter.org/nftables/tree/src if s.Anonymous && !s.Constant { return errors.New("anonymous structs must be constant") } if s.ID == 0 { allocSetID++ s.ID = allocSetID if s.Anonymous { s.Name = "__set%d" if s.IsMap { s.Name = "__map%d" } } } var flags uint32 if s.Anonymous { flags |= unix.NFT_SET_ANONYMOUS } if s.Constant { flags |= unix.NFT_SET_CONSTANT } if s.Interval { flags |= unix.NFT_SET_INTERVAL } if s.IsMap { flags |= unix.NFT_SET_MAP } if s.HasTimeout { flags |= unix.NFT_SET_TIMEOUT } if s.Concatenation { flags |= NFT_SET_CONCAT } tableInfo := []netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, {Type: unix.NFTA_SET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, {Type: unix.NFTA_SET_KEY_TYPE, Data: binaryutil.BigEndian.PutUint32(s.KeyType.nftMagic)}, {Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)}, {Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, } if s.IsMap { // Check if it is vmap case if s.DataType.nftMagic == 1 { // For Verdict data type, the expected magic is 0xfffff0 tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(unix.NFT_DATA_VERDICT))}, netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) } else { tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)}, netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) } } if s.HasTimeout && s.Timeout != 0 { // If Set's global timeout is specified, add it to set's attributes tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(s.Timeout.Milliseconds()))}) } if s.Constant { // nft cli tool adds the number of elements to set/map's descriptor // It make sense to do only if a set or map are constant, otherwise skip NFTA_SET_DESC attribute numberOfElements, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, }) if err != nil { return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) } if s.Concatenation { // Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset var concatDefinition []byte elements := ConcatSetTypeElements(s.KeyType) for i, v := range elements { // Marshal base type size value valData, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, }) if err != nil { return fmt.Errorf("fail to marshal element key size %d: %v", i, err) } // Marshal base type size description descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, }) concatDefinition = append(concatDefinition, descSize...) } // Marshal all base type descriptions into concatenation size description concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) if err != nil { return fmt.Errorf("fail to marshal concat definition %v", err) } // Marshal concat size description as set description tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: concatBytes}) } if s.Anonymous || s.Constant || s.Interval { tableInfo = append(tableInfo, // Semantically useless - kept for binary compatability with nft netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) } cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(tableInfo)...), }) // Set the values of the set if initial values were provided. if len(vals) > 0 { hdrType := unix.NFT_MSG_NEWSETELEM elements, err := s.makeElemList(vals, s.ID) if err != nil { return err } cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), }) } return nil } // DelSet deletes a specific set, along with all elements it contains. func (cc *Conn) DelSet(s *Set) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, }, Data: append(extraHeader(uint8(s.Table.Family), 0), data...), }) } // SetDeleteElements deletes data points from an nftables set. func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { cc.mu.Lock() defer cc.mu.Unlock() if s.Anonymous { return errors.New("anonymous sets cannot be updated") } elements, err := s.makeElemList(vals, s.ID) if err != nil { return err } cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), }) return nil } // FlushSet deletes all data points from an nftables set. func (cc *Conn) FlushSet(s *Set) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, }, Data: append(extraHeader(uint8(s.Table.Family), 0), data...), }) } var setHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) func setsFromMsg(msg netlink.Message) (*Set, error) { if got, want := msg.Header.Type, setHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { return nil, err } ad.ByteOrder = binary.BigEndian var set Set for ad.Next() { switch ad.Type() { case unix.NFTA_SET_NAME: set.Name = ad.String() case unix.NFTA_SET_ID: set.ID = binary.BigEndian.Uint32(ad.Bytes()) case unix.NFTA_SET_TIMEOUT: set.Timeout = time.Duration(time.Millisecond * time.Duration(binary.BigEndian.Uint64(ad.Bytes()))) set.HasTimeout = true case unix.NFTA_SET_FLAGS: flags := ad.Uint32() set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0 set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0 set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0 set.IsMap = (flags & unix.NFT_SET_MAP) != 0 set.HasTimeout = (flags & unix.NFT_SET_TIMEOUT) != 0 set.Concatenation = (flags & NFT_SET_CONCAT) != 0 case unix.NFTA_SET_KEY_TYPE: nftMagic := ad.Uint32() if invalidMagic, ok := validateKeyType(nftMagic); !ok { return nil, fmt.Errorf("could not determine key type %+v", invalidMagic) } set.KeyType.nftMagic = nftMagic for _, dt := range nftDatatypes { // If this is a non-concatenated type, we can assign the descriptor. if nftMagic == dt.nftMagic { set.KeyType = dt break } } case unix.NFTA_SET_DATA_TYPE: nftMagic := ad.Uint32() // Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1 if nftMagic == 0xffffff00 { set.KeyType = TypeVerdict break } for _, dt := range nftDatatypes { if nftMagic == dt.nftMagic { set.DataType = dt break } } if set.DataType.nftMagic == 0 { return nil, fmt.Errorf("could not determine data type %x", nftMagic) } } } return &set, nil } func validateKeyType(bits uint32) ([]uint32, bool) { var unpackTypes []uint32 var invalidTypes []uint32 found := false valid := true for bits != 0 { unpackTypes = append(unpackTypes, bits&SetConcatTypeMask) bits = bits >> SetConcatTypeBits } for _, t := range unpackTypes { for _, dt := range nftDatatypes { if t == dt.nftMagic { found = true } } if !found { invalidTypes = append(invalidTypes, t) valid = false } found = false } return invalidTypes, valid } var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) func elementsFromMsg(msg netlink.Message) ([]SetElement, error) { if got, want := msg.Header.Type, elemHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { return nil, err } ad.ByteOrder = binary.BigEndian var elements []SetElement for ad.Next() { b := ad.Bytes() if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err } ad.ByteOrder = binary.BigEndian for ad.Next() { var elem SetElement switch ad.Type() { case unix.NFTA_LIST_ELEM: ad.Do(elem.decode()) } elements = append(elements, elem) } } } return elements, nil } // GetSets returns the sets in the specified table. func (cc *Conn) GetSets(t *Table) ([]*Set, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, }) if err != nil { return nil, err } message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, }, Data: append(extraHeader(uint8(t.Family), 0), data...), } if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { return nil, fmt.Errorf("SendMessages: %v", err) } reply, err := conn.Receive() if err != nil { return nil, fmt.Errorf("Receive: %v", err) } var sets []*Set for _, msg := range reply { s, err := setsFromMsg(msg) if err != nil { return nil, err } s.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} sets = append(sets, s) } return sets, nil } // GetSetByName returns the set in the specified table if matching name is found. func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(name + "\x00")}, }) if err != nil { return nil, err } message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSET), Flags: netlink.Request | netlink.Acknowledge, }, Data: append(extraHeader(uint8(t.Family), 0), data...), } if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { return nil, fmt.Errorf("SendMessages: %w", err) } reply, err := conn.Receive() if err != nil { return nil, fmt.Errorf("Receive: %w", err) } if len(reply) != 1 { return nil, fmt.Errorf("Receive: expected to receive 1 message but got %d", len(reply)) } rs, err := setsFromMsg(reply[0]) if err != nil { return nil, err } rs.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} return rs, nil } // GetSetElements returns the elements in the specified set. func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) if err != nil { return nil, err } message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, }, Data: append(extraHeader(uint8(s.Table.Family), 0), data...), } if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { return nil, fmt.Errorf("SendMessages: %v", err) } reply, err := conn.Receive() if err != nil { return nil, fmt.Errorf("Receive: %v", err) } var elems []SetElement for _, msg := range reply { s, err := elementsFromMsg(msg) if err != nil { return nil, err } elems = append(elems, s...) } return elems, nil }