diff --git a/chain.go b/chain.go index 4f4c0a5..f1853cf 100644 --- a/chain.go +++ b/chain.go @@ -140,7 +140,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -161,7 +161,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -179,7 +179,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index 99baa2d..d974b80 100644 --- a/conn.go +++ b/conn.go @@ -41,7 +41,7 @@ type Conn struct { lasting bool // establish a lasting connection to be used across multiple netlink operations. mu sync.Mutex // protects the following state - messages []netlink.Message + messages []netlinkMessage err error nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. sockOptions []SockOption @@ -49,6 +49,12 @@ type Conn struct { allocatedIDs uint32 } +type netlinkMessage struct { + Header netlink.Header + Data []byte + rule *Rule +} + // ConnOption is an option to change the behavior of the nftables Conn returned by Open. type ConnOption func(*Conn) @@ -268,6 +274,11 @@ func (cc *Conn) Flush() error { } else if replyIndex < len(cc.messages) { msg := messages[replyIndex+1] if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type { + // The only messages which set the echo flag are rule create messages. + err := cc.messages[replyIndex].rule.handleCreateReply(reply) + if err != nil { + errs = errors.Join(errs, err) + } replyIndex++ for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 { replyIndex++ @@ -309,7 +320,7 @@ func (cc *Conn) Flush() error { func (cc *Conn) FlushRuleset() { cc.mu.Lock() defer cc.mu.Unlock() - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -368,26 +379,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { return b } -func batch(messages []netlink.Message) []netlink.Message { - batch := []netlink.Message{ - { - Header: netlink.Header{ - Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), - Flags: netlink.Request, - }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), +func batch(messages []netlinkMessage) []netlink.Message { + batch := make([]netlink.Message, len(messages)+2) + batch[0] = netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), } - batch = append(batch, messages...) + for i, msg := range messages { + batch[i+1] = netlink.Message{ + Header: msg.Header, + Data: msg.Data, + } + } - batch = append(batch, netlink.Message{ + batch[len(messages)+1] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }) + } return batch } diff --git a/flowtable.go b/flowtable.go index 93dbcb5..a35712f 100644 --- a/flowtable.go +++ b/flowtable.go @@ -142,7 +142,7 @@ func (cc *Conn) AddFlowtable(f *Flowtable) *Flowtable { {Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)}, })...) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -162,7 +162,7 @@ func (cc *Conn) DelFlowtable(f *Flowtable) { {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/internal/nftest/nftest.go b/internal/nftest/nftest.go index 8d5b496..509bac3 100644 --- a/internal/nftest/nftest.go +++ b/internal/nftest/nftest.go @@ -8,7 +8,9 @@ import ( "testing" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" ) // Recorder provides an nftables connection that does not send to the Linux @@ -21,6 +23,7 @@ type Recorder struct { // Conn opens an nftables connection that records netlink messages into the // Recorder. func (r *Recorder) Conn() (*nftables.Conn, error) { + nextHandle := uint64(1) return nftables.New(nftables.WithTestDial( func(req []netlink.Message) ([]netlink.Message, error) { r.requests = append(r.requests, req...) @@ -30,6 +33,14 @@ func (r *Recorder) Conn() (*nftables.Conn, error) { for _, msg := range req { if msg.Header.Flags&netlink.Echo != 0 { data := append([]byte{}, msg.Data...) + switch msg.Header.Type { + case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE): + attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)}, + }) + nextHandle++ + data = append(data, attrs...) + } replies = append(replies, netlink.Message{ Header: msg.Header, Data: data, diff --git a/nftables_test.go b/nftables_test.go index 6e9ecfe..fd0fa1a 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -80,6 +80,7 @@ func linediff(a, b string) string { } func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption { + nextHandle := uint64(1) return nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) { var replies []netlink.Message for idx, msg := range req { @@ -103,6 +104,14 @@ func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption { // Generate replies. if msg.Header.Flags&netlink.Echo != 0 { data := append([]byte{}, msg.Data...) + switch msg.Header.Type { + case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE): + attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)}, + }) + nextHandle++ + data = append(data, attrs...) + } replies = append(replies, netlink.Message{ Header: msg.Header, Data: data, @@ -316,7 +325,7 @@ func TestRuleHandle(t *testing.T) { } for _, tt := range tests { - for _, afterFlush := range []bool{false} { + for _, afterFlush := range []bool{false, true} { flushName := map[bool]string{ false: "-before-flush", true: "-after-flush", diff --git a/obj.go b/obj.go index 634931b..65c4402 100644 --- a/obj.go +++ b/obj.go @@ -124,7 +124,7 @@ func (cc *Conn) AddObj(o Obj) Obj { attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -146,7 +146,7 @@ func (cc *Conn) DeleteObject(o Obj) { data := cc.marshalAttr(attrs) data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ), Flags: netlink.Request | netlink.Acknowledge, diff --git a/rule.go b/rule.go index 68da81f..65c6562 100644 --- a/rule.go +++ b/rule.go @@ -48,10 +48,13 @@ const ( type Rule struct { Table *Table Chain *Chain - // Handle identifies an existing Rule. + // Handle identifies an existing Rule. For a new Rule, this field is set + // during the Flush() in which the rule is committed. Make sure to not access + // this field concurrently with this Flush() to avoid data races. Handle uint64 // ID is an identifier for a new Rule, which is assigned by // AddRule/InsertRule, and only valid before the rule is committed by Flush(). + // The field is set to 0 during Flush(). ID uint32 // Position can be set to the Handle of another Rule to insert the new Rule // before (InsertRule) or after (AddRule) the existing rule. @@ -171,11 +174,14 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { } var flags netlink.HeaderFlags + var ruleRef *Rule switch op { case operationAdd: flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append + ruleRef = r case operationInsert: flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo + ruleRef = r case operationReplace: flags = netlink.Request | netlink.Acknowledge | netlink.Replace } @@ -190,17 +196,42 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { })...) } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: newRuleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), + rule: ruleRef, }) return r } +func (r *Rule) handleCreateReply(reply netlink.Message) error { + ad, err := netlink.NewAttributeDecoder(reply.Data[4:]) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + var handle uint64 + for ad.Next() { + switch ad.Type() { + case unix.NFTA_RULE_HANDLE: + handle = ad.Uint64() + } + } + if ad.Err() != nil { + return ad.Err() + } + if handle == 0 { + return fmt.Errorf("missing rule handle in create reply") + } + r.Handle = handle + r.ID = 0 + return nil +} + func (cc *Conn) ReplaceRule(r *Rule) *Rule { return cc.newRule(r, operationReplace) } @@ -247,7 +278,7 @@ func (cc *Conn) DelRule(r *Rule) error { } flags := netlink.Request | netlink.Acknowledge - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: delRuleHeaderType, Flags: flags, diff --git a/set.go b/set.go index 431191e..5818ca7 100644 --- a/set.go +++ b/set.go @@ -506,7 +506,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -680,7 +680,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) } - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -700,7 +700,7 @@ func (cc *Conn) DelSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, @@ -717,7 +717,7 @@ func (cc *Conn) FlushSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, diff --git a/set_test.go b/set_test.go index 65a8e00..dd30f45 100644 --- a/set_test.go +++ b/set_test.go @@ -254,7 +254,10 @@ func TestMarshalSet(t *testing.T) { } msg := c.messages[connMsgSetIdx] - nset, err := setsFromMsg(msg) + nset, err := setsFromMsg(netlink.Message{ + Header: msg.Header, + Data: msg.Data, + }) if err != nil { t.Fatalf("setsFromMsg() error: %+v", err) } diff --git a/table.go b/table.go index 79e486b..3686b7a 100644 --- a/table.go +++ b/table.go @@ -57,7 +57,7 @@ func (cc *Conn) DelTable(t *Table) { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge, @@ -73,7 +73,7 @@ func (cc *Conn) addTable(t *Table, flag netlink.HeaderFlags) *Table { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Flags: netlink.Request | netlink.Acknowledge | flag, @@ -103,7 +103,7 @@ func (cc *Conn) FlushTable(t *Table) { data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.messages = append(cc.messages, netlinkMessage{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge,