From b77f1a918e8ab3b82b90032576dc46b8f805138d Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:04:40 +0000 Subject: [PATCH] Objects implementation refactor Refactored obj.go to a more generic approach Added object support for already implemented expressions Added test for limit object Fixes https://github.com/google/nftables/issues/253 --- counter.go | 17 +- expr/bitwise.go | 20 ++- expr/byteorder.go | 18 ++- expr/connlimit.go | 12 +- expr/counter.go | 12 +- expr/ct.go | 20 ++- expr/dup.go | 23 +-- expr/dynset.go | 26 +-- expr/expr.go | 201 ++++++++++++++---------- expr/exthdr.go | 20 ++- expr/fib.go | 17 +- expr/flow_offload.go | 10 +- expr/hash.go | 20 ++- expr/immediate.go | 20 ++- expr/limit.go | 22 +-- expr/log.go | 22 +-- expr/lookup.go | 22 +-- expr/match.go | 21 +-- expr/nat.go | 20 ++- expr/notrack.go | 4 + expr/numgen.go | 22 +-- expr/objref.go | 12 +- expr/payload.go | 21 +-- expr/queue.go | 20 ++- expr/quota.go | 22 +-- expr/range.go | 18 ++- expr/redirect.go | 22 +-- expr/reject.go | 12 +- expr/rt.go | 12 +- expr/socket.go | 21 +-- expr/target.go | 21 +-- expr/tproxy.go | 20 ++- expr/verdict.go | 19 ++- internal/parseexprfunc/parseexprfunc.go | 4 +- nftables_test.go | 137 +++++++++++++--- obj.go | 144 ++++++++++++----- quota.go | 25 ++- set.go | 2 +- 38 files changed, 727 insertions(+), 374 deletions(-) diff --git a/counter.go b/counter.go index 25d37d8..34c36aa 100644 --- a/counter.go +++ b/counter.go @@ -16,11 +16,12 @@ package nftables import ( "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -// CounterObj implements Obj. +// Deprecated: Use ObjAttr instead type CounterObj struct { Table *Table Name string // e.g. “fwded” @@ -41,6 +42,20 @@ func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { return ad.Err() } +func (c *CounterObj) data() expr.Any { + return &expr.Counter{ + Bytes: c.Bytes, + Packets: c.Packets, + } +} + +func (c *CounterObj) name() string { + return c.Name +} +func (c *CounterObj) objType() ObjType { + return ObjTypeCounter +} + func (c *CounterObj) table() *Table { return c.Table } diff --git a/expr/bitwise.go b/expr/bitwise.go index 62f7f9b..5f3cdea 100644 --- a/expr/bitwise.go +++ b/expr/bitwise.go @@ -31,6 +31,17 @@ type Bitwise struct { } func (e *Bitwise) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("bitwise\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Bitwise) marshalData(fam byte) ([]byte, error) { mask, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: e.Mask}, }) @@ -44,20 +55,13 @@ func (e *Bitwise) marshal(fam byte) ([]byte, error) { return nil, err } - data, err := netlink.MarshalAttributes([]netlink.Attribute{ + return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_BITWISE_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, {Type: unix.NFTA_BITWISE_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, {Type: unix.NFTA_BITWISE_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, {Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_MASK, Data: mask}, {Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_XOR, Data: xor}, }) - if err != nil { - return nil, err - } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("bitwise\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) } func (e *Bitwise) unmarshal(fam byte, data []byte) error { diff --git a/expr/byteorder.go b/expr/byteorder.go index 2450e8f..cf9e2fe 100644 --- a/expr/byteorder.go +++ b/expr/byteorder.go @@ -38,13 +38,7 @@ type Byteorder struct { } func (e *Byteorder) marshal(fam byte) ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, - {Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, - {Type: unix.NFTA_BYTEORDER_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, - {Type: unix.NFTA_BYTEORDER_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, - {Type: unix.NFTA_BYTEORDER_SIZE, Data: binaryutil.BigEndian.PutUint32(e.Size)}, - }) + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -54,6 +48,16 @@ func (e *Byteorder) marshal(fam byte) ([]byte, error) { }) } +func (e *Byteorder) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + {Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + {Type: unix.NFTA_BYTEORDER_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, + {Type: unix.NFTA_BYTEORDER_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + {Type: unix.NFTA_BYTEORDER_SIZE, Data: binaryutil.BigEndian.PutUint32(e.Size)}, + }) +} + func (e *Byteorder) unmarshal(fam byte, data []byte) error { return fmt.Errorf("not yet implemented") } diff --git a/expr/connlimit.go b/expr/connlimit.go index b712967..11bd07b 100644 --- a/expr/connlimit.go +++ b/expr/connlimit.go @@ -37,10 +37,7 @@ type Connlimit struct { } func (e *Connlimit) marshal(fam byte) ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: NFTA_CONNLIMIT_COUNT, Data: binaryutil.BigEndian.PutUint32(e.Count)}, - {Type: NFTA_CONNLIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}, - }) + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -51,6 +48,13 @@ func (e *Connlimit) marshal(fam byte) ([]byte, error) { }) } +func (e *Connlimit) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTA_CONNLIMIT_COUNT, Data: binaryutil.BigEndian.PutUint32(e.Count)}, + {Type: NFTA_CONNLIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}, + }) +} + func (e *Connlimit) unmarshal(fam byte, data []byte) error { ad, err := netlink.NewAttributeDecoder(data) if err != nil { diff --git a/expr/counter.go b/expr/counter.go index dd6eab3..7483ee4 100644 --- a/expr/counter.go +++ b/expr/counter.go @@ -28,10 +28,7 @@ type Counter struct { } func (e *Counter) marshal(fam byte) ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)}, - {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)}, - }) + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -42,6 +39,13 @@ func (e *Counter) marshal(fam byte) ([]byte, error) { }) } +func (e *Counter) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)}, + {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)}, + }) +} + func (e *Counter) unmarshal(fam byte, data []byte) error { ad, err := netlink.NewAttributeDecoder(data) if err != nil { diff --git a/expr/ct.go b/expr/ct.go index 1a0ee68..4efea02 100644 --- a/expr/ct.go +++ b/expr/ct.go @@ -64,7 +64,19 @@ type Ct struct { } func (e *Ct) marshal(fam byte) ([]byte, error) { - regData := []byte{} + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("ct\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Ct) marshalData(fam byte) ([]byte, error) { + var regData []byte exprData, err := netlink.MarshalAttributes( []netlink.Attribute{ {Type: unix.NFTA_CT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, @@ -90,11 +102,7 @@ func (e *Ct) marshal(fam byte) ([]byte, error) { return nil, err } exprData = append(exprData, regData...) - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("ct\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, - }) + return exprData, nil } func (e *Ct) unmarshal(fam byte, data []byte) error { diff --git a/expr/dup.go b/expr/dup.go index 0114fa7..9012fda 100644 --- a/expr/dup.go +++ b/expr/dup.go @@ -29,16 +29,7 @@ type Dup struct { } func (e *Dup) marshal(fam byte) ([]byte, error) { - attrs := []netlink.Attribute{ - {Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)}, - } - - if e.IsRegDevSet { - attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_DUP_SREG_DEV, Data: binaryutil.BigEndian.PutUint32(e.RegDev)}) - } - - data, err := netlink.MarshalAttributes(attrs) - + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -49,6 +40,18 @@ func (e *Dup) marshal(fam byte) ([]byte, error) { }) } +func (e *Dup) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)}, + } + + if e.IsRegDevSet { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_DUP_SREG_DEV, Data: binaryutil.BigEndian.PutUint32(e.RegDev)}) + } + + return netlink.MarshalAttributes(attrs) +} + func (e *Dup) unmarshal(fam byte, data []byte) error { ad, err := netlink.NewAttributeDecoder(data) if err != nil { diff --git a/expr/dynset.go b/expr/dynset.go index e44f772..aa6bc79 100644 --- a/expr/dynset.go +++ b/expr/dynset.go @@ -44,6 +44,18 @@ type Dynset struct { } func (e *Dynset) marshal(fam byte) ([]byte, error) { + opData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: opData}, + }) +} + +func (e *Dynset) marshalData(fam byte) ([]byte, error) { // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c var opAttrs []netlink.Attribute opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) @@ -89,17 +101,9 @@ func (e *Dynset) marshal(fam byte) ([]byte, error) { opAttrs = append(opAttrs, netlink.Attribute{Type: NFTA_DYNSET_EXPRESSIONS, Data: elemData}) } } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) - - opData, err := netlink.MarshalAttributes(opAttrs) - if err != nil { - return nil, err - } - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: opData}, - }) + return netlink.MarshalAttributes(opAttrs) } func (e *Dynset) unmarshal(fam byte, data []byte) error { @@ -125,7 +129,7 @@ func (e *Dynset) unmarshal(fam byte, data []byte) error { case unix.NFTA_DYNSET_FLAGS: e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0 case unix.NFTA_DYNSET_EXPR: - exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad, ad.Bytes()) + exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad) if err != nil { return err } diff --git a/expr/expr.go b/expr/expr.go index a4d970f..2840835 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -25,8 +25,8 @@ import ( ) func init() { - parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) { - exprs, err := exprsFromBytes(fam, ad, b) + parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder, args ...string) ([]interface{}, error) { + exprs, err := exprsFromBytes(fam, ad, args...) if err != nil { return nil, err } @@ -36,7 +36,7 @@ func init() { } return result, nil } - parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte) ([]interface{}, error) { + parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte, args ...string) ([]interface{}, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err @@ -44,7 +44,7 @@ func init() { ad.ByteOrder = binary.BigEndian var exprs []interface{} for ad.Next() { - e, err := parseexprfunc.ParseExprBytesFunc(fam, ad, b) + e, err := parseexprfunc.ParseExprBytesFunc(fam, ad, args...) if err != nil { return e, err } @@ -59,6 +59,10 @@ func Marshal(fam byte, e Any) ([]byte, error) { return e.marshal(fam) } +func MarshalExprData(fam byte, e Any) ([]byte, error) { + return e.marshalData(fam) +} + // Unmarshal fills an expression from the specified byte slice. func Unmarshal(fam byte, data []byte, e Any) error { return e.unmarshal(fam, data) @@ -66,8 +70,20 @@ func Unmarshal(fam byte, data []byte, e Any) error { // exprsFromBytes parses nested raw expressions bytes // to construct nftables expressions -func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, error) { +func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, args ...string) ([]Any, error) { var exprs []Any + if len(args) > 0 { + e := exprFromName(args[0]) + ad.Do(func(b []byte) error { + if err := Unmarshal(fam, b, e); err != nil { + return err + } + exprs = append(exprs, e) + return nil + }) + return exprs, ad.Err() + } + ad.Do(func(b []byte) error { ad, err := netlink.NewAttributeDecoder(b) if err != nil { @@ -84,65 +100,12 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er exprs = append(exprs, e) } case unix.NFTA_EXPR_DATA: - var e Any - switch name { - case "ct": - e = &Ct{} - case "range": - e = &Range{} - case "meta": - e = &Meta{} - case "cmp": - e = &Cmp{} - case "counter": - e = &Counter{} - case "objref": - e = &Objref{} - case "payload": - e = &Payload{} - case "lookup": - e = &Lookup{} - case "immediate": - e = &Immediate{} - case "bitwise": - e = &Bitwise{} - case "redir": - e = &Redir{} - case "nat": - e = &NAT{} - case "limit": - e = &Limit{} - case "quota": - e = &Quota{} - case "dynset": - e = &Dynset{} - case "log": - e = &Log{} - case "exthdr": - e = &Exthdr{} - case "match": - e = &Match{} - case "target": - e = &Target{} - case "connlimit": - e = &Connlimit{} - case "queue": - e = &Queue{} - case "flow_offload": - e = &FlowOffload{} - case "reject": - e = &Reject{} - case "masq": - e = &Masq{} - case "hash": - e = &Hash{} - } + e := exprFromName(name) if e == nil { // TODO: introduce an opaque expression type so that users know // something is here. continue // unsupported expression type } - ad.Do(func(b []byte) error { if err := Unmarshal(fam, b, e); err != nil { return err @@ -166,9 +129,67 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er return exprs, ad.Err() } +func exprFromName(name string) Any { + var e Any + switch name { + case "ct": + e = &Ct{} + case "range": + e = &Range{} + case "meta": + e = &Meta{} + case "cmp": + e = &Cmp{} + case "counter": + e = &Counter{} + case "objref": + e = &Objref{} + case "payload": + e = &Payload{} + case "lookup": + e = &Lookup{} + case "immediate": + e = &Immediate{} + case "bitwise": + e = &Bitwise{} + case "redir": + e = &Redir{} + case "nat": + e = &NAT{} + case "limit": + e = &Limit{} + case "quota": + e = &Quota{} + case "dynset": + e = &Dynset{} + case "log": + e = &Log{} + case "exthdr": + e = &Exthdr{} + case "match": + e = &Match{} + case "target": + e = &Target{} + case "connlimit": + e = &Connlimit{} + case "queue": + e = &Queue{} + case "flow_offload": + e = &FlowOffload{} + case "reject": + e = &Reject{} + case "masq": + e = &Masq{} + case "hash": + e = &Hash{} + } + return e +} + // Any is an interface implemented by any expression type. type Any interface { marshal(fam byte) ([]byte, error) + marshalData(fam byte) ([]byte, error) unmarshal(fam byte, data []byte) error } @@ -214,7 +235,19 @@ type Meta struct { } func (e *Meta) marshal(fam byte) ([]byte, error) { - regData := []byte{} + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("meta\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Meta) marshalData(fam byte) ([]byte, error) { + var regData []byte exprData, err := netlink.MarshalAttributes( []netlink.Attribute{ {Type: unix.NFTA_META_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, @@ -240,11 +273,7 @@ func (e *Meta) marshal(fam byte) ([]byte, error) { return nil, err } exprData = append(exprData, regData...) - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("meta\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, - }) + return exprData, nil } func (e *Meta) unmarshal(fam byte, data []byte) error { @@ -291,6 +320,17 @@ const ( ) func (e *Masq) marshal(fam byte) ([]byte, error) { + msgData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("masq\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: msgData}, + }) +} + +func (e *Masq) marshalData(fam byte) ([]byte, error) { msgData := []byte{} if !e.ToPorts { flags := uint32(0) @@ -327,10 +367,7 @@ func (e *Masq) marshal(fam byte) ([]byte, error) { msgData = append(msgData, regsData...) } } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("masq\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: msgData}, - }) + return msgData, nil } func (e *Masq) unmarshal(fam byte, data []byte) error { @@ -377,17 +414,7 @@ type Cmp struct { } func (e *Cmp) marshal(fam byte) ([]byte, error) { - cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, - }) - if err != nil { - return nil, err - } - exprData, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_CMP_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, - {Type: unix.NFTA_CMP_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, - {Type: unix.NLA_F_NESTED | unix.NFTA_CMP_DATA, Data: cmpData}, - }) + exprData, err := e.marshalData(fam) if err != nil { return nil, err } @@ -397,6 +424,20 @@ func (e *Cmp) marshal(fam byte) ([]byte, error) { }) } +func (e *Cmp) marshalData(fam byte) ([]byte, error) { + cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_CMP_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NFTA_CMP_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, + {Type: unix.NLA_F_NESTED | unix.NFTA_CMP_DATA, Data: cmpData}, + }) +} + func (e *Cmp) unmarshal(fam byte, data []byte) error { ad, err := netlink.NewAttributeDecoder(data) if err != nil { diff --git a/expr/exthdr.go b/expr/exthdr.go index df0c7db..0a9d9fc 100644 --- a/expr/exthdr.go +++ b/expr/exthdr.go @@ -40,6 +40,17 @@ type Exthdr struct { } func (e *Exthdr) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("exthdr\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Exthdr) marshalData(fam byte) ([]byte, error) { var attr []netlink.Attribute // Operations are differentiated by the Op and whether the SourceRegister @@ -64,14 +75,7 @@ func (e *Exthdr) marshal(fam byte) ([]byte, error) { netlink.Attribute{Type: unix.NFTA_EXTHDR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) } - data, err := netlink.MarshalAttributes(attr) - if err != nil { - return nil, err - } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("exthdr\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return netlink.MarshalAttributes(attr) } func (e *Exthdr) unmarshal(fam byte, data []byte) error { diff --git a/expr/fib.go b/expr/fib.go index f7ee704..62a717c 100644 --- a/expr/fib.go +++ b/expr/fib.go @@ -37,6 +37,17 @@ type Fib struct { } func (e *Fib) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("fib\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Fib) marshalData(fam byte) ([]byte, error) { data := []byte{} reg, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_FIB_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, @@ -92,11 +103,7 @@ func (e *Fib) marshal(fam byte) ([]byte, error) { } data = append(data, rslt...) } - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("fib\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return data, nil } func (e *Fib) unmarshal(fam byte, data []byte) error { diff --git a/expr/flow_offload.go b/expr/flow_offload.go index 54f956f..de4949a 100644 --- a/expr/flow_offload.go +++ b/expr/flow_offload.go @@ -28,9 +28,7 @@ type FlowOffload struct { } func (e *FlowOffload) marshal(fam byte) ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: NFTNL_EXPR_FLOW_TABLE_NAME, Data: []byte(e.Name)}, - }) + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -41,6 +39,12 @@ func (e *FlowOffload) marshal(fam byte) ([]byte, error) { }) } +func (e *FlowOffload) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTNL_EXPR_FLOW_TABLE_NAME, Data: []byte(e.Name)}, + }) +} + func (e *FlowOffload) unmarshal(fam byte, data []byte) error { ad, err := netlink.NewAttributeDecoder(data) if err != nil { diff --git a/expr/hash.go b/expr/hash.go index e8506b9..92b9eea 100644 --- a/expr/hash.go +++ b/expr/hash.go @@ -41,6 +41,17 @@ type Hash struct { } func (e *Hash) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("hash\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Hash) marshalData(fam byte) ([]byte, error) { hashAttrs := []netlink.Attribute{ {Type: unix.NFTA_HASH_SREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.SourceRegister))}, {Type: unix.NFTA_HASH_DREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.DestRegister))}, @@ -56,14 +67,7 @@ func (e *Hash) marshal(fam byte) ([]byte, error) { {Type: unix.NFTA_HASH_OFFSET, Data: binaryutil.BigEndian.PutUint32(uint32(e.Offset))}, {Type: unix.NFTA_HASH_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, }...) - data, err := netlink.MarshalAttributes(hashAttrs) - if err != nil { - return nil, err - } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("hash\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return netlink.MarshalAttributes(hashAttrs) } func (e *Hash) unmarshal(fam byte, data []byte) error { diff --git a/expr/immediate.go b/expr/immediate.go index 99531f8..19eea44 100644 --- a/expr/immediate.go +++ b/expr/immediate.go @@ -29,6 +29,17 @@ type Immediate struct { } func (e *Immediate) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Immediate) marshalData(fam byte) ([]byte, error) { immData, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, }) @@ -36,17 +47,10 @@ func (e *Immediate) marshal(fam byte) ([]byte, error) { return nil, err } - data, err := netlink.MarshalAttributes([]netlink.Attribute{ + return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_IMMEDIATE_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, {Type: unix.NLA_F_NESTED | unix.NFTA_IMMEDIATE_DATA, Data: immData}, }) - if err != nil { - return nil, err - } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) } func (e *Immediate) unmarshal(fam byte, data []byte) error { diff --git a/expr/limit.go b/expr/limit.go index 9ecb41f..9d2facc 100644 --- a/expr/limit.go +++ b/expr/limit.go @@ -72,6 +72,18 @@ type Limit struct { } func (l *Limit) marshal(fam byte) ([]byte, error) { + data, err := l.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("limit\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (l *Limit) marshalData(fam byte) ([]byte, error) { var flags uint32 if l.Over { flags = unix.NFT_LIMIT_F_INV @@ -84,15 +96,7 @@ func (l *Limit) marshal(fam byte) ([]byte, error) { {Type: unix.NFTA_LIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, } - data, err := netlink.MarshalAttributes(attrs) - if err != nil { - return nil, err - } - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("limit\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return netlink.MarshalAttributes(attrs) } func (l *Limit) unmarshal(fam byte, data []byte) error { diff --git a/expr/log.go b/expr/log.go index a712b99..eaa057a 100644 --- a/expr/log.go +++ b/expr/log.go @@ -69,6 +69,18 @@ type Log struct { } func (e *Log) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("log\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Log) marshalData(fam byte) ([]byte, error) { // Per https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n129 attrs := make([]netlink.Attribute, 0) if e.Key&(1< 0 { attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMin)}) @@ -40,15 +52,7 @@ func (e *Redir) marshal(fam byte) ([]byte, error) { attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) } - data, err := netlink.MarshalAttributes(attrs) - if err != nil { - return nil, err - } - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("redir\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return netlink.MarshalAttributes(attrs) } func (e *Redir) unmarshal(fam byte, data []byte) error { diff --git a/expr/reject.go b/expr/reject.go index a742626..7fe216d 100644 --- a/expr/reject.go +++ b/expr/reject.go @@ -28,10 +28,7 @@ type Reject struct { } func (e *Reject) marshal(fam byte) ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, - {Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}}, - }) + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -41,6 +38,13 @@ func (e *Reject) marshal(fam byte) ([]byte, error) { }) } +func (e *Reject) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, + {Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}}, + }) +} + func (e *Reject) unmarshal(fam byte, data []byte) error { ad, err := netlink.NewAttributeDecoder(data) if err != nil { diff --git a/expr/rt.go b/expr/rt.go index c3be7ff..21c3a63 100644 --- a/expr/rt.go +++ b/expr/rt.go @@ -37,10 +37,7 @@ type Rt struct { } func (e *Rt) marshal(fam byte) ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, - {Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, - }) + data, err := e.marshalData(fam) if err != nil { return nil, err } @@ -50,6 +47,13 @@ func (e *Rt) marshal(fam byte) ([]byte, error) { }) } +func (e *Rt) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + {Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }) +} + func (e *Rt) unmarshal(fam byte, data []byte) error { return fmt.Errorf("not yet implemented") } diff --git a/expr/socket.go b/expr/socket.go index 1b6bc24..e3843cc 100644 --- a/expr/socket.go +++ b/expr/socket.go @@ -49,23 +49,26 @@ const ( ) func (e *Socket) marshal(fam byte) ([]byte, error) { + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("socket\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Socket) marshalData(fam byte) ([]byte, error) { // NOTE: Socket.Level is only used when Socket.Key == SocketKeyCgroupv2. But `nft` always encoding it. Check link below: // http://git.netfilter.org/nftables/tree/src/netlink_linearize.c?id=0583bac241ea18c9d7f61cb20ca04faa1e043b78#n319 - exprData, err := netlink.MarshalAttributes( + return netlink.MarshalAttributes( []netlink.Attribute{ {Type: NFTA_SOCKET_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, {Type: NFTA_SOCKET_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, {Type: NFTA_SOCKET_LEVEL, Data: binaryutil.BigEndian.PutUint32(uint32(e.Level))}, }, ) - if err != nil { - return nil, err - } - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("socket\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, - }) } func (e *Socket) unmarshal(fam byte, data []byte) error { diff --git a/expr/target.go b/expr/target.go index e531a9f..d1c800b 100644 --- a/expr/target.go +++ b/expr/target.go @@ -21,6 +21,17 @@ type Target struct { } func (e *Target) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("target\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Target) marshalData(fam byte) ([]byte, error) { // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38 name := e.Name // limit the extension name as (some) user-space tools do and leave room for @@ -40,15 +51,7 @@ func (e *Target) marshal(fam byte) ([]byte, error) { {Type: unix.NFTA_TARGET_INFO, Data: info}, } - data, err := netlink.MarshalAttributes(attrs) - if err != nil { - return nil, err - } - - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("target\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return netlink.MarshalAttributes(attrs) } func (e *Target) unmarshal(fam byte, data []byte) error { diff --git a/expr/tproxy.go b/expr/tproxy.go index 2846aab..142740c 100644 --- a/expr/tproxy.go +++ b/expr/tproxy.go @@ -40,6 +40,17 @@ type TProxy struct { } func (e *TProxy) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("tproxy\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *TProxy) marshalData(fam byte) ([]byte, error) { attrs := []netlink.Attribute{ {Type: NFTA_TPROXY_FAMILY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Family))}, {Type: NFTA_TPROXY_REG_PORT, Data: binaryutil.BigEndian.PutUint32(e.RegPort)}, @@ -52,14 +63,7 @@ func (e *TProxy) marshal(fam byte) ([]byte, error) { }) } - data, err := netlink.MarshalAttributes(attrs) - if err != nil { - return nil, err - } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("tproxy\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) + return netlink.MarshalAttributes(attrs) } func (e *TProxy) unmarshal(fam byte, data []byte) error { diff --git a/expr/verdict.go b/expr/verdict.go index 421fa06..239b408 100644 --- a/expr/verdict.go +++ b/expr/verdict.go @@ -64,7 +64,17 @@ func (e *Verdict) marshal(fam byte) ([]byte, error) { // } // } // } + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} +func (e *Verdict) marshalData(fam byte) ([]byte, error) { attrs := []netlink.Attribute{ {Type: unix.NFTA_VERDICT_CODE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Kind))}, } @@ -83,17 +93,10 @@ func (e *Verdict) marshal(fam byte) ([]byte, error) { return nil, err } - data, err := netlink.MarshalAttributes([]netlink.Attribute{ + return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_IMMEDIATE_DREG, Data: binaryutil.BigEndian.PutUint32(unix.NFT_REG_VERDICT)}, {Type: unix.NLA_F_NESTED | unix.NFTA_IMMEDIATE_DATA, Data: immData}, }) - if err != nil { - return nil, err - } - return netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, - {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, - }) } func (e *Verdict) unmarshal(fam byte, data []byte) error { diff --git a/internal/parseexprfunc/parseexprfunc.go b/internal/parseexprfunc/parseexprfunc.go index 523859d..ae840b4 100644 --- a/internal/parseexprfunc/parseexprfunc.go +++ b/internal/parseexprfunc/parseexprfunc.go @@ -5,6 +5,6 @@ import ( ) var ( - ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) - ParseExprMsgFunc func(fam byte, b []byte) ([]interface{}, error) + ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder, args ...string) ([]interface{}, error) + ParseExprMsgFunc func(fam byte, b []byte, args ...string) ([]interface{}, error) ) diff --git a/nftables_test.go b/nftables_test.go index be8b83b..5ac8ab1 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1783,7 +1783,7 @@ func TestListChainByName(t *testing.T) { } func TestListChainByNameUsingLasting(t *testing.T) { - conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) if err != nil { t.Fatalf("nftables.New() failed: %v", err) @@ -1882,8 +1882,7 @@ func TestListTableByName(t *testing.T) { } // not specifying correct family should return err since no table in ipv4 - tr, err = conn.ListTable(table2.Name) - if err == nil { + if _, err = conn.ListTable(table2.Name); err == nil { t.Fatalf("conn.ListTable() should have failed") } @@ -2114,9 +2113,9 @@ func TestGetObjReset(t *testing.T) { t.Fatal(err) } - co, ok := obj.(*nftables.CounterObj) + co, ok := obj.(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj) + t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj) } if got, want := co.Table.Name, filter.Name; got != want { t.Errorf("unexpected table name: got %q, want %q", got, want) @@ -2124,10 +2123,14 @@ func TestGetObjReset(t *testing.T) { if got, want := co.Table.Family, filter.Family; got != want { t.Errorf("unexpected table family: got %d, want %d", got, want) } - if got, want := co.Packets, uint64(9); got != want { + o, ok := co.Obj.(*expr.Counter) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Counter", o) + } + if got, want := o.Packets, uint64(9); got != want { t.Errorf("unexpected number of packets: got %d, want %d", got, want) } - if got, want := co.Bytes, uint64(1121); got != want { + if got, want := o.Bytes, uint64(1121); got != want { t.Errorf("unexpected number of bytes: got %d, want %d", got, want) } } @@ -2223,10 +2226,9 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter1) failed: %v failed", err) } - rcounter1, ok := obj1.(*nftables.CounterObj) - + rcounter1, ok := obj1.(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter1) + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj1) } if rcounter1.Name != "fwded1" { @@ -2238,10 +2240,9 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter2) failed: %v failed", err) } - rcounter2, ok := obj2.(*nftables.CounterObj) - + rcounter2, ok := obj2.(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter2) + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2) } if rcounter2.Name != "fwded2" { @@ -2260,7 +2261,7 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter1) failed: %v failed", err) } - if counter1 := obj1.(*nftables.CounterObj); counter1.Packets > 0 { + if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 { t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) } @@ -2270,7 +2271,7 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter2) failed: %v failed", err) } - if counter2 := obj2.(*nftables.CounterObj); counter2.Packets != 1 { + if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 { t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) } @@ -2767,7 +2768,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { } func TestCappedErrMsgOnSets(t *testing.T) { - c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) if err != nil { t.Fatalf("nftables.New() failed: %v", err) @@ -6285,6 +6286,84 @@ func TestGetRulesObjref(t *testing.T) { } } +func TestAddLimitObj(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "limit_demo", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + l := &expr.Limit{ + Type: expr.LimitTypePkts, + Rate: 400, + Unit: expr.LimitTimeMinute, + Burst: 5, + Over: false, + } + o := &nftables.ObjAttr{ + Table: tr, + Name: "limit_test", + Type: nftables.ObjTypeLimit, + Obj: l, + } + conn.AddObj(o) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + obj, err := conn.GetObj(&nftables.ObjAttr{ + Table: table, + Name: "limit_test", + Type: nftables.ObjTypeLimit, + }) + if err != nil { + t.Fatalf("conn.GetObj() failed: %v", err) + } + + if got, want := len(obj), 1; got != want { + t.Fatalf("unexpected object list length: got %d, want %d", got, want) + } + + o1, ok := obj[0].(*nftables.ObjAttr) + if !ok { + t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) + } + if got, want := o1.Name, o.Name; got != want { + t.Fatalf("limit name mismatch: got %s, want %s", got, want) + } + q, ok := o1.Obj.(*expr.Limit) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) + } + if got, want := q.Burst, l.Burst; got != want { + t.Fatalf("limit burst mismatch: got %d, want %d", got, want) + } + if got, want := q.Unit, l.Unit; got != want { + t.Fatalf("limit unit mismatch: got %d, want %d", got, want) + } + if got, want := q.Rate, l.Rate; got != want { + t.Fatalf("limit rate mismatch: got %v, want %v", got, want) + } + if got, want := q.Over, l.Over; got != want { + t.Fatalf("limit over mismatch: got %v, want %v", got, want) + } + if got, want := q.Type, l.Type; got != want { + t.Fatalf("limit type mismatch: got %v, want %v", got, want) + } +} + func TestAddQuotaObj(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) @@ -6328,20 +6407,24 @@ func TestAddQuotaObj(t *testing.T) { t.Fatalf("unexpected object list length: got %d, want %d", got, want) } - o1, ok := obj[0].(*nftables.QuotaObj) + o1, ok := obj[0].(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0]) + t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) } if got, want := o1.Name, o.Name; got != want { t.Fatalf("quota name mismatch: got %s, want %s", got, want) } - if got, want := o1.Bytes, o.Bytes; got != want { + q, ok := o1.Obj.(*expr.Quota) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) + } + if got, want := q.Bytes, o.Bytes; got != want { t.Fatalf("quota bytes mismatch: got %d, want %d", got, want) } - if got, want := o1.Consumed, o.Consumed; got != want { + if got, want := q.Consumed, o.Consumed; got != want { t.Fatalf("quota consumed mismatch: got %d, want %d", got, want) } - if got, want := o1.Over, o.Over; got != want { + if got, want := q.Over, o.Over; got != want { t.Fatalf("quota over mismatch: got %v, want %v", got, want) } } @@ -6452,7 +6535,17 @@ func TestDeleteQuotaObj(t *testing.T) { t.Fatalf("unexpected number of objects: got %d, want %d", got, want) } - if got, want := obj[0], o; !reflect.DeepEqual(got, want) { + want := &nftables.ObjAttr{ + Table: tr, + Name: "q_test", + Type: nftables.ObjTypeQuota, + Obj: &expr.Quota{ + Bytes: o.Bytes, + Consumed: o.Consumed, + Over: o.Over, + }, + } + if got, want := obj[0], want; !reflect.DeepEqual(got, want) { t.Errorf("got = %+v, want = %+v", got, want) } diff --git a/obj.go b/obj.go index c468a63..1527a55 100644 --- a/obj.go +++ b/obj.go @@ -18,6 +18,9 @@ import ( "encoding/binary" "fmt" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/google/nftables/internal/parseexprfunc" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) @@ -27,13 +30,70 @@ var ( delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ) ) +type ObjType uint32 + +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1612 +const ( + ObjTypeCounter ObjType = unix.NFT_OBJECT_COUNTER + ObjTypeQuota ObjType = unix.NFT_OBJECT_QUOTA + ObjTypeCtHelper ObjType = unix.NFT_OBJECT_CT_HELPER + ObjTypeLimit ObjType = unix.NFT_OBJECT_LIMIT + ObjTypeConnLimit ObjType = unix.NFT_OBJECT_CONNLIMIT + ObjTypeTunnel ObjType = unix.NFT_OBJECT_TUNNEL + ObjTypeCtTimeout ObjType = unix.NFT_OBJECT_CT_TIMEOUT + ObjTypeSecMark ObjType = unix.NFT_OBJECT_SECMARK + ObjTypeCtExpect ObjType = unix.NFT_OBJECT_CT_EXPECT + ObjTypeSynProxy ObjType = unix.NFT_OBJECT_SYNPROXY +) + +var objByObjTypeMagic = map[ObjType]string{ + ObjTypeCounter: "counter", + ObjTypeQuota: "quota", + ObjTypeLimit: "limit", + ObjTypeConnLimit: "connlimit", + ObjTypeCtHelper: "cthelper", // not implemented in expr + ObjTypeTunnel: "tunnel", // not implemented in expr + ObjTypeCtTimeout: "cttimeout", // not implemented in expr + ObjTypeSecMark: "secmark", // not implemented in expr + ObjTypeCtExpect: "ctexpect", // not implemented in expr + ObjTypeSynProxy: "synproxy", // not implemented in expr +} + // Obj represents a netfilter stateful object. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects type Obj interface { table() *Table family() TableFamily - unmarshal(*netlink.AttributeDecoder) error - marshal(data bool) ([]byte, error) + data() expr.Any + name() string + objType() ObjType +} + +type ObjAttr struct { + Table *Table + Name string + Type ObjType + Obj expr.Any +} + +func (o *ObjAttr) table() *Table { + return o.Table +} + +func (o *ObjAttr) family() TableFamily { + return o.Table.Family +} + +func (o *ObjAttr) data() expr.Any { + return o.Obj +} + +func (o *ObjAttr) name() string { + return o.Name +} + +func (o *ObjAttr) objType() ObjType { + return o.Type } // AddObject adds the specified Obj. Alias of AddObj. @@ -46,18 +106,27 @@ func (cc *Conn) AddObject(o Obj) Obj { func (cc *Conn) AddObj(o Obj) Obj { cc.mu.Lock() defer cc.mu.Unlock() - data, err := o.marshal(true) + data, err := expr.MarshalExprData(byte(o.family()), o.data()) if err != nil { cc.setErr(err) return nil } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + if len(data) > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) + } + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, - Data: append(extraHeader(uint8(o.family()), 0), data...), + Data: append(extraHeader(uint8(o.family()), 0), cc.marshalAttr(attrs)...), }) return o } @@ -66,12 +135,12 @@ func (cc *Conn) AddObj(o Obj) Obj { func (cc *Conn) DeleteObject(o Obj) { cc.mu.Lock() defer cc.mu.Unlock() - data, err := o.marshal(false) - if err != nil { - cc.setErr(err) - return + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, } - + data := cc.marshalAttr(attrs) data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) cc.messages = append(cc.messages, netlink.Message{ @@ -150,38 +219,26 @@ func objFromMsg(msg netlink.Message) (Obj, error) { case unix.NFTA_OBJ_TYPE: objectType = ad.Uint32() case unix.NFTA_OBJ_DATA: - switch objectType { - case unix.NFT_OBJECT_COUNTER: - o := CounterObj{ - Table: table, - Name: name, - } - - ad.Do(func(b []byte) error { - ad, err := netlink.NewAttributeDecoder(b) - if err != nil { - return err - } - ad.ByteOrder = binary.BigEndian - return o.unmarshal(ad) - }) - return &o, ad.Err() - case NFT_OBJECT_QUOTA: - o := QuotaObj{ - Table: table, - Name: name, - } - - ad.Do(func(b []byte) error { - ad, err := netlink.NewAttributeDecoder(b) - if err != nil { - return err - } - ad.ByteOrder = binary.BigEndian - return o.unmarshal(ad) - }) - return &o, ad.Err() + o := ObjAttr{ + Table: table, + Name: name, + Type: ObjType(objectType), } + + objs, err := parseexprfunc.ParseExprBytesFunc(byte(o.family()), ad, objByObjTypeMagic[o.Type]) + if err != nil { + return nil, err + } + exprs := make([]expr.Any, len(objs)) + for i := range exprs { + exprs[i] = objs[i].(expr.Any) + } + if len(exprs) == 0 { + return nil, fmt.Errorf("objFromMsg: exprs is empty for obj %v", o) + } + + o.Obj = exprs[0] + return &o, ad.Err() } } if err := ad.Err(); err != nil { @@ -201,7 +258,12 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { var flags netlink.HeaderFlags if o != nil { - data, err = o.marshal(false) + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + data = cc.marshalAttr(attrs) } else { flags = netlink.Dump data, err = netlink.MarshalAttributes([]netlink.Attribute{ diff --git a/quota.go b/quota.go index 71cb9bb..e3c71b1 100644 --- a/quota.go +++ b/quota.go @@ -16,15 +16,12 @@ package nftables import ( "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -const ( - NFTA_OBJ_USERDATA = 8 - NFT_OBJECT_QUOTA = 2 -) - +// Deprecated: Use ObjAttr instead type QuotaObj struct { Table *Table Name string @@ -63,7 +60,7 @@ func (q *QuotaObj) marshal(data bool) ([]byte, error) { attrs := []netlink.Attribute{ {Type: unix.NFTA_OBJ_TABLE, Data: []byte(q.Table.Name + "\x00")}, {Type: unix.NFTA_OBJ_NAME, Data: []byte(q.Name + "\x00")}, - {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(NFT_OBJECT_QUOTA)}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(unix.NFT_OBJECT_QUOTA)}, } if data { attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj}) @@ -78,3 +75,19 @@ func (q *QuotaObj) table() *Table { func (q *QuotaObj) family() TableFamily { return q.Table.Family } + +func (q *QuotaObj) data() expr.Any { + return &expr.Quota{ + Bytes: q.Bytes, + Consumed: q.Consumed, + Over: q.Over, + } +} + +func (q *QuotaObj) name() string { + return q.Name +} + +func (q *QuotaObj) objType() ObjType { + return ObjTypeQuota +} diff --git a/set.go b/set.go index 192c619..36163a9 100644 --- a/set.go +++ b/set.go @@ -321,7 +321,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error { case unix.NFTA_SET_ELEM_EXPIRATION: s.Expires = time.Millisecond * time.Duration(ad.Uint64()) case unix.NFTA_SET_ELEM_EXPR: - elems, err := parseexprfunc.ParseExprBytesFunc(fam, ad, ad.Bytes()) + elems, err := parseexprfunc.ParseExprBytesFunc(fam, ad) if err != nil { return err }