#####################################################################
# Python script to check if heterogeneous blocks, e.g., RAM and multipliers
# have been inferred during openfpga flow
# # This script will
#   - Check the .csv file generated by openfpga task-run to find out
#     the number of each type of heterogeneous blocks
#####################################################################

import os
from os.path import dirname, abspath, isfile
import shutil
import re
import argparse
import logging
import csv

#####################################################################
# Contants
#####################################################################
csv_name_tag = "name"
csv_metric_tag = "metric"

#####################################################################
# Initialize logger
#####################################################################
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)

#####################################################################
# Parse the options
# - [mandatory option] the file path to .csv file
#####################################################################
parser = argparse.ArgumentParser(
    description="A checker for hetergeneous block mapping in OpenFPGA flow"
)
parser.add_argument(
    "--check_csv_file",
    required=True,
    help="Specify the to-be-checked csv file constaining flow-run information",
)
parser.add_argument(
    "--reference_csv_file",
    required=True,
    help="Specify the reference csv file constaining flow-run information",
)
parser.add_argument(
    "--metric_checklist_csv_file",
    required=True,
    help="Specify the csv file constaining metrics to be checked",
)
# By default, allow a 50% tolerance when checking metrics
parser.add_argument(
    "--check_tolerance",
    default="0.5,1.5",
    help="Specify the tolerance when checking metrics. Format <lower_bound>,<upper_bound>",
)
args = parser.parse_args()

#####################################################################
# Check options:
# - Input csv files must be valid
#   Otherwise, error out
#####################################################################
if not isfile(args.check_csv_file):
    logging.error("Invalid csv file to check: " + args.check_csv_file + "\nFile does not exist!\n")
    exit(1)

if not isfile(args.reference_csv_file):
    logging.error(
        "Invalid reference csv file: " + args.reference_csv_file + "\nFile does not exist!\n"
    )
    exit(1)

if not isfile(args.metric_checklist_csv_file):
    logging.error(
        "Invalid metric checklist csv file: "
        + args.metric_checklist_csv_file
        + "\nFile does not exist!\n"
    )
    exit(1)

#####################################################################
# Parse a checklist for metrics to be checked
#####################################################################
metric_checklist_csv_file = open(args.metric_checklist_csv_file, "r")
metric_checklist_csv_content = csv.DictReader(
    filter(lambda row: row[0] != "#", metric_checklist_csv_file), delimiter=","
)
# Hash the reference results with the name tag
metric_checklist = []
for row in metric_checklist_csv_content:
    metric_checklist.append(row[csv_metric_tag])

#####################################################################
# Parse the reference csv file
# Skip any line start with '#' which is treated as comments
#####################################################################
ref_csv_file = open(args.reference_csv_file, "r")
ref_csv_content = csv.DictReader(filter(lambda row: row[0] != "#", ref_csv_file), delimiter=",")
# Hash the reference results with the name tag
ref_results = {}
for row in ref_csv_content:
    ref_results[row[csv_name_tag]] = row

#####################################################################
# Parse the tolerance to be applied when checking metrics
#####################################################################
lower_bound_factor = float(args.check_tolerance.split(",")[0])
upper_bound_factor = float(args.check_tolerance.split(",")[1])

#####################################################################
# Parse the csv file to check
#####################################################################
with open(args.check_csv_file, newline="") as check_csv_file:
    results_to_check = csv.DictReader(check_csv_file, delimiter=",")
    checkpoint_count = 0
    check_error_count = 0
    for row in results_to_check:
        # Start from line 1 and check information
        for metric_to_check in metric_checklist:
            # Check if the metric is in a range
            if (
                lower_bound_factor * float(ref_results[row[csv_name_tag]][metric_to_check])
                > float(row[metric_to_check])
            ) or (
                upper_bound_factor * float(ref_results[row[csv_name_tag]][metric_to_check])
                < float(row[metric_to_check])
            ):
                # Check QoR failed, error out
                logging.error(
                    "Benchmark "
                    + str(row[csv_name_tag])
                    + " failed in checking '"
                    + str(metric_to_check)
                    + "'\n"
                    + "Found: "
                    + str(row[metric_to_check])
                    + " but expected: "
                    + str(ref_results[row[csv_name_tag]][metric_to_check])
                    + " outside range ["
                    + str(lower_bound_factor * 100)
                    + "%, "
                    + str(upper_bound_factor * 100)
                    + "%]"
                )
                check_error_count += 1
            # Pass this metric check, increase counter
            checkpoint_count += 1
    logging.info("Checked " + str(checkpoint_count) + " metrics")
    logging.info("See " + str(check_error_count) + " QoR failures")

    if 0 < check_error_count:
        exit(1)

#####################################################################
# Post checked results on stdout:
# reaching here, it means all the checks have passed
#####################################################################
with open(args.check_csv_file, newline="") as check_csv_file:
    results_to_check = csv.DictReader(check_csv_file, delimiter=",")
    # Print out keywords: name + metric checklist
    print(str(csv_name_tag) + " ", end="")
    for metric_to_check in metric_checklist:
        print(str(metric_to_check) + " ", end="")
    print("")

    for row in results_to_check:
        # Start from line 1, print checked metrics
        print(row[csv_name_tag] + " ", end="")
        for metric_to_check in metric_checklist:
            print(row[metric_to_check] + " ", end="")
        print("")