pattern shiftmul_right
//
// Optimize mul+shift pairs that result from expressions such as foo[s*W+:W]
//

match shift
	select shift->type.in($shift, $shiftx, $shr)
	filter !port(shift, \B).empty()
endmatch

// the right shift amount
state <SigSpec> shift_amount
// log2 scale factor in interpreting of shift_amount
// due to zero padding on the shift cell's B port
state <int> log2scale

code shift_amount log2scale
	shift_amount = port(shift, \B);
	if (shift->type.in($shr) || !param(shift, \B_SIGNED).as_bool())
		shift_amount.append(State::S0);

	// at this point shift_amount is signed, make
	// sure we can never go negative
	if (shift_amount.bits().back() != State::S0)
		reject;

	while (shift_amount.bits().back() == State::S0) {
		shift_amount.remove(GetSize(shift_amount) - 1);
		if (shift_amount.empty()) reject;
	}

	log2scale = 0;
	while (shift_amount[0] == State::S0) {
		shift_amount.remove(0);
		if (shift_amount.empty()) reject;
		log2scale++;
	}

	if (GetSize(shift_amount) > 20)
		reject;
endcode

state <SigSpec> mul_din
state <Const> mul_const

match mul
	select mul->type.in($mul)
	index <SigSpec> port(mul, \Y) === shift_amount
	filter !param(mul, \A_SIGNED).as_bool()

	choice <IdString> constport {\A, \B}
	filter port(mul, constport).is_fully_const()

	define <IdString> varport (constport == \A ? \B : \A)
	set mul_const SigSpec({port(mul, constport), SigSpec(State::S0, log2scale)}).as_const()
	// get mul_din unmapped (so no `port()` shorthand)
	// because we will be using it to set the \A port
	// on the shift cell, and we want to stay close
	// to the original design
	set mul_din mul->getPort(varport)
endmatch

code
{
	if (mul_const.empty() || GetSize(mul_const) > 20)
		reject;

	// make sure there's no overlap in the signal
	// selections by the shiftmul pattern
	if (GetSize(port(shift, \Y)) > mul_const.as_int())
		reject;

	int factor_bits = ceil_log2(mul_const.as_int());
	// make sure the multiplication never wraps around
	if (GetSize(shift_amount) + log2scale < factor_bits + GetSize(mul_din))
		reject;

	did_something = true;
	log("right shiftmul pattern in %s: shift=%s, mul=%s\n", log_id(module), log_id(shift), log_id(mul));

	int const_factor = mul_const.as_int();
	int new_const_factor = 1 << factor_bits;
	SigSpec padding(State::Sx, new_const_factor-const_factor);
	SigSpec old_a = port(shift, \A), new_a;

	for (int i = 0; i*const_factor < GetSize(old_a); i++) {
		if ((i+1)*const_factor < GetSize(old_a)) {
			SigSpec slice = old_a.extract(i*const_factor, const_factor);
			new_a.append(slice);
			new_a.append(padding);
		} else {
			new_a.append(old_a.extract_end(i*const_factor));
		}
	}

	SigSpec new_b = {mul_din, SigSpec(State::S0, factor_bits)};
	if (param(shift, \B_SIGNED).as_bool())
		new_b.append(State::S0);

	shift->setPort(\A, new_a);
	shift->setParam(\A_WIDTH, GetSize(new_a));
	shift->setPort(\B, new_b);
	shift->setParam(\B_WIDTH, GetSize(new_b));

	blacklist(shift);
	accept;
}
endcode