Various fixes and improvements in yosys-smtbmc

This commit is contained in:
Clifford Wolf 2016-08-29 13:53:12 +02:00
parent eae390ae17
commit a2e2fc5980
1 changed files with 72 additions and 48 deletions

View File

@ -91,7 +91,7 @@ for o, a in opts:
if o == "-t": if o == "-t":
a = a.split(":") a = a.split(":")
if len(a) == 1: if len(a) == 1:
num_steps = int(a[1]) num_steps = int(a[0])
elif len(a) == 2: elif len(a) == 2:
skip_steps = int(a[0]) skip_steps = int(a[0])
num_steps = int(a[1]) num_steps = int(a[1])
@ -139,9 +139,12 @@ constr_assumes = defaultdict(list)
for fn in inconstr: for fn in inconstr:
current_states = None current_states = None
current_line = 0
with open(fn, "r") as f: with open(fn, "r") as f:
for line in f: for line in f:
current_line += 1
if line.startswith("#"): if line.startswith("#"):
continue continue
@ -203,7 +206,7 @@ for fn in inconstr:
assert current_states is not None assert current_states is not None
for state in current_states: for state in current_states:
constr_asserts[state].append(" ".join(tokens[1:])) constr_asserts[state].append(("%s:%d" % (fn, current_line), " ".join(tokens[1:])))
continue continue
@ -211,14 +214,14 @@ for fn in inconstr:
assert current_states is not None assert current_states is not None
for state in current_states: for state in current_states:
constr_assumes[state].append(" ".join(tokens[1:])) constr_assumes[state].append(("%s:%d" % (fn, current_line), " ".join(tokens[1:])))
continue continue
assert 0 assert 0
def get_constr_expr(db, state, final=False): def get_constr_expr(db, state, final=False, getvalues=False):
if final: if final:
if ("final-%d" % state) not in db: if ("final-%d" % state) not in db:
return "true" return "true"
@ -243,9 +246,17 @@ def get_constr_expr(db, state, final=False):
return match.group(1) + expr return match.group(1) + expr
expr_list = list() expr_list = list()
for expr in db[("final-%d" % state) if final else state]: for loc, expr in db[("final-%d" % state) if final else state]:
expr = netref_regex.sub(replace_netref, expr) actual_expr = netref_regex.sub(replace_netref, expr)
expr_list.append(expr) if getvalues:
expr_list.append((loc, expr, actual_expr))
else:
expr_list.append(actual_expr)
if getvalues:
loc_list, expr_list, acual_expr_list = zip(*expr_list)
value_list = smt.get_list(acual_expr_list)
return loc_list, expr_list, value_list
if len(expr_list) == 0: if len(expr_list) == 0:
return "true" return "true"
@ -400,41 +411,42 @@ def write_constr_trace(steps):
width = smt.modinfo[topmod].wsize[name] width = smt.modinfo[topmod].wsize[name]
primary_inputs.append((name, width)) primary_inputs.append((name, width))
for k in range(steps):
if k != 0:
print("", file=f)
print("initial", file=f)
regnames = sorted(smt.hiernets(topmod, regs_only=True))
regvals = smt.get_net_list(topmod, regnames, "s0")
for name, val in zip(regnames, regvals):
print("assume (= [%s] %s)" % (".".join(name), val), file=f)
mems = sorted(smt.hiermems(topmod))
for mempath in mems:
abits, width, ports = smt.mem_info(topmod, "s0", mempath)
mem = smt.mem_expr(topmod, "s0", mempath)
addr_expr_list = list()
for i in range(steps):
for j in range(ports):
addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j))
addr_list = set()
for val in smt.get_list(addr_expr_list):
addr_list.add(smt.bv2int(val))
expr_list = list()
for i in addr_list:
expr_list.append("(select %s #b%s)" % (mem, format(i, "0%db" % abits)))
for i, val in zip(addr_list, smt.get_list(expr_list)):
print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f)
for k in range(steps):
print("", file=f)
print("state %d" % k, file=f) print("state %d" % k, file=f)
if k == 0: pi_names = [[name] for name, _ in sorted(primary_inputs)]
regnames = sorted(smt.hiernets(topmod, regs_only=True))
regvals = smt.get_net_list(topmod, regnames, "s0")
for name, val in zip(regnames, regvals):
print("assume (= [%s] %s)" % (".".join(name), val), file=f)
mems = sorted(smt.hiermems(topmod))
for mempath in mems:
abits, width, ports = smt.mem_info(topmod, "s0", mempath)
mem = smt.mem_expr(topmod, "s0", mempath)
addr_expr_list = list()
for i in range(steps):
for j in range(ports):
addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j))
addr_list = set()
for val in smt.get_list(addr_expr_list):
addr_list.add(smt.bv2int(val))
expr_list = list()
for i in addr_list:
expr_list.append("(select %s #b%s)" % (mem, format(i, "0%db" % abits)))
for i, val in zip(addr_list, smt.get_list(expr_list)):
print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f)
pi_names = [[name] for name, _ in primary_inputs]
pi_values = smt.get_net_list(topmod, pi_names, "s%d" % k) pi_values = smt.get_net_list(topmod, pi_names, "s%d" % k)
for name, val in zip(pi_names, pi_values): for name, val in zip(pi_names, pi_values):
@ -452,20 +464,31 @@ def write_trace(steps):
write_constr_trace(steps) write_constr_trace(steps)
def print_failed_asserts(mod, state, path): def print_failed_asserts_worker(mod, state, path):
assert mod in smt.modinfo assert mod in smt.modinfo
if smt.get("(|%s_a| %s)" % (mod, state)) == "true": if smt.get("(|%s_a| s%d)" % (mod, state)) == "true":
return return
for cellname, celltype in smt.modinfo[mod].cells.items(): for cellname, celltype in smt.modinfo[mod].cells.items():
print_failed_asserts(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname) print_failed_asserts_worker(celltype, "(|%s_h %s| s%d)" % (mod, cellname, state), path + "." + cellname)
for assertfun, assertinfo in smt.modinfo[mod].asserts.items(): for assertfun, assertinfo in smt.modinfo[mod].asserts.items():
if smt.get("(|%s| %s)" % (assertfun, state)) == "false": if smt.get("(|%s| s%d)" % (assertfun, state)) == "false":
print("%s Assert failed in %s: %s" % (smt.timestamp(), path, assertinfo)) print("%s Assert failed in %s: %s" % (smt.timestamp(), path, assertinfo))
def print_failed_asserts(state, final=False):
loc_list, expr_list, value_list = get_constr_expr(constr_asserts, state, final=final, getvalues=True)
for loc, expr, value in zip(loc_list, expr_list, value_list):
if smt.bv2int(value) == 0:
print("%s Assert %s failed: %s" % (smt.timestamp(), loc, expr))
if not final:
print_failed_asserts_worker(topmod, state, topmod)
if tempind: if tempind:
retstatus = False retstatus = False
skip_counter = step_size skip_counter = step_size
@ -497,7 +520,7 @@ if tempind:
if smt.check_sat() == "sat": if smt.check_sat() == "sat":
if step == 0: if step == 0:
print("%s Temporal induction failed!" % smt.timestamp()) print("%s Temporal induction failed!" % smt.timestamp())
print_failed_asserts(topmod, "s%d" % step, topmod) print_failed_asserts(num_steps)
write_trace(num_steps+1) write_trace(num_steps+1)
else: else:
@ -556,8 +579,9 @@ else: # not tempind
if smt.check_sat() == "sat": if smt.check_sat() == "sat":
print("%s BMC failed!" % smt.timestamp()) print("%s BMC failed!" % smt.timestamp())
print_failed_asserts(topmod, "s%d" % step, topmod) for i in range(step, last_check_step+1):
write_trace(step+step_size) print_failed_asserts(i)
write_trace(last_check_step+1)
retstatus = False retstatus = False
break break
@ -580,8 +604,8 @@ else: # not tempind
if smt.check_sat() == "sat": if smt.check_sat() == "sat":
print("%s BMC failed!" % smt.timestamp()) print("%s BMC failed!" % smt.timestamp())
print_failed_asserts(topmod, "s%d" % i, topmod) print_failed_asserts(i, final=True)
write_trace(i) write_trace(i+1)
retstatus = False retstatus = False
break break