Added dynset exprs support (#173)

fixes https://github.com/google/nftables/issues/172

- Rearranged `exprFromMsg` function
- Rearranged limit expr marshaling logic
- Added dynamic flag for sets
- Implemented connlimit
- Added missing constants 
- Added tests
This commit is contained in:
turekt 2022-07-29 16:32:59 +00:00 committed by GitHub
parent a346d51f53
commit ec1e802faf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 501 additions and 112 deletions

70
expr/connlimit.go Normal file
View File

@ -0,0 +1,70 @@
// Copyright 2019 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expr
import (
"encoding/binary"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
const (
// Per https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1167
NFTA_CONNLIMIT_UNSPEC = iota
NFTA_CONNLIMIT_COUNT
NFTA_CONNLIMIT_FLAGS
NFT_CONNLIMIT_F_INV = 1
)
// Per https://git.netfilter.org/libnftnl/tree/src/expr/connlimit.c?id=84d12cfacf8ddd857a09435f3d982ab6250d250c
type Connlimit struct {
Count uint32
Flags uint32
}
func (e *Connlimit) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: NFTA_CONNLIMIT_COUNT, Data: binaryutil.BigEndian.PutUint32(e.Count)},
{Type: NFTA_CONNLIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)},
})
if err != nil {
return nil, err
}
return netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_EXPR_NAME, Data: []byte("connlimit\x00")},
{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
})
}
func (e *Connlimit) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case NFTA_CONNLIMIT_COUNT:
e.Count = binaryutil.BigEndian.Uint32(ad.Bytes())
case NFTA_CONNLIMIT_FLAGS:
e.Flags = binaryutil.BigEndian.Uint32(ad.Bytes())
}
}
return ad.Err()
}

View File

@ -19,10 +19,18 @@ import (
"time"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/internal/parseexprfunc"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
// Not yet supported by unix package
// https://cs.opensource.google/go/x/sys/+/c6bc011c:unix/ztypes_linux.go;l=2027-2036
const (
NFTA_DYNSET_EXPRESSIONS = 0xa
NFT_DYNSET_F_EXPR = (1 << 1)
)
// Dynset represent a rule dynamically adding or updating a set or a map based on an incoming packet.
type Dynset struct {
SrcRegKey uint32
@ -32,6 +40,7 @@ type Dynset struct {
Operation uint32
Timeout time.Duration
Invert bool
Exprs []Any
}
func (e *Dynset) marshal(fam byte) ([]byte, error) {
@ -45,12 +54,43 @@ func (e *Dynset) marshal(fam byte) ([]byte, error) {
if e.Timeout != 0 {
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(e.Timeout.Milliseconds()))})
}
var flags uint32
if e.Invert {
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(unix.NFT_DYNSET_F_INV)})
flags |= unix.NFT_DYNSET_F_INV
}
opAttrs = append(opAttrs,
netlink.Attribute{Type: unix.NFTA_DYNSET_SET_NAME, Data: []byte(e.SetName + "\x00")},
netlink.Attribute{Type: unix.NFTA_DYNSET_SET_ID, Data: binaryutil.BigEndian.PutUint32(e.SetID)})
// Per https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n170
if len(e.Exprs) > 0 {
flags |= NFT_DYNSET_F_EXPR
switch len(e.Exprs) {
case 1:
exprData, err := Marshal(fam, e.Exprs[0])
if err != nil {
return nil, err
}
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_EXPR, Data: exprData})
default:
var elemAttrs []netlink.Attribute
for _, ex := range e.Exprs {
exprData, err := Marshal(fam, ex)
if err != nil {
return nil, err
}
elemAttrs = append(elemAttrs, netlink.Attribute{Type: unix.NFTA_LIST_ELEM, Data: exprData})
}
elemData, err := netlink.MarshalAttributes(elemAttrs)
if err != nil {
return nil, err
}
opAttrs = append(opAttrs, netlink.Attribute{Type: NFTA_DYNSET_EXPRESSIONS, Data: elemData})
}
}
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)})
opData, err := netlink.MarshalAttributes(opAttrs)
if err != nil {
return nil, err
@ -84,7 +124,26 @@ func (e *Dynset) unmarshal(fam byte, data []byte) error {
e.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64()))
case unix.NFTA_DYNSET_FLAGS:
e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0
case unix.NFTA_DYNSET_EXPR:
exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad, ad.Bytes())
if err != nil {
return err
}
e.setInterfaceExprs(exprs)
case NFTA_DYNSET_EXPRESSIONS:
exprs, err := parseexprfunc.ParseExprMsgFunc(fam, ad.Bytes())
if err != nil {
return err
}
e.setInterfaceExprs(exprs)
}
}
return ad.Err()
}
func (e *Dynset) setInterfaceExprs(exprs []interface{}) {
e.Exprs = make([]Any, len(exprs))
for i := range exprs {
e.Exprs[i] = exprs[i].(Any)
}
}

View File

@ -19,10 +19,41 @@ import (
"encoding/binary"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/internal/parseexprfunc"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
func init() {
parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) {
exprs, err := exprsFromBytes(fam, ad, b)
if err != nil {
return nil, err
}
result := make([]interface{}, len(exprs))
for idx, expr := range exprs {
result[idx] = expr
}
return result, nil
}
parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte) ([]interface{}, error) {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return nil, err
}
ad.ByteOrder = binary.BigEndian
var exprs []interface{}
for ad.Next() {
e, err := parseexprfunc.ParseExprBytesFunc(fam, ad, b)
if err != nil {
return e, err
}
exprs = append(exprs, e...)
}
return exprs, ad.Err()
}
}
// Marshal serializes the specified expression into a byte slice.
func Marshal(fam byte, e Any) ([]byte, error) {
return e.marshal(fam)
@ -33,6 +64,96 @@ func Unmarshal(fam byte, data []byte, e Any) error {
return e.unmarshal(fam, data)
}
// exprsFromBytes parses nested raw expressions bytes
// to construct nftables expressions
func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, error) {
var exprs []Any
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
var name string
for ad.Next() {
switch ad.Type() {
case unix.NFTA_EXPR_NAME:
name = ad.String()
if name == "notrack" {
e := &Notrack{}
exprs = append(exprs, e)
}
case unix.NFTA_EXPR_DATA:
var e Any
switch name {
case "ct":
e = &Ct{}
case "range":
e = &Range{}
case "meta":
e = &Meta{}
case "cmp":
e = &Cmp{}
case "counter":
e = &Counter{}
case "payload":
e = &Payload{}
case "lookup":
e = &Lookup{}
case "immediate":
e = &Immediate{}
case "bitwise":
e = &Bitwise{}
case "redir":
e = &Redir{}
case "nat":
e = &NAT{}
case "limit":
e = &Limit{}
case "quota":
e = &Quota{}
case "dynset":
e = &Dynset{}
case "log":
e = &Log{}
case "exthdr":
e = &Exthdr{}
case "match":
e = &Match{}
case "target":
e = &Target{}
case "connlimit":
e = &Connlimit{}
}
if e == nil {
// TODO: introduce an opaque expression type so that users know
// something is here.
continue // unsupported expression type
}
ad.Do(func(b []byte) error {
if err := Unmarshal(fam, b, e); err != nil {
return err
}
// Verdict expressions are a special-case of immediate expressions, so
// if the expression is an immediate writing nothing into the verdict
// register (invalid), re-parse it as a verdict expression.
if imm, isImmediate := e.(*Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 {
e = &Verdict{}
if err := Unmarshal(fam, b, e); err != nil {
return err
}
}
exprs = append(exprs, e)
return nil
})
}
}
return ad.Err()
})
return exprs, ad.Err()
}
// Any is an interface implemented by any expression type.
type Any interface {
marshal(fam byte) ([]byte, error)

View File

@ -72,24 +72,16 @@ type Limit struct {
}
func (l *Limit) marshal(fam byte) ([]byte, error) {
var flags uint32
if l.Over {
flags = unix.NFT_LIMIT_F_INV
}
attrs := []netlink.Attribute{
{Type: unix.NFTA_LIMIT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(l.Type))},
{Type: unix.NFTA_LIMIT_RATE, Data: binaryutil.BigEndian.PutUint64(l.Rate)},
{Type: unix.NFTA_LIMIT_UNIT, Data: binaryutil.BigEndian.PutUint64(uint64(l.Unit))},
}
if l.Over {
attrs = append(attrs, netlink.Attribute{
Type: unix.NFTA_LIMIT_FLAGS,
Data: binaryutil.BigEndian.PutUint32(unix.NFT_LIMIT_F_INV),
})
}
if l.Burst != 0 {
attrs = append(attrs, netlink.Attribute{
Type: unix.NFTA_LIMIT_BURST,
Data: binaryutil.BigEndian.PutUint32(l.Burst),
})
{Type: unix.NFTA_LIMIT_BURST, Data: binaryutil.BigEndian.PutUint32(l.Burst)},
{Type: unix.NFTA_LIMIT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(l.Type))},
{Type: unix.NFTA_LIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)},
}
data, err := netlink.MarshalAttributes(attrs)

View File

@ -0,0 +1,10 @@
package parseexprfunc
import (
"github.com/mdlayher/netlink"
)
var (
ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error)
ParseExprMsgFunc func(fam byte, b []byte) ([]interface{}, error)
)

View File

@ -3313,6 +3313,221 @@ func TestDynset(t *testing.T) {
}
}
func TestDynsetWithOneExpression(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
table := &nftables.Table{
Name: "filter",
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: "forward",
Hooknum: nftables.ChainHookForward,
Table: table,
Priority: 0,
Type: nftables.ChainTypeFilter,
}
set := &nftables.Set{
Table: table,
Name: "myMeter",
KeyType: nftables.TypeIPAddr,
Dynamic: true,
}
c.AddTable(table)
c.AddChain(chain)
if err := c.AddSet(set, nil); err != nil {
t.Errorf("c.AddSet(myMeter) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12),
Len: uint32(4),
},
&expr.Dynset{
SrcRegKey: 1,
SetName: set.Name,
Operation: uint32(unix.NFT_DYNSET_OP_ADD),
Exprs: []expr.Any{
&expr.Limit{
Type: expr.LimitTypePkts,
Rate: 200,
Unit: expr.LimitTimeSecond,
Burst: 5,
},
},
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
}
c.AddRule(rule)
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRules(table, chain)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 3; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
dynset, dynsetOk := rules[0].Exprs[1].(*expr.Dynset)
if !dynsetOk {
t.Fatalf("Exprs[0] is type %T, want *expr.Dynset", rules[0].Exprs[1])
}
if got, want := len(dynset.Exprs), 1; got != want {
t.Fatalf("unexpected number of dynset.Exprs: got %d, want %d", got, want)
}
if got, want := dynset.SetName, set.Name; got != want {
t.Fatalf("dynset.SetName is %s, want %s", got, want)
}
if want := (&expr.Limit{
Type: expr.LimitTypePkts,
Rate: 200,
Unit: expr.LimitTimeSecond,
Burst: 5,
}); !reflect.DeepEqual(dynset.Exprs[0], want) {
t.Errorf("dynset.Exprs[0] expr = %+v, wanted %+v", dynset.Exprs[0], want)
}
}
func TestDynsetWithMultipleExpressions(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
table := &nftables.Table{
Name: "filter",
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: "forward",
Hooknum: nftables.ChainHookForward,
Table: table,
Priority: 0,
Type: nftables.ChainTypeFilter,
}
set := &nftables.Set{
Table: table,
Name: "myMeter",
KeyType: nftables.TypeIPAddr,
Dynamic: true,
}
c.AddTable(table)
c.AddChain(chain)
if err := c.AddSet(set, nil); err != nil {
t.Errorf("c.AddSet(myMeter) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12),
Len: uint32(4),
},
&expr.Dynset{
SrcRegKey: 1,
SetName: set.Name,
Operation: uint32(unix.NFT_DYNSET_OP_ADD),
Exprs: []expr.Any{
&expr.Connlimit{
Count: 20,
Flags: 1,
},
&expr.Limit{
Type: expr.LimitTypePkts,
Rate: 10,
Unit: expr.LimitTimeSecond,
Burst: 2,
},
},
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
}
c.AddRule(rule)
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRules(table, chain)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 3; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
dynset, dynsetOk := rules[0].Exprs[1].(*expr.Dynset)
if !dynsetOk {
t.Fatalf("Exprs[0] is type %T, want *expr.Dynset", rules[0].Exprs[1])
}
if got, want := len(dynset.Exprs), 2; got != want {
t.Fatalf("unexpected number of dynset.Exprs: got %d, want %d", got, want)
}
if got, want := dynset.SetName, set.Name; got != want {
t.Fatalf("dynset.SetName is %s, want %s", got, want)
}
if want := (&expr.Connlimit{
Count: 20,
Flags: 1,
}); !reflect.DeepEqual(dynset.Exprs[0], want) {
t.Errorf("dynset.Exprs[0] expr = %+v, wanted %+v", dynset.Exprs[0], want)
}
if want := (&expr.Limit{
Type: expr.LimitTypePkts,
Rate: 10,
Unit: expr.LimitTimeSecond,
Burst: 2,
}); !reflect.DeepEqual(dynset.Exprs[1], want) {
t.Errorf("dynset.Exprs[1] expr = %+v, wanted %+v", dynset.Exprs[1], want)
}
}
func TestConfigureNATRedirect(t *testing.T) {
// The want byte sequences come from stracing nft(8), e.g.:
// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat

105
rule.go
View File

@ -20,6 +20,7 @@ import (
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/internal/parseexprfunc"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
@ -215,99 +216,6 @@ func (cc *Conn) DelRule(r *Rule) error {
return nil
}
func exprsFromMsg(fam TableFamily, b []byte) ([]expr.Any, error) {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return nil, err
}
ad.ByteOrder = binary.BigEndian
var exprs []expr.Any
for ad.Next() {
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
var name string
for ad.Next() {
switch ad.Type() {
case unix.NFTA_EXPR_NAME:
name = ad.String()
if name == "notrack" {
e := &expr.Notrack{}
exprs = append(exprs, e)
}
case unix.NFTA_EXPR_DATA:
var e expr.Any
switch name {
case "ct":
e = &expr.Ct{}
case "range":
e = &expr.Range{}
case "meta":
e = &expr.Meta{}
case "cmp":
e = &expr.Cmp{}
case "counter":
e = &expr.Counter{}
case "payload":
e = &expr.Payload{}
case "lookup":
e = &expr.Lookup{}
case "immediate":
e = &expr.Immediate{}
case "bitwise":
e = &expr.Bitwise{}
case "redir":
e = &expr.Redir{}
case "nat":
e = &expr.NAT{}
case "limit":
e = &expr.Limit{}
case "quota":
e = &expr.Quota{}
case "dynset":
e = &expr.Dynset{}
case "log":
e = &expr.Log{}
case "exthdr":
e = &expr.Exthdr{}
case "match":
e = &expr.Match{}
case "target":
e = &expr.Target{}
}
if e == nil {
// TODO: introduce an opaque expression type so that users know
// something is here.
continue // unsupported expression type
}
ad.Do(func(b []byte) error {
if err := expr.Unmarshal(byte(fam), b, e); err != nil {
return err
}
// Verdict expressions are a special-case of immediate expressions, so
// if the expression is an immediate writing nothing into the verdict
// register (invalid), re-parse it as a verdict expression.
if imm, isImmediate := e.(*expr.Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 {
e = &expr.Verdict{}
if err := expr.Unmarshal(byte(fam), b, e); err != nil {
return err
}
}
exprs = append(exprs, e)
return nil
})
}
}
return ad.Err()
})
}
return exprs, ad.Err()
}
func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {
if got, want := msg.Header.Type, ruleHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
@ -329,8 +237,15 @@ func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {
r.Chain = &Chain{Name: ad.String()}
case unix.NFTA_RULE_EXPRESSIONS:
ad.Do(func(b []byte) error {
r.Exprs, err = exprsFromMsg(fam, b)
return err
exprs, err := parseexprfunc.ParseExprMsgFunc(byte(fam), b)
if err != nil {
return err
}
r.Exprs = make([]expr.Any, len(exprs))
for i := range exprs {
r.Exprs[i] = exprs[i].(expr.Any)
}
return nil
})
case unix.NFTA_RULE_POSITION:
r.Position = ad.Uint64()

7
set.go
View File

@ -235,6 +235,10 @@ type Set struct {
Interval bool
IsMap bool
HasTimeout bool
// Can be updated per evaluation path, per `nft list ruleset`
// indicates that set contains "flags dynamic"
// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n298
Dynamic bool
// Indicates that the set contains a concatenation
// https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306
Concatenation bool
@ -468,6 +472,9 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
if s.HasTimeout {
flags |= unix.NFT_SET_TIMEOUT
}
if s.Dynamic {
flags |= unix.NFT_SET_EVAL
}
if s.Concatenation {
flags |= NFT_SET_CONCAT
}

View File

@ -104,7 +104,7 @@ func (cc *Conn) ListTables() ([]*Table, error) {
return cc.ListTablesOfFamily(TableFamilyUnspecified)
}
// ListTables returns currently configured tables for the specified table family
// ListTablesOfFamily returns currently configured tables for the specified table family
// in the kernel. It lists all tables if family is TableFamilyUnspecified.
func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) {
conn, closer, err := cc.netlinkConn()