Compare commits

...

3 Commits

Author SHA1 Message Date
Paul Greenberg 8c564b0097
Merge dae73eaa9c into ba5b671e14 2025-09-12 12:40:01 +02:00
Nick Garlis ba5b671e14
Add GetRuleByHandle, ResetRule & ResetRules methods (#326)
This commit introduces the following methods:

  - GetRuleByHandle
  - ResetRule
  - ResetRules

It also refactors GetRules and the deprecated GetRule methods to share
a common getRules implementation.
2025-09-12 10:09:30 +02:00
Paul Greenberg dae73eaa9c rule: add String() method
Before this commit: the printing of a rule results in
a pointer address.

After this commit: the printing of a rules results in
a human-readable text.

Resolves: #104

Signed-off-by: Paul Greenberg <greenpau@outlook.com>
2020-08-03 10:59:40 -04:00
6 changed files with 368 additions and 5 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
nftables.test

View File

@ -21,4 +21,12 @@ the data types/API will be identified as more functionality is added.
Contributions are very welcome!
### Testing Changes
Run the following commands to test your changes:
```bash
go test ./...
go test -c github.com/google/nftables
sudo ./nftables.test -test.v -run_system_tests
```

View File

@ -24,6 +24,15 @@ import (
"golang.org/x/sys/unix"
)
const (
NFT_DROP = 0
NFT_ACCEPT = 1
NFT_STOLEN = 2
NFT_QUEUE = 3
NFT_REPEAT = 4
NFT_STOP = 5
)
// This code assembles the verdict structure, as expected by the
// nftables netlink API.
// For further information, consult:
@ -129,3 +138,37 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
}
return ad.Err()
}
func (e *Verdict) String() string {
var v string
switch e.Kind {
case unix.NFT_RETURN:
v = "return" // -0x5
case unix.NFT_GOTO:
v = "goto" // -0x4
case unix.NFT_JUMP:
v = "jump" // NFT_JUMP = -0x3
case unix.NFT_BREAK:
v = "break" // NFT_BREAK = -0x2
case unix.NFT_CONTINUE:
v = "continue" // NFT_CONTINUE = -0x1
case NFT_DROP:
v = "drop"
case NFT_ACCEPT:
v = "accept"
case NFT_STOLEN:
v = "stolen"
case NFT_QUEUE:
v = "queue"
case NFT_REPEAT:
v = "repeat"
case NFT_STOP:
v = "stop"
default:
v = fmt.Sprintf("verdict %v", e.Kind)
}
if e.Chain != "" {
return v + " " + e.Chain
}
return v
}

View File

@ -307,12 +307,27 @@ func TestRuleOperations(t *testing.T) {
expr.VerdictDrop,
}
wantStrings := []string{
"queue",
"accept",
"queue",
"accept",
"drop",
"drop",
}
for i, r := range rules {
rr, _ := r.Exprs[0].(*expr.Verdict)
if rr.Kind != want[i] {
t.Fatalf("bad verdict kind at %d", i)
}
if rr.String() != wantStrings[i] {
t.Fatalf("bad verdict string at %d: %s (received) vs. %s (expected)", i, rr.String(), wantStrings[i])
}
t.Logf("%s", rr)
}
}
@ -7560,3 +7575,229 @@ func TestFlushWithGenID(t *testing.T) {
t.Errorf("expected table to not exist, got: %v", table)
}
}
func TestGetRuleByHandle(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
table := conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
chain := conn.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
})
for i := range 3 {
conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
UserData: fmt.Appendf([]byte{}, "rule-%d", i+1),
Exprs: []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
rules, err := conn.GetRules(table, chain)
if err != nil {
t.Fatalf("GetRules failed: %v", err)
}
want := rules[1]
got, err := conn.GetRuleByHandle(table, chain, want.Handle)
if err != nil {
t.Fatalf("GetRuleByHandle failed: %v", err)
}
if !bytes.Equal(got.UserData, want.UserData) {
t.Fatalf("expected userdata %q, got %q", got.UserData, want.UserData)
}
}
func TestResetRule(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
table := conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
chain := conn.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
})
tests := [...]struct {
Bytes uint64
Packets uint64
Reset bool
}{
{
Bytes: 1024,
Packets: 1,
Reset: false,
},
{
Bytes: 2048,
Packets: 2,
Reset: true,
},
{
Bytes: 4096,
Packets: 4,
Reset: false,
},
}
for _, tt := range tests {
conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Counter{
Bytes: tt.Bytes,
Packets: tt.Packets,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
if err := conn.Flush(); err != nil {
t.Fatalf("flush failed: %v", err)
}
rules, err := conn.GetRules(table, chain)
if err != nil {
t.Fatalf("GetRules failed: %v", err)
}
if len(rules) != len(tests) {
t.Fatalf("expected %d rules, got %d", len(tests), len(rules))
}
for i, r := range rules {
if !tests[i].Reset {
continue
}
_, err := conn.ResetRule(table, chain, r.Handle)
if err != nil {
t.Fatalf("ResetRule failed: %v", err)
}
}
rules, err = conn.GetRules(table, chain)
if err != nil {
t.Fatalf("GetRules failed: %v", err)
}
for i, r := range rules {
counter, ok := r.Exprs[0].(*expr.Counter)
if !ok {
t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0])
}
if tests[i].Reset {
if counter.Bytes != 0 || counter.Packets != 0 {
t.Errorf(
"expected counter values to be reset to zero, got Bytes=%d, Packets=%d",
counter.Bytes,
counter.Packets,
)
}
} else {
// Making sure that only the selected rules were reset
if counter.Bytes != tests[i].Bytes || counter.Packets != tests[i].Packets {
t.Errorf(
"unexpected counter values: got Bytes=%d, Packets=%d, want Bytes=%d, Packets=%d",
counter.Bytes,
counter.Packets,
tests[i].Bytes,
tests[i].Packets)
}
}
}
}
func TestResetRules(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
table := conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
chain := conn.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
})
for range 3 {
conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Counter{
Bytes: 1,
Packets: 1,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
if err := conn.Flush(); err != nil {
t.Fatalf("flush failed: %v", err)
}
rules, err := conn.GetRules(table, chain)
if err != nil {
t.Fatalf("GetRules failed: %v", err)
}
if len(rules) != 3 {
t.Fatalf("expected %d rules, got %d", 3, len(rules))
}
if _, err := conn.ResetRules(table, chain); err != nil {
t.Fatalf("ResetRules failed: %v", err)
}
rules, err = conn.GetRules(table, chain)
if err != nil {
t.Fatalf("GetRules failed: %v", err)
}
for _, r := range rules {
counter, ok := r.Exprs[0].(*expr.Counter)
if !ok {
t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0])
}
if counter.Bytes != 0 || counter.Packets != 0 {
t.Errorf(
"expected counter values to be reset to zero, got Bytes=%d, Packets=%d",
counter.Bytes,
counter.Packets,
)
}
}
}

3
nftables_test.sh Executable file
View File

@ -0,0 +1,3 @@
go test ./...
go test -c github.com/google/nftables
sudo ./nftables.test -test.v -run_system_tests

75
rule.go
View File

@ -71,31 +71,98 @@ type Rule struct {
// GetRule returns the rules in the specified table and chain.
//
// Deprecated: use GetRules instead.
// Deprecated: use GetRuleByHandle instead.
func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
return cc.GetRules(t, c)
}
// GetRuleByHandle returns the rule in the specified table and chain by its
// handle.
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule
func (cc *Conn) GetRuleByHandle(t *Table, c *Chain, handle uint64) (*Rule, error) {
rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE, handle)
if err != nil {
return nil, err
}
if got, want := len(rules), 1; got != want {
return nil, fmt.Errorf("expected rule count %d, got %d", want, got)
}
return rules[0], nil
}
// GetRules returns the rules in the specified table and chain.
func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
return cc.getRules(t, c, unix.NFT_MSG_GETRULE, 0)
}
// ResetRule resets the stateful expressions (e.g., counters) of the given
// rule. The reset is applied immediately (no Flush is required). The returned
// rule reflects its state prior to the reset. The provided rule must have a
// valid Handle.
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset
func (cc *Conn) ResetRule(t *Table, c *Chain, handle uint64) (*Rule, error) {
if handle == 0 {
return nil, fmt.Errorf("rule must have a valid handle")
}
rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, handle)
if err != nil {
return nil, err
}
if got, want := len(rules), 1; got != want {
return nil, fmt.Errorf("expected rule count %d, got %d", want, got)
}
return rules[0], nil
}
// ResetRules resets the stateful expressions (e.g., counters) of all rules
// in the given table and chain. The reset is applied immediately (no Flush
// is required). The returned rules reflect their state prior to the reset.
// state.
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset
func (cc *Conn) ResetRules(t *Table, c *Chain) ([]*Rule, error) {
return cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, 0)
}
// getRules retrieves rules from the given table and chain, using the provided
// msgType (either unix.NFT_MSG_GETRULE or unix.NFT_MSG_GETRULE_RESET). If the
// handle is non-zero, the operation applies only to the rule with that handle.
func (cc *Conn) getRules(t *Table, c *Chain, msgType int, handle uint64) ([]*Rule, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
attrs := []netlink.Attribute{
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
}
var flags netlink.HeaderFlags = netlink.Request | netlink.Acknowledge | netlink.Dump
if handle != 0 {
attrs = append(attrs, netlink.Attribute{
Type: unix.NFTA_RULE_HANDLE,
Data: binaryutil.BigEndian.PutUint64(handle),
})
flags = netlink.Request | netlink.Acknowledge
}
data, err := netlink.MarshalAttributes(attrs)
if err != nil {
return nil, err
}
message := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump,
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType),
Flags: flags,
},
Data: append(extraHeader(uint8(t.Family), 0), data...),
}