From 2fecffcfe11c49a3067ac8983d945717d9e1e7e1 Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Mon, 9 Sep 2024 08:35:05 +0200 Subject: [PATCH] Add ct expect support (#272) --- expr/ct.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++ expr/expr.go | 2 ++ nftables_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++-- obj.go | 4 +-- 4 files changed, 143 insertions(+), 4 deletions(-) diff --git a/expr/ct.go b/expr/ct.go index 7ba113a..1980371 100644 --- a/expr/ct.go +++ b/expr/ct.go @@ -194,3 +194,74 @@ func (c *CtHelper) unmarshal(fam byte, data []byte) error { } return ad.Err() } + +// From https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1601 +// Currently not available in sys/unix +const ( + NFTA_CT_EXPECT_L3PROTO = 0x01 + NFTA_CT_EXPECT_L4PROTO = 0x02 + NFTA_CT_EXPECT_DPORT = 0x03 + NFTA_CT_EXPECT_TIMEOUT = 0x04 + NFTA_CT_EXPECT_SIZE = 0x05 +) + +type CtExpect struct { + L3Proto uint16 + L4Proto uint8 + DPort uint16 + Timeout uint32 + Size uint8 +} + +func (c *CtExpect) marshal(fam byte) ([]byte, error) { + exprData, err := c.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("ctexpect\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (c *CtExpect) marshalData(fam byte) ([]byte, error) { + // all elements except l3proto must be defined + // per https://git.netfilter.org/nftables/tree/doc/stateful-objects.txt?id=db70959a5ccf2952b218f51c3d529e186a5a43bb#n119 + // from man page: l3proto is derived from the table family by default + exprData := []netlink.Attribute{ + {Type: NFTA_CT_EXPECT_L4PROTO, Data: []byte{c.L4Proto}}, + {Type: NFTA_CT_EXPECT_DPORT, Data: binaryutil.BigEndian.PutUint16(c.DPort)}, + {Type: NFTA_CT_EXPECT_TIMEOUT, Data: binaryutil.BigEndian.PutUint32(c.Timeout)}, + {Type: NFTA_CT_EXPECT_SIZE, Data: []byte{c.Size}}, + } + + if c.L3Proto != 0 { + attr := netlink.Attribute{Type: NFTA_CT_EXPECT_L3PROTO, Data: binaryutil.BigEndian.PutUint16(c.L3Proto)} + exprData = append(exprData, attr) + } + return netlink.MarshalAttributes(exprData) +} + +func (c *CtExpect) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_CT_EXPECT_L3PROTO: + c.L3Proto = ad.Uint16() + case NFTA_CT_EXPECT_L4PROTO: + c.L4Proto = ad.Uint8() + case NFTA_CT_EXPECT_DPORT: + c.DPort = ad.Uint16() + case NFTA_CT_EXPECT_TIMEOUT: + c.Timeout = ad.Uint32() + case NFTA_CT_EXPECT_SIZE: + c.Size = ad.Uint8() + } + } + return ad.Err() +} diff --git a/expr/expr.go b/expr/expr.go index 2ae57a9..66e26a9 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -201,6 +201,8 @@ func exprFromName(name string) Any { e = &CtHelper{} case "synproxy": e = &SynProxy{} + case "ctexpect": + e = &CtExpect{} } return e } diff --git a/nftables_test.go b/nftables_test.go index bb61d6d..020757f 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1428,7 +1428,6 @@ func TestSynProxyObject(t *testing.T) { conn.AddObj(syn1) conn.AddObj(syn2) conn.AddObj(syn3) - if err := conn.Flush(); err != nil { t.Fatalf(err.Error()) } @@ -1437,7 +1436,6 @@ func TestSynProxyObject(t *testing.T) { if err != nil { t.Errorf("c.GetObjects(table) failed: %v", err) } - if got, want := len(objs), 3; got != want { t.Fatalf("received %d objects, expected %d", got, want) } @@ -1481,6 +1479,74 @@ func TestSynProxyObject(t *testing.T) { } } +func TestCtExpect(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + cte := &nftables.NamedObj{ + Table: table, + Name: "expect", + Type: nftables.ObjTypeCtExpect, + Obj: &expr.CtExpect{ + L3Proto: unix.NFPROTO_IPV4, + L4Proto: unix.IPPROTO_TCP, + DPort: 53, + Timeout: 20, + Size: 100, + }, + } + + conn.AddObj(cte) + if err := conn.Flush(); err != nil { + t.Fatalf(err.Error()) + } + + objs, err := conn.GetNamedObjects(table) + if err != nil { + t.Errorf("c.GetObjects(table) failed: %v", err) + } + + if got, want := len(objs), 1; got != want { + t.Fatalf("received %d objects, expected %d", got, want) + } + + obj := objs[0].(*nftables.NamedObj) + if got, want := obj.Name, cte.Name; got != want { + t.Errorf("object names are not equal: got %s, want %s", got, want) + } + if got, want := obj.Type, cte.Type; got != want { + t.Errorf("object types are not equal: got %v, want %v", got, want) + } + if got, want := obj.Table.Name, cte.Table.Name; got != want { + t.Errorf("object tables are not equal: got %s, want %s", got, want) + } + + ce1 := obj.Obj.(*expr.CtExpect) + ce2 := cte.Obj.(*expr.CtExpect) + if got, want := ce1.L3Proto, ce2.L3Proto; got != want { + t.Errorf("object l3proto not equal: got %d, want %d", got, want) + } + if got, want := ce1.L4Proto, ce2.L4Proto; got != want { + t.Errorf("object l4proto not equal: got %d, want %d", got, want) + } + if got, want := ce1.DPort, ce2.DPort; got != want { + t.Errorf("object dport not equal: got %d, want %d", got, want) + } + if got, want := ce1.Size, ce2.Size; got != want { + t.Errorf("object Size not equal: got %d, want %d", got, want) + } + if got, want := ce1.Timeout, ce2.Timeout; got != want { + t.Errorf("object timeout not equal: got %d, want %d", got, want) + } +} + func TestCtHelper(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) diff --git a/obj.go b/obj.go index 421a87c..6e8be6d 100644 --- a/obj.go +++ b/obj.go @@ -55,8 +55,8 @@ var objByObjTypeMagic = map[ObjType]string{ ObjTypeTunnel: "tunnel", // not implemented in expr ObjTypeCtTimeout: "cttimeout", // not implemented in expr ObjTypeSecMark: "secmark", // not implemented in expr - ObjTypeCtExpect: "ctexpect", // not implemented in expr - ObjTypeSynProxy: "synproxy", // not implemented in expr + ObjTypeCtExpect: "ctexpect", + ObjTypeSynProxy: "synproxy", } // Obj represents a netfilter stateful object. See also