Add support for maps (#55)

This commit is contained in:
Serguei Bezverkhi 2019-08-27 11:52:21 -04:00 committed by Michael Stapelberg
parent 85a78b5285
commit 1ad7112fd7
2 changed files with 118 additions and 15 deletions

View File

@ -2784,3 +2784,85 @@ func TestFib(t *testing.T) {
} }
} }
} }
func TestMap(t *testing.T) {
tests := []struct {
name string
chain *nftables.Chain
want [][]byte
set nftables.Set
element []nftables.SetElement
}{
{
name: "map inet_service: inet_service 1 element",
chain: &nftables.Chain{
Name: "base-chain",
},
want: [][]byte{
// batch begin
[]byte("\x00\x00\x00\x0a"),
// nft add table ip filter
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
// nft add chain ip filter base-chain
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0f\x00\x03\x00\x62\x61\x73\x65\x2d\x63\x68\x61\x69\x6e\x00\x00"),
// nft add map ip filter test-map { type inet_service: inet_service\; elements={ 22: 1024 } \; }
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0d\x00\x02\x00\x74\x65\x73\x74\x2d\x6d\x61\x70\x00\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x08\x08\x00\x04\x00\x00\x00\x00\x0d\x08\x00\x05\x00\x00\x00\x00\x02\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x0d\x08\x00\x07\x00\x00\x00\x00\x02"),
[]byte("\x02\x00\x00\x00\x0d\x00\x02\x00\x74\x65\x73\x74\x2d\x6d\x61\x70\x00\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x20\x00\x03\x80\x1c\x00\x01\x80\x0c\x00\x01\x80\x06\x00\x01\x00\x00\x16\x00\x00\x0c\x00\x02\x80\x06\x00\x01\x00\x04\x00\x00\x00"),
// batch end
[]byte("\x00\x00\x00\x0a"),
},
set: nftables.Set{
Name: "test-map",
ID: uint32(1),
KeyType: nftables.TypeInetService,
DataType: nftables.TypeInetService,
IsMap: true,
},
element: []nftables.SetElement{
{
Key: binaryutil.BigEndian.PutUint16(uint16(22)),
Val: binaryutil.BigEndian.PutUint16(uint16(1024)),
},
},
},
}
for _, tt := range tests {
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(tt.want[idx]) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
got := b
if !bytes.Equal(got, tt.want[idx]) {
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx])))
}
}
return req, nil
},
}
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
tt.chain.Table = filter
c.AddChain(tt.chain)
tt.set.Table = filter
c.AddSet(&tt.set, tt.element)
if err := c.Flush(); err != nil {
t.Fatalf("Test \"%s\" failed with error: %+v", tt.name, err)
}
}
}

49
set.go
View File

@ -64,9 +64,10 @@ type Set struct {
Anonymous bool Anonymous bool
Constant bool Constant bool
Interval bool Interval bool
IsMap bool
KeyType SetDatatype KeyType SetDatatype
DataLen int DataType SetDatatype
} }
// SetElement represents a data point within a set. // SetElement represents a data point within a set.
@ -158,14 +159,14 @@ func (s *Set) makeElemList(vals []SetElement) ([]netlink.Attribute, error) {
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)})
} }
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_SET_ELEM_KEY, Data: v.Key}}) encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal key %d: %v", i, err) return nil, fmt.Errorf("marshal key %d: %v", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if len(v.Val) > 0 { if len(v.Val) > 0 {
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_SET_ELEM_DATA, Data: v.Val}}) encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err) return nil, fmt.Errorf("marshal item %d: %v", i, err)
} }
@ -207,12 +208,11 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
s.ID = allocSetID s.ID = allocSetID
if s.Anonymous { if s.Anonymous {
s.Name = "__set%d" s.Name = "__set%d"
if s.IsMap {
s.Name = "__map%d"
}
} }
} }
setData := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
})
var flags uint32 var flags uint32
if s.Anonymous { if s.Anonymous {
@ -224,7 +224,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
if s.Interval { if s.Interval {
flags |= unix.NFT_SET_INTERVAL flags |= unix.NFT_SET_INTERVAL
} }
if s.DataLen > 0 { if s.IsMap {
flags |= unix.NFT_SET_MAP flags |= unix.NFT_SET_MAP
} }
@ -236,12 +236,23 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)}, {Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)},
{Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, {Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
} }
if s.DataLen > 0 { if s.IsMap {
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(unix.NFT_DATA_VALUE)}, tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)},
netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(uint32(s.DataLen))}) netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)})
}
if s.Constant {
// nft cli tool adds the number of elements to set/map's descriptor
// It make sense to do only if a set or map are constant, otherwise skip NFTA_SET_DESC attribute
numberOfElements, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
})
if err != nil {
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)
}
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
} }
if s.Anonymous || s.Constant || s.Interval { if s.Anonymous || s.Constant || s.Interval {
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: setData}, tableInfo = append(tableInfo,
// Semantically useless - kept for binary compatability with nft // Semantically useless - kept for binary compatability with nft
netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")})
} }
@ -332,13 +343,12 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.Name = ad.String() set.Name = ad.String()
case unix.NFTA_SET_ID: case unix.NFTA_SET_ID:
set.ID = binary.BigEndian.Uint32(ad.Bytes()) set.ID = binary.BigEndian.Uint32(ad.Bytes())
case unix.NFTA_SET_DATA_LEN:
set.DataLen = int(ad.Uint32())
case unix.NFTA_SET_FLAGS: case unix.NFTA_SET_FLAGS:
flags := ad.Uint32() flags := ad.Uint32()
set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0 set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0
set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0 set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0
set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0 set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0
set.IsMap = (flags & unix.NFTA_SET_TABLE) != 0
case unix.NFTA_SET_KEY_TYPE: case unix.NFTA_SET_KEY_TYPE:
nftMagic := ad.Uint32() nftMagic := ad.Uint32()
for _, dt := range nftDatatypes { for _, dt := range nftDatatypes {
@ -350,6 +360,17 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
if set.KeyType.nftMagic == 0 { if set.KeyType.nftMagic == 0 {
return nil, fmt.Errorf("could not determine datatype %x", nftMagic) return nil, fmt.Errorf("could not determine datatype %x", nftMagic)
} }
case unix.NFTA_SET_DATA_TYPE:
nftMagic := ad.Uint32()
for _, dt := range nftDatatypes {
if nftMagic == dt.nftMagic {
set.DataType = dt
break
}
}
if set.DataType.nftMagic == 0 {
return nil, fmt.Errorf("could not determine datatype %x", nftMagic)
}
} }
} }
return &set, nil return &set, nil