diff --git a/nftables_test.go b/nftables_test.go index df5de51..fb42e96 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -624,6 +624,154 @@ func TestTProxy(t *testing.T) { } } +func TestAddRuleWithPosition(t *testing.T) { + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add rule ip ipv4table ipv4chain-1 position 2 ip version 6 + []byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x31\x00\xa8\x00\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x01\x0c\x00\x04\x80\x05\x00\x01\x00\xf0\x00\x00\x00\x0c\x00\x05\x80\x05\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x60\x00\x00\x00\x0c\x00\x06\x00\x00\x00\x00\x00\x00\x00\x00\x02"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + }, + } + + c.AddRule(&nftables.Rule{ + Position: 2, + Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4}, + Chain: &nftables.Chain{ + Name: "ipv4chain-1", + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: 0, + }, + + Exprs: []expr.Any{ + // [ payload load 1b @ network header + 0 => reg 1 ] + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 0, // Offset for a transport protocol header + Len: 1, // 1 bytes for port + }, + // [ bitwise reg 1 = (reg=1 & 0x000000f0 ) ^ 0x00000000 ] + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 1, + Mask: []byte{0xf0}, + Xor: []byte{0x0}, + }, + // [ cmp eq reg 1 0x00000060 ] + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{(0x6 << 4)}, + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } +} + +func TestAddRuleRuleID(t *testing.T) { + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add rule ip ipv4table ipv4chain-1 ip version 6 comment \0x01 + []byte("\x02\x00\x00\x00\x0e\x00\x01\x00\x69\x70\x76\x34\x74\x61\x62\x6c\x65\x00\x00\x00\x10\x00\x02\x00\x69\x70\x76\x34\x63\x68\x61\x69\x6e\x2d\x31\x00\xa8\x00\x04\x80\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x01\x0c\x00\x04\x80\x05\x00\x01\x00\xf0\x00\x00\x00\x0c\x00\x05\x80\x05\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x60\x00\x00\x00\x08\x00\x07\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + }, + } + + c.AddRule(&nftables.Rule{ + RuleID: 1, + Table: &nftables.Table{Name: "ipv4table", Family: nftables.TableFamilyIPv4}, + Chain: &nftables.Chain{ + Name: "ipv4chain-1", + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: 0, + }, + + Exprs: []expr.Any{ + // [ payload load 1b @ network header + 0 => reg 1 ] + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 0, // Offset for a transport protocol header + Len: 1, // 1 bytes for port + }, + // [ bitwise reg 1 = (reg=1 & 0x000000f0 ) ^ 0x00000000 ] + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 1, + Mask: []byte{0xf0}, + Xor: []byte{0x0}, + }, + // [ cmp eq reg 1 0x00000060 ] + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{(0x6 << 4)}, + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } +} + func TestAddChain(t *testing.T) { tests := []struct { name string diff --git a/rule.go b/rule.go index 8f61046..7775942 100644 --- a/rule.go +++ b/rule.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "fmt" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -28,9 +29,12 @@ var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix. // A Rule does something with a packet. See also // https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management type Rule struct { - Table *Table - Chain *Chain - Exprs []expr.Any + Table *Table + Chain *Chain + RuleID uint32 + Position uint64 + Handle uint64 + Exprs []expr.Any } // GetRule returns the rules in the specified table and chain. @@ -52,7 +56,7 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), - Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, }, Data: append(extraHeader(uint8(t.Family), 0), data...), } @@ -77,8 +81,51 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { return rules, nil } -// AddRule adds the specified Rule. See also -// https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management +// GetRuleHandle returns a specific rule's handle. Rule is identified by Table, Chain and RuleID. +func (cc *Conn) GetRuleHandle(t *Table, c *Chain, ruleID uint32) (uint64, error) { + conn, err := cc.dialNetlink() + if err != nil { + return 0, err + } + defer conn.Close() + if ruleID == 0 { + return 0, fmt.Errorf("rule's id cannot be 0") + } + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, + {Type: unix.NFTA_RULE_USERDATA, Data: binaryutil.BigEndian.PutUint32(ruleID)}, + }) + if err != nil { + return 0, err + } + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return 0, fmt.Errorf("SendMessages: %v", err) + } + reply, err := conn.Receive() + if err != nil { + return 0, fmt.Errorf("Receive: %v", err) + } + if len(reply) != 1 { + return 0, fmt.Errorf("Receive: Expected 1 message but got %d", len(reply)) + } + rr, err := ruleFromMsg(reply[0]) + if err != nil { + return 0, err + } + + return rr.Handle, nil +} + +// AddRule adds the specified Rule func (cc *Conn) AddRule(r *Rule) *Rule { exprAttrs := make([]netlink.Attribute, len(r.Exprs)) for idx, expr := range r.Exprs { @@ -87,19 +134,35 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: cc.marshalExpr(expr), } } - data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_EXPRESSIONS, Data: cc.marshalAttr(exprAttrs)}, }) - + msgData := []byte{} + msgData = append(msgData, data...) + var flags netlink.HeaderFlags + if r.RuleID != 0 { + msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_USERDATA, Data: binaryutil.BigEndian.PutUint32(r.RuleID)}, + })...) + } + if r.Position != 0 { + msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_POSITION, Data: binaryutil.BigEndian.PutUint64(r.Position)}, + })...) + // when a rule's position is specified, it becomes nft insert rule operation + flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO + } else { + // unix.NLM_F_APPEND is added when nft add rule operation is executed. + flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND + } cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: ruleHeaderType, - Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + Flags: flags, }, - Data: append(extraHeader(uint8(r.Table.Family), 0), data...), + Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), }) return r @@ -191,6 +254,12 @@ func ruleFromMsg(msg netlink.Message) (*Rule, error) { r.Exprs, err = exprsFromMsg(b) return err }) + case unix.NFTA_RULE_POSITION: + r.Position = ad.Uint64() + case unix.NFTA_RULE_HANDLE: + r.Handle = ad.Uint64() + case unix.NFTA_RULE_USERDATA: + r.RuleID = ad.Uint32() } } return &r, ad.Err()