Compare commits

...

3 Commits

Author SHA1 Message Date
corpix 2bdc32395a
Merge dd13cb1d03 into 8a8ad2be81 2025-06-06 11:26:46 +02:00
Gleb Zhizhchenko 8a8ad2be81
ct: Add optional direction fields (#317) 2025-06-06 11:18:25 +02:00
Dmitry Moskowski dd13cb1d03 Replace %v with %w to wrap underlying errors 2025-04-05 21:03:30 +00:00
9 changed files with 68 additions and 30 deletions

View File

@ -215,7 +215,7 @@ func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) {
response, err := conn.Execute(msg) response, err := conn.Execute(msg)
if err != nil { if err != nil {
return nil, fmt.Errorf("conn.Execute failed: %v", err) return nil, fmt.Errorf("conn.Execute failed: %w", err)
} }
if got, want := len(response), 1; got != want { if got, want := len(response), 1; got != want {

View File

@ -110,6 +110,12 @@ const (
CtStateUDPREPLIED CtStateUDPREPLIED
) )
const (
// https://git.netfilter.org/libnftnl/tree/src/expr/ct.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n31
CtDirOriginal = iota
CtDirReply
)
// https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n57 // https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n57
var CtStateUDPTimeoutDefaults CtStatePolicyTimeout = map[uint16]uint32{ var CtStateUDPTimeoutDefaults CtStatePolicyTimeout = map[uint16]uint32{
CtStateUDPUNREPLIED: 30, CtStateUDPUNREPLIED: 30,
@ -122,6 +128,7 @@ type Ct struct {
SourceRegister bool SourceRegister bool
Key CtKey Key CtKey
Direction uint32 Direction uint32
OptDirection bool
} }
func (e *Ct) marshal(fam byte) ([]byte, error) { func (e *Ct) marshal(fam byte) ([]byte, error) {
@ -165,10 +172,16 @@ func (e *Ct) marshalData(fam byte) ([]byte, error) {
exprData = append(exprData, regData...) exprData = append(exprData, regData...)
switch e.Key { switch e.Key {
case CtKeyPKTS, CtKeyBYTES, CtKeyAVGPKT, CtKeyL3PROTOCOL, CtKeyPROTOCOL:
if !e.OptDirection {
break
}
fallthrough
case CtKeySRC, CtKeyDST, CtKeyPROTOSRC, CtKeyPROTODST, CtKeySRCIP, CtKeyDSTIP, CtKeySRCIP6, CtKeyDSTIP6: case CtKeySRC, CtKeyDST, CtKeyPROTOSRC, CtKeyPROTODST, CtKeySRCIP, CtKeyDSTIP, CtKeySRCIP6, CtKeyDSTIP6:
regData, err = netlink.MarshalAttributes( regData, err = netlink.MarshalAttributes(
[]netlink.Attribute{ []netlink.Attribute{
{Type: unix.NFTA_CT_DIRECTION, Data: binaryutil.BigEndian.PutUint32(e.Direction)}, {Type: unix.NFTA_CT_DIRECTION, Data: []byte{uint8(e.Direction)}},
}, },
) )
if err != nil { if err != nil {
@ -186,6 +199,8 @@ func (e *Ct) unmarshal(fam byte, data []byte) error {
return err return err
} }
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
var hasDirection bool
for ad.Next() { for ad.Next() {
switch ad.Type() { switch ad.Type() {
case unix.NFTA_CT_KEY: case unix.NFTA_CT_KEY:
@ -193,12 +208,19 @@ func (e *Ct) unmarshal(fam byte, data []byte) error {
case unix.NFTA_CT_DREG: case unix.NFTA_CT_DREG:
e.Register = ad.Uint32() e.Register = ad.Uint32()
case unix.NFTA_CT_DIRECTION: case unix.NFTA_CT_DIRECTION:
e.Direction = ad.Uint32() e.Direction = uint32(ad.Uint8())
hasDirection = true
case unix.NFTA_CT_SREG: case unix.NFTA_CT_SREG:
e.SourceRegister = true e.SourceRegister = true
e.Register = ad.Uint32() e.Register = ad.Uint32()
} }
} }
switch e.Key {
case CtKeyPKTS, CtKeyBYTES, CtKeyAVGPKT, CtKeyL3PROTOCOL, CtKeyPROTOCOL:
e.OptDirection = hasDirection
}
return ad.Err() return ad.Err()
} }

View File

@ -78,6 +78,22 @@ func TestCt(t *testing.T) {
Direction: 1, Direction: 1,
}, },
}, },
{
name: "Unmarshal Ct packets direction original case",
ct: Ct{
Register: 1,
Key: CtKeyPKTS,
Direction: CtDirOriginal,
OptDirection: true,
},
},
{
name: "Unmarshal Ct bytes without direction case",
ct: Ct{
Register: 1,
Key: CtKeyBYTES,
},
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -66,7 +66,7 @@ func (e *Immediate) unmarshal(fam byte, data []byte) error {
case unix.NFTA_IMMEDIATE_DATA: case unix.NFTA_IMMEDIATE_DATA:
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
if err != nil { if err != nil {
return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) return fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
} }
for nestedAD.Next() { for nestedAD.Next() {
switch nestedAD.Type() { switch nestedAD.Type() {
@ -75,7 +75,7 @@ func (e *Immediate) unmarshal(fam byte, data []byte) error {
} }
} }
if nestedAD.Err() != nil { if nestedAD.Err() != nil {
return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) return fmt.Errorf("decoding immediate: %w", nestedAD.Err())
} }
} }
} }

View File

@ -111,7 +111,7 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
case unix.NFTA_IMMEDIATE_DATA: case unix.NFTA_IMMEDIATE_DATA:
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
if err != nil { if err != nil {
return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) return fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
} }
for nestedAD.Next() { for nestedAD.Next() {
switch nestedAD.Type() { switch nestedAD.Type() {
@ -123,7 +123,7 @@ func (e *Verdict) unmarshal(fam byte, data []byte) error {
} }
} }
if nestedAD.Err() != nil { if nestedAD.Err() != nil {
return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) return fmt.Errorf("decoding immediate: %w", nestedAD.Err())
} }
} }
} }

View File

@ -214,12 +214,12 @@ func (cc *Conn) getFlowtables(t *Table) ([]netlink.Message, error) {
} }
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err) return nil, fmt.Errorf("SendMessages: %w", err)
} }
reply, err := receiveAckAware(conn, message.Header.Flags) reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil { if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err) return nil, fmt.Errorf("receiveAckAware: %w", err)
} }
return reply, nil return reply, nil

4
obj.go
View File

@ -361,12 +361,12 @@ func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLega
} }
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err) return nil, fmt.Errorf("SendMessages: %w", err)
} }
reply, err := receiveAckAware(conn, message.Header.Flags) reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil { if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err) return nil, fmt.Errorf("receiveAckAware: %w", err)
} }
var objs []Obj var objs []Obj
for _, msg := range reply { for _, msg := range reply {

View File

@ -101,12 +101,12 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
} }
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err) return nil, fmt.Errorf("SendMessages: %w", err)
} }
reply, err := receiveAckAware(conn, message.Header.Flags) reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil { if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err) return nil, fmt.Errorf("receiveAckAware: %w", err)
} }
var rules []*Rule var rules []*Rule
for _, msg := range reply { for _, msg := range reply {

34
set.go
View File

@ -298,7 +298,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
return func(b []byte) error { return func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b) ad, err := netlink.NewAttributeDecoder(b)
if err != nil { if err != nil {
return fmt.Errorf("failed to create nested attribute decoder: %v", err) return fmt.Errorf("failed to create nested attribute decoder: %w", err)
} }
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
@ -353,7 +353,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error {
func decodeElement(d []byte) ([]byte, error) { func decodeElement(d []byte) ([]byte, error) {
ad, err := netlink.NewAttributeDecoder(d) ad, err := netlink.NewAttributeDecoder(d)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err) return nil, fmt.Errorf("failed to create nested attribute decoder: %w", err)
} }
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
var b []byte var b []byte
@ -414,14 +414,14 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}})
if err != nil { if err != nil {
return fmt.Errorf("marshal key %d: %v", i, err) return fmt.Errorf("marshal key %d: %w", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey})
if len(v.KeyEnd) > 0 { if len(v.KeyEnd) > 0 {
encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}})
if err != nil { if err != nil {
return fmt.Errorf("marshal key end %d: %v", i, err) return fmt.Errorf("marshal key end %d: %w", i, err)
} }
item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd})
} }
@ -441,7 +441,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))},
}) })
if err != nil { if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %w", i, err)
} }
encodedVal = append(encodedVal, encodedKind...) encodedVal = append(encodedVal, encodedKind...)
if len(v.VerdictData.Chain) != 0 { if len(v.VerdictData.Chain) != 0 {
@ -449,21 +449,21 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
{Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")},
}) })
if err != nil { if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %w", i, err)
} }
encodedVal = append(encodedVal, encodedChain...) encodedVal = append(encodedVal, encodedChain...)
} }
encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}})
if err != nil { if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %w", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict})
case len(v.Val) > 0: case len(v.Val) > 0:
// Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes
encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}})
if err != nil { if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %w", i, err)
} }
item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal})
@ -479,7 +479,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
encodedItem, err := netlink.MarshalAttributes(item) encodedItem, err := netlink.MarshalAttributes(item)
if err != nil { if err != nil {
return fmt.Errorf("marshal item %d: %v", i, err) return fmt.Errorf("marshal item %d: %w", i, err)
} }
itemSize := unix.NLA_HDRLEN + len(encodedItem) itemSize := unix.NLA_HDRLEN + len(encodedItem)
@ -496,7 +496,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
for _, batch := range batches { for _, batch := range batches {
encodedElem, err := netlink.MarshalAttributes(batch) encodedElem, err := netlink.MarshalAttributes(batch)
if err != nil { if err != nil {
return fmt.Errorf("marshal elements: %v", err) return fmt.Errorf("marshal elements: %w", err)
} }
message := []netlink.Attribute{ message := []netlink.Attribute{
@ -591,7 +591,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) return fmt.Errorf("fail to marshal number of elements %d: %w", len(vals), err)
} }
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
} }
@ -620,7 +620,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
{Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)},
}) })
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal element key size %d: %v", i, err) return fmt.Errorf("fail to marshal element key size %d: %w", i, err)
} }
// Marshal base type size description // Marshal base type size description
descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ descSize, err := netlink.MarshalAttributes([]netlink.Attribute{
@ -634,7 +634,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
// Marshal all base type descriptions into concatenation size description // Marshal all base type descriptions into concatenation size description
concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}})
if err != nil { if err != nil {
return fmt.Errorf("fail to marshal concat definition %v", err) return fmt.Errorf("fail to marshal concat definition %w", err)
} }
descBytes = append(descBytes, concatBytes...) descBytes = append(descBytes, concatBytes...)
@ -890,12 +890,12 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) {
} }
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err) return nil, fmt.Errorf("SendMessages: %w", err)
} }
reply, err := receiveAckAware(conn, message.Header.Flags) reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil { if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err) return nil, fmt.Errorf("receiveAckAware: %w", err)
} }
var sets []*Set var sets []*Set
for _, msg := range reply { for _, msg := range reply {
@ -980,12 +980,12 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
} }
if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
return nil, fmt.Errorf("SendMessages: %v", err) return nil, fmt.Errorf("SendMessages: %w", err)
} }
reply, err := receiveAckAware(conn, message.Header.Flags) reply, err := receiveAckAware(conn, message.Header.Flags)
if err != nil { if err != nil {
return nil, fmt.Errorf("receiveAckAware: %v", err) return nil, fmt.Errorf("receiveAckAware: %w", err)
} }
var elems []SetElement var elems []SetElement
for _, msg := range reply { for _, msg := range reply {