Compare commits
1 Commits
533a9343c8
...
b77f1a918e
Author | SHA1 | Date |
---|---|---|
|
b77f1a918e |
|
@ -21,6 +21,7 @@ import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Deprecated: Use ObjAttr instead
|
||||||
type CounterObj struct {
|
type CounterObj struct {
|
||||||
Table *Table
|
Table *Table
|
||||||
Name string // e.g. “fwded”
|
Name string // e.g. “fwded”
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -12,6 +12,6 @@ require (
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
github.com/josharian/native v1.1.0 // indirect
|
github.com/josharian/native v1.1.0 // indirect
|
||||||
github.com/mdlayher/socket v0.5.0 // indirect
|
github.com/mdlayher/socket v0.5.0 // indirect
|
||||||
golang.org/x/net v0.23.0 // indirect
|
golang.org/x/net v0.22.0 // indirect
|
||||||
golang.org/x/sync v0.6.0 // indirect
|
golang.org/x/sync v0.6.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -8,8 +8,8 @@ github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI
|
||||||
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
|
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
|
||||||
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
|
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
|
||||||
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
|
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
|
||||||
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
|
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||||
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||||
|
|
|
@ -259,7 +259,7 @@ func (monitor *Monitor) monitor() {
|
||||||
}
|
}
|
||||||
monitor.eventCh <- event
|
monitor.eventCh <- event
|
||||||
case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ:
|
case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ:
|
||||||
obj, err := objFromMsg(msg, true)
|
obj, err := objFromMsg(msg)
|
||||||
event := &MonitorEvent{
|
event := &MonitorEvent{
|
||||||
Type: MonitorEventType(msgType),
|
Type: MonitorEventType(msgType),
|
||||||
Data: obj,
|
Data: obj,
|
||||||
|
|
413
nftables_test.go
413
nftables_test.go
|
@ -2105,10 +2105,9 @@ func TestGetObjReset(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
|
filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
|
||||||
obj, err := c.ResetObject(&nftables.ObjAttr{
|
obj, err := c.ResetObject(&nftables.CounterObj{
|
||||||
Table: filter,
|
Table: filter,
|
||||||
Name: "fwded",
|
Name: "fwded",
|
||||||
Type: nftables.ObjTypeCounter,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -2168,366 +2167,6 @@ func TestObjAPI(t *testing.T) {
|
||||||
Priority: nftables.ChainPriorityFilter,
|
Priority: nftables.ChainPriorityFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
counter1 := c.AddObj(&nftables.ObjAttr{
|
|
||||||
Table: table,
|
|
||||||
Name: "fwded1",
|
|
||||||
Type: nftables.ObjTypeCounter,
|
|
||||||
Obj: &expr.Counter{
|
|
||||||
Bytes: 1,
|
|
||||||
Packets: 1,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
counter2 := c.AddObj(&nftables.ObjAttr{
|
|
||||||
Table: table,
|
|
||||||
Name: "fwded2",
|
|
||||||
Type: nftables.ObjTypeCounter,
|
|
||||||
Obj: &expr.Counter{
|
|
||||||
Bytes: 1,
|
|
||||||
Packets: 1,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
c.AddObj(&nftables.ObjAttr{
|
|
||||||
Table: tableOther,
|
|
||||||
Name: "fwdedOther",
|
|
||||||
Type: nftables.ObjTypeCounter,
|
|
||||||
Obj: &expr.Counter{
|
|
||||||
Bytes: 0,
|
|
||||||
Packets: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
c.AddRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Objref{
|
|
||||||
Type: 1,
|
|
||||||
Name: "fwded1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
objs, err := c.GetObjects(table)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObjects(table) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got := len(objs); got != 2 {
|
|
||||||
t.Fatalf("unexpected number of objects: got %d, want %d", got, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
objsOther, err := c.GetObjects(tableOther)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObjects(tableOther) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got := len(objsOther); got != 1 {
|
|
||||||
t.Fatalf("unexpected number of objects: got %d, want %d", got, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
obj1, err := c.GetObject(counter1)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rcounter1, ok := obj1.(*nftables.ObjAttr)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rcounter1.Name != "fwded1" {
|
|
||||||
t.Fatalf("unexpected counter name: got %s, want %s", rcounter1.Name, "fwded1")
|
|
||||||
}
|
|
||||||
|
|
||||||
obj2, err := c.GetObject(counter2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rcounter2, ok := obj2.(*nftables.ObjAttr)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rcounter2.Name != "fwded2" {
|
|
||||||
t.Fatalf("unexpected counter name: got %s, want %s", rcounter2.Name, "fwded2")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.ResetObject(counter1)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.ResetObjects(table) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
obj1, err = c.GetObject(counter1)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 {
|
|
||||||
t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
obj2, err = c.GetObject(counter2)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 {
|
|
||||||
t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
legacy, err := c.GetObj(counter1)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObj(counter1) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacy) != 2 {
|
|
||||||
t.Errorf("unexpected number of objects: got %d, want %d", len(legacy), 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
legacyReset, err := c.GetObjReset(counter1)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("c.GetObjReset(counter1) failed: %v failed", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacyReset) != 2 {
|
|
||||||
t.Errorf("unexpected number of objects: got %d, want %d", len(legacyReset), 2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteLegacyQuotaObj(t *testing.T) {
|
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
conn.FlushRuleset()
|
|
||||||
defer conn.FlushRuleset()
|
|
||||||
|
|
||||||
table := &nftables.Table{
|
|
||||||
Name: "quota_demo",
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
}
|
|
||||||
tr := conn.AddTable(table)
|
|
||||||
|
|
||||||
c := &nftables.Chain{
|
|
||||||
Name: "filter",
|
|
||||||
Table: table,
|
|
||||||
}
|
|
||||||
conn.AddChain(c)
|
|
||||||
|
|
||||||
o := &nftables.QuotaObj{
|
|
||||||
Table: tr,
|
|
||||||
Name: "q_test",
|
|
||||||
Bytes: 0x06400000,
|
|
||||||
Consumed: 0,
|
|
||||||
Over: true,
|
|
||||||
}
|
|
||||||
conn.AddObj(o)
|
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Fatalf("conn.Flush() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
obj, err := conn.GetObj(&nftables.QuotaObj{
|
|
||||||
Table: table,
|
|
||||||
Name: "q_test",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetObj() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := len(obj), 1; got != want {
|
|
||||||
t.Fatalf("unexpected number of objects: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := obj[0], o; !reflect.DeepEqual(got, want) {
|
|
||||||
t.Errorf("got = %+v, want = %+v", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.DeleteObject(&nftables.QuotaObj{
|
|
||||||
Table: tr,
|
|
||||||
Name: "q_test",
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Fatalf("conn.Flush() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
obj, err = conn.GetObj(&nftables.QuotaObj{
|
|
||||||
Table: table,
|
|
||||||
Name: "q_test",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetObj() failed: %v", err)
|
|
||||||
}
|
|
||||||
if got, want := len(obj), 0; got != want {
|
|
||||||
t.Fatalf("unexpected object list length: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddLegacyQuotaObj(t *testing.T) {
|
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
conn.FlushRuleset()
|
|
||||||
defer conn.FlushRuleset()
|
|
||||||
|
|
||||||
table := &nftables.Table{
|
|
||||||
Name: "quota_demo",
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
}
|
|
||||||
tr := conn.AddTable(table)
|
|
||||||
|
|
||||||
c := &nftables.Chain{
|
|
||||||
Name: "filter",
|
|
||||||
Table: table,
|
|
||||||
}
|
|
||||||
conn.AddChain(c)
|
|
||||||
|
|
||||||
o := &nftables.QuotaObj{
|
|
||||||
Table: tr,
|
|
||||||
Name: "q_test",
|
|
||||||
Bytes: 0x06400000,
|
|
||||||
Consumed: 0,
|
|
||||||
Over: true,
|
|
||||||
}
|
|
||||||
conn.AddObj(o)
|
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Errorf("conn.Flush() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
obj, err := conn.GetObj(&nftables.QuotaObj{
|
|
||||||
Table: table,
|
|
||||||
Name: "q_test",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn.GetObj() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := len(obj), 1; got != want {
|
|
||||||
t.Fatalf("unexpected object list length: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
o1, ok := obj[0].(*nftables.QuotaObj)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0])
|
|
||||||
}
|
|
||||||
if got, want := o1.Name, o.Name; got != want {
|
|
||||||
t.Fatalf("quota name mismatch: got %s, want %s", got, want)
|
|
||||||
}
|
|
||||||
if got, want := o1.Bytes, o.Bytes; got != want {
|
|
||||||
t.Fatalf("quota bytes mismatch: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
if got, want := o1.Consumed, o.Consumed; got != want {
|
|
||||||
t.Fatalf("quota consumed mismatch: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
if got, want := o1.Over, o.Over; got != want {
|
|
||||||
t.Fatalf("quota over mismatch: got %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddLegacyQuotaObjRef(t *testing.T) {
|
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
conn.FlushRuleset()
|
|
||||||
defer conn.FlushRuleset()
|
|
||||||
|
|
||||||
table := &nftables.Table{
|
|
||||||
Name: "quota_demo",
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
}
|
|
||||||
tr := conn.AddTable(table)
|
|
||||||
|
|
||||||
c := &nftables.Chain{
|
|
||||||
Name: "filter",
|
|
||||||
Table: table,
|
|
||||||
}
|
|
||||||
conn.AddChain(c)
|
|
||||||
|
|
||||||
o := &nftables.QuotaObj{
|
|
||||||
Table: tr,
|
|
||||||
Name: "q_test",
|
|
||||||
Bytes: 0x06400000,
|
|
||||||
Consumed: 0,
|
|
||||||
Over: true,
|
|
||||||
}
|
|
||||||
conn.AddObj(o)
|
|
||||||
|
|
||||||
r := &nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: c,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Objref{
|
|
||||||
Type: 2,
|
|
||||||
Name: "q_test",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn.AddRule(r)
|
|
||||||
if err := conn.Flush(); err != nil {
|
|
||||||
t.Fatalf("failed to flush: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := conn.GetRules(table, c)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := len(rules), 1; got != want {
|
|
||||||
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
if got, want := len(rules[0].Exprs), 1; got != want {
|
|
||||||
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
objref, ok := rules[0].Exprs[0].(*expr.Objref)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Exprs[0] is type %T, want *expr.Objref", rules[0].Exprs[0])
|
|
||||||
}
|
|
||||||
if want := r.Exprs[0]; !reflect.DeepEqual(objref, want) {
|
|
||||||
t.Errorf("objref expr = %+v, wanted %+v", objref, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestObjAPICounterLegacyType(t *testing.T) {
|
|
||||||
if os.Getenv("TRAVIS") == "true" {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new network namespace to test these operations,
|
|
||||||
// and tear down the namespace at test completion.
|
|
||||||
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
|
||||||
|
|
||||||
// Clear all rules at the beginning + end of the test.
|
|
||||||
c.FlushRuleset()
|
|
||||||
defer c.FlushRuleset()
|
|
||||||
|
|
||||||
table := c.AddTable(&nftables.Table{
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
Name: "filter",
|
|
||||||
})
|
|
||||||
|
|
||||||
tableOther := c.AddTable(&nftables.Table{
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
Name: "foo",
|
|
||||||
})
|
|
||||||
|
|
||||||
chain := c.AddChain(&nftables.Chain{
|
|
||||||
Name: "chain",
|
|
||||||
Table: table,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
})
|
|
||||||
|
|
||||||
counter1 := c.AddObj(&nftables.CounterObj{
|
counter1 := c.AddObj(&nftables.CounterObj{
|
||||||
Table: table,
|
Table: table,
|
||||||
Name: "fwded1",
|
Name: "fwded1",
|
||||||
|
@ -2587,10 +2226,9 @@ func TestObjAPICounterLegacyType(t *testing.T) {
|
||||||
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
|
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rcounter1, ok := obj1.(*nftables.CounterObj)
|
rcounter1, ok := obj1.(*nftables.ObjAttr)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter1)
|
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rcounter1.Name != "fwded1" {
|
if rcounter1.Name != "fwded1" {
|
||||||
|
@ -2602,10 +2240,9 @@ func TestObjAPICounterLegacyType(t *testing.T) {
|
||||||
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
|
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rcounter2, ok := obj2.(*nftables.CounterObj)
|
rcounter2, ok := obj2.(*nftables.ObjAttr)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter2)
|
t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rcounter2.Name != "fwded2" {
|
if rcounter2.Name != "fwded2" {
|
||||||
|
@ -2624,7 +2261,7 @@ func TestObjAPICounterLegacyType(t *testing.T) {
|
||||||
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
|
t.Errorf("c.GetObject(counter1) failed: %v failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if counter1 := obj1.(*nftables.CounterObj); counter1.Packets > 0 {
|
if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 {
|
||||||
t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0)
|
t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2634,7 +2271,7 @@ func TestObjAPICounterLegacyType(t *testing.T) {
|
||||||
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
|
t.Errorf("c.GetObject(counter2) failed: %v failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if counter2 := obj2.(*nftables.CounterObj); counter2.Packets != 1 {
|
if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 {
|
||||||
t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1)
|
t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6745,26 +6382,22 @@ func TestAddQuotaObj(t *testing.T) {
|
||||||
}
|
}
|
||||||
conn.AddChain(c)
|
conn.AddChain(c)
|
||||||
|
|
||||||
o := &nftables.ObjAttr{
|
o := &nftables.QuotaObj{
|
||||||
Table: tr,
|
Table: tr,
|
||||||
Name: "q_test",
|
Name: "q_test",
|
||||||
Type: nftables.ObjTypeQuota,
|
|
||||||
Obj: &expr.Quota{
|
|
||||||
Bytes: 0x06400000,
|
Bytes: 0x06400000,
|
||||||
Consumed: 0,
|
Consumed: 0,
|
||||||
Over: true,
|
Over: true,
|
||||||
},
|
|
||||||
}
|
}
|
||||||
conn.AddObj(o)
|
conn.AddObj(o)
|
||||||
|
|
||||||
if err := conn.Flush(); err != nil {
|
if err := conn.Flush(); err != nil {
|
||||||
t.Fatalf("conn.Flush() failed: %v", err)
|
t.Errorf("conn.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
obj, err := conn.GetObj(&nftables.ObjAttr{
|
obj, err := conn.GetObj(&nftables.QuotaObj{
|
||||||
Table: table,
|
Table: table,
|
||||||
Name: "q_test",
|
Name: "q_test",
|
||||||
Type: nftables.ObjTypeQuota,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("conn.GetObj() failed: %v", err)
|
t.Fatalf("conn.GetObj() failed: %v", err)
|
||||||
|
@ -6785,14 +6418,13 @@ func TestAddQuotaObj(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj)
|
t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj)
|
||||||
}
|
}
|
||||||
o2, _ := o.Obj.(*expr.Quota)
|
if got, want := q.Bytes, o.Bytes; got != want {
|
||||||
if got, want := q.Bytes, o2.Bytes; got != want {
|
|
||||||
t.Fatalf("quota bytes mismatch: got %d, want %d", got, want)
|
t.Fatalf("quota bytes mismatch: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
if got, want := q.Consumed, o2.Consumed; got != want {
|
if got, want := q.Consumed, o.Consumed; got != want {
|
||||||
t.Fatalf("quota consumed mismatch: got %d, want %d", got, want)
|
t.Fatalf("quota consumed mismatch: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
if got, want := q.Over, o2.Over; got != want {
|
if got, want := q.Over, o.Over; got != want {
|
||||||
t.Fatalf("quota over mismatch: got %v, want %v", got, want)
|
t.Fatalf("quota over mismatch: got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6860,7 +6492,7 @@ func TestAddQuotaObjRef(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteQuotaObjMixedTypes(t *testing.T) {
|
func TestDeleteQuotaObj(t *testing.T) {
|
||||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
defer nftest.CleanupSystemConn(t, newNS)
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
conn.FlushRuleset()
|
conn.FlushRuleset()
|
||||||
|
@ -6878,15 +6510,12 @@ func TestDeleteQuotaObjMixedTypes(t *testing.T) {
|
||||||
}
|
}
|
||||||
conn.AddChain(c)
|
conn.AddChain(c)
|
||||||
|
|
||||||
o := &nftables.ObjAttr{
|
o := &nftables.QuotaObj{
|
||||||
Table: tr,
|
Table: tr,
|
||||||
Name: "q_test",
|
Name: "q_test",
|
||||||
Type: nftables.ObjTypeQuota,
|
|
||||||
Obj: &expr.Quota{
|
|
||||||
Bytes: 0x06400000,
|
Bytes: 0x06400000,
|
||||||
Consumed: 0,
|
Consumed: 0,
|
||||||
Over: true,
|
Over: true,
|
||||||
},
|
|
||||||
}
|
}
|
||||||
conn.AddObj(o)
|
conn.AddObj(o)
|
||||||
|
|
||||||
|
@ -6894,10 +6523,9 @@ func TestDeleteQuotaObjMixedTypes(t *testing.T) {
|
||||||
t.Fatalf("conn.Flush() failed: %v", err)
|
t.Fatalf("conn.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
obj, err := conn.GetObj(&nftables.ObjAttr{
|
obj, err := conn.GetObj(&nftables.QuotaObj{
|
||||||
Table: tr,
|
Table: table,
|
||||||
Name: "q_test",
|
Name: "q_test",
|
||||||
Type: nftables.ObjTypeQuota,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("conn.GetObj() failed: %v", err)
|
t.Fatalf("conn.GetObj() failed: %v", err)
|
||||||
|
@ -6907,15 +6535,14 @@ func TestDeleteQuotaObjMixedTypes(t *testing.T) {
|
||||||
t.Fatalf("unexpected number of objects: got %d, want %d", got, want)
|
t.Fatalf("unexpected number of objects: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
o2, _ := o.Obj.(*expr.Quota)
|
|
||||||
want := &nftables.ObjAttr{
|
want := &nftables.ObjAttr{
|
||||||
Table: tr,
|
Table: tr,
|
||||||
Name: "q_test",
|
Name: "q_test",
|
||||||
Type: nftables.ObjTypeQuota,
|
Type: nftables.ObjTypeQuota,
|
||||||
Obj: &expr.Quota{
|
Obj: &expr.Quota{
|
||||||
Bytes: o2.Bytes,
|
Bytes: o.Bytes,
|
||||||
Consumed: o2.Consumed,
|
Consumed: o.Consumed,
|
||||||
Over: o2.Over,
|
Over: o.Over,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if got, want := obj[0], want; !reflect.DeepEqual(got, want) {
|
if got, want := obj[0], want; !reflect.DeepEqual(got, want) {
|
||||||
|
|
58
obj.go
58
obj.go
|
@ -155,13 +155,13 @@ func (cc *Conn) DeleteObject(o Obj) {
|
||||||
// GetObj is a legacy method that return all Obj that belongs
|
// GetObj is a legacy method that return all Obj that belongs
|
||||||
// to the same table as the given one
|
// to the same table as the given one
|
||||||
func (cc *Conn) GetObj(o Obj) ([]Obj, error) {
|
func (cc *Conn) GetObj(o Obj) ([]Obj, error) {
|
||||||
return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ, cc.useLegacyObjType(o))
|
return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetObjReset is a legacy method that reset all Obj that belongs
|
// GetObjReset is a legacy method that reset all Obj that belongs
|
||||||
// the same table as the given one
|
// the same table as the given one
|
||||||
func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) {
|
func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) {
|
||||||
return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET, cc.useLegacyObjType(o))
|
return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetObject gets the specified Object
|
// GetObject gets the specified Object
|
||||||
|
@ -196,7 +196,7 @@ func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) {
|
||||||
return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET)
|
return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET)
|
||||||
}
|
}
|
||||||
|
|
||||||
func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
|
func objFromMsg(msg netlink.Message) (Obj, error) {
|
||||||
if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 {
|
if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 {
|
||||||
return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2)
|
return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2)
|
||||||
}
|
}
|
||||||
|
@ -219,7 +219,6 @@ func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
|
||||||
case unix.NFTA_OBJ_TYPE:
|
case unix.NFTA_OBJ_TYPE:
|
||||||
objectType = ad.Uint32()
|
objectType = ad.Uint32()
|
||||||
case unix.NFTA_OBJ_DATA:
|
case unix.NFTA_OBJ_DATA:
|
||||||
if !returnLegacyType {
|
|
||||||
o := ObjAttr{
|
o := ObjAttr{
|
||||||
Table: table,
|
Table: table,
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -241,40 +240,6 @@ func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
|
||||||
o.Obj = exprs[0]
|
o.Obj = exprs[0]
|
||||||
return &o, ad.Err()
|
return &o, ad.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
switch objectType {
|
|
||||||
case unix.NFT_OBJECT_COUNTER:
|
|
||||||
o := CounterObj{
|
|
||||||
Table: table,
|
|
||||||
Name: name,
|
|
||||||
}
|
|
||||||
|
|
||||||
ad.Do(func(b []byte) error {
|
|
||||||
ad, err := netlink.NewAttributeDecoder(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ad.ByteOrder = binary.BigEndian
|
|
||||||
return o.unmarshal(ad)
|
|
||||||
})
|
|
||||||
return &o, ad.Err()
|
|
||||||
case unix.NFT_OBJECT_QUOTA:
|
|
||||||
o := QuotaObj{
|
|
||||||
Table: table,
|
|
||||||
Name: name,
|
|
||||||
}
|
|
||||||
|
|
||||||
ad.Do(func(b []byte) error {
|
|
||||||
ad, err := netlink.NewAttributeDecoder(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ad.ByteOrder = binary.BigEndian
|
|
||||||
return o.unmarshal(ad)
|
|
||||||
})
|
|
||||||
return &o, ad.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := ad.Err(); err != nil {
|
if err := ad.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -283,10 +248,6 @@ func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
|
func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
|
||||||
return cc.getObjWithLegacyType(o, t, msgType, cc.useLegacyObjType(o))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLegacyObjType bool) ([]Obj, error) {
|
|
||||||
conn, closer, err := cc.netlinkConn()
|
conn, closer, err := cc.netlinkConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -331,7 +292,7 @@ func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLega
|
||||||
}
|
}
|
||||||
var objs []Obj
|
var objs []Obj
|
||||||
for _, msg := range reply {
|
for _, msg := range reply {
|
||||||
o, err := objFromMsg(msg, returnLegacyObjType)
|
o, err := objFromMsg(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -340,14 +301,3 @@ func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLega
|
||||||
|
|
||||||
return objs, nil
|
return objs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *Conn) useLegacyObjType(o Obj) bool {
|
|
||||||
useLegacyType := true
|
|
||||||
if o != nil {
|
|
||||||
switch o.(type) {
|
|
||||||
case *ObjAttr:
|
|
||||||
useLegacyType = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return useLegacyType
|
|
||||||
}
|
|
||||||
|
|
1
quota.go
1
quota.go
|
@ -21,6 +21,7 @@ import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Deprecated: Use ObjAttr instead
|
||||||
type QuotaObj struct {
|
type QuotaObj struct {
|
||||||
Table *Table
|
Table *Table
|
||||||
Name string
|
Name string
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
package xt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CommentSize is the fixed size of a comment info xt blob, see:
|
|
||||||
// https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/netfilter/xt_comment.h#L5
|
|
||||||
const CommentSize = 256
|
|
||||||
|
|
||||||
// Comment gets marshalled and unmarshalled as a fixed-sized char array, filled
|
|
||||||
// with zeros as necessary, see:
|
|
||||||
// https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/netfilter/xt_comment.h#L7
|
|
||||||
type Comment string
|
|
||||||
|
|
||||||
func (c *Comment) marshal(fam TableFamily, rev uint32) ([]byte, error) {
|
|
||||||
if len(*c) >= CommentSize {
|
|
||||||
return nil, fmt.Errorf("comment must be less than %d bytes, got %d bytes",
|
|
||||||
CommentSize, len(*c))
|
|
||||||
}
|
|
||||||
data := make([]byte, CommentSize)
|
|
||||||
copy(data, []byte(*c))
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Comment) unmarshal(fam TableFamily, rev uint32, data []byte) error {
|
|
||||||
if len(data) != CommentSize {
|
|
||||||
return fmt.Errorf("malformed comment: got %d bytes, expected exactly %d bytes",
|
|
||||||
len(data), CommentSize)
|
|
||||||
}
|
|
||||||
*c = Comment(bytes.TrimRight(data, "\x00"))
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
package xt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestComment(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
payload := Comment("The quick brown fox jumps over the lazy dog.")
|
|
||||||
oversized := Comment(strings.Repeat("foobar", 100))
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
info InfoAny
|
|
||||||
errmsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "un/marshal Comment round-trip",
|
|
||||||
info: &payload,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "marshal oversized Comment",
|
|
||||||
info: &oversized,
|
|
||||||
errmsg: "comment must be less than 256 bytes, got 600 bytes",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
data, err := tt.info.marshal(0, 0)
|
|
||||||
if err != nil {
|
|
||||||
if tt.errmsg != "" && err.Error() == tt.errmsg {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("marshal error: %+v", err)
|
|
||||||
|
|
||||||
}
|
|
||||||
if len(data) != CommentSize {
|
|
||||||
t.Fatalf("marshal error: invalid size %d", len(data))
|
|
||||||
}
|
|
||||||
if data[len(data)-1] != 0 {
|
|
||||||
t.Fatalf("marshal error: invalid termination")
|
|
||||||
}
|
|
||||||
var comment Comment
|
|
||||||
var recoveredInfo InfoAny = &comment
|
|
||||||
err = recoveredInfo.unmarshal(0, 0, data)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unmarshal error: %+v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(tt.info, recoveredInfo) {
|
|
||||||
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
oversizeddata := []byte(oversized)
|
|
||||||
var comment Comment
|
|
||||||
if err := (&comment).unmarshal(0, 0, oversizeddata); err == nil {
|
|
||||||
t.Fatalf("unmarshal: expected error, but got nil")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -34,9 +34,6 @@ func Unmarshal(name string, fam TableFamily, rev uint32, data []byte) (InfoAny,
|
||||||
case 1:
|
case 1:
|
||||||
i = &AddrTypeV1{}
|
i = &AddrTypeV1{}
|
||||||
}
|
}
|
||||||
case "comment":
|
|
||||||
var c Comment
|
|
||||||
i = &c
|
|
||||||
case "conntrack":
|
case "conntrack":
|
||||||
switch rev {
|
switch rev {
|
||||||
case 1:
|
case 1:
|
||||||
|
|
Loading…
Reference in New Issue