switch to new netlink.AttributeDecoder

fixes #2
This commit is contained in:
Michael Stapelberg 2018-08-10 18:59:05 +02:00
parent 121db0bb23
commit 409eade12e
5 changed files with 140 additions and 110 deletions

View File

@ -15,6 +15,8 @@
package expr
import (
"encoding/binary"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -41,18 +43,18 @@ func (e *Counter) marshal() ([]byte, error) {
}
func (e *Counter) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_COUNTER_BYTES:
e.Bytes = binaryutil.BigEndian.Uint64(attr.Data)
e.Bytes = ad.Uint64()
case unix.NFTA_COUNTER_PACKETS:
e.Packets = binaryutil.BigEndian.Uint64(attr.Data)
e.Packets = ad.Uint64()
}
}
return nil
return ad.Err()
}

View File

@ -16,6 +16,7 @@
package expr
import (
"encoding/binary"
"fmt"
"github.com/google/nftables/binaryutil"
@ -96,20 +97,20 @@ func (e *Meta) marshal() ([]byte, error) {
}
func (e *Meta) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_META_DREG:
e.Register = binaryutil.BigEndian.Uint32(attr.Data)
e.Register = ad.Uint32()
case unix.NFTA_META_KEY:
e.Key = MetaKey(binaryutil.BigEndian.Uint32(attr.Data))
e.Key = MetaKey(ad.Uint32())
}
}
return nil
return ad.Err()
}
// Masq (Masquerade) is a special case of SNAT, where the source address is
@ -170,26 +171,33 @@ func (e *Cmp) marshal() ([]byte, error) {
}
func (e *Cmp) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_CMP_SREG:
e.Register = binaryutil.BigEndian.Uint32(attr.Data)
e.Register = ad.Uint32()
case unix.NFTA_CMP_OP:
e.Op = CmpOp(binaryutil.BigEndian.Uint32(attr.Data))
e.Op = CmpOp(ad.Uint32())
case unix.NFTA_CMP_DATA:
attrs, err := netlink.UnmarshalAttributes(attr.Data)
if err != nil {
return err
}
if len(attrs) == 1 && attrs[0].Type == unix.NFTA_DATA_VALUE {
e.Data = attrs[0].Data
}
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE {
ad.Do(func(b []byte) error {
e.Data = b
return nil
})
}
return ad.Err()
})
}
}
return nil
return ad.Err()
}

View File

@ -15,6 +15,8 @@
package expr
import (
"encoding/binary"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -41,18 +43,18 @@ func (e *Objref) marshal() ([]byte, error) {
}
func (e *Objref) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_OBJREF_IMM_TYPE:
e.Type = int(binaryutil.BigEndian.Uint32(attr.Data))
e.Type = int(ad.Uint32())
case unix.NFTA_OBJREF_IMM_NAME:
e.Name = string(attr.Data)
e.Name = ad.String()
}
}
return nil
return ad.Err()
}

View File

@ -15,6 +15,8 @@
package expr
import (
"encoding/binary"
"github.com/google/nftables/binaryutil"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
@ -53,22 +55,22 @@ func (e *Payload) marshal() ([]byte, error) {
}
func (e *Payload) unmarshal(data []byte) error {
attrs, err := netlink.UnmarshalAttributes(data)
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
}
for _, attr := range attrs {
switch attr.Type {
ad.ByteOrder = binary.BigEndian
for ad.Next() {
switch ad.Type() {
case unix.NFTA_PAYLOAD_DREG:
e.DestRegister = binaryutil.BigEndian.Uint32(attr.Data)
e.DestRegister = ad.Uint32()
case unix.NFTA_PAYLOAD_BASE:
e.Base = PayloadBase(binaryutil.BigEndian.Uint32(attr.Data))
e.Base = PayloadBase(ad.Uint32())
case unix.NFTA_PAYLOAD_OFFSET:
e.Offset = binaryutil.BigEndian.Uint32(attr.Data)
e.Offset = ad.Uint32()
case unix.NFTA_PAYLOAD_LEN:
e.Len = binaryutil.BigEndian.Uint32(attr.Data)
e.Len = ad.Uint32()
}
}
return nil
return ad.Err()
}

View File

@ -16,6 +16,7 @@
package nftables
import (
"encoding/binary"
"fmt"
"math"
"strings"
@ -183,75 +184,81 @@ func stringFrom0(b []byte) string {
}
func exprsFromMsg(b []byte) ([]expr.Any, error) {
elems, err := netlink.UnmarshalAttributes(b)
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return nil, err
}
ad.ByteOrder = binary.BigEndian
var exprs []expr.Any
for _, elem := range elems {
attrs, err := netlink.UnmarshalAttributes(elem.Data)
if err != nil {
return nil, err
}
var (
name string
data []byte
)
for _, attr := range attrs {
switch attr.Type {
case unix.NFTA_EXPR_NAME:
name = stringFrom0(attr.Data)
case unix.NFTA_EXPR_DATA:
data = attr.Data
for ad.Next() {
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
}
var e expr.Any
switch name {
case "meta":
e = &expr.Meta{}
case "cmp":
e = &expr.Cmp{}
case "counter":
e = &expr.Counter{}
case "payload":
e = &expr.Payload{}
}
if e == nil {
// TODO: introduce an opaque expression type so that users know
// something is here.
continue // unsupported expression type
}
if err := expr.Unmarshal(data, e); err != nil {
return nil, err
}
exprs = append(exprs, e)
ad.ByteOrder = binary.BigEndian
var name string
for ad.Next() {
switch ad.Type() {
case unix.NFTA_EXPR_NAME:
name = ad.String()
case unix.NFTA_EXPR_DATA:
var e expr.Any
switch name {
case "meta":
e = &expr.Meta{}
case "cmp":
e = &expr.Cmp{}
case "counter":
e = &expr.Counter{}
case "payload":
e = &expr.Payload{}
}
if e == nil {
// TODO: introduce an opaque expression type so that users know
// something is here.
continue // unsupported expression type
}
ad.Do(func(b []byte) error {
if err := expr.Unmarshal(b, e); err != nil {
return err
}
exprs = append(exprs, e)
return nil
})
}
}
return ad.Err()
})
}
return exprs, nil
return exprs, ad.Err()
}
func ruleFromMsg(msg netlink.Message) (*Rule, error) {
if got, want := msg.Header.Type, ruleHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
}
attrs, err := netlink.UnmarshalAttributes(msg.Data[4:])
ad, err := netlink.NewAttributeDecoder(msg.Data[4:])
if err != nil {
return nil, err
}
ad.ByteOrder = binary.BigEndian
var r Rule
for _, attr := range attrs {
switch attr.Type {
for ad.Next() {
switch ad.Type() {
case unix.NFTA_RULE_TABLE:
r.Table = &Table{Name: stringFrom0(attr.Data)}
r.Table = &Table{Name: ad.String()}
case unix.NFTA_RULE_CHAIN:
r.Chain = &Chain{Name: stringFrom0(attr.Data)}
r.Chain = &Chain{Name: ad.String()}
case unix.NFTA_RULE_EXPRESSIONS:
r.Exprs, err = exprsFromMsg(attr.Data)
if err != nil {
return nil, err
}
ad.Do(func(b []byte) error {
r.Exprs, err = exprsFromMsg(b)
return err
})
}
}
return &r, nil
return &r, ad.Err()
}
// AddRule adds the specified Rule. See also
@ -430,16 +437,16 @@ type CounterObj struct {
Packets uint64
}
func (c *CounterObj) unmarshal(attrs []netlink.Attribute) error {
for _, attr := range attrs {
switch attr.Type {
func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error {
for ad.Next() {
switch ad.Type() {
case unix.NFTA_COUNTER_BYTES:
c.Bytes = binaryutil.BigEndian.Uint64(attr.Data)
c.Bytes = ad.Uint64()
case unix.NFTA_COUNTER_PACKETS:
c.Packets = binaryutil.BigEndian.Uint64(attr.Data)
c.Packets = ad.Uint64()
}
}
return nil
return ad.Err()
}
func (c *CounterObj) family() TableFamily {
@ -470,7 +477,7 @@ func (c *CounterObj) marshal(data bool) ([]byte, error) {
// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects
type Obj interface {
family() TableFamily
unmarshal([]netlink.Attribute) error
unmarshal(*netlink.AttributeDecoder) error
marshal(data bool) ([]byte, error)
}
@ -499,39 +506,48 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
if got, want := msg.Header.Type, objHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
}
attrs, err := netlink.UnmarshalAttributes(msg.Data[4:])
ad, err := netlink.NewAttributeDecoder(msg.Data[4:])
if err != nil {
return nil, err
}
ad.ByteOrder = binary.BigEndian
var (
table *Table
name string
objectType uint32
)
const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix
for _, attr := range attrs {
switch attr.Type {
for ad.Next() {
switch ad.Type() {
case unix.NFTA_OBJ_TABLE:
table = &Table{Name: stringFrom0(attr.Data)}
table = &Table{Name: ad.String()}
case unix.NFTA_OBJ_NAME:
name = stringFrom0(attr.Data)
name = ad.String()
case unix.NFTA_OBJ_TYPE:
objectType = binaryutil.BigEndian.Uint32(attr.Data)
objectType = ad.Uint32()
case unix.NFTA_OBJ_DATA:
switch objectType {
case NFT_OBJECT_COUNTER:
attrs, err := netlink.UnmarshalAttributes(attr.Data)
if err != nil {
return nil, err
}
o := CounterObj{
Table: table,
Name: name,
}
return &o, o.unmarshal(attrs)
ad.Do(func(b []byte) error {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return err
}
ad.ByteOrder = binary.BigEndian
return o.unmarshal(ad)
})
return &o, ad.Err()
}
}
}
if err := ad.Err(); err != nil {
return nil, err
}
return nil, fmt.Errorf("malformed stateful object")
}