Compare commits

...

5 Commits

Author SHA1 Message Date
TheDiveO 06687b6e34
use TableFamilyUnspecified (NFPROTO_UNSPEC) instead of AF_UNSPEC (#165) 2022-05-15 23:16:05 +02:00
Michael Stapelberg 58da7d8bf3 make links stable 2022-05-15 23:15:01 +02:00
thediveo 8ea944061f add typed xtables information un/marshalling
more tests and fixes

more info support; refactoring
2022-05-15 23:12:26 +02:00
thediveo 4b6f0f2b44 add un/marshalling with native endianess and alignment 2022-05-15 23:12:26 +02:00
thediveo 3e042f75d7 refactor: pass table family when un/marshalling expr 2022-05-15 23:12:26 +02:00
56 changed files with 2026 additions and 105 deletions

240
alignedbuff/alignedbuff.go Normal file
View File

@ -0,0 +1,240 @@
// Package alignedbuff implements encoding and decoding aligned data elements
// to/from buffers in native endianess.
package alignedbuff
import (
"bytes"
"errors"
"fmt"
"unsafe"
"github.com/google/nftables/binaryutil"
)
// ErrEOF signals trying to read beyond the available payload information.
var ErrEOF = errors.New("not enough data left")
// AlignedBuff implements marshalling and unmarshalling information in
// platform/architecture native endianess and data type alignment. It
// additionally covers some of the nftables-xtables translation-specific
// idiosyncracies to the extend needed in order to properly marshal and
// unmarshal Match and Target expressions, and their Info payload in particular.
type AlignedBuff struct {
data []byte
pos int
}
// New returns a new AlignedBuff for marshalling aligned data in native
// endianess.
func New() AlignedBuff {
return AlignedBuff{}
}
// NewWithData returns a new AlignedBuff for unmarshalling the passed data in
// native endianess.
func NewWithData(data []byte) AlignedBuff {
return AlignedBuff{data: data}
}
// Data returns the properly padded info payload data written before by calling
// the various Uint8, Uint16, ... marshalling functions.
func (a *AlignedBuff) Data() []byte {
// The Linux kernel expects payloads to be padded to the next uint64
// alignment.
a.alignWrite(uint64AlignMask)
return a.data
}
// BytesAligned32 unmarshals the given amount of bytes starting with the native
// alignment for uint32 data types. It returns ErrEOF when trying to read beyond
// the payload.
//
// BytesAligned32 is used to unmarshal IP addresses for different IP versions,
// which are always aligned the same way as the native alignment for uint32.
func (a *AlignedBuff) BytesAligned32(size int) ([]byte, error) {
if err := a.alignCheckedRead(uint32AlignMask); err != nil {
return nil, err
}
if a.pos > len(a.data)-size {
return nil, ErrEOF
}
data := a.data[a.pos : a.pos+size]
a.pos += size
return data, nil
}
// Uint8 unmarshals an uint8 in native endianess and alignment. It returns
// ErrEOF when trying to read beyond the payload.
func (a *AlignedBuff) Uint8() (uint8, error) {
if a.pos >= len(a.data) {
return 0, ErrEOF
}
v := a.data[a.pos]
a.pos++
return v, nil
}
// Uint16 unmarshals an uint16 in native endianess and alignment. It returns
// ErrEOF when trying to read beyond the payload.
func (a *AlignedBuff) Uint16() (uint16, error) {
if err := a.alignCheckedRead(uint16AlignMask); err != nil {
return 0, err
}
v := binaryutil.NativeEndian.Uint16(a.data[a.pos : a.pos+2])
a.pos += 2
return v, nil
}
// Uint16BE unmarshals an uint16 in "network" (=big endian) endianess and native
// uint16 alignment. It returns ErrEOF when trying to read beyond the payload.
func (a *AlignedBuff) Uint16BE() (uint16, error) {
if err := a.alignCheckedRead(uint16AlignMask); err != nil {
return 0, err
}
v := binaryutil.BigEndian.Uint16(a.data[a.pos : a.pos+2])
a.pos += 2
return v, nil
}
// Uint32 unmarshals an uint32 in native endianess and alignment. It returns
// ErrEOF when trying to read beyond the payload.
func (a *AlignedBuff) Uint32() (uint32, error) {
if err := a.alignCheckedRead(uint32AlignMask); err != nil {
return 0, err
}
v := binaryutil.NativeEndian.Uint32(a.data[a.pos : a.pos+4])
a.pos += 4
return v, nil
}
// Uint64 unmarshals an uint64 in native endianess and alignment. It returns
// ErrEOF when trying to read beyond the payload.
func (a *AlignedBuff) Uint64() (uint64, error) {
if err := a.alignCheckedRead(uint64AlignMask); err != nil {
return 0, err
}
v := binaryutil.NativeEndian.Uint64(a.data[a.pos : a.pos+8])
a.pos += 8
return v, nil
}
// Uint unmarshals an uint in native endianess and alignment for the C "unsigned
// int" type. It returns ErrEOF when trying to read beyond the payload. Please
// note that on 64bit platforms, the size and alignment of C's and Go's unsigned
// integer data types differ, so we encapsulate this difference here.
func (a *AlignedBuff) Uint() (uint, error) {
switch uintSize {
case 2:
v, err := a.Uint16()
return uint(v), err
case 4:
v, err := a.Uint32()
return uint(v), err
case 8:
v, err := a.Uint64()
return uint(v), err
default:
panic(fmt.Sprintf("unsupported uint size %d", uintSize))
}
}
// PutBytesAligned32 marshals the given bytes starting with the native alignment
// for uint32 data types. It additionaly adds padding to reach the specified
// size.
//
// PutBytesAligned32 is used to marshal IP addresses for different IP versions,
// which are always aligned the same way as the native alignment for uint32.
func (a *AlignedBuff) PutBytesAligned32(data []byte, size int) {
a.alignWrite(uint32AlignMask)
a.data = append(a.data, data...)
a.pos += len(data)
if len(data) < size {
padding := size - len(data)
a.data = append(a.data, bytes.Repeat([]byte{0}, padding)...)
a.pos += padding
}
}
// PutUint8 marshals an uint8 in native endianess and alignment.
func (a *AlignedBuff) PutUint8(v uint8) {
a.data = append(a.data, v)
a.pos++
}
// PutUint16 marshals an uint16 in native endianess and alignment.
func (a *AlignedBuff) PutUint16(v uint16) {
a.alignWrite(uint16AlignMask)
a.data = append(a.data, binaryutil.NativeEndian.PutUint16(v)...)
a.pos += 2
}
// PutUint16BE marshals an uint16 in "network" (=big endian) endianess and
// native uint16 alignment.
func (a *AlignedBuff) PutUint16BE(v uint16) {
a.alignWrite(uint16AlignMask)
a.data = append(a.data, binaryutil.BigEndian.PutUint16(v)...)
a.pos += 2
}
// PutUint32 marshals an uint32 in native endianess and alignment.
func (a *AlignedBuff) PutUint32(v uint32) {
a.alignWrite(uint32AlignMask)
a.data = append(a.data, binaryutil.NativeEndian.PutUint32(v)...)
a.pos += 4
}
// PutUint64 marshals an uint64 in native endianess and alignment.
func (a *AlignedBuff) PutUint64(v uint64) {
a.alignWrite(uint64AlignMask)
a.data = append(a.data, binaryutil.NativeEndian.PutUint64(v)...)
a.pos += 8
}
// PutUint marshals an uint in native endianess and alignment for the C
// "unsigned int" type. Please note that on 64bit platforms, the size and
// alignment of C's and Go's unsigned integer data types differ, so we
// encapsulate this difference here.
func (a *AlignedBuff) PutUint(v uint) {
switch uintSize {
case 2:
a.PutUint16(uint16(v))
case 4:
a.PutUint32(uint32(v))
case 8:
a.PutUint64(uint64(v))
default:
panic(fmt.Sprintf("unsupported uint size %d", uintSize))
}
}
// alignCheckedRead aligns the (read) position if necessary and suitable
// according to the specified alignment mask. alignCheckedRead returns an error
// if after any necessary alignment there isn't enough data left to be read into
// a value of the size corresponding to the specified alignment mask.
func (a *AlignedBuff) alignCheckedRead(m int) error {
a.pos = (a.pos + m) & ^m
if a.pos > len(a.data)-(m+1) {
return ErrEOF
}
return nil
}
// alignWrite aligns the (write) position if necessary and suitable according to
// the specified alignment mask. It doubles as final payload padding helpmate in
// order to keep the kernel happy.
func (a *AlignedBuff) alignWrite(m int) {
pos := (a.pos + m) & ^m
if pos != a.pos {
a.data = append(a.data, padding[:pos-a.pos]...)
a.pos = pos
}
}
// This is ... ugly.
var uint16AlignMask = int(unsafe.Alignof(uint16(0)) - 1)
var uint32AlignMask = int(unsafe.Alignof(uint32(0)) - 1)
var uint64AlignMask = int(unsafe.Alignof(uint64(0)) - 1)
var padding = bytes.Repeat([]byte{0}, uint64AlignMask)
// And this even worse.
var uintSize = unsafe.Sizeof(uint32(0))

View File

@ -0,0 +1,204 @@
package alignedbuff
import (
"testing"
)
func TestAlignmentData(t *testing.T) {
if uint16AlignMask == 0 {
t.Fatal("zero uint16 alignment mask")
}
if uint32AlignMask == 0 {
t.Fatal("zero uint32 alignment mask")
}
if uint64AlignMask == 0 {
t.Fatal("zero uint64 alignment mask")
}
if len(padding) == 0 {
t.Fatal("zero alignment padding sequence")
}
if uintSize == 0 {
t.Fatal("zero uint size")
}
}
func TestAlignedBuff8(t *testing.T) {
b := NewWithData([]byte{0x42})
tests := []struct {
name string
v uint8
err error
}{
{
name: "first read",
v: 0x42,
err: nil,
},
{
name: "end of buffer",
v: 0,
err: ErrEOF,
},
}
for _, tt := range tests {
v, err := b.Uint8()
if v != tt.v || err != tt.err {
t.Errorf("expected: %#v %#v, got: %#v, %#v",
tt.v, tt.err, v, err)
}
}
}
func TestAlignedBuff16(t *testing.T) {
b0 := New()
b0.PutUint8(0x42)
b0.PutUint16(0x1234)
b0.PutUint16(0x5678)
b := NewWithData(b0.data)
v, err := b.Uint8()
if v != 0x42 || err != nil {
t.Fatalf("unaligment read failed")
}
tests := []struct {
name string
v uint16
err error
}{
{
name: "first read",
v: 0x1234,
err: nil,
},
{
name: "second read",
v: 0x5678,
err: nil,
},
{
name: "end of buffer",
v: 0,
err: ErrEOF,
},
}
for _, tt := range tests {
v, err := b.Uint16()
if v != tt.v || err != tt.err {
t.Errorf("%s failed, expected: %#v %#v, got: %#v, %#v",
tt.name, tt.v, tt.err, v, err)
}
}
}
func TestAlignedBuff32(t *testing.T) {
b0 := New()
b0.PutUint8(0x42)
b0.PutUint32(0x12345678)
b0.PutUint32(0x01cecafe)
b := NewWithData(b0.data)
if len(b0.Data()) != 4*4 {
t.Fatalf("alignment padding failed")
}
v, err := b.Uint8()
if v != 0x42 || err != nil {
t.Fatalf("unaligment read failed")
}
tests := []struct {
name string
v uint32
err error
}{
{
name: "first read",
v: 0x12345678,
err: nil,
},
{
name: "second read",
v: 0x01cecafe,
err: nil,
},
{
name: "end of buffer",
v: 0,
err: ErrEOF,
},
}
for _, tt := range tests {
v, err := b.Uint32()
if v != tt.v || err != tt.err {
t.Errorf("expected: %#v %#v, got: %#v, %#v",
tt.v, tt.err, v, err)
}
}
}
func TestAlignedBuff64(t *testing.T) {
b0 := New()
b0.PutUint8(0x42)
b0.PutUint64(0x1234567823456789)
b0.PutUint64(0x01cecafec001beef)
b := NewWithData(b0.data)
v, err := b.Uint8()
if v != 0x42 || err != nil {
t.Fatalf("unaligment read failed")
}
tests := []struct {
name string
v uint64
err error
}{
{
name: "first read",
v: 0x1234567823456789,
err: nil,
},
{
name: "second read",
v: 0x01cecafec001beef,
err: nil,
},
{
name: "end of buffer",
v: 0,
err: ErrEOF,
},
}
for _, tt := range tests {
v, err := b.Uint64()
if v != tt.v || err != tt.err {
t.Errorf("expected: %#v %#v, got: %#v, %#v",
tt.v, tt.err, v, err)
}
}
}
func TestAlignedUint(t *testing.T) {
expectedv := uint(^uint32(0) - 1)
b0 := New()
b0.PutUint8(0x55)
b0.PutUint(expectedv)
b0.PutUint8(0xAA)
b := NewWithData(b0.data)
v, err := b.Uint8()
if v != 0x55 || err != nil {
t.Fatalf("sentinel read failed")
}
uiv, err := b.Uint()
if uiv != expectedv || err != nil {
t.Fatalf("uint read failed, expected: %d, got: %d", expectedv, uiv)
}
v, err = b.Uint8()
if v != 0xAA || err != nil {
t.Fatalf("sentinel read failed")
}
}

View File

@ -184,7 +184,7 @@ func (cc *Conn) ListChains() ([]*Chain, error) {
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN),
Flags: netlink.Request | netlink.Dump, Flags: netlink.Request | netlink.Dump,
}, },
Data: extraHeader(uint8(unix.AF_UNSPEC), 0), Data: extraHeader(uint8(TableFamilyUnspecified), 0),
} }
response, err := conn.Execute(msg) response, err := conn.Execute(msg)

View File

@ -227,8 +227,8 @@ func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
return b return b
} }
func (cc *Conn) marshalExpr(e expr.Any) []byte { func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
b, err := expr.Marshal(e) b, err := expr.Marshal(fam, e)
if err != nil { if err != nil {
cc.setErr(err) cc.setErr(err)
return nil return nil

View File

@ -30,7 +30,7 @@ type Bitwise struct {
Xor []byte Xor []byte
} }
func (e *Bitwise) marshal() ([]byte, error) { func (e *Bitwise) marshal(fam byte) ([]byte, error) {
mask, err := netlink.MarshalAttributes([]netlink.Attribute{ mask, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_DATA_VALUE, Data: e.Mask}, {Type: unix.NFTA_DATA_VALUE, Data: e.Mask},
}) })
@ -60,7 +60,7 @@ func (e *Bitwise) marshal() ([]byte, error) {
}) })
} }
func (e *Bitwise) unmarshal(data []byte) error { func (e *Bitwise) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -32,7 +32,7 @@ func TestBitwise(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
nbw := Bitwise{} nbw := Bitwise{}
data, err := tt.bw.marshal() data, err := tt.bw.marshal(0 /* don't care in this test */)
if err != nil { if err != nil {
t.Fatalf("marshal error: %+v", err) t.Fatalf("marshal error: %+v", err)
@ -44,7 +44,7 @@ func TestBitwise(t *testing.T) {
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
for ad.Next() { for ad.Next() {
if ad.Type() == unix.NFTA_EXPR_DATA { if ad.Type() == unix.NFTA_EXPR_DATA {
if err := nbw.unmarshal(ad.Bytes()); err != nil { if err := nbw.unmarshal(0, ad.Bytes()); err != nil {
t.Errorf("unmarshal error: %+v", err) t.Errorf("unmarshal error: %+v", err)
break break
} }

View File

@ -37,7 +37,7 @@ type Byteorder struct {
Size uint32 Size uint32
} }
func (e *Byteorder) marshal() ([]byte, error) { func (e *Byteorder) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, {Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)},
{Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, {Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)},
@ -54,6 +54,6 @@ func (e *Byteorder) marshal() ([]byte, error) {
}) })
} }
func (e *Byteorder) unmarshal(data []byte) error { func (e *Byteorder) unmarshal(fam byte, data []byte) error {
return fmt.Errorf("not yet implemented") return fmt.Errorf("not yet implemented")
} }

View File

@ -27,7 +27,7 @@ type Counter struct {
Packets uint64 Packets uint64
} }
func (e *Counter) marshal() ([]byte, error) { func (e *Counter) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)}, {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)},
{Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)}, {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)},
@ -42,7 +42,7 @@ func (e *Counter) marshal() ([]byte, error) {
}) })
} }
func (e *Counter) unmarshal(data []byte) error { func (e *Counter) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -63,7 +63,7 @@ type Ct struct {
Key CtKey Key CtKey
} }
func (e *Ct) marshal() ([]byte, error) { func (e *Ct) marshal(fam byte) ([]byte, error) {
regData := []byte{} regData := []byte{}
exprData, err := netlink.MarshalAttributes( exprData, err := netlink.MarshalAttributes(
[]netlink.Attribute{ []netlink.Attribute{
@ -97,7 +97,7 @@ func (e *Ct) marshal() ([]byte, error) {
}) })
} }
func (e *Ct) unmarshal(data []byte) error { func (e *Ct) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -28,7 +28,7 @@ type Dup struct {
IsRegDevSet bool IsRegDevSet bool
} }
func (e *Dup) marshal() ([]byte, error) { func (e *Dup) marshal(fam byte) ([]byte, error) {
attrs := []netlink.Attribute{ attrs := []netlink.Attribute{
{Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)}, {Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)},
} }
@ -49,7 +49,7 @@ func (e *Dup) marshal() ([]byte, error) {
}) })
} }
func (e *Dup) unmarshal(data []byte) error { func (e *Dup) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -34,7 +34,7 @@ type Dynset struct {
Invert bool Invert bool
} }
func (e *Dynset) marshal() ([]byte, error) { func (e *Dynset) marshal(fam byte) ([]byte, error) {
// See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c
var opAttrs []netlink.Attribute var opAttrs []netlink.Attribute
opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)})
@ -62,7 +62,7 @@ func (e *Dynset) marshal() ([]byte, error) {
}) })
} }
func (e *Dynset) unmarshal(data []byte) error { func (e *Dynset) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -24,19 +24,19 @@ import (
) )
// Marshal serializes the specified expression into a byte slice. // Marshal serializes the specified expression into a byte slice.
func Marshal(e Any) ([]byte, error) { func Marshal(fam byte, e Any) ([]byte, error) {
return e.marshal() return e.marshal(fam)
} }
// Unmarshal fills an expression from the specified byte slice. // Unmarshal fills an expression from the specified byte slice.
func Unmarshal(data []byte, e Any) error { func Unmarshal(fam byte, data []byte, e Any) error {
return e.unmarshal(data) return e.unmarshal(fam, data)
} }
// Any is an interface implemented by any expression type. // Any is an interface implemented by any expression type.
type Any interface { type Any interface {
marshal() ([]byte, error) marshal(fam byte) ([]byte, error)
unmarshal([]byte) error unmarshal(fam byte, data []byte) error
} }
// MetaKey specifies which piece of meta information should be loaded. See also // MetaKey specifies which piece of meta information should be loaded. See also
@ -80,7 +80,7 @@ type Meta struct {
Register uint32 Register uint32
} }
func (e *Meta) marshal() ([]byte, error) { func (e *Meta) marshal(fam byte) ([]byte, error) {
regData := []byte{} regData := []byte{}
exprData, err := netlink.MarshalAttributes( exprData, err := netlink.MarshalAttributes(
[]netlink.Attribute{ []netlink.Attribute{
@ -114,7 +114,7 @@ func (e *Meta) marshal() ([]byte, error) {
}) })
} }
func (e *Meta) unmarshal(data []byte) error { func (e *Meta) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
@ -153,7 +153,7 @@ const (
NF_NAT_RANGE_PERSISTENT = 0x8 NF_NAT_RANGE_PERSISTENT = 0x8
) )
func (e *Masq) marshal() ([]byte, error) { func (e *Masq) marshal(fam byte) ([]byte, error) {
msgData := []byte{} msgData := []byte{}
if !e.ToPorts { if !e.ToPorts {
flags := uint32(0) flags := uint32(0)
@ -196,7 +196,7 @@ func (e *Masq) marshal() ([]byte, error) {
}) })
} }
func (e *Masq) unmarshal(data []byte) error { func (e *Masq) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
@ -238,7 +238,7 @@ type Cmp struct {
Data []byte Data []byte
} }
func (e *Cmp) marshal() ([]byte, error) { func (e *Cmp) marshal(fam byte) ([]byte, error) {
cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{ cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_DATA_VALUE, Data: e.Data}, {Type: unix.NFTA_DATA_VALUE, Data: e.Data},
}) })
@ -259,7 +259,7 @@ func (e *Cmp) marshal() ([]byte, error) {
}) })
} }
func (e *Cmp) unmarshal(data []byte) error { func (e *Cmp) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -39,7 +39,7 @@ type Exthdr struct {
SourceRegister uint32 SourceRegister uint32
} }
func (e *Exthdr) marshal() ([]byte, error) { func (e *Exthdr) marshal(fam byte) ([]byte, error) {
var attr []netlink.Attribute var attr []netlink.Attribute
// Operations are differentiated by the Op and whether the SourceRegister // Operations are differentiated by the Op and whether the SourceRegister
@ -49,7 +49,7 @@ func (e *Exthdr) marshal() ([]byte, error) {
{Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}} {Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}}
} else { } else {
attr = []netlink.Attribute{ attr = []netlink.Attribute{
netlink.Attribute{Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}} {Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}}
} }
attr = append(attr, attr = append(attr,
@ -74,7 +74,7 @@ func (e *Exthdr) marshal() ([]byte, error) {
}) })
} }
func (e *Exthdr) unmarshal(data []byte) error { func (e *Exthdr) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -44,7 +44,7 @@ func TestExthdr(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
neh := Exthdr{} neh := Exthdr{}
data, err := tt.eh.marshal() data, err := tt.eh.marshal(0 /* don't care in this test */)
if err != nil { if err != nil {
t.Fatalf("marshal error: %+v", err) t.Fatalf("marshal error: %+v", err)
@ -56,7 +56,7 @@ func TestExthdr(t *testing.T) {
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
for ad.Next() { for ad.Next() {
if ad.Type() == unix.NFTA_EXPR_DATA { if ad.Type() == unix.NFTA_EXPR_DATA {
if err := neh.unmarshal(ad.Bytes()); err != nil { if err := neh.unmarshal(0, ad.Bytes()); err != nil {
t.Errorf("unmarshal error: %+v", err) t.Errorf("unmarshal error: %+v", err)
break break
} }

View File

@ -36,7 +36,7 @@ type Fib struct {
FlagPRESENT bool FlagPRESENT bool
} }
func (e *Fib) marshal() ([]byte, error) { func (e *Fib) marshal(fam byte) ([]byte, error) {
data := []byte{} data := []byte{}
reg, err := netlink.MarshalAttributes([]netlink.Attribute{ reg, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_FIB_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, {Type: unix.NFTA_FIB_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)},
@ -99,7 +99,7 @@ func (e *Fib) marshal() ([]byte, error) {
}) })
} }
func (e *Fib) unmarshal(data []byte) error { func (e *Fib) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -40,7 +40,7 @@ type Hash struct {
Type HashType Type HashType
} }
func (e *Hash) marshal() ([]byte, error) { func (e *Hash) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_HASH_SREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.SourceRegister))}, {Type: unix.NFTA_HASH_SREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.SourceRegister))},
{Type: unix.NFTA_HASH_DREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.DestRegister))}, {Type: unix.NFTA_HASH_DREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.DestRegister))},
@ -59,7 +59,7 @@ func (e *Hash) marshal() ([]byte, error) {
}) })
} }
func (e *Hash) unmarshal(data []byte) error { func (e *Hash) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -28,7 +28,7 @@ type Immediate struct {
Data []byte Data []byte
} }
func (e *Immediate) marshal() ([]byte, error) { func (e *Immediate) marshal(fam byte) ([]byte, error) {
immData, err := netlink.MarshalAttributes([]netlink.Attribute{ immData, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_DATA_VALUE, Data: e.Data}, {Type: unix.NFTA_DATA_VALUE, Data: e.Data},
}) })
@ -49,7 +49,7 @@ func (e *Immediate) marshal() ([]byte, error) {
}) })
} }
func (e *Immediate) unmarshal(data []byte) error { func (e *Immediate) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -71,7 +71,7 @@ type Limit struct {
Burst uint32 Burst uint32
} }
func (l *Limit) marshal() ([]byte, error) { func (l *Limit) marshal(fam byte) ([]byte, error) {
attrs := []netlink.Attribute{ attrs := []netlink.Attribute{
{Type: unix.NFTA_LIMIT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(l.Type))}, {Type: unix.NFTA_LIMIT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(l.Type))},
{Type: unix.NFTA_LIMIT_RATE, Data: binaryutil.BigEndian.PutUint64(l.Rate)}, {Type: unix.NFTA_LIMIT_RATE, Data: binaryutil.BigEndian.PutUint64(l.Rate)},
@ -103,7 +103,7 @@ func (l *Limit) marshal() ([]byte, error) {
}) })
} }
func (l *Limit) unmarshal(data []byte) error { func (l *Limit) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -68,7 +68,7 @@ type Log struct {
Data []byte Data []byte
} }
func (e *Log) marshal() ([]byte, error) { func (e *Log) marshal(fam byte) ([]byte, error) {
// Per https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n129 // Per https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n129
attrs := make([]netlink.Attribute, 0) attrs := make([]netlink.Attribute, 0)
if e.Key&(1<<unix.NFTA_LOG_GROUP) != 0 { if e.Key&(1<<unix.NFTA_LOG_GROUP) != 0 {
@ -120,7 +120,7 @@ func (e *Log) marshal() ([]byte, error) {
}) })
} }
func (e *Log) unmarshal(data []byte) error { func (e *Log) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -33,7 +33,7 @@ type Lookup struct {
Invert bool Invert bool
} }
func (e *Lookup) marshal() ([]byte, error) { func (e *Lookup) marshal(fam byte) ([]byte, error) {
// See: https://git.netfilter.org/libnftnl/tree/src/expr/lookup.c?id=6dc1c3d8bb64077da7f3f28c7368fb087d10a492#n115 // See: https://git.netfilter.org/libnftnl/tree/src/expr/lookup.c?id=6dc1c3d8bb64077da7f3f28c7368fb087d10a492#n115
var opAttrs []netlink.Attribute var opAttrs []netlink.Attribute
if e.SourceRegister != 0 { if e.SourceRegister != 0 {
@ -60,7 +60,7 @@ func (e *Lookup) marshal() ([]byte, error) {
}) })
} }
func (e *Lookup) unmarshal(data []byte) error { func (e *Lookup) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/xt"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -13,10 +14,10 @@ import (
type Match struct { type Match struct {
Name string Name string
Rev uint32 Rev uint32
Info []byte Info xt.InfoAny
} }
func (e *Match) marshal() ([]byte, error) { func (e *Match) marshal(fam byte) ([]byte, error) {
// Per https://git.netfilter.org/libnftnl/tree/src/expr/match.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38 // Per https://git.netfilter.org/libnftnl/tree/src/expr/match.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38
name := e.Name name := e.Name
// limit the extension name as (some) user-space tools do and leave room for // limit the extension name as (some) user-space tools do and leave room for
@ -24,10 +25,16 @@ func (e *Match) marshal() ([]byte, error) {
if len(name) >= /* sic! */ XTablesExtensionNameMaxLen { if len(name) >= /* sic! */ XTablesExtensionNameMaxLen {
name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00. name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00.
} }
// Marshalling assumes that the correct Info type for the particular table
// family and Match revision has been set.
info, err := xt.Marshal(xt.TableFamily(fam), e.Rev, e.Info)
if err != nil {
return nil, err
}
attrs := []netlink.Attribute{ attrs := []netlink.Attribute{
{Type: unix.NFTA_MATCH_NAME, Data: []byte(name + "\x00")}, {Type: unix.NFTA_MATCH_NAME, Data: []byte(name + "\x00")},
{Type: unix.NFTA_MATCH_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)}, {Type: unix.NFTA_MATCH_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)},
{Type: unix.NFTA_MATCH_INFO, Data: e.Info}, {Type: unix.NFTA_MATCH_INFO, Data: info},
} }
data, err := netlink.MarshalAttributes(attrs) data, err := netlink.MarshalAttributes(attrs)
if err != nil { if err != nil {
@ -40,13 +47,14 @@ func (e *Match) marshal() ([]byte, error) {
}) })
} }
func (e *Match) unmarshal(data []byte) error { func (e *Match) unmarshal(fam byte, data []byte) error {
// Per https://git.netfilter.org/libnftnl/tree/src/expr/match.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65 // Per https://git.netfilter.org/libnftnl/tree/src/expr/match.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
var info []byte
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
for ad.Next() { for ad.Next() {
switch ad.Type() { switch ad.Type() {
@ -56,8 +64,12 @@ func (e *Match) unmarshal(data []byte) error {
case unix.NFTA_MATCH_REV: case unix.NFTA_MATCH_REV:
e.Rev = ad.Uint32() e.Rev = ad.Uint32()
case unix.NFTA_MATCH_INFO: case unix.NFTA_MATCH_INFO:
e.Info = ad.Bytes() info = ad.Bytes()
} }
} }
return ad.Err() if err = ad.Err(); err != nil {
return err
}
e.Info, err = xt.Unmarshal(e.Name, xt.TableFamily(fam), e.Rev, info)
return err
} }

View File

@ -5,12 +5,14 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/google/nftables/xt"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func TestMatch(t *testing.T) { func TestMatch(t *testing.T) {
t.Parallel() t.Parallel()
payload := xt.Unknown([]byte{0xb0, 0x1d, 0xca, 0xfe, 0x00})
tests := []struct { tests := []struct {
name string name string
mtch Match mtch Match
@ -20,7 +22,7 @@ func TestMatch(t *testing.T) {
mtch: Match{ mtch: Match{
Name: "foobar", Name: "foobar",
Rev: 1234567890, Rev: 1234567890,
Info: []byte{0xb0, 0x1d, 0xca, 0xfe, 0x00}, Info: &payload,
}, },
}, },
} }
@ -28,7 +30,7 @@ func TestMatch(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ntgt := Match{} ntgt := Match{}
data, err := tt.mtch.marshal() data, err := tt.mtch.marshal(0 /* don't care in this test */)
if err != nil { if err != nil {
t.Fatalf("marshal error: %+v", err) t.Fatalf("marshal error: %+v", err)
@ -40,7 +42,7 @@ func TestMatch(t *testing.T) {
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
for ad.Next() { for ad.Next() {
if ad.Type() == unix.NFTA_EXPR_DATA { if ad.Type() == unix.NFTA_EXPR_DATA {
if err := ntgt.unmarshal(ad.Bytes()); err != nil { if err := ntgt.unmarshal(0 /* don't care in this test */, ad.Bytes()); err != nil {
t.Errorf("unmarshal error: %+v", err) t.Errorf("unmarshal error: %+v", err)
break break
} }

View File

@ -55,7 +55,7 @@ type NAT struct {
// |00008|--|00005| |len |flags| type| NFTA_NAT_REG_PROTO_MIN // |00008|--|00005| |len |flags| type| NFTA_NAT_REG_PROTO_MIN
// | 00 00 00 02 | | data | reg 2 // | 00 00 00 02 | | data | reg 2
func (e *NAT) marshal() ([]byte, error) { func (e *NAT) marshal(fam byte) ([]byte, error) {
attrs := []netlink.Attribute{ attrs := []netlink.Attribute{
{Type: unix.NFTA_NAT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, {Type: unix.NFTA_NAT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))},
{Type: unix.NFTA_NAT_FAMILY, Data: binaryutil.BigEndian.PutUint32(e.Family)}, {Type: unix.NFTA_NAT_FAMILY, Data: binaryutil.BigEndian.PutUint32(e.Family)},
@ -96,7 +96,7 @@ func (e *NAT) marshal() ([]byte, error) {
}) })
} }
func (e *NAT) unmarshal(data []byte) error { func (e *NAT) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -21,13 +21,13 @@ import (
type Notrack struct{} type Notrack struct{}
func (e *Notrack) marshal() ([]byte, error) { func (e *Notrack) marshal(fam byte) ([]byte, error) {
return netlink.MarshalAttributes([]netlink.Attribute{ return netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_EXPR_NAME, Data: []byte("notrack\x00")}, {Type: unix.NFTA_EXPR_NAME, Data: []byte("notrack\x00")},
}) })
} }
func (e *Notrack) unmarshal(data []byte) error { func (e *Notrack) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {

View File

@ -31,7 +31,7 @@ type Numgen struct {
Offset uint32 Offset uint32
} }
func (e *Numgen) marshal() ([]byte, error) { func (e *Numgen) marshal(fam byte) ([]byte, error) {
// Currently only two types are supported, failing if Type is not of two known types // Currently only two types are supported, failing if Type is not of two known types
switch e.Type { switch e.Type {
case unix.NFT_NG_INCREMENTAL: case unix.NFT_NG_INCREMENTAL:
@ -56,7 +56,7 @@ func (e *Numgen) marshal() ([]byte, error) {
}) })
} }
func (e *Numgen) unmarshal(data []byte) error { func (e *Numgen) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -27,7 +27,7 @@ type Objref struct {
Name string Name string
} }
func (e *Objref) marshal() ([]byte, error) { func (e *Objref) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_OBJREF_IMM_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, {Type: unix.NFTA_OBJREF_IMM_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))},
{Type: unix.NFTA_OBJREF_IMM_NAME, Data: []byte(e.Name)}, // NOT \x00-terminated?! {Type: unix.NFTA_OBJREF_IMM_NAME, Data: []byte(e.Name)}, // NOT \x00-terminated?!
@ -42,7 +42,7 @@ func (e *Objref) marshal() ([]byte, error) {
}) })
} }
func (e *Objref) unmarshal(data []byte) error { func (e *Objref) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -57,7 +57,7 @@ type Payload struct {
CsumFlags uint32 CsumFlags uint32
} }
func (e *Payload) marshal() ([]byte, error) { func (e *Payload) marshal(fam byte) ([]byte, error) {
var attrs []netlink.Attribute var attrs []netlink.Attribute
@ -100,7 +100,7 @@ func (e *Payload) marshal() ([]byte, error) {
}) })
} }
func (e *Payload) unmarshal(data []byte) error { func (e *Payload) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -44,7 +44,7 @@ type Queue struct {
Flag QueueFlag Flag QueueFlag
} }
func (e *Queue) marshal() ([]byte, error) { func (e *Queue) marshal(fam byte) ([]byte, error) {
if e.Total == 0 { if e.Total == 0 {
e.Total = 1 // The total default value is 1 e.Total = 1 // The total default value is 1
} }
@ -62,7 +62,7 @@ func (e *Queue) marshal() ([]byte, error) {
}) })
} }
func (e *Queue) unmarshal(data []byte) error { func (e *Queue) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -29,7 +29,7 @@ type Quota struct {
Over bool Over bool
} }
func (q *Quota) marshal() ([]byte, error) { func (q *Quota) marshal(fam byte) ([]byte, error) {
attrs := []netlink.Attribute{ attrs := []netlink.Attribute{
{Type: unix.NFTA_QUOTA_BYTES, Data: binaryutil.BigEndian.PutUint64(q.Bytes)}, {Type: unix.NFTA_QUOTA_BYTES, Data: binaryutil.BigEndian.PutUint64(q.Bytes)},
{Type: unix.NFTA_QUOTA_CONSUMED, Data: binaryutil.BigEndian.PutUint64(q.Consumed)}, {Type: unix.NFTA_QUOTA_CONSUMED, Data: binaryutil.BigEndian.PutUint64(q.Consumed)},
@ -55,7 +55,7 @@ func (q *Quota) marshal() ([]byte, error) {
}) })
} }
func (q *Quota) unmarshal(data []byte) error { func (q *Quota) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -30,7 +30,7 @@ type Range struct {
ToData []byte ToData []byte
} }
func (e *Range) marshal() ([]byte, error) { func (e *Range) marshal(fam byte) ([]byte, error) {
var attrs []netlink.Attribute var attrs []netlink.Attribute
var err error var err error
var rangeFromData, rangeToData []byte var rangeFromData, rangeToData []byte
@ -64,7 +64,7 @@ func (e *Range) marshal() ([]byte, error) {
}) })
} }
func (e *Range) unmarshal(data []byte) error { func (e *Range) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -28,7 +28,7 @@ type Redir struct {
Flags uint32 Flags uint32
} }
func (e *Redir) marshal() ([]byte, error) { func (e *Redir) marshal(fam byte) ([]byte, error) {
var attrs []netlink.Attribute var attrs []netlink.Attribute
if e.RegisterProtoMin > 0 { if e.RegisterProtoMin > 0 {
attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMin)}) attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMin)})
@ -51,7 +51,7 @@ func (e *Redir) marshal() ([]byte, error) {
}) })
} }
func (e *Redir) unmarshal(data []byte) error { func (e *Redir) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -27,7 +27,7 @@ type Reject struct {
Code uint8 Code uint8
} }
func (e *Reject) marshal() ([]byte, error) { func (e *Reject) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, {Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)},
{Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}}, {Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}},
@ -41,7 +41,7 @@ func (e *Reject) marshal() ([]byte, error) {
}) })
} }
func (e *Reject) unmarshal(data []byte) error { func (e *Reject) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -36,7 +36,7 @@ type Rt struct {
Key RtKey Key RtKey
} }
func (e *Rt) marshal() ([]byte, error) { func (e *Rt) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, {Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))},
{Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, {Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)},
@ -50,6 +50,6 @@ func (e *Rt) marshal() ([]byte, error) {
}) })
} }
func (e *Rt) unmarshal(data []byte) error { func (e *Rt) unmarshal(fam byte, data []byte) error {
return fmt.Errorf("not yet implemented") return fmt.Errorf("not yet implemented")
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/xt"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -16,10 +17,10 @@ const XTablesExtensionNameMaxLen = 29
type Target struct { type Target struct {
Name string Name string
Rev uint32 Rev uint32
Info []byte Info xt.InfoAny
} }
func (e *Target) marshal() ([]byte, error) { func (e *Target) marshal(fam byte) ([]byte, error) {
// Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38 // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38
name := e.Name name := e.Name
// limit the extension name as (some) user-space tools do and leave room for // limit the extension name as (some) user-space tools do and leave room for
@ -27,10 +28,16 @@ func (e *Target) marshal() ([]byte, error) {
if len(name) >= /* sic! */ XTablesExtensionNameMaxLen { if len(name) >= /* sic! */ XTablesExtensionNameMaxLen {
name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00. name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00.
} }
// Marshalling assumes that the correct Info type for the particular table
// family and Match revision has been set.
info, err := xt.Marshal(xt.TableFamily(fam), e.Rev, e.Info)
if err != nil {
return nil, err
}
attrs := []netlink.Attribute{ attrs := []netlink.Attribute{
{Type: unix.NFTA_TARGET_NAME, Data: []byte(name + "\x00")}, {Type: unix.NFTA_TARGET_NAME, Data: []byte(name + "\x00")},
{Type: unix.NFTA_TARGET_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)}, {Type: unix.NFTA_TARGET_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)},
{Type: unix.NFTA_TARGET_INFO, Data: e.Info}, {Type: unix.NFTA_TARGET_INFO, Data: info},
} }
data, err := netlink.MarshalAttributes(attrs) data, err := netlink.MarshalAttributes(attrs)
@ -44,13 +51,14 @@ func (e *Target) marshal() ([]byte, error) {
}) })
} }
func (e *Target) unmarshal(data []byte) error { func (e *Target) unmarshal(fam byte, data []byte) error {
// Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65 // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err
} }
var info []byte
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
for ad.Next() { for ad.Next() {
switch ad.Type() { switch ad.Type() {
@ -60,8 +68,12 @@ func (e *Target) unmarshal(data []byte) error {
case unix.NFTA_TARGET_REV: case unix.NFTA_TARGET_REV:
e.Rev = ad.Uint32() e.Rev = ad.Uint32()
case unix.NFTA_TARGET_INFO: case unix.NFTA_TARGET_INFO:
e.Info = ad.Bytes() info = ad.Bytes()
} }
} }
return ad.Err() if err = ad.Err(); err != nil {
return err
}
e.Info, err = xt.Unmarshal(e.Name, xt.TableFamily(fam), e.Rev, info)
return err
} }

View File

@ -5,12 +5,14 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/google/nftables/xt"
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func TestTarget(t *testing.T) { func TestTarget(t *testing.T) {
t.Parallel() t.Parallel()
payload := xt.Unknown([]byte{0xb0, 0x1d, 0xca, 0xfe, 0x00})
tests := []struct { tests := []struct {
name string name string
tgt Target tgt Target
@ -20,7 +22,7 @@ func TestTarget(t *testing.T) {
tgt: Target{ tgt: Target{
Name: "foobar", Name: "foobar",
Rev: 1234567890, Rev: 1234567890,
Info: []byte{0xb0, 0x1d, 0xca, 0xfe, 0x00}, Info: &payload,
}, },
}, },
} }
@ -28,7 +30,7 @@ func TestTarget(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ntgt := Target{} ntgt := Target{}
data, err := tt.tgt.marshal() data, err := tt.tgt.marshal(0 /* don't care in this test */)
if err != nil { if err != nil {
t.Fatalf("marshal error: %+v", err) t.Fatalf("marshal error: %+v", err)
@ -40,7 +42,7 @@ func TestTarget(t *testing.T) {
ad.ByteOrder = binary.BigEndian ad.ByteOrder = binary.BigEndian
for ad.Next() { for ad.Next() {
if ad.Type() == unix.NFTA_EXPR_DATA { if ad.Type() == unix.NFTA_EXPR_DATA {
if err := ntgt.unmarshal(ad.Bytes()); err != nil { if err := ntgt.unmarshal(0 /* don't care in this test */, ad.Bytes()); err != nil {
t.Errorf("unmarshal error: %+v", err) t.Errorf("unmarshal error: %+v", err)
break break
} }

View File

@ -36,7 +36,7 @@ type TProxy struct {
RegPort uint32 RegPort uint32
} }
func (e *TProxy) marshal() ([]byte, error) { func (e *TProxy) marshal(fam byte) ([]byte, error) {
data, err := netlink.MarshalAttributes([]netlink.Attribute{ data, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: NFTA_TPROXY_FAMILY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Family))}, {Type: NFTA_TPROXY_FAMILY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Family))},
{Type: NFTA_TPROXY_REG, Data: binaryutil.BigEndian.PutUint32(e.RegPort)}, {Type: NFTA_TPROXY_REG, Data: binaryutil.BigEndian.PutUint32(e.RegPort)},
@ -50,7 +50,7 @@ func (e *TProxy) marshal() ([]byte, error) {
}) })
} }
func (e *TProxy) unmarshal(data []byte) error { func (e *TProxy) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

View File

@ -53,7 +53,7 @@ const (
VerdictStop VerdictStop
) )
func (e *Verdict) marshal() ([]byte, error) { func (e *Verdict) marshal(fam byte) ([]byte, error) {
// A verdict is a tree of netlink attributes structured as follows: // A verdict is a tree of netlink attributes structured as follows:
// NFTA_LIST_ELEM | NLA_F_NESTED { // NFTA_LIST_ELEM | NLA_F_NESTED {
// NFTA_EXPR_NAME { "immediate\x00" } // NFTA_EXPR_NAME { "immediate\x00" }
@ -96,7 +96,7 @@ func (e *Verdict) marshal() ([]byte, error) {
}) })
} }
func (e *Verdict) unmarshal(data []byte) error { func (e *Verdict) unmarshal(fam byte, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data) ad, err := netlink.NewAttributeDecoder(data)
if err != nil { if err != nil {
return err return err

19
rule.go
View File

@ -92,7 +92,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
} }
var rules []*Rule var rules []*Rule
for _, msg := range reply { for _, msg := range reply {
r, err := ruleFromMsg(msg) r, err := ruleFromMsg(t.Family, msg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,7 +113,7 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
for idx, expr := range r.Exprs { for idx, expr := range r.Exprs {
exprAttrs[idx] = netlink.Attribute{ exprAttrs[idx] = netlink.Attribute{
Type: unix.NLA_F_NESTED | unix.NFTA_LIST_ELEM, Type: unix.NLA_F_NESTED | unix.NFTA_LIST_ELEM,
Data: cc.marshalExpr(expr), Data: cc.marshalExpr(byte(r.Table.Family), expr),
} }
} }
@ -215,7 +215,7 @@ func (cc *Conn) DelRule(r *Rule) error {
return nil return nil
} }
func exprsFromMsg(b []byte) ([]expr.Any, error) { func exprsFromMsg(fam TableFamily, b []byte) ([]expr.Any, error) {
ad, err := netlink.NewAttributeDecoder(b) ad, err := netlink.NewAttributeDecoder(b)
if err != nil { if err != nil {
return nil, err return nil, err
@ -285,7 +285,7 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) {
} }
ad.Do(func(b []byte) error { ad.Do(func(b []byte) error {
if err := expr.Unmarshal(b, e); err != nil { if err := expr.Unmarshal(byte(fam), b, e); err != nil {
return err return err
} }
// Verdict expressions are a special-case of immediate expressions, so // Verdict expressions are a special-case of immediate expressions, so
@ -293,7 +293,7 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) {
// register (invalid), re-parse it as a verdict expression. // register (invalid), re-parse it as a verdict expression.
if imm, isImmediate := e.(*expr.Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 { if imm, isImmediate := e.(*expr.Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 {
e = &expr.Verdict{} e = &expr.Verdict{}
if err := expr.Unmarshal(b, e); err != nil { if err := expr.Unmarshal(byte(fam), b, e); err != nil {
return err return err
} }
} }
@ -308,7 +308,7 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) {
return exprs, ad.Err() return exprs, ad.Err()
} }
func ruleFromMsg(msg netlink.Message) (*Rule, error) { func ruleFromMsg(fam TableFamily, msg netlink.Message) (*Rule, error) {
if got, want := msg.Header.Type, ruleHeaderType; got != want { if got, want := msg.Header.Type, ruleHeaderType; got != want {
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
} }
@ -321,12 +321,15 @@ func ruleFromMsg(msg netlink.Message) (*Rule, error) {
for ad.Next() { for ad.Next() {
switch ad.Type() { switch ad.Type() {
case unix.NFTA_RULE_TABLE: case unix.NFTA_RULE_TABLE:
r.Table = &Table{Name: ad.String()} r.Table = &Table{
Name: ad.String(),
Family: fam,
}
case unix.NFTA_RULE_CHAIN: case unix.NFTA_RULE_CHAIN:
r.Chain = &Chain{Name: ad.String()} r.Chain = &Chain{Name: ad.String()}
case unix.NFTA_RULE_EXPRESSIONS: case unix.NFTA_RULE_EXPRESSIONS:
ad.Do(func(b []byte) error { ad.Do(func(b []byte) error {
r.Exprs, err = exprsFromMsg(b) r.Exprs, err = exprsFromMsg(fam, b)
return err return err
}) })
case unix.NFTA_RULE_POSITION: case unix.NFTA_RULE_POSITION:

View File

@ -28,12 +28,13 @@ type TableFamily byte
// Possible TableFamily values. // Possible TableFamily values.
const ( const (
TableFamilyINet TableFamily = unix.NFPROTO_INET TableFamilyUnspecified TableFamily = unix.NFPROTO_UNSPEC
TableFamilyIPv4 TableFamily = unix.NFPROTO_IPV4 TableFamilyINet TableFamily = unix.NFPROTO_INET
TableFamilyIPv6 TableFamily = unix.NFPROTO_IPV6 TableFamilyIPv4 TableFamily = unix.NFPROTO_IPV4
TableFamilyARP TableFamily = unix.NFPROTO_ARP TableFamilyIPv6 TableFamily = unix.NFPROTO_IPV6
TableFamilyNetdev TableFamily = unix.NFPROTO_NETDEV TableFamilyARP TableFamily = unix.NFPROTO_ARP
TableFamilyBridge TableFamily = unix.NFPROTO_BRIDGE TableFamilyNetdev TableFamily = unix.NFPROTO_NETDEV
TableFamilyBridge TableFamily = unix.NFPROTO_BRIDGE
) )
// A Table contains Chains. See also // A Table contains Chains. See also
@ -111,7 +112,7 @@ func (cc *Conn) ListTables() ([]*Table, error) {
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE), Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE),
Flags: netlink.Request | netlink.Dump, Flags: netlink.Request | netlink.Dump,
}, },
Data: extraHeader(uint8(unix.AF_UNSPEC), 0), Data: extraHeader(uint8(TableFamilyUnspecified), 0),
} }
response, err := conn.Execute(msg) response, err := conn.Execute(msg)

94
xt/info.go Normal file
View File

@ -0,0 +1,94 @@
package xt
import (
"golang.org/x/sys/unix"
)
// TableFamily specifies the address family of the table Match or Target Info
// data is contained in. On purpose, we don't import the expr package here in
// order to keep the option open to import this package instead into expr.
type TableFamily byte
// InfoAny is a (un)marshaling implemented by any info type.
type InfoAny interface {
marshal(fam TableFamily, rev uint32) ([]byte, error)
unmarshal(fam TableFamily, rev uint32, data []byte) error
}
// Marshal a Match or Target Info type into its binary representation.
func Marshal(fam TableFamily, rev uint32, info InfoAny) ([]byte, error) {
return info.marshal(fam, rev)
}
// Unmarshal Info binary payload into its corresponding dedicated type as
// indicated by the name argument. In several cases, unmarshalling depends on
// the specific table family the Target or Match expression with the info
// payload belongs to, as well as the specific info structure revision.
func Unmarshal(name string, fam TableFamily, rev uint32, data []byte) (InfoAny, error) {
var i InfoAny
switch name {
case "addrtype":
switch rev {
case 0:
i = &AddrType{}
case 1:
i = &AddrTypeV1{}
}
case "conntrack":
switch rev {
case 1:
i = &ConntrackMtinfo1{}
case 2:
i = &ConntrackMtinfo2{}
case 3:
i = &ConntrackMtinfo3{}
}
case "tcp":
i = &Tcp{}
case "udp":
i = &Udp{}
case "SNAT":
if fam == unix.NFPROTO_IPV4 {
i = &NatIPv4MultiRangeCompat{}
}
case "DNAT":
switch fam {
case unix.NFPROTO_IPV4:
if rev == 0 {
i = &NatIPv4MultiRangeCompat{}
break
}
fallthrough
case unix.NFPROTO_IPV6:
switch rev {
case 1:
i = &NatRange{}
case 2:
i = &NatRange2{}
}
}
case "MASQUERADE":
switch fam {
case unix.NFPROTO_IPV4:
i = &NatIPv4MultiRangeCompat{}
}
case "REDIRECT":
switch fam {
case unix.NFPROTO_IPV4:
if rev == 0 {
i = &NatIPv4MultiRangeCompat{}
break
}
fallthrough
case unix.NFPROTO_IPV6:
i = &NatRange{}
}
}
if i == nil {
i = &Unknown{}
}
if err := i.unmarshal(fam, rev, data); err != nil {
return nil, err
}
return i, nil
}

89
xt/match_addrtype.go Normal file
View File

@ -0,0 +1,89 @@
package xt
import (
"github.com/google/nftables/alignedbuff"
)
// Rev. 0, see https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_addrtype.h#L38
type AddrType struct {
Source uint16
Dest uint16
InvertSource bool
InvertDest bool
}
type AddrTypeFlags uint32
const (
AddrTypeUnspec AddrTypeFlags = 1 << iota
AddrTypeUnicast
AddrTypeLocal
AddrTypeBroadcast
AddrTypeAnycast
AddrTypeMulticast
AddrTypeBlackhole
AddrTypeUnreachable
AddrTypeProhibit
AddrTypeThrow
AddrTypeNat
AddrTypeXresolve
)
// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_addrtype.h#L31
type AddrTypeV1 struct {
Source uint16
Dest uint16
Flags AddrTypeFlags
}
func (x *AddrType) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
ab.PutUint16(x.Source)
ab.PutUint16(x.Dest)
putBool32(&ab, x.InvertSource)
putBool32(&ab, x.InvertDest)
return ab.Data(), nil
}
func (x *AddrType) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if x.Source, err = ab.Uint16(); err != nil {
return nil
}
if x.Dest, err = ab.Uint16(); err != nil {
return nil
}
if x.InvertSource, err = bool32(&ab); err != nil {
return nil
}
if x.InvertDest, err = bool32(&ab); err != nil {
return nil
}
return nil
}
func (x *AddrTypeV1) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
ab.PutUint16(x.Source)
ab.PutUint16(x.Dest)
ab.PutUint32(uint32(x.Flags))
return ab.Data(), nil
}
func (x *AddrTypeV1) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if x.Source, err = ab.Uint16(); err != nil {
return nil
}
if x.Dest, err = ab.Uint16(); err != nil {
return nil
}
var flags uint32
if flags, err = ab.Uint32(); err != nil {
return nil
}
x.Flags = AddrTypeFlags(flags)
return nil
}

71
xt/match_addrtype_test.go Normal file
View File

@ -0,0 +1,71 @@
package xt
import (
"reflect"
"testing"
)
func TestTargetAddrType(t *testing.T) {
t.Parallel()
tests := []struct {
name string
fam byte
rev uint32
info InfoAny
empty InfoAny
}{
{
name: "un/marshal AddrType Rev 0 round-trip",
fam: 0,
rev: 0,
info: &AddrType{
Source: 0x1234,
Dest: 0x5678,
InvertSource: true,
InvertDest: false,
},
empty: &AddrType{},
},
{
name: "un/marshal AddrType Rev 0 round-trip",
fam: 0,
rev: 0,
info: &AddrType{
Source: 0x1234,
Dest: 0x5678,
InvertSource: false,
InvertDest: true,
},
empty: &AddrType{},
},
{
name: "un/marshal AddrType Rev 1 round-trip",
fam: 0,
rev: 0,
info: &AddrTypeV1{
Source: 0x1234,
Dest: 0x5678,
Flags: 0xb00f,
},
empty: &AddrTypeV1{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(TableFamily(tt.fam), tt.rev)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = tt.empty
err = recoveredInfo.unmarshal(TableFamily(tt.fam), tt.rev, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

250
xt/match_conntrack.go Normal file
View File

@ -0,0 +1,250 @@
package xt
import (
"net"
"github.com/google/nftables/alignedbuff"
)
type ConntrackFlags uint16
const (
ConntrackState ConntrackFlags = 1 << iota
ConntrackProto
ConntrackOrigSrc
ConntrackOrigDst
ConntrackReplSrc
ConntrackReplDst
ConntrackStatus
ConntrackExpires
ConntrackOrigSrcPort
ConntrackOrigDstPort
ConntrackReplSrcPort
ConntrackReplDstPrt
ConntrackDirection
ConntrackStateAlias
)
type ConntrackMtinfoBase struct {
OrigSrcAddr net.IP
OrigSrcMask net.IPMask
OrigDstAddr net.IP
OrigDstMask net.IPMask
ReplSrcAddr net.IP
ReplSrcMask net.IPMask
ReplDstAddr net.IP
ReplDstMask net.IPMask
ExpiresMin uint32
ExpiresMax uint32
L4Proto uint16
OrigSrcPort uint16
OrigDstPort uint16
ReplSrcPort uint16
ReplDstPort uint16
}
// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L38
type ConntrackMtinfo1 struct {
ConntrackMtinfoBase
StateMask uint8
StatusMask uint8
}
// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L51
type ConntrackMtinfo2 struct {
ConntrackMtinfoBase
StateMask uint16
StatusMask uint16
}
// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L64
type ConntrackMtinfo3 struct {
ConntrackMtinfo2
OrigSrcPortHigh uint16
OrigDstPortHigh uint16
ReplSrcPortHigh uint16
ReplDstPortHigh uint16
}
func (x *ConntrackMtinfoBase) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
if err := putIPv46(ab, fam, x.OrigSrcAddr); err != nil {
return err
}
if err := putIPv46Mask(ab, fam, x.OrigSrcMask); err != nil {
return err
}
if err := putIPv46(ab, fam, x.OrigDstAddr); err != nil {
return err
}
if err := putIPv46Mask(ab, fam, x.OrigDstMask); err != nil {
return err
}
if err := putIPv46(ab, fam, x.ReplSrcAddr); err != nil {
return err
}
if err := putIPv46Mask(ab, fam, x.ReplSrcMask); err != nil {
return err
}
if err := putIPv46(ab, fam, x.ReplDstAddr); err != nil {
return err
}
if err := putIPv46Mask(ab, fam, x.ReplDstMask); err != nil {
return err
}
ab.PutUint32(x.ExpiresMin)
ab.PutUint32(x.ExpiresMax)
ab.PutUint16(x.L4Proto)
ab.PutUint16(x.OrigSrcPort)
ab.PutUint16(x.OrigDstPort)
ab.PutUint16(x.ReplSrcPort)
ab.PutUint16(x.ReplDstPort)
return nil
}
func (x *ConntrackMtinfoBase) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
var err error
if x.OrigSrcAddr, err = iPv46(ab, fam); err != nil {
return err
}
if x.OrigSrcMask, err = iPv46Mask(ab, fam); err != nil {
return err
}
if x.OrigDstAddr, err = iPv46(ab, fam); err != nil {
return err
}
if x.OrigDstMask, err = iPv46Mask(ab, fam); err != nil {
return err
}
if x.ReplSrcAddr, err = iPv46(ab, fam); err != nil {
return err
}
if x.ReplSrcMask, err = iPv46Mask(ab, fam); err != nil {
return err
}
if x.ReplDstAddr, err = iPv46(ab, fam); err != nil {
return err
}
if x.ReplDstMask, err = iPv46Mask(ab, fam); err != nil {
return err
}
if x.ExpiresMin, err = ab.Uint32(); err != nil {
return err
}
if x.ExpiresMax, err = ab.Uint32(); err != nil {
return err
}
if x.L4Proto, err = ab.Uint16(); err != nil {
return err
}
if x.OrigSrcPort, err = ab.Uint16(); err != nil {
return err
}
if x.OrigDstPort, err = ab.Uint16(); err != nil {
return err
}
if x.ReplSrcPort, err = ab.Uint16(); err != nil {
return err
}
if x.ReplDstPort, err = ab.Uint16(); err != nil {
return err
}
return nil
}
func (x *ConntrackMtinfo1) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
if err := x.ConntrackMtinfoBase.marshalAB(fam, rev, &ab); err != nil {
return nil, err
}
ab.PutUint8(x.StateMask)
ab.PutUint8(x.StatusMask)
return ab.Data(), nil
}
func (x *ConntrackMtinfo1) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if err = x.ConntrackMtinfoBase.unmarshalAB(fam, rev, &ab); err != nil {
return err
}
if x.StateMask, err = ab.Uint8(); err != nil {
return err
}
if x.StatusMask, err = ab.Uint8(); err != nil {
return err
}
return nil
}
func (x *ConntrackMtinfo2) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
if err := x.ConntrackMtinfoBase.marshalAB(fam, rev, ab); err != nil {
return err
}
ab.PutUint16(x.StateMask)
ab.PutUint16(x.StatusMask)
return nil
}
func (x *ConntrackMtinfo2) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
if err := x.marshalAB(fam, rev, &ab); err != nil {
return nil, err
}
return ab.Data(), nil
}
func (x *ConntrackMtinfo2) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
var err error
if err = x.ConntrackMtinfoBase.unmarshalAB(fam, rev, ab); err != nil {
return err
}
if x.StateMask, err = ab.Uint16(); err != nil {
return err
}
if x.StatusMask, err = ab.Uint16(); err != nil {
return err
}
return nil
}
func (x *ConntrackMtinfo2) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if err = x.unmarshalAB(fam, rev, &ab); err != nil {
return err
}
return nil
}
func (x *ConntrackMtinfo3) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
if err := x.ConntrackMtinfo2.marshalAB(fam, rev, &ab); err != nil {
return nil, err
}
ab.PutUint16(x.OrigSrcPortHigh)
ab.PutUint16(x.OrigDstPortHigh)
ab.PutUint16(x.ReplSrcPortHigh)
ab.PutUint16(x.ReplDstPortHigh)
return ab.Data(), nil
}
func (x *ConntrackMtinfo3) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if err = x.ConntrackMtinfo2.unmarshalAB(fam, rev, &ab); err != nil {
return err
}
if x.OrigSrcPortHigh, err = ab.Uint16(); err != nil {
return err
}
if x.OrigDstPortHigh, err = ab.Uint16(); err != nil {
return err
}
if x.ReplSrcPortHigh, err = ab.Uint16(); err != nil {
return err
}
if x.ReplDstPortHigh, err = ab.Uint16(); err != nil {
return err
}
return nil
}

213
xt/match_conntrack_test.go Normal file
View File

@ -0,0 +1,213 @@
package xt
import (
"net"
"reflect"
"testing"
"golang.org/x/sys/unix"
)
func TestMatchConntrack(t *testing.T) {
t.Parallel()
tests := []struct {
name string
fam byte
rev uint32
info InfoAny
empty InfoAny
}{
{
name: "un/marshal ConntrackMtinfo1 IPv4 round-trip",
fam: unix.NFPROTO_IPV4,
rev: 0,
info: &ConntrackMtinfo1{
ConntrackMtinfoBase: ConntrackMtinfoBase{
OrigSrcAddr: net.ParseIP("1.2.3.4").To4(),
OrigSrcMask: net.IPv4Mask(0x12, 0x23, 0x34, 0x45), // only for test ;)
OrigDstAddr: net.ParseIP("2.3.4.5").To4(),
OrigDstMask: net.IPv4Mask(0x23, 0x34, 0x45, 0x56), // only for test ;)
ReplSrcAddr: net.ParseIP("10.20.30.40").To4(),
ReplSrcMask: net.IPv4Mask(0xf2, 0xe3, 0xd4, 0xc5), // only for test ;)
ReplDstAddr: net.ParseIP("2.3.4.5").To4(),
ReplDstMask: net.IPv4Mask(0xe3, 0xd4, 0xc5, 0xb6), // only for test ;)
ExpiresMin: 0x1234,
ExpiresMax: 0x2345,
L4Proto: 0xaa55,
OrigSrcPort: 123,
OrigDstPort: 321,
ReplSrcPort: 789,
ReplDstPort: 987,
},
StateMask: 0x55,
StatusMask: 0xaa,
},
empty: &ConntrackMtinfo1{},
},
{
name: "un/marshal ConntrackMtinfo1 IPv6 round-trip",
fam: unix.NFPROTO_IPV6,
rev: 0,
info: &ConntrackMtinfo1{
ConntrackMtinfoBase: ConntrackMtinfoBase{
OrigSrcAddr: net.ParseIP("fe80::dead:f001"),
OrigSrcMask: net.IPMask{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
OrigDstAddr: net.ParseIP("fd00::dead:f001"),
OrigDstMask: net.IPMask{0x11, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ReplSrcAddr: net.ParseIP("fe80::c01d:cafe"),
ReplSrcMask: net.IPMask{0x21, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ReplDstAddr: net.ParseIP("fd00::c01d:cafe"),
ReplDstMask: net.IPMask{0x31, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ExpiresMin: 0x1234,
ExpiresMax: 0x2345,
L4Proto: 0xaa55,
OrigSrcPort: 123,
OrigDstPort: 321,
ReplSrcPort: 789,
ReplDstPort: 987,
},
StateMask: 0x55,
StatusMask: 0xaa,
},
empty: &ConntrackMtinfo1{},
},
{
name: "un/marshal ConntrackMtinfo2 IPv4 round-trip",
fam: unix.NFPROTO_IPV4,
rev: 0,
info: &ConntrackMtinfo2{
ConntrackMtinfoBase: ConntrackMtinfoBase{
OrigSrcAddr: net.ParseIP("1.2.3.4").To4(),
OrigSrcMask: net.IPv4Mask(0x12, 0x23, 0x34, 0x45), // only for test ;)
OrigDstAddr: net.ParseIP("2.3.4.5").To4(),
OrigDstMask: net.IPv4Mask(0x23, 0x34, 0x45, 0x56), // only for test ;)
ReplSrcAddr: net.ParseIP("10.20.30.40").To4(),
ReplSrcMask: net.IPv4Mask(0xf2, 0xe3, 0xd4, 0xc5), // only for test ;)
ReplDstAddr: net.ParseIP("2.3.4.5").To4(),
ReplDstMask: net.IPv4Mask(0xe3, 0xd4, 0xc5, 0xb6), // only for test ;)
ExpiresMin: 0x1234,
ExpiresMax: 0x2345,
L4Proto: 0xaa55,
OrigSrcPort: 123,
OrigDstPort: 321,
ReplSrcPort: 789,
ReplDstPort: 987,
},
StateMask: 0x55aa,
StatusMask: 0xaa55,
},
empty: &ConntrackMtinfo2{},
},
{
name: "un/marshal ConntrackMtinfo1 IPv6 round-trip",
fam: unix.NFPROTO_IPV6,
rev: 0,
info: &ConntrackMtinfo2{
ConntrackMtinfoBase: ConntrackMtinfoBase{
OrigSrcAddr: net.ParseIP("fe80::dead:f001"),
OrigSrcMask: net.IPMask{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
OrigDstAddr: net.ParseIP("fd00::dead:f001"),
OrigDstMask: net.IPMask{0x11, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ReplSrcAddr: net.ParseIP("fe80::c01d:cafe"),
ReplSrcMask: net.IPMask{0x21, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ReplDstAddr: net.ParseIP("fd00::c01d:cafe"),
ReplDstMask: net.IPMask{0x31, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ExpiresMin: 0x1234,
ExpiresMax: 0x2345,
L4Proto: 0xaa55,
OrigSrcPort: 123,
OrigDstPort: 321,
ReplSrcPort: 789,
ReplDstPort: 987,
},
StateMask: 0x55aa,
StatusMask: 0xaa55,
},
empty: &ConntrackMtinfo2{},
},
{
name: "un/marshal ConntrackMtinfo3 IPv4 round-trip",
fam: unix.NFPROTO_IPV4,
rev: 0,
info: &ConntrackMtinfo3{
ConntrackMtinfo2: ConntrackMtinfo2{
ConntrackMtinfoBase: ConntrackMtinfoBase{
OrigSrcAddr: net.ParseIP("1.2.3.4").To4(),
OrigSrcMask: net.IPv4Mask(0x12, 0x23, 0x34, 0x45), // only for test ;)
OrigDstAddr: net.ParseIP("2.3.4.5").To4(),
OrigDstMask: net.IPv4Mask(0x23, 0x34, 0x45, 0x56), // only for test ;)
ReplSrcAddr: net.ParseIP("10.20.30.40").To4(),
ReplSrcMask: net.IPv4Mask(0xf2, 0xe3, 0xd4, 0xc5), // only for test ;)
ReplDstAddr: net.ParseIP("2.3.4.5").To4(),
ReplDstMask: net.IPv4Mask(0xe3, 0xd4, 0xc5, 0xb6), // only for test ;)
ExpiresMin: 0x1234,
ExpiresMax: 0x2345,
L4Proto: 0xaa55,
OrigSrcPort: 123,
OrigDstPort: 321,
ReplSrcPort: 789,
ReplDstPort: 987,
},
StateMask: 0x55aa,
StatusMask: 0xaa55,
},
OrigSrcPortHigh: 0xabcd,
OrigDstPortHigh: 0xcdba,
ReplSrcPortHigh: 0x1234,
ReplDstPortHigh: 0x4321,
},
empty: &ConntrackMtinfo3{},
},
{
name: "un/marshal ConntrackMtinfo1 IPv6 round-trip",
fam: unix.NFPROTO_IPV6,
rev: 0,
info: &ConntrackMtinfo3{
ConntrackMtinfo2: ConntrackMtinfo2{
ConntrackMtinfoBase: ConntrackMtinfoBase{
OrigSrcAddr: net.ParseIP("fe80::dead:f001"),
OrigSrcMask: net.IPMask{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
OrigDstAddr: net.ParseIP("fd00::dead:f001"),
OrigDstMask: net.IPMask{0x11, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ReplSrcAddr: net.ParseIP("fe80::c01d:cafe"),
ReplSrcMask: net.IPMask{0x21, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ReplDstAddr: net.ParseIP("fd00::c01d:cafe"),
ReplDstMask: net.IPMask{0x31, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, // only for test ;)
ExpiresMin: 0x1234,
ExpiresMax: 0x2345,
L4Proto: 0xaa55,
OrigSrcPort: 123,
OrigDstPort: 321,
ReplSrcPort: 789,
ReplDstPort: 987,
},
StateMask: 0x55aa,
StatusMask: 0xaa55,
},
OrigSrcPortHigh: 0xabcd,
OrigDstPortHigh: 0xcdba,
ReplSrcPortHigh: 0x1234,
ReplDstPortHigh: 0x4321,
},
empty: &ConntrackMtinfo3{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(TableFamily(tt.fam), tt.rev)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = tt.empty
err = recoveredInfo.unmarshal(TableFamily(tt.fam), tt.rev, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

74
xt/match_tcp.go Normal file
View File

@ -0,0 +1,74 @@
package xt
import (
"github.com/google/nftables/alignedbuff"
)
// Tcp is the Match.Info payload for the tcp xtables extension
// (https://wiki.nftables.org/wiki-nftables/index.php/Supported_features_compared_to_xtables#tcp).
//
// See
// https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_tcpudp.h#L8
type Tcp struct {
SrcPorts [2]uint16 // min, max source port range
DstPorts [2]uint16 // min, max destination port range
Option uint8 // TCP option if non-zero
FlagsMask uint8 // TCP flags mask
FlagsCmp uint8 // TCP flags compare
InvFlags TcpInvFlagset // Inverse flags
}
type TcpInvFlagset uint8
const (
TcpInvSrcPorts TcpInvFlagset = 1 << iota
TcpInvDestPorts
TcpInvFlags
TcpInvOption
TcpInvMask TcpInvFlagset = (1 << iota) - 1
)
func (x *Tcp) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
ab.PutUint16(x.SrcPorts[0])
ab.PutUint16(x.SrcPorts[1])
ab.PutUint16(x.DstPorts[0])
ab.PutUint16(x.DstPorts[1])
ab.PutUint8(x.Option)
ab.PutUint8(x.FlagsMask)
ab.PutUint8(x.FlagsCmp)
ab.PutUint8(byte(x.InvFlags))
return ab.Data(), nil
}
func (x *Tcp) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if x.SrcPorts[0], err = ab.Uint16(); err != nil {
return err
}
if x.SrcPorts[1], err = ab.Uint16(); err != nil {
return err
}
if x.DstPorts[0], err = ab.Uint16(); err != nil {
return err
}
if x.DstPorts[1], err = ab.Uint16(); err != nil {
return err
}
if x.Option, err = ab.Uint8(); err != nil {
return err
}
if x.FlagsMask, err = ab.Uint8(); err != nil {
return err
}
if x.FlagsCmp, err = ab.Uint8(); err != nil {
return err
}
var invFlags uint8
if invFlags, err = ab.Uint8(); err != nil {
return err
}
x.InvFlags = TcpInvFlagset(invFlags)
return nil
}

44
xt/match_tcp_test.go Normal file
View File

@ -0,0 +1,44 @@
package xt
import (
"reflect"
"testing"
)
func TestMatchTcp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
info InfoAny
}{
{
name: "un/marshal Tcp round-trip",
info: &Tcp{
SrcPorts: [2]uint16{0x1234, 0x5678},
DstPorts: [2]uint16{0x2345, 0x6789},
Option: 0x12,
FlagsMask: 0x34,
FlagsCmp: 0x56,
InvFlags: 0x78,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(0, 0)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = &Tcp{}
err = recoveredInfo.unmarshal(0, 0, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

57
xt/match_udp.go Normal file
View File

@ -0,0 +1,57 @@
package xt
import (
"github.com/google/nftables/alignedbuff"
)
// Tcp is the Match.Info payload for the tcp xtables extension
// (https://wiki.nftables.org/wiki-nftables/index.php/Supported_features_compared_to_xtables#tcp).
//
// See
// https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_tcpudp.h#L25
type Udp struct {
SrcPorts [2]uint16 // min, max source port range
DstPorts [2]uint16 // min, max destination port range
InvFlags UdpInvFlagset // Inverse flags
}
type UdpInvFlagset uint8
const (
UdpInvSrcPorts UdpInvFlagset = 1 << iota
UdpInvDestPorts
UdpInvMask UdpInvFlagset = (1 << iota) - 1
)
func (x *Udp) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
ab.PutUint16(x.SrcPorts[0])
ab.PutUint16(x.SrcPorts[1])
ab.PutUint16(x.DstPorts[0])
ab.PutUint16(x.DstPorts[1])
ab.PutUint8(byte(x.InvFlags))
return ab.Data(), nil
}
func (x *Udp) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if x.SrcPorts[0], err = ab.Uint16(); err != nil {
return err
}
if x.SrcPorts[1], err = ab.Uint16(); err != nil {
return err
}
if x.DstPorts[0], err = ab.Uint16(); err != nil {
return err
}
if x.DstPorts[1], err = ab.Uint16(); err != nil {
return err
}
var invFlags uint8
if invFlags, err = ab.Uint8(); err != nil {
return err
}
x.InvFlags = UdpInvFlagset(invFlags)
return nil
}

41
xt/match_udp_test.go Normal file
View File

@ -0,0 +1,41 @@
package xt
import (
"reflect"
"testing"
)
func TestMatchUdp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
info InfoAny
}{
{
name: "un/marshal Udp round-trip",
info: &Udp{
SrcPorts: [2]uint16{0x1234, 0x5678},
DstPorts: [2]uint16{0x2345, 0x6789},
InvFlags: 0x78,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(0, 0)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = &Udp{}
err = recoveredInfo.unmarshal(0, 0, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

106
xt/target_dnat.go Normal file
View File

@ -0,0 +1,106 @@
package xt
import (
"net"
"github.com/google/nftables/alignedbuff"
)
type NatRangeFlags uint
// See: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L8
const (
NatRangeMapIPs NatRangeFlags = (1 << iota)
NatRangeProtoSpecified
NatRangeProtoRandom
NatRangePersistent
NatRangeProtoRandomFully
NatRangeProtoOffset
NatRangeNetmap
NatRangeMask NatRangeFlags = (1 << iota) - 1
NatRangeProtoRandomAll = NatRangeProtoRandom | NatRangeProtoRandomFully
)
// see: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L38
type NatRange struct {
Flags uint // sic! platform/arch/compiler-dependent uint size
MinIP net.IP // always taking up space for an IPv6 address
MaxIP net.IP // dito
MinPort uint16
MaxPort uint16
}
// see: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L46
type NatRange2 struct {
NatRange
BasePort uint16
}
func (x *NatRange) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
if err := x.marshalAB(fam, rev, &ab); err != nil {
return nil, err
}
return ab.Data(), nil
}
func (x *NatRange) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
ab.PutUint(x.Flags)
if err := putIPv46(ab, fam, x.MinIP); err != nil {
return err
}
if err := putIPv46(ab, fam, x.MaxIP); err != nil {
return err
}
ab.PutUint16BE(x.MinPort)
ab.PutUint16BE(x.MaxPort)
return nil
}
func (x *NatRange) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
return x.unmarshalAB(fam, rev, &ab)
}
func (x *NatRange) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
var err error
if x.Flags, err = ab.Uint(); err != nil {
return err
}
if x.MinIP, err = iPv46(ab, fam); err != nil {
return err
}
if x.MaxIP, err = iPv46(ab, fam); err != nil {
return err
}
if x.MinPort, err = ab.Uint16BE(); err != nil {
return err
}
if x.MaxPort, err = ab.Uint16BE(); err != nil {
return err
}
return nil
}
func (x *NatRange2) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
if err := x.NatRange.marshalAB(fam, rev, &ab); err != nil {
return nil, err
}
ab.PutUint16BE(x.BasePort)
return ab.Data(), nil
}
func (x *NatRange2) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
var err error
if err = x.NatRange.unmarshalAB(fam, rev, &ab); err != nil {
return err
}
if x.BasePort, err = ab.Uint16BE(); err != nil {
return err
}
return nil
}

97
xt/target_dnat_test.go Normal file
View File

@ -0,0 +1,97 @@
package xt
import (
"net"
"reflect"
"testing"
"golang.org/x/sys/unix"
)
func TestTargetDNAT(t *testing.T) {
t.Parallel()
tests := []struct {
name string
fam byte
rev uint32
info InfoAny
empty InfoAny
}{
{
name: "un/marshal NatRange IPv4 round-trip",
fam: unix.NFPROTO_IPV4,
rev: 0,
info: &NatRange{
Flags: 0x1234,
MinIP: net.ParseIP("12.23.34.45").To4(),
MaxIP: net.ParseIP("21.32.43.54").To4(),
MinPort: 0x5678,
MaxPort: 0xabcd,
},
empty: &NatRange{},
},
{
name: "un/marshal NatRange IPv6 round-trip",
fam: unix.NFPROTO_IPV6,
rev: 0,
info: &NatRange{
Flags: 0x1234,
MinIP: net.ParseIP("fe80::dead:beef"),
MaxIP: net.ParseIP("fe80::c001:cafe"),
MinPort: 0x5678,
MaxPort: 0xabcd,
},
empty: &NatRange{},
},
{
name: "un/marshal NatRange2 IPv4 round-trip",
fam: unix.NFPROTO_IPV4,
rev: 0,
info: &NatRange2{
NatRange: NatRange{
Flags: 0x1234,
MinIP: net.ParseIP("12.23.34.45").To4(),
MaxIP: net.ParseIP("21.32.43.54").To4(),
MinPort: 0x5678,
MaxPort: 0xabcd,
},
BasePort: 0xfedc,
},
empty: &NatRange2{},
},
{
name: "un/marshal NatRange2 IPv6 round-trip",
fam: unix.NFPROTO_IPV6,
rev: 0,
info: &NatRange2{
NatRange: NatRange{
Flags: 0x1234,
MinIP: net.ParseIP("fe80::dead:beef"),
MaxIP: net.ParseIP("fe80::c001:cafe"),
MinPort: 0x5678,
MaxPort: 0xabcd,
},
BasePort: 0xfedc,
},
empty: &NatRange2{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(TableFamily(tt.fam), tt.rev)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = tt.empty
err = recoveredInfo.unmarshal(TableFamily(tt.fam), tt.rev, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

View File

@ -0,0 +1,86 @@
package xt
import (
"errors"
"net"
"github.com/google/nftables/alignedbuff"
)
// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L25
type NatIPv4Range struct {
Flags uint // sic!
MinIP net.IP
MaxIP net.IP
MinPort uint16
MaxPort uint16
}
// NatIPv4MultiRangeCompat despite being a slice of NAT IPv4 ranges is currently allowed to
// only hold exactly one element.
//
// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L33
type NatIPv4MultiRangeCompat []NatIPv4Range
func (x *NatIPv4MultiRangeCompat) marshal(fam TableFamily, rev uint32) ([]byte, error) {
ab := alignedbuff.New()
if len(*x) != 1 {
return nil, errors.New("MasqueradeIp must contain exactly one NatIPv4Range")
}
ab.PutUint(uint(len(*x)))
for _, nat := range *x {
if err := nat.marshalAB(fam, rev, &ab); err != nil {
return nil, err
}
}
return ab.Data(), nil
}
func (x *NatIPv4MultiRangeCompat) unmarshal(fam TableFamily, rev uint32, data []byte) error {
ab := alignedbuff.NewWithData(data)
l, err := ab.Uint()
if err != nil {
return err
}
nats := make(NatIPv4MultiRangeCompat, l)
for l > 0 {
l--
if err := nats[l].unmarshalAB(fam, rev, &ab); err != nil {
return err
}
}
*x = nats
return nil
}
func (x *NatIPv4Range) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
ab.PutUint(x.Flags)
ab.PutBytesAligned32(x.MinIP.To4(), 4)
ab.PutBytesAligned32(x.MaxIP.To4(), 4)
ab.PutUint16BE(x.MinPort)
ab.PutUint16BE(x.MaxPort)
return nil
}
func (x *NatIPv4Range) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error {
var err error
if x.Flags, err = ab.Uint(); err != nil {
return err
}
var ip []byte
if ip, err = ab.BytesAligned32(4); err != nil {
return err
}
x.MinIP = net.IP(ip)
if ip, err = ab.BytesAligned32(4); err != nil {
return err
}
x.MaxIP = net.IP(ip)
if x.MinPort, err = ab.Uint16BE(); err != nil {
return err
}
if x.MaxPort, err = ab.Uint16BE(); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,54 @@
package xt
import (
"net"
"reflect"
"testing"
"golang.org/x/sys/unix"
)
func TestTargetMasqueradeIP(t *testing.T) {
t.Parallel()
tests := []struct {
name string
fam byte
rev uint32
info InfoAny
empty InfoAny
}{
{
name: "un/marshal NatIPv4Range round-trip",
fam: unix.NFPROTO_IPV4,
rev: 0,
info: &NatIPv4MultiRangeCompat{
NatIPv4Range{
Flags: 0x1234,
MinIP: net.ParseIP("12.23.34.45").To4(),
MaxIP: net.ParseIP("21.32.43.54").To4(),
MinPort: 0x5678,
MaxPort: 0xabcd,
},
},
empty: new(NatIPv4MultiRangeCompat),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(TableFamily(tt.fam), tt.rev)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = tt.empty
err = recoveredInfo.unmarshal(TableFamily(tt.fam), tt.rev, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

17
xt/unknown.go Normal file
View File

@ -0,0 +1,17 @@
package xt
// Unknown represents the bytes Info payload for unknown Info types where no
// dedicated match/target info type has (yet) been defined.
type Unknown []byte
func (x *Unknown) marshal(fam TableFamily, rev uint32) ([]byte, error) {
// In case of unknown payload we assume its creator knows what she/he does
// and thus we don't do any alignment padding. Just take the payload "as
// is".
return *x, nil
}
func (x *Unknown) unmarshal(fam TableFamily, rev uint32, data []byte) error {
*x = data
return nil
}

38
xt/unknown_test.go Normal file
View File

@ -0,0 +1,38 @@
package xt
import (
"reflect"
"testing"
)
func TestUnknown(t *testing.T) {
t.Parallel()
payload := Unknown([]byte{0xb0, 0x1d, 0xca, 0xfe, 0x00})
tests := []struct {
name string
info InfoAny
}{
{
name: "un/marshal Unknown round-trip",
info: &payload,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := tt.info.marshal(0, 0)
if err != nil {
t.Fatalf("marshal error: %+v", err)
}
var recoveredInfo InfoAny = &Unknown{}
err = recoveredInfo.unmarshal(0, 0, data)
if err != nil {
t.Fatalf("unmarshal error: %+v", err)
}
if !reflect.DeepEqual(tt.info, recoveredInfo) {
t.Fatalf("original %+v and recovered %+v are different", tt.info, recoveredInfo)
}
})
}
}

64
xt/util.go Normal file
View File

@ -0,0 +1,64 @@
package xt
import (
"fmt"
"net"
"github.com/google/nftables/alignedbuff"
"golang.org/x/sys/unix"
)
func bool32(ab *alignedbuff.AlignedBuff) (bool, error) {
v, err := ab.Uint32()
if err != nil {
return false, err
}
if v != 0 {
return true, nil
}
return false, nil
}
func putBool32(ab *alignedbuff.AlignedBuff, b bool) {
if b {
ab.PutUint32(1)
return
}
ab.PutUint32(0)
}
func iPv46(ab *alignedbuff.AlignedBuff, fam TableFamily) (net.IP, error) {
ip, err := ab.BytesAligned32(16)
if err != nil {
return nil, err
}
switch fam {
case unix.NFPROTO_IPV4:
return net.IP(ip[:4]), nil
case unix.NFPROTO_IPV6:
return net.IP(ip), nil
default:
return nil, fmt.Errorf("unmarshal IP: unsupported table family %d", fam)
}
}
func iPv46Mask(ab *alignedbuff.AlignedBuff, fam TableFamily) (net.IPMask, error) {
v, err := iPv46(ab, fam)
return net.IPMask(v), err
}
func putIPv46(ab *alignedbuff.AlignedBuff, fam TableFamily, ip net.IP) error {
switch fam {
case unix.NFPROTO_IPV4:
ab.PutBytesAligned32(ip.To4(), 16)
case unix.NFPROTO_IPV6:
ab.PutBytesAligned32(ip.To16(), 16)
default:
return fmt.Errorf("marshal IP: unsupported table family %d", fam)
}
return nil
}
func putIPv46Mask(ab *alignedbuff.AlignedBuff, fam TableFamily, mask net.IPMask) error {
return putIPv46(ab, fam, net.IP(mask))
}

50
xt/xt.go Normal file
View File

@ -0,0 +1,50 @@
/*
Package xt implements dedicated types for (some) of the "Info" payload in Match
and Target expressions that bridge between the nftables and xtables worlds.
Bridging between the more unified world of nftables and the slightly
heterogenous world of xtables comes with some caveats. Unmarshalling the
extension/translation information in Match and Target expressions requires
information about the table family the information belongs to, as well as type
and type revision information. In consequence, unmarshalling the Match and
Target Info field payloads often (but not necessarily always) require the table
family and revision information, so it gets passed to the type-specific
unmarshallers.
To complicate things more, even marshalling requires knowledge about the
enclosing table family. The NatRange/NatRange2 types are an example, where it is
necessary to differentiate between IPv4 and IPv6 address marshalling. Due to
Go's net.IP habit to normally store IPv4 addresses as IPv4-compatible IPv6
addresses (see also RFC 4291, section 2.5.5.1) marshalling must be handled
differently in the context of an IPv6 table compared to an IPv4 table. In an
IPv4 table, an IPv4-compatible IPv6 address must be marshalled as a 32bit
address, whereas in an IPv6 table the IPv4 address must be marshalled as an
128bit IPv4-compatible IPv6 address. Not relying on heuristics here we avoid
behavior unexpected and most probably unknown to our API users. The net.IP habit
of storing IPv4 addresses in two different storage formats is already a source
for trouble, especially when comparing net.IPs from different Go module sources.
We won't add to this confusion. (...or maybe we can, because of it?)
An important property of all types of Info extension/translation payloads is
that their marshalling and unmarshalling doesn't follow netlink's TLV
(tag-length-value) architecture. Instead, Info payloads a basically plain binary
blobs of their respective type-specific data structures, so host
platform/architecture alignment and data type sizes apply. The alignedbuff
package implements the different required data types alignments.
Please note that Info payloads are always padded at their end to the next uint64
alignment. Kernel code is checking for the padded payload size and will reject
payloads not correctly padded at their ends.
Most of the time, we find explifcitly sized (unsigned integer) data types.
However, there are notable exceptions where "unsigned int" is used: on 64bit
platforms this mostly translates into 32bit(!). This differs from Go mapping
uint to uint64 instead. This package currently clamps its mapping of C's
"unsigned int" to Go's uint32 for marshalling and unmarshalling. If in the
future 128bit platforms with a differently sized C unsigned int should come into
production, then the alignedbuff package will need to be adapted accordingly, as
it abstracts away this data type handling.
*/
package xt