Add GetRuleByHandle, ResetRule & ResetRules methods (#326)
This commit introduces the following methods: - GetRuleByHandle - ResetRule - ResetRules It also refactors GetRules and the deprecated GetRule methods to share a common getRules implementation.
This commit is contained in:
parent
1148f1a84f
commit
ba5b671e14
226
nftables_test.go
226
nftables_test.go
|
@ -7560,3 +7560,229 @@ func TestFlushWithGenID(t *testing.T) {
|
||||||
t.Errorf("expected table to not exist, got: %v", table)
|
t.Errorf("expected table to not exist, got: %v", table)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetRuleByHandle(t *testing.T) {
|
||||||
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
|
defer conn.FlushRuleset()
|
||||||
|
|
||||||
|
table := conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
chain := conn.AddChain(&nftables.Chain{
|
||||||
|
Name: "test-chain",
|
||||||
|
Table: table,
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range 3 {
|
||||||
|
conn.AddRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: chain,
|
||||||
|
UserData: fmt.Appendf([]byte{}, "rule-%d", i+1),
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
t.Fatalf("failed to flush: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRules failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := rules[1]
|
||||||
|
|
||||||
|
got, err := conn.GetRuleByHandle(table, chain, want.Handle)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRuleByHandle failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got.UserData, want.UserData) {
|
||||||
|
t.Fatalf("expected userdata %q, got %q", got.UserData, want.UserData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetRule(t *testing.T) {
|
||||||
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
|
defer conn.FlushRuleset()
|
||||||
|
|
||||||
|
table := conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
chain := conn.AddChain(&nftables.Chain{
|
||||||
|
Name: "test-chain",
|
||||||
|
Table: table,
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := [...]struct {
|
||||||
|
Bytes uint64
|
||||||
|
Packets uint64
|
||||||
|
Reset bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Bytes: 1024,
|
||||||
|
Packets: 1,
|
||||||
|
Reset: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Bytes: 2048,
|
||||||
|
Packets: 2,
|
||||||
|
Reset: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Bytes: 4096,
|
||||||
|
Packets: 4,
|
||||||
|
Reset: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
conn.AddRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Counter{
|
||||||
|
Bytes: tt.Bytes,
|
||||||
|
Packets: tt.Packets,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
t.Fatalf("flush failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRules failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rules) != len(tests) {
|
||||||
|
t.Fatalf("expected %d rules, got %d", len(tests), len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range rules {
|
||||||
|
if !tests[i].Reset {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, err := conn.ResetRule(table, chain, r.Handle)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResetRule failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err = conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRules failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range rules {
|
||||||
|
counter, ok := r.Exprs[0].(*expr.Counter)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if tests[i].Reset {
|
||||||
|
if counter.Bytes != 0 || counter.Packets != 0 {
|
||||||
|
t.Errorf(
|
||||||
|
"expected counter values to be reset to zero, got Bytes=%d, Packets=%d",
|
||||||
|
counter.Bytes,
|
||||||
|
counter.Packets,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Making sure that only the selected rules were reset
|
||||||
|
if counter.Bytes != tests[i].Bytes || counter.Packets != tests[i].Packets {
|
||||||
|
t.Errorf(
|
||||||
|
"unexpected counter values: got Bytes=%d, Packets=%d, want Bytes=%d, Packets=%d",
|
||||||
|
counter.Bytes,
|
||||||
|
counter.Packets,
|
||||||
|
tests[i].Bytes,
|
||||||
|
tests[i].Packets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetRules(t *testing.T) {
|
||||||
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
|
defer conn.FlushRuleset()
|
||||||
|
|
||||||
|
table := conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
chain := conn.AddChain(&nftables.Chain{
|
||||||
|
Name: "test-chain",
|
||||||
|
Table: table,
|
||||||
|
})
|
||||||
|
|
||||||
|
for range 3 {
|
||||||
|
conn.AddRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Counter{
|
||||||
|
Bytes: 1,
|
||||||
|
Packets: 1,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
t.Fatalf("flush failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRules failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rules) != 3 {
|
||||||
|
t.Fatalf("expected %d rules, got %d", 3, len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.ResetRules(table, chain); err != nil {
|
||||||
|
t.Fatalf("ResetRules failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err = conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRules failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
counter, ok := r.Exprs[0].(*expr.Counter)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if counter.Bytes != 0 || counter.Packets != 0 {
|
||||||
|
t.Errorf(
|
||||||
|
"expected counter values to be reset to zero, got Bytes=%d, Packets=%d",
|
||||||
|
counter.Bytes,
|
||||||
|
counter.Packets,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
75
rule.go
75
rule.go
|
@ -71,31 +71,98 @@ type Rule struct {
|
||||||
|
|
||||||
// GetRule returns the rules in the specified table and chain.
|
// GetRule returns the rules in the specified table and chain.
|
||||||
//
|
//
|
||||||
// Deprecated: use GetRules instead.
|
// Deprecated: use GetRuleByHandle instead.
|
||||||
func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
|
func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
|
||||||
return cc.GetRules(t, c)
|
return cc.GetRules(t, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRuleByHandle returns the rule in the specified table and chain by its
|
||||||
|
// handle.
|
||||||
|
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule
|
||||||
|
func (cc *Conn) GetRuleByHandle(t *Table, c *Chain, handle uint64) (*Rule, error) {
|
||||||
|
rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE, handle)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(rules), 1; got != want {
|
||||||
|
return nil, fmt.Errorf("expected rule count %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetRules returns the rules in the specified table and chain.
|
// GetRules returns the rules in the specified table and chain.
|
||||||
func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
|
func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
|
||||||
|
return cc.getRules(t, c, unix.NFT_MSG_GETRULE, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetRule resets the stateful expressions (e.g., counters) of the given
|
||||||
|
// rule. The reset is applied immediately (no Flush is required). The returned
|
||||||
|
// rule reflects its state prior to the reset. The provided rule must have a
|
||||||
|
// valid Handle.
|
||||||
|
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset
|
||||||
|
func (cc *Conn) ResetRule(t *Table, c *Chain, handle uint64) (*Rule, error) {
|
||||||
|
if handle == 0 {
|
||||||
|
return nil, fmt.Errorf("rule must have a valid handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, handle)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(rules), 1; got != want {
|
||||||
|
return nil, fmt.Errorf("expected rule count %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetRules resets the stateful expressions (e.g., counters) of all rules
|
||||||
|
// in the given table and chain. The reset is applied immediately (no Flush
|
||||||
|
// is required). The returned rules reflect their state prior to the reset.
|
||||||
|
// state.
|
||||||
|
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset
|
||||||
|
func (cc *Conn) ResetRules(t *Table, c *Chain) ([]*Rule, error) {
|
||||||
|
return cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRules retrieves rules from the given table and chain, using the provided
|
||||||
|
// msgType (either unix.NFT_MSG_GETRULE or unix.NFT_MSG_GETRULE_RESET). If the
|
||||||
|
// handle is non-zero, the operation applies only to the rule with that handle.
|
||||||
|
func (cc *Conn) getRules(t *Table, c *Chain, msgType int, handle uint64) ([]*Rule, error) {
|
||||||
conn, closer, err := cc.netlinkConn()
|
conn, closer, err := cc.netlinkConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = closer() }()
|
defer func() { _ = closer() }()
|
||||||
|
|
||||||
data, err := netlink.MarshalAttributes([]netlink.Attribute{
|
attrs := []netlink.Attribute{
|
||||||
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
|
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
|
||||||
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
|
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags netlink.HeaderFlags = netlink.Request | netlink.Acknowledge | netlink.Dump
|
||||||
|
|
||||||
|
if handle != 0 {
|
||||||
|
attrs = append(attrs, netlink.Attribute{
|
||||||
|
Type: unix.NFTA_RULE_HANDLE,
|
||||||
|
Data: binaryutil.BigEndian.PutUint64(handle),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
flags = netlink.Request | netlink.Acknowledge
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := netlink.MarshalAttributes(attrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
message := netlink.Message{
|
message := netlink.Message{
|
||||||
Header: netlink.Header{
|
Header: netlink.Header{
|
||||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
|
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType),
|
||||||
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump,
|
Flags: flags,
|
||||||
},
|
},
|
||||||
Data: append(extraHeader(uint8(t.Family), 0), data...),
|
Data: append(extraHeader(uint8(t.Family), 0), data...),
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue