diff --git a/conn.go b/conn.go index c6c85bb..a67edf6 100644 --- a/conn.go +++ b/conn.go @@ -42,7 +42,7 @@ type Conn struct { lasting bool // establish a lasting connection to be used across multiple netlink operations. mu sync.Mutex // protects the following state messages []netlinkMessage - err error + errs []error nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. sockOptions []SockOption lastID uint32 @@ -237,14 +237,15 @@ func (cc *Conn) Flush() error { defer func() { cc.messages = nil cc.allocatedIDs = 0 + cc.errs = nil cc.mu.Unlock() }() if len(cc.messages) == 0 { // Messages were already programmed, returning nil return nil } - if cc.err != nil { - return cc.err // serialization error + if len(cc.errs) > 0 { + return errors.Join(cc.errs...) } conn, closer, err := cc.netlinkConnUnderLock() if err != nil { @@ -363,17 +364,17 @@ func (cc *Conn) dialNetlink() (*netlink.Conn, error) { return conn, nil } -func (cc *Conn) setErr(err error) { - if cc.err != nil { +func (cc *Conn) appendErr(err error) { + if err == nil { return } - cc.err = err + cc.errs = append(cc.errs, err) } func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { b, err := netlink.MarshalAttributes(attrs) if err != nil { - cc.setErr(err) + cc.appendErr(err) return nil } return b @@ -382,7 +383,7 @@ func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { b, err := expr.Marshal(fam, e) if err != nil { - cc.setErr(err) + cc.appendErr(err) return nil } return b diff --git a/integration/nft_test.go b/integration/nft_test.go index 14c7d43..a9b0a84 100644 --- a/integration/nft_test.go +++ b/integration/nft_test.go @@ -111,9 +111,7 @@ func TestNFTables(t *testing.T) { }) } - if err := c.AddSet(devicesSet, elements); err != nil { - t.Errorf("failed to add Set %s : %v", devicesSet.Name, err) - } + c.AddSet(devicesSet, elements) flowtable := &nftables.Flowtable{ Table: table, diff --git a/nftables_test.go b/nftables_test.go index fe12566..e0cb329 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -290,9 +290,7 @@ func TestRuleHandle(t *testing.T) { { Name: "delete-rule", Func: func() { - if err := c.DelRule(rule2); err != nil { - t.Errorf("DelRule failed: %v", err) - } + c.DelRule(rule2) }, Expect: []string{"1", "3"}, }, @@ -3335,12 +3333,10 @@ func TestCreateUseAnonymousSet(t *testing.T) { KeyType: nftables.TypeInetService, } - if err := c.AddSet(set, []nftables.SetElement{ + c.AddSet(set, []nftables.SetElement{ {Key: binaryutil.BigEndian.PutUint16(69)}, {Key: binaryutil.BigEndian.PutUint16(1163)}, - }); err != nil { - t.Errorf("c.AddSet() failed: %v", err) - } + }) c.AddRule(&nftables.Rule{ Table: filter, @@ -3414,9 +3410,7 @@ func TestCappedErrMsgOnSets(t *testing.T) { Name: "if_set", KeyType: nftables.TypeIFName, } - if err := c.AddSet(ifSet, nil); err != nil { - t.Errorf("c.AddSet(ifSet) failed: %v", err) - } + c.AddSet(ifSet, nil) if err := c.Flush(); err != nil { t.Errorf("failed adding set ifSet: %v", err) } @@ -3437,9 +3431,7 @@ func TestCappedErrMsgOnSets(t *testing.T) { elements := []nftables.SetElement{ {Key: []byte("012345678912345\x00")}, } - if err := c.SetAddElements(ifSet, elements); err != nil { - t.Errorf("adding SetElements(ifSet) failed: %v", err) - } + c.SetAddElements(ifSet, elements) if err := c.Flush(); err != nil { t.Errorf("failed adding set elements ifSet: %v", err) } @@ -3477,24 +3469,16 @@ func TestCreateUseNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}) ipSet := &nftables.Set{ Table: filter, Name: "IPs_4_dayz", KeyType: nftables.TypeIPAddr, } - if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { - t.Errorf("c.AddSet(ipSet) failed: %v", err) - } - if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil { - t.Errorf("c.SetVal(ipSet) failed: %v", err) - } + c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}) + c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3535,12 +3519,8 @@ func TestCreateAutoMergeSet(t *testing.T) { Interval: true, AutoMerge: true, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}) ipSet := &nftables.Set{ Table: filter, @@ -3549,12 +3529,8 @@ func TestCreateAutoMergeSet(t *testing.T) { Interval: true, AutoMerge: true, } - if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil { - t.Errorf("c.AddSet(ipSet) failed: %v", err) - } - if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil { - t.Errorf("c.SetVal(ipSet) failed: %v", err) - } + c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}) + c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3592,15 +3568,11 @@ func TestIP6SetAddElements(t *testing.T) { Name: "ports", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{ + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{ {Key: binaryutil.BigEndian.PutUint16(22)}, {Key: binaryutil.BigEndian.PutUint16(80)}, - }); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + }) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -3648,9 +3620,7 @@ func TestSetElementBatching(t *testing.T) { elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i)) elements[i].Comment = "0123456789" } - if err := c.AddSet(portSet, elements); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, elements) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3694,12 +3664,8 @@ func TestCreateUseCounterSet(t *testing.T) { KeyType: nftables.TypeInetService, Counter: true, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } - if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) + c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -3736,9 +3702,7 @@ func TestCreateDeleteNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3777,9 +3741,7 @@ func TestDeleteElementNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3824,9 +3786,7 @@ func TestFlushNamedSet(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -3866,20 +3826,16 @@ func TestSetElementsInterval(t *testing.T) { Interval: true, Concatenation: true, } - if err := c.AddSet(portSet, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(portSet, nil) // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} // { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203} // elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb } - if err := c.SetAddElements(portSet, []nftables.SetElement{ + c.SetAddElements(portSet, []nftables.SetElement{ {Key: keyBytes, KeyEnd: keyEndBytes}, - }); err != nil { - t.Errorf("c.SetVal(portSet) failed: %v", err) - } + }) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -3939,9 +3895,7 @@ func TestSetSizeConcat(t *testing.T) { Size: 200, } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(set) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) @@ -4550,9 +4504,7 @@ func TestGetLookupExprDestSet(t *testing.T) { KeyType: nftables.TypeInetService, DataType: nftables.TypeVerdict, } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(set) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4647,9 +4599,7 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) { Name: "test", KeyType: nftables.TypeInetService, } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4776,9 +4726,8 @@ func TestDynset(t *testing.T) { HasTimeout: true, Timeout: time.Duration(600 * time.Second), } - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(portSet) failed: %v", err) - } + c.AddSet(set, nil) + if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4866,9 +4815,7 @@ func TestDynsetWithOneExpression(t *testing.T) { } c.AddTable(table) c.AddChain(chain) - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(myMeter) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -4968,9 +4915,7 @@ func TestDynsetWithMultipleExpressions(t *testing.T) { } c.AddTable(table) c.AddChain(chain) - if err := c.AddSet(set, nil); err != nil { - t.Errorf("c.AddSet(myMeter) failed: %v", err) - } + c.AddSet(set, nil) if err := c.Flush(); err != nil { t.Errorf("c.Flush() failed: %v", err) } @@ -5612,9 +5557,7 @@ func TestSet4(t *testing.T) { setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i]) } - if err := c.AddSet(&set, setElements); err != nil { - t.Fatal(err) - } + c.AddSet(&set, setElements) c.AddRule(&nftables.Rule{ Table: tbl, @@ -5653,15 +5596,13 @@ func TestSetComment(t *testing.T) { Name: "filter", }) - if err := c.AddSet(&nftables.Set{ + c.AddSet(&nftables.Set{ ID: 2, Table: filter, Name: "setname", KeyType: nftables.TypeIPAddr, Comment: "test comment", - }, nil); err != nil { - t.Fatal(err) - } + }, nil) if err := c.Flush(); err != nil { t.Fatal(err) @@ -7359,9 +7300,8 @@ func TestSetElementComment(t *testing.T) { } // Add the set with elements - if err := conn.AddSet(set, elements); err != nil { - t.Fatalf("failed to add set: %v", err) - } + conn.AddSet(set, elements) + if err := conn.Flush(); err != nil { t.Fatalf("failed to flush: %v", err) } @@ -7429,3 +7369,68 @@ func TestAutoBufferSize(t *testing.T) { t.Fatalf("failed to flush: %v", err) } } + +func TestDeferredErrors(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test-table", + Family: nftables.TableFamilyIPv4, + }) + + // Anonymous sets have to be constant. Adding this set should queue an error. + set := &nftables.Set{ + KeyType: nftables.TypeInetService, + Table: table, + Anonymous: true, + Constant: false, + } + conn.AddSet(set, []nftables.SetElement{ + {Key: binaryutil.BigEndian.PutUint16(80)}, + {Key: binaryutil.BigEndian.PutUint16(443)}, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "test-chain", + Table: table, + }) + + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: set.ID, + }, + }, + }) + + if err := conn.Flush(); err == nil { + t.Error("expected error when adding a non-constant anonymous set, got nil") + } else { + var errno syscall.Errno + if errors.As(err, &errno) { + t.Errorf("expected error to be not syscall.Errno, got %v", errno) + } + } + + table, err := conn.ListTable("test-table") + if table != nil || !errors.Is(err, syscall.ENOENT) { + t.Error("expected table to not exist") + } +} diff --git a/obj.go b/obj.go index 65c4402..0e2bbc9 100644 --- a/obj.go +++ b/obj.go @@ -111,7 +111,7 @@ func (cc *Conn) AddObj(o Obj) Obj { defer cc.mu.Unlock() data, err := expr.MarshalExprData(byte(o.family()), o.data()) if err != nil { - cc.setErr(err) + cc.appendErr(err) return nil } diff --git a/rule.go b/rule.go index 10958c6..00181d6 100644 --- a/rule.go +++ b/rule.go @@ -152,7 +152,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { })...) if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { - cc.setErr(err) + cc.appendErr(err) } else if compatPolicy != nil { data = append(data, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ @@ -256,7 +256,7 @@ func (cc *Conn) InsertRule(r *Rule) *Rule { // DelRule deletes the specified Rule. Either the Handle or ID of the // rule must be set. -func (cc *Conn) DelRule(r *Rule) error { +func (cc *Conn) DelRule(r *Rule) { cc.mu.Lock() defer cc.mu.Unlock() data := cc.marshalAttr([]netlink.Attribute{ @@ -272,7 +272,8 @@ func (cc *Conn) DelRule(r *Rule) error { {Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)}, })...) } else { - return fmt.Errorf("rule must have a handle or ID") + cc.appendErr(fmt.Errorf("rule must have a handle or ID")) + return } flags := netlink.Request | netlink.Acknowledge @@ -283,8 +284,6 @@ func (cc *Conn) DelRule(r *Rule) error { }, Data: append(extraHeader(uint8(r.Table.Family), 0), data...), }) - - return nil } func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) { diff --git a/set.go b/set.go index e320117..5e37f9a 100644 --- a/set.go +++ b/set.go @@ -372,23 +372,32 @@ func decodeElement(d []byte) ([]byte, error) { } // SetAddElements applies data points to an nftables set. -func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { +func (cc *Conn) SetAddElements(s *Set, vals []SetElement) { cc.mu.Lock() defer cc.mu.Unlock() + if s.Anonymous { - return errors.New("anonymous sets cannot be updated") + cc.appendErr(errors.New("anonymous sets cannot be updated")) + return + } + err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) + if err != nil { + cc.appendErr(err) } - return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) } // SetDeleteElements deletes data points from an nftables set. -func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) { cc.mu.Lock() defer cc.mu.Unlock() if s.Anonymous { - return errors.New("anonymous sets cannot be updated") + cc.appendErr(errors.New("anonymous sets cannot be updated")) + return + } + err := cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) + if err != nil { + cc.appendErr(err) } - return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) } // maxElemBatchSize is the maximum size in bytes of encoded set elements which @@ -518,7 +527,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error } // AddSet adds the specified Set. -func (cc *Conn) AddSet(s *Set, vals []SetElement) error { +func (cc *Conn) AddSet(s *Set, vals []SetElement) { cc.mu.Lock() defer cc.mu.Unlock() // Based on nft implementation & linux source. @@ -526,7 +535,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { // Another reference: https://git.netfilter.org/nftables/tree/src if s.Anonymous && !s.Constant { - return errors.New("anonymous structs must be constant") + cc.appendErr(errors.New("anonymous structs must be constant")) + return } if s.ID == 0 { @@ -591,7 +601,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, }) if err != nil { - return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) + cc.appendErr(fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)) + return } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) } @@ -604,7 +615,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)}, }) if err != nil { - return fmt.Errorf("fail to marshal set size description: %w", err) + cc.appendErr(fmt.Errorf("fail to marshal set size description: %w", err)) + return } descBytes = append(descBytes, descSizeBytes...) @@ -620,21 +632,24 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, }) if err != nil { - return fmt.Errorf("fail to marshal element key size %d: %v", i, err) + cc.appendErr(fmt.Errorf("fail to marshal element key size %d: %v", i, err)) + return } // Marshal base type size description descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, }) if err != nil { - return fmt.Errorf("fail to marshal base type size description: %w", err) + cc.appendErr(fmt.Errorf("fail to marshal base type size description: %w", err)) + return } concatDefinition = append(concatDefinition, descSize...) } // Marshal all base type descriptions into concatenation size description concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) if err != nil { - return fmt.Errorf("fail to marshal concat definition %v", err) + cc.appendErr(fmt.Errorf("fail to marshal concat definition %v", err)) + return } descBytes = append(descBytes, concatBytes...) @@ -675,7 +690,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { {Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}}, }) if err != nil { - return err + cc.appendErr(err) + return } tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) } @@ -689,7 +705,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { }) // Set the values of the set if initial values were provided. - return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) + err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM) + cc.appendErr(err) } // DelSet deletes a specific set, along with all elements it contains. diff --git a/set_test.go b/set_test.go index df40e57..090fa12 100644 --- a/set_test.go +++ b/set_test.go @@ -268,9 +268,7 @@ func TestMarshalSet(t *testing.T) { for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := c.AddSet(&tt.set, nil); err != nil { - t.Fatal(err) - } + c.AddSet(&tt.set, nil) connMsgSetIdx := connMsgStart + i if len(c.messages) != connMsgSetIdx+1 {