diff --git a/FPSim2/src/sim.cpp b/FPSim2/src/sim.cpp index 86dd7e2..42c106e 100644 --- a/FPSim2/src/sim.cpp +++ b/FPSim2/src/sim.cpp @@ -8,16 +8,6 @@ namespace py = pybind11; -inline uint32_t SubstructCoeff(const uint32_t &rel_co_popcnt, - const uint32_t &common_popcnt) -{ - uint32_t coeff = 0; - coeff = rel_co_popcnt + common_popcnt; - if (coeff != 0) - coeff = common_popcnt / coeff; - return coeff; -} - py::array_t SubstructureScreenout(const py::array_t py_query, const py::array_t py_db, const uint32_t start, @@ -35,21 +25,20 @@ py::array_t SubstructureScreenout(const py::array_t py_query auto results = new std::vector(); - uint32_t coeff; - uint64_t common_popcnt = 0; - uint64_t rel_co_popcnt = 0; - for (auto i = start; i < end; i++, dbptr += fp_shape, - common_popcnt = 0, rel_co_popcnt = 0) + // Substructure match requires all query bits to be present in db entry + for (auto i = start; i < end; i++, dbptr += fp_shape) { + bool is_match = true; for (size_t j = 1; j < popcnt_idx; j++) { - common_popcnt += popcntll(qptr[j] & dbptr[j]); - rel_co_popcnt += popcntll(qptr[j] & ~dbptr[j]); + // Early termination: if any query bit is not in db, skip the entry + if (qptr[j] & ~dbptr[j]) + { + is_match = false; + break; + } } - // calc optimised tversky with a=1, b=0 - coeff = SubstructCoeff(rel_co_popcnt, common_popcnt); - - if (coeff == 1) + if (is_match) results->push_back((uint32_t)dbptr[0]); } // acquire the GIL before calling Python code @@ -57,18 +46,6 @@ py::array_t SubstructureScreenout(const py::array_t py_query return utils::Vector2NumPy(results); } -inline float TverskyCoeff(const uint32_t &common_popcnt, - const uint32_t &rel_co_popcnt, - const uint32_t &rel_co_popcnt2, - const float &a, const float &b) -{ - float coeff = 0.0; - coeff = common_popcnt + a * rel_co_popcnt + b * rel_co_popcnt2; - if (coeff != 0.0) - coeff = common_popcnt / coeff; - return coeff; -} - py::array_t TverskySearch(const py::array_t py_query, const py::array_t py_db, const float threshold, @@ -86,24 +63,25 @@ py::array_t TverskySearch(const py::array_t py_query, const auto fp_shape = query.shape(0); const auto popcnt_idx = fp_shape - 1; + const auto q_popcnt = qptr[popcnt_idx]; + + const float one_minus_a_minus_b = 1.0f - a - b; + const float a_times_q = a * q_popcnt; auto results = new std::vector(); - float coeff; - uint64_t common_popcnt = 0; - uint64_t rel_co_popcnt = 0; - uint64_t rel_co_popcnt2 = 0; - for (auto i = start; i < end; i++, dbptr += fp_shape, - common_popcnt = 0, rel_co_popcnt = 0, rel_co_popcnt2 = 0) + for (auto i = start; i < end; i++, dbptr += fp_shape) { + const auto db_popcnt = dbptr[popcnt_idx]; + + uint64_t common_popcnt = 0; for (auto j = 1; j < popcnt_idx; j++) - { - // popcnts of both relative complements and intersection common_popcnt += popcntll(qptr[j] & dbptr[j]); - rel_co_popcnt += popcntll(qptr[j] & ~dbptr[j]); - rel_co_popcnt2 += popcntll(dbptr[j] & ~qptr[j]); - } - coeff = TverskyCoeff(common_popcnt, rel_co_popcnt, rel_co_popcnt2, a, b); + + // Tversky: common / (common*(1-a-b) + a*q_popcnt + b*db_popcnt) + float denom = common_popcnt * one_minus_a_minus_b + a_times_q + b * db_popcnt; + float coeff = (denom != 0.0f) ? common_popcnt / denom : 0.0f; + if (coeff >= threshold) results->push_back({i, (uint32_t)dbptr[0], coeff}); }