diff --git a/chain.go b/chain.go index 9928d63..e1bda29 100644 --- a/chain.go +++ b/chain.go @@ -223,9 +223,10 @@ func (cc *Conn) ListChainsOfTableFamily(family TableFamily) ([]*Chain, error) { } func chainFromMsg(msg netlink.Message) (*Chain, error) { - chainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) - if got, want := msg.Header.Type, chainHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + newChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) + delChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN) + if got, want1, want2 := msg.Header.Type, newChainHeaderType, delChainHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } var c Chain diff --git a/go.mod b/go.mod index 34c332a..a247513 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/google/nftables -go 1.17 +go 1.18 require ( github.com/mdlayher/netlink v1.7.1 diff --git a/monitor.go b/monitor.go new file mode 100644 index 0000000..cb3ac1c --- /dev/null +++ b/monitor.go @@ -0,0 +1,309 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "math" + "strings" + "sync" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type MonitorAction uint8 + +// Possible MonitorAction values. +const ( + MonitorActionNew MonitorAction = 1 << iota + MonitorActionDel + MonitorActionMask MonitorAction = (1 << iota) - 1 + MonitorActionAny MonitorAction = MonitorActionMask +) + +type MonitorObject uint32 + +// Possible MonitorObject values. +const ( + MonitorObjectTables MonitorObject = 1 << iota + MonitorObjectChains + MonitorObjectSets + MonitorObjectRules + MonitorObjectElements + MonitorObjectRuleset + MonitorObjectMask MonitorObject = (1 << iota) - 1 + MonitorObjectAny MonitorObject = MonitorObjectMask +) + +var ( + monitorFlags = map[MonitorAction]map[MonitorObject]uint32{ + MonitorActionAny: { + MonitorObjectAny: 0xffffffff, + MonitorObjectTables: 1<>8 != netlink.HeaderType(unix.NFNL_SUBSYS_NFTABLES) { + continue + } + msgType := msg.Header.Type & 0x00ff + if monitor.monitorFlags&1< reg 1 + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + // cmp eq reg 1 0x0245a8c0 + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: net.ParseIP("192.168.69.2").To4(), + }, + + // masq + &expr.Masq{}, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } + wg.Wait() + if gotTable.Family != nat.Family || gotTable.Name != nat.Name { + t.Fatal("no want table", gotTable.Family, gotTable.Name) + } + if gotChain.Type != postrouting.Type || gotChain.Name != postrouting.Name || + *gotChain.Hooknum != *postrouting.Hooknum { + t.Fatal("no want chain", gotChain.Type, gotChain.Name, gotChain.Hooknum) + } + if len(gotRule.Exprs) != len(rule.Exprs) { + t.Fatal("no want rule") + } +} diff --git a/obj.go b/obj.go index e975bb8..50f83f4 100644 --- a/obj.go +++ b/obj.go @@ -22,7 +22,10 @@ import ( "golang.org/x/sys/unix" ) -var objHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) +var ( + newObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) + delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ) +) // Obj represents a netfilter stateful object. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects @@ -125,8 +128,8 @@ func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) { } func objFromMsg(msg netlink.Message) (Obj, error) { - if got, want := msg.Header.Type, objHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { diff --git a/rule.go b/rule.go index 0ae9a53..8bcfda1 100644 --- a/rule.go +++ b/rule.go @@ -25,7 +25,10 @@ import ( "golang.org/x/sys/unix" ) -var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) +var ( + newRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) + delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE) +) type ruleOperation uint32 @@ -168,7 +171,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ - Type: ruleHeaderType, + Type: newRuleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), @@ -215,7 +218,7 @@ func (cc *Conn) DelRule(r *Rule) error { cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Type: delRuleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), data...), @@ -225,8 +228,8 @@ func (cc *Conn) DelRule(r *Rule) error { } func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) { - if got, want := msg.Header.Type, ruleHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newRuleHeaderType, delRuleHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", msg.Header.Type, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { diff --git a/set.go b/set.go index c267659..192c619 100644 --- a/set.go +++ b/set.go @@ -684,11 +684,14 @@ func (cc *Conn) FlushSet(s *Set) { }) } -var setHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) +var ( + newSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) + delSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET) +) func setsFromMsg(msg netlink.Message) (*Set, error) { - if got, want := msg.Header.Type, setHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newSetHeaderType, delSetHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { @@ -762,11 +765,14 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) { return dt, nil } -var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) +var ( + newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) + delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM) +) func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { - if got, want := msg.Header.Type, elemHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newElemHeaderType, delElemHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { diff --git a/table.go b/table.go index 24782ed..ff3b592 100644 --- a/table.go +++ b/table.go @@ -21,7 +21,10 @@ import ( "golang.org/x/sys/unix" ) -var tableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) +var ( + newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) + delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) +) // TableFamily specifies the address family for this table. type TableFamily byte @@ -150,8 +153,8 @@ func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) { } func tableFromMsg(msg netlink.Message) (*Table, error) { - if got, want := msg.Header.Type, tableHeaderType; got != want { - return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + if got, want1, want2 := msg.Header.Type, newTableHeaderType, delTableHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } var t Table diff --git a/util.go b/util.go index de88807..b0576e7 100644 --- a/util.go +++ b/util.go @@ -15,6 +15,8 @@ package nftables import ( + "encoding/binary" + "github.com/google/nftables/binaryutil" "golang.org/x/sys/unix" ) @@ -25,3 +27,20 @@ func extraHeader(family uint8, resID uint16) []byte { unix.NFNETLINK_V0, }, binaryutil.BigEndian.PutUint16(resID)...) } + +// General form of address family dependent message, see +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nfnetlink.h#29 +type NFGenMsg struct { + NFGenFamily uint8 + Version uint8 + ResourceID uint16 +} + +func (genmsg *NFGenMsg) Decode(b []byte) { + if len(b) < 16 { + return + } + genmsg.NFGenFamily = b[0] + genmsg.Version = b[1] + genmsg.ResourceID = binary.BigEndian.Uint16(b[2:]) +}