diff --git a/gen.go b/gen.go new file mode 100644 index 0000000..4ebcfbe --- /dev/null +++ b/gen.go @@ -0,0 +1,45 @@ +package nftables + +import ( + "encoding/binary" + "fmt" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type GenMsg struct { + ID uint32 + ProcPID uint32 + ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN +} + +var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) + +func genFromMsg(msg netlink.Message) (*GenMsg, error) { + if got, want := msg.Header.Type, genHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + msgOut := &GenMsg{} + for ad.Next() { + switch ad.Type() { + case unix.NFTA_GEN_ID: + msgOut.ID = ad.Uint32() + case unix.NFTA_GEN_PROC_PID: + msgOut.ProcPID = ad.Uint32() + case unix.NFTA_GEN_PROC_NAME: + msgOut.ProcComm = ad.String() + default: + return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return msgOut, nil +} diff --git a/monitor.go b/monitor.go index 0400cc9..7a25e3b 100644 --- a/monitor.go +++ b/monitor.go @@ -116,9 +116,15 @@ const ( // nftables.MonitorEventTypeNewTable, you can access the corresponding table // details via Data.(*nftables.Table). type MonitorEvent struct { - Type MonitorEventType - Data any - Error error + Header netlink.Header + Type MonitorEventType + Data any + Error error +} + +type MonitorEvents struct { + GeneratedBy *MonitorEvent + Changes []*MonitorEvent } const ( @@ -139,7 +145,7 @@ type Monitor struct { // mu covers eventCh and status mu sync.Mutex - eventCh chan *MonitorEvent + eventCh chan *MonitorEvents status int } @@ -147,7 +153,7 @@ type MonitorOption func(*Monitor) func WithMonitorEventBuffer(size int) MonitorOption { return func(monitor *Monitor) { - monitor.eventCh = make(chan *MonitorEvent, size) + monitor.eventCh = make(chan *MonitorEvents, size) } } @@ -177,7 +183,7 @@ func NewMonitor(opts ...MonitorOption) *Monitor { opt(monitor) } if monitor.eventCh == nil { - monitor.eventCh = make(chan *MonitorEvent) + monitor.eventCh = make(chan *MonitorEvents) } objects, ok := monitorFlags[monitor.action] if !ok { @@ -192,6 +198,8 @@ func NewMonitor(opts ...MonitorOption) *Monitor { } func (monitor *Monitor) monitor() { + var changesEvents []*MonitorEvent + for { msgs, err := monitor.conn.Receive() if err != nil { @@ -199,13 +207,21 @@ func (monitor *Monitor) monitor() { // ignore the error that be closed break } else { - // any other errors will be send to user, and then to close eventCh + // any other errors will be sent to user, and then to close eventCh event := &MonitorEvent{ Type: MonitorEventTypeOOB, Data: nil, Error: err, } - monitor.eventCh <- event + + changesEvents = append(changesEvents, event) + + monitor.eventCh <- &MonitorEvents{ + GeneratedBy: event, + Changes: changesEvents, + } + changesEvents = nil + break } } @@ -221,54 +237,76 @@ func (monitor *Monitor) monitor() { case unix.NFT_MSG_NEWTABLE, unix.NFT_MSG_DELTABLE: table, err := tableFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: table, - Error: err, + Type: MonitorEventType(msgType), + Data: table, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWCHAIN, unix.NFT_MSG_DELCHAIN: chain, err := chainFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: chain, - Error: err, + Type: MonitorEventType(msgType), + Data: chain, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWRULE, unix.NFT_MSG_DELRULE: rule, err := parseRuleFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: rule, - Error: err, + Type: MonitorEventType(msgType), + Data: rule, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWSET, unix.NFT_MSG_DELSET: set, err := setsFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: set, - Error: err, + Type: MonitorEventType(msgType), + Data: set, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWSETELEM, unix.NFT_MSG_DELSETELEM: elems, err := elementsFromMsg(uint8(TableFamilyUnspecified), msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: elems, - Error: err, + Type: MonitorEventType(msgType), + Data: elems, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ: obj, err := objFromMsg(msg, true) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: obj, - Error: err, + Type: MonitorEventType(msgType), + Data: obj, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) + case unix.NFT_MSG_NEWGEN: + gen, err := genFromMsg(msg) + event := &MonitorEvent{ + Type: MonitorEventType(msgType), + Data: gen, + Error: err, + Header: msg.Header, + } + + monitor.eventCh <- &MonitorEvents{ + GeneratedBy: event, + Changes: changesEvents, + } + + changesEvents = nil } } } + monitor.mu.Lock() defer monitor.mu.Unlock() @@ -294,6 +332,26 @@ func (monitor *Monitor) Close() error { // Caller may receive a MonitorEventTypeOOB event which contains an error we didn't // handle, for now. func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvent, error) { + generationalEventCh, err := cc.AddGenerationalMonitor(monitor) + if err != nil { + return nil, err + } + + eventCh := make(chan *MonitorEvent) + + go func() { + defer close(eventCh) + for monitorEvents := range generationalEventCh { + for _, event := range monitorEvents.Changes { + eventCh <- event + } + } + }() + + return eventCh, nil +} + +func (cc *Conn) AddGenerationalMonitor(monitor *Monitor) (chan *MonitorEvents, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err diff --git a/monitor_test.go b/monitor_test.go index f47b5de..d690432 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -4,6 +4,8 @@ import ( "fmt" "log" "net" + "os" + "path/filepath" "sync" "sync/atomic" "testing" @@ -21,17 +23,19 @@ func ExampleNewMonitor() { mon := nftables.NewMonitor() defer mon.Close() - events, err := conn.AddMonitor(mon) + events, err := conn.AddGenerationalMonitor(mon) if err != nil { log.Fatal(err) } for ev := range events { - log.Printf("ev: %+v, data = %T", ev, ev.Data) - switch ev.Type { - case nftables.MonitorEventTypeNewTable: - log.Printf("data = %+v", ev.Data.(*nftables.Table)) + log.Printf("ev: %+v, data = %T", ev, ev.Changes) - // …more cases if needed… + for _, change := range ev.Changes { + switch change.Type { + case nftables.MonitorEventTypeNewTable: + log.Printf("data = %+v", change.Data.(*nftables.Table)) + // …more cases if needed… + } } } } @@ -44,10 +48,9 @@ func TestMonitor(t *testing.T) { // Clear all rules at the beginning + end of the test. c.FlushRuleset() defer c.FlushRuleset() - // default to monitor all monitor := nftables.NewMonitor() - events, err := c.AddMonitor(monitor) + events, err := c.AddGenerationalMonitor(monitor) if err != nil { t.Fatal(err) } @@ -58,6 +61,7 @@ func TestMonitor(t *testing.T) { var gotRule *nftables.Rule wg := sync.WaitGroup{} wg.Add(1) + var errMonitor error go func() { defer wg.Done() count := int32(0) @@ -66,23 +70,35 @@ func TestMonitor(t *testing.T) { if !ok { return } - if event.Error != nil { - err = fmt.Errorf("monitor err: %s", event.Error) + + genMsg := event.GeneratedBy.Data.(*nftables.GenMsg) + fileName := filepath.Base(os.Args[0]) + + if genMsg.ProcComm != fileName { + errMonitor = fmt.Errorf("procComm: %s, want: %s", genMsg.ProcComm, fileName) return } - switch event.Type { - case nftables.MonitorEventTypeNewTable: - gotTable = event.Data.(*nftables.Table) - atomic.AddInt32(&count, 1) - case nftables.MonitorEventTypeNewChain: - gotChain = event.Data.(*nftables.Chain) - atomic.AddInt32(&count, 1) - case nftables.MonitorEventTypeNewRule: - gotRule = event.Data.(*nftables.Rule) - atomic.AddInt32(&count, 1) - } - if atomic.LoadInt32(&count) == 3 { - return + + for _, change := range event.Changes { + if change.Error != nil { + errMonitor = fmt.Errorf("monitor err: %s", change.Error) + return + } + + switch change.Type { + case nftables.MonitorEventTypeNewTable: + gotTable = change.Data.(*nftables.Table) + atomic.AddInt32(&count, 1) + case nftables.MonitorEventTypeNewChain: + gotChain = change.Data.(*nftables.Chain) + atomic.AddInt32(&count, 1) + case nftables.MonitorEventTypeNewRule: + gotRule = change.Data.(*nftables.Rule) + atomic.AddInt32(&count, 1) + } + if atomic.LoadInt32(&count) == 3 { + return + } } } }() @@ -126,7 +142,13 @@ func TestMonitor(t *testing.T) { if err := c.Flush(); err != nil { t.Fatal(err) } + wg.Wait() + + if errMonitor != nil { + t.Fatal("monitor err", errMonitor) + } + if gotTable.Family != nat.Family || gotTable.Name != nat.Name { t.Fatal("no want table", gotTable.Family, gotTable.Name) }