diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index 448a14202..1bb9dd93e 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -91,7 +91,7 @@ for o, a in opts: if o == "-t": a = a.split(":") if len(a) == 1: - num_steps = int(a[1]) + num_steps = int(a[0]) elif len(a) == 2: skip_steps = int(a[0]) num_steps = int(a[1]) @@ -139,9 +139,12 @@ constr_assumes = defaultdict(list) for fn in inconstr: current_states = None + current_line = 0 with open(fn, "r") as f: for line in f: + current_line += 1 + if line.startswith("#"): continue @@ -203,7 +206,7 @@ for fn in inconstr: assert current_states is not None for state in current_states: - constr_asserts[state].append(" ".join(tokens[1:])) + constr_asserts[state].append(("%s:%d" % (fn, current_line), " ".join(tokens[1:]))) continue @@ -211,14 +214,14 @@ for fn in inconstr: assert current_states is not None for state in current_states: - constr_assumes[state].append(" ".join(tokens[1:])) + constr_assumes[state].append(("%s:%d" % (fn, current_line), " ".join(tokens[1:]))) continue assert 0 -def get_constr_expr(db, state, final=False): +def get_constr_expr(db, state, final=False, getvalues=False): if final: if ("final-%d" % state) not in db: return "true" @@ -243,9 +246,17 @@ def get_constr_expr(db, state, final=False): return match.group(1) + expr expr_list = list() - for expr in db[("final-%d" % state) if final else state]: - expr = netref_regex.sub(replace_netref, expr) - expr_list.append(expr) + for loc, expr in db[("final-%d" % state) if final else state]: + actual_expr = netref_regex.sub(replace_netref, expr) + if getvalues: + expr_list.append((loc, expr, actual_expr)) + else: + expr_list.append(actual_expr) + + if getvalues: + loc_list, expr_list, acual_expr_list = zip(*expr_list) + value_list = smt.get_list(acual_expr_list) + return loc_list, expr_list, value_list if len(expr_list) == 0: return "true" @@ -400,41 +411,42 @@ def write_constr_trace(steps): width = smt.modinfo[topmod].wsize[name] primary_inputs.append((name, width)) - for k in range(steps): - if k != 0: - print("", file=f) + print("initial", file=f) + + regnames = sorted(smt.hiernets(topmod, regs_only=True)) + regvals = smt.get_net_list(topmod, regnames, "s0") + + for name, val in zip(regnames, regvals): + print("assume (= [%s] %s)" % (".".join(name), val), file=f) + + mems = sorted(smt.hiermems(topmod)) + for mempath in mems: + abits, width, ports = smt.mem_info(topmod, "s0", mempath) + mem = smt.mem_expr(topmod, "s0", mempath) + + addr_expr_list = list() + for i in range(steps): + for j in range(ports): + addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j)) + + addr_list = set() + for val in smt.get_list(addr_expr_list): + addr_list.add(smt.bv2int(val)) + + expr_list = list() + for i in addr_list: + expr_list.append("(select %s #b%s)" % (mem, format(i, "0%db" % abits))) + + for i, val in zip(addr_list, smt.get_list(expr_list)): + print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f) + + + for k in range(steps): + print("", file=f) print("state %d" % k, file=f) - if k == 0: - regnames = sorted(smt.hiernets(topmod, regs_only=True)) - regvals = smt.get_net_list(topmod, regnames, "s0") - - for name, val in zip(regnames, regvals): - print("assume (= [%s] %s)" % (".".join(name), val), file=f) - - mems = sorted(smt.hiermems(topmod)) - for mempath in mems: - abits, width, ports = smt.mem_info(topmod, "s0", mempath) - mem = smt.mem_expr(topmod, "s0", mempath) - - addr_expr_list = list() - for i in range(steps): - for j in range(ports): - addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j)) - - addr_list = set() - for val in smt.get_list(addr_expr_list): - addr_list.add(smt.bv2int(val)) - - expr_list = list() - for i in addr_list: - expr_list.append("(select %s #b%s)" % (mem, format(i, "0%db" % abits))) - - for i, val in zip(addr_list, smt.get_list(expr_list)): - print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f) - - pi_names = [[name] for name, _ in primary_inputs] + pi_names = [[name] for name, _ in sorted(primary_inputs)] pi_values = smt.get_net_list(topmod, pi_names, "s%d" % k) for name, val in zip(pi_names, pi_values): @@ -452,20 +464,31 @@ def write_trace(steps): write_constr_trace(steps) -def print_failed_asserts(mod, state, path): +def print_failed_asserts_worker(mod, state, path): assert mod in smt.modinfo - if smt.get("(|%s_a| %s)" % (mod, state)) == "true": + if smt.get("(|%s_a| s%d)" % (mod, state)) == "true": return for cellname, celltype in smt.modinfo[mod].cells.items(): - print_failed_asserts(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname) + print_failed_asserts_worker(celltype, "(|%s_h %s| s%d)" % (mod, cellname, state), path + "." + cellname) for assertfun, assertinfo in smt.modinfo[mod].asserts.items(): - if smt.get("(|%s| %s)" % (assertfun, state)) == "false": + if smt.get("(|%s| s%d)" % (assertfun, state)) == "false": print("%s Assert failed in %s: %s" % (smt.timestamp(), path, assertinfo)) +def print_failed_asserts(state, final=False): + loc_list, expr_list, value_list = get_constr_expr(constr_asserts, state, final=final, getvalues=True) + + for loc, expr, value in zip(loc_list, expr_list, value_list): + if smt.bv2int(value) == 0: + print("%s Assert %s failed: %s" % (smt.timestamp(), loc, expr)) + + if not final: + print_failed_asserts_worker(topmod, state, topmod) + + if tempind: retstatus = False skip_counter = step_size @@ -497,7 +520,7 @@ if tempind: if smt.check_sat() == "sat": if step == 0: print("%s Temporal induction failed!" % smt.timestamp()) - print_failed_asserts(topmod, "s%d" % step, topmod) + print_failed_asserts(num_steps) write_trace(num_steps+1) else: @@ -556,8 +579,9 @@ else: # not tempind if smt.check_sat() == "sat": print("%s BMC failed!" % smt.timestamp()) - print_failed_asserts(topmod, "s%d" % step, topmod) - write_trace(step+step_size) + for i in range(step, last_check_step+1): + print_failed_asserts(i) + write_trace(last_check_step+1) retstatus = False break @@ -580,8 +604,8 @@ else: # not tempind if smt.check_sat() == "sat": print("%s BMC failed!" % smt.timestamp()) - print_failed_asserts(topmod, "s%d" % i, topmod) - write_trace(i) + print_failed_asserts(i, final=True) + write_trace(i+1) retstatus = False break