functional backend: add different types of input/output/state variables

This commit is contained in:
Emily Schmidt 2024-08-06 09:56:28 +01:00
parent 79a1b691ea
commit 50047d25b3
5 changed files with 237 additions and 120 deletions

View File

@ -150,8 +150,8 @@ template<class NodePrinter> struct CxxPrintVisitor : public Functional::Abstract
void arithmetic_shift_right(Node, Node a, Node b) override { print("{}.arithmetic_shift_right({})", a, b); }
void mux(Node, Node a, Node b, Node s) override { print("{2}.any() ? {1} : {0}", a, b, s); }
void constant(Node, RTLIL::Const const & value) override { print("{}", cxx_const(value)); }
void input(Node, IdString name) override { print("input.{}", input_struct[name]); }
void state(Node, IdString name) override { print("current_state.{}", state_struct[name]); }
void input(Node, IdString name, IdString type) override { log_assert(type == ID($input)); print("input.{}", input_struct[name]); }
void state(Node, IdString name, IdString type) override { log_assert(type == ID($state)); print("current_state.{}", state_struct[name]); }
void memory_read(Node, Node mem, Node addr) override { print("{}.read({})", mem, addr); }
void memory_write(Node, Node mem, Node addr, Node data) override { print("{}.write({}, {})", mem, addr, data); }
};
@ -175,12 +175,12 @@ struct CxxModule {
output_struct("Outputs"),
state_struct("State")
{
for (auto [name, sort] : ir.inputs())
input_struct.insert(name, sort);
for (auto [name, sort] : ir.outputs())
output_struct.insert(name, sort);
for (auto [name, sort] : ir.state())
state_struct.insert(name, sort);
for (auto input : ir.inputs())
input_struct.insert(input->name, input->sort);
for (auto output : ir.outputs())
output_struct.insert(output->name, output->sort);
for (auto state : ir.states())
state_struct.insert(state->name, state->sort);
module_name = CxxScope<int>().unique_name(module->name);
}
void write_header(CxxWriter &f) {
@ -197,19 +197,19 @@ struct CxxModule {
}
void write_initial_def(CxxWriter &f) {
f.print("void {0}::initialize({0}::State &state)\n{{\n", module_name);
for (auto [name, sort] : ir.state()) {
if (sort.is_signal())
f.print("\tstate.{} = {};\n", state_struct[name], cxx_const(ir.get_initial_state_signal(name)));
else if (sort.is_memory()) {
for (auto state : ir.states()) {
if (state->sort.is_signal())
f.print("\tstate.{} = {};\n", state_struct[state->name], cxx_const(state->initial_value_signal()));
else if (state->sort.is_memory()) {
f.print("\t{{\n");
f.print("\t\tstd::array<Signal<{}>, {}> mem;\n", sort.data_width(), 1<<sort.addr_width());
const auto &contents = ir.get_initial_state_memory(name);
f.print("\t\tstd::array<Signal<{}>, {}> mem;\n", state->sort.data_width(), 1<<state->sort.addr_width());
const auto &contents = state->initial_value_memory();
f.print("\t\tmem.fill({});\n", cxx_const(contents.default_value()));
for(auto range : contents)
for(auto addr = range.base(); addr < range.limit(); addr++)
if(!equal_def(range[addr], contents.default_value()))
f.print("\t\tmem[{}] = {};\n", addr, cxx_const(range[addr]));
f.print("\t\tstate.{} = mem;\n", state_struct[name]);
f.print("\t\tstate.{} = mem;\n", state_struct[state->name]);
f.print("\t}}\n");
}
}
@ -229,10 +229,10 @@ struct CxxModule {
node.visit(printVisitor);
f.print(";\n");
}
for (auto [name, sort] : ir.state())
f.print("\tnext_state.{} = {};\n", state_struct[name], node_name(ir.get_state_next_node(name)));
for (auto [name, sort] : ir.outputs())
f.print("\toutput.{} = {};\n", output_struct[name], node_name(ir.get_output_node(name)));
for (auto state : ir.states())
f.print("\tnext_state.{} = {};\n", state_struct[state->name], node_name(state->next_value()));
for (auto output : ir.outputs())
f.print("\toutput.{} = {};\n", output_struct[output->name], node_name(output->value()));
f.print("}}\n\n");
}
};

View File

@ -178,8 +178,8 @@ struct SmtPrintVisitor : public Functional::AbstractVisitor<SExpr> {
SExpr memory_read(Node, Node mem, Node addr) override { return list("select", n(mem), n(addr)); }
SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("store", n(mem), n(addr), n(data)); }
SExpr input(Node, IdString name) override { return input_struct.access("inputs", name); }
SExpr state(Node, IdString name) override { return state_struct.access("state", name); }
SExpr input(Node, IdString name, IdString type) override { log_assert(type == ID($input)); return input_struct.access("inputs", name); }
SExpr state(Node, IdString name, IdString type) override { log_assert(type == ID($state)); return state_struct.access("state", name); }
};
struct SmtModule {
@ -200,12 +200,12 @@ struct SmtModule {
, state_struct(scope.unique_name(module->name.str() + "_State"), scope)
{
scope.reserve(name + "-initial");
for (const auto &input : ir.inputs())
input_struct.insert(input.first, input.second);
for (const auto &output : ir.outputs())
output_struct.insert(output.first, output.second);
for (const auto &state : ir.state())
state_struct.insert(state.first, state.second);
for (auto input : ir.inputs())
input_struct.insert(input->name, input->sort);
for (auto output : ir.outputs())
output_struct.insert(output->name, output->sort);
for (auto state : ir.states())
state_struct.insert(state->name, state->sort);
}
void write_eval(SExprWriter &w)
@ -232,8 +232,8 @@ struct SmtModule {
w.comment(SmtSort(n.sort()).to_sexpr().to_string(), true);
}
w.open(list("pair"));
output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_output_node(name)); });
state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_state_next_node(name)); });
output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.output(name).value()); });
state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.state(name).next_value()); });
w.pop();
}
@ -241,14 +241,14 @@ struct SmtModule {
{
std::string initial = name + "-initial";
w << list("declare-const", initial, state_struct.name);
for (const auto &[name, sort] : ir.state()) {
if(sort.is_signal())
w << list("assert", list("=", state_struct.access(initial, name), smt_const(ir.get_initial_state_signal(name))));
else if(sort.is_memory()) {
auto contents = ir.get_initial_state_memory(name);
for(int i = 0; i < 1<<sort.addr_width(); i++) {
auto addr = smt_const(RTLIL::Const(i, sort.addr_width()));
w << list("assert", list("=", list("select", state_struct.access(initial, name), addr), smt_const(contents[i])));
for (auto state : ir.states()) {
if(state->sort.is_signal())
w << list("assert", list("=", state_struct.access(initial, state->name), smt_const(state->initial_value_signal())));
else if(state->sort.is_memory()) {
const auto &contents = state->initial_value_memory();
for(int i = 0; i < 1<<state->sort.addr_width(); i++) {
auto addr = smt_const(RTLIL::Const(i, state->sort.addr_width()));
w << list("assert", list("=", list("select", state_struct.access(initial, state->name), addr), smt_const(contents[i])));
}
}
}

View File

@ -142,10 +142,10 @@ struct FunctionalTestGeneric : public Pass
auto fir = Functional::IR::from_module(module);
for(auto node : fir)
std::cout << RTLIL::unescape_id(node.name()) << " = " << node.to_string([](auto n) { return RTLIL::unescape_id(n.name()); }) << "\n";
for(auto [name, sort] : fir.outputs())
std::cout << "output " << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_output_node(name).name()) << "\n";
for(auto [name, sort] : fir.state())
std::cout << "state " << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_state_next_node(name).name()) << "\n";
for(auto output : fir.all_outputs())
std::cout << RTLIL::unescape_id(output->type) << " " << RTLIL::unescape_id(output->name) << " = " << RTLIL::unescape_id(output->value().name()) << "\n";
for(auto state : fir.all_states())
std::cout << RTLIL::unescape_id(state->type) << " " << RTLIL::unescape_id(state->name) << " = " << RTLIL::unescape_id(state->next_value().name()) << "\n";
}
}
} FunctionalCxxBackend;

View File

@ -65,6 +65,51 @@ const char *fn_to_string(Fn fn) {
log_error("fn_to_string: unknown Functional::Fn value %d", (int)fn);
}
vector<IRInput const*> IR::inputs(IdString type) const {
vector<IRInput const*> ret;
for (const auto &[name, input] : _inputs)
if(input.type == type)
ret.push_back(&input);
return ret;
}
vector<IROutput const*> IR::outputs(IdString type) const {
vector<IROutput const*> ret;
for (const auto &[name, output] : _outputs)
if(output.type == type)
ret.push_back(&output);
return ret;
}
vector<IRState const*> IR::states(IdString type) const {
vector<IRState const*> ret;
for (const auto &[name, state] : _states)
if(state.type == type)
ret.push_back(&state);
return ret;
}
vector<IRInput const*> IR::all_inputs() const {
vector<IRInput const*> ret;
for (const auto &[name, input] : _inputs)
ret.push_back(&input);
return ret;
}
vector<IROutput const*> IR::all_outputs() const {
vector<IROutput const*> ret;
for (const auto &[name, output] : _outputs)
ret.push_back(&output);
return ret;
}
vector<IRState const*> IR::all_states() const {
vector<IRState const*> ret;
for (const auto &[name, state] : _states)
ret.push_back(&state);
return ret;
}
struct PrintVisitor : DefaultVisitor<std::string> {
std::function<std::string(Node)> np;
PrintVisitor(std::function<std::string(Node)> np) : np(np) { }
@ -73,8 +118,8 @@ struct PrintVisitor : DefaultVisitor<std::string> {
std::string zero_extend(Node, Node a, int out_width) override { return "zero_extend(" + np(a) + ", " + std::to_string(out_width) + ")"; }
std::string sign_extend(Node, Node a, int out_width) override { return "sign_extend(" + np(a) + ", " + std::to_string(out_width) + ")"; }
std::string constant(Node, RTLIL::Const const& value) override { return "constant(" + value.as_string() + ")"; }
std::string input(Node, IdString name) override { return "input(" + name.str() + ")"; }
std::string state(Node, IdString name) override { return "state(" + name.str() + ")"; }
std::string input(Node, IdString name, IdString type) override { return "input(" + name.str() + ", " + type.str() + ")"; }
std::string state(Node, IdString name, IdString type) override { return "state(" + name.str() + ", " + type.str() + ")"; }
std::string default_handler(Node self) override {
std::string ret = fn_to_string(self.fn());
ret += "(";
@ -199,7 +244,7 @@ private:
return handle_alu(g, factory.bitwise_or(p, g), g.width(), false, ci, factory.constant(Const(State::S0, 1))).at(ID(CO));
}
public:
std::variant<dict<IdString, Node>, Node> handle(IdString cellType, dict<IdString, Const> parameters, dict<IdString, Node> inputs)
std::variant<dict<IdString, Node>, Node> handle(IdString cellName, IdString cellType, dict<IdString, Const> parameters, dict<IdString, Node> inputs)
{
int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int();
int b_width = parameters.at(ID(B_WIDTH), Const(-1)).as_int();
@ -392,6 +437,26 @@ public:
return handle_lcu(inputs.at(ID(P)), inputs.at(ID(G)), inputs.at(ID(CI)));
} else if(cellType == ID($alu)) {
return handle_alu(inputs.at(ID(A)), inputs.at(ID(B)), y_width, a_signed && b_signed, inputs.at(ID(CI)), inputs.at(ID(BI)));
} else if(cellType.in({ID($assert), ID($assume), ID($live), ID($fair), ID($cover)})) {
Node a = factory.mux(factory.constant(Const(State::S1, 1)), inputs.at(ID(A)), inputs.at(ID(EN)));
auto &output = factory.add_output(cellName, cellType, Sort(1));
output.set_value(a);
return {};
} else if(cellType.in({ID($anyconst), ID($allconst), ID($anyseq), ID($allseq)})) {
int width = parameters.at(ID(WIDTH)).as_int();
auto &input = factory.add_input(cellName, cellType, Sort(width));
return factory.value(input);
} else if(cellType == ID($initstate)) {
if(factory.ir().has_state(ID($initstate), ID($state)))
return factory.value(factory.ir().state(ID($initstate)));
else {
auto &state = factory.add_state(ID($initstate), ID($state), Sort(1));
state.set_initial_value(RTLIL::Const(State::S1, 1));
state.set_next_value(factory.constant(RTLIL::Const(State::S0, 1)));
return factory.value(state);
}
} else if(cellType == ID($check)) {
log_error("The design contains a $check cell `%s'. This is not supported by the functional backend. Call `chformal -lower' to avoid this error.\n", cellName.c_str());
} else {
log_error("`%s' cells are not supported by the functional backend\n", cellType.c_str());
}
@ -448,16 +513,15 @@ public:
{
driver_map.add(module);
for (auto cell : module->cells()) {
if (cell->type.in(ID($assert), ID($assume), ID($cover), ID($check)))
if (cell->type.in(ID($assert), ID($assume), ID($live), ID($fair), ID($cover), ID($check)))
queue.emplace_back(cell);
}
for (auto wire : module->wires()) {
if (wire->port_input)
factory.add_input(wire->name, wire->width);
factory.add_input(wire->name, ID($input), Sort(wire->width));
if (wire->port_output) {
factory.add_output(wire->name, wire->width);
Node value = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width)));
factory.set_output(wire->name, value);
auto &output = factory.add_output(wire->name, ID($output), Sort(wire->width));
output.set_value(enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))));
}
}
memories_vector = Mem::get_all_memories(module);
@ -495,9 +559,9 @@ public:
// - Since wr port j can only have priority over wr port i if j > i, if we do writes in
// ascending index order the result will obey the priorty relation.
vector<Node> read_results;
factory.add_state(mem->cell->name, Sort(ceil_log2(mem->size), mem->width));
factory.set_initial_state(mem->cell->name, MemContents(mem));
Node node = factory.current_state(mem->cell->name);
auto &state = factory.add_state(mem->cell->name, ID($state), Sort(ceil_log2(mem->size), mem->width));
state.set_initial_value(MemContents(mem));
Node node = factory.value(state);
for (size_t i = 0; i < mem->wr_ports.size(); i++) {
const auto &wr = mem->wr_ports[i];
if (wr.clk_enable)
@ -521,7 +585,7 @@ public:
Node addr = enqueue(driver_map(DriveSpec(rd.addr)));
read_results.push_back(factory.memory_read(node, addr));
}
factory.set_next_state(mem->cell->name, node);
state.set_next_value(node);
return concatenate_read_results(mem, read_results);
}
void process_cell(Cell *cell)
@ -540,25 +604,25 @@ public:
if (!ff.has_gclk)
log_error("The design contains a %s flip-flop at %s. This is not supported by the functional backend. "
"Call async2sync or clk2fflogic to avoid this error.\n", log_id(cell->type), log_id(cell));
factory.add_state(ff.name, Sort(ff.width));
Node q_value = factory.current_state(ff.name);
auto &state = factory.add_state(ff.name, ID($state), Sort(ff.width));
Node q_value = factory.value(state);
factory.suggest_name(q_value, ff.name);
factory.update_pending(cell_outputs.at({cell, ID(Q)}), q_value);
factory.set_next_state(ff.name, enqueue(ff.sig_d));
factory.set_initial_state(ff.name, ff.val_init);
state.set_next_value(enqueue(ff.sig_d));
state.set_initial_value(ff.val_init);
} else {
dict<IdString, Node> connections;
IdString output_name; // for the single output case
int n_outputs = 0;
for(auto const &[name, sigspec] : cell->connections()) {
if(driver_map.celltypes.cell_input(cell->type, name))
if(driver_map.celltypes.cell_input(cell->type, name) && sigspec.size() > 0)
connections.insert({ name, enqueue(DriveChunkPort(cell, {name, sigspec})) });
if(driver_map.celltypes.cell_output(cell->type, name)) {
output_name = name;
n_outputs++;
}
}
std::variant<dict<IdString, Node>, Node> outputs = simplifier.handle(cell->type, cell->parameters, connections);
std::variant<dict<IdString, Node>, Node> outputs = simplifier.handle(cell->name, cell->type, cell->parameters, connections);
if(auto *nodep = std::get_if<Node>(&outputs); nodep != nullptr) {
log_assert(n_outputs == 1);
factory.update_pending(cell_outputs.at({cell, output_name}), *nodep);
@ -591,7 +655,7 @@ public:
DriveChunkWire wire_chunk = chunk.wire();
if (wire_chunk.is_whole()) {
if (wire_chunk.wire->port_input) {
Node node = factory.input(wire_chunk.wire->name);
Node node = factory.value(factory.ir().input(wire_chunk.wire->name));
factory.suggest_name(node, wire_chunk.wire->name);
factory.update_pending(pending, node);
} else {
@ -668,10 +732,12 @@ void IR::topological_sort() {
scc = true;
}
});
for(const auto &[name, sort]: _state_sorts)
toposort.process(get_state_next_node(name).id());
for(const auto &[name, sort]: _output_sorts)
toposort.process(get_output_node(name).id());
for(const auto &[name, state]: _states)
if(state.has_next_value())
toposort.process(state.next_value().id());
for(const auto &[name, output]: _outputs)
if(output.has_value())
toposort.process(output.value().id());
// any nodes untouched by this point are dead code and will be removed by permute
_graph.permute(perm);
if(scc) log_error("The design contains combinational loops. This is not supported by the functional backend. "

View File

@ -152,11 +152,60 @@ namespace Functional {
bool operator==(Sort const& other) const { return _v == other._v; }
unsigned int hash() const { return mkhash(_v); }
};
class IR;
class Factory;
class Node;
class IRInput {
friend class Factory;
public:
IdString name;
IdString type;
Sort sort;
private:
IRInput(IR &, IdString name, IdString type, Sort sort)
: name(name), type(type), sort(std::move(sort)) {}
};
class IROutput {
friend class Factory;
IR &_ir;
public:
IdString name;
IdString type;
Sort sort;
private:
IROutput(IR &ir, IdString name, IdString type, Sort sort)
: _ir(ir), name(name), type(type), sort(std::move(sort)) {}
public:
Node value() const;
bool has_value() const;
void set_value(Node value);
};
class IRState {
friend class Factory;
IR &_ir;
public:
IdString name;
IdString type;
Sort sort;
private:
std::variant<RTLIL::Const, MemContents> _initial;
IRState(IR &ir, IdString name, IdString type, Sort sort)
: _ir(ir), name(name), type(type), sort(std::move(sort)) {}
public:
Node next_value() const;
bool has_next_value() const;
RTLIL::Const const& initial_value_signal() const { return std::get<RTLIL::Const>(_initial); }
MemContents const& initial_value_memory() const { return std::get<MemContents>(_initial); }
void set_next_value(Node value);
void set_initial_value(RTLIL::Const value) { value.extu(sort.width()); _initial = std::move(value); }
void set_initial_value(MemContents value) { log_assert(Sort(value.addr_width(), value.data_width()) == sort); _initial = std::move(value); }
};
class IR {
friend class Factory;
friend class Node;
friend class IRInput;
friend class IROutput;
friend class IRState;
// one NodeData is stored per Node, containing the function and non-node arguments
// note that NodeData is deduplicated by ComputeGraph
class NodeData {
@ -164,7 +213,7 @@ namespace Functional {
std::variant<
std::monostate,
RTLIL::Const,
IdString,
std::pair<IdString, IdString>,
int
> _extra;
public:
@ -173,7 +222,7 @@ namespace Functional {
template<class T> NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward<T>(extra)) {}
Fn fn() const { return _fn; }
const RTLIL::Const &as_const() const { return std::get<RTLIL::Const>(_extra); }
IdString as_idstring() const { return std::get<IdString>(_extra); }
std::pair<IdString, IdString> as_idstring_pair() const { return std::get<std::pair<IdString, IdString>>(_extra); }
int as_int() const { return std::get<int>(_extra); }
int hash() const {
return mkhash((unsigned int) _fn, mkhash(_extra));
@ -190,13 +239,12 @@ namespace Functional {
// the sparse_attr IdString stores a naming suggestion, retrieved with name()
// the key is currently used to identify the nodes that represent output and next state values
// the bool is true for next state values
using Graph = ComputeGraph<NodeData, Attr, IdString, std::pair<IdString, bool>>;
using Graph = ComputeGraph<NodeData, Attr, IdString, std::tuple<IdString, IdString, bool>>;
Graph _graph;
dict<IdString, Sort> _input_sorts;
dict<IdString, Sort> _output_sorts;
dict<IdString, Sort> _state_sorts;
dict<IdString, RTLIL::Const> _initial_state_signal;
dict<IdString, MemContents> _initial_state_memory;
dict<std::pair<IdString, IdString>, IRInput> _inputs;
dict<std::pair<IdString, IdString>, IROutput> _outputs;
dict<std::pair<IdString, IdString>, IRState> _states;
IR::Graph::Ref mutate(Node n);
public:
static IR from_module(Module *module);
Factory factory();
@ -204,13 +252,24 @@ namespace Functional {
Node operator[](int i);
void topological_sort();
void forward_buf();
dict<IdString, Sort> inputs() const { return _input_sorts; }
dict<IdString, Sort> outputs() const { return _output_sorts; }
dict<IdString, Sort> state() const { return _state_sorts; }
RTLIL::Const const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); }
MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); }
Node get_output_node(IdString name);
Node get_state_next_node(IdString name);
IRInput const& input(IdString name, IdString type) const { return _inputs.at({name, type}); }
IRInput const& input(IdString name) const { return input(name, ID($input)); }
IROutput const& output(IdString name, IdString type) const { return _outputs.at({name, type}); }
IROutput const& output(IdString name) const { return output(name, ID($output)); }
IRState const& state(IdString name, IdString type) const { return _states.at({name, type}); }
IRState const& state(IdString name) const { return state(name, ID($state)); }
bool has_input(IdString name, IdString type) const { return _inputs.count({name, type}); }
bool has_output(IdString name, IdString type) const { return _outputs.count({name, type}); }
bool has_state(IdString name, IdString type) const { return _states.count({name, type}); }
vector<IRInput const*> inputs(IdString type) const;
vector<IRInput const*> inputs() const { return inputs(ID($input)); }
vector<IROutput const*> outputs(IdString type) const;
vector<IROutput const*> outputs() const { return outputs(ID($output)); }
vector<IRState const*> states(IdString type) const;
vector<IRState const*> states() const { return states(ID($state)); }
vector<IRInput const*> all_inputs() const;
vector<IROutput const*> all_outputs() const;
vector<IRState const*> all_states() const;
class iterator {
friend class IR;
IR *_ir;
@ -236,6 +295,9 @@ namespace Functional {
class Node {
friend class Factory;
friend class IR;
friend class IRInput;
friend class IROutput;
friend class IRState;
IR::Graph::ConstRef _ref;
explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { }
explicit operator IR::Graph::ConstRef() { return _ref; }
@ -290,8 +352,8 @@ namespace Functional {
case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break;
case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break;
case Fn::constant: return v.constant(*this, _ref.function().as_const()); break;
case Fn::input: return v.input(*this, _ref.function().as_idstring()); break;
case Fn::state: return v.state(*this, _ref.function().as_idstring()); break;
case Fn::input: return v.input(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break;
case Fn::state: return v.state(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break;
case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break;
case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break;
}
@ -300,9 +362,14 @@ namespace Functional {
std::string to_string();
std::string to_string(std::function<std::string(Node)>);
};
inline IR::Graph::Ref IR::mutate(Node n) { return _graph[n._ref.index()]; }
inline Node IR::operator[](int i) { return Node(_graph[i]); }
inline Node IR::get_output_node(IdString name) { return Node(_graph({name, false})); }
inline Node IR::get_state_next_node(IdString name) { return Node(_graph({name, true})); }
inline Node IROutput::value() const { return Node(_ir._graph({name, type, false})); }
inline bool IROutput::has_value() const { return _ir._graph.has_key({name, type, false}); }
inline void IROutput::set_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, type, false}); }
inline Node IRState::next_value() const { return Node(_ir._graph({name, type, true})); }
inline bool IRState::has_next_value() const { return _ir._graph.has_key({name, type, true}); }
inline void IRState::set_next_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, type, true}); }
inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); }
inline arrow_proxy<Node> IR::iterator::operator->() { return arrow_proxy<Node>(**this); }
// AbstractVisitor provides an abstract base class for visitors
@ -336,8 +403,8 @@ namespace Functional {
virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0;
virtual T mux(Node self, Node a, Node b, Node s) = 0;
virtual T constant(Node self, RTLIL::Const const & value) = 0;
virtual T input(Node self, IdString name) = 0;
virtual T state(Node self, IdString name) = 0;
virtual T input(Node self, IdString name, IdString type) = 0;
virtual T state(Node self, IdString name, IdString type) = 0;
virtual T memory_read(Node self, Node mem, Node addr) = 0;
virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0;
};
@ -373,8 +440,8 @@ namespace Functional {
T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); }
T mux(Node self, Node, Node, Node) override { return default_handler(self); }
T constant(Node self, RTLIL::Const const &) override { return default_handler(self); }
T input(Node self, IdString) override { return default_handler(self); }
T state(Node self, IdString) override { return default_handler(self); }
T input(Node self, IdString, IdString) override { return default_handler(self); }
T state(Node self, IdString, IdString) override { return default_handler(self); }
T memory_read(Node self, Node, Node) override { return default_handler(self); }
T memory_write(Node self, Node, Node, Node) override { return default_handler(self); }
};
@ -383,7 +450,7 @@ namespace Functional {
friend class IR;
IR &_ir;
explicit Factory(IR &ir) : _ir(ir) {}
Node add(IR::NodeData &&fn, Sort &&sort, std::initializer_list<Node> args) {
Node add(IR::NodeData &&fn, Sort const &sort, std::initializer_list<Node> args) {
log_assert(!sort.is_signal() || sort.width() > 0);
log_assert(!sort.is_memory() || (sort.addr_width() > 0 && sort.data_width() > 0));
IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)});
@ -391,13 +458,11 @@ namespace Functional {
ref.append_arg(IR::Graph::ConstRef(arg));
return Node(ref);
}
IR::Graph::Ref mutate(Node n) {
return _ir._graph[n._ref.index()];
}
void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); }
void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); }
void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
public:
IR &ir() { return _ir; }
Node slice(Node a, int offset, int out_width) {
log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
if(offset == 0 && out_width == a.width())
@ -480,45 +545,31 @@ namespace Functional {
void update_pending(Node node, Node value) {
log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0);
log_assert(node.sort() == value.sort());
mutate(node).append_arg(value._ref);
_ir.mutate(node).append_arg(value._ref);
}
void add_input(IdString name, int width) {
auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width));
IRInput &add_input(IdString name, IdString type, Sort sort) {
auto [it, inserted] = _ir._inputs.emplace({name, type}, IRInput(_ir, name, type, std::move(sort)));
if (!inserted) log_error("input `%s` was re-defined", name.c_str());
return it->second;
}
void add_output(IdString name, int width) {
auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width));
IROutput &add_output(IdString name, IdString type, Sort sort) {
auto [it, inserted] = _ir._outputs.emplace({name, type}, IROutput(_ir, name, type, std::move(sort)));
if (!inserted) log_error("output `%s` was re-defined", name.c_str());
return it->second;
}
void add_state(IdString name, Sort sort) {
auto [it, inserted] = _ir._state_sorts.emplace(name, sort);
IRState &add_state(IdString name, IdString type, Sort sort) {
auto [it, inserted] = _ir._states.emplace({name, type}, IRState(_ir, name, type, std::move(sort)));
if (!inserted) log_error("state `%s` was re-defined", name.c_str());
return it->second;
}
Node input(IdString name) {
return add(IR::NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {});
Node value(IRInput const& input) {
return add(IR::NodeData(Fn::input, std::pair(input.name, input.type)), input.sort, {});
}
Node current_state(IdString name) {
return add(IR::NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {});
}
void set_output(IdString output, Node value) {
log_assert(_ir._output_sorts.at(output) == value.sort());
mutate(value).assign_key({output, false});
}
void set_initial_state(IdString state, RTLIL::Const value) {
Sort &sort = _ir._state_sorts.at(state);
value.extu(sort.width());
_ir._initial_state_signal.emplace(state, std::move(value));
}
void set_initial_state(IdString state, MemContents value) {
log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state));
_ir._initial_state_memory.emplace(state, std::move(value));
}
void set_next_state(IdString state, Node value) {
log_assert(_ir._state_sorts.at(state) == value.sort());
mutate(value).assign_key({state, true});
Node value(IRState const& state) {
return add(IR::NodeData(Fn::state, std::pair(state.name, state.type)), state.sort, {});
}
void suggest_name(Node node, IdString name) {
mutate(node).sparse_attr() = name;
_ir.mutate(node).sparse_attr() = name;
}
};
inline Factory IR::factory() { return Factory(*this); }