diff --git a/nftables.go b/nftables.go index d6f54a9..4f86dfc 100644 --- a/nftables.go +++ b/nftables.go @@ -420,3 +420,174 @@ func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { return rules, nil } + +// CounterObj implements Obj. +type CounterObj struct { + Table *Table + Name string // e.g. “fwded” + + Bytes uint64 + Packets uint64 +} + +func (c *CounterObj) unmarshal(attrs []netlink.Attribute) error { + for _, attr := range attrs { + switch attr.Type { + case unix.NFTA_COUNTER_BYTES: + c.Bytes = binaryutil.BigEndian.Uint64(attr.Data) + case unix.NFTA_COUNTER_PACKETS: + c.Packets = binaryutil.BigEndian.Uint64(attr.Data) + } + } + return nil +} + +func (c *CounterObj) family() TableFamily { + return c.Table.Family +} + +func (c *CounterObj) marshal(data bool) ([]byte, error) { + obj, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(c.Bytes)}, + {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(c.Packets)}, + }) + if err != nil { + return nil, err + } + const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(c.Name + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(NFT_OBJECT_COUNTER)}, + } + if data { + attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj}) + } + return netlink.MarshalAttributes(attrs) +} + +// Obj represents a netfilter stateful object. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects +type Obj interface { + family() TableFamily + unmarshal([]netlink.Attribute) error + marshal(data bool) ([]byte, error) +} + +// AddObj adds the specified Obj. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects +func (cc *Conn) AddObj(o Obj) Obj { + data, err := o.marshal(true) + if err != nil { + cc.setErr(err) + return nil + } + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), + Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsCreate, + }, + Data: append(extraHeader(uint8(o.family()), 0), data...), + }) + return o +} + +var objHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) + +func objFromMsg(msg netlink.Message) (Obj, error) { + if got, want := msg.Header.Type, objHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + attrs, err := netlink.UnmarshalAttributes(msg.Data[4:]) + if err != nil { + return nil, err + } + var ( + table *Table + name string + objectType uint32 + ) + const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix + for _, attr := range attrs { + switch attr.Type { + case unix.NFTA_OBJ_TABLE: + table = &Table{Name: stringFrom0(attr.Data)} + case unix.NFTA_OBJ_NAME: + name = stringFrom0(attr.Data) + case unix.NFTA_OBJ_TYPE: + objectType = binaryutil.BigEndian.Uint32(attr.Data) + case unix.NFTA_OBJ_DATA: + switch objectType { + case NFT_OBJECT_COUNTER: + attrs, err := netlink.UnmarshalAttributes(attr.Data) + if err != nil { + return nil, err + } + o := CounterObj{ + Table: table, + Name: name, + } + return &o, o.unmarshal(attrs) + } + } + } + return nil, fmt.Errorf("malformed stateful object") +} + +func (cc *Conn) getObj(o Obj, msgType uint16) ([]Obj, error) { + var conn *netlink.Conn + var err error + if cc.TestDial == nil { + conn, err = netlink.Dial(unix.NETLINK_NETFILTER, nil) + } else { + conn = nltest.Dial(cc.TestDial) + } + if err != nil { + return nil, err + } + + defer conn.Close() + + data, err := o.marshal(false) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType), + Flags: netlink.HeaderFlagsRequest | netlink.HeaderFlagsAcknowledge | netlink.HeaderFlagsDump, + }, + Data: append(extraHeader(uint8(o.family()), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := conn.Receive() + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + var objs []Obj + for _, msg := range reply { + o, err := objFromMsg(msg) + if err != nil { + return nil, err + } + objs = append(objs, o) + } + + return objs, nil +} + +// GetObj gets the specified Obj without resetting it. +func (cc *Conn) GetObj(o Obj) ([]Obj, error) { + return cc.getObj(o, unix.NFT_MSG_GETOBJ) +} + +// GetObjReset gets the specified Obj and resets it. +func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { + return cc.getObj(o, unix.NFT_MSG_GETOBJ_RESET) +} diff --git a/nftables_test.go b/nftables_test.go index 9136430..c2571b7 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -329,3 +329,135 @@ func TestGetRule(t *testing.T) { t.Errorf("unexpected number of bytes: got %d, want %d", got, want) } } + +func TestAddCounter(t *testing.T) { + // The want byte sequences come from stracing nft(8), e.g.: + // strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat + // + // The nft(8) command sequence was taken from: + // https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT) + want := [][]byte{ + // batch begin + []byte("\x00\x00\x0a\x00"), + // nft add counter ip filter fwded + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0a\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x01\x1c\x00\x04\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + // nft add rule ip filter forward counter name fwded + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x2c\x00\x04\x80\x28\x00\x01\x80\x0b\x00\x01\x00\x6f\x62\x6a\x72\x65\x66\x00\x00\x18\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x09\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00"), + // batch end + []byte("\x00\x00\x0a\x00"), + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: got %x, want %x", idx, got, want) + } + want = want[1:] + } + return req, nil + }, + } + + c.AddObj(&nftables.CounterObj{ + Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}, + Name: "fwded", + Bytes: 0, + Packets: 0, + }) + + c.AddRule(&nftables.Rule{ + Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}, + Chain: &nftables.Chain{Name: "forward", Type: nftables.ChainTypeFilter}, + Exprs: []expr.Any{ + &expr.Objref{ + Type: 1, + Name: "fwded", + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } +} + +func TestGetObjReset(t *testing.T) { + // The want byte sequences come from stracing nft(8), e.g.: + // strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward + + want := [][]byte{ + []byte{0x2, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x66, 0x77, 0x64, 0x65, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1}, + } + + // The reply messages come from adding log.Printf("msgs: %#v", msgs) to + // (*github.com/mdlayher/netlink/Conn).receive + reply := [][]netlink.Message{ + nil, + []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x64, Type: 0xa12, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x2, 0x0, 0x0, 0x10, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x66, 0x77, 0x64, 0x65, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x1, 0x1c, 0x0, 0x4, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x61, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9, 0xc, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2}}}, + []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: got %#v, want %#v", idx, got, want) + } + want = want[1:] + } + rep := reply[0] + reply = reply[1:] + return rep, nil + }, + } + + objs, err := c.GetObjReset(&nftables.CounterObj{ + Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}, + Name: "fwded", + }) + + if err != nil { + t.Fatal(err) + } + + if got, want := len(objs), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + + obj := objs[0] + co, ok := obj.(*nftables.CounterObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj) + } + if got, want := co.Packets, uint64(9); got != want { + t.Errorf("unexpected number of packets: got %d, want %d", got, want) + } + if got, want := co.Bytes, uint64(1121); got != want { + t.Errorf("unexpected number of bytes: got %d, want %d", got, want) + } +}