From 9aa6fdf5a28cdfec5695ebbf7d7add61eccdf61a Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Sun, 15 Jan 2023 20:51:35 +0000 Subject: [PATCH] Masq marshal fix (#214) Fixes https://github.com/google/nftables/issues/213 --- expr/expr.go | 3 ++ nftables_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/expr/expr.go b/expr/expr.go index aa9e9d6..9a9ea76 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -132,6 +132,8 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er e = &FlowOffload{} case "reject": e = &Reject{} + case "masq": + e = &Masq{} } if e == nil { // TODO: introduce an opaque expression type so that users know @@ -337,6 +339,7 @@ func (e *Masq) unmarshal(fam byte, data []byte) error { for ad.Next() { switch ad.Type() { case unix.NFTA_MASQ_REG_PROTO_MIN: + e.ToPorts = true e.RegProtoMin = ad.Uint32() case unix.NFTA_MASQ_REG_PROTO_MAX: e.RegProtoMax = ad.Uint32() diff --git a/nftables_test.go b/nftables_test.go index bad4ccd..2f5f1e8 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -580,6 +580,83 @@ func TestConfigureNATSourceAddress(t *testing.T) { } } +func TestMasqMarshalUnmarshal(t *testing.T) { + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyINet, + Name: "filter", + }) + postrouting := c.AddChain(&nftables.Chain{ + Name: "postrouting", + Table: filter, + Type: nftables.ChainTypeNAT, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityFilter, + }) + + min := uint32(1) + max := uint32(3) + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: postrouting, + Exprs: []expr.Any{ + &expr.Masq{ + ToPorts: true, + RegProtoMin: min, + RegProtoMax: max, + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatalf("c.Flush() failed: %v", err) + } + + rules, err := c.GetRules( + &nftables.Table{ + Family: nftables.TableFamilyINet, + Name: "filter", + }, + &nftables.Chain{ + Name: "postrouting", + }, + ) + if err != nil { + t.Fatalf("c.GetRules() failed: %v", err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected rule count: got %d, want %d", got, want) + } + + rule := rules[0] + if got, want := len(rule.Exprs), 1; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + me, ok := rule.Exprs[0].(*expr.Masq) + if !ok { + t.Fatalf("unexpected expression type: got %T, want *expr.Masq", rule.Exprs[0]) + } + + if got, want := me.ToPorts, true; got != want { + t.Errorf("unexpected masq random flag: got %v, want %v", got, want) + } + + if got, want := me.RegProtoMin, min; got != want { + t.Errorf("unexpected reg proto min: got %d, want %d", got, want) + } + + if got, want := me.RegProtoMax, max; got != want { + t.Errorf("unexpected reg proto max: got %d, want %d", got, want) + } +} + func TestExprLogOptions(t *testing.T) { c, newNS := openSystemNFTConn(t) defer cleanupSystemNFTConn(t, newNS)