diff --git a/kernel/functionalir.cc b/kernel/functionalir.cc index 48baad716..ddd821ada 100644 --- a/kernel/functionalir.cc +++ b/kernel/functionalir.cc @@ -101,17 +101,6 @@ std::string FunctionalIR::Node::to_string(std::function np) class CellSimplifier { using Node = FunctionalIR::Node; FunctionalIR::Factory &factory; - Node reduce_shift_width(Node b, int y_width) { - log_assert(y_width > 0); - int new_width = ceil_log2(y_width + 1); - if (b.width() <= new_width) { - return b; - } else { - Node lower_b = factory.slice(b, 0, new_width); - Node overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b.width()))); - return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow); - } - } Node sign(Node a) { return factory.slice(a, a.width() - 1, 1); } @@ -121,19 +110,30 @@ class CellSimplifier { Node abs(Node a) { return neg_if(a, sign(a)); } + Node handle_shift(Node a, Node b, bool is_right, bool is_signed) { + // to prevent new_width == 0, we handle this case separately + if(a.width() == 1) { + if(!is_signed) + return factory.bitwise_and(a, factory.bitwise_not(factory.reduce_or(b))); + else + return a; + } + int new_width = ceil_log2(a.width()); + Node b_truncated = factory.extend(b, new_width, false); + Node y = + !is_right ? factory.logical_shift_left(a, b_truncated) : + !is_signed ? factory.logical_shift_right(a, b_truncated) : + factory.arithmetic_shift_right(a, b_truncated); + if(b.width() <= new_width) + return y; + Node overflow = factory.unsigned_greater_equal(b, factory.constant(RTLIL::Const(a.width(), b.width()))); + Node y_if_overflow = is_signed ? factory.extend(sign(a), a.width(), true) : factory.constant(RTLIL::Const(State::S0, a.width())); + return factory.mux(y, y_if_overflow, overflow); + } public: - Node logical_shift_left(Node a, Node b) { - Node reduced_b = reduce_shift_width(b, a.width()); - return factory.logical_shift_left(a, reduced_b); - } - Node logical_shift_right(Node a, Node b) { - Node reduced_b = reduce_shift_width(b, a.width()); - return factory.logical_shift_right(a, reduced_b); - } - Node arithmetic_shift_right(Node a, Node b) { - Node reduced_b = reduce_shift_width(b, a.width()); - return factory.arithmetic_shift_right(a, reduced_b); - } + Node logical_shift_left(Node a, Node b) { return handle_shift(a, b, false, false); } + Node logical_shift_right(Node a, Node b) { return handle_shift(a, b, true, false); } + Node arithmetic_shift_right(Node a, Node b) { return handle_shift(a, b, true, true); } Node bitwise_mux(Node a, Node b, Node s) { Node aa = factory.bitwise_and(a, factory.bitwise_not(s)); Node bb = factory.bitwise_and(b, s); @@ -348,7 +348,7 @@ public: int width = parameters.at(ID(WIDTH)).as_int(); int s_width = parameters.at(ID(S_WIDTH)).as_int(); int y_width = width << s_width; - int b_width = ceil_log2(y_width + 1); + int b_width = ceil_log2(y_width); Node a = factory.extend(inputs.at(ID(A)), y_width, false); Node s = factory.extend(inputs.at(ID(S)), b_width, false); Node b = factory.mul(s, factory.constant(Const(width, b_width))); diff --git a/kernel/functionalir.h b/kernel/functionalir.h index 44e3589db..0559eae89 100644 --- a/kernel/functionalir.h +++ b/kernel/functionalir.h @@ -109,13 +109,13 @@ public: // unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b) unsigned_greater_equal, // logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b - // required: M <= clog2(N + 1) + // required: M == clog2(N) logical_shift_left, // logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b - // required: M <= clog2(N + 1) + // required: M == clog2(N) logical_shift_right, // arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b - // required: M <= clog2(N + 1) + // required: M == clog2(N) arithmetic_shift_right, // mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a mux, @@ -373,6 +373,8 @@ public: friend class FunctionalIR; explicit Factory(FunctionalIR &ir) : _ir(ir) {} Node add(NodeData &&fn, Sort &&sort, std::initializer_list args) { + log_assert(!sort.is_signal() || sort.width() > 0); + log_assert(!sort.is_memory() || sort.addr_width() > 0 && sort.data_width() > 0); Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)}); for (auto arg : args) ref.append_arg(Graph::ConstRef(arg)); @@ -382,7 +384,7 @@ public: return _ir._graph[n._ref.index()]; } void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); } - void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); } + void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); } void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } public: Node slice(Node a, int offset, int out_width) { diff --git a/tests/functional/rtlil_cells.py b/tests/functional/rtlil_cells.py index b4c0ad1c1..7858e3781 100644 --- a/tests/functional/rtlil_cells.py +++ b/tests/functional/rtlil_cells.py @@ -206,7 +206,10 @@ shift_widths = [ (32, 32, 64, True, True), (32, 32, 64, False, True), # at least one test where the result is going to be truncated - (32, 6, 16, False, False) + (32, 6, 16, False, False), + # since 1-bit shifts are special cased + (1, 4, 1, False, False), + (1, 4, 1, True, False), ] rtlil_cells = [