Compare commits

..

1 Commits

Author SHA1 Message Date
Mikhail Sennikovsky 0d0cf9d2ab
Merge a0423c9897 into e99829fb4f 2024-12-20 19:33:55 +08:00
16 changed files with 38 additions and 633 deletions

View File

@ -33,5 +33,3 @@ jobs:
go test ./...
go test -c github.com/google/nftables
sudo ./nftables.test -test.v -run_system_tests
go test -c github.com/google/nftables/integration
(cd integration && sudo ../integration.test -test.v -run_system_tests)

View File

@ -207,8 +207,6 @@ func exprFromName(name string) Any {
e = &SecMark{}
case "cttimeout":
e = &CtTimeout{}
case "fib":
e = &Fib{}
}
return e
}

View File

@ -118,22 +118,17 @@ func (e *Fib) unmarshal(fam byte, data []byte) error {
e.Register = ad.Uint32()
case unix.NFTA_FIB_RESULT:
result := ad.Uint32()
switch result {
case unix.NFT_FIB_RESULT_OIF:
e.ResultOIF = true
case unix.NFT_FIB_RESULT_OIFNAME:
e.ResultOIFNAME = true
case unix.NFT_FIB_RESULT_ADDRTYPE:
e.ResultADDRTYPE = true
}
e.ResultOIF = (result & unix.NFT_FIB_RESULT_OIF) == 1
e.ResultOIFNAME = (result & unix.NFT_FIB_RESULT_OIFNAME) == 1
e.ResultADDRTYPE = (result & unix.NFT_FIB_RESULT_ADDRTYPE) == 1
case unix.NFTA_FIB_FLAGS:
flags := ad.Uint32()
e.FlagSADDR = (flags & unix.NFTA_FIB_F_SADDR) != 0
e.FlagDADDR = (flags & unix.NFTA_FIB_F_DADDR) != 0
e.FlagMARK = (flags & unix.NFTA_FIB_F_MARK) != 0
e.FlagIIF = (flags & unix.NFTA_FIB_F_IIF) != 0
e.FlagOIF = (flags & unix.NFTA_FIB_F_OIF) != 0
e.FlagPRESENT = (flags & unix.NFTA_FIB_F_PRESENT) != 0
e.FlagSADDR = (flags & unix.NFTA_FIB_F_SADDR) == 1
e.FlagDADDR = (flags & unix.NFTA_FIB_F_DADDR) == 1
e.FlagMARK = (flags & unix.NFTA_FIB_F_MARK) == 1
e.FlagIIF = (flags & unix.NFTA_FIB_F_IIF) == 1
e.FlagOIF = (flags & unix.NFTA_FIB_F_OIF) == 1
e.FlagPRESENT = (flags & unix.NFTA_FIB_F_PRESENT) == 1
}
}
return ad.Err()

View File

@ -123,7 +123,7 @@ func (l *Limit) unmarshal(fam byte, data []byte) error {
return fmt.Errorf("expr: invalid limit type %d", l.Type)
}
case unix.NFTA_LIMIT_FLAGS:
l.Over = (ad.Uint32() & unix.NFT_LIMIT_F_INV) != 0
l.Over = (ad.Uint32() & unix.NFT_LIMIT_F_INV) == 1
default:
return errors.New("expr: unhandled limit netlink attribute")
}

View File

@ -73,7 +73,7 @@ func (q *Quota) unmarshal(fam byte, data []byte) error {
case unix.NFTA_QUOTA_CONSUMED:
q.Consumed = ad.Uint64()
case unix.NFTA_QUOTA_FLAGS:
q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) != 0
q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) == 1
}
}
return ad.Err()

12
go.mod
View File

@ -3,15 +3,15 @@ module github.com/google/nftables
go 1.21
require (
github.com/google/go-cmp v0.6.0
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42
github.com/vishvananda/netlink v1.3.0
github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.28.0
github.com/mdlayher/netlink v1.7.2
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc
golang.org/x/sys v0.18.0
)
require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/sync v0.6.0 // indirect
)

22
go.sum
View File

@ -1,18 +1,16 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

View File

@ -1,252 +0,0 @@
// Copyright 2025 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"flag"
"os/exec"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/internal/nftest"
"github.com/vishvananda/netlink"
)
var enableSysTests = flag.Bool("run_system_tests", false, "Run tests that operate against the live kernel")
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
return b
}
func TestNFTables(t *testing.T) {
tests := []struct {
name string
scriptPath string
goCommands func(t *testing.T, c *nftables.Conn)
expectFailure bool
}{
{
name: "AddTable",
scriptPath: "testdata/add_table.nft",
goCommands: func(t *testing.T, c *nftables.Conn) {
c.FlushRuleset()
c.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyINet,
})
err := c.Flush()
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
},
},
{
name: "AddChain",
scriptPath: "testdata/add_chain.nft",
goCommands: func(t *testing.T, c *nftables.Conn) {
c.FlushRuleset()
table := c.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyINet,
})
c.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
err := c.Flush()
if err != nil {
t.Fatalf("Error creating table: %v", err)
}
},
},
{
name: "AddFlowtables",
scriptPath: "testdata/add_flowtables.nft",
goCommands: func(t *testing.T, c *nftables.Conn) {
devices := []string{"dummy0"}
c.FlushRuleset()
// add + delete + add for flushing all the table
table := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyINet,
Name: "test-table",
})
devicesSet := &nftables.Set{
Table: table,
Name: "test-set",
KeyType: nftables.TypeIFName,
KeyByteOrder: binaryutil.NativeEndian,
}
elements := []nftables.SetElement{}
for _, dev := range devices {
elements = append(elements, nftables.SetElement{
Key: ifname(dev),
})
}
if err := c.AddSet(devicesSet, elements); err != nil {
t.Errorf("failed to add Set %s : %v", devicesSet.Name, err)
}
flowtable := &nftables.Flowtable{
Table: table,
Name: "test-flowtable",
Devices: devices,
Hooknum: nftables.FlowtableHookIngress,
Priority: nftables.FlowtablePriorityRef(5),
}
c.AddFlowtable(flowtable)
chain := c.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
})
c.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, SourceRegister: false, Register: 0x1},
&expr.Lookup{SourceRegister: 0x1, DestRegister: 0x0, IsDestRegSet: false, SetName: "test-set", Invert: true},
&expr.Verdict{Kind: expr.VerdictReturn},
},
})
c.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, SourceRegister: false, Register: 0x1},
&expr.Lookup{SourceRegister: 0x1, DestRegister: 0x0, IsDestRegSet: false, SetName: "test-set", Invert: true},
&expr.Verdict{Kind: expr.VerdictReturn},
},
})
c.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{Register: 0x1, SourceRegister: false, Key: expr.CtKeySTATE, Direction: 0x0},
&expr.Bitwise{SourceRegister: 0x1, DestRegister: 0x1, Len: 0x4, Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED), Xor: binaryutil.NativeEndian.PutUint32(0)},
&expr.Cmp{Op: 0x1, Register: 0x1, Data: []uint8{0x0, 0x0, 0x0, 0x0}},
&expr.Ct{Register: 0x1, SourceRegister: false, Key: expr.CtKeyPKTS, Direction: 0x0},
&expr.Cmp{Op: expr.CmpOpGt, Register: 0x1, Data: binaryutil.NativeEndian.PutUint64(20)},
&expr.FlowOffload{Name: "test-flowtable"},
&expr.Counter{},
},
})
if err := c.Flush(); err != nil {
t.Fatal(err)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
// Real interface must exist otherwise some nftables will fail
la := netlink.NewLinkAttrs()
la.Name = "dummy0"
dummy := &netlink.Dummy{LinkAttrs: la}
if err := netlink.LinkAdd(dummy); err != nil {
t.Fatal(err)
}
scriptOutput, err := applyNFTRuleset(tt.scriptPath)
if err != nil {
t.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
}
if len(scriptOutput) > 0 {
t.Logf("nft output:\n%s", scriptOutput)
}
// Retrieve nftables state using nft
expectedOutput, err := listNFTRuleset()
if err != nil {
t.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, expectedOutput)
}
t.Logf("Expected output:\n%s", expectedOutput)
// Program nftables using your Go code
if err := flushNFTRuleset(); err != nil {
t.Fatalf("Failed to flush nftables ruleset: %v", err)
}
tt.goCommands(t, c)
// Retrieve nftables state using nft
actualOutput, err := listNFTRuleset()
if err != nil {
t.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}
t.Logf("Actual output:\n%s", actualOutput)
if expectedOutput != actualOutput {
t.Errorf("nftables ruleset mismatch:\n%s", cmp.Diff(expectedOutput, actualOutput))
}
if err := flushNFTRuleset(); err != nil {
t.Fatalf("Failed to flush nftables ruleset: %v", err)
}
})
}
}
func applyNFTRuleset(scriptPath string) (string, error) {
cmd := exec.Command("nft", "--debug=all", "-f", scriptPath)
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}
func listNFTRuleset() (string, error) {
cmd := exec.Command("nft", "list", "ruleset")
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}
func flushNFTRuleset() error {
cmd := exec.Command("nft", "flush", "ruleset")
return cmd.Run()
}

View File

@ -1,5 +0,0 @@
table inet test-table {
chain test-chain {
type nat hook output priority dstnat; policy accept;
}
}

View File

@ -1,18 +0,0 @@
table inet test-table {
set test-set {
type ifname
elements = { "dummy0" }
}
flowtable test-flowtable {
hook ingress priority filter + 5
devices = { dummy0 }
}
chain test-chain {
type filter hook forward priority mangle; policy accept;
iifname != @test-set return
oifname != @test-set return
ct state established ct packets > 20 flow add @test-flowtable counter packets 0 bytes 0
}
}

View File

@ -1,2 +0,0 @@
table inet test-table {
}

View File

@ -1442,7 +1442,7 @@ func TestSecMarkMarshaling(t *testing.T) {
conn.AddObj(sec)
if err := conn.Flush(); err != nil {
t.Fatal(err.Error())
t.Fatalf(err.Error())
}
}
@ -1492,7 +1492,7 @@ func TestSynProxyObject(t *testing.T) {
conn.AddObj(syn2)
conn.AddObj(syn3)
if err := conn.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
objs, err := conn.GetNamedObjects(table)
@ -1637,7 +1637,7 @@ func TestCtTimeout(t *testing.T) {
})
if err := conn.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
obj, err := conn.GetObject(ctt1)
@ -1693,7 +1693,7 @@ func TestCtExpect(t *testing.T) {
conn.AddObj(cte)
if err := conn.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
objs, err := conn.GetNamedObjects(table)
@ -1758,7 +1758,7 @@ func TestCtHelper(t *testing.T) {
})
if err := conn.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
obj1, err := conn.GetObject(cthelp1)
@ -2590,7 +2590,7 @@ func TestGetResetNamedObj(t *testing.T) {
})
if err := c.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
objsNamed, err := c.GetNamedObjects(table)
@ -2698,7 +2698,7 @@ func TestObjAPI(t *testing.T) {
})
if err := c.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
objs, err := c.GetObjects(table)
@ -3049,7 +3049,7 @@ func TestObjAPICounterLegacyType(t *testing.T) {
})
if err := c.Flush(); err != nil {
t.Fatal(err)
t.Fatalf(err.Error())
}
objs, err := c.GetObjects(table)
@ -4103,46 +4103,6 @@ func TestSetElementsInterval(t *testing.T) {
}
}
func TestSetSizeConcat(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv6,
Name: "filter",
})
set := &nftables.Set{
Name: "test-set",
Table: filter,
KeyType: nftables.MustConcatSetType(nftables.TypeIP6Addr, nftables.TypeInetService, nftables.TypeIP6Addr),
Dynamic: true,
Concatenation: true,
Size: 200,
}
if err := c.AddSet(set, nil); err != nil {
t.Errorf("c.AddSet(set) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
sets, err := c.GetSets(filter)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(sets) != 1 {
t.Fatalf("len(sets) = %d, want 1", len(sets))
}
}
func TestCreateListFlowtable(t *testing.T) {
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
@ -6427,61 +6387,6 @@ func TestFib(t *testing.T) {
}
}
func TestFibSystem(t *testing.T) {
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
chain := c.AddChain(&nftables.Chain{
Name: "test-chain",
Table: filter,
})
expect := &expr.Fib{
Register: 1,
FlagDADDR: true,
ResultADDRTYPE: true,
}
c.AddRule(&nftables.Rule{
Table: filter,
Chain: chain,
Exprs: []expr.Any{expect},
})
if err := c.Flush(); err != nil {
t.Fatalf("c.Flush() failed with error %+v", err)
}
rules, err := c.GetRules(filter, chain)
if err != nil {
t.Fatalf("GetRules failed: %v", err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 1; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
fib := rules[0].Exprs[0].(*expr.Fib)
if got, want := fib.FlagDADDR, expect.FlagDADDR; got != want {
t.Errorf("fib daddr not equal: got %+v, want %+v", got, want)
}
if got, want := fib.ResultADDRTYPE, expect.ResultADDRTYPE; got != want {
t.Errorf("fib addr type not equal: got %+v, want %+v", got, want)
}
}
func TestNumgen(t *testing.T) {
tests := []struct {
name string
@ -7911,74 +7816,3 @@ func TestNftablesDeadlock(t *testing.T) {
})
}
}
func TestSetElementComment(t *testing.T) {
// Create a new network namespace to test these operations
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()
// Add a new table
table := &nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
}
conn.AddTable(table)
// Create a new set
set := &nftables.Set{
Name: "test-set",
Table: table,
KeyType: nftables.TypeIPAddr,
}
// Create set elements with comments
elements := []nftables.SetElement{
{
Key: net.ParseIP("192.0.2.1").To4(),
Comment: "First IP address",
},
{
Key: net.ParseIP("192.0.2.2").To4(),
Comment: "Second IP address",
},
}
// Add the set with elements
if err := conn.AddSet(set, elements); err != nil {
t.Fatalf("failed to add set: %v", err)
}
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
// Get the set elements back and verify comments
gotElements, err := conn.GetSetElements(set)
if err != nil {
t.Fatalf("failed to get set elements: %v", err)
}
if got, want := len(gotElements), len(elements); got != want {
t.Fatalf("got %d elements, want %d", got, want)
}
// Create maps to compare elements by their IP addresses
wantMap := make(map[string]string)
for _, elem := range elements {
wantMap[string(elem.Key)] = elem.Comment
}
gotMap := make(map[string]string)
for _, elem := range gotElements {
gotMap[string(elem.Key)] = elem.Comment
}
// Compare the comments for each IP
for ip, wantComment := range wantMap {
if gotComment, ok := gotMap[ip]; !ok {
t.Errorf("IP %s not found in retrieved elements", ip)
} else if gotComment != wantComment {
t.Errorf("for IP %s: got comment %q, want comment %q", ip, gotComment, wantComment)
}
}
}

View File

@ -36,7 +36,7 @@ func (q *QuotaObj) unmarshal(ad *netlink.AttributeDecoder) error {
case unix.NFTA_QUOTA_CONSUMED:
q.Consumed = ad.Uint64()
case unix.NFTA_QUOTA_FLAGS:
q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) != 0
q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) == 1
}
}
return nil

56
set.go
View File

@ -267,8 +267,6 @@ type Set struct {
// https://git.netfilter.org/nftables/tree/include/datatype.h?id=d486c9e626405e829221b82d7355558005b26d8a#n109
KeyByteOrder binaryutil.ByteOrder
Comment string
// Indicates that the set has "size" specifier
Size uint32
}
// SetElement represents a data point within a set.
@ -290,7 +288,6 @@ type SetElement struct {
Expires time.Duration
Counter *expr.Counter
Comment string
}
func (s *SetElement) decode(fam byte) func(b []byte) error {
@ -325,12 +322,6 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
s.Timeout = time.Millisecond * time.Duration(ad.Uint64())
case unix.NFTA_SET_ELEM_EXPIRATION:
s.Expires = time.Millisecond * time.Duration(ad.Uint64())
case unix.NFTA_SET_ELEM_USERDATA:
userData := ad.Bytes()
// Try to extract comment from userdata if present
if comment, ok := userdata.GetString(userData, userdata.NFTNL_UDATA_SET_ELEM_COMMENT); ok {
s.Comment = comment
}
case unix.NFTA_SET_ELEM_EXPR:
elems, err := parseexprfunc.ParseExprBytesFunc(fam, ad)
if err != nil {
@ -463,12 +454,6 @@ func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, e
// If niether of previous cases matche, it means 'e' is an element of a regular Set, no need to add to the attributes
}
// Add comment to userdata if present
if len(v.Comment) > 0 {
userData := userdata.AppendString(nil, userdata.NFTNL_UDATA_SET_ELEM_COMMENT, v.Comment)
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_USERDATA, Data: userData})
}
encodedItem, err := netlink.MarshalAttributes(item)
if err != nil {
return nil, fmt.Errorf("marshal item %d: %v", i, err)
@ -568,21 +553,6 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
}
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
}
var descBytes []byte
if s.Size > 0 {
// Marshal set size description
descSizeBytes, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)},
})
if err != nil {
return fmt.Errorf("fail to marshal set size description: %w", err)
}
descBytes = append(descBytes, descSizeBytes...)
}
if s.Concatenation {
// Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset
var concatDefinition []byte
@ -609,13 +579,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
if err != nil {
return fmt.Errorf("fail to marshal concat definition %v", err)
}
descBytes = append(descBytes, concatBytes...)
}
if len(descBytes) > 0 {
// Marshal set description
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: descBytes})
// Marshal concat size description as set description
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: concatBytes})
}
// https://git.netfilter.org/libnftnl/tree/include/udata.h#n17
@ -769,7 +734,6 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0
set.IsMap = (flags & unix.NFT_SET_MAP) != 0
set.HasTimeout = (flags & unix.NFT_SET_TIMEOUT) != 0
set.Dynamic = (flags & unix.NFT_SET_EVAL) != 0
set.Concatenation = (flags & NFT_SET_CONCAT) != 0
case unix.NFTA_SET_KEY_TYPE:
nftMagic := ad.Uint32()
@ -798,20 +762,6 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
data := ad.Bytes()
value, ok := userdata.GetUint32(data, userdata.NFTNL_UDATA_SET_MERGE_ELEMENTS)
set.AutoMerge = ok && value == 1
case unix.NFTA_SET_DESC:
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
if err != nil {
return nil, fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
}
for nestedAD.Next() {
switch nestedAD.Type() {
case unix.NFTA_SET_DESC_SIZE:
set.Size = binary.BigEndian.Uint32(nestedAD.Bytes())
}
}
if nestedAD.Err() != nil {
return nil, fmt.Errorf("decoding set description: %w", nestedAD.Err())
}
}
}
return &set, nil
@ -857,7 +807,6 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) {
b := ad.Bytes()
if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return nil, err
}
@ -869,7 +818,6 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) {
case unix.NFTA_LIST_ELEM:
ad.Do(elem.decode(fam))
}
elements = append(elements, elem)
}
}

View File

@ -1,11 +1,7 @@
package nftables
import (
"reflect"
"testing"
"time"
"github.com/mdlayher/netlink"
)
// unknownNFTMagic is an nftMagic value that's unhandled by this
@ -189,82 +185,3 @@ func TestConcatSetTypeElements(t *testing.T) {
})
}
}
func TestMarshalSet(t *testing.T) {
t.Parallel()
tbl := &Table{
Name: "ipv4table",
Family: TableFamilyIPv4,
}
c, err := New(WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
return req, nil
}))
if err != nil {
t.Fatal(err)
}
c.AddTable(tbl)
// Ensure the table is added.
const connMsgStart = 1
if len(c.messages) != connMsgStart {
t.Fatalf("AddSet() wrong start message count: %d, expected: %d", len(c.messages), connMsgStart)
}
tests := []struct {
name string
set Set
}{
{
name: "Set without flags",
set: Set{
Name: "test-set",
ID: uint32(1),
Table: tbl,
KeyType: TypeIPAddr,
},
},
{
name: "Set with size, timeout, dynamic flag specified",
set: Set{
Name: "test-set",
ID: uint32(2),
HasTimeout: true,
Dynamic: true,
Size: 10,
Table: tbl,
KeyType: TypeIPAddr,
Timeout: 30 * time.Second,
},
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := c.AddSet(&tt.set, nil); err != nil {
t.Fatal(err)
}
connMsgSetIdx := connMsgStart + i
if len(c.messages) != connMsgSetIdx+1 {
t.Fatalf("AddSet() wrong message count: %d, expected: %d", len(c.messages), connMsgSetIdx+1)
}
msg := c.messages[connMsgSetIdx]
nset, err := setsFromMsg(msg)
if err != nil {
t.Fatalf("setsFromMsg() error: %+v", err)
}
// Table pointer is set after flush, which is not implemented in the test.
tt.set.Table = nil
if !reflect.DeepEqual(&tt.set, nset) {
t.Fatalf("original %+v and recovered %+v Set structs are different", tt.set, nset)
}
})
}
}

View File

@ -46,12 +46,6 @@ const (
NFTNL_UDATA_SET_MAX
)
// Set element userdata types
const (
NFTNL_UDATA_SET_ELEM_COMMENT Type = iota
NFTNL_UDATA_SET_ELEM_FLAGS
)
func Append(udata []byte, typ Type, data []byte) []byte {
udata = append(udata, byte(typ), byte(len(data)))
udata = append(udata, data...)