diff --git a/expr/dynset.go b/expr/dynset.go new file mode 100644 index 0000000..1e990ab --- /dev/null +++ b/expr/dynset.go @@ -0,0 +1,90 @@ +// Copyright 2020 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "time" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Dynset represent a rule dynamically adding or updating a set or a map based on an incoming packet. +type Dynset struct { + SrcRegKey uint32 + SrcRegData uint32 + SetID uint32 + SetName string + Operation uint32 + Timeout time.Duration + Invert bool +} + +func (e *Dynset) marshal() ([]byte, error) { + // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c + var opAttrs []netlink.Attribute + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) + if e.SrcRegData != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_DATA, Data: binaryutil.BigEndian.PutUint32(e.SrcRegData)}) + } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_OP, Data: binaryutil.BigEndian.PutUint32(e.Operation)}) + if e.Timeout != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(e.Timeout.Milliseconds()))}) + } + if e.Invert { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(unix.NFT_DYNSET_F_INV)}) + } + opAttrs = append(opAttrs, + netlink.Attribute{Type: unix.NFTA_DYNSET_SET_NAME, Data: []byte(e.SetName + "\x00")}, + netlink.Attribute{Type: unix.NFTA_DYNSET_SET_ID, Data: binaryutil.BigEndian.PutUint32(e.SetID)}) + opData, err := netlink.MarshalAttributes(opAttrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: opData}, + }) +} + +func (e *Dynset) unmarshal(data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_DYNSET_SET_NAME: + e.SetName = ad.String() + case unix.NFTA_DYNSET_SET_ID: + e.SetID = ad.Uint32() + case unix.NFTA_DYNSET_SREG_KEY: + e.SrcRegKey = ad.Uint32() + case unix.NFTA_DYNSET_SREG_DATA: + e.SrcRegData = ad.Uint32() + case unix.NFTA_DYNSET_OP: + e.Operation = ad.Uint32() + case unix.NFTA_DYNSET_TIMEOUT: + e.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64())) + case unix.NFTA_DYNSET_FLAGS: + e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0 + } + } + return ad.Err() +} diff --git a/nftables_test.go b/nftables_test.go index 1fe24f6..5db5672 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -24,6 +24,7 @@ import ( "runtime" "strings" "testing" + "time" "github.com/google/nftables" "github.com/google/nftables/binaryutil" @@ -2487,6 +2488,97 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) { } } +func TestDynset(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + forward := c.AddChain(&nftables.Chain{ + Name: "forward", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + + set := &nftables.Set{ + Table: filter, + Name: "dynamic-set", + KeyType: nftables.TypeIPAddr, + HasTimeout: true, + Timeout: time.Duration(600 * time.Second), + } + if err := c.AddSet(set, nil); err != nil { + t.Errorf("c.AddSet(portSet) failed: %v", err) + } + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: forward, + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: uint32(12), + Len: uint32(4), + }, + &expr.Dynset{ + SrcRegKey: 1, + SetName: set.Name, + SetID: set.ID, + Operation: uint32(unix.NFT_DYNSET_OP_UPDATE), + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + rules, err := c.GetRule( + &nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }, + &nftables.Chain{ + Name: "forward", + }, + ) + if err != nil { + t.Fatal(err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + if got, want := len(rules[0].Exprs), 2; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + dynset, dynsetOk := rules[0].Exprs[1].(*expr.Dynset) + if !dynsetOk { + t.Fatalf("Exprs[0] is type %T, want *expr.Dynset", rules[0].Exprs[1]) + } + if want := (&expr.Dynset{ + SrcRegKey: 1, + SetName: set.Name, + Operation: uint32(unix.NFT_DYNSET_OP_UPDATE), + }); !reflect.DeepEqual(dynset, want) { + t.Errorf("dynset expr = %+v, wanted %+v", dynset, want) + } +} + func TestConfigureNATRedirect(t *testing.T) { // The want byte sequences come from stracing nft(8), e.g.: // strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat diff --git a/rule.go b/rule.go index a86ce97..6fda09b 100644 --- a/rule.go +++ b/rule.go @@ -240,6 +240,8 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) { e = &expr.Redir{} case "nat": e = &expr.NAT{} + case "dynset": + e = &expr.Dynset{} } if e == nil { // TODO: introduce an opaque expression type so that users know diff --git a/set.go b/set.go index 0577fd8..979f53a 100644 --- a/set.go +++ b/set.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/google/nftables/expr" @@ -130,7 +131,7 @@ type Set struct { Interval bool IsMap bool HasTimeout bool - Timeout uint64 + Timeout time.Duration KeyType SetDatatype DataType SetDatatype } @@ -145,7 +146,7 @@ type SetElement struct { // and VerdictData will be wrapped into Attribute data. VerdictData *expr.Verdict // To support aging of set elements - Timeout uint64 + Timeout time.Duration } func (s *SetElement) decode() func(b []byte) error { @@ -172,7 +173,7 @@ func (s *SetElement) decode() func(b []byte) error { flags := ad.Uint32() s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 case unix.NFTA_SET_ELEM_TIMEOUT: - s.Timeout = ad.Uint64() + s.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64())) } } return ad.Err() @@ -241,7 +242,7 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) if s.HasTimeout && v.Timeout != 0 { // Set has Timeout flag set, which means an individual element can specify its own timeout. - item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(v.Timeout)}) + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(v.Timeout.Milliseconds()))}) } // The following switch statement deal with 3 different types of elements. // 1. v is an element of vmap @@ -365,7 +366,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { } if s.HasTimeout && s.Timeout != 0 { // If Set's global timeout is specified, add it to set's attributes - tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(s.Timeout)}) + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(s.Timeout.Milliseconds()))}) } if s.Constant { // nft cli tool adds the number of elements to set/map's descriptor @@ -489,7 +490,7 @@ func setsFromMsg(msg netlink.Message) (*Set, error) { case unix.NFTA_SET_ID: set.ID = binary.BigEndian.Uint32(ad.Bytes()) case unix.NFTA_SET_TIMEOUT: - set.Timeout = binary.BigEndian.Uint64(ad.Bytes()) + set.Timeout = time.Duration(time.Millisecond * time.Duration(binary.BigEndian.Uint64(ad.Bytes()))) set.HasTimeout = true case unix.NFTA_SET_FLAGS: flags := ad.Uint32()