diff --git a/nftables_test.go b/nftables_test.go index 7489cd1..1e6cf84 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1677,6 +1677,56 @@ func TestCreateUseNamedSet(t *testing.T) { } } +func TestIP6SetAddElements(t *testing.T) { + + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv6, + Name: "filter", + }) + portSet := &nftables.Set{ + Table: filter, + Name: "ports", + KeyType: nftables.TypeInetService, + } + if err := c.AddSet(portSet, nil); err != nil { + t.Errorf("c.AddSet(portSet) failed: %v", err) + } + if err := c.SetAddElements(portSet, []nftables.SetElement{ + {Key: binaryutil.BigEndian.PutUint16(22)}, + {Key: binaryutil.BigEndian.PutUint16(80)}, + }); err != nil { + t.Errorf("c.SetVal(portSet) failed: %v", err) + } + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + sets, err := c.GetSets(filter) + if err != nil { + t.Errorf("c.GetSets() failed: %v", err) + } + if len(sets) != 1 { + t.Fatalf("len(sets) = %d, want 1", len(sets)) + } + + elements, err := c.GetSetElements(sets[0]) + if err != nil { + t.Errorf("c.GetSetElements(portSet) failed: %v", err) + } + if len(elements) != 2 { + t.Fatalf("len(portSetElements) = %d, want 2", len(sets)) + } +} + func TestCreateDeleteNamedSet(t *testing.T) { // Create a new network namespace to test these operations, // and tear down the namespace at test completion. diff --git a/set.go b/set.go index 6fa1645..c670ad3 100644 --- a/set.go +++ b/set.go @@ -170,7 +170,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, - Data: append(extraHeader(unix.NFTA_SET_NAME, 0), cc.marshalAttr(elements)...), + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), }) return nil @@ -559,6 +559,7 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) { if err != nil { return nil, err } + s.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} sets = append(sets, s) }