diff --git a/expr/ct.go b/expr/ct.go index 1980371..127b6fd 100644 --- a/expr/ct.go +++ b/expr/ct.go @@ -56,6 +56,58 @@ const ( CtStateBitUNTRACKED uint32 = 64 ) +// Missing ct timeout consts +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1592 +const ( + NFTA_CT_TIMEOUT_L3PROTO = 0x01 + NFTA_CT_TIMEOUT_L4PROTO = 0x02 + NFTA_CT_TIMEOUT_DATA = 0x03 +) + +type CtStatePolicyTimeout map[uint16]uint32 + +const ( + // https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n24 + CtStateTCPSYNSENT = iota + CtStateTCPSYNRECV + CtStateTCPESTABLISHED + CtStateTCPFINWAIT + CtStateTCPCLOSEWAIT + CtStateTCPLASTACK + CtStateTCPTIMEWAIT + CtStateTCPCLOSE + CtStateTCPSYNSENT2 + CtStateTCPRETRANS + CtStateTCPUNACK +) + +// https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n38 +var CtStateTCPTimeoutDefaults CtStatePolicyTimeout = map[uint16]uint32{ + CtStateTCPSYNSENT: 120, + CtStateTCPSYNRECV: 60, + CtStateTCPESTABLISHED: 43200, + CtStateTCPFINWAIT: 120, + CtStateTCPCLOSEWAIT: 60, + CtStateTCPLASTACK: 30, + CtStateTCPTIMEWAIT: 120, + CtStateTCPCLOSE: 10, + CtStateTCPSYNSENT2: 120, + CtStateTCPRETRANS: 300, + CtStateTCPUNACK: 300, +} + +const ( + // https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n57 + CtStateUDPUNREPLIED = iota + CtStateUDPREPLIED +) + +// https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n57 +var CtStateUDPTimeoutDefaults CtStatePolicyTimeout = map[uint16]uint32{ + CtStateUDPUNREPLIED: 30, + CtStateUDPREPLIED: 180, +} + // Ct defines type for NFT connection tracking type Ct struct { Register uint32 @@ -265,3 +317,84 @@ func (c *CtExpect) unmarshal(fam byte, data []byte) error { } return ad.Err() } + +type CtTimeout struct { + L3Proto uint16 + L4Proto uint8 + Policy CtStatePolicyTimeout +} + +func (c *CtTimeout) marshal(fam byte) ([]byte, error) { + exprData, err := c.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("cttimeout\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (c *CtTimeout) marshalData(fam byte) ([]byte, error) { + var policy CtStatePolicyTimeout + switch c.L4Proto { + case unix.IPPROTO_UDP: + policy = CtStateUDPTimeoutDefaults + default: + policy = CtStateTCPTimeoutDefaults + } + + for k, v := range c.Policy { + policy[k] = v + } + + var policyAttrs []netlink.Attribute + for k, v := range policy { + policyAttrs = append(policyAttrs, netlink.Attribute{Type: k + 1, Data: binaryutil.BigEndian.PutUint32(v)}) + } + policyData, err := netlink.MarshalAttributes(policyAttrs) + if err != nil { + return nil, err + } + + exprData := []netlink.Attribute{ + {Type: NFTA_CT_TIMEOUT_L3PROTO, Data: binaryutil.BigEndian.PutUint16(c.L3Proto)}, + {Type: NFTA_CT_TIMEOUT_L4PROTO, Data: []byte{c.L4Proto}}, + {Type: unix.NLA_F_NESTED | NFTA_CT_TIMEOUT_DATA, Data: policyData}, + } + + return netlink.MarshalAttributes(exprData) +} + +func (c *CtTimeout) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_CT_TIMEOUT_L3PROTO: + c.L3Proto = ad.Uint16() + case NFTA_CT_TIMEOUT_L4PROTO: + c.L4Proto = ad.Uint8() + case NFTA_CT_TIMEOUT_DATA: + decoder, err := netlink.NewAttributeDecoder(ad.Bytes()) + decoder.ByteOrder = binary.BigEndian + if err != nil { + return err + } + for decoder.Next() { + switch c.L4Proto { + case unix.IPPROTO_UDP: + c.Policy = CtStateUDPTimeoutDefaults + default: + c.Policy = CtStateTCPTimeoutDefaults + } + c.Policy[decoder.Type()-1] = decoder.Uint32() + } + } + } + return ad.Err() +} diff --git a/expr/expr.go b/expr/expr.go index 00b81c2..fc00a69 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -205,6 +205,8 @@ func exprFromName(name string) Any { e = &CtExpect{} case "secmark": e = &SecMark{} + case "cttimeout": + e = &CtTimeout{} } return e } diff --git a/nftables_test.go b/nftables_test.go index 72600f6..b241327 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1541,6 +1541,131 @@ func TestSynProxyObject(t *testing.T) { } } +func TestCtTimeout(t *testing.T) { + t.Parallel() + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + tests := [...]struct { + Name string + Input expr.CtTimeout + Expect expr.CtTimeout + }{ + { + Name: "timeout-blank-tcp-policy", + Input: expr.CtTimeout{L4Proto: unix.IPPROTO_TCP}, + Expect: expr.CtTimeout{ + L4Proto: unix.IPPROTO_TCP, + L3Proto: unix.NFPROTO_UNSPEC, + Policy: expr.CtStateTCPTimeoutDefaults, + }, + }, + { + Name: "timeout-blank-udp-policy", + Input: expr.CtTimeout{L4Proto: unix.IPPROTO_UDP}, + Expect: expr.CtTimeout{ + L4Proto: unix.IPPROTO_UDP, + L3Proto: unix.NFPROTO_UNSPEC, + Policy: expr.CtStateUDPTimeoutDefaults, + }, + }, + { + Name: "timeout-partial-tcp-policy", + Input: expr.CtTimeout{ + L4Proto: unix.IPPROTO_TCP, + L3Proto: unix.NFPROTO_IPV4, + Policy: expr.CtStatePolicyTimeout{ + expr.CtStateTCPSYNSENT: 100, + expr.CtStateTCPESTABLISHED: 5, + expr.CtStateTCPCLOSEWAIT: 9, + }, + }, + Expect: expr.CtTimeout{ + L4Proto: unix.IPPROTO_TCP, + L3Proto: unix.NFPROTO_IPV4, + Policy: expr.CtStatePolicyTimeout{ + expr.CtStateTCPSYNSENT: 100, + expr.CtStateTCPSYNRECV: 60, + expr.CtStateTCPESTABLISHED: 5, + expr.CtStateTCPFINWAIT: 120, + expr.CtStateTCPCLOSEWAIT: 9, + expr.CtStateTCPLASTACK: 30, + expr.CtStateTCPTIMEWAIT: 120, + expr.CtStateTCPCLOSE: 10, + expr.CtStateTCPSYNSENT2: 120, + expr.CtStateTCPRETRANS: 300, + expr.CtStateTCPUNACK: 300, + }, + }, + }, + { + Name: "timeout-complete-udp-policy", + Input: expr.CtTimeout{ + L4Proto: unix.IPPROTO_UDP, + L3Proto: unix.NFPROTO_IPV6, + Policy: expr.CtStatePolicyTimeout{ + expr.CtStateUDPUNREPLIED: 500, + expr.CtStateUDPREPLIED: 10000, + }, + }, + Expect: expr.CtTimeout{ + L4Proto: unix.IPPROTO_UDP, + L3Proto: unix.NFPROTO_IPV6, + Policy: expr.CtStatePolicyTimeout{ + expr.CtStateUDPUNREPLIED: 500, + expr.CtStateUDPREPLIED: 10000, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + ctt1 := conn.AddObj(&nftables.NamedObj{ + Table: table, + Name: tt.Name, + Type: nftables.ObjTypeCtTimeout, + Obj: &tt.Input, + }) + + if err := conn.Flush(); err != nil { + t.Fatalf(err.Error()) + } + + obj, err := conn.GetObject(ctt1) + if err != nil { + t.Errorf("c.GetObject(ctt1) failed: %v failed", err) + } + + ctt2, ok := obj.(*nftables.NamedObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.NamedObj", ctt2) + } + + o1 := ctt2.Obj.(*expr.CtTimeout) + o2 := &tt.Expect + if got, want := o1.L3Proto, o2.L3Proto; got != want { + t.Fatalf("unexpected l3proto: got %d, want %d", got, want) + } + + if got, want := o1.L4Proto, o2.L4Proto; got != want { + t.Fatalf("unexpected l4proto: got %d, want %d", got, want) + } + + if got, want := o1.Policy, o2.Policy; !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected policy: got %v, want %v", got, want) + } + }) + } +} + func TestCtExpect(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) @@ -1642,7 +1767,7 @@ func TestCtHelper(t *testing.T) { helper, ok := obj1.(*nftables.NamedObj) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1) + t.Fatalf("unexpected type: got %T, want *nftables.NamedObj", obj1) } if got, want := helper.Name, "ftp-standard"; got != want { diff --git a/obj.go b/obj.go index 116c92f..3fcd6d7 100644 --- a/obj.go +++ b/obj.go @@ -52,8 +52,8 @@ var objByObjTypeMagic = map[ObjType]string{ ObjTypeLimit: "limit", ObjTypeConnLimit: "connlimit", ObjTypeCtHelper: "cthelper", - ObjTypeTunnel: "tunnel", // not implemented in expr - ObjTypeCtTimeout: "cttimeout", // not implemented in expr + ObjTypeTunnel: "tunnel", // not implemented in expr + ObjTypeCtTimeout: "cttimeout", ObjTypeSecMark: "secmark", ObjTypeCtExpect: "ctexpect", ObjTypeSynProxy: "synproxy",