Add Dynset expression and unit test (#97)

* Add dynset expression and unit test

Signed-off-by: Serguei Bezverkhi <sbezverk@cisco.com>
This commit is contained in:
Serguei Bezverkhi 2020-02-10 05:14:20 -05:00 committed by GitHub
parent 9cdc3d048a
commit 1c56a1906f
4 changed files with 191 additions and 6 deletions

90
expr/dynset.go Normal file
View File

@ -0,0 +1,90 @@
// Copyright 2020 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expr
import (
"encoding/binary"
"time"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
// Dynset represent a rule dynamically adding or updating a set or a map based on an incoming packet.
type Dynset struct {
SrcRegKey uint32
SrcRegData uint32
SetID uint32
SetName string
Operation uint32
Timeout time.Duration
Invert bool
}
func (e *Dynset) marshal() ([]byte, error) {
// See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c
var opAttrs []netlink.Attribute
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)})
if e.SrcRegData != 0 {
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_DATA, Data: binaryutil.BigEndian.PutUint32(e.SrcRegData)})
}
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_OP, Data: binaryutil.BigEndian.PutUint32(e.Operation)})
if e.Timeout != 0 {
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(e.Timeout.Milliseconds()))})
}
if e.Invert {
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(unix.NFT_DYNSET_F_INV)})
}
opAttrs = append(opAttrs,
netlink.Attribute{Type: unix.NFTA_DYNSET_SET_NAME, Data: []byte(e.SetName + "\x00")},
netlink.Attribute{Type: unix.NFTA_DYNSET_SET_ID, Data: binaryutil.BigEndian.PutUint32(e.SetID)})
opData, err := netlink.MarshalAttributes(opAttrs)
if err != nil {
return nil, err
}
return netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")},
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: opData},
})
}
func (e *Dynset) unmarshal(data []byte) error {
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_DYNSET_SET_NAME:
e.SetName = ad.String()
case unix.NFTA_DYNSET_SET_ID:
e.SetID = ad.Uint32()
case unix.NFTA_DYNSET_SREG_KEY:
e.SrcRegKey = ad.Uint32()
case unix.NFTA_DYNSET_SREG_DATA:
e.SrcRegData = ad.Uint32()
case unix.NFTA_DYNSET_OP:
e.Operation = ad.Uint32()
case unix.NFTA_DYNSET_TIMEOUT:
e.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64()))
case unix.NFTA_DYNSET_FLAGS:
e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0
}
}
return ad.Err()
}

View File

@ -24,6 +24,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
@ -2487,6 +2488,97 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) {
} }
} }
func TestDynset(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
forward := c.AddChain(&nftables.Chain{
Name: "forward",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
set := &nftables.Set{
Table: filter,
Name: "dynamic-set",
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
Timeout: time.Duration(600 * time.Second),
}
if err := c.AddSet(set, nil); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12),
Len: uint32(4),
},
&expr.Dynset{
SrcRegKey: 1,
SetName: set.Name,
SetID: set.ID,
Operation: uint32(unix.NFT_DYNSET_OP_UPDATE),
},
},
})
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRule(
&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
},
&nftables.Chain{
Name: "forward",
},
)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 2; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
dynset, dynsetOk := rules[0].Exprs[1].(*expr.Dynset)
if !dynsetOk {
t.Fatalf("Exprs[0] is type %T, want *expr.Dynset", rules[0].Exprs[1])
}
if want := (&expr.Dynset{
SrcRegKey: 1,
SetName: set.Name,
Operation: uint32(unix.NFT_DYNSET_OP_UPDATE),
}); !reflect.DeepEqual(dynset, want) {
t.Errorf("dynset expr = %+v, wanted %+v", dynset, want)
}
}
func TestConfigureNATRedirect(t *testing.T) { func TestConfigureNATRedirect(t *testing.T) {
// The want byte sequences come from stracing nft(8), e.g.: // The want byte sequences come from stracing nft(8), e.g.:
// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat // strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat

View File

@ -240,6 +240,8 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) {
e = &expr.Redir{} e = &expr.Redir{}
case "nat": case "nat":
e = &expr.NAT{} e = &expr.NAT{}
case "dynset":
e = &expr.Dynset{}
} }
if e == nil { if e == nil {
// TODO: introduce an opaque expression type so that users know // TODO: introduce an opaque expression type so that users know

13
set.go
View File

@ -19,6 +19,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
@ -130,7 +131,7 @@ type Set struct {
Interval bool Interval bool
IsMap bool IsMap bool
HasTimeout bool HasTimeout bool
Timeout uint64 Timeout time.Duration
KeyType SetDatatype KeyType SetDatatype
DataType SetDatatype DataType SetDatatype
} }
@ -145,7 +146,7 @@ type SetElement struct {
// and VerdictData will be wrapped into Attribute data. // and VerdictData will be wrapped into Attribute data.
VerdictData *expr.Verdict VerdictData *expr.Verdict
// To support aging of set elements // To support aging of set elements
Timeout uint64 Timeout time.Duration
} }
func (s *SetElement) decode() func(b []byte) error { func (s *SetElement) decode() func(b []byte) error {
@ -172,7 +173,7 @@ func (s *SetElement) decode() func(b []byte) error {
flags := ad.Uint32() flags := ad.Uint32()
s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0
case unix.NFTA_SET_ELEM_TIMEOUT: case unix.NFTA_SET_ELEM_TIMEOUT:
s.Timeout = ad.Uint64() s.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64()))
} }
} }
return ad.Err() return ad.Err()
@ -241,7 +242,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if s.HasTimeout && v.Timeout != 0 { if s.HasTimeout && v.Timeout != 0 {
// Set has Timeout flag set, which means an individual element can specify its own timeout. // Set has Timeout flag set, which means an individual element can specify its own timeout.
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(v.Timeout)}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(v.Timeout.Milliseconds()))})
} }
// The following switch statement deal with 3 different types of elements. // The following switch statement deal with 3 different types of elements.
// 1. v is an element of vmap // 1. v is an element of vmap
@ -365,7 +366,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
} }
if s.HasTimeout && s.Timeout != 0 { if s.HasTimeout && s.Timeout != 0 {
// If Set's global timeout is specified, add it to set's attributes // If Set's global timeout is specified, add it to set's attributes
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(s.Timeout)}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(s.Timeout.Milliseconds()))})
} }
if s.Constant { if s.Constant {
// nft cli tool adds the number of elements to set/map's descriptor // nft cli tool adds the number of elements to set/map's descriptor
@ -489,7 +490,7 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
case unix.NFTA_SET_ID: case unix.NFTA_SET_ID:
set.ID = binary.BigEndian.Uint32(ad.Bytes()) set.ID = binary.BigEndian.Uint32(ad.Bytes())
case unix.NFTA_SET_TIMEOUT: case unix.NFTA_SET_TIMEOUT:
set.Timeout = binary.BigEndian.Uint64(ad.Bytes()) set.Timeout = time.Duration(time.Millisecond * time.Duration(binary.BigEndian.Uint64(ad.Bytes())))
set.HasTimeout = true set.HasTimeout = true
case unix.NFTA_SET_FLAGS: case unix.NFTA_SET_FLAGS:
flags := ad.Uint32() flags := ad.Uint32()