From 1c789726cfebc8afac9383f7005f381f05959b37 Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Thu, 16 Jan 2025 09:15:33 +0100 Subject: [PATCH] Fix Fib parsing (#296) --- expr/expr.go | 2 ++ expr/fib.go | 23 ++++++++++++-------- expr/limit.go | 2 +- expr/quota.go | 2 +- nftables_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ quota.go | 2 +- 6 files changed, 74 insertions(+), 12 deletions(-) diff --git a/expr/expr.go b/expr/expr.go index fc00a69..e7e7a45 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -207,6 +207,8 @@ func exprFromName(name string) Any { e = &SecMark{} case "cttimeout": e = &CtTimeout{} + case "fib": + e = &Fib{} } return e } diff --git a/expr/fib.go b/expr/fib.go index 62a717c..ea6c059 100644 --- a/expr/fib.go +++ b/expr/fib.go @@ -118,17 +118,22 @@ func (e *Fib) unmarshal(fam byte, data []byte) error { e.Register = ad.Uint32() case unix.NFTA_FIB_RESULT: result := ad.Uint32() - e.ResultOIF = (result & unix.NFT_FIB_RESULT_OIF) == 1 - e.ResultOIFNAME = (result & unix.NFT_FIB_RESULT_OIFNAME) == 1 - e.ResultADDRTYPE = (result & unix.NFT_FIB_RESULT_ADDRTYPE) == 1 + switch result { + case unix.NFT_FIB_RESULT_OIF: + e.ResultOIF = true + case unix.NFT_FIB_RESULT_OIFNAME: + e.ResultOIFNAME = true + case unix.NFT_FIB_RESULT_ADDRTYPE: + e.ResultADDRTYPE = true + } case unix.NFTA_FIB_FLAGS: flags := ad.Uint32() - e.FlagSADDR = (flags & unix.NFTA_FIB_F_SADDR) == 1 - e.FlagDADDR = (flags & unix.NFTA_FIB_F_DADDR) == 1 - e.FlagMARK = (flags & unix.NFTA_FIB_F_MARK) == 1 - e.FlagIIF = (flags & unix.NFTA_FIB_F_IIF) == 1 - e.FlagOIF = (flags & unix.NFTA_FIB_F_OIF) == 1 - e.FlagPRESENT = (flags & unix.NFTA_FIB_F_PRESENT) == 1 + e.FlagSADDR = (flags & unix.NFTA_FIB_F_SADDR) != 0 + e.FlagDADDR = (flags & unix.NFTA_FIB_F_DADDR) != 0 + e.FlagMARK = (flags & unix.NFTA_FIB_F_MARK) != 0 + e.FlagIIF = (flags & unix.NFTA_FIB_F_IIF) != 0 + e.FlagOIF = (flags & unix.NFTA_FIB_F_OIF) != 0 + e.FlagPRESENT = (flags & unix.NFTA_FIB_F_PRESENT) != 0 } } return ad.Err() diff --git a/expr/limit.go b/expr/limit.go index 9d2facc..1e170ac 100644 --- a/expr/limit.go +++ b/expr/limit.go @@ -123,7 +123,7 @@ func (l *Limit) unmarshal(fam byte, data []byte) error { return fmt.Errorf("expr: invalid limit type %d", l.Type) } case unix.NFTA_LIMIT_FLAGS: - l.Over = (ad.Uint32() & unix.NFT_LIMIT_F_INV) == 1 + l.Over = (ad.Uint32() & unix.NFT_LIMIT_F_INV) != 0 default: return errors.New("expr: unhandled limit netlink attribute") } diff --git a/expr/quota.go b/expr/quota.go index 87bcb9d..ca55f6c 100644 --- a/expr/quota.go +++ b/expr/quota.go @@ -73,7 +73,7 @@ func (q *Quota) unmarshal(fam byte, data []byte) error { case unix.NFTA_QUOTA_CONSUMED: q.Consumed = ad.Uint64() case unix.NFTA_QUOTA_FLAGS: - q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) == 1 + q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) != 0 } } return ad.Err() diff --git a/nftables_test.go b/nftables_test.go index 1c39b67..7655b83 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -6387,6 +6387,61 @@ func TestFib(t *testing.T) { } } +func TestFibSystem(t *testing.T) { + c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + chain := c.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: filter, + }) + + expect := &expr.Fib{ + Register: 1, + FlagDADDR: true, + ResultADDRTYPE: true, + } + + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: chain, + Exprs: []expr.Any{expect}, + }) + + if err := c.Flush(); err != nil { + t.Fatalf("c.Flush() failed with error %+v", err) + } + + rules, err := c.GetRules(filter, chain) + if err != nil { + t.Fatalf("GetRules failed: %v", err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + + if got, want := len(rules[0].Exprs), 1; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + fib := rules[0].Exprs[0].(*expr.Fib) + if got, want := fib.FlagDADDR, expect.FlagDADDR; got != want { + t.Errorf("fib daddr not equal: got %+v, want %+v", got, want) + } + + if got, want := fib.ResultADDRTYPE, expect.ResultADDRTYPE; got != want { + t.Errorf("fib addr type not equal: got %+v, want %+v", got, want) + } +} + func TestNumgen(t *testing.T) { tests := []struct { name string diff --git a/quota.go b/quota.go index 123c9da..a8be634 100644 --- a/quota.go +++ b/quota.go @@ -36,7 +36,7 @@ func (q *QuotaObj) unmarshal(ad *netlink.AttributeDecoder) error { case unix.NFTA_QUOTA_CONSUMED: q.Consumed = ad.Uint64() case unix.NFTA_QUOTA_FLAGS: - q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) == 1 + q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) != 0 } } return nil