switch to new netlink.AttributeDecoder

fixes #2
This commit is contained in:
Michael Stapelberg 2018-08-10 18:59:05 +02:00
parent 121db0bb23
commit 409eade12e
5 changed files with 140 additions and 110 deletions

View File

@ -15,6 +15,8 @@
package expr package expr
import ( import (
"encoding/binary"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -41,18 +43,18 @@ func (e *Counter) marshal() ([]byte, error) {
} }
func (e *Counter) unmarshal(data []byte) error { func (e *Counter) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
for _, attr := range attrs { ad.ByteOrder = binary.BigEndian
switch attr.Type { for ad.Next() {
switch ad.Type() {
case unix.NFTA_COUNTER_BYTES: case unix.NFTA_COUNTER_BYTES:
e.Bytes = binaryutil.BigEndian.Uint64(attr.Data) e.Bytes = ad.Uint64()
case unix.NFTA_COUNTER_PACKETS: case unix.NFTA_COUNTER_PACKETS:
e.Packets = binaryutil.BigEndian.Uint64(attr.Data) e.Packets = ad.Uint64()
} }
} }
return ad.Err()
return nil
} }

View File

@ -16,6 +16,7 @@
package expr package expr
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
@ -96,20 +97,20 @@ func (e *Meta) marshal() ([]byte, error) {
} }
func (e *Meta) unmarshal(data []byte) error { func (e *Meta) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
for _, attr := range attrs { ad.ByteOrder = binary.BigEndian
switch attr.Type { for ad.Next() {
switch ad.Type() {
case unix.NFTA_META_DREG: case unix.NFTA_META_DREG:
e.Register = binaryutil.BigEndian.Uint32(attr.Data) e.Register = ad.Uint32()
case unix.NFTA_META_KEY: case unix.NFTA_META_KEY:
e.Key = MetaKey(binaryutil.BigEndian.Uint32(attr.Data)) e.Key = MetaKey(ad.Uint32())
} }
} }
return ad.Err()
return nil
} }
// Masq (Masquerade) is a special case of SNAT, where the source address is // Masq (Masquerade) is a special case of SNAT, where the source address is
@ -170,26 +171,33 @@ func (e *Cmp) marshal() ([]byte, error) {
} }
func (e *Cmp) unmarshal(data []byte) error { func (e *Cmp) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
for _, attr := range attrs { ad.ByteOrder = binary.BigEndian
switch attr.Type { for ad.Next() {
switch ad.Type() {
case unix.NFTA_CMP_SREG: case unix.NFTA_CMP_SREG:
e.Register = binaryutil.BigEndian.Uint32(attr.Data) e.Register = ad.Uint32()
case unix.NFTA_CMP_OP: case unix.NFTA_CMP_OP:
e.Op = CmpOp(binaryutil.BigEndian.Uint32(attr.Data)) e.Op = CmpOp(ad.Uint32())
case unix.NFTA_CMP_DATA: case unix.NFTA_CMP_DATA:
attrs, err := netlink.UnmarshalAttributes(attr.Data) ad.Do(func(b []byte) error {
if err != nil { ad, err := netlink.NewAttributeDecoder(data)
return err if err != nil {
} return err
if len(attrs) == 1 && attrs[0].Type == unix.NFTA_DATA_VALUE { }
e.Data = attrs[0].Data ad.ByteOrder = binary.BigEndian
} if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE {
ad.Do(func(b []byte) error {
e.Data = b
return nil
})
}
return ad.Err()
})
} }
} }
return ad.Err()
return nil
} }

View File

@ -15,6 +15,8 @@
package expr package expr
import ( import (
"encoding/binary"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -41,18 +43,18 @@ func (e *Objref) marshal() ([]byte, error) {
} }
func (e *Objref) unmarshal(data []byte) error { func (e *Objref) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
for _, attr := range attrs { ad.ByteOrder = binary.BigEndian
switch attr.Type { for ad.Next() {
switch ad.Type() {
case unix.NFTA_OBJREF_IMM_TYPE: case unix.NFTA_OBJREF_IMM_TYPE:
e.Type = int(binaryutil.BigEndian.Uint32(attr.Data)) e.Type = int(ad.Uint32())
case unix.NFTA_OBJREF_IMM_NAME: case unix.NFTA_OBJREF_IMM_NAME:
e.Name = string(attr.Data) e.Name = ad.String()
} }
} }
return ad.Err()
return nil
} }

View File

@ -15,6 +15,8 @@
package expr package expr
import ( import (
"encoding/binary"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -53,22 +55,22 @@ func (e *Payload) marshal() ([]byte, error) {
} }
func (e *Payload) unmarshal(data []byte) error { func (e *Payload) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
for _, attr := range attrs { ad.ByteOrder = binary.BigEndian
switch attr.Type { for ad.Next() {
switch ad.Type() {
case unix.NFTA_PAYLOAD_DREG: case unix.NFTA_PAYLOAD_DREG:
e.DestRegister = binaryutil.BigEndian.Uint32(attr.Data) e.DestRegister = ad.Uint32()
case unix.NFTA_PAYLOAD_BASE: case unix.NFTA_PAYLOAD_BASE:
e.Base = PayloadBase(binaryutil.BigEndian.Uint32(attr.Data)) e.Base = PayloadBase(ad.Uint32())
case unix.NFTA_PAYLOAD_OFFSET: case unix.NFTA_PAYLOAD_OFFSET:
e.Offset = binaryutil.BigEndian.Uint32(attr.Data) e.Offset = ad.Uint32()
case unix.NFTA_PAYLOAD_LEN: case unix.NFTA_PAYLOAD_LEN:
e.Len = binaryutil.BigEndian.Uint32(attr.Data) e.Len = ad.Uint32()
} }
} }
return ad.Err()
return nil
} }

View File

@ -16,6 +16,7 @@
package nftables package nftables
import ( import (
"encoding/binary"
"fmt" "fmt"
"math" "math"
"strings" "strings"
@ -183,75 +184,81 @@ func stringFrom0(b []byte) string {
} }
func exprsFromMsg(b []byte) ([]expr.Any, error) { func exprsFromMsg(b []byte) ([]expr.Any, error) {
elems, err := netlink.UnmarshalAttributes(b) ad, err := netlink.NewAttributeDecoder(b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ad.ByteOrder = binary.BigEndian
var exprs []expr.Any var exprs []expr.Any
for _, elem := range elems { for ad.Next() {
attrs, err := netlink.UnmarshalAttributes(elem.Data) ad.Do(func(b []byte) error {
if err != nil { ad, err := netlink.NewAttributeDecoder(b)
return nil, err if err != nil {
} return err
var (
name string
data []byte
)
for _, attr := range attrs {
switch attr.Type {
case unix.NFTA_EXPR_NAME:
name = stringFrom0(attr.Data)
case unix.NFTA_EXPR_DATA:
data = attr.Data
} }
} ad.ByteOrder = binary.BigEndian
var e expr.Any var name string
switch name { for ad.Next() {
case "meta": switch ad.Type() {
e = &expr.Meta{} case unix.NFTA_EXPR_NAME:
case "cmp": name = ad.String()
e = &expr.Cmp{} case unix.NFTA_EXPR_DATA:
case "counter": var e expr.Any
e = &expr.Counter{} switch name {
case "payload": case "meta":
e = &expr.Payload{} e = &expr.Meta{}
} case "cmp":
if e == nil { e = &expr.Cmp{}
// TODO: introduce an opaque expression type so that users know case "counter":
// something is here. e = &expr.Counter{}
continue // unsupported expression type case "payload":
} e = &expr.Payload{}
if err := expr.Unmarshal(data, e); err != nil { }
return nil, err if e == nil {
} // TODO: introduce an opaque expression type so that users know
exprs = append(exprs, e) // something is here.
continue // unsupported expression type
}
ad.Do(func(b []byte) error {
if err := expr.Unmarshal(b, e); err != nil {
return err
}
exprs = append(exprs, e)
return nil
})
}
}
return ad.Err()
})
} }
return exprs, nil return exprs, ad.Err()
} }
func ruleFromMsg(msg netlink.Message) (*Rule, error) { func ruleFromMsg(msg netlink.Message) (*Rule, error) {
if got, want := msg.Header.Type, ruleHeaderType; got != want { if got, want := msg.Header.Type, ruleHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
} }
attrs, err := netlink.UnmarshalAttributes(msg.Data[4:]) ad, err := netlink.NewAttributeDecoder(msg.Data[4:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
ad.ByteOrder = binary.BigEndian
var r Rule var r Rule
for _, attr := range attrs { for ad.Next() {
switch attr.Type { switch ad.Type() {
case unix.NFTA_RULE_TABLE: case unix.NFTA_RULE_TABLE:
r.Table = &Table{Name: stringFrom0(attr.Data)} r.Table = &Table{Name: ad.String()}
case unix.NFTA_RULE_CHAIN: case unix.NFTA_RULE_CHAIN:
r.Chain = &Chain{Name: stringFrom0(attr.Data)} r.Chain = &Chain{Name: ad.String()}
case unix.NFTA_RULE_EXPRESSIONS: case unix.NFTA_RULE_EXPRESSIONS:
r.Exprs, err = exprsFromMsg(attr.Data) ad.Do(func(b []byte) error {
if err != nil { r.Exprs, err = exprsFromMsg(b)
return nil, err return err
} })
} }
} }
return &r, nil return &r, ad.Err()
} }
// AddRule adds the specified Rule. See also // AddRule adds the specified Rule. See also
@ -430,16 +437,16 @@ type CounterObj struct {
Packets uint64 Packets uint64
} }
func (c *CounterObj) unmarshal(attrs []netlink.Attribute) error { func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error {
for _, attr := range attrs { for ad.Next() {
switch attr.Type { switch ad.Type() {
case unix.NFTA_COUNTER_BYTES: case unix.NFTA_COUNTER_BYTES:
c.Bytes = binaryutil.BigEndian.Uint64(attr.Data) c.Bytes = ad.Uint64()
case unix.NFTA_COUNTER_PACKETS: case unix.NFTA_COUNTER_PACKETS:
c.Packets = binaryutil.BigEndian.Uint64(attr.Data) c.Packets = ad.Uint64()
} }
} }
return nil return ad.Err()
} }
func (c *CounterObj) family() TableFamily { func (c *CounterObj) family() TableFamily {
@ -470,7 +477,7 @@ func (c *CounterObj) marshal(data bool) ([]byte, error) {
// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects
type Obj interface { type Obj interface {
family() TableFamily family() TableFamily
unmarshal([]netlink.Attribute) error unmarshal(*netlink.AttributeDecoder) error
marshal(data bool) ([]byte, error) marshal(data bool) ([]byte, error)
} }
@ -499,39 +506,48 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
if got, want := msg.Header.Type, objHeaderType; got != want { if got, want := msg.Header.Type, objHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
} }
attrs, err := netlink.UnmarshalAttributes(msg.Data[4:]) ad, err := netlink.NewAttributeDecoder(msg.Data[4:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
ad.ByteOrder = binary.BigEndian
var ( var (
table *Table table *Table
name string name string
objectType uint32 objectType uint32
) )
const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix
for _, attr := range attrs { for ad.Next() {
switch attr.Type { switch ad.Type() {
case unix.NFTA_OBJ_TABLE: case unix.NFTA_OBJ_TABLE:
table = &Table{Name: stringFrom0(attr.Data)} table = &Table{Name: ad.String()}
case unix.NFTA_OBJ_NAME: case unix.NFTA_OBJ_NAME:
name = stringFrom0(attr.Data) name = ad.String()
case unix.NFTA_OBJ_TYPE: case unix.NFTA_OBJ_TYPE:
objectType = binaryutil.BigEndian.Uint32(attr.Data) objectType = ad.Uint32()
case unix.NFTA_OBJ_DATA: case unix.NFTA_OBJ_DATA:
switch objectType { switch objectType {
case NFT_OBJECT_COUNTER: case NFT_OBJECT_COUNTER:
attrs, err := netlink.UnmarshalAttributes(attr.Data)
if err != nil {
return nil, err
}
o := CounterObj{ o := CounterObj{
Table: table, Table: table,
Name: name, Name: name,
} }
return &o, o.unmarshal(attrs)
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
return o.unmarshal(ad)
})
return &o, ad.Err()
} }
} }
} }
if err := ad.Err(); err != nil {
return nil, err
}
return nil, fmt.Errorf("malformed stateful object") return nil, fmt.Errorf("malformed stateful object")
} }