Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions FPSim2/src/include/popcnt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,78 @@ static inline uint64_t popcntll(const uint64_t X) {
}

#endif

// AVX512 support for vectorized popcount
#if defined(__AVX512VPOPCNTDQ__) && defined(__AVX512F__)

#include <immintrin.h>

static inline uint64_t common_popcnt_avx512(const uint64_t* qptr,
const uint64_t* dbptr,
const size_t start,
const size_t end) {
const size_t len = end - start;
const size_t vec_end = start + (len / 8) * 8;

__m512i sum = _mm512_setzero_si512();

for (size_t j = start; j < vec_end; j += 8) {
__m512i q = _mm512_loadu_si512((__m512i*)&qptr[j]);
__m512i d = _mm512_loadu_si512((__m512i*)&dbptr[j]);
__m512i and_result = _mm512_and_si512(q, d);
__m512i popcnt = _mm512_popcnt_epi64(and_result);
sum = _mm512_add_epi64(sum, popcnt);
}

// Horizontal sum of 8 uint64_t values
uint64_t result = _mm512_reduce_add_epi64(sum);

// Handle remaining elements
for (size_t j = vec_end; j < end; j++) {
result += popcntll(qptr[j] & dbptr[j]);
}

return result;
}

#define HAS_AVX512_POPCNT 1

#else

#define HAS_AVX512_POPCNT 0

#endif

// Portable restrict keyword
#if defined(_MSC_VER)
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif

// Generic common popcount function - uses AVX512 if available
static inline uint64_t common_popcnt(const uint64_t* RESTRICT qptr,
const uint64_t* RESTRICT dbptr,
const size_t start,
const size_t end) {
#if HAS_AVX512_POPCNT
return common_popcnt_avx512(qptr, dbptr, start, end);
#else
uint64_t result = 0;
size_t j = start;

// Unroll by 4 for better instruction-level parallelism
const size_t unroll_end = start + ((end - start) / 4) * 4;
for (; j < unroll_end; j += 4) {
result += popcntll(qptr[j] & dbptr[j]);
result += popcntll(qptr[j+1] & dbptr[j+1]);
result += popcntll(qptr[j+2] & dbptr[j+2]);
result += popcntll(qptr[j+3] & dbptr[j+3]);
}
// Handle remainder
for (; j < end; j++) {
result += popcntll(qptr[j] & dbptr[j]);
}
return result;
#endif
}
49 changes: 23 additions & 26 deletions FPSim2/src/sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,11 @@ py::array_t<Result> TverskySearch(const py::array_t<uint64_t> py_query,
{
const auto db_popcnt = dbptr[popcnt_idx];

uint64_t common_popcnt = 0;
for (auto j = 1; j < popcnt_idx; j++)
common_popcnt += popcntll(qptr[j] & dbptr[j]);
uint64_t common_popcnt_val = common_popcnt(qptr, dbptr, 1, popcnt_idx);

// Tversky: common / (common*(1-a-b) + a*q_popcnt + b*db_popcnt)
float denom = common_popcnt * one_minus_a_minus_b + a_times_q + b * db_popcnt;
float coeff = (denom != 0.0f) ? common_popcnt / denom : 0.0f;
float denom = common_popcnt_val * one_minus_a_minus_b + a_times_q + b * db_popcnt;
float coeff = (denom != 0.0f) ? common_popcnt_val / denom : 0.0f;

if (coeff >= threshold)
results->push_back({i, (uint32_t)dbptr[0], coeff});
Expand All @@ -94,31 +92,31 @@ py::array_t<Result> TverskySearch(const py::array_t<uint64_t> py_query,

struct TanimotoCalculator
{
static inline float calculate(const uint32_t &common_popcnt,
const uint32_t &qcount,
const uint32_t &ocount)
static inline float calculate(const uint64_t common_popcnt,
const uint64_t qcount,
const uint64_t ocount)
{
return (float)common_popcnt / (qcount + ocount - common_popcnt);
return (float)common_popcnt / (float)(qcount + ocount - common_popcnt);
}
};

struct CosineCalculator
{
static inline float calculate(const uint32_t &common_popcnt,
const uint32_t &qcount,
const uint32_t &ocount)
static inline float calculate(const uint64_t common_popcnt,
const uint64_t qcount,
const uint64_t ocount)
{
return (float)common_popcnt / sqrt(qcount * ocount);
return (float)common_popcnt / sqrtf((float)qcount * (float)ocount);
}
};

struct DiceCalculator
{
static inline float calculate(const uint32_t &common_popcnt,
const uint32_t &qcount,
const uint32_t &ocount)
static inline float calculate(const uint64_t common_popcnt,
const uint64_t qcount,
const uint64_t ocount)
{
return (2.0f * common_popcnt) / (qcount + ocount);
return (2.0f * common_popcnt) / (float)(qcount + ocount);
}
};

Expand Down Expand Up @@ -168,14 +166,14 @@ py::array_t<Result> GenericSearchImpl(const py::array_t<uint64_t> py_query,
{
std::priority_queue<Result, std::vector<Result>, utils::ResultComparator> top_k;

float min_coeff = threshold;

for (uint32_t idx = start; idx < end; ++idx, dbptr += fp_shape)
{
uint64_t common_popcnt = 0;
for (auto j = 1; j < popcnt_idx; j++)
common_popcnt += popcntll(qptr[j] & dbptr[j]);
uint64_t common_popcnt_val = common_popcnt(qptr, dbptr, 1, popcnt_idx);

float coeff = calc.calculate(common_popcnt, q_popcnt, dbptr[popcnt_idx]);
if (coeff < threshold)
float coeff = calc.calculate(common_popcnt_val, q_popcnt, dbptr[popcnt_idx]);
if (coeff < min_coeff)
continue;

if (top_k.size() < k)
Expand All @@ -186,6 +184,7 @@ py::array_t<Result> GenericSearchImpl(const py::array_t<uint64_t> py_query,
{
top_k.pop();
top_k.push({idx, static_cast<uint32_t>(dbptr[0]), coeff});
min_coeff = top_k.top().coeff;
}
}
results->reserve(top_k.size());
Expand All @@ -200,11 +199,9 @@ py::array_t<Result> GenericSearchImpl(const py::array_t<uint64_t> py_query,
{
for (auto i = start; i < end; i++, dbptr += fp_shape)
{
uint64_t common_popcnt = 0;
for (auto j = 1; j < popcnt_idx; j++)
common_popcnt += popcntll(qptr[j] & dbptr[j]);
uint64_t common_popcnt_val = common_popcnt(qptr, dbptr, 1, popcnt_idx);

float coeff = calc.calculate(common_popcnt, q_popcnt, dbptr[popcnt_idx]);
float coeff = calc.calculate(common_popcnt_val, q_popcnt, dbptr[popcnt_idx]);
if (coeff < threshold)
continue;
results->push_back({i, static_cast<uint32_t>(dbptr[0]), coeff});
Expand Down