Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 23 additions & 45 deletions FPSim2/src/sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@

namespace py = pybind11;

inline uint32_t SubstructCoeff(const uint32_t &rel_co_popcnt,
const uint32_t &common_popcnt)
{
uint32_t coeff = 0;
coeff = rel_co_popcnt + common_popcnt;
if (coeff != 0)
coeff = common_popcnt / coeff;
return coeff;
}

py::array_t<uint32_t> SubstructureScreenout(const py::array_t<uint64_t> py_query,
const py::array_t<uint64_t> py_db,
const uint32_t start,
Expand All @@ -35,40 +25,27 @@ py::array_t<uint32_t> SubstructureScreenout(const py::array_t<uint64_t> py_query

auto results = new std::vector<uint32_t>();

uint32_t coeff;
uint64_t common_popcnt = 0;
uint64_t rel_co_popcnt = 0;
for (auto i = start; i < end; i++, dbptr += fp_shape,
common_popcnt = 0, rel_co_popcnt = 0)
// Substructure match requires all query bits to be present in db entry
for (auto i = start; i < end; i++, dbptr += fp_shape)
{
bool is_match = true;
for (size_t j = 1; j < popcnt_idx; j++)
{
common_popcnt += popcntll(qptr[j] & dbptr[j]);
rel_co_popcnt += popcntll(qptr[j] & ~dbptr[j]);
// Early termination: if any query bit is not in db, skip the entry
if (qptr[j] & ~dbptr[j])
{
is_match = false;
break;
}
}
// calc optimised tversky with a=1, b=0
coeff = SubstructCoeff(rel_co_popcnt, common_popcnt);

if (coeff == 1)
if (is_match)
results->push_back((uint32_t)dbptr[0]);
}
// acquire the GIL before calling Python code
py::gil_scoped_acquire acquire;
return utils::Vector2NumPy<uint32_t>(results);
}

inline float TverskyCoeff(const uint32_t &common_popcnt,
const uint32_t &rel_co_popcnt,
const uint32_t &rel_co_popcnt2,
const float &a, const float &b)
{
float coeff = 0.0;
coeff = common_popcnt + a * rel_co_popcnt + b * rel_co_popcnt2;
if (coeff != 0.0)
coeff = common_popcnt / coeff;
return coeff;
}

py::array_t<Result> TverskySearch(const py::array_t<uint64_t> py_query,
const py::array_t<uint64_t> py_db,
const float threshold,
Expand All @@ -86,24 +63,25 @@ py::array_t<Result> TverskySearch(const py::array_t<uint64_t> py_query,

const auto fp_shape = query.shape(0);
const auto popcnt_idx = fp_shape - 1;
const auto q_popcnt = qptr[popcnt_idx];

const float one_minus_a_minus_b = 1.0f - a - b;
const float a_times_q = a * q_popcnt;

auto results = new std::vector<Result>();

float coeff;
uint64_t common_popcnt = 0;
uint64_t rel_co_popcnt = 0;
uint64_t rel_co_popcnt2 = 0;
for (auto i = start; i < end; i++, dbptr += fp_shape,
common_popcnt = 0, rel_co_popcnt = 0, rel_co_popcnt2 = 0)
for (auto i = start; i < end; i++, dbptr += fp_shape)
{
const auto db_popcnt = dbptr[popcnt_idx];

uint64_t common_popcnt = 0;
for (auto j = 1; j < popcnt_idx; j++)
{
// popcnts of both relative complements and intersection
common_popcnt += popcntll(qptr[j] & dbptr[j]);
rel_co_popcnt += popcntll(qptr[j] & ~dbptr[j]);
rel_co_popcnt2 += popcntll(dbptr[j] & ~qptr[j]);
}
coeff = TverskyCoeff(common_popcnt, rel_co_popcnt, rel_co_popcnt2, a, b);

// Tversky: common / (common*(1-a-b) + a*q_popcnt + b*db_popcnt)
float denom = common_popcnt * one_minus_a_minus_b + a_times_q + b * db_popcnt;
float coeff = (denom != 0.0f) ? common_popcnt / denom : 0.0f;

if (coeff >= threshold)
results->push_back({i, (uint32_t)dbptr[0], coeff});
}
Expand Down