diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index cc47bc376..e6b4088db 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -57,6 +57,8 @@ keep_going = False check_witness = False detect_loops = False incremental = None +track_assumes = False +minimize_assumes = False so = SmtOpts() @@ -189,6 +191,15 @@ def help(): --incremental run in incremental mode (experimental) + --track-assumes + track individual assumptions and report a subset of used + assumptions that are sufficient for the reported outcome. This + can be used to debug PREUNSAT failures as well as to find a + smaller set of sufficient assumptions. + + --minimize-assumes + when using --track-assumes, solve for a minimal set of sufficient assumptions. + """ + so.helpmsg()) def usage(): @@ -200,7 +211,8 @@ try: opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts + ["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", - "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"]) + "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental", + "track-assumes", "minimize-assumes"]) except: usage() @@ -289,6 +301,10 @@ for o, a in opts: elif o == "--incremental": from smtbmc_incremental import Incremental incremental = Incremental() + elif o == "--track-assumes": + track_assumes = True + elif o == "--minimize-assumes": + minimize_assumes = True elif so.handle(o, a): pass else: @@ -447,6 +463,9 @@ def get_constr_expr(db, state, final=False, getvalues=False, individual=False): smt = SmtIo(opts=so) +if track_assumes: + smt.smt2_options[':produce-unsat-assumptions'] = 'true' + if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None: smt.produce_models = False @@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active): return assert_map +assume_enables = {} + +def declare_assume_enables(): + def recurse(mod, path, key_base=()): + for expr, desc in smt.modinfo[mod].assumes.items(): + enable = f"|assume_enable {len(assume_enables)}|" + smt.smt2_assumptions[(expr, key_base)] = enable + smt.write(f"(declare-const {enable} Bool)") + assume_enables[(expr, key_base)] = (enable, path, desc) + + for cell, submod in smt.modinfo[mod].cells.items(): + recurse(submod, f"{path}.{cell}", (mod, cell, key_base)) + + recurse(topmod, topmod) + +if track_assumes: + declare_assume_enables() + +def smt_assert_design_assumes(step): + if not track_assumes: + smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + return + + if not assume_enables: + return + + def expr_for_assume(assume_key, base=None): + expr, key_base = assume_key + expr_prefix = f"(|{expr}| " + expr_suffix = ")" + while key_base: + mod, cell, key_base = key_base + expr_prefix += f"(|{mod}_h {cell}| " + expr_suffix += ")" + return f"{expr_prefix} s{step}{expr_suffix}" + + for assume_key, (enable, path, desc) in assume_enables.items(): + smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})") states = list() asserts_antecedent_cache = [list()] @@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]): smt_forall_assert() return smt.check_sat(expected=expected) +def report_tracked_assumptions(msg): + if track_assumes: + print_msg(msg) + for key in smt.get_unsat_assumptions(minimize=minimize_assumes): + enable, path, descr = assume_enables[key] + print_msg(f" In {path}: {descr}") + if incremental: incremental.mainloop() @@ -1664,7 +1728,7 @@ elif tempind: break smt_state(step) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + smt_assert_design_assumes(step) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step)) @@ -1707,6 +1771,7 @@ elif tempind: else: print_msg("Temporal induction successful.") + report_tracked_assumptions("Used assumptions:") retstatus = "PASSED" break @@ -1732,7 +1797,7 @@ elif covermode: while step < num_steps: smt_state(step) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + smt_assert_design_assumes(step) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step)) @@ -1753,6 +1818,7 @@ elif covermode: smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc))) if smt_check_sat() == "unsat": + report_tracked_assumptions("Used assumptions:") smt_pop() break @@ -1761,13 +1827,14 @@ elif covermode: print_msg("Appending additional step %d." % i) smt_state(i) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) + smt_assert_design_assumes(i) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i)) print_msg("Re-solving with appended steps..") if smt_check_sat() == "unsat": print("%s Cannot appended steps without violating assumptions!" % smt.timestamp()) + report_tracked_assumptions("Conflicting assumptions:") found_failed_assert = True retstatus = "FAILED" break @@ -1823,7 +1890,7 @@ else: # not tempind, covermode retstatus = "PASSED" while step < num_steps: smt_state(step) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + smt_assert_design_assumes(step) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step)) @@ -1853,7 +1920,7 @@ else: # not tempind, covermode if step+i < num_steps: smt_state(step+i) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i)) + smt_assert_design_assumes(step + i) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i)) smt_assert_consequent(get_constr_expr(constr_assumes, step+i)) @@ -1867,7 +1934,8 @@ else: # not tempind, covermode print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step)) if smt_check_sat() == "unsat": - print("%s Assumptions are unsatisfiable!" % smt.timestamp()) + print_msg("Assumptions are unsatisfiable!") + report_tracked_assumptions("Conficting assumptions:") retstatus = "PREUNSAT" break @@ -1920,13 +1988,14 @@ else: # not tempind, covermode print_msg("Appending additional step %d." % i) smt_state(i) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) + smt_assert_design_assumes(i) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i)) print_msg("Re-solving with appended steps..") if smt_check_sat() == "unsat": - print("%s Cannot append steps without violating assumptions!" % smt.timestamp()) + print_msg("Cannot append steps without violating assumptions!") + report_tracked_assumptions("Conflicting assumptions:") retstatus = "FAILED" break print_anyconsts(step) diff --git a/backends/smt2/smtbmc_incremental.py b/backends/smt2/smtbmc_incremental.py index 1a2a45703..f43e878f3 100644 --- a/backends/smt2/smtbmc_incremental.py +++ b/backends/smt2/smtbmc_incremental.py @@ -15,6 +15,14 @@ class InteractiveError(Exception): pass +def mkkey(data): + if isinstance(data, list): + return tuple(map(mkkey, data)) + elif isinstance(data, dict): + raise InteractiveError(f"JSON objects found in assumption key: {data!r}") + return data + + class Incremental: def __init__(self): self.traceidx = 0 @@ -73,17 +81,17 @@ class Incremental: if min_len is not None and arg_len < min_len: if min_len == max_len: - raise ( + raise InteractiveError( f"{json.dumps(expr[0])} expression must have " f"{min_len} argument{'s' if min_len != 1 else ''}" ) else: - raise ( + raise InteractiveError( f"{json.dumps(expr[0])} expression must have at least " f"{min_len} argument{'s' if min_len != 1 else ''}" ) if max_len is not None and arg_len > max_len: - raise ( + raise InteractiveError( f"{json.dumps(expr[0])} expression can have at most " f"{min_len} argument{'s' if max_len != 1 else ''}" ) @@ -96,14 +104,31 @@ class Incremental: smt_out.append(f"s{step}") return "module", smtbmc.topmod - def expr_mod_constraint(self, expr, smt_out): - self.expr_arg_len(expr, 1) + def expr_cell(self, expr, smt_out): + self.expr_arg_len(expr, 2) position = len(smt_out) smt_out.append(None) - arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None]) + arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None]) + smt_out.append(")") module = arg_sort[1] + cell = expr[1] + submod = smtbmc.smt.modinfo[module].cells.get(cell) + if submod is None: + raise InteractiveError(f"module {module!r} has no cell {cell!r}") + smt_out[position] = f"(|{module}_h {cell}| " + return ("module", submod) + + def expr_mod_constraint(self, expr, smt_out): suffix = expr[0][3:] - smt_out[position] = f"(|{module}{suffix}| " + self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1) + position = len(smt_out) + smt_out.append(None) + arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None]) + module = arg_sort[1] + if len(expr) == 3: + smt_out[position] = f"(|{module}{suffix} {expr[1]}| " + else: + smt_out[position] = f"(|{module}{suffix}| " smt_out.append(")") return "Bool" @@ -223,20 +248,19 @@ class Incremental: subexpr = expr[2] if not isinstance(label, str): - raise InteractiveError(f"expression label has to be a string") + raise InteractiveError("expression label has to be a string") smt_out.append("(! ") - smt_out.appedd(label) - smt_out.append(" ") - sort = self.expr(subexpr, smt_out) - + smt_out.append(" :named ") + smt_out.append(label) smt_out.append(")") return sort expr_handlers = { "step": expr_step, + "cell": expr_cell, "mod_h": expr_mod_constraint, "mod_is": expr_mod_constraint, "mod_i": expr_mod_constraint, @@ -302,6 +326,30 @@ class Incremental: assert_fn(self.expr_smt(cmd.get("expr"), "Bool")) + def cmd_assert_design_assumes(self, cmd): + step = self.arg_step(cmd) + smtbmc.smt_assert_design_assumes(step) + + def cmd_get_design_assume(self, cmd): + key = mkkey(cmd.get("key")) + return smtbmc.assume_enables.get(key) + + def cmd_update_assumptions(self, cmd): + expr = cmd.get("expr") + key = cmd.get("key") + + + key = mkkey(key) + + result = smtbmc.smt.smt2_assumptions.pop(key, None) + if expr is not None: + expr = self.expr_smt(expr, "Bool") + smtbmc.smt.smt2_assumptions[key] = expr + return result + + def cmd_get_unsat_assumptions(self, cmd): + return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize'))) + def cmd_push(self, cmd): smtbmc.smt_push() @@ -313,11 +361,14 @@ class Incremental: def cmd_smtlib(self, cmd): command = cmd.get("command") + response = cmd.get("response", False) if not isinstance(command, str): raise InteractiveError( f"raw SMT-LIB command must be a string, found {json.dumps(command)}" ) smtbmc.smt.write(command) + if response: + return smtbmc.smt.read() def cmd_design_hierwitness(self, cmd=None): allregs = (cmd is None) or bool(cmd.get("allreges", False)) @@ -369,6 +420,21 @@ class Incremental: return dict(last_step=last_step) + def cmd_modinfo(self, cmd): + fields = cmd.get("fields", []) + + mod = cmd.get("mod") + if mod is None: + mod = smtbmc.topmod + modinfo = smtbmc.smt.modinfo.get(mod) + if modinfo is None: + return None + + result = dict(name=mod) + for field in fields: + result[field] = getattr(modinfo, field, None) + return result + def cmd_ping(self, cmd): return cmd @@ -377,6 +443,10 @@ class Incremental: "assert": cmd_assert, "assert_antecedent": cmd_assert, "assert_consequent": cmd_assert, + "assert_design_assumes": cmd_assert_design_assumes, + "get_design_assume": cmd_get_design_assume, + "update_assumptions": cmd_update_assumptions, + "get_unsat_assumptions": cmd_get_unsat_assumptions, "push": cmd_push, "pop": cmd_pop, "check": cmd_check, @@ -384,6 +454,7 @@ class Incremental: "design_hierwitness": cmd_design_hierwitness, "write_yw_trace": cmd_write_yw_trace, "read_yw_trace": cmd_read_yw_trace, + "modinfo": cmd_modinfo, "ping": cmd_ping, } diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index c904aea95..e32f43c60 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -114,6 +114,7 @@ class SmtModInfo: self.clocks = dict() self.cells = dict() self.asserts = dict() + self.assumes = dict() self.covers = dict() self.maximize = set() self.minimize = set() @@ -141,6 +142,7 @@ class SmtIo: self.recheck = False self.smt2cache = [list()] self.smt2_options = dict() + self.smt2_assumptions = dict() self.p = None self.p_index = solvers_index solvers_index += 1 @@ -602,6 +604,12 @@ class SmtIo: else: self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] + if fields[1] == "yosys-smt2-assume": + if len(fields) > 4: + self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})' + else: + self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = fields[3] + if fields[1] == "yosys-smt2-maximize": self.modinfo[self.curmod].maximize.add(fields[2]) @@ -785,8 +793,13 @@ class SmtIo: return stmt def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]): + if self.smt2_assumptions: + assume_exprs = " ".join(self.smt2_assumptions.values()) + check_stmt = f"(check-sat-assuming ({assume_exprs}))" + else: + check_stmt = "(check-sat)" if self.debug_print: - print("> (check-sat)") + print(f"> {check_stmt}") if self.debug_file and not self.nocomments: print("; running check-sat..", file=self.debug_file) self.debug_file.flush() @@ -800,7 +813,7 @@ class SmtIo: for cache_stmt in cache_ctx: self.p_write(cache_stmt + "\n", False) - self.p_write("(check-sat)\n", True) + self.p_write(f"{check_stmt}\n", True) if self.timeinfo: i = 0 @@ -868,7 +881,7 @@ class SmtIo: if self.debug_file: print("(set-info :status %s)" % result, file=self.debug_file) - print("(check-sat)", file=self.debug_file) + print(check_stmt, file=self.debug_file) self.debug_file.flush() if result not in expected: @@ -945,6 +958,48 @@ class SmtIo: def bv2int(self, v): return int(self.bv2bin(v), 2) + def get_raw_unsat_assumptions(self): + self.write("(get-unsat-assumptions)") + exprs = set(self.unparse(part) for part in self.parse(self.read())) + unsat_assumptions = [] + for key, value in self.smt2_assumptions.items(): + # normalize expression + value = self.unparse(self.parse(value)) + if value in exprs: + exprs.remove(value) + unsat_assumptions.append(key) + return unsat_assumptions + + def get_unsat_assumptions(self, minimize=False): + if not minimize: + return self.get_raw_unsat_assumptions() + required_assumptions = {} + + while True: + candidate_assumptions = {} + for key in self.get_raw_unsat_assumptions(): + if key not in required_assumptions: + candidate_assumptions[key] = self.smt2_assumptions[key] + + while candidate_assumptions: + + candidate_key, candidate_assume = candidate_assumptions.popitem() + + self.smt2_assumptions = {} + for key, assume in candidate_assumptions.items(): + self.smt2_assumptions[key] = assume + for key, assume in required_assumptions.items(): + self.smt2_assumptions[key] = assume + result = self.check_sat() + + if result == 'unsat': + candidate_assumptions = None + else: + required_assumptions[candidate_key] = candidate_assume + + if candidate_assumptions is not None: + return list(required_assumptions) + def get(self, expr): self.write("(get-value (%s))" % (expr)) return self.parse(self.read())[0][1]