diff --git a/Makefile b/Makefile index 68e6fda4a..e61948cb4 100644 --- a/Makefile +++ b/Makefile @@ -640,7 +640,7 @@ $(eval $(call add_include_file,backends/rtlil/rtlil_backend.h)) OBJS += kernel/driver.o kernel/register.o kernel/rtlil.o kernel/log.o kernel/calc.o kernel/yosys.o OBJS += kernel/binding.o OBJS += kernel/cellaigs.o kernel/celledges.o kernel/cost.o kernel/satgen.o kernel/scopeinfo.o kernel/qcsat.o kernel/mem.o kernel/ffmerge.o kernel/ff.o kernel/yw.o kernel/json.o kernel/fmt.o kernel/sexpr.o -OBJS += kernel/drivertools.o kernel/functionalir.o +OBJS += kernel/drivertools.o kernel/functional.o ifeq ($(ENABLE_ZLIB),1) OBJS += kernel/fstdata.o endif diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index 8d53c9e03..a4755e144 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -18,7 +18,7 @@ */ #include "kernel/yosys.h" -#include "kernel/functionalir.h" +#include "kernel/functional.h" #include USING_YOSYS_NAMESPACE @@ -42,7 +42,7 @@ const char *reserved_keywords[] = { nullptr }; -template struct CxxScope : public FunctionalTools::Scope { +template struct CxxScope : public Functional::Scope { CxxScope() { for(const char **p = reserved_keywords; *p != nullptr; p++) this->reserve(*p); @@ -53,8 +53,8 @@ template struct CxxScope : public FunctionalTools::Scope { }; struct CxxType { - FunctionalIR::Sort sort; - CxxType(FunctionalIR::Sort sort) : sort(sort) {} + Functional::Sort sort; + CxxType(Functional::Sort sort) : sort(sort) {} std::string to_string() const { if(sort.is_memory()) { return stringf("Memory<%d, %d>", sort.addr_width(), sort.data_width()); @@ -66,7 +66,7 @@ struct CxxType { } }; -using CxxWriter = FunctionalTools::Writer; +using CxxWriter = Functional::Writer; struct CxxStruct { std::string name; @@ -111,8 +111,8 @@ std::string cxx_const(RTLIL::Const const &value) { return ss.str(); } -template struct CxxPrintVisitor : public FunctionalIR::AbstractVisitor { - using Node = FunctionalIR::Node; +template struct CxxPrintVisitor : public Functional::AbstractVisitor { + using Node = Functional::Node; CxxWriter &f; NodePrinter np; CxxStruct &input_struct; @@ -165,12 +165,12 @@ bool equal_def(RTLIL::Const const &a, RTLIL::Const const &b) { } struct CxxModule { - FunctionalIR ir; + Functional::IR ir; CxxStruct input_struct, output_struct, state_struct; std::string module_name; explicit CxxModule(Module *module) : - ir(FunctionalIR::from_module(module)), + ir(Functional::IR::from_module(module)), input_struct("Inputs"), output_struct("Outputs"), state_struct("State") @@ -222,7 +222,7 @@ struct CxxModule { locals.reserve("output"); locals.reserve("current_state"); locals.reserve("next_state"); - auto node_name = [&](FunctionalIR::Node n) { return locals(n.id(), n.name()); }; + auto node_name = [&](Functional::Node n) { return locals(n.id(), n.name()); }; CxxPrintVisitor printVisitor(f, node_name, input_struct, state_struct); for (auto node : ir) { f.print("\t{} {} = ", CxxType(node.sort()).to_string(), node_name(node)); diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc index 0d2763d32..7fd6fe564 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib.cc @@ -17,7 +17,7 @@ * */ -#include "kernel/functionalir.h" +#include "kernel/functional.h" #include "kernel/yosys.h" #include "kernel/sexpr.h" #include @@ -42,7 +42,7 @@ const char *reserved_keywords[] = { nullptr }; -struct SmtScope : public FunctionalTools::Scope { +struct SmtScope : public Functional::Scope { SmtScope() { for(const char **p = reserved_keywords; *p != nullptr; p++) reserve(*p); @@ -53,8 +53,8 @@ struct SmtScope : public FunctionalTools::Scope { }; struct SmtSort { - FunctionalIR::Sort sort; - SmtSort(FunctionalIR::Sort sort) : sort(sort) {} + Functional::Sort sort; + SmtSort(Functional::Sort sort) : sort(sort) {} SExpr to_sexpr() const { if(sort.is_memory()) { return list("Array", list("_", "BitVec", sort.addr_width()), list("_", "BitVec", sort.data_width())); @@ -116,8 +116,8 @@ std::string smt_const(RTLIL::Const const &c) { return s; } -struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor { - using Node = FunctionalIR::Node; +struct SmtPrintVisitor : public Functional::AbstractVisitor { + using Node = Functional::Node; std::function n; SmtStruct &input_struct; SmtStruct &state_struct; @@ -183,7 +183,7 @@ struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor { }; struct SmtModule { - FunctionalIR ir; + Functional::IR ir; SmtScope scope; std::string name; @@ -192,7 +192,7 @@ struct SmtModule { SmtStruct state_struct; SmtModule(Module *module) - : ir(FunctionalIR::from_module(module)) + : ir(Functional::IR::from_module(module)) , scope() , name(scope.unique_name(module->name)) , input_struct(scope.unique_name(module->name.str() + "_Inputs"), scope) @@ -215,11 +215,11 @@ struct SmtModule { list(list("inputs", input_struct.name), list("state", state_struct.name)), list("Pair", output_struct.name, state_struct.name))); - auto inlined = [&](FunctionalIR::Node n) { - return n.fn() == FunctionalIR::Fn::constant; + auto inlined = [&](Functional::Node n) { + return n.fn() == Functional::Fn::constant; }; SmtPrintVisitor visitor(input_struct, state_struct); - auto node_to_sexpr = [&](FunctionalIR::Node n) -> SExpr { + auto node_to_sexpr = [&](Functional::Node n) -> SExpr { if(inlined(n)) return n.visit(visitor); else diff --git a/backends/functional/test_generic.cc b/backends/functional/test_generic.cc index 5d9349276..83ea09d8d 100644 --- a/backends/functional/test_generic.cc +++ b/backends/functional/test_generic.cc @@ -18,7 +18,7 @@ */ #include "kernel/yosys.h" -#include "kernel/functionalir.h" +#include "kernel/functional.h" #include USING_YOSYS_NAMESPACE @@ -139,7 +139,7 @@ struct FunctionalTestGeneric : public Pass for (auto module : design->selected_modules()) { log("Dumping module `%s'.\n", module->name.c_str()); - auto fir = FunctionalIR::from_module(module); + auto fir = Functional::IR::from_module(module); for(auto node : fir) std::cout << RTLIL::unescape_id(node.name()) << " = " << node.to_string([](auto n) { return RTLIL::unescape_id(n.name()); }) << "\n"; for(auto [name, sort] : fir.outputs()) diff --git a/kernel/compute_graph.h b/kernel/compute_graph.h new file mode 100644 index 000000000..aeba17f8c --- /dev/null +++ b/kernel/compute_graph.h @@ -0,0 +1,403 @@ +/* + * yosys -- Yosys Open SYnthesis Suite + * + * Copyright (C) 2024 Jannis Harder + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#ifndef COMPUTE_GRAPH_H +#define COMPUTE_GRAPH_H + +#include +#include "kernel/yosys.h" + +YOSYS_NAMESPACE_BEGIN + +template< + typename Fn, // Function type (deduplicated across whole graph) + typename Attr = std::tuple<>, // Call attributes (present in every node) + typename SparseAttr = std::tuple<>, // Sparse call attributes (optional per node) + typename Key = std::tuple<> // Stable keys to refer to nodes +> +struct ComputeGraph +{ + struct Ref; +private: + + // Functions are deduplicated by assigning unique ids + idict functions; + + struct Node { + int fn_index; + int arg_offset; + int arg_count; + Attr attr; + + Node(int fn_index, Attr &&attr, int arg_offset, int arg_count = 0) + : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(std::move(attr)) {} + + Node(int fn_index, Attr const &attr, int arg_offset, int arg_count = 0) + : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(attr) {} + }; + + + std::vector nodes; + std::vector args; + dict keys_; + dict sparse_attrs; + +public: + template + struct BaseRef + { + protected: + friend struct ComputeGraph; + Graph *graph_; + int index_; + BaseRef(Graph *graph, int index) : graph_(graph), index_(index) { + log_assert(index_ >= 0); + check(); + } + + void check() const { log_assert(index_ < graph_->size()); } + + Node const &deref() const { check(); return graph_->nodes[index_]; } + + public: + ComputeGraph const &graph() const { return graph_; } + int index() const { return index_; } + + int size() const { return deref().arg_count; } + + BaseRef arg(int n) const + { + Node const &node = deref(); + log_assert(n >= 0 && n < node.arg_count); + return BaseRef(graph_, graph_->args[node.arg_offset + n]); + } + + std::vector::const_iterator arg_indices_cbegin() const + { + Node const &node = deref(); + return graph_->args.cbegin() + node.arg_offset; + } + + std::vector::const_iterator arg_indices_cend() const + { + Node const &node = deref(); + return graph_->args.cbegin() + node.arg_offset + node.arg_count; + } + + Fn const &function() const { return graph_->functions[deref().fn_index]; } + Attr const &attr() const { return deref().attr; } + + bool has_sparse_attr() const { return graph_->sparse_attrs.count(index_); } + + SparseAttr const &sparse_attr() const + { + auto found = graph_->sparse_attrs.find(index_); + log_assert(found != graph_->sparse_attrs.end()); + return found->second; + } + }; + + using ConstRef = BaseRef; + + struct Ref : public BaseRef + { + private: + friend struct ComputeGraph; + Ref(ComputeGraph *graph, int index) : BaseRef(graph, index) {} + Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; } + + public: + Ref(BaseRef ref) : Ref(ref.graph_, ref.index_) {} + + void set_function(Fn const &function) const + { + deref().fn_index = this->graph_->functions(function); + } + + Attr &attr() const { return deref().attr; } + + void append_arg(ConstRef arg) const + { + log_assert(arg.graph_ == this->graph_); + append_arg(arg.index()); + } + + void append_arg(int arg) const + { + log_assert(arg >= 0 && arg < this->graph_->size()); + Node &node = deref(); + if (node.arg_offset + node.arg_count != GetSize(this->graph_->args)) + move_args(node); + this->graph_->args.push_back(arg); + node.arg_count++; + } + + operator ConstRef() const + { + return ConstRef(this->graph_, this->index_); + } + + SparseAttr &sparse_attr() const + { + return this->graph_->sparse_attrs[this->index_]; + } + + void clear_sparse_attr() const + { + this->graph_->sparse_attrs.erase(this->index_); + } + + void assign_key(Key const &key) const + { + this->graph_->keys_.emplace(key, this->index_); + } + + private: + void move_args(Node &node) const + { + auto &args = this->graph_->args; + int old_offset = node.arg_offset; + node.arg_offset = GetSize(args); + for (int i = 0; i != node.arg_count; ++i) + args.push_back(args[old_offset + i]); + } + + }; + + bool has_key(Key const &key) const + { + return keys_.count(key); + } + + dict const &keys() const + { + return keys_; + } + + ConstRef operator()(Key const &key) const + { + auto it = keys_.find(key); + log_assert(it != keys_.end()); + return (*this)[it->second]; + } + + Ref operator()(Key const &key) + { + auto it = keys_.find(key); + log_assert(it != keys_.end()); + return (*this)[it->second]; + } + + int size() const { return GetSize(nodes); } + + ConstRef operator[](int index) const { return ConstRef(this, index); } + Ref operator[](int index) { return Ref(this, index); } + + Ref add(Fn const &function, Attr &&attr) + { + int index = GetSize(nodes); + int fn_index = functions(function); + nodes.emplace_back(fn_index, std::move(attr), GetSize(args)); + return Ref(this, index); + } + + Ref add(Fn const &function, Attr const &attr) + { + int index = GetSize(nodes); + int fn_index = functions(function); + nodes.emplace_back(fn_index, attr, GetSize(args)); + return Ref(this, index); + } + + template + Ref add(Fn const &function, Attr const &attr, T &&args) + { + Ref added = add(function, attr); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + template + Ref add(Fn const &function, Attr &&attr, T &&args) + { + Ref added = add(function, std::move(attr)); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + Ref add(Fn const &function, Attr const &attr, std::initializer_list args) + { + Ref added = add(function, attr); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + Ref add(Fn const &function, Attr &&attr, std::initializer_list args) + { + Ref added = add(function, std::move(attr)); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + template + Ref add(Fn const &function, Attr const &attr, T begin, T end) + { + Ref added = add(function, attr); + for (; begin != end; ++begin) + added.append_arg(*begin); + return added; + } + + void compact_args() + { + std::vector new_args; + for (auto &node : nodes) + { + int new_offset = GetSize(new_args); + for (int i = 0; i < node.arg_count; i++) + new_args.push_back(args[node.arg_offset + i]); + node.arg_offset = new_offset; + } + std::swap(args, new_args); + } + + void permute(std::vector const &perm) + { + log_assert(perm.size() <= nodes.size()); + std::vector inv_perm; + inv_perm.resize(nodes.size(), -1); + for (int i = 0; i < GetSize(perm); ++i) + { + int j = perm[i]; + log_assert(j >= 0 && j < GetSize(nodes)); + log_assert(inv_perm[j] == -1); + inv_perm[j] = i; + } + permute(perm, inv_perm); + } + + void permute(std::vector const &perm, std::vector const &inv_perm) + { + log_assert(inv_perm.size() == nodes.size()); + std::vector new_nodes; + new_nodes.reserve(perm.size()); + dict new_sparse_attrs; + for (int i : perm) + { + int j = GetSize(new_nodes); + new_nodes.emplace_back(std::move(nodes[i])); + auto found = sparse_attrs.find(i); + if (found != sparse_attrs.end()) + new_sparse_attrs.emplace(j, std::move(found->second)); + } + + std::swap(nodes, new_nodes); + std::swap(sparse_attrs, new_sparse_attrs); + + compact_args(); + for (int &arg : args) + { + log_assert(arg < GetSize(inv_perm)); + log_assert(inv_perm[arg] >= 0); + arg = inv_perm[arg]; + } + + for (auto &key : keys_) + { + log_assert(key.second < GetSize(inv_perm)); + log_assert(inv_perm[key.second] >= 0); + key.second = inv_perm[key.second]; + } + } + + struct SccAdaptor + { + private: + ComputeGraph const &graph_; + std::vector indices_; + public: + SccAdaptor(ComputeGraph const &graph) : graph_(graph) + { + indices_.resize(graph.size(), -1); + } + + + typedef int node_type; + + struct node_enumerator { + private: + friend struct SccAdaptor; + int current, end; + node_enumerator(int current, int end) : current(current), end(end) {} + + public: + + bool finished() const { return current == end; } + node_type next() { + log_assert(!finished()); + node_type result = current; + ++current; + return result; + } + }; + + node_enumerator enumerate_nodes() { + return node_enumerator(0, GetSize(indices_)); + } + + + struct successor_enumerator { + private: + friend struct SccAdaptor; + std::vector::const_iterator current, end; + successor_enumerator(std::vector::const_iterator current, std::vector::const_iterator end) : + current(current), end(end) {} + + public: + bool finished() const { return current == end; } + node_type next() { + log_assert(!finished()); + node_type result = *current; + ++current; + return result; + } + }; + + successor_enumerator enumerate_successors(int index) const { + auto const &ref = graph_[index]; + return successor_enumerator(ref.arg_indices_cbegin(), ref.arg_indices_cend()); + } + + int &dfs_index(node_type const &node) { return indices_[node]; } + + std::vector const &dfs_indices() { return indices_; } + }; + +}; + + + +YOSYS_NAMESPACE_END + + +#endif diff --git a/kernel/functionalir.cc b/kernel/functional.cc similarity index 90% rename from kernel/functionalir.cc rename to kernel/functional.cc index 223fdaa91..ad507187d 100644 --- a/kernel/functionalir.cc +++ b/kernel/functional.cc @@ -17,55 +17,55 @@ * */ -#include "kernel/functionalir.h" -#include +#include "kernel/functional.h" +#include "kernel/topo_scc.h" #include "ff.h" #include "ffinit.h" YOSYS_NAMESPACE_BEGIN +namespace Functional { -const char *FunctionalIR::fn_to_string(FunctionalIR::Fn fn) { +const char *fn_to_string(Fn fn) { switch(fn) { - case FunctionalIR::Fn::invalid: return "invalid"; - case FunctionalIR::Fn::buf: return "buf"; - case FunctionalIR::Fn::slice: return "slice"; - case FunctionalIR::Fn::zero_extend: return "zero_extend"; - case FunctionalIR::Fn::sign_extend: return "sign_extend"; - case FunctionalIR::Fn::concat: return "concat"; - case FunctionalIR::Fn::add: return "add"; - case FunctionalIR::Fn::sub: return "sub"; - case FunctionalIR::Fn::mul: return "mul"; - case FunctionalIR::Fn::unsigned_div: return "unsigned_div"; - case FunctionalIR::Fn::unsigned_mod: return "unsigned_mod"; - case FunctionalIR::Fn::bitwise_and: return "bitwise_and"; - case FunctionalIR::Fn::bitwise_or: return "bitwise_or"; - case FunctionalIR::Fn::bitwise_xor: return "bitwise_xor"; - case FunctionalIR::Fn::bitwise_not: return "bitwise_not"; - case FunctionalIR::Fn::reduce_and: return "reduce_and"; - case FunctionalIR::Fn::reduce_or: return "reduce_or"; - case FunctionalIR::Fn::reduce_xor: return "reduce_xor"; - case FunctionalIR::Fn::unary_minus: return "unary_minus"; - case FunctionalIR::Fn::equal: return "equal"; - case FunctionalIR::Fn::not_equal: return "not_equal"; - case FunctionalIR::Fn::signed_greater_than: return "signed_greater_than"; - case FunctionalIR::Fn::signed_greater_equal: return "signed_greater_equal"; - case FunctionalIR::Fn::unsigned_greater_than: return "unsigned_greater_than"; - case FunctionalIR::Fn::unsigned_greater_equal: return "unsigned_greater_equal"; - case FunctionalIR::Fn::logical_shift_left: return "logical_shift_left"; - case FunctionalIR::Fn::logical_shift_right: return "logical_shift_right"; - case FunctionalIR::Fn::arithmetic_shift_right: return "arithmetic_shift_right"; - case FunctionalIR::Fn::mux: return "mux"; - case FunctionalIR::Fn::constant: return "constant"; - case FunctionalIR::Fn::input: return "input"; - case FunctionalIR::Fn::state: return "state"; - case FunctionalIR::Fn::memory_read: return "memory_read"; - case FunctionalIR::Fn::memory_write: return "memory_write"; + case Fn::invalid: return "invalid"; + case Fn::buf: return "buf"; + case Fn::slice: return "slice"; + case Fn::zero_extend: return "zero_extend"; + case Fn::sign_extend: return "sign_extend"; + case Fn::concat: return "concat"; + case Fn::add: return "add"; + case Fn::sub: return "sub"; + case Fn::mul: return "mul"; + case Fn::unsigned_div: return "unsigned_div"; + case Fn::unsigned_mod: return "unsigned_mod"; + case Fn::bitwise_and: return "bitwise_and"; + case Fn::bitwise_or: return "bitwise_or"; + case Fn::bitwise_xor: return "bitwise_xor"; + case Fn::bitwise_not: return "bitwise_not"; + case Fn::reduce_and: return "reduce_and"; + case Fn::reduce_or: return "reduce_or"; + case Fn::reduce_xor: return "reduce_xor"; + case Fn::unary_minus: return "unary_minus"; + case Fn::equal: return "equal"; + case Fn::not_equal: return "not_equal"; + case Fn::signed_greater_than: return "signed_greater_than"; + case Fn::signed_greater_equal: return "signed_greater_equal"; + case Fn::unsigned_greater_than: return "unsigned_greater_than"; + case Fn::unsigned_greater_equal: return "unsigned_greater_equal"; + case Fn::logical_shift_left: return "logical_shift_left"; + case Fn::logical_shift_right: return "logical_shift_right"; + case Fn::arithmetic_shift_right: return "arithmetic_shift_right"; + case Fn::mux: return "mux"; + case Fn::constant: return "constant"; + case Fn::input: return "input"; + case Fn::state: return "state"; + case Fn::memory_read: return "memory_read"; + case Fn::memory_write: return "memory_write"; } - log_error("fn_to_string: unknown FunctionalIR::Fn value %d", (int)fn); + log_error("fn_to_string: unknown Functional::Fn value %d", (int)fn); } -struct PrintVisitor : FunctionalIR::DefaultVisitor { - using Node = FunctionalIR::Node; +struct PrintVisitor : DefaultVisitor { std::function np; PrintVisitor(std::function np) : np(np) { } // as a general rule the default handler is good enough iff the only arguments are of type Node @@ -76,7 +76,7 @@ struct PrintVisitor : FunctionalIR::DefaultVisitor { std::string input(Node, IdString name) override { return "input(" + name.str() + ")"; } std::string state(Node, IdString name) override { return "state(" + name.str() + ")"; } std::string default_handler(Node self) override { - std::string ret = FunctionalIR::fn_to_string(self.fn()); + std::string ret = fn_to_string(self.fn()); ret += "("; for(size_t i = 0; i < self.arg_count(); i++) { if(i > 0) ret += ", "; @@ -87,19 +87,18 @@ struct PrintVisitor : FunctionalIR::DefaultVisitor { } }; -std::string FunctionalIR::Node::to_string() +std::string Node::to_string() { return to_string([](Node n) { return RTLIL::unescape_id(n.name()); }); } -std::string FunctionalIR::Node::to_string(std::function np) +std::string Node::to_string(std::function np) { return visit(PrintVisitor(np)); } class CellSimplifier { - using Node = FunctionalIR::Node; - FunctionalIR::Factory &factory; + Factory &factory; Node sign(Node a) { return factory.slice(a, a.width() - 1, 1); } @@ -138,7 +137,7 @@ public: Node bb = factory.bitwise_and(b, s); return factory.bitwise_or(aa, bb); } - CellSimplifier(FunctionalIR::Factory &f) : factory(f) {} + CellSimplifier(Factory &f) : factory(f) {} private: Node handle_pow(Node a0, Node b, int y_width, bool is_signed) { Node a = factory.extend(a0, y_width, is_signed); @@ -400,12 +399,11 @@ public: }; class FunctionalIRConstruction { - using Node = FunctionalIR::Node; std::deque> queue; dict graph_nodes; dict, Node> cell_outputs; DriverMap driver_map; - FunctionalIR::Factory& factory; + Factory& factory; CellSimplifier simplifier; vector memories_vector; dict memories; @@ -442,7 +440,7 @@ class FunctionalIRConstruction { return it->second; } public: - FunctionalIRConstruction(Module *module, FunctionalIR::Factory &f) + FunctionalIRConstruction(Module *module, Factory &f) : factory(f) , simplifier(f) , sig_map(module) @@ -497,7 +495,7 @@ public: // - Since wr port j can only have priority over wr port i if j > i, if we do writes in // ascending index order the result will obey the priorty relation. vector read_results; - factory.add_state(mem->cell->name, FunctionalIR::Sort(ceil_log2(mem->size), mem->width)); + factory.add_state(mem->cell->name, Sort(ceil_log2(mem->size), mem->width)); factory.set_initial_state(mem->cell->name, MemContents(mem)); Node node = factory.current_state(mem->cell->name); for (size_t i = 0; i < mem->wr_ports.size(); i++) { @@ -542,7 +540,7 @@ public: if (!ff.has_gclk) log_error("The design contains a %s flip-flop at %s. This is not supported by the functional backend. " "Call async2sync or clk2fflogic to avoid this error.\n", log_id(cell->type), log_id(cell)); - factory.add_state(ff.name, FunctionalIR::Sort(ff.width)); + factory.add_state(ff.name, Sort(ff.width)); Node q_value = factory.current_state(ff.name); factory.suggest_name(q_value, ff.name); factory.update_pending(cell_outputs.at({cell, ID(Q)}), q_value); @@ -643,8 +641,8 @@ public: } }; -FunctionalIR FunctionalIR::from_module(Module *module) { - FunctionalIR ir; +IR IR::from_module(Module *module) { + IR ir; auto factory = ir.factory(); FunctionalIRConstruction ctor(module, factory); ctor.process_queue(); @@ -653,7 +651,7 @@ FunctionalIR FunctionalIR::from_module(Module *module) { return ir; } -void FunctionalIR::topological_sort() { +void IR::topological_sort() { Graph::SccAdaptor compute_graph_scc(_graph); bool scc = false; std::vector perm; @@ -687,7 +685,7 @@ static IdString merge_name(IdString a, IdString b) { return a; } -void FunctionalIR::forward_buf() { +void IR::forward_buf() { std::vector perm, alias; perm.clear(); @@ -734,7 +732,7 @@ static std::string quote_fmt(const char *fmt) return r; } -void FunctionalTools::Writer::print_impl(const char *fmt, vector> &fns) +void Writer::print_impl(const char *fmt, vector> &fns) { size_t next_index = 0; for(const char *p = fmt; *p != 0; p++) @@ -770,4 +768,5 @@ void FunctionalTools::Writer::print_impl(const char *fmt, vector + * Copyright (C) 2024 Emily Schmidt * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -20,384 +20,571 @@ #ifndef FUNCTIONAL_H #define FUNCTIONAL_H -#include #include "kernel/yosys.h" +#include "kernel/compute_graph.h" +#include "kernel/drivertools.h" +#include "kernel/mem.h" +#include "kernel/utils.h" +USING_YOSYS_NAMESPACE YOSYS_NAMESPACE_BEGIN -template< - typename Fn, // Function type (deduplicated across whole graph) - typename Attr = std::tuple<>, // Call attributes (present in every node) - typename SparseAttr = std::tuple<>, // Sparse call attributes (optional per node) - typename Key = std::tuple<> // Stable keys to refer to nodes -> -struct ComputeGraph -{ - struct Ref; -private: - - // Functions are deduplicated by assigning unique ids - idict functions; - - struct Node { - int fn_index; - int arg_offset; - int arg_count; - Attr attr; - - Node(int fn_index, Attr &&attr, int arg_offset, int arg_count = 0) - : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(std::move(attr)) {} - - Node(int fn_index, Attr const &attr, int arg_offset, int arg_count = 0) - : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(attr) {} - }; - - - std::vector nodes; - std::vector args; - dict keys_; - dict sparse_attrs; - -public: - template - struct BaseRef - { - protected: - friend struct ComputeGraph; - Graph *graph_; - int index_; - BaseRef(Graph *graph, int index) : graph_(graph), index_(index) { - log_assert(index_ >= 0); - check(); - } - - void check() const { log_assert(index_ < graph_->size()); } - - Node const &deref() const { check(); return graph_->nodes[index_]; } - - public: - ComputeGraph const &graph() const { return graph_; } - int index() const { return index_; } - - int size() const { return deref().arg_count; } - - BaseRef arg(int n) const - { - Node const &node = deref(); - log_assert(n >= 0 && n < node.arg_count); - return BaseRef(graph_, graph_->args[node.arg_offset + n]); - } - - std::vector::const_iterator arg_indices_cbegin() const - { - Node const &node = deref(); - return graph_->args.cbegin() + node.arg_offset; - } - - std::vector::const_iterator arg_indices_cend() const - { - Node const &node = deref(); - return graph_->args.cbegin() + node.arg_offset + node.arg_count; - } - - Fn const &function() const { return graph_->functions[deref().fn_index]; } - Attr const &attr() const { return deref().attr; } - - bool has_sparse_attr() const { return graph_->sparse_attrs.count(index_); } - - SparseAttr const &sparse_attr() const - { - auto found = graph_->sparse_attrs.find(index_); - log_assert(found != graph_->sparse_attrs.end()); - return found->second; - } - }; - - using ConstRef = BaseRef; - - struct Ref : public BaseRef - { - private: - friend struct ComputeGraph; - Ref(ComputeGraph *graph, int index) : BaseRef(graph, index) {} - Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; } - - public: - Ref(BaseRef ref) : Ref(ref.graph_, ref.index_) {} - - void set_function(Fn const &function) const - { - deref().fn_index = this->graph_->functions(function); - } - - Attr &attr() const { return deref().attr; } - - void append_arg(ConstRef arg) const - { - log_assert(arg.graph_ == this->graph_); - append_arg(arg.index()); - } - - void append_arg(int arg) const - { - log_assert(arg >= 0 && arg < this->graph_->size()); - Node &node = deref(); - if (node.arg_offset + node.arg_count != GetSize(this->graph_->args)) - move_args(node); - this->graph_->args.push_back(arg); - node.arg_count++; - } - - operator ConstRef() const - { - return ConstRef(this->graph_, this->index_); - } - - SparseAttr &sparse_attr() const - { - return this->graph_->sparse_attrs[this->index_]; - } - - void clear_sparse_attr() const - { - this->graph_->sparse_attrs.erase(this->index_); - } - - void assign_key(Key const &key) const - { - this->graph_->keys_.emplace(key, this->index_); - } - - private: - void move_args(Node &node) const - { - auto &args = this->graph_->args; - int old_offset = node.arg_offset; - node.arg_offset = GetSize(args); - for (int i = 0; i != node.arg_count; ++i) - args.push_back(args[old_offset + i]); - } - - }; - - bool has_key(Key const &key) const - { - return keys_.count(key); - } - - dict const &keys() const - { - return keys_; - } - - ConstRef operator()(Key const &key) const - { - auto it = keys_.find(key); - log_assert(it != keys_.end()); - return (*this)[it->second]; - } - - Ref operator()(Key const &key) - { - auto it = keys_.find(key); - log_assert(it != keys_.end()); - return (*this)[it->second]; - } - - int size() const { return GetSize(nodes); } - - ConstRef operator[](int index) const { return ConstRef(this, index); } - Ref operator[](int index) { return Ref(this, index); } - - Ref add(Fn const &function, Attr &&attr) - { - int index = GetSize(nodes); - int fn_index = functions(function); - nodes.emplace_back(fn_index, std::move(attr), GetSize(args)); - return Ref(this, index); - } - - Ref add(Fn const &function, Attr const &attr) - { - int index = GetSize(nodes); - int fn_index = functions(function); - nodes.emplace_back(fn_index, attr, GetSize(args)); - return Ref(this, index); - } - - template - Ref add(Fn const &function, Attr const &attr, T &&args) - { - Ref added = add(function, attr); - for (auto arg : args) - added.append_arg(arg); - return added; - } - - template - Ref add(Fn const &function, Attr &&attr, T &&args) - { - Ref added = add(function, std::move(attr)); - for (auto arg : args) - added.append_arg(arg); - return added; - } - - Ref add(Fn const &function, Attr const &attr, std::initializer_list args) - { - Ref added = add(function, attr); - for (auto arg : args) - added.append_arg(arg); - return added; - } - - Ref add(Fn const &function, Attr &&attr, std::initializer_list args) - { - Ref added = add(function, std::move(attr)); - for (auto arg : args) - added.append_arg(arg); - return added; - } - - template - Ref add(Fn const &function, Attr const &attr, T begin, T end) - { - Ref added = add(function, attr); - for (; begin != end; ++begin) - added.append_arg(*begin); - return added; - } - - void compact_args() - { - std::vector new_args; - for (auto &node : nodes) - { - int new_offset = GetSize(new_args); - for (int i = 0; i < node.arg_count; i++) - new_args.push_back(args[node.arg_offset + i]); - node.arg_offset = new_offset; - } - std::swap(args, new_args); - } - - void permute(std::vector const &perm) - { - log_assert(perm.size() <= nodes.size()); - std::vector inv_perm; - inv_perm.resize(nodes.size(), -1); - for (int i = 0; i < GetSize(perm); ++i) - { - int j = perm[i]; - log_assert(j >= 0 && j < GetSize(nodes)); - log_assert(inv_perm[j] == -1); - inv_perm[j] = i; - } - permute(perm, inv_perm); - } - - void permute(std::vector const &perm, std::vector const &inv_perm) - { - log_assert(inv_perm.size() == nodes.size()); - std::vector new_nodes; - new_nodes.reserve(perm.size()); - dict new_sparse_attrs; - for (int i : perm) - { - int j = GetSize(new_nodes); - new_nodes.emplace_back(std::move(nodes[i])); - auto found = sparse_attrs.find(i); - if (found != sparse_attrs.end()) - new_sparse_attrs.emplace(j, std::move(found->second)); - } - - std::swap(nodes, new_nodes); - std::swap(sparse_attrs, new_sparse_attrs); - - compact_args(); - for (int &arg : args) - { - log_assert(arg < GetSize(inv_perm)); - log_assert(inv_perm[arg] >= 0); - arg = inv_perm[arg]; - } - - for (auto &key : keys_) - { - log_assert(key.second < GetSize(inv_perm)); - log_assert(inv_perm[key.second] >= 0); - key.second = inv_perm[key.second]; - } - } - - struct SccAdaptor - { - private: - ComputeGraph const &graph_; - std::vector indices_; - public: - SccAdaptor(ComputeGraph const &graph) : graph_(graph) - { - indices_.resize(graph.size(), -1); - } - - - typedef int node_type; - - struct node_enumerator { - private: - friend struct SccAdaptor; - int current, end; - node_enumerator(int current, int end) : current(current), end(end) {} - - public: - - bool finished() const { return current == end; } - node_type next() { - log_assert(!finished()); - node_type result = current; - ++current; - return result; - } - }; - - node_enumerator enumerate_nodes() { - return node_enumerator(0, GetSize(indices_)); - } - - - struct successor_enumerator { - private: - friend struct SccAdaptor; - std::vector::const_iterator current, end; - successor_enumerator(std::vector::const_iterator current, std::vector::const_iterator end) : - current(current), end(end) {} - - public: - bool finished() const { return current == end; } - node_type next() { - log_assert(!finished()); - node_type result = *current; - ++current; - return result; - } - }; - - successor_enumerator enumerate_successors(int index) const { - auto const &ref = graph_[index]; - return successor_enumerator(ref.arg_indices_cbegin(), ref.arg_indices_cend()); - } - - int &dfs_index(node_type const &node) { return indices_[node]; } - - std::vector const &dfs_indices() { return indices_; } - }; - -}; - - +namespace Functional { + // each function is documented with a short pseudocode declaration or definition + // standard C/Verilog operators are used to describe the result + // + // the types used in this are: + // - bit[N]: a bitvector of N bits + // bit[N] can be indicated as signed or unsigned. this is not tracked by the functional backend + // but is meant to indicate how the value is interpreted + // if a bit[N] is marked as neither signed nor unsigned, this means the result should be valid with *either* interpretation + // - memory[N, M]: a memory with N address and M data bits + // - int: C++ int + // - Const[N]: yosys RTLIL::Const (with size() == N) + // - IdString: yosys IdString + // - any: used in documentation to indicate that the type is unconstrained + // + // nodes in the functional backend are either of type bit[N] or memory[N,M] (for some N, M: int) + // additionally, they can carry a constant of type int, Const[N] or IdString + // each node has a 'sort' field that stores the type of the node + // slice, zero_extend, sign_extend use the type field to store out_width + enum class Fn { + // invalid() = known-invalid/shouldn't happen value + // TODO: maybe remove this and use e.g. std::optional instead? + invalid, + // buf(a: any): any = a + // no-op operation + // when constructing the compute graph we generate invalid buf() nodes as a placeholder + // and later insert the argument + buf, + // slice(a: bit[in_width], offset: int, out_width: int): bit[out_width] = a[offset +: out_width] + // required: offset + out_width <= in_width + slice, + // zero_extend(a: unsigned bit[in_width], out_width: int): unsigned bit[out_width] = a (zero extended) + // required: out_width > in_width + zero_extend, + // sign_extend(a: signed bit[in_width], out_width: int): signed bit[out_width] = a (sign extended) + // required: out_width > in_width + sign_extend, + // concat(a: bit[N], b: bit[M]): bit[N+M] = {b, a} (verilog syntax) + // concatenates two bitvectors, with a in the least significant position and b in the more significant position + concat, + // add(a: bit[N], b: bit[N]): bit[N] = a + b + add, + // sub(a: bit[N], b: bit[N]): bit[N] = a - b + sub, + // mul(a: bit[N], b: bit[N]): bit[N] = a * b + mul, + // unsigned_div(a: unsigned bit[N], b: unsigned bit[N]): bit[N] = a / b + unsigned_div, + // unsigned_mod(a: signed bit[N], b: signed bit[N]): bit[N] = a % b + unsigned_mod, + // bitwise_and(a: bit[N], b: bit[N]): bit[N] = a & b + bitwise_and, + // bitwise_or(a: bit[N], b: bit[N]): bit[N] = a | b + bitwise_or, + // bitwise_xor(a: bit[N], b: bit[N]): bit[N] = a ^ b + bitwise_xor, + // bitwise_not(a: bit[N]): bit[N] = ~a + bitwise_not, + // reduce_and(a: bit[N]): bit[1] = &a + reduce_and, + // reduce_or(a: bit[N]): bit[1] = |a + reduce_or, + // reduce_xor(a: bit[N]): bit[1] = ^a + reduce_xor, + // unary_minus(a: bit[N]): bit[N] = -a + unary_minus, + // equal(a: bit[N], b: bit[N]): bit[1] = (a == b) + equal, + // not_equal(a: bit[N], b: bit[N]): bit[1] = (a != b) + not_equal, + // signed_greater_than(a: signed bit[N], b: signed bit[N]): bit[1] = (a > b) + signed_greater_than, + // signed_greater_equal(a: signed bit[N], b: signed bit[N]): bit[1] = (a >= b) + signed_greater_equal, + // unsigned_greater_than(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a > b) + unsigned_greater_than, + // unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b) + unsigned_greater_equal, + // logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b + // required: M == clog2(N) + logical_shift_left, + // logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b + // required: M == clog2(N) + logical_shift_right, + // arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b + // required: M == clog2(N) + arithmetic_shift_right, + // mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a + mux, + // constant(a: Const[N]): bit[N] = a + constant, + // input(a: IdString): any + // returns the current value of the input with the specified name + input, + // state(a: IdString): any + // returns the current value of the state variable with the specified name + state, + // memory_read(memory: memory[addr_width, data_width], addr: bit[addr_width]): bit[data_width] = memory[addr] + memory_read, + // memory_write(memory: memory[addr_width, data_width], addr: bit[addr_width], data: bit[data_width]): memory[addr_width, data_width] + // returns a copy of `memory` but with the value at `addr` changed to `data` + memory_write + }; + // returns the name of a Fn value, as a string literal + const char *fn_to_string(Fn); + // Sort represents the sort or type of a node + // currently the only two types are signal/bit and memory + class Sort { + std::variant> _v; + public: + explicit Sort(int width) : _v(width) { } + Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { } + bool is_signal() const { return _v.index() == 0; } + bool is_memory() const { return _v.index() == 1; } + // returns the width of a bitvector type, errors out for other types + int width() const { return std::get<0>(_v); } + // returns the address width of a bitvector type, errors out for other types + int addr_width() const { return std::get<1>(_v).first; } + // returns the data width of a bitvector type, errors out for other types + int data_width() const { return std::get<1>(_v).second; } + bool operator==(Sort const& other) const { return _v == other._v; } + unsigned int hash() const { return mkhash(_v); } + }; + class Factory; + class Node; + class IR { + friend class Factory; + friend class Node; + // one NodeData is stored per Node, containing the function and non-node arguments + // note that NodeData is deduplicated by ComputeGraph + class NodeData { + Fn _fn; + std::variant< + std::monostate, + RTLIL::Const, + IdString, + int + > _extra; + public: + NodeData() : _fn(Fn::invalid) {} + NodeData(Fn fn) : _fn(fn) {} + template NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward(extra)) {} + Fn fn() const { return _fn; } + const RTLIL::Const &as_const() const { return std::get(_extra); } + IdString as_idstring() const { return std::get(_extra); } + int as_int() const { return std::get(_extra); } + int hash() const { + return mkhash((unsigned int) _fn, mkhash(_extra)); + } + bool operator==(NodeData const &other) const { + return _fn == other._fn && _extra == other._extra; + } + }; + // Attr contains all the information about a note that should not be deduplicated + struct Attr { + Sort sort; + }; + // our specialised version of ComputeGraph + // the sparse_attr IdString stores a naming suggestion, retrieved with name() + // the key is currently used to identify the nodes that represent output and next state values + // the bool is true for next state values + using Graph = ComputeGraph>; + Graph _graph; + dict _input_sorts; + dict _output_sorts; + dict _state_sorts; + dict _initial_state_signal; + dict _initial_state_memory; + public: + static IR from_module(Module *module); + Factory factory(); + int size() const { return _graph.size(); } + Node operator[](int i); + void topological_sort(); + void forward_buf(); + dict inputs() const { return _input_sorts; } + dict outputs() const { return _output_sorts; } + dict state() const { return _state_sorts; } + RTLIL::Const const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); } + MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); } + Node get_output_node(IdString name); + Node get_state_next_node(IdString name); + class iterator { + friend class IR; + IR *_ir; + int _index; + iterator(IR *ir, int index) : _ir(ir), _index(index) {} + public: + using iterator_category = std::input_iterator_tag; + using value_type = Node; + using pointer = arrow_proxy; + using reference = Node; + using difference_type = ptrdiff_t; + Node operator*(); + iterator &operator++() { _index++; return *this; } + bool operator!=(iterator const &other) const { return _ir != other._ir || _index != other._index; } + bool operator==(iterator const &other) const { return !(*this != other); } + pointer operator->(); + // TODO: implement operator-> using the arrow_proxy class currently in mem.h + }; + iterator begin() { return iterator(this, 0); } + iterator end() { return iterator(this, _graph.size()); } + }; + // Node is an immutable reference to a FunctionalIR node + class Node { + friend class Factory; + friend class IR; + IR::Graph::ConstRef _ref; + explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { } + explicit operator IR::Graph::ConstRef() { return _ref; } + public: + // the node's index. may change if nodes are added or removed + int id() const { return _ref.index(); } + // a name suggestion for the node, which need not be unique + IdString name() const { + if(_ref.has_sparse_attr()) + return _ref.sparse_attr(); + else + return std::string("\\n") + std::to_string(id()); + } + Fn fn() const { return _ref.function().fn(); } + Sort sort() const { return _ref.attr().sort; } + // returns the width of a bitvector node, errors out for other nodes + int width() const { return sort().width(); } + size_t arg_count() const { return _ref.size(); } + Node arg(int n) const { return Node(_ref.arg(n)); } + // visit calls the appropriate visitor method depending on the type of the node + template auto visit(Visitor v) const + { + // currently templated but could be switched to AbstractVisitor & + switch(_ref.function().fn()) { + case Fn::invalid: log_error("invalid node in visit"); break; + case Fn::buf: return v.buf(*this, arg(0)); break; + case Fn::slice: return v.slice(*this, arg(0), _ref.function().as_int(), sort().width()); break; + case Fn::zero_extend: return v.zero_extend(*this, arg(0), width()); break; + case Fn::sign_extend: return v.sign_extend(*this, arg(0), width()); break; + case Fn::concat: return v.concat(*this, arg(0), arg(1)); break; + case Fn::add: return v.add(*this, arg(0), arg(1)); break; + case Fn::sub: return v.sub(*this, arg(0), arg(1)); break; + case Fn::mul: return v.mul(*this, arg(0), arg(1)); break; + case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1)); break; + case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1)); break; + case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1)); break; + case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1)); break; + case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1)); break; + case Fn::bitwise_not: return v.bitwise_not(*this, arg(0)); break; + case Fn::unary_minus: return v.unary_minus(*this, arg(0)); break; + case Fn::reduce_and: return v.reduce_and(*this, arg(0)); break; + case Fn::reduce_or: return v.reduce_or(*this, arg(0)); break; + case Fn::reduce_xor: return v.reduce_xor(*this, arg(0)); break; + case Fn::equal: return v.equal(*this, arg(0), arg(1)); break; + case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1)); break; + case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1)); break; + case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1)); break; + case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1)); break; + case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1)); break; + case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1)); break; + case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break; + case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break; + case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break; + case Fn::constant: return v.constant(*this, _ref.function().as_const()); break; + case Fn::input: return v.input(*this, _ref.function().as_idstring()); break; + case Fn::state: return v.state(*this, _ref.function().as_idstring()); break; + case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break; + case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break; + } + } + std::string to_string(); + std::string to_string(std::function); + }; + inline Node IR::operator[](int i) { return Node(_graph[i]); } + inline Node IR::get_output_node(IdString name) { return Node(_graph({name, false})); } + inline Node IR::get_state_next_node(IdString name) { return Node(_graph({name, true})); } + inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); } + inline arrow_proxy IR::iterator::operator->() { return arrow_proxy(**this); } + // AbstractVisitor provides an abstract base class for visitors + template struct AbstractVisitor { + virtual T buf(Node self, Node n) = 0; + virtual T slice(Node self, Node a, int offset, int out_width) = 0; + virtual T zero_extend(Node self, Node a, int out_width) = 0; + virtual T sign_extend(Node self, Node a, int out_width) = 0; + virtual T concat(Node self, Node a, Node b) = 0; + virtual T add(Node self, Node a, Node b) = 0; + virtual T sub(Node self, Node a, Node b) = 0; + virtual T mul(Node self, Node a, Node b) = 0; + virtual T unsigned_div(Node self, Node a, Node b) = 0; + virtual T unsigned_mod(Node self, Node a, Node b) = 0; + virtual T bitwise_and(Node self, Node a, Node b) = 0; + virtual T bitwise_or(Node self, Node a, Node b) = 0; + virtual T bitwise_xor(Node self, Node a, Node b) = 0; + virtual T bitwise_not(Node self, Node a) = 0; + virtual T unary_minus(Node self, Node a) = 0; + virtual T reduce_and(Node self, Node a) = 0; + virtual T reduce_or(Node self, Node a) = 0; + virtual T reduce_xor(Node self, Node a) = 0; + virtual T equal(Node self, Node a, Node b) = 0; + virtual T not_equal(Node self, Node a, Node b) = 0; + virtual T signed_greater_than(Node self, Node a, Node b) = 0; + virtual T signed_greater_equal(Node self, Node a, Node b) = 0; + virtual T unsigned_greater_than(Node self, Node a, Node b) = 0; + virtual T unsigned_greater_equal(Node self, Node a, Node b) = 0; + virtual T logical_shift_left(Node self, Node a, Node b) = 0; + virtual T logical_shift_right(Node self, Node a, Node b) = 0; + virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0; + virtual T mux(Node self, Node a, Node b, Node s) = 0; + virtual T constant(Node self, RTLIL::Const const & value) = 0; + virtual T input(Node self, IdString name) = 0; + virtual T state(Node self, IdString name) = 0; + virtual T memory_read(Node self, Node mem, Node addr) = 0; + virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0; + }; + // DefaultVisitor provides defaults for all visitor methods which just calls default_handler + template struct DefaultVisitor : public AbstractVisitor { + virtual T default_handler(Node self) = 0; + T buf(Node self, Node) override { return default_handler(self); } + T slice(Node self, Node, int, int) override { return default_handler(self); } + T zero_extend(Node self, Node, int) override { return default_handler(self); } + T sign_extend(Node self, Node, int) override { return default_handler(self); } + T concat(Node self, Node, Node) override { return default_handler(self); } + T add(Node self, Node, Node) override { return default_handler(self); } + T sub(Node self, Node, Node) override { return default_handler(self); } + T mul(Node self, Node, Node) override { return default_handler(self); } + T unsigned_div(Node self, Node, Node) override { return default_handler(self); } + T unsigned_mod(Node self, Node, Node) override { return default_handler(self); } + T bitwise_and(Node self, Node, Node) override { return default_handler(self); } + T bitwise_or(Node self, Node, Node) override { return default_handler(self); } + T bitwise_xor(Node self, Node, Node) override { return default_handler(self); } + T bitwise_not(Node self, Node) override { return default_handler(self); } + T unary_minus(Node self, Node) override { return default_handler(self); } + T reduce_and(Node self, Node) override { return default_handler(self); } + T reduce_or(Node self, Node) override { return default_handler(self); } + T reduce_xor(Node self, Node) override { return default_handler(self); } + T equal(Node self, Node, Node) override { return default_handler(self); } + T not_equal(Node self, Node, Node) override { return default_handler(self); } + T signed_greater_than(Node self, Node, Node) override { return default_handler(self); } + T signed_greater_equal(Node self, Node, Node) override { return default_handler(self); } + T unsigned_greater_than(Node self, Node, Node) override { return default_handler(self); } + T unsigned_greater_equal(Node self, Node, Node) override { return default_handler(self); } + T logical_shift_left(Node self, Node, Node) override { return default_handler(self); } + T logical_shift_right(Node self, Node, Node) override { return default_handler(self); } + T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); } + T mux(Node self, Node, Node, Node) override { return default_handler(self); } + T constant(Node self, RTLIL::Const const &) override { return default_handler(self); } + T input(Node self, IdString) override { return default_handler(self); } + T state(Node self, IdString) override { return default_handler(self); } + T memory_read(Node self, Node, Node) override { return default_handler(self); } + T memory_write(Node self, Node, Node, Node) override { return default_handler(self); } + }; + // a factory is used to modify a FunctionalIR. it creates new nodes and allows for some modification of existing nodes. + class Factory { + friend class IR; + IR &_ir; + explicit Factory(IR &ir) : _ir(ir) {} + Node add(IR::NodeData &&fn, Sort &&sort, std::initializer_list args) { + log_assert(!sort.is_signal() || sort.width() > 0); + log_assert(!sort.is_memory() || sort.addr_width() > 0 && sort.data_width() > 0); + IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)}); + for (auto arg : args) + ref.append_arg(IR::Graph::ConstRef(arg)); + return Node(ref); + } + IR::Graph::Ref mutate(Node n) { + return _ir._graph[n._ref.index()]; + } + void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); } + void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); } + void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } + public: + Node slice(Node a, int offset, int out_width) { + log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); + if(offset == 0 && out_width == a.width()) + return a; + return add(IR::NodeData(Fn::slice, offset), Sort(out_width), {a}); + } + // extend will either extend or truncate the provided value to reach the desired width + Node extend(Node a, int out_width, bool is_signed) { + int in_width = a.sort().width(); + log_assert(a.sort().is_signal()); + if(in_width == out_width) + return a; + if(in_width > out_width) + return slice(a, 0, out_width); + if(is_signed) + return add(Fn::sign_extend, Sort(out_width), {a}); + else + return add(Fn::zero_extend, Sort(out_width), {a}); + } + Node concat(Node a, Node b) { + log_assert(a.sort().is_signal() && b.sort().is_signal()); + return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b}); + } + Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } + Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } + Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } + Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } + Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } + Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } + Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } + Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } + Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } + Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } + Node reduce_and(Node a) { + check_unary(a); + if(a.width() == 1) + return a; + return add(Fn::reduce_and, Sort(1), {a}); + } + Node reduce_or(Node a) { + check_unary(a); + if(a.width() == 1) + return a; + return add(Fn::reduce_or, Sort(1), {a}); + } + Node reduce_xor(Node a) { + check_unary(a); + if(a.width() == 1) + return a; + return add(Fn::reduce_xor, Sort(1), {a}); + } + Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } + Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } + Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } + Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } + Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } + Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } + Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } + Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } + Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } + Node mux(Node a, Node b, Node s) { + log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); + return add(Fn::mux, a.sort(), {a, b, s}); + } + Node memory_read(Node mem, Node addr) { + log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); + return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr}); + } + Node memory_write(Node mem, Node addr, Node data) { + log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() && + mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width()); + return add(Fn::memory_write, mem.sort(), {mem, addr, data}); + } + Node constant(RTLIL::Const value) { + return add(IR::NodeData(Fn::constant, std::move(value)), Sort(value.size()), {}); + } + Node create_pending(int width) { + return add(Fn::buf, Sort(width), {}); + } + void update_pending(Node node, Node value) { + log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0); + log_assert(node.sort() == value.sort()); + mutate(node).append_arg(value._ref); + } + void add_input(IdString name, int width) { + auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width)); + if (!inserted) log_error("input `%s` was re-defined", name.c_str()); + } + void add_output(IdString name, int width) { + auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width)); + if (!inserted) log_error("output `%s` was re-defined", name.c_str()); + } + void add_state(IdString name, Sort sort) { + auto [it, inserted] = _ir._state_sorts.emplace(name, sort); + if (!inserted) log_error("state `%s` was re-defined", name.c_str()); + } + Node input(IdString name) { + return add(IR::NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {}); + } + Node current_state(IdString name) { + return add(IR::NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {}); + } + void set_output(IdString output, Node value) { + log_assert(_ir._output_sorts.at(output) == value.sort()); + mutate(value).assign_key({output, false}); + } + void set_initial_state(IdString state, RTLIL::Const value) { + Sort &sort = _ir._state_sorts.at(state); + value.extu(sort.width()); + _ir._initial_state_signal.emplace(state, std::move(value)); + } + void set_initial_state(IdString state, MemContents value) { + log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state)); + _ir._initial_state_memory.emplace(state, std::move(value)); + } + void set_next_state(IdString state, Node value) { + log_assert(_ir._state_sorts.at(state) == value.sort()); + mutate(value).assign_key({state, true}); + } + void suggest_name(Node node, IdString name) { + mutate(node).sparse_attr() = name; + } + }; + inline Factory IR::factory() { return Factory(*this); } + template class Scope { + protected: + char substitution_character = '_'; + virtual bool is_character_legal(char) = 0; + private: + pool _used_names; + dict _by_id; + public: + void reserve(std::string name) { + _used_names.insert(std::move(name)); + } + std::string unique_name(IdString suggestion) { + std::string str = RTLIL::unescape_id(suggestion); + for(size_t i = 0; i < str.size(); i++) + if(!is_character_legal(str[i])) + str[i] = substitution_character; + if(_used_names.count(str) == 0) { + _used_names.insert(str); + return str; + } + for (int idx = 0 ; ; idx++){ + std::string suffixed = str + "_" + std::to_string(idx); + if(_used_names.count(suffixed) == 0) { + _used_names.insert(suffixed); + return suffixed; + } + } + } + std::string operator()(Id id, IdString suggestion) { + auto it = _by_id.find(id); + if(it != _by_id.end()) + return it->second; + std::string str = unique_name(suggestion); + _by_id.insert({id, str}); + return str; + } + }; + class Writer { + std::ostream *os; + void print_impl(const char *fmt, vector>& fns); + public: + Writer(std::ostream &os) : os(&os) {} + template Writer& operator <<(T&& arg) { *os << std::forward(arg); return *this; } + template + void print(const char *fmt, Args&&... args) + { + vector> fns { [&]() { *this << args; }... }; + print_impl(fmt, fns); + } + template + void print_with(Fn fn, const char *fmt, Args&&... args) + { + vector> fns { [&]() { + if constexpr (std::is_invocable_v) + *this << fn(args); + else + *this << args; }... + }; + print_impl(fmt, fns); + } + }; + +} YOSYS_NAMESPACE_END - #endif diff --git a/kernel/functionalir.h b/kernel/functionalir.h deleted file mode 100644 index fdbdcbde3..000000000 --- a/kernel/functionalir.h +++ /dev/null @@ -1,575 +0,0 @@ -/* - * yosys -- Yosys Open SYnthesis Suite - * - * Copyright (C) 2024 Emily Schmidt - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - * - */ - -#ifndef FUNCTIONALIR_H -#define FUNCTIONALIR_H - -#include "kernel/yosys.h" -#include "kernel/functional.h" -#include "kernel/drivertools.h" -#include "kernel/mem.h" -#include "kernel/topo_scc.h" - -USING_YOSYS_NAMESPACE -YOSYS_NAMESPACE_BEGIN - -class FunctionalIR { -public: - // each function is documented with a short pseudocode declaration or definition - // standard C/Verilog operators are used to describe the result - // - // the types used in this are: - // - bit[N]: a bitvector of N bits - // bit[N] can be indicated as signed or unsigned. this is not tracked by the functional backend - // but is meant to indicate how the value is interpreted - // if a bit[N] is marked as neither signed nor unsigned, this means the result should be valid with *either* interpretation - // - memory[N, M]: a memory with N address and M data bits - // - int: C++ int - // - Const[N]: yosys RTLIL::Const (with size() == N) - // - IdString: yosys IdString - // - any: used in documentation to indicate that the type is unconstrained - // - // nodes in the functional backend are either of type bit[N] or memory[N,M] (for some N, M: int) - // additionally, they can carry a constant of type int, Const[N] or IdString - // each node has a 'sort' field that stores the type of the node - // slice, zero_extend, sign_extend use the type field to store out_width - enum class Fn { - // invalid() = known-invalid/shouldn't happen value - // TODO: maybe remove this and use e.g. std::optional instead? - invalid, - // buf(a: any): any = a - // no-op operation - // when constructing the compute graph we generate invalid buf() nodes as a placeholder - // and later insert the argument - buf, - // slice(a: bit[in_width], offset: int, out_width: int): bit[out_width] = a[offset +: out_width] - // required: offset + out_width <= in_width - slice, - // zero_extend(a: unsigned bit[in_width], out_width: int): unsigned bit[out_width] = a (zero extended) - // required: out_width > in_width - zero_extend, - // sign_extend(a: signed bit[in_width], out_width: int): signed bit[out_width] = a (sign extended) - // required: out_width > in_width - sign_extend, - // concat(a: bit[N], b: bit[M]): bit[N+M] = {b, a} (verilog syntax) - // concatenates two bitvectors, with a in the least significant position and b in the more significant position - concat, - // add(a: bit[N], b: bit[N]): bit[N] = a + b - add, - // sub(a: bit[N], b: bit[N]): bit[N] = a - b - sub, - // mul(a: bit[N], b: bit[N]): bit[N] = a * b - mul, - // unsigned_div(a: unsigned bit[N], b: unsigned bit[N]): bit[N] = a / b - unsigned_div, - // unsigned_mod(a: signed bit[N], b: signed bit[N]): bit[N] = a % b - unsigned_mod, - // bitwise_and(a: bit[N], b: bit[N]): bit[N] = a & b - bitwise_and, - // bitwise_or(a: bit[N], b: bit[N]): bit[N] = a | b - bitwise_or, - // bitwise_xor(a: bit[N], b: bit[N]): bit[N] = a ^ b - bitwise_xor, - // bitwise_not(a: bit[N]): bit[N] = ~a - bitwise_not, - // reduce_and(a: bit[N]): bit[1] = &a - reduce_and, - // reduce_or(a: bit[N]): bit[1] = |a - reduce_or, - // reduce_xor(a: bit[N]): bit[1] = ^a - reduce_xor, - // unary_minus(a: bit[N]): bit[N] = -a - unary_minus, - // equal(a: bit[N], b: bit[N]): bit[1] = (a == b) - equal, - // not_equal(a: bit[N], b: bit[N]): bit[1] = (a != b) - not_equal, - // signed_greater_than(a: signed bit[N], b: signed bit[N]): bit[1] = (a > b) - signed_greater_than, - // signed_greater_equal(a: signed bit[N], b: signed bit[N]): bit[1] = (a >= b) - signed_greater_equal, - // unsigned_greater_than(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a > b) - unsigned_greater_than, - // unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b) - unsigned_greater_equal, - // logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b - // required: M == clog2(N) - logical_shift_left, - // logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b - // required: M == clog2(N) - logical_shift_right, - // arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b - // required: M == clog2(N) - arithmetic_shift_right, - // mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a - mux, - // constant(a: Const[N]): bit[N] = a - constant, - // input(a: IdString): any - // returns the current value of the input with the specified name - input, - // state(a: IdString): any - // returns the current value of the state variable with the specified name - state, - // memory_read(memory: memory[addr_width, data_width], addr: bit[addr_width]): bit[data_width] = memory[addr] - memory_read, - // memory_write(memory: memory[addr_width, data_width], addr: bit[addr_width], data: bit[data_width]): memory[addr_width, data_width] - // returns a copy of `memory` but with the value at `addr` changed to `data` - memory_write - }; - // returns the name of a FunctionalIR::Fn value, as a string literal - static const char *fn_to_string(Fn); - // FunctionalIR::Sort represents the sort or type of a node - // currently the only two types are signal/bit and memory - class Sort { - std::variant> _v; - public: - explicit Sort(int width) : _v(width) { } - Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { } - bool is_signal() const { return _v.index() == 0; } - bool is_memory() const { return _v.index() == 1; } - // returns the width of a bitvector type, errors out for other types - int width() const { return std::get<0>(_v); } - // returns the address width of a bitvector type, errors out for other types - int addr_width() const { return std::get<1>(_v).first; } - // returns the data width of a bitvector type, errors out for other types - int data_width() const { return std::get<1>(_v).second; } - bool operator==(Sort const& other) const { return _v == other._v; } - unsigned int hash() const { return mkhash(_v); } - }; -private: - // one NodeData is stored per Node, containing the function and non-node arguments - // note that NodeData is deduplicated by ComputeGraph - class NodeData { - Fn _fn; - std::variant< - std::monostate, - RTLIL::Const, - IdString, - int - > _extra; - public: - NodeData() : _fn(Fn::invalid) {} - NodeData(Fn fn) : _fn(fn) {} - template NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward(extra)) {} - Fn fn() const { return _fn; } - const RTLIL::Const &as_const() const { return std::get(_extra); } - IdString as_idstring() const { return std::get(_extra); } - int as_int() const { return std::get(_extra); } - int hash() const { - return mkhash((unsigned int) _fn, mkhash(_extra)); - } - bool operator==(NodeData const &other) const { - return _fn == other._fn && _extra == other._extra; - } - }; - // Attr contains all the information about a note that should not be deduplicated - struct Attr { - Sort sort; - }; - // our specialised version of ComputeGraph - // the sparse_attr IdString stores a naming suggestion, retrieved with name() - // the key is currently used to identify the nodes that represent output and next state values - // the bool is true for next state values - using Graph = ComputeGraph>; - Graph _graph; - dict _input_sorts; - dict _output_sorts; - dict _state_sorts; - dict _initial_state_signal; - dict _initial_state_memory; -public: - class Factory; - // Node is an immutable reference to a FunctionalIR node - class Node { - friend class Factory; - friend class FunctionalIR; - Graph::ConstRef _ref; - explicit Node(Graph::ConstRef ref) : _ref(ref) { } - explicit operator Graph::ConstRef() { return _ref; } - public: - // the node's index. may change if nodes are added or removed - int id() const { return _ref.index(); } - // a name suggestion for the node, which need not be unique - IdString name() const { - if(_ref.has_sparse_attr()) - return _ref.sparse_attr(); - else - return std::string("\\n") + std::to_string(id()); - } - Fn fn() const { return _ref.function().fn(); } - Sort sort() const { return _ref.attr().sort; } - // returns the width of a bitvector node, errors out for other nodes - int width() const { return sort().width(); } - size_t arg_count() const { return _ref.size(); } - Node arg(int n) const { return Node(_ref.arg(n)); } - // visit calls the appropriate visitor method depending on the type of the node - template auto visit(Visitor v) const - { - // currently templated but could be switched to AbstractVisitor & - switch(_ref.function().fn()) { - case Fn::invalid: log_error("invalid node in visit"); break; - case Fn::buf: return v.buf(*this, arg(0)); break; - case Fn::slice: return v.slice(*this, arg(0), _ref.function().as_int(), sort().width()); break; - case Fn::zero_extend: return v.zero_extend(*this, arg(0), width()); break; - case Fn::sign_extend: return v.sign_extend(*this, arg(0), width()); break; - case Fn::concat: return v.concat(*this, arg(0), arg(1)); break; - case Fn::add: return v.add(*this, arg(0), arg(1)); break; - case Fn::sub: return v.sub(*this, arg(0), arg(1)); break; - case Fn::mul: return v.mul(*this, arg(0), arg(1)); break; - case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1)); break; - case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1)); break; - case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1)); break; - case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1)); break; - case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1)); break; - case Fn::bitwise_not: return v.bitwise_not(*this, arg(0)); break; - case Fn::unary_minus: return v.unary_minus(*this, arg(0)); break; - case Fn::reduce_and: return v.reduce_and(*this, arg(0)); break; - case Fn::reduce_or: return v.reduce_or(*this, arg(0)); break; - case Fn::reduce_xor: return v.reduce_xor(*this, arg(0)); break; - case Fn::equal: return v.equal(*this, arg(0), arg(1)); break; - case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1)); break; - case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1)); break; - case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1)); break; - case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1)); break; - case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1)); break; - case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1)); break; - case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break; - case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break; - case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break; - case Fn::constant: return v.constant(*this, _ref.function().as_const()); break; - case Fn::input: return v.input(*this, _ref.function().as_idstring()); break; - case Fn::state: return v.state(*this, _ref.function().as_idstring()); break; - case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break; - case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break; - } - } - std::string to_string(); - std::string to_string(std::function); - }; - // AbstractVisitor provides an abstract base class for visitors - template struct AbstractVisitor { - virtual T buf(Node self, Node n) = 0; - virtual T slice(Node self, Node a, int offset, int out_width) = 0; - virtual T zero_extend(Node self, Node a, int out_width) = 0; - virtual T sign_extend(Node self, Node a, int out_width) = 0; - virtual T concat(Node self, Node a, Node b) = 0; - virtual T add(Node self, Node a, Node b) = 0; - virtual T sub(Node self, Node a, Node b) = 0; - virtual T mul(Node self, Node a, Node b) = 0; - virtual T unsigned_div(Node self, Node a, Node b) = 0; - virtual T unsigned_mod(Node self, Node a, Node b) = 0; - virtual T bitwise_and(Node self, Node a, Node b) = 0; - virtual T bitwise_or(Node self, Node a, Node b) = 0; - virtual T bitwise_xor(Node self, Node a, Node b) = 0; - virtual T bitwise_not(Node self, Node a) = 0; - virtual T unary_minus(Node self, Node a) = 0; - virtual T reduce_and(Node self, Node a) = 0; - virtual T reduce_or(Node self, Node a) = 0; - virtual T reduce_xor(Node self, Node a) = 0; - virtual T equal(Node self, Node a, Node b) = 0; - virtual T not_equal(Node self, Node a, Node b) = 0; - virtual T signed_greater_than(Node self, Node a, Node b) = 0; - virtual T signed_greater_equal(Node self, Node a, Node b) = 0; - virtual T unsigned_greater_than(Node self, Node a, Node b) = 0; - virtual T unsigned_greater_equal(Node self, Node a, Node b) = 0; - virtual T logical_shift_left(Node self, Node a, Node b) = 0; - virtual T logical_shift_right(Node self, Node a, Node b) = 0; - virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0; - virtual T mux(Node self, Node a, Node b, Node s) = 0; - virtual T constant(Node self, RTLIL::Const const & value) = 0; - virtual T input(Node self, IdString name) = 0; - virtual T state(Node self, IdString name) = 0; - virtual T memory_read(Node self, Node mem, Node addr) = 0; - virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0; - }; - // DefaultVisitor provides defaults for all visitor methods which just calls default_handler - template struct DefaultVisitor : public AbstractVisitor { - virtual T default_handler(Node self) = 0; - T buf(Node self, Node) override { return default_handler(self); } - T slice(Node self, Node, int, int) override { return default_handler(self); } - T zero_extend(Node self, Node, int) override { return default_handler(self); } - T sign_extend(Node self, Node, int) override { return default_handler(self); } - T concat(Node self, Node, Node) override { return default_handler(self); } - T add(Node self, Node, Node) override { return default_handler(self); } - T sub(Node self, Node, Node) override { return default_handler(self); } - T mul(Node self, Node, Node) override { return default_handler(self); } - T unsigned_div(Node self, Node, Node) override { return default_handler(self); } - T unsigned_mod(Node self, Node, Node) override { return default_handler(self); } - T bitwise_and(Node self, Node, Node) override { return default_handler(self); } - T bitwise_or(Node self, Node, Node) override { return default_handler(self); } - T bitwise_xor(Node self, Node, Node) override { return default_handler(self); } - T bitwise_not(Node self, Node) override { return default_handler(self); } - T unary_minus(Node self, Node) override { return default_handler(self); } - T reduce_and(Node self, Node) override { return default_handler(self); } - T reduce_or(Node self, Node) override { return default_handler(self); } - T reduce_xor(Node self, Node) override { return default_handler(self); } - T equal(Node self, Node, Node) override { return default_handler(self); } - T not_equal(Node self, Node, Node) override { return default_handler(self); } - T signed_greater_than(Node self, Node, Node) override { return default_handler(self); } - T signed_greater_equal(Node self, Node, Node) override { return default_handler(self); } - T unsigned_greater_than(Node self, Node, Node) override { return default_handler(self); } - T unsigned_greater_equal(Node self, Node, Node) override { return default_handler(self); } - T logical_shift_left(Node self, Node, Node) override { return default_handler(self); } - T logical_shift_right(Node self, Node, Node) override { return default_handler(self); } - T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); } - T mux(Node self, Node, Node, Node) override { return default_handler(self); } - T constant(Node self, RTLIL::Const const &) override { return default_handler(self); } - T input(Node self, IdString) override { return default_handler(self); } - T state(Node self, IdString) override { return default_handler(self); } - T memory_read(Node self, Node, Node) override { return default_handler(self); } - T memory_write(Node self, Node, Node, Node) override { return default_handler(self); } - }; - // a factory is used to modify a FunctionalIR. it creates new nodes and allows for some modification of existing nodes. - class Factory { - FunctionalIR &_ir; - friend class FunctionalIR; - explicit Factory(FunctionalIR &ir) : _ir(ir) {} - Node add(NodeData &&fn, Sort &&sort, std::initializer_list args) { - log_assert(!sort.is_signal() || sort.width() > 0); - log_assert(!sort.is_memory() || sort.addr_width() > 0 && sort.data_width() > 0); - Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)}); - for (auto arg : args) - ref.append_arg(Graph::ConstRef(arg)); - return Node(ref); - } - Graph::Ref mutate(Node n) { - return _ir._graph[n._ref.index()]; - } - void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); } - void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); } - void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } - public: - Node slice(Node a, int offset, int out_width) { - log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); - if(offset == 0 && out_width == a.width()) - return a; - return add(NodeData(Fn::slice, offset), Sort(out_width), {a}); - } - // extend will either extend or truncate the provided value to reach the desired width - Node extend(Node a, int out_width, bool is_signed) { - int in_width = a.sort().width(); - log_assert(a.sort().is_signal()); - if(in_width == out_width) - return a; - if(in_width > out_width) - return slice(a, 0, out_width); - if(is_signed) - return add(Fn::sign_extend, Sort(out_width), {a}); - else - return add(Fn::zero_extend, Sort(out_width), {a}); - } - Node concat(Node a, Node b) { - log_assert(a.sort().is_signal() && b.sort().is_signal()); - return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b}); - } - Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } - Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } - Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } - Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } - Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } - Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } - Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } - Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } - Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } - Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } - Node reduce_and(Node a) { - check_unary(a); - if(a.width() == 1) - return a; - return add(Fn::reduce_and, Sort(1), {a}); - } - Node reduce_or(Node a) { - check_unary(a); - if(a.width() == 1) - return a; - return add(Fn::reduce_or, Sort(1), {a}); - } - Node reduce_xor(Node a) { - check_unary(a); - if(a.width() == 1) - return a; - return add(Fn::reduce_xor, Sort(1), {a}); - } - Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } - Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } - Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } - Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } - Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } - Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } - Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } - Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } - Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } - Node mux(Node a, Node b, Node s) { - log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); - return add(Fn::mux, a.sort(), {a, b, s}); - } - Node memory_read(Node mem, Node addr) { - log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); - return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr}); - } - Node memory_write(Node mem, Node addr, Node data) { - log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() && - mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width()); - return add(Fn::memory_write, mem.sort(), {mem, addr, data}); - } - Node constant(RTLIL::Const value) { - return add(NodeData(Fn::constant, std::move(value)), Sort(value.size()), {}); - } - Node create_pending(int width) { - return add(Fn::buf, Sort(width), {}); - } - void update_pending(Node node, Node value) { - log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0); - log_assert(node.sort() == value.sort()); - mutate(node).append_arg(value._ref); - } - void add_input(IdString name, int width) { - auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width)); - if (!inserted) log_error("input `%s` was re-defined", name.c_str()); - } - void add_output(IdString name, int width) { - auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width)); - if (!inserted) log_error("output `%s` was re-defined", name.c_str()); - } - void add_state(IdString name, Sort sort) { - auto [it, inserted] = _ir._state_sorts.emplace(name, sort); - if (!inserted) log_error("state `%s` was re-defined", name.c_str()); - } - Node input(IdString name) { - return add(NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {}); - } - Node current_state(IdString name) { - return add(NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {}); - } - void set_output(IdString output, Node value) { - log_assert(_ir._output_sorts.at(output) == value.sort()); - mutate(value).assign_key({output, false}); - } - void set_initial_state(IdString state, RTLIL::Const value) { - Sort &sort = _ir._state_sorts.at(state); - value.extu(sort.width()); - _ir._initial_state_signal.emplace(state, std::move(value)); - } - void set_initial_state(IdString state, MemContents value) { - log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state)); - _ir._initial_state_memory.emplace(state, std::move(value)); - } - void set_next_state(IdString state, Node value) { - log_assert(_ir._state_sorts.at(state) == value.sort()); - mutate(value).assign_key({state, true}); - } - void suggest_name(Node node, IdString name) { - mutate(node).sparse_attr() = name; - } - }; - static FunctionalIR from_module(Module *module); - Factory factory() { return Factory(*this); } - int size() const { return _graph.size(); } - Node operator[](int i) { return Node(_graph[i]); } - void topological_sort(); - void forward_buf(); - dict inputs() const { return _input_sorts; } - dict outputs() const { return _output_sorts; } - dict state() const { return _state_sorts; } - RTLIL::Const const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); } - MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); } - Node get_output_node(IdString name) { return Node(_graph({name, false})); } - Node get_state_next_node(IdString name) { return Node(_graph({name, true})); } - class Iterator { - friend class FunctionalIR; - FunctionalIR *_ir; - int _index; - Iterator(FunctionalIR *ir, int index) : _ir(ir), _index(index) {} - public: - Node operator*() { return Node(_ir->_graph[_index]); } - Iterator &operator++() { _index++; return *this; } - bool operator!=(Iterator const &other) const { return _index != other._index; } - }; - Iterator begin() { return Iterator(this, 0); } - Iterator end() { return Iterator(this, _graph.size()); } -}; - -namespace FunctionalTools { - template class Scope { - protected: - char substitution_character = '_'; - virtual bool is_character_legal(char) = 0; - private: - pool _used_names; - dict _by_id; - public: - void reserve(std::string name) { - _used_names.insert(std::move(name)); - } - std::string unique_name(IdString suggestion) { - std::string str = RTLIL::unescape_id(suggestion); - for(size_t i = 0; i < str.size(); i++) - if(!is_character_legal(str[i])) - str[i] = substitution_character; - if(_used_names.count(str) == 0) { - _used_names.insert(str); - return str; - } - for (int idx = 0 ; ; idx++){ - std::string suffixed = str + "_" + std::to_string(idx); - if(_used_names.count(suffixed) == 0) { - _used_names.insert(suffixed); - return suffixed; - } - } - } - std::string operator()(Id id, IdString suggestion) { - auto it = _by_id.find(id); - if(it != _by_id.end()) - return it->second; - std::string str = unique_name(suggestion); - _by_id.insert({id, str}); - return str; - } - }; - class Writer { - std::ostream *os; - void print_impl(const char *fmt, vector>& fns); - public: - Writer(std::ostream &os) : os(&os) {} - template Writer& operator <<(T&& arg) { *os << std::forward(arg); return *this; } - template - void print(const char *fmt, Args&&... args) - { - vector> fns { [&]() { *this << args; }... }; - print_impl(fmt, fns); - } - template - void print_with(Fn fn, const char *fmt, Args&&... args) - { - vector> fns { [&]() { - if constexpr (std::is_invocable_v) - *this << fn(args); - else - *this << args; }... - }; - print_impl(fmt, fns); - } - }; -} - -YOSYS_NAMESPACE_END - -#endif diff --git a/kernel/mem.h b/kernel/mem.h index 4be4b6864..8c935adc1 100644 --- a/kernel/mem.h +++ b/kernel/mem.h @@ -22,6 +22,7 @@ #include "kernel/yosys.h" #include "kernel/ffinit.h" +#include "kernel/utils.h" YOSYS_NAMESPACE_BEGIN @@ -224,15 +225,6 @@ struct Mem : RTLIL::AttrObject { Mem(Module *module, IdString memid, int width, int start_offset, int size) : module(module), memid(memid), packed(false), mem(nullptr), cell(nullptr), width(width), start_offset(start_offset), size(size) {} }; -// this class is used for implementing operator-> on iterators that return values rather than references -// it's necessary because in C++ operator-> is called recursively until a raw pointer is obtained -template -struct arrow_proxy { - T v; - explicit arrow_proxy(T const & v) : v(v) {} - T* operator->() { return &v; } -}; - // MemContents efficiently represents the contents of a potentially sparse memory by storing only those segments that are actually defined class MemContents { public: @@ -303,6 +295,7 @@ public: reference operator *() const { return range(_memory->_data_width, _addr, _memory->_values.at(_addr)); } pointer operator->() const { return arrow_proxy(**this); } bool operator !=(iterator const &other) const { return _memory != other._memory || _addr != other._addr; } + bool operator ==(iterator const &other) const { return !(*this != other); } iterator &operator++(); }; MemContents(int addr_width, int data_width, RTLIL::Const default_value) diff --git a/kernel/utils.h b/kernel/utils.h index 3216c5eb5..99f327db4 100644 --- a/kernel/utils.h +++ b/kernel/utils.h @@ -253,6 +253,15 @@ template , typename OPS = hash_ops> cla } }; +// this class is used for implementing operator-> on iterators that return values rather than references +// it's necessary because in C++ operator-> is called recursively until a raw pointer is obtained +template +struct arrow_proxy { + T v; + explicit arrow_proxy(T const & v) : v(v) {} + T* operator->() { return &v; } +}; + YOSYS_NAMESPACE_END #endif diff --git a/passes/cmds/example_dt.cc b/passes/cmds/example_dt.cc index 4b836d75b..aaf07dadd 100644 --- a/passes/cmds/example_dt.cc +++ b/passes/cmds/example_dt.cc @@ -1,7 +1,7 @@ #include "kernel/yosys.h" #include "kernel/drivertools.h" #include "kernel/topo_scc.h" -#include "kernel/functional.h" +#include "kernel/compute_graph.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN