// Copyright 2018 Google LLC. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nftables import ( "encoding/binary" "fmt" "math" "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) // ChainHook specifies at which step in packet processing the Chain should be // executed. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_hooks type ChainHook uint32 // Possible ChainHook values. var ( ChainHookPrerouting *ChainHook = ChainHookRef(unix.NF_INET_PRE_ROUTING) ChainHookInput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_IN) ChainHookForward *ChainHook = ChainHookRef(unix.NF_INET_FORWARD) ChainHookOutput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_OUT) ChainHookPostrouting *ChainHook = ChainHookRef(unix.NF_INET_POST_ROUTING) ChainHookIngress *ChainHook = ChainHookRef(unix.NF_NETDEV_INGRESS) ChainHookEgress *ChainHook = ChainHookRef(unix.NF_NETDEV_EGRESS) ) // ChainHookRef returns a pointer to a ChainHookRef value. func ChainHookRef(h ChainHook) *ChainHook { return &h } // ChainPriority orders the chain relative to Netfilter internal operations. See // also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_priority type ChainPriority int32 // Possible ChainPriority values. var ( // from /usr/include/linux/netfilter_ipv4.h ChainPriorityFirst *ChainPriority = ChainPriorityRef(math.MinInt32) ChainPriorityConntrackDefrag *ChainPriority = ChainPriorityRef(-400) ChainPriorityRaw *ChainPriority = ChainPriorityRef(-300) ChainPrioritySELinuxFirst *ChainPriority = ChainPriorityRef(-225) ChainPriorityConntrack *ChainPriority = ChainPriorityRef(-200) ChainPriorityMangle *ChainPriority = ChainPriorityRef(-150) ChainPriorityNATDest *ChainPriority = ChainPriorityRef(-100) ChainPriorityFilter *ChainPriority = ChainPriorityRef(0) ChainPrioritySecurity *ChainPriority = ChainPriorityRef(50) ChainPriorityNATSource *ChainPriority = ChainPriorityRef(100) ChainPrioritySELinuxLast *ChainPriority = ChainPriorityRef(225) ChainPriorityConntrackHelper *ChainPriority = ChainPriorityRef(300) ChainPriorityConntrackConfirm *ChainPriority = ChainPriorityRef(math.MaxInt32) ChainPriorityLast *ChainPriority = ChainPriorityRef(math.MaxInt32) ) // ChainPriorityRef returns a pointer to a ChainPriority value. func ChainPriorityRef(p ChainPriority) *ChainPriority { return &p } // ChainType defines what this chain will be used for. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_types type ChainType string // Possible ChainType values. const ( ChainTypeFilter ChainType = "filter" ChainTypeRoute ChainType = "route" ChainTypeNAT ChainType = "nat" ) // ChainPolicy defines what this chain default policy will be. type ChainPolicy uint32 // Possible ChainPolicy values. const ( ChainPolicyDrop ChainPolicy = iota ChainPolicyAccept ) // A Chain contains Rules. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains type Chain struct { Name string Table *Table Hooknum *ChainHook Priority *ChainPriority Type ChainType Policy *ChainPolicy Device string } // AddChain adds the specified Chain. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains func (cc *Conn) AddChain(c *Chain) *Chain { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) if c.Hooknum != nil && c.Priority != nil { hookAttr := []netlink.Attribute{ {Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Hooknum))}, {Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Priority))}, } if c.Device != "" { hookAttr = append(hookAttr, netlink.Attribute{Type: unix.NFTA_HOOK_DEV, Data: []byte(c.Device + "\x00")}) } data = append(data, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NLA_F_NESTED | unix.NFTA_CHAIN_HOOK, Data: cc.marshalAttr(hookAttr)}, })...) } if c.Policy != nil { data = append(data, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_CHAIN_POLICY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Policy))}, })...) } if c.Type != "" { data = append(data, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, Data: append(extraHeader(uint8(c.Table.Family), 0), data...), }) return c } // DelChain deletes the specified Chain. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Deleting_chains func (cc *Conn) DelChain(c *Chain) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, }, Data: append(extraHeader(uint8(c.Table.Family), 0), data...), }) } // FlushChain removes all rules within the specified Chain. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Flushing_chain func (cc *Conn) FlushChain(c *Chain) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, }, Data: append(extraHeader(uint8(c.Table.Family), 0), data...), }) } // ListChains returns currently configured chains in the kernel func (cc *Conn) ListChains() ([]*Chain, error) { return cc.ListChainsOfTableFamily(TableFamilyUnspecified) } // ListChain returns a single chain configured in the specified table func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() attrs := []netlink.Attribute{ {Type: unix.NFTA_TABLE_NAME, Data: []byte(table.Name + "\x00")}, {Type: unix.NFTA_CHAIN_NAME, Data: []byte(chain + "\x00")}, } msg := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), Flags: netlink.Request, }, Data: append(extraHeader(uint8(table.Family), 0), cc.marshalAttr(attrs)...), } response, err := conn.Execute(msg) if err != nil { return nil, fmt.Errorf("conn.Execute failed: %v", err) } if got, want := len(response), 1; got != want { return nil, fmt.Errorf("expected %d response message for chain, got %d", want, got) } ch, err := chainFromMsg(response[0]) if err != nil { return nil, err } return ch, nil } // ListChainsOfTableFamily returns currently configured chains for the specified // family in the kernel. It lists all chains ins all tables if family is // TableFamilyUnspecified. func (cc *Conn) ListChainsOfTableFamily(family TableFamily) ([]*Chain, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() msg := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), Flags: netlink.Request | netlink.Dump, }, Data: extraHeader(uint8(family), 0), } response, err := conn.Execute(msg) if err != nil { return nil, err } var chains []*Chain for _, m := range response { c, err := chainFromMsg(m) if err != nil { return nil, err } chains = append(chains, c) } return chains, nil } func chainFromMsg(msg netlink.Message) (*Chain, error) { newChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) delChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN) if got, want1, want2 := msg.Header.Type, newChainHeaderType, delChainHeaderType; got != want1 && got != want2 { return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } var c Chain ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { return nil, err } for ad.Next() { switch ad.Type() { case unix.NFTA_CHAIN_NAME: c.Name = ad.String() case unix.NFTA_TABLE_NAME: c.Table = &Table{Name: ad.String()} // msg[0] carries TableFamily byte indicating whether it is IPv4, IPv6 or something else c.Table.Family = TableFamily(msg.Data[0]) case unix.NFTA_CHAIN_TYPE: c.Type = ChainType(ad.String()) case unix.NFTA_CHAIN_POLICY: policy := ChainPolicy(binaryutil.BigEndian.Uint32(ad.Bytes())) c.Policy = &policy case unix.NFTA_CHAIN_HOOK: ad.Do(func(b []byte) error { c.Hooknum, c.Priority, err = hookFromMsg(b) return err }) } } return &c, nil } func hookFromMsg(b []byte) (*ChainHook, *ChainPriority, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, nil, err } ad.ByteOrder = binary.BigEndian var hooknum ChainHook var prio ChainPriority for ad.Next() { switch ad.Type() { case unix.NFTA_HOOK_HOOKNUM: hooknum = ChainHook(ad.Uint32()) case unix.NFTA_HOOK_PRIORITY: prio = ChainPriority(ad.Uint32()) } } return &hooknum, &prio, nil }