smtbmc: Improvements for --incremental and .yw fixes

This extends the experimental incremental JSON API to allow arbitrary
smtlib subexpressions, defining smtlib constants and to allow access of
signals by their .yw path.

It also fixes a bug during .yw writing where values would be re-emitted
in later cycles if they have no newer defined value and a potential
crash when using --track-assumes.
This commit is contained in:
Jannis Harder 2024-05-07 17:57:37 +02:00
parent 71f2540cd8
commit a52088b6af
3 changed files with 284 additions and 97 deletions

View File

@ -199,7 +199,6 @@ def help():
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())
def usage():
@ -670,18 +669,12 @@ if aimfile is not None:
ywfile_hierwitness_cache = None
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
def ywfile_hierwitness():
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}
if ywfile_hierwitness_cache is None:
ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True)
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
if ywfile_hierwitness_cache is None:
ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True)
inits, seqs, clocks, mems = ywfile_hierwitness_cache
inits, seqs, clocks, mems = ywfile_hierwitness
smt_wires = defaultdict(list)
smt_mems = defaultdict(list)
@ -692,9 +685,128 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
for mem in mems:
smt_mems[mem["path"]].append(mem)
addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems
return ywfile_hierwitness_cache
def_bits_re = re.compile(r'([01]+)')
def smt_extract_mask(smt_expr, mask):
chunks = []
def_bits = ''
mask_index_order = mask[::-1]
for matched in def_bits_re.finditer(mask_index_order):
chunks.append(matched.span())
def_bits += matched[0]
if not chunks:
return
if len(chunks) == 1:
start, end = chunks[0]
if start == 0 and end == len(mask_index_order):
combined_chunks = smt_expr
else:
combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr)
else:
combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join(
'((_ extract %d %d) x)' % (end - 1, start)
for start, end in reversed(chunks)
))
return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1]
def smt_concat(exprs):
if not exprs:
return ""
if len(exprs) == 1:
return exprs[1]
return "(concat %s)" % ' '.join(exprs)
def ywfile_signal(sig, step, mask=None):
assert sig.width > 0
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
sig_end = sig.offset + sig.width
output = []
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
offset = max(offset, 0)
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
else:
smt_expr = "(ite %s #b1 #b0)" % smt_expr
output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr))
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
if sig.width < width:
slice_high = sig.offset + sig.width - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
output.append((0, sig.width, smt_expr))
output.sort()
output = [chunk for chunk in output if chunk[0] != chunk[1]]
pos = 0
for start, end, smt_expr in output:
assert start == pos
pos = end
assert pos == sig.width
if len(output) == 1:
return output[0][-1]
return smt_concat(smt_expr for start, end, smt_expr in reversed(output))
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
bits_re = re.compile(r'[01?]*$')
max_t = -1
for t, step in inyw.steps():
@ -706,77 +818,14 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
sig_end = sig.offset + len(bits)
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
smt_expr = ywfile_signal(sig, map_steps.get(t, t))
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
smt_expr, bits = smt_extract_mask(smt_expr, bits)
offset = max(offset, 0)
smt_constr = "(= %s #b%s)" % (smt_expr, bits)
constr_assumes[t].append((inywfile, smt_constr))
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
if bit_slice.count("?") == len(bit_slice):
continue
if smt_bool:
assert width == 1
smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
else:
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
if len(bits) < width:
slice_high = sig.offset + len(bits) - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
bit_slice = bits
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
return max_t
if inywfile is not None:
@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None):
exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs)
all_sigs.append(sigs)
all_sigs.append((step_values, sigs))
bvs = iter(smt.get_list(exprs))
for sigs in all_sigs:
for (step_values, sigs) in all_sigs:
for sig in sigs:
value = smt.bv2bin(next(bvs))
step_values[sig["sig"]] = value

View File

@ -1,7 +1,7 @@
from collections import defaultdict
import json
import typing
from functools import partial
import ywio
if typing.TYPE_CHECKING:
import smtbmc
@ -34,6 +34,7 @@ class Incremental:
self._witness_index = None
self._yw_constraints = {}
self._define_sorts = {}
def setup(self):
generic_assert_map = smtbmc.get_assert_map(
@ -175,11 +176,7 @@ class Incremental:
if len(expr) == 1:
smt_out.push({"and": "true", "or": "false"}[expr[0]])
elif len(expr) == 2:
arg_sort = self.expr(expr[1], smt_out)
if arg_sort != "Bool":
raise InteractiveError(
f"arguments of {json.dumps(expr[0])} must have sort Bool"
)
self.expr(expr[1], smt_out, required_sort="Bool")
else:
sep = f"({expr[0]} "
for arg in expr[1:]:
@ -189,7 +186,51 @@ class Incremental:
smt_out.append(")")
return "Bool"
def expr_bv_binop(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smt_out.append(f"({expr[0]} ")
arg_sort = self.expr(expr[1], smt_out, required_sort=("BitVec", None))
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
smt_out.append(")")
return arg_sort
def expr_extract(self, expr, smt_out):
self.expr_arg_len(expr, 3)
hi = expr[1]
lo = expr[2]
smt_out.append(f"((_ extract {hi} {lo}) ")
arg_sort = self.expr(expr[3], smt_out, required_sort=("BitVec", None))
smt_out.append(")")
if not (isinstance(hi, int) and 0 <= hi < arg_sort[1]):
raise InteractiveError(
f"high bit index must be 0 <= index < {arg_sort[1]}, is {hi!r}"
)
if not (isinstance(lo, int) and 0 <= lo <= hi):
raise InteractiveError(
f"low bit index must be 0 <= index < {hi}, is {lo!r}"
)
return "BitVec", hi - lo + 1
def expr_bv(self, expr, smt_out):
self.expr_arg_len(expr, 1)
arg = expr[1]
if not isinstance(arg, str) or arg.count("0") + arg.count("1") != len(arg):
raise InteractiveError("bv argument must contain only 0 or 1 bits")
smt_out.append("#b" + arg)
return "BitVec", len(arg)
def expr_yw(self, expr, smt_out):
self.expr_arg_len(expr, 1, 2)
if len(expr) == 2:
name = None
step = expr[1]
@ -219,6 +260,40 @@ class Incremental:
return "Bool"
def expr_yw_sig(self, expr, smt_out):
self.expr_arg_len(expr, 3, 4)
step = expr[1]
path = expr[2]
offset = expr[3]
width = expr[4] if len(expr) == 5 else 1
if not isinstance(offset, int) or offset < 0:
raise InteractiveError(
f"offset must be a non-negative integer, got {json.dumps(offset)}"
)
if not isinstance(width, int) or width <= 0:
raise InteractiveError(
f"width must be a positive integer, got {json.dumps(width)}"
)
if not isinstance(path, list) or not all(isinstance(s, str) for s in path):
raise InteractiveError(
f"path must be a string list, got {json.dumps(path)}"
)
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
smt_expr = smtbmc.ywfile_signal(
ywio.WitnessSig(path=path, offset=offset, width=width), step
)
smt_out.append(smt_expr)
return "BitVec", width
def expr_smtlib(self, expr, smt_out):
self.expr_arg_len(expr, 2)
@ -231,10 +306,15 @@ class Incremental:
f"got {json.dumps(smtlib_expr)}"
)
if not isinstance(sort, str):
raise InteractiveError(
f"raw SMT-LIB sort has to be a string, got {json.dumps(sort)}"
)
if (
isinstance(sort, list)
and len(sort) == 2
and sort[0] == "BitVec"
and (sort[1] is None or isinstance(sort[1], int))
):
sort = tuple(sort)
elif not isinstance(sort, str):
raise InteractiveError(f"unsupported raw SMT-LIB sort {json.dumps(sort)}")
smt_out.append(smtlib_expr)
return sort
@ -258,6 +338,14 @@ class Incremental:
return sort
def expr_def(self, expr, smt_out):
self.expr_arg_len(expr, 1)
sort = self._define_sorts.get(expr[1])
if sort is None:
raise InteractiveError(f"unknown definition {json.dumps(expr)}")
smt_out.append(expr[1])
return sort
expr_handlers = {
"step": expr_step,
"cell": expr_cell,
@ -270,8 +358,15 @@ class Incremental:
"not": expr_not,
"and": expr_andor,
"or": expr_andor,
"bv": expr_bv,
"bvand": expr_bv_binop,
"bvor": expr_bv_binop,
"bvxor": expr_bv_binop,
"extract": expr_extract,
"def": expr_def,
"=": expr_eq,
"yw": expr_yw,
"yw_sig": expr_yw_sig,
"smtlib": expr_smtlib,
"!": expr_label,
}
@ -305,10 +400,13 @@ class Incremental:
raise InteractiveError(f"unknown expression {json.dumps(expr[0])}")
def expr_smt(self, expr, required_sort):
return self.expr_smt_and_sort(expr, required_sort)[0]
def expr_smt_and_sort(self, expr, required_sort=None):
smt_out = []
self.expr(expr, smt_out, required_sort=required_sort)
output_sort = self.expr(expr, smt_out, required_sort=required_sort)
out = "".join(smt_out)
return out
return out, output_sort
def cmd_new_step(self, cmd):
step = self.arg_step(cmd, declare=True)
@ -338,7 +436,6 @@ class Incremental:
expr = cmd.get("expr")
key = cmd.get("key")
key = mkkey(key)
result = smtbmc.smt.smt2_assumptions.pop(key, None)
@ -348,7 +445,7 @@ class Incremental:
return result
def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get("minimize")))
def cmd_push(self, cmd):
smtbmc.smt_push()
@ -370,6 +467,27 @@ class Incremental:
if response:
return smtbmc.smt.read()
def cmd_define(self, cmd):
expr = cmd.get("expr")
if expr is None:
raise InteractiveError("'define' copmmand requires 'expr' parameter")
expr, sort = self.expr_smt_and_sort(expr)
if isinstance(sort, tuple) and sort[0] == "module":
raise InteractiveError("'define' does not support module sorts")
define_name = f"|inc def {len(self._define_sorts)}|"
self._define_sorts[define_name] = sort
if isinstance(sort, tuple):
sort = f"(_ {' '.join(map(str, sort))})"
smtbmc.smt.write(f"(define-const {define_name} {sort} {expr})")
return {"name": define_name}
def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
if self._cached_hierwitness[allregs] is not None:
@ -451,6 +569,7 @@ class Incremental:
"pop": cmd_pop,
"check": cmd_check,
"smtlib": cmd_smtlib,
"define": cmd_define,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,

View File

@ -160,6 +160,7 @@ class SmtIo:
self.noincr = opts.noincr
self.info_stmts = opts.info_stmts
self.nocomments = opts.nocomments
self.smt2_options.update(opts.smt2_options)
else:
self.solver = "yices"
@ -959,6 +960,8 @@ class SmtIo:
return int(self.bv2bin(v), 2)
def get_raw_unsat_assumptions(self):
if not self.smt2_assumptions:
return []
self.write("(get-unsat-assumptions)")
exprs = set(self.unparse(part) for part in self.parse(self.read()))
unsat_assumptions = []
@ -973,6 +976,10 @@ class SmtIo:
def get_unsat_assumptions(self, minimize=False):
if not minimize:
return self.get_raw_unsat_assumptions()
orig_assumptions = self.smt2_assumptions
self.smt2_assumptions = dict(orig_assumptions)
required_assumptions = {}
while True:
@ -998,6 +1005,7 @@ class SmtIo:
required_assumptions[candidate_key] = candidate_assume
if candidate_assumptions is not None:
self.smt2_assumptions = orig_assumptions
return list(required_assumptions)
def get(self, expr):
@ -1146,7 +1154,7 @@ class SmtIo:
class SmtOpts:
def __init__(self):
self.shortopts = "s:S:v"
self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"]
self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments", "smt2-option="]
self.solver = "yices"
self.solver_opts = list()
self.debug_print = False
@ -1159,6 +1167,7 @@ class SmtOpts:
self.logic = None
self.info_stmts = list()
self.nocomments = False
self.smt2_options = {}
def handle(self, o, a):
if o == "-s":
@ -1185,6 +1194,13 @@ class SmtOpts:
self.info_stmts.append(a)
elif o == "--nocomments":
self.nocomments = True
elif o == "--smt2-option":
args = a.split('=', 1)
if len(args) != 2:
print("--smt2-option expects an <option>=<value> argument")
sys.exit(1)
option, value = args
self.smt2_options[option] = value
else:
return False
return True
@ -1208,6 +1224,9 @@ class SmtOpts:
if solver is "dummy", read solver output from that file
otherwise: write solver output to that file
--smt2-option <option>=<value>
enable an SMT-LIBv2 option.
-v
enable debug output