From 50b37861c0d74d6ca2ccdd7c6dfe59f1611f2aa4 Mon Sep 17 00:00:00 2001 From: Auztin Zhai Date: Thu, 7 Dec 2023 09:55:00 -0500 Subject: [PATCH] feat: add monitor on table chain rule set setelem and obj events --- chain.go | 7 +- go.sum | 4 + monitor.go | 302 ++++++++++++++++++++++++++++++++++++++++++++++++ monitor_test.go | 117 +++++++++++++++++++ obj.go | 9 +- rule.go | 13 ++- set.go | 18 ++- table.go | 9 +- util.go | 19 +++ 9 files changed, 478 insertions(+), 20 deletions(-) create mode 100644 monitor.go create mode 100644 monitor_test.go diff --git a/chain.go b/chain.go index 9928d63..e1bda29 100644 --- a/chain.go +++ b/chain.go @@ -223,9 +223,10 @@ func (cc *Conn) ListChainsOfTableFamily(family TableFamily) ([]*Chain, error) { } func chainFromMsg(msg netlink.Message) (*Chain, error) { - chainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) - if got, want := msg.Header.Type, chainHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + newChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) + delChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN) + if got, want1, want2 := msg.Header.Type, newChainHeaderType, delChainHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } var c Chain diff --git a/go.sum b/go.sum index 3c98fad..7a07830 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTm github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ= github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw= github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc= +github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= +github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -26,6 +28,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/monitor.go b/monitor.go new file mode 100644 index 0000000..a49958d --- /dev/null +++ b/monitor.go @@ -0,0 +1,302 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "sync" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type MonitorAction uint8 + +// Possible MonitorAction values. +const ( + MonitorActionNew MonitorAction = 1 << iota + MonitorActionDel + MonitorActionMask MonitorAction = (1 << iota) - 1 + MonitorActionAny MonitorAction = MonitorActionMask +) + +type MonitorObject uint32 + +// Possible MonitorObject values. +const ( + MonitorObjectTables MonitorObject = 1 << iota + MonitorObjectChains + MonitorObjectSets + MonitorObjectRules + MonitorObjectElements + MonitorObjectRuleset + MonitorObjectMask MonitorObject = (1 << iota) - 1 + MonitorObjectAny MonitorObject = MonitorObjectMask +) + +var ( + monitorFlags map[MonitorAction]map[MonitorObject]uint32 + monitorFlagsInitOnce sync.Once +) + +// A lazy init function to define flags. +func lazyInit() { + monitorFlagsInitOnce.Do(func() { + monitorFlags = map[MonitorAction]map[MonitorObject]uint32{ + MonitorActionAny: { + MonitorObjectAny: 0xffffffff, + MonitorObjectTables: 1<>8, netlink.HeaderType(unix.NFNL_SUBSYS_NFTABLES); got != want { + continue + } + msgType := msg.Header.Type & 0x00ff + if monitor.monitorFlags&1< reg 1 + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + // cmp eq reg 1 0x0245a8c0 + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: net.ParseIP("192.168.69.2").To4(), + }, + + // masq + &expr.Masq{}, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // It takes time for the kernel to take effect + time.Sleep(time.Second) + monitor.Close() + wg.Wait() + if err != nil { + t.Fatal(err) + } + if gotTable.Family != nat.Family || gotTable.Name != nat.Name { + t.Fatal("no want table", gotTable.Family, gotTable.Name) + } + if gotChain.Type != postrouting.Type || gotChain.Name != postrouting.Name || + *gotChain.Hooknum != *postrouting.Hooknum { + t.Fatal("no want chain", gotChain.Type, gotChain.Name, gotChain.Hooknum) + } + if len(gotRule.Exprs) != len(rule.Exprs) { + t.Fatal("no want rule") + } +} diff --git a/obj.go b/obj.go index 08d43f4..6201e7f 100644 --- a/obj.go +++ b/obj.go @@ -22,7 +22,10 @@ import ( "golang.org/x/sys/unix" ) -var objHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) +var ( + newObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) + delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ) +) // Obj represents a netfilter stateful object. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects @@ -125,8 +128,8 @@ func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) { } func objFromMsg(msg netlink.Message) (Obj, error) { - if got, want := msg.Header.Type, objHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { diff --git a/rule.go b/rule.go index 0ae9a53..8bcfda1 100644 --- a/rule.go +++ b/rule.go @@ -25,7 +25,10 @@ import ( "golang.org/x/sys/unix" ) -var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) +var ( + newRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) + delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE) +) type ruleOperation uint32 @@ -168,7 +171,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ - Type: ruleHeaderType, + Type: newRuleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), @@ -215,7 +218,7 @@ func (cc *Conn) DelRule(r *Rule) error { cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Type: delRuleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), data...), @@ -225,8 +228,8 @@ func (cc *Conn) DelRule(r *Rule) error { } func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) { - if got, want := msg.Header.Type, ruleHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newRuleHeaderType, delRuleHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", msg.Header.Type, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { diff --git a/set.go b/set.go index c267659..192c619 100644 --- a/set.go +++ b/set.go @@ -684,11 +684,14 @@ func (cc *Conn) FlushSet(s *Set) { }) } -var setHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) +var ( + newSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) + delSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET) +) func setsFromMsg(msg netlink.Message) (*Set, error) { - if got, want := msg.Header.Type, setHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newSetHeaderType, delSetHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { @@ -762,11 +765,14 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) { return dt, nil } -var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) +var ( + newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) + delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM) +) func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { - if got, want := msg.Header.Type, elemHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newElemHeaderType, delElemHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { diff --git a/table.go b/table.go index 24782ed..ff3b592 100644 --- a/table.go +++ b/table.go @@ -21,7 +21,10 @@ import ( "golang.org/x/sys/unix" ) -var tableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) +var ( + newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) + delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) +) // TableFamily specifies the address family for this table. type TableFamily byte @@ -150,8 +153,8 @@ func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) { } func tableFromMsg(msg netlink.Message) (*Table, error) { - if got, want := msg.Header.Type, tableHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newTableHeaderType, delTableHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } var t Table diff --git a/util.go b/util.go index de88807..b0576e7 100644 --- a/util.go +++ b/util.go @@ -15,6 +15,8 @@ package nftables import ( + "encoding/binary" + "github.com/google/nftables/binaryutil" "golang.org/x/sys/unix" ) @@ -25,3 +27,20 @@ func extraHeader(family uint8, resID uint16) []byte { unix.NFNETLINK_V0, }, binaryutil.BigEndian.PutUint16(resID)...) } + +// General form of address family dependent message, see +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nfnetlink.h#29 +type NFGenMsg struct { + NFGenFamily uint8 + Version uint8 + ResourceID uint16 +} + +func (genmsg *NFGenMsg) Decode(b []byte) { + if len(b) < 16 { + return + } + genmsg.NFGenFamily = b[0] + genmsg.Version = b[1] + genmsg.ResourceID = binary.BigEndian.Uint16(b[2:]) +}