From 50047d25b3806fb614c983a950de85b09db8fcfc Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Tue, 6 Aug 2024 09:56:28 +0100 Subject: [PATCH] functional backend: add different types of input/output/state variables --- backends/functional/cxx.cc | 38 +++---- backends/functional/smtlib.cc | 36 +++--- backends/functional/test_generic.cc | 8 +- kernel/functional.cc | 112 +++++++++++++++---- kernel/functional.h | 163 ++++++++++++++++++---------- 5 files changed, 237 insertions(+), 120 deletions(-) diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index a4755e144..19740777f 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -150,8 +150,8 @@ template struct CxxPrintVisitor : public Functional::Abstract void arithmetic_shift_right(Node, Node a, Node b) override { print("{}.arithmetic_shift_right({})", a, b); } void mux(Node, Node a, Node b, Node s) override { print("{2}.any() ? {1} : {0}", a, b, s); } void constant(Node, RTLIL::Const const & value) override { print("{}", cxx_const(value)); } - void input(Node, IdString name) override { print("input.{}", input_struct[name]); } - void state(Node, IdString name) override { print("current_state.{}", state_struct[name]); } + void input(Node, IdString name, IdString type) override { log_assert(type == ID($input)); print("input.{}", input_struct[name]); } + void state(Node, IdString name, IdString type) override { log_assert(type == ID($state)); print("current_state.{}", state_struct[name]); } void memory_read(Node, Node mem, Node addr) override { print("{}.read({})", mem, addr); } void memory_write(Node, Node mem, Node addr, Node data) override { print("{}.write({}, {})", mem, addr, data); } }; @@ -175,12 +175,12 @@ struct CxxModule { output_struct("Outputs"), state_struct("State") { - for (auto [name, sort] : ir.inputs()) - input_struct.insert(name, sort); - for (auto [name, sort] : ir.outputs()) - output_struct.insert(name, sort); - for (auto [name, sort] : ir.state()) - state_struct.insert(name, sort); + for (auto input : ir.inputs()) + input_struct.insert(input->name, input->sort); + for (auto output : ir.outputs()) + output_struct.insert(output->name, output->sort); + for (auto state : ir.states()) + state_struct.insert(state->name, state->sort); module_name = CxxScope().unique_name(module->name); } void write_header(CxxWriter &f) { @@ -197,19 +197,19 @@ struct CxxModule { } void write_initial_def(CxxWriter &f) { f.print("void {0}::initialize({0}::State &state)\n{{\n", module_name); - for (auto [name, sort] : ir.state()) { - if (sort.is_signal()) - f.print("\tstate.{} = {};\n", state_struct[name], cxx_const(ir.get_initial_state_signal(name))); - else if (sort.is_memory()) { + for (auto state : ir.states()) { + if (state->sort.is_signal()) + f.print("\tstate.{} = {};\n", state_struct[state->name], cxx_const(state->initial_value_signal())); + else if (state->sort.is_memory()) { f.print("\t{{\n"); - f.print("\t\tstd::array, {}> mem;\n", sort.data_width(), 1<, {}> mem;\n", state->sort.data_width(), 1<sort.addr_width()); + const auto &contents = state->initial_value_memory(); f.print("\t\tmem.fill({});\n", cxx_const(contents.default_value())); for(auto range : contents) for(auto addr = range.base(); addr < range.limit(); addr++) if(!equal_def(range[addr], contents.default_value())) f.print("\t\tmem[{}] = {};\n", addr, cxx_const(range[addr])); - f.print("\t\tstate.{} = mem;\n", state_struct[name]); + f.print("\t\tstate.{} = mem;\n", state_struct[state->name]); f.print("\t}}\n"); } } @@ -229,10 +229,10 @@ struct CxxModule { node.visit(printVisitor); f.print(";\n"); } - for (auto [name, sort] : ir.state()) - f.print("\tnext_state.{} = {};\n", state_struct[name], node_name(ir.get_state_next_node(name))); - for (auto [name, sort] : ir.outputs()) - f.print("\toutput.{} = {};\n", output_struct[name], node_name(ir.get_output_node(name))); + for (auto state : ir.states()) + f.print("\tnext_state.{} = {};\n", state_struct[state->name], node_name(state->next_value())); + for (auto output : ir.outputs()) + f.print("\toutput.{} = {};\n", output_struct[output->name], node_name(output->value())); f.print("}}\n\n"); } }; diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc index 7fd6fe564..ec9e6b242 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib.cc @@ -178,8 +178,8 @@ struct SmtPrintVisitor : public Functional::AbstractVisitor { SExpr memory_read(Node, Node mem, Node addr) override { return list("select", n(mem), n(addr)); } SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("store", n(mem), n(addr), n(data)); } - SExpr input(Node, IdString name) override { return input_struct.access("inputs", name); } - SExpr state(Node, IdString name) override { return state_struct.access("state", name); } + SExpr input(Node, IdString name, IdString type) override { log_assert(type == ID($input)); return input_struct.access("inputs", name); } + SExpr state(Node, IdString name, IdString type) override { log_assert(type == ID($state)); return state_struct.access("state", name); } }; struct SmtModule { @@ -200,12 +200,12 @@ struct SmtModule { , state_struct(scope.unique_name(module->name.str() + "_State"), scope) { scope.reserve(name + "-initial"); - for (const auto &input : ir.inputs()) - input_struct.insert(input.first, input.second); - for (const auto &output : ir.outputs()) - output_struct.insert(output.first, output.second); - for (const auto &state : ir.state()) - state_struct.insert(state.first, state.second); + for (auto input : ir.inputs()) + input_struct.insert(input->name, input->sort); + for (auto output : ir.outputs()) + output_struct.insert(output->name, output->sort); + for (auto state : ir.states()) + state_struct.insert(state->name, state->sort); } void write_eval(SExprWriter &w) @@ -232,8 +232,8 @@ struct SmtModule { w.comment(SmtSort(n.sort()).to_sexpr().to_string(), true); } w.open(list("pair")); - output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_output_node(name)); }); - state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_state_next_node(name)); }); + output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.output(name).value()); }); + state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.state(name).next_value()); }); w.pop(); } @@ -241,14 +241,14 @@ struct SmtModule { { std::string initial = name + "-initial"; w << list("declare-const", initial, state_struct.name); - for (const auto &[name, sort] : ir.state()) { - if(sort.is_signal()) - w << list("assert", list("=", state_struct.access(initial, name), smt_const(ir.get_initial_state_signal(name)))); - else if(sort.is_memory()) { - auto contents = ir.get_initial_state_memory(name); - for(int i = 0; i < 1<sort.is_signal()) + w << list("assert", list("=", state_struct.access(initial, state->name), smt_const(state->initial_value_signal()))); + else if(state->sort.is_memory()) { + const auto &contents = state->initial_value_memory(); + for(int i = 0; i < 1<sort.addr_width(); i++) { + auto addr = smt_const(RTLIL::Const(i, state->sort.addr_width())); + w << list("assert", list("=", list("select", state_struct.access(initial, state->name), addr), smt_const(contents[i]))); } } } diff --git a/backends/functional/test_generic.cc b/backends/functional/test_generic.cc index 83ea09d8d..1617d54c0 100644 --- a/backends/functional/test_generic.cc +++ b/backends/functional/test_generic.cc @@ -142,10 +142,10 @@ struct FunctionalTestGeneric : public Pass auto fir = Functional::IR::from_module(module); for(auto node : fir) std::cout << RTLIL::unescape_id(node.name()) << " = " << node.to_string([](auto n) { return RTLIL::unescape_id(n.name()); }) << "\n"; - for(auto [name, sort] : fir.outputs()) - std::cout << "output " << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_output_node(name).name()) << "\n"; - for(auto [name, sort] : fir.state()) - std::cout << "state " << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_state_next_node(name).name()) << "\n"; + for(auto output : fir.all_outputs()) + std::cout << RTLIL::unescape_id(output->type) << " " << RTLIL::unescape_id(output->name) << " = " << RTLIL::unescape_id(output->value().name()) << "\n"; + for(auto state : fir.all_states()) + std::cout << RTLIL::unescape_id(state->type) << " " << RTLIL::unescape_id(state->name) << " = " << RTLIL::unescape_id(state->next_value().name()) << "\n"; } } } FunctionalCxxBackend; diff --git a/kernel/functional.cc b/kernel/functional.cc index ad507187d..202eaa502 100644 --- a/kernel/functional.cc +++ b/kernel/functional.cc @@ -65,6 +65,51 @@ const char *fn_to_string(Fn fn) { log_error("fn_to_string: unknown Functional::Fn value %d", (int)fn); } +vector IR::inputs(IdString type) const { + vector ret; + for (const auto &[name, input] : _inputs) + if(input.type == type) + ret.push_back(&input); + return ret; +} + +vector IR::outputs(IdString type) const { + vector ret; + for (const auto &[name, output] : _outputs) + if(output.type == type) + ret.push_back(&output); + return ret; +} + +vector IR::states(IdString type) const { + vector ret; + for (const auto &[name, state] : _states) + if(state.type == type) + ret.push_back(&state); + return ret; +} + +vector IR::all_inputs() const { + vector ret; + for (const auto &[name, input] : _inputs) + ret.push_back(&input); + return ret; +} + +vector IR::all_outputs() const { + vector ret; + for (const auto &[name, output] : _outputs) + ret.push_back(&output); + return ret; +} + +vector IR::all_states() const { + vector ret; + for (const auto &[name, state] : _states) + ret.push_back(&state); + return ret; +} + struct PrintVisitor : DefaultVisitor { std::function np; PrintVisitor(std::function np) : np(np) { } @@ -73,8 +118,8 @@ struct PrintVisitor : DefaultVisitor { std::string zero_extend(Node, Node a, int out_width) override { return "zero_extend(" + np(a) + ", " + std::to_string(out_width) + ")"; } std::string sign_extend(Node, Node a, int out_width) override { return "sign_extend(" + np(a) + ", " + std::to_string(out_width) + ")"; } std::string constant(Node, RTLIL::Const const& value) override { return "constant(" + value.as_string() + ")"; } - std::string input(Node, IdString name) override { return "input(" + name.str() + ")"; } - std::string state(Node, IdString name) override { return "state(" + name.str() + ")"; } + std::string input(Node, IdString name, IdString type) override { return "input(" + name.str() + ", " + type.str() + ")"; } + std::string state(Node, IdString name, IdString type) override { return "state(" + name.str() + ", " + type.str() + ")"; } std::string default_handler(Node self) override { std::string ret = fn_to_string(self.fn()); ret += "("; @@ -199,7 +244,7 @@ private: return handle_alu(g, factory.bitwise_or(p, g), g.width(), false, ci, factory.constant(Const(State::S0, 1))).at(ID(CO)); } public: - std::variant, Node> handle(IdString cellType, dict parameters, dict inputs) + std::variant, Node> handle(IdString cellName, IdString cellType, dict parameters, dict inputs) { int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int(); int b_width = parameters.at(ID(B_WIDTH), Const(-1)).as_int(); @@ -392,6 +437,26 @@ public: return handle_lcu(inputs.at(ID(P)), inputs.at(ID(G)), inputs.at(ID(CI))); } else if(cellType == ID($alu)) { return handle_alu(inputs.at(ID(A)), inputs.at(ID(B)), y_width, a_signed && b_signed, inputs.at(ID(CI)), inputs.at(ID(BI))); + } else if(cellType.in({ID($assert), ID($assume), ID($live), ID($fair), ID($cover)})) { + Node a = factory.mux(factory.constant(Const(State::S1, 1)), inputs.at(ID(A)), inputs.at(ID(EN))); + auto &output = factory.add_output(cellName, cellType, Sort(1)); + output.set_value(a); + return {}; + } else if(cellType.in({ID($anyconst), ID($allconst), ID($anyseq), ID($allseq)})) { + int width = parameters.at(ID(WIDTH)).as_int(); + auto &input = factory.add_input(cellName, cellType, Sort(width)); + return factory.value(input); + } else if(cellType == ID($initstate)) { + if(factory.ir().has_state(ID($initstate), ID($state))) + return factory.value(factory.ir().state(ID($initstate))); + else { + auto &state = factory.add_state(ID($initstate), ID($state), Sort(1)); + state.set_initial_value(RTLIL::Const(State::S1, 1)); + state.set_next_value(factory.constant(RTLIL::Const(State::S0, 1))); + return factory.value(state); + } + } else if(cellType == ID($check)) { + log_error("The design contains a $check cell `%s'. This is not supported by the functional backend. Call `chformal -lower' to avoid this error.\n", cellName.c_str()); } else { log_error("`%s' cells are not supported by the functional backend\n", cellType.c_str()); } @@ -448,16 +513,15 @@ public: { driver_map.add(module); for (auto cell : module->cells()) { - if (cell->type.in(ID($assert), ID($assume), ID($cover), ID($check))) + if (cell->type.in(ID($assert), ID($assume), ID($live), ID($fair), ID($cover), ID($check))) queue.emplace_back(cell); } for (auto wire : module->wires()) { if (wire->port_input) - factory.add_input(wire->name, wire->width); + factory.add_input(wire->name, ID($input), Sort(wire->width)); if (wire->port_output) { - factory.add_output(wire->name, wire->width); - Node value = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))); - factory.set_output(wire->name, value); + auto &output = factory.add_output(wire->name, ID($output), Sort(wire->width)); + output.set_value(enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width)))); } } memories_vector = Mem::get_all_memories(module); @@ -495,9 +559,9 @@ public: // - Since wr port j can only have priority over wr port i if j > i, if we do writes in // ascending index order the result will obey the priorty relation. vector read_results; - factory.add_state(mem->cell->name, Sort(ceil_log2(mem->size), mem->width)); - factory.set_initial_state(mem->cell->name, MemContents(mem)); - Node node = factory.current_state(mem->cell->name); + auto &state = factory.add_state(mem->cell->name, ID($state), Sort(ceil_log2(mem->size), mem->width)); + state.set_initial_value(MemContents(mem)); + Node node = factory.value(state); for (size_t i = 0; i < mem->wr_ports.size(); i++) { const auto &wr = mem->wr_ports[i]; if (wr.clk_enable) @@ -521,7 +585,7 @@ public: Node addr = enqueue(driver_map(DriveSpec(rd.addr))); read_results.push_back(factory.memory_read(node, addr)); } - factory.set_next_state(mem->cell->name, node); + state.set_next_value(node); return concatenate_read_results(mem, read_results); } void process_cell(Cell *cell) @@ -540,25 +604,25 @@ public: if (!ff.has_gclk) log_error("The design contains a %s flip-flop at %s. This is not supported by the functional backend. " "Call async2sync or clk2fflogic to avoid this error.\n", log_id(cell->type), log_id(cell)); - factory.add_state(ff.name, Sort(ff.width)); - Node q_value = factory.current_state(ff.name); + auto &state = factory.add_state(ff.name, ID($state), Sort(ff.width)); + Node q_value = factory.value(state); factory.suggest_name(q_value, ff.name); factory.update_pending(cell_outputs.at({cell, ID(Q)}), q_value); - factory.set_next_state(ff.name, enqueue(ff.sig_d)); - factory.set_initial_state(ff.name, ff.val_init); + state.set_next_value(enqueue(ff.sig_d)); + state.set_initial_value(ff.val_init); } else { dict connections; IdString output_name; // for the single output case int n_outputs = 0; for(auto const &[name, sigspec] : cell->connections()) { - if(driver_map.celltypes.cell_input(cell->type, name)) + if(driver_map.celltypes.cell_input(cell->type, name) && sigspec.size() > 0) connections.insert({ name, enqueue(DriveChunkPort(cell, {name, sigspec})) }); if(driver_map.celltypes.cell_output(cell->type, name)) { output_name = name; n_outputs++; } } - std::variant, Node> outputs = simplifier.handle(cell->type, cell->parameters, connections); + std::variant, Node> outputs = simplifier.handle(cell->name, cell->type, cell->parameters, connections); if(auto *nodep = std::get_if(&outputs); nodep != nullptr) { log_assert(n_outputs == 1); factory.update_pending(cell_outputs.at({cell, output_name}), *nodep); @@ -591,7 +655,7 @@ public: DriveChunkWire wire_chunk = chunk.wire(); if (wire_chunk.is_whole()) { if (wire_chunk.wire->port_input) { - Node node = factory.input(wire_chunk.wire->name); + Node node = factory.value(factory.ir().input(wire_chunk.wire->name)); factory.suggest_name(node, wire_chunk.wire->name); factory.update_pending(pending, node); } else { @@ -668,10 +732,12 @@ void IR::topological_sort() { scc = true; } }); - for(const auto &[name, sort]: _state_sorts) - toposort.process(get_state_next_node(name).id()); - for(const auto &[name, sort]: _output_sorts) - toposort.process(get_output_node(name).id()); + for(const auto &[name, state]: _states) + if(state.has_next_value()) + toposort.process(state.next_value().id()); + for(const auto &[name, output]: _outputs) + if(output.has_value()) + toposort.process(output.value().id()); // any nodes untouched by this point are dead code and will be removed by permute _graph.permute(perm); if(scc) log_error("The design contains combinational loops. This is not supported by the functional backend. " diff --git a/kernel/functional.h b/kernel/functional.h index 08b7b99ca..f7ff08228 100644 --- a/kernel/functional.h +++ b/kernel/functional.h @@ -152,11 +152,60 @@ namespace Functional { bool operator==(Sort const& other) const { return _v == other._v; } unsigned int hash() const { return mkhash(_v); } }; + class IR; class Factory; class Node; + class IRInput { + friend class Factory; + public: + IdString name; + IdString type; + Sort sort; + private: + IRInput(IR &, IdString name, IdString type, Sort sort) + : name(name), type(type), sort(std::move(sort)) {} + }; + class IROutput { + friend class Factory; + IR &_ir; + public: + IdString name; + IdString type; + Sort sort; + private: + IROutput(IR &ir, IdString name, IdString type, Sort sort) + : _ir(ir), name(name), type(type), sort(std::move(sort)) {} + public: + Node value() const; + bool has_value() const; + void set_value(Node value); + }; + class IRState { + friend class Factory; + IR &_ir; + public: + IdString name; + IdString type; + Sort sort; + private: + std::variant _initial; + IRState(IR &ir, IdString name, IdString type, Sort sort) + : _ir(ir), name(name), type(type), sort(std::move(sort)) {} + public: + Node next_value() const; + bool has_next_value() const; + RTLIL::Const const& initial_value_signal() const { return std::get(_initial); } + MemContents const& initial_value_memory() const { return std::get(_initial); } + void set_next_value(Node value); + void set_initial_value(RTLIL::Const value) { value.extu(sort.width()); _initial = std::move(value); } + void set_initial_value(MemContents value) { log_assert(Sort(value.addr_width(), value.data_width()) == sort); _initial = std::move(value); } + }; class IR { friend class Factory; friend class Node; + friend class IRInput; + friend class IROutput; + friend class IRState; // one NodeData is stored per Node, containing the function and non-node arguments // note that NodeData is deduplicated by ComputeGraph class NodeData { @@ -164,7 +213,7 @@ namespace Functional { std::variant< std::monostate, RTLIL::Const, - IdString, + std::pair, int > _extra; public: @@ -173,7 +222,7 @@ namespace Functional { template NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward(extra)) {} Fn fn() const { return _fn; } const RTLIL::Const &as_const() const { return std::get(_extra); } - IdString as_idstring() const { return std::get(_extra); } + std::pair as_idstring_pair() const { return std::get>(_extra); } int as_int() const { return std::get(_extra); } int hash() const { return mkhash((unsigned int) _fn, mkhash(_extra)); @@ -190,13 +239,12 @@ namespace Functional { // the sparse_attr IdString stores a naming suggestion, retrieved with name() // the key is currently used to identify the nodes that represent output and next state values // the bool is true for next state values - using Graph = ComputeGraph>; + using Graph = ComputeGraph>; Graph _graph; - dict _input_sorts; - dict _output_sorts; - dict _state_sorts; - dict _initial_state_signal; - dict _initial_state_memory; + dict, IRInput> _inputs; + dict, IROutput> _outputs; + dict, IRState> _states; + IR::Graph::Ref mutate(Node n); public: static IR from_module(Module *module); Factory factory(); @@ -204,13 +252,24 @@ namespace Functional { Node operator[](int i); void topological_sort(); void forward_buf(); - dict inputs() const { return _input_sorts; } - dict outputs() const { return _output_sorts; } - dict state() const { return _state_sorts; } - RTLIL::Const const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); } - MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); } - Node get_output_node(IdString name); - Node get_state_next_node(IdString name); + IRInput const& input(IdString name, IdString type) const { return _inputs.at({name, type}); } + IRInput const& input(IdString name) const { return input(name, ID($input)); } + IROutput const& output(IdString name, IdString type) const { return _outputs.at({name, type}); } + IROutput const& output(IdString name) const { return output(name, ID($output)); } + IRState const& state(IdString name, IdString type) const { return _states.at({name, type}); } + IRState const& state(IdString name) const { return state(name, ID($state)); } + bool has_input(IdString name, IdString type) const { return _inputs.count({name, type}); } + bool has_output(IdString name, IdString type) const { return _outputs.count({name, type}); } + bool has_state(IdString name, IdString type) const { return _states.count({name, type}); } + vector inputs(IdString type) const; + vector inputs() const { return inputs(ID($input)); } + vector outputs(IdString type) const; + vector outputs() const { return outputs(ID($output)); } + vector states(IdString type) const; + vector states() const { return states(ID($state)); } + vector all_inputs() const; + vector all_outputs() const; + vector all_states() const; class iterator { friend class IR; IR *_ir; @@ -236,6 +295,9 @@ namespace Functional { class Node { friend class Factory; friend class IR; + friend class IRInput; + friend class IROutput; + friend class IRState; IR::Graph::ConstRef _ref; explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { } explicit operator IR::Graph::ConstRef() { return _ref; } @@ -290,8 +352,8 @@ namespace Functional { case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break; case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break; case Fn::constant: return v.constant(*this, _ref.function().as_const()); break; - case Fn::input: return v.input(*this, _ref.function().as_idstring()); break; - case Fn::state: return v.state(*this, _ref.function().as_idstring()); break; + case Fn::input: return v.input(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break; + case Fn::state: return v.state(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break; case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break; case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break; } @@ -300,9 +362,14 @@ namespace Functional { std::string to_string(); std::string to_string(std::function); }; + inline IR::Graph::Ref IR::mutate(Node n) { return _graph[n._ref.index()]; } inline Node IR::operator[](int i) { return Node(_graph[i]); } - inline Node IR::get_output_node(IdString name) { return Node(_graph({name, false})); } - inline Node IR::get_state_next_node(IdString name) { return Node(_graph({name, true})); } + inline Node IROutput::value() const { return Node(_ir._graph({name, type, false})); } + inline bool IROutput::has_value() const { return _ir._graph.has_key({name, type, false}); } + inline void IROutput::set_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, type, false}); } + inline Node IRState::next_value() const { return Node(_ir._graph({name, type, true})); } + inline bool IRState::has_next_value() const { return _ir._graph.has_key({name, type, true}); } + inline void IRState::set_next_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, type, true}); } inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); } inline arrow_proxy IR::iterator::operator->() { return arrow_proxy(**this); } // AbstractVisitor provides an abstract base class for visitors @@ -336,8 +403,8 @@ namespace Functional { virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0; virtual T mux(Node self, Node a, Node b, Node s) = 0; virtual T constant(Node self, RTLIL::Const const & value) = 0; - virtual T input(Node self, IdString name) = 0; - virtual T state(Node self, IdString name) = 0; + virtual T input(Node self, IdString name, IdString type) = 0; + virtual T state(Node self, IdString name, IdString type) = 0; virtual T memory_read(Node self, Node mem, Node addr) = 0; virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0; }; @@ -373,8 +440,8 @@ namespace Functional { T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); } T mux(Node self, Node, Node, Node) override { return default_handler(self); } T constant(Node self, RTLIL::Const const &) override { return default_handler(self); } - T input(Node self, IdString) override { return default_handler(self); } - T state(Node self, IdString) override { return default_handler(self); } + T input(Node self, IdString, IdString) override { return default_handler(self); } + T state(Node self, IdString, IdString) override { return default_handler(self); } T memory_read(Node self, Node, Node) override { return default_handler(self); } T memory_write(Node self, Node, Node, Node) override { return default_handler(self); } }; @@ -383,7 +450,7 @@ namespace Functional { friend class IR; IR &_ir; explicit Factory(IR &ir) : _ir(ir) {} - Node add(IR::NodeData &&fn, Sort &&sort, std::initializer_list args) { + Node add(IR::NodeData &&fn, Sort const &sort, std::initializer_list args) { log_assert(!sort.is_signal() || sort.width() > 0); log_assert(!sort.is_memory() || (sort.addr_width() > 0 && sort.data_width() > 0)); IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)}); @@ -391,13 +458,11 @@ namespace Functional { ref.append_arg(IR::Graph::ConstRef(arg)); return Node(ref); } - IR::Graph::Ref mutate(Node n) { - return _ir._graph[n._ref.index()]; - } void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); } void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); } void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } public: + IR &ir() { return _ir; } Node slice(Node a, int offset, int out_width) { log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); if(offset == 0 && out_width == a.width()) @@ -480,45 +545,31 @@ namespace Functional { void update_pending(Node node, Node value) { log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0); log_assert(node.sort() == value.sort()); - mutate(node).append_arg(value._ref); + _ir.mutate(node).append_arg(value._ref); } - void add_input(IdString name, int width) { - auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width)); + IRInput &add_input(IdString name, IdString type, Sort sort) { + auto [it, inserted] = _ir._inputs.emplace({name, type}, IRInput(_ir, name, type, std::move(sort))); if (!inserted) log_error("input `%s` was re-defined", name.c_str()); + return it->second; } - void add_output(IdString name, int width) { - auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width)); + IROutput &add_output(IdString name, IdString type, Sort sort) { + auto [it, inserted] = _ir._outputs.emplace({name, type}, IROutput(_ir, name, type, std::move(sort))); if (!inserted) log_error("output `%s` was re-defined", name.c_str()); + return it->second; } - void add_state(IdString name, Sort sort) { - auto [it, inserted] = _ir._state_sorts.emplace(name, sort); + IRState &add_state(IdString name, IdString type, Sort sort) { + auto [it, inserted] = _ir._states.emplace({name, type}, IRState(_ir, name, type, std::move(sort))); if (!inserted) log_error("state `%s` was re-defined", name.c_str()); + return it->second; } - Node input(IdString name) { - return add(IR::NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {}); + Node value(IRInput const& input) { + return add(IR::NodeData(Fn::input, std::pair(input.name, input.type)), input.sort, {}); } - Node current_state(IdString name) { - return add(IR::NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {}); - } - void set_output(IdString output, Node value) { - log_assert(_ir._output_sorts.at(output) == value.sort()); - mutate(value).assign_key({output, false}); - } - void set_initial_state(IdString state, RTLIL::Const value) { - Sort &sort = _ir._state_sorts.at(state); - value.extu(sort.width()); - _ir._initial_state_signal.emplace(state, std::move(value)); - } - void set_initial_state(IdString state, MemContents value) { - log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state)); - _ir._initial_state_memory.emplace(state, std::move(value)); - } - void set_next_state(IdString state, Node value) { - log_assert(_ir._state_sorts.at(state) == value.sort()); - mutate(value).assign_key({state, true}); + Node value(IRState const& state) { + return add(IR::NodeData(Fn::state, std::pair(state.name, state.type)), state.sort, {}); } void suggest_name(Node node, IdString name) { - mutate(node).sparse_attr() = name; + _ir.mutate(node).sparse_attr() = name; } }; inline Factory IR::factory() { return Factory(*this); }