Compare commits

..

3 Commits

Author SHA1 Message Date
turekt 533a9343c8 Objects implementation refactor
Refactored obj.go to a more generic approach
Added object support for already implemented expressions
Added test for limit object
Fixes https://github.com/google/nftables/issues/253
2024-06-25 21:07:40 +00:00
TheDiveO aa8348f790
feat: add xt.Comment (#260)
Signed-off-by: thediveo <thediveo@gmx.eu>
2024-04-22 08:53:34 +02:00
dependabot[bot] 20edd38e22
Bump golang.org/x/net from 0.22.0 to 0.23.0 (#261)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.22.0 to 0.23.0.
- [Commits](https://github.com/golang/net/compare/v0.22.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-19 19:12:50 +02:00
10 changed files with 577 additions and 57 deletions

View File

@ -21,7 +21,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// Deprecated: Use ObjAttr instead
type CounterObj struct { type CounterObj struct {
Table *Table Table *Table
Name string // e.g. “fwded” Name string // e.g. “fwded”

2
go.mod
View File

@ -12,6 +12,6 @@ require (
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
golang.org/x/net v0.22.0 // indirect golang.org/x/net v0.23.0 // indirect
golang.org/x/sync v0.6.0 // indirect golang.org/x/sync v0.6.0 // indirect
) )

4
go.sum
View File

@ -8,8 +8,8 @@ github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=

View File

@ -259,7 +259,7 @@ func (monitor *Monitor) monitor() {
} }
monitor.eventCh <- event monitor.eventCh <- event
case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ: case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ:
obj, err := objFromMsg(msg) obj, err := objFromMsg(msg, true)
event := &MonitorEvent{ event := &MonitorEvent{
Type: MonitorEventType(msgType), Type: MonitorEventType(msgType),
Data: obj, Data: obj,

View File

@ -2105,9 +2105,10 @@ func TestGetObjReset(t *testing.T) {
} }
filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
obj, err := c.ResetObject(&nftables.CounterObj{ obj, err := c.ResetObject(&nftables.ObjAttr{
Table: filter, Table: filter,
Name: "fwded", Name: "fwded",
Type: nftables.ObjTypeCounter,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -2167,6 +2168,366 @@ func TestObjAPI(t *testing.T) {
Priority: nftables.ChainPriorityFilter, Priority: nftables.ChainPriorityFilter,
}) })
counter1 := c.AddObj(&nftables.ObjAttr{
Table: table,
Name: "fwded1",
Type: nftables.ObjTypeCounter,
Obj: &expr.Counter{
Bytes: 1,
Packets: 1,
},
})
counter2 := c.AddObj(&nftables.ObjAttr{
Table: table,
Name: "fwded2",
Type: nftables.ObjTypeCounter,
Obj: &expr.Counter{
Bytes: 1,
Packets: 1,
},
})
c.AddObj(&nftables.ObjAttr{
Table: tableOther,
Name: "fwdedOther",
Type: nftables.ObjTypeCounter,
Obj: &expr.Counter{
Bytes: 0,
Packets: 0,
},
})
c.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Objref{
Type: 1,
Name: "fwded1",
},
},
})
if err := c.Flush(); err != nil {
t.Fatalf(err.Error())
}
objs, err := c.GetObjects(table)
if err != nil {
t.Errorf("c.GetObjects(table) failed: %v failed", err)
}
if got := len(objs); got != 2 {
t.Fatalf("unexpected number of objects: got %d, want %d", got, 2)
}
objsOther, err := c.GetObjects(tableOther)
if err != nil {
t.Errorf("c.GetObjects(tableOther) failed: %v failed", err)
}
if got := len(objsOther); got != 1 {
t.Fatalf("unexpected number of objects: got %d, want %d", got, 1)
}
obj1, err := c.GetObject(counter1)
if err != nil {
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
}
rcounter1, ok := obj1.(*nftables.ObjAttr)
if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1)
}
if rcounter1.Name != "fwded1" {
t.Fatalf("unexpected counter name: got %s, want %s", rcounter1.Name, "fwded1")
}
obj2, err := c.GetObject(counter2)
if err != nil {
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
}
rcounter2, ok := obj2.(*nftables.ObjAttr)
if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2)
}
if rcounter2.Name != "fwded2" {
t.Fatalf("unexpected counter name: got %s, want %s", rcounter2.Name, "fwded2")
}
_, err = c.ResetObject(counter1)
if err != nil {
t.Errorf("c.ResetObjects(table) failed: %v failed", err)
}
obj1, err = c.GetObject(counter1)
if err != nil {
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
}
if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 {
t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0)
}
obj2, err = c.GetObject(counter2)
if err != nil {
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
}
if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 {
t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1)
}
legacy, err := c.GetObj(counter1)
if err != nil {
t.Errorf("c.GetObj(counter1) failed: %v failed", err)
}
if len(legacy) != 2 {
t.Errorf("unexpected number of objects: got %d, want %d", len(legacy), 2)
}
legacyReset, err := c.GetObjReset(counter1)
if err != nil {
t.Errorf("c.GetObjReset(counter1) failed: %v failed", err)
}
if len(legacyReset) != 2 {
t.Errorf("unexpected number of objects: got %d, want %d", len(legacyReset), 2)
}
}
func TestDeleteLegacyQuotaObj(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table := &nftables.Table{
Name: "quota_demo",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)
c := &nftables.Chain{
Name: "filter",
Table: table,
}
conn.AddChain(c)
o := &nftables.QuotaObj{
Table: tr,
Name: "q_test",
Bytes: 0x06400000,
Consumed: 0,
Over: true,
}
conn.AddObj(o)
if err := conn.Flush(); err != nil {
t.Fatalf("conn.Flush() failed: %v", err)
}
obj, err := conn.GetObj(&nftables.QuotaObj{
Table: table,
Name: "q_test",
})
if err != nil {
t.Fatalf("conn.GetObj() failed: %v", err)
}
if got, want := len(obj), 1; got != want {
t.Fatalf("unexpected number of objects: got %d, want %d", got, want)
}
if got, want := obj[0], o; !reflect.DeepEqual(got, want) {
t.Errorf("got = %+v, want = %+v", got, want)
}
conn.DeleteObject(&nftables.QuotaObj{
Table: tr,
Name: "q_test",
})
if err := conn.Flush(); err != nil {
t.Fatalf("conn.Flush() failed: %v", err)
}
obj, err = conn.GetObj(&nftables.QuotaObj{
Table: table,
Name: "q_test",
})
if err != nil {
t.Fatalf("conn.GetObj() failed: %v", err)
}
if got, want := len(obj), 0; got != want {
t.Fatalf("unexpected object list length: got %d, want %d", got, want)
}
}
func TestAddLegacyQuotaObj(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table := &nftables.Table{
Name: "quota_demo",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)
c := &nftables.Chain{
Name: "filter",
Table: table,
}
conn.AddChain(c)
o := &nftables.QuotaObj{
Table: tr,
Name: "q_test",
Bytes: 0x06400000,
Consumed: 0,
Over: true,
}
conn.AddObj(o)
if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}
obj, err := conn.GetObj(&nftables.QuotaObj{
Table: table,
Name: "q_test",
})
if err != nil {
t.Fatalf("conn.GetObj() failed: %v", err)
}
if got, want := len(obj), 1; got != want {
t.Fatalf("unexpected object list length: got %d, want %d", got, want)
}
o1, ok := obj[0].(*nftables.QuotaObj)
if !ok {
t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0])
}
if got, want := o1.Name, o.Name; got != want {
t.Fatalf("quota name mismatch: got %s, want %s", got, want)
}
if got, want := o1.Bytes, o.Bytes; got != want {
t.Fatalf("quota bytes mismatch: got %d, want %d", got, want)
}
if got, want := o1.Consumed, o.Consumed; got != want {
t.Fatalf("quota consumed mismatch: got %d, want %d", got, want)
}
if got, want := o1.Over, o.Over; got != want {
t.Fatalf("quota over mismatch: got %v, want %v", got, want)
}
}
func TestAddLegacyQuotaObjRef(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
table := &nftables.Table{
Name: "quota_demo",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)
c := &nftables.Chain{
Name: "filter",
Table: table,
}
conn.AddChain(c)
o := &nftables.QuotaObj{
Table: tr,
Name: "q_test",
Bytes: 0x06400000,
Consumed: 0,
Over: true,
}
conn.AddObj(o)
r := &nftables.Rule{
Table: table,
Chain: c,
Exprs: []expr.Any{
&expr.Objref{
Type: 2,
Name: "q_test",
},
},
}
conn.AddRule(r)
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
rules, err := conn.GetRules(table, c)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 1; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
objref, ok := rules[0].Exprs[0].(*expr.Objref)
if !ok {
t.Fatalf("Exprs[0] is type %T, want *expr.Objref", rules[0].Exprs[0])
}
if want := r.Exprs[0]; !reflect.DeepEqual(objref, want) {
t.Errorf("objref expr = %+v, wanted %+v", objref, want)
}
}
func TestObjAPICounterLegacyType(t *testing.T) {
if os.Getenv("TRAVIS") == "true" {
t.SkipNow()
}
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
table := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
tableOther := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "foo",
})
chain := c.AddChain(&nftables.Chain{
Name: "chain",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityFilter,
})
counter1 := c.AddObj(&nftables.CounterObj{ counter1 := c.AddObj(&nftables.CounterObj{
Table: table, Table: table,
Name: "fwded1", Name: "fwded1",
@ -2226,9 +2587,10 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter1) failed: %v failed", err) t.Errorf("c.GetObject(counter1) failed: %v failed", err)
} }
rcounter1, ok := obj1.(*nftables.ObjAttr) rcounter1, ok := obj1.(*nftables.CounterObj)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj1) t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter1)
} }
if rcounter1.Name != "fwded1" { if rcounter1.Name != "fwded1" {
@ -2240,9 +2602,10 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter2) failed: %v failed", err) t.Errorf("c.GetObject(counter2) failed: %v failed", err)
} }
rcounter2, ok := obj2.(*nftables.ObjAttr) rcounter2, ok := obj2.(*nftables.CounterObj)
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2) t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter2)
} }
if rcounter2.Name != "fwded2" { if rcounter2.Name != "fwded2" {
@ -2261,7 +2624,7 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter1) failed: %v failed", err) t.Errorf("c.GetObject(counter1) failed: %v failed", err)
} }
if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 { if counter1 := obj1.(*nftables.CounterObj); counter1.Packets > 0 {
t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0)
} }
@ -2271,7 +2634,7 @@ func TestObjAPI(t *testing.T) {
t.Errorf("c.GetObject(counter2) failed: %v failed", err) t.Errorf("c.GetObject(counter2) failed: %v failed", err)
} }
if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 { if counter2 := obj2.(*nftables.CounterObj); counter2.Packets != 1 {
t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1)
} }
@ -6382,22 +6745,26 @@ func TestAddQuotaObj(t *testing.T) {
} }
conn.AddChain(c) conn.AddChain(c)
o := &nftables.QuotaObj{ o := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota,
Obj: &expr.Quota{
Bytes: 0x06400000, Bytes: 0x06400000,
Consumed: 0, Consumed: 0,
Over: true, Over: true,
},
} }
conn.AddObj(o) conn.AddObj(o)
if err := conn.Flush(); err != nil { if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err) t.Fatalf("conn.Flush() failed: %v", err)
} }
obj, err := conn.GetObj(&nftables.QuotaObj{ obj, err := conn.GetObj(&nftables.ObjAttr{
Table: table, Table: table,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota,
}) })
if err != nil { if err != nil {
t.Fatalf("conn.GetObj() failed: %v", err) t.Fatalf("conn.GetObj() failed: %v", err)
@ -6418,13 +6785,14 @@ func TestAddQuotaObj(t *testing.T) {
if !ok { if !ok {
t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj)
} }
if got, want := q.Bytes, o.Bytes; got != want { o2, _ := o.Obj.(*expr.Quota)
if got, want := q.Bytes, o2.Bytes; got != want {
t.Fatalf("quota bytes mismatch: got %d, want %d", got, want) t.Fatalf("quota bytes mismatch: got %d, want %d", got, want)
} }
if got, want := q.Consumed, o.Consumed; got != want { if got, want := q.Consumed, o2.Consumed; got != want {
t.Fatalf("quota consumed mismatch: got %d, want %d", got, want) t.Fatalf("quota consumed mismatch: got %d, want %d", got, want)
} }
if got, want := q.Over, o.Over; got != want { if got, want := q.Over, o2.Over; got != want {
t.Fatalf("quota over mismatch: got %v, want %v", got, want) t.Fatalf("quota over mismatch: got %v, want %v", got, want)
} }
} }
@ -6492,7 +6860,7 @@ func TestAddQuotaObjRef(t *testing.T) {
} }
} }
func TestDeleteQuotaObj(t *testing.T) { func TestDeleteQuotaObjMixedTypes(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset() conn.FlushRuleset()
@ -6510,12 +6878,15 @@ func TestDeleteQuotaObj(t *testing.T) {
} }
conn.AddChain(c) conn.AddChain(c)
o := &nftables.QuotaObj{ o := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota,
Obj: &expr.Quota{
Bytes: 0x06400000, Bytes: 0x06400000,
Consumed: 0, Consumed: 0,
Over: true, Over: true,
},
} }
conn.AddObj(o) conn.AddObj(o)
@ -6523,9 +6894,10 @@ func TestDeleteQuotaObj(t *testing.T) {
t.Fatalf("conn.Flush() failed: %v", err) t.Fatalf("conn.Flush() failed: %v", err)
} }
obj, err := conn.GetObj(&nftables.QuotaObj{ obj, err := conn.GetObj(&nftables.ObjAttr{
Table: table, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota,
}) })
if err != nil { if err != nil {
t.Fatalf("conn.GetObj() failed: %v", err) t.Fatalf("conn.GetObj() failed: %v", err)
@ -6535,14 +6907,15 @@ func TestDeleteQuotaObj(t *testing.T) {
t.Fatalf("unexpected number of objects: got %d, want %d", got, want) t.Fatalf("unexpected number of objects: got %d, want %d", got, want)
} }
o2, _ := o.Obj.(*expr.Quota)
want := &nftables.ObjAttr{ want := &nftables.ObjAttr{
Table: tr, Table: tr,
Name: "q_test", Name: "q_test",
Type: nftables.ObjTypeQuota, Type: nftables.ObjTypeQuota,
Obj: &expr.Quota{ Obj: &expr.Quota{
Bytes: o.Bytes, Bytes: o2.Bytes,
Consumed: o.Consumed, Consumed: o2.Consumed,
Over: o.Over, Over: o2.Over,
}, },
} }
if got, want := obj[0], want; !reflect.DeepEqual(got, want) { if got, want := obj[0], want; !reflect.DeepEqual(got, want) {

58
obj.go
View File

@ -155,13 +155,13 @@ func (cc *Conn) DeleteObject(o Obj) {
// GetObj is a legacy method that return all Obj that belongs // GetObj is a legacy method that return all Obj that belongs
// to the same table as the given one // to the same table as the given one
func (cc *Conn) GetObj(o Obj) ([]Obj, error) { func (cc *Conn) GetObj(o Obj) ([]Obj, error) {
return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ) return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ, cc.useLegacyObjType(o))
} }
// GetObjReset is a legacy method that reset all Obj that belongs // GetObjReset is a legacy method that reset all Obj that belongs
// the same table as the given one // the same table as the given one
func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) {
return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET) return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET, cc.useLegacyObjType(o))
} }
// GetObject gets the specified Object // GetObject gets the specified Object
@ -196,7 +196,7 @@ func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) {
return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET) return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET)
} }
func objFromMsg(msg netlink.Message) (Obj, error) { func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 { if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 {
return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2)
} }
@ -219,6 +219,7 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
case unix.NFTA_OBJ_TYPE: case unix.NFTA_OBJ_TYPE:
objectType = ad.Uint32() objectType = ad.Uint32()
case unix.NFTA_OBJ_DATA: case unix.NFTA_OBJ_DATA:
if !returnLegacyType {
o := ObjAttr{ o := ObjAttr{
Table: table, Table: table,
Name: name, Name: name,
@ -240,6 +241,40 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
o.Obj = exprs[0] o.Obj = exprs[0]
return &o, ad.Err() return &o, ad.Err()
} }
switch objectType {
case unix.NFT_OBJECT_COUNTER:
o := CounterObj{
Table: table,
Name: name,
}
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
return o.unmarshal(ad)
})
return &o, ad.Err()
case unix.NFT_OBJECT_QUOTA:
o := QuotaObj{
Table: table,
Name: name,
}
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
return o.unmarshal(ad)
})
return &o, ad.Err()
}
}
} }
if err := ad.Err(); err != nil { if err := ad.Err(); err != nil {
return nil, err return nil, err
@ -248,6 +283,10 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
} }
func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
return cc.getObjWithLegacyType(o, t, msgType, cc.useLegacyObjType(o))
}
func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLegacyObjType bool) ([]Obj, error) {
conn, closer, err := cc.netlinkConn() conn, closer, err := cc.netlinkConn()
if err != nil { if err != nil {
return nil, err return nil, err
@ -292,7 +331,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
} }
var objs []Obj var objs []Obj
for _, msg := range reply { for _, msg := range reply {
o, err := objFromMsg(msg) o, err := objFromMsg(msg, returnLegacyObjType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -301,3 +340,14 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
return objs, nil return objs, nil
} }
func (cc *Conn) useLegacyObjType(o Obj) bool {
useLegacyType := true
if o != nil {
switch o.(type) {
case *ObjAttr:
useLegacyType = false
}
}
return useLegacyType
}

View File

@ -21,7 +21,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// Deprecated: Use ObjAttr instead
type QuotaObj struct { type QuotaObj struct {
Table *Table Table *Table
Name string Name string

34
xt/comment.go Normal file
View File

@ -0,0 +1,34 @@
package xt
import (
"bytes"
"fmt"
)
// CommentSize is the fixed size of a comment info xt blob, see:
// https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/netfilter/xt_comment.h#L5
const CommentSize = 256
// Comment gets marshalled and unmarshalled as a fixed-sized char array, filled
// with zeros as necessary, see:
// https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/netfilter/xt_comment.h#L7
type Comment string
func (c *Comment) marshal(fam TableFamily, rev uint32) ([]byte, error) {
if len(*c) >= CommentSize {
return nil, fmt.Errorf("comment must be less than %d bytes, got %d bytes",
CommentSize, len(*c))
}
data := make([]byte, CommentSize)
copy(data, []byte(*c))
return data, nil
}
func (c *Comment) unmarshal(fam TableFamily, rev uint32, data []byte) error {
if len(data) != CommentSize {
return fmt.Errorf("malformed comment: got %d bytes, expected exactly %d bytes",
len(data), CommentSize)
}
*c = Comment(bytes.TrimRight(data, "\x00"))
return nil
}

62
xt/comment_test.go Normal file
View File

@ -0,0 +1,62 @@
package xt
import (
"reflect"
"strings"
"testing"
)
func TestComment(t *testing.T) {
t.Parallel()
payload := Comment("The quick brown fox jumps over the lazy dog.")
oversized := Comment(strings.Repeat("foobar", 100))
tests := []struct {
name string
info InfoAny
errmsg string
}{
{
name: "un/marshal Comment round-trip",
info: &payload,
},
{
name: "marshal oversized Comment",
info: &oversized,
errmsg: "comment must be less than 256 bytes, got 600 bytes",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(0, 0)
if err != nil {
if tt.errmsg != "" && err.Error() == tt.errmsg {
return
}
t.Fatalf("marshal error: %+v", err)
}
if len(data) != CommentSize {
t.Fatalf("marshal error: invalid size %d", len(data))
}
if data[len(data)-1] != 0 {
t.Fatalf("marshal error: invalid termination")
}
var comment Comment
var recoveredInfo InfoAny = &comment
err = recoveredInfo.unmarshal(0, 0, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
oversizeddata := []byte(oversized)
var comment Comment
if err := (&comment).unmarshal(0, 0, oversizeddata); err == nil {
t.Fatalf("unmarshal: expected error, but got nil")
}
}

View File

@ -34,6 +34,9 @@ func Unmarshal(name string, fam TableFamily, rev uint32, data []byte) (InfoAny,
case 1: case 1:
i = &AddrTypeV1{} i = &AddrTypeV1{}
} }
case "comment":
var c Comment
i = &c
case "conntrack": case "conntrack":
switch rev { switch rev {
case 1: case 1: