From 4e4da6b88a0999b4950a3e1ab78d730a7f7eff9f Mon Sep 17 00:00:00 2001 From: ignatella Date: Wed, 16 Oct 2024 14:48:24 +0200 Subject: [PATCH 1/5] Update: process monitor events in batches --- gen.go | 45 +++++++++++++++++++++ monitor.go | 104 +++++++++++++++++++++++++++++++++--------------- monitor_test.go | 50 +++++++++++++---------- set.go | 14 ++++++- 4 files changed, 157 insertions(+), 56 deletions(-) create mode 100644 gen.go diff --git a/gen.go b/gen.go new file mode 100644 index 0000000..4ebcfbe --- /dev/null +++ b/gen.go @@ -0,0 +1,45 @@ +package nftables + +import ( + "encoding/binary" + "fmt" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type GenMsg struct { + ID uint32 + ProcPID uint32 + ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN +} + +var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) + +func genFromMsg(msg netlink.Message) (*GenMsg, error) { + if got, want := msg.Header.Type, genHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + msgOut := &GenMsg{} + for ad.Next() { + switch ad.Type() { + case unix.NFTA_GEN_ID: + msgOut.ID = ad.Uint32() + case unix.NFTA_GEN_PROC_PID: + msgOut.ProcPID = ad.Uint32() + case unix.NFTA_GEN_PROC_NAME: + msgOut.ProcComm = ad.String() + default: + return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return msgOut, nil +} diff --git a/monitor.go b/monitor.go index 0400cc9..753eb11 100644 --- a/monitor.go +++ b/monitor.go @@ -116,9 +116,15 @@ const ( // nftables.MonitorEventTypeNewTable, you can access the corresponding table // details via Data.(*nftables.Table). type MonitorEvent struct { - Type MonitorEventType - Data any - Error error + Header netlink.Header + Type MonitorEventType + Data any + Error error +} + +type MonitorEvents struct { + GenerateBy *MonitorEvent + Changes []*MonitorEvent } const ( @@ -139,7 +145,7 @@ type Monitor struct { // mu covers eventCh and status mu sync.Mutex - eventCh chan *MonitorEvent + eventCh chan *MonitorEvents status int } @@ -147,7 +153,7 @@ type MonitorOption func(*Monitor) func WithMonitorEventBuffer(size int) MonitorOption { return func(monitor *Monitor) { - monitor.eventCh = make(chan *MonitorEvent, size) + monitor.eventCh = make(chan *MonitorEvents, size) } } @@ -177,7 +183,7 @@ func NewMonitor(opts ...MonitorOption) *Monitor { opt(monitor) } if monitor.eventCh == nil { - monitor.eventCh = make(chan *MonitorEvent) + monitor.eventCh = make(chan *MonitorEvents) } objects, ok := monitorFlags[monitor.action] if !ok { @@ -192,6 +198,8 @@ func NewMonitor(opts ...MonitorOption) *Monitor { } func (monitor *Monitor) monitor() { + changesEvents := make([]*MonitorEvent, 0, 2) + for { msgs, err := monitor.conn.Receive() if err != nil { @@ -199,13 +207,21 @@ func (monitor *Monitor) monitor() { // ignore the error that be closed break } else { - // any other errors will be send to user, and then to close eventCh + // any other errors will be sent to user, and then to close eventCh event := &MonitorEvent{ Type: MonitorEventTypeOOB, Data: nil, Error: err, } - monitor.eventCh <- event + + changesEvents = append(changesEvents, event) + + monitor.eventCh <- &MonitorEvents{ + GenerateBy: event, + Changes: changesEvents, + } + changesEvents = make([]*MonitorEvent, 0, 2) + break } } @@ -221,54 +237,76 @@ func (monitor *Monitor) monitor() { case unix.NFT_MSG_NEWTABLE, unix.NFT_MSG_DELTABLE: table, err := tableFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: table, - Error: err, + Type: MonitorEventType(msgType), + Data: table, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWCHAIN, unix.NFT_MSG_DELCHAIN: chain, err := chainFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: chain, - Error: err, + Type: MonitorEventType(msgType), + Data: chain, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWRULE, unix.NFT_MSG_DELRULE: rule, err := parseRuleFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: rule, - Error: err, + Type: MonitorEventType(msgType), + Data: rule, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWSET, unix.NFT_MSG_DELSET: set, err := setsFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: set, - Error: err, + Type: MonitorEventType(msgType), + Data: set, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWSETELEM, unix.NFT_MSG_DELSETELEM: elems, err := elementsFromMsg(uint8(TableFamilyUnspecified), msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: elems, - Error: err, + Type: MonitorEventType(msgType), + Data: elems, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ: obj, err := objFromMsg(msg, true) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: obj, - Error: err, + Type: MonitorEventType(msgType), + Data: obj, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) + case unix.NFT_MSG_NEWGEN: + gen, err := genFromMsg(msg) + event := &MonitorEvent{ + Type: MonitorEventType(msgType), + Data: gen, + Error: err, + Header: msg.Header, + } + + monitor.eventCh <- &MonitorEvents{ + GenerateBy: event, + Changes: changesEvents, + } + + changesEvents = make([]*MonitorEvent, 0, 2) } } } + monitor.mu.Lock() defer monitor.mu.Unlock() @@ -293,7 +331,7 @@ func (monitor *Monitor) Close() error { // calling Close on Monitor or encountering a netlink conn error while Receive. // Caller may receive a MonitorEventTypeOOB event which contains an error we didn't // handle, for now. -func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvent, error) { +func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvents, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err diff --git a/monitor_test.go b/monitor_test.go index f47b5de..3bab034 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -26,12 +26,14 @@ func ExampleNewMonitor() { log.Fatal(err) } for ev := range events { - log.Printf("ev: %+v, data = %T", ev, ev.Data) - switch ev.Type { - case nftables.MonitorEventTypeNewTable: - log.Printf("data = %+v", ev.Data.(*nftables.Table)) + log.Printf("ev: %+v, data = %T", ev, ev.Changes) - // …more cases if needed… + for _, change := range ev.Changes { + switch change.Type { + case nftables.MonitorEventTypeNewTable: + log.Printf("data = %+v", change.Data.(*nftables.Table)) + // …more cases if needed… + } } } } @@ -66,23 +68,27 @@ func TestMonitor(t *testing.T) { if !ok { return } - if event.Error != nil { - err = fmt.Errorf("monitor err: %s", event.Error) - return - } - switch event.Type { - case nftables.MonitorEventTypeNewTable: - gotTable = event.Data.(*nftables.Table) - atomic.AddInt32(&count, 1) - case nftables.MonitorEventTypeNewChain: - gotChain = event.Data.(*nftables.Chain) - atomic.AddInt32(&count, 1) - case nftables.MonitorEventTypeNewRule: - gotRule = event.Data.(*nftables.Rule) - atomic.AddInt32(&count, 1) - } - if atomic.LoadInt32(&count) == 3 { - return + + for _, change := range event.Changes { + if change.Error != nil { + err = fmt.Errorf("monitor err: %s", change.Error) + return + } + + switch change.Type { + case nftables.MonitorEventTypeNewTable: + gotTable = change.Data.(*nftables.Table) + atomic.AddInt32(&count, 1) + case nftables.MonitorEventTypeNewChain: + gotChain = change.Data.(*nftables.Chain) + atomic.AddInt32(&count, 1) + case nftables.MonitorEventTypeNewRule: + gotRule = change.Data.(*nftables.Rule) + atomic.AddInt32(&count, 1) + } + if atomic.LoadInt32(&count) == 3 { + return + } } } }() diff --git a/set.go b/set.go index d5afff3..602f77f 100644 --- a/set.go +++ b/set.go @@ -268,6 +268,12 @@ type Set struct { KeyByteOrder binaryutil.ByteOrder } +type SetElementsInfo struct { + TableName string + SetName string + Elements []SetElement +} + // SetElement represents a data point within a set. type SetElement struct { Key []byte @@ -797,10 +803,16 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { } ad.ByteOrder = binary.BigEndian + var info = &SetElementsInfo{} var elements []SetElement for ad.Next() { b := ad.Bytes() - if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { + switch ad.Type() { + case unix.NFTA_SET_ELEM_LIST_TABLE: + info.TableName = ad.String() + case unix.NFTA_SET_ELEM_LIST_SET: + info.SetName = ad.String() + case unix.NFTA_SET_ELEM_LIST_ELEMENTS: ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err From afa496e5e9aa605b04020844bedbb87f1e387d69 Mon Sep 17 00:00:00 2001 From: ignatella Date: Mon, 4 Nov 2024 17:23:43 +0100 Subject: [PATCH 2/5] Add: generational monitor --- monitor.go | 41 +++++++++++++++++++++++++++++++---------- monitor_test.go | 4 ++-- set.go | 12 +----------- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/monitor.go b/monitor.go index 753eb11..bcc0f9d 100644 --- a/monitor.go +++ b/monitor.go @@ -123,8 +123,8 @@ type MonitorEvent struct { } type MonitorEvents struct { - GenerateBy *MonitorEvent - Changes []*MonitorEvent + GeneratedBy *MonitorEvent + Changes []*MonitorEvent } const ( @@ -198,7 +198,7 @@ func NewMonitor(opts ...MonitorOption) *Monitor { } func (monitor *Monitor) monitor() { - changesEvents := make([]*MonitorEvent, 0, 2) + changesEvents := make([]*MonitorEvent, 0) for { msgs, err := monitor.conn.Receive() @@ -217,10 +217,10 @@ func (monitor *Monitor) monitor() { changesEvents = append(changesEvents, event) monitor.eventCh <- &MonitorEvents{ - GenerateBy: event, - Changes: changesEvents, + GeneratedBy: event, + Changes: changesEvents, } - changesEvents = make([]*MonitorEvent, 0, 2) + changesEvents = make([]*MonitorEvent, 0) break } @@ -298,11 +298,11 @@ func (monitor *Monitor) monitor() { } monitor.eventCh <- &MonitorEvents{ - GenerateBy: event, - Changes: changesEvents, + GeneratedBy: event, + Changes: changesEvents, } - changesEvents = make([]*MonitorEvent, 0, 2) + changesEvents = make([]*MonitorEvent, 0) } } } @@ -331,7 +331,28 @@ func (monitor *Monitor) Close() error { // calling Close on Monitor or encountering a netlink conn error while Receive. // Caller may receive a MonitorEventTypeOOB event which contains an error we didn't // handle, for now. -func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvents, error) { +func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvent, error) { + generationalEventCh, err := cc.AddGenerationalMonitor(monitor) + + if err != nil { + return nil, err + } + + eventCh := make(chan *MonitorEvent) + + go func() { + defer close(eventCh) + for monitorEvents := range generationalEventCh { + for _, event := range monitorEvents.Changes { + eventCh <- event + } + } + }() + + return eventCh, nil +} + +func (cc *Conn) AddGenerationalMonitor(monitor *Monitor) (chan *MonitorEvents, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err diff --git a/monitor_test.go b/monitor_test.go index 3bab034..155e13c 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -21,7 +21,7 @@ func ExampleNewMonitor() { mon := nftables.NewMonitor() defer mon.Close() - events, err := conn.AddMonitor(mon) + events, err := conn.AddGenerationalMonitor(mon) if err != nil { log.Fatal(err) } @@ -49,7 +49,7 @@ func TestMonitor(t *testing.T) { // default to monitor all monitor := nftables.NewMonitor() - events, err := c.AddMonitor(monitor) + events, err := c.AddGenerationalMonitor(monitor) if err != nil { t.Fatal(err) } diff --git a/set.go b/set.go index 602f77f..45faf20 100644 --- a/set.go +++ b/set.go @@ -268,12 +268,6 @@ type Set struct { KeyByteOrder binaryutil.ByteOrder } -type SetElementsInfo struct { - TableName string - SetName string - Elements []SetElement -} - // SetElement represents a data point within a set. type SetElement struct { Key []byte @@ -803,15 +797,10 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { } ad.ByteOrder = binary.BigEndian - var info = &SetElementsInfo{} var elements []SetElement for ad.Next() { b := ad.Bytes() switch ad.Type() { - case unix.NFTA_SET_ELEM_LIST_TABLE: - info.TableName = ad.String() - case unix.NFTA_SET_ELEM_LIST_SET: - info.SetName = ad.String() case unix.NFTA_SET_ELEM_LIST_ELEMENTS: ad, err := netlink.NewAttributeDecoder(b) if err != nil { @@ -829,6 +818,7 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { } } } + return elements, nil } From a77a91fb4661b811dee6c0df79ce6b26ae160aaf Mon Sep 17 00:00:00 2001 From: ignatella Date: Mon, 4 Nov 2024 18:08:39 +0100 Subject: [PATCH 3/5] Fix: remove not pr-related changes --- set.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/set.go b/set.go index 45faf20..d5afff3 100644 --- a/set.go +++ b/set.go @@ -800,8 +800,7 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { var elements []SetElement for ad.Next() { b := ad.Bytes() - switch ad.Type() { - case unix.NFTA_SET_ELEM_LIST_ELEMENTS: + if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err @@ -818,7 +817,6 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { } } } - return elements, nil } From 70148431616bcc74c0c68428c3ae1dbbfaefe126 Mon Sep 17 00:00:00 2001 From: ignatella Date: Mon, 4 Nov 2024 18:49:25 +0100 Subject: [PATCH 4/5] Add: proc comm test --- monitor.go | 5 ++--- monitor_test.go | 11 ++++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/monitor.go b/monitor.go index bcc0f9d..7bd88f3 100644 --- a/monitor.go +++ b/monitor.go @@ -198,7 +198,7 @@ func NewMonitor(opts ...MonitorOption) *Monitor { } func (monitor *Monitor) monitor() { - changesEvents := make([]*MonitorEvent, 0) + var changesEvents []*MonitorEvent for { msgs, err := monitor.conn.Receive() @@ -220,7 +220,7 @@ func (monitor *Monitor) monitor() { GeneratedBy: event, Changes: changesEvents, } - changesEvents = make([]*MonitorEvent, 0) + changesEvents = nil break } @@ -333,7 +333,6 @@ func (monitor *Monitor) Close() error { // handle, for now. func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvent, error) { generationalEventCh, err := cc.AddGenerationalMonitor(monitor) - if err != nil { return nil, err } diff --git a/monitor_test.go b/monitor_test.go index 155e13c..9bccff4 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -4,6 +4,8 @@ import ( "fmt" "log" "net" + "os" + "path/filepath" "sync" "sync/atomic" "testing" @@ -46,7 +48,6 @@ func TestMonitor(t *testing.T) { // Clear all rules at the beginning + end of the test. c.FlushRuleset() defer c.FlushRuleset() - // default to monitor all monitor := nftables.NewMonitor() events, err := c.AddGenerationalMonitor(monitor) @@ -69,6 +70,14 @@ func TestMonitor(t *testing.T) { return } + genMsg := event.GeneratedBy.Data.(*nftables.GenMsg) + fileName := filepath.Base(os.Args[0]) + + if genMsg.ProcComm != fileName { + err = fmt.Errorf("procComm: %s, want: %s", genMsg.ProcComm, fileName) + return + } + for _, change := range event.Changes { if change.Error != nil { err = fmt.Errorf("monitor err: %s", change.Error) From ffa09824d5d944e7274444b75703659238ee564a Mon Sep 17 00:00:00 2001 From: ignatella Date: Mon, 4 Nov 2024 19:54:56 +0100 Subject: [PATCH 5/5] Add: test monitor error --- monitor_test.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/monitor_test.go b/monitor_test.go index 9bccff4..d690432 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -61,6 +61,7 @@ func TestMonitor(t *testing.T) { var gotRule *nftables.Rule wg := sync.WaitGroup{} wg.Add(1) + var errMonitor error go func() { defer wg.Done() count := int32(0) @@ -74,13 +75,13 @@ func TestMonitor(t *testing.T) { fileName := filepath.Base(os.Args[0]) if genMsg.ProcComm != fileName { - err = fmt.Errorf("procComm: %s, want: %s", genMsg.ProcComm, fileName) + errMonitor = fmt.Errorf("procComm: %s, want: %s", genMsg.ProcComm, fileName) return } for _, change := range event.Changes { if change.Error != nil { - err = fmt.Errorf("monitor err: %s", change.Error) + errMonitor = fmt.Errorf("monitor err: %s", change.Error) return } @@ -141,7 +142,13 @@ func TestMonitor(t *testing.T) { if err := c.Flush(); err != nil { t.Fatal(err) } + wg.Wait() + + if errMonitor != nil { + t.Fatal("monitor err", errMonitor) + } + if gotTable.Family != nat.Family || gotTable.Name != nat.Name { t.Fatal("no want table", gotTable.Family, gotTable.Name) }