Address review suggestions
This commit is contained in:
parent
9857ffe35c
commit
710638aff4
29
conn.go
29
conn.go
|
@ -232,9 +232,20 @@ func (cc *Conn) CloseLasting() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush sends all buffered commands in a single batch to nftables.
|
// Flush sends all buffered commands in a single batch to nftables.
|
||||||
// If an optional gen ID is provided, it will be used in the batch begin message.
|
func (cc *Conn) Flush() error {
|
||||||
// If the gen ID is not matched by the kernel, it will return an ERESTART error.
|
return cc.flush(0)
|
||||||
func (cc *Conn) Flush(genID ...uint32) error {
|
}
|
||||||
|
|
||||||
|
// FlushWithGenID sends all buffered commands in a single batch to nftables
|
||||||
|
// along with the provided gen ID. If the ruleset has changed since the gen ID
|
||||||
|
// was retrieved, an ERESTART error will be returned.
|
||||||
|
func (cc *Conn) FlushWithGenID(genID uint32) error {
|
||||||
|
return cc.flush(genID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush sends all buffered commands in a single batch to nftables. If genID is
|
||||||
|
// non-zero, it will be included in the batch messages.
|
||||||
|
func (cc *Conn) flush(genID uint32) error {
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
defer func() {
|
defer func() {
|
||||||
cc.messages = nil
|
cc.messages = nil
|
||||||
|
@ -261,7 +272,7 @@ func (cc *Conn) Flush(genID ...uint32) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
batch, err := batch(cc.messages, genID...)
|
batch, err := batch(cc.messages, genID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -395,17 +406,17 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch wraps the given messages in a batch begin and end message, and returns
|
// batch wraps the given messages in a batch begin and end message, and returns
|
||||||
// the resulting slice of netlink messages. If a genID is provided, it is
|
// the resulting slice of netlink messages. If the genID is non-zero, it will be
|
||||||
// included in both batch messages.
|
// included in both batch messages.
|
||||||
func batch(messages []netlinkMessage, genID ...uint32) ([]netlink.Message, error) {
|
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
|
||||||
batch := make([]netlink.Message, len(messages)+2)
|
batch := make([]netlink.Message, len(messages)+2)
|
||||||
|
|
||||||
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
|
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
|
||||||
|
|
||||||
if len(genID) > 0 && genID[0] > 0 {
|
if genID > 0 {
|
||||||
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
|
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||||
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID[0])},
|
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
1
gen.go
1
gen.go
|
@ -14,6 +14,7 @@ type Gen struct {
|
||||||
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
|
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
|
||||||
type GenMsg = Gen
|
type GenMsg = Gen
|
||||||
|
|
||||||
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
|
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
|
||||||
|
|
|
@ -7493,7 +7493,7 @@ func TestFlushWithGenID(t *testing.T) {
|
||||||
Family: nftables.TableFamilyIPv4,
|
Family: nftables.TableFamilyIPv4,
|
||||||
})
|
})
|
||||||
|
|
||||||
err = conn.Flush(gen.ID)
|
err = conn.FlushWithGenID(gen.ID)
|
||||||
if err == nil || !errors.Is(err, syscall.ERESTART) {
|
if err == nil || !errors.Is(err, syscall.ERESTART) {
|
||||||
t.Errorf("expected error to be ERESTART, got: %v", err)
|
t.Errorf("expected error to be ERESTART, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue