From 409eade12e2cb8abb7c89d7e61ebbdf0b1d57353 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Fri, 10 Aug 2018 18:59:05 +0200 Subject: [PATCH] switch to new netlink.AttributeDecoder fixes #2 --- expr/counter.go | 16 +++--- expr/expr.go | 50 +++++++++------- expr/objref.go | 16 +++--- expr/payload.go | 20 ++++--- nftables.go | 148 +++++++++++++++++++++++++++--------------------- 5 files changed, 140 insertions(+), 110 deletions(-) diff --git a/expr/counter.go b/expr/counter.go index 272c59a..d441cd8 100644 --- a/expr/counter.go +++ b/expr/counter.go @@ -15,6 +15,8 @@ package expr import ( + "encoding/binary" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -41,18 +43,18 @@ func (e *Counter) marshal() ([]byte, error) { } func (e *Counter) unmarshal(data []byte) error { - attrs, err := netlink.UnmarshalAttributes(data) + ad, err := netlink.NewAttributeDecoder(data) if err != nil { return err } - for _, attr := range attrs { - switch attr.Type { + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { case unix.NFTA_COUNTER_BYTES: - e.Bytes = binaryutil.BigEndian.Uint64(attr.Data) + e.Bytes = ad.Uint64() case unix.NFTA_COUNTER_PACKETS: - e.Packets = binaryutil.BigEndian.Uint64(attr.Data) + e.Packets = ad.Uint64() } } - - return nil + return ad.Err() } diff --git a/expr/expr.go b/expr/expr.go index 5ee854e..da50d3b 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -16,6 +16,7 @@ package expr import ( + "encoding/binary" "fmt" "github.com/google/nftables/binaryutil" @@ -96,20 +97,20 @@ func (e *Meta) marshal() ([]byte, error) { } func (e *Meta) unmarshal(data []byte) error { - attrs, err := netlink.UnmarshalAttributes(data) + ad, err := netlink.NewAttributeDecoder(data) if err != nil { return err } - for _, attr := range attrs { - switch attr.Type { + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { case unix.NFTA_META_DREG: - e.Register = binaryutil.BigEndian.Uint32(attr.Data) + e.Register = ad.Uint32() case unix.NFTA_META_KEY: - e.Key = MetaKey(binaryutil.BigEndian.Uint32(attr.Data)) + e.Key = MetaKey(ad.Uint32()) } } - - return nil + return ad.Err() } // Masq (Masquerade) is a special case of SNAT, where the source address is @@ -170,26 +171,33 @@ func (e *Cmp) marshal() ([]byte, error) { } func (e *Cmp) unmarshal(data []byte) error { - attrs, err := netlink.UnmarshalAttributes(data) + ad, err := netlink.NewAttributeDecoder(data) if err != nil { return err } - for _, attr := range attrs { - switch attr.Type { + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { case unix.NFTA_CMP_SREG: - e.Register = binaryutil.BigEndian.Uint32(attr.Data) + e.Register = ad.Uint32() case unix.NFTA_CMP_OP: - e.Op = CmpOp(binaryutil.BigEndian.Uint32(attr.Data)) + e.Op = CmpOp(ad.Uint32()) case unix.NFTA_CMP_DATA: - attrs, err := netlink.UnmarshalAttributes(attr.Data) - if err != nil { - return err - } - if len(attrs) == 1 && attrs[0].Type == unix.NFTA_DATA_VALUE { - e.Data = attrs[0].Data - } + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.Data = b + return nil + }) + } + return ad.Err() + }) } } - - return nil + return ad.Err() } diff --git a/expr/objref.go b/expr/objref.go index 966d826..39de39e 100644 --- a/expr/objref.go +++ b/expr/objref.go @@ -15,6 +15,8 @@ package expr import ( + "encoding/binary" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -41,18 +43,18 @@ func (e *Objref) marshal() ([]byte, error) { } func (e *Objref) unmarshal(data []byte) error { - attrs, err := netlink.UnmarshalAttributes(data) + ad, err := netlink.NewAttributeDecoder(data) if err != nil { return err } - for _, attr := range attrs { - switch attr.Type { + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { case unix.NFTA_OBJREF_IMM_TYPE: - e.Type = int(binaryutil.BigEndian.Uint32(attr.Data)) + e.Type = int(ad.Uint32()) case unix.NFTA_OBJREF_IMM_NAME: - e.Name = string(attr.Data) + e.Name = ad.String() } } - - return nil + return ad.Err() } diff --git a/expr/payload.go b/expr/payload.go index 11727c3..d6a53d4 100644 --- a/expr/payload.go +++ b/expr/payload.go @@ -15,6 +15,8 @@ package expr import ( + "encoding/binary" + "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" @@ -53,22 +55,22 @@ func (e *Payload) marshal() ([]byte, error) { } func (e *Payload) unmarshal(data []byte) error { - attrs, err := netlink.UnmarshalAttributes(data) + ad, err := netlink.NewAttributeDecoder(data) if err != nil { return err } - for _, attr := range attrs { - switch attr.Type { + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { case unix.NFTA_PAYLOAD_DREG: - e.DestRegister = binaryutil.BigEndian.Uint32(attr.Data) + e.DestRegister = ad.Uint32() case unix.NFTA_PAYLOAD_BASE: - e.Base = PayloadBase(binaryutil.BigEndian.Uint32(attr.Data)) + e.Base = PayloadBase(ad.Uint32()) case unix.NFTA_PAYLOAD_OFFSET: - e.Offset = binaryutil.BigEndian.Uint32(attr.Data) + e.Offset = ad.Uint32() case unix.NFTA_PAYLOAD_LEN: - e.Len = binaryutil.BigEndian.Uint32(attr.Data) + e.Len = ad.Uint32() } } - - return nil + return ad.Err() } diff --git a/nftables.go b/nftables.go index 4f86dfc..06965f7 100644 --- a/nftables.go +++ b/nftables.go @@ -16,6 +16,7 @@ package nftables import ( + "encoding/binary" "fmt" "math" "strings" @@ -183,75 +184,81 @@ func stringFrom0(b []byte) string { } func exprsFromMsg(b []byte) ([]expr.Any, error) { - elems, err := netlink.UnmarshalAttributes(b) + ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err } + ad.ByteOrder = binary.BigEndian var exprs []expr.Any - for _, elem := range elems { - attrs, err := netlink.UnmarshalAttributes(elem.Data) - if err != nil { - return nil, err - } - var ( - name string - data []byte - ) - for _, attr := range attrs { - switch attr.Type { - case unix.NFTA_EXPR_NAME: - name = stringFrom0(attr.Data) - case unix.NFTA_EXPR_DATA: - data = attr.Data + for ad.Next() { + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err } - } - var e expr.Any - switch name { - case "meta": - e = &expr.Meta{} - case "cmp": - e = &expr.Cmp{} - case "counter": - e = &expr.Counter{} - case "payload": - e = &expr.Payload{} - } - if e == nil { - // TODO: introduce an opaque expression type so that users know - // something is here. - continue // unsupported expression type - } - if err := expr.Unmarshal(data, e); err != nil { - return nil, err - } - exprs = append(exprs, e) + ad.ByteOrder = binary.BigEndian + var name string + for ad.Next() { + switch ad.Type() { + case unix.NFTA_EXPR_NAME: + name = ad.String() + case unix.NFTA_EXPR_DATA: + var e expr.Any + switch name { + case "meta": + e = &expr.Meta{} + case "cmp": + e = &expr.Cmp{} + case "counter": + e = &expr.Counter{} + case "payload": + e = &expr.Payload{} + } + if e == nil { + // TODO: introduce an opaque expression type so that users know + // something is here. + continue // unsupported expression type + } + + ad.Do(func(b []byte) error { + if err := expr.Unmarshal(b, e); err != nil { + return err + } + exprs = append(exprs, e) + return nil + }) + } + } + return ad.Err() + }) } - return exprs, nil + return exprs, ad.Err() } func ruleFromMsg(msg netlink.Message) (*Rule, error) { if got, want := msg.Header.Type, ruleHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } - attrs, err := netlink.UnmarshalAttributes(msg.Data[4:]) + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { return nil, err } + ad.ByteOrder = binary.BigEndian var r Rule - for _, attr := range attrs { - switch attr.Type { + for ad.Next() { + switch ad.Type() { case unix.NFTA_RULE_TABLE: - r.Table = &Table{Name: stringFrom0(attr.Data)} + r.Table = &Table{Name: ad.String()} case unix.NFTA_RULE_CHAIN: - r.Chain = &Chain{Name: stringFrom0(attr.Data)} + r.Chain = &Chain{Name: ad.String()} case unix.NFTA_RULE_EXPRESSIONS: - r.Exprs, err = exprsFromMsg(attr.Data) - if err != nil { - return nil, err - } + ad.Do(func(b []byte) error { + r.Exprs, err = exprsFromMsg(b) + return err + }) } } - return &r, nil + return &r, ad.Err() } // AddRule adds the specified Rule. See also @@ -430,16 +437,16 @@ type CounterObj struct { Packets uint64 } -func (c *CounterObj) unmarshal(attrs []netlink.Attribute) error { - for _, attr := range attrs { - switch attr.Type { +func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { + for ad.Next() { + switch ad.Type() { case unix.NFTA_COUNTER_BYTES: - c.Bytes = binaryutil.BigEndian.Uint64(attr.Data) + c.Bytes = ad.Uint64() case unix.NFTA_COUNTER_PACKETS: - c.Packets = binaryutil.BigEndian.Uint64(attr.Data) + c.Packets = ad.Uint64() } } - return nil + return ad.Err() } func (c *CounterObj) family() TableFamily { @@ -470,7 +477,7 @@ func (c *CounterObj) marshal(data bool) ([]byte, error) { // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects type Obj interface { family() TableFamily - unmarshal([]netlink.Attribute) error + unmarshal(*netlink.AttributeDecoder) error marshal(data bool) ([]byte, error) } @@ -499,39 +506,48 @@ func objFromMsg(msg netlink.Message) (Obj, error) { if got, want := msg.Header.Type, objHeaderType; got != want { return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) } - attrs, err := netlink.UnmarshalAttributes(msg.Data[4:]) + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) if err != nil { return nil, err } + ad.ByteOrder = binary.BigEndian var ( table *Table name string objectType uint32 ) const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix - for _, attr := range attrs { - switch attr.Type { + for ad.Next() { + switch ad.Type() { case unix.NFTA_OBJ_TABLE: - table = &Table{Name: stringFrom0(attr.Data)} + table = &Table{Name: ad.String()} case unix.NFTA_OBJ_NAME: - name = stringFrom0(attr.Data) + name = ad.String() case unix.NFTA_OBJ_TYPE: - objectType = binaryutil.BigEndian.Uint32(attr.Data) + objectType = ad.Uint32() case unix.NFTA_OBJ_DATA: switch objectType { case NFT_OBJECT_COUNTER: - attrs, err := netlink.UnmarshalAttributes(attr.Data) - if err != nil { - return nil, err - } o := CounterObj{ Table: table, Name: name, } - return &o, o.unmarshal(attrs) + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + return o.unmarshal(ad) + }) + return &o, ad.Err() } } } + if err := ad.Err(); err != nil { + return nil, err + } return nil, fmt.Errorf("malformed stateful object") }