From 64aca752d17d2aafa446e76c35776e3736d53846 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Mon, 9 Mar 2020 08:43:47 +0100 Subject: [PATCH] Remove Object API (#100) Co-authored-by: Alexis PIRES --- nftables_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ obj.go | 36 ++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/nftables_test.go b/nftables_test.go index 5db5672..c12d7a4 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -711,6 +711,64 @@ func TestAddCounter(t *testing.T) { } } +func TestDeleteCounter(t *testing.T) { + // The want byte sequences come from stracing nft(8), e.g.: + // strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat + // + // The nft(8) command sequence was taken from: + // https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT) + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add counter ip filter fwded + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0a\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x01\x1c\x00\x04\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + // nft delete counter ip filter fwded + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0a\x00\x02\x00\x66\x77\x64\x65\x64\x00\x00\x00\x08\x00\x03\x00\x00\x00\x00\x01\x04\x00\x04\x80"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + }, + } + + c.AddObj(&nftables.CounterObj{ + Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}, + Name: "fwded", + Bytes: 0, + Packets: 0, + }) + + c.DeleteObject(&nftables.CounterObj{ + Table: &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}, + Name: "fwded", + }) + + if err := c.Flush(); err != nil { + t.Fatal(err) + } +} + func TestDelRule(t *testing.T) { want := [][]byte{ // batch begin diff --git a/obj.go b/obj.go index 223d910..72936bd 100644 --- a/obj.go +++ b/obj.go @@ -33,6 +33,11 @@ type Obj interface { marshal(data bool) ([]byte, error) } +// AddObject adds the specified Obj. Alias of AddObj. +func (cc *Conn) AddObject(o Obj) Obj { + return cc.AddObj(o) +} + // AddObj adds the specified Obj. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects func (cc *Conn) AddObj(o Obj) Obj { @@ -54,6 +59,27 @@ func (cc *Conn) AddObj(o Obj) Obj { return o } +// DeleteObject deletes the specified Obj +func (cc *Conn) DeleteObject(o Obj) { + cc.Lock() + defer cc.Unlock() + data, err := o.marshal(false) + if err != nil { + cc.setErr(err) + return + } + + data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(o.family()), 0), data...), + }) +} + // GetObj is a legacy method that return all Obj that belongs // to the same table as the given one func (cc *Conn) GetObj(o Obj) ([]Obj, error) { @@ -69,6 +95,11 @@ func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { // GetObject gets the specified Object func (cc *Conn) GetObject(o Obj) (Obj, error) { objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) + + if len(objs) == 0 { + return nil, err + } + return objs[0], err } @@ -80,6 +111,11 @@ func (cc *Conn) GetObjects(t *Table) ([]Obj, error) { // ResetObject reset the given Obj func (cc *Conn) ResetObject(o Obj) (Obj, error) { objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) + + if len(objs) == 0 { + return nil, err + } + return objs[0], err }