OpenFPGA/openfpga_flow/benchmarks/vtr_benchmark/matmul_8x8_fp16.v

1663 lines
52 KiB
Verilog

////////////////////////////////////////////////
// Matrix multiplication design
// Multiplies 8x8 matrix (A) with another 8x8 matrix (B)
// to produce an 8x8 matrix (C).
// Data precision is IEEE floating point 16-bit (half precision)
// The architecture is systolic in nature (output stationary).
// 4 4x4 matmuls composed to make a larger 8x8 matmul.
// There is state machine for control and an APB
// interface for programming/configuring.
// Matrices are stores in RAM blocks.
///////////////////////////////////////////////
`timescale 1ns/1ns
`define DWIDTH 16
`define AWIDTH 10
`define MEM_SIZE 1024
`define DESIGN_SIZE 8
`define MAT_MUL_SIZE 4
`define MASK_WIDTH 4
`define LOG2_MAT_MUL_SIZE 2
`define NUM_CYCLES_IN_MAC 3
`define MEM_ACCESS_LATENCY 1
`define REG_DATAWIDTH 32
`define REG_ADDRWIDTH 8
`define ADDR_STRIDE_WIDTH 16
`define REG_STDN_TPU_ADDR 32'h4
`define REG_MATRIX_A_ADDR 32'he
`define REG_MATRIX_B_ADDR 32'h12
`define REG_MATRIX_C_ADDR 32'h16
`define REG_VALID_MASK_A_ROWS_ADDR 32'h20
`define REG_VALID_MASK_A_COLS_ADDR 32'h54
`define REG_VALID_MASK_B_ROWS_ADDR 32'h5c
`define REG_VALID_MASK_B_COLS_ADDR 32'h58
`define REG_MATRIX_A_STRIDE_ADDR 32'h28
`define REG_MATRIX_B_STRIDE_ADDR 32'h32
`define REG_MATRIX_C_STRIDE_ADDR 32'h36
module matrix_multiplication(
input clk,
input clk_mem,
input resetn,
input pe_resetn,
input PRESETn,
input [`REG_ADDRWIDTH-1:0] PADDR,
input PWRITE,
input PSEL,
input PENABLE,
input [`REG_DATAWIDTH-1:0] PWDATA,
output reg [`REG_DATAWIDTH-1:0] PRDATA,
output reg PREADY,
input [7:0] bram_select,
input [`AWIDTH-1:0] bram_addr_ext,
output reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_ext,
input [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_ext,
input [`MAT_MUL_SIZE-1:0] bram_we_ext
);
wire PCLK;
assign PCLK = clk;
reg start_reg;
reg clear_done_reg;
//Dummy register to sync all other invalid/unimplemented addresses
reg [`REG_DATAWIDTH-1:0] reg_dummy;
reg [`AWIDTH-1:0] bram_addr_a_0_0_ext;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_0_0_ext;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_0_0_ext;
reg [`MASK_WIDTH-1:0] bram_we_a_0_0_ext;
reg [`AWIDTH-1:0] bram_addr_a_1_0_ext;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_1_0_ext;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_1_0_ext;
reg [`MASK_WIDTH-1:0] bram_we_a_1_0_ext;
reg [`AWIDTH-1:0] bram_addr_b_0_0_ext;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_0_ext;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_0_ext;
reg [`MASK_WIDTH-1:0] bram_we_b_0_0_ext;
reg [`AWIDTH-1:0] bram_addr_b_0_1_ext;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_1_ext;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_1_ext;
reg [`MASK_WIDTH-1:0] bram_we_b_0_1_ext;
reg [`AWIDTH-1:0] bram_addr_c_0_1_ext;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_0_1_ext;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_0_1_ext;
reg [`MASK_WIDTH-1:0] bram_we_c_0_1_ext;
reg [`AWIDTH-1:0] bram_addr_c_1_1_ext;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_1_1_ext;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_1_1_ext;
reg [`MASK_WIDTH-1:0] bram_we_c_1_1_ext;
wire [`AWIDTH-1:0] bram_addr_a_0_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_0_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_0_0;
wire [`MASK_WIDTH-1:0] bram_we_a_0_0;
wire bram_en_a_0_0;
wire [`AWIDTH-1:0] bram_addr_a_1_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_1_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_1_0;
wire [`MASK_WIDTH-1:0] bram_we_a_1_0;
wire bram_en_a_1_0;
wire [`AWIDTH-1:0] bram_addr_b_0_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_0;
wire [`MASK_WIDTH-1:0] bram_we_b_0_0;
wire bram_en_b_0_0;
wire [`AWIDTH-1:0] bram_addr_b_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_1;
wire [`MASK_WIDTH-1:0] bram_we_b_0_1;
wire bram_en_b_0_1;
wire [`AWIDTH-1:0] bram_addr_c_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_0_1;
wire [`MASK_WIDTH-1:0] bram_we_c_0_1;
wire bram_en_c_0_1;
wire [`AWIDTH-1:0] bram_addr_c_1_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_1_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_1_1;
wire [`MASK_WIDTH-1:0] bram_we_c_1_1;
wire bram_en_c_1_1;
always @* begin
case (bram_select)
0: begin
bram_addr_a_0_0_ext = bram_addr_ext;
bram_wdata_a_0_0_ext = bram_wdata_ext;
bram_we_a_0_0_ext = bram_we_ext;
bram_rdata_ext = bram_rdata_a_0_0_ext;
end
1: begin
bram_addr_a_1_0_ext = bram_addr_ext;
bram_wdata_a_1_0_ext = bram_wdata_ext;
bram_we_a_1_0_ext = bram_we_ext;
bram_rdata_ext = bram_rdata_a_1_0_ext;
end
2: begin
bram_addr_b_0_0_ext = bram_addr_ext;
bram_wdata_b_0_0_ext = bram_wdata_ext;
bram_we_b_0_0_ext = bram_we_ext;
bram_rdata_ext = bram_rdata_b_0_0_ext;
end
3: begin
bram_addr_b_0_1_ext = bram_addr_ext;
bram_wdata_b_0_1_ext = bram_wdata_ext;
bram_we_b_0_1_ext = bram_we_ext;
bram_rdata_ext = bram_rdata_b_0_1_ext;
end
4: begin
bram_addr_c_0_1_ext = bram_addr_ext;
bram_wdata_c_0_1_ext = bram_wdata_ext;
bram_we_c_0_1_ext = bram_we_ext;
bram_rdata_ext = bram_rdata_c_0_1_ext;
end
5: begin
bram_addr_c_1_1_ext = bram_addr_ext;
bram_wdata_c_1_1_ext = bram_wdata_ext;
bram_we_c_1_1_ext = bram_we_ext;
bram_rdata_ext = bram_rdata_c_1_1_ext;
end
default: begin
bram_rdata_ext = 0;
end
endcase
end
/////////////////////////////////////////////////
// BRAMs to store matrix A
/////////////////////////////////////////////////
// BRAM matrix A 0_0 (bank0)
ram matrix_A_0_0(
.addr0(bram_addr_a_0_0),
.d0(bram_wdata_a_0_0),
.we0(bram_we_a_0_0),
.q0(bram_rdata_a_0_0),
.addr1(bram_addr_a_0_0_ext),
.d1(bram_wdata_a_0_0_ext),
.we1(bram_we_a_0_0_ext),
.q1(bram_rdata_a_0_0_ext),
.clk(clk_mem));
// BRAM matrix A 1_0 (bank1)
ram matrix_A_1_0(
.addr0(bram_addr_a_1_0),
.d0(bram_wdata_a_1_0),
.we0(bram_we_a_1_0),
.q0(bram_rdata_a_1_0),
.addr1(bram_addr_a_1_0_ext),
.d1(bram_wdata_a_1_0_ext),
.we1(bram_we_a_1_0_ext),
.q1(bram_rdata_a_1_0_ext),
.clk(clk_mem));
/////////////////////////////////////////////////
// BRAMs to store matrix B
/////////////////////////////////////////////////
// BRAM matrix B 0_0 (bank0)
ram matrix_B_0_0(
.addr0(bram_addr_b_0_0),
.d0(bram_wdata_b_0_0),
.we0(bram_we_b_0_0),
.q0(bram_rdata_b_0_0),
.addr1(bram_addr_b_0_0_ext),
.d1(bram_wdata_b_0_0_ext),
.we1(bram_we_b_0_0_ext),
.q1(bram_rdata_b_0_0_ext),
.clk(clk_mem));
// BRAM matrix B 0_1 (bank1)
ram matrix_B_0_1(
.addr0(bram_addr_b_0_1),
.d0(bram_wdata_b_0_1),
.we0(bram_we_b_0_1),
.q0(bram_rdata_b_0_1),
.addr1(bram_addr_b_0_1_ext),
.d1(bram_wdata_b_0_1_ext),
.we1(bram_we_b_0_1_ext),
.q1(bram_rdata_b_0_1_ext),
.clk(clk_mem));
/////////////////////////////////////////////////
// BRAMs to store matrix C
/////////////////////////////////////////////////
// BRAM matrix C 0_1 (bank0)
ram matrix_C_0_1(
.addr0(bram_addr_c_0_1),
.d0(bram_wdata_c_0_1),
.we0(bram_we_c_0_1),
.q0(bram_rdata_c_0_1),
.addr1(bram_addr_c_0_1_ext),
.d1(bram_wdata_c_0_1_ext),
.we1(bram_we_c_0_1_ext),
.q1(bram_rdata_c_0_1_ext),
.clk(clk_mem));
// BRAM matrix C 1_1 (bank1)
ram matrix_C_1_1(
.addr0(bram_addr_c_1_1),
.d0(bram_wdata_c_1_1),
.we0(bram_we_c_1_1),
.q0(bram_rdata_c_1_1),
.addr1(bram_addr_c_1_1_ext),
.d1(bram_wdata_c_1_1_ext),
.we1(bram_we_c_1_1_ext),
.q1(bram_rdata_c_1_1_ext),
.clk(clk_mem));
reg start_mat_mul;
wire done_mat_mul;
reg [3:0] state;
////////////////////////////////////////////////////////////////
// Control logic
////////////////////////////////////////////////////////////////
always @( posedge clk) begin
if (resetn == 1'b0) begin
state <= 4'b0000;
start_mat_mul <= 1'b0;
end
else begin
case (state)
4'b0000: begin
start_mat_mul <= 1'b0;
if (start_reg == 1'b1) begin
state <= 4'b0001;
end else begin
state <= 4'b0000;
end
end
4'b0001: begin
start_mat_mul <= 1'b1;
state <= 4'b1010;
end
4'b1010: begin
if (done_mat_mul == 1'b1) begin
start_mat_mul <= 1'b0;
state <= 4'b1000;
end
else begin
state <= 4'b1010;
end
end
4'b1000: begin
if (clear_done_reg == 1'b1) begin
state <= 4'b0000;
end
else begin
state <= 4'b1000;
end
end
endcase
end
end
reg [1:0] state_apb;
`define IDLE 2'b00
`define W_ENABLE 2'b01
`define R_ENABLE 2'b10
reg [`AWIDTH-1:0] address_mat_a;
reg [`AWIDTH-1:0] address_mat_b;
reg [`AWIDTH-1:0] address_mat_c;
reg [`MASK_WIDTH-1:0] validity_mask_a_rows;
reg [`MASK_WIDTH-1:0] validity_mask_a_cols;
reg [`MASK_WIDTH-1:0] validity_mask_b_rows;
reg [`MASK_WIDTH-1:0] validity_mask_b_cols;
reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_a;
reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_b;
reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_c;
////////////////////////////////////////////////////////////////
// Configuration logic
////////////////////////////////////////////////////////////////
always @(posedge PCLK) begin
if (PRESETn == 0) begin
state_apb <= `IDLE;
PRDATA <= 0;
PREADY <= 0;
address_mat_a <= 0;
address_mat_b <= 0;
address_mat_c <= 0;
validity_mask_a_rows <= {`MASK_WIDTH{1'b1}};
validity_mask_a_cols <= {`MASK_WIDTH{1'b1}};
validity_mask_b_rows <= {`MASK_WIDTH{1'b1}};
validity_mask_b_cols <= {`MASK_WIDTH{1'b1}};
address_stride_a <= `MAT_MUL_SIZE;
address_stride_b <= `MAT_MUL_SIZE;
address_stride_c <= `MAT_MUL_SIZE;
end
else begin
case (state_apb)
`IDLE : begin
PRDATA <= 0;
if (PSEL) begin
if (PWRITE) begin
state_apb <= `W_ENABLE;
end
else begin
state_apb <= `R_ENABLE;
end
end
PREADY <= 0;
end
`W_ENABLE : begin
if (PSEL && PWRITE && PENABLE) begin
case (PADDR)
`REG_STDN_TPU_ADDR : begin
start_reg <= PWDATA[0];
clear_done_reg <= PWDATA[31];
end
`REG_MATRIX_A_ADDR : address_mat_a <= PWDATA[`AWIDTH-1:0];
`REG_MATRIX_B_ADDR : address_mat_b <= PWDATA[`AWIDTH-1:0];
`REG_MATRIX_C_ADDR : address_mat_c <= PWDATA[`AWIDTH-1:0];
`REG_VALID_MASK_A_ROWS_ADDR: begin
validity_mask_a_rows <= PWDATA[`MASK_WIDTH-1:0];
end
`REG_VALID_MASK_A_COLS_ADDR: begin
validity_mask_a_cols <= PWDATA[`MASK_WIDTH-1:0];
end
`REG_VALID_MASK_B_ROWS_ADDR: begin
validity_mask_b_rows <= PWDATA[`MASK_WIDTH-1:0];
end
`REG_VALID_MASK_B_COLS_ADDR: begin
validity_mask_b_cols <= PWDATA[`MASK_WIDTH-1:0];
end
`REG_MATRIX_A_STRIDE_ADDR : address_stride_a <= PWDATA[`ADDR_STRIDE_WIDTH-1:0];
`REG_MATRIX_B_STRIDE_ADDR : address_stride_b <= PWDATA[`ADDR_STRIDE_WIDTH-1:0];
`REG_MATRIX_C_STRIDE_ADDR : address_stride_c <= PWDATA[`ADDR_STRIDE_WIDTH-1:0];
default : reg_dummy <= PWDATA; //sink writes to a dummy register
endcase
PREADY <=1;
end
state_apb <= `IDLE;
end
`R_ENABLE : begin
if (PSEL && !PWRITE && PENABLE) begin
PREADY <= 1;
case (PADDR)
`REG_STDN_TPU_ADDR : PRDATA <= {done_mat_mul, 30'b0, start_mat_mul};
`REG_MATRIX_A_ADDR : PRDATA <= address_mat_a;
`REG_MATRIX_B_ADDR : PRDATA <= address_mat_b;
`REG_MATRIX_C_ADDR : PRDATA <= address_mat_c;
`REG_VALID_MASK_A_ROWS_ADDR: PRDATA <= validity_mask_a_rows;
`REG_VALID_MASK_A_COLS_ADDR: PRDATA <= validity_mask_a_cols;
`REG_VALID_MASK_B_ROWS_ADDR: PRDATA <= validity_mask_b_rows;
`REG_VALID_MASK_B_COLS_ADDR: PRDATA <= validity_mask_b_cols;
`REG_MATRIX_A_STRIDE_ADDR : PRDATA <= address_stride_a;
`REG_MATRIX_B_STRIDE_ADDR : PRDATA <= address_stride_b;
`REG_MATRIX_C_STRIDE_ADDR : PRDATA <= address_stride_c;
default : PRDATA <= reg_dummy; //read the dummy register for undefined addresses
endcase
end
state_apb <= `IDLE;
end
default: begin
state_apb <= `IDLE;
end
endcase
end
end
wire reset;
assign reset = ~resetn;
wire pe_reset;
assign pe_reset = ~pe_resetn;
wire c_data_0_1_available;
assign bram_en_c_0_1 = 1'b1;
assign bram_we_c_0_1 = (c_data_0_1_available) ? {`MASK_WIDTH{1'b1}} : {`MASK_WIDTH{1'b0}};
wire c_data_1_1_available;
assign bram_en_c_1_1 = 1'b1;
assign bram_we_c_1_1 = (c_data_1_1_available) ? {`MASK_WIDTH{1'b1}} : {`MASK_WIDTH{1'b0}};
assign bram_wdata_a_0_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
assign bram_en_a_0_0 = 1'b1;
assign bram_we_a_0_0 = {`MASK_WIDTH{1'b0}};
assign bram_wdata_a_1_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
assign bram_en_a_1_0 = 1'b1;
assign bram_we_a_1_0 = {`MASK_WIDTH{1'b0}};
assign bram_wdata_b_0_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
assign bram_en_b_0_0 = 1'b1;
assign bram_we_b_0_0 = {`MASK_WIDTH{1'b0}};
assign bram_wdata_b_0_1 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
assign bram_en_b_0_1 = 1'b1;
assign bram_we_b_0_1 = {`MASK_WIDTH{1'b0}};
/////////////////////////////////////////////////
// The 8x8 matmul instantiation
/////////////////////////////////////////////////
matmul_8x8_systolic u_matmul_8x8_systolic (
.clk(clk),
.reset(reset),
.pe_reset(pe_reset),
.start_mat_mul(start_mat_mul),
.done_mat_mul(done_mat_mul),
.address_mat_a(address_mat_a),
.address_mat_b(address_mat_b),
.address_mat_c(address_mat_c),
.address_stride_a(address_stride_a),
.address_stride_b(address_stride_b),
.address_stride_c(address_stride_c),
.a_data_0_0(bram_rdata_a_0_0),
.b_data_0_0(bram_rdata_b_0_0),
.a_addr_0_0(bram_addr_a_0_0),
.b_addr_0_0(bram_addr_b_0_0),
.a_data_1_0(bram_rdata_a_1_0),
.b_data_0_1(bram_rdata_b_0_1),
.a_addr_1_0(bram_addr_a_1_0),
.b_addr_0_1(bram_addr_b_0_1),
.c_data_0_1(bram_wdata_c_0_1),
.c_addr_0_1(bram_addr_c_0_1),
.c_data_0_1_available(c_data_0_1_available),
.c_data_1_1(bram_wdata_c_1_1),
.c_addr_1_1(bram_addr_c_1_1),
.c_data_1_1_available(c_data_1_1_available),
.validity_mask_a_rows(validity_mask_a_rows),
.validity_mask_a_cols(validity_mask_a_cols),
.validity_mask_b_rows(validity_mask_b_rows),
.validity_mask_b_cols(validity_mask_b_cols)
);
endmodule
/////////////////////////////////////////////////
// The 8x8 matmul definition
/////////////////////////////////////////////////
module matmul_8x8_systolic(
input clk,
input reset,
input pe_reset,
input start_mat_mul,
output done_mat_mul,
input [`AWIDTH-1:0] address_mat_a,
input [`AWIDTH-1:0] address_mat_b,
input [`AWIDTH-1:0] address_mat_c,
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a,
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b,
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c,
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_0,
output [`AWIDTH-1:0] a_addr_0_0,
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_0,
output [`AWIDTH-1:0] b_addr_0_0,
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_0,
output [`AWIDTH-1:0] a_addr_1_0,
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_1,
output [`AWIDTH-1:0] b_addr_0_1,
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_1,
output [`AWIDTH-1:0] c_addr_0_1,
output c_data_0_1_available,
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_1,
output [`AWIDTH-1:0] c_addr_1_1,
output c_data_1_1_available,
input [`MASK_WIDTH-1:0] validity_mask_a_rows,
input [`MASK_WIDTH-1:0] validity_mask_a_cols,
input [`MASK_WIDTH-1:0] validity_mask_b_rows,
input [`MASK_WIDTH-1:0] validity_mask_b_cols
);
/////////////////////////////////////////////////
// ORing all done signals
/////////////////////////////////////////////////
wire done_mat_mul_0_0;
wire done_mat_mul_0_1;
wire done_mat_mul_1_0;
wire done_mat_mul_1_1;
assign done_mat_mul = done_mat_mul_0_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_1_NC;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_1_NC;
/////////////////////////////////////////////////
// Matmul 0_0
/////////////////////////////////////////////////
wire [3:0] flags_NC_0_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_0_to_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_0_to_1_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in_0_0_NC;
assign a_data_in_0_0_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in_0_0_NC;
assign c_data_in_0_0_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in_0_0_NC;
assign b_data_in_0_0_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_0_to_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_0_to_0_1_NC;
wire [`AWIDTH-1:0] c_addr_0_0_NC;
wire c_data_0_0_available_NC;
matmul_4x4_systolic u_matmul_4x4_systolic_0_0(
.clk(clk),
.reset(reset),
.pe_reset(pe_reset),
.start_mat_mul(start_mat_mul),
.done_mat_mul(done_mat_mul_0_0),
.address_mat_a(address_mat_a),
.address_mat_b(address_mat_b),
.address_mat_c(address_mat_c),
.address_stride_a(address_stride_a),
.address_stride_b(address_stride_b),
.address_stride_c(address_stride_c),
.a_data(a_data_0_0),
.b_data(b_data_0_0),
.a_data_in(a_data_in_0_0_NC),
.b_data_in(b_data_in_0_0_NC),
.c_data_in(c_data_in_0_0_NC),
.c_data_out(c_data_0_0_to_0_1_NC),
.a_data_out(a_data_0_0_to_0_1),
.b_data_out(b_data_0_0_to_1_0),
.a_addr(a_addr_0_0),
.b_addr(b_addr_0_0),
.c_addr(c_addr_0_0_NC),
.c_data_available(c_data_0_0_available_NC),
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
.final_mat_mul_size(8'd8),
.a_loc(8'd0),
.b_loc(8'd0)
);
/////////////////////////////////////////////////
// Matmul 0_1
/////////////////////////////////////////////////
wire [3:0] flags_NC_0_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_1_to_0_2;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_1_to_1_1;
wire [`AWIDTH-1:0] a_addr_0_1_NC;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_1_NC;
assign a_data_0_1_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in_0_1_NC;
assign b_data_in_0_1_NC = 0;
matmul_4x4_systolic u_matmul_4x4_systolic_0_1(
.clk(clk),
.reset(reset),
.pe_reset(pe_reset),
.start_mat_mul(start_mat_mul),
.done_mat_mul(done_mat_mul_0_1),
.address_mat_a(address_mat_a),
.address_mat_b(address_mat_b),
.address_mat_c(address_mat_c),
.address_stride_a(address_stride_a),
.address_stride_b(address_stride_b),
.address_stride_c(address_stride_c),
.a_data(a_data_0_1_NC),
.b_data(b_data_0_1),
.a_data_in(a_data_0_0_to_0_1),
.b_data_in(b_data_in_0_1_NC),
.c_data_in(c_data_0_0_to_0_1),
.c_data_out(c_data_0_1),
.a_data_out(a_data_0_1_to_0_2),
.b_data_out(b_data_0_1_to_1_1),
.a_addr(a_addr_0_1_NC),
.b_addr(b_addr_0_1),
.c_addr(c_addr_0_1),
.c_data_available(c_data_0_1_available),
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
.final_mat_mul_size(8'd8),
.a_loc(8'd0),
.b_loc(8'd1)
);
/////////////////////////////////////////////////
// Matmul 1_0
/////////////////////////////////////////////////
wire [3:0] flags_NC_1_0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_0_to_1_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_0_to_2_0;
wire [`AWIDTH-1:0] b_addr_1_0_NC;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_0_NC;
assign b_data_1_0_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in_1_0_NC;
assign a_data_in_1_0_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in_1_0_NC;
assign c_data_in_1_0_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_0_to_1_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_0_to_1_1_NC;
wire [`AWIDTH-1:0] c_addr_1_0_NC;
wire c_data_1_0_available_NC;
matmul_4x4_systolic u_matmul_4x4_systolic_1_0(
.clk(clk),
.reset(reset),
.pe_reset(pe_reset),
.start_mat_mul(start_mat_mul),
.done_mat_mul(done_mat_mul_1_0),
.address_mat_a(address_mat_a),
.address_mat_b(address_mat_b),
.address_mat_c(address_mat_c),
.address_stride_a(address_stride_a),
.address_stride_b(address_stride_b),
.address_stride_c(address_stride_c),
.a_data(a_data_1_0),
.b_data(b_data_1_0_NC),
.a_data_in(a_data_in_1_0_NC),
.b_data_in(b_data_0_0_to_1_0),
.c_data_in(c_data_in_1_0_NC),
.c_data_out(c_data_1_0_to_1_1_NC),
.a_data_out(a_data_1_0_to_1_1),
.b_data_out(b_data_1_0_to_2_0),
.a_addr(a_addr_1_0),
.b_addr(b_addr_1_0_NC),
.c_addr(c_addr_1_0_NC),
.c_data_available(c_data_1_0_available_NC),
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
.final_mat_mul_size(8'd8),
.a_loc(8'd1),
.b_loc(8'd0)
);
/////////////////////////////////////////////////
// Matmul 1_1
/////////////////////////////////////////////////
wire [3:0] flags_NC_1_1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_1_to_1_2;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_1_to_2_1;
wire [`AWIDTH-1:0] a_addr_1_1_NC;
wire [`AWIDTH-1:0] b_addr_1_1_NC;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_1_NC;
assign a_data_1_1_NC = 0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_1_NC;
assign b_data_1_1_NC = 0;
matmul_4x4_systolic u_matmul_4x4_systolic_1_1(
.clk(clk),
.reset(reset),
.pe_reset(pe_reset),
.start_mat_mul(start_mat_mul),
.done_mat_mul(done_mat_mul_1_1),
.address_mat_a(address_mat_a),
.address_mat_b(address_mat_b),
.address_mat_c(address_mat_c),
.address_stride_a(address_stride_a),
.address_stride_b(address_stride_b),
.address_stride_c(address_stride_c),
.a_data(a_data_1_1_NC),
.b_data(b_data_1_1_NC),
.a_data_in(a_data_1_0_to_1_1),
.b_data_in(b_data_0_1_to_1_1),
.c_data_in(c_data_1_0_to_1_1),
.c_data_out(c_data_1_1),
.a_data_out(a_data_1_1_to_1_2),
.b_data_out(b_data_1_1_to_2_1),
.a_addr(a_addr_1_1_NC),
.b_addr(b_addr_1_1_NC),
.c_addr(c_addr_1_1),
.c_data_available(c_data_1_1_available),
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
.final_mat_mul_size(8'd8),
.a_loc(8'd1),
.b_loc(8'd1)
);
endmodule
//////////////////////////////////
//Dual port RAM
//////////////////////////////////
module ram (
addr0,
d0,
we0,
q0,
addr1,
d1,
we1,
q1,
clk);
input [`AWIDTH-1:0] addr0;
input [`AWIDTH-1:0] addr1;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] d0;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] d1;
input [`MAT_MUL_SIZE-1:0] we0;
input [`MAT_MUL_SIZE-1:0] we1;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] q0;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] q1;
input clk;
`ifdef VCS
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] q0;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] q1;
reg [7:0] ram[((1<<`AWIDTH)-1):0];
integer i;
always @(posedge clk)
begin
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
if (we0[i]) ram[addr0+i] <= d0[i*`DWIDTH +: `DWIDTH];
end
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
q0[i*`DWIDTH +: `DWIDTH] <= ram[addr0+i];
end
end
always @(posedge clk)
begin
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
if (we1[i]) ram[addr0+i] <= d1[i*`DWIDTH +: `DWIDTH];
end
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
q1[i*`DWIDTH +: `DWIDTH] <= ram[addr1+i];
end
end
`else
//BRAMs available in VTR FPGA architectures have one bit write-enables.
//So let's combine multiple bits into 1. We don't have a usecase of
//writing/not-writing only parts of the word anyway.
wire we0_coalesced;
assign we0_coalesced = |we0;
wire we1_coalesced;
assign we1_coalesced = |we1;
dual_port_ram u_dual_port_ram(
.addr1(addr0),
.we1(we0_coalesced),
.data1(d0),
.out1(q0),
.addr2(addr1),
.we2(we1_coalesced),
.data2(d1),
.out2(q1),
.clk(clk)
);
`endif
endmodule
//////////////////////////////////////////////
//4x4 systolic matrix multiplier
//////////////////////////////////////////////
module matmul_4x4_systolic(
clk,
reset,
pe_reset,
start_mat_mul,
done_mat_mul,
address_mat_a,
address_mat_b,
address_mat_c,
address_stride_a,
address_stride_b,
address_stride_c,
a_data,
b_data,
a_data_in, //Data values coming in from previous matmul - systolic connections
b_data_in, //Data values coming in from previous matmul - systolic connections
c_data_in, //Data values coming in from previous matmul - systolic shifting
c_data_out,//Data values going out to next matmul - systolic shifting
a_data_out,
b_data_out,
a_addr,
b_addr,
c_addr,
c_data_available,
validity_mask_a_rows,
validity_mask_a_cols,
validity_mask_b_rows,
validity_mask_b_cols,
final_mat_mul_size,
a_loc,
b_loc
);
input clk;
input reset;
input pe_reset;
input start_mat_mul;
output done_mat_mul;
input [`AWIDTH-1:0] address_mat_a;
input [`AWIDTH-1:0] address_mat_b;
input [`AWIDTH-1:0] address_mat_c;
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a;
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b;
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_out;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_out;
output [`AWIDTH-1:0] a_addr;
output [`AWIDTH-1:0] b_addr;
output [`AWIDTH-1:0] c_addr;
output c_data_available;
input [`MASK_WIDTH-1:0] validity_mask_a_rows;
input [`MASK_WIDTH-1:0] validity_mask_a_cols;
input [`MASK_WIDTH-1:0] validity_mask_b_rows;
input [`MASK_WIDTH-1:0] validity_mask_b_cols;
//7:0 is okay here. We aren't going to make a matmul larger than 128x128
//In fact, these will get optimized out by the synthesis tool, because
//we hardcode them at the instantiation level.
input [7:0] final_mat_mul_size;
input [7:0] a_loc;
input [7:0] b_loc;
//////////////////////////////////////////////////////////////////////////
// Logic for clock counting and when to assert done
//////////////////////////////////////////////////////////////////////////
reg done_mat_mul;
//This is 7 bits because the expectation is that clock count will be pretty
//small. For large matmuls, this will need to increased to have more bits.
//In general, a systolic multiplier takes 4*N-2+P cycles, where N is the size
//of the matmul and P is the number of pipleine stages in the MAC block.
reg [7:0] clk_cnt;
//Finding out number of cycles to assert matmul done.
//When we have to save the outputs to accumulators, then we don't need to
//shift out data. So, we can assert done_mat_mul early.
//In the normal case, we have to include the time to shift out the results.
//Note: the count expression used to contain "4*final_mat_mul_size", but
//to avoid multiplication, we now use "final_mat_mul_size<<2"
wire [7:0] clk_cnt_for_done;
assign clk_cnt_for_done = ((final_mat_mul_size<<2) - 2 + `NUM_CYCLES_IN_MAC) ;
always @(posedge clk) begin
if (reset || ~start_mat_mul) begin
clk_cnt <= 0;
done_mat_mul <= 0;
end
else if (clk_cnt == clk_cnt_for_done) begin
done_mat_mul <= 1;
clk_cnt <= clk_cnt + 1;
end
else if (done_mat_mul == 0) begin
clk_cnt <= clk_cnt + 1;
end
else begin
done_mat_mul <= 0;
clk_cnt <= clk_cnt + 1;
end
end
wire [`DWIDTH-1:0] a0_data;
wire [`DWIDTH-1:0] a1_data;
wire [`DWIDTH-1:0] a2_data;
wire [`DWIDTH-1:0] a3_data;
wire [`DWIDTH-1:0] b0_data;
wire [`DWIDTH-1:0] b1_data;
wire [`DWIDTH-1:0] b2_data;
wire [`DWIDTH-1:0] b3_data;
wire [`DWIDTH-1:0] a1_data_delayed_1;
wire [`DWIDTH-1:0] a2_data_delayed_1;
wire [`DWIDTH-1:0] a2_data_delayed_2;
wire [`DWIDTH-1:0] a3_data_delayed_1;
wire [`DWIDTH-1:0] a3_data_delayed_2;
wire [`DWIDTH-1:0] a3_data_delayed_3;
wire [`DWIDTH-1:0] b1_data_delayed_1;
wire [`DWIDTH-1:0] b2_data_delayed_1;
wire [`DWIDTH-1:0] b2_data_delayed_2;
wire [`DWIDTH-1:0] b3_data_delayed_1;
wire [`DWIDTH-1:0] b3_data_delayed_2;
wire [`DWIDTH-1:0] b3_data_delayed_3;
//////////////////////////////////////////////////////////////////////////
// Instantiation of systolic data setup
//////////////////////////////////////////////////////////////////////////
systolic_data_setup u_systolic_data_setup(
.clk(clk),
.reset(reset),
.start_mat_mul(start_mat_mul),
.a_addr(a_addr),
.b_addr(b_addr),
.address_mat_a(address_mat_a),
.address_mat_b(address_mat_b),
.address_stride_a(address_stride_a),
.address_stride_b(address_stride_b),
.a_data(a_data),
.b_data(b_data),
.clk_cnt(clk_cnt),
.a0_data(a0_data),
.a1_data_delayed_1(a1_data_delayed_1),
.a2_data_delayed_2(a2_data_delayed_2),
.a3_data_delayed_3(a3_data_delayed_3),
.b0_data(b0_data),
.b1_data_delayed_1(b1_data_delayed_1),
.b2_data_delayed_2(b2_data_delayed_2),
.b3_data_delayed_3(b3_data_delayed_3),
.validity_mask_a_rows(validity_mask_a_rows),
.validity_mask_a_cols(validity_mask_a_cols),
.validity_mask_b_rows(validity_mask_b_rows),
.validity_mask_b_cols(validity_mask_b_cols),
.final_mat_mul_size(final_mat_mul_size),
.a_loc(a_loc),
.b_loc(b_loc)
);
//////////////////////////////////////////////////////////////////////////
// Logic to mux data_in coming from neighboring matmuls
//////////////////////////////////////////////////////////////////////////
wire [`DWIDTH-1:0] a0;
wire [`DWIDTH-1:0] a1;
wire [`DWIDTH-1:0] a2;
wire [`DWIDTH-1:0] a3;
wire [`DWIDTH-1:0] b0;
wire [`DWIDTH-1:0] b1;
wire [`DWIDTH-1:0] b2;
wire [`DWIDTH-1:0] b3;
wire [`DWIDTH-1:0] a0_data_in;
wire [`DWIDTH-1:0] a1_data_in;
wire [`DWIDTH-1:0] a2_data_in;
wire [`DWIDTH-1:0] a3_data_in;
assign a0_data_in = a_data_in[`DWIDTH-1:0];
assign a1_data_in = a_data_in[2*`DWIDTH-1:`DWIDTH];
assign a2_data_in = a_data_in[3*`DWIDTH-1:2*`DWIDTH];
assign a3_data_in = a_data_in[4*`DWIDTH-1:3*`DWIDTH];
wire [`DWIDTH-1:0] b0_data_in;
wire [`DWIDTH-1:0] b1_data_in;
wire [`DWIDTH-1:0] b2_data_in;
wire [`DWIDTH-1:0] b3_data_in;
assign b0_data_in = b_data_in[`DWIDTH-1:0];
assign b1_data_in = b_data_in[2*`DWIDTH-1:`DWIDTH];
assign b2_data_in = b_data_in[3*`DWIDTH-1:2*`DWIDTH];
assign b3_data_in = b_data_in[4*`DWIDTH-1:3*`DWIDTH];
//If b_loc is 0, that means this matmul block is on the top-row of the
//final large matmul. In that case, b will take inputs from mem.
//If b_loc != 0, that means this matmul block is not on the top-row of the
//final large matmul. In that case, b will take inputs from the matmul on top
//of this one.
assign a0 = (b_loc==0) ? a0_data : a0_data_in;
assign a1 = (b_loc==0) ? a1_data_delayed_1 : a1_data_in;
assign a2 = (b_loc==0) ? a2_data_delayed_2 : a2_data_in;
assign a3 = (b_loc==0) ? a3_data_delayed_3 : a3_data_in;
//If a_loc is 0, that means this matmul block is on the left-col of the
//final large matmul. In that case, a will take inputs from mem.
//If a_loc != 0, that means this matmul block is not on the left-col of the
//final large matmul. In that case, a will take inputs from the matmul on left
//of this one.
assign b0 = (a_loc==0) ? b0_data : b0_data_in;
assign b1 = (a_loc==0) ? b1_data_delayed_1 : b1_data_in;
assign b2 = (a_loc==0) ? b2_data_delayed_2 : b2_data_in;
assign b3 = (a_loc==0) ? b3_data_delayed_3 : b3_data_in;
wire [`DWIDTH-1:0] matrixC00;
wire [`DWIDTH-1:0] matrixC01;
wire [`DWIDTH-1:0] matrixC02;
wire [`DWIDTH-1:0] matrixC03;
wire [`DWIDTH-1:0] matrixC10;
wire [`DWIDTH-1:0] matrixC11;
wire [`DWIDTH-1:0] matrixC12;
wire [`DWIDTH-1:0] matrixC13;
wire [`DWIDTH-1:0] matrixC20;
wire [`DWIDTH-1:0] matrixC21;
wire [`DWIDTH-1:0] matrixC22;
wire [`DWIDTH-1:0] matrixC23;
wire [`DWIDTH-1:0] matrixC30;
wire [`DWIDTH-1:0] matrixC31;
wire [`DWIDTH-1:0] matrixC32;
wire [`DWIDTH-1:0] matrixC33;
//////////////////////////////////////////////////////////////////////////
// Instantiation of the output logic
//////////////////////////////////////////////////////////////////////////
output_logic u_output_logic(
.clk(clk),
.reset(reset),
.start_mat_mul(start_mat_mul),
.done_mat_mul(done_mat_mul),
.address_mat_c(address_mat_c),
.address_stride_c(address_stride_c),
.c_data_out(c_data_out),
.c_data_in(c_data_in),
.c_addr(c_addr),
.c_data_available(c_data_available),
.clk_cnt(clk_cnt),
.row_latch_en(row_latch_en),
.final_mat_mul_size(final_mat_mul_size),
.matrixC00(matrixC00),
.matrixC01(matrixC01),
.matrixC02(matrixC02),
.matrixC03(matrixC03),
.matrixC10(matrixC10),
.matrixC11(matrixC11),
.matrixC12(matrixC12),
.matrixC13(matrixC13),
.matrixC20(matrixC20),
.matrixC21(matrixC21),
.matrixC22(matrixC22),
.matrixC23(matrixC23),
.matrixC30(matrixC30),
.matrixC31(matrixC31),
.matrixC32(matrixC32),
.matrixC33(matrixC33)
);
//////////////////////////////////////////////////////////////////////////
// Instantiations of the actual processing elements
//////////////////////////////////////////////////////////////////////////
systolic_pe_matrix u_systolic_pe_matrix(
.reset(reset),
.clk(clk),
.pe_reset(pe_reset),
.a0(a0),
.a1(a1),
.a2(a2),
.a3(a3),
.b0(b0),
.b1(b1),
.b2(b2),
.b3(b3),
.matrixC00(matrixC00),
.matrixC01(matrixC01),
.matrixC02(matrixC02),
.matrixC03(matrixC03),
.matrixC10(matrixC10),
.matrixC11(matrixC11),
.matrixC12(matrixC12),
.matrixC13(matrixC13),
.matrixC20(matrixC20),
.matrixC21(matrixC21),
.matrixC22(matrixC22),
.matrixC23(matrixC23),
.matrixC30(matrixC30),
.matrixC31(matrixC31),
.matrixC32(matrixC32),
.matrixC33(matrixC33),
.a_data_out(a_data_out),
.b_data_out(b_data_out)
);
endmodule
//////////////////////////////////////////////////////////////////////////
// Output logic
//////////////////////////////////////////////////////////////////////////
module output_logic(
clk,
reset,
start_mat_mul,
done_mat_mul,
address_mat_c,
address_stride_c,
c_data_in,
c_data_out, //Data values going out to next matmul - systolic shifting
c_addr,
c_data_available,
clk_cnt,
row_latch_en,
final_mat_mul_size,
matrixC00,
matrixC01,
matrixC02,
matrixC03,
matrixC10,
matrixC11,
matrixC12,
matrixC13,
matrixC20,
matrixC21,
matrixC22,
matrixC23,
matrixC30,
matrixC31,
matrixC32,
matrixC33
);
input clk;
input reset;
input start_mat_mul;
input done_mat_mul;
input [`AWIDTH-1:0] address_mat_c;
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out;
output [`AWIDTH-1:0] c_addr;
output c_data_available;
input [7:0] clk_cnt;
output row_latch_en;
input [7:0] final_mat_mul_size;
input [`DWIDTH-1:0] matrixC00;
input [`DWIDTH-1:0] matrixC01;
input [`DWIDTH-1:0] matrixC02;
input [`DWIDTH-1:0] matrixC03;
input [`DWIDTH-1:0] matrixC10;
input [`DWIDTH-1:0] matrixC11;
input [`DWIDTH-1:0] matrixC12;
input [`DWIDTH-1:0] matrixC13;
input [`DWIDTH-1:0] matrixC20;
input [`DWIDTH-1:0] matrixC21;
input [`DWIDTH-1:0] matrixC22;
input [`DWIDTH-1:0] matrixC23;
input [`DWIDTH-1:0] matrixC30;
input [`DWIDTH-1:0] matrixC31;
input [`DWIDTH-1:0] matrixC32;
input [`DWIDTH-1:0] matrixC33;
wire row_latch_en;
//////////////////////////////////////////////////////////////////////////
// Logic to capture matrix C data from the PEs and shift it out
//////////////////////////////////////////////////////////////////////////
assign row_latch_en = ((clk_cnt == ((final_mat_mul_size<<2) - final_mat_mul_size -1 +`NUM_CYCLES_IN_MAC)));
reg c_data_available;
reg [`AWIDTH-1:0] c_addr;
reg start_capturing_c_data;
integer counter;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_1;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_2;
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_3;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col0;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col1;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col2;
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col3;
assign col0 = {matrixC30, matrixC20, matrixC10, matrixC00};
assign col1 = {matrixC31, matrixC21, matrixC11, matrixC01};
assign col2 = {matrixC32, matrixC22, matrixC12, matrixC02};
assign col3 = {matrixC33, matrixC23, matrixC13, matrixC03};
//If save_output_to_accum is asserted, that means we are not intending to shift
//out the outputs, because the outputs are still partial sums.
wire condition_to_start_shifting_output;
assign condition_to_start_shifting_output = row_latch_en ;
//For larger matmuls, this logic will have more entries in the case statement
always @(posedge clk) begin
if (reset | ~start_mat_mul) begin
start_capturing_c_data <= 1'b0;
c_data_available <= 1'b0;
c_addr <= address_mat_c+address_stride_c;
c_data_out <= 0;
counter <= 0;
c_data_out_1 <= 0;
c_data_out_2 <= 0;
c_data_out_3 <= 0;
end
else if (condition_to_start_shifting_output) begin
start_capturing_c_data <= 1'b1;
c_data_available <= 1'b1;
c_addr <= c_addr - address_stride_c;
c_data_out <= col0;
c_data_out_1 <= col1;
c_data_out_2 <= col2;
c_data_out_3 <= col3;
counter <= counter + 1;
end
else if (done_mat_mul) begin
start_capturing_c_data <= 1'b0;
c_data_available <= 1'b0;
c_addr <= address_mat_c+address_stride_c;
c_data_out <= 0;
c_data_out_1 <= 0;
c_data_out_2 <= 0;
c_data_out_3 <= 0;
end
else if (counter >= `MAT_MUL_SIZE) begin
c_addr <= c_addr - address_stride_c;
c_data_out <= c_data_out_1;
c_data_out_1 <= c_data_out_2;
c_data_out_2 <= c_data_out_3;
c_data_out_3 <= c_data_in;
end
else if (start_capturing_c_data) begin
c_data_available <= 1'b1;
c_addr <= c_addr - address_stride_c;
counter <= counter + 1;
c_data_out <= c_data_out_1;
c_data_out_1 <= c_data_out_2;
c_data_out_2 <= c_data_out_3;
c_data_out_3 <= c_data_in;
end
end
endmodule
//////////////////////////////////////////////////////////////////////////
// Systolic data setup
//////////////////////////////////////////////////////////////////////////
module systolic_data_setup(
clk,
reset,
start_mat_mul,
a_addr,
b_addr,
address_mat_a,
address_mat_b,
address_stride_a,
address_stride_b,
a_data,
b_data,
clk_cnt,
a0_data,
a1_data_delayed_1,
a2_data_delayed_2,
a3_data_delayed_3,
b0_data,
b1_data_delayed_1,
b2_data_delayed_2,
b3_data_delayed_3,
validity_mask_a_rows,
validity_mask_a_cols,
validity_mask_b_rows,
validity_mask_b_cols,
final_mat_mul_size,
a_loc,
b_loc
);
input clk;
input reset;
input start_mat_mul;
output [`AWIDTH-1:0] a_addr;
output [`AWIDTH-1:0] b_addr;
input [`AWIDTH-1:0] address_mat_a;
input [`AWIDTH-1:0] address_mat_b;
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a;
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data;
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data;
input [7:0] clk_cnt;
output [`DWIDTH-1:0] a0_data;
output [`DWIDTH-1:0] a1_data_delayed_1;
output [`DWIDTH-1:0] a2_data_delayed_2;
output [`DWIDTH-1:0] a3_data_delayed_3;
output [`DWIDTH-1:0] b0_data;
output [`DWIDTH-1:0] b1_data_delayed_1;
output [`DWIDTH-1:0] b2_data_delayed_2;
output [`DWIDTH-1:0] b3_data_delayed_3;
input [`MASK_WIDTH-1:0] validity_mask_a_rows;
input [`MASK_WIDTH-1:0] validity_mask_a_cols;
input [`MASK_WIDTH-1:0] validity_mask_b_rows;
input [`MASK_WIDTH-1:0] validity_mask_b_cols;
input [7:0] final_mat_mul_size;
input [7:0] a_loc;
input [7:0] b_loc;
wire [`DWIDTH-1:0] a0_data;
wire [`DWIDTH-1:0] a1_data;
wire [`DWIDTH-1:0] a2_data;
wire [`DWIDTH-1:0] a3_data;
wire [`DWIDTH-1:0] b0_data;
wire [`DWIDTH-1:0] b1_data;
wire [`DWIDTH-1:0] b2_data;
wire [`DWIDTH-1:0] b3_data;
//////////////////////////////////////////////////////////////////////////
// Logic to generate addresses to BRAM A
//////////////////////////////////////////////////////////////////////////
reg [`AWIDTH-1:0] a_addr;
reg a_mem_access; //flag that tells whether the matmul is trying to access memory or not
always @(posedge clk) begin
//else if (clk_cnt >= a_loc*`MAT_MUL_SIZE+final_mat_mul_size) begin
//Writing the line above to avoid multiplication:
if ((reset || ~start_mat_mul) || (clk_cnt >= (a_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
a_addr <= address_mat_a-address_stride_a;
a_mem_access <= 0;
end
//else if ((clk_cnt >= a_loc*`MAT_MUL_SIZE) && (clk_cnt < a_loc*`MAT_MUL_SIZE+final_mat_mul_size)) begin
//Writing the line above to avoid multiplication:
else if ((clk_cnt >= (a_loc<<`LOG2_MAT_MUL_SIZE)) && (clk_cnt < (a_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
a_addr <= a_addr + address_stride_a;
a_mem_access <= 1;
end
end
//////////////////////////////////////////////////////////////////////////
// Logic to generate valid signals for data coming from BRAM A
//////////////////////////////////////////////////////////////////////////
reg [7:0] a_mem_access_counter;
always @(posedge clk) begin
if (reset || ~start_mat_mul) begin
a_mem_access_counter <= 0;
end
else if (a_mem_access == 1) begin
a_mem_access_counter <= a_mem_access_counter + 1;
end
else begin
a_mem_access_counter <= 0;
end
end
wire a_data_valid; //flag that tells whether the data from memory is valid
assign a_data_valid =
((validity_mask_a_cols[0]==1'b0 && a_mem_access_counter==1) ||
(validity_mask_a_cols[1]==1'b0 && a_mem_access_counter==2) ||
(validity_mask_a_cols[2]==1'b0 && a_mem_access_counter==3) ||
(validity_mask_a_cols[3]==1'b0 && a_mem_access_counter==4)) ?
1'b0 : (a_mem_access_counter >= `MEM_ACCESS_LATENCY);
//////////////////////////////////////////////////////////////////////////
// Logic to delay certain parts of the data received from BRAM A (systolic data setup)
//////////////////////////////////////////////////////////////////////////
//Slice data into chunks and qualify it with whether it is valid or not
assign a0_data = a_data[`DWIDTH-1:0] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[0]}};
assign a1_data = a_data[2*`DWIDTH-1:`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[1]}};
assign a2_data = a_data[3*`DWIDTH-1:2*`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[2]}};
assign a3_data = a_data[4*`DWIDTH-1:3*`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[3]}};
//For larger matmuls, more such delaying flops will be needed
reg [`DWIDTH-1:0] a1_data_delayed_1;
reg [`DWIDTH-1:0] a2_data_delayed_1;
reg [`DWIDTH-1:0] a2_data_delayed_2;
reg [`DWIDTH-1:0] a3_data_delayed_1;
reg [`DWIDTH-1:0] a3_data_delayed_2;
reg [`DWIDTH-1:0] a3_data_delayed_3;
always @(posedge clk) begin
if (reset || ~start_mat_mul || clk_cnt==0) begin
a1_data_delayed_1 <= 0;
a2_data_delayed_1 <= 0;
a2_data_delayed_2 <= 0;
a3_data_delayed_1 <= 0;
a3_data_delayed_2 <= 0;
a3_data_delayed_3 <= 0;
end
else begin
a1_data_delayed_1 <= a1_data;
a2_data_delayed_1 <= a2_data;
a2_data_delayed_2 <= a2_data_delayed_1;
a3_data_delayed_1 <= a3_data;
a3_data_delayed_2 <= a3_data_delayed_1;
a3_data_delayed_3 <= a3_data_delayed_2;
end
end
//////////////////////////////////////////////////////////////////////////
// Logic to generate addresses to BRAM B
//////////////////////////////////////////////////////////////////////////
reg [`AWIDTH-1:0] b_addr;
reg b_mem_access; //flag that tells whether the matmul is trying to access memory or not
always @(posedge clk) begin
//else if (clk_cnt >= b_loc*`MAT_MUL_SIZE+final_mat_mul_size) begin
//Writing the line above to avoid multiplication:
if ((reset || ~start_mat_mul) || (clk_cnt >= (b_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
b_addr <= address_mat_b - address_stride_b;
b_mem_access <= 0;
end
//else if ((clk_cnt >= b_loc*`MAT_MUL_SIZE) && (clk_cnt < b_loc*`MAT_MUL_SIZE+final_mat_mul_size)) begin
//Writing the line above to avoid multiplication:
else if ((clk_cnt >= (b_loc<<`LOG2_MAT_MUL_SIZE)) && (clk_cnt < (b_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
b_addr <= b_addr + address_stride_b;
b_mem_access <= 1;
end
end
//////////////////////////////////////////////////////////////////////////
// Logic to generate valid signals for data coming from BRAM B
//////////////////////////////////////////////////////////////////////////
reg [7:0] b_mem_access_counter;
always @(posedge clk) begin
if (reset || ~start_mat_mul) begin
b_mem_access_counter <= 0;
end
else if (b_mem_access == 1) begin
b_mem_access_counter <= b_mem_access_counter + 1;
end
else begin
b_mem_access_counter <= 0;
end
end
wire b_data_valid; //flag that tells whether the data from memory is valid
assign b_data_valid =
((validity_mask_b_rows[0]==1'b0 && b_mem_access_counter==1) ||
(validity_mask_b_rows[1]==1'b0 && b_mem_access_counter==2) ||
(validity_mask_b_rows[2]==1'b0 && b_mem_access_counter==3) ||
(validity_mask_b_rows[3]==1'b0 && b_mem_access_counter==4)) ?
1'b0 : (b_mem_access_counter >= `MEM_ACCESS_LATENCY);
//////////////////////////////////////////////////////////////////////////
// Logic to delay certain parts of the data received from BRAM B (systolic data setup)
//////////////////////////////////////////////////////////////////////////
//Slice data into chunks and qualify it with whether it is valid or not
assign b0_data = b_data[`DWIDTH-1:0] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[0]}};
assign b1_data = b_data[2*`DWIDTH-1:`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[1]}};
assign b2_data = b_data[3*`DWIDTH-1:2*`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[2]}};
assign b3_data = b_data[4*`DWIDTH-1:3*`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[3]}};
//For larger matmuls, more such delaying flops will be needed
reg [`DWIDTH-1:0] b1_data_delayed_1;
reg [`DWIDTH-1:0] b2_data_delayed_1;
reg [`DWIDTH-1:0] b2_data_delayed_2;
reg [`DWIDTH-1:0] b3_data_delayed_1;
reg [`DWIDTH-1:0] b3_data_delayed_2;
reg [`DWIDTH-1:0] b3_data_delayed_3;
always @(posedge clk) begin
if (reset || ~start_mat_mul || clk_cnt==0) begin
b1_data_delayed_1 <= 0;
b2_data_delayed_1 <= 0;
b2_data_delayed_2 <= 0;
b3_data_delayed_1 <= 0;
b3_data_delayed_2 <= 0;
b3_data_delayed_3 <= 0;
end
else begin
b1_data_delayed_1 <= b1_data;
b2_data_delayed_1 <= b2_data;
b2_data_delayed_2 <= b2_data_delayed_1;
b3_data_delayed_1 <= b3_data;
b3_data_delayed_2 <= b3_data_delayed_1;
b3_data_delayed_3 <= b3_data_delayed_2;
end
end
endmodule
//////////////////////////////////////////////////////////////////////////
// Systolically connected PEs
//////////////////////////////////////////////////////////////////////////
module systolic_pe_matrix(
reset,
clk,
pe_reset,
a0, a1, a2, a3,
b0, b1, b2, b3,
matrixC00,
matrixC01,
matrixC02,
matrixC03,
matrixC10,
matrixC11,
matrixC12,
matrixC13,
matrixC20,
matrixC21,
matrixC22,
matrixC23,
matrixC30,
matrixC31,
matrixC32,
matrixC33,
a_data_out,
b_data_out
);
input clk;
input reset;
input pe_reset;
input [`DWIDTH-1:0] a0;
input [`DWIDTH-1:0] a1;
input [`DWIDTH-1:0] a2;
input [`DWIDTH-1:0] a3;
input [`DWIDTH-1:0] b0;
input [`DWIDTH-1:0] b1;
input [`DWIDTH-1:0] b2;
input [`DWIDTH-1:0] b3;
output [`DWIDTH-1:0] matrixC00;
output [`DWIDTH-1:0] matrixC01;
output [`DWIDTH-1:0] matrixC02;
output [`DWIDTH-1:0] matrixC03;
output [`DWIDTH-1:0] matrixC10;
output [`DWIDTH-1:0] matrixC11;
output [`DWIDTH-1:0] matrixC12;
output [`DWIDTH-1:0] matrixC13;
output [`DWIDTH-1:0] matrixC20;
output [`DWIDTH-1:0] matrixC21;
output [`DWIDTH-1:0] matrixC22;
output [`DWIDTH-1:0] matrixC23;
output [`DWIDTH-1:0] matrixC30;
output [`DWIDTH-1:0] matrixC31;
output [`DWIDTH-1:0] matrixC32;
output [`DWIDTH-1:0] matrixC33;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_out;
output [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_out;
wire [`DWIDTH-1:0] a00to01, a01to02, a02to03, a03to04;
wire [`DWIDTH-1:0] a10to11, a11to12, a12to13, a13to14;
wire [`DWIDTH-1:0] a20to21, a21to22, a22to23, a23to24;
wire [`DWIDTH-1:0] a30to31, a31to32, a32to33, a33to34;
wire [`DWIDTH-1:0] b00to10, b10to20, b20to30, b30to40;
wire [`DWIDTH-1:0] b01to11, b11to21, b21to31, b31to41;
wire [`DWIDTH-1:0] b02to12, b12to22, b22to32, b32to42;
wire [`DWIDTH-1:0] b03to13, b13to23, b23to33, b33to43;
wire effective_rst;
assign effective_rst = reset | pe_reset;
processing_element pe00(.reset(effective_rst), .clk(clk), .in_a(a0), .in_b(b0), .out_a(a00to01), .out_b(b00to10), .out_c(matrixC00));
processing_element pe01(.reset(effective_rst), .clk(clk), .in_a(a00to01), .in_b(b1), .out_a(a01to02), .out_b(b01to11), .out_c(matrixC01));
processing_element pe02(.reset(effective_rst), .clk(clk), .in_a(a01to02), .in_b(b2), .out_a(a02to03), .out_b(b02to12), .out_c(matrixC02));
processing_element pe03(.reset(effective_rst), .clk(clk), .in_a(a02to03), .in_b(b3), .out_a(a03to04), .out_b(b03to13), .out_c(matrixC03));
processing_element pe10(.reset(effective_rst), .clk(clk), .in_a(a1), .in_b(b00to10), .out_a(a10to11), .out_b(b10to20), .out_c(matrixC10));
processing_element pe11(.reset(effective_rst), .clk(clk), .in_a(a10to11), .in_b(b01to11), .out_a(a11to12), .out_b(b11to21), .out_c(matrixC11));
processing_element pe12(.reset(effective_rst), .clk(clk), .in_a(a11to12), .in_b(b02to12), .out_a(a12to13), .out_b(b12to22), .out_c(matrixC12));
processing_element pe13(.reset(effective_rst), .clk(clk), .in_a(a12to13), .in_b(b03to13), .out_a(a13to14), .out_b(b13to23), .out_c(matrixC13));
processing_element pe20(.reset(effective_rst), .clk(clk), .in_a(a2), .in_b(b10to20), .out_a(a20to21), .out_b(b20to30), .out_c(matrixC20));
processing_element pe21(.reset(effective_rst), .clk(clk), .in_a(a20to21), .in_b(b11to21), .out_a(a21to22), .out_b(b21to31), .out_c(matrixC21));
processing_element pe22(.reset(effective_rst), .clk(clk), .in_a(a21to22), .in_b(b12to22), .out_a(a22to23), .out_b(b22to32), .out_c(matrixC22));
processing_element pe23(.reset(effective_rst), .clk(clk), .in_a(a22to23), .in_b(b13to23), .out_a(a23to24), .out_b(b23to33), .out_c(matrixC23));
processing_element pe30(.reset(effective_rst), .clk(clk), .in_a(a3), .in_b(b20to30), .out_a(a30to31), .out_b(b30to40), .out_c(matrixC30));
processing_element pe31(.reset(effective_rst), .clk(clk), .in_a(a30to31), .in_b(b21to31), .out_a(a31to32), .out_b(b31to41), .out_c(matrixC31));
processing_element pe32(.reset(effective_rst), .clk(clk), .in_a(a31to32), .in_b(b22to32), .out_a(a32to33), .out_b(b32to42), .out_c(matrixC32));
processing_element pe33(.reset(effective_rst), .clk(clk), .in_a(a32to33), .in_b(b23to33), .out_a(a33to34), .out_b(b33to43), .out_c(matrixC33));
assign a_data_out = {a33to34,a23to24,a13to14,a03to04};
assign b_data_out = {b33to43,b32to42,b31to41,b30to40};
endmodule
//////////////////////////////////////////////////////////////////////////
// Processing element (PE)
//////////////////////////////////////////////////////////////////////////
module processing_element(
reset,
clk,
in_a,
in_b,
out_a,
out_b,
out_c
);
input reset;
input clk;
input [`DWIDTH-1:0] in_a;
input [`DWIDTH-1:0] in_b;
output [`DWIDTH-1:0] out_a;
output [`DWIDTH-1:0] out_b;
output [`DWIDTH-1:0] out_c; //reduced precision
reg [`DWIDTH-1:0] out_a;
reg [`DWIDTH-1:0] out_b;
wire [`DWIDTH-1:0] out_c;
wire [`DWIDTH-1:0] out_mac;
assign out_c = out_mac;
//This is an instantiation of a module that is defined in the arch file.
//It's a mode of the DSP slice (floating point 16-bit multiply and accumulate).
mac_fp u_mac(.a(in_a), .b(in_b), .out(out_mac), .reset(reset), .clk(clk));
always @(posedge clk)begin
if(reset) begin
out_a<=0;
out_b<=0;
end
else begin
out_a<=in_a;
out_b<=in_b;
end
end
endmodule