diff --git a/nftables_test.go b/nftables_test.go index 574a230..9c079cb 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1252,7 +1252,7 @@ func TestGetObjReset(t *testing.T) { } filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} - obj, err := c.ResetObj(&nftables.CounterObj{ + obj, err := c.ResetObject(&nftables.CounterObj{ Table: filter, Name: "fwded", }) @@ -1347,30 +1347,30 @@ func TestObjAPI(t *testing.T) { t.Fatalf(err.Error()) } - objs, err := c.GetObjs(table) + objs, err := c.GetObjects(table) if err != nil { - t.Errorf("c.GetObjs(table) failed: %v failed", err) + t.Errorf("c.GetObjects(table) failed: %v failed", err) } if got := len(objs); got != 2 { t.Fatalf("unexpected number of objects: got %d, want %d", got, 2) } - objsOther, err := c.GetObjs(tableOther) + objsOther, err := c.GetObjects(tableOther) if err != nil { - t.Errorf("c.GetObjs(tableOther) failed: %v failed", err) + t.Errorf("c.GetObjects(tableOther) failed: %v failed", err) } if got := len(objsOther); got != 1 { t.Fatalf("unexpected number of objects: got %d, want %d", got, 1) } - obj1, err := c.GetObj(counter1) + obj1, err := c.GetObject(counter1) if err != nil { - t.Errorf("c.GetObj(counter1) failed: %v failed", err) + t.Errorf("c.GetObject(counter1) failed: %v failed", err) } rcounter1, ok := obj1.(*nftables.CounterObj) @@ -1383,10 +1383,10 @@ func TestObjAPI(t *testing.T) { t.Fatalf("unexpected counter name: got %s, want %s", rcounter1.Name, "fwded1") } - obj2, err := c.GetObj(counter2) + obj2, err := c.GetObject(counter2) if err != nil { - t.Errorf("c.GetObj(counter2) failed: %v failed", err) + t.Errorf("c.GetObject(counter2) failed: %v failed", err) } rcounter2, ok := obj2.(*nftables.CounterObj) @@ -1399,32 +1399,52 @@ func TestObjAPI(t *testing.T) { t.Fatalf("unexpected counter name: got %s, want %s", rcounter2.Name, "fwded2") } - _, err = c.ResetObj(counter1) + _, err = c.ResetObject(counter1) if err != nil { - t.Errorf("c.ResetObjs(table) failed: %v failed", err) + t.Errorf("c.ResetObjects(table) failed: %v failed", err) } - obj1, err = c.GetObj(counter1) + obj1, err = c.GetObject(counter1) if err != nil { - t.Errorf("c.GetObj(counter1) failed: %v failed", err) + t.Errorf("c.GetObject(counter1) failed: %v failed", err) } if counter1 := obj1.(*nftables.CounterObj); counter1.Packets > 0 { t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) } - obj2, err = c.GetObj(counter2) + obj2, err = c.GetObject(counter2) if err != nil { - t.Errorf("c.GetObj(counter2) failed: %v failed", err) + t.Errorf("c.GetObject(counter2) failed: %v failed", err) } if counter2 := obj2.(*nftables.CounterObj); counter2.Packets != 1 { t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) } + legacy, err := c.GetObj(counter1) + + if err != nil { + t.Errorf("c.GetObj(counter1) failed: %v failed", err) + } + + if len(legacy) != 2 { + t.Errorf("unexpected number of objects: got %d, want %d", len(legacy), 2) + } + + legacyReset, err := c.GetObjReset(counter1) + + if err != nil { + t.Errorf("c.GetObjReset(counter1) failed: %v failed", err) + } + + if len(legacyReset) != 2 { + t.Errorf("unexpected number of objects: got %d, want %d", len(legacyReset), 2) + } + } func TestConfigureClamping(t *testing.T) { diff --git a/obj.go b/obj.go index 252d97c..4ee09ba 100644 --- a/obj.go +++ b/obj.go @@ -54,23 +54,38 @@ func (cc *Conn) AddObj(o Obj) Obj { return o } -// GetObj gets the specified Obj without resetting it. -func (cc *Conn) GetObj(o Obj) (Obj, error) { +// GetObj is a legacy method that return all Obj that belongs +// to the same table as the given one +func (cc *Conn) GetObj(o Obj) ([]Obj, error) { + return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ) +} + +// GetObjReset is a legacy method that reset all Obj that belongs +// the same table as the given one +func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { + return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET) +} + + +// GetObject gets the specified Object +func (cc *Conn) GetObject(o Obj) (Obj, error) { objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) return objs[0], err } -func (cc *Conn) GetObjs(t *Table) ([]Obj, error) { +// GetObjects get all the Obj that belongs to the given table +func (cc *Conn) GetObjects(t *Table) ([]Obj, error) { return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ) } -// GetObjReset gets the specified Obj and resets it. -func (cc *Conn) ResetObj(o Obj) (Obj, error) { +// ResetObject reset the given Obj +func (cc *Conn) ResetObject(o Obj) (Obj, error) { objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) return objs[0], err } -func (cc *Conn) ResetObjs(t *Table) ([]Obj, error) { +// ResetObjects reset all the Obj that belongs to the given table +func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) { return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET) } @@ -131,7 +146,6 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { defer conn.Close() var data []byte - var message netlink.Message var flags netlink.HeaderFlags if o != nil { @@ -146,7 +160,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { }) } - message = netlink.Message{ + message := netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType), Flags: netlink.Request | netlink.Acknowledge | flags,