Add "onehot" pass, improve "pmux2shiftx" onehot handling

Signed-off-by: Clifford Wolf <clifford@clifford.at>
This commit is contained in:
Clifford Wolf 2019-04-20 17:52:16 +02:00
parent b3a3e08e38
commit 97e9caa4fa
1 changed files with 404 additions and 13 deletions

View File

@ -23,6 +23,172 @@
USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN
struct OnehotDatabase
{
Module *module;
const SigMap &sigmap;
bool verbose = false;
pool<SigBit> init_ones;
dict<SigSpec, pool<SigSpec>> sig_sources_db;
dict<SigSpec, bool> sig_onehot_cache;
pool<SigSpec> recursion_guard;
OnehotDatabase(Module *module, const SigMap &sigmap) : module(module), sigmap(sigmap)
{
}
void initialize()
{
for (auto wire : module->wires())
{
auto it = wire->attributes.find("\\init");
if (it == wire->attributes.end())
continue;
auto &val = it->second;
int width = std::max(GetSize(wire), GetSize(val));
for (int i = 0; i < width; i++)
if (val[i] == State::S1)
init_ones.insert(sigmap(SigBit(wire, i)));
}
for (auto cell : module->cells())
{
vector<SigSpec> inputs;
SigSpec output;
if (cell->type.in("$adff", "$dff", "$dffe", "$dlatch", "$ff"))
{
output = cell->getPort("\\Q");
if (cell->type == "$adff")
inputs.push_back(cell->getParam("\\ARST_VALUE"));
inputs.push_back(cell->getPort("\\D"));
}
if (cell->type.in("$mux", "$pmux"))
{
output = cell->getPort("\\Y");
inputs.push_back(cell->getPort("\\A"));
SigSpec B = cell->getPort("\\B");
for (int i = 0; i < GetSize(B); i += GetSize(output))
inputs.push_back(B.extract(i, GetSize(output)));
}
if (!output.empty())
{
output = sigmap(output);
auto &srcs = sig_sources_db[output];
for (auto src : inputs) {
while (!src.empty() && src[GetSize(src)-1] == State::S0)
src.remove(GetSize(src)-1);
srcs.insert(sigmap(src));
}
}
}
}
void query_worker(const SigSpec &sig, bool &retval, bool &cache, int indent)
{
if (verbose)
log("%*s %s\n", indent, "", log_signal(sig));
log_assert(retval);
if (recursion_guard.count(sig)) {
if (verbose)
log("%*s - recursion\n", indent, "");
cache = false;
return;
}
auto it = sig_onehot_cache.find(sig);
if (it != sig_onehot_cache.end()) {
if (verbose)
log("%*s - cached (%s)\n", indent, "", it->second ? "true" : "false");
if (!it->second)
retval = false;
return;
}
bool found_init_ones = false;
for (auto bit : sig) {
if (init_ones.count(bit)) {
if (found_init_ones) {
if (verbose)
log("%*s - non-onehot init value\n", indent, "");
retval = false;
break;
}
found_init_ones = true;
}
}
if (retval)
{
if (sig.is_fully_const())
{
bool found_ones = false;
for (auto bit : sig) {
if (bit == State::S1) {
if (found_ones) {
if (verbose)
log("%*s - non-onehot constant\n", indent, "");
retval = false;
break;
}
found_ones = true;
}
}
}
else
{
auto srcs = sig_sources_db.find(sig);
if (srcs == sig_sources_db.end()) {
if (verbose)
log("%*s - no sources for non-const signal\n", indent, "");
retval = false;
} else {
for (auto &src : srcs->second) {
bool child_cache = true;
recursion_guard.insert(sig);
query_worker(src, retval, child_cache, indent+4);
recursion_guard.erase(sig);
if (!child_cache)
cache = false;
if (!retval)
break;
}
}
}
}
// it is always safe to cache a negative result
if (cache || !retval)
sig_onehot_cache[sig] = retval;
}
bool query(const SigSpec &sig)
{
bool retval = true;
bool cache = true;
if (verbose)
log("** ONEHOT QUERY START (%s)\n", log_signal(sig));
query_worker(sig, retval, cache, 3);
if (verbose)
log("** ONEHOT QUERY RESULT = %s\n", retval ? "true" : "false");
// it is always safe to cache the root result of a query
if (!cache)
sig_onehot_cache[sig] = retval;
return retval;
}
};
struct Pmux2ShiftxPass : public Pass {
Pmux2ShiftxPass() : Pass("pmux2shiftx", "transform $pmux cells to $shiftx cells") { }
void help() YS_OVERRIDE
@ -33,6 +199,9 @@ struct Pmux2ShiftxPass : public Pass {
log("\n");
log("This pass transforms $pmux cells to $shiftx cells.\n");
log("\n");
log(" -v, -vv\n");
log(" verbose output\n");
log("\n");
log(" -min_density <percentage>\n");
log(" specifies the minimum density for the shifter\n");
log(" default: 50\n");
@ -41,9 +210,9 @@ struct Pmux2ShiftxPass : public Pass {
log(" specified the minimum number of choices for a control signal\n");
log(" default: 3\n");
log("\n");
log(" -allow_onehot\n");
log(" by default, pmuxes with one-hot encoded control signals are not\n");
log(" converted. this option disables that check.\n");
log(" -onehot ignore|pmux|shiftx\n");
log(" select strategy for one-hot encoded control signals\n");
log(" default: pmux\n");
log("\n");
}
void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
@ -51,6 +220,9 @@ struct Pmux2ShiftxPass : public Pass {
int min_density = 50;
int min_choices = 3;
bool allow_onehot = false;
bool optimize_onehot = true;
bool verbose = false;
bool verbose_onehot = false;
log_header(design, "Executing PMUX2SHIFTX pass.\n");
@ -64,8 +236,31 @@ struct Pmux2ShiftxPass : public Pass {
min_choices = atoi(args[++argidx].c_str());
continue;
}
if (args[argidx] == "-allow_onehot") {
if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "ignore") {
argidx++;
allow_onehot = false;
optimize_onehot = false;
continue;
}
if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "pmux") {
argidx++;
allow_onehot = false;
optimize_onehot = true;
continue;
}
if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "shiftx") {
argidx++;
allow_onehot = true;
optimize_onehot = false;
continue;
}
if (args[argidx] == "-v") {
verbose = true;
continue;
}
if (args[argidx] == "-vv") {
verbose = true;
verbose_onehot = true;
continue;
}
break;
@ -75,10 +270,15 @@ struct Pmux2ShiftxPass : public Pass {
for (auto module : design->selected_modules())
{
SigMap sigmap(module);
OnehotDatabase onehot_db(module, sigmap);
onehot_db.verbose = verbose_onehot;
if (optimize_onehot)
onehot_db.initialize();
dict<SigBit, pair<SigSpec, Const>> eqdb;
for (auto cell : module->selected_cells())
for (auto cell : module->cells())
{
if (cell->type == "$eq")
{
@ -181,6 +381,12 @@ struct Pmux2ShiftxPass : public Pass {
bool printed_pmux_header = false;
if (verbose) {
printed_pmux_header = true;
log("Inspecting $pmux cell %s/%s.\n", log_id(module), log_id(cell));
log(" data width: %d (next power-of-2 = %d, log2 = %d)\n", width, extwidth, width_bits);
}
SigSpec updated_S = cell->getPort("\\S");
SigSpec updated_B = cell->getPort("\\B");
@ -196,7 +402,7 @@ struct Pmux2ShiftxPass : public Pass {
}
// find the relevant choices
bool is_onehot = true;
bool is_onehot = GetSize(sig) > 2;
dict<Const, int> choices;
for (int i : seldb.at(sig)) {
Const val = eqdb.at(S[i]).second;
@ -211,14 +417,17 @@ struct Pmux2ShiftxPass : public Pass {
// TBD: also find choices that are using signals that are subsets of the bits in "sig"
if (is_onehot && !allow_onehot) {
seldb.erase(sig);
continue;
}
if (!verbose)
{
if (is_onehot && !allow_onehot && !optimize_onehot) {
seldb.erase(sig);
continue;
}
if (GetSize(choices) < min_choices) {
seldb.erase(sig);
continue;
if (GetSize(choices) < min_choices) {
seldb.erase(sig);
continue;
}
}
if (!printed_pmux_header) {
@ -229,6 +438,65 @@ struct Pmux2ShiftxPass : public Pass {
log(" checking ctrl signal %s\n", log_signal(sig));
auto print_choices = [&]() {
log(" table of choices:\n");
for (auto &it : choices)
log(" %3d: %s: %s\n", it.second, log_signal(it.first),
log_signal(B.extract(it.second*width, width)));
};
if (verbose)
{
if (is_onehot && !allow_onehot && !optimize_onehot) {
print_choices();
log(" ignoring one-hot encoding.\n");
seldb.erase(sig);
continue;
}
if (GetSize(choices) < min_choices) {
print_choices();
log(" insufficient choices.\n");
seldb.erase(sig);
continue;
}
}
if (is_onehot && optimize_onehot)
{
print_choices();
if (!onehot_db.query(sig))
{
log(" failed to detect onehot driver. do not optimize.\n");
}
else
{
log(" optimizing one-hot encoding.\n");
for (auto &it : choices)
{
const Const &val = it.first;
int index = -1;
for (int i = 0; i < GetSize(val); i++)
if (val[i] == State::S1) {
log_assert(index < 0);
index = i;
}
if (index < 0) {
log(" %3d: zero encoding.\n", it.second);
continue;
}
SigBit new_ctrl = sig[index];
log(" %3d: new crtl signal is %s.\n", it.second, log_signal(new_ctrl));
updated_S[it.second] = new_ctrl;
}
}
seldb.erase(sig);
continue;
}
// find the best permutation
vector<int> perm_new_from_old(GetSize(sig));
Const perm_xormask(State::S0, GetSize(sig));
@ -434,4 +702,127 @@ struct Pmux2ShiftxPass : public Pass {
}
} Pmux2ShiftxPass;
struct OnehotPass : public Pass {
OnehotPass() : Pass("onehot", "optimize $eq cells for onehot signals") { }
void help() YS_OVERRIDE
{
// |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
log("\n");
log(" onehot [options] [selection]\n");
log("\n");
log("This pass optimizes $eq cells that compare one-hot signals against constants\n");
log("\n");
log(" -v, -vv\n");
log(" verbose output\n");
log("\n");
}
void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
{
bool verbose = false;
bool verbose_onehot = false;
log_header(design, "Executing ONEHOT pass.\n");
size_t argidx;
for (argidx = 1; argidx < args.size(); argidx++) {
if (args[argidx] == "-v") {
verbose = true;
continue;
}
if (args[argidx] == "-vv") {
verbose = true;
verbose_onehot = true;
continue;
}
break;
}
extra_args(args, argidx, design);
for (auto module : design->selected_modules())
{
SigMap sigmap(module);
OnehotDatabase onehot_db(module, sigmap);
onehot_db.verbose = verbose_onehot;
onehot_db.initialize();
for (auto cell : module->selected_cells())
{
if (cell->type != "$eq")
continue;
SigSpec A = sigmap(cell->getPort("\\A"));
SigSpec B = sigmap(cell->getPort("\\B"));
int a_width = cell->getParam("\\A_WIDTH").as_int();
int b_width = cell->getParam("\\B_WIDTH").as_int();
if (a_width < b_width) {
bool a_signed = cell->getParam("\\A_SIGNED").as_int();
A.extend_u0(b_width, a_signed);
}
if (b_width < a_width) {
bool b_signed = cell->getParam("\\B_SIGNED").as_int();
B.extend_u0(a_width, b_signed);
}
if (A.is_fully_const())
std::swap(A, B);
if (!B.is_fully_const())
continue;
if (verbose)
log("Checking $eq(%s, %s) cell %s/%s.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell));
if (!onehot_db.query(A)) {
if (verbose)
log(" onehot driver test on %s failed.\n", log_signal(A));
continue;
}
int index = -1;
bool not_onehot = false;
for (int i = 0; i < GetSize(B); i++) {
if (B[i] != State::S1)
continue;
if (index >= 0)
not_onehot = true;
index = i;
}
if (index < 0) {
if (verbose)
log(" not optimizing the zero pattern.\n");
continue;
}
SigSpec Y = cell->getPort("\\Y");
if (not_onehot)
{
if (verbose)
log(" replacing with constant 0 driver.\n");
else
log("Replacing one-hot $eq(%s, %s) cell %s/%s with constant 0 driver.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell));
module->connect(Y, SigSpec(1, GetSize(Y)));
}
else
{
SigSpec sig = A[index];
if (verbose)
log(" replacing with signal %s.\n", log_signal(sig));
else
log("Replacing one-hot $eq(%s, %s) cell %s/%s with signal %s.\n",log_signal(A), log_signal(B), log_id(module), log_id(cell), log_signal(sig));
sig.extend_u0(GetSize(Y));
module->connect(Y, sig);
}
module->remove(cell);
}
}
}
} OnehotPass;
PRIVATE_NAMESPACE_END