From 0aa65c0fdd5c27f6d3d0af2c1fd1e25f68195fb2 Mon Sep 17 00:00:00 2001 From: vsandonis <113995541+vsandonis@users.noreply.github.com> Date: Wed, 28 Sep 2022 18:33:16 +0200 Subject: [PATCH] Fix Objref expression parsing (#193) The Objref expression was not considered when parsing raw expressions bytes to construct nftables expressions. Add unit test to check that a rule with an Objref expression is properly obtained by GetRules(). Signed-off-by: Victor Sandonis Consuegra --- expr/expr.go | 2 ++ nftables_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/expr/expr.go b/expr/expr.go index 2f342be..257b4e4 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -96,6 +96,8 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er e = &Cmp{} case "counter": e = &Counter{} + case "objref": + e = &Objref{} case "payload": e = &Payload{} case "lookup": diff --git a/nftables_test.go b/nftables_test.go index 8324763..74b593f 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -5567,3 +5567,68 @@ func TestStatelessNAT(t *testing.T) { t.Fatal(err) } } + +func TestGetRulesObjref(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + table := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + chain := c.AddChain(&nftables.Chain{ + Name: "forward", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + + counterName := "fwded1" + c.AddObj(&nftables.CounterObj{ + Table: table, + Name: counterName, + Bytes: 1, + Packets: 1, + }) + + counterRule := c.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Objref{ + Type: 1, + Name: counterName, + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + rules, err := c.GetRules(table, chain) + if err != nil { + t.Fatal(err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + if got, want := len(rules[0].Exprs), 1; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + objref, objrefOk := rules[0].Exprs[0].(*expr.Objref) + if !objrefOk { + t.Fatalf("Exprs[0] is type %T, want *expr.Objref", rules[0].Exprs[0]) + } + if want := counterRule.Exprs[0]; !reflect.DeepEqual(objref, want) { + t.Errorf("objref expr = %+v, wanted %+v", objref, want) + } +}