From 9ac63cb2823f4bcabf7a1a615908314920d0591e Mon Sep 17 00:00:00 2001
From: Michael Stapelberg <stapelberg@google.com>
Date: Mon, 22 Oct 2018 09:22:02 +0200
Subject: [PATCH] add exprs and test for TCP MSS clamping

---
 expr/bitwise.go   |  65 ++++++++++++++++
 expr/byteorder.go |  59 +++++++++++++++
 expr/exthdr.go    |  64 ++++++++++++++++
 expr/rt.go        |  55 ++++++++++++++
 nftables_test.go  | 183 +++++++++++++++++++++++++++++++++++++++++++++-
 5 files changed, 422 insertions(+), 4 deletions(-)
 create mode 100644 expr/bitwise.go
 create mode 100644 expr/byteorder.go
 create mode 100644 expr/exthdr.go
 create mode 100644 expr/rt.go

diff --git a/expr/bitwise.go b/expr/bitwise.go
new file mode 100644
index 0000000..6196da6
--- /dev/null
+++ b/expr/bitwise.go
@@ -0,0 +1,65 @@
+// Copyright 2018 Google LLC. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package expr
+
+import (
+	"fmt"
+
+	"github.com/google/nftables/binaryutil"
+	"github.com/mdlayher/netlink"
+	"golang.org/x/sys/unix"
+)
+
+type Bitwise struct {
+	SourceRegister uint32
+	DestRegister   uint32
+	Len            uint32
+	Mask           []byte
+	Xor            []byte
+}
+
+func (e *Bitwise) marshal() ([]byte, error) {
+	mask, err := netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_DATA_VALUE, Data: e.Mask},
+	})
+	if err != nil {
+		return nil, err
+	}
+	xor, err := netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_DATA_VALUE, Data: e.Xor},
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	data, err := netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_BITWISE_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)},
+		{Type: unix.NFTA_BITWISE_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)},
+		{Type: unix.NFTA_BITWISE_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)},
+		{Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_MASK, Data: mask},
+		{Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_XOR, Data: xor},
+	})
+	if err != nil {
+		return nil, err
+	}
+	return netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_EXPR_NAME, Data: []byte("bitwise\x00")},
+		{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
+	})
+}
+
+func (e *Bitwise) unmarshal(data []byte) error {
+	return fmt.Errorf("not yet implemented")
+}
diff --git a/expr/byteorder.go b/expr/byteorder.go
new file mode 100644
index 0000000..a28996d
--- /dev/null
+++ b/expr/byteorder.go
@@ -0,0 +1,59 @@
+// Copyright 2018 Google LLC. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package expr
+
+import (
+	"fmt"
+
+	"github.com/google/nftables/binaryutil"
+	"github.com/mdlayher/netlink"
+	"golang.org/x/sys/unix"
+)
+
+type ByteorderOp uint32
+
+const (
+	ByteorderNtoh ByteorderOp = unix.NFT_BYTEORDER_NTOH
+	ByteorderHton ByteorderOp = unix.NFT_BYTEORDER_HTON
+)
+
+type Byteorder struct {
+	SourceRegister uint32
+	DestRegister   uint32
+	Op             ByteorderOp
+	Len            uint32
+	Size           uint32
+}
+
+func (e *Byteorder) marshal() ([]byte, error) {
+	data, err := netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)},
+		{Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)},
+		{Type: unix.NFTA_BYTEORDER_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))},
+		{Type: unix.NFTA_BYTEORDER_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)},
+		{Type: unix.NFTA_BYTEORDER_SIZE, Data: binaryutil.BigEndian.PutUint32(e.Size)},
+	})
+	if err != nil {
+		return nil, err
+	}
+	return netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_EXPR_NAME, Data: []byte("byteorder\x00")},
+		{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
+	})
+}
+
+func (e *Byteorder) unmarshal(data []byte) error {
+	return fmt.Errorf("not yet implemented")
+}
diff --git a/expr/exthdr.go b/expr/exthdr.go
new file mode 100644
index 0000000..e47f268
--- /dev/null
+++ b/expr/exthdr.go
@@ -0,0 +1,64 @@
+// Copyright 2018 Google LLC. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package expr
+
+import (
+	"fmt"
+
+	"github.com/google/nftables/binaryutil"
+	"github.com/mdlayher/netlink"
+	"golang.org/x/sys/unix"
+)
+
+type ExthdrOp uint32
+
+const (
+	ExthdrOpIpv6   ExthdrOp = unix.NFT_EXTHDR_OP_IPV6
+	ExthdrOpTcpopt ExthdrOp = unix.NFT_EXTHDR_OP_TCPOPT
+)
+
+type Exthdr struct {
+	DestRegister   uint32
+	Type           uint8
+	Offset         uint32
+	Len            uint32
+	Flags          uint32
+	Op             ExthdrOp
+	SourceRegister uint32
+}
+
+func (e *Exthdr) marshal() ([]byte, error) {
+	data, err := netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)},
+		{Type: unix.NFTA_EXTHDR_TYPE, Data: []byte{e.Type}},
+		{Type: unix.NFTA_EXTHDR_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)},
+		{Type: unix.NFTA_EXTHDR_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)},
+		// TODO: these fields seem to be conditional?
+		//{Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)},
+		//{Type: unix.NFTA_EXTHDR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)},
+		{Type: unix.NFTA_EXTHDR_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))},
+	})
+	if err != nil {
+		return nil, err
+	}
+	return netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_EXPR_NAME, Data: []byte("exthdr\x00")},
+		{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
+	})
+}
+
+func (e *Exthdr) unmarshal(data []byte) error {
+	return fmt.Errorf("not yet implemented")
+}
diff --git a/expr/rt.go b/expr/rt.go
new file mode 100644
index 0000000..8fdbdb5
--- /dev/null
+++ b/expr/rt.go
@@ -0,0 +1,55 @@
+// Copyright 2018 Google LLC. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package expr
+
+import (
+	"fmt"
+
+	"github.com/google/nftables/binaryutil"
+	"github.com/mdlayher/netlink"
+	"golang.org/x/sys/unix"
+)
+
+type RtKey uint32
+
+const (
+	RtClassid  RtKey = unix.NFT_RT_CLASSID
+	RtNexthop4 RtKey = unix.NFT_RT_NEXTHOP4
+	RtNexthop6 RtKey = unix.NFT_RT_NEXTHOP6
+	RtTCPMSS   RtKey = unix.NFT_RT_TCPMSS
+)
+
+type Rt struct {
+	Register uint32
+	Key      RtKey
+}
+
+func (e *Rt) marshal() ([]byte, error) {
+	data, err := netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))},
+		{Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)},
+	})
+	if err != nil {
+		return nil, err
+	}
+	return netlink.MarshalAttributes([]netlink.Attribute{
+		{Type: unix.NFTA_EXPR_NAME, Data: []byte("rt\x00")},
+		{Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data},
+	})
+}
+
+func (e *Rt) unmarshal(data []byte) error {
+	return fmt.Errorf("not yet implemented")
+}
diff --git a/nftables_test.go b/nftables_test.go
index 3af63a7..5290219 100644
--- a/nftables_test.go
+++ b/nftables_test.go
@@ -16,7 +16,9 @@ package nftables_test
 
 import (
 	"bytes"
+	"fmt"
 	"net"
+	"strings"
 	"testing"
 
 	"github.com/google/nftables"
@@ -26,6 +28,46 @@ import (
 	"golang.org/x/sys/unix"
 )
 
+// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
+// users to make sense of large byte literals more easily.
+func nfdump(b []byte) string {
+	var buf bytes.Buffer
+	i := 0
+	for ; i < len(b); i += 4 {
+		// TODO: show printable characters as ASCII
+		fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
+			b[i],
+			b[i+1],
+			b[i+2],
+			b[i+3])
+	}
+	for ; i < len(b); i++ {
+		fmt.Fprintf(&buf, "%02x ", b[i])
+	}
+	return buf.String()
+}
+
+// linediff returns a side-by-side diff of two nfdump() return values, flagging
+// lines which are not equal with an exclamation point prefix.
+func linediff(a, b string) string {
+	var buf bytes.Buffer
+	fmt.Fprintf(&buf, "got -- want\n")
+	linesA := strings.Split(a, "\n")
+	linesB := strings.Split(b, "\n")
+	for idx, lineA := range linesA {
+		if idx >= len(linesB) {
+			break
+		}
+		lineB := linesB[idx]
+		prefix := "! "
+		if lineA == lineB {
+			prefix = "  "
+		}
+		fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
+	}
+	return buf.String()
+}
+
 func ifname(n string) []byte {
 	b := make([]byte, 16)
 	copy(b, []byte(n+"\x00"))
@@ -75,7 +117,7 @@ func TestConfigureNAT(t *testing.T) {
 					continue
 				}
 				if got, want := b, want[0]; !bytes.Equal(got, want) {
-					t.Errorf("message %d: got %x, want %x", idx, got, want)
+					t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
 				}
 				want = want[1:]
 			}
@@ -284,7 +326,7 @@ func TestGetRule(t *testing.T) {
 					continue
 				}
 				if got, want := b, want[0]; !bytes.Equal(got, want) {
-					t.Errorf("message %d: got %#v, want %#v", idx, got, want)
+					t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
 				}
 				want = want[1:]
 			}
@@ -363,7 +405,7 @@ func TestAddCounter(t *testing.T) {
 					continue
 				}
 				if got, want := b, want[0]; !bytes.Equal(got, want) {
-					t.Errorf("message %d: got %x, want %x", idx, got, want)
+					t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
 				}
 				want = want[1:]
 			}
@@ -426,7 +468,7 @@ func TestGetObjReset(t *testing.T) {
 					continue
 				}
 				if got, want := b, want[0]; !bytes.Equal(got, want) {
-					t.Errorf("message %d: got %#v, want %#v", idx, got, want)
+					t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
 				}
 				want = want[1:]
 			}
@@ -468,3 +510,136 @@ func TestGetObjReset(t *testing.T) {
 		t.Errorf("unexpected number of bytes: got %d, want %d", got, want)
 	}
 }
+
+func TestConfigureClamping(t *testing.T) {
+	// The want byte sequences come from stracing nft(8), e.g.:
+	// strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat
+	//
+	// The nft(8) command sequence was taken from:
+	// https://wiki.nftables.org/wiki-nftables/index.php/Mangle_TCP_options
+	want := [][]byte{
+		// batch begin
+		[]byte("\x00\x00\x00\x0a"),
+		// nft flush ruleset
+		[]byte("\x00\x00\x00\x00"),
+		// nft add table ip filter
+		[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
+		// nft add chain filter forward '{' type filter hook forward priority 0 \; '}'
+		[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x03\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"),
+		// nft add rule ip filter forward oifname uplink0 tcp flags syn tcp option maxseg size set rt mtu
+		[]byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x66\x6f\x72\x77\x61\x72\x64\x00\xf0\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x07\x08\x00\x01\x00\x00\x00\x00\x01\x38\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x18\x00\x03\x80\x14\x00\x01\x00\x75\x70\x6c\x69\x6e\x6b\x30\x00\x00\x00\x00\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x10\x08\x00\x01\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x05\x00\x01\x00\x06\x00\x00\x00\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x02\x08\x00\x03\x00\x00\x00\x00\x0d\x08\x00\x04\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x01\x0c\x00\x04\x80\x05\x00\x01\x00\x02\x00\x00\x00\x0c\x00\x05\x80\x05\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x05\x00\x01\x00\x00\x00\x00\x00\x20\x00\x01\x80\x07\x00\x01\x00\x72\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x40\x00\x01\x80\x0e\x00\x01\x00\x62\x79\x74\x65\x6f\x72\x64\x65\x72\x00\x00\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x01\x08\x00\x04\x00\x00\x00\x00\x02\x08\x00\x05\x00\x00\x00\x00\x02\x3c\x00\x01\x80\x0b\x00\x01\x00\x65\x78\x74\x68\x64\x72\x00\x00\x2c\x00\x02\x80\x08\x00\x07\x00\x00\x00\x00\x01\x05\x00\x02\x00\x02\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x02\x08\x00\x04\x00\x00\x00\x00\x02\x08\x00\x06\x00\x00\x00\x00\x01"),
+		// batch end
+		[]byte("\x00\x00\x00\x0a"),
+	}
+
+	c := &nftables.Conn{
+		TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
+			for idx, msg := range req {
+				b, err := msg.MarshalBinary()
+				if err != nil {
+					t.Fatal(err)
+				}
+				if len(b) < 16 {
+					continue
+				}
+				b = b[16:]
+				if len(want) == 0 {
+					t.Errorf("no want entry for message %d: %x", idx, b)
+					continue
+				}
+				if got, want := b, want[0]; !bytes.Equal(got, want) {
+					t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
+				}
+				want = want[1:]
+			}
+			return req, nil
+		},
+	}
+
+	c.FlushRuleset()
+
+	filter := c.AddTable(&nftables.Table{
+		Family: nftables.TableFamilyIPv4,
+		Name:   "filter",
+	})
+
+	forward := c.AddChain(&nftables.Chain{
+		Name:     "forward",
+		Table:    filter,
+		Type:     nftables.ChainTypeFilter,
+		Hooknum:  nftables.ChainHookForward,
+		Priority: nftables.ChainPriorityFilter,
+	})
+
+	c.AddRule(&nftables.Rule{
+		Table: filter,
+		Chain: forward,
+		Exprs: []expr.Any{
+			// [ meta load oifname => reg 1 ]
+			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
+			// [ cmp eq reg 1 0x30707070 0x00000000 0x00000000 0x00000000 ]
+			&expr.Cmp{
+				Op:       expr.CmpOpEq,
+				Register: 1,
+				Data:     ifname("uplink0"),
+			},
+
+			// [ meta load l4proto => reg 1 ]
+			&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
+			// [ cmp eq reg 1 0x00000006 ]
+			&expr.Cmp{
+				Op:       expr.CmpOpEq,
+				Register: 1,
+				Data:     []byte{unix.IPPROTO_TCP},
+			},
+
+			// [ payload load 1b @ transport header + 13 => reg 1 ]
+			&expr.Payload{
+				DestRegister: 1,
+				Base:         expr.PayloadBaseTransportHeader,
+				Offset:       13, // TODO
+				Len:          1,  // TODO
+			},
+			// [ bitwise reg 1 = (reg=1 & 0x00000002 ) ^ 0x00000000 ]
+			&expr.Bitwise{
+				DestRegister:   1,
+				SourceRegister: 1,
+				Len:            1,
+				Mask:           []byte{0x02},
+				Xor:            []byte{0x00},
+			},
+			// [ cmp neq reg 1 0x00000000 ]
+			&expr.Cmp{
+				Op:       expr.CmpOpNeq,
+				Register: 1,
+				Data:     []byte{0x00},
+			},
+
+			// [ rt load tcpmss => reg 1 ]
+			&expr.Rt{
+				Register: 1,
+				Key:      expr.RtTCPMSS,
+			},
+			// [ byteorder reg 1 = hton(reg 1, 2, 2) ]
+			&expr.Byteorder{
+				DestRegister:   1,
+				SourceRegister: 1,
+				Op:             expr.ByteorderHton,
+				Len:            2,
+				Size:           2,
+			},
+			// [ exthdr write tcpopt reg 1 => 2b @ 2 + 2 ]
+			&expr.Exthdr{
+				SourceRegister: 1,
+				Type:           2, // TODO
+				Offset:         2,
+				Len:            2,
+				Op:             expr.ExthdrOpTcpopt,
+			},
+		},
+	})
+
+	if err := c.Flush(); err != nil {
+		t.Fatal(err)
+	}
+}