diff --git a/nftables_test.go b/nftables_test.go index 8bfff70..89d4052 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -7560,3 +7560,229 @@ func TestFlushWithGenID(t *testing.T) { t.Errorf("expected table to not exist, got: %v", table) } } + +func TestGetRuleByHandle(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + for i := range 3 { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + UserData: fmt.Appendf([]byte{}, "rule-%d", i+1), + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + t.Fatalf("GetRules failed: %v", err) + } + + want := rules[1] + + got, err := conn.GetRuleByHandle(table, chain, want.Handle) + if err != nil { + t.Fatalf("GetRuleByHandle failed: %v", err) + } + if !bytes.Equal(got.UserData, want.UserData) { + t.Fatalf("expected userdata %q, got %q", got.UserData, want.UserData) + } +} + +func TestResetRule(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + tests := [...]struct { + Bytes uint64 + Packets uint64 + Reset bool + }{ + { + Bytes: 1024, + Packets: 1, + Reset: false, + }, + { + Bytes: 2048, + Packets: 2, + Reset: true, + }, + { + Bytes: 4096, + Packets: 4, + Reset: false, + }, + } + + for _, tt := range tests { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Counter{ + Bytes: tt.Bytes, + Packets: tt.Packets, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + if err := conn.Flush(); err != nil { + t.Fatalf("flush failed: %v", err) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + t.Fatalf("GetRules failed: %v", err) + } + + if len(rules) != len(tests) { + t.Fatalf("expected %d rules, got %d", len(tests), len(rules)) + } + + for i, r := range rules { + if !tests[i].Reset { + continue + } + _, err := conn.ResetRule(table, chain, r.Handle) + if err != nil { + t.Fatalf("ResetRule failed: %v", err) + } + } + + rules, err = conn.GetRules(table, chain) + if err != nil { + t.Fatalf("GetRules failed: %v", err) + } + + for i, r := range rules { + counter, ok := r.Exprs[0].(*expr.Counter) + if !ok { + t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0]) + } + + if tests[i].Reset { + if counter.Bytes != 0 || counter.Packets != 0 { + t.Errorf( + "expected counter values to be reset to zero, got Bytes=%d, Packets=%d", + counter.Bytes, + counter.Packets, + ) + } + } else { + // Making sure that only the selected rules were reset + if counter.Bytes != tests[i].Bytes || counter.Packets != tests[i].Packets { + t.Errorf( + "unexpected counter values: got Bytes=%d, Packets=%d, want Bytes=%d, Packets=%d", + counter.Bytes, + counter.Packets, + tests[i].Bytes, + tests[i].Packets) + } + } + } +} + +func TestResetRules(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + for range 3 { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Counter{ + Bytes: 1, + Packets: 1, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + if err := conn.Flush(); err != nil { + t.Fatalf("flush failed: %v", err) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + t.Fatalf("GetRules failed: %v", err) + } + + if len(rules) != 3 { + t.Fatalf("expected %d rules, got %d", 3, len(rules)) + } + + if _, err := conn.ResetRules(table, chain); err != nil { + t.Fatalf("ResetRules failed: %v", err) + } + + rules, err = conn.GetRules(table, chain) + if err != nil { + t.Fatalf("GetRules failed: %v", err) + } + + for _, r := range rules { + counter, ok := r.Exprs[0].(*expr.Counter) + if !ok { + t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0]) + } + + if counter.Bytes != 0 || counter.Packets != 0 { + t.Errorf( + "expected counter values to be reset to zero, got Bytes=%d, Packets=%d", + counter.Bytes, + counter.Packets, + ) + } + } +} diff --git a/rule.go b/rule.go index 10958c6..caa56e8 100644 --- a/rule.go +++ b/rule.go @@ -71,31 +71,98 @@ type Rule struct { // GetRule returns the rules in the specified table and chain. // -// Deprecated: use GetRules instead. +// Deprecated: use GetRuleByHandle instead. func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { return cc.GetRules(t, c) } +// GetRuleByHandle returns the rule in the specified table and chain by its +// handle. +// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule +func (cc *Conn) GetRuleByHandle(t *Table, c *Chain, handle uint64) (*Rule, error) { + rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE, handle) + if err != nil { + return nil, err + } + + if got, want := len(rules), 1; got != want { + return nil, fmt.Errorf("expected rule count %d, got %d", want, got) + } + + return rules[0], nil +} + // GetRules returns the rules in the specified table and chain. func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { + return cc.getRules(t, c, unix.NFT_MSG_GETRULE, 0) +} + +// ResetRule resets the stateful expressions (e.g., counters) of the given +// rule. The reset is applied immediately (no Flush is required). The returned +// rule reflects its state prior to the reset. The provided rule must have a +// valid Handle. +// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset +func (cc *Conn) ResetRule(t *Table, c *Chain, handle uint64) (*Rule, error) { + if handle == 0 { + return nil, fmt.Errorf("rule must have a valid handle") + } + + rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, handle) + if err != nil { + return nil, err + } + + if got, want := len(rules), 1; got != want { + return nil, fmt.Errorf("expected rule count %d, got %d", want, got) + } + + return rules[0], nil +} + +// ResetRules resets the stateful expressions (e.g., counters) of all rules +// in the given table and chain. The reset is applied immediately (no Flush +// is required). The returned rules reflect their state prior to the reset. +// state. +// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset +func (cc *Conn) ResetRules(t *Table, c *Chain) ([]*Rule, error) { + return cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, 0) +} + +// getRules retrieves rules from the given table and chain, using the provided +// msgType (either unix.NFT_MSG_GETRULE or unix.NFT_MSG_GETRULE_RESET). If the +// handle is non-zero, the operation applies only to the rule with that handle. +func (cc *Conn) getRules(t *Table, c *Chain, msgType int, handle uint64) ([]*Rule, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err } defer func() { _ = closer() }() - data, err := netlink.MarshalAttributes([]netlink.Attribute{ + attrs := []netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, - }) + } + + var flags netlink.HeaderFlags = netlink.Request | netlink.Acknowledge | netlink.Dump + + if handle != 0 { + attrs = append(attrs, netlink.Attribute{ + Type: unix.NFTA_RULE_HANDLE, + Data: binaryutil.BigEndian.PutUint64(handle), + }) + + flags = netlink.Request | netlink.Acknowledge + } + + data, err := netlink.MarshalAttributes(attrs) if err != nil { return nil, err } message := netlink.Message{ Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), - Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType), + Flags: flags, }, Data: append(extraHeader(uint8(t.Family), 0), data...), }