//////////////////////////////////////////////// // Matrix multiplication design // Multiplies 8x8 matrix (A) with another 8x8 matrix (B) // to produce an 8x8 matrix (C). // Data precision is IEEE floating point 16-bit (half precision) // The architecture is systolic in nature (output stationary). // 4 4x4 matmuls composed to make a larger 8x8 matmul. // There is state machine for control and an APB // interface for programming/configuring. // Matrices are stores in RAM blocks. /////////////////////////////////////////////// `timescale 1ns/1ns `define DWIDTH 16 `define AWIDTH 10 `define MEM_SIZE 1024 `define DESIGN_SIZE 8 `define MAT_MUL_SIZE 4 `define MASK_WIDTH 4 `define LOG2_MAT_MUL_SIZE 2 `define NUM_CYCLES_IN_MAC 3 `define MEM_ACCESS_LATENCY 1 `define REG_DATAWIDTH 32 `define REG_ADDRWIDTH 8 `define ADDR_STRIDE_WIDTH 16 `define REG_STDN_TPU_ADDR 32'h4 `define REG_MATRIX_A_ADDR 32'he `define REG_MATRIX_B_ADDR 32'h12 `define REG_MATRIX_C_ADDR 32'h16 `define REG_VALID_MASK_A_ROWS_ADDR 32'h20 `define REG_VALID_MASK_A_COLS_ADDR 32'h54 `define REG_VALID_MASK_B_ROWS_ADDR 32'h5c `define REG_VALID_MASK_B_COLS_ADDR 32'h58 `define REG_MATRIX_A_STRIDE_ADDR 32'h28 `define REG_MATRIX_B_STRIDE_ADDR 32'h32 `define REG_MATRIX_C_STRIDE_ADDR 32'h36 module matrix_multiplication( input clk, input clk_mem, input resetn, input pe_resetn, input PRESETn, input [`REG_ADDRWIDTH-1:0] PADDR, input PWRITE, input PSEL, input PENABLE, input [`REG_DATAWIDTH-1:0] PWDATA, output reg [`REG_DATAWIDTH-1:0] PRDATA, output reg PREADY, input [7:0] bram_select, input [`AWIDTH-1:0] bram_addr_ext, output reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_ext, input [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_ext, input [`MAT_MUL_SIZE-1:0] bram_we_ext ); wire PCLK; assign PCLK = clk; reg start_reg; reg clear_done_reg; //Dummy register to sync all other invalid/unimplemented addresses reg [`REG_DATAWIDTH-1:0] reg_dummy; reg [`AWIDTH-1:0] bram_addr_a_0_0_ext; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_0_0_ext; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_0_0_ext; reg [`MASK_WIDTH-1:0] bram_we_a_0_0_ext; reg [`AWIDTH-1:0] bram_addr_a_1_0_ext; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_1_0_ext; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_1_0_ext; reg [`MASK_WIDTH-1:0] bram_we_a_1_0_ext; reg [`AWIDTH-1:0] bram_addr_b_0_0_ext; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_0_ext; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_0_ext; reg [`MASK_WIDTH-1:0] bram_we_b_0_0_ext; reg [`AWIDTH-1:0] bram_addr_b_0_1_ext; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_1_ext; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_1_ext; reg [`MASK_WIDTH-1:0] bram_we_b_0_1_ext; reg [`AWIDTH-1:0] bram_addr_c_0_1_ext; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_0_1_ext; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_0_1_ext; reg [`MASK_WIDTH-1:0] bram_we_c_0_1_ext; reg [`AWIDTH-1:0] bram_addr_c_1_1_ext; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_1_1_ext; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_1_1_ext; reg [`MASK_WIDTH-1:0] bram_we_c_1_1_ext; wire [`AWIDTH-1:0] bram_addr_a_0_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_0_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_0_0; wire [`MASK_WIDTH-1:0] bram_we_a_0_0; wire bram_en_a_0_0; wire [`AWIDTH-1:0] bram_addr_a_1_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_a_1_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_a_1_0; wire [`MASK_WIDTH-1:0] bram_we_a_1_0; wire bram_en_a_1_0; wire [`AWIDTH-1:0] bram_addr_b_0_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_0; wire [`MASK_WIDTH-1:0] bram_we_b_0_0; wire bram_en_b_0_0; wire [`AWIDTH-1:0] bram_addr_b_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_b_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_b_0_1; wire [`MASK_WIDTH-1:0] bram_we_b_0_1; wire bram_en_b_0_1; wire [`AWIDTH-1:0] bram_addr_c_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_0_1; wire [`MASK_WIDTH-1:0] bram_we_c_0_1; wire bram_en_c_0_1; wire [`AWIDTH-1:0] bram_addr_c_1_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_rdata_c_1_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] bram_wdata_c_1_1; wire [`MASK_WIDTH-1:0] bram_we_c_1_1; wire bram_en_c_1_1; always @* begin case (bram_select) 0: begin bram_addr_a_0_0_ext = bram_addr_ext; bram_wdata_a_0_0_ext = bram_wdata_ext; bram_we_a_0_0_ext = bram_we_ext; bram_rdata_ext = bram_rdata_a_0_0_ext; end 1: begin bram_addr_a_1_0_ext = bram_addr_ext; bram_wdata_a_1_0_ext = bram_wdata_ext; bram_we_a_1_0_ext = bram_we_ext; bram_rdata_ext = bram_rdata_a_1_0_ext; end 2: begin bram_addr_b_0_0_ext = bram_addr_ext; bram_wdata_b_0_0_ext = bram_wdata_ext; bram_we_b_0_0_ext = bram_we_ext; bram_rdata_ext = bram_rdata_b_0_0_ext; end 3: begin bram_addr_b_0_1_ext = bram_addr_ext; bram_wdata_b_0_1_ext = bram_wdata_ext; bram_we_b_0_1_ext = bram_we_ext; bram_rdata_ext = bram_rdata_b_0_1_ext; end 4: begin bram_addr_c_0_1_ext = bram_addr_ext; bram_wdata_c_0_1_ext = bram_wdata_ext; bram_we_c_0_1_ext = bram_we_ext; bram_rdata_ext = bram_rdata_c_0_1_ext; end 5: begin bram_addr_c_1_1_ext = bram_addr_ext; bram_wdata_c_1_1_ext = bram_wdata_ext; bram_we_c_1_1_ext = bram_we_ext; bram_rdata_ext = bram_rdata_c_1_1_ext; end default: begin bram_rdata_ext = 0; end endcase end ///////////////////////////////////////////////// // BRAMs to store matrix A ///////////////////////////////////////////////// // BRAM matrix A 0_0 (bank0) ram matrix_A_0_0( .addr0(bram_addr_a_0_0), .d0(bram_wdata_a_0_0), .we0(bram_we_a_0_0), .q0(bram_rdata_a_0_0), .addr1(bram_addr_a_0_0_ext), .d1(bram_wdata_a_0_0_ext), .we1(bram_we_a_0_0_ext), .q1(bram_rdata_a_0_0_ext), .clk(clk_mem)); // BRAM matrix A 1_0 (bank1) ram matrix_A_1_0( .addr0(bram_addr_a_1_0), .d0(bram_wdata_a_1_0), .we0(bram_we_a_1_0), .q0(bram_rdata_a_1_0), .addr1(bram_addr_a_1_0_ext), .d1(bram_wdata_a_1_0_ext), .we1(bram_we_a_1_0_ext), .q1(bram_rdata_a_1_0_ext), .clk(clk_mem)); ///////////////////////////////////////////////// // BRAMs to store matrix B ///////////////////////////////////////////////// // BRAM matrix B 0_0 (bank0) ram matrix_B_0_0( .addr0(bram_addr_b_0_0), .d0(bram_wdata_b_0_0), .we0(bram_we_b_0_0), .q0(bram_rdata_b_0_0), .addr1(bram_addr_b_0_0_ext), .d1(bram_wdata_b_0_0_ext), .we1(bram_we_b_0_0_ext), .q1(bram_rdata_b_0_0_ext), .clk(clk_mem)); // BRAM matrix B 0_1 (bank1) ram matrix_B_0_1( .addr0(bram_addr_b_0_1), .d0(bram_wdata_b_0_1), .we0(bram_we_b_0_1), .q0(bram_rdata_b_0_1), .addr1(bram_addr_b_0_1_ext), .d1(bram_wdata_b_0_1_ext), .we1(bram_we_b_0_1_ext), .q1(bram_rdata_b_0_1_ext), .clk(clk_mem)); ///////////////////////////////////////////////// // BRAMs to store matrix C ///////////////////////////////////////////////// // BRAM matrix C 0_1 (bank0) ram matrix_C_0_1( .addr0(bram_addr_c_0_1), .d0(bram_wdata_c_0_1), .we0(bram_we_c_0_1), .q0(bram_rdata_c_0_1), .addr1(bram_addr_c_0_1_ext), .d1(bram_wdata_c_0_1_ext), .we1(bram_we_c_0_1_ext), .q1(bram_rdata_c_0_1_ext), .clk(clk_mem)); // BRAM matrix C 1_1 (bank1) ram matrix_C_1_1( .addr0(bram_addr_c_1_1), .d0(bram_wdata_c_1_1), .we0(bram_we_c_1_1), .q0(bram_rdata_c_1_1), .addr1(bram_addr_c_1_1_ext), .d1(bram_wdata_c_1_1_ext), .we1(bram_we_c_1_1_ext), .q1(bram_rdata_c_1_1_ext), .clk(clk_mem)); reg start_mat_mul; wire done_mat_mul; reg [3:0] state; //////////////////////////////////////////////////////////////// // Control logic //////////////////////////////////////////////////////////////// always @( posedge clk) begin if (resetn == 1'b0) begin state <= 4'b0000; start_mat_mul <= 1'b0; end else begin case (state) 4'b0000: begin start_mat_mul <= 1'b0; if (start_reg == 1'b1) begin state <= 4'b0001; end else begin state <= 4'b0000; end end 4'b0001: begin start_mat_mul <= 1'b1; state <= 4'b1010; end 4'b1010: begin if (done_mat_mul == 1'b1) begin start_mat_mul <= 1'b0; state <= 4'b1000; end else begin state <= 4'b1010; end end 4'b1000: begin if (clear_done_reg == 1'b1) begin state <= 4'b0000; end else begin state <= 4'b1000; end end endcase end end reg [1:0] state_apb; `define IDLE 2'b00 `define W_ENABLE 2'b01 `define R_ENABLE 2'b10 reg [`AWIDTH-1:0] address_mat_a; reg [`AWIDTH-1:0] address_mat_b; reg [`AWIDTH-1:0] address_mat_c; reg [`MASK_WIDTH-1:0] validity_mask_a_rows; reg [`MASK_WIDTH-1:0] validity_mask_a_cols; reg [`MASK_WIDTH-1:0] validity_mask_b_rows; reg [`MASK_WIDTH-1:0] validity_mask_b_cols; reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_a; reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_b; reg [`ADDR_STRIDE_WIDTH-1:0] address_stride_c; //////////////////////////////////////////////////////////////// // Configuration logic //////////////////////////////////////////////////////////////// always @(posedge PCLK) begin if (PRESETn == 0) begin state_apb <= `IDLE; PRDATA <= 0; PREADY <= 0; address_mat_a <= 0; address_mat_b <= 0; address_mat_c <= 0; validity_mask_a_rows <= {`MASK_WIDTH{1'b1}}; validity_mask_a_cols <= {`MASK_WIDTH{1'b1}}; validity_mask_b_rows <= {`MASK_WIDTH{1'b1}}; validity_mask_b_cols <= {`MASK_WIDTH{1'b1}}; address_stride_a <= `MAT_MUL_SIZE; address_stride_b <= `MAT_MUL_SIZE; address_stride_c <= `MAT_MUL_SIZE; end else begin case (state_apb) `IDLE : begin PRDATA <= 0; if (PSEL) begin if (PWRITE) begin state_apb <= `W_ENABLE; end else begin state_apb <= `R_ENABLE; end end PREADY <= 0; end `W_ENABLE : begin if (PSEL && PWRITE && PENABLE) begin case (PADDR) `REG_STDN_TPU_ADDR : begin start_reg <= PWDATA[0]; clear_done_reg <= PWDATA[31]; end `REG_MATRIX_A_ADDR : address_mat_a <= PWDATA[`AWIDTH-1:0]; `REG_MATRIX_B_ADDR : address_mat_b <= PWDATA[`AWIDTH-1:0]; `REG_MATRIX_C_ADDR : address_mat_c <= PWDATA[`AWIDTH-1:0]; `REG_VALID_MASK_A_ROWS_ADDR: begin validity_mask_a_rows <= PWDATA[`MASK_WIDTH-1:0]; end `REG_VALID_MASK_A_COLS_ADDR: begin validity_mask_a_cols <= PWDATA[`MASK_WIDTH-1:0]; end `REG_VALID_MASK_B_ROWS_ADDR: begin validity_mask_b_rows <= PWDATA[`MASK_WIDTH-1:0]; end `REG_VALID_MASK_B_COLS_ADDR: begin validity_mask_b_cols <= PWDATA[`MASK_WIDTH-1:0]; end `REG_MATRIX_A_STRIDE_ADDR : address_stride_a <= PWDATA[`ADDR_STRIDE_WIDTH-1:0]; `REG_MATRIX_B_STRIDE_ADDR : address_stride_b <= PWDATA[`ADDR_STRIDE_WIDTH-1:0]; `REG_MATRIX_C_STRIDE_ADDR : address_stride_c <= PWDATA[`ADDR_STRIDE_WIDTH-1:0]; default : reg_dummy <= PWDATA; //sink writes to a dummy register endcase PREADY <=1; end state_apb <= `IDLE; end `R_ENABLE : begin if (PSEL && !PWRITE && PENABLE) begin PREADY <= 1; case (PADDR) `REG_STDN_TPU_ADDR : PRDATA <= {done_mat_mul, 30'b0, start_mat_mul}; `REG_MATRIX_A_ADDR : PRDATA <= address_mat_a; `REG_MATRIX_B_ADDR : PRDATA <= address_mat_b; `REG_MATRIX_C_ADDR : PRDATA <= address_mat_c; `REG_VALID_MASK_A_ROWS_ADDR: PRDATA <= validity_mask_a_rows; `REG_VALID_MASK_A_COLS_ADDR: PRDATA <= validity_mask_a_cols; `REG_VALID_MASK_B_ROWS_ADDR: PRDATA <= validity_mask_b_rows; `REG_VALID_MASK_B_COLS_ADDR: PRDATA <= validity_mask_b_cols; `REG_MATRIX_A_STRIDE_ADDR : PRDATA <= address_stride_a; `REG_MATRIX_B_STRIDE_ADDR : PRDATA <= address_stride_b; `REG_MATRIX_C_STRIDE_ADDR : PRDATA <= address_stride_c; default : PRDATA <= reg_dummy; //read the dummy register for undefined addresses endcase end state_apb <= `IDLE; end default: begin state_apb <= `IDLE; end endcase end end wire reset; assign reset = ~resetn; wire pe_reset; assign pe_reset = ~pe_resetn; wire c_data_0_1_available; assign bram_en_c_0_1 = 1'b1; assign bram_we_c_0_1 = (c_data_0_1_available) ? {`MASK_WIDTH{1'b1}} : {`MASK_WIDTH{1'b0}}; wire c_data_1_1_available; assign bram_en_c_1_1 = 1'b1; assign bram_we_c_1_1 = (c_data_1_1_available) ? {`MASK_WIDTH{1'b1}} : {`MASK_WIDTH{1'b0}}; assign bram_wdata_a_0_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}}; assign bram_en_a_0_0 = 1'b1; assign bram_we_a_0_0 = {`MASK_WIDTH{1'b0}}; assign bram_wdata_a_1_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}}; assign bram_en_a_1_0 = 1'b1; assign bram_we_a_1_0 = {`MASK_WIDTH{1'b0}}; assign bram_wdata_b_0_0 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}}; assign bram_en_b_0_0 = 1'b1; assign bram_we_b_0_0 = {`MASK_WIDTH{1'b0}}; assign bram_wdata_b_0_1 = {`MAT_MUL_SIZE*`DWIDTH{1'b0}}; assign bram_en_b_0_1 = 1'b1; assign bram_we_b_0_1 = {`MASK_WIDTH{1'b0}}; ///////////////////////////////////////////////// // The 8x8 matmul instantiation ///////////////////////////////////////////////// matmul_8x8_systolic u_matmul_8x8_systolic ( .clk(clk), .reset(reset), .pe_reset(pe_reset), .start_mat_mul(start_mat_mul), .done_mat_mul(done_mat_mul), .address_mat_a(address_mat_a), .address_mat_b(address_mat_b), .address_mat_c(address_mat_c), .address_stride_a(address_stride_a), .address_stride_b(address_stride_b), .address_stride_c(address_stride_c), .a_data_0_0(bram_rdata_a_0_0), .b_data_0_0(bram_rdata_b_0_0), .a_addr_0_0(bram_addr_a_0_0), .b_addr_0_0(bram_addr_b_0_0), .a_data_1_0(bram_rdata_a_1_0), .b_data_0_1(bram_rdata_b_0_1), .a_addr_1_0(bram_addr_a_1_0), .b_addr_0_1(bram_addr_b_0_1), .c_data_0_1(bram_wdata_c_0_1), .c_addr_0_1(bram_addr_c_0_1), .c_data_0_1_available(c_data_0_1_available), .c_data_1_1(bram_wdata_c_1_1), .c_addr_1_1(bram_addr_c_1_1), .c_data_1_1_available(c_data_1_1_available), .validity_mask_a_rows(validity_mask_a_rows), .validity_mask_a_cols(validity_mask_a_cols), .validity_mask_b_rows(validity_mask_b_rows), .validity_mask_b_cols(validity_mask_b_cols) ); endmodule ///////////////////////////////////////////////// // The 8x8 matmul definition ///////////////////////////////////////////////// module matmul_8x8_systolic( input clk, input reset, input pe_reset, input start_mat_mul, output done_mat_mul, input [`AWIDTH-1:0] address_mat_a, input [`AWIDTH-1:0] address_mat_b, input [`AWIDTH-1:0] address_mat_c, input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a, input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b, input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c, input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_0, output [`AWIDTH-1:0] a_addr_0_0, input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_0, output [`AWIDTH-1:0] b_addr_0_0, input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_0, output [`AWIDTH-1:0] a_addr_1_0, input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_1, output [`AWIDTH-1:0] b_addr_0_1, output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_1, output [`AWIDTH-1:0] c_addr_0_1, output c_data_0_1_available, output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_1, output [`AWIDTH-1:0] c_addr_1_1, output c_data_1_1_available, input [`MASK_WIDTH-1:0] validity_mask_a_rows, input [`MASK_WIDTH-1:0] validity_mask_a_cols, input [`MASK_WIDTH-1:0] validity_mask_b_rows, input [`MASK_WIDTH-1:0] validity_mask_b_cols ); ///////////////////////////////////////////////// // ORing all done signals ///////////////////////////////////////////////// wire done_mat_mul_0_0; wire done_mat_mul_0_1; wire done_mat_mul_1_0; wire done_mat_mul_1_1; assign done_mat_mul = done_mat_mul_0_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_1_NC; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_1_NC; ///////////////////////////////////////////////// // Matmul 0_0 ///////////////////////////////////////////////// wire [3:0] flags_NC_0_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_0_to_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_0_to_1_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in_0_0_NC; assign a_data_in_0_0_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in_0_0_NC; assign c_data_in_0_0_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in_0_0_NC; assign b_data_in_0_0_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_0_to_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_0_0_to_0_1_NC; wire [`AWIDTH-1:0] c_addr_0_0_NC; wire c_data_0_0_available_NC; matmul_4x4_systolic u_matmul_4x4_systolic_0_0( .clk(clk), .reset(reset), .pe_reset(pe_reset), .start_mat_mul(start_mat_mul), .done_mat_mul(done_mat_mul_0_0), .address_mat_a(address_mat_a), .address_mat_b(address_mat_b), .address_mat_c(address_mat_c), .address_stride_a(address_stride_a), .address_stride_b(address_stride_b), .address_stride_c(address_stride_c), .a_data(a_data_0_0), .b_data(b_data_0_0), .a_data_in(a_data_in_0_0_NC), .b_data_in(b_data_in_0_0_NC), .c_data_in(c_data_in_0_0_NC), .c_data_out(c_data_0_0_to_0_1_NC), .a_data_out(a_data_0_0_to_0_1), .b_data_out(b_data_0_0_to_1_0), .a_addr(a_addr_0_0), .b_addr(b_addr_0_0), .c_addr(c_addr_0_0_NC), .c_data_available(c_data_0_0_available_NC), .validity_mask_a_rows({4'b0,validity_mask_a_rows}), .validity_mask_a_cols({4'b0,validity_mask_a_cols}), .validity_mask_b_rows({4'b0,validity_mask_b_rows}), .validity_mask_b_cols({4'b0,validity_mask_b_cols}), .final_mat_mul_size(8'd8), .a_loc(8'd0), .b_loc(8'd0) ); ///////////////////////////////////////////////// // Matmul 0_1 ///////////////////////////////////////////////// wire [3:0] flags_NC_0_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_1_to_0_2; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_0_1_to_1_1; wire [`AWIDTH-1:0] a_addr_0_1_NC; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_0_1_NC; assign a_data_0_1_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in_0_1_NC; assign b_data_in_0_1_NC = 0; matmul_4x4_systolic u_matmul_4x4_systolic_0_1( .clk(clk), .reset(reset), .pe_reset(pe_reset), .start_mat_mul(start_mat_mul), .done_mat_mul(done_mat_mul_0_1), .address_mat_a(address_mat_a), .address_mat_b(address_mat_b), .address_mat_c(address_mat_c), .address_stride_a(address_stride_a), .address_stride_b(address_stride_b), .address_stride_c(address_stride_c), .a_data(a_data_0_1_NC), .b_data(b_data_0_1), .a_data_in(a_data_0_0_to_0_1), .b_data_in(b_data_in_0_1_NC), .c_data_in(c_data_0_0_to_0_1), .c_data_out(c_data_0_1), .a_data_out(a_data_0_1_to_0_2), .b_data_out(b_data_0_1_to_1_1), .a_addr(a_addr_0_1_NC), .b_addr(b_addr_0_1), .c_addr(c_addr_0_1), .c_data_available(c_data_0_1_available), .validity_mask_a_rows({4'b0,validity_mask_a_rows}), .validity_mask_a_cols({4'b0,validity_mask_a_cols}), .validity_mask_b_rows({4'b0,validity_mask_b_rows}), .validity_mask_b_cols({4'b0,validity_mask_b_cols}), .final_mat_mul_size(8'd8), .a_loc(8'd0), .b_loc(8'd1) ); ///////////////////////////////////////////////// // Matmul 1_0 ///////////////////////////////////////////////// wire [3:0] flags_NC_1_0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_0_to_1_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_0_to_2_0; wire [`AWIDTH-1:0] b_addr_1_0_NC; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_0_NC; assign b_data_1_0_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in_1_0_NC; assign a_data_in_1_0_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in_1_0_NC; assign c_data_in_1_0_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_0_to_1_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_1_0_to_1_1_NC; wire [`AWIDTH-1:0] c_addr_1_0_NC; wire c_data_1_0_available_NC; matmul_4x4_systolic u_matmul_4x4_systolic_1_0( .clk(clk), .reset(reset), .pe_reset(pe_reset), .start_mat_mul(start_mat_mul), .done_mat_mul(done_mat_mul_1_0), .address_mat_a(address_mat_a), .address_mat_b(address_mat_b), .address_mat_c(address_mat_c), .address_stride_a(address_stride_a), .address_stride_b(address_stride_b), .address_stride_c(address_stride_c), .a_data(a_data_1_0), .b_data(b_data_1_0_NC), .a_data_in(a_data_in_1_0_NC), .b_data_in(b_data_0_0_to_1_0), .c_data_in(c_data_in_1_0_NC), .c_data_out(c_data_1_0_to_1_1_NC), .a_data_out(a_data_1_0_to_1_1), .b_data_out(b_data_1_0_to_2_0), .a_addr(a_addr_1_0), .b_addr(b_addr_1_0_NC), .c_addr(c_addr_1_0_NC), .c_data_available(c_data_1_0_available_NC), .validity_mask_a_rows({4'b0,validity_mask_a_rows}), .validity_mask_a_cols({4'b0,validity_mask_a_cols}), .validity_mask_b_rows({4'b0,validity_mask_b_rows}), .validity_mask_b_cols({4'b0,validity_mask_b_cols}), .final_mat_mul_size(8'd8), .a_loc(8'd1), .b_loc(8'd0) ); ///////////////////////////////////////////////// // Matmul 1_1 ///////////////////////////////////////////////// wire [3:0] flags_NC_1_1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_1_to_1_2; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_1_to_2_1; wire [`AWIDTH-1:0] a_addr_1_1_NC; wire [`AWIDTH-1:0] b_addr_1_1_NC; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_1_1_NC; assign a_data_1_1_NC = 0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_1_1_NC; assign b_data_1_1_NC = 0; matmul_4x4_systolic u_matmul_4x4_systolic_1_1( .clk(clk), .reset(reset), .pe_reset(pe_reset), .start_mat_mul(start_mat_mul), .done_mat_mul(done_mat_mul_1_1), .address_mat_a(address_mat_a), .address_mat_b(address_mat_b), .address_mat_c(address_mat_c), .address_stride_a(address_stride_a), .address_stride_b(address_stride_b), .address_stride_c(address_stride_c), .a_data(a_data_1_1_NC), .b_data(b_data_1_1_NC), .a_data_in(a_data_1_0_to_1_1), .b_data_in(b_data_0_1_to_1_1), .c_data_in(c_data_1_0_to_1_1), .c_data_out(c_data_1_1), .a_data_out(a_data_1_1_to_1_2), .b_data_out(b_data_1_1_to_2_1), .a_addr(a_addr_1_1_NC), .b_addr(b_addr_1_1_NC), .c_addr(c_addr_1_1), .c_data_available(c_data_1_1_available), .validity_mask_a_rows({4'b0,validity_mask_a_rows}), .validity_mask_a_cols({4'b0,validity_mask_a_cols}), .validity_mask_b_rows({4'b0,validity_mask_b_rows}), .validity_mask_b_cols({4'b0,validity_mask_b_cols}), .final_mat_mul_size(8'd8), .a_loc(8'd1), .b_loc(8'd1) ); endmodule ////////////////////////////////// //Dual port RAM ////////////////////////////////// module ram ( addr0, d0, we0, q0, addr1, d1, we1, q1, clk); input [`AWIDTH-1:0] addr0; input [`AWIDTH-1:0] addr1; input [`MAT_MUL_SIZE*`DWIDTH-1:0] d0; input [`MAT_MUL_SIZE*`DWIDTH-1:0] d1; input [`MAT_MUL_SIZE-1:0] we0; input [`MAT_MUL_SIZE-1:0] we1; output [`MAT_MUL_SIZE*`DWIDTH-1:0] q0; output [`MAT_MUL_SIZE*`DWIDTH-1:0] q1; input clk; `ifdef VCS reg [`MAT_MUL_SIZE*`DWIDTH-1:0] q0; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] q1; reg [7:0] ram[((1<<`AWIDTH)-1):0]; integer i; always @(posedge clk) begin for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin if (we0[i]) ram[addr0+i] <= d0[i*`DWIDTH +: `DWIDTH]; end for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin q0[i*`DWIDTH +: `DWIDTH] <= ram[addr0+i]; end end always @(posedge clk) begin for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin if (we1[i]) ram[addr0+i] <= d1[i*`DWIDTH +: `DWIDTH]; end for (i = 0; i < `MAT_MUL_SIZE; i=i+1) begin q1[i*`DWIDTH +: `DWIDTH] <= ram[addr1+i]; end end `else //BRAMs available in VTR FPGA architectures have one bit write-enables. //So let's combine multiple bits into 1. We don't have a usecase of //writing/not-writing only parts of the word anyway. wire we0_coalesced; assign we0_coalesced = |we0; wire we1_coalesced; assign we1_coalesced = |we1; dual_port_ram u_dual_port_ram( .addr1(addr0), .we1(we0_coalesced), .data1(d0), .out1(q0), .addr2(addr1), .we2(we1_coalesced), .data2(d1), .out2(q1), .clk(clk) ); `endif endmodule ////////////////////////////////////////////// //4x4 systolic matrix multiplier ////////////////////////////////////////////// module matmul_4x4_systolic( clk, reset, pe_reset, start_mat_mul, done_mat_mul, address_mat_a, address_mat_b, address_mat_c, address_stride_a, address_stride_b, address_stride_c, a_data, b_data, a_data_in, //Data values coming in from previous matmul - systolic connections b_data_in, //Data values coming in from previous matmul - systolic connections c_data_in, //Data values coming in from previous matmul - systolic shifting c_data_out,//Data values going out to next matmul - systolic shifting a_data_out, b_data_out, a_addr, b_addr, c_addr, c_data_available, validity_mask_a_rows, validity_mask_a_cols, validity_mask_b_rows, validity_mask_b_cols, final_mat_mul_size, a_loc, b_loc ); input clk; input reset; input pe_reset; input start_mat_mul; output done_mat_mul; input [`AWIDTH-1:0] address_mat_a; input [`AWIDTH-1:0] address_mat_b; input [`AWIDTH-1:0] address_mat_c; input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a; input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b; input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c; input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data; input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data; input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_in; input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_in; input [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in; output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out; output [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_out; output [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_out; output [`AWIDTH-1:0] a_addr; output [`AWIDTH-1:0] b_addr; output [`AWIDTH-1:0] c_addr; output c_data_available; input [`MASK_WIDTH-1:0] validity_mask_a_rows; input [`MASK_WIDTH-1:0] validity_mask_a_cols; input [`MASK_WIDTH-1:0] validity_mask_b_rows; input [`MASK_WIDTH-1:0] validity_mask_b_cols; //7:0 is okay here. We aren't going to make a matmul larger than 128x128 //In fact, these will get optimized out by the synthesis tool, because //we hardcode them at the instantiation level. input [7:0] final_mat_mul_size; input [7:0] a_loc; input [7:0] b_loc; ////////////////////////////////////////////////////////////////////////// // Logic for clock counting and when to assert done ////////////////////////////////////////////////////////////////////////// reg done_mat_mul; //This is 7 bits because the expectation is that clock count will be pretty //small. For large matmuls, this will need to increased to have more bits. //In general, a systolic multiplier takes 4*N-2+P cycles, where N is the size //of the matmul and P is the number of pipleine stages in the MAC block. reg [7:0] clk_cnt; //Finding out number of cycles to assert matmul done. //When we have to save the outputs to accumulators, then we don't need to //shift out data. So, we can assert done_mat_mul early. //In the normal case, we have to include the time to shift out the results. //Note: the count expression used to contain "4*final_mat_mul_size", but //to avoid multiplication, we now use "final_mat_mul_size<<2" wire [7:0] clk_cnt_for_done; assign clk_cnt_for_done = ((final_mat_mul_size<<2) - 2 + `NUM_CYCLES_IN_MAC) ; always @(posedge clk) begin if (reset || ~start_mat_mul) begin clk_cnt <= 0; done_mat_mul <= 0; end else if (clk_cnt == clk_cnt_for_done) begin done_mat_mul <= 1; clk_cnt <= clk_cnt + 1; end else if (done_mat_mul == 0) begin clk_cnt <= clk_cnt + 1; end else begin done_mat_mul <= 0; clk_cnt <= clk_cnt + 1; end end wire [`DWIDTH-1:0] a0_data; wire [`DWIDTH-1:0] a1_data; wire [`DWIDTH-1:0] a2_data; wire [`DWIDTH-1:0] a3_data; wire [`DWIDTH-1:0] b0_data; wire [`DWIDTH-1:0] b1_data; wire [`DWIDTH-1:0] b2_data; wire [`DWIDTH-1:0] b3_data; wire [`DWIDTH-1:0] a1_data_delayed_1; wire [`DWIDTH-1:0] a2_data_delayed_1; wire [`DWIDTH-1:0] a2_data_delayed_2; wire [`DWIDTH-1:0] a3_data_delayed_1; wire [`DWIDTH-1:0] a3_data_delayed_2; wire [`DWIDTH-1:0] a3_data_delayed_3; wire [`DWIDTH-1:0] b1_data_delayed_1; wire [`DWIDTH-1:0] b2_data_delayed_1; wire [`DWIDTH-1:0] b2_data_delayed_2; wire [`DWIDTH-1:0] b3_data_delayed_1; wire [`DWIDTH-1:0] b3_data_delayed_2; wire [`DWIDTH-1:0] b3_data_delayed_3; ////////////////////////////////////////////////////////////////////////// // Instantiation of systolic data setup ////////////////////////////////////////////////////////////////////////// systolic_data_setup u_systolic_data_setup( .clk(clk), .reset(reset), .start_mat_mul(start_mat_mul), .a_addr(a_addr), .b_addr(b_addr), .address_mat_a(address_mat_a), .address_mat_b(address_mat_b), .address_stride_a(address_stride_a), .address_stride_b(address_stride_b), .a_data(a_data), .b_data(b_data), .clk_cnt(clk_cnt), .a0_data(a0_data), .a1_data_delayed_1(a1_data_delayed_1), .a2_data_delayed_2(a2_data_delayed_2), .a3_data_delayed_3(a3_data_delayed_3), .b0_data(b0_data), .b1_data_delayed_1(b1_data_delayed_1), .b2_data_delayed_2(b2_data_delayed_2), .b3_data_delayed_3(b3_data_delayed_3), .validity_mask_a_rows(validity_mask_a_rows), .validity_mask_a_cols(validity_mask_a_cols), .validity_mask_b_rows(validity_mask_b_rows), .validity_mask_b_cols(validity_mask_b_cols), .final_mat_mul_size(final_mat_mul_size), .a_loc(a_loc), .b_loc(b_loc) ); ////////////////////////////////////////////////////////////////////////// // Logic to mux data_in coming from neighboring matmuls ////////////////////////////////////////////////////////////////////////// wire [`DWIDTH-1:0] a0; wire [`DWIDTH-1:0] a1; wire [`DWIDTH-1:0] a2; wire [`DWIDTH-1:0] a3; wire [`DWIDTH-1:0] b0; wire [`DWIDTH-1:0] b1; wire [`DWIDTH-1:0] b2; wire [`DWIDTH-1:0] b3; wire [`DWIDTH-1:0] a0_data_in; wire [`DWIDTH-1:0] a1_data_in; wire [`DWIDTH-1:0] a2_data_in; wire [`DWIDTH-1:0] a3_data_in; assign a0_data_in = a_data_in[`DWIDTH-1:0]; assign a1_data_in = a_data_in[2*`DWIDTH-1:`DWIDTH]; assign a2_data_in = a_data_in[3*`DWIDTH-1:2*`DWIDTH]; assign a3_data_in = a_data_in[4*`DWIDTH-1:3*`DWIDTH]; wire [`DWIDTH-1:0] b0_data_in; wire [`DWIDTH-1:0] b1_data_in; wire [`DWIDTH-1:0] b2_data_in; wire [`DWIDTH-1:0] b3_data_in; assign b0_data_in = b_data_in[`DWIDTH-1:0]; assign b1_data_in = b_data_in[2*`DWIDTH-1:`DWIDTH]; assign b2_data_in = b_data_in[3*`DWIDTH-1:2*`DWIDTH]; assign b3_data_in = b_data_in[4*`DWIDTH-1:3*`DWIDTH]; //If b_loc is 0, that means this matmul block is on the top-row of the //final large matmul. In that case, b will take inputs from mem. //If b_loc != 0, that means this matmul block is not on the top-row of the //final large matmul. In that case, b will take inputs from the matmul on top //of this one. assign a0 = (b_loc==0) ? a0_data : a0_data_in; assign a1 = (b_loc==0) ? a1_data_delayed_1 : a1_data_in; assign a2 = (b_loc==0) ? a2_data_delayed_2 : a2_data_in; assign a3 = (b_loc==0) ? a3_data_delayed_3 : a3_data_in; //If a_loc is 0, that means this matmul block is on the left-col of the //final large matmul. In that case, a will take inputs from mem. //If a_loc != 0, that means this matmul block is not on the left-col of the //final large matmul. In that case, a will take inputs from the matmul on left //of this one. assign b0 = (a_loc==0) ? b0_data : b0_data_in; assign b1 = (a_loc==0) ? b1_data_delayed_1 : b1_data_in; assign b2 = (a_loc==0) ? b2_data_delayed_2 : b2_data_in; assign b3 = (a_loc==0) ? b3_data_delayed_3 : b3_data_in; wire [`DWIDTH-1:0] matrixC00; wire [`DWIDTH-1:0] matrixC01; wire [`DWIDTH-1:0] matrixC02; wire [`DWIDTH-1:0] matrixC03; wire [`DWIDTH-1:0] matrixC10; wire [`DWIDTH-1:0] matrixC11; wire [`DWIDTH-1:0] matrixC12; wire [`DWIDTH-1:0] matrixC13; wire [`DWIDTH-1:0] matrixC20; wire [`DWIDTH-1:0] matrixC21; wire [`DWIDTH-1:0] matrixC22; wire [`DWIDTH-1:0] matrixC23; wire [`DWIDTH-1:0] matrixC30; wire [`DWIDTH-1:0] matrixC31; wire [`DWIDTH-1:0] matrixC32; wire [`DWIDTH-1:0] matrixC33; ////////////////////////////////////////////////////////////////////////// // Instantiation of the output logic ////////////////////////////////////////////////////////////////////////// output_logic u_output_logic( .clk(clk), .reset(reset), .start_mat_mul(start_mat_mul), .done_mat_mul(done_mat_mul), .address_mat_c(address_mat_c), .address_stride_c(address_stride_c), .c_data_out(c_data_out), .c_data_in(c_data_in), .c_addr(c_addr), .c_data_available(c_data_available), .clk_cnt(clk_cnt), .row_latch_en(row_latch_en), .final_mat_mul_size(final_mat_mul_size), .matrixC00(matrixC00), .matrixC01(matrixC01), .matrixC02(matrixC02), .matrixC03(matrixC03), .matrixC10(matrixC10), .matrixC11(matrixC11), .matrixC12(matrixC12), .matrixC13(matrixC13), .matrixC20(matrixC20), .matrixC21(matrixC21), .matrixC22(matrixC22), .matrixC23(matrixC23), .matrixC30(matrixC30), .matrixC31(matrixC31), .matrixC32(matrixC32), .matrixC33(matrixC33) ); ////////////////////////////////////////////////////////////////////////// // Instantiations of the actual processing elements ////////////////////////////////////////////////////////////////////////// systolic_pe_matrix u_systolic_pe_matrix( .reset(reset), .clk(clk), .pe_reset(pe_reset), .a0(a0), .a1(a1), .a2(a2), .a3(a3), .b0(b0), .b1(b1), .b2(b2), .b3(b3), .matrixC00(matrixC00), .matrixC01(matrixC01), .matrixC02(matrixC02), .matrixC03(matrixC03), .matrixC10(matrixC10), .matrixC11(matrixC11), .matrixC12(matrixC12), .matrixC13(matrixC13), .matrixC20(matrixC20), .matrixC21(matrixC21), .matrixC22(matrixC22), .matrixC23(matrixC23), .matrixC30(matrixC30), .matrixC31(matrixC31), .matrixC32(matrixC32), .matrixC33(matrixC33), .a_data_out(a_data_out), .b_data_out(b_data_out) ); endmodule ////////////////////////////////////////////////////////////////////////// // Output logic ////////////////////////////////////////////////////////////////////////// module output_logic( clk, reset, start_mat_mul, done_mat_mul, address_mat_c, address_stride_c, c_data_in, c_data_out, //Data values going out to next matmul - systolic shifting c_addr, c_data_available, clk_cnt, row_latch_en, final_mat_mul_size, matrixC00, matrixC01, matrixC02, matrixC03, matrixC10, matrixC11, matrixC12, matrixC13, matrixC20, matrixC21, matrixC22, matrixC23, matrixC30, matrixC31, matrixC32, matrixC33 ); input clk; input reset; input start_mat_mul; input done_mat_mul; input [`AWIDTH-1:0] address_mat_c; input [`ADDR_STRIDE_WIDTH-1:0] address_stride_c; input [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_in; output [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out; output [`AWIDTH-1:0] c_addr; output c_data_available; input [7:0] clk_cnt; output row_latch_en; input [7:0] final_mat_mul_size; input [`DWIDTH-1:0] matrixC00; input [`DWIDTH-1:0] matrixC01; input [`DWIDTH-1:0] matrixC02; input [`DWIDTH-1:0] matrixC03; input [`DWIDTH-1:0] matrixC10; input [`DWIDTH-1:0] matrixC11; input [`DWIDTH-1:0] matrixC12; input [`DWIDTH-1:0] matrixC13; input [`DWIDTH-1:0] matrixC20; input [`DWIDTH-1:0] matrixC21; input [`DWIDTH-1:0] matrixC22; input [`DWIDTH-1:0] matrixC23; input [`DWIDTH-1:0] matrixC30; input [`DWIDTH-1:0] matrixC31; input [`DWIDTH-1:0] matrixC32; input [`DWIDTH-1:0] matrixC33; wire row_latch_en; ////////////////////////////////////////////////////////////////////////// // Logic to capture matrix C data from the PEs and shift it out ////////////////////////////////////////////////////////////////////////// assign row_latch_en = ((clk_cnt == ((final_mat_mul_size<<2) - final_mat_mul_size -1 +`NUM_CYCLES_IN_MAC))); reg c_data_available; reg [`AWIDTH-1:0] c_addr; reg start_capturing_c_data; integer counter; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_1; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_2; reg [`MAT_MUL_SIZE*`DWIDTH-1:0] c_data_out_3; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col0; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col1; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col2; wire [`MAT_MUL_SIZE*`DWIDTH-1:0] col3; assign col0 = {matrixC30, matrixC20, matrixC10, matrixC00}; assign col1 = {matrixC31, matrixC21, matrixC11, matrixC01}; assign col2 = {matrixC32, matrixC22, matrixC12, matrixC02}; assign col3 = {matrixC33, matrixC23, matrixC13, matrixC03}; //If save_output_to_accum is asserted, that means we are not intending to shift //out the outputs, because the outputs are still partial sums. wire condition_to_start_shifting_output; assign condition_to_start_shifting_output = row_latch_en ; //For larger matmuls, this logic will have more entries in the case statement always @(posedge clk) begin if (reset | ~start_mat_mul) begin start_capturing_c_data <= 1'b0; c_data_available <= 1'b0; c_addr <= address_mat_c+address_stride_c; c_data_out <= 0; counter <= 0; c_data_out_1 <= 0; c_data_out_2 <= 0; c_data_out_3 <= 0; end else if (condition_to_start_shifting_output) begin start_capturing_c_data <= 1'b1; c_data_available <= 1'b1; c_addr <= c_addr - address_stride_c; c_data_out <= col0; c_data_out_1 <= col1; c_data_out_2 <= col2; c_data_out_3 <= col3; counter <= counter + 1; end else if (done_mat_mul) begin start_capturing_c_data <= 1'b0; c_data_available <= 1'b0; c_addr <= address_mat_c+address_stride_c; c_data_out <= 0; c_data_out_1 <= 0; c_data_out_2 <= 0; c_data_out_3 <= 0; end else if (counter >= `MAT_MUL_SIZE) begin c_addr <= c_addr - address_stride_c; c_data_out <= c_data_out_1; c_data_out_1 <= c_data_out_2; c_data_out_2 <= c_data_out_3; c_data_out_3 <= c_data_in; end else if (start_capturing_c_data) begin c_data_available <= 1'b1; c_addr <= c_addr - address_stride_c; counter <= counter + 1; c_data_out <= c_data_out_1; c_data_out_1 <= c_data_out_2; c_data_out_2 <= c_data_out_3; c_data_out_3 <= c_data_in; end end endmodule ////////////////////////////////////////////////////////////////////////// // Systolic data setup ////////////////////////////////////////////////////////////////////////// module systolic_data_setup( clk, reset, start_mat_mul, a_addr, b_addr, address_mat_a, address_mat_b, address_stride_a, address_stride_b, a_data, b_data, clk_cnt, a0_data, a1_data_delayed_1, a2_data_delayed_2, a3_data_delayed_3, b0_data, b1_data_delayed_1, b2_data_delayed_2, b3_data_delayed_3, validity_mask_a_rows, validity_mask_a_cols, validity_mask_b_rows, validity_mask_b_cols, final_mat_mul_size, a_loc, b_loc ); input clk; input reset; input start_mat_mul; output [`AWIDTH-1:0] a_addr; output [`AWIDTH-1:0] b_addr; input [`AWIDTH-1:0] address_mat_a; input [`AWIDTH-1:0] address_mat_b; input [`ADDR_STRIDE_WIDTH-1:0] address_stride_a; input [`ADDR_STRIDE_WIDTH-1:0] address_stride_b; input [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data; input [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data; input [7:0] clk_cnt; output [`DWIDTH-1:0] a0_data; output [`DWIDTH-1:0] a1_data_delayed_1; output [`DWIDTH-1:0] a2_data_delayed_2; output [`DWIDTH-1:0] a3_data_delayed_3; output [`DWIDTH-1:0] b0_data; output [`DWIDTH-1:0] b1_data_delayed_1; output [`DWIDTH-1:0] b2_data_delayed_2; output [`DWIDTH-1:0] b3_data_delayed_3; input [`MASK_WIDTH-1:0] validity_mask_a_rows; input [`MASK_WIDTH-1:0] validity_mask_a_cols; input [`MASK_WIDTH-1:0] validity_mask_b_rows; input [`MASK_WIDTH-1:0] validity_mask_b_cols; input [7:0] final_mat_mul_size; input [7:0] a_loc; input [7:0] b_loc; wire [`DWIDTH-1:0] a0_data; wire [`DWIDTH-1:0] a1_data; wire [`DWIDTH-1:0] a2_data; wire [`DWIDTH-1:0] a3_data; wire [`DWIDTH-1:0] b0_data; wire [`DWIDTH-1:0] b1_data; wire [`DWIDTH-1:0] b2_data; wire [`DWIDTH-1:0] b3_data; ////////////////////////////////////////////////////////////////////////// // Logic to generate addresses to BRAM A ////////////////////////////////////////////////////////////////////////// reg [`AWIDTH-1:0] a_addr; reg a_mem_access; //flag that tells whether the matmul is trying to access memory or not always @(posedge clk) begin //else if (clk_cnt >= a_loc*`MAT_MUL_SIZE+final_mat_mul_size) begin //Writing the line above to avoid multiplication: if ((reset || ~start_mat_mul) || (clk_cnt >= (a_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin a_addr <= address_mat_a-address_stride_a; a_mem_access <= 0; end //else if ((clk_cnt >= a_loc*`MAT_MUL_SIZE) && (clk_cnt < a_loc*`MAT_MUL_SIZE+final_mat_mul_size)) begin //Writing the line above to avoid multiplication: else if ((clk_cnt >= (a_loc<<`LOG2_MAT_MUL_SIZE)) && (clk_cnt < (a_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin a_addr <= a_addr + address_stride_a; a_mem_access <= 1; end end ////////////////////////////////////////////////////////////////////////// // Logic to generate valid signals for data coming from BRAM A ////////////////////////////////////////////////////////////////////////// reg [7:0] a_mem_access_counter; always @(posedge clk) begin if (reset || ~start_mat_mul) begin a_mem_access_counter <= 0; end else if (a_mem_access == 1) begin a_mem_access_counter <= a_mem_access_counter + 1; end else begin a_mem_access_counter <= 0; end end wire a_data_valid; //flag that tells whether the data from memory is valid assign a_data_valid = ((validity_mask_a_cols[0]==1'b0 && a_mem_access_counter==1) || (validity_mask_a_cols[1]==1'b0 && a_mem_access_counter==2) || (validity_mask_a_cols[2]==1'b0 && a_mem_access_counter==3) || (validity_mask_a_cols[3]==1'b0 && a_mem_access_counter==4)) ? 1'b0 : (a_mem_access_counter >= `MEM_ACCESS_LATENCY); ////////////////////////////////////////////////////////////////////////// // Logic to delay certain parts of the data received from BRAM A (systolic data setup) ////////////////////////////////////////////////////////////////////////// //Slice data into chunks and qualify it with whether it is valid or not assign a0_data = a_data[`DWIDTH-1:0] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[0]}}; assign a1_data = a_data[2*`DWIDTH-1:`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[1]}}; assign a2_data = a_data[3*`DWIDTH-1:2*`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[2]}}; assign a3_data = a_data[4*`DWIDTH-1:3*`DWIDTH] & {`DWIDTH{a_data_valid}} & {`DWIDTH{validity_mask_a_rows[3]}}; //For larger matmuls, more such delaying flops will be needed reg [`DWIDTH-1:0] a1_data_delayed_1; reg [`DWIDTH-1:0] a2_data_delayed_1; reg [`DWIDTH-1:0] a2_data_delayed_2; reg [`DWIDTH-1:0] a3_data_delayed_1; reg [`DWIDTH-1:0] a3_data_delayed_2; reg [`DWIDTH-1:0] a3_data_delayed_3; always @(posedge clk) begin if (reset || ~start_mat_mul || clk_cnt==0) begin a1_data_delayed_1 <= 0; a2_data_delayed_1 <= 0; a2_data_delayed_2 <= 0; a3_data_delayed_1 <= 0; a3_data_delayed_2 <= 0; a3_data_delayed_3 <= 0; end else begin a1_data_delayed_1 <= a1_data; a2_data_delayed_1 <= a2_data; a2_data_delayed_2 <= a2_data_delayed_1; a3_data_delayed_1 <= a3_data; a3_data_delayed_2 <= a3_data_delayed_1; a3_data_delayed_3 <= a3_data_delayed_2; end end ////////////////////////////////////////////////////////////////////////// // Logic to generate addresses to BRAM B ////////////////////////////////////////////////////////////////////////// reg [`AWIDTH-1:0] b_addr; reg b_mem_access; //flag that tells whether the matmul is trying to access memory or not always @(posedge clk) begin //else if (clk_cnt >= b_loc*`MAT_MUL_SIZE+final_mat_mul_size) begin //Writing the line above to avoid multiplication: if ((reset || ~start_mat_mul) || (clk_cnt >= (b_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin b_addr <= address_mat_b - address_stride_b; b_mem_access <= 0; end //else if ((clk_cnt >= b_loc*`MAT_MUL_SIZE) && (clk_cnt < b_loc*`MAT_MUL_SIZE+final_mat_mul_size)) begin //Writing the line above to avoid multiplication: else if ((clk_cnt >= (b_loc<<`LOG2_MAT_MUL_SIZE)) && (clk_cnt < (b_loc<<`LOG2_MAT_MUL_SIZE)+final_mat_mul_size)) begin b_addr <= b_addr + address_stride_b; b_mem_access <= 1; end end ////////////////////////////////////////////////////////////////////////// // Logic to generate valid signals for data coming from BRAM B ////////////////////////////////////////////////////////////////////////// reg [7:0] b_mem_access_counter; always @(posedge clk) begin if (reset || ~start_mat_mul) begin b_mem_access_counter <= 0; end else if (b_mem_access == 1) begin b_mem_access_counter <= b_mem_access_counter + 1; end else begin b_mem_access_counter <= 0; end end wire b_data_valid; //flag that tells whether the data from memory is valid assign b_data_valid = ((validity_mask_b_rows[0]==1'b0 && b_mem_access_counter==1) || (validity_mask_b_rows[1]==1'b0 && b_mem_access_counter==2) || (validity_mask_b_rows[2]==1'b0 && b_mem_access_counter==3) || (validity_mask_b_rows[3]==1'b0 && b_mem_access_counter==4)) ? 1'b0 : (b_mem_access_counter >= `MEM_ACCESS_LATENCY); ////////////////////////////////////////////////////////////////////////// // Logic to delay certain parts of the data received from BRAM B (systolic data setup) ////////////////////////////////////////////////////////////////////////// //Slice data into chunks and qualify it with whether it is valid or not assign b0_data = b_data[`DWIDTH-1:0] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[0]}}; assign b1_data = b_data[2*`DWIDTH-1:`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[1]}}; assign b2_data = b_data[3*`DWIDTH-1:2*`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[2]}}; assign b3_data = b_data[4*`DWIDTH-1:3*`DWIDTH] & {`DWIDTH{b_data_valid}} & {`DWIDTH{validity_mask_b_cols[3]}}; //For larger matmuls, more such delaying flops will be needed reg [`DWIDTH-1:0] b1_data_delayed_1; reg [`DWIDTH-1:0] b2_data_delayed_1; reg [`DWIDTH-1:0] b2_data_delayed_2; reg [`DWIDTH-1:0] b3_data_delayed_1; reg [`DWIDTH-1:0] b3_data_delayed_2; reg [`DWIDTH-1:0] b3_data_delayed_3; always @(posedge clk) begin if (reset || ~start_mat_mul || clk_cnt==0) begin b1_data_delayed_1 <= 0; b2_data_delayed_1 <= 0; b2_data_delayed_2 <= 0; b3_data_delayed_1 <= 0; b3_data_delayed_2 <= 0; b3_data_delayed_3 <= 0; end else begin b1_data_delayed_1 <= b1_data; b2_data_delayed_1 <= b2_data; b2_data_delayed_2 <= b2_data_delayed_1; b3_data_delayed_1 <= b3_data; b3_data_delayed_2 <= b3_data_delayed_1; b3_data_delayed_3 <= b3_data_delayed_2; end end endmodule ////////////////////////////////////////////////////////////////////////// // Systolically connected PEs ////////////////////////////////////////////////////////////////////////// module systolic_pe_matrix( reset, clk, pe_reset, a0, a1, a2, a3, b0, b1, b2, b3, matrixC00, matrixC01, matrixC02, matrixC03, matrixC10, matrixC11, matrixC12, matrixC13, matrixC20, matrixC21, matrixC22, matrixC23, matrixC30, matrixC31, matrixC32, matrixC33, a_data_out, b_data_out ); input clk; input reset; input pe_reset; input [`DWIDTH-1:0] a0; input [`DWIDTH-1:0] a1; input [`DWIDTH-1:0] a2; input [`DWIDTH-1:0] a3; input [`DWIDTH-1:0] b0; input [`DWIDTH-1:0] b1; input [`DWIDTH-1:0] b2; input [`DWIDTH-1:0] b3; output [`DWIDTH-1:0] matrixC00; output [`DWIDTH-1:0] matrixC01; output [`DWIDTH-1:0] matrixC02; output [`DWIDTH-1:0] matrixC03; output [`DWIDTH-1:0] matrixC10; output [`DWIDTH-1:0] matrixC11; output [`DWIDTH-1:0] matrixC12; output [`DWIDTH-1:0] matrixC13; output [`DWIDTH-1:0] matrixC20; output [`DWIDTH-1:0] matrixC21; output [`DWIDTH-1:0] matrixC22; output [`DWIDTH-1:0] matrixC23; output [`DWIDTH-1:0] matrixC30; output [`DWIDTH-1:0] matrixC31; output [`DWIDTH-1:0] matrixC32; output [`DWIDTH-1:0] matrixC33; output [`MAT_MUL_SIZE*`DWIDTH-1:0] a_data_out; output [`MAT_MUL_SIZE*`DWIDTH-1:0] b_data_out; wire [`DWIDTH-1:0] a00to01, a01to02, a02to03, a03to04; wire [`DWIDTH-1:0] a10to11, a11to12, a12to13, a13to14; wire [`DWIDTH-1:0] a20to21, a21to22, a22to23, a23to24; wire [`DWIDTH-1:0] a30to31, a31to32, a32to33, a33to34; wire [`DWIDTH-1:0] b00to10, b10to20, b20to30, b30to40; wire [`DWIDTH-1:0] b01to11, b11to21, b21to31, b31to41; wire [`DWIDTH-1:0] b02to12, b12to22, b22to32, b32to42; wire [`DWIDTH-1:0] b03to13, b13to23, b23to33, b33to43; wire effective_rst; assign effective_rst = reset | pe_reset; processing_element pe00(.reset(effective_rst), .clk(clk), .in_a(a0), .in_b(b0), .out_a(a00to01), .out_b(b00to10), .out_c(matrixC00)); processing_element pe01(.reset(effective_rst), .clk(clk), .in_a(a00to01), .in_b(b1), .out_a(a01to02), .out_b(b01to11), .out_c(matrixC01)); processing_element pe02(.reset(effective_rst), .clk(clk), .in_a(a01to02), .in_b(b2), .out_a(a02to03), .out_b(b02to12), .out_c(matrixC02)); processing_element pe03(.reset(effective_rst), .clk(clk), .in_a(a02to03), .in_b(b3), .out_a(a03to04), .out_b(b03to13), .out_c(matrixC03)); processing_element pe10(.reset(effective_rst), .clk(clk), .in_a(a1), .in_b(b00to10), .out_a(a10to11), .out_b(b10to20), .out_c(matrixC10)); processing_element pe11(.reset(effective_rst), .clk(clk), .in_a(a10to11), .in_b(b01to11), .out_a(a11to12), .out_b(b11to21), .out_c(matrixC11)); processing_element pe12(.reset(effective_rst), .clk(clk), .in_a(a11to12), .in_b(b02to12), .out_a(a12to13), .out_b(b12to22), .out_c(matrixC12)); processing_element pe13(.reset(effective_rst), .clk(clk), .in_a(a12to13), .in_b(b03to13), .out_a(a13to14), .out_b(b13to23), .out_c(matrixC13)); processing_element pe20(.reset(effective_rst), .clk(clk), .in_a(a2), .in_b(b10to20), .out_a(a20to21), .out_b(b20to30), .out_c(matrixC20)); processing_element pe21(.reset(effective_rst), .clk(clk), .in_a(a20to21), .in_b(b11to21), .out_a(a21to22), .out_b(b21to31), .out_c(matrixC21)); processing_element pe22(.reset(effective_rst), .clk(clk), .in_a(a21to22), .in_b(b12to22), .out_a(a22to23), .out_b(b22to32), .out_c(matrixC22)); processing_element pe23(.reset(effective_rst), .clk(clk), .in_a(a22to23), .in_b(b13to23), .out_a(a23to24), .out_b(b23to33), .out_c(matrixC23)); processing_element pe30(.reset(effective_rst), .clk(clk), .in_a(a3), .in_b(b20to30), .out_a(a30to31), .out_b(b30to40), .out_c(matrixC30)); processing_element pe31(.reset(effective_rst), .clk(clk), .in_a(a30to31), .in_b(b21to31), .out_a(a31to32), .out_b(b31to41), .out_c(matrixC31)); processing_element pe32(.reset(effective_rst), .clk(clk), .in_a(a31to32), .in_b(b22to32), .out_a(a32to33), .out_b(b32to42), .out_c(matrixC32)); processing_element pe33(.reset(effective_rst), .clk(clk), .in_a(a32to33), .in_b(b23to33), .out_a(a33to34), .out_b(b33to43), .out_c(matrixC33)); assign a_data_out = {a33to34,a23to24,a13to14,a03to04}; assign b_data_out = {b33to43,b32to42,b31to41,b30to40}; endmodule ////////////////////////////////////////////////////////////////////////// // Processing element (PE) ////////////////////////////////////////////////////////////////////////// module processing_element( reset, clk, in_a, in_b, out_a, out_b, out_c ); input reset; input clk; input [`DWIDTH-1:0] in_a; input [`DWIDTH-1:0] in_b; output [`DWIDTH-1:0] out_a; output [`DWIDTH-1:0] out_b; output [`DWIDTH-1:0] out_c; //reduced precision reg [`DWIDTH-1:0] out_a; reg [`DWIDTH-1:0] out_b; wire [`DWIDTH-1:0] out_c; wire [`DWIDTH-1:0] out_mac; assign out_c = out_mac; //This is an instantiation of a module that is defined in the arch file. //It's a mode of the DSP slice (floating point 16-bit multiply and accumulate). mac_fp u_mac(.a(in_a), .b(in_b), .out(out_mac), .reset(reset), .clk(clk)); always @(posedge clk)begin if(reset) begin out_a<=0; out_b<=0; end else begin out_a<=in_a; out_b<=in_b; end end endmodule