Compare commits

..

2 Commits

Author SHA1 Message Date
Joe Williams 4f5cd5826f
add int32 and string types to alignedbuff (#195) 2022-10-15 21:04:45 +02:00
Andrew LeFevre d007ae63f1
fix queue expression getting skipped when unmarshaling rules (#197) 2022-10-15 19:08:15 +02:00
6 changed files with 275 additions and 0 deletions

View File

@ -118,6 +118,39 @@ func (a *AlignedBuff) Uint64() (uint64, error) {
return v, nil
}
// Int32 unmarshals an int32 in native endianess and alignment. It returns
// ErrEOF when trying to read beyond the payload.
func (a *AlignedBuff) Int32() (int32, error) {
if err := a.alignCheckedRead(int32AlignMask); err != nil {
return 0, err
}
v := binaryutil.Int32(a.data[a.pos : a.pos+4])
a.pos += 4
return v, nil
}
// String unmarshals a null terminated string
func (a *AlignedBuff) String() (string, error) {
len := 0
for {
if a.data[a.pos+len] == 0x00 {
break
}
len++
}
v := binaryutil.String(a.data[a.pos : a.pos+len])
a.pos += len
return v, nil
}
// Unmarshals a string of a given length (for non-null terminated strings)
func (a *AlignedBuff) StringWithLength(len int) (string, error) {
v := binaryutil.String(a.data[a.pos : a.pos+len])
a.pos += len
return v, nil
}
// Uint unmarshals an uint in native endianess and alignment for the C "unsigned
// int" type. It returns ErrEOF when trying to read beyond the payload. Please
// note that on 64bit platforms, the size and alignment of C's and Go's unsigned
@ -190,6 +223,19 @@ func (a *AlignedBuff) PutUint64(v uint64) {
a.pos += 8
}
// PutInt32 marshals an int32 in native endianess and alignment.
func (a *AlignedBuff) PutInt32(v int32) {
a.alignWrite(int32AlignMask)
a.data = append(a.data, binaryutil.PutInt32(v)...)
a.pos += 4
}
// PutString marshals a string.
func (a *AlignedBuff) PutString(v string) {
a.data = append(a.data, binaryutil.PutString(v)...)
a.pos += len(v)
}
// PutUint marshals an uint in native endianess and alignment for the C
// "unsigned int" type. Please note that on 64bit platforms, the size and
// alignment of C's and Go's unsigned integer data types differ, so we
@ -236,5 +282,7 @@ var uint32AlignMask = int(unsafe.Alignof(uint32(0)) - 1)
var uint64AlignMask = int(unsafe.Alignof(uint64(0)) - 1)
var padding = bytes.Repeat([]byte{0}, uint64AlignMask)
var int32AlignMask = int(unsafe.Alignof(int32(0)) - 1)
// And this even worse.
var uintSize = unsafe.Sizeof(uint32(0))

View File

@ -20,6 +20,9 @@ func TestAlignmentData(t *testing.T) {
if uintSize == 0 {
t.Fatal("zero uint size")
}
if int32AlignMask == 0 {
t.Fatal("zero uint32 alignment mask")
}
}
func TestAlignedBuff8(t *testing.T) {
@ -202,3 +205,114 @@ func TestAlignedUint(t *testing.T) {
t.Fatalf("sentinel read failed")
}
}
func TestAlignedBuffInt32(t *testing.T) {
b0 := New()
b0.PutUint8(0x42)
b0.PutInt32(0x12345678)
b0.PutInt32(0x01cecafe)
b := NewWithData(b0.data)
if len(b0.Data()) != 4*4 {
t.Fatalf("alignment padding failed")
}
v, err := b.Uint8()
if v != 0x42 || err != nil {
t.Fatalf("unaligment read failed")
}
tests := []struct {
name string
v int32
err error
}{
{
name: "first read",
v: 0x12345678,
err: nil,
},
{
name: "second read",
v: 0x01cecafe,
err: nil,
},
{
name: "end of buffer",
v: 0,
err: ErrEOF,
},
}
for _, tt := range tests {
v, err := b.Int32()
if v != tt.v || err != tt.err {
t.Errorf("expected: %#v %#v, got: %#v, %#v",
tt.v, tt.err, v, err)
}
}
}
func TestAlignedBuffPutNullTerminatedString(t *testing.T) {
b0 := New()
b0.PutUint8(0x42)
b0.PutString("test" + "\x00")
b := NewWithData(b0.data)
v, err := b.Uint8()
if v != 0x42 || err != nil {
t.Fatalf("unaligment read failed")
}
tests := []struct {
name string
v string
err error
}{
{
name: "first read",
v: "test",
err: nil,
},
}
for _, tt := range tests {
v, err := b.String()
if v != tt.v || err != tt.err {
t.Errorf("expected: %#v %#v, got: %#v, %#v",
tt.v, tt.err, v, err)
}
}
}
func TestAlignedBuffPutString(t *testing.T) {
b0 := New()
b0.PutUint8(0x42)
b0.PutString("test")
b := NewWithData(b0.data)
v, err := b.Uint8()
if v != 0x42 || err != nil {
t.Fatalf("unaligment read failed")
}
tests := []struct {
name string
v string
err error
}{
{
name: "first read",
v: "test",
err: nil,
},
}
for _, tt := range tests {
v, err := b.StringWithLength(len("test"))
if v != tt.v || err != tt.err {
t.Errorf("expected: %#v %#v, got: %#v, %#v",
tt.v, tt.err, v, err)
}
}
}

View File

@ -16,6 +16,7 @@
package binaryutil
import (
"bytes"
"encoding/binary"
"unsafe"
)
@ -102,3 +103,23 @@ func (bigEndian) Uint32(b []byte) uint32 {
func (bigEndian) Uint64(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}
// For dealing with types not supported by the encoding/binary interface
func PutInt32(v int32) []byte {
buf := make([]byte, 4)
*(*int32)(unsafe.Pointer(&buf[0])) = v
return buf
}
func Int32(b []byte) int32 {
return *(*int32)(unsafe.Pointer(&b[0]))
}
func PutString(s string) []byte {
return []byte(s)
}
func String(b []byte) string {
return string(bytes.TrimRight(b, "\x00"))
}

View File

@ -107,3 +107,36 @@ func TestBigEndian(t *testing.T) {
}
}
}
func TestOtherTypes(t *testing.T) {
tests := []struct {
name string
expected []byte
expectedv interface{}
actual []byte
unmarshal func(b []byte) interface{}
}{
{
name: "Int32",
expected: []byte{0x78, 0x56, 0x34, 0x12},
expectedv: int32(0x12345678),
actual: PutInt32(0x12345678),
unmarshal: func(b []byte) interface{} { return Int32(b) },
},
{
name: "String",
expected: []byte{0x74, 0x65, 0x73, 0x74},
expectedv: "test",
actual: PutString("test"),
unmarshal: func(b []byte) interface{} { return String(b) },
},
}
for _, tt := range tests {
if bytes.Compare(tt.actual, tt.expected) != 0 {
t.Errorf("Put%s failure, expected: %#v, got: %#v", tt.name, tt.expected, tt.actual)
}
if actual := tt.unmarshal(tt.actual); !reflect.DeepEqual(actual, tt.expectedv) {
t.Errorf("%s failure, expected: %#v, got: %#v", tt.name, tt.expectedv, actual)
}
}
}

View File

@ -126,6 +126,8 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er
e = &Target{}
case "connlimit":
e = &Connlimit{}
case "queue":
e = &Queue{}
}
if e == nil {
// TODO: introduce an opaque expression type so that users know

View File

@ -5711,3 +5711,60 @@ func TestGetRulesObjref(t *testing.T) {
t.Errorf("objref expr = %+v, wanted %+v", objref, want)
}
}
func TestGetRulesQueue(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()
table := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
chain := c.AddChain(&nftables.Chain{
Name: "forward",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
queueRule := c.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Queue{
Num: 1000,
Flag: expr.QueueFlagBypass,
},
},
})
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRules(table, chain)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 1; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
queueExpr, ok := rules[0].Exprs[0].(*expr.Queue)
if !ok {
t.Fatalf("Exprs[0] is type %T, want *expr.Queue", rules[0].Exprs[0])
}
if want := queueRule.Exprs[0]; !reflect.DeepEqual(queueExpr, want) {
t.Errorf("queue expr = %+v, wanted %+v", queueExpr, want)
}
}