diff --git a/frontends/ast/simplify.cc b/frontends/ast/simplify.cc index 6970135d0..c9da8ba48 100644 --- a/frontends/ast/simplify.cc +++ b/frontends/ast/simplify.cc @@ -380,31 +380,100 @@ static int size_packed_struct(AstNode *snode, int base_offset) static AstNode *node_int(int ival) { - // maybe mkconst_int should have default values for the common integer case - return AstNode::mkconst_int(ival, true, 32); + return AstNode::mkconst_int(ival, true); } -static AstNode *offset_indexed_range(int offset_right, int stride, AstNode *left_expr, AstNode *right_expr) +static AstNode *node_uint(uint ival) +{ + return AstNode::mkconst_int(ival, false); +} + +static unsigned int power_of_two(int n) +{ + // iff n is a power of two then return the power, else return 0 + // caller must ensure n > 1 + log_assert(n > 1); + if (n & (n - 1)) { + // not a power of 2 + return 0; + } + // brute force the shift + for (unsigned int i = 1; i < 32; i++) { + n >>= 1; + if (n & 1) { + return i; + } + } + return 0; +} + +static AstNode *multiply_by_const(AstNode *expr_node, int stride) +{ + // the stride is very likely a power of 2, e.g. 8 for bytes + // and so could be optimised with a shift + AstNode *node; + unsigned int shift; + if ((shift = power_of_two(stride)) > 0) { + node = new AstNode(AST_SHIFT_LEFT, expr_node, node_uint(shift)); + } + else { + node = new AstNode(AST_MUL, expr_node, node_int(stride)); + } + return node; +} + +static AstNode *offset_indexed_range(int offset, int stride, AstNode *left_expr, AstNode *right_expr) { // adjust the range expressions to add an offset into the struct // and maybe index using an array stride auto left = left_expr->clone(); auto right = right_expr->clone(); - if (stride == 1) { - // just add the offset - left = new AstNode(AST_ADD, node_int(offset_right), left); - right = new AstNode(AST_ADD, node_int(offset_right), right); + if (stride > 1) { + // newleft = (left + 1) * stride - 1 + left = new AstNode(AST_SUB, multiply_by_const(new AstNode(AST_ADD, left, node_int(1)), stride), node_int(1)); + // newright = right * stride + right = multiply_by_const(right, stride); } - else { - // newleft = offset_right - 1 + (left + 1) * stride - left = new AstNode(AST_ADD, new AstNode(AST_SUB, node_int(offset_right), node_int(1)), - new AstNode(AST_MUL, node_int(stride), new AstNode(AST_ADD, left, node_int(1)))); - // newright = offset_right + right * stride - right = new AstNode(AST_ADD, node_int(offset_right), new AstNode(AST_MUL, right, node_int(stride))); + // add the offset + if (offset) { + left = new AstNode(AST_ADD, node_int(offset), left); + right = new AstNode(AST_ADD, node_int(offset), right); } return new AstNode(AST_RANGE, left, right); } +static AstNode *make_struct_index_range(AstNode *node, AstNode *rnode, int stride, int offset) +{ + // generate a range node to perform either bit or array indexing + if (rnode->children.size() == 1) { + // index e.g. s.a[i] + return offset_indexed_range(offset, stride, rnode->children[0], rnode->children[0]); + } + else if (rnode->children.size() == 2) { + // slice e.g. s.a[i:j] + return offset_indexed_range(offset, stride, rnode->children[0], rnode->children[1]); + } + else { + struct_op_error(node); + } +} + +static AstNode *slice_range(AstNode *rnode, AstNode *snode) +{ + // apply the bit slice indicated by snode to the range rnode + log_assert(rnode->type==AST_RANGE); + auto left = rnode->children[0]; + auto right = rnode->children[1]; + log_assert(snode->type==AST_RANGE); + auto slice_left = snode->children[0]; + auto slice_right = snode->children[1]; + auto width = new AstNode(AST_SUB, slice_left->clone(), slice_right->clone()); + right = new AstNode(AST_ADD, right->clone(), slice_right->clone()); + left = new AstNode(AST_ADD, right->clone(), width); + return new AstNode(AST_RANGE, left, right); +} + + static AstNode *make_struct_member_range(AstNode *node, AstNode *member_node) { // Work out the range in the packed array that corresponds to a struct member @@ -414,27 +483,26 @@ static AstNode *make_struct_member_range(AstNode *node, AstNode *member_node) int range_right = member_node->range_right; if (node->children.empty()) { // no range operations apply, return the whole width + return make_range(range_left, range_right); } - else if (node->children.size() == 1 && node->children[0]->type == AST_RANGE) { - auto rnode = node->children[0]; - int stride = get_struct_array_width(member_node); - if (rnode->children.size() == 1) { - // index e.g. s.a[i] - return offset_indexed_range(range_right, stride, rnode->children[0], rnode->children[0]); - } - else if (rnode->children.size() == 2) { - // slice e.g. s.a[i:j] - return offset_indexed_range(range_right, stride, rnode->children[0], rnode->children[1]); - } - else { - struct_op_error(node); - } + int stride = get_struct_array_width(member_node); + if (node->children.size() == 1 && node->children[0]->type == AST_RANGE) { + // bit or array indexing e.g. s.a[2] or s.a[1:0] + return make_struct_index_range(node, node->children[0], stride, range_right); + } + else if (node->children.size() == 1 && node->children[0]->type == AST_MULTIRANGE) { + // multirange, i.e. bit slice after array index, e.g. s.a[i][p:q] + log_assert(stride > 1); + auto mrnode = node->children[0]; + auto element_range = make_struct_index_range(node, mrnode->children[0], stride, range_right); + // then apply bit slice range + auto range = slice_range(element_range, mrnode->children[1]); + delete element_range; + return range; } else { - // TODO multirange, i.e. bit slice after array index s.a[i][p:q] struct_op_error(node); } - return make_range(range_left, range_right); } static void add_members_to_scope(AstNode *snode, std::string name) diff --git a/tests/svtypes/struct_array.sv b/tests/svtypes/struct_array.sv index 022ad56c6..9c90375ee 100644 --- a/tests/svtypes/struct_array.sv +++ b/tests/svtypes/struct_array.sv @@ -3,7 +3,7 @@ module top; struct packed { - bit [5:0] [7:0] a; // 6 element packed array of bytes + bit [7:0] [7:0] a; // 8 element packed array of bytes bit [15:0] b; // filler for non-zero offset } s; @@ -13,10 +13,13 @@ module top; s.a[2:1] = 16'h1234; s.a[5] = 8'h42; + s.a[7] = '1; + s.a[7][1:0] = '0; + s.b = '1; s.b[1:0] = '0; end - always_comb assert(s==64'h4200_0012_3400_FFFC); + always_comb assert(s==80'hFC00_4200_0012_3400_FFFC); endmodule