This commit is contained in:
Nick Garlis 2025-08-18 15:27:43 +00:00 committed by GitHub
commit f577990180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 150 additions and 132 deletions

17
conn.go
View File

@ -42,7 +42,7 @@ type Conn struct {
lasting bool // establish a lasting connection to be used across multiple netlink operations. lasting bool // establish a lasting connection to be used across multiple netlink operations.
mu sync.Mutex // protects the following state mu sync.Mutex // protects the following state
messages []netlinkMessage messages []netlinkMessage
err error errs []error
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
sockOptions []SockOption sockOptions []SockOption
lastID uint32 lastID uint32
@ -237,14 +237,15 @@ func (cc *Conn) Flush() error {
defer func() { defer func() {
cc.messages = nil cc.messages = nil
cc.allocatedIDs = 0 cc.allocatedIDs = 0
cc.errs = nil
cc.mu.Unlock() cc.mu.Unlock()
}() }()
if len(cc.messages) == 0 { if len(cc.messages) == 0 {
// Messages were already programmed, returning nil // Messages were already programmed, returning nil
return nil return nil
} }
if cc.err != nil { if len(cc.errs) > 0 {
return cc.err // serialization error return errors.Join(cc.errs...)
} }
conn, closer, err := cc.netlinkConnUnderLock() conn, closer, err := cc.netlinkConnUnderLock()
if err != nil { if err != nil {
@ -363,17 +364,17 @@ func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
return conn, nil return conn, nil
} }
func (cc *Conn) setErr(err error) { func (cc *Conn) appendErr(err error) {
if cc.err != nil { if err == nil {
return return
} }
cc.err = err cc.errs = append(cc.errs, err)
} }
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
b, err := netlink.MarshalAttributes(attrs) b, err := netlink.MarshalAttributes(attrs)
if err != nil { if err != nil {
cc.setErr(err) cc.appendErr(err)
return nil return nil
} }
return b return b
@ -382,7 +383,7 @@ func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
b, err := expr.Marshal(fam, e) b, err := expr.Marshal(fam, e)
if err != nil { if err != nil {
cc.setErr(err) cc.appendErr(err)
return nil return nil
} }
return b return b

View File

@ -111,9 +111,7 @@ func TestNFTables(t *testing.T) {
}) })
} }
if err := c.AddSet(devicesSet, elements); err != nil { c.AddSet(devicesSet, elements)
t.Errorf("failed to add Set %s : %v", devicesSet.Name, err)
}
flowtable := &nftables.Flowtable{ flowtable := &nftables.Flowtable{
Table: table, Table: table,

View File

@ -290,9 +290,7 @@ func TestRuleHandle(t *testing.T) {
{ {
Name: "delete-rule", Name: "delete-rule",
Func: func() { Func: func() {
if err := c.DelRule(rule2); err != nil { c.DelRule(rule2)
t.Errorf("DelRule failed: %v", err)
}
}, },
Expect: []string{"1", "3"}, Expect: []string{"1", "3"},
}, },
@ -3335,12 +3333,10 @@ func TestCreateUseAnonymousSet(t *testing.T) {
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(set, []nftables.SetElement{ c.AddSet(set, []nftables.SetElement{
{Key: binaryutil.BigEndian.PutUint16(69)}, {Key: binaryutil.BigEndian.PutUint16(69)},
{Key: binaryutil.BigEndian.PutUint16(1163)}, {Key: binaryutil.BigEndian.PutUint16(1163)},
}); err != nil { })
t.Errorf("c.AddSet() failed: %v", err)
}
c.AddRule(&nftables.Rule{ c.AddRule(&nftables.Rule{
Table: filter, Table: filter,
@ -3414,9 +3410,7 @@ func TestCappedErrMsgOnSets(t *testing.T) {
Name: "if_set", Name: "if_set",
KeyType: nftables.TypeIFName, KeyType: nftables.TypeIFName,
} }
if err := c.AddSet(ifSet, nil); err != nil { c.AddSet(ifSet, nil)
t.Errorf("c.AddSet(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("failed adding set ifSet: %v", err) t.Errorf("failed adding set ifSet: %v", err)
} }
@ -3437,9 +3431,7 @@ func TestCappedErrMsgOnSets(t *testing.T) {
elements := []nftables.SetElement{ elements := []nftables.SetElement{
{Key: []byte("012345678912345\x00")}, {Key: []byte("012345678912345\x00")},
} }
if err := c.SetAddElements(ifSet, elements); err != nil { c.SetAddElements(ifSet, elements)
t.Errorf("adding SetElements(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("failed adding set elements ifSet: %v", err) t.Errorf("failed adding set elements ifSet: %v", err)
} }
@ -3477,24 +3469,16 @@ func TestCreateUseNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
ipSet := &nftables.Set{ ipSet := &nftables.Set{
Table: filter, Table: filter,
Name: "IPs_4_dayz", Name: "IPs_4_dayz",
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
} }
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}})
t.Errorf("c.AddSet(ipSet) failed: %v", err) c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}})
}
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
t.Errorf("c.SetVal(ipSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3535,12 +3519,8 @@ func TestCreateAutoMergeSet(t *testing.T) {
Interval: true, Interval: true,
AutoMerge: true, AutoMerge: true,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
ipSet := &nftables.Set{ ipSet := &nftables.Set{
Table: filter, Table: filter,
@ -3549,12 +3529,8 @@ func TestCreateAutoMergeSet(t *testing.T) {
Interval: true, Interval: true,
AutoMerge: true, AutoMerge: true,
} }
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}})
t.Errorf("c.AddSet(ipSet) failed: %v", err) c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}})
}
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
t.Errorf("c.SetVal(ipSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3592,15 +3568,11 @@ func TestIP6SetAddElements(t *testing.T) {
Name: "ports", Name: "ports",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{
}
if err := c.SetAddElements(portSet, []nftables.SetElement{
{Key: binaryutil.BigEndian.PutUint16(22)}, {Key: binaryutil.BigEndian.PutUint16(22)},
{Key: binaryutil.BigEndian.PutUint16(80)}, {Key: binaryutil.BigEndian.PutUint16(80)},
}); err != nil { })
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -3648,9 +3620,7 @@ func TestSetElementBatching(t *testing.T) {
elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i)) elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i))
elements[i].Comment = "0123456789" elements[i].Comment = "0123456789"
} }
if err := c.AddSet(portSet, elements); err != nil { c.AddSet(portSet, elements)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3694,12 +3664,8 @@ func TestCreateUseCounterSet(t *testing.T) {
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
Counter: true, Counter: true,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -3736,9 +3702,7 @@ func TestCreateDeleteNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3777,9 +3741,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}})
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3824,9 +3786,7 @@ func TestFlushNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}})
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3866,20 +3826,16 @@ func TestSetElementsInterval(t *testing.T) {
Interval: true, Interval: true,
Concatenation: true, Concatenation: true,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
// elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } // elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
if err := c.SetAddElements(portSet, []nftables.SetElement{ c.SetAddElements(portSet, []nftables.SetElement{
{Key: keyBytes, KeyEnd: keyEndBytes}, {Key: keyBytes, KeyEnd: keyEndBytes},
}); err != nil { })
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -3939,9 +3895,7 @@ func TestSetSizeConcat(t *testing.T) {
Size: 200, Size: 200,
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(set) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -4550,9 +4504,7 @@ func TestGetLookupExprDestSet(t *testing.T) {
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
DataType: nftables.TypeVerdict, DataType: nftables.TypeVerdict,
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(set) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4647,9 +4599,7 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4776,9 +4726,8 @@ func TestDynset(t *testing.T) {
HasTimeout: true, HasTimeout: true,
Timeout: time.Duration(600 * time.Second), Timeout: time.Duration(600 * time.Second),
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4866,9 +4815,7 @@ func TestDynsetWithOneExpression(t *testing.T) {
} }
c.AddTable(table) c.AddTable(table)
c.AddChain(chain) c.AddChain(chain)
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(myMeter) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4968,9 +4915,7 @@ func TestDynsetWithMultipleExpressions(t *testing.T) {
} }
c.AddTable(table) c.AddTable(table)
c.AddChain(chain) c.AddChain(chain)
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(myMeter) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -5612,9 +5557,7 @@ func TestSet4(t *testing.T) {
setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i]) setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i])
} }
if err := c.AddSet(&set, setElements); err != nil { c.AddSet(&set, setElements)
t.Fatal(err)
}
c.AddRule(&nftables.Rule{ c.AddRule(&nftables.Rule{
Table: tbl, Table: tbl,
@ -5653,15 +5596,13 @@ func TestSetComment(t *testing.T) {
Name: "filter", Name: "filter",
}) })
if err := c.AddSet(&nftables.Set{ c.AddSet(&nftables.Set{
ID: 2, ID: 2,
Table: filter, Table: filter,
Name: "setname", Name: "setname",
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
Comment: "test comment", Comment: "test comment",
}, nil); err != nil { }, nil)
t.Fatal(err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Fatal(err) t.Fatal(err)
@ -7359,9 +7300,8 @@ func TestSetElementComment(t *testing.T) {
} }
// Add the set with elements // Add the set with elements
if err := conn.AddSet(set, elements); err != nil { conn.AddSet(set, elements)
t.Fatalf("failed to add set: %v", err)
}
if err := conn.Flush(); err != nil { if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err) t.Fatalf("failed to flush: %v", err)
} }
@ -7429,3 +7369,68 @@ func TestAutoBufferSize(t *testing.T) {
t.Fatalf("failed to flush: %v", err) t.Fatalf("failed to flush: %v", err)
} }
} }
func TestDeferredErrors(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
table := conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Anonymous sets have to be constant. Adding this set should queue an error.
set := &nftables.Set{
KeyType: nftables.TypeInetService,
Table: table,
Anonymous: true,
Constant: false,
}
conn.AddSet(set, []nftables.SetElement{
{Key: binaryutil.BigEndian.PutUint16(80)},
{Key: binaryutil.BigEndian.PutUint16(443)},
})
chain := conn.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
})
conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetID: set.ID,
},
},
})
if err := conn.Flush(); err == nil {
t.Error("expected error when adding a non-constant anonymous set, got nil")
} else {
var errno syscall.Errno
if errors.As(err, &errno) {
t.Errorf("expected error to be not syscall.Errno, got %v", errno)
}
}
table, err := conn.ListTable("test-table")
if table != nil || !errors.Is(err, syscall.ENOENT) {
t.Error("expected table to not exist")
}
}

2
obj.go
View File

@ -111,7 +111,7 @@ func (cc *Conn) AddObj(o Obj) Obj {
defer cc.mu.Unlock() defer cc.mu.Unlock()
data, err := expr.MarshalExprData(byte(o.family()), o.data()) data, err := expr.MarshalExprData(byte(o.family()), o.data())
if err != nil { if err != nil {
cc.setErr(err) cc.appendErr(err)
return nil return nil
} }

View File

@ -152,7 +152,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
})...) })...)
if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil {
cc.setErr(err) cc.appendErr(err)
} else if compatPolicy != nil { } else if compatPolicy != nil {
data = append(data, cc.marshalAttr([]netlink.Attribute{ data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{
@ -256,7 +256,7 @@ func (cc *Conn) InsertRule(r *Rule) *Rule {
// DelRule deletes the specified Rule. Either the Handle or ID of the // DelRule deletes the specified Rule. Either the Handle or ID of the
// rule must be set. // rule must be set.
func (cc *Conn) DelRule(r *Rule) error { func (cc *Conn) DelRule(r *Rule) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{ data := cc.marshalAttr([]netlink.Attribute{
@ -272,7 +272,8 @@ func (cc *Conn) DelRule(r *Rule) error {
{Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)}, {Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)},
})...) })...)
} else { } else {
return fmt.Errorf("rule must have a handle or ID") cc.appendErr(fmt.Errorf("rule must have a handle or ID"))
return
} }
flags := netlink.Request | netlink.Acknowledge flags := netlink.Request | netlink.Acknowledge
@ -283,8 +284,6 @@ func (cc *Conn) DelRule(r *Rule) error {
}, },
Data: append(extraHeader(uint8(r.Table.Family), 0), data...), Data: append(extraHeader(uint8(r.Table.Family), 0), data...),
}) })
return nil
} }
func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) { func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {

47
set.go
View File

@ -372,23 +372,32 @@ func decodeElement(d []byte) ([]byte, error) {
} }
// SetAddElements applies data points to an nftables set. // SetAddElements applies data points to an nftables set.
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { func (cc *Conn) SetAddElements(s *Set, vals []SetElement) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if s.Anonymous { if s.Anonymous {
return errors.New("anonymous sets cannot be updated") cc.appendErr(errors.New("anonymous sets cannot be updated"))
return
}
err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
if err != nil {
cc.appendErr(err)
} }
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
} }
// SetDeleteElements deletes data points from an nftables set. // SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if s.Anonymous { if s.Anonymous {
return errors.New("anonymous sets cannot be updated") cc.appendErr(errors.New("anonymous sets cannot be updated"))
return
}
err := cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
if err != nil {
cc.appendErr(err)
} }
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
} }
// maxElemBatchSize is the maximum size in bytes of encoded set elements which // maxElemBatchSize is the maximum size in bytes of encoded set elements which
@ -518,7 +527,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
} }
// AddSet adds the specified Set. // AddSet adds the specified Set.
func (cc *Conn) AddSet(s *Set, vals []SetElement) error { func (cc *Conn) AddSet(s *Set, vals []SetElement) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
// Based on nft implementation & linux source. // Based on nft implementation & linux source.
@ -526,7 +535,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
// Another reference: https://git.netfilter.org/nftables/tree/src // Another reference: https://git.netfilter.org/nftables/tree/src
if s.Anonymous && !s.Constant { if s.Anonymous && !s.Constant {
return errors.New("anonymous structs must be constant") cc.appendErr(errors.New("anonymous structs must be constant"))
return
} }
if s.ID == 0 { if s.ID == 0 {
@ -591,7 +601,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) cc.appendErr(fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err))
return
} }
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
} }
@ -604,7 +615,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)}, {Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal set size description: %w", err) cc.appendErr(fmt.Errorf("fail to marshal set size description: %w", err))
return
} }
descBytes = append(descBytes, descSizeBytes...) descBytes = append(descBytes, descSizeBytes...)
@ -620,21 +632,24 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal element key size %d: %v", i, err) cc.appendErr(fmt.Errorf("fail to marshal element key size %d: %v", i, err))
return
} }
// Marshal base type size description // Marshal base type size description
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, {Type: unix.NFTA_SET_DESC_SIZE, Data: valData},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal base type size description: %w", err) cc.appendErr(fmt.Errorf("fail to marshal base type size description: %w", err))
return
} }
concatDefinition = append(concatDefinition, descSize...) concatDefinition = append(concatDefinition, descSize...)
} }
// Marshal all base type descriptions into concatenation size description // Marshal all base type descriptions into concatenation size description
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal concat definition %v", err) cc.appendErr(fmt.Errorf("fail to marshal concat definition %v", err))
return
} }
descBytes = append(descBytes, concatBytes...) descBytes = append(descBytes, concatBytes...)
@ -675,7 +690,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}}, {Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}},
}) })
if err != nil { if err != nil {
return err cc.appendErr(err)
return
} }
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data})
} }
@ -689,7 +705,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
}) })
// Set the values of the set if initial values were provided. // Set the values of the set if initial values were provided.
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
cc.appendErr(err)
} }
// DelSet deletes a specific set, along with all elements it contains. // DelSet deletes a specific set, along with all elements it contains.

View File

@ -268,9 +268,7 @@ func TestMarshalSet(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := c.AddSet(&tt.set, nil); err != nil { c.AddSet(&tt.set, nil)
t.Fatal(err)
}
connMsgSetIdx := connMsgStart + i connMsgSetIdx := connMsgStart + i
if len(c.messages) != connMsgSetIdx+1 { if len(c.messages) != connMsgSetIdx+1 {