From 85d0f3a0db1e02e9bd5120ecd1a41df494fc1abe Mon Sep 17 00:00:00 2001 From: TheDiveO <6920158+thediveo@users.noreply.github.com> Date: Sun, 8 May 2022 20:39:12 +0200 Subject: [PATCH] add GetRules and deprecate GetRule, update tests (#160) --- nftables_test.go | 50 ++++++++++++++++++++++++------------------------ rule.go | 7 +++++++ 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index ec87ca4..7ee23de 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -184,7 +184,7 @@ func TestRuleOperations(t *testing.T) { t.Fatal(err) } - rules, _ := c.GetRule(filter, prerouting) + rules, _ := c.GetRules(filter, prerouting) want := []expr.VerdictKind{ expr.VerdictQueue, @@ -241,7 +241,7 @@ func TestRuleOperations(t *testing.T) { t.Fatal(err) } - rules, _ = c.GetRule(filter, prerouting) + rules, _ = c.GetRules(filter, prerouting) want = []expr.VerdictKind{ expr.VerdictQueue, @@ -627,7 +627,7 @@ func TestExprLogOptions(t *testing.T) { t.Errorf("c.Flush() failed: %v", err) } - rules, err := c.GetRule( + rules, err := c.GetRules( &nftables.Table{ Family: nftables.TableFamilyIPv4, Name: "filter", @@ -670,7 +670,7 @@ func TestExprLogOptions(t *testing.T) { t.Fatalf("unexpected snaplen: got %d, want %d", got, want) } - rules, err = c.GetRule( + rules, err = c.GetRules( &nftables.Table{ Family: nftables.TableFamilyIPv4, Name: "filter", @@ -748,7 +748,7 @@ func TestExprLogPrefix(t *testing.T) { t.Errorf("c.Flush() failed: %v", err) } - rules, err := c.GetRule( + rules, err := c.GetRules( &nftables.Table{ Family: nftables.TableFamilyIPv4, Name: "filter", @@ -786,7 +786,7 @@ func TestExprLogPrefix(t *testing.T) { } } -func TestGetRule(t *testing.T) { +func TestGetRules(t *testing.T) { // The want byte sequences come from stracing nft(8), e.g.: // strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward @@ -828,7 +828,7 @@ func TestGetRule(t *testing.T) { }, } - rules, err := c.GetRule( + rules, err := c.GetRules( &nftables.Table{ Family: nftables.TableFamilyIPv4, Name: "filter", @@ -2729,9 +2729,9 @@ func TestFlushChain(t *testing.T) { if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } - rules, err := c.GetRule(filter, forward) + rules, err := c.GetRules(filter, forward) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 2 { t.Fatalf("len(rules) = %d, want 2", len(rules)) @@ -2743,9 +2743,9 @@ func TestFlushChain(t *testing.T) { t.Errorf("Second c.Flush() failed: %v", err) } - rules, err = c.GetRule(filter, forward) + rules, err = c.GetRules(filter, forward) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 0 { t.Fatalf("len(rules) = %d, want 0", len(rules)) @@ -2931,23 +2931,23 @@ func TestFlushTable(t *testing.T) { if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } - rules, err := c.GetRule(filter, forward) + rules, err := c.GetRules(filter, forward) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 2 { t.Fatalf("len(rules) = %d, want 2", len(rules)) } - rules, err = c.GetRule(filter, input) + rules, err = c.GetRules(filter, input) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 1 { t.Fatalf("len(rules) = %d, want 1", len(rules)) } - rules, err = c.GetRule(nat, prerouting) + rules, err = c.GetRules(nat, prerouting) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 1 { t.Fatalf("len(rules) = %d, want 1", len(rules)) @@ -2959,23 +2959,23 @@ func TestFlushTable(t *testing.T) { t.Errorf("Second c.Flush() failed: %v", err) } - rules, err = c.GetRule(filter, forward) + rules, err = c.GetRules(filter, forward) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 0 { t.Fatalf("len(rules) = %d, want 0", len(rules)) } - rules, err = c.GetRule(filter, input) + rules, err = c.GetRules(filter, input) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 0 { t.Fatalf("len(rules) = %d, want 0", len(rules)) } - rules, err = c.GetRule(nat, prerouting) + rules, err = c.GetRules(nat, prerouting) if err != nil { - t.Errorf("c.GetRule() failed: %v", err) + t.Errorf("c.GetRules() failed: %v", err) } if len(rules) != 1 { t.Fatalf("len(rules) = %d, want 1", len(rules)) @@ -3056,7 +3056,7 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) { t.Errorf("c.Flush() failed: %v", err) } - rules, err := c.GetRule( + rules, err := c.GetRules( &nftables.Table{ Family: nftables.TableFamilyIPv4, Name: "filter", @@ -3167,7 +3167,7 @@ func TestDynset(t *testing.T) { t.Errorf("c.Flush() failed: %v", err) } - rules, err := c.GetRule( + rules, err := c.GetRules( &nftables.Table{ Family: nftables.TableFamilyIPv4, Name: "filter", diff --git a/rule.go b/rule.go index 5729ef3..be48792 100644 --- a/rule.go +++ b/rule.go @@ -52,7 +52,14 @@ type Rule struct { } // GetRule returns the rules in the specified table and chain. +// +// Deprecated: use GetRules instead. func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { + return cc.GetRules(t, c) +} + +// GetRules returns the rules in the specified table and chain. +func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { conn, err := cc.dialNetlink() if err != nil { return nil, err