diff --git a/compat_policy.go b/compat_policy.go new file mode 100644 index 0000000..c1f3908 --- /dev/null +++ b/compat_policy.go @@ -0,0 +1,89 @@ +package nftables + +import ( + "fmt" + + "github.com/google/nftables/expr" + "golang.org/x/sys/unix" +) + +const nft_RULE_COMPAT_F_INV uint32 = (1 << 1) +const nft_RULE_COMPAT_F_MASK uint32 = nft_RULE_COMPAT_F_INV + +// Used by xt match or target like xt_tcpudp to set compat policy between xtables and nftables +// https://elixir.bootlin.com/linux/v5.12/source/net/netfilter/nft_compat.c#L187 +type compatPolicy struct { + Proto uint32 + Flag uint32 +} + +var xtMatchCompatMap map[string]*compatPolicy = map[string]*compatPolicy{ + "tcp": { + Proto: unix.IPPROTO_TCP, + }, + "udp": { + Proto: unix.IPPROTO_UDP, + }, + "udplite": { + Proto: unix.IPPROTO_UDPLITE, + }, + "tcpmss": { + Proto: unix.IPPROTO_TCP, + }, + "sctp": { + Proto: unix.IPPROTO_SCTP, + }, + "osf": { + Proto: unix.IPPROTO_TCP, + }, + "ipcomp": { + Proto: unix.IPPROTO_COMP, + }, + "esp": { + Proto: unix.IPPROTO_ESP, + }, +} + +var xtTargetCompatMap map[string]*compatPolicy = map[string]*compatPolicy{ + "TCPOPTSTRIP": { + Proto: unix.IPPROTO_TCP, + }, + "TCPMSS": { + Proto: unix.IPPROTO_TCP, + }, +} + +func getCompatPolicy(exprs []expr.Any) (*compatPolicy, error) { + var exprItem expr.Any + var compat *compatPolicy + + for _, iter := range exprs { + var tmpExprItem expr.Any + var tmpCompat *compatPolicy + switch item := iter.(type) { + case *expr.Match: + if compat, ok := xtMatchCompatMap[item.Name]; ok { + tmpCompat = compat + tmpExprItem = item + } else { + continue + } + case *expr.Target: + if compat, ok := xtTargetCompatMap[item.Name]; ok { + tmpCompat = compat + tmpExprItem = item + } else { + continue + } + default: + continue + } + if compat == nil { + compat = tmpCompat + exprItem = tmpExprItem + } else if *compat != *tmpCompat { + return nil, fmt.Errorf("%#v and %#v has conflict compat policy %#v vs %#v", exprItem, tmpExprItem, compat, tmpCompat) + } + } + return compat, nil +} diff --git a/compat_policy_test.go b/compat_policy_test.go new file mode 100644 index 0000000..7565de0 --- /dev/null +++ b/compat_policy_test.go @@ -0,0 +1,77 @@ +package nftables + +import ( + "testing" + + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" + "golang.org/x/sys/unix" +) + +func TestGetCompatPolicy(t *testing.T) { + // -tcp --dport 0:65534 --sport 0:65534 + tcpMatch := &expr.Match{ + Name: "tcp", + Info: &xt.Tcp{ + SrcPorts: [2]uint16{0, 65534}, + DstPorts: [2]uint16{0, 65534}, + }, + } + + // -udp --dport 0:65534 --sport 0:65534 + udpMatch := &expr.Match{ + Name: "udp", + Info: &xt.Udp{ + SrcPorts: [2]uint16{0, 65534}, + DstPorts: [2]uint16{0, 65534}, + }, + } + + // -j TCPMSS --set-mss 1460 + mess := xt.Unknown([]byte{1460 & 0xff, (1460 >> 8) & 0xff}) + tcpMessTarget := &expr.Target{ + Name: "TCPMESS", + Info: &mess, + } + + // -m state --state ESTABLISHED + ctMatch := &expr.Match{ + Name: "conntrack", + Rev: 1, + Info: &xt.ConntrackMtinfo1{ + ConntrackMtinfoBase: xt.ConntrackMtinfoBase{ + MatchFlags: 0x2001, + }, + StateMask: 0x02, + }, + } + + // compatPolicy.Proto should be tcp + if compatPolicy, err := getCompatPolicy([]expr.Any{ + tcpMatch, + tcpMessTarget, + ctMatch, + }); err != nil { + t.Fatalf("getCompatPolicy fail %#v", err) + } else if compatPolicy.Proto != unix.IPPROTO_TCP { + t.Fatalf("getCompatPolicy wrong %#v", compatPolicy) + } + + // should conflict + if _, err := getCompatPolicy([]expr.Any{ + udpMatch, + tcpMatch, + }, + ); err == nil { + t.Fatalf("getCompatPolicy fail err should not be nil") + } + + // compatPolicy should be nil + if compatPolicy, err := getCompatPolicy([]expr.Any{ + ctMatch, + }); err != nil { + t.Fatalf("getCompatPolicy fail %#v", err) + } else if compatPolicy != nil { + t.Fatalf("getCompatPolicy fail compat policy of conntrack match should be nil") + } +} diff --git a/nftables_test.go b/nftables_test.go index 34f6831..bad4ccd 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -31,6 +31,7 @@ import ( "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/google/nftables/internal/nftest" + "github.com/google/nftables/xt" "github.com/mdlayher/netlink" "github.com/vishvananda/netns" "golang.org/x/sys/unix" @@ -6039,3 +6040,123 @@ func TestGetRulesQueue(t *testing.T) { t.Errorf("queue expr = %+v, wanted %+v", queueExpr, want) } } + +func TestNftablesCompat(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + input := c.AddChain(&nftables.Chain{ + Name: "input", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + + // -tcp --dport 0:65534 --sport 0:65534 + tcpMatch := &expr.Match{ + Name: "tcp", + Info: &xt.Tcp{ + SrcPorts: [2]uint16{0, 65534}, + DstPorts: [2]uint16{0, 65534}, + }, + } + + // -udp --dport 0:65534 --sport 0:65534 + udpMatch := &expr.Match{ + Name: "udp", + Info: &xt.Udp{ + SrcPorts: [2]uint16{0, 65534}, + DstPorts: [2]uint16{0, 65534}, + }, + } + + // - j TCPMSS --set-mss 1460 + mess := xt.Unknown([]byte{1460 & 0xff, (1460 >> 8) & 0xff}) + tcpMessTarget := &expr.Target{ + Name: "TCPMSS", + Info: &mess, + } + + // -m state --state ESTABLISHED + ctMatch := &expr.Match{ + Name: "conntrack", + Rev: 1, + Info: &xt.ConntrackMtinfo1{ + ConntrackMtinfoBase: xt.ConntrackMtinfoBase{ + MatchFlags: 0x2001, + }, + StateMask: 0x02, + }, + } + + // -p tcp --dport --dport 0:65534 --sport 0:65534 -m state --state ESTABLISHED -j TCPMSS --set-mss 1460 + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: input, + Exprs: []expr.Any{ + tcpMatch, + ctMatch, + tcpMessTarget, + }, + }) + if err := c.Flush(); err != nil { + t.Fatalf("add rule fail %#v", err) + } + + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: input, + Exprs: []expr.Any{ + udpMatch, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + if err := c.Flush(); err != nil { + t.Fatalf("add rule %#v fail", err) + } + + // -m state --state ESTABLISHED -j ACCEPT + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: input, + Exprs: []expr.Any{ + ctMatch, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + if err := c.Flush(); err != nil { + t.Fatalf("add rule %#v fail", err) + } + + // -p udp --dport --dport 0:65534 --sport 0:65534 -m state --state ESTABLISHED -j ACCEPT + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: input, + Exprs: []expr.Any{ + tcpMatch, + udpMatch, + ctMatch, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + if err := c.Flush(); err == nil { + t.Fatalf("compat policy should conflict and err should not be err") + } +} diff --git a/rule.go b/rule.go index 95bfdff..f004e45 100644 --- a/rule.go +++ b/rule.go @@ -133,6 +133,17 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_EXPRESSIONS, Data: cc.marshalAttr(exprAttrs)}, })...) + if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { + cc.setErr(err) + } else if compatPolicy != nil { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_COMPAT_PROTO, Data: binaryutil.BigEndian.PutUint32(compatPolicy.Proto)}, + {Type: unix.NFTA_RULE_COMPAT_FLAGS, Data: binaryutil.BigEndian.PutUint32(compatPolicy.Flag & nft_RULE_COMPAT_F_MASK)}, + })}, + })...) + } + msgData := []byte{} msgData = append(msgData, data...)