diff --git a/expr/expr.go b/expr/expr.go index 2f342be..257b4e4 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -96,6 +96,8 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er e = &Cmp{} case "counter": e = &Counter{} + case "objref": + e = &Objref{} case "payload": e = &Payload{} case "lookup": diff --git a/nftables_test.go b/nftables_test.go index 8324763..74b593f 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -5567,3 +5567,68 @@ func TestStatelessNAT(t *testing.T) { t.Fatal(err) } } + +func TestGetRulesObjref(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + table := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + chain := c.AddChain(&nftables.Chain{ + Name: "forward", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + + counterName := "fwded1" + c.AddObj(&nftables.CounterObj{ + Table: table, + Name: counterName, + Bytes: 1, + Packets: 1, + }) + + counterRule := c.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Objref{ + Type: 1, + Name: counterName, + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + rules, err := c.GetRules(table, chain) + if err != nil { + t.Fatal(err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + if got, want := len(rules[0].Exprs), 1; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + objref, objrefOk := rules[0].Exprs[0].(*expr.Objref) + if !objrefOk { + t.Fatalf("Exprs[0] is type %T, want *expr.Objref", rules[0].Exprs[0]) + } + if want := counterRule.Exprs[0]; !reflect.DeepEqual(objref, want) { + t.Errorf("objref expr = %+v, wanted %+v", objref, want) + } +}