From d0b38630ac6ca8de78d50fd5c5387523fc3cb80d Mon Sep 17 00:00:00 2001 From: nickgarlis Date: Mon, 18 Aug 2025 16:50:14 +0200 Subject: [PATCH] Defer errors until Flush to avoid incomplete batches Any create/update/delete operation that returns a validation or marshalling error can leave the message batch in an incomplete state due to short-circuiting. This can result in either: - Non-atomic transactions if Flush is called (incomplete batch) - Users being unable to clear the incomplete batch (no API exposed) This change ensures that errors are collected and deferred until Flush. Instead of returning immediately, the following methods now append errors to a slice checked at Flush: - AddSet - DelRule - SetAddElements See: https://github.com/google/nftables/issues/323 --- conn.go | 17 ++-- integration/nft_test.go | 4 +- nftables_test.go | 199 ++++++++++++++++++++-------------------- obj.go | 2 +- rule.go | 9 +- set.go | 47 +++++++--- set_test.go | 4 +- 7 files changed, 150 insertions(+), 132 deletions(-) diff --git a/conn.go b/conn.go index c6c85bb..a67edf6 100644 --- a/conn.go +++ b/conn.go @@ -42,7 +42,7 @@ type Conn struct { lasting bool // establish a lasting connection to be used across multiple netlink operations. mu sync.Mutex // protects the following state messages []netlinkMessage - err error + errs []error nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. sockOptions []SockOption lastID uint32 @@ -237,14 +237,15 @@ func (cc *Conn) Flush() error { defer func() { cc.messages = nil cc.allocatedIDs = 0 + cc.errs = nil cc.mu.Unlock() }() if len(cc.messages) == 0 { // Messages were already programmed, returning nil return nil } - if cc.err != nil { - return cc.err // serialization error + if len(cc.errs) > 0 { + return errors.Join(cc.errs...) } conn, closer, err := cc.netlinkConnUnderLock() if err != nil { @@ -363,17 +364,17 @@ func (cc *Conn) dialNetlink() (*netlink.Conn, error) { return conn, nil } -func (cc *Conn) setErr(err error) { - if cc.err != nil { +func (cc *Conn) appendErr(err error) { + if err == nil { return } - cc.err = err + cc.errs = append(cc.errs, err) } func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { b, err := netlink.MarshalAttributes(attrs) if err != nil { - cc.setErr(err) + cc.appendErr(err) return nil } return b @@ -382,7 +383,7 @@ func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { b, err := expr.Marshal(fam, e) if err != nil { - cc.setErr(err) + cc.appendErr(err) return nil } return b diff --git a/integration/nft_test.go b/integration/nft_test.go index 14c7d43..a9b0a84 100644 --- a/integration/nft_test.go +++ b/integration/nft_test.go @@ -111,9 +111,7 @@ func TestNFTables(t *testing.T) { }) } - if err := c.AddSet(devicesSet, elements); err != nil { - t.Errorf("failed to add Set %s : %v", devicesSet.Name, err) - } + c.AddSet(devicesSet, elements) flowtable := &nftables.Flowtable{ Table: table, diff --git a/nftables_test.go b/nftables_test.go index fe12566..e0cb329 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -290,9 +290,7 @@ func TestRuleHandle(t *testing.T) { { Name: "delete-rule", Func: func() { - if err := c.DelRule(rule2); err != nil { - t.Errorf("DelRule failed: %v", err) - } + c.DelRule(rule2) }, Expect: []string{"1", "3"}, }, @@ -3335,12 +3333,10 @@ func TestCreateUseAnonymousSet(t *testing.T) { KeyType: nftables.TypeInetService, } - if err := c.AddSet(set, []nftables.SetElement{ + c.AddSet(set, []nftables.SetElement{ {Key: binaryutil.BigEndian.PutUint16(69)}, {Key: binaryutil.BigEndian.PutUint16(1163)}, - }); err != nil { - t.Errorf("c.AddSet() failed: %v", err) - } + }) c.AddRule(&nftables.Rule{ Table: filter, @@ -3414,9 +3410,7 @@ func TestCappedErrMsgOnSets(t *testing.T) { Name: "if_set", KeyType: nftables.TypeIFName, } - if err := c.AddSet(ifSet, nil); err != nil { - t.Errorf("c.AddSet(ifSet) failed: %v", err) - } + c.AddSet(ifSet, nil) if err := c.Flush(); err != nil { t.Errorf("failed adding set ifSet: %v", err) } @@ -3437,9 +3431,7 @@ func TestCappedErrMsgOnSets(t *testing.T) { elements := []nftables.SetElement{ {Key: []byte("012345678912345\x00")}, } - if err := c.SetAddElements(ifSet, elements); err != nil { - t.Errorf("adding SetElements(ifSet) failed: %v", err) - } + c.SetAddElements(ifSet, elements) if err := c.Flush(); err != nil { t.Errorf("failed adding set elements ifSet: %v", err) } @@ -3477,24 +3469,16 @@ func TestCreateUseNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}) ipSet := &nftables.Set{ Table: filter, Name: "IPs_4_dayz", KeyType: nftables.TypeIPAddr, } - if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { - t.Errorf("c.AddSet(ipSet) failed: %v", err) - } - if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil { - t.Errorf("c.SetVal(ipSet) failed: %v", err) - } + c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}) + c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3535,12 +3519,8 @@ func TestCreateAutoMergeSet(t *testing.T) { Interval: true, AutoMerge: true, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}) ipSet := &nftables.Set{ Table: filter, @@ -3549,12 +3529,8 @@ func TestCreateAutoMergeSet(t *testing.T) { Interval: true, AutoMerge: true, } - if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { - t.Errorf("c.AddSet(ipSet) failed: %v", err) - } - if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil { - t.Errorf("c.SetVal(ipSet) failed: %v", err) - } + c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}) + c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3592,15 +3568,11 @@ func TestIP6SetAddElements(t *testing.T) { Name: "ports", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{ + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{ {Key: binaryutil.BigEndian.PutUint16(22)}, {Key: binaryutil.BigEndian.PutUint16(80)}, - }); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + }) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -3648,9 +3620,7 @@ func TestSetElementBatching(t *testing.T) { elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i)) elements[i].Comment = "0123456789" } - if err := c.AddSet(portSet, elements); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, elements) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3694,12 +3664,8 @@ func TestCreateUseCounterSet(t *testing.T) { KeyType: nftables.TypeInetService, Counter: true, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -3736,9 +3702,7 @@ func TestCreateDeleteNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3777,9 +3741,7 @@ func TestDeleteElementNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3824,9 +3786,7 @@ func TestFlushNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3866,20 +3826,16 @@ func TestSetElementsInterval(t *testing.T) { Interval: true, Concatenation: true, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} // elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } - if err := c.SetAddElements(portSet, []nftables.SetElement{ + c.SetAddElements(portSet, []nftables.SetElement{ {Key: keyBytes, KeyEnd: keyEndBytes}, - }); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + }) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -3939,9 +3895,7 @@ func TestSetSizeConcat(t *testing.T) { Size: 200, } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(set) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -4550,9 +4504,7 @@ func TestGetLookupExprDestSet(t *testing.T) { KeyType: nftables.TypeInetService, DataType: nftables.TypeVerdict, } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(set) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4647,9 +4599,7 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4776,9 +4726,8 @@ func TestDynset(t *testing.T) { HasTimeout: true, Timeout: time.Duration(600 * time.Second), } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(set, nil) + if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4866,9 +4815,7 @@ func TestDynsetWithOneExpression(t *testing.T) { } c.AddTable(table) c.AddChain(chain) - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(myMeter) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4968,9 +4915,7 @@ func TestDynsetWithMultipleExpressions(t *testing.T) { } c.AddTable(table) c.AddChain(chain) - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(myMeter) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -5612,9 +5557,7 @@ func TestSet4(t *testing.T) { setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i]) } - if err := c.AddSet(&set, setElements); err != nil { - t.Fatal(err) - } + c.AddSet(&set, setElements) c.AddRule(&nftables.Rule{ Table: tbl, @@ -5653,15 +5596,13 @@ func TestSetComment(t *testing.T) { Name: "filter", }) - if err := c.AddSet(&nftables.Set{ + c.AddSet(&nftables.Set{ ID: 2, Table: filter, Name: "setname", KeyType: nftables.TypeIPAddr, Comment: "test comment", - }, nil); err != nil { - t.Fatal(err) - } + }, nil) if err := c.Flush(); err != nil { t.Fatal(err) @@ -7359,9 +7300,8 @@ func TestSetElementComment(t *testing.T) { } // Add the set with elements - if err := conn.AddSet(set, elements); err != nil { - t.Fatalf("failed to add set: %v", err) - } + conn.AddSet(set, elements) + if err := conn.Flush(); err != nil { t.Fatalf("failed to flush: %v", err) } @@ -7429,3 +7369,68 @@ func TestAutoBufferSize(t *testing.T) { t.Fatalf("failed to flush: %v", err) } } + +func TestDeferredErrors(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Anonymous sets have to be constant. Adding this set should queue an error. + set := &nftables.Set{ + KeyType: nftables.TypeInetService, + Table: table, + Anonymous: true, + Constant: false, + } + conn.AddSet(set, []nftables.SetElement{ + {Key: binaryutil.BigEndian.PutUint16(80)}, + {Key: binaryutil.BigEndian.PutUint16(443)}, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: set.ID, + }, + }, + }) + + if err := conn.Flush(); err == nil { + t.Error("expected error when adding a non-constant anonymous set, got nil") + } else { + var errno syscall.Errno + if errors.As(err, &errno) { + t.Errorf("expected error to be not syscall.Errno, got %v", errno) + } + } + + table, err := conn.ListTable("test-table") + if table != nil || !errors.Is(err, syscall.ENOENT) { + t.Error("expected table to not exist") + } +} diff --git a/obj.go b/obj.go index 65c4402..0e2bbc9 100644 --- a/obj.go +++ b/obj.go @@ -111,7 +111,7 @@ func (cc *Conn) AddObj(o Obj) Obj { defer cc.mu.Unlock() data, err := expr.MarshalExprData(byte(o.family()), o.data()) if err != nil { - cc.setErr(err) + cc.appendErr(err) return nil } diff --git a/rule.go b/rule.go index 10958c6..00181d6 100644 --- a/rule.go +++ b/rule.go @@ -152,7 +152,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { })...) if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { - cc.setErr(err) + cc.appendErr(err) } else if compatPolicy != nil { data = append(data, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ @@ -256,7 +256,7 @@ func (cc *Conn) InsertRule(r *Rule) *Rule { // DelRule deletes the specified Rule. Either the Handle or ID of the // rule must be set. -func (cc *Conn) DelRule(r *Rule) error { +func (cc *Conn) DelRule(r *Rule) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ @@ -272,7 +272,8 @@ func (cc *Conn) DelRule(r *Rule) error { {Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)}, })...) } else { - return fmt.Errorf("rule must have a handle or ID") + cc.appendErr(fmt.Errorf("rule must have a handle or ID")) + return } flags := netlink.Request | netlink.Acknowledge @@ -283,8 +284,6 @@ func (cc *Conn) DelRule(r *Rule) error { }, Data: append(extraHeader(uint8(r.Table.Family), 0), data...), }) - - return nil } func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) { diff --git a/set.go b/set.go index e320117..5e37f9a 100644 --- a/set.go +++ b/set.go @@ -372,23 +372,32 @@ func decodeElement(d []byte) ([]byte, error) { } // SetAddElements applies data points to an nftables set. -func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { +func (cc *Conn) SetAddElements(s *Set, vals []SetElement) { cc.mu.Lock() defer cc.mu.Unlock() + if s.Anonymous { - return errors.New("anonymous sets cannot be updated") + cc.appendErr(errors.New("anonymous sets cannot be updated")) + return + } + err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) + if err != nil { + cc.appendErr(err) } - return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) } // SetDeleteElements deletes data points from an nftables set. -func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) { cc.mu.Lock() defer cc.mu.Unlock() if s.Anonymous { - return errors.New("anonymous sets cannot be updated") + cc.appendErr(errors.New("anonymous sets cannot be updated")) + return + } + err := cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) + if err != nil { + cc.appendErr(err) } - return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) } // maxElemBatchSize is the maximum size in bytes of encoded set elements which @@ -518,7 +527,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error } // AddSet adds the specified Set. -func (cc *Conn) AddSet(s *Set, vals []SetElement) error { +func (cc *Conn) AddSet(s *Set, vals []SetElement) { cc.mu.Lock() defer cc.mu.Unlock() // Based on nft implementation & linux source. @@ -526,7 +535,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { // Another reference: https://git.netfilter.org/nftables/tree/src if s.Anonymous && !s.Constant { - return errors.New("anonymous structs must be constant") + cc.appendErr(errors.New("anonymous structs must be constant")) + return } if s.ID == 0 { @@ -591,7 +601,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, }) if err != nil { - return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) + cc.appendErr(fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)) + return } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) } @@ -604,7 +615,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)}, }) if err != nil { - return fmt.Errorf("fail to marshal set size description: %w", err) + cc.appendErr(fmt.Errorf("fail to marshal set size description: %w", err)) + return } descBytes = append(descBytes, descSizeBytes...) @@ -620,21 +632,24 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, }) if err != nil { - return fmt.Errorf("fail to marshal element key size %d: %v", i, err) + cc.appendErr(fmt.Errorf("fail to marshal element key size %d: %v", i, err)) + return } // Marshal base type size description descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, }) if err != nil { - return fmt.Errorf("fail to marshal base type size description: %w", err) + cc.appendErr(fmt.Errorf("fail to marshal base type size description: %w", err)) + return } concatDefinition = append(concatDefinition, descSize...) } // Marshal all base type descriptions into concatenation size description concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) if err != nil { - return fmt.Errorf("fail to marshal concat definition %v", err) + cc.appendErr(fmt.Errorf("fail to marshal concat definition %v", err)) + return } descBytes = append(descBytes, concatBytes...) @@ -675,7 +690,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}}, }) if err != nil { - return err + cc.appendErr(err) + return } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) } @@ -689,7 +705,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { }) // Set the values of the set if initial values were provided. - return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) + err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) + cc.appendErr(err) } // DelSet deletes a specific set, along with all elements it contains. diff --git a/set_test.go b/set_test.go index df40e57..090fa12 100644 --- a/set_test.go +++ b/set_test.go @@ -268,9 +268,7 @@ func TestMarshalSet(t *testing.T) { for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := c.AddSet(&tt.set, nil); err != nil { - t.Fatal(err) - } + c.AddSet(&tt.set, nil) connMsgSetIdx := connMsgStart + i if len(c.messages) != connMsgSetIdx+1 {