diff --git a/expr/expr.go b/expr/expr.go index 66e26a9..00b81c2 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -203,6 +203,8 @@ func exprFromName(name string) Any { e = &SynProxy{} case "ctexpect": e = &CtExpect{} + case "secmark": + e = &SecMark{} } return e } diff --git a/expr/secmark.go b/expr/secmark.go new file mode 100644 index 0000000..3faf87f --- /dev/null +++ b/expr/secmark.go @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// From https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1338 +const ( + NFTA_SECMARK_CTX = 0x01 +) + +type SecMark struct { + Ctx string +} + +func (e *SecMark) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("secmark\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *SecMark) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: NFTA_SECMARK_CTX, Data: []byte(e.Ctx)}, + } + return netlink.MarshalAttributes(attrs) +} + +func (e *SecMark) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_SECMARK_CTX: + e.Ctx = ad.String() + } + } + return ad.Err() +} diff --git a/nftables_test.go b/nftables_test.go index 020757f..72600f6 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1383,6 +1383,68 @@ func TestCt(t *testing.T) { } } +func TestSecMarkMarshaling(t *testing.T) { + // Testing marshaling since secmark requires live selinux tag otherwise + // errors with conn.Receive: netlink receive: no such file or directory. + // More information available here: + // https://git.netfilter.org/nftables/tree/files/examples/secmark.nft?id=26d9cbefb10e6bc3765df7e9e7a4fc3b951a80f3#n6 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // sudo nft add table inet filter + []byte("\x01\x00\x00\x00\x0b\x00\x01\x00filter\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // sudo nft add secmark inet filter sshtag '{ ctx "system_u:object_r:ssh_server_packet_t:s0" }' + []byte("\x01\x00\x00\x00\x0b\x00\x01\x00filter\x00\x00\x0b\x00\x02\x00sshtag\x00\x00\x08\x00\x03\x00\x00\x00\x00\x080\x00\x04\x80,\x00\x01\x00system_u:object_r:ssh_server_packet_t:s0"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + conn, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + })) + if err != nil { + t.Fatal(err) + } + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyINet, + Name: "filter", + }) + + sec := &nftables.NamedObj{ + Table: table, + Name: "sshtag", + Type: nftables.ObjTypeSecMark, + Obj: &expr.SecMark{ + Ctx: "system_u:object_r:ssh_server_packet_t:s0", + }, + } + conn.AddObj(sec) + + if err := conn.Flush(); err != nil { + t.Fatalf(err.Error()) + } +} + func TestSynProxyObject(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) diff --git a/obj.go b/obj.go index 6e8be6d..116c92f 100644 --- a/obj.go +++ b/obj.go @@ -54,7 +54,7 @@ var objByObjTypeMagic = map[ObjType]string{ ObjTypeCtHelper: "cthelper", ObjTypeTunnel: "tunnel", // not implemented in expr ObjTypeCtTimeout: "cttimeout", // not implemented in expr - ObjTypeSecMark: "secmark", // not implemented in expr + ObjTypeSecMark: "secmark", ObjTypeCtExpect: "ctexpect", ObjTypeSynProxy: "synproxy", }