Compare commits

..

1 Commits

Author SHA1 Message Date
Paul Greenberg 3900fe312c
Merge dae73eaa9c into 508bb1ffd4 2025-07-19 22:34:56 +02:00
6 changed files with 14 additions and 249 deletions

44
conn.go
View File

@ -233,19 +233,6 @@ func (cc *Conn) CloseLasting() error {
// Flush sends all buffered commands in a single batch to nftables. // Flush sends all buffered commands in a single batch to nftables.
func (cc *Conn) Flush() error { func (cc *Conn) Flush() error {
return cc.flush(0)
}
// FlushWithGenID sends all buffered commands in a single batch to nftables
// along with the provided gen ID. If the ruleset has changed since the gen ID
// was retrieved, an ERESTART error will be returned.
func (cc *Conn) FlushWithGenID(genID uint32) error {
return cc.flush(genID)
}
// flush sends all buffered commands in a single batch to nftables. If genID is
// non-zero, it will be included in the batch messages.
func (cc *Conn) flush(genID uint32) error {
cc.mu.Lock() cc.mu.Lock()
defer func() { defer func() {
cc.messages = nil cc.messages = nil
@ -272,12 +259,7 @@ func (cc *Conn) flush(genID uint32) error {
return err return err
} }
batch, err := batch(cc.messages, genID) messages, err := conn.SendMessages(batch(cc.messages))
if err != nil {
return err
}
messages, err := conn.SendMessages(batch)
if err != nil { if err != nil {
return fmt.Errorf("SendMessages: %w", err) return fmt.Errorf("SendMessages: %w", err)
} }
@ -406,30 +388,14 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
return b return b
} }
// batch wraps the given messages in a batch begin and end message, and returns func batch(messages []netlinkMessage) []netlink.Message {
// the resulting slice of netlink messages. If the genID is non-zero, it will be
// included in both batch messages.
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
batch := make([]netlink.Message, len(messages)+2) batch := make([]netlink.Message, len(messages)+2)
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
if genID > 0 {
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
})
if err != nil {
return nil, err
}
data = append(data, attr...)
}
batch[0] = netlink.Message{ batch[0] = netlink.Message{
Header: netlink.Header{ Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: data, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
} }
for i, msg := range messages { for i, msg := range messages {
@ -444,10 +410,10 @@ func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request, Flags: netlink.Request,
}, },
Data: data, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
} }
return batch, nil return batch
} }
// allocateTransactionID allocates an identifier which is only valid in the // allocateTransactionID allocates an identifier which is only valid in the

50
gen.go
View File

@ -8,18 +8,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type Gen struct { type GenMsg struct {
ID uint32 ID uint32
ProcPID uint32 ProcPID uint32
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
} }
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
type GenMsg = Gen
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
func genFromMsg(msg netlink.Message) (*Gen, error) { func genFromMsg(msg netlink.Message) (*GenMsg, error) {
if got, want := msg.Header.Type, genHeaderType; got != want { if got, want := msg.Header.Type, genHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
} }
@ -29,7 +26,7 @@ func genFromMsg(msg netlink.Message) (*Gen, error) {
} }
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
msgOut := &Gen{} msgOut := &GenMsg{}
for ad.Next() { for ad.Next() {
switch ad.Type() { switch ad.Type() {
case unix.NFTA_GEN_ID: case unix.NFTA_GEN_ID:
@ -39,7 +36,7 @@ func genFromMsg(msg netlink.Message) (*Gen, error) {
case unix.NFTA_GEN_PROC_NAME: case unix.NFTA_GEN_PROC_NAME:
msgOut.ProcComm = ad.String() msgOut.ProcComm = ad.String()
default: default:
return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes()) return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes())
} }
} }
if err := ad.Err(); err != nil { if err := ad.Err(); err != nil {
@ -47,42 +44,3 @@ func genFromMsg(msg netlink.Message) (*Gen, error) {
} }
return msgOut, nil return msgOut, nil
} }
// GetGen retrieves the current nftables generation ID together with the name
// and ID of the process that last modified the ruleset.
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
func (cc *Conn) GetGen() (*Gen, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()
data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_GEN_ID},
})
if err != nil {
return nil, err
}
message := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN),
Flags: netlink.Request | netlink.Acknowledge,
},
Data: append(extraHeader(0, 0), data...),
}
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err)
}
reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err)
}
if len(reply) == 0 {
return nil, fmt.Errorf("receiveAckAware: no reply")
}
return genFromMsg(reply[0])
}

View File

@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) {
return return
} }
genMsg := event.GeneratedBy.Data.(*nftables.Gen) genMsg := event.GeneratedBy.Data.(*nftables.GenMsg)
fileName := filepath.Base(os.Args[0]) fileName := filepath.Base(os.Args[0])
if genMsg.ProcComm != fileName { if genMsg.ProcComm != fileName {

View File

@ -128,47 +128,6 @@ func ifname(n string) []byte {
return b return b
} }
func TestTableCreateDestroy(t *testing.T) {
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer c.FlushRuleset()
filter := &nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
}
c.DestroyTable(filter)
c.AddTable(filter)
err := c.Flush()
if err != nil {
t.Fatalf("on Flush: %q", err.Error())
}
lookupMyTable := func() bool {
ts, err := c.ListTables()
if err != nil {
t.Fatalf("on ListTables: %q", err.Error())
}
return slices.ContainsFunc(ts, func(t *nftables.Table) bool {
return t.Name == filter.Name && t.Family == filter.Family
})
}
if !lookupMyTable() {
t.Fatal("AddTable doesn't create my table!")
}
c.DestroyTable(filter)
if err = c.Flush(); err != nil {
t.Fatalf("on Flush: %q", err.Error())
}
if lookupMyTable() {
t.Fatal("DestroyTable doesn't delete my table!")
}
c.DestroyTable(filter) // just for test that 'destroy' ignore error 'not found'
}
func TestRuleOperations(t *testing.T) { func TestRuleOperations(t *testing.T) {
// Create a new network namespace to test these operations, // Create a new network namespace to test these operations,
// and tear down the namespace at test completion. // and tear down the namespace at test completion.
@ -3833,7 +3792,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
Name: "test", Name: "test",
KeyType: nftables.TypeInetService, KeyType: nftables.TypeInetService,
} }
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}, {Key: []byte{0, 24}}}); err != nil { if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
t.Errorf("c.AddSet(portSet) failed: %v", err) t.Errorf("c.AddSet(portSet) failed: %v", err)
} }
if err := c.Flush(); err != nil { if err := c.Flush(); err != nil {
@ -3850,22 +3809,6 @@ func TestDeleteElementNamedSet(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("c.GetSets() failed: %v", err) t.Errorf("c.GetSets() failed: %v", err)
} }
if len(elems) != 2 {
t.Fatalf("len(elems) = %d, want 2", len(elems))
}
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 99}}})
if err := c.Flush(); err != nil {
t.Errorf("Third c.Flush() failed: %v", err)
}
elems, err = c.GetSetElements(portSet)
if err != nil {
t.Errorf("c.GetSets() failed: %v", err)
}
if len(elems) != 1 { if len(elems) != 1 {
t.Fatalf("len(elems) = %d, want 1", len(elems)) t.Fatalf("len(elems) = %d, want 1", len(elems))
} }
@ -7501,77 +7444,3 @@ func TestAutoBufferSize(t *testing.T) {
t.Fatalf("failed to flush: %v", err) t.Fatalf("failed to flush: %v", err)
} }
} }
func TestGetGen(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
gen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Flush to increment the generation ID.
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
newGen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
if newGen.ID <= gen.ID {
t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID)
}
if newGen.ProcComm != gen.ProcComm {
t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm)
}
if newGen.ProcPID != gen.ProcPID {
t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID)
}
}
func TestFlushWithGenID(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
defer conn.FlushRuleset()
gen, err := conn.GetGen()
if err != nil {
t.Fatalf("failed to get gen: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table",
Family: nftables.TableFamilyIPv4,
})
// Flush to increment the generation ID.
if err := conn.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
conn.AddTable(&nftables.Table{
Name: "test-table-2",
Family: nftables.TableFamilyIPv4,
})
err = conn.FlushWithGenID(gen.ID)
if err == nil || !errors.Is(err, syscall.ERESTART) {
t.Errorf("expected error to be ERESTART, got: %v", err)
}
table, err := conn.ListTable("test-table-2")
if table != nil && !errors.Is(err, syscall.ENOENT) {
t.Errorf("expected table to not exist, got: %v", table)
}
}

18
set.go
View File

@ -44,9 +44,6 @@ const (
NFTA_SET_ELEM_KEY_END = 10 NFTA_SET_ELEM_KEY_END = 10
// https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n429 // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n429
NFTA_SET_ELEM_EXPRESSIONS = 0x11 NFTA_SET_ELEM_EXPRESSIONS = 0x11
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYSETELEM const yet.
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
NFT_MSG_DESTROYSETELEM = 0x1e
) )
// SetDatatype represents a datatype declared by nft. // SetDatatype represents a datatype declared by nft.
@ -394,16 +391,6 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM) return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
} }
// SetDestroyElements like SetDeleteElements, but not an error if setelement doesn't exists
func (cc *Conn) SetDestroyElements(s *Set, vals []SetElement) error {
cc.mu.Lock()
defer cc.mu.Unlock()
if s.Anonymous {
return errors.New("anonymous sets cannot be updated")
}
return cc.appendElemList(s, vals, NFT_MSG_DESTROYSETELEM)
}
// maxElemBatchSize is the maximum size in bytes of encoded set elements which // maxElemBatchSize is the maximum size in bytes of encoded set elements which
// are sent in one netlink message. The size field of a netlink attribute is a // are sent in one netlink message. The size field of a netlink attribute is a
// uint16, and 1024 bytes is more than enough for the per-message headers. // uint16, and 1024 bytes is more than enough for the per-message headers.
@ -839,9 +826,8 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) {
} }
const ( const (
newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM) delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM)
destroyElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DESTROYSETELEM)
) )
func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) {

View File

@ -24,10 +24,6 @@ import (
const ( const (
newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE)
delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE)
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYTABLE const yet.
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
destroyTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | 0x1a)
) )
// TableFamily specifies the address family for this table. // TableFamily specifies the address family for this table.
@ -55,25 +51,15 @@ type Table struct {
// DelTable deletes a specific table, along with all chains/rules it contains. // DelTable deletes a specific table, along with all chains/rules it contains.
func (cc *Conn) DelTable(t *Table) { func (cc *Conn) DelTable(t *Table) {
cc.delTable(t, delTableHeaderType)
}
// DestroyTable is like DelTable, but not an error if table doesn't exists
func (cc *Conn) DestroyTable(t *Table) {
cc.delTable(t, destroyTableHeaderType)
}
func (cc *Conn) delTable(t *Table, hdrType netlink.HeaderType) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
data := cc.marshalAttr([]netlink.Attribute{ data := cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
}) })
cc.messages = append(cc.messages, netlinkMessage{ cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{ Header: netlink.Header{
Type: hdrType, Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Flags: netlink.Request | netlink.Acknowledge, Flags: netlink.Request | netlink.Acknowledge,
}, },
Data: append(extraHeader(uint8(t.Family), 0), data...), Data: append(extraHeader(uint8(t.Family), 0), data...),