diff --git a/FPSim2/src/include/popcnt.hpp b/FPSim2/src/include/popcnt.hpp index d758566..bcdb48a 100644 --- a/FPSim2/src/include/popcnt.hpp +++ b/FPSim2/src/include/popcnt.hpp @@ -17,3 +17,78 @@ static inline uint64_t popcntll(const uint64_t X) { } #endif + +// AVX512 support for vectorized popcount +#if defined(__AVX512VPOPCNTDQ__) && defined(__AVX512F__) + +#include + +static inline uint64_t common_popcnt_avx512(const uint64_t* qptr, + const uint64_t* dbptr, + const size_t start, + const size_t end) { + const size_t len = end - start; + const size_t vec_end = start + (len / 8) * 8; + + __m512i sum = _mm512_setzero_si512(); + + for (size_t j = start; j < vec_end; j += 8) { + __m512i q = _mm512_loadu_si512((__m512i*)&qptr[j]); + __m512i d = _mm512_loadu_si512((__m512i*)&dbptr[j]); + __m512i and_result = _mm512_and_si512(q, d); + __m512i popcnt = _mm512_popcnt_epi64(and_result); + sum = _mm512_add_epi64(sum, popcnt); + } + + // Horizontal sum of 8 uint64_t values + uint64_t result = _mm512_reduce_add_epi64(sum); + + // Handle remaining elements + for (size_t j = vec_end; j < end; j++) { + result += popcntll(qptr[j] & dbptr[j]); + } + + return result; +} + +#define HAS_AVX512_POPCNT 1 + +#else + +#define HAS_AVX512_POPCNT 0 + +#endif + +// Portable restrict keyword +#if defined(_MSC_VER) + #define RESTRICT __restrict +#else + #define RESTRICT __restrict__ +#endif + +// Generic common popcount function - uses AVX512 if available +static inline uint64_t common_popcnt(const uint64_t* RESTRICT qptr, + const uint64_t* RESTRICT dbptr, + const size_t start, + const size_t end) { +#if HAS_AVX512_POPCNT + return common_popcnt_avx512(qptr, dbptr, start, end); +#else + uint64_t result = 0; + size_t j = start; + + // Unroll by 4 for better instruction-level parallelism + const size_t unroll_end = start + ((end - start) / 4) * 4; + for (; j < unroll_end; j += 4) { + result += popcntll(qptr[j] & dbptr[j]); + result += popcntll(qptr[j+1] & dbptr[j+1]); + result += popcntll(qptr[j+2] & dbptr[j+2]); + result += popcntll(qptr[j+3] & dbptr[j+3]); + } + // Handle remainder + for (; j < end; j++) { + result += popcntll(qptr[j] & dbptr[j]); + } + return result; +#endif +} diff --git a/FPSim2/src/sim.cpp b/FPSim2/src/sim.cpp index 42c106e..cee5956 100644 --- a/FPSim2/src/sim.cpp +++ b/FPSim2/src/sim.cpp @@ -74,13 +74,11 @@ py::array_t TverskySearch(const py::array_t py_query, { const auto db_popcnt = dbptr[popcnt_idx]; - uint64_t common_popcnt = 0; - for (auto j = 1; j < popcnt_idx; j++) - common_popcnt += popcntll(qptr[j] & dbptr[j]); + uint64_t common_popcnt_val = common_popcnt(qptr, dbptr, 1, popcnt_idx); // Tversky: common / (common*(1-a-b) + a*q_popcnt + b*db_popcnt) - float denom = common_popcnt * one_minus_a_minus_b + a_times_q + b * db_popcnt; - float coeff = (denom != 0.0f) ? common_popcnt / denom : 0.0f; + float denom = common_popcnt_val * one_minus_a_minus_b + a_times_q + b * db_popcnt; + float coeff = (denom != 0.0f) ? common_popcnt_val / denom : 0.0f; if (coeff >= threshold) results->push_back({i, (uint32_t)dbptr[0], coeff}); @@ -94,31 +92,31 @@ py::array_t TverskySearch(const py::array_t py_query, struct TanimotoCalculator { - static inline float calculate(const uint32_t &common_popcnt, - const uint32_t &qcount, - const uint32_t &ocount) + static inline float calculate(const uint64_t common_popcnt, + const uint64_t qcount, + const uint64_t ocount) { - return (float)common_popcnt / (qcount + ocount - common_popcnt); + return (float)common_popcnt / (float)(qcount + ocount - common_popcnt); } }; struct CosineCalculator { - static inline float calculate(const uint32_t &common_popcnt, - const uint32_t &qcount, - const uint32_t &ocount) + static inline float calculate(const uint64_t common_popcnt, + const uint64_t qcount, + const uint64_t ocount) { - return (float)common_popcnt / sqrt(qcount * ocount); + return (float)common_popcnt / sqrtf((float)qcount * (float)ocount); } }; struct DiceCalculator { - static inline float calculate(const uint32_t &common_popcnt, - const uint32_t &qcount, - const uint32_t &ocount) + static inline float calculate(const uint64_t common_popcnt, + const uint64_t qcount, + const uint64_t ocount) { - return (2.0f * common_popcnt) / (qcount + ocount); + return (2.0f * common_popcnt) / (float)(qcount + ocount); } }; @@ -168,14 +166,14 @@ py::array_t GenericSearchImpl(const py::array_t py_query, { std::priority_queue, utils::ResultComparator> top_k; + float min_coeff = threshold; + for (uint32_t idx = start; idx < end; ++idx, dbptr += fp_shape) { - uint64_t common_popcnt = 0; - for (auto j = 1; j < popcnt_idx; j++) - common_popcnt += popcntll(qptr[j] & dbptr[j]); + uint64_t common_popcnt_val = common_popcnt(qptr, dbptr, 1, popcnt_idx); - float coeff = calc.calculate(common_popcnt, q_popcnt, dbptr[popcnt_idx]); - if (coeff < threshold) + float coeff = calc.calculate(common_popcnt_val, q_popcnt, dbptr[popcnt_idx]); + if (coeff < min_coeff) continue; if (top_k.size() < k) @@ -186,6 +184,7 @@ py::array_t GenericSearchImpl(const py::array_t py_query, { top_k.pop(); top_k.push({idx, static_cast(dbptr[0]), coeff}); + min_coeff = top_k.top().coeff; } } results->reserve(top_k.size()); @@ -200,11 +199,9 @@ py::array_t GenericSearchImpl(const py::array_t py_query, { for (auto i = start; i < end; i++, dbptr += fp_shape) { - uint64_t common_popcnt = 0; - for (auto j = 1; j < popcnt_idx; j++) - common_popcnt += popcntll(qptr[j] & dbptr[j]); + uint64_t common_popcnt_val = common_popcnt(qptr, dbptr, 1, popcnt_idx); - float coeff = calc.calculate(common_popcnt, q_popcnt, dbptr[popcnt_idx]); + float coeff = calc.calculate(common_popcnt_val, q_popcnt, dbptr[popcnt_idx]); if (coeff < threshold) continue; results->push_back({i, static_cast(dbptr[0]), coeff});