Compare commits

..

2 Commits

Author SHA1 Message Date
turekt 72b6fe192a
Merge b5406ff95a into 912dee68b1 2024-07-24 16:05:45 +00:00
turekt b5406ff95a Objects implementation refactor
Refactored obj.go to a more generic approach
Added object support for already implemented expressions
Added test for limit object
Fixes https://github.com/google/nftables/issues/253
2024-07-24 18:04:09 +02:00
2 changed files with 50 additions and 42 deletions

View File

@ -2105,7 +2105,7 @@ func TestGetObjReset(t *testing.T) {
} }
filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
obj, err := c.ResetObject(&nftables.NamedObj{ obj, err := c.ResetObject(&nftables.ObjAttr{
Table: filter, Table: filter,
Name: "fwded", Name: "fwded",
Type: nftables.ObjTypeCounter, Type: nftables.ObjTypeCounter,
@ -2114,7 +2114,7 @@ func TestGetObjReset(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
co, ok := obj.(*nftables.NamedObj) co, ok := obj.(*nftables.ObjAttr)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj) t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj)
} }
@ -2168,7 +2168,7 @@ func TestObjAPI(t *testing.T) {
Priority: nftables.ChainPriorityFilter, Priority: nftables.ChainPriorityFilter,
}) })
counter1 := c.AddObj(&nftables.NamedObj{ counter1 := c.AddObj(&nftables.ObjAttr{
Table: table, Table: table,
Name: "fwded1", Name: "fwded1",
Type: nftables.ObjTypeCounter, Type: nftables.ObjTypeCounter,
@ -2178,7 +2178,7 @@ func TestObjAPI(t *testing.T) {
}, },
}) })
counter2 := c.AddObj(&nftables.NamedObj{ counter2 := c.AddObj(&nftables.ObjAttr{
Table: table, Table: table,
Name: "fwded2", Name: "fwded2",
Type: nftables.ObjTypeCounter, Type: nftables.ObjTypeCounter,
@ -2188,7 +2188,7 @@ func TestObjAPI(t *testing.T) {
}, },
}) })
c.AddObj(&nftables.NamedObj{ c.AddObj(&nftables.ObjAttr{
Table: tableOther, Table: tableOther,
Name: "fwdedOther", Name: "fwdedOther",
Type: nftables.ObjTypeCounter, Type: nftables.ObjTypeCounter,
@ -2236,7 +2236,7 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter1) failed: %v failed", err) t.Errorf("c.GetObject(counter1) failed: %v failed", err)
} }
rcounter1, ok := obj1.(*nftables.NamedObj) rcounter1, ok := obj1.(*nftables.ObjAttr)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1) t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1)
} }
@ -2250,7 +2250,7 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter2) failed: %v failed", err) t.Errorf("c.GetObject(counter2) failed: %v failed", err)
} }
rcounter2, ok := obj2.(*nftables.NamedObj) rcounter2, ok := obj2.(*nftables.ObjAttr)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2) t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2)
} }
@ -2271,7 +2271,7 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter1) failed: %v failed", err) t.Errorf("c.GetObject(counter1) failed: %v failed", err)
} }
if counter1 := obj1.(*nftables.NamedObj).Obj.(*expr.Counter); counter1.Packets > 0 { if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 {
t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0)
} }
@ -2281,7 +2281,7 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter2) failed: %v failed", err) t.Errorf("c.GetObject(counter2) failed: %v failed", err)
} }
if counter2 := obj2.(*nftables.NamedObj).Obj.(*expr.Counter); counter2.Packets != 1 { if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 {
t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1)
} }
@ -6674,7 +6674,7 @@ func TestAddLimitObj(t *testing.T) {
Burst: 5, Burst: 5,
Over: false, Over: false,
} }
o := &nftables.NamedObj{ o := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "limit_test", Name: "limit_test",
Type: nftables.ObjTypeLimit, Type: nftables.ObjTypeLimit,
@ -6686,7 +6686,7 @@ func TestAddLimitObj(t *testing.T) {
t.Errorf("conn.Flush() failed: %v", err) t.Errorf("conn.Flush() failed: %v", err)
} }
obj, err := conn.GetObj(&nftables.NamedObj{ obj, err := conn.GetObj(&nftables.ObjAttr{
Table: table, Table: table,
Name: "limit_test", Name: "limit_test",
Type: nftables.ObjTypeLimit, Type: nftables.ObjTypeLimit,
@ -6699,7 +6699,7 @@ func TestAddLimitObj(t *testing.T) {
t.Fatalf("unexpected object list length: got %d, want %d", got, want) t.Fatalf("unexpected object list length: got %d, want %d", got, want)
} }
o1, ok := obj[0].(*nftables.NamedObj) o1, ok := obj[0].(*nftables.ObjAttr)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0])
} }
@ -6745,7 +6745,7 @@ func TestAddQuotaObj(t *testing.T) {
} }
conn.AddChain(c) conn.AddChain(c)
o := &nftables.NamedObj{ o := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota, Type: nftables.ObjTypeQuota,
@ -6761,7 +6761,7 @@ func TestAddQuotaObj(t *testing.T) {
t.Fatalf("conn.Flush() failed: %v", err) t.Fatalf("conn.Flush() failed: %v", err)
} }
obj, err := conn.GetObj(&nftables.NamedObj{ obj, err := conn.GetObj(&nftables.ObjAttr{
Table: table, Table: table,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota, Type: nftables.ObjTypeQuota,
@ -6774,7 +6774,7 @@ func TestAddQuotaObj(t *testing.T) {
t.Fatalf("unexpected object list length: got %d, want %d", got, want) t.Fatalf("unexpected object list length: got %d, want %d", got, want)
} }
o1, ok := obj[0].(*nftables.NamedObj) o1, ok := obj[0].(*nftables.ObjAttr)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0])
} }
@ -6878,7 +6878,7 @@ func TestDeleteQuotaObjMixedTypes(t *testing.T) {
} }
conn.AddChain(c) conn.AddChain(c)
o := &nftables.NamedObj{ o := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota, Type: nftables.ObjTypeQuota,
@ -6894,7 +6894,7 @@ func TestDeleteQuotaObjMixedTypes(t *testing.T) {
t.Fatalf("conn.Flush() failed: %v", err) t.Fatalf("conn.Flush() failed: %v", err)
} }
obj, err := conn.GetObj(&nftables.NamedObj{ obj, err := conn.GetObj(&nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota, Type: nftables.ObjTypeQuota,
@ -6908,7 +6908,7 @@ func TestDeleteQuotaObjMixedTypes(t *testing.T) {
} }
o2, _ := o.Obj.(*expr.Quota) o2, _ := o.Obj.(*expr.Quota)
want := &nftables.NamedObj{ want := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota, Type: nftables.ObjTypeQuota,

56
obj.go
View File

@ -69,33 +69,33 @@ type Obj interface {
objType() ObjType objType() ObjType
} }
// NamedObj represents nftables stateful object attributes // ObjAttr represents nftables stateful object attributes
// Corresponds to netfilter nft_object_attributes as per // Corresponds to netfilter nft_object_attributes as per
// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=116e95aa7b6358c917de8c69f6f173874030b46b#n1626 // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=116e95aa7b6358c917de8c69f6f173874030b46b#n1626
type NamedObj struct { type ObjAttr struct {
Table *Table Table *Table
Name string Name string
Type ObjType Type ObjType
Obj expr.Any Obj expr.Any
} }
func (o *NamedObj) table() *Table { func (o *ObjAttr) table() *Table {
return o.Table return o.Table
} }
func (o *NamedObj) family() TableFamily { func (o *ObjAttr) family() TableFamily {
return o.Table.Family return o.Table.Family
} }
func (o *NamedObj) data() expr.Any { func (o *ObjAttr) data() expr.Any {
return o.Obj return o.Obj
} }
func (o *NamedObj) name() string { func (o *ObjAttr) name() string {
return o.Name return o.Name
} }
func (o *NamedObj) objType() ObjType { func (o *ObjAttr) objType() ObjType {
return o.Type return o.Type
} }
@ -157,26 +157,32 @@ func (cc *Conn) DeleteObject(o Obj) {
// GetObj is a legacy method that return all Obj that belongs // GetObj is a legacy method that return all Obj that belongs
// to the same table as the given one // to the same table as the given one
// This function returns the same concrete type as passed, // This function will determine whether returned object will be
// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more // one of legacy types QuotaObj/CounterObj or the new ObjAttr
// generic NamedObj over the legacy QuotaObj and CounterObj types. // type struct based passed o parameter type
// If o is of type ObjAttr, the implementation will work with
// the new ObjAttr type, otherwise falls back to legacy QuotaObj/CounterObj
func (cc *Conn) GetObj(o Obj) ([]Obj, error) { func (cc *Conn) GetObj(o Obj) ([]Obj, error) {
return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ, cc.useLegacyObjType(o)) return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ, cc.useLegacyObjType(o))
} }
// GetObjReset is a legacy method that reset all Obj that belongs // GetObjReset is a legacy method that reset all Obj that belongs
// the same table as the given one // the same table as the given one
// This function returns the same concrete type as passed, // This function will determine whether returned object will be
// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more // one of legacy types QuotaObj/CounterObj or the new ObjAttr
// generic NamedObj over the legacy QuotaObj and CounterObj types. // type struct based passed o parameter type
// If o is of type ObjAttr, the implementation will work with
// the new ObjAttr type, otherwise falls back to legacy QuotaObj/CounterObj
func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) {
return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET, cc.useLegacyObjType(o)) return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET, cc.useLegacyObjType(o))
} }
// GetObject gets the specified Object // GetObject gets the specified Object
// This function returns the same concrete type as passed, // This function will determine whether returned object will be
// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more // one of legacy types QuotaObj/CounterObj or the new ObjAttr
// generic NamedObj over the legacy QuotaObj and CounterObj types. // type struct based passed o parameter type
// If o is of type ObjAttr, the implementation will work with
// the new ObjAttr type, otherwise falls back to legacy QuotaObj/CounterObj
func (cc *Conn) GetObject(o Obj) (Obj, error) { func (cc *Conn) GetObject(o Obj) (Obj, error) {
objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ)
@ -195,9 +201,11 @@ func (cc *Conn) GetObjects(t *Table) ([]Obj, error) {
} }
// ResetObject reset the given Obj // ResetObject reset the given Obj
// This function returns the same concrete type as passed, // This function will determine whether returned object will be
// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more // one of legacy types QuotaObj/CounterObj or the new ObjAttr
// generic NamedObj over the legacy QuotaObj and CounterObj types. // type struct based passed o parameter type
// If o is of type ObjAttr, the implementation will work with
// the new ObjAttr type, otherwise falls back to legacy QuotaObj/CounterObj
func (cc *Conn) ResetObject(o Obj) (Obj, error) { func (cc *Conn) ResetObject(o Obj) (Obj, error) {
objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET)
@ -242,7 +250,7 @@ func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
return objDataFromMsgLegacy(ad, table, name, objectType) return objDataFromMsgLegacy(ad, table, name, objectType)
} }
o := NamedObj{ o := ObjAttr{
Table: table, Table: table,
Name: name, Name: name,
Type: ObjType(objectType), Type: ObjType(objectType),
@ -252,13 +260,13 @@ func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(objs) == 0 {
return nil, fmt.Errorf("objFromMsg: objs is empty for obj %v", o)
}
exprs := make([]expr.Any, len(objs)) exprs := make([]expr.Any, len(objs))
for i := range exprs { for i := range exprs {
exprs[i] = objs[i].(expr.Any) exprs[i] = objs[i].(expr.Any)
} }
if len(exprs) == 0 {
return nil, fmt.Errorf("objFromMsg: exprs is empty for obj %v", o)
}
o.Obj = exprs[0] o.Obj = exprs[0]
return &o, ad.Err() return &o, ad.Err()
@ -372,7 +380,7 @@ func (cc *Conn) useLegacyObjType(o Obj) bool {
useLegacyType := true useLegacyType := true
if o != nil { if o != nil {
switch o.(type) { switch o.(type) {
case *NamedObj: case *ObjAttr:
useLegacyType = false useLegacyType = false
} }
} }