From ce5436e43c4814759489120c8c6e49757ddf9292 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Tue, 14 Jan 2020 17:34:05 +0100 Subject: [PATCH] Before GetObj/GetObjReset returned all objects instead of only the object, now it's return only the given one. New methods: GetObj(Obj) (Obj, error), GetObjs(Table) ([]Obj, error), ResetObj(Obj) (Obj, error), ResetObjs(Table) ([]Obj, error). Deleted Methods: GetObj(Obj) ([]Obj, error), GetObjReset(Obj) ([]Obj, error) --- counter.go | 4 ++ nftables_test.go | 155 +++++++++++++++++++++++++++++++++++++++++++++-- obj.go | 44 ++++++++++---- 3 files changed, 186 insertions(+), 17 deletions(-) diff --git a/counter.go b/counter.go index 58c008c..e428202 100644 --- a/counter.go +++ b/counter.go @@ -41,6 +41,10 @@ func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { return ad.Err() } +func (c *CounterObj) table() *Table { + return c.Table +} + func (c *CounterObj) family() TableFamily { return c.Table.Family } diff --git a/nftables_test.go b/nftables_test.go index 7489cd1..574a230 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1252,7 +1252,7 @@ func TestGetObjReset(t *testing.T) { } filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} - objs, err := c.GetObjReset(&nftables.CounterObj{ + obj, err := c.ResetObj(&nftables.CounterObj{ Table: filter, Name: "fwded", }) @@ -1261,11 +1261,6 @@ func TestGetObjReset(t *testing.T) { t.Fatal(err) } - if got, want := len(objs), 1; got != want { - t.Fatalf("unexpected number of rules: got %d, want %d", got, want) - } - - obj := objs[0] co, ok := obj.(*nftables.CounterObj) if !ok { t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj) @@ -1284,6 +1279,154 @@ func TestGetObjReset(t *testing.T) { } } +func TestObjAPI(t *testing.T) { + if os.Getenv("TRAVIS") == "true" { + t.SkipNow() + } + + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + table := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + tableOther := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "foo", + }) + + chain := c.AddChain(&nftables.Chain{ + Name: "chain", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityFilter, + }) + + counter1 := c.AddObj(&nftables.CounterObj{ + Table: table, + Name: "fwded1", + Bytes: 1, + Packets: 1, + }) + + counter2 := c.AddObj(&nftables.CounterObj{ + Table: table, + Name: "fwded2", + Bytes: 1, + Packets: 1, + }) + + c.AddObj(&nftables.CounterObj{ + Table: tableOther, + Name: "fwdedOther", + Bytes: 0, + Packets: 0, + }) + + c.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Objref{ + Type: 1, + Name: "fwded1", + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatalf(err.Error()) + } + + objs, err := c.GetObjs(table) + + if err != nil { + t.Errorf("c.GetObjs(table) failed: %v failed", err) + } + + if got := len(objs); got != 2 { + t.Fatalf("unexpected number of objects: got %d, want %d", got, 2) + } + + objsOther, err := c.GetObjs(tableOther) + + if err != nil { + t.Errorf("c.GetObjs(tableOther) failed: %v failed", err) + } + + if got := len(objsOther); got != 1 { + t.Fatalf("unexpected number of objects: got %d, want %d", got, 1) + } + + obj1, err := c.GetObj(counter1) + + if err != nil { + t.Errorf("c.GetObj(counter1) failed: %v failed", err) + } + + rcounter1, ok := obj1.(*nftables.CounterObj) + + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter1) + } + + if rcounter1.Name != "fwded1" { + t.Fatalf("unexpected counter name: got %s, want %s", rcounter1.Name, "fwded1") + } + + obj2, err := c.GetObj(counter2) + + if err != nil { + t.Errorf("c.GetObj(counter2) failed: %v failed", err) + } + + rcounter2, ok := obj2.(*nftables.CounterObj) + + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter2) + } + + if rcounter2.Name != "fwded2" { + t.Fatalf("unexpected counter name: got %s, want %s", rcounter2.Name, "fwded2") + } + + _, err = c.ResetObj(counter1) + + if err != nil { + t.Errorf("c.ResetObjs(table) failed: %v failed", err) + } + + obj1, err = c.GetObj(counter1) + + if err != nil { + t.Errorf("c.GetObj(counter1) failed: %v failed", err) + } + + if counter1 := obj1.(*nftables.CounterObj); counter1.Packets > 0 { + t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) + } + + obj2, err = c.GetObj(counter2) + + if err != nil { + t.Errorf("c.GetObj(counter2) failed: %v failed", err) + } + + if counter2 := obj2.(*nftables.CounterObj); counter2.Packets != 1 { + t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) + } + +} + func TestConfigureClamping(t *testing.T) { // The want byte sequences come from stracing nft(8), e.g.: // strace -f -v -x -s 2048 -eraw=sendto nft add table ip nat diff --git a/obj.go b/obj.go index f3627df..252d97c 100644 --- a/obj.go +++ b/obj.go @@ -27,6 +27,7 @@ var objHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.N // Obj represents a netfilter stateful object. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects type Obj interface { + table() *Table family() TableFamily unmarshal(*netlink.AttributeDecoder) error marshal(data bool) ([]byte, error) @@ -54,13 +55,23 @@ func (cc *Conn) AddObj(o Obj) Obj { } // GetObj gets the specified Obj without resetting it. -func (cc *Conn) GetObj(o Obj) ([]Obj, error) { - return cc.getObj(o, unix.NFT_MSG_GETOBJ) +func (cc *Conn) GetObj(o Obj) (Obj, error) { + objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) + return objs[0], err +} + +func (cc *Conn) GetObjs(t *Table) ([]Obj, error) { + return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ) } // GetObjReset gets the specified Obj and resets it. -func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { - return cc.getObj(o, unix.NFT_MSG_GETOBJ_RESET) +func (cc *Conn) ResetObj(o Obj) (Obj, error) { + objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) + return objs[0], err +} + +func (cc *Conn) ResetObjs(t *Table) ([]Obj, error) { + return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET) } func objFromMsg(msg netlink.Message) (Obj, error) { @@ -112,24 +123,35 @@ func objFromMsg(msg netlink.Message) (Obj, error) { return nil, fmt.Errorf("malformed stateful object") } -func (cc *Conn) getObj(o Obj, msgType uint16) ([]Obj, error) { +func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { conn, err := cc.dialNetlink() if err != nil { return nil, err } defer conn.Close() - data, err := o.marshal(false) - if err != nil { - return nil, err + var data []byte + var message netlink.Message + var flags netlink.HeaderFlags + + if o != nil { + data, err = o.marshal(false) + if err != nil { + return nil, err + } + } else { + flags = netlink.Dump + data, err = netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + }) } - message := netlink.Message{ + message = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType), - Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + Flags: netlink.Request | netlink.Acknowledge | flags, }, - Data: append(extraHeader(uint8(o.family()), 0), data...), + Data: append(extraHeader(uint8(t.Family), 0), data...), } if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {