/*
 *  ezSAT -- A simple and easy to use CNF generator for SAT solvers
 *
 *  Copyright (C) 2013  Clifford Wolf <clifford@clifford.at>
 *
 *  Permission to use, copy, modify, and/or distribute this software for any
 *  purpose with or without fee is hereby granted, provided that the above
 *  copyright notice and this permission notice appear in all copies.
 *
 *  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 *  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 *  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 *  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 *  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 *  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 *  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

#include "ezminisat.h"
#include <stdio.h>

struct xorshift128 {
	uint32_t x, y, z, w;
	xorshift128() {
		x = 123456789;
		y = 362436069;
		z = 521288629;
		w = 88675123;
	}
	uint32_t operator()() {
		uint32_t t = x ^ (x << 11);
		x = y; y = z; z = w;
		w ^= (w >> 19) ^ t ^ (t >> 8);
		return w;
	}
};

bool test(ezSAT &sat, int assumption = 0)
{
	std::vector<int> modelExpressions;
	std::vector<bool> modelValues;

	for (int id = 1; id <= sat.numLiterals(); id++)
		if (sat.bound(id))
			modelExpressions.push_back(id);

	if (sat.solve(modelExpressions, modelValues, assumption)) {
		printf("satisfiable:");
		for (int i = 0; i < int(modelExpressions.size()); i++)
			printf(" %s=%d", sat.to_string(modelExpressions[i]).c_str(), int(modelValues[i]));
		printf("\n\n");
		return true;
	} else {
		printf("not satisfiable.\n\n");
		return false;
	}
}

// ------------------------------------------------------------------------------------------------------------

void test_simple()
{
	printf("==== %s ====\n\n", __PRETTY_FUNCTION__);

	ezMiniSAT sat;
	sat.non_incremental();
	sat.assume(sat.OR("A", "B"));
	sat.assume(sat.NOT(sat.AND("A", "B")));
	test(sat);
}

// ------------------------------------------------------------------------------------------------------------

void test_xorshift32_try(ezSAT &sat, uint32_t input_pattern)
{
	uint32_t output_pattern = input_pattern;
	output_pattern ^= output_pattern << 13;
	output_pattern ^= output_pattern >> 17;
	output_pattern ^= output_pattern << 5;

	std::vector<int> modelExpressions;
	std::vector<int> forwardAssumptions, backwardAssumptions;
	std::vector<bool> forwardModel, backwardModel;

	sat.vec_append(modelExpressions, sat.vec_var("i", 32));
	sat.vec_append(modelExpressions, sat.vec_var("o", 32));

	sat.vec_append_unsigned(forwardAssumptions,  sat.vec_var("i", 32), input_pattern);
	sat.vec_append_unsigned(backwardAssumptions,  sat.vec_var("o", 32), output_pattern);

	if (!sat.solve(modelExpressions, backwardModel, backwardAssumptions)) {
		printf("backward solving failed!\n");
		abort();
	}

	if (!sat.solve(modelExpressions, forwardModel, forwardAssumptions)) {
		printf("forward solving failed!\n");
		abort();
	}

	printf("xorshift32 test with input pattern 0x%08x:\n", input_pattern);

	printf("forward  solution: input=0x%08x output=0x%08x\n",
			(unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("i", 32)),
			(unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("o", 32)));

	printf("backward solution: input=0x%08x output=0x%08x\n",
			(unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("i", 32)),
			(unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("o", 32)));

	if (forwardModel != backwardModel) {
		printf("forward and backward results are inconsistend!\n");
		abort();
	}

	printf("passed.\n\n");
}

void test_xorshift32()
{
	printf("==== %s ====\n\n", __PRETTY_FUNCTION__);

	ezMiniSAT sat;
	sat.keep_cnf();

	xorshift128 rng;

	std::vector<int> bits = sat.vec_var("i", 32);

	bits = sat.vec_xor(bits, sat.vec_shl(bits, 13));
	bits = sat.vec_xor(bits, sat.vec_shr(bits, 17));
	bits = sat.vec_xor(bits, sat.vec_shl(bits,  5));

	sat.vec_set(bits, sat.vec_var("o", 32));

	test_xorshift32_try(sat, 0);
	test_xorshift32_try(sat, 314159265);
	test_xorshift32_try(sat, rng());
	test_xorshift32_try(sat, rng());
	test_xorshift32_try(sat, rng());
	test_xorshift32_try(sat, rng());

	sat.printDIMACS(stdout, true);
	printf("\n");
}

// ------------------------------------------------------------------------------------------------------------

#define CHECK(_expr1, _expr2) check(#_expr1, _expr1, #_expr2, _expr2)

void check(const char *expr1_str, bool expr1, const char *expr2_str, bool expr2)
{
	if (expr1 == expr2) {
		printf("[ %s ] == [ %s ] .. ok (%s == %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false");
	} else {
		printf("[ %s ] != [ %s ] .. ERROR (%s != %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false");
		abort();
	}
}

void test_signed(int8_t a, int8_t b, int8_t c)
{
	ezMiniSAT sat;

	std::vector<int> av = sat.vec_const_signed(a, 8);
	std::vector<int> bv = sat.vec_const_signed(b, 8);
	std::vector<int> cv = sat.vec_const_signed(c, 8);

	printf("Testing signed arithmetic using: a=%+d, b=%+d, c=%+d\n", int(a), int(b), int(c));

	CHECK(a <  b+c, sat.solve(sat.vec_lt_signed(av, sat.vec_add(bv, cv))));
	CHECK(a <= b-c, sat.solve(sat.vec_le_signed(av, sat.vec_sub(bv, cv))));

	CHECK(a >  b+c, sat.solve(sat.vec_gt_signed(av, sat.vec_add(bv, cv))));
	CHECK(a >= b-c, sat.solve(sat.vec_ge_signed(av, sat.vec_sub(bv, cv))));

	printf("\n");
}

void test_unsigned(uint8_t a, uint8_t b, uint8_t c)
{
	ezMiniSAT sat;

	if (b < c)
		b ^= c, c ^= b, b ^= c;

	std::vector<int> av = sat.vec_const_unsigned(a, 8);
	std::vector<int> bv = sat.vec_const_unsigned(b, 8);
	std::vector<int> cv = sat.vec_const_unsigned(c, 8);

	printf("Testing unsigned arithmetic using: a=%d, b=%d, c=%d\n", int(a), int(b), int(c));

	CHECK(a <  b+c, sat.solve(sat.vec_lt_unsigned(av, sat.vec_add(bv, cv))));
	CHECK(a <= b-c, sat.solve(sat.vec_le_unsigned(av, sat.vec_sub(bv, cv))));

	CHECK(a >  b+c, sat.solve(sat.vec_gt_unsigned(av, sat.vec_add(bv, cv))));
	CHECK(a >= b-c, sat.solve(sat.vec_ge_unsigned(av, sat.vec_sub(bv, cv))));

	printf("\n");
}

void test_count(uint32_t x)
{
	ezMiniSAT sat;

	int count = 0;
	for (int i = 0; i < 32; i++)
		if (((x >> i) & 1) != 0)
			count++;

	printf("Testing bit counting using x=0x%08x (%d set bits) .. ", x, count);

	std::vector<int> v = sat.vec_const_unsigned(x, 32);

	std::vector<int> cv6 = sat.vec_const_unsigned(count, 6);
	std::vector<int> cv4 = sat.vec_const_unsigned(count <= 15 ? count : 15, 4);

	if (cv6 != sat.vec_count(v, 6, false)) {
		fprintf(stderr, "FAILED 6bit-no-clipping test!\n");
		abort();
	}

	if (cv4 != sat.vec_count(v, 4, true)) {
		fprintf(stderr, "FAILED 4bit-clipping test!\n");
		abort();
	}

	printf("ok.\n");
}

void test_arith()
{
	printf("==== %s ====\n\n", __PRETTY_FUNCTION__);

	xorshift128 rng;

	for (int i = 0; i < 100; i++)
		test_signed(rng() % 19 - 10, rng() % 19 - 10, rng() % 19 - 10);

	for (int i = 0; i < 100; i++)
		test_unsigned(rng() % 10, rng() % 10, rng() % 10);

	test_count(0x00000000);
	test_count(0xffffffff);
	for (int i = 0; i < 30; i++)
		test_count(rng());

	printf("\n");
}

// ------------------------------------------------------------------------------------------------------------

void test_onehot()
{
	printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
	ezMiniSAT ez;

	int a = ez.frozen_literal("a");
	int b = ez.frozen_literal("b");
	int c = ez.frozen_literal("c");
	int d = ez.frozen_literal("d");

	std::vector<int> abcd;
	abcd.push_back(a);
	abcd.push_back(b);
	abcd.push_back(c);
	abcd.push_back(d);

	ez.assume(ez.onehot(abcd));

	int solution_counter = 0;
	while (1)
	{
		std::vector<bool> modelValues;
		bool ok = ez.solve(abcd, modelValues);

		if (!ok)
			break;

		printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]));

		int count_hot = 0;
		std::vector<int> sol;
		for (int i = 0; i < 4; i++) {
			if (modelValues[i])
				count_hot++;
			sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i]));
		}
		ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));

		if (count_hot != 1) {
			fprintf(stderr, "Wrong number of hot bits!\n");
			abort();
		}

		solution_counter++;
	}

	if (solution_counter != 4) {
		fprintf(stderr, "Wrong number of one-hot solutions!\n");
		abort();
	}

	printf("\n");
}

void test_manyhot()
{
	printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
	ezMiniSAT ez;

	int a = ez.frozen_literal("a");
	int b = ez.frozen_literal("b");
	int c = ez.frozen_literal("c");
	int d = ez.frozen_literal("d");

	std::vector<int> abcd;
	abcd.push_back(a);
	abcd.push_back(b);
	abcd.push_back(c);
	abcd.push_back(d);

	ez.assume(ez.manyhot(abcd, 1, 2));

	int solution_counter = 0;
	while (1)
	{
		std::vector<bool> modelValues;
		bool ok = ez.solve(abcd, modelValues);

		if (!ok)
			break;

		printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]));

		int count_hot = 0;
		std::vector<int> sol;
		for (int i = 0; i < 4; i++) {
			if (modelValues[i])
				count_hot++;
			sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i]));
		}
		ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));

		if (count_hot != 1 && count_hot != 2) {
			fprintf(stderr, "Wrong number of hot bits!\n");
			abort();
		}

		solution_counter++;
	}

	if (solution_counter != 4 + 4*3/2) {
		fprintf(stderr, "Wrong number of one-hot solutions!\n");
		abort();
	}

	printf("\n");
}

void test_ordered()
{
	printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
	ezMiniSAT ez;

	int a = ez.frozen_literal("a");
	int b = ez.frozen_literal("b");
	int c = ez.frozen_literal("c");

	int x = ez.frozen_literal("x");
	int y = ez.frozen_literal("y");
	int z = ez.frozen_literal("z");

	std::vector<int> abc;
	abc.push_back(a);
	abc.push_back(b);
	abc.push_back(c);

	std::vector<int> xyz;
	xyz.push_back(x);
	xyz.push_back(y);
	xyz.push_back(z);

	ez.assume(ez.ordered(abc, xyz));

	int solution_counter = 0;

	while (1)
	{
		std::vector<int> modelVariables;
		std::vector<bool> modelValues;

		modelVariables.push_back(a);
		modelVariables.push_back(b);
		modelVariables.push_back(c);

		modelVariables.push_back(x);
		modelVariables.push_back(y);
		modelVariables.push_back(z);

		bool ok = ez.solve(modelVariables, modelValues);

		if (!ok)
			break;

		printf("Solution: %d %d %d | %d %d %d\n",
				int(modelValues[0]), int(modelValues[1]), int(modelValues[2]),
				int(modelValues[3]), int(modelValues[4]), int(modelValues[5]));

		std::vector<int> sol;
		for (size_t i = 0; i < modelVariables.size(); i++)
			sol.push_back(modelValues[i] ? modelVariables[i] : ez.NOT(modelVariables[i]));
		ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));

		solution_counter++;
	}

	if (solution_counter != 8+7+6+5+4+3+2+1) {
		fprintf(stderr, "Wrong number of solutions!\n");
		abort();
	}

	printf("\n");
}

// ------------------------------------------------------------------------------------------------------------


int main()
{
	test_simple();
	test_xorshift32();
	test_arith();
	test_onehot();
	test_manyhot();
	test_ordered();
	printf("Passed all tests.\n\n");
	return 0;
}