peepopt: Add left-shift 'shiftmul' variant

Add a separate shiftmul pattern to match on left shifts which implement
demuxing. This mirrors the right shift pattern matcher but is probably
best kept separate instead of merging the two into a single matcher.
In any case the diff of the two matchers should be easily readable.
This commit is contained in:
Martin Povišer 2023-08-02 18:04:54 +02:00
parent 038a5e1ed4
commit aa9b86aeec
4 changed files with 167 additions and 5 deletions

View File

@ -42,7 +42,8 @@ GENFILES += passes/pmgen/peepopt_pm.h
passes/pmgen/peepopt.o: passes/pmgen/peepopt_pm.h passes/pmgen/peepopt.o: passes/pmgen/peepopt_pm.h
$(eval $(call add_extra_objs,passes/pmgen/peepopt_pm.h)) $(eval $(call add_extra_objs,passes/pmgen/peepopt_pm.h))
PEEPOPT_PATTERN = passes/pmgen/peepopt_shiftmul.pmg PEEPOPT_PATTERN = passes/pmgen/peepopt_shiftmul_right.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_shiftmul_left.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_muldiv.pmg PEEPOPT_PATTERN += passes/pmgen/peepopt_muldiv.pmg
passes/pmgen/peepopt_pm.h: passes/pmgen/pmgen.py $(PEEPOPT_PATTERN) passes/pmgen/peepopt_pm.h: passes/pmgen/pmgen.py $(PEEPOPT_PATTERN)

View File

@ -59,7 +59,7 @@ struct PeepoptPass : public Pass {
if (!genmode.empty()) if (!genmode.empty())
{ {
if (genmode == "shiftmul") if (genmode == "shiftmul")
GENERATE_PATTERN(peepopt_pm, shiftmul); GENERATE_PATTERN(peepopt_pm, shiftmul_right);
else if (genmode == "muldiv") else if (genmode == "muldiv")
GENERATE_PATTERN(peepopt_pm, muldiv); GENERATE_PATTERN(peepopt_pm, muldiv);
else else
@ -79,7 +79,8 @@ struct PeepoptPass : public Pass {
pm.setup(module->selected_cells()); pm.setup(module->selected_cells());
pm.run_shiftmul(); pm.run_shiftmul_right();
pm.run_shiftmul_left();
pm.run_muldiv(); pm.run_muldiv();
} }
} }

View File

@ -0,0 +1,160 @@
pattern shiftmul_left
//
// Optimize mul+shift pairs that result from expressions such as foo[s*W+:W]
//
match shift
select shift->type.in($shift, $shiftx, $shl)
select shift->type.in($shl) || param(shift, \B_SIGNED).as_bool()
filter !port(shift, \B).empty()
endmatch
match neg
if shift->type.in($shift, $shiftx)
select neg->type == $neg
index <SigSpec> port(neg, \Y) === port(shift, \B)
filter !port(shift, \A).empty()
endmatch
// the left shift amount
state <SigSpec> shift_amount
// log2 scale factor in interpreting of shift_amount
// due to zero padding on the shift cell's B port
state <int> log2scale
code shift_amount log2scale
if (neg) {
// case of `$shift`, `$shiftx`
shift_amount = port(neg, \A);
if (!param(neg, \A_SIGNED).as_bool())
shift_amount.append(State::S0);
} else {
// case of `$shl`
shift_amount = port(shift, \B);
if (!param(shift, \B_SIGNED).as_bool())
shift_amount.append(State::S0);
}
// at this point shift_amount is signed, make
// sure we can never go negative
if (shift_amount.bits().back() != State::S0)
reject;
while (shift_amount.bits().back() == State::S0) {
shift_amount.remove(GetSize(shift_amount) - 1);
if (shift_amount.empty()) reject;
}
log2scale = 0;
while (shift_amount[0] == State::S0) {
shift_amount.remove(0);
if (shift_amount.empty()) reject;
log2scale++;
}
if (GetSize(shift_amount) > 20)
reject;
endcode
state <SigSpec> mul_din
state <Const> mul_const
match mul
select mul->type.in($mul)
index <SigSpec> port(mul, \Y) === shift_amount
filter !param(mul, \A_SIGNED).as_bool()
choice <IdString> constport {\A, \B}
filter port(mul, constport).is_fully_const()
define <IdString> varport (constport == \A ? \B : \A)
set mul_const SigSpec({port(mul, constport), SigSpec(State::S0, log2scale)}).as_const()
// get mul_din unmapped (so no `port()` shorthand)
// because we will be using it to set the \A port
// on the shift cell, and we want to stay close
// to the original design
set mul_din mul->getPort(varport)
endmatch
code
{
if (mul_const.empty() || GetSize(mul_const) > 20)
reject;
// make sure there's no overlap in the signal
// selections by the shiftmul pattern
if (GetSize(port(shift, \A)) > mul_const.as_int())
reject;
int factor_bits = ceil_log2(mul_const.as_int());
// make sure the multiplication never wraps around
if (GetSize(shift_amount) < factor_bits + GetSize(mul_din))
reject;
if (neg) {
// make sure the negation never wraps around
if (GetSize(port(shift, \B)) < factor_bits + GetSize(mul_din)
+ log2scale + 1)
reject;
}
did_something = true;
log("left shiftmul pattern in %s: shift=%s, mul=%s\n", log_id(module), log_id(shift), log_id(mul));
int const_factor = mul_const.as_int();
int new_const_factor = 1 << factor_bits;
SigSpec padding(State::Sm, new_const_factor-const_factor);
SigSpec old_y = port(shift, \Y), new_y;
int trunc = 0;
if (GetSize(old_y) % const_factor != 0) {
trunc = const_factor - GetSize(old_y) % const_factor;
old_y.append(SigSpec(State::Sm, trunc));
}
for (int i = 0; i*const_factor < GetSize(old_y); i++) {
SigSpec slice = old_y.extract(i*const_factor, const_factor);
new_y.append(slice);
new_y.append(padding);
}
if (trunc > 0)
new_y.remove(GetSize(new_y)-trunc, trunc);
{
// Now replace occurences of Sm in new_y with bits
// of a dummy wire
int padbits = 0;
for (auto bit : new_y)
if (bit == SigBit(State::Sm))
padbits++;
SigSpec padwire = module->addWire(NEW_ID, padbits);
for (int i = new_y.size() - 1; i >= 0; i--)
if (new_y[i] == SigBit(State::Sm)) {
new_y[i] = padwire.bits().back();
padwire.remove(padwire.size() - 1);
}
}
SigSpec new_b = {mul_din, SigSpec(State::S0, factor_bits)};
shift->setPort(\Y, new_y);
shift->setParam(\Y_WIDTH, GetSize(new_y));
if (shift->type == $shl) {
if (param(shift, \B_SIGNED).as_bool())
new_b.append(State::S0);
shift->setPort(\B, new_b);
shift->setParam(\B_WIDTH, GetSize(new_b));
} else {
SigSpec b_neg = module->addWire(NEW_ID, GetSize(new_b) + 1);
module->addNeg(NEW_ID, new_b, b_neg);
shift->setPort(\B, b_neg);
shift->setParam(\B_WIDTH, GetSize(b_neg));
}
blacklist(shift);
accept;
}
endcode

View File

@ -1,4 +1,4 @@
pattern shiftmul pattern shiftmul_right
// //
// Optimize mul+shift pairs that result from expressions such as foo[s*W+:W] // Optimize mul+shift pairs that result from expressions such as foo[s*W+:W]
// //
@ -76,7 +76,7 @@ code
reject; reject;
did_something = true; did_something = true;
log("shiftmul pattern in %s: shift=%s, mul=%s\n", log_id(module), log_id(shift), log_id(mul)); log("right shiftmul pattern in %s: shift=%s, mul=%s\n", log_id(module), log_id(shift), log_id(mul));
int const_factor = mul_const.as_int(); int const_factor = mul_const.as_int();
int new_const_factor = 1 << factor_bits; int new_const_factor = 1 << factor_bits;