From 7b29d177ac511cd68a4c51001e6ad3522268b08e Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Wed, 12 Jun 2024 11:27:10 +0100 Subject: [PATCH] add support for memories to c++ and smtlib functional backends --- backends/functional/cxx.cc | 110 ++++++++++++++++++-------- backends/functional/cxx_runtime/sim.h | 19 +++++ backends/functional/smtlib.cc | 72 ++++++++++++++--- kernel/drivertools.h | 7 ++ kernel/graphtools.h | 68 ++++++++++++++-- 5 files changed, 226 insertions(+), 50 deletions(-) diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index b42e20ca7..08d9ba791 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -83,6 +83,41 @@ struct CxxScope { } }; +struct CxxType { + bool _is_memory; + int _width; + int _addr_width; +public: + CxxType() : _is_memory(false), _width(0), _addr_width(0) { } + CxxType(int width) : _is_memory(false), _width(width), _addr_width(0) { } + CxxType(int addr_width, int data_width) : _is_memory(true), _width(data_width), _addr_width(addr_width) { } + static CxxType signal(int width) { return CxxType(width); } + static CxxType memory(int addr_width, int data_width) { return CxxType(addr_width, data_width); } + bool is_signal() const { return !_is_memory; } + bool is_memory() const { return _is_memory; } + int width() const { log_assert(is_signal()); return _width; } + int addr_width() const { log_assert(is_memory()); return _addr_width; } + int data_width() const { log_assert(is_memory()); return _width; } + std::string to_string() const { + if(_is_memory) { + return stringf("Memory<%d, %d>", addr_width(), data_width()); + } else { + return stringf("Signal<%d>", width()); + } + } + bool operator ==(CxxType const& other) const { + if(_is_memory != other._is_memory) return false; + if(_is_memory && _addr_width != other._addr_width) return false; + return _width == other._width; + } + unsigned int hash() const { + if(_is_memory) + return mkhash(1, mkhash(_width, _addr_width)); + else + return mkhash(0, _width); + } +}; + struct CxxWriter { std::ostream &f; CxxWriter(std::ostream &out) : f(out) {} @@ -97,7 +132,7 @@ struct CxxWriter { struct CxxStruct { std::string name; - dict types; + dict types; CxxScope scope; bool generate_methods; int count; @@ -106,14 +141,14 @@ struct CxxStruct { scope.reserve("out"); scope.reserve("dump"); } - void insert(IdString name, std::string type) { + void insert(IdString name, CxxType type) { scope.insert(name); types.insert({name, type}); } void print(CxxWriter &f) { f.printf("struct %s {\n", name.c_str()); for (auto p : types) { - f.printf("\t%s %s;\n", p.second.c_str(), scope[p.first].c_str()); + f.printf("\t%s %s;\n", p.second.to_string().c_str(), scope[p.first].c_str()); } f.printf("\n\ttemplate void dump(T &out) const {\n"); for (auto p : types) { @@ -149,7 +184,7 @@ struct CxxStruct { std::string generate_variant_types() const { std::set unique_types; for (const auto& p : types) { - unique_types.insert("std::reference_wrapper<" + p.second + ">"); + unique_types.insert("std::reference_wrapper<" + p.second.to_string() + ">"); } std::ostringstream oss; for (auto it = unique_types.begin(); it != unique_types.end(); ++it) { @@ -164,18 +199,18 @@ struct CxxStruct { struct CxxFunction { IdString name; - int width; + CxxType type; dict parameters; - CxxFunction(IdString name, int width) : name(name), width(width) {} - CxxFunction(IdString name, int width, dict parameters) : name(name), width(width), parameters(parameters) {} + CxxFunction(IdString name, CxxType type) : name(name), type(type) {} + CxxFunction(IdString name, CxxType type, dict parameters) : name(name), type(type), parameters(parameters) {} bool operator==(CxxFunction const &other) const { - return name == other.name && parameters == other.parameters && width == other.width; + return name == other.name && parameters == other.parameters && type == other.type; } unsigned int hash() const { - return mkhash(name.hash(), parameters.hash()); + return mkhash(name.hash(), mkhash(type.hash(), parameters.hash())); } }; @@ -232,6 +267,9 @@ public: } T input(IdString name, int width) { return graph.add(CxxFunction(ID($$input), width, {{name, {}}}), 0); } T state(IdString name, int width) { return graph.add(CxxFunction(ID($$state), width, {{name, {}}}), 0); } + T state_memory(IdString name, int addr_width, int data_width) { + return graph.add(CxxFunction(ID($$state), CxxType::memory(addr_width, data_width), {{name, {}}}), 0); + } T cell_output(T cell, IdString type, IdString name, int width) { if (is_single_output(type)) return cell; @@ -245,12 +283,19 @@ public: return graph.add(CxxFunction(ID($$undriven), width), 0); } + T memory_read(T mem, T addr, int addr_width, int data_width) { + return graph.add(CxxFunction(ID($memory_read), data_width), 0, std::array{mem, addr}); + } + T memory_write(T mem, T addr, T data, int addr_width, int data_width) { + return graph.add(CxxFunction(ID($memory_write), CxxType::memory(addr_width, data_width)), 0, std::array{mem, addr, data}); + } + T create_pending(int width) { return graph.add(CxxFunction(ID($$pending), width), 0); } void update_pending(T pending, T node) { log_assert(pending.function().name == ID($$pending)); - pending.set_function(CxxFunction(ID($$buf), pending.function().width)); + pending.set_function(CxxFunction(ID($$buf), pending.function().type)); pending.append_arg(node); } void declare_output(T node, IdString name, int) { @@ -259,6 +304,9 @@ public: void declare_state(T node, IdString name, int) { node.assign_key(name); } + void declare_state_memory(T node, IdString name, int, int) { + node.assign_key(name); + } void suggest_name(T node, IdString name) { node.sparse_attr() = name; } @@ -312,8 +360,10 @@ struct FunctionalCxxBackend : public Backend { int target_index = alias[node.arg(0).index()]; auto target_node = compute_graph[perm[target_index]]; - if(!target_node.has_sparse_attr() && node.has_sparse_attr()) - target_node.sparse_attr() = node.sparse_attr(); + if(!target_node.has_sparse_attr() && node.has_sparse_attr()){ + IdString id = node.sparse_attr(); + target_node.sparse_attr() = id; + } alias.push_back(target_index); } else @@ -328,7 +378,7 @@ struct FunctionalCxxBackend : public Backend void printCxx(std::ostream &stream, std::string, std::string const & name, CxxComputeGraph &compute_graph) { - dict inputs, state; + dict inputs, state; CxxWriter f(stream); // Dump the compute graph @@ -336,22 +386,22 @@ struct FunctionalCxxBackend : public Backend { auto ref = compute_graph[i]; if(ref.function().name == ID($$input)) - inputs[ref.function().parameters.begin()->first] = ref.function().width; + inputs[ref.function().parameters.begin()->first] = ref.function().type; if(ref.function().name == ID($$state)) - state[ref.function().parameters.begin()->first] = ref.function().width; + state[ref.function().parameters.begin()->first] = ref.function().type; } f.printf("#include \"sim.h\"\n"); f.printf("#include \n"); CxxStruct input_struct(name + "_Inputs", true, inputs.size()); for (auto const &input : inputs) - input_struct.insert(input.first, "Signal<" + std::to_string(input.second) + ">"); + input_struct.insert(input.first, input.second); CxxStruct output_struct(name + "_Outputs"); for (auto const &key : compute_graph.keys()) if(state.count(key.first) == 0) - output_struct.insert(key.first, "Signal<" + std::to_string(compute_graph[key.second].function().width) + ">"); + output_struct.insert(key.first, compute_graph[key.second].function().type); CxxStruct state_struct(name + "_State"); for (auto const &state_var : state) - state_struct.insert(state_var.first, "Signal<" + std::to_string(state_var.second) + ">"); + state_struct.insert(state_var.first, state_var.second); idict node_names; CxxScope locals; @@ -368,7 +418,7 @@ struct FunctionalCxxBackend : public Backend for (int i = 0; i < compute_graph.size(); ++i) { auto ref = compute_graph[i]; - int width = ref.function().width; + auto type = ref.function().type; std::string name; if(ref.has_sparse_attr()) name = locals.insert(ref.sparse_attr()); @@ -376,19 +426,19 @@ struct FunctionalCxxBackend : public Backend name = locals.insert("\\n" + std::to_string(i)); node_names(name); if(ref.function().name == ID($$input)) - f.printf("\tSignal<%d> %s = input.%s;\n", width, name.c_str(), input_struct[ref.function().parameters.begin()->first].c_str()); + f.printf("\t%s %s = input.%s;\n", type.to_string().c_str(), name.c_str(), input_struct[ref.function().parameters.begin()->first].c_str()); else if(ref.function().name == ID($$state)) - f.printf("\tSignal<%d> %s = current_state.%s;\n", width, name.c_str(), state_struct[ref.function().parameters.begin()->first].c_str()); + f.printf("\t%s %s = current_state.%s;\n", type.to_string().c_str(), name.c_str(), state_struct[ref.function().parameters.begin()->first].c_str()); else if(ref.function().name == ID($$buf)) - f.printf("\tSignal<%d> %s = %s;\n", width, name.c_str(), node_names[ref.arg(0).index()].c_str()); + f.printf("\t%s %s = %s;\n", type.to_string().c_str(), name.c_str(), node_names[ref.arg(0).index()].c_str()); else if(ref.function().name == ID($$cell_output)) - f.printf("\tSignal<%d> %s = %s.%s;\n", width, name.c_str(), node_names[ref.arg(0).index()].c_str(), RTLIL::unescape_id(ref.function().parameters.begin()->first).c_str()); + f.printf("\t%s %s = %s.%s;\n", type.to_string().c_str(), name.c_str(), node_names[ref.arg(0).index()].c_str(), RTLIL::unescape_id(ref.function().parameters.begin()->first).c_str()); else if(ref.function().name == ID($$const)){ auto c = ref.function().parameters.begin()->second; if(c.size() <= 32){ - f.printf("\tSignal<%d> %s = $const<%d>(%#x);\n", width, name.c_str(), width, (uint32_t) c.as_int()); + f.printf("\t%s %s = $const<%d>(%#x);\n", type.to_string().c_str(), name.c_str(), type.width(), (uint32_t) c.as_int()); }else{ - f.printf("\tSignal<%d> %s = $const<%d>({%#x", width, name.c_str(), width, (uint32_t) c.as_int()); + f.printf("\t%s %s = $const<%d>({%#x", type.to_string().c_str(), name.c_str(), type.width(), (uint32_t) c.as_int()); while(c.size() > 32){ c = c.extract(32, c.size() - 32); f.printf(", %#x", c.as_int()); @@ -396,9 +446,9 @@ struct FunctionalCxxBackend : public Backend f.printf("});\n"); } }else if(ref.function().name == ID($$undriven)) - f.printf("\tSignal<%d> %s; //undriven\n", width, name.c_str()); + f.printf("\t%s %s; //undriven\n", type.to_string().c_str(), name.c_str()); else if(ref.function().name == ID($$slice)) - f.printf("\tSignal<%d> %s = slice<%d>(%s, %d);\n", width, name.c_str(), width, node_names[ref.arg(0).index()].c_str(), ref.function().parameters.at(ID(offset)).as_int()); + f.printf("\t%s %s = slice<%d>(%s, %d);\n", type.to_string().c_str(), name.c_str(), type.width(), node_names[ref.arg(0).index()].c_str(), ref.function().parameters.at(ID(offset)).as_int()); else if(ref.function().name == ID($$concat)){ f.printf("\tauto %s = concat(", name.c_str()); for (int i = 0, end = ref.size(); i != end; ++i){ @@ -409,11 +459,7 @@ struct FunctionalCxxBackend : public Backend f.printf(");\n"); }else{ f.printf("\t"); - if(ref.function().width > 0) - f.printf("Signal<%d>", ref.function().width); - else - f.printf("%s_Outputs", log_id(ref.function().name)); - f.printf(" %s = %s", name.c_str(), log_id(ref.function().name)); + f.printf("%s %s = %s", type.to_string().c_str(), name.c_str(), log_id(ref.function().name)); if(ref.function().parameters.count(ID(WIDTH))){ f.printf("<%d>", ref.function().parameters.at(ID(WIDTH)).as_int()); } diff --git a/backends/functional/cxx_runtime/sim.h b/backends/functional/cxx_runtime/sim.h index 75d045ccf..310927f5b 100644 --- a/backends/functional/cxx_runtime/sim.h +++ b/backends/functional/cxx_runtime/sim.h @@ -363,4 +363,23 @@ Signal $sign_extend(Signal const& a) return ret; } +template +struct Memory { + std::array, 1< contents; +}; + +template +Signal $memory_read(Memory memory, Signal addr) +{ + return memory.contents[as_int(addr)]; +} + +template +Memory $memory_write(Memory memory, Signal addr, Signal data) +{ + Memory ret = memory; + ret.contents[as_int(addr)] = data; + return ret; +} + #endif diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc index 7a283cccf..a90927ab6 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib.cc @@ -219,19 +219,54 @@ std::ostream& operator << (std::ostream &os, SExpr const &s) { return os; } +struct SmtlibType { + bool _is_memory; + int _width; + int _addr_width; +public: + SmtlibType() : _is_memory(false), _width(0), _addr_width(0) { } + SmtlibType(int width) : _is_memory(false), _width(width), _addr_width(0) { } + SmtlibType(int addr_width, int data_width) : _is_memory(true), _width(data_width), _addr_width(addr_width) { } + static SmtlibType signal(int width) { return SmtlibType(width); } + static SmtlibType memory(int addr_width, int data_width) { return SmtlibType(addr_width, data_width); } + bool is_signal() const { return !_is_memory; } + bool is_memory() const { return _is_memory; } + int width() const { log_assert(is_signal()); return _width; } + int addr_width() const { log_assert(is_memory()); return _addr_width; } + int data_width() const { log_assert(is_memory()); return _width; } + SExpr to_sexpr() const { + if(_is_memory) { + return SExpr{ "Array", SExpr{ "_", "BitVec", addr_width() }, SExpr{ "_", "BitVec", data_width() }}; + } else { + return SExpr{ "_", "BitVec", width() }; + } + } + bool operator ==(SmtlibType const& other) const { + if(_is_memory != other._is_memory) return false; + if(_is_memory && _addr_width != other._addr_width) return false; + return _width == other._width; + } + unsigned int hash() const { + if(_is_memory) + return mkhash(1, mkhash(_width, _addr_width)); + else + return mkhash(0, _width); + } +}; + struct SmtlibStruct { SmtlibScope &scope; std::string name; idict members; - vector widths; + vector types; vector accessors; SmtlibStruct(std::string name, SmtlibScope &scope) : scope(scope), name(name) { } - std::string insert(IdString field, int width) { + std::string insert(IdString field, SmtlibType type) { if(members.at(field, -1) == -1){ members(field); scope.insert(field); - widths.push_back(width); + types.push_back(type); accessors.push_back(scope.insert(std::string("\\") + name + "_" + RTLIL::unescape_id(field))); } return scope[field]; @@ -239,7 +274,7 @@ struct SmtlibStruct { void print(SmtlibWriter &f) { f.printf("(declare-datatype %s ((%s\n", name.c_str(), name.c_str()); for (size_t i = 0; i < members.size(); i++) - f.printf(" (%s (_ BitVec %d))\n", accessors[i].c_str(), widths[i]); + f << " " << SExpr{accessors[i], types[i].to_sexpr()} << "\n"; f.printf(")))\n"); } void print_value(SmtlibWriter &f, dict values, int indentation) { @@ -260,16 +295,16 @@ struct SmtlibStruct { struct Node { SExpr expr; - int width; + SmtlibType type; - Node(SExpr &&expr, int width) : expr(std::move(expr)), width(width) {} + Node(SExpr &&expr, SmtlibType type) : expr(std::move(expr)), type(type) {} bool operator==(Node const &other) const { - return expr == other.expr && width == other.width; + return expr == other.expr && type == other.type; } unsigned int hash() const { - return mkhash(expr.hash(), width); + return mkhash(expr.hash(), type.hash()); } }; @@ -302,8 +337,8 @@ class SmtlibComputeGraphFactory { auto it = yosys_celltypes.cell_types.find(type); return it != yosys_celltypes.cell_types.end() && it->second.outputs.size() <= 1; } - T node(SExpr &&expr, int width, std::initializer_list args) { - return graph.add(Node(std::move(expr), width), 0, args); + T node(SExpr &&expr, SmtlibType type, std::initializer_list args) { + return graph.add(Node(std::move(expr), type), 0, args); } T shift(const char *name, T a, T b, int y_width, int b_width, bool a_signed = false) { int width = max(y_width, b_width); @@ -367,6 +402,13 @@ public: T logical_shift_right(T a, T b, int y_width, int b_width) { return shift("bvlshl", a, b, y_width, b_width); } T arithmetic_shift_right(T a, T b, int y_width, int b_width) { return shift("bvashr", a, b, y_width, b_width, true); } + T memory_read(T mem, T addr, int addr_width, int data_width) { + return node(SExpr {"select", Arg(1), Arg(2)}, data_width, {mem, addr}); + } + T memory_write(T mem, T addr, T data, int addr_width, int data_width) { + return node(SExpr {"store", Arg(1), Arg(2), Arg(3)}, SmtlibType::memory(addr_width, data_width), {mem, addr, data}); + } + T constant(RTLIL::Const value) { return node(SExpr(value), value.size(), {}); } T input(IdString name, int width) { module.input_struct.insert(name, width); @@ -376,6 +418,10 @@ public: module.state_struct.insert(name, width); return node(module.state_struct.access("current_state", name), width, {}); } + T state_memory(IdString name, int addr_width, int data_width) { + module.state_struct.insert(name, SmtlibType::memory(addr_width, data_width)); + return node(module.state_struct.access("current_state", name), SmtlibType::memory(addr_width, data_width), {}); + } T cell_output(T cell, IdString type, IdString name, int width) { if (is_single_output(type)) return cell; @@ -399,7 +445,7 @@ public: } void update_pending(T pending, T node) { log_assert(pending.function().expr.is_none()); - pending.set_function(Node(Arg(1), pending.function().width)); + pending.set_function(Node(Arg(1), pending.function().type)); pending.append_arg(node); } void declare_output(T node, IdString name, int width) { @@ -410,6 +456,10 @@ public: module.state_struct.insert(name, width); node.assign_key(name); } + void declare_state_memory(T node, IdString name, int addr_width, int data_width) { + module.state_struct.insert(name, SmtlibType::memory(addr_width, data_width)); + node.assign_key(name); + } void suggest_name(T node, IdString name) { node.sparse_attr() = name; } diff --git a/kernel/drivertools.h b/kernel/drivertools.h index 1cb835df2..48c846b5f 100644 --- a/kernel/drivertools.h +++ b/kernel/drivertools.h @@ -1064,6 +1064,13 @@ public: append(bit); } + DriveSpec(SigSpec const &sig) + { + // TODO: converting one chunk at a time would be faster + for (auto const &bit : sig.bits()) + append(bit); + } + std::vector const &chunks() const { pack(); return chunks_; } std::vector const &bits() const { unpack(); return bits_; } diff --git a/kernel/graphtools.h b/kernel/graphtools.h index adf4764c2..98046c6be 100644 --- a/kernel/graphtools.h +++ b/kernel/graphtools.h @@ -23,6 +23,7 @@ #include "kernel/yosys.h" #include "kernel/drivertools.h" #include "kernel/functional.h" +#include "kernel/mem.h" USING_YOSYS_NAMESPACE YOSYS_NAMESPACE_BEGIN @@ -71,6 +72,11 @@ public: T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); return factory.arithmetic_shift_right(a, reduced_b, y_width, reduced_b_width); } + T bitwise_mux(T a, T b, T s, int width) { + T aa = factory.bitwise_and(a, factory.bitwise_not(s, width), width); + T bb = factory.bitwise_and(b, s, width); + return factory.bitwise_or(aa, bb, width); + } CellSimplifier(Factory &f) : factory(f) {} T handle(IdString cellType, dict parameters, dict inputs) { @@ -104,7 +110,7 @@ public: T b = extend(inputs.at(ID(B)), b_width, width, is_signed); if(cellType.in({ID($eq), ID($eqx)})) return extend(factory.eq(a, b, width), 1, y_width, false); - if(cellType.in({ID($ne), ID($nex)})) + else if(cellType.in({ID($ne), ID($nex)})) return extend(factory.ne(a, b, width), 1, y_width, false); else if(cellType == ID($lt)) return extend(is_signed ? factory.gt(b, a, width) : factory.ugt(b, a, width), 1, y_width, false); @@ -197,6 +203,8 @@ class ComputeGraphConstruction { DriverMap driver_map; Factory& factory; CellSimplifier simplifier; + vector memories_vector; + dict memories; T enqueue(DriveSpec const &spec) { @@ -224,6 +232,45 @@ public: factory.declare_output(node, wire->name, wire->width); } } + memories_vector = Mem::get_all_memories(module); + for (auto &mem : memories_vector) { + if (mem.cell != nullptr) + memories[mem.cell] = &mem; + } + } + T concatenate_read_results(Mem *mem, vector results) + { + if(results.size() == 0) + return factory.undriven(0); + T node = results[0]; + int size = results[0].size(); + for(size_t i = 1; i < results.size(); i++) { + node = factory.concat(node, size, results[i], results[i].size()); + size += results[i].size(); + } + return node; + } + T handle_memory(Mem *mem) + { + vector read_results; + int addr_width = ceil_log2(mem->size); + int data_width = mem->width; + T node = factory.state_memory(mem->cell->name, addr_width, data_width); + for (auto &rd : mem->rd_ports) { + log_assert(!rd.clk_enable); + T addr = enqueue(driver_map(DriveSpec(rd.addr))); + read_results.push_back(factory.memory_read(node, addr, addr_width, data_width)); + } + for (auto &wr : mem->wr_ports) { + T en = enqueue(driver_map(DriveSpec(wr.en))); + T addr = enqueue(driver_map(DriveSpec(wr.addr))); + T new_data = enqueue(driver_map(DriveSpec(wr.data))); + T old_data = factory.memory_read(node, addr, addr_width, data_width); + T wr_data = simplifier.bitwise_mux(old_data, new_data, en, data_width); + node = factory.memory_write(node, addr, wr_data, addr_width, data_width); + } + factory.declare_state_memory(node, mem->cell->name, addr_width, data_width); + return concatenate_read_results(mem, read_results); } void process_queue() { @@ -306,13 +353,20 @@ public: factory.update_pending(pending, node); } else if (chunk.is_marker()) { Cell *cell = cells[chunk.marker().marker]; - dict connections; - for(auto const &conn : cell->connections()) { - if(driver_map.celltypes.cell_input(cell->type, conn.first)) - connections.insert({ conn.first, enqueue(DriveChunkPort(cell, conn)) }); + if (cell->is_mem_cell()) { + Mem *mem = memories.at(cell, nullptr); + log_assert(mem != nullptr); + T node = handle_memory(mem); + factory.update_pending(pending, node); + } else { + dict connections; + for(auto const &conn : cell->connections()) { + if(driver_map.celltypes.cell_input(cell->type, conn.first)) + connections.insert({ conn.first, enqueue(DriveChunkPort(cell, conn)) }); + } + T node = simplifier.handle(cell->type, cell->parameters, connections); + factory.update_pending(pending, node); } - T node = simplifier.handle(cell->type, cell->parameters, connections); - factory.update_pending(pending, node); } else if (chunk.is_none()) { T node = factory.undriven(chunk.size()); factory.update_pending(pending, node);