functional backend: require shift width == clog2(operand width)

This commit is contained in:
Emily Schmidt 2024-07-17 13:26:53 +01:00
parent 7f8f21b980
commit c0c90c2c31
3 changed files with 34 additions and 29 deletions

View File

@ -101,17 +101,6 @@ std::string FunctionalIR::Node::to_string(std::function<std::string(Node)> np)
class CellSimplifier {
using Node = FunctionalIR::Node;
FunctionalIR::Factory &factory;
Node reduce_shift_width(Node b, int y_width) {
log_assert(y_width > 0);
int new_width = ceil_log2(y_width + 1);
if (b.width() <= new_width) {
return b;
} else {
Node lower_b = factory.slice(b, 0, new_width);
Node overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b.width())));
return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow);
}
}
Node sign(Node a) {
return factory.slice(a, a.width() - 1, 1);
}
@ -121,19 +110,30 @@ class CellSimplifier {
Node abs(Node a) {
return neg_if(a, sign(a));
}
Node handle_shift(Node a, Node b, bool is_right, bool is_signed) {
// to prevent new_width == 0, we handle this case separately
if(a.width() == 1) {
if(!is_signed)
return factory.bitwise_and(a, factory.bitwise_not(factory.reduce_or(b)));
else
return a;
}
int new_width = ceil_log2(a.width());
Node b_truncated = factory.extend(b, new_width, false);
Node y =
!is_right ? factory.logical_shift_left(a, b_truncated) :
!is_signed ? factory.logical_shift_right(a, b_truncated) :
factory.arithmetic_shift_right(a, b_truncated);
if(b.width() <= new_width)
return y;
Node overflow = factory.unsigned_greater_equal(b, factory.constant(RTLIL::Const(a.width(), b.width())));
Node y_if_overflow = is_signed ? factory.extend(sign(a), a.width(), true) : factory.constant(RTLIL::Const(State::S0, a.width()));
return factory.mux(y, y_if_overflow, overflow);
}
public:
Node logical_shift_left(Node a, Node b) {
Node reduced_b = reduce_shift_width(b, a.width());
return factory.logical_shift_left(a, reduced_b);
}
Node logical_shift_right(Node a, Node b) {
Node reduced_b = reduce_shift_width(b, a.width());
return factory.logical_shift_right(a, reduced_b);
}
Node arithmetic_shift_right(Node a, Node b) {
Node reduced_b = reduce_shift_width(b, a.width());
return factory.arithmetic_shift_right(a, reduced_b);
}
Node logical_shift_left(Node a, Node b) { return handle_shift(a, b, false, false); }
Node logical_shift_right(Node a, Node b) { return handle_shift(a, b, true, false); }
Node arithmetic_shift_right(Node a, Node b) { return handle_shift(a, b, true, true); }
Node bitwise_mux(Node a, Node b, Node s) {
Node aa = factory.bitwise_and(a, factory.bitwise_not(s));
Node bb = factory.bitwise_and(b, s);
@ -348,7 +348,7 @@ public:
int width = parameters.at(ID(WIDTH)).as_int();
int s_width = parameters.at(ID(S_WIDTH)).as_int();
int y_width = width << s_width;
int b_width = ceil_log2(y_width + 1);
int b_width = ceil_log2(y_width);
Node a = factory.extend(inputs.at(ID(A)), y_width, false);
Node s = factory.extend(inputs.at(ID(S)), b_width, false);
Node b = factory.mul(s, factory.constant(Const(width, b_width)));

View File

@ -109,13 +109,13 @@ public:
// unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b)
unsigned_greater_equal,
// logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b
// required: M <= clog2(N + 1)
// required: M == clog2(N)
logical_shift_left,
// logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b
// required: M <= clog2(N + 1)
// required: M == clog2(N)
logical_shift_right,
// arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b
// required: M <= clog2(N + 1)
// required: M == clog2(N)
arithmetic_shift_right,
// mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a
mux,
@ -373,6 +373,8 @@ public:
friend class FunctionalIR;
explicit Factory(FunctionalIR &ir) : _ir(ir) {}
Node add(NodeData &&fn, Sort &&sort, std::initializer_list<Node> args) {
log_assert(!sort.is_signal() || sort.width() > 0);
log_assert(!sort.is_memory() || sort.addr_width() > 0 && sort.data_width() > 0);
Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)});
for (auto arg : args)
ref.append_arg(Graph::ConstRef(arg));
@ -382,7 +384,7 @@ public:
return _ir._graph[n._ref.index()];
}
void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); }
void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); }
void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); }
void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
public:
Node slice(Node a, int offset, int out_width) {

View File

@ -206,7 +206,10 @@ shift_widths = [
(32, 32, 64, True, True),
(32, 32, 64, False, True),
# at least one test where the result is going to be truncated
(32, 6, 16, False, False)
(32, 6, 16, False, False),
# since 1-bit shifts are special cased
(1, 4, 1, False, False),
(1, 4, 1, True, False),
]
rtlil_cells = [