Defer errors until Flush to avoid incomplete batches
Any create/update/delete operation that returns a validation or marshalling error can leave the message batch in an incomplete state due to short-circuiting. This can result in either: - Non-atomic transactions if Flush is called (incomplete batch) - Users being unable to clear the incomplete batch (no API exposed) This change ensures that errors are collected and deferred until Flush. Instead of returning immediately, the following methods now append errors to a slice checked at Flush: - AddSet - DelRule - SetAddElements See: https://github.com/google/nftables/issues/323
This commit is contained in:
parent
508bb1ffd4
commit
d0b38630ac
17
conn.go
17
conn.go
|
@ -42,7 +42,7 @@ type Conn struct {
|
||||||
lasting bool // establish a lasting connection to be used across multiple netlink operations.
|
lasting bool // establish a lasting connection to be used across multiple netlink operations.
|
||||||
mu sync.Mutex // protects the following state
|
mu sync.Mutex // protects the following state
|
||||||
messages []netlinkMessage
|
messages []netlinkMessage
|
||||||
err error
|
errs []error
|
||||||
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
|
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
|
||||||
sockOptions []SockOption
|
sockOptions []SockOption
|
||||||
lastID uint32
|
lastID uint32
|
||||||
|
@ -237,14 +237,15 @@ func (cc *Conn) Flush() error {
|
||||||
defer func() {
|
defer func() {
|
||||||
cc.messages = nil
|
cc.messages = nil
|
||||||
cc.allocatedIDs = 0
|
cc.allocatedIDs = 0
|
||||||
|
cc.errs = nil
|
||||||
cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
if len(cc.messages) == 0 {
|
if len(cc.messages) == 0 {
|
||||||
// Messages were already programmed, returning nil
|
// Messages were already programmed, returning nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cc.err != nil {
|
if len(cc.errs) > 0 {
|
||||||
return cc.err // serialization error
|
return errors.Join(cc.errs...)
|
||||||
}
|
}
|
||||||
conn, closer, err := cc.netlinkConnUnderLock()
|
conn, closer, err := cc.netlinkConnUnderLock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -363,17 +364,17 @@ func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *Conn) setErr(err error) {
|
func (cc *Conn) appendErr(err error) {
|
||||||
if cc.err != nil {
|
if err == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc.err = err
|
cc.errs = append(cc.errs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
|
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
|
||||||
b, err := netlink.MarshalAttributes(attrs)
|
b, err := netlink.MarshalAttributes(attrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cc.setErr(err)
|
cc.appendErr(err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
|
@ -382,7 +383,7 @@ func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
|
||||||
func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
||||||
b, err := expr.Marshal(fam, e)
|
b, err := expr.Marshal(fam, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cc.setErr(err)
|
cc.appendErr(err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
|
|
|
@ -111,9 +111,7 @@ func TestNFTables(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.AddSet(devicesSet, elements); err != nil {
|
c.AddSet(devicesSet, elements)
|
||||||
t.Errorf("failed to add Set %s : %v", devicesSet.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
flowtable := &nftables.Flowtable{
|
flowtable := &nftables.Flowtable{
|
||||||
Table: table,
|
Table: table,
|
||||||
|
|
199
nftables_test.go
199
nftables_test.go
|
@ -290,9 +290,7 @@ func TestRuleHandle(t *testing.T) {
|
||||||
{
|
{
|
||||||
Name: "delete-rule",
|
Name: "delete-rule",
|
||||||
Func: func() {
|
Func: func() {
|
||||||
if err := c.DelRule(rule2); err != nil {
|
c.DelRule(rule2)
|
||||||
t.Errorf("DelRule failed: %v", err)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
Expect: []string{"1", "3"},
|
Expect: []string{"1", "3"},
|
||||||
},
|
},
|
||||||
|
@ -3335,12 +3333,10 @@ func TestCreateUseAnonymousSet(t *testing.T) {
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.AddSet(set, []nftables.SetElement{
|
c.AddSet(set, []nftables.SetElement{
|
||||||
{Key: binaryutil.BigEndian.PutUint16(69)},
|
{Key: binaryutil.BigEndian.PutUint16(69)},
|
||||||
{Key: binaryutil.BigEndian.PutUint16(1163)},
|
{Key: binaryutil.BigEndian.PutUint16(1163)},
|
||||||
}); err != nil {
|
})
|
||||||
t.Errorf("c.AddSet() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.AddRule(&nftables.Rule{
|
c.AddRule(&nftables.Rule{
|
||||||
Table: filter,
|
Table: filter,
|
||||||
|
@ -3414,9 +3410,7 @@ func TestCappedErrMsgOnSets(t *testing.T) {
|
||||||
Name: "if_set",
|
Name: "if_set",
|
||||||
KeyType: nftables.TypeIFName,
|
KeyType: nftables.TypeIFName,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(ifSet, nil); err != nil {
|
c.AddSet(ifSet, nil)
|
||||||
t.Errorf("c.AddSet(ifSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("failed adding set ifSet: %v", err)
|
t.Errorf("failed adding set ifSet: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3437,9 +3431,7 @@ func TestCappedErrMsgOnSets(t *testing.T) {
|
||||||
elements := []nftables.SetElement{
|
elements := []nftables.SetElement{
|
||||||
{Key: []byte("012345678912345\x00")},
|
{Key: []byte("012345678912345\x00")},
|
||||||
}
|
}
|
||||||
if err := c.SetAddElements(ifSet, elements); err != nil {
|
c.SetAddElements(ifSet, elements)
|
||||||
t.Errorf("adding SetElements(ifSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("failed adding set elements ifSet: %v", err)
|
t.Errorf("failed adding set elements ifSet: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3477,24 +3469,16 @@ func TestCreateUseNamedSet(t *testing.T) {
|
||||||
Name: "test",
|
Name: "test",
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, nil); err != nil {
|
c.AddSet(portSet, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
|
||||||
}
|
|
||||||
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
|
|
||||||
t.Errorf("c.SetVal(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipSet := &nftables.Set{
|
ipSet := &nftables.Set{
|
||||||
Table: filter,
|
Table: filter,
|
||||||
Name: "IPs_4_dayz",
|
Name: "IPs_4_dayz",
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: nftables.TypeIPAddr,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil {
|
c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}})
|
||||||
t.Errorf("c.AddSet(ipSet) failed: %v", err)
|
c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}})
|
||||||
}
|
|
||||||
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
|
|
||||||
t.Errorf("c.SetVal(ipSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3535,12 +3519,8 @@ func TestCreateAutoMergeSet(t *testing.T) {
|
||||||
Interval: true,
|
Interval: true,
|
||||||
AutoMerge: true,
|
AutoMerge: true,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, nil); err != nil {
|
c.AddSet(portSet, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
|
||||||
}
|
|
||||||
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
|
|
||||||
t.Errorf("c.SetVal(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipSet := &nftables.Set{
|
ipSet := &nftables.Set{
|
||||||
Table: filter,
|
Table: filter,
|
||||||
|
@ -3549,12 +3529,8 @@ func TestCreateAutoMergeSet(t *testing.T) {
|
||||||
Interval: true,
|
Interval: true,
|
||||||
AutoMerge: true,
|
AutoMerge: true,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}}); err != nil {
|
c.AddSet(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.64").To4())}})
|
||||||
t.Errorf("c.AddSet(ipSet) failed: %v", err)
|
c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}})
|
||||||
}
|
|
||||||
if err := c.SetAddElements(ipSet, []nftables.SetElement{{Key: []byte(net.ParseIP("192.168.1.42").To4())}}); err != nil {
|
|
||||||
t.Errorf("c.SetVal(ipSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3592,15 +3568,11 @@ func TestIP6SetAddElements(t *testing.T) {
|
||||||
Name: "ports",
|
Name: "ports",
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, nil); err != nil {
|
c.AddSet(portSet, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
c.SetAddElements(portSet, []nftables.SetElement{
|
||||||
}
|
|
||||||
if err := c.SetAddElements(portSet, []nftables.SetElement{
|
|
||||||
{Key: binaryutil.BigEndian.PutUint16(22)},
|
{Key: binaryutil.BigEndian.PutUint16(22)},
|
||||||
{Key: binaryutil.BigEndian.PutUint16(80)},
|
{Key: binaryutil.BigEndian.PutUint16(80)},
|
||||||
}); err != nil {
|
})
|
||||||
t.Errorf("c.SetVal(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
|
@ -3648,9 +3620,7 @@ func TestSetElementBatching(t *testing.T) {
|
||||||
elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i))
|
elements[i].Key = binaryutil.BigEndian.PutUint16(uint16(i))
|
||||||
elements[i].Comment = "0123456789"
|
elements[i].Comment = "0123456789"
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, elements); err != nil {
|
c.AddSet(portSet, elements)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3694,12 +3664,8 @@ func TestCreateUseCounterSet(t *testing.T) {
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
Counter: true,
|
Counter: true,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, nil); err != nil {
|
c.AddSet(portSet, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}})
|
||||||
}
|
|
||||||
if err := c.SetAddElements(portSet, []nftables.SetElement{{Key: binaryutil.BigEndian.PutUint16(22)}}); err != nil {
|
|
||||||
t.Errorf("c.SetVal(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
|
@ -3736,9 +3702,7 @@ func TestCreateDeleteNamedSet(t *testing.T) {
|
||||||
Name: "test",
|
Name: "test",
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, nil); err != nil {
|
c.AddSet(portSet, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3777,9 +3741,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
|
||||||
Name: "test",
|
Name: "test",
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
|
c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}})
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3824,9 +3786,7 @@ func TestFlushNamedSet(t *testing.T) {
|
||||||
Name: "test",
|
Name: "test",
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
|
c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}})
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3866,20 +3826,16 @@ func TestSetElementsInterval(t *testing.T) {
|
||||||
Interval: true,
|
Interval: true,
|
||||||
Concatenation: true,
|
Concatenation: true,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(portSet, nil); err != nil {
|
c.AddSet(portSet, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
|
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
|
||||||
keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
|
keyBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 195, 80, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
|
||||||
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
|
// { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
|
||||||
keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
|
keyEndBytes := []byte{119, 124, 171, 75, 133, 240, 22, 20, 73, 229, 210, 155, 170, 123, 204, 144, 234, 96, 0, 0, 135, 9, 28, 185, 22, 62, 155, 85, 53, 127, 239, 100, 112, 138, 237, 203}
|
||||||
// elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
|
// elements = { 777c:ab4b:85f0:1614:49e5:d29b:aa7b:cc90 . 50000-60000 . 8709:1cb9:163e:9b55:357f:ef64:708a:edcb }
|
||||||
if err := c.SetAddElements(portSet, []nftables.SetElement{
|
c.SetAddElements(portSet, []nftables.SetElement{
|
||||||
{Key: keyBytes, KeyEnd: keyEndBytes},
|
{Key: keyBytes, KeyEnd: keyEndBytes},
|
||||||
}); err != nil {
|
})
|
||||||
t.Errorf("c.SetVal(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
|
@ -3939,9 +3895,7 @@ func TestSetSizeConcat(t *testing.T) {
|
||||||
Size: 200,
|
Size: 200,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.AddSet(set, nil); err != nil {
|
c.AddSet(set, nil)
|
||||||
t.Errorf("c.AddSet(set) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
|
@ -4550,9 +4504,7 @@ func TestGetLookupExprDestSet(t *testing.T) {
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
DataType: nftables.TypeVerdict,
|
DataType: nftables.TypeVerdict,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(set, nil); err != nil {
|
c.AddSet(set, nil)
|
||||||
t.Errorf("c.AddSet(set) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -4647,9 +4599,7 @@ func TestGetRuleLookupVerdictImmediate(t *testing.T) {
|
||||||
Name: "test",
|
Name: "test",
|
||||||
KeyType: nftables.TypeInetService,
|
KeyType: nftables.TypeInetService,
|
||||||
}
|
}
|
||||||
if err := c.AddSet(set, nil); err != nil {
|
c.AddSet(set, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -4776,9 +4726,8 @@ func TestDynset(t *testing.T) {
|
||||||
HasTimeout: true,
|
HasTimeout: true,
|
||||||
Timeout: time.Duration(600 * time.Second),
|
Timeout: time.Duration(600 * time.Second),
|
||||||
}
|
}
|
||||||
if err := c.AddSet(set, nil); err != nil {
|
c.AddSet(set, nil)
|
||||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -4866,9 +4815,7 @@ func TestDynsetWithOneExpression(t *testing.T) {
|
||||||
}
|
}
|
||||||
c.AddTable(table)
|
c.AddTable(table)
|
||||||
c.AddChain(chain)
|
c.AddChain(chain)
|
||||||
if err := c.AddSet(set, nil); err != nil {
|
c.AddSet(set, nil)
|
||||||
t.Errorf("c.AddSet(myMeter) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -4968,9 +4915,7 @@ func TestDynsetWithMultipleExpressions(t *testing.T) {
|
||||||
}
|
}
|
||||||
c.AddTable(table)
|
c.AddTable(table)
|
||||||
c.AddChain(chain)
|
c.AddChain(chain)
|
||||||
if err := c.AddSet(set, nil); err != nil {
|
c.AddSet(set, nil)
|
||||||
t.Errorf("c.AddSet(myMeter) failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Errorf("c.Flush() failed: %v", err)
|
t.Errorf("c.Flush() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -5612,9 +5557,7 @@ func TestSet4(t *testing.T) {
|
||||||
setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i])
|
setElements[i].Key = binaryutil.BigEndian.PutUint16(ports[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.AddSet(&set, setElements); err != nil {
|
c.AddSet(&set, setElements)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.AddRule(&nftables.Rule{
|
c.AddRule(&nftables.Rule{
|
||||||
Table: tbl,
|
Table: tbl,
|
||||||
|
@ -5653,15 +5596,13 @@ func TestSetComment(t *testing.T) {
|
||||||
Name: "filter",
|
Name: "filter",
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := c.AddSet(&nftables.Set{
|
c.AddSet(&nftables.Set{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
Table: filter,
|
Table: filter,
|
||||||
Name: "setname",
|
Name: "setname",
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: nftables.TypeIPAddr,
|
||||||
Comment: "test comment",
|
Comment: "test comment",
|
||||||
}, nil); err != nil {
|
}, nil)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -7359,9 +7300,8 @@ func TestSetElementComment(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the set with elements
|
// Add the set with elements
|
||||||
if err := conn.AddSet(set, elements); err != nil {
|
conn.AddSet(set, elements)
|
||||||
t.Fatalf("failed to add set: %v", err)
|
|
||||||
}
|
|
||||||
if err := conn.Flush(); err != nil {
|
if err := conn.Flush(); err != nil {
|
||||||
t.Fatalf("failed to flush: %v", err)
|
t.Fatalf("failed to flush: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -7429,3 +7369,68 @@ func TestAutoBufferSize(t *testing.T) {
|
||||||
t.Fatalf("failed to flush: %v", err)
|
t.Fatalf("failed to flush: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeferredErrors(t *testing.T) {
|
||||||
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
|
defer conn.FlushRuleset()
|
||||||
|
|
||||||
|
table := conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Anonymous sets have to be constant. Adding this set should queue an error.
|
||||||
|
set := &nftables.Set{
|
||||||
|
KeyType: nftables.TypeInetService,
|
||||||
|
Table: table,
|
||||||
|
Anonymous: true,
|
||||||
|
Constant: false,
|
||||||
|
}
|
||||||
|
conn.AddSet(set, []nftables.SetElement{
|
||||||
|
{Key: binaryutil.BigEndian.PutUint16(80)},
|
||||||
|
{Key: binaryutil.BigEndian.PutUint16(443)},
|
||||||
|
})
|
||||||
|
|
||||||
|
chain := conn.AddChain(&nftables.Chain{
|
||||||
|
Name: "test-chain",
|
||||||
|
Table: table,
|
||||||
|
})
|
||||||
|
|
||||||
|
conn.AddRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{unix.IPPROTO_TCP},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetID: set.ID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := conn.Flush(); err == nil {
|
||||||
|
t.Error("expected error when adding a non-constant anonymous set, got nil")
|
||||||
|
} else {
|
||||||
|
var errno syscall.Errno
|
||||||
|
if errors.As(err, &errno) {
|
||||||
|
t.Errorf("expected error to be not syscall.Errno, got %v", errno)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := conn.ListTable("test-table")
|
||||||
|
if table != nil || !errors.Is(err, syscall.ENOENT) {
|
||||||
|
t.Error("expected table to not exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
2
obj.go
2
obj.go
|
@ -111,7 +111,7 @@ func (cc *Conn) AddObj(o Obj) Obj {
|
||||||
defer cc.mu.Unlock()
|
defer cc.mu.Unlock()
|
||||||
data, err := expr.MarshalExprData(byte(o.family()), o.data())
|
data, err := expr.MarshalExprData(byte(o.family()), o.data())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cc.setErr(err)
|
cc.appendErr(err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
9
rule.go
9
rule.go
|
@ -152,7 +152,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
|
||||||
})...)
|
})...)
|
||||||
|
|
||||||
if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil {
|
if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil {
|
||||||
cc.setErr(err)
|
cc.appendErr(err)
|
||||||
} else if compatPolicy != nil {
|
} else if compatPolicy != nil {
|
||||||
data = append(data, cc.marshalAttr([]netlink.Attribute{
|
data = append(data, cc.marshalAttr([]netlink.Attribute{
|
||||||
{Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{
|
{Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{
|
||||||
|
@ -256,7 +256,7 @@ func (cc *Conn) InsertRule(r *Rule) *Rule {
|
||||||
|
|
||||||
// DelRule deletes the specified Rule. Either the Handle or ID of the
|
// DelRule deletes the specified Rule. Either the Handle or ID of the
|
||||||
// rule must be set.
|
// rule must be set.
|
||||||
func (cc *Conn) DelRule(r *Rule) error {
|
func (cc *Conn) DelRule(r *Rule) {
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
defer cc.mu.Unlock()
|
defer cc.mu.Unlock()
|
||||||
data := cc.marshalAttr([]netlink.Attribute{
|
data := cc.marshalAttr([]netlink.Attribute{
|
||||||
|
@ -272,7 +272,8 @@ func (cc *Conn) DelRule(r *Rule) error {
|
||||||
{Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)},
|
{Type: unix.NFTA_RULE_ID, Data: binaryutil.BigEndian.PutUint32(r.ID)},
|
||||||
})...)
|
})...)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("rule must have a handle or ID")
|
cc.appendErr(fmt.Errorf("rule must have a handle or ID"))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
flags := netlink.Request | netlink.Acknowledge
|
flags := netlink.Request | netlink.Acknowledge
|
||||||
|
|
||||||
|
@ -283,8 +284,6 @@ func (cc *Conn) DelRule(r *Rule) error {
|
||||||
},
|
},
|
||||||
Data: append(extraHeader(uint8(r.Table.Family), 0), data...),
|
Data: append(extraHeader(uint8(r.Table.Family), 0), data...),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {
|
func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {
|
||||||
|
|
47
set.go
47
set.go
|
@ -372,23 +372,32 @@ func decodeElement(d []byte) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAddElements applies data points to an nftables set.
|
// SetAddElements applies data points to an nftables set.
|
||||||
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error {
|
func (cc *Conn) SetAddElements(s *Set, vals []SetElement) {
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
defer cc.mu.Unlock()
|
defer cc.mu.Unlock()
|
||||||
|
|
||||||
if s.Anonymous {
|
if s.Anonymous {
|
||||||
return errors.New("anonymous sets cannot be updated")
|
cc.appendErr(errors.New("anonymous sets cannot be updated"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
|
||||||
|
if err != nil {
|
||||||
|
cc.appendErr(err)
|
||||||
}
|
}
|
||||||
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDeleteElements deletes data points from an nftables set.
|
// SetDeleteElements deletes data points from an nftables set.
|
||||||
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
|
func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) {
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
defer cc.mu.Unlock()
|
defer cc.mu.Unlock()
|
||||||
if s.Anonymous {
|
if s.Anonymous {
|
||||||
return errors.New("anonymous sets cannot be updated")
|
cc.appendErr(errors.New("anonymous sets cannot be updated"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
|
||||||
|
if err != nil {
|
||||||
|
cc.appendErr(err)
|
||||||
}
|
}
|
||||||
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// maxElemBatchSize is the maximum size in bytes of encoded set elements which
|
// maxElemBatchSize is the maximum size in bytes of encoded set elements which
|
||||||
|
@ -518,7 +527,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSet adds the specified Set.
|
// AddSet adds the specified Set.
|
||||||
func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
func (cc *Conn) AddSet(s *Set, vals []SetElement) {
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
defer cc.mu.Unlock()
|
defer cc.mu.Unlock()
|
||||||
// Based on nft implementation & linux source.
|
// Based on nft implementation & linux source.
|
||||||
|
@ -526,7 +535,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
// Another reference: https://git.netfilter.org/nftables/tree/src
|
// Another reference: https://git.netfilter.org/nftables/tree/src
|
||||||
|
|
||||||
if s.Anonymous && !s.Constant {
|
if s.Anonymous && !s.Constant {
|
||||||
return errors.New("anonymous structs must be constant")
|
cc.appendErr(errors.New("anonymous structs must be constant"))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ID == 0 {
|
if s.ID == 0 {
|
||||||
|
@ -591,7 +601,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
|
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)
|
cc.appendErr(fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
|
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
|
||||||
}
|
}
|
||||||
|
@ -604,7 +615,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)},
|
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fail to marshal set size description: %w", err)
|
cc.appendErr(fmt.Errorf("fail to marshal set size description: %w", err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
descBytes = append(descBytes, descSizeBytes...)
|
descBytes = append(descBytes, descSizeBytes...)
|
||||||
|
@ -620,21 +632,24 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
|
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fail to marshal element key size %d: %v", i, err)
|
cc.appendErr(fmt.Errorf("fail to marshal element key size %d: %v", i, err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
// Marshal base type size description
|
// Marshal base type size description
|
||||||
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
|
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||||
{Type: unix.NFTA_SET_DESC_SIZE, Data: valData},
|
{Type: unix.NFTA_SET_DESC_SIZE, Data: valData},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fail to marshal base type size description: %w", err)
|
cc.appendErr(fmt.Errorf("fail to marshal base type size description: %w", err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
concatDefinition = append(concatDefinition, descSize...)
|
concatDefinition = append(concatDefinition, descSize...)
|
||||||
}
|
}
|
||||||
// Marshal all base type descriptions into concatenation size description
|
// Marshal all base type descriptions into concatenation size description
|
||||||
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
|
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fail to marshal concat definition %v", err)
|
cc.appendErr(fmt.Errorf("fail to marshal concat definition %v", err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
descBytes = append(descBytes, concatBytes...)
|
descBytes = append(descBytes, concatBytes...)
|
||||||
|
@ -675,7 +690,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
{Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}},
|
{Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
cc.appendErr(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data})
|
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data})
|
||||||
}
|
}
|
||||||
|
@ -689,7 +705,8 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set the values of the set if initial values were provided.
|
// Set the values of the set if initial values were provided.
|
||||||
return cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
|
err := cc.appendElemList(s, vals, unix.NFT_MSG_NEWSETELEM)
|
||||||
|
cc.appendErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DelSet deletes a specific set, along with all elements it contains.
|
// DelSet deletes a specific set, along with all elements it contains.
|
||||||
|
|
|
@ -268,9 +268,7 @@ func TestMarshalSet(t *testing.T) {
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := c.AddSet(&tt.set, nil); err != nil {
|
c.AddSet(&tt.set, nil)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
connMsgSetIdx := connMsgStart + i
|
connMsgSetIdx := connMsgStart + i
|
||||||
if len(c.messages) != connMsgSetIdx+1 {
|
if len(c.messages) != connMsgSetIdx+1 {
|
||||||
|
|
Loading…
Reference in New Issue