diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib_rosette.cc index 3eacf407c..c9e737d19 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib_rosette.cc @@ -29,80 +29,86 @@ PRIVATE_NAMESPACE_BEGIN using SExprUtil::list; const char *reserved_keywords[] = { - // reserved keywords from the smtlib spec - "BINARY", "DECIMAL", "HEXADECIMAL", "NUMERAL", "STRING", "_", "!", "as", "let", "exists", "forall", "match", "par", - "assert", "check-sat", "check-sat-assuming", "declare-const", "declare-datatype", "declare-datatypes", - "declare-fun", "declare-sort", "define-fun", "define-fun-rec", "define-funs-rec", "define-sort", - "exit", "get-assertions", "symbol", "sort", "get-assignment", "get-info", "get-model", - "get-option", "get-proof", "get-unsat-assumptions", "get-unsat-core", "get-value", - "pop", "push", "reset", "reset-assertions", "set-info", "set-logic", "set-option", + // reserved keywords from the racket spec + "struct", "lambda", "values", "extract", "concat", "bv", "let", "define", "cons", "list", "read", "write", + "stream", "error", "raise", "exit", "for", "begin", "when", "unless", "module", "require", "provide", "apply", + "if", "cond", "even", "odd", "any", "and", "or", "match", "command-line", "ffi-lib", "thread", "kill", "sync", + "future", "touch", "subprocess", "make-custodian", "custodian-shutdown-all", "current-custodian", "make", "tcp", + "connect", "prepare", "malloc", "free", "_fun", "_cprocedure", "build", "path", "file", "peek", "bytes", + "flush", "with", "lexer", "parser", "syntax", "interface", "send", "make-object", "new", "instantiate", + "define-generics", "set", // reserved for our own purposes - "pair", "Pair", "first", "second", - "inputs", "state", + "inputs", "state", "name", nullptr }; -struct SmtScope : public Functional::Scope { - SmtScope() { +struct SmtrScope : public Functional::Scope { + SmtrScope() { for(const char **p = reserved_keywords; *p != nullptr; p++) reserve(*p); } bool is_character_legal(char c, int index) override { - return isascii(c) && (isalpha(c) || (isdigit(c) && index > 0) || strchr("~!@$%^&*_-+=<>.?/", c)); + return isascii(c) && (isalpha(c) || (isdigit(c) && index > 0) || strchr("@$%^&_+=.", c)); } }; -struct SmtSort { +struct SmtrSort { Functional::Sort sort; - SmtSort(Functional::Sort sort) : sort(sort) {} + SmtrSort(Functional::Sort sort) : sort(sort) {} SExpr to_sexpr() const { if(sort.is_memory()) { - return list("Array", list("_", "BitVec", sort.addr_width()), list("_", "BitVec", sort.data_width())); + return list("list", list("bitvector", sort.addr_width()), list("bitvector", sort.data_width())); } else if(sort.is_signal()) { - return list("_", "BitVec", sort.width()); + return list("bitvector", sort.width()); } else { log_error("unknown sort"); } } }; -class SmtStruct { +class SmtrStruct { struct Field { - SmtSort sort; + SmtrSort sort; std::string accessor; + std::string name; }; idict field_names; vector fields; - SmtScope &scope; + SmtrScope &global_scope; + SmtrScope local_scope; public: std::string name; - SmtStruct(std::string name, SmtScope &scope) : scope(scope), name(name) {} - void insert(IdString field_name, SmtSort sort) { + SmtrStruct(std::string name, SmtrScope &scope) : global_scope(scope), local_scope(), name(name) {} + void insert(IdString field_name, SmtrSort sort) { field_names(field_name); - auto accessor = scope.unique_name("\\" + name + "_" + RTLIL::unescape_id(field_name)); - fields.emplace_back(Field{sort, accessor}); + auto base_name = local_scope.unique_name(field_name); + auto accessor = name + "-" + base_name; + global_scope.reserve(accessor); + fields.emplace_back(Field{sort, accessor, base_name}); } void write_definition(SExprWriter &w) { - w.open(list("declare-datatype", name)); - w.open(list()); - w.open(list(name)); - for(const auto &field : fields) - w << list(field.accessor, field.sort.to_sexpr()); - w.close(3); + vector field_list; + for(const auto &field : fields) { + field_list.emplace_back(field.name); + } + w.push(); + w.open(list("struct", name, field_list, "#:transparent")); + if (field_names.size()) { + for (const auto &field : fields) { + auto bv_type = field.sort.to_sexpr(); + w.comment(field.name + " " + bv_type.to_string()); + } + } + w.pop(); } template void write_value(SExprWriter &w, Fn fn) { - if(field_names.empty()) { - // Zero-argument constructors in SMTLIB must not be called as functions. - w << name; - } else { - w.open(list(name)); - for(auto field_name : field_names) { - w << fn(field_name); - w.comment(RTLIL::unescape_id(field_name), true); - } - w.close(); + w.open(list(name)); + for(auto field_name : field_names) { + w << fn(field_name); + w.comment(RTLIL::unescape_id(field_name), true); } + w.close(); } SExpr access(SExpr record, IdString name) { size_t i = field_names.at(name); @@ -117,28 +123,28 @@ std::string smt_const(RTLIL::Const const &c) { return s; } -struct SmtPrintVisitor : public Functional::AbstractVisitor { +struct SmtrPrintVisitor : public Functional::AbstractVisitor { using Node = Functional::Node; std::function n; - SmtStruct &input_struct; - SmtStruct &state_struct; + SmtrStruct &input_struct; + SmtrStruct &state_struct; - SmtPrintVisitor(SmtStruct &input_struct, SmtStruct &state_struct) : input_struct(input_struct), state_struct(state_struct) {} + SmtrPrintVisitor(SmtrStruct &input_struct, SmtrStruct &state_struct) : input_struct(input_struct), state_struct(state_struct) {} SExpr from_bool(SExpr &&arg) { - return list("ite", std::move(arg), "#b1", "#b0"); + return list("bool->bitvector", std::move(arg)); } SExpr to_bool(SExpr &&arg) { - return list("=", std::move(arg), "#b1"); + return list("bitvector->bool", std::move(arg)); } - SExpr extract(SExpr &&arg, int offset, int out_width = 1) { - return list(list("_", "extract", offset + out_width - 1, offset), std::move(arg)); + SExpr to_list(SExpr &&arg) { + return list("bitvector->bits", std::move(arg)); } SExpr buf(Node, Node a) override { return n(a); } - SExpr slice(Node, Node a, int offset, int out_width) override { return extract(n(a), offset, out_width); } - SExpr zero_extend(Node, Node a, int out_width) override { return list(list("_", "zero_extend", out_width - a.width()), n(a)); } - SExpr sign_extend(Node, Node a, int out_width) override { return list(list("_", "sign_extend", out_width - a.width()), n(a)); } + SExpr slice(Node, Node a, int offset, int out_width) override { return list("extract", offset + out_width - 1, offset, n(a)); } + SExpr zero_extend(Node, Node a, int out_width) override { return list("zero-extend", n(a), list("bitvector", out_width)); } + SExpr sign_extend(Node, Node a, int out_width) override { return list("sign-extend", n(a), list("bitvector", out_width)); } SExpr concat(Node, Node a, Node b) override { return list("concat", n(b), n(a)); } SExpr add(Node, Node a, Node b) override { return list("bvadd", n(a), n(b)); } SExpr sub(Node, Node a, Node b) override { return list("bvsub", n(a), n(b)); } @@ -150,16 +156,11 @@ struct SmtPrintVisitor : public Functional::AbstractVisitor { SExpr bitwise_xor(Node, Node a, Node b) override { return list("bvxor", n(a), n(b)); } SExpr bitwise_not(Node, Node a) override { return list("bvnot", n(a)); } SExpr unary_minus(Node, Node a) override { return list("bvneg", n(a)); } - SExpr reduce_and(Node, Node a) override { return from_bool(list("=", n(a), smt_const(RTLIL::Const(State::S1, a.width())))); } - SExpr reduce_or(Node, Node a) override { return from_bool(list("distinct", n(a), smt_const(RTLIL::Const(State::S0, a.width())))); } - SExpr reduce_xor(Node, Node a) override { - vector s { "bvxor" }; - for(int i = 0; i < a.width(); i++) - s.push_back(extract(n(a), i)); - return s; - } - SExpr equal(Node, Node a, Node b) override { return from_bool(list("=", n(a), n(b))); } - SExpr not_equal(Node, Node a, Node b) override { return from_bool(list("distinct", n(a), n(b))); } + SExpr reduce_and(Node, Node a) override { return list("apply", "bvand", to_list(n(a))); } + SExpr reduce_or(Node, Node a) override { return list("apply", "bvor", to_list(n(a))); } + SExpr reduce_xor(Node, Node a) override { return list("apply", "bvxor", to_list(n(a))); } + SExpr equal(Node, Node a, Node b) override { return from_bool(list("bveq", n(a), n(b))); } + SExpr not_equal(Node, Node a, Node b) override { return from_bool(list("not", list("bveq", n(a), n(b)))); } SExpr signed_greater_than(Node, Node a, Node b) override { return from_bool(list("bvsgt", n(a), n(b))); } SExpr signed_greater_equal(Node, Node a, Node b) override { return from_bool(list("bvsge", n(a), n(b))); } SExpr unsigned_greater_than(Node, Node a, Node b) override { return from_bool(list("bvugt", n(a), n(b))); } @@ -167,32 +168,32 @@ struct SmtPrintVisitor : public Functional::AbstractVisitor { SExpr extend(SExpr &&a, int in_width, int out_width) { if(in_width < out_width) - return list(list("_", "zero_extend", out_width - in_width), std::move(a)); + return list("zero-extend", std::move(a), list("bitvector", out_width)); else return std::move(a); } SExpr logical_shift_left(Node, Node a, Node b) override { return list("bvshl", n(a), extend(n(b), b.width(), a.width())); } SExpr logical_shift_right(Node, Node a, Node b) override { return list("bvlshr", n(a), extend(n(b), b.width(), a.width())); } SExpr arithmetic_shift_right(Node, Node a, Node b) override { return list("bvashr", n(a), extend(n(b), b.width(), a.width())); } - SExpr mux(Node, Node a, Node b, Node s) override { return list("ite", to_bool(n(s)), n(b), n(a)); } - SExpr constant(Node, RTLIL::Const const &value) override { return smt_const(value); } - SExpr memory_read(Node, Node mem, Node addr) override { return list("select", n(mem), n(addr)); } - SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("store", n(mem), n(addr), n(data)); } + SExpr mux(Node, Node a, Node b, Node s) override { return list("if", to_bool(n(s)), n(b), n(a)); } + SExpr constant(Node, RTLIL::Const const& value) override { return list("bv", smt_const(value), value.size()); } + SExpr memory_read(Node, Node mem, Node addr) override { return list("list-ref-bv", n(mem), n(addr)); } + SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("list-set-bv", n(mem), n(addr), n(data)); } SExpr input(Node, IdString name, IdString kind) override { log_assert(kind == ID($input)); return input_struct.access("inputs", name); } SExpr state(Node, IdString name, IdString kind) override { log_assert(kind == ID($state)); return state_struct.access("state", name); } }; -struct SmtModule { +struct SmtrModule { Functional::IR ir; - SmtScope scope; + SmtrScope scope; std::string name; - SmtStruct input_struct; - SmtStruct output_struct; - SmtStruct state_struct; + SmtrStruct input_struct; + SmtrStruct output_struct; + SmtrStruct state_struct; - SmtModule(Module *module) + SmtrModule(Module *module) : ir(Functional::IR::from_module(module)) , scope() , name(scope.unique_name(module->name)) @@ -200,7 +201,7 @@ struct SmtModule { , output_struct(scope.unique_name(module->name.str() + "_Outputs"), scope) , state_struct(scope.unique_name(module->name.str() + "_State"), scope) { - scope.reserve(name + "-initial"); + scope.reserve(name + "_initial"); for (auto input : ir.inputs()) input_struct.insert(input->name, input->sort); for (auto output : ir.outputs()) @@ -212,14 +213,11 @@ struct SmtModule { void write_eval(SExprWriter &w) { w.push(); - w.open(list("define-fun", name, - list(list("inputs", input_struct.name), - list("state", state_struct.name)), - list("Pair", output_struct.name, state_struct.name))); + w.open(list("define", list(name, "inputs", "state"))); auto inlined = [&](Functional::Node n) { return n.fn() == Functional::Fn::constant; }; - SmtPrintVisitor visitor(input_struct, state_struct); + SmtrPrintVisitor visitor(input_struct, state_struct); auto node_to_sexpr = [&](Functional::Node n) -> SExpr { if(inlined(n)) return n.visit(visitor); @@ -230,9 +228,9 @@ struct SmtModule { for(auto n : ir) if(!inlined(n)) { w.open(list("let", list(list(node_to_sexpr(n), n.visit(visitor)))), false); - w.comment(SmtSort(n.sort()).to_sexpr().to_string(), true); + w.comment(SmtrSort(n.sort()).to_sexpr().to_string(), true); } - w.open(list("pair")); + w.open(list("cons")); output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.output(name).value()); }); state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.state(name).next_value()); }); w.pop(); @@ -240,19 +238,23 @@ struct SmtModule { void write_initial(SExprWriter &w) { - std::string initial = name + "-initial"; - w << list("declare-const", initial, state_struct.name); + w.push(); + auto initial = name + "_initial"; + w.open(list("define", initial)); + w.open(list(state_struct.name)); for (auto state : ir.states()) { - if(state->sort.is_signal()) - w << list("assert", list("=", state_struct.access(initial, state->name), smt_const(state->initial_value_signal()))); - else if(state->sort.is_memory()) { + if (state->sort.is_signal()) + w << list("bv", smt_const(state->initial_value_signal()), state->sort.width()); + else if (state->sort.is_memory()) { const auto &contents = state->initial_value_memory(); + w.open(list("list")); for(int i = 0; i < 1<sort.addr_width(); i++) { - auto addr = smt_const(RTLIL::Const(i, state->sort.addr_width())); - w << list("assert", list("=", list("select", state_struct.access(initial, state->name), addr), smt_const(contents[i]))); + w << list("bv", smt_const(contents[i]), state->sort.data_width()); } + w.close(); } } + w.pop(); } void write(std::ostream &out) @@ -263,33 +265,53 @@ struct SmtModule { output_struct.write_definition(w); state_struct.write_definition(w); - w << list("declare-datatypes", - list(list("Pair", 2)), - list(list("par", list("X", "Y"), list(list("pair", list("first", "X"), list("second", "Y")))))); - write_eval(w); write_initial(w); } }; -struct FunctionalSmtBackend : public Backend { - FunctionalSmtBackend() : Backend("functional_smt2", "Generate SMT-LIB from Functional IR") {} +struct FunctionalSmtrBackend : public Backend { + FunctionalSmtrBackend() : Backend("functional_rosette", "Generate Rosette compatible Racket from Functional IR") {} - void help() override { log("\nFunctional SMT Backend.\n\n"); } + void help() override { + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + log("\n"); + log(" write_functional_rosette [options] [filename]\n"); + log("\n"); + log("Functional Rosette Backend.\n"); + log("\n"); + log(" -provides\n"); + log(" include 'provide' statement(s) for loading output as a module\n"); + log("\n"); + } void execute(std::ostream *&f, std::string filename, std::vector args, RTLIL::Design *design) override { - log_header(design, "Executing Functional SMT Backend.\n"); + auto provides = false; + + log_header(design, "Executing Functional Rosette Backend.\n"); - size_t argidx = 1; - extra_args(f, filename, args, argidx, design); + size_t argidx; + for (argidx = 1; argidx < args.size(); argidx++) + { + if (args[argidx] == "-provides") + provides = true; + else + break; + } + extra_args(f, filename, args, argidx); + + *f << "#lang rosette/safe\n"; + if (provides) { + *f << "(provide (all-defined-out))\n"; + } for (auto module : design->selected_modules()) { log("Processing module `%s`.\n", module->name.c_str()); - SmtModule smt(module); - smt.write(*f); + SmtrModule smtr(module); + smtr.write(*f); } } -} FunctionalSmtBackend; +} FunctionalSmtrBackend; PRIVATE_NAMESPACE_END