Merge pull request #4268 from jix/smtbmc-track-assumes

smtbmc: Add --track-assumes and --minimize-assumes options
This commit is contained in:
N. Engelhardt 2024-03-11 16:34:30 +01:00 committed by GitHub
commit 0909c2ef5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 219 additions and 24 deletions

View File

@ -57,6 +57,8 @@ keep_going = False
check_witness = False check_witness = False
detect_loops = False detect_loops = False
incremental = None incremental = None
track_assumes = False
minimize_assumes = False
so = SmtOpts() so = SmtOpts()
@ -189,6 +191,15 @@ def help():
--incremental --incremental
run in incremental mode (experimental) run in incremental mode (experimental)
--track-assumes
track individual assumptions and report a subset of used
assumptions that are sufficient for the reported outcome. This
can be used to debug PREUNSAT failures as well as to find a
smaller set of sufficient assumptions.
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg()) """ + so.helpmsg())
def usage(): def usage():
@ -200,7 +211,8 @@ try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts + opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", ["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"]) "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental",
"track-assumes", "minimize-assumes"])
except: except:
usage() usage()
@ -289,6 +301,10 @@ for o, a in opts:
elif o == "--incremental": elif o == "--incremental":
from smtbmc_incremental import Incremental from smtbmc_incremental import Incremental
incremental = Incremental() incremental = Incremental()
elif o == "--track-assumes":
track_assumes = True
elif o == "--minimize-assumes":
minimize_assumes = True
elif so.handle(o, a): elif so.handle(o, a):
pass pass
else: else:
@ -447,6 +463,9 @@ def get_constr_expr(db, state, final=False, getvalues=False, individual=False):
smt = SmtIo(opts=so) smt = SmtIo(opts=so)
if track_assumes:
smt.smt2_options[':produce-unsat-assumptions'] = 'true'
if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None: if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False smt.produce_models = False
@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active):
return assert_map return assert_map
assume_enables = {}
def declare_assume_enables():
def recurse(mod, path, key_base=()):
for expr, desc in smt.modinfo[mod].assumes.items():
enable = f"|assume_enable {len(assume_enables)}|"
smt.smt2_assumptions[(expr, key_base)] = enable
smt.write(f"(declare-const {enable} Bool)")
assume_enables[(expr, key_base)] = (enable, path, desc)
for cell, submod in smt.modinfo[mod].cells.items():
recurse(submod, f"{path}.{cell}", (mod, cell, key_base))
recurse(topmod, topmod)
if track_assumes:
declare_assume_enables()
def smt_assert_design_assumes(step):
if not track_assumes:
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
return
if not assume_enables:
return
def expr_for_assume(assume_key, base=None):
expr, key_base = assume_key
expr_prefix = f"(|{expr}| "
expr_suffix = ")"
while key_base:
mod, cell, key_base = key_base
expr_prefix += f"(|{mod}_h {cell}| "
expr_suffix += ")"
return f"{expr_prefix} s{step}{expr_suffix}"
for assume_key, (enable, path, desc) in assume_enables.items():
smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})")
states = list() states = list()
asserts_antecedent_cache = [list()] asserts_antecedent_cache = [list()]
@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert() smt_forall_assert()
return smt.check_sat(expected=expected) return smt.check_sat(expected=expected)
def report_tracked_assumptions(msg):
if track_assumes:
print_msg(msg)
for key in smt.get_unsat_assumptions(minimize=minimize_assumes):
enable, path, descr = assume_enables[key]
print_msg(f" In {path}: {descr}")
if incremental: if incremental:
incremental.mainloop() incremental.mainloop()
@ -1664,7 +1728,7 @@ elif tempind:
break break
smt_state(step) smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1707,6 +1771,7 @@ elif tempind:
else: else:
print_msg("Temporal induction successful.") print_msg("Temporal induction successful.")
report_tracked_assumptions("Used assumptions:")
retstatus = "PASSED" retstatus = "PASSED"
break break
@ -1732,7 +1797,7 @@ elif covermode:
while step < num_steps: while step < num_steps:
smt_state(step) smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1753,6 +1818,7 @@ elif covermode:
smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc))) smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc)))
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
report_tracked_assumptions("Used assumptions:")
smt_pop() smt_pop()
break break
@ -1761,13 +1827,14 @@ elif covermode:
print_msg("Appending additional step %d." % i) print_msg("Appending additional step %d." % i)
smt_state(i) smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..") print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
print("%s Cannot appended steps without violating assumptions!" % smt.timestamp()) print("%s Cannot appended steps without violating assumptions!" % smt.timestamp())
report_tracked_assumptions("Conflicting assumptions:")
found_failed_assert = True found_failed_assert = True
retstatus = "FAILED" retstatus = "FAILED"
break break
@ -1823,7 +1890,7 @@ else: # not tempind, covermode
retstatus = "PASSED" retstatus = "PASSED"
while step < num_steps: while step < num_steps:
smt_state(step) smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1853,7 +1920,7 @@ else: # not tempind, covermode
if step+i < num_steps: if step+i < num_steps:
smt_state(step+i) smt_state(step+i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i)) smt_assert_design_assumes(step + i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i))
smt_assert_consequent(get_constr_expr(constr_assumes, step+i)) smt_assert_consequent(get_constr_expr(constr_assumes, step+i))
@ -1867,7 +1934,8 @@ else: # not tempind, covermode
print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step)) print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step))
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
print("%s Assumptions are unsatisfiable!" % smt.timestamp()) print_msg("Assumptions are unsatisfiable!")
report_tracked_assumptions("Conficting assumptions:")
retstatus = "PREUNSAT" retstatus = "PREUNSAT"
break break
@ -1920,13 +1988,14 @@ else: # not tempind, covermode
print_msg("Appending additional step %d." % i) print_msg("Appending additional step %d." % i)
smt_state(i) smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..") print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
print("%s Cannot append steps without violating assumptions!" % smt.timestamp()) print_msg("Cannot append steps without violating assumptions!")
report_tracked_assumptions("Conflicting assumptions:")
retstatus = "FAILED" retstatus = "FAILED"
break break
print_anyconsts(step) print_anyconsts(step)

View File

@ -15,6 +15,14 @@ class InteractiveError(Exception):
pass pass
def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {data!r}")
return data
class Incremental: class Incremental:
def __init__(self): def __init__(self):
self.traceidx = 0 self.traceidx = 0
@ -73,17 +81,17 @@ class Incremental:
if min_len is not None and arg_len < min_len: if min_len is not None and arg_len < min_len:
if min_len == max_len: if min_len == max_len:
raise ( raise InteractiveError(
f"{json.dumps(expr[0])} expression must have " f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}" f"{min_len} argument{'s' if min_len != 1 else ''}"
) )
else: else:
raise ( raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least " f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}" f"{min_len} argument{'s' if min_len != 1 else ''}"
) )
if max_len is not None and arg_len > max_len: if max_len is not None and arg_len > max_len:
raise ( raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most " f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}" f"{min_len} argument{'s' if max_len != 1 else ''}"
) )
@ -96,14 +104,31 @@ class Incremental:
smt_out.append(f"s{step}") smt_out.append(f"s{step}")
return "module", smtbmc.topmod return "module", smtbmc.topmod
def expr_mod_constraint(self, expr, smt_out): def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 1) self.expr_arg_len(expr, 2)
position = len(smt_out) position = len(smt_out)
smt_out.append(None) smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None]) arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1] module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)
def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:] suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| " self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")") smt_out.append(")")
return "Bool" return "Bool"
@ -223,20 +248,19 @@ class Incremental:
subexpr = expr[2] subexpr = expr[2]
if not isinstance(label, str): if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string") raise InteractiveError("expression label has to be a string")
smt_out.append("(! ") smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")
sort = self.expr(subexpr, smt_out) sort = self.expr(subexpr, smt_out)
smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")") smt_out.append(")")
return sort return sort
expr_handlers = { expr_handlers = {
"step": expr_step, "step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint, "mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint, "mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint, "mod_i": expr_mod_constraint,
@ -302,6 +326,30 @@ class Incremental:
assert_fn(self.expr_smt(cmd.get("expr"), "Bool")) assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))
def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)
def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)
def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")
key = mkkey(key)
result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result
def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))
def cmd_push(self, cmd): def cmd_push(self, cmd):
smtbmc.smt_push() smtbmc.smt_push()
@ -313,11 +361,14 @@ class Incremental:
def cmd_smtlib(self, cmd): def cmd_smtlib(self, cmd):
command = cmd.get("command") command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str): if not isinstance(command, str):
raise InteractiveError( raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}" f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
) )
smtbmc.smt.write(command) smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()
def cmd_design_hierwitness(self, cmd=None): def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False)) allregs = (cmd is None) or bool(cmd.get("allreges", False))
@ -369,6 +420,21 @@ class Incremental:
return dict(last_step=last_step) return dict(last_step=last_step)
def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])
mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None
result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result
def cmd_ping(self, cmd): def cmd_ping(self, cmd):
return cmd return cmd
@ -377,6 +443,10 @@ class Incremental:
"assert": cmd_assert, "assert": cmd_assert,
"assert_antecedent": cmd_assert, "assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert, "assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push, "push": cmd_push,
"pop": cmd_pop, "pop": cmd_pop,
"check": cmd_check, "check": cmd_check,
@ -384,6 +454,7 @@ class Incremental:
"design_hierwitness": cmd_design_hierwitness, "design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace, "write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace, "read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping, "ping": cmd_ping,
} }

View File

@ -114,6 +114,7 @@ class SmtModInfo:
self.clocks = dict() self.clocks = dict()
self.cells = dict() self.cells = dict()
self.asserts = dict() self.asserts = dict()
self.assumes = dict()
self.covers = dict() self.covers = dict()
self.maximize = set() self.maximize = set()
self.minimize = set() self.minimize = set()
@ -141,6 +142,7 @@ class SmtIo:
self.recheck = False self.recheck = False
self.smt2cache = [list()] self.smt2cache = [list()]
self.smt2_options = dict() self.smt2_options = dict()
self.smt2_assumptions = dict()
self.p = None self.p = None
self.p_index = solvers_index self.p_index = solvers_index
solvers_index += 1 solvers_index += 1
@ -602,6 +604,12 @@ class SmtIo:
else: else:
self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3]
if fields[1] == "yosys-smt2-assume":
if len(fields) > 4:
self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})'
else:
self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = fields[3]
if fields[1] == "yosys-smt2-maximize": if fields[1] == "yosys-smt2-maximize":
self.modinfo[self.curmod].maximize.add(fields[2]) self.modinfo[self.curmod].maximize.add(fields[2])
@ -785,8 +793,13 @@ class SmtIo:
return stmt return stmt
def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]): def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]):
if self.smt2_assumptions:
assume_exprs = " ".join(self.smt2_assumptions.values())
check_stmt = f"(check-sat-assuming ({assume_exprs}))"
else:
check_stmt = "(check-sat)"
if self.debug_print: if self.debug_print:
print("> (check-sat)") print(f"> {check_stmt}")
if self.debug_file and not self.nocomments: if self.debug_file and not self.nocomments:
print("; running check-sat..", file=self.debug_file) print("; running check-sat..", file=self.debug_file)
self.debug_file.flush() self.debug_file.flush()
@ -800,7 +813,7 @@ class SmtIo:
for cache_stmt in cache_ctx: for cache_stmt in cache_ctx:
self.p_write(cache_stmt + "\n", False) self.p_write(cache_stmt + "\n", False)
self.p_write("(check-sat)\n", True) self.p_write(f"{check_stmt}\n", True)
if self.timeinfo: if self.timeinfo:
i = 0 i = 0
@ -868,7 +881,7 @@ class SmtIo:
if self.debug_file: if self.debug_file:
print("(set-info :status %s)" % result, file=self.debug_file) print("(set-info :status %s)" % result, file=self.debug_file)
print("(check-sat)", file=self.debug_file) print(check_stmt, file=self.debug_file)
self.debug_file.flush() self.debug_file.flush()
if result not in expected: if result not in expected:
@ -945,6 +958,48 @@ class SmtIo:
def bv2int(self, v): def bv2int(self, v):
return int(self.bv2bin(v), 2) return int(self.bv2bin(v), 2)
def get_raw_unsat_assumptions(self):
self.write("(get-unsat-assumptions)")
exprs = set(self.unparse(part) for part in self.parse(self.read()))
unsat_assumptions = []
for key, value in self.smt2_assumptions.items():
# normalize expression
value = self.unparse(self.parse(value))
if value in exprs:
exprs.remove(value)
unsat_assumptions.append(key)
return unsat_assumptions
def get_unsat_assumptions(self, minimize=False):
if not minimize:
return self.get_raw_unsat_assumptions()
required_assumptions = {}
while True:
candidate_assumptions = {}
for key in self.get_raw_unsat_assumptions():
if key not in required_assumptions:
candidate_assumptions[key] = self.smt2_assumptions[key]
while candidate_assumptions:
candidate_key, candidate_assume = candidate_assumptions.popitem()
self.smt2_assumptions = {}
for key, assume in candidate_assumptions.items():
self.smt2_assumptions[key] = assume
for key, assume in required_assumptions.items():
self.smt2_assumptions[key] = assume
result = self.check_sat()
if result == 'unsat':
candidate_assumptions = None
else:
required_assumptions[candidate_key] = candidate_assume
if candidate_assumptions is not None:
return list(required_assumptions)
def get(self, expr): def get(self, expr):
self.write("(get-value (%s))" % (expr)) self.write("(get-value (%s))" % (expr))
return self.parse(self.read())[0][1] return self.parse(self.read())[0][1]