diff --git a/nftables_test.go b/nftables_test.go index 11bd900..237c54e 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -2866,3 +2866,121 @@ func TestMap(t *testing.T) { } } } + +func TestVmap(t *testing.T) { + tests := []struct { + name string + chain *nftables.Chain + want [][]byte + set nftables.Set + element []nftables.SetElement + }{ + { + name: "map inet_service: drop verdict", + chain: &nftables.Chain{ + Name: "base-chain", + }, + want: [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip filter + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip filter base-chain + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0f\x00\x03\x00\x62\x61\x73\x65\x2d\x63\x68\x61\x69\x6e\x00\x00"), + // nft add map ip filter test-vmap { type inet_service: verdict\; elements={ 22: drop } \; } + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0e\x00\x02\x00\x74\x65\x73\x74\x2d\x76\x6d\x61\x70\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x08\x08\x00\x04\x00\x00\x00\x00\x0d\x08\x00\x05\x00\x00\x00\x00\x02\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\xff\xff\xff\x00\x08\x00\x07\x00\x00\x00\x00\x00"), + []byte("\x02\x00\x00\x00\x0e\x00\x02\x00\x74\x65\x73\x74\x2d\x76\x6d\x61\x70\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x24\x00\x03\x80\x20\x00\x01\x80\x0c\x00\x01\x80\x06\x00\x01\x00\x00\x16\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + }, + set: nftables.Set{ + Name: "test-vmap", + ID: uint32(1), + KeyType: nftables.TypeInetService, + DataType: nftables.TypeVerdict, + IsMap: true, + }, + element: []nftables.SetElement{ + { + Key: binaryutil.BigEndian.PutUint16(uint16(22)), + VerdictData: &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + }, + }, { + name: "map inet_service: jump to chain verdict", + chain: &nftables.Chain{ + Name: "base-chain", + }, + want: [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip filter + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip filter base-chain + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0f\x00\x03\x00\x62\x61\x73\x65\x2d\x63\x68\x61\x69\x6e\x00\x00"), + // nft add map ip filter test-vmap { type inet_service: verdict\; elements={ 22: jump fake-chain } \; } + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0e\x00\x02\x00\x74\x65\x73\x74\x2d\x76\x6d\x61\x70\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x08\x08\x00\x04\x00\x00\x00\x00\x0d\x08\x00\x05\x00\x00\x00\x00\x02\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\xff\xff\xff\x00\x08\x00\x07\x00\x00\x00\x00\x00"), + []byte("\x02\x00\x00\x00\x0e\x00\x02\x00\x74\x65\x73\x74\x2d\x76\x6d\x61\x70\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x34\x00\x03\x80\x30\x00\x01\x80\x0c\x00\x01\x80\x06\x00\x01\x00\x00\x16\x00\x00\x20\x00\x02\x80\x1c\x00\x02\x80\x08\x00\x01\x00\xff\xff\xff\xfd\x0f\x00\x02\x00\x66\x61\x6b\x65\x2d\x63\x68\x61\x69\x6e\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + }, + set: nftables.Set{ + Name: "test-vmap", + ID: uint32(1), + KeyType: nftables.TypeInetService, + DataType: nftables.TypeVerdict, + IsMap: true, + }, + element: []nftables.SetElement{ + { + Key: binaryutil.BigEndian.PutUint16(uint16(22)), + VerdictData: &expr.Verdict{ + Kind: unix.NFT_JUMP, + Chain: "fake-chain", + }, + }, + }, + }, + } + + for _, tt := range tests { + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(tt.want[idx]) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + got := b + if !bytes.Equal(got, tt.want[idx]) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx]))) + } + } + return req, nil + }, + } + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + tt.chain.Table = filter + c.AddChain(tt.chain) + tt.set.Table = filter + c.AddSet(&tt.set, tt.element) + if err := c.Flush(); err != nil { + t.Fatalf("Test \"%s\" failed with error: %+v", tt.name, err) + } + } +} diff --git a/set.go b/set.go index 6a338f7..6de4085 100644 --- a/set.go +++ b/set.go @@ -19,6 +19,8 @@ import ( "errors" "fmt" + "github.com/google/nftables/expr" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -40,6 +42,7 @@ type SetDatatype struct { // NFT datatypes. See: https://git.netfilter.org/nftables/tree/src/datatype.c var ( TypeInvalid = SetDatatype{Name: "invalid", nftMagic: 1} + TypeVerdict = SetDatatype{Name: "verdict", Bytes: 0, nftMagic: 1} TypeIPAddr = SetDatatype{Name: "ipv4_addr", Bytes: 4, nftMagic: 7} TypeIP6Addr = SetDatatype{Name: "ipv6_addr", Bytes: 16, nftMagic: 8} TypeEtherAddr = SetDatatype{Name: "ether_addr", Bytes: 6, nftMagic: 9} @@ -47,6 +50,7 @@ var ( TypeInetService = SetDatatype{Name: "inet_service", Bytes: 2, nftMagic: 13} nftDatatypes = []SetDatatype{ + TypeVerdict, TypeIPAddr, TypeIP6Addr, TypeEtherAddr, @@ -75,6 +79,10 @@ type SetElement struct { Key []byte Val []byte IntervalEnd bool + // To support vmap, a caller must be able to pass Verdict type of data. + // If IsMap is true and VerdictData is not nil, then Val of SetElement will be ignored + // and VerdictData will be wrapped into Attribute data. + VerdictData *expr.Verdict } func (s *SetElement) decode() func(b []byte) error { @@ -164,13 +172,46 @@ func (s *Set) makeElemList(vals []SetElement) ([]netlink.Attribute, error) { return nil, fmt.Errorf("marshal key %d: %v", i, err) } item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) - - if len(v.Val) > 0 { + // The following switch statement deal with 3 different types of elements. + // 1. v is an element of vmap + // 2. v is an element of a regular map + // 3. v is an element of a regular set (default) + switch { + case v.VerdictData != nil: + // Since VerdictData is not nil, v is vmap element, need to add to the attributes + encodedVal := []byte{} + encodedKind, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, + }) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + encodedVal = append(encodedVal, encodedKind...) + if len(v.VerdictData.Chain) != 0 { + encodedChain, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, + }) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + encodedVal = append(encodedVal, encodedChain...) + } + encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) + case len(v.Val) > 0: + // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) if err != nil { return nil, fmt.Errorf("marshal item %d: %v", i, err) } + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) + default: + // If niether of previous cases matche, it means 'e' is an element of a regular Set, no need to add to the attributes } encodedItem, err := netlink.MarshalAttributes(item) @@ -237,8 +278,15 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, } if s.IsMap { - tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)}, - netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + // Check if it is vmap case + if s.DataType.nftMagic == 1 { + // For Verdict data type, the expected magic is 0xfffff0 + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(unix.NFT_DATA_VERDICT))}, + netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + } else { + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)}, + netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + } } if s.Constant { // nft cli tool adds the number of elements to set/map's descriptor