Defer errors until Flush to avoid incomplete batches

Any create/update/delete operation that returns a validation or
marshalling error can leave the message batch in an incomplete state
due to short-circuiting. This can result in either:

  - Non-atomic transactions if Flush is called (incomplete batch)
  - Users being unable to clear the incomplete batch (no API exposed)

This change ensures that errors are collected and deferred until Flush.
Instead of returning immediately, the following methods now append
errors to a slice checked at Flush:

  - AddSet
  - DelRule
  - SetAddElements

See: https://github.com/google/nftables/issues/323
This commit is contained in:
nickgarlis 2025-08-18 16:50:14 +02:00
parent 508bb1ffd4
commit d0b38630ac
7 changed files with 150 additions and 132 deletions

17
conn.go
View File

@ -42,7 +42,7 @@ type Conn struct {
lasting bool // establish a lasting connection to be used across multiple netlink operations. lasting bool // establish a lasting connection to be used across multiple netlink operations.
mu sync.Mutex // protects the following state mu sync.Mutex // protects the following state
messages []netlinkMessage messages []netlinkMessage
err error errs []error
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
sockOptions []SockOption sockOptions []SockOption
lastID uint32 lastID uint32
@ -237,14 +237,15 @@ func (cc *Conn) Flush() error {
defer func() { defer func() {
cc.messages = nil cc.messages = nil
cc.allocatedIDs = 0 cc.allocatedIDs = 0
cc.errs = nil
cc.mu.Unlock() cc.mu.Unlock()
}() }()
if len(cc.messages) == 0 { if len(cc.messages) == 0 {
// Messages were already programmed, returning nil // Messages were already programmed, returning nil
return nil return nil
} }
if cc.err != nil { if len(cc.errs) > 0 {
return cc.err // serialization error return errors.Join(cc.errs...)
} }
conn, closer, err := cc.netlinkConnUnderLock() conn, closer, err := cc.netlinkConnUnderLock()
if err != nil { if err != nil {
@ -363,17 +364,17 @@ func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
return conn, nil return conn, nil
} }
func (cc *Conn) setErr(err error) { func (cc *Conn) appendErr(err error) {
if cc.err != nil { if err == nil {
return return
} }
cc.err = err cc.errs = append(cc.errs, err)
} }
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
b, err := netlink.MarshalAttributes(attrs) b, err := netlink.MarshalAttributes(attrs)
if err != nil { if err != nil {
cc.setErr(err) cc.appendErr(err)
return nil return nil
} }
return b return b
@ -382,7 +383,7 @@ func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
b, err := expr.Marshal(fam, e) b, err := expr.Marshal(fam, e)
if err != nil { if err != nil {
cc.setErr(err) cc.appendErr(err)
return nil return nil
} }
return b return b

View File

@ -111,9 +111,7 @@ func TestNFTables(t *testing.T) {
}) })
} }
if err := c.AddSet(devicesSet, elements); err != nil { c.AddSet(devicesSet, elements)
t.Errorf("failed to add Set %s : %v", devicesSet.Name, err)
}
flowtable := &nftables.Flowtable{ flowtable := &nftables.Flowtable{
Table: table, Table: table,

View File

@ -290,9 +290,7 @@ func TestRuleHandle(t *testing.T) {
{ {
Name: "delete-rule", Name: "delete-rule",
Func: func() { Func: func() {
if err := c.DelRule(rule2); err != nil { c.DelRule(rule2)
t.Errorf("DelRule failed: %v", err)
}
}, },
Expect: []string{"1", "3"}, Expect: []string{"1", "3"},
}, },
@ -3335,12 +3333,10 @@ func TestCreateUseAnonymousSet(t *testing.T) {
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(set, []nftables.SetElement{ c.AddSet(set, []nftables.SetElement{
{Key: binaryutil.BigEndian.PutUint16(69)}, {Key: binaryutil.BigEndian.PutUint16(69)},
{Key: binaryutil.BigEndian.PutUint16(1163)}, {Key: binaryutil.BigEndian.PutUint16(1163)},
}); err != nil { })
t.Errorf("c.AddSet() failed: %v", err)
}
c.AddRule(&nftables.Rule{ c.AddRule(&nftables.Rule{
Table: filter, Table: filter,
@ -3414,9 +3410,7 @@ func TestCappedErrMsgOnSets(t *testing.T) {
Name: "if_set", Name: "if_set",
KeyType: nftables.TypeIFName, KeyType: nftables.TypeIFName,
} }
if err := c.AddSet(ifSet, nil); err != nil { c.AddSet(ifSet, nil)
t.Errorf("c.AddSet(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("failed adding set ifSet: %v", err) t.Errorf("failed adding set ifSet: %v", err)
} }
@ -3437,9 +3431,7 @@ func TestCappedErrMsgOnSets(t *testing.T) {
elements := []nftables.SetElement{ elements := []nftables.SetElement{
{Key: []byte("012345678912345\x00")}, {Key: []byte("012345678912345\x00")},
} }
if err := c.SetAddElements(ifSet, elements); err != nil { c.SetAddElements(ifSet, elements)
t.Errorf("adding SetElements(ifSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("failed adding set elements ifSet: %v", err) t.Errorf("failed adding set elements ifSet: %v", err)
} }
@ -3477,24 +3469,16 @@ func TestCreateUseNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
ipSet := &nftables.Set{ ipSet := &nftables.Set{
Table: filter, Table: filter,
Name: "IPs_4_dayz", Name: "IPs_4_dayz",
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
} }
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}})
t.Errorf("c.AddSet(ipSet) failed: %v", err) c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}})
}
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
t.Errorf("c.SetVal(ipSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3535,12 +3519,8 @@ func TestCreateAutoMergeSet(t *testing.T) {
Interval: true, Interval: true,
AutoMerge: true, AutoMerge: true,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
ipSet := &nftables.Set{ ipSet := &nftables.Set{
Table: filter, Table: filter,
@ -3549,12 +3529,8 @@ func TestCreateAutoMergeSet(t *testing.T) {
Interval: true, Interval: true,
AutoMerge: true, AutoMerge: true,
} }
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}})
t.Errorf("c.AddSet(ipSet) failed: %v", err) c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}})
}
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
t.Errorf("c.SetVal(ipSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3592,15 +3568,11 @@ func TestIP6SetAddElements(t *testing.T) {
Name: "ports", Name: "ports",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{
}
if err := c.SetAddElements(portSet, []nftables.SetElement{
{Key: binaryutil.BigEndian.PutUint16(22)}, {Key: binaryutil.BigEndian.PutUint16(22)},
{Key: binaryutil.BigEndian.PutUint16(80)}, {Key: binaryutil.BigEndian.PutUint16(80)},
}); err != nil { })
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -3648,9 +3620,7 @@ func TestSetElementBatching(t *testing.T) {
elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i)) elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i))
elements[i].Comment = "0123456789" elements[i].Comment = "0123456789"
} }
if err := c.AddSet(portSet, elements); err != nil { c.AddSet(portSet, elements)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3694,12 +3664,8 @@ func TestCreateUseCounterSet(t *testing.T) {
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
Counter: true, Counter: true,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err) c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
}
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -3736,9 +3702,7 @@ func TestCreateDeleteNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3777,9 +3741,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}})
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3824,9 +3786,7 @@ func TestFlushNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}})
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -3866,20 +3826,16 @@ func TestSetElementsInterval(t *testing.T) {
Interval: true, Interval: true,
Concatenation: true, Concatenation: true,
} }
if err := c.AddSet(portSet, nil); err != nil { c.AddSet(portSet, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
// elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } // elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
if err := c.SetAddElements(portSet, []nftables.SetElement{ c.SetAddElements(portSet, []nftables.SetElement{
{Key: keyBytes, KeyEnd: keyEndBytes}, {Key: keyBytes, KeyEnd: keyEndBytes},
}); err != nil { })
t.Errorf("c.SetVal(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -3939,9 +3895,7 @@ func TestSetSizeConcat(t *testing.T) {
Size: 200, Size: 200,
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(set) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
@ -4550,9 +4504,7 @@ func TestGetLookupExprDestSet(t *testing.T) {
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
DataType: nftables.TypeVerdict, DataType: nftables.TypeVerdict,
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(set) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4647,9 +4599,7 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4776,9 +4726,8 @@ func TestDynset(t *testing.T) {
HasTimeout: true, HasTimeout: true,
Timeout: time.Duration(600 * time.Second), Timeout: time.Duration(600 * time.Second),
} }
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4866,9 +4815,7 @@ func TestDynsetWithOneExpression(t *testing.T) {
} }
c.AddTable(table) c.AddTable(table)
c.AddChain(chain) c.AddChain(chain)
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(myMeter) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -4968,9 +4915,7 @@ func TestDynsetWithMultipleExpressions(t *testing.T) {
} }
c.AddTable(table) c.AddTable(table)
c.AddChain(chain) c.AddChain(chain)
if err := c.AddSet(set, nil); err != nil { c.AddSet(set, nil)
t.Errorf("c.AddSet(myMeter) failed: %v", err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err) t.Errorf("c.Flush() failed: %v", err)
} }
@ -5612,9 +5557,7 @@ func TestSet4(t *testing.T) {
setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i]) setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i])
} }
if err := c.AddSet(&set, setElements); err != nil { c.AddSet(&set, setElements)
t.Fatal(err)
}
c.AddRule(&nftables.Rule{ c.AddRule(&nftables.Rule{
Table: tbl, Table: tbl,
@ -5653,15 +5596,13 @@ func TestSetComment(t *testing.T) {
Name: "filter", Name: "filter",
}) })
if err := c.AddSet(&nftables.Set{ c.AddSet(&nftables.Set{
ID: 2, ID: 2,
Table: filter, Table: filter,
Name: "setname", Name: "setname",
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
Comment: "test comment", Comment: "test comment",
}, nil); err != nil { }, nil)
t.Fatal(err)
}
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
t.Fatal(err) t.Fatal(err)
@ -7359,9 +7300,8 @@ func TestSetElementComment(t *testing.T) {
} }
// Add the set with elements // Add the set with elements
if err := conn.AddSet(set, elements); err != nil { conn.AddSet(set, elements)
t.Fatalf("failed to add set: %v", err)
}
if err := conn.Flush(); err != nil { if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err) t.Fatalf("failed to flush: %v", err)
} }
@ -7429,3 +7369,68 @@ func TestAutoBufferSize(t *testing.T) {
t.Fatalf("failed to flush: %v", err) t.Fatalf("failed to flush: %v", err)
} }
} }
func TestDeferredErrors(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
table := conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Anonymous sets have to be constant. Adding this set should queue an error.
set := &nftables.Set{
KeyType: nftables.TypeInetService,
Table: table,
Anonymous: true,
Constant: false,
}
conn.AddSet(set, []nftables.SetElement{
{Key: binaryutil.BigEndian.PutUint16(80)},
{Key: binaryutil.BigEndian.PutUint16(443)},
})
chain := conn.AddChain(&nftables.Chain{
Name: "test-chain",
Table: table,
})
conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetID: set.ID,
},
},
})
if err := conn.Flush(); err == nil {
t.Error("expected error when adding a non-constant anonymous set, got nil")
} else {
var errno syscall.Errno
if errors.As(err, &errno) {
t.Errorf("expected error to be not syscall.Errno, got %v", errno)
}
}
table, err := conn.ListTable("test-table")
if table != nil || !errors.Is(err, syscall.ENOENT) {
t.Error("expected table to not exist")
}
}

2
obj.go
View File

@ -111,7 +111,7 @@ func (cc *Conn) AddObj(o Obj) Obj {
defer cc.mu.Unlock() defer cc.mu.Unlock()
data, err := expr.MarshalExprData(byte(o.family()), o.data()) data, err := expr.MarshalExprData(byte(o.family()), o.data())
if err != nil { if err != nil {
cc.setErr(err) cc.appendErr(err)
return nil return nil
} }

View File

@ -152,7 +152,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
})...) })...)
if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil {
cc.setErr(err) cc.appendErr(err)
} else if compatPolicy != nil { } else if compatPolicy != nil {
data = append(data, cc.marshalAttr([]netlink.Attribute{ data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{
@ -256,7 +256,7 @@ func (cc *Conn) InsertRule(r *Rule) *Rule {
// DelRule deletes the specified Rule. Either the Handle or ID of the // DelRule deletes the specified Rule. Either the Handle or ID of the
// rule must be set. // rule must be set.
func (cc *Conn) DelRule(r *Rule) error { func (cc *Conn) DelRule(r *Rule) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{ data := cc.marshalAttr([]netlink.Attribute{
@ -272,7 +272,8 @@ func (cc *Conn) DelRule(r *Rule) error {
{Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)}, {Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)},
})...) })...)
} else { } else {
return fmt.Errorf("rule must have a handle or ID") cc.appendErr(fmt.Errorf("rule must have a handle or ID"))
return
} }
flags := netlink.Request | netlink.Acknowledge flags := netlink.Request | netlink.Acknowledge
@ -283,8 +284,6 @@ func (cc *Conn) DelRule(r *Rule) error {
}, },
Data: append(extraHeader(uint8(r.Table.Family), 0), data...), Data: append(extraHeader(uint8(r.Table.Family), 0), data...),
}) })
return nil
} }
func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) { func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {

47
set.go
View File

@ -372,23 +372,32 @@ func decodeElement(d []byte) ([]byte, error) {
} }
// SetAddElements applies data points to an nftables set. // SetAddElements applies data points to an nftables set.
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { func (cc *Conn) SetAddElements(s *Set, vals []SetElement) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if s.Anonymous { if s.Anonymous {
return errors.New("anonymous sets cannot be updated") cc.appendErr(errors.New("anonymous sets cannot be updated"))
return
}
err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
if err != nil {
cc.appendErr(err)
} }
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
} }
// SetDeleteElements deletes data points from an nftables set. // SetDeleteElements deletes data points from an nftables set.
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if s.Anonymous { if s.Anonymous {
return errors.New("anonymous sets cannot be updated") cc.appendErr(errors.New("anonymous sets cannot be updated"))
return
}
err := cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
if err != nil {
cc.appendErr(err)
} }
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
} }
// maxElemBatchSize is the maximum size in bytes of encoded set elements which // maxElemBatchSize is the maximum size in bytes of encoded set elements which
@ -518,7 +527,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
} }
// AddSet adds the specified Set. // AddSet adds the specified Set.
func (cc *Conn) AddSet(s *Set, vals []SetElement) error { func (cc *Conn) AddSet(s *Set, vals []SetElement) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
// Based on nft implementation & linux source. // Based on nft implementation & linux source.
@ -526,7 +535,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
// Another reference: https://git.netfilter.org/nftables/tree/src // Another reference: https://git.netfilter.org/nftables/tree/src
if s.Anonymous && !s.Constant { if s.Anonymous && !s.Constant {
return errors.New("anonymous structs must be constant") cc.appendErr(errors.New("anonymous structs must be constant"))
return
} }
if s.ID == 0 { if s.ID == 0 {
@ -591,7 +601,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) cc.appendErr(fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err))
return
} }
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
} }
@ -604,7 +615,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)}, {Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal set size description: %w", err) cc.appendErr(fmt.Errorf("fail to marshal set size description: %w", err))
return
} }
descBytes = append(descBytes, descSizeBytes...) descBytes = append(descBytes, descSizeBytes...)
@ -620,21 +632,24 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal element key size %d: %v", i, err) cc.appendErr(fmt.Errorf("fail to marshal element key size %d: %v", i, err))
return
} }
// Marshal base type size description // Marshal base type size description
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, {Type: unix.NFTA_SET_DESC_SIZE, Data: valData},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal base type size description: %w", err) cc.appendErr(fmt.Errorf("fail to marshal base type size description: %w", err))
return
} }
concatDefinition = append(concatDefinition, descSize...) concatDefinition = append(concatDefinition, descSize...)
} }
// Marshal all base type descriptions into concatenation size description // Marshal all base type descriptions into concatenation size description
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal concat definition %v", err) cc.appendErr(fmt.Errorf("fail to marshal concat definition %v", err))
return
} }
descBytes = append(descBytes, concatBytes...) descBytes = append(descBytes, concatBytes...)
@ -675,7 +690,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}}, {Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}},
}) })
if err != nil { if err != nil {
return err cc.appendErr(err)
return
} }
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data})
} }
@ -689,7 +705,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
}) })
// Set the values of the set if initial values were provided. // Set the values of the set if initial values were provided.
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
cc.appendErr(err)
} }
// DelSet deletes a specific set, along with all elements it contains. // DelSet deletes a specific set, along with all elements it contains.

View File

@ -268,9 +268,7 @@ func TestMarshalSet(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := c.AddSet(&tt.set, nil); err != nil { c.AddSet(&tt.set, nil)
t.Fatal(err)
}
connMsgSetIdx := connMsgStart + i connMsgSetIdx := connMsgStart + i
if len(c.messages) != connMsgSetIdx+1 { if len(c.messages) != connMsgSetIdx+1 {