Merge pull request #4035 from jix/smtbmc-incremental

smtbmc: Add --incremental mode
This commit is contained in:
Jannis Harder 2023-11-20 17:00:29 +01:00 committed by GitHub
commit b23a607421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 512 additions and 64 deletions

View File

@ -17,7 +17,7 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
import os, sys, getopt, re, bisect
import os, sys, getopt, re, bisect, json
##yosys-sys-path##
from smtio import SmtIo, SmtOpts, MkVcd
from ywio import ReadWitness, WriteWitness, WitnessValues
@ -56,6 +56,7 @@ binarymode = False
keep_going = False
check_witness = False
detect_loops = False
incremental = None
so = SmtOpts()
@ -185,6 +186,9 @@ def help():
check if states are unique in temporal induction counter examples
(this feature is experimental and incomplete)
--incremental
run in incremental mode (experimental)
""" + so.helpmsg())
def usage():
@ -196,7 +200,7 @@ try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops"])
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
except:
usage()
@ -282,6 +286,9 @@ for o, a in opts:
check_witness = True
elif o == "--detect-loops":
detect_loops = True
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif so.handle(o, a):
pass
else:
@ -290,7 +297,7 @@ for o, a in opts:
if len(args) != 1:
usage()
if sum([tempind, gentrace, covermode]) > 1:
if sum([tempind, gentrace, covermode, incremental is not None]) > 1:
usage()
constr_final_start = None
@ -444,8 +451,10 @@ if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False
def print_msg(msg):
print("%s %s" % (smt.timestamp(), msg))
sys.stdout.flush()
if incremental:
incremental.print_msg(msg)
else:
print("%s %s" % (smt.timestamp(), msg), flush=True)
print_msg("Solver: %s" % (so.solver))
@ -640,10 +649,9 @@ if aimfile is not None:
num_steps = max(num_steps, step+2)
step += 1
if inywfile is not None:
if not got_topt:
skip_steps = 0
num_steps = 0
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
if map_steps is None:
map_steps = {}
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
@ -662,10 +670,14 @@ if inywfile is not None:
addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
max_t = -1
for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
@ -684,7 +696,7 @@ if inywfile is not None:
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{t}", wire)
smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
@ -714,7 +726,7 @@ if inywfile is not None:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{t}", mem["smtpath"])
smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
@ -738,11 +750,21 @@ if inywfile is not None:
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
if not got_topt:
if not check_witness:
skip_steps = max(skip_steps, t)
num_steps = max(num_steps, t+1)
return max_t
if inywfile is not None:
if not got_topt:
skip_steps = 0
num_steps = 0
max_t = ywfile_constraints(inywfile, constr_assumes)
if not got_topt:
if not check_witness:
skip_steps = max(skip_steps, max_t)
num_steps = max(num_steps, max_t+1)
if btorwitfile is not None:
with open(btorwitfile, "r") as f:
@ -841,7 +863,7 @@ if btorwitfile is not None:
skip_steps = step
num_steps = step+1
def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
def collect_mem_trace_data(steps, vcd=None):
mem_trace_data = dict()
for mempath in sorted(smt.hiermems(topmod)):
@ -849,16 +871,16 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
expr_id = list()
expr_list = list()
for i in range(steps_start, steps_stop):
for seq, i in enumerate(steps):
for j in range(rports):
expr_id.append(('R', i-steps_start, j, 'A'))
expr_id.append(('R', i-steps_start, j, 'D'))
expr_id.append(('R', seq, j, 'A'))
expr_id.append(('R', seq, j, 'D'))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j))
for j in range(wports):
expr_id.append(('W', i-steps_start, j, 'A'))
expr_id.append(('W', i-steps_start, j, 'D'))
expr_id.append(('W', i-steps_start, j, 'M'))
expr_id.append(('W', seq, j, 'A'))
expr_id.append(('W', seq, j, 'D'))
expr_id.append(('W', seq, j, 'M'))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j))
expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j))
@ -943,14 +965,14 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr)
vcd.add_net([topmod] + netpath, width)
for i in range(steps_start, steps_stop):
for seq, i in enumerate(steps):
if i not in mem_trace_data:
mem_trace_data[i] = list()
mem_trace_data[i].append((netpath, int_addr, "".join(tdata[i-steps_start])))
mem_trace_data[i].append((netpath, int_addr, "".join(tdata[seq])))
return mem_trace_data
def write_vcd_trace(steps_start, steps_stop, index):
def write_vcd_trace(steps, index, seq_time=False):
filename = vcdfile.replace("%", index)
print_msg("Writing trace to VCD file: %s" % (filename))
@ -971,10 +993,10 @@ def write_vcd_trace(steps_start, steps_stop, index):
vcd.add_clock([topmod] + netpath, edge)
path_list.append(netpath)
mem_trace_data = collect_mem_trace_data(steps_start, steps_stop, vcd)
mem_trace_data = collect_mem_trace_data(steps, vcd)
for i in range(steps_start, steps_stop):
vcd.set_time(i)
for seq, i in enumerate(steps):
vcd.set_time(seq if seq_time else i)
value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i)
for path, value in zip(path_list, value_list):
vcd.set_net([topmod] + path, value)
@ -982,7 +1004,14 @@ def write_vcd_trace(steps_start, steps_stop, index):
for path, addr, value in mem_trace_data[i]:
vcd.set_net([topmod] + path, value)
vcd.set_time(steps_stop)
if seq_time:
end_time = len(steps)
elif steps:
end_time = steps[-1] + 1
else:
end_time = 0
vcd.set_time(end_time)
def detect_state_loop(steps_start, steps_stop):
print_msg(f"Checking for loops in found induction counter example")
@ -1027,7 +1056,7 @@ def escape_identifier(identifier):
def write_vlogtb_trace(steps_start, steps_stop, index):
def write_vlogtb_trace(steps, index):
filename = vlogtbfile.replace("%", index)
print_msg("Writing trace to Verilog testbench: %s" % (filename))
@ -1092,7 +1121,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print(" initial begin", file=f)
regs = sorted(smt.hiernets(vlogtb_topmod, regs_only=True))
regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps_start)))
regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps[0])))
print("`ifndef VERILATOR", file=f)
print(" #1;", file=f)
@ -1107,7 +1136,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
anyconsts = sorted(smt.hieranyconsts(vlogtb_topmod))
for info in anyconsts:
if info[3] is not None:
modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps_start)), info[0])
modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps[0])), info[0])
value = smt.bv2bin(smt.get("(|%s| %s)" % (info[1], modstate)))
print(" UUT.%s = %d'b%s;" % (".".join(escape_identifier(info[0] + [info[3]])), len(value), value), file=f);
@ -1117,7 +1146,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
addr_expr_list = list()
data_expr_list = list()
for i in range(steps_start, steps_stop):
for i in steps:
for j in range(rports):
addr_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j))
data_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j))
@ -1138,7 +1167,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print("", file=f)
anyseqs = sorted(smt.hieranyseqs(vlogtb_topmod))
for i in range(steps_start, steps_stop):
for i in steps:
pi_names = [[name] for name, _ in primary_inputs if name not in clock_inputs]
pi_values = smt.get_net_bin_list(vlogtb_topmod, pi_names, vlogtb_state.replace("@@step_idx@@", str(i)))
@ -1170,14 +1199,14 @@ def write_vlogtb_trace(steps_start, steps_stop, index):
print(" end", file=f)
print(" always @(posedge clock) begin", file=f)
print(" genclock <= cycle < %d;" % (steps_stop-1), file=f)
print(" genclock <= cycle < %d;" % (steps[-1]), file=f)
print(" cycle <= cycle + 1;", file=f)
print(" end", file=f)
print("endmodule", file=f)
def write_constr_trace(steps_start, steps_stop, index):
def write_constr_trace(steps, index):
filename = outconstr.replace("%", index)
print_msg("Writing trace to constraints file: %s" % (filename))
@ -1194,7 +1223,7 @@ def write_constr_trace(steps_start, steps_stop, index):
constr_prefix = smtctop[1] + "."
if smtcinit:
steps_start = steps_stop - 1
steps = [steps[-1]]
with open(filename, "w") as f:
primary_inputs = list()
@ -1203,13 +1232,13 @@ def write_constr_trace(steps_start, steps_stop, index):
width = smt.modinfo[constr_topmod].wsize[name]
primary_inputs.append((name, width))
if steps_start == 0 or smtcinit:
if steps[0] == 0 or smtcinit:
print("initial", file=f)
else:
print("state %d" % steps_start, file=f)
print("state %d" % steps[0], file=f)
regnames = sorted(smt.hiernets(constr_topmod, regs_only=True))
regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps_start)))
regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps[0])))
for name, val in zip(regnames, regvals):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
@ -1220,7 +1249,7 @@ def write_constr_trace(steps_start, steps_stop, index):
addr_expr_list = list()
data_expr_list = list()
for i in range(steps_start, steps_stop):
for i in steps:
for j in range(rports):
addr_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j))
data_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j))
@ -1236,7 +1265,7 @@ def write_constr_trace(steps_start, steps_stop, index):
for addr, data in addr_data.items():
print("assume (= (select [%s%s] %s) %s)" % (constr_prefix, ".".join(mempath), addr, data), file=f)
for k in range(steps_start, steps_stop):
for k in steps:
if not smtcinit:
print("", file=f)
print("state %d" % k, file=f)
@ -1247,11 +1276,14 @@ def write_constr_trace(steps_start, steps_stop, index):
for name, val in zip(pi_names, pi_values):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
def write_yw_trace(steps_start, steps_stop, index, allregs=False):
filename = outywfile.replace("%", index)
print_msg("Writing trace to Yosys witness file: %s" % (filename))
def write_yw_trace(steps, index, allregs=False, filename=None):
if filename is None:
if outywfile is None:
return
filename = outywfile.replace("%", index)
print_msg("Writing trace to Yosys witness file: %s" % (filename))
mem_trace_data = collect_mem_trace_data(steps_start, steps_stop)
mem_trace_data = collect_mem_trace_data(steps)
with open(filename, "w") as f:
inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs)
@ -1295,10 +1327,10 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False):
sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True)
mem_init_values.append((sig, overlap_bits.replace("x", "?")))
for k in range(steps_start, steps_stop):
for i, k in enumerate(steps):
step_values = WitnessValues()
if k == steps_start:
if not i:
for sig, value in mem_init_values:
step_values[sig] = value
sigs = inits + seqs
@ -1314,17 +1346,24 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False):
def write_trace(steps_start, steps_stop, index, allregs=False):
if steps_stop is None:
steps = steps_start
seq_time = True
else:
steps = list(range(steps_start, steps_stop))
seq_time = False
if vcdfile is not None:
write_vcd_trace(steps_start, steps_stop, index)
write_vcd_trace(steps, index, seq_time=seq_time)
if vlogtbfile is not None:
write_vlogtb_trace(steps_start, steps_stop, index)
write_vlogtb_trace(steps, index)
if outconstr is not None:
write_constr_trace(steps_start, steps_stop, index)
write_constr_trace(steps, index)
if outywfile is not None:
write_yw_trace(steps_start, steps_stop, index, allregs)
write_yw_trace(steps, index, allregs)
def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()):
@ -1596,7 +1635,11 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert()
return smt.check_sat(expected=expected)
if tempind:
if incremental:
incremental.mainloop()
elif tempind:
retstatus = "FAILED"
skip_counter = step_size
for step in range(num_steps, -1, -1):
@ -1954,5 +1997,6 @@ else: # not tempind, covermode
smt.write("(exit)")
smt.wait()
print_msg("Status: %s" % retstatus)
sys.exit(0 if retstatus == "PASSED" else 1)
if not incremental:
print_msg("Status: %s" % retstatus)
sys.exit(0 if retstatus == "PASSED" else 1)

View File

@ -0,0 +1,389 @@
from collections import defaultdict
import json
import typing
from functools import partial
if typing.TYPE_CHECKING:
import smtbmc
else:
import sys
smtbmc = sys.modules["__main__"]
class InteractiveError(Exception):
pass
class Incremental:
def __init__(self):
self.traceidx = 0
self.state_set = set()
self.map_cache = {}
self._cached_hierwitness = {}
self._witness_index = None
self._yw_constraints = {}
def setup(self):
generic_assert_map = smtbmc.get_assert_map(
smtbmc.topmod, "state", smtbmc.topmod
)
self.inv_generic_assert_map = {
tuple(data[1:]): key for key, data in generic_assert_map.items()
}
assert len(self.inv_generic_assert_map) == len(generic_assert_map)
def print_json(self, **kwargs):
print(json.dumps(kwargs), flush=True)
def print_msg(self, msg):
self.print_json(msg=msg)
def get_cached_assert(self, step, name):
try:
assert_map = self.map_cache[step]
except KeyError:
assert_map = self.map_cache[step] = smtbmc.get_assert_map(
smtbmc.topmod, f"s{step}", smtbmc.topmod
)
return assert_map[self.inv_generic_assert_map[name]][0]
def arg_step(self, cmd, declare=False, name="step", optional=False):
step = cmd.get(name, None)
if step is None and optional:
return None
if not isinstance(step, int):
if optional:
raise InteractiveError(f"{name} must be an integer")
else:
raise InteractiveError(f"integer {name} argument required")
if declare and step in self.state_set:
raise InteractiveError(f"step {step} already declared")
if not declare and step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
return step
def expr_arg_len(self, expr, min_len, max_len=-1):
if max_len == -1:
max_len = min_len
arg_len = len(expr) - 1
if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise (
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise (
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise (
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
def expr_step(self, expr, smt_out):
self.expr_arg_len(expr, 1)
step = expr[1]
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
smt_out.append(f"s{step}")
return "module", smtbmc.topmod
def expr_mod_constraint(self, expr, smt_out):
self.expr_arg_len(expr, 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
module = arg_sort[1]
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
def expr_mod_constraint2(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
module = arg_sort[1]
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"
def expr_not(self, expr, smt_out):
self.expr_arg_len(expr, 1)
smt_out.append("(not ")
self.expr(expr[1], smt_out, required_sort="Bool")
smt_out.append(")")
return "Bool"
def expr_eq(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smt_out.append("(= ")
arg_sort = self.expr(expr[1], smt_out)
if (
smtbmc.smt.unroll
and isinstance(arg_sort, (list, tuple))
and arg_sort[0] == "module"
):
raise InteractiveError("state equality not supported in unroll mode")
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
smt_out.append(")")
return "Bool"
def expr_andor(self, expr, smt_out):
if len(expr) == 1:
smt_out.push({"and": "true", "or": "false"}[expr[0]])
elif len(expr) == 2:
arg_sort = self.expr(expr[1], smt_out)
if arg_sort != "Bool":
raise InteractiveError(
f"arguments of {json.dumps(expr[0])} must have sort Bool"
)
else:
sep = f"({expr[0]} "
for arg in expr[1:]:
smt_out.append(sep)
sep = " "
self.expr(arg, smt_out, required_sort="Bool")
smt_out.append(")")
return "Bool"
def expr_yw(self, expr, smt_out):
if len(expr) == 2:
name = None
step = expr[1]
elif len(expr) == 3:
name = expr[1]
step = expr[2]
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
if name not in self._yw_constraints:
raise InteractiveError(f"no yw file loaded as name {name!r}")
constraints = self._yw_constraints[name].get(step, [])
if len(constraints) == 0:
smt_out.append("true")
elif len(constraints) == 1:
smt_out.append(constraints[0])
else:
sep = "(and "
for constraint in constraints:
smt_out.append(sep)
sep = " "
smt_out.append(constraint)
smt_out.append(")")
return "Bool"
def expr_label(self, expr, smt_out):
if len(expr) != 3:
raise InteractiveError(f'expected ["!", label, sub_expr], got {expr!r}')
label = expr[1]
subexpr = expr[2]
if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string")
smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")
sort = self.expr(subexpr, smt_out)
smt_out.append(")")
return sort
expr_handlers = {
"step": expr_step,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
"mod_a": expr_mod_constraint,
"mod_u": expr_mod_constraint,
"mod_t": expr_mod_constraint2,
"not": expr_not,
"and": expr_andor,
"or": expr_andor,
"=": expr_eq,
"yw": expr_yw,
"!": expr_label,
}
def expr(self, expr, smt_out, required_sort=None):
if not isinstance(expr, (list, tuple)) or not expr:
raise InteractiveError(
f"expression must be a non-empty JSON array, found: {json.dumps(expr)}"
)
name = expr[0]
handler = self.expr_handlers.get(name)
if handler:
sort = handler(self, expr, smt_out)
if required_sort is not None:
if isinstance(required_sort, (list, tuple)):
if (
not isinstance(sort, (list, tuple))
or len(sort) != len(required_sort)
or any(
r is not None and r != s
for r, s in zip(required_sort, sort)
)
):
raise InteractiveError(
f"required sort {json.dumps(required_sort)} found sort {json.dumps(sort)}"
)
return sort
raise InteractiveError(f"unknown expression {json.dumps(expr[0])}")
def expr_smt(self, expr, required_sort):
smt_out = []
self.expr(expr, smt_out, required_sort=required_sort)
out = "".join(smt_out)
return out
def cmd_new_step(self, cmd):
step = self.arg_step(cmd, declare=True)
self.state_set.add(step)
smtbmc.smt_state(step)
def cmd_assert(self, cmd):
name = cmd.get("cmd")
assert_fn = {
"assert_antecedent": smtbmc.smt_assert_antecedent,
"assert_consequent": smtbmc.smt_assert_consequent,
"assert": smtbmc.smt_assert,
}[name]
assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))
def cmd_push(self, cmd):
smtbmc.smt_push()
def cmd_pop(self, cmd):
smtbmc.smt_pop()
def cmd_check(self, cmd):
return smtbmc.smt_check_sat()
def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
if self._cached_hierwitness[allregs] is not None:
return self._cached_hierwitness[allregs]
inits, seqs, clocks, mems = smtbmc.smt.hierwitness(smtbmc.topmod, allregs)
self._cached_hierwitness[allregs] = result = dict(
inits=inits, seqs=seqs, clocks=clocks, mems=mems
)
return result
def cmd_write_yw_trace(self, cmd):
steps = cmd.get("steps")
allregs = bool(cmd.get("allregs", False))
if steps is None:
steps = sorted(self.state_set)
path = cmd.get("path")
smtbmc.write_yw_trace(steps, self.traceidx, allregs=allregs, filename=path)
if path is None:
self.traceidx += 1
def cmd_read_yw_trace(self, cmd):
steps = cmd.get("steps")
path = cmd.get("path")
name = cmd.get("name")
skip_x = cmd.get("skip_x", False)
if path is None:
raise InteractiveError("path required")
constraints = defaultdict(list)
if steps is None:
steps = sorted(self.state_set)
map_steps = {i: int(j) for i, j in enumerate(steps)}
smtbmc.ywfile_constraints(path, constraints, map_steps=map_steps, skip_x=skip_x)
self._yw_constraints[name] = {
map_steps.get(i, i): [smtexpr for cexfile, smtexpr in constraint_list]
for i, constraint_list in constraints.items()
}
def cmd_ping(self, cmd):
return cmd
cmd_handlers = {
"new_step": cmd_new_step,
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"ping": cmd_ping,
}
def handle_command(self, cmd):
if not isinstance(cmd, dict) or "cmd" not in cmd:
raise InteractiveError('object with "cmd" key required')
name = cmd.get("cmd", None)
handler = self.cmd_handlers.get(name)
if handler:
return handler(self, cmd)
else:
raise InteractiveError(f"unknown command: {name}")
def mainloop(self):
self.setup()
while True:
try:
cmd = input().strip()
if not cmd or cmd.startswith("#") or cmd.startswith("//"):
continue
try:
cmd = json.loads(cmd)
except json.decoder.JSONDecodeError as e:
self.print_json(err=f"invalid JSON: {e}")
continue
except EOFError:
break
try:
result = self.handle_command(cmd)
except InteractiveError as e:
self.print_json(err=str(e))
continue
except Exception as e:
self.print_json(err=f"internal error: {e}")
raise
else:
self.print_json(ok=result)

View File

@ -33,10 +33,14 @@ def cli():
Display a Yosys witness trace in a human readable format.
""")
@click.argument("input", type=click.File("r"))
def display(input):
@click.option("--skip-x", help="Treat x bits as unassigned.", is_flag=True)
def display(input, skip_x):
click.echo(f"Reading Yosys witness trace {input.name!r}...")
inyw = ReadWitness(input)
if skip_x:
inyw.skip_x()
def output():
yield click.style("*** RTLIL bit-order below may differ from source level declarations ***", fg="red")
@ -91,7 +95,11 @@ If two or more inputs are provided they will be concatenated together into the o
@click.option("--append", "-p", type=int, multiple=True,
help="Number of steps (+ve or -ve) to append to end of input trace. "
+"Can be defined multiple times, following the same order as input traces. ")
def yw2yw(inputs, output, append):
@click.option("--skip-x", help="Leave input x bits unassigned.", is_flag=True)
def yw2yw(inputs, output, append, skip_x):
if len(inputs) == 0:
raise click.ClickException(f"no inputs specified")
outyw = WriteWitness(output, "yosys-witness yw2yw")
join_inputs = len(inputs) > 1
inyws = {}
@ -129,12 +137,12 @@ def yw2yw(inputs, output, append):
click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...")
if first_witness:
outyw.step(init_values)
outyw.step(init_values, skip_x=skip_x)
else:
outyw.step(inyw.first_step())
outyw.step(inyw.first_step(), skip_x=skip_x)
for t, values in inyw.steps(1):
outyw.step(values)
outyw.step(values, skip_x=skip_x)
click.echo(f" copied {t + 1} time steps.")
first_witness = False
@ -174,7 +182,8 @@ This requires a Yosys witness AIGER map file as generated by 'write_aiger -ywmap
@click.argument("input", type=click.File("r"))
@click.argument("mapfile", type=click.File("r"))
@click.argument("output", type=click.File("w"))
def aiw2yw(input, mapfile, output):
@click.option("--skip-x", help="Leave input x bits unassigned.", is_flag=True)
def aiw2yw(input, mapfile, output, skip_x):
input_name = input.name
click.echo(f"Converting AIGER witness trace {input_name!r} to Yosys witness trace {output.name!r}...")
click.echo(f"Using Yosys witness AIGER map file {mapfile.name!r}")
@ -245,7 +254,7 @@ def aiw2yw(input, mapfile, output):
values[bit] = v
outyw.step(values)
outyw.step(values, skip_x=skip_x)
outyw.end_trace()

View File

@ -351,11 +351,14 @@ class WriteWitness:
self.out.name("steps")
self.out.begin_array()
def step(self, values):
def step(self, values, skip_x=False):
if not self.header_written:
self.write_header()
self.out.value({"bits": values.pack(self.sigmap)})
packed = values.pack(self.sigmap)
if skip_x:
packed = packed.replace('x', '?')
self.out.value({"bits": packed})
self.t += 1
@ -390,6 +393,9 @@ class ReadWitness:
self.bits = [step["bits"] for step in data["steps"]]
def skip_x(self):
self.bits = [step.replace('x', '?') for step in self.bits]
def init_step(self):
return self.step(0)