Compare commits

...

4 Commits

Author SHA1 Message Date
corpix 77210037da
Merge dd13cb1d03 into 1148f1a84f 2025-09-02 22:04:14 +00:00
Nikita Vorontsov 1148f1a84f
add DestroyTable and SetDestroyElements (#322)
These methods are like their DeleteTable and SetDeleteElements counterparts, but they do not return an error if the specified table/set does not exist.
2025-09-02 14:08:18 +02:00
Nick Garlis 3efc75f481
Add GetGen method to retrieve current generation ID (#325)
Add GetGen method to retrieve current generation ID

nftables uses generation IDs (gen IDs) for optimistic concurrency
control. This commit adds a GetGen method to expose current gen ID so
that users can retrieve it explicitly.

Typical usage:
  1. Call GetGen to retrieve current gen ID.
  2. Read the the current state.
  3. Send the batch along with the gen ID by calling Flush.

If the state changes before the flush, the kernel will reject the
batch, preventing stale writes.

- https://wiki.nftables.org/wiki-nftables/index.php/Portal:DeveloperDocs/nftables_internals#Batched_handlers
- https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
- 3957a57201/net/netfilter/nfnetlink.c (L424)
2025-09-02 14:05:05 +02:00
Dmitry Moskowski dd13cb1d03 Replace %v with %w to wrap underlying errors 2025-04-05 21:03:30 +00:00
12 changed files with 277 additions and 42 deletions

View File

@ -215,7 +215,7 @@ func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) {
response, err := conn.Execute(msg)
if err != nil {
return nil, fmt.Errorf("conn.Execute failed: %v", err)
return nil, fmt.Errorf("conn.Execute failed: %w", err)
}
if got, want := len(response), 1; got != want {

44
conn.go
View File

@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error {
// Flush sends all buffered commands in a single batch to nftables.
func (cc *Conn) Flush() error {
return cc.flush(0)
}
// FlushWithGenID sends all buffered commands in a single batch to nftables
// along with the provided gen ID. If the ruleset has changed since the gen ID
// was retrieved, an ERESTART error will be returned.
func (cc *Conn) FlushWithGenID(genID uint32) error {
return cc.flush(genID)
}
// flush sends all buffered commands in a single batch to nftables. If genID is
// non-zero, it will be included in the batch messages.
func (cc *Conn) flush(genID uint32) error {
cc.mu.Lock()
defer func() {
cc.messages = nil
@ -259,7 +272,12 @@ func (cc *Conn) Flush() error {
return err
}
messages, err := conn.SendMessages(batch(cc.messages))
batch, err := batch(cc.messages, genID)
if err != nil {
return err
}
messages, err := conn.SendMessages(batch)
if err != nil {
return fmt.Errorf("SendMessages: %w", err)
}
@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
return b
}
func batch(messages []netlinkMessage) []netlink.Message {
// batch wraps the given messages in a batch begin and end message, and returns
// the resulting slice of netlink messages. If the genID is non-zero, it will be
// included in both batch messages.
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
batch := make([]netlink.Message, len(messages)+2)
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
if genID > 0 {
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
})
if err != nil {
return nil, err
}
data = append(data, attr...)
}
batch[0] = netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request,
},
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
Data: data,
}
for i, msg := range messages {
@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message {
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request,
},
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
Data: data,
}
return batch
return batch, nil
}
// allocateTransactionID allocates an identifier which is only valid in the

View File

@ -66,7 +66,7 @@ func (e *Immediate) unmarshal(fam byte, data []byte) error {
case unix.NFTA_IMMEDIATE_DATA:
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
if err != nil {
return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err)
return fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
}
for nestedAD.Next() {
switch nestedAD.Type() {
@ -75,7 +75,7 @@ func (e *Immediate) unmarshal(fam byte, data []byte) error {
}
}
if nestedAD.Err() != nil {
return fmt.Errorf("decoding immediate: %v", nestedAD.Err())
return fmt.Errorf("decoding immediate: %w", nestedAD.Err())
}
}
}

View File

@ -111,7 +111,7 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
case unix.NFTA_IMMEDIATE_DATA:
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
if err != nil {
return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err)
return fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
}
for nestedAD.Next() {
switch nestedAD.Type() {
@ -123,7 +123,7 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
}
}
if nestedAD.Err() != nil {
return fmt.Errorf("decoding immediate: %v", nestedAD.Err())
return fmt.Errorf("decoding immediate: %w", nestedAD.Err())
}
}
}

View File

@ -214,12 +214,12 @@ func (cc *Conn) getFlowtables(t *Table) ([]netlink.Message, error) {
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
return nil, fmt.Errorf("receiveAckAware: %w", err)
}
return reply, nil

50
gen.go
View File

@ -8,15 +8,18 @@ import (
"golang.org/x/sys/unix"
)
type GenMsg struct {
type Gen struct {
ID uint32
ProcPID uint32
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
}
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
type GenMsg = Gen
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
func genFromMsg(msg netlink.Message) (*GenMsg, error) {
func genFromMsg(msg netlink.Message) (*Gen, error) {
if got, want := msg.Header.Type, genHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
}
@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
}
ad.ByteOrder = binary.BigEndian
msgOut := &GenMsg{}
msgOut := &Gen{}
for ad.Next() {
switch ad.Type() {
case unix.NFTA_GEN_ID:
@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
case unix.NFTA_GEN_PROC_NAME:
msgOut.ProcComm = ad.String()
default:
return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes())
return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes())
}
}
if err := ad.Err(); err != nil {
@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
}
return msgOut, nil
}
// GetGen retrieves the current nftables generation ID together with the name
// and ID of the process that last modified the ruleset.
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
func (cc *Conn) GetGen() (*Gen, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_GEN_ID},
})
if err != nil {
return nil, err
}
message := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN),
Flags: netlink.Request | netlink.Acknowledge,
},
Data: append(extraHeader(0, 0), data...),
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
}
if len(reply) == 0 {
return nil, fmt.Errorf("receiveAckAware: no reply")
}
return genFromMsg(reply[0])
}

View File

@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) {
return
}
genMsg := event.GeneratedBy.Data.(*nftables.GenMsg)
genMsg := event.GeneratedBy.Data.(*nftables.Gen)
fileName := filepath.Base(os.Args[0])
if genMsg.ProcComm != fileName {

View File

@ -128,6 +128,47 @@ func ifname(n string) []byte {
return b
}
func TestTableCreateDestroy(t *testing.T) {
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer c.FlushRuleset()
filter := &nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
}
c.DestroyTable(filter)
c.AddTable(filter)
err := c.Flush()
if err != nil {
t.Fatalf("on Flush: %q", err.Error())
}
lookupMyTable := func() bool {
ts, err := c.ListTables()
if err != nil {
t.Fatalf("on ListTables: %q", err.Error())
}
return slices.ContainsFunc(ts, func(t *nftables.Table) bool {
return t.Name == filter.Name && t.Family == filter.Family
})
}
if !lookupMyTable() {
t.Fatal("AddTable doesn't create my table!")
}
c.DestroyTable(filter)
if err = c.Flush(); err != nil {
t.Fatalf("on Flush: %q", err.Error())
}
if lookupMyTable() {
t.Fatal("DestroyTable doesn't delete my table!")
}
c.DestroyTable(filter) // just for test that 'destroy' ignore error 'not found'
}
func TestRuleOperations(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
@ -3777,7 +3818,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
Name: "test",
KeyType: nftables.TypeInetService,
}
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}, {Key: []byte{0, 24}}}); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err)
}
if err := c.Flush(); err != nil {
@ -3794,6 +3835,22 @@ func TestDeleteElementNamedSet(t *testing.T) {
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(elems) != 2 {
t.Fatalf("len(elems) = %d, want 2", len(elems))
}
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 99}}})
if err := c.Flush(); err != nil {
t.Errorf("Third c.Flush() failed: %v", err)
}
elems, err = c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(elems) != 1 {
t.Fatalf("len(elems) = %d, want 1", len(elems))
}
@ -7429,3 +7486,77 @@ func TestAutoBufferSize(t *testing.T) {
t.Fatalf("failed to flush: %v", err)
}
}
func TestGetGen(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
gen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Flush to increment the generation ID.
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
newGen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
if newGen.ID <= gen.ID {
t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID)
}
if newGen.ProcComm != gen.ProcComm {
t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm)
}
if newGen.ProcPID != gen.ProcPID {
t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID)
}
}
func TestFlushWithGenID(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
gen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Flush to increment the generation ID.
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table-2",
Family: nftables.TableFamilyIPv4,
})
err = conn.FlushWithGenID(gen.ID)
if err == nil || !errors.Is(err, syscall.ERESTART) {
t.Errorf("expected error to be ERESTART, got: %v", err)
}
table, err := conn.ListTable("test-table-2")
if table != nil && !errors.Is(err, syscall.ENOENT) {
t.Errorf("expected table to not exist, got: %v", table)
}
}

4
obj.go
View File

@ -361,12 +361,12 @@ func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLega
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
return nil, fmt.Errorf("receiveAckAware: %w", err)
}
var objs []Obj
for _, msg := range reply {

View File

@ -101,12 +101,12 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
return nil, fmt.Errorf("receiveAckAware: %w", err)
}
var rules []*Rule
for _, msg := range reply {

48
set.go
View File

@ -44,6 +44,9 @@ const (
NFTA_SET_ELEM_KEY_END = 10
// https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n429
NFTA_SET_ELEM_EXPRESSIONS = 0x11
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYSETELEM const yet.
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
NFT_MSG_DESTROYSETELEM = 0x1e
)
// SetDatatype represents a datatype declared by nft.
@ -298,7 +301,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
return func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return fmt.Errorf("failed to create nested attribute decoder: %v", err)
return fmt.Errorf("failed to create nested attribute decoder: %w", err)
}
ad.ByteOrder = binary.BigEndian
@ -353,7 +356,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
func decodeElement(d []byte) ([]byte, error) {
ad, err := netlink.NewAttributeDecoder(d)
if err != nil {
return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err)
return nil, fmt.Errorf("failed to create nested attribute decoder: %w", err)
}
ad.ByteOrder = binary.BigEndian
var b []byte
@ -391,6 +394,16 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
}
// SetDestroyElements like SetDeleteElements, but not an error if setelement doesn't exists
func (cc *Conn) SetDestroyElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
return cc.appendElemList(s, vals, NFT_MSG_DESTROYSETELEM)
}
// maxElemBatchSize is the maximum size in bytes of encoded set elements which
// are sent in one netlink message. The size field of a netlink attribute is a
// uint16, and 1024 bytes is more than enough for the per-message headers.
@ -414,14 +427,14 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
if err != nil {
return fmt.Errorf("marshal key %d: %v", i, err)
return fmt.Errorf("marshal key %d: %w", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if len(v.KeyEnd) > 0 {
encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
if err != nil {
return fmt.Errorf("marshal key end %d: %v", i, err)
return fmt.Errorf("marshal key end %d: %w", i, err)
}
item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
}
@ -441,7 +454,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
})
if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %w", i, err)
}
encodedVal = append(encodedVal, encodedKind...)
if len(v.VerdictData.Chain) != 0 {
@ -449,21 +462,21 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
{Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
})
if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %w", i, err)
}
encodedVal = append(encodedVal, encodedChain...)
}
encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %w", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
case len(v.Val) > 0:
// Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %w", i, err)
}
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
@ -479,7 +492,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
encodedItem, err := netlink.MarshalAttributes(item)
if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err)
return fmt.Errorf("marshal item %d: %w", i, err)
}
itemSize := unix.NLA_HDRLEN + len(encodedItem)
@ -496,7 +509,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
for _, batch := range batches {
encodedElem, err := netlink.MarshalAttributes(batch)
if err != nil {
return fmt.Errorf("marshal elements: %v", err)
return fmt.Errorf("marshal elements: %w", err)
}
message := []netlink.Attribute{
@ -591,7 +604,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
})
if err != nil {
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)
return fmt.Errorf("fail to marshal number of elements %d: %w", len(vals), err)
}
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
}
@ -620,7 +633,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
})
if err != nil {
return fmt.Errorf("fail to marshal element key size %d: %v", i, err)
return fmt.Errorf("fail to marshal element key size %d: %w", i, err)
}
// Marshal base type size description
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
@ -634,7 +647,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
// Marshal all base type descriptions into concatenation size description
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
if err != nil {
return fmt.Errorf("fail to marshal concat definition %v", err)
return fmt.Errorf("fail to marshal concat definition %w", err)
}
descBytes = append(descBytes, concatBytes...)
@ -828,6 +841,7 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) {
const (
newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM)
destroyElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DESTROYSETELEM)
)
func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) {
@ -889,12 +903,12 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
return nil, fmt.Errorf("receiveAckAware: %w", err)
}
var sets []*Set
for _, msg := range reply {
@ -979,12 +993,12 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
return nil, fmt.Errorf("SendMessages: %w", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
return nil, fmt.Errorf("receiveAckAware: %w", err)
}
var elems []SetElement
for _, msg := range reply {

View File

@ -24,6 +24,10 @@ import (
const (
newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE)
delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE)
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYTABLE const yet.
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
destroyTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | 0x1a)
)
// TableFamily specifies the address family for this table.
@ -51,15 +55,25 @@ type Table struct {
// DelTable deletes a specific table, along with all chains/rules it contains.
func (cc *Conn) DelTable(t *Table) {
cc.delTable(t, delTableHeaderType)
}
// DestroyTable is like DelTable, but not an error if table doesn't exists
func (cc *Conn) DestroyTable(t *Table) {
cc.delTable(t, destroyTableHeaderType)
}
func (cc *Conn) delTable(t *Table, hdrType netlink.HeaderType) {
cc.mu.Lock()
defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
})
cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Type: hdrType,
Flags: netlink.Request | netlink.Acknowledge,
},
Data: append(extraHeader(uint8(t.Family), 0), data...),