diff --git a/nftables_test.go b/nftables_test.go index 606911b..584d6f9 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -5921,6 +5921,68 @@ func TestSet4(t *testing.T) { } } +func TestSetComment(t *testing.T) { + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft flush ruleset + []byte("\x00\x00\x00\x00"), + // nft add table inet filter + []byte("\x01\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add set inet filter setname { type ipv4_addr\; comment \"test comment\" \; } + []byte("\x01\x00\x00\x00\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x73\x65\x74\x6e\x61\x6d\x65\x00\x08\x00\x03\x00\x00\x00\x00\x00\x08\x00\x04\x00\x00\x00\x00\x07\x08\x00\x05\x00\x00\x00\x00\x04\x08\x00\x0a\x00\x00\x00\x00\x02\x13\x00\x0d\x00\x07\x0d\x74\x65\x73\x74\x20\x63\x6f\x6d\x6d\x65\x6e\x74\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + c, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + })) + if err != nil { + t.Fatal(err) + } + + c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyINet, + Name: "filter", + }) + + if err := c.AddSet(&nftables.Set{ + ID: 2, + Table: filter, + Name: "setname", + KeyType: nftables.TypeIPAddr, + Comment: "test comment", + }, nil); err != nil { + t.Fatal(err) + } + + if err := c.Flush(); err != nil { + t.Fatal(err) + } +} + func TestMasq(t *testing.T) { tests := []struct { name string diff --git a/set.go b/set.go index d5afff3..e2f58fe 100644 --- a/set.go +++ b/set.go @@ -266,6 +266,7 @@ type Set struct { // Either host (binaryutil.NativeEndian) or big (binaryutil.BigEndian) endian as per // https://git.netfilter.org/nftables/tree/include/datatype.h?id=d486c9e626405e829221b82d7355558005b26d8a#n109 KeyByteOrder binaryutil.ByteOrder + Comment string } // SetElement represents a data point within a set. @@ -598,6 +599,10 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { userData = userdata.AppendUint32(userData, userdata.NFTNL_UDATA_SET_MERGE_ELEMENTS, 1) } + if len(s.Comment) != 0 { + userData = userdata.AppendString(userData, userdata.NFTNL_UDATA_SET_COMMENT, s.Comment) + } + if len(userData) > 0 { tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: userData}) }