diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index e6b4088db..995a714c9 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -199,7 +199,6 @@ def help(): --minimize-assumes when using --track-assumes, solve for a minimal set of sufficient assumptions. - """ + so.helpmsg()) def usage(): @@ -670,18 +669,12 @@ if aimfile is not None: ywfile_hierwitness_cache = None -def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): +def ywfile_hierwitness(): global ywfile_hierwitness_cache - if map_steps is None: - map_steps = {} + if ywfile_hierwitness_cache is None: + ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True) - with open(inywfile, "r") as f: - inyw = ReadWitness(f) - - if ywfile_hierwitness_cache is None: - ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True) - - inits, seqs, clocks, mems = ywfile_hierwitness_cache + inits, seqs, clocks, mems = ywfile_hierwitness smt_wires = defaultdict(list) smt_mems = defaultdict(list) @@ -692,9 +685,128 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): for mem in mems: smt_mems[mem["path"]].append(mem) - addr_re = re.compile(r'\\\[[0-9]+\]$') - bits_re = re.compile(r'[01?]*$') + ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems + return ywfile_hierwitness_cache + +def_bits_re = re.compile(r'([01]+)') + +def smt_extract_mask(smt_expr, mask): + chunks = [] + def_bits = '' + + mask_index_order = mask[::-1] + + for matched in def_bits_re.finditer(mask_index_order): + chunks.append(matched.span()) + def_bits += matched[0] + + if not chunks: + return + + if len(chunks) == 1: + start, end = chunks[0] + if start == 0 and end == len(mask_index_order): + combined_chunks = smt_expr + else: + combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr) + else: + combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join( + '((_ extract %d %d) x)' % (end - 1, start) + for start, end in reversed(chunks) + )) + + return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1] + +def smt_concat(exprs): + if not exprs: + return "" + if len(exprs) == 1: + return exprs[1] + return "(concat %s)" % ' '.join(exprs) + +def ywfile_signal(sig, step, mask=None): + assert sig.width > 0 + + inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness() + sig_end = sig.offset + sig.width + + output = [] + + if sig.path in smt_wires: + for wire in smt_wires[sig.path]: + width, offset = wire["width"], wire["offset"] + + smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1 + + offset = max(offset, 0) + + end = width + offset + common_offset = max(sig.offset, offset) + common_end = min(sig_end, end) + if common_end <= common_offset: + continue + + smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire) + + if not smt_bool: + slice_high = common_end - offset - 1 + slice_low = common_offset - offset + smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr) + else: + smt_expr = "(ite %s #b1 #b0)" % smt_expr + + output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr)) + + if sig.memory_path: + if sig.memory_path in smt_mems: + for mem in smt_mems[sig.memory_path]: + width, size, bv = mem["width"], mem["size"], mem["statebv"] + + smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"]) + + if bv: + word_low = sig.memory_addr * width + word_high = word_low + width - 1 + smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr) + else: + addr_width = (size - 1).bit_length() + addr_bits = f"{sig.memory_addr:0{addr_width}b}" + smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits) + + if sig.width < width: + slice_high = sig.offset + sig.width - 1 + smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr) + + output.append((0, sig.width, smt_expr)) + + output.sort() + + output = [chunk for chunk in output if chunk[0] != chunk[1]] + + pos = 0 + + for start, end, smt_expr in output: + assert start == pos + pos = end + + assert pos == sig.width + + if len(output) == 1: + return output[0][-1] + return smt_concat(smt_expr for start, end, smt_expr in reversed(output)) + +def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): + global ywfile_hierwitness_cache + if map_steps is None: + map_steps = {} + + with open(inywfile, "r") as f: + inyw = ReadWitness(f) + + inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness() + + bits_re = re.compile(r'[01?]*$') max_t = -1 for t, step in inyw.steps(): @@ -706,77 +818,14 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): if not bits_re.match(bits): raise ValueError("unsupported bit value in Yosys witness file") - sig_end = sig.offset + len(bits) - if sig.path in smt_wires: - for wire in smt_wires[sig.path]: - width, offset = wire["width"], wire["offset"] + smt_expr = ywfile_signal(sig, map_steps.get(t, t)) - smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1 + smt_expr, bits = smt_extract_mask(smt_expr, bits) - offset = max(offset, 0) + smt_constr = "(= %s #b%s)" % (smt_expr, bits) + constr_assumes[t].append((inywfile, smt_constr)) - end = width + offset - common_offset = max(sig.offset, offset) - common_end = min(sig_end, end) - if common_end <= common_offset: - continue - - smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire) - - if not smt_bool: - slice_high = common_end - offset - 1 - slice_low = common_offset - offset - smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr) - - bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)] - - if bit_slice.count("?") == len(bit_slice): - continue - - if smt_bool: - assert width == 1 - smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false") - else: - if "?" in bit_slice: - mask = bit_slice.replace("0", "1").replace("?", "0") - bit_slice = bit_slice.replace("?", "0") - smt_expr = "(bvand %s #b%s)" % (smt_expr, mask) - - smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) - - constr_assumes[t].append((inywfile, smt_constr)) - - if sig.memory_path: - if sig.memory_path in smt_mems: - for mem in smt_mems[sig.memory_path]: - width, size, bv = mem["width"], mem["size"], mem["statebv"] - - smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"]) - - if bv: - word_low = sig.memory_addr * width - word_high = word_low + width - 1 - smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr) - else: - addr_width = (size - 1).bit_length() - addr_bits = f"{sig.memory_addr:0{addr_width}b}" - smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits) - - if len(bits) < width: - slice_high = sig.offset + len(bits) - 1 - smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr) - - bit_slice = bits - - if "?" in bit_slice: - mask = bit_slice.replace("0", "1").replace("?", "0") - bit_slice = bit_slice.replace("?", "0") - smt_expr = "(bvand %s #b%s)" % (smt_expr, mask) - - smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) - constr_assumes[t].append((inywfile, smt_constr)) max_t = t - return max_t if inywfile is not None: @@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None): exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs) - all_sigs.append(sigs) + all_sigs.append((step_values, sigs)) bvs = iter(smt.get_list(exprs)) - for sigs in all_sigs: + for (step_values, sigs) in all_sigs: for sig in sigs: value = smt.bv2bin(next(bvs)) step_values[sig["sig"]] = value diff --git a/backends/smt2/smtbmc_incremental.py b/backends/smt2/smtbmc_incremental.py index f43e878f3..0bd280b4a 100644 --- a/backends/smt2/smtbmc_incremental.py +++ b/backends/smt2/smtbmc_incremental.py @@ -1,7 +1,7 @@ from collections import defaultdict import json import typing -from functools import partial +import ywio if typing.TYPE_CHECKING: import smtbmc @@ -34,6 +34,7 @@ class Incremental: self._witness_index = None self._yw_constraints = {} + self._define_sorts = {} def setup(self): generic_assert_map = smtbmc.get_assert_map( @@ -175,11 +176,7 @@ class Incremental: if len(expr) == 1: smt_out.push({"and": "true", "or": "false"}[expr[0]]) elif len(expr) == 2: - arg_sort = self.expr(expr[1], smt_out) - if arg_sort != "Bool": - raise InteractiveError( - f"arguments of {json.dumps(expr[0])} must have sort Bool" - ) + self.expr(expr[1], smt_out, required_sort="Bool") else: sep = f"({expr[0]} " for arg in expr[1:]: @@ -189,7 +186,51 @@ class Incremental: smt_out.append(")") return "Bool" + def expr_bv_binop(self, expr, smt_out): + self.expr_arg_len(expr, 2) + + smt_out.append(f"({expr[0]} ") + arg_sort = self.expr(expr[1], smt_out, required_sort=("BitVec", None)) + smt_out.append(" ") + self.expr(expr[2], smt_out, required_sort=arg_sort) + smt_out.append(")") + return arg_sort + + def expr_extract(self, expr, smt_out): + self.expr_arg_len(expr, 3) + + hi = expr[1] + lo = expr[2] + + smt_out.append(f"((_ extract {hi} {lo}) ") + + arg_sort = self.expr(expr[3], smt_out, required_sort=("BitVec", None)) + smt_out.append(")") + + if not (isinstance(hi, int) and 0 <= hi < arg_sort[1]): + raise InteractiveError( + f"high bit index must be 0 <= index < {arg_sort[1]}, is {hi!r}" + ) + if not (isinstance(lo, int) and 0 <= lo <= hi): + raise InteractiveError( + f"low bit index must be 0 <= index < {hi}, is {lo!r}" + ) + + return "BitVec", hi - lo + 1 + + def expr_bv(self, expr, smt_out): + self.expr_arg_len(expr, 1) + + arg = expr[1] + if not isinstance(arg, str) or arg.count("0") + arg.count("1") != len(arg): + raise InteractiveError("bv argument must contain only 0 or 1 bits") + + smt_out.append("#b" + arg) + + return "BitVec", len(arg) + def expr_yw(self, expr, smt_out): + self.expr_arg_len(expr, 1, 2) if len(expr) == 2: name = None step = expr[1] @@ -219,6 +260,40 @@ class Incremental: return "Bool" + def expr_yw_sig(self, expr, smt_out): + self.expr_arg_len(expr, 3, 4) + + step = expr[1] + path = expr[2] + offset = expr[3] + width = expr[4] if len(expr) == 5 else 1 + + if not isinstance(offset, int) or offset < 0: + raise InteractiveError( + f"offset must be a non-negative integer, got {json.dumps(offset)}" + ) + + if not isinstance(width, int) or width <= 0: + raise InteractiveError( + f"width must be a positive integer, got {json.dumps(width)}" + ) + + if not isinstance(path, list) or not all(isinstance(s, str) for s in path): + raise InteractiveError( + f"path must be a string list, got {json.dumps(path)}" + ) + + if step not in self.state_set: + raise InteractiveError(f"step {step} not declared") + + smt_expr = smtbmc.ywfile_signal( + ywio.WitnessSig(path=path, offset=offset, width=width), step + ) + + smt_out.append(smt_expr) + + return "BitVec", width + def expr_smtlib(self, expr, smt_out): self.expr_arg_len(expr, 2) @@ -231,10 +306,15 @@ class Incremental: f"got {json.dumps(smtlib_expr)}" ) - if not isinstance(sort, str): - raise InteractiveError( - f"raw SMT-LIB sort has to be a string, got {json.dumps(sort)}" - ) + if ( + isinstance(sort, list) + and len(sort) == 2 + and sort[0] == "BitVec" + and (sort[1] is None or isinstance(sort[1], int)) + ): + sort = tuple(sort) + elif not isinstance(sort, str): + raise InteractiveError(f"unsupported raw SMT-LIB sort {json.dumps(sort)}") smt_out.append(smtlib_expr) return sort @@ -258,6 +338,14 @@ class Incremental: return sort + def expr_def(self, expr, smt_out): + self.expr_arg_len(expr, 1) + sort = self._define_sorts.get(expr[1]) + if sort is None: + raise InteractiveError(f"unknown definition {json.dumps(expr)}") + smt_out.append(expr[1]) + return sort + expr_handlers = { "step": expr_step, "cell": expr_cell, @@ -270,8 +358,15 @@ class Incremental: "not": expr_not, "and": expr_andor, "or": expr_andor, + "bv": expr_bv, + "bvand": expr_bv_binop, + "bvor": expr_bv_binop, + "bvxor": expr_bv_binop, + "extract": expr_extract, + "def": expr_def, "=": expr_eq, "yw": expr_yw, + "yw_sig": expr_yw_sig, "smtlib": expr_smtlib, "!": expr_label, } @@ -305,10 +400,13 @@ class Incremental: raise InteractiveError(f"unknown expression {json.dumps(expr[0])}") def expr_smt(self, expr, required_sort): + return self.expr_smt_and_sort(expr, required_sort)[0] + + def expr_smt_and_sort(self, expr, required_sort=None): smt_out = [] - self.expr(expr, smt_out, required_sort=required_sort) + output_sort = self.expr(expr, smt_out, required_sort=required_sort) out = "".join(smt_out) - return out + return out, output_sort def cmd_new_step(self, cmd): step = self.arg_step(cmd, declare=True) @@ -338,7 +436,6 @@ class Incremental: expr = cmd.get("expr") key = cmd.get("key") - key = mkkey(key) result = smtbmc.smt.smt2_assumptions.pop(key, None) @@ -348,7 +445,7 @@ class Incremental: return result def cmd_get_unsat_assumptions(self, cmd): - return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize'))) + return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get("minimize"))) def cmd_push(self, cmd): smtbmc.smt_push() @@ -370,6 +467,27 @@ class Incremental: if response: return smtbmc.smt.read() + def cmd_define(self, cmd): + expr = cmd.get("expr") + if expr is None: + raise InteractiveError("'define' copmmand requires 'expr' parameter") + + expr, sort = self.expr_smt_and_sort(expr) + + if isinstance(sort, tuple) and sort[0] == "module": + raise InteractiveError("'define' does not support module sorts") + + define_name = f"|inc def {len(self._define_sorts)}|" + + self._define_sorts[define_name] = sort + + if isinstance(sort, tuple): + sort = f"(_ {' '.join(map(str, sort))})" + + smtbmc.smt.write(f"(define-const {define_name} {sort} {expr})") + + return {"name": define_name} + def cmd_design_hierwitness(self, cmd=None): allregs = (cmd is None) or bool(cmd.get("allreges", False)) if self._cached_hierwitness[allregs] is not None: @@ -451,6 +569,7 @@ class Incremental: "pop": cmd_pop, "check": cmd_check, "smtlib": cmd_smtlib, + "define": cmd_define, "design_hierwitness": cmd_design_hierwitness, "write_yw_trace": cmd_write_yw_trace, "read_yw_trace": cmd_read_yw_trace, diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index e32f43c60..5fc3ab5a4 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -160,6 +160,7 @@ class SmtIo: self.noincr = opts.noincr self.info_stmts = opts.info_stmts self.nocomments = opts.nocomments + self.smt2_options.update(opts.smt2_options) else: self.solver = "yices" @@ -959,6 +960,8 @@ class SmtIo: return int(self.bv2bin(v), 2) def get_raw_unsat_assumptions(self): + if not self.smt2_assumptions: + return [] self.write("(get-unsat-assumptions)") exprs = set(self.unparse(part) for part in self.parse(self.read())) unsat_assumptions = [] @@ -973,6 +976,10 @@ class SmtIo: def get_unsat_assumptions(self, minimize=False): if not minimize: return self.get_raw_unsat_assumptions() + orig_assumptions = self.smt2_assumptions + + self.smt2_assumptions = dict(orig_assumptions) + required_assumptions = {} while True: @@ -998,6 +1005,7 @@ class SmtIo: required_assumptions[candidate_key] = candidate_assume if candidate_assumptions is not None: + self.smt2_assumptions = orig_assumptions return list(required_assumptions) def get(self, expr): @@ -1146,7 +1154,7 @@ class SmtIo: class SmtOpts: def __init__(self): self.shortopts = "s:S:v" - self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"] + self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments", "smt2-option="] self.solver = "yices" self.solver_opts = list() self.debug_print = False @@ -1159,6 +1167,7 @@ class SmtOpts: self.logic = None self.info_stmts = list() self.nocomments = False + self.smt2_options = {} def handle(self, o, a): if o == "-s": @@ -1185,6 +1194,13 @@ class SmtOpts: self.info_stmts.append(a) elif o == "--nocomments": self.nocomments = True + elif o == "--smt2-option": + args = a.split('=', 1) + if len(args) != 2: + print("--smt2-option expects an