smtbmc: Add --track-assumes and --minimize-assumes options

The --track-assumes option makes smtbmc keep track of which assumptions
were used by the solver when reaching an unsat case and to output that
set of assumptions. This is particularly useful to debug PREUNSAT
failures.

The --minimize-assumes option can be used in addition to --track-assumes
which will cause smtbmc to spend additional solving effort to produce a
minimal set of assumptions that are sufficient to cause the unsat
result.
This commit is contained in:
Jannis Harder 2024-03-07 13:27:03 +01:00
parent e4f11eb0a0
commit 42122e240e
3 changed files with 219 additions and 24 deletions

View File

@ -57,6 +57,8 @@ keep_going = False
check_witness = False
detect_loops = False
incremental = None
track_assumes = False
minimize_assumes = False
so = SmtOpts()
@ -189,6 +191,15 @@ def help():
--incremental
run in incremental mode (experimental)
--track-assumes
track individual assumptions and report a subset of used
assumptions that are sufficient for the reported outcome. This
can be used to debug PREUNSAT failures as well as to find a
smaller set of sufficient assumptions.
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())
def usage():
@ -200,7 +211,8 @@ try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental",
"track-assumes", "minimize-assumes"])
except:
usage()
@ -289,6 +301,10 @@ for o, a in opts:
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif o == "--track-assumes":
track_assumes = True
elif o == "--minimize-assumes":
minimize_assumes = True
elif so.handle(o, a):
pass
else:
@ -447,6 +463,9 @@ def get_constr_expr(db, state, final=False, getvalues=False, individual=False):
smt = SmtIo(opts=so)
if track_assumes:
smt.smt2_options[':produce-unsat-assumptions'] = 'true'
if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False
@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active):
return assert_map
assume_enables = {}
def declare_assume_enables():
def recurse(mod, path, key_base=()):
for expr, desc in smt.modinfo[mod].assumes.items():
enable = f"|assume_enable {len(assume_enables)}|"
smt.smt2_assumptions[(expr, key_base)] = enable
smt.write(f"(declare-const {enable} Bool)")
assume_enables[(expr, key_base)] = (enable, path, desc)
for cell, submod in smt.modinfo[mod].cells.items():
recurse(submod, f"{path}.{cell}", (mod, cell, key_base))
recurse(topmod, topmod)
if track_assumes:
declare_assume_enables()
def smt_assert_design_assumes(step):
if not track_assumes:
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
return
if not assume_enables:
return
def expr_for_assume(assume_key, base=None):
expr, key_base = assume_key
expr_prefix = f"(|{expr}| "
expr_suffix = ")"
while key_base:
mod, cell, key_base = key_base
expr_prefix += f"(|{mod}_h {cell}| "
expr_suffix += ")"
return f"{expr_prefix} s{step}{expr_suffix}"
for assume_key, (enable, path, desc) in assume_enables.items():
smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})")
states = list()
asserts_antecedent_cache = [list()]
@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert()
return smt.check_sat(expected=expected)
def report_tracked_assumptions(msg):
if track_assumes:
print_msg(msg)
for key in smt.get_unsat_assumptions(minimize=minimize_assumes):
enable, path, descr = assume_enables[key]
print_msg(f" In {path}: {descr}")
if incremental:
incremental.mainloop()
@ -1664,7 +1728,7 @@ elif tempind:
break
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1707,6 +1771,7 @@ elif tempind:
else:
print_msg("Temporal induction successful.")
report_tracked_assumptions("Used assumptions:")
retstatus = "PASSED"
break
@ -1732,7 +1797,7 @@ elif covermode:
while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1753,6 +1818,7 @@ elif covermode:
smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc)))
if smt_check_sat() == "unsat":
report_tracked_assumptions("Used assumptions:")
smt_pop()
break
@ -1761,13 +1827,14 @@ elif covermode:
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot appended steps without violating assumptions!" % smt.timestamp())
report_tracked_assumptions("Conflicting assumptions:")
found_failed_assert = True
retstatus = "FAILED"
break
@ -1823,7 +1890,7 @@ else: # not tempind, covermode
retstatus = "PASSED"
while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1853,7 +1920,7 @@ else: # not tempind, covermode
if step+i < num_steps:
smt_state(step+i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i))
smt_assert_design_assumes(step + i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i))
smt_assert_consequent(get_constr_expr(constr_assumes, step+i))
@ -1867,7 +1934,8 @@ else: # not tempind, covermode
print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step))
if smt_check_sat() == "unsat":
print("%s Assumptions are unsatisfiable!" % smt.timestamp())
print_msg("Assumptions are unsatisfiable!")
report_tracked_assumptions("Conficting assumptions:")
retstatus = "PREUNSAT"
break
@ -1920,13 +1988,14 @@ else: # not tempind, covermode
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot append steps without violating assumptions!" % smt.timestamp())
print_msg("Cannot append steps without violating assumptions!")
report_tracked_assumptions("Conflicting assumptions:")
retstatus = "FAILED"
break
print_anyconsts(step)

View File

@ -15,6 +15,14 @@ class InteractiveError(Exception):
pass
def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {data!r}")
return data
class Incremental:
def __init__(self):
self.traceidx = 0
@ -73,17 +81,17 @@ class Incremental:
if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
@ -96,14 +104,31 @@ class Incremental:
smt_out.append(f"s{step}")
return "module", smtbmc.topmod
def expr_mod_constraint(self, expr, smt_out):
self.expr_arg_len(expr, 1)
def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)
def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
@ -223,20 +248,19 @@ class Incremental:
subexpr = expr[2]
if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string")
raise InteractiveError("expression label has to be a string")
smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")
sort = self.expr(subexpr, smt_out)
smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")")
return sort
expr_handlers = {
"step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
@ -302,6 +326,30 @@ class Incremental:
assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))
def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)
def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)
def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")
key = mkkey(key)
result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result
def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))
def cmd_push(self, cmd):
smtbmc.smt_push()
@ -313,11 +361,14 @@ class Incremental:
def cmd_smtlib(self, cmd):
command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str):
raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
)
smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()
def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
@ -369,6 +420,21 @@ class Incremental:
return dict(last_step=last_step)
def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])
mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None
result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result
def cmd_ping(self, cmd):
return cmd
@ -377,6 +443,10 @@ class Incremental:
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
@ -384,6 +454,7 @@ class Incremental:
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping,
}

View File

@ -114,6 +114,7 @@ class SmtModInfo:
self.clocks = dict()
self.cells = dict()
self.asserts = dict()
self.assumes = dict()
self.covers = dict()
self.maximize = set()
self.minimize = set()
@ -141,6 +142,7 @@ class SmtIo:
self.recheck = False
self.smt2cache = [list()]
self.smt2_options = dict()
self.smt2_assumptions = dict()
self.p = None
self.p_index = solvers_index
solvers_index += 1
@ -602,6 +604,12 @@ class SmtIo:
else:
self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3]
if fields[1] == "yosys-smt2-assume":
if len(fields) > 4:
self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})'
else:
self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = fields[3]
if fields[1] == "yosys-smt2-maximize":
self.modinfo[self.curmod].maximize.add(fields[2])
@ -785,8 +793,13 @@ class SmtIo:
return stmt
def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]):
if self.smt2_assumptions:
assume_exprs = " ".join(self.smt2_assumptions.values())
check_stmt = f"(check-sat-assuming ({assume_exprs}))"
else:
check_stmt = "(check-sat)"
if self.debug_print:
print("> (check-sat)")
print(f"> {check_stmt}")
if self.debug_file and not self.nocomments:
print("; running check-sat..", file=self.debug_file)
self.debug_file.flush()
@ -800,7 +813,7 @@ class SmtIo:
for cache_stmt in cache_ctx:
self.p_write(cache_stmt + "\n", False)
self.p_write("(check-sat)\n", True)
self.p_write(f"{check_stmt}\n", True)
if self.timeinfo:
i = 0
@ -868,7 +881,7 @@ class SmtIo:
if self.debug_file:
print("(set-info :status %s)" % result, file=self.debug_file)
print("(check-sat)", file=self.debug_file)
print(check_stmt, file=self.debug_file)
self.debug_file.flush()
if result not in expected:
@ -945,6 +958,48 @@ class SmtIo:
def bv2int(self, v):
return int(self.bv2bin(v), 2)
def get_raw_unsat_assumptions(self):
self.write("(get-unsat-assumptions)")
exprs = set(self.unparse(part) for part in self.parse(self.read()))
unsat_assumptions = []
for key, value in self.smt2_assumptions.items():
# normalize expression
value = self.unparse(self.parse(value))
if value in exprs:
exprs.remove(value)
unsat_assumptions.append(key)
return unsat_assumptions
def get_unsat_assumptions(self, minimize=False):
if not minimize:
return self.get_raw_unsat_assumptions()
required_assumptions = {}
while True:
candidate_assumptions = {}
for key in self.get_raw_unsat_assumptions():
if key not in required_assumptions:
candidate_assumptions[key] = self.smt2_assumptions[key]
while candidate_assumptions:
candidate_key, candidate_assume = candidate_assumptions.popitem()
self.smt2_assumptions = {}
for key, assume in candidate_assumptions.items():
self.smt2_assumptions[key] = assume
for key, assume in required_assumptions.items():
self.smt2_assumptions[key] = assume
result = self.check_sat()
if result == 'unsat':
candidate_assumptions = None
else:
required_assumptions[candidate_key] = candidate_assume
if candidate_assumptions is not None:
return list(required_assumptions)
def get(self, expr):
self.write("(get-value (%s))" % (expr))
return self.parse(self.read())[0][1]