Merge 710638aff4
into 508bb1ffd4
This commit is contained in:
commit
f4feaa8a24
44
conn.go
44
conn.go
|
@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error {
|
||||||
|
|
||||||
// Flush sends all buffered commands in a single batch to nftables.
|
// Flush sends all buffered commands in a single batch to nftables.
|
||||||
func (cc *Conn) Flush() error {
|
func (cc *Conn) Flush() error {
|
||||||
|
return cc.flush(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlushWithGenID sends all buffered commands in a single batch to nftables
|
||||||
|
// along with the provided gen ID. If the ruleset has changed since the gen ID
|
||||||
|
// was retrieved, an ERESTART error will be returned.
|
||||||
|
func (cc *Conn) FlushWithGenID(genID uint32) error {
|
||||||
|
return cc.flush(genID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush sends all buffered commands in a single batch to nftables. If genID is
|
||||||
|
// non-zero, it will be included in the batch messages.
|
||||||
|
func (cc *Conn) flush(genID uint32) error {
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
defer func() {
|
defer func() {
|
||||||
cc.messages = nil
|
cc.messages = nil
|
||||||
|
@ -259,7 +272,12 @@ func (cc *Conn) Flush() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
messages, err := conn.SendMessages(batch(cc.messages))
|
batch, err := batch(cc.messages, genID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := conn.SendMessages(batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("SendMessages: %w", err)
|
return fmt.Errorf("SendMessages: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func batch(messages []netlinkMessage) []netlink.Message {
|
// batch wraps the given messages in a batch begin and end message, and returns
|
||||||
|
// the resulting slice of netlink messages. If the genID is non-zero, it will be
|
||||||
|
// included in both batch messages.
|
||||||
|
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
|
||||||
batch := make([]netlink.Message, len(messages)+2)
|
batch := make([]netlink.Message, len(messages)+2)
|
||||||
|
|
||||||
|
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
|
||||||
|
|
||||||
|
if genID > 0 {
|
||||||
|
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||||
|
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data = append(data, attr...)
|
||||||
|
}
|
||||||
|
|
||||||
batch[0] = netlink.Message{
|
batch[0] = netlink.Message{
|
||||||
Header: netlink.Header{
|
Header: netlink.Header{
|
||||||
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
|
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
|
||||||
Flags: netlink.Request,
|
Flags: netlink.Request,
|
||||||
},
|
},
|
||||||
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, msg := range messages {
|
for i, msg := range messages {
|
||||||
|
@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message {
|
||||||
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
|
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
|
||||||
Flags: netlink.Request,
|
Flags: netlink.Request,
|
||||||
},
|
},
|
||||||
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
return batch
|
return batch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocateTransactionID allocates an identifier which is only valid in the
|
// allocateTransactionID allocates an identifier which is only valid in the
|
||||||
|
|
50
gen.go
50
gen.go
|
@ -8,15 +8,18 @@ import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GenMsg struct {
|
type Gen struct {
|
||||||
ID uint32
|
ID uint32
|
||||||
ProcPID uint32
|
ProcPID uint32
|
||||||
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
|
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
|
||||||
|
type GenMsg = Gen
|
||||||
|
|
||||||
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
|
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
|
||||||
|
|
||||||
func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
func genFromMsg(msg netlink.Message) (*Gen, error) {
|
||||||
if got, want := msg.Header.Type, genHeaderType; got != want {
|
if got, want := msg.Header.Type, genHeaderType; got != want {
|
||||||
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
|
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
||||||
}
|
}
|
||||||
ad.ByteOrder = binary.BigEndian
|
ad.ByteOrder = binary.BigEndian
|
||||||
|
|
||||||
msgOut := &GenMsg{}
|
msgOut := &Gen{}
|
||||||
for ad.Next() {
|
for ad.Next() {
|
||||||
switch ad.Type() {
|
switch ad.Type() {
|
||||||
case unix.NFTA_GEN_ID:
|
case unix.NFTA_GEN_ID:
|
||||||
|
@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
||||||
case unix.NFTA_GEN_PROC_NAME:
|
case unix.NFTA_GEN_PROC_NAME:
|
||||||
msgOut.ProcComm = ad.String()
|
msgOut.ProcComm = ad.String()
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes())
|
return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := ad.Err(); err != nil {
|
if err := ad.Err(); err != nil {
|
||||||
|
@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
||||||
}
|
}
|
||||||
return msgOut, nil
|
return msgOut, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGen retrieves the current nftables generation ID together with the name
|
||||||
|
// and ID of the process that last modified the ruleset.
|
||||||
|
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
|
||||||
|
func (cc *Conn) GetGen() (*Gen, error) {
|
||||||
|
conn, closer, err := cc.netlinkConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = closer() }()
|
||||||
|
|
||||||
|
data, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||||
|
{Type: unix.NFTA_GEN_ID},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
message := netlink.Message{
|
||||||
|
Header: netlink.Header{
|
||||||
|
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN),
|
||||||
|
Flags: netlink.Request | netlink.Acknowledge,
|
||||||
|
},
|
||||||
|
Data: append(extraHeader(0, 0), data...),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||||
|
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||||
|
}
|
||||||
|
if len(reply) == 0 {
|
||||||
|
return nil, fmt.Errorf("receiveAckAware: no reply")
|
||||||
|
}
|
||||||
|
return genFromMsg(reply[0])
|
||||||
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
genMsg := event.GeneratedBy.Data.(*nftables.GenMsg)
|
genMsg := event.GeneratedBy.Data.(*nftables.Gen)
|
||||||
fileName := filepath.Base(os.Args[0])
|
fileName := filepath.Base(os.Args[0])
|
||||||
|
|
||||||
if genMsg.ProcComm != fileName {
|
if genMsg.ProcComm != fileName {
|
||||||
|
|
|
@ -7429,3 +7429,77 @@ func TestAutoBufferSize(t *testing.T) {
|
||||||
t.Fatalf("failed to flush: %v", err)
|
t.Fatalf("failed to flush: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetGen(t *testing.T) {
|
||||||
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
|
defer conn.FlushRuleset()
|
||||||
|
|
||||||
|
gen, err := conn.GetGen()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get gen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Flush to increment the generation ID.
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
t.Fatalf("failed to flush: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newGen, err := conn.GetGen()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get gen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newGen.ID <= gen.ID {
|
||||||
|
t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newGen.ProcComm != gen.ProcComm {
|
||||||
|
t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newGen.ProcPID != gen.ProcPID {
|
||||||
|
t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlushWithGenID(t *testing.T) {
|
||||||
|
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||||
|
defer nftest.CleanupSystemConn(t, newNS)
|
||||||
|
defer conn.FlushRuleset()
|
||||||
|
|
||||||
|
gen, err := conn.GetGen()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get gen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Flush to increment the generation ID.
|
||||||
|
if err := conn.Flush(); err != nil {
|
||||||
|
t.Fatalf("failed to flush: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.AddTable(&nftables.Table{
|
||||||
|
Name: "test-table-2",
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = conn.FlushWithGenID(gen.ID)
|
||||||
|
if err == nil || !errors.Is(err, syscall.ERESTART) {
|
||||||
|
t.Errorf("expected error to be ERESTART, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := conn.ListTable("test-table-2")
|
||||||
|
if table != nil && !errors.Is(err, syscall.ENOENT) {
|
||||||
|
t.Errorf("expected table to not exist, got: %v", table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue