Add support for maps (#55)
This commit is contained in:
parent
85a78b5285
commit
1ad7112fd7
|
@ -2784,3 +2784,85 @@ func TestFib(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
chain *nftables.Chain
|
||||||
|
want [][]byte
|
||||||
|
set nftables.Set
|
||||||
|
element []nftables.SetElement
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "map inet_service: inet_service 1 element",
|
||||||
|
chain: &nftables.Chain{
|
||||||
|
Name: "base-chain",
|
||||||
|
},
|
||||||
|
want: [][]byte{
|
||||||
|
// batch begin
|
||||||
|
[]byte("\x00\x00\x00\x0a"),
|
||||||
|
// nft add table ip filter
|
||||||
|
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||||||
|
// nft add chain ip filter base-chain
|
||||||
|
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0f\x00\x03\x00\x62\x61\x73\x65\x2d\x63\x68\x61\x69\x6e\x00\x00"),
|
||||||
|
// nft add map ip filter test-map { type inet_service: inet_service\; elements={ 22: 1024 } \; }
|
||||||
|
[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0d\x00\x02\x00\x74\x65\x73\x74\x2d\x6d\x61\x70\x00\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x08\x08\x00\x04\x00\x00\x00\x00\x0d\x08\x00\x05\x00\x00\x00\x00\x02\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x0d\x08\x00\x07\x00\x00\x00\x00\x02"),
|
||||||
|
[]byte("\x02\x00\x00\x00\x0d\x00\x02\x00\x74\x65\x73\x74\x2d\x6d\x61\x70\x00\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x20\x00\x03\x80\x1c\x00\x01\x80\x0c\x00\x01\x80\x06\x00\x01\x00\x00\x16\x00\x00\x0c\x00\x02\x80\x06\x00\x01\x00\x04\x00\x00\x00"),
|
||||||
|
// batch end
|
||||||
|
[]byte("\x00\x00\x00\x0a"),
|
||||||
|
},
|
||||||
|
set: nftables.Set{
|
||||||
|
Name: "test-map",
|
||||||
|
ID: uint32(1),
|
||||||
|
KeyType: nftables.TypeInetService,
|
||||||
|
DataType: nftables.TypeInetService,
|
||||||
|
IsMap: true,
|
||||||
|
},
|
||||||
|
element: []nftables.SetElement{
|
||||||
|
{
|
||||||
|
Key: binaryutil.BigEndian.PutUint16(uint16(22)),
|
||||||
|
Val: binaryutil.BigEndian.PutUint16(uint16(1024)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
c := &nftables.Conn{
|
||||||
|
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
|
||||||
|
for idx, msg := range req {
|
||||||
|
b, err := msg.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(b) < 16 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b = b[16:]
|
||||||
|
if len(tt.want[idx]) == 0 {
|
||||||
|
t.Errorf("no want entry for message %d: %x", idx, b)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
got := b
|
||||||
|
if !bytes.Equal(got, tt.want[idx]) {
|
||||||
|
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(tt.want[idx])))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := c.AddTable(&nftables.Table{
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
Name: "filter",
|
||||||
|
})
|
||||||
|
|
||||||
|
tt.chain.Table = filter
|
||||||
|
c.AddChain(tt.chain)
|
||||||
|
tt.set.Table = filter
|
||||||
|
c.AddSet(&tt.set, tt.element)
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatalf("Test \"%s\" failed with error: %+v", tt.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
49
set.go
49
set.go
|
@ -64,9 +64,10 @@ type Set struct {
|
||||||
Anonymous bool
|
Anonymous bool
|
||||||
Constant bool
|
Constant bool
|
||||||
Interval bool
|
Interval bool
|
||||||
|
IsMap bool
|
||||||
|
|
||||||
KeyType SetDatatype
|
KeyType SetDatatype
|
||||||
DataLen int
|
DataType SetDatatype
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetElement represents a data point within a set.
|
// SetElement represents a data point within a set.
|
||||||
|
@ -158,14 +159,14 @@ func (s *Set) makeElemList(vals []SetElement) ([]netlink.Attribute, error) {
|
||||||
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)})
|
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)})
|
||||||
}
|
}
|
||||||
|
|
||||||
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_SET_ELEM_KEY, Data: v.Key}})
|
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal key %d: %v", i, err)
|
return nil, fmt.Errorf("marshal key %d: %v", i, err)
|
||||||
}
|
}
|
||||||
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
|
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
|
||||||
|
|
||||||
if len(v.Val) > 0 {
|
if len(v.Val) > 0 {
|
||||||
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_SET_ELEM_DATA, Data: v.Val}})
|
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal item %d: %v", i, err)
|
return nil, fmt.Errorf("marshal item %d: %v", i, err)
|
||||||
}
|
}
|
||||||
|
@ -207,12 +208,11 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
s.ID = allocSetID
|
s.ID = allocSetID
|
||||||
if s.Anonymous {
|
if s.Anonymous {
|
||||||
s.Name = "__set%d"
|
s.Name = "__set%d"
|
||||||
|
if s.IsMap {
|
||||||
|
s.Name = "__map%d"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setData := cc.marshalAttr([]netlink.Attribute{
|
|
||||||
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
|
|
||||||
})
|
|
||||||
|
|
||||||
var flags uint32
|
var flags uint32
|
||||||
if s.Anonymous {
|
if s.Anonymous {
|
||||||
|
@ -224,7 +224,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
if s.Interval {
|
if s.Interval {
|
||||||
flags |= unix.NFT_SET_INTERVAL
|
flags |= unix.NFT_SET_INTERVAL
|
||||||
}
|
}
|
||||||
if s.DataLen > 0 {
|
if s.IsMap {
|
||||||
flags |= unix.NFT_SET_MAP
|
flags |= unix.NFT_SET_MAP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,12 +236,23 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
{Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)},
|
{Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)},
|
||||||
{Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
|
{Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)},
|
||||||
}
|
}
|
||||||
if s.DataLen > 0 {
|
if s.IsMap {
|
||||||
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(unix.NFT_DATA_VALUE)},
|
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)},
|
||||||
netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(uint32(s.DataLen))})
|
netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)})
|
||||||
|
}
|
||||||
|
if s.Constant {
|
||||||
|
// nft cli tool adds the number of elements to set/map's descriptor
|
||||||
|
// It make sense to do only if a set or map are constant, otherwise skip NFTA_SET_DESC attribute
|
||||||
|
numberOfElements, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||||
|
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)
|
||||||
|
}
|
||||||
|
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
|
||||||
}
|
}
|
||||||
if s.Anonymous || s.Constant || s.Interval {
|
if s.Anonymous || s.Constant || s.Interval {
|
||||||
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: setData},
|
tableInfo = append(tableInfo,
|
||||||
// Semantically useless - kept for binary compatability with nft
|
// Semantically useless - kept for binary compatability with nft
|
||||||
netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")})
|
netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")})
|
||||||
}
|
}
|
||||||
|
@ -332,13 +343,12 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
|
||||||
set.Name = ad.String()
|
set.Name = ad.String()
|
||||||
case unix.NFTA_SET_ID:
|
case unix.NFTA_SET_ID:
|
||||||
set.ID = binary.BigEndian.Uint32(ad.Bytes())
|
set.ID = binary.BigEndian.Uint32(ad.Bytes())
|
||||||
case unix.NFTA_SET_DATA_LEN:
|
|
||||||
set.DataLen = int(ad.Uint32())
|
|
||||||
case unix.NFTA_SET_FLAGS:
|
case unix.NFTA_SET_FLAGS:
|
||||||
flags := ad.Uint32()
|
flags := ad.Uint32()
|
||||||
set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0
|
set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0
|
||||||
set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0
|
set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0
|
||||||
set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0
|
set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0
|
||||||
|
set.IsMap = (flags & unix.NFTA_SET_TABLE) != 0
|
||||||
case unix.NFTA_SET_KEY_TYPE:
|
case unix.NFTA_SET_KEY_TYPE:
|
||||||
nftMagic := ad.Uint32()
|
nftMagic := ad.Uint32()
|
||||||
for _, dt := range nftDatatypes {
|
for _, dt := range nftDatatypes {
|
||||||
|
@ -350,6 +360,17 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
|
||||||
if set.KeyType.nftMagic == 0 {
|
if set.KeyType.nftMagic == 0 {
|
||||||
return nil, fmt.Errorf("could not determine datatype %x", nftMagic)
|
return nil, fmt.Errorf("could not determine datatype %x", nftMagic)
|
||||||
}
|
}
|
||||||
|
case unix.NFTA_SET_DATA_TYPE:
|
||||||
|
nftMagic := ad.Uint32()
|
||||||
|
for _, dt := range nftDatatypes {
|
||||||
|
if nftMagic == dt.nftMagic {
|
||||||
|
set.DataType = dt
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if set.DataType.nftMagic == 0 {
|
||||||
|
return nil, fmt.Errorf("could not determine datatype %x", nftMagic)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &set, nil
|
return &set, nil
|
||||||
|
|
Loading…
Reference in New Issue