mirror of https://github.com/YosysHQ/yosys.git
180 lines
7.0 KiB
Python
180 lines
7.0 KiB
Python
|
import sys
|
||
|
import argparse
|
||
|
import random
|
||
|
import os
|
||
|
import smtio
|
||
|
import re
|
||
|
|
||
|
class SExprParserError(Exception):
|
||
|
pass
|
||
|
|
||
|
class SExprParser:
|
||
|
def __init__(self):
|
||
|
self.peekbuf = None
|
||
|
self.stack = [[]]
|
||
|
self.atom_pattern = re.compile(r'[a-zA-Z0-9~!@$%^&*_\-+=<>.?/#]+')
|
||
|
def parse_line(self, line):
|
||
|
ptr = 0
|
||
|
while ptr < len(line):
|
||
|
if line[ptr].isspace():
|
||
|
ptr += 1
|
||
|
elif line[ptr] == ';':
|
||
|
break
|
||
|
elif line[ptr] == '(':
|
||
|
ptr += 1
|
||
|
self.stack.append([])
|
||
|
elif line[ptr] == ')':
|
||
|
ptr += 1
|
||
|
assert len(self.stack) > 1, "too many closed parentheses"
|
||
|
v = self.stack.pop()
|
||
|
self.stack[-1].append(v)
|
||
|
else:
|
||
|
match = self.atom_pattern.match(line, ptr)
|
||
|
if match is None:
|
||
|
raise SExprParserError(f"invalid character '{line[ptr]}' in line '{line}'")
|
||
|
start, ptr = match.span()
|
||
|
self.stack[-1].append(line[start:ptr])
|
||
|
def finish(self):
|
||
|
assert len(self.stack) == 1, "too many open parentheses"
|
||
|
def retrieve(self):
|
||
|
rv, self.stack[0] = self.stack[0], []
|
||
|
return rv
|
||
|
|
||
|
def simulate_smt_with_smtio(smt_file_path, vcd_path, smt_io):
|
||
|
inputs = {}
|
||
|
outputs = {}
|
||
|
|
||
|
def handle_datatype(lst):
|
||
|
print(lst)
|
||
|
datatype_name = lst[1]
|
||
|
declarations = lst[2][0][1:] # Skip the first item (e.g., 'mk_inputs')
|
||
|
if datatype_name.endswith("_Inputs"):
|
||
|
for declaration in declarations:
|
||
|
input_name = declaration[0]
|
||
|
bitvec_size = declaration[1][2]
|
||
|
assert input_name.startswith("gold_Inputs_")
|
||
|
inputs[input_name[len("gold_Inputs_"):]] = int(bitvec_size)
|
||
|
elif datatype_name.endswith("_Outputs"):
|
||
|
for declaration in declarations:
|
||
|
output_name = declaration[0]
|
||
|
bitvec_size = declaration[1][2]
|
||
|
assert output_name.startswith("gold_Outputs_")
|
||
|
outputs[output_name[len("gold_Outputs_"):]] = int(bitvec_size)
|
||
|
|
||
|
parser = SExprParser()
|
||
|
with open(smt_file_path, 'r') as smt_file:
|
||
|
for line in smt_file:
|
||
|
parser.parse_line(line)
|
||
|
for expr in parser.retrieve():
|
||
|
smt_io.write(smt_io.unparse(expr))
|
||
|
if expr[0] == "declare-datatype":
|
||
|
handle_datatype(expr)
|
||
|
|
||
|
parser.finish()
|
||
|
assert smt_io.check_sat() == 'sat'
|
||
|
|
||
|
def set_step(inputs, step):
|
||
|
# This function assumes 'inputs' is a dictionary like {"A": 5, "B": 4}
|
||
|
# and 'input_values' is a dictionary like {"A": 5, "B": 13} specifying the concrete values for each input.
|
||
|
|
||
|
mk_inputs_parts = []
|
||
|
for input_name, width in inputs.items():
|
||
|
value = random.getrandbits(width) # Generate a random number up to the maximum value for the bit size
|
||
|
binary_string = format(value, '0{}b'.format(width)) # Convert value to binary with leading zeros
|
||
|
mk_inputs_parts.append(f"#b{binary_string}")
|
||
|
|
||
|
mk_inputs_call = "gold_Inputs " + " ".join(mk_inputs_parts)
|
||
|
define_inputs = f"(define-const test_inputs_step_n{step} gold_Inputs ({mk_inputs_call}))\n"
|
||
|
|
||
|
define_outputs = f"(define-const test_outputs_step_n{step} gold_Outputs (first (gold test_inputs_step_n{step} gold_State)))\n"
|
||
|
smt_commands = [define_inputs, define_outputs]
|
||
|
return smt_commands
|
||
|
|
||
|
num_steps = 1000
|
||
|
smt_commands = []
|
||
|
for step in range(num_steps):
|
||
|
for step_command in set_step(inputs, step):
|
||
|
smt_commands.append(step_command)
|
||
|
|
||
|
for command in smt_commands:
|
||
|
smt_io.write(command)
|
||
|
|
||
|
assert smt_io.check_sat() == 'sat'
|
||
|
|
||
|
# Store signal values
|
||
|
signals = {name: [] for name in list(inputs.keys()) + list(outputs.keys())}
|
||
|
# Retrieve and print values for each state
|
||
|
def hex_to_bin(value):
|
||
|
if value.startswith('x'):
|
||
|
hex_value = value[1:] # Remove the 'x' prefix
|
||
|
bin_value = bin(int(hex_value, 16))[2:] # Convert to binary and remove the '0b' prefix
|
||
|
return f'b{bin_value.zfill(len(hex_value) * 4)}' # Add 'b' prefix and pad with zeros
|
||
|
return value
|
||
|
|
||
|
combined_assertions = []
|
||
|
for step in range(num_steps):
|
||
|
print(f"Values for step {step + 1}:")
|
||
|
for input_name, width in inputs.items():
|
||
|
value = smt_io.get(f'(gold_Inputs_{input_name} test_inputs_step_n{step})')
|
||
|
value = hex_to_bin(value[1:])
|
||
|
print(f" {input_name}: {value}")
|
||
|
signals[input_name].append((step, value))
|
||
|
for output_name, width in outputs.items():
|
||
|
value = smt_io.get(f'(gold_Outputs_{output_name} test_outputs_step_n{step})')
|
||
|
value = hex_to_bin(value[1:])
|
||
|
print(f" {output_name}: {value}")
|
||
|
signals[output_name].append((step, value))
|
||
|
combined_assertions.append(f'(= (gold_Outputs_{output_name} test_outputs_step_n{step}) #{value})')
|
||
|
# Create a single assertion covering all timesteps
|
||
|
combined_condition = " ".join(combined_assertions)
|
||
|
smt_io.write(f'(assert (not (and {combined_condition})))')
|
||
|
|
||
|
# Check the combined assertion
|
||
|
assert smt_io.check_sat(["unsat"]) == "unsat"
|
||
|
|
||
|
def write_vcd(filename, signals, timescale='1 ns', date='today'):
|
||
|
with open(filename, 'w') as f:
|
||
|
# Write the header
|
||
|
f.write(f"$date\n {date}\n$end\n")
|
||
|
f.write(f"$timescale {timescale} $end\n")
|
||
|
|
||
|
# Declare signals
|
||
|
f.write("$scope module gold $end\n")
|
||
|
for signal_name, changes in signals.items():
|
||
|
signal_size = len(changes[0][1])
|
||
|
f.write(f"$var wire {signal_size - 1} {signal_name} {signal_name} $end\n")
|
||
|
f.write("$upscope $end\n")
|
||
|
f.write("$enddefinitions $end\n")
|
||
|
|
||
|
# Collect all unique timestamps
|
||
|
timestamps = sorted(set(time for changes in signals.values() for time, _ in changes))
|
||
|
|
||
|
# Write initial values
|
||
|
f.write("#0\n")
|
||
|
for signal_name, changes in signals.items():
|
||
|
for time, value in changes:
|
||
|
if time == 0:
|
||
|
f.write(f"{value} {signal_name}\n")
|
||
|
|
||
|
# Write value changes
|
||
|
for time in timestamps:
|
||
|
if time != 0:
|
||
|
f.write(f"#{time}\n")
|
||
|
for signal_name, changes in signals.items():
|
||
|
for change_time, value in changes:
|
||
|
if change_time == time:
|
||
|
f.write(f"{value} {signal_name}\n")
|
||
|
|
||
|
|
||
|
write_vcd(vcd_path, signals)
|
||
|
|
||
|
def simulate_smt(smt_file_path, vcd_path):
|
||
|
so = smtio.SmtOpts()
|
||
|
so.solver = "z3"
|
||
|
so.logic = "BV"
|
||
|
so.debug_print = True
|
||
|
smt_io = smtio.SmtIo(opts=so)
|
||
|
try:
|
||
|
simulate_smt_with_smtio(smt_file_path, vcd_path, smt_io)
|
||
|
finally:
|
||
|
smt_io.p_close()
|