Merge pull request #4377 from jix/smtbmc-incremental-improvements

smtbmc: Improvements for --incremental and .yw fixes
This commit is contained in:
Miodrag Milanović 2024-05-07 21:35:10 +02:00 committed by GitHub
commit c9d87d5e7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 284 additions and 97 deletions

View File

@ -199,7 +199,6 @@ def help():
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())
def usage():
@ -670,18 +669,12 @@ if aimfile is not None:
ywfile_hierwitness_cache = None
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
def ywfile_hierwitness():
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
if ywfile_hierwitness_cache is None:
ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True)
ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True)
inits, seqs, clocks, mems = ywfile_hierwitness_cache
inits, seqs, clocks, mems = ywfile_hierwitness
smt_wires = defaultdict(list)
smt_mems = defaultdict(list)
@ -692,21 +685,54 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
for mem in mems:
smt_mems[mem["path"]].append(mem)
addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems
max_t = -1
return ywfile_hierwitness_cache
for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
def_bits_re = re.compile(r'([01]+)')
def smt_extract_mask(smt_expr, mask):
chunks = []
def_bits = ''
mask_index_order = mask[::-1]
for matched in def_bits_re.finditer(mask_index_order):
chunks.append(matched.span())
def_bits += matched[0]
if not chunks:
return
if len(chunks) == 1:
start, end = chunks[0]
if start == 0 and end == len(mask_index_order):
combined_chunks = smt_expr
else:
combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr)
else:
combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join(
'((_ extract %d %d) x)' % (end - 1, start)
for start, end in reversed(chunks)
))
return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1]
def smt_concat(exprs):
if not exprs:
return ""
if len(exprs) == 1:
return exprs[1]
return "(concat %s)" % ' '.join(exprs)
def ywfile_signal(sig, step, mask=None):
assert sig.width > 0
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
sig_end = sig.offset + sig.width
output = []
sig_end = sig.offset + len(bits)
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
@ -721,37 +747,23 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
if common_end <= common_offset:
continue
smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire)
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
if bit_slice.count("?") == len(bit_slice):
continue
if smt_bool:
assert width == 1
smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
else:
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
smt_expr = "(ite %s #b1 #b0)" % smt_expr
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr))
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"])
if bv:
word_low = sig.memory_addr * width
@ -762,21 +774,58 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
if len(bits) < width:
slice_high = sig.offset + len(bits) - 1
if sig.width < width:
slice_high = sig.offset + sig.width - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
bit_slice = bits
output.append((0, sig.width, smt_expr))
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
output.sort()
smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
output = [chunk for chunk in output if chunk[0] != chunk[1]]
pos = 0
for start, end, smt_expr in output:
assert start == pos
pos = end
assert pos == sig.width
if len(output) == 1:
return output[0][-1]
return smt_concat(smt_expr for start, end, smt_expr in reversed(output))
def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}
with open(inywfile, "r") as f:
inyw = ReadWitness(f)
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
bits_re = re.compile(r'[01?]*$')
max_t = -1
for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
smt_expr = ywfile_signal(sig, map_steps.get(t, t))
smt_expr, bits = smt_extract_mask(smt_expr, bits)
smt_constr = "(= %s #b%s)" % (smt_expr, bits)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
max_t = t
return max_t
if inywfile is not None:
@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None):
exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs)
all_sigs.append(sigs)
all_sigs.append((step_values, sigs))
bvs = iter(smt.get_list(exprs))
for sigs in all_sigs:
for (step_values, sigs) in all_sigs:
for sig in sigs:
value = smt.bv2bin(next(bvs))
step_values[sig["sig"]] = value

View File

@ -1,7 +1,7 @@
from collections import defaultdict
import json
import typing
from functools import partial
import ywio
if typing.TYPE_CHECKING:
import smtbmc
@ -34,6 +34,7 @@ class Incremental:
self._witness_index = None
self._yw_constraints = {}
self._define_sorts = {}
def setup(self):
generic_assert_map = smtbmc.get_assert_map(
@ -175,11 +176,7 @@ class Incremental:
if len(expr) == 1:
smt_out.push({"and": "true", "or": "false"}[expr[0]])
elif len(expr) == 2:
arg_sort = self.expr(expr[1], smt_out)
if arg_sort != "Bool":
raise InteractiveError(
f"arguments of {json.dumps(expr[0])} must have sort Bool"
)
self.expr(expr[1], smt_out, required_sort="Bool")
else:
sep = f"({expr[0]} "
for arg in expr[1:]:
@ -189,7 +186,51 @@ class Incremental:
smt_out.append(")")
return "Bool"
def expr_bv_binop(self, expr, smt_out):
self.expr_arg_len(expr, 2)
smt_out.append(f"({expr[0]} ")
arg_sort = self.expr(expr[1], smt_out, required_sort=("BitVec", None))
smt_out.append(" ")
self.expr(expr[2], smt_out, required_sort=arg_sort)
smt_out.append(")")
return arg_sort
def expr_extract(self, expr, smt_out):
self.expr_arg_len(expr, 3)
hi = expr[1]
lo = expr[2]
smt_out.append(f"((_ extract {hi} {lo}) ")
arg_sort = self.expr(expr[3], smt_out, required_sort=("BitVec", None))
smt_out.append(")")
if not (isinstance(hi, int) and 0 <= hi < arg_sort[1]):
raise InteractiveError(
f"high bit index must be 0 <= index < {arg_sort[1]}, is {hi!r}"
)
if not (isinstance(lo, int) and 0 <= lo <= hi):
raise InteractiveError(
f"low bit index must be 0 <= index < {hi}, is {lo!r}"
)
return "BitVec", hi - lo + 1
def expr_bv(self, expr, smt_out):
self.expr_arg_len(expr, 1)
arg = expr[1]
if not isinstance(arg, str) or arg.count("0") + arg.count("1") != len(arg):
raise InteractiveError("bv argument must contain only 0 or 1 bits")
smt_out.append("#b" + arg)
return "BitVec", len(arg)
def expr_yw(self, expr, smt_out):
self.expr_arg_len(expr, 1, 2)
if len(expr) == 2:
name = None
step = expr[1]
@ -219,6 +260,40 @@ class Incremental:
return "Bool"
def expr_yw_sig(self, expr, smt_out):
self.expr_arg_len(expr, 3, 4)
step = expr[1]
path = expr[2]
offset = expr[3]
width = expr[4] if len(expr) == 5 else 1
if not isinstance(offset, int) or offset < 0:
raise InteractiveError(
f"offset must be a non-negative integer, got {json.dumps(offset)}"
)
if not isinstance(width, int) or width <= 0:
raise InteractiveError(
f"width must be a positive integer, got {json.dumps(width)}"
)
if not isinstance(path, list) or not all(isinstance(s, str) for s in path):
raise InteractiveError(
f"path must be a string list, got {json.dumps(path)}"
)
if step not in self.state_set:
raise InteractiveError(f"step {step} not declared")
smt_expr = smtbmc.ywfile_signal(
ywio.WitnessSig(path=path, offset=offset, width=width), step
)
smt_out.append(smt_expr)
return "BitVec", width
def expr_smtlib(self, expr, smt_out):
self.expr_arg_len(expr, 2)
@ -231,10 +306,15 @@ class Incremental:
f"got {json.dumps(smtlib_expr)}"
)
if not isinstance(sort, str):
raise InteractiveError(
f"raw SMT-LIB sort has to be a string, got {json.dumps(sort)}"
)
if (
isinstance(sort, list)
and len(sort) == 2
and sort[0] == "BitVec"
and (sort[1] is None or isinstance(sort[1], int))
):
sort = tuple(sort)
elif not isinstance(sort, str):
raise InteractiveError(f"unsupported raw SMT-LIB sort {json.dumps(sort)}")
smt_out.append(smtlib_expr)
return sort
@ -258,6 +338,14 @@ class Incremental:
return sort
def expr_def(self, expr, smt_out):
self.expr_arg_len(expr, 1)
sort = self._define_sorts.get(expr[1])
if sort is None:
raise InteractiveError(f"unknown definition {json.dumps(expr)}")
smt_out.append(expr[1])
return sort
expr_handlers = {
"step": expr_step,
"cell": expr_cell,
@ -270,8 +358,15 @@ class Incremental:
"not": expr_not,
"and": expr_andor,
"or": expr_andor,
"bv": expr_bv,
"bvand": expr_bv_binop,
"bvor": expr_bv_binop,
"bvxor": expr_bv_binop,
"extract": expr_extract,
"def": expr_def,
"=": expr_eq,
"yw": expr_yw,
"yw_sig": expr_yw_sig,
"smtlib": expr_smtlib,
"!": expr_label,
}
@ -305,10 +400,13 @@ class Incremental:
raise InteractiveError(f"unknown expression {json.dumps(expr[0])}")
def expr_smt(self, expr, required_sort):
return self.expr_smt_and_sort(expr, required_sort)[0]
def expr_smt_and_sort(self, expr, required_sort=None):
smt_out = []
self.expr(expr, smt_out, required_sort=required_sort)
output_sort = self.expr(expr, smt_out, required_sort=required_sort)
out = "".join(smt_out)
return out
return out, output_sort
def cmd_new_step(self, cmd):
step = self.arg_step(cmd, declare=True)
@ -338,7 +436,6 @@ class Incremental:
expr = cmd.get("expr")
key = cmd.get("key")
key = mkkey(key)
result = smtbmc.smt.smt2_assumptions.pop(key, None)
@ -348,7 +445,7 @@ class Incremental:
return result
def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get("minimize")))
def cmd_push(self, cmd):
smtbmc.smt_push()
@ -370,6 +467,27 @@ class Incremental:
if response:
return smtbmc.smt.read()
def cmd_define(self, cmd):
expr = cmd.get("expr")
if expr is None:
raise InteractiveError("'define' copmmand requires 'expr' parameter")
expr, sort = self.expr_smt_and_sort(expr)
if isinstance(sort, tuple) and sort[0] == "module":
raise InteractiveError("'define' does not support module sorts")
define_name = f"|inc def {len(self._define_sorts)}|"
self._define_sorts[define_name] = sort
if isinstance(sort, tuple):
sort = f"(_ {' '.join(map(str, sort))})"
smtbmc.smt.write(f"(define-const {define_name} {sort} {expr})")
return {"name": define_name}
def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
if self._cached_hierwitness[allregs] is not None:
@ -451,6 +569,7 @@ class Incremental:
"pop": cmd_pop,
"check": cmd_check,
"smtlib": cmd_smtlib,
"define": cmd_define,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,

View File

@ -160,6 +160,7 @@ class SmtIo:
self.noincr = opts.noincr
self.info_stmts = opts.info_stmts
self.nocomments = opts.nocomments
self.smt2_options.update(opts.smt2_options)
else:
self.solver = "yices"
@ -959,6 +960,8 @@ class SmtIo:
return int(self.bv2bin(v), 2)
def get_raw_unsat_assumptions(self):
if not self.smt2_assumptions:
return []
self.write("(get-unsat-assumptions)")
exprs = set(self.unparse(part) for part in self.parse(self.read()))
unsat_assumptions = []
@ -973,6 +976,10 @@ class SmtIo:
def get_unsat_assumptions(self, minimize=False):
if not minimize:
return self.get_raw_unsat_assumptions()
orig_assumptions = self.smt2_assumptions
self.smt2_assumptions = dict(orig_assumptions)
required_assumptions = {}
while True:
@ -998,6 +1005,7 @@ class SmtIo:
required_assumptions[candidate_key] = candidate_assume
if candidate_assumptions is not None:
self.smt2_assumptions = orig_assumptions
return list(required_assumptions)
def get(self, expr):
@ -1146,7 +1154,7 @@ class SmtIo:
class SmtOpts:
def __init__(self):
self.shortopts = "s:S:v"
self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"]
self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments", "smt2-option="]
self.solver = "yices"
self.solver_opts = list()
self.debug_print = False
@ -1159,6 +1167,7 @@ class SmtOpts:
self.logic = None
self.info_stmts = list()
self.nocomments = False
self.smt2_options = {}
def handle(self, o, a):
if o == "-s":
@ -1185,6 +1194,13 @@ class SmtOpts:
self.info_stmts.append(a)
elif o == "--nocomments":
self.nocomments = True
elif o == "--smt2-option":
args = a.split('=', 1)
if len(args) != 2:
print("--smt2-option expects an <option>=<value> argument")
sys.exit(1)
option, value = args
self.smt2_options[option] = value
else:
return False
return True
@ -1208,6 +1224,9 @@ class SmtOpts:
if solver is "dummy", read solver output from that file
otherwise: write solver output to that file
--smt2-option <option>=<value>
enable an SMT-LIBv2 option.
-v
enable debug output