Compare commits
4 Commits
588db72d00
...
77210037da
Author | SHA1 | Date |
---|---|---|
|
77210037da | |
|
1148f1a84f | |
|
3efc75f481 | |
|
dd13cb1d03 |
2
chain.go
2
chain.go
|
@ -215,7 +215,7 @@ func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) {
|
|||
|
||||
response, err := conn.Execute(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conn.Execute failed: %v", err)
|
||||
return nil, fmt.Errorf("conn.Execute failed: %w", err)
|
||||
}
|
||||
|
||||
if got, want := len(response), 1; got != want {
|
||||
|
|
44
conn.go
44
conn.go
|
@ -233,6 +233,19 @@ func (cc *Conn) CloseLasting() error {
|
|||
|
||||
// Flush sends all buffered commands in a single batch to nftables.
|
||||
func (cc *Conn) Flush() error {
|
||||
return cc.flush(0)
|
||||
}
|
||||
|
||||
// FlushWithGenID sends all buffered commands in a single batch to nftables
|
||||
// along with the provided gen ID. If the ruleset has changed since the gen ID
|
||||
// was retrieved, an ERESTART error will be returned.
|
||||
func (cc *Conn) FlushWithGenID(genID uint32) error {
|
||||
return cc.flush(genID)
|
||||
}
|
||||
|
||||
// flush sends all buffered commands in a single batch to nftables. If genID is
|
||||
// non-zero, it will be included in the batch messages.
|
||||
func (cc *Conn) flush(genID uint32) error {
|
||||
cc.mu.Lock()
|
||||
defer func() {
|
||||
cc.messages = nil
|
||||
|
@ -259,7 +272,12 @@ func (cc *Conn) Flush() error {
|
|||
return err
|
||||
}
|
||||
|
||||
messages, err := conn.SendMessages(batch(cc.messages))
|
||||
batch, err := batch(cc.messages, genID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messages, err := conn.SendMessages(batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
@ -388,14 +406,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
|
|||
return b
|
||||
}
|
||||
|
||||
func batch(messages []netlinkMessage) []netlink.Message {
|
||||
// batch wraps the given messages in a batch begin and end message, and returns
|
||||
// the resulting slice of netlink messages. If the genID is non-zero, it will be
|
||||
// included in both batch messages.
|
||||
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
|
||||
batch := make([]netlink.Message, len(messages)+2)
|
||||
|
||||
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
|
||||
|
||||
if genID > 0 {
|
||||
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data = append(data, attr...)
|
||||
}
|
||||
|
||||
batch[0] = netlink.Message{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
|
||||
Flags: netlink.Request,
|
||||
},
|
||||
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
|
@ -410,10 +444,10 @@ func batch(messages []netlinkMessage) []netlink.Message {
|
|||
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
|
||||
Flags: netlink.Request,
|
||||
},
|
||||
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return batch
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
// allocateTransactionID allocates an identifier which is only valid in the
|
||||
|
|
|
@ -66,7 +66,7 @@ func (e *Immediate) unmarshal(fam byte, data []byte) error {
|
|||
case unix.NFTA_IMMEDIATE_DATA:
|
||||
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err)
|
||||
return fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
|
||||
}
|
||||
for nestedAD.Next() {
|
||||
switch nestedAD.Type() {
|
||||
|
@ -75,7 +75,7 @@ func (e *Immediate) unmarshal(fam byte, data []byte) error {
|
|||
}
|
||||
}
|
||||
if nestedAD.Err() != nil {
|
||||
return fmt.Errorf("decoding immediate: %v", nestedAD.Err())
|
||||
return fmt.Errorf("decoding immediate: %w", nestedAD.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
|
|||
case unix.NFTA_IMMEDIATE_DATA:
|
||||
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err)
|
||||
return fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
|
||||
}
|
||||
for nestedAD.Next() {
|
||||
switch nestedAD.Type() {
|
||||
|
@ -123,7 +123,7 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
|
|||
}
|
||||
}
|
||||
if nestedAD.Err() != nil {
|
||||
return fmt.Errorf("decoding immediate: %v", nestedAD.Err())
|
||||
return fmt.Errorf("decoding immediate: %w", nestedAD.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -214,12 +214,12 @@ func (cc *Conn) getFlowtables(t *Table) ([]netlink.Message, error) {
|
|||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
return nil, fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
return nil, fmt.Errorf("receiveAckAware: %w", err)
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
|
|
50
gen.go
50
gen.go
|
@ -8,15 +8,18 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type GenMsg struct {
|
||||
type Gen struct {
|
||||
ID uint32
|
||||
ProcPID uint32
|
||||
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
|
||||
}
|
||||
|
||||
// Deprecated: GenMsg is an inconsistent old name for Gen. Prefer using Gen.
|
||||
type GenMsg = Gen
|
||||
|
||||
const genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
|
||||
|
||||
func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
||||
func genFromMsg(msg netlink.Message) (*Gen, error) {
|
||||
if got, want := msg.Header.Type, genHeaderType; got != want {
|
||||
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
|
||||
}
|
||||
|
@ -26,7 +29,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
|||
}
|
||||
ad.ByteOrder = binary.BigEndian
|
||||
|
||||
msgOut := &GenMsg{}
|
||||
msgOut := &Gen{}
|
||||
for ad.Next() {
|
||||
switch ad.Type() {
|
||||
case unix.NFTA_GEN_ID:
|
||||
|
@ -36,7 +39,7 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
|||
case unix.NFTA_GEN_PROC_NAME:
|
||||
msgOut.ProcComm = ad.String()
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes())
|
||||
return nil, fmt.Errorf("unknown attribute: %d, %v", ad.Type(), ad.Bytes())
|
||||
}
|
||||
}
|
||||
if err := ad.Err(); err != nil {
|
||||
|
@ -44,3 +47,42 @@ func genFromMsg(msg netlink.Message) (*GenMsg, error) {
|
|||
}
|
||||
return msgOut, nil
|
||||
}
|
||||
|
||||
// GetGen retrieves the current nftables generation ID together with the name
|
||||
// and ID of the process that last modified the ruleset.
|
||||
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getgen
|
||||
func (cc *Conn) GetGen() (*Gen, error) {
|
||||
conn, closer, err := cc.netlinkConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = closer() }()
|
||||
|
||||
data, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
{Type: unix.NFTA_GEN_ID},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := netlink.Message{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETGEN),
|
||||
Flags: netlink.Request | netlink.Acknowledge,
|
||||
},
|
||||
Data: append(extraHeader(0, 0), data...),
|
||||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
}
|
||||
if len(reply) == 0 {
|
||||
return nil, fmt.Errorf("receiveAckAware: no reply")
|
||||
}
|
||||
return genFromMsg(reply[0])
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestMonitor(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
genMsg := event.GeneratedBy.Data.(*nftables.GenMsg)
|
||||
genMsg := event.GeneratedBy.Data.(*nftables.Gen)
|
||||
fileName := filepath.Base(os.Args[0])
|
||||
|
||||
if genMsg.ProcComm != fileName {
|
||||
|
|
133
nftables_test.go
133
nftables_test.go
|
@ -128,6 +128,47 @@ func ifname(n string) []byte {
|
|||
return b
|
||||
}
|
||||
|
||||
func TestTableCreateDestroy(t *testing.T) {
|
||||
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||
defer nftest.CleanupSystemConn(t, newNS)
|
||||
defer c.FlushRuleset()
|
||||
|
||||
filter := &nftables.Table{
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Name: "filter",
|
||||
}
|
||||
c.DestroyTable(filter)
|
||||
c.AddTable(filter)
|
||||
err := c.Flush()
|
||||
if err != nil {
|
||||
t.Fatalf("on Flush: %q", err.Error())
|
||||
}
|
||||
|
||||
lookupMyTable := func() bool {
|
||||
ts, err := c.ListTables()
|
||||
if err != nil {
|
||||
t.Fatalf("on ListTables: %q", err.Error())
|
||||
}
|
||||
return slices.ContainsFunc(ts, func(t *nftables.Table) bool {
|
||||
return t.Name == filter.Name && t.Family == filter.Family
|
||||
})
|
||||
}
|
||||
if !lookupMyTable() {
|
||||
t.Fatal("AddTable doesn't create my table!")
|
||||
}
|
||||
|
||||
c.DestroyTable(filter)
|
||||
if err = c.Flush(); err != nil {
|
||||
t.Fatalf("on Flush: %q", err.Error())
|
||||
}
|
||||
|
||||
if lookupMyTable() {
|
||||
t.Fatal("DestroyTable doesn't delete my table!")
|
||||
}
|
||||
|
||||
c.DestroyTable(filter) // just for test that 'destroy' ignore error 'not found'
|
||||
}
|
||||
|
||||
func TestRuleOperations(t *testing.T) {
|
||||
// Create a new network namespace to test these operations,
|
||||
// and tear down the namespace at test completion.
|
||||
|
@ -3777,7 +3818,7 @@ func TestDeleteElementNamedSet(t *testing.T) {
|
|||
Name: "test",
|
||||
KeyType: nftables.TypeInetService,
|
||||
}
|
||||
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}}); err != nil {
|
||||
if err := c.AddSet(portSet, []nftables.SetElement{{Key: []byte{0, 22}}, {Key: []byte{0, 23}}, {Key: []byte{0, 24}}}); err != nil {
|
||||
t.Errorf("c.AddSet(portSet) failed: %v", err)
|
||||
}
|
||||
if err := c.Flush(); err != nil {
|
||||
|
@ -3794,6 +3835,22 @@ func TestDeleteElementNamedSet(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("c.GetSets() failed: %v", err)
|
||||
}
|
||||
if len(elems) != 2 {
|
||||
t.Fatalf("len(elems) = %d, want 2", len(elems))
|
||||
}
|
||||
|
||||
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
|
||||
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 24}}})
|
||||
c.SetDestroyElements(portSet, []nftables.SetElement{{Key: []byte{0, 99}}})
|
||||
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Errorf("Third c.Flush() failed: %v", err)
|
||||
}
|
||||
|
||||
elems, err = c.GetSetElements(portSet)
|
||||
if err != nil {
|
||||
t.Errorf("c.GetSets() failed: %v", err)
|
||||
}
|
||||
if len(elems) != 1 {
|
||||
t.Fatalf("len(elems) = %d, want 1", len(elems))
|
||||
}
|
||||
|
@ -7429,3 +7486,77 @@ func TestAutoBufferSize(t *testing.T) {
|
|||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGen(t *testing.T) {
|
||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||
defer nftest.CleanupSystemConn(t, newNS)
|
||||
defer conn.FlushRuleset()
|
||||
|
||||
gen, err := conn.GetGen()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get gen: %v", err)
|
||||
}
|
||||
|
||||
conn.AddTable(&nftables.Table{
|
||||
Name: "test-table",
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
|
||||
// Flush to increment the generation ID.
|
||||
if err := conn.Flush(); err != nil {
|
||||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
|
||||
newGen, err := conn.GetGen()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get gen: %v", err)
|
||||
}
|
||||
|
||||
if newGen.ID <= gen.ID {
|
||||
t.Fatalf("gen ID did not increase, got %d, want > %d", newGen.ID, gen.ID)
|
||||
}
|
||||
|
||||
if newGen.ProcComm != gen.ProcComm {
|
||||
t.Errorf("gen ProcComm changed, got %s, want %s", newGen.ProcComm, gen.ProcComm)
|
||||
}
|
||||
|
||||
if newGen.ProcPID != gen.ProcPID {
|
||||
t.Errorf("gen ProcPID changed, got %d, want %d", newGen.ProcPID, gen.ProcPID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushWithGenID(t *testing.T) {
|
||||
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
|
||||
defer nftest.CleanupSystemConn(t, newNS)
|
||||
defer conn.FlushRuleset()
|
||||
|
||||
gen, err := conn.GetGen()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get gen: %v", err)
|
||||
}
|
||||
|
||||
conn.AddTable(&nftables.Table{
|
||||
Name: "test-table",
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
|
||||
// Flush to increment the generation ID.
|
||||
if err := conn.Flush(); err != nil {
|
||||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
|
||||
conn.AddTable(&nftables.Table{
|
||||
Name: "test-table-2",
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
|
||||
err = conn.FlushWithGenID(gen.ID)
|
||||
if err == nil || !errors.Is(err, syscall.ERESTART) {
|
||||
t.Errorf("expected error to be ERESTART, got: %v", err)
|
||||
}
|
||||
|
||||
table, err := conn.ListTable("test-table-2")
|
||||
if table != nil && !errors.Is(err, syscall.ENOENT) {
|
||||
t.Errorf("expected table to not exist, got: %v", table)
|
||||
}
|
||||
}
|
||||
|
|
4
obj.go
4
obj.go
|
@ -361,12 +361,12 @@ func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLega
|
|||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
return nil, fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
return nil, fmt.Errorf("receiveAckAware: %w", err)
|
||||
}
|
||||
var objs []Obj
|
||||
for _, msg := range reply {
|
||||
|
|
4
rule.go
4
rule.go
|
@ -101,12 +101,12 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
|
|||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
return nil, fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
return nil, fmt.Errorf("receiveAckAware: %w", err)
|
||||
}
|
||||
var rules []*Rule
|
||||
for _, msg := range reply {
|
||||
|
|
48
set.go
48
set.go
|
@ -44,6 +44,9 @@ const (
|
|||
NFTA_SET_ELEM_KEY_END = 10
|
||||
// https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n429
|
||||
NFTA_SET_ELEM_EXPRESSIONS = 0x11
|
||||
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYSETELEM const yet.
|
||||
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
|
||||
NFT_MSG_DESTROYSETELEM = 0x1e
|
||||
)
|
||||
|
||||
// SetDatatype represents a datatype declared by nft.
|
||||
|
@ -298,7 +301,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
|
|||
return func(b []byte) error {
|
||||
ad, err := netlink.NewAttributeDecoder(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create nested attribute decoder: %v", err)
|
||||
return fmt.Errorf("failed to create nested attribute decoder: %w", err)
|
||||
}
|
||||
ad.ByteOrder = binary.BigEndian
|
||||
|
||||
|
@ -353,7 +356,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
|
|||
func decodeElement(d []byte) ([]byte, error) {
|
||||
ad, err := netlink.NewAttributeDecoder(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err)
|
||||
return nil, fmt.Errorf("failed to create nested attribute decoder: %w", err)
|
||||
}
|
||||
ad.ByteOrder = binary.BigEndian
|
||||
var b []byte
|
||||
|
@ -391,6 +394,16 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error {
|
|||
return cc.appendElemList(s, vals, unix.NFT_MSG_DELSETELEM)
|
||||
}
|
||||
|
||||
// SetDestroyElements like SetDeleteElements, but not an error if setelement doesn't exists
|
||||
func (cc *Conn) SetDestroyElements(s *Set, vals []SetElement) error {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
if s.Anonymous {
|
||||
return errors.New("anonymous sets cannot be updated")
|
||||
}
|
||||
return cc.appendElemList(s, vals, NFT_MSG_DESTROYSETELEM)
|
||||
}
|
||||
|
||||
// maxElemBatchSize is the maximum size in bytes of encoded set elements which
|
||||
// are sent in one netlink message. The size field of a netlink attribute is a
|
||||
// uint16, and 1024 bytes is more than enough for the per-message headers.
|
||||
|
@ -414,14 +427,14 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
|
|||
|
||||
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal key %d: %v", i, err)
|
||||
return fmt.Errorf("marshal key %d: %w", i, err)
|
||||
}
|
||||
|
||||
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
|
||||
if len(v.KeyEnd) > 0 {
|
||||
encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal key end %d: %v", i, err)
|
||||
return fmt.Errorf("marshal key end %d: %w", i, err)
|
||||
}
|
||||
item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
|
||||
}
|
||||
|
@ -441,7 +454,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
|
|||
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal item %d: %v", i, err)
|
||||
return fmt.Errorf("marshal item %d: %w", i, err)
|
||||
}
|
||||
encodedVal = append(encodedVal, encodedKind...)
|
||||
if len(v.VerdictData.Chain) != 0 {
|
||||
|
@ -449,21 +462,21 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
|
|||
{Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal item %d: %v", i, err)
|
||||
return fmt.Errorf("marshal item %d: %w", i, err)
|
||||
}
|
||||
encodedVal = append(encodedVal, encodedChain...)
|
||||
}
|
||||
encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal item %d: %v", i, err)
|
||||
return fmt.Errorf("marshal item %d: %w", i, err)
|
||||
}
|
||||
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
|
||||
case len(v.Val) > 0:
|
||||
// Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
|
||||
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal item %d: %v", i, err)
|
||||
return fmt.Errorf("marshal item %d: %w", i, err)
|
||||
}
|
||||
|
||||
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
|
||||
|
@ -479,7 +492,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
|
|||
|
||||
encodedItem, err := netlink.MarshalAttributes(item)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal item %d: %v", i, err)
|
||||
return fmt.Errorf("marshal item %d: %w", i, err)
|
||||
}
|
||||
|
||||
itemSize := unix.NLA_HDRLEN + len(encodedItem)
|
||||
|
@ -496,7 +509,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
|
|||
for _, batch := range batches {
|
||||
encodedElem, err := netlink.MarshalAttributes(batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal elements: %v", err)
|
||||
return fmt.Errorf("marshal elements: %w", err)
|
||||
}
|
||||
|
||||
message := []netlink.Attribute{
|
||||
|
@ -591,7 +604,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
|||
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err)
|
||||
return fmt.Errorf("fail to marshal number of elements %d: %w", len(vals), err)
|
||||
}
|
||||
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
|
||||
}
|
||||
|
@ -620,7 +633,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
|||
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to marshal element key size %d: %v", i, err)
|
||||
return fmt.Errorf("fail to marshal element key size %d: %w", i, err)
|
||||
}
|
||||
// Marshal base type size description
|
||||
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
|
||||
|
@ -634,7 +647,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
|
|||
// Marshal all base type descriptions into concatenation size description
|
||||
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to marshal concat definition %v", err)
|
||||
return fmt.Errorf("fail to marshal concat definition %w", err)
|
||||
}
|
||||
|
||||
descBytes = append(descBytes, concatBytes...)
|
||||
|
@ -828,6 +841,7 @@ func parseSetDatatype(magic uint32) (SetDatatype, error) {
|
|||
const (
|
||||
newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
|
||||
delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM)
|
||||
destroyElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DESTROYSETELEM)
|
||||
)
|
||||
|
||||
func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) {
|
||||
|
@ -889,12 +903,12 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
|
|||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
return nil, fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
return nil, fmt.Errorf("receiveAckAware: %w", err)
|
||||
}
|
||||
var sets []*Set
|
||||
for _, msg := range reply {
|
||||
|
@ -979,12 +993,12 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
|
|||
}
|
||||
|
||||
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
|
||||
return nil, fmt.Errorf("SendMessages: %v", err)
|
||||
return nil, fmt.Errorf("SendMessages: %w", err)
|
||||
}
|
||||
|
||||
reply, err := receiveAckAware(conn, message.Header.Flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("receiveAckAware: %v", err)
|
||||
return nil, fmt.Errorf("receiveAckAware: %w", err)
|
||||
}
|
||||
var elems []SetElement
|
||||
for _, msg := range reply {
|
||||
|
|
16
table.go
16
table.go
|
@ -24,6 +24,10 @@ import (
|
|||
const (
|
||||
newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE)
|
||||
delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE)
|
||||
|
||||
// FIXME: in sys@v0.34.0 no unix.NFT_MSG_DESTROYTABLE const yet.
|
||||
// See nf_tables_msg_types enum in https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h
|
||||
destroyTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | 0x1a)
|
||||
)
|
||||
|
||||
// TableFamily specifies the address family for this table.
|
||||
|
@ -51,15 +55,25 @@ type Table struct {
|
|||
|
||||
// DelTable deletes a specific table, along with all chains/rules it contains.
|
||||
func (cc *Conn) DelTable(t *Table) {
|
||||
cc.delTable(t, delTableHeaderType)
|
||||
}
|
||||
|
||||
// DestroyTable is like DelTable, but not an error if table doesn't exists
|
||||
func (cc *Conn) DestroyTable(t *Table) {
|
||||
cc.delTable(t, destroyTableHeaderType)
|
||||
}
|
||||
|
||||
func (cc *Conn) delTable(t *Table, hdrType netlink.HeaderType) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
data := cc.marshalAttr([]netlink.Attribute{
|
||||
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
|
||||
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
|
||||
})
|
||||
|
||||
cc.messages = append(cc.messages, netlinkMessage{
|
||||
Header: netlink.Header{
|
||||
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
|
||||
Type: hdrType,
|
||||
Flags: netlink.Request | netlink.Acknowledge,
|
||||
},
|
||||
Data: append(extraHeader(uint8(t.Family), 0), data...),
|
||||
|
|
Loading…
Reference in New Issue