1663 lines
52 KiB
Coq
1663 lines
52 KiB
Coq
|
////////////////////////////////////////////////
|
||
|
// Matrix multiplication design
|
||
|
// Multiplies 8x8 matrix (A) with another 8x8 matrix (B)
|
||
|
// to produce an 8x8 matrix (C).
|
||
|
// Data precision is IEEE floating point 16-bit (half precision)
|
||
|
// The architecture is systolic in nature (output stationary).
|
||
|
// 4 4x4 matmuls composed to make a larger 8x8 matmul.
|
||
|
// There is state machine for control and an APB
|
||
|
// interface for programming/configuring.
|
||
|
// Matrices are stores in RAM blocks.
|
||
|
///////////////////////////////////////////////
|
||
|
|
||
|
`timescale 1ns/1ns
|
||
|
|
||
|
`define DWIDTH 16
|
||
|
`define AWIDTH 10
|
||
|
`define MEM_SIZE 1024
|
||
|
`define DESIGN_SIZE 8
|
||
|
`define MAT_MUL_SIZE 4
|
||
|
`define MASK_WIDTH 4
|
||
|
`define LOG2_MAT_MUL_SIZE 2
|
||
|
`define NUM_CYCLES_IN_MAC 3
|
||
|
`define MEM_ACCESS_LATENCY 1
|
||
|
`define REG_DATAWIDTH 32
|
||
|
`define REG_ADDRWIDTH 8
|
||
|
`define ADDR_STRIDE_WIDTH 16
|
||
|
`define REG_STDN_TPU_ADDR 32'h4
|
||
|
`define REG_MATRIX_A_ADDR 32'he
|
||
|
`define REG_MATRIX_B_ADDR 32'h12
|
||
|
`define REG_MATRIX_C_ADDR 32'h16
|
||
|
`define REG_VALID_MASK_A_ROWS_ADDR 32'h20
|
||
|
`define REG_VALID_MASK_A_COLS_ADDR 32'h54
|
||
|
`define REG_VALID_MASK_B_ROWS_ADDR 32'h5c
|
||
|
`define REG_VALID_MASK_B_COLS_ADDR 32'h58
|
||
|
`define REG_MATRIX_A_STRIDE_ADDR 32'h28
|
||
|
`define REG_MATRIX_B_STRIDE_ADDR 32'h32
|
||
|
`define REG_MATRIX_C_STRIDE_ADDR 32'h36
|
||
|
|
||
|
module matrix_multiplication(
|
||
|
input clk,
|
||
|
input clk_mem,
|
||
|
input resetn,
|
||
|
input pe_resetn,
|
||
|
input PRESETn,
|
||
|
input [`REG_ADDRWIDTH-1:0] PADDR,
|
||
|
input PWRITE,
|
||
|
input PSEL,
|
||
|
input PENABLE,
|
||
|
input [`REG_DATAWIDTH-1:0] PWDATA,
|
||
|
output reg [`REG_DATAWIDTH-1:0] PRDATA,
|
||
|
output reg PREADY,
|
||
|
input [7:0] bram_select,
|
||
|
input [`AWIDTH-1:0] bram_addr_ext,
|
||
|
output reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_ext,
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_ext,
|
||
|
input [`MAT_MUL_SIZE-1:0] bram_we_ext
|
||
|
);
|
||
|
|
||
|
|
||
|
wire PCLK;
|
||
|
assign PCLK = clk;
|
||
|
reg start_reg;
|
||
|
reg clear_done_reg;
|
||
|
//Dummy register to sync all other invalid/unimplemented addresses
|
||
|
reg [`REG_DATAWIDTH-1:0] reg_dummy;
|
||
|
|
||
|
reg [`AWIDTH-1:0] bram_addr_a_0_0_ext;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_0_0_ext;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_0_0_ext;
|
||
|
reg [`MASK_WIDTH-1:0] bram_we_a_0_0_ext;
|
||
|
|
||
|
reg [`AWIDTH-1:0] bram_addr_a_1_0_ext;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_1_0_ext;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_1_0_ext;
|
||
|
reg [`MASK_WIDTH-1:0] bram_we_a_1_0_ext;
|
||
|
|
||
|
reg [`AWIDTH-1:0] bram_addr_b_0_0_ext;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_0_ext;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_0_ext;
|
||
|
reg [`MASK_WIDTH-1:0] bram_we_b_0_0_ext;
|
||
|
|
||
|
reg [`AWIDTH-1:0] bram_addr_b_0_1_ext;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_1_ext;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_1_ext;
|
||
|
reg [`MASK_WIDTH-1:0] bram_we_b_0_1_ext;
|
||
|
|
||
|
reg [`AWIDTH-1:0] bram_addr_c_0_1_ext;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_0_1_ext;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_0_1_ext;
|
||
|
reg [`MASK_WIDTH-1:0] bram_we_c_0_1_ext;
|
||
|
|
||
|
reg [`AWIDTH-1:0] bram_addr_c_1_1_ext;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_1_1_ext;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_1_1_ext;
|
||
|
reg [`MASK_WIDTH-1:0] bram_we_c_1_1_ext;
|
||
|
|
||
|
wire [`AWIDTH-1:0] bram_addr_a_0_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_0_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_0_0;
|
||
|
wire [`MASK_WIDTH-1:0] bram_we_a_0_0;
|
||
|
wire bram_en_a_0_0;
|
||
|
|
||
|
wire [`AWIDTH-1:0] bram_addr_a_1_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_1_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_1_0;
|
||
|
wire [`MASK_WIDTH-1:0] bram_we_a_1_0;
|
||
|
wire bram_en_a_1_0;
|
||
|
|
||
|
wire [`AWIDTH-1:0] bram_addr_b_0_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_0;
|
||
|
wire [`MASK_WIDTH-1:0] bram_we_b_0_0;
|
||
|
wire bram_en_b_0_0;
|
||
|
|
||
|
wire [`AWIDTH-1:0] bram_addr_b_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_1;
|
||
|
wire [`MASK_WIDTH-1:0] bram_we_b_0_1;
|
||
|
wire bram_en_b_0_1;
|
||
|
|
||
|
wire [`AWIDTH-1:0] bram_addr_c_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_0_1;
|
||
|
wire [`MASK_WIDTH-1:0] bram_we_c_0_1;
|
||
|
wire bram_en_c_0_1;
|
||
|
|
||
|
wire [`AWIDTH-1:0] bram_addr_c_1_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_1_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_1_1;
|
||
|
wire [`MASK_WIDTH-1:0] bram_we_c_1_1;
|
||
|
wire bram_en_c_1_1;
|
||
|
|
||
|
always @* begin
|
||
|
case (bram_select)
|
||
|
|
||
|
0: begin
|
||
|
bram_addr_a_0_0_ext = bram_addr_ext;
|
||
|
bram_wdata_a_0_0_ext = bram_wdata_ext;
|
||
|
bram_we_a_0_0_ext = bram_we_ext;
|
||
|
bram_rdata_ext = bram_rdata_a_0_0_ext;
|
||
|
end
|
||
|
|
||
|
1: begin
|
||
|
bram_addr_a_1_0_ext = bram_addr_ext;
|
||
|
bram_wdata_a_1_0_ext = bram_wdata_ext;
|
||
|
bram_we_a_1_0_ext = bram_we_ext;
|
||
|
bram_rdata_ext = bram_rdata_a_1_0_ext;
|
||
|
end
|
||
|
|
||
|
2: begin
|
||
|
bram_addr_b_0_0_ext = bram_addr_ext;
|
||
|
bram_wdata_b_0_0_ext = bram_wdata_ext;
|
||
|
bram_we_b_0_0_ext = bram_we_ext;
|
||
|
bram_rdata_ext = bram_rdata_b_0_0_ext;
|
||
|
end
|
||
|
|
||
|
3: begin
|
||
|
bram_addr_b_0_1_ext = bram_addr_ext;
|
||
|
bram_wdata_b_0_1_ext = bram_wdata_ext;
|
||
|
bram_we_b_0_1_ext = bram_we_ext;
|
||
|
bram_rdata_ext = bram_rdata_b_0_1_ext;
|
||
|
end
|
||
|
|
||
|
4: begin
|
||
|
bram_addr_c_0_1_ext = bram_addr_ext;
|
||
|
bram_wdata_c_0_1_ext = bram_wdata_ext;
|
||
|
bram_we_c_0_1_ext = bram_we_ext;
|
||
|
bram_rdata_ext = bram_rdata_c_0_1_ext;
|
||
|
end
|
||
|
|
||
|
5: begin
|
||
|
bram_addr_c_1_1_ext = bram_addr_ext;
|
||
|
bram_wdata_c_1_1_ext = bram_wdata_ext;
|
||
|
bram_we_c_1_1_ext = bram_we_ext;
|
||
|
bram_rdata_ext = bram_rdata_c_1_1_ext;
|
||
|
end
|
||
|
|
||
|
default: begin
|
||
|
bram_rdata_ext = 0;
|
||
|
end
|
||
|
endcase
|
||
|
end
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// BRAMs to store matrix A
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
// BRAM matrix A 0_0 (bank0)
|
||
|
ram matrix_A_0_0(
|
||
|
.addr0(bram_addr_a_0_0),
|
||
|
.d0(bram_wdata_a_0_0),
|
||
|
.we0(bram_we_a_0_0),
|
||
|
.q0(bram_rdata_a_0_0),
|
||
|
.addr1(bram_addr_a_0_0_ext),
|
||
|
.d1(bram_wdata_a_0_0_ext),
|
||
|
.we1(bram_we_a_0_0_ext),
|
||
|
.q1(bram_rdata_a_0_0_ext),
|
||
|
.clk(clk_mem));
|
||
|
|
||
|
// BRAM matrix A 1_0 (bank1)
|
||
|
ram matrix_A_1_0(
|
||
|
.addr0(bram_addr_a_1_0),
|
||
|
.d0(bram_wdata_a_1_0),
|
||
|
.we0(bram_we_a_1_0),
|
||
|
.q0(bram_rdata_a_1_0),
|
||
|
.addr1(bram_addr_a_1_0_ext),
|
||
|
.d1(bram_wdata_a_1_0_ext),
|
||
|
.we1(bram_we_a_1_0_ext),
|
||
|
.q1(bram_rdata_a_1_0_ext),
|
||
|
.clk(clk_mem));
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// BRAMs to store matrix B
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
// BRAM matrix B 0_0 (bank0)
|
||
|
ram matrix_B_0_0(
|
||
|
.addr0(bram_addr_b_0_0),
|
||
|
.d0(bram_wdata_b_0_0),
|
||
|
.we0(bram_we_b_0_0),
|
||
|
.q0(bram_rdata_b_0_0),
|
||
|
.addr1(bram_addr_b_0_0_ext),
|
||
|
.d1(bram_wdata_b_0_0_ext),
|
||
|
.we1(bram_we_b_0_0_ext),
|
||
|
.q1(bram_rdata_b_0_0_ext),
|
||
|
.clk(clk_mem));
|
||
|
|
||
|
// BRAM matrix B 0_1 (bank1)
|
||
|
ram matrix_B_0_1(
|
||
|
.addr0(bram_addr_b_0_1),
|
||
|
.d0(bram_wdata_b_0_1),
|
||
|
.we0(bram_we_b_0_1),
|
||
|
.q0(bram_rdata_b_0_1),
|
||
|
.addr1(bram_addr_b_0_1_ext),
|
||
|
.d1(bram_wdata_b_0_1_ext),
|
||
|
.we1(bram_we_b_0_1_ext),
|
||
|
.q1(bram_rdata_b_0_1_ext),
|
||
|
.clk(clk_mem));
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// BRAMs to store matrix C
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
|
||
|
// BRAM matrix C 0_1 (bank0)
|
||
|
ram matrix_C_0_1(
|
||
|
.addr0(bram_addr_c_0_1),
|
||
|
.d0(bram_wdata_c_0_1),
|
||
|
.we0(bram_we_c_0_1),
|
||
|
.q0(bram_rdata_c_0_1),
|
||
|
.addr1(bram_addr_c_0_1_ext),
|
||
|
.d1(bram_wdata_c_0_1_ext),
|
||
|
.we1(bram_we_c_0_1_ext),
|
||
|
.q1(bram_rdata_c_0_1_ext),
|
||
|
.clk(clk_mem));
|
||
|
|
||
|
// BRAM matrix C 1_1 (bank1)
|
||
|
ram matrix_C_1_1(
|
||
|
.addr0(bram_addr_c_1_1),
|
||
|
.d0(bram_wdata_c_1_1),
|
||
|
.we0(bram_we_c_1_1),
|
||
|
.q0(bram_rdata_c_1_1),
|
||
|
.addr1(bram_addr_c_1_1_ext),
|
||
|
.d1(bram_wdata_c_1_1_ext),
|
||
|
.we1(bram_we_c_1_1_ext),
|
||
|
.q1(bram_rdata_c_1_1_ext),
|
||
|
.clk(clk_mem));
|
||
|
|
||
|
reg start_mat_mul;
|
||
|
wire done_mat_mul;
|
||
|
|
||
|
reg [3:0] state;
|
||
|
|
||
|
////////////////////////////////////////////////////////////////
|
||
|
// Control logic
|
||
|
////////////////////////////////////////////////////////////////
|
||
|
always @( posedge clk) begin
|
||
|
if (resetn == 1'b0) begin
|
||
|
state <= 4'b0000;
|
||
|
start_mat_mul <= 1'b0;
|
||
|
end
|
||
|
else begin
|
||
|
case (state)
|
||
|
|
||
|
4'b0000: begin
|
||
|
start_mat_mul <= 1'b0;
|
||
|
if (start_reg == 1'b1) begin
|
||
|
state <= 4'b0001;
|
||
|
end else begin
|
||
|
state <= 4'b0000;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
4'b0001: begin
|
||
|
start_mat_mul <= 1'b1;
|
||
|
state <= 4'b1010;
|
||
|
end
|
||
|
|
||
|
4'b1010: begin
|
||
|
if (done_mat_mul == 1'b1) begin
|
||
|
start_mat_mul <= 1'b0;
|
||
|
state <= 4'b1000;
|
||
|
end
|
||
|
else begin
|
||
|
state <= 4'b1010;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
4'b1000: begin
|
||
|
if (clear_done_reg == 1'b1) begin
|
||
|
state <= 4'b0000;
|
||
|
end
|
||
|
else begin
|
||
|
state <= 4'b1000;
|
||
|
end
|
||
|
end
|
||
|
endcase
|
||
|
end
|
||
|
end
|
||
|
|
||
|
reg [1:0] state_apb;
|
||
|
`define IDLE 2'b00
|
||
|
`define W_ENABLE 2'b01
|
||
|
`define R_ENABLE 2'b10
|
||
|
|
||
|
reg [`AWIDTH-1:0] address_mat_a;
|
||
|
reg [`AWIDTH-1:0] address_mat_b;
|
||
|
reg [`AWIDTH-1:0] address_mat_c;
|
||
|
reg [`MASK_WIDTH-1:0] validity_mask_a_rows;
|
||
|
reg [`MASK_WIDTH-1:0] validity_mask_a_cols;
|
||
|
reg [`MASK_WIDTH-1:0] validity_mask_b_rows;
|
||
|
reg [`MASK_WIDTH-1:0] validity_mask_b_cols;
|
||
|
reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_a;
|
||
|
reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_b;
|
||
|
reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_c;
|
||
|
|
||
|
////////////////////////////////////////////////////////////////
|
||
|
// Configuration logic
|
||
|
////////////////////////////////////////////////////////////////
|
||
|
always @(posedge PCLK) begin
|
||
|
if (PRESETn == 0) begin
|
||
|
state_apb <= `IDLE;
|
||
|
PRDATA <= 0;
|
||
|
PREADY <= 0;
|
||
|
address_mat_a <= 0;
|
||
|
address_mat_b <= 0;
|
||
|
address_mat_c <= 0;
|
||
|
validity_mask_a_rows <= {`MASK_WIDTH{1'b1}};
|
||
|
validity_mask_a_cols <= {`MASK_WIDTH{1'b1}};
|
||
|
validity_mask_b_rows <= {`MASK_WIDTH{1'b1}};
|
||
|
validity_mask_b_cols <= {`MASK_WIDTH{1'b1}};
|
||
|
address_stride_a <= `MAT_MUL_SIZE;
|
||
|
address_stride_b <= `MAT_MUL_SIZE;
|
||
|
address_stride_c <= `MAT_MUL_SIZE;
|
||
|
end
|
||
|
|
||
|
else begin
|
||
|
case (state_apb)
|
||
|
`IDLE : begin
|
||
|
PRDATA <= 0;
|
||
|
if (PSEL) begin
|
||
|
if (PWRITE) begin
|
||
|
state_apb <= `W_ENABLE;
|
||
|
end
|
||
|
else begin
|
||
|
state_apb <= `R_ENABLE;
|
||
|
end
|
||
|
end
|
||
|
PREADY <= 0;
|
||
|
end
|
||
|
|
||
|
`W_ENABLE : begin
|
||
|
if (PSEL && PWRITE && PENABLE) begin
|
||
|
case (PADDR)
|
||
|
`REG_STDN_TPU_ADDR : begin
|
||
|
start_reg <= PWDATA[0];
|
||
|
clear_done_reg <= PWDATA[31];
|
||
|
end
|
||
|
`REG_MATRIX_A_ADDR : address_mat_a <= PWDATA[`AWIDTH-1:0];
|
||
|
`REG_MATRIX_B_ADDR : address_mat_b <= PWDATA[`AWIDTH-1:0];
|
||
|
`REG_MATRIX_C_ADDR : address_mat_c <= PWDATA[`AWIDTH-1:0];
|
||
|
`REG_VALID_MASK_A_ROWS_ADDR: begin
|
||
|
validity_mask_a_rows <= PWDATA[`MASK_WIDTH-1:0];
|
||
|
end
|
||
|
`REG_VALID_MASK_A_COLS_ADDR: begin
|
||
|
validity_mask_a_cols <= PWDATA[`MASK_WIDTH-1:0];
|
||
|
end
|
||
|
`REG_VALID_MASK_B_ROWS_ADDR: begin
|
||
|
validity_mask_b_rows <= PWDATA[`MASK_WIDTH-1:0];
|
||
|
end
|
||
|
`REG_VALID_MASK_B_COLS_ADDR: begin
|
||
|
validity_mask_b_cols <= PWDATA[`MASK_WIDTH-1:0];
|
||
|
end
|
||
|
`REG_MATRIX_A_STRIDE_ADDR : address_stride_a <= PWDATA[`ADDR_STRIDE_WIDTH-1:0];
|
||
|
`REG_MATRIX_B_STRIDE_ADDR : address_stride_b <= PWDATA[`ADDR_STRIDE_WIDTH-1:0];
|
||
|
`REG_MATRIX_C_STRIDE_ADDR : address_stride_c <= PWDATA[`ADDR_STRIDE_WIDTH-1:0];
|
||
|
default : reg_dummy <= PWDATA; //sink writes to a dummy register
|
||
|
endcase
|
||
|
PREADY <=1;
|
||
|
end
|
||
|
state_apb <= `IDLE;
|
||
|
end
|
||
|
|
||
|
`R_ENABLE : begin
|
||
|
if (PSEL && !PWRITE && PENABLE) begin
|
||
|
PREADY <= 1;
|
||
|
case (PADDR)
|
||
|
`REG_STDN_TPU_ADDR : PRDATA <= {done_mat_mul, 30'b0, start_mat_mul};
|
||
|
`REG_MATRIX_A_ADDR : PRDATA <= address_mat_a;
|
||
|
`REG_MATRIX_B_ADDR : PRDATA <= address_mat_b;
|
||
|
`REG_MATRIX_C_ADDR : PRDATA <= address_mat_c;
|
||
|
`REG_VALID_MASK_A_ROWS_ADDR: PRDATA <= validity_mask_a_rows;
|
||
|
`REG_VALID_MASK_A_COLS_ADDR: PRDATA <= validity_mask_a_cols;
|
||
|
`REG_VALID_MASK_B_ROWS_ADDR: PRDATA <= validity_mask_b_rows;
|
||
|
`REG_VALID_MASK_B_COLS_ADDR: PRDATA <= validity_mask_b_cols;
|
||
|
`REG_MATRIX_A_STRIDE_ADDR : PRDATA <= address_stride_a;
|
||
|
`REG_MATRIX_B_STRIDE_ADDR : PRDATA <= address_stride_b;
|
||
|
`REG_MATRIX_C_STRIDE_ADDR : PRDATA <= address_stride_c;
|
||
|
default : PRDATA <= reg_dummy; //read the dummy register for undefined addresses
|
||
|
endcase
|
||
|
end
|
||
|
state_apb <= `IDLE;
|
||
|
end
|
||
|
default: begin
|
||
|
state_apb <= `IDLE;
|
||
|
end
|
||
|
endcase
|
||
|
end
|
||
|
end
|
||
|
|
||
|
wire reset;
|
||
|
assign reset = ~resetn;
|
||
|
wire pe_reset;
|
||
|
assign pe_reset = ~pe_resetn;
|
||
|
|
||
|
wire c_data_0_1_available;
|
||
|
assign bram_en_c_0_1 = 1'b1;
|
||
|
assign bram_we_c_0_1 = (c_data_0_1_available) ? {`MASK_WIDTH{1'b1}} : {`MASK_WIDTH{1'b0}};
|
||
|
|
||
|
wire c_data_1_1_available;
|
||
|
assign bram_en_c_1_1 = 1'b1;
|
||
|
assign bram_we_c_1_1 = (c_data_1_1_available) ? {`MASK_WIDTH{1'b1}} : {`MASK_WIDTH{1'b0}};
|
||
|
|
||
|
assign bram_wdata_a_0_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
|
||
|
assign bram_en_a_0_0 = 1'b1;
|
||
|
assign bram_we_a_0_0 = {`MASK_WIDTH{1'b0}};
|
||
|
|
||
|
assign bram_wdata_a_1_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
|
||
|
assign bram_en_a_1_0 = 1'b1;
|
||
|
assign bram_we_a_1_0 = {`MASK_WIDTH{1'b0}};
|
||
|
|
||
|
assign bram_wdata_b_0_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
|
||
|
assign bram_en_b_0_0 = 1'b1;
|
||
|
assign bram_we_b_0_0 = {`MASK_WIDTH{1'b0}};
|
||
|
|
||
|
assign bram_wdata_b_0_1 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}};
|
||
|
assign bram_en_b_0_1 = 1'b1;
|
||
|
assign bram_we_b_0_1 = {`MASK_WIDTH{1'b0}};
|
||
|
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// The 8x8 matmul instantiation
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
matmul_8x8_systolic u_matmul_8x8_systolic (
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.pe_reset(pe_reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.done_mat_mul(done_mat_mul),
|
||
|
.address_mat_a(address_mat_a),
|
||
|
.address_mat_b(address_mat_b),
|
||
|
.address_mat_c(address_mat_c),
|
||
|
.address_stride_a(address_stride_a),
|
||
|
.address_stride_b(address_stride_b),
|
||
|
.address_stride_c(address_stride_c),
|
||
|
|
||
|
.a_data_0_0(bram_rdata_a_0_0),
|
||
|
.b_data_0_0(bram_rdata_b_0_0),
|
||
|
.a_addr_0_0(bram_addr_a_0_0),
|
||
|
.b_addr_0_0(bram_addr_b_0_0),
|
||
|
|
||
|
.a_data_1_0(bram_rdata_a_1_0),
|
||
|
.b_data_0_1(bram_rdata_b_0_1),
|
||
|
.a_addr_1_0(bram_addr_a_1_0),
|
||
|
.b_addr_0_1(bram_addr_b_0_1),
|
||
|
|
||
|
.c_data_0_1(bram_wdata_c_0_1),
|
||
|
.c_addr_0_1(bram_addr_c_0_1),
|
||
|
.c_data_0_1_available(c_data_0_1_available),
|
||
|
|
||
|
.c_data_1_1(bram_wdata_c_1_1),
|
||
|
.c_addr_1_1(bram_addr_c_1_1),
|
||
|
.c_data_1_1_available(c_data_1_1_available),
|
||
|
|
||
|
.validity_mask_a_rows(validity_mask_a_rows),
|
||
|
.validity_mask_a_cols(validity_mask_a_cols),
|
||
|
.validity_mask_b_rows(validity_mask_b_rows),
|
||
|
.validity_mask_b_cols(validity_mask_b_cols)
|
||
|
);
|
||
|
endmodule
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// The 8x8 matmul definition
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
module matmul_8x8_systolic(
|
||
|
input clk,
|
||
|
input reset,
|
||
|
input pe_reset,
|
||
|
input start_mat_mul,
|
||
|
output done_mat_mul,
|
||
|
|
||
|
input [`AWIDTH-1:0] address_mat_a,
|
||
|
input [`AWIDTH-1:0] address_mat_b,
|
||
|
input [`AWIDTH-1:0] address_mat_c,
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a,
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b,
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c,
|
||
|
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_0,
|
||
|
output [`AWIDTH-1:0] a_addr_0_0,
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_0,
|
||
|
output [`AWIDTH-1:0] b_addr_0_0,
|
||
|
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_0,
|
||
|
output [`AWIDTH-1:0] a_addr_1_0,
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_1,
|
||
|
output [`AWIDTH-1:0] b_addr_0_1,
|
||
|
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_1,
|
||
|
output [`AWIDTH-1:0] c_addr_0_1,
|
||
|
output c_data_0_1_available,
|
||
|
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_1,
|
||
|
output [`AWIDTH-1:0] c_addr_1_1,
|
||
|
output c_data_1_1_available,
|
||
|
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_a_rows,
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_a_cols,
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_b_rows,
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_b_cols
|
||
|
);
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// ORing all done signals
|
||
|
/////////////////////////////////////////////////
|
||
|
wire done_mat_mul_0_0;
|
||
|
wire done_mat_mul_0_1;
|
||
|
wire done_mat_mul_1_0;
|
||
|
wire done_mat_mul_1_1;
|
||
|
|
||
|
assign done_mat_mul = done_mat_mul_0_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_1_NC;
|
||
|
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_1_NC;
|
||
|
/////////////////////////////////////////////////
|
||
|
// Matmul 0_0
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
wire [3:0] flags_NC_0_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_0_to_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_0_to_1_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in_0_0_NC;
|
||
|
assign a_data_in_0_0_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in_0_0_NC;
|
||
|
assign c_data_in_0_0_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in_0_0_NC;
|
||
|
assign b_data_in_0_0_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_0_to_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_0_to_0_1_NC;
|
||
|
wire [`AWIDTH-1:0] c_addr_0_0_NC;
|
||
|
wire c_data_0_0_available_NC;
|
||
|
|
||
|
matmul_4x4_systolic u_matmul_4x4_systolic_0_0(
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.pe_reset(pe_reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.done_mat_mul(done_mat_mul_0_0),
|
||
|
.address_mat_a(address_mat_a),
|
||
|
.address_mat_b(address_mat_b),
|
||
|
.address_mat_c(address_mat_c),
|
||
|
.address_stride_a(address_stride_a),
|
||
|
.address_stride_b(address_stride_b),
|
||
|
.address_stride_c(address_stride_c),
|
||
|
.a_data(a_data_0_0),
|
||
|
.b_data(b_data_0_0),
|
||
|
.a_data_in(a_data_in_0_0_NC),
|
||
|
.b_data_in(b_data_in_0_0_NC),
|
||
|
.c_data_in(c_data_in_0_0_NC),
|
||
|
.c_data_out(c_data_0_0_to_0_1_NC),
|
||
|
.a_data_out(a_data_0_0_to_0_1),
|
||
|
.b_data_out(b_data_0_0_to_1_0),
|
||
|
.a_addr(a_addr_0_0),
|
||
|
.b_addr(b_addr_0_0),
|
||
|
.c_addr(c_addr_0_0_NC),
|
||
|
.c_data_available(c_data_0_0_available_NC),
|
||
|
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
|
||
|
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
|
||
|
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
|
||
|
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
|
||
|
.final_mat_mul_size(8'd8),
|
||
|
.a_loc(8'd0),
|
||
|
.b_loc(8'd0)
|
||
|
);
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// Matmul 0_1
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
wire [3:0] flags_NC_0_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_1_to_0_2;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_1_to_1_1;
|
||
|
wire [`AWIDTH-1:0] a_addr_0_1_NC;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_1_NC;
|
||
|
assign a_data_0_1_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in_0_1_NC;
|
||
|
assign b_data_in_0_1_NC = 0;
|
||
|
|
||
|
matmul_4x4_systolic u_matmul_4x4_systolic_0_1(
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.pe_reset(pe_reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.done_mat_mul(done_mat_mul_0_1),
|
||
|
.address_mat_a(address_mat_a),
|
||
|
.address_mat_b(address_mat_b),
|
||
|
.address_mat_c(address_mat_c),
|
||
|
.address_stride_a(address_stride_a),
|
||
|
.address_stride_b(address_stride_b),
|
||
|
.address_stride_c(address_stride_c),
|
||
|
.a_data(a_data_0_1_NC),
|
||
|
.b_data(b_data_0_1),
|
||
|
.a_data_in(a_data_0_0_to_0_1),
|
||
|
.b_data_in(b_data_in_0_1_NC),
|
||
|
.c_data_in(c_data_0_0_to_0_1),
|
||
|
.c_data_out(c_data_0_1),
|
||
|
.a_data_out(a_data_0_1_to_0_2),
|
||
|
.b_data_out(b_data_0_1_to_1_1),
|
||
|
.a_addr(a_addr_0_1_NC),
|
||
|
.b_addr(b_addr_0_1),
|
||
|
.c_addr(c_addr_0_1),
|
||
|
.c_data_available(c_data_0_1_available),
|
||
|
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
|
||
|
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
|
||
|
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
|
||
|
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
|
||
|
.final_mat_mul_size(8'd8),
|
||
|
.a_loc(8'd0),
|
||
|
.b_loc(8'd1)
|
||
|
);
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// Matmul 1_0
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
wire [3:0] flags_NC_1_0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_0_to_1_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_0_to_2_0;
|
||
|
wire [`AWIDTH-1:0] b_addr_1_0_NC;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_0_NC;
|
||
|
assign b_data_1_0_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in_1_0_NC;
|
||
|
assign a_data_in_1_0_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in_1_0_NC;
|
||
|
assign c_data_in_1_0_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_0_to_1_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_0_to_1_1_NC;
|
||
|
wire [`AWIDTH-1:0] c_addr_1_0_NC;
|
||
|
wire c_data_1_0_available_NC;
|
||
|
|
||
|
matmul_4x4_systolic u_matmul_4x4_systolic_1_0(
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.pe_reset(pe_reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.done_mat_mul(done_mat_mul_1_0),
|
||
|
.address_mat_a(address_mat_a),
|
||
|
.address_mat_b(address_mat_b),
|
||
|
.address_mat_c(address_mat_c),
|
||
|
.address_stride_a(address_stride_a),
|
||
|
.address_stride_b(address_stride_b),
|
||
|
.address_stride_c(address_stride_c),
|
||
|
.a_data(a_data_1_0),
|
||
|
.b_data(b_data_1_0_NC),
|
||
|
.a_data_in(a_data_in_1_0_NC),
|
||
|
.b_data_in(b_data_0_0_to_1_0),
|
||
|
.c_data_in(c_data_in_1_0_NC),
|
||
|
.c_data_out(c_data_1_0_to_1_1_NC),
|
||
|
.a_data_out(a_data_1_0_to_1_1),
|
||
|
.b_data_out(b_data_1_0_to_2_0),
|
||
|
.a_addr(a_addr_1_0),
|
||
|
.b_addr(b_addr_1_0_NC),
|
||
|
.c_addr(c_addr_1_0_NC),
|
||
|
.c_data_available(c_data_1_0_available_NC),
|
||
|
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
|
||
|
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
|
||
|
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
|
||
|
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
|
||
|
.final_mat_mul_size(8'd8),
|
||
|
.a_loc(8'd1),
|
||
|
.b_loc(8'd0)
|
||
|
);
|
||
|
|
||
|
/////////////////////////////////////////////////
|
||
|
// Matmul 1_1
|
||
|
/////////////////////////////////////////////////
|
||
|
|
||
|
wire [3:0] flags_NC_1_1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_1_to_1_2;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_1_to_2_1;
|
||
|
wire [`AWIDTH-1:0] a_addr_1_1_NC;
|
||
|
wire [`AWIDTH-1:0] b_addr_1_1_NC;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_1_NC;
|
||
|
assign a_data_1_1_NC = 0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_1_NC;
|
||
|
assign b_data_1_1_NC = 0;
|
||
|
|
||
|
matmul_4x4_systolic u_matmul_4x4_systolic_1_1(
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.pe_reset(pe_reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.done_mat_mul(done_mat_mul_1_1),
|
||
|
.address_mat_a(address_mat_a),
|
||
|
.address_mat_b(address_mat_b),
|
||
|
.address_mat_c(address_mat_c),
|
||
|
.address_stride_a(address_stride_a),
|
||
|
.address_stride_b(address_stride_b),
|
||
|
.address_stride_c(address_stride_c),
|
||
|
.a_data(a_data_1_1_NC),
|
||
|
.b_data(b_data_1_1_NC),
|
||
|
.a_data_in(a_data_1_0_to_1_1),
|
||
|
.b_data_in(b_data_0_1_to_1_1),
|
||
|
.c_data_in(c_data_1_0_to_1_1),
|
||
|
.c_data_out(c_data_1_1),
|
||
|
.a_data_out(a_data_1_1_to_1_2),
|
||
|
.b_data_out(b_data_1_1_to_2_1),
|
||
|
.a_addr(a_addr_1_1_NC),
|
||
|
.b_addr(b_addr_1_1_NC),
|
||
|
.c_addr(c_addr_1_1),
|
||
|
.c_data_available(c_data_1_1_available),
|
||
|
.validity_mask_a_rows({4'b0,validity_mask_a_rows}),
|
||
|
.validity_mask_a_cols({4'b0,validity_mask_a_cols}),
|
||
|
.validity_mask_b_rows({4'b0,validity_mask_b_rows}),
|
||
|
.validity_mask_b_cols({4'b0,validity_mask_b_cols}),
|
||
|
.final_mat_mul_size(8'd8),
|
||
|
.a_loc(8'd1),
|
||
|
.b_loc(8'd1)
|
||
|
);
|
||
|
|
||
|
endmodule
|
||
|
|
||
|
|
||
|
//////////////////////////////////
|
||
|
//Dual port RAM
|
||
|
//////////////////////////////////
|
||
|
module ram (
|
||
|
addr0,
|
||
|
d0,
|
||
|
we0,
|
||
|
q0,
|
||
|
addr1,
|
||
|
d1,
|
||
|
we1,
|
||
|
q1,
|
||
|
clk);
|
||
|
|
||
|
input [`AWIDTH-1:0] addr0;
|
||
|
input [`AWIDTH-1:0] addr1;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] d0;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] d1;
|
||
|
input [`MAT_MUL_SIZE-1:0] we0;
|
||
|
input [`MAT_MUL_SIZE-1:0] we1;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] q0;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] q1;
|
||
|
input clk;
|
||
|
|
||
|
`ifdef VCS
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] q0;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] q1;
|
||
|
reg [7:0] ram[((1<<`AWIDTH)-1):0];
|
||
|
integer i;
|
||
|
|
||
|
always @(posedge clk)
|
||
|
begin
|
||
|
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
|
||
|
if (we0[i]) ram[addr0+i] <= d0[i*`DWIDTH +: `DWIDTH];
|
||
|
end
|
||
|
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
|
||
|
q0[i*`DWIDTH +: `DWIDTH] <= ram[addr0+i];
|
||
|
end
|
||
|
end
|
||
|
|
||
|
always @(posedge clk)
|
||
|
begin
|
||
|
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
|
||
|
if (we1[i]) ram[addr0+i] <= d1[i*`DWIDTH +: `DWIDTH];
|
||
|
end
|
||
|
for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin
|
||
|
q1[i*`DWIDTH +: `DWIDTH] <= ram[addr1+i];
|
||
|
end
|
||
|
end
|
||
|
|
||
|
`else
|
||
|
//BRAMs available in VTR FPGA architectures have one bit write-enables.
|
||
|
//So let's combine multiple bits into 1. We don't have a usecase of
|
||
|
//writing/not-writing only parts of the word anyway.
|
||
|
wire we0_coalesced;
|
||
|
assign we0_coalesced = |we0;
|
||
|
wire we1_coalesced;
|
||
|
assign we1_coalesced = |we1;
|
||
|
|
||
|
dual_port_ram u_dual_port_ram(
|
||
|
.addr1(addr0),
|
||
|
.we1(we0_coalesced),
|
||
|
.data1(d0),
|
||
|
.out1(q0),
|
||
|
.addr2(addr1),
|
||
|
.we2(we1_coalesced),
|
||
|
.data2(d1),
|
||
|
.out2(q1),
|
||
|
.clk(clk)
|
||
|
);
|
||
|
|
||
|
`endif
|
||
|
|
||
|
endmodule
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////
|
||
|
//4x4 systolic matrix multiplier
|
||
|
//////////////////////////////////////////////
|
||
|
module matmul_4x4_systolic(
|
||
|
clk,
|
||
|
reset,
|
||
|
pe_reset,
|
||
|
start_mat_mul,
|
||
|
done_mat_mul,
|
||
|
address_mat_a,
|
||
|
address_mat_b,
|
||
|
address_mat_c,
|
||
|
address_stride_a,
|
||
|
address_stride_b,
|
||
|
address_stride_c,
|
||
|
a_data,
|
||
|
b_data,
|
||
|
a_data_in, //Data values coming in from previous matmul - systolic connections
|
||
|
b_data_in, //Data values coming in from previous matmul - systolic connections
|
||
|
c_data_in, //Data values coming in from previous matmul - systolic shifting
|
||
|
c_data_out,//Data values going out to next matmul - systolic shifting
|
||
|
a_data_out,
|
||
|
b_data_out,
|
||
|
a_addr,
|
||
|
b_addr,
|
||
|
c_addr,
|
||
|
c_data_available,
|
||
|
validity_mask_a_rows,
|
||
|
validity_mask_a_cols,
|
||
|
validity_mask_b_rows,
|
||
|
validity_mask_b_cols,
|
||
|
final_mat_mul_size,
|
||
|
a_loc,
|
||
|
b_loc
|
||
|
);
|
||
|
|
||
|
input clk;
|
||
|
input reset;
|
||
|
input pe_reset;
|
||
|
input start_mat_mul;
|
||
|
output done_mat_mul;
|
||
|
input [`AWIDTH-1:0] address_mat_a;
|
||
|
input [`AWIDTH-1:0] address_mat_b;
|
||
|
input [`AWIDTH-1:0] address_mat_c;
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a;
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b;
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_out;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_out;
|
||
|
output [`AWIDTH-1:0] a_addr;
|
||
|
output [`AWIDTH-1:0] b_addr;
|
||
|
output [`AWIDTH-1:0] c_addr;
|
||
|
output c_data_available;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_a_rows;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_a_cols;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_b_rows;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_b_cols;
|
||
|
//7:0 is okay here. We aren't going to make a matmul larger than 128x128
|
||
|
//In fact, these will get optimized out by the synthesis tool, because
|
||
|
//we hardcode them at the instantiation level.
|
||
|
input [7:0] final_mat_mul_size;
|
||
|
input [7:0] a_loc;
|
||
|
input [7:0] b_loc;
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic for clock counting and when to assert done
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
|
||
|
reg done_mat_mul;
|
||
|
//This is 7 bits because the expectation is that clock count will be pretty
|
||
|
//small. For large matmuls, this will need to increased to have more bits.
|
||
|
//In general, a systolic multiplier takes 4*N-2+P cycles, where N is the size
|
||
|
//of the matmul and P is the number of pipleine stages in the MAC block.
|
||
|
reg [7:0] clk_cnt;
|
||
|
|
||
|
//Finding out number of cycles to assert matmul done.
|
||
|
//When we have to save the outputs to accumulators, then we don't need to
|
||
|
//shift out data. So, we can assert done_mat_mul early.
|
||
|
//In the normal case, we have to include the time to shift out the results.
|
||
|
//Note: the count expression used to contain "4*final_mat_mul_size", but
|
||
|
//to avoid multiplication, we now use "final_mat_mul_size<<2"
|
||
|
wire [7:0] clk_cnt_for_done;
|
||
|
assign clk_cnt_for_done = ((final_mat_mul_size<<2) - 2 + `NUM_CYCLES_IN_MAC) ;
|
||
|
|
||
|
always @(posedge clk) begin
|
||
|
if (reset || ~start_mat_mul) begin
|
||
|
clk_cnt <= 0;
|
||
|
done_mat_mul <= 0;
|
||
|
end
|
||
|
else if (clk_cnt == clk_cnt_for_done) begin
|
||
|
done_mat_mul <= 1;
|
||
|
clk_cnt <= clk_cnt + 1;
|
||
|
end
|
||
|
else if (done_mat_mul == 0) begin
|
||
|
clk_cnt <= clk_cnt + 1;
|
||
|
end
|
||
|
else begin
|
||
|
done_mat_mul <= 0;
|
||
|
clk_cnt <= clk_cnt + 1;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
|
||
|
wire [`DWIDTH-1:0] a0_data;
|
||
|
wire [`DWIDTH-1:0] a1_data;
|
||
|
wire [`DWIDTH-1:0] a2_data;
|
||
|
wire [`DWIDTH-1:0] a3_data;
|
||
|
wire [`DWIDTH-1:0] b0_data;
|
||
|
wire [`DWIDTH-1:0] b1_data;
|
||
|
wire [`DWIDTH-1:0] b2_data;
|
||
|
wire [`DWIDTH-1:0] b3_data;
|
||
|
wire [`DWIDTH-1:0] a1_data_delayed_1;
|
||
|
wire [`DWIDTH-1:0] a2_data_delayed_1;
|
||
|
wire [`DWIDTH-1:0] a2_data_delayed_2;
|
||
|
wire [`DWIDTH-1:0] a3_data_delayed_1;
|
||
|
wire [`DWIDTH-1:0] a3_data_delayed_2;
|
||
|
wire [`DWIDTH-1:0] a3_data_delayed_3;
|
||
|
wire [`DWIDTH-1:0] b1_data_delayed_1;
|
||
|
wire [`DWIDTH-1:0] b2_data_delayed_1;
|
||
|
wire [`DWIDTH-1:0] b2_data_delayed_2;
|
||
|
wire [`DWIDTH-1:0] b3_data_delayed_1;
|
||
|
wire [`DWIDTH-1:0] b3_data_delayed_2;
|
||
|
wire [`DWIDTH-1:0] b3_data_delayed_3;
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Instantiation of systolic data setup
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
systolic_data_setup u_systolic_data_setup(
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.a_addr(a_addr),
|
||
|
.b_addr(b_addr),
|
||
|
.address_mat_a(address_mat_a),
|
||
|
.address_mat_b(address_mat_b),
|
||
|
.address_stride_a(address_stride_a),
|
||
|
.address_stride_b(address_stride_b),
|
||
|
.a_data(a_data),
|
||
|
.b_data(b_data),
|
||
|
.clk_cnt(clk_cnt),
|
||
|
.a0_data(a0_data),
|
||
|
.a1_data_delayed_1(a1_data_delayed_1),
|
||
|
.a2_data_delayed_2(a2_data_delayed_2),
|
||
|
.a3_data_delayed_3(a3_data_delayed_3),
|
||
|
.b0_data(b0_data),
|
||
|
.b1_data_delayed_1(b1_data_delayed_1),
|
||
|
.b2_data_delayed_2(b2_data_delayed_2),
|
||
|
.b3_data_delayed_3(b3_data_delayed_3),
|
||
|
.validity_mask_a_rows(validity_mask_a_rows),
|
||
|
.validity_mask_a_cols(validity_mask_a_cols),
|
||
|
.validity_mask_b_rows(validity_mask_b_rows),
|
||
|
.validity_mask_b_cols(validity_mask_b_cols),
|
||
|
.final_mat_mul_size(final_mat_mul_size),
|
||
|
.a_loc(a_loc),
|
||
|
.b_loc(b_loc)
|
||
|
);
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to mux data_in coming from neighboring matmuls
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
wire [`DWIDTH-1:0] a0;
|
||
|
wire [`DWIDTH-1:0] a1;
|
||
|
wire [`DWIDTH-1:0] a2;
|
||
|
wire [`DWIDTH-1:0] a3;
|
||
|
wire [`DWIDTH-1:0] b0;
|
||
|
wire [`DWIDTH-1:0] b1;
|
||
|
wire [`DWIDTH-1:0] b2;
|
||
|
wire [`DWIDTH-1:0] b3;
|
||
|
|
||
|
wire [`DWIDTH-1:0] a0_data_in;
|
||
|
wire [`DWIDTH-1:0] a1_data_in;
|
||
|
wire [`DWIDTH-1:0] a2_data_in;
|
||
|
wire [`DWIDTH-1:0] a3_data_in;
|
||
|
assign a0_data_in = a_data_in[`DWIDTH-1:0];
|
||
|
assign a1_data_in = a_data_in[2*`DWIDTH-1:`DWIDTH];
|
||
|
assign a2_data_in = a_data_in[3*`DWIDTH-1:2*`DWIDTH];
|
||
|
assign a3_data_in = a_data_in[4*`DWIDTH-1:3*`DWIDTH];
|
||
|
|
||
|
wire [`DWIDTH-1:0] b0_data_in;
|
||
|
wire [`DWIDTH-1:0] b1_data_in;
|
||
|
wire [`DWIDTH-1:0] b2_data_in;
|
||
|
wire [`DWIDTH-1:0] b3_data_in;
|
||
|
assign b0_data_in = b_data_in[`DWIDTH-1:0];
|
||
|
assign b1_data_in = b_data_in[2*`DWIDTH-1:`DWIDTH];
|
||
|
assign b2_data_in = b_data_in[3*`DWIDTH-1:2*`DWIDTH];
|
||
|
assign b3_data_in = b_data_in[4*`DWIDTH-1:3*`DWIDTH];
|
||
|
|
||
|
//If b_loc is 0, that means this matmul block is on the top-row of the
|
||
|
//final large matmul. In that case, b will take inputs from mem.
|
||
|
//If b_loc != 0, that means this matmul block is not on the top-row of the
|
||
|
//final large matmul. In that case, b will take inputs from the matmul on top
|
||
|
//of this one.
|
||
|
assign a0 = (b_loc==0) ? a0_data : a0_data_in;
|
||
|
assign a1 = (b_loc==0) ? a1_data_delayed_1 : a1_data_in;
|
||
|
assign a2 = (b_loc==0) ? a2_data_delayed_2 : a2_data_in;
|
||
|
assign a3 = (b_loc==0) ? a3_data_delayed_3 : a3_data_in;
|
||
|
|
||
|
//If a_loc is 0, that means this matmul block is on the left-col of the
|
||
|
//final large matmul. In that case, a will take inputs from mem.
|
||
|
//If a_loc != 0, that means this matmul block is not on the left-col of the
|
||
|
//final large matmul. In that case, a will take inputs from the matmul on left
|
||
|
//of this one.
|
||
|
assign b0 = (a_loc==0) ? b0_data : b0_data_in;
|
||
|
assign b1 = (a_loc==0) ? b1_data_delayed_1 : b1_data_in;
|
||
|
assign b2 = (a_loc==0) ? b2_data_delayed_2 : b2_data_in;
|
||
|
assign b3 = (a_loc==0) ? b3_data_delayed_3 : b3_data_in;
|
||
|
|
||
|
wire [`DWIDTH-1:0] matrixC00;
|
||
|
wire [`DWIDTH-1:0] matrixC01;
|
||
|
wire [`DWIDTH-1:0] matrixC02;
|
||
|
wire [`DWIDTH-1:0] matrixC03;
|
||
|
wire [`DWIDTH-1:0] matrixC10;
|
||
|
wire [`DWIDTH-1:0] matrixC11;
|
||
|
wire [`DWIDTH-1:0] matrixC12;
|
||
|
wire [`DWIDTH-1:0] matrixC13;
|
||
|
wire [`DWIDTH-1:0] matrixC20;
|
||
|
wire [`DWIDTH-1:0] matrixC21;
|
||
|
wire [`DWIDTH-1:0] matrixC22;
|
||
|
wire [`DWIDTH-1:0] matrixC23;
|
||
|
wire [`DWIDTH-1:0] matrixC30;
|
||
|
wire [`DWIDTH-1:0] matrixC31;
|
||
|
wire [`DWIDTH-1:0] matrixC32;
|
||
|
wire [`DWIDTH-1:0] matrixC33;
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Instantiation of the output logic
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
output_logic u_output_logic(
|
||
|
.clk(clk),
|
||
|
.reset(reset),
|
||
|
.start_mat_mul(start_mat_mul),
|
||
|
.done_mat_mul(done_mat_mul),
|
||
|
.address_mat_c(address_mat_c),
|
||
|
.address_stride_c(address_stride_c),
|
||
|
.c_data_out(c_data_out),
|
||
|
.c_data_in(c_data_in),
|
||
|
.c_addr(c_addr),
|
||
|
.c_data_available(c_data_available),
|
||
|
.clk_cnt(clk_cnt),
|
||
|
.row_latch_en(row_latch_en),
|
||
|
.final_mat_mul_size(final_mat_mul_size),
|
||
|
.matrixC00(matrixC00),
|
||
|
.matrixC01(matrixC01),
|
||
|
.matrixC02(matrixC02),
|
||
|
.matrixC03(matrixC03),
|
||
|
.matrixC10(matrixC10),
|
||
|
.matrixC11(matrixC11),
|
||
|
.matrixC12(matrixC12),
|
||
|
.matrixC13(matrixC13),
|
||
|
.matrixC20(matrixC20),
|
||
|
.matrixC21(matrixC21),
|
||
|
.matrixC22(matrixC22),
|
||
|
.matrixC23(matrixC23),
|
||
|
.matrixC30(matrixC30),
|
||
|
.matrixC31(matrixC31),
|
||
|
.matrixC32(matrixC32),
|
||
|
.matrixC33(matrixC33)
|
||
|
);
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Instantiations of the actual processing elements
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
systolic_pe_matrix u_systolic_pe_matrix(
|
||
|
.reset(reset),
|
||
|
.clk(clk),
|
||
|
.pe_reset(pe_reset),
|
||
|
.a0(a0),
|
||
|
.a1(a1),
|
||
|
.a2(a2),
|
||
|
.a3(a3),
|
||
|
.b0(b0),
|
||
|
.b1(b1),
|
||
|
.b2(b2),
|
||
|
.b3(b3),
|
||
|
.matrixC00(matrixC00),
|
||
|
.matrixC01(matrixC01),
|
||
|
.matrixC02(matrixC02),
|
||
|
.matrixC03(matrixC03),
|
||
|
.matrixC10(matrixC10),
|
||
|
.matrixC11(matrixC11),
|
||
|
.matrixC12(matrixC12),
|
||
|
.matrixC13(matrixC13),
|
||
|
.matrixC20(matrixC20),
|
||
|
.matrixC21(matrixC21),
|
||
|
.matrixC22(matrixC22),
|
||
|
.matrixC23(matrixC23),
|
||
|
.matrixC30(matrixC30),
|
||
|
.matrixC31(matrixC31),
|
||
|
.matrixC32(matrixC32),
|
||
|
.matrixC33(matrixC33),
|
||
|
.a_data_out(a_data_out),
|
||
|
.b_data_out(b_data_out)
|
||
|
);
|
||
|
|
||
|
endmodule
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Output logic
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
module output_logic(
|
||
|
clk,
|
||
|
reset,
|
||
|
start_mat_mul,
|
||
|
done_mat_mul,
|
||
|
address_mat_c,
|
||
|
address_stride_c,
|
||
|
c_data_in,
|
||
|
c_data_out, //Data values going out to next matmul - systolic shifting
|
||
|
c_addr,
|
||
|
c_data_available,
|
||
|
clk_cnt,
|
||
|
row_latch_en,
|
||
|
final_mat_mul_size,
|
||
|
matrixC00,
|
||
|
matrixC01,
|
||
|
matrixC02,
|
||
|
matrixC03,
|
||
|
matrixC10,
|
||
|
matrixC11,
|
||
|
matrixC12,
|
||
|
matrixC13,
|
||
|
matrixC20,
|
||
|
matrixC21,
|
||
|
matrixC22,
|
||
|
matrixC23,
|
||
|
matrixC30,
|
||
|
matrixC31,
|
||
|
matrixC32,
|
||
|
matrixC33
|
||
|
);
|
||
|
|
||
|
input clk;
|
||
|
input reset;
|
||
|
input start_mat_mul;
|
||
|
input done_mat_mul;
|
||
|
input [`AWIDTH-1:0] address_mat_c;
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out;
|
||
|
output [`AWIDTH-1:0] c_addr;
|
||
|
output c_data_available;
|
||
|
input [7:0] clk_cnt;
|
||
|
output row_latch_en;
|
||
|
input [7:0] final_mat_mul_size;
|
||
|
input [`DWIDTH-1:0] matrixC00;
|
||
|
input [`DWIDTH-1:0] matrixC01;
|
||
|
input [`DWIDTH-1:0] matrixC02;
|
||
|
input [`DWIDTH-1:0] matrixC03;
|
||
|
input [`DWIDTH-1:0] matrixC10;
|
||
|
input [`DWIDTH-1:0] matrixC11;
|
||
|
input [`DWIDTH-1:0] matrixC12;
|
||
|
input [`DWIDTH-1:0] matrixC13;
|
||
|
input [`DWIDTH-1:0] matrixC20;
|
||
|
input [`DWIDTH-1:0] matrixC21;
|
||
|
input [`DWIDTH-1:0] matrixC22;
|
||
|
input [`DWIDTH-1:0] matrixC23;
|
||
|
input [`DWIDTH-1:0] matrixC30;
|
||
|
input [`DWIDTH-1:0] matrixC31;
|
||
|
input [`DWIDTH-1:0] matrixC32;
|
||
|
input [`DWIDTH-1:0] matrixC33;
|
||
|
|
||
|
wire row_latch_en;
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to capture matrix C data from the PEs and shift it out
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
|
||
|
assign row_latch_en = ((clk_cnt == ((final_mat_mul_size<<2) - final_mat_mul_size -1 +`NUM_CYCLES_IN_MAC)));
|
||
|
|
||
|
reg c_data_available;
|
||
|
reg [`AWIDTH-1:0] c_addr;
|
||
|
reg start_capturing_c_data;
|
||
|
integer counter;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_1;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_2;
|
||
|
reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_3;
|
||
|
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col0;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col1;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col2;
|
||
|
wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col3;
|
||
|
assign col0 = {matrixC30, matrixC20, matrixC10, matrixC00};
|
||
|
assign col1 = {matrixC31, matrixC21, matrixC11, matrixC01};
|
||
|
assign col2 = {matrixC32, matrixC22, matrixC12, matrixC02};
|
||
|
assign col3 = {matrixC33, matrixC23, matrixC13, matrixC03};
|
||
|
|
||
|
//If save_output_to_accum is asserted, that means we are not intending to shift
|
||
|
//out the outputs, because the outputs are still partial sums.
|
||
|
wire condition_to_start_shifting_output;
|
||
|
assign condition_to_start_shifting_output = row_latch_en ;
|
||
|
|
||
|
//For larger matmuls, this logic will have more entries in the case statement
|
||
|
always @(posedge clk) begin
|
||
|
if (reset | ~start_mat_mul) begin
|
||
|
start_capturing_c_data <= 1'b0;
|
||
|
c_data_available <= 1'b0;
|
||
|
c_addr <= address_mat_c+address_stride_c;
|
||
|
c_data_out <= 0;
|
||
|
counter <= 0;
|
||
|
c_data_out_1 <= 0;
|
||
|
c_data_out_2 <= 0;
|
||
|
c_data_out_3 <= 0;
|
||
|
end
|
||
|
else if (condition_to_start_shifting_output) begin
|
||
|
start_capturing_c_data <= 1'b1;
|
||
|
c_data_available <= 1'b1;
|
||
|
c_addr <= c_addr - address_stride_c;
|
||
|
c_data_out <= col0;
|
||
|
c_data_out_1 <= col1;
|
||
|
c_data_out_2 <= col2;
|
||
|
c_data_out_3 <= col3;
|
||
|
counter <= counter + 1;
|
||
|
end
|
||
|
else if (done_mat_mul) begin
|
||
|
start_capturing_c_data <= 1'b0;
|
||
|
c_data_available <= 1'b0;
|
||
|
c_addr <= address_mat_c+address_stride_c;
|
||
|
c_data_out <= 0;
|
||
|
c_data_out_1 <= 0;
|
||
|
c_data_out_2 <= 0;
|
||
|
c_data_out_3 <= 0;
|
||
|
end
|
||
|
else if (counter >= `MAT_MUL_SIZE) begin
|
||
|
c_addr <= c_addr - address_stride_c;
|
||
|
c_data_out <= c_data_out_1;
|
||
|
c_data_out_1 <= c_data_out_2;
|
||
|
c_data_out_2 <= c_data_out_3;
|
||
|
c_data_out_3 <= c_data_in;
|
||
|
end
|
||
|
else if (start_capturing_c_data) begin
|
||
|
c_data_available <= 1'b1;
|
||
|
c_addr <= c_addr - address_stride_c;
|
||
|
counter <= counter + 1;
|
||
|
c_data_out <= c_data_out_1;
|
||
|
c_data_out_1 <= c_data_out_2;
|
||
|
c_data_out_2 <= c_data_out_3;
|
||
|
c_data_out_3 <= c_data_in;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
endmodule
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Systolic data setup
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
module systolic_data_setup(
|
||
|
clk,
|
||
|
reset,
|
||
|
start_mat_mul,
|
||
|
a_addr,
|
||
|
b_addr,
|
||
|
address_mat_a,
|
||
|
address_mat_b,
|
||
|
address_stride_a,
|
||
|
address_stride_b,
|
||
|
a_data,
|
||
|
b_data,
|
||
|
clk_cnt,
|
||
|
a0_data,
|
||
|
a1_data_delayed_1,
|
||
|
a2_data_delayed_2,
|
||
|
a3_data_delayed_3,
|
||
|
b0_data,
|
||
|
b1_data_delayed_1,
|
||
|
b2_data_delayed_2,
|
||
|
b3_data_delayed_3,
|
||
|
validity_mask_a_rows,
|
||
|
validity_mask_a_cols,
|
||
|
validity_mask_b_rows,
|
||
|
validity_mask_b_cols,
|
||
|
final_mat_mul_size,
|
||
|
a_loc,
|
||
|
b_loc
|
||
|
);
|
||
|
|
||
|
input clk;
|
||
|
input reset;
|
||
|
input start_mat_mul;
|
||
|
output [`AWIDTH-1:0] a_addr;
|
||
|
output [`AWIDTH-1:0] b_addr;
|
||
|
input [`AWIDTH-1:0] address_mat_a;
|
||
|
input [`AWIDTH-1:0] address_mat_b;
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a;
|
||
|
input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data;
|
||
|
input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data;
|
||
|
input [7:0] clk_cnt;
|
||
|
output [`DWIDTH-1:0] a0_data;
|
||
|
output [`DWIDTH-1:0] a1_data_delayed_1;
|
||
|
output [`DWIDTH-1:0] a2_data_delayed_2;
|
||
|
output [`DWIDTH-1:0] a3_data_delayed_3;
|
||
|
output [`DWIDTH-1:0] b0_data;
|
||
|
output [`DWIDTH-1:0] b1_data_delayed_1;
|
||
|
output [`DWIDTH-1:0] b2_data_delayed_2;
|
||
|
output [`DWIDTH-1:0] b3_data_delayed_3;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_a_rows;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_a_cols;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_b_rows;
|
||
|
input [`MASK_WIDTH-1:0] validity_mask_b_cols;
|
||
|
input [7:0] final_mat_mul_size;
|
||
|
input [7:0] a_loc;
|
||
|
input [7:0] b_loc;
|
||
|
|
||
|
wire [`DWIDTH-1:0] a0_data;
|
||
|
wire [`DWIDTH-1:0] a1_data;
|
||
|
wire [`DWIDTH-1:0] a2_data;
|
||
|
wire [`DWIDTH-1:0] a3_data;
|
||
|
wire [`DWIDTH-1:0] b0_data;
|
||
|
wire [`DWIDTH-1:0] b1_data;
|
||
|
wire [`DWIDTH-1:0] b2_data;
|
||
|
wire [`DWIDTH-1:0] b3_data;
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to generate addresses to BRAM A
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
reg [`AWIDTH-1:0] a_addr;
|
||
|
reg a_mem_access; //flag that tells whether the matmul is trying to access memory or not
|
||
|
|
||
|
always @(posedge clk) begin
|
||
|
//else if (clk_cnt >= a_loc*`MAT_MUL_SIZE+final_mat_mul_size) begin
|
||
|
//Writing the line above to avoid multiplication:
|
||
|
if ((reset || ~start_mat_mul) || (clk_cnt >= (a_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
|
||
|
a_addr <= address_mat_a-address_stride_a;
|
||
|
a_mem_access <= 0;
|
||
|
end
|
||
|
|
||
|
//else if ((clk_cnt >= a_loc*`MAT_MUL_SIZE) && (clk_cnt < a_loc*`MAT_MUL_SIZE+final_mat_mul_size)) begin
|
||
|
//Writing the line above to avoid multiplication:
|
||
|
else if ((clk_cnt >= (a_loc<<`LOG2_MAT_MUL_SIZE)) && (clk_cnt < (a_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
|
||
|
a_addr <= a_addr + address_stride_a;
|
||
|
a_mem_access <= 1;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to generate valid signals for data coming from BRAM A
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
reg [7:0] a_mem_access_counter;
|
||
|
always @(posedge clk) begin
|
||
|
if (reset || ~start_mat_mul) begin
|
||
|
a_mem_access_counter <= 0;
|
||
|
end
|
||
|
else if (a_mem_access == 1) begin
|
||
|
a_mem_access_counter <= a_mem_access_counter + 1;
|
||
|
|
||
|
end
|
||
|
else begin
|
||
|
a_mem_access_counter <= 0;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
wire a_data_valid; //flag that tells whether the data from memory is valid
|
||
|
assign a_data_valid =
|
||
|
((validity_mask_a_cols[0]==1'b0 && a_mem_access_counter==1) ||
|
||
|
(validity_mask_a_cols[1]==1'b0 && a_mem_access_counter==2) ||
|
||
|
(validity_mask_a_cols[2]==1'b0 && a_mem_access_counter==3) ||
|
||
|
(validity_mask_a_cols[3]==1'b0 && a_mem_access_counter==4)) ?
|
||
|
1'b0 : (a_mem_access_counter >= `MEM_ACCESS_LATENCY);
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to delay certain parts of the data received from BRAM A (systolic data setup)
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
//Slice data into chunks and qualify it with whether it is valid or not
|
||
|
assign a0_data = a_data[`DWIDTH-1:0] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[0]}};
|
||
|
assign a1_data = a_data[2*`DWIDTH-1:`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[1]}};
|
||
|
assign a2_data = a_data[3*`DWIDTH-1:2*`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[2]}};
|
||
|
assign a3_data = a_data[4*`DWIDTH-1:3*`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[3]}};
|
||
|
|
||
|
//For larger matmuls, more such delaying flops will be needed
|
||
|
reg [`DWIDTH-1:0] a1_data_delayed_1;
|
||
|
reg [`DWIDTH-1:0] a2_data_delayed_1;
|
||
|
reg [`DWIDTH-1:0] a2_data_delayed_2;
|
||
|
reg [`DWIDTH-1:0] a3_data_delayed_1;
|
||
|
reg [`DWIDTH-1:0] a3_data_delayed_2;
|
||
|
reg [`DWIDTH-1:0] a3_data_delayed_3;
|
||
|
always @(posedge clk) begin
|
||
|
if (reset || ~start_mat_mul || clk_cnt==0) begin
|
||
|
a1_data_delayed_1 <= 0;
|
||
|
a2_data_delayed_1 <= 0;
|
||
|
a2_data_delayed_2 <= 0;
|
||
|
a3_data_delayed_1 <= 0;
|
||
|
a3_data_delayed_2 <= 0;
|
||
|
a3_data_delayed_3 <= 0;
|
||
|
end
|
||
|
else begin
|
||
|
a1_data_delayed_1 <= a1_data;
|
||
|
a2_data_delayed_1 <= a2_data;
|
||
|
a2_data_delayed_2 <= a2_data_delayed_1;
|
||
|
a3_data_delayed_1 <= a3_data;
|
||
|
a3_data_delayed_2 <= a3_data_delayed_1;
|
||
|
a3_data_delayed_3 <= a3_data_delayed_2;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to generate addresses to BRAM B
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
reg [`AWIDTH-1:0] b_addr;
|
||
|
reg b_mem_access; //flag that tells whether the matmul is trying to access memory or not
|
||
|
|
||
|
always @(posedge clk) begin
|
||
|
//else if (clk_cnt >= b_loc*`MAT_MUL_SIZE+final_mat_mul_size) begin
|
||
|
//Writing the line above to avoid multiplication:
|
||
|
if ((reset || ~start_mat_mul) || (clk_cnt >= (b_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
|
||
|
b_addr <= address_mat_b - address_stride_b;
|
||
|
b_mem_access <= 0;
|
||
|
end
|
||
|
//else if ((clk_cnt >= b_loc*`MAT_MUL_SIZE) && (clk_cnt < b_loc*`MAT_MUL_SIZE+final_mat_mul_size)) begin
|
||
|
//Writing the line above to avoid multiplication:
|
||
|
else if ((clk_cnt >= (b_loc<<`LOG2_MAT_MUL_SIZE)) && (clk_cnt < (b_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin
|
||
|
b_addr <= b_addr + address_stride_b;
|
||
|
b_mem_access <= 1;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to generate valid signals for data coming from BRAM B
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
reg [7:0] b_mem_access_counter;
|
||
|
always @(posedge clk) begin
|
||
|
if (reset || ~start_mat_mul) begin
|
||
|
b_mem_access_counter <= 0;
|
||
|
end
|
||
|
else if (b_mem_access == 1) begin
|
||
|
b_mem_access_counter <= b_mem_access_counter + 1;
|
||
|
end
|
||
|
else begin
|
||
|
b_mem_access_counter <= 0;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
wire b_data_valid; //flag that tells whether the data from memory is valid
|
||
|
assign b_data_valid =
|
||
|
((validity_mask_b_rows[0]==1'b0 && b_mem_access_counter==1) ||
|
||
|
(validity_mask_b_rows[1]==1'b0 && b_mem_access_counter==2) ||
|
||
|
(validity_mask_b_rows[2]==1'b0 && b_mem_access_counter==3) ||
|
||
|
(validity_mask_b_rows[3]==1'b0 && b_mem_access_counter==4)) ?
|
||
|
1'b0 : (b_mem_access_counter >= `MEM_ACCESS_LATENCY);
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Logic to delay certain parts of the data received from BRAM B (systolic data setup)
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
//Slice data into chunks and qualify it with whether it is valid or not
|
||
|
assign b0_data = b_data[`DWIDTH-1:0] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[0]}};
|
||
|
assign b1_data = b_data[2*`DWIDTH-1:`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[1]}};
|
||
|
assign b2_data = b_data[3*`DWIDTH-1:2*`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[2]}};
|
||
|
assign b3_data = b_data[4*`DWIDTH-1:3*`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[3]}};
|
||
|
|
||
|
//For larger matmuls, more such delaying flops will be needed
|
||
|
reg [`DWIDTH-1:0] b1_data_delayed_1;
|
||
|
reg [`DWIDTH-1:0] b2_data_delayed_1;
|
||
|
reg [`DWIDTH-1:0] b2_data_delayed_2;
|
||
|
reg [`DWIDTH-1:0] b3_data_delayed_1;
|
||
|
reg [`DWIDTH-1:0] b3_data_delayed_2;
|
||
|
reg [`DWIDTH-1:0] b3_data_delayed_3;
|
||
|
always @(posedge clk) begin
|
||
|
if (reset || ~start_mat_mul || clk_cnt==0) begin
|
||
|
b1_data_delayed_1 <= 0;
|
||
|
b2_data_delayed_1 <= 0;
|
||
|
b2_data_delayed_2 <= 0;
|
||
|
b3_data_delayed_1 <= 0;
|
||
|
b3_data_delayed_2 <= 0;
|
||
|
b3_data_delayed_3 <= 0;
|
||
|
end
|
||
|
else begin
|
||
|
b1_data_delayed_1 <= b1_data;
|
||
|
b2_data_delayed_1 <= b2_data;
|
||
|
b2_data_delayed_2 <= b2_data_delayed_1;
|
||
|
b3_data_delayed_1 <= b3_data;
|
||
|
b3_data_delayed_2 <= b3_data_delayed_1;
|
||
|
b3_data_delayed_3 <= b3_data_delayed_2;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
endmodule
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Systolically connected PEs
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
module systolic_pe_matrix(
|
||
|
reset,
|
||
|
clk,
|
||
|
pe_reset,
|
||
|
a0, a1, a2, a3,
|
||
|
b0, b1, b2, b3,
|
||
|
matrixC00,
|
||
|
matrixC01,
|
||
|
matrixC02,
|
||
|
matrixC03,
|
||
|
matrixC10,
|
||
|
matrixC11,
|
||
|
matrixC12,
|
||
|
matrixC13,
|
||
|
matrixC20,
|
||
|
matrixC21,
|
||
|
matrixC22,
|
||
|
matrixC23,
|
||
|
matrixC30,
|
||
|
matrixC31,
|
||
|
matrixC32,
|
||
|
matrixC33,
|
||
|
a_data_out,
|
||
|
b_data_out
|
||
|
);
|
||
|
|
||
|
input clk;
|
||
|
input reset;
|
||
|
input pe_reset;
|
||
|
input [`DWIDTH-1:0] a0;
|
||
|
input [`DWIDTH-1:0] a1;
|
||
|
input [`DWIDTH-1:0] a2;
|
||
|
input [`DWIDTH-1:0] a3;
|
||
|
input [`DWIDTH-1:0] b0;
|
||
|
input [`DWIDTH-1:0] b1;
|
||
|
input [`DWIDTH-1:0] b2;
|
||
|
input [`DWIDTH-1:0] b3;
|
||
|
output [`DWIDTH-1:0] matrixC00;
|
||
|
output [`DWIDTH-1:0] matrixC01;
|
||
|
output [`DWIDTH-1:0] matrixC02;
|
||
|
output [`DWIDTH-1:0] matrixC03;
|
||
|
output [`DWIDTH-1:0] matrixC10;
|
||
|
output [`DWIDTH-1:0] matrixC11;
|
||
|
output [`DWIDTH-1:0] matrixC12;
|
||
|
output [`DWIDTH-1:0] matrixC13;
|
||
|
output [`DWIDTH-1:0] matrixC20;
|
||
|
output [`DWIDTH-1:0] matrixC21;
|
||
|
output [`DWIDTH-1:0] matrixC22;
|
||
|
output [`DWIDTH-1:0] matrixC23;
|
||
|
output [`DWIDTH-1:0] matrixC30;
|
||
|
output [`DWIDTH-1:0] matrixC31;
|
||
|
output [`DWIDTH-1:0] matrixC32;
|
||
|
output [`DWIDTH-1:0] matrixC33;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_out;
|
||
|
output [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_out;
|
||
|
|
||
|
wire [`DWIDTH-1:0] a00to01, a01to02, a02to03, a03to04;
|
||
|
wire [`DWIDTH-1:0] a10to11, a11to12, a12to13, a13to14;
|
||
|
wire [`DWIDTH-1:0] a20to21, a21to22, a22to23, a23to24;
|
||
|
wire [`DWIDTH-1:0] a30to31, a31to32, a32to33, a33to34;
|
||
|
|
||
|
wire [`DWIDTH-1:0] b00to10, b10to20, b20to30, b30to40;
|
||
|
wire [`DWIDTH-1:0] b01to11, b11to21, b21to31, b31to41;
|
||
|
wire [`DWIDTH-1:0] b02to12, b12to22, b22to32, b32to42;
|
||
|
wire [`DWIDTH-1:0] b03to13, b13to23, b23to33, b33to43;
|
||
|
|
||
|
wire effective_rst;
|
||
|
assign effective_rst = reset | pe_reset;
|
||
|
|
||
|
processing_element pe00(.reset(effective_rst), .clk(clk), .in_a(a0), .in_b(b0), .out_a(a00to01), .out_b(b00to10), .out_c(matrixC00));
|
||
|
processing_element pe01(.reset(effective_rst), .clk(clk), .in_a(a00to01), .in_b(b1), .out_a(a01to02), .out_b(b01to11), .out_c(matrixC01));
|
||
|
processing_element pe02(.reset(effective_rst), .clk(clk), .in_a(a01to02), .in_b(b2), .out_a(a02to03), .out_b(b02to12), .out_c(matrixC02));
|
||
|
processing_element pe03(.reset(effective_rst), .clk(clk), .in_a(a02to03), .in_b(b3), .out_a(a03to04), .out_b(b03to13), .out_c(matrixC03));
|
||
|
|
||
|
processing_element pe10(.reset(effective_rst), .clk(clk), .in_a(a1), .in_b(b00to10), .out_a(a10to11), .out_b(b10to20), .out_c(matrixC10));
|
||
|
processing_element pe11(.reset(effective_rst), .clk(clk), .in_a(a10to11), .in_b(b01to11), .out_a(a11to12), .out_b(b11to21), .out_c(matrixC11));
|
||
|
processing_element pe12(.reset(effective_rst), .clk(clk), .in_a(a11to12), .in_b(b02to12), .out_a(a12to13), .out_b(b12to22), .out_c(matrixC12));
|
||
|
processing_element pe13(.reset(effective_rst), .clk(clk), .in_a(a12to13), .in_b(b03to13), .out_a(a13to14), .out_b(b13to23), .out_c(matrixC13));
|
||
|
|
||
|
processing_element pe20(.reset(effective_rst), .clk(clk), .in_a(a2), .in_b(b10to20), .out_a(a20to21), .out_b(b20to30), .out_c(matrixC20));
|
||
|
processing_element pe21(.reset(effective_rst), .clk(clk), .in_a(a20to21), .in_b(b11to21), .out_a(a21to22), .out_b(b21to31), .out_c(matrixC21));
|
||
|
processing_element pe22(.reset(effective_rst), .clk(clk), .in_a(a21to22), .in_b(b12to22), .out_a(a22to23), .out_b(b22to32), .out_c(matrixC22));
|
||
|
processing_element pe23(.reset(effective_rst), .clk(clk), .in_a(a22to23), .in_b(b13to23), .out_a(a23to24), .out_b(b23to33), .out_c(matrixC23));
|
||
|
|
||
|
processing_element pe30(.reset(effective_rst), .clk(clk), .in_a(a3), .in_b(b20to30), .out_a(a30to31), .out_b(b30to40), .out_c(matrixC30));
|
||
|
processing_element pe31(.reset(effective_rst), .clk(clk), .in_a(a30to31), .in_b(b21to31), .out_a(a31to32), .out_b(b31to41), .out_c(matrixC31));
|
||
|
processing_element pe32(.reset(effective_rst), .clk(clk), .in_a(a31to32), .in_b(b22to32), .out_a(a32to33), .out_b(b32to42), .out_c(matrixC32));
|
||
|
processing_element pe33(.reset(effective_rst), .clk(clk), .in_a(a32to33), .in_b(b23to33), .out_a(a33to34), .out_b(b33to43), .out_c(matrixC33));
|
||
|
|
||
|
assign a_data_out = {a33to34,a23to24,a13to14,a03to04};
|
||
|
assign b_data_out = {b33to43,b32to42,b31to41,b30to40};
|
||
|
|
||
|
endmodule
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// Processing element (PE)
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
module processing_element(
|
||
|
reset,
|
||
|
clk,
|
||
|
in_a,
|
||
|
in_b,
|
||
|
out_a,
|
||
|
out_b,
|
||
|
out_c
|
||
|
);
|
||
|
|
||
|
input reset;
|
||
|
input clk;
|
||
|
input [`DWIDTH-1:0] in_a;
|
||
|
input [`DWIDTH-1:0] in_b;
|
||
|
output [`DWIDTH-1:0] out_a;
|
||
|
output [`DWIDTH-1:0] out_b;
|
||
|
output [`DWIDTH-1:0] out_c; //reduced precision
|
||
|
|
||
|
reg [`DWIDTH-1:0] out_a;
|
||
|
reg [`DWIDTH-1:0] out_b;
|
||
|
wire [`DWIDTH-1:0] out_c;
|
||
|
|
||
|
wire [`DWIDTH-1:0] out_mac;
|
||
|
|
||
|
assign out_c = out_mac;
|
||
|
|
||
|
//This is an instantiation of a module that is defined in the arch file.
|
||
|
//It's a mode of the DSP slice (floating point 16-bit multiply and accumulate).
|
||
|
mac_fp u_mac(.a(in_a), .b(in_b), .out(out_mac), .reset(reset), .clk(clk));
|
||
|
|
||
|
always @(posedge clk)begin
|
||
|
if(reset) begin
|
||
|
out_a<=0;
|
||
|
out_b<=0;
|
||
|
end
|
||
|
else begin
|
||
|
out_a<=in_a;
|
||
|
out_b<=in_b;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
endmodule
|
||
|
|