/* * yosys -- Yosys Open SYnthesis Suite * * Copyright (C) 2024 Emily Schmidt * Copyright (C) 2024 National Technology and Engineering Solutions of Sandia, LLC * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. * */ #ifndef FUNCTIONAL_H #define FUNCTIONAL_H #include "kernel/yosys.h" #include "kernel/compute_graph.h" #include "kernel/drivertools.h" #include "kernel/mem.h" #include "kernel/utils.h" USING_YOSYS_NAMESPACE YOSYS_NAMESPACE_BEGIN namespace Functional { // each function is documented with a short pseudocode declaration or definition // standard C/Verilog operators are used to describe the result // // the sorts used in this are: // - bit[N]: a bitvector of N bits // bit[N] can be indicated as signed or unsigned. this is not tracked by the functional backend // but is meant to indicate how the value is interpreted // if a bit[N] is marked as neither signed nor unsigned, this means the result should be valid with *either* interpretation // - memory[N, M]: a memory with N address and M data bits // - int: C++ int // - Const[N]: yosys RTLIL::Const (with size() == N) // - IdString: yosys IdString // - any: used in documentation to indicate that the sort is unconstrained // // nodes in the functional backend are either of sort bit[N] or memory[N,M] (for some N, M: int) // additionally, they can carry a constant of sort int, Const[N] or IdString // each node has a 'sort' field that stores the sort of the node // slice, zero_extend, sign_extend use the sort field to store out_width enum class Fn { // invalid() = known-invalid/shouldn't happen value // TODO: maybe remove this and use e.g. std::optional instead? invalid, // buf(a: any): any = a // no-op operation // when constructing the compute graph we generate invalid buf() nodes as a placeholder // and later insert the argument buf, // slice(a: bit[in_width], offset: int, out_width: int): bit[out_width] = a[offset +: out_width] // required: offset + out_width <= in_width slice, // zero_extend(a: unsigned bit[in_width], out_width: int): unsigned bit[out_width] = a (zero extended) // required: out_width > in_width zero_extend, // sign_extend(a: signed bit[in_width], out_width: int): signed bit[out_width] = a (sign extended) // required: out_width > in_width sign_extend, // concat(a: bit[N], b: bit[M]): bit[N+M] = {b, a} (verilog syntax) // concatenates two bitvectors, with a in the least significant position and b in the more significant position concat, // add(a: bit[N], b: bit[N]): bit[N] = a + b add, // sub(a: bit[N], b: bit[N]): bit[N] = a - b sub, // mul(a: bit[N], b: bit[N]): bit[N] = a * b mul, // unsigned_div(a: unsigned bit[N], b: unsigned bit[N]): bit[N] = a / b unsigned_div, // unsigned_mod(a: signed bit[N], b: signed bit[N]): bit[N] = a % b unsigned_mod, // bitwise_and(a: bit[N], b: bit[N]): bit[N] = a & b bitwise_and, // bitwise_or(a: bit[N], b: bit[N]): bit[N] = a | b bitwise_or, // bitwise_xor(a: bit[N], b: bit[N]): bit[N] = a ^ b bitwise_xor, // bitwise_not(a: bit[N]): bit[N] = ~a bitwise_not, // reduce_and(a: bit[N]): bit[1] = &a reduce_and, // reduce_or(a: bit[N]): bit[1] = |a reduce_or, // reduce_xor(a: bit[N]): bit[1] = ^a reduce_xor, // unary_minus(a: bit[N]): bit[N] = -a unary_minus, // equal(a: bit[N], b: bit[N]): bit[1] = (a == b) equal, // not_equal(a: bit[N], b: bit[N]): bit[1] = (a != b) not_equal, // signed_greater_than(a: signed bit[N], b: signed bit[N]): bit[1] = (a > b) signed_greater_than, // signed_greater_equal(a: signed bit[N], b: signed bit[N]): bit[1] = (a >= b) signed_greater_equal, // unsigned_greater_than(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a > b) unsigned_greater_than, // unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b) unsigned_greater_equal, // logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b // required: M == clog2(N) logical_shift_left, // logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b // required: M == clog2(N) logical_shift_right, // arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b // required: M == clog2(N) arithmetic_shift_right, // mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a mux, // constant(a: Const[N]): bit[N] = a constant, // input(a: IdString): any // returns the current value of the input with the specified name input, // state(a: IdString): any // returns the current value of the state variable with the specified name state, // memory_read(memory: memory[addr_width, data_width], addr: bit[addr_width]): bit[data_width] = memory[addr] memory_read, // memory_write(memory: memory[addr_width, data_width], addr: bit[addr_width], data: bit[data_width]): memory[addr_width, data_width] // returns a copy of `memory` but with the value at `addr` changed to `data` memory_write }; // returns the name of a Fn value, as a string literal const char *fn_to_string(Fn); // Sort represents the sort or type of a node // currently the only two sorts are signal/bit and memory class Sort { std::variant> _v; public: explicit Sort(int width) : _v(width) { } Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { } bool is_signal() const { return _v.index() == 0; } bool is_memory() const { return _v.index() == 1; } // returns the width of a bitvector sort, errors out for other sorts int width() const { return std::get<0>(_v); } // returns the address width of a bitvector sort, errors out for other sorts int addr_width() const { return std::get<1>(_v).first; } // returns the data width of a bitvector sort, errors out for other sorts int data_width() const { return std::get<1>(_v).second; } bool operator==(Sort const& other) const { return _v == other._v; } Hasher hash_into(Hasher h) const { h.eat(_v); return h; } }; class IR; class Factory; class Node; class IRInput { friend class Factory; public: IdString name; IdString kind; Sort sort; private: IRInput(IR &, IdString name, IdString kind, Sort sort) : name(name), kind(kind), sort(std::move(sort)) {} }; class IROutput { friend class Factory; IR &_ir; public: IdString name; IdString kind; Sort sort; private: IROutput(IR &ir, IdString name, IdString kind, Sort sort) : _ir(ir), name(name), kind(kind), sort(std::move(sort)) {} public: Node value() const; bool has_value() const; void set_value(Node value); }; class IRState { friend class Factory; IR &_ir; public: IdString name; IdString kind; Sort sort; private: std::variant _initial; IRState(IR &ir, IdString name, IdString kind, Sort sort) : _ir(ir), name(name), kind(kind), sort(std::move(sort)) {} public: Node next_value() const; bool has_next_value() const; RTLIL::Const const& initial_value_signal() const { return std::get(_initial); } MemContents const& initial_value_memory() const { return std::get(_initial); } void set_next_value(Node value); void set_initial_value(RTLIL::Const value) { value.extu(sort.width()); _initial = std::move(value); } void set_initial_value(MemContents value) { log_assert(Sort(value.addr_width(), value.data_width()) == sort); _initial = std::move(value); } }; class IR { friend class Factory; friend class Node; friend class IRInput; friend class IROutput; friend class IRState; // one NodeData is stored per Node, containing the function and non-node arguments // note that NodeData is deduplicated by ComputeGraph class NodeData { Fn _fn; std::variant< std::monostate, RTLIL::Const, std::pair, int > _extra; public: NodeData() : _fn(Fn::invalid) {} NodeData(Fn fn) : _fn(fn) {} template NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward(extra)) {} Fn fn() const { return _fn; } const RTLIL::Const &as_const() const { return std::get(_extra); } std::pair as_idstring_pair() const { return std::get>(_extra); } int as_int() const { return std::get(_extra); } Hasher hash_into(Hasher h) const { h.eat((unsigned int) _fn); h.eat(_extra); return h; } bool operator==(NodeData const &other) const { return _fn == other._fn && _extra == other._extra; } }; // Attr contains all the information about a note that should not be deduplicated struct Attr { Sort sort; }; // our specialised version of ComputeGraph // the sparse_attr IdString stores a naming suggestion, retrieved with name() // the key is currently used to identify the nodes that represent output and next state values // the bool is true for next state values using Graph = ComputeGraph>; Graph _graph; dict, IRInput> _inputs; dict, IROutput> _outputs; dict, IRState> _states; IR::Graph::Ref mutate(Node n); public: static IR from_module(Module *module); Factory factory(); int size() const { return _graph.size(); } Node operator[](int i); void topological_sort(); void forward_buf(); IRInput const& input(IdString name, IdString kind) const { return _inputs.at({name, kind}); } IRInput const& input(IdString name) const { return input(name, ID($input)); } IROutput const& output(IdString name, IdString kind) const { return _outputs.at({name, kind}); } IROutput const& output(IdString name) const { return output(name, ID($output)); } IRState const& state(IdString name, IdString kind) const { return _states.at({name, kind}); } IRState const& state(IdString name) const { return state(name, ID($state)); } bool has_input(IdString name, IdString kind) const { return _inputs.count({name, kind}); } bool has_output(IdString name, IdString kind) const { return _outputs.count({name, kind}); } bool has_state(IdString name, IdString kind) const { return _states.count({name, kind}); } vector inputs(IdString kind) const; vector inputs() const { return inputs(ID($input)); } vector outputs(IdString kind) const; vector outputs() const { return outputs(ID($output)); } vector states(IdString kind) const; vector states() const { return states(ID($state)); } vector all_inputs() const; vector all_outputs() const; vector all_states() const; class iterator { friend class IR; IR *_ir; int _index; iterator(IR *ir, int index) : _ir(ir), _index(index) {} public: using iterator_category = std::input_iterator_tag; using value_type = Node; using pointer = arrow_proxy; using reference = Node; using difference_type = ptrdiff_t; Node operator*(); iterator &operator++() { _index++; return *this; } bool operator!=(iterator const &other) const { return _ir != other._ir || _index != other._index; } bool operator==(iterator const &other) const { return !(*this != other); } pointer operator->(); }; iterator begin() { return iterator(this, 0); } iterator end() { return iterator(this, _graph.size()); } }; // Node is an immutable reference to a FunctionalIR node class Node { friend class Factory; friend class IR; friend class IRInput; friend class IROutput; friend class IRState; IR::Graph::ConstRef _ref; explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { } explicit operator IR::Graph::ConstRef() { return _ref; } public: // the node's index. may change if nodes are added or removed int id() const { return _ref.index(); } // a name suggestion for the node, which need not be unique IdString name() const { if(_ref.has_sparse_attr()) return _ref.sparse_attr(); else return std::string("\\n") + std::to_string(id()); } Fn fn() const { return _ref.function().fn(); } Sort sort() const { return _ref.attr().sort; } // returns the width of a bitvector node, errors out for other nodes int width() const { return sort().width(); } size_t arg_count() const { return _ref.size(); } Node arg(int n) const { return Node(_ref.arg(n)); } // visit calls the appropriate visitor method depending on the type of the node template auto visit(Visitor v) const { // currently templated but could be switched to AbstractVisitor & switch(_ref.function().fn()) { case Fn::invalid: log_error("invalid node in visit"); break; case Fn::buf: return v.buf(*this, arg(0)); break; case Fn::slice: return v.slice(*this, arg(0), _ref.function().as_int(), sort().width()); break; case Fn::zero_extend: return v.zero_extend(*this, arg(0), width()); break; case Fn::sign_extend: return v.sign_extend(*this, arg(0), width()); break; case Fn::concat: return v.concat(*this, arg(0), arg(1)); break; case Fn::add: return v.add(*this, arg(0), arg(1)); break; case Fn::sub: return v.sub(*this, arg(0), arg(1)); break; case Fn::mul: return v.mul(*this, arg(0), arg(1)); break; case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1)); break; case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1)); break; case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1)); break; case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1)); break; case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1)); break; case Fn::bitwise_not: return v.bitwise_not(*this, arg(0)); break; case Fn::unary_minus: return v.unary_minus(*this, arg(0)); break; case Fn::reduce_and: return v.reduce_and(*this, arg(0)); break; case Fn::reduce_or: return v.reduce_or(*this, arg(0)); break; case Fn::reduce_xor: return v.reduce_xor(*this, arg(0)); break; case Fn::equal: return v.equal(*this, arg(0), arg(1)); break; case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1)); break; case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1)); break; case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1)); break; case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1)); break; case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1)); break; case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1)); break; case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break; case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break; case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break; case Fn::constant: return v.constant(*this, _ref.function().as_const()); break; case Fn::input: return v.input(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break; case Fn::state: return v.state(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break; case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break; case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break; } log_abort(); } std::string to_string(); std::string to_string(std::function); }; inline IR::Graph::Ref IR::mutate(Node n) { return _graph[n._ref.index()]; } inline Node IR::operator[](int i) { return Node(_graph[i]); } inline Node IROutput::value() const { return Node(_ir._graph({name, kind, false})); } inline bool IROutput::has_value() const { return _ir._graph.has_key({name, kind, false}); } inline void IROutput::set_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, kind, false}); } inline Node IRState::next_value() const { return Node(_ir._graph({name, kind, true})); } inline bool IRState::has_next_value() const { return _ir._graph.has_key({name, kind, true}); } inline void IRState::set_next_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, kind, true}); } inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); } inline arrow_proxy IR::iterator::operator->() { return arrow_proxy(**this); } // AbstractVisitor provides an abstract base class for visitors template struct AbstractVisitor { virtual T buf(Node self, Node n) = 0; virtual T slice(Node self, Node a, int offset, int out_width) = 0; virtual T zero_extend(Node self, Node a, int out_width) = 0; virtual T sign_extend(Node self, Node a, int out_width) = 0; virtual T concat(Node self, Node a, Node b) = 0; virtual T add(Node self, Node a, Node b) = 0; virtual T sub(Node self, Node a, Node b) = 0; virtual T mul(Node self, Node a, Node b) = 0; virtual T unsigned_div(Node self, Node a, Node b) = 0; virtual T unsigned_mod(Node self, Node a, Node b) = 0; virtual T bitwise_and(Node self, Node a, Node b) = 0; virtual T bitwise_or(Node self, Node a, Node b) = 0; virtual T bitwise_xor(Node self, Node a, Node b) = 0; virtual T bitwise_not(Node self, Node a) = 0; virtual T unary_minus(Node self, Node a) = 0; virtual T reduce_and(Node self, Node a) = 0; virtual T reduce_or(Node self, Node a) = 0; virtual T reduce_xor(Node self, Node a) = 0; virtual T equal(Node self, Node a, Node b) = 0; virtual T not_equal(Node self, Node a, Node b) = 0; virtual T signed_greater_than(Node self, Node a, Node b) = 0; virtual T signed_greater_equal(Node self, Node a, Node b) = 0; virtual T unsigned_greater_than(Node self, Node a, Node b) = 0; virtual T unsigned_greater_equal(Node self, Node a, Node b) = 0; virtual T logical_shift_left(Node self, Node a, Node b) = 0; virtual T logical_shift_right(Node self, Node a, Node b) = 0; virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0; virtual T mux(Node self, Node a, Node b, Node s) = 0; virtual T constant(Node self, RTLIL::Const const & value) = 0; virtual T input(Node self, IdString name, IdString kind) = 0; virtual T state(Node self, IdString name, IdString kind) = 0; virtual T memory_read(Node self, Node mem, Node addr) = 0; virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0; }; // DefaultVisitor provides defaults for all visitor methods which just calls default_handler template struct DefaultVisitor : public AbstractVisitor { virtual T default_handler(Node self) = 0; T buf(Node self, Node) override { return default_handler(self); } T slice(Node self, Node, int, int) override { return default_handler(self); } T zero_extend(Node self, Node, int) override { return default_handler(self); } T sign_extend(Node self, Node, int) override { return default_handler(self); } T concat(Node self, Node, Node) override { return default_handler(self); } T add(Node self, Node, Node) override { return default_handler(self); } T sub(Node self, Node, Node) override { return default_handler(self); } T mul(Node self, Node, Node) override { return default_handler(self); } T unsigned_div(Node self, Node, Node) override { return default_handler(self); } T unsigned_mod(Node self, Node, Node) override { return default_handler(self); } T bitwise_and(Node self, Node, Node) override { return default_handler(self); } T bitwise_or(Node self, Node, Node) override { return default_handler(self); } T bitwise_xor(Node self, Node, Node) override { return default_handler(self); } T bitwise_not(Node self, Node) override { return default_handler(self); } T unary_minus(Node self, Node) override { return default_handler(self); } T reduce_and(Node self, Node) override { return default_handler(self); } T reduce_or(Node self, Node) override { return default_handler(self); } T reduce_xor(Node self, Node) override { return default_handler(self); } T equal(Node self, Node, Node) override { return default_handler(self); } T not_equal(Node self, Node, Node) override { return default_handler(self); } T signed_greater_than(Node self, Node, Node) override { return default_handler(self); } T signed_greater_equal(Node self, Node, Node) override { return default_handler(self); } T unsigned_greater_than(Node self, Node, Node) override { return default_handler(self); } T unsigned_greater_equal(Node self, Node, Node) override { return default_handler(self); } T logical_shift_left(Node self, Node, Node) override { return default_handler(self); } T logical_shift_right(Node self, Node, Node) override { return default_handler(self); } T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); } T mux(Node self, Node, Node, Node) override { return default_handler(self); } T constant(Node self, RTLIL::Const const &) override { return default_handler(self); } T input(Node self, IdString, IdString) override { return default_handler(self); } T state(Node self, IdString, IdString) override { return default_handler(self); } T memory_read(Node self, Node, Node) override { return default_handler(self); } T memory_write(Node self, Node, Node, Node) override { return default_handler(self); } }; // a factory is used to modify a FunctionalIR. it creates new nodes and allows for some modification of existing nodes. class Factory { friend class IR; IR &_ir; explicit Factory(IR &ir) : _ir(ir) {} Node add(IR::NodeData &&fn, Sort const &sort, std::initializer_list args) { log_assert(!sort.is_signal() || sort.width() > 0); log_assert(!sort.is_memory() || (sort.addr_width() > 0 && sort.data_width() > 0)); IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)}); for (auto arg : args) ref.append_arg(IR::Graph::ConstRef(arg)); return Node(ref); } void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); } void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); } void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } public: IR &ir() { return _ir; } Node slice(Node a, int offset, int out_width) { log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); if(offset == 0 && out_width == a.width()) return a; return add(IR::NodeData(Fn::slice, offset), Sort(out_width), {a}); } // extend will either extend or truncate the provided value to reach the desired width Node extend(Node a, int out_width, bool is_signed) { int in_width = a.sort().width(); log_assert(a.sort().is_signal()); if(in_width == out_width) return a; if(in_width > out_width) return slice(a, 0, out_width); if(is_signed) return add(Fn::sign_extend, Sort(out_width), {a}); else return add(Fn::zero_extend, Sort(out_width), {a}); } Node concat(Node a, Node b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b}); } Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } Node reduce_and(Node a) { check_unary(a); if(a.width() == 1) return a; return add(Fn::reduce_and, Sort(1), {a}); } Node reduce_or(Node a) { check_unary(a); if(a.width() == 1) return a; return add(Fn::reduce_or, Sort(1), {a}); } Node reduce_xor(Node a) { check_unary(a); if(a.width() == 1) return a; return add(Fn::reduce_xor, Sort(1), {a}); } Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } Node mux(Node a, Node b, Node s) { log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); return add(Fn::mux, a.sort(), {a, b, s}); } Node memory_read(Node mem, Node addr) { log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr}); } Node memory_write(Node mem, Node addr, Node data) { log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() && mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width()); return add(Fn::memory_write, mem.sort(), {mem, addr, data}); } Node constant(RTLIL::Const value) { int s = value.size(); return add(IR::NodeData(Fn::constant, std::move(value)), Sort(s), {}); } Node create_pending(int width) { return add(Fn::buf, Sort(width), {}); } void update_pending(Node node, Node value) { log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0); log_assert(node.sort() == value.sort()); _ir.mutate(node).append_arg(value._ref); } IRInput &add_input(IdString name, IdString kind, Sort sort) { auto [it, inserted] = _ir._inputs.emplace({name, kind}, IRInput(_ir, name, kind, std::move(sort))); if (!inserted) log_error("input `%s` was re-defined", name.c_str()); return it->second; } IROutput &add_output(IdString name, IdString kind, Sort sort) { auto [it, inserted] = _ir._outputs.emplace({name, kind}, IROutput(_ir, name, kind, std::move(sort))); if (!inserted) log_error("output `%s` was re-defined", name.c_str()); return it->second; } IRState &add_state(IdString name, IdString kind, Sort sort) { auto [it, inserted] = _ir._states.emplace({name, kind}, IRState(_ir, name, kind, std::move(sort))); if (!inserted) log_error("state `%s` was re-defined", name.c_str()); return it->second; } Node value(IRInput const& input) { return add(IR::NodeData(Fn::input, std::pair(input.name, input.kind)), input.sort, {}); } Node value(IRState const& state) { return add(IR::NodeData(Fn::state, std::pair(state.name, state.kind)), state.sort, {}); } void suggest_name(Node node, IdString name) { _ir.mutate(node).sparse_attr() = name; } }; inline Factory IR::factory() { return Factory(*this); } template class Scope { protected: char substitution_character = '_'; virtual bool is_character_legal(char, int) = 0; private: pool _used_names; dict _by_id; public: void reserve(std::string name) { _used_names.insert(std::move(name)); } std::string unique_name(IdString suggestion) { std::string str = RTLIL::unescape_id(suggestion); for(size_t i = 0; i < str.size(); i++) if(!is_character_legal(str[i], i)) str[i] = substitution_character; if(_used_names.count(str) == 0) { _used_names.insert(str); return str; } for (int idx = 0 ; ; idx++){ std::string suffixed = str + "_" + std::to_string(idx); if(_used_names.count(suffixed) == 0) { _used_names.insert(suffixed); return suffixed; } } } std::string operator()(Id id, IdString suggestion) { auto it = _by_id.find(id); if(it != _by_id.end()) return it->second; std::string str = unique_name(suggestion); _by_id.insert({id, str}); return str; } }; class Writer { std::ostream *os; void print_impl(const char *fmt, vector>& fns); public: Writer(std::ostream &os) : os(&os) {} template Writer& operator <<(T&& arg) { *os << std::forward(arg); return *this; } template void print(const char *fmt, Args&&... args) { vector> fns { [&]() { *this << args; }... }; print_impl(fmt, fns); } template void print_with(Fn fn, const char *fmt, Args&&... args) { vector> fns { [&]() { if constexpr (std::is_invocable_v) *this << fn(args); else *this << args; }... }; print_impl(fmt, fns); } }; } YOSYS_NAMESPACE_END #endif