package nftables_test import ( "testing" "github.com/google/nftables" "github.com/mdlayher/netlink" ) func TestGetObjReset(t *testing.T) { // The want byte sequences come from stracing nft(8), e.g.: // strace -f -v -x -s 2048 -eraw=sendto nft list chain ip filter forward want := [][]byte{ []byte{0x2, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x66, 0x77, 0x64, 0x65, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1}, } // The reply messages come from adding log.Printf("msgs: %#v", msgs) to // (*github.com/mdlayher/netlink/Conn).receive reply := [][]netlink.Message{ nil, []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x64, Type: 0xa12, Flags: 0x802, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x2, 0x0, 0x0, 0x10, 0xb, 0x0, 0x1, 0x0, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x0, 0x0, 0xa, 0x0, 0x2, 0x0, 0x66, 0x77, 0x64, 0x65, 0x64, 0x0, 0x0, 0x0, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0x1, 0x1c, 0x0, 0x4, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x61, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9, 0xc, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2}}}, []netlink.Message{netlink.Message{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x9acb0443, PID: 0xde9}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, } c := &nftables.Conn{ TestDial: CheckNLReq(t, want, reply), } filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} objs, err := c.GetObjReset(&nftables.CounterObj{ Table: filter, Name: "fwded", }) if err != nil { t.Fatal(err) } if got, want := len(objs), 1; got != want { t.Fatalf("unexpected number of rules: got %d, want %d", got, want) } obj := objs[0] co, ok := obj.(*nftables.CounterObj) if !ok { t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj) } if got, want := co.Table.Name, filter.Name; got != want { t.Errorf("unexpected table name: got %q, want %q", got, want) } if got, want := co.Table.Family, filter.Family; got != want { t.Errorf("unexpected table family: got %d, want %d", got, want) } if got, want := co.Packets, uint64(9); got != want { t.Errorf("unexpected number of packets: got %d, want %d", got, want) } if got, want := co.Bytes, uint64(1121); got != want { t.Errorf("unexpected number of bytes: got %d, want %d", got, want) } }