diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc index 177bcd8a3..8c666ca1f 100644 --- a/libs/ezsat/ezsat.cc +++ b/libs/ezsat/ezsat.cc @@ -1371,24 +1371,39 @@ int ezSAT::onehot(const std::vector &vec, bool max_only) if (max_only == false) formula.push_back(expression(OpOr, vec)); - // create binary vector - int num_bits = clog2(vec.size()); - std::vector bits; - for (int k = 0; k < num_bits; k++) - bits.push_back(literal()); + if (vec.size() < 8) + { + // fall-back to simple O(n^2) solution for small cases + for (size_t i = 0; i < vec.size(); i++) + for (size_t j = i+1; j < vec.size(); j++) { + std::vector clause; + clause.push_back(NOT(vec[i])); + clause.push_back(NOT(vec[j])); + formula.push_back(expression(OpOr, clause)); + } + } + else + { + // create binary vector + int num_bits = clog2(vec.size()); + std::vector bits; + for (int k = 0; k < num_bits; k++) + bits.push_back(literal()); - // add at-most-one clauses using binary encoding - for (size_t i = 0; i < vec.size(); i++) - for (int k = 0; k < num_bits; k++) { - std::vector clause; - clause.push_back(NOT(vec[i])); - clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k])); - formula.push_back(expression(OpOr, clause)); - } + // add at-most-one clauses using binary encoding + for (size_t i = 0; i < vec.size(); i++) + for (int k = 0; k < num_bits; k++) { + std::vector clause; + clause.push_back(NOT(vec[i])); + clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k])); + formula.push_back(expression(OpOr, clause)); + } + } return expression(OpAnd, formula); } +#if 0 int ezSAT::manyhot(const std::vector &vec, int min_hot, int max_hot) { // many-hot encoding using a simple sorting network @@ -1426,6 +1441,123 @@ int ezSAT::manyhot(const std::vector &vec, int min_hot, int max_hot) return expression(OpAnd, formula); } +#else +static std::vector lfsr_sym(ezSAT *that, const std::vector &vec, int poly) +{ + std::vector out; + + for (int i = 0; i < int(vec.size()); i++) + if ((poly & (1 << (i+1))) != 0) { + if (out.empty()) + out.push_back(vec.at(i)); + else + out.at(0) = that->XOR(out.at(0), vec.at(i)); + } + + for (int i = 0; i+1 < int(vec.size()); i++) + out.push_back(vec.at(i)); + + return out; +} + +static int lfsr_num(int vec, int poly, int cnt = 1) +{ + int mask = poly >> 1; + mask |= mask >> 1; + mask |= mask >> 2; + mask |= mask >> 4; + mask |= mask >> 8; + mask |= mask >> 16; + + while (cnt-- > 0) { + int bits = vec & (poly >> 1); + bits = ((bits & 0xAAAAAAAA) >> 1) ^ (bits & 0x55555555); + bits = ((bits & 0x44444444) >> 2) ^ (bits & 0x11111111); + bits = ((bits & 0x10101010) >> 4) ^ (bits & 0x01010101); + bits = ((bits & 0x01000100) >> 8) ^ (bits & 0x00010001); + bits = ((bits & 0x00010000) >> 16) ^ (bits & 0x00000001); + vec = ((vec << 1) | bits) & mask; + } + + return vec; +} + +int ezSAT::manyhot(const std::vector &vec, int min_hot, int max_hot) +{ + // many-hot encoding using LFSR as counter + + int poly = 0; + int nbits = 0; + + if (vec.size() < 3) { + poly = (1 << 2) | (1 << 1) | 1; + nbits = 2; + } else + if (vec.size() < 7) { + poly = (1 << 3) | (1 << 2) | 1; + nbits = 3; + } else + if (vec.size() < 15) { + poly = (1 << 4) | (1 << 3) | 1; + nbits = 4; + } else + if (vec.size() < 31) { + poly = (1 << 5) | (1 << 3) | 1; + nbits = 5; + } else + if (vec.size() < 63) { + poly = (1 << 6) | (1 << 5) | 1; + nbits = 6; + } else + if (vec.size() < 127) { + poly = (1 << 7) | (1 << 6) | 1; + nbits = 7; + } else + // if (vec.size() < 255) { + // poly = (1 << 8) | (1 << 6) | (1 << 5) | (1 << 4) | 1; + // nbits = 8; + // } else + if (vec.size() < 511) { + poly = (1 << 9) | (1 << 5) | 1; + nbits = 9; + } else { + assert(0); + } + + std::vector min_val; + std::vector max_val; + + if (min_hot > 1) + min_val = vec_const_unsigned(lfsr_num(1, poly, min_hot), nbits); + + if (max_hot >= 0) + max_val = vec_const_unsigned(lfsr_num(1, poly, max_hot+1), nbits); + + std::vector state = vec_const_unsigned(1, nbits); + + std::vector match_min; + std::vector match_max; + + if (min_hot == 1) + match_min = vec; + + for (int i = 0; i < int(vec.size()); i++) + { + state = vec_ite(vec[i], lfsr_sym(this, state, poly), state); + + if (!min_val.empty() && i+1 >= min_hot) + match_min.push_back(vec_eq(min_val, state)); + + if (!max_val.empty() && i >= max_hot) + match_max.push_back(vec_eq(max_val, state)); + } + + int min_matched = min_hot ? vec_reduce_or(match_min) : CONST_TRUE; + int max_matched = vec_reduce_or(match_max); + + return AND(min_matched, NOT(max_matched)); +} +#endif int ezSAT::ordered(const std::vector &vec1, const std::vector &vec2, bool allow_equal) {