diff --git a/gen.go b/gen.go new file mode 100644 index 0000000..4ebcfbe --- /dev/null +++ b/gen.go @@ -0,0 +1,45 @@ +package nftables + +import ( + "encoding/binary" + "fmt" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type GenMsg struct { + ID uint32 + ProcPID uint32 + ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN +} + +var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) + +func genFromMsg(msg netlink.Message) (*GenMsg, error) { + if got, want := msg.Header.Type, genHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + msgOut := &GenMsg{} + for ad.Next() { + switch ad.Type() { + case unix.NFTA_GEN_ID: + msgOut.ID = ad.Uint32() + case unix.NFTA_GEN_PROC_PID: + msgOut.ProcPID = ad.Uint32() + case unix.NFTA_GEN_PROC_NAME: + msgOut.ProcComm = ad.String() + default: + return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return msgOut, nil +} diff --git a/monitor.go b/monitor.go index 0400cc9..753eb11 100644 --- a/monitor.go +++ b/monitor.go @@ -116,9 +116,15 @@ const ( // nftables.MonitorEventTypeNewTable, you can access the corresponding table // details via Data.(*nftables.Table). type MonitorEvent struct { - Type MonitorEventType - Data any - Error error + Header netlink.Header + Type MonitorEventType + Data any + Error error +} + +type MonitorEvents struct { + GenerateBy *MonitorEvent + Changes []*MonitorEvent } const ( @@ -139,7 +145,7 @@ type Monitor struct { // mu covers eventCh and status mu sync.Mutex - eventCh chan *MonitorEvent + eventCh chan *MonitorEvents status int } @@ -147,7 +153,7 @@ type MonitorOption func(*Monitor) func WithMonitorEventBuffer(size int) MonitorOption { return func(monitor *Monitor) { - monitor.eventCh = make(chan *MonitorEvent, size) + monitor.eventCh = make(chan *MonitorEvents, size) } } @@ -177,7 +183,7 @@ func NewMonitor(opts ...MonitorOption) *Monitor { opt(monitor) } if monitor.eventCh == nil { - monitor.eventCh = make(chan *MonitorEvent) + monitor.eventCh = make(chan *MonitorEvents) } objects, ok := monitorFlags[monitor.action] if !ok { @@ -192,6 +198,8 @@ func NewMonitor(opts ...MonitorOption) *Monitor { } func (monitor *Monitor) monitor() { + changesEvents := make([]*MonitorEvent, 0, 2) + for { msgs, err := monitor.conn.Receive() if err != nil { @@ -199,13 +207,21 @@ func (monitor *Monitor) monitor() { // ignore the error that be closed break } else { - // any other errors will be send to user, and then to close eventCh + // any other errors will be sent to user, and then to close eventCh event := &MonitorEvent{ Type: MonitorEventTypeOOB, Data: nil, Error: err, } - monitor.eventCh <- event + + changesEvents = append(changesEvents, event) + + monitor.eventCh <- &MonitorEvents{ + GenerateBy: event, + Changes: changesEvents, + } + changesEvents = make([]*MonitorEvent, 0, 2) + break } } @@ -221,54 +237,76 @@ func (monitor *Monitor) monitor() { case unix.NFT_MSG_NEWTABLE, unix.NFT_MSG_DELTABLE: table, err := tableFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: table, - Error: err, + Type: MonitorEventType(msgType), + Data: table, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWCHAIN, unix.NFT_MSG_DELCHAIN: chain, err := chainFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: chain, - Error: err, + Type: MonitorEventType(msgType), + Data: chain, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWRULE, unix.NFT_MSG_DELRULE: rule, err := parseRuleFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: rule, - Error: err, + Type: MonitorEventType(msgType), + Data: rule, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWSET, unix.NFT_MSG_DELSET: set, err := setsFromMsg(msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: set, - Error: err, + Type: MonitorEventType(msgType), + Data: set, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWSETELEM, unix.NFT_MSG_DELSETELEM: elems, err := elementsFromMsg(uint8(TableFamilyUnspecified), msg) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: elems, - Error: err, + Type: MonitorEventType(msgType), + Data: elems, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ: obj, err := objFromMsg(msg, true) event := &MonitorEvent{ - Type: MonitorEventType(msgType), - Data: obj, - Error: err, + Type: MonitorEventType(msgType), + Data: obj, + Error: err, + Header: msg.Header, } - monitor.eventCh <- event + changesEvents = append(changesEvents, event) + case unix.NFT_MSG_NEWGEN: + gen, err := genFromMsg(msg) + event := &MonitorEvent{ + Type: MonitorEventType(msgType), + Data: gen, + Error: err, + Header: msg.Header, + } + + monitor.eventCh <- &MonitorEvents{ + GenerateBy: event, + Changes: changesEvents, + } + + changesEvents = make([]*MonitorEvent, 0, 2) } } } + monitor.mu.Lock() defer monitor.mu.Unlock() @@ -293,7 +331,7 @@ func (monitor *Monitor) Close() error { // calling Close on Monitor or encountering a netlink conn error while Receive. // Caller may receive a MonitorEventTypeOOB event which contains an error we didn't // handle, for now. -func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvent, error) { +func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvents, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err diff --git a/monitor_test.go b/monitor_test.go index f47b5de..3bab034 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -26,12 +26,14 @@ func ExampleNewMonitor() { log.Fatal(err) } for ev := range events { - log.Printf("ev: %+v, data = %T", ev, ev.Data) - switch ev.Type { - case nftables.MonitorEventTypeNewTable: - log.Printf("data = %+v", ev.Data.(*nftables.Table)) + log.Printf("ev: %+v, data = %T", ev, ev.Changes) - // …more cases if needed… + for _, change := range ev.Changes { + switch change.Type { + case nftables.MonitorEventTypeNewTable: + log.Printf("data = %+v", change.Data.(*nftables.Table)) + // …more cases if needed… + } } } } @@ -66,23 +68,27 @@ func TestMonitor(t *testing.T) { if !ok { return } - if event.Error != nil { - err = fmt.Errorf("monitor err: %s", event.Error) - return - } - switch event.Type { - case nftables.MonitorEventTypeNewTable: - gotTable = event.Data.(*nftables.Table) - atomic.AddInt32(&count, 1) - case nftables.MonitorEventTypeNewChain: - gotChain = event.Data.(*nftables.Chain) - atomic.AddInt32(&count, 1) - case nftables.MonitorEventTypeNewRule: - gotRule = event.Data.(*nftables.Rule) - atomic.AddInt32(&count, 1) - } - if atomic.LoadInt32(&count) == 3 { - return + + for _, change := range event.Changes { + if change.Error != nil { + err = fmt.Errorf("monitor err: %s", change.Error) + return + } + + switch change.Type { + case nftables.MonitorEventTypeNewTable: + gotTable = change.Data.(*nftables.Table) + atomic.AddInt32(&count, 1) + case nftables.MonitorEventTypeNewChain: + gotChain = change.Data.(*nftables.Chain) + atomic.AddInt32(&count, 1) + case nftables.MonitorEventTypeNewRule: + gotRule = change.Data.(*nftables.Rule) + atomic.AddInt32(&count, 1) + } + if atomic.LoadInt32(&count) == 3 { + return + } } } }() diff --git a/set.go b/set.go index d5afff3..602f77f 100644 --- a/set.go +++ b/set.go @@ -268,6 +268,12 @@ type Set struct { KeyByteOrder binaryutil.ByteOrder } +type SetElementsInfo struct { + TableName string + SetName string + Elements []SetElement +} + // SetElement represents a data point within a set. type SetElement struct { Key []byte @@ -797,10 +803,16 @@ func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { } ad.ByteOrder = binary.BigEndian + var info = &SetElementsInfo{} var elements []SetElement for ad.Next() { b := ad.Bytes() - if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { + switch ad.Type() { + case unix.NFTA_SET_ELEM_LIST_TABLE: + info.TableName = ad.String() + case unix.NFTA_SET_ELEM_LIST_SET: + info.SetName = ad.String() + case unix.NFTA_SET_ELEM_LIST_ELEMENTS: ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err