From 1324f5d5a9f7df15e3e204b5a8d8eefe472678a1 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Sat, 23 Jun 2018 21:12:14 +0200 Subject: [PATCH] add GetRule --- binaryutil/binaryutil.go | 31 ++++++++++ expr/expr.go | 48 +++++++++++++++ expr/immediate.go | 6 ++ expr/nat.go | 6 ++ expr/payload.go | 6 ++ nftables.go | 130 ++++++++++++++++++++++++++++++++++++++- nftables_test.go | 78 +++++++++++++++++++++++ 7 files changed, 304 insertions(+), 1 deletion(-) diff --git a/binaryutil/binaryutil.go b/binaryutil/binaryutil.go index 3e404de..47e7786 100644 --- a/binaryutil/binaryutil.go +++ b/binaryutil/binaryutil.go @@ -26,6 +26,9 @@ import ( type ByteOrder interface { PutUint16(v uint16) []byte PutUint32(v uint32) []byte + PutUint64(v uint64) []byte + Uint32(b []byte) uint32 + Uint64(b []byte) uint64 } // NativeEndian is either little endian or big endian, depending on the native @@ -46,6 +49,20 @@ func (nativeEndian) PutUint32(v uint32) []byte { return buf } +func (nativeEndian) PutUint64(v uint64) []byte { + buf := make([]byte, 8) + natend.NativeEndian.PutUint64(buf, v) + return buf +} + +func (nativeEndian) Uint32(b []byte) uint32 { + return natend.NativeEndian.Uint32(b) +} + +func (nativeEndian) Uint64(b []byte) uint64 { + return natend.NativeEndian.Uint64(b) +} + // BigEndian is like binary.BigEndian, but allocates memory and returns byte // slices, for convenience. var BigEndian ByteOrder = &bigEndian{} @@ -63,3 +80,17 @@ func (bigEndian) PutUint32(v uint32) []byte { binary.BigEndian.PutUint32(buf, v) return buf } + +func (bigEndian) PutUint64(v uint64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, v) + return buf +} + +func (bigEndian) Uint32(b []byte) uint32 { + return binary.BigEndian.Uint32(b) +} + +func (bigEndian) Uint64(b []byte) uint64 { + return binary.BigEndian.Uint64(b) +} diff --git a/expr/expr.go b/expr/expr.go index 21997e2..e055145 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -16,6 +16,8 @@ package expr import ( + "fmt" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -26,9 +28,15 @@ func Marshal(e Any) ([]byte, error) { return e.marshal() } +// Unmarshal fills an expression from the specified byte slice. +func Unmarshal(data []byte, e Any) error { + return e.unmarshal(data) +} + // Any is an interface implemented by any expression type. type Any interface { marshal() ([]byte, error) + unmarshal([]byte) error } // MetaKey specifies which piece of meta information should be loaded. See also @@ -87,6 +95,23 @@ func (e *Meta) marshal() ([]byte, error) { }) } +func (e *Meta) unmarshal(data []byte) error { + attrs, err := netlink.UnmarshalAttributes(data) + if err != nil { + return err + } + for _, attr := range attrs { + switch attr.Type { + case unix.NFTA_META_DREG: + e.Register = binaryutil.BigEndian.Uint32(attr.Data) + case unix.NFTA_META_KEY: + e.Key = MetaKey(binaryutil.BigEndian.Uint32(attr.Data)) + } + } + + return nil +} + // Masq (Masquerade) is a special case of SNAT, where the source address is // automagically set to the address of the output interface. See also // https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)#Masquerading @@ -99,6 +124,10 @@ func (e *Masq) marshal() ([]byte, error) { }) } +func (e *Masq) unmarshal(data []byte) error { + return fmt.Errorf("not yet implemented") +} + // CmpOp specifies which type of comparison should be performed. type CmpOp uint32 @@ -139,3 +168,22 @@ func (e *Cmp) marshal() ([]byte, error) { {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, }) } + +func (e *Cmp) unmarshal(data []byte) error { + attrs, err := netlink.UnmarshalAttributes(data) + if err != nil { + return err + } + for _, attr := range attrs { + switch attr.Type { + case unix.NFTA_CMP_SREG: + e.Register = binaryutil.BigEndian.Uint32(attr.Data) + case unix.NFTA_CMP_OP: + e.Op = CmpOp(binaryutil.BigEndian.Uint32(attr.Data)) + case unix.NFTA_CMP_DATA: + e.Data = attr.Data + } + } + + return nil +} diff --git a/expr/immediate.go b/expr/immediate.go index 38f815d..f050ce5 100644 --- a/expr/immediate.go +++ b/expr/immediate.go @@ -15,6 +15,8 @@ package expr import ( + "fmt" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -45,3 +47,7 @@ func (e *Immediate) marshal() ([]byte, error) { {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, }) } + +func (e *Immediate) unmarshal(data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/expr/nat.go b/expr/nat.go index 3275425..6832be5 100644 --- a/expr/nat.go +++ b/expr/nat.go @@ -15,6 +15,8 @@ package expr import ( + "fmt" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -68,3 +70,7 @@ func (e *NAT) marshal() ([]byte, error) { {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, }) } + +func (e *NAT) unmarshal(data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/expr/payload.go b/expr/payload.go index 5d48d75..d04b9fa 100644 --- a/expr/payload.go +++ b/expr/payload.go @@ -15,6 +15,8 @@ package expr import ( + "fmt" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -51,3 +53,7 @@ func (e *Payload) marshal() ([]byte, error) { {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, }) } + +func (e *Payload) unmarshal(data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/nftables.go b/nftables.go index 8510809..b2d07a2 100644 --- a/nftables.go +++ b/nftables.go @@ -18,6 +18,7 @@ package nftables import ( "fmt" "math" + "strings" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -175,6 +176,82 @@ type Rule struct { Exprs []expr.Any } +var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) + +func stringFrom0(b []byte) string { + return strings.TrimSuffix(string(b), "\x00") +} + +func exprsFromMsg(b []byte) ([]expr.Any, error) { + elems, err := netlink.UnmarshalAttributes(b) + if err != nil { + return nil, err + } + var exprs []expr.Any + for _, elem := range elems { + attrs, err := netlink.UnmarshalAttributes(elem.Data) + if err != nil { + return nil, err + } + var ( + name string + data []byte + ) + for _, attr := range attrs { + switch attr.Type { + case unix.NFTA_EXPR_NAME: + name = stringFrom0(attr.Data) + case unix.NFTA_EXPR_DATA: + data = attr.Data + } + } + var e expr.Any + switch name { + case "meta": + e = &expr.Meta{} + case "cmp": + e = &expr.Cmp{} + case "counter": + e = &expr.Counter{} + } + if e == nil { + // TODO: introduce an opaque expression type so that users know + // something is here. + continue // unsupported expression type + } + if err := expr.Unmarshal(data, e); err != nil { + return nil, err + } + exprs = append(exprs, e) + } + return exprs, nil +} + +func ruleFromMsg(msg netlink.Message) (*Rule, error) { + if got, want := msg.Header.Type, ruleHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + attrs, err := netlink.UnmarshalAttributes(msg.Data[4:]) + if err != nil { + return nil, err + } + var r Rule + for _, attr := range attrs { + switch attr.Type { + case unix.NFTA_RULE_TABLE: + r.Table = &Table{Name: stringFrom0(attr.Data)} + case unix.NFTA_RULE_CHAIN: + r.Chain = &Chain{Name: stringFrom0(attr.Data)} + case unix.NFTA_RULE_EXPRESSIONS: + r.Exprs, err = exprsFromMsg(attr.Data) + if err != nil { + return nil, err + } + } + } + return &r, nil +} + // AddRule adds the specified Rule. See also // https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management func (cc *Conn) AddRule(r *Rule) *Rule { @@ -194,7 +271,7 @@ func (cc *Conn) AddRule(r *Rule) *Rule { cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ - Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE), + Type: ruleHeaderType, Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsCreate, }, Data: append(extraHeader(uint8(r.Table.Family), 0), data...), @@ -290,3 +367,54 @@ func (cc *Conn) Flush() error { return nil } + +// GetRule returns the rules in the specified table and chain. +func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { + var conn *netlink.Conn + var err error + if cc.TestDial == nil { + conn, err = netlink.Dial(unix.NETLINK_NETFILTER, nil) + } else { + conn = nltest.Dial(cc.TestDial) + } + if err != nil { + return nil, err + } + + defer conn.Close() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), + Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsDump, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := conn.Receive() + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + var rules []*Rule + for _, msg := range reply { + r, err := ruleFromMsg(msg) + if err != nil { + return nil, err + } + rules = append(rules, r) + } + + return rules, nil +} diff --git a/nftables_test.go b/nftables_test.go index c3aa5d2..9136430 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -251,3 +251,81 @@ func TestConfigureNAT(t *testing.T) { t.Fatal(err) } } + +func TestGetRule(t *testing.T) { + // The want byte sequences come from stracing nft(8), e.g.: + // strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward + + want := [][]byte{ + []byte{0x2, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x0, 0x0, 0x0}, + } + + // The reply messages come from adding log.Printf("msgs: %#v", msgs) to + // (*github.com/mdlayher/netlink/Conn).receive + reply := [][]netlink.Message{ + nil, + []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x68, Type: 0xa06, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x2, 0x0, 0x0, 0xc, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x0, 0x4, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6d, 0x92, 0x20, 0x20, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x48, 0xd9}}}, + []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xba38ef3c}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: got %#v, want %#v", idx, got, want) + } + want = want[1:] + } + rep := reply[0] + reply = reply[1:] + return rep, nil + }, + } + + rules, err := c.GetRule( + &nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }, + &nftables.Chain{ + Name: "input", + }, + ) + if err != nil { + t.Fatal(err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + + rule := rules[0] + if got, want := len(rule.Exprs), 1; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + ce, ok := rule.Exprs[0].(*expr.Counter) + if !ok { + t.Fatalf("unexpected expression type: got %T, want *expr.Counter", rule.Exprs[0]) + } + + if got, want := ce.Packets, uint64(674009); got != want { + t.Errorf("unexpected number of packets: got %d, want %d", got, want) + } + + if got, want := ce.Bytes, uint64(1838293024); got != want { + t.Errorf("unexpected number of bytes: got %d, want %d", got, want) + } +}