From 6b82f3fd36fc9d3255f923e95e0206ab8f06176f Mon Sep 17 00:00:00 2001 From: vcodestar Date: Sun, 25 Jan 2026 14:14:29 +0200 Subject: [PATCH 1/6] pass entities --- hnswlib/ats_dummy.h | 9 +++ hnswlib/hnswalg.h | 44 ++++++++++++++ python_bindings/bindings.cpp | 112 +++++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+) create mode 100644 hnswlib/ats_dummy.h diff --git a/hnswlib/ats_dummy.h b/hnswlib/ats_dummy.h new file mode 100644 index 00000000..89f919be --- /dev/null +++ b/hnswlib/ats_dummy.h @@ -0,0 +1,9 @@ +#pragma once +#include + +class ATSDummy { +public: + static void ping() { + // std::cout << "[ATS] ats_dummy.h included successfully" << std::endl; + } +}; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e269ae69..c4f8312d 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -10,6 +10,8 @@ #include #include +#include "ats_dummy.h" + namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; @@ -70,6 +72,8 @@ class HierarchicalNSW : public AlgorithmInterface { std::mutex deleted_elements_lock; // lock for deleted_elements std::unordered_set deleted_elements; // contains internal ids of deleted elements + std::vector> node_entities_; + HierarchicalNSW(SpaceInterface *s) { } @@ -96,6 +100,7 @@ class HierarchicalNSW : public AlgorithmInterface { : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), link_list_locks_(max_elements), element_levels_(max_elements), + node_entities_(max_elements), allow_replace_deleted_(allow_replace_deleted) { max_elements_ = max_elements; num_deleted_ = 0; @@ -169,6 +174,18 @@ class HierarchicalNSW : public AlgorithmInterface { } }; + double getJaccardSimilarity(const std::unordered_set& a, + const std::unordered_set& b) { + + size_t intersection = 0; + for (const auto& x : a) { + if (b.count(x)) intersection++; + } + + size_t union_count = a.size() + b.size() - intersection; + return union_count == 0 ? 0.0 : static_cast(intersection) / union_count; + } + void setEf(size_t ef) { ef_ = ef; @@ -224,6 +241,12 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + ATSDummy::ping(); + // std::unordered_set setA = {1, 2, 3, 4}; + // std::unordered_set setB = {3, 4, 5, 6}; + // double sim = getJaccardSimilarity(setA, setB); + // std::cout << "Jaccard similarity: " << sim << std::endl; + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -683,6 +706,7 @@ class HierarchicalNSW : public AlgorithmInterface { } void saveIndex(const std::string &location) { + std::cout<< "======================SAVING=============================\n"; std::ofstream output(location, std::ios::binary); std::streampos position; @@ -714,7 +738,9 @@ class HierarchicalNSW : public AlgorithmInterface { void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + std::cout<< "======================LOADING=============================\n"; std::ifstream input(location, std::ios::binary); + node_entities_.resize(cur_element_count); if (!input.is_open()) throw std::runtime_error("Cannot open file"); @@ -1265,6 +1291,24 @@ class HierarchicalNSW : public AlgorithmInterface { } return cur_c; } + + + tableint addPointWithEntities( + const void* data_point, + labeltype label, + const std::vector& entities, + int level = -1) { + tableint id = addPoint(data_point, label, level); + + // Ensure capacity + if (id >= node_entities_.size()) { + node_entities_.resize(max_elements_); + } + + node_entities_[id] = entities; + return id; + } + std::priority_queue> diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index dd09e80a..3375343c 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -248,6 +248,17 @@ class Index { } + // std::vector getEntities(tableint internal_id) { + // if (!appr_alg) { + // throw std::runtime_error("Index not initialized"); + // } + // if (internal_id >= appr_alg->getCurrentElementCount()) { + // throw std::out_of_range("Invalid internal id"); + // } + // return appr_alg->node_entities_[internal_id] + // } + + void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); @@ -303,6 +314,100 @@ class Index { } } + void addItemsWithEntities(py::object input, py::object ids_ = py::none(), py::object entities_ = py::none(), + int num_threads = -1, bool replace_deleted = false) + { + + std::cout << "CUSTOM FUNCTION CALLED" << std::endl; + if (!entities_.is_none()) { + std::cout << "Entities passed:" << std::endl; + + // Convert Python object to iterable + py::list entity_list = entities_; + for (size_t i = 0; i < entity_list.size(); i++) { + py::object ent = entity_list[i]; + // Convert to string for printing + std::string ent_str = py::str(ent); + std::cout << " " << i << ": " << ent_str << std::endl; + } + } else { + std::cout << "No entities provided." << std::endl; + } + + py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); + auto buffer = items.request(); + if (num_threads <= 0) + num_threads = num_threads_default; + + size_t rows, features; + get_input_array_shapes(buffer, &rows, &features); + + if (features != dim) + throw std::runtime_error("Wrong dimensionality of the vectors"); + + // avoid using threads when the number of additions is small: + if (rows <= num_threads * 4) { + num_threads = 1; + } + + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); + + py::array_t entities_arr = + entities_.cast>(); + + auto buf = entities_arr.request(); + + size_t entity_rows = buf.shape[0]; + size_t cols = buf.shape[1]; + + auto* data = static_cast(buf.ptr); + + for (size_t id : ids) { + std::cout << "id is: " << id << "\n"; + std::cout << "entities[" << id << "]: "; + + for (size_t j = 0; j < cols; j++) { + std::cout << data[id * cols + j] << " "; + } + std::cout << "\n"; + } + + { + int start = 0; + if (!ep_added) { + size_t id = ids.size() ? ids.at(0) : (cur_l); + float* vector_data = (float*)items.data(0); + std::vector norm_array(dim); + if (normalize) { + normalize_vector(vector_data, norm_array.data()); + vector_data = norm_array.data(); + } + appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted); + start = 1; + ep_added = true; + } + + py::gil_scoped_release l; + if (normalize == false) { + ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted); + }); + } else { + std::vector norm_array(num_threads * dim); + ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { + // normalize vector: + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); + + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted); + }); + } + cur_l += rows; + } + } + py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") { std::vector return_types{"numpy", "list"}; @@ -934,6 +1039,13 @@ PYBIND11_PLUGIN(hnswlib) { py::arg("ids") = py::none(), py::arg("num_threads") = -1, py::arg("replace_deleted") = false) + .def("add_items_with_entities", + &Index::addItemsWithEntities, + py::arg("data"), + py::arg("ids_") = py::none(), + py::arg("entities_") = py::none(), + py::arg("num_threads") = -1, + py::arg("replace_deleted") = false) .def("get_items", &Index::getData, py::arg("ids") = py::none(), py::arg("return_type") = "numpy") .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) From a050e5b8248aebdf3c56fe734c322a0f4019f23f Mon Sep 17 00:00:00 2001 From: vcodestar Date: Sun, 25 Jan 2026 15:41:30 +0200 Subject: [PATCH 2/6] move entities inside array --- hnswlib/hnswalg.h | 13 +++++++++++++ python_bindings/bindings.cpp | 21 ++++++++++++++++----- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index c4f8312d..0d1030ba 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -186,6 +186,19 @@ class HierarchicalNSW : public AlgorithmInterface { return union_count == 0 ? 0.0 : static_cast(intersection) / union_count; } + void setNodeEntities(const std::vector>& entities) { + node_entities_ = entities; + + std::cout << "==================================Node entities set: " << std::endl; + for (size_t i = 0; i < node_entities_.size(); i++) { + std::cout << "Node " << i << ": "; + for (size_t j = 0; j < node_entities_[i].size(); j++) { + std::cout << node_entities_[i][j] << " "; + } + std::cout << std::endl; + } + } + void setEf(size_t ef) { ef_ = ef; diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 3375343c..60141677 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -362,15 +362,26 @@ class Index { auto* data = static_cast(buf.ptr); - for (size_t id : ids) { - std::cout << "id is: " << id << "\n"; - std::cout << "entities[" << id << "]: "; + // for (size_t id : ids) { + // std::cout << "id is: " << id << "\n"; + // std::cout << "entities[" << id << "]: "; + // for (size_t j = 0; j < cols; j++) { + // std::cout << data[id * cols + j] << " "; + // } + // std::cout << "\n"; + // } + + std::vector> entities_cpp(entity_rows, std::vector(cols)); + + for (size_t i = 0; i < entity_rows; i++) { for (size_t j = 0; j < cols; j++) { - std::cout << data[id * cols + j] << " "; + entities_cpp[i][j] = data[i * cols + j]; } - std::cout << "\n"; } + + appr_alg->setNodeEntities(entities_cpp); + { int start = 0; From 6987e3030b7615376665b8363265ef6abfef51d1 Mon Sep 17 00:00:00 2001 From: vcodestar Date: Sun, 25 Jan 2026 15:59:16 +0200 Subject: [PATCH 3/6] variable size entities lists --- hnswlib/hnswalg.h | 7 +++++++ python_bindings/bindings.cpp | 36 +++++++++++++----------------------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 0d1030ba..cba76224 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -197,6 +197,13 @@ class HierarchicalNSW : public AlgorithmInterface { } std::cout << std::endl; } + + std::cout << "Total nodes: " << node_entities_.size() << std::endl; + + for (size_t i = 0; i < node_entities_.size(); i++) { + std::cout << "Node " << i << " has " << node_entities_[i].size() << " entities." << std::endl; + } + } diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 60141677..74e2837b 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -352,34 +352,24 @@ class Index { std::vector ids = get_input_ids_and_check_shapes(ids_, rows); - py::array_t entities_arr = - entities_.cast>(); - - auto buf = entities_arr.request(); - - size_t entity_rows = buf.shape[0]; - size_t cols = buf.shape[1]; - - auto* data = static_cast(buf.ptr); - - // for (size_t id : ids) { - // std::cout << "id is: " << id << "\n"; - // std::cout << "entities[" << id << "]: "; + std::vector> entities_cpp; + if (!entities_.is_none()) { + py::list entity_list = entities_; + if (entity_list.size() != rows) + throw std::runtime_error("Number of entities lists must match number of vectors"); - // for (size_t j = 0; j < cols; j++) { - // std::cout << data[id * cols + j] << " "; - // } - // std::cout << "\n"; - // } + for (size_t i = 0; i < entity_list.size(); i++) { + py::list single_node = entity_list[i]; + std::vector node_entities; - std::vector> entities_cpp(entity_rows, std::vector(cols)); + for (size_t j = 0; j < single_node.size(); j++) { + node_entities.push_back(single_node[j].cast()); + } - for (size_t i = 0; i < entity_rows; i++) { - for (size_t j = 0; j < cols; j++) { - entities_cpp[i][j] = data[i * cols + j]; + entities_cpp.push_back(node_entities); } } - + appr_alg->setNodeEntities(entities_cpp); From 1bc6c8a757b61c64c8ea926ffc7be4ec157c30d9 Mon Sep 17 00:00:00 2001 From: vcodestar Date: Sun, 25 Jan 2026 17:22:28 +0200 Subject: [PATCH 4/6] save/load entities --- hnswlib/hnswalg.h | 73 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index cba76224..e23ce8e1 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -200,10 +200,16 @@ class HierarchicalNSW : public AlgorithmInterface { std::cout << "Total nodes: " << node_entities_.size() << std::endl; + size_t total_bytes = 0; + for (size_t i = 0; i < node_entities_.size(); i++) { - std::cout << "Node " << i << " has " << node_entities_[i].size() << " entities." << std::endl; + size_t node_size = node_entities_[i].size() * sizeof(tableint); + total_bytes += node_size; + std::cout << "Node " << i << " has " << node_entities_[i].size() + << " entities, approx " << node_size << " bytes" << std::endl; } + std::cout << "Approx total memory for all entities: " << total_bytes << " bytes" << std::endl; } @@ -722,6 +728,13 @@ class HierarchicalNSW : public AlgorithmInterface { size += sizeof(linkListSize); size += linkListSize; } + + for (size_t i = 0; i < node_entities_.size(); i++) { + unsigned int numEntities = node_entities_[i].size(); + size += sizeof(numEntities); + size += numEntities * sizeof(tableint); + } + return size; } @@ -733,6 +746,7 @@ class HierarchicalNSW : public AlgorithmInterface { writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); writeBinaryPOD(output, cur_element_count); + std::cout << "CURRENT ELEMENT COUNT: " << cur_element_count << "\n"; writeBinaryPOD(output, size_data_per_element_); writeBinaryPOD(output, label_offset_); writeBinaryPOD(output, offsetData_); @@ -753,6 +767,16 @@ class HierarchicalNSW : public AlgorithmInterface { if (linkListSize) output.write(linkLists_[i], linkListSize); } + + for (size_t i = 0; i < node_entities_.size(); i++) { + unsigned int numEntities = node_entities_[i].size(); + writeBinaryPOD(output, numEntities); + if (numEntities > 0) { + output.write(reinterpret_cast(node_entities_[i].data()), + numEntities * sizeof(tableint)); + } + } + output.close(); } @@ -760,7 +784,7 @@ class HierarchicalNSW : public AlgorithmInterface { void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { std::cout<< "======================LOADING=============================\n"; std::ifstream input(location, std::ios::binary); - node_entities_.resize(cur_element_count); + // node_entities_.resize(cur_element_count); if (!input.is_open()) throw std::runtime_error("Cannot open file"); @@ -798,22 +822,22 @@ class HierarchicalNSW : public AlgorithmInterface { auto pos = input.tellg(); /// Optional - check if index is ok: - input.seekg(cur_element_count * size_data_per_element_, input.cur); - for (size_t i = 0; i < cur_element_count; i++) { - if (input.tellg() < 0 || input.tellg() >= total_filesize) { - throw std::runtime_error("Index seems to be corrupted or unsupported"); - } - - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize != 0) { - input.seekg(linkListSize, input.cur); - } - } + // input.seekg(cur_element_count * size_data_per_element_, input.cur); + // for (size_t i = 0; i < cur_element_count; i++) { + // if (input.tellg() < 0 || input.tellg() >= total_filesize) { + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // } + + // unsigned int linkListSize; + // readBinaryPOD(input, linkListSize); + // if (linkListSize != 0) { + // input.seekg(linkListSize, input.cur); + // } + // } // throw exception if it either corrupted or old index - if (input.tellg() != total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); + // if (input.tellg() != total_filesize) + // throw std::runtime_error("Index seems to be corrupted or unsupported"); input.clear(); /// Optional check end @@ -862,7 +886,24 @@ class HierarchicalNSW : public AlgorithmInterface { } } + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int numEntities; + readBinaryPOD(input, numEntities); + + std::vector entities(numEntities); + if (numEntities > 0) + input.read(reinterpret_cast(entities.data()), numEntities * sizeof(tableint)); + + node_entities_.push_back(std::move(entities)); + } + input.close(); + for (size_t i = 0; i < node_entities_.size(); i++) { + size_t node_size = node_entities_[i].size() * sizeof(tableint); + std::cout << "Node " << i << " has " << node_entities_[i].size() + << " entities, approx " << node_size << " bytes" << std::endl; + } + return; } From bc259509b9b04aebed9395322e52a62a1f699155 Mon Sep 17 00:00:00 2001 From: vcodestar Date: Sun, 25 Jan 2026 18:14:32 +0200 Subject: [PATCH 5/6] optional check fix --- hnswlib/hnswalg.h | 52 ++++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e23ce8e1..2aba2cd2 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -821,23 +821,31 @@ class HierarchicalNSW : public AlgorithmInterface { auto pos = input.tellg(); - /// Optional - check if index is ok: - // input.seekg(cur_element_count * size_data_per_element_, input.cur); - // for (size_t i = 0; i < cur_element_count; i++) { - // if (input.tellg() < 0 || input.tellg() >= total_filesize) { - // throw std::runtime_error("Index seems to be corrupted or unsupported"); - // } - - // unsigned int linkListSize; - // readBinaryPOD(input, linkListSize); - // if (linkListSize != 0) { - // input.seekg(linkListSize, input.cur); - // } - // } - - // throw exception if it either corrupted or old index - // if (input.tellg() != total_filesize) - // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // optional check + input.seekg(cur_element_count * size_data_per_element_, input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if (input.tellg() < 0 || input.tellg() > total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + + if (linkListSize) + input.seekg(linkListSize, std::ios::cur); + } + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int numEntities; + readBinaryPOD(input, numEntities); + + input.seekg(numEntities * sizeof(tableint), std::ios::cur); + + if (input.tellg() < 0 || input.tellg() > total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + if (input.tellg() != total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); input.clear(); /// Optional check end @@ -900,8 +908,14 @@ class HierarchicalNSW : public AlgorithmInterface { input.close(); for (size_t i = 0; i < node_entities_.size(); i++) { size_t node_size = node_entities_[i].size() * sizeof(tableint); - std::cout << "Node " << i << " has " << node_entities_[i].size() - << " entities, approx " << node_size << " bytes" << std::endl; + std::cout << "Loaded Node " << i << " has " << node_entities_[i].size() + << " entities, approx " << node_size << " bytes: "; + + for (size_t j = 0; j < node_entities_[i].size(); j++) { + std::cout << node_entities_[i][j] << " "; + } + + std::cout << std::endl; } From efd53ad8ce945f7f89d35bfa7490cc4fb1c3a8b4 Mon Sep 17 00:00:00 2001 From: vcodestar Date: Thu, 29 Jan 2026 11:40:10 +0200 Subject: [PATCH 6/6] added AhoCorasick Trie --- hnswlib/AhoCorasick.cpp | 142 ++++++++++++++++++++++++++++++++++++++++ hnswlib/AhoCorasick.h | 25 +++++++ hnswlib/TrieNode.h | 14 ++++ hnswlib/hnswalg.h | 24 +++++++ hnswlib/test.cpp | 32 +++++++++ setup.py | 24 +++++-- 6 files changed, 254 insertions(+), 7 deletions(-) create mode 100644 hnswlib/AhoCorasick.cpp create mode 100644 hnswlib/AhoCorasick.h create mode 100644 hnswlib/TrieNode.h create mode 100644 hnswlib/test.cpp diff --git a/hnswlib/AhoCorasick.cpp b/hnswlib/AhoCorasick.cpp new file mode 100644 index 00000000..9957d822 --- /dev/null +++ b/hnswlib/AhoCorasick.cpp @@ -0,0 +1,142 @@ +#include "AhoCorasick.h" +#include +#include +#include +#include +#include + +AhoCorasick::AhoCorasick() { + root = new TrieNode(); +} + +AhoCorasick::~AhoCorasick() { + deleteTrie(root); +} + +void AhoCorasick::deleteTrie(TrieNode* node) { + if (!node) return; + for (auto& kv: node->children) { + deleteTrie(kv.second); + } + delete node; +} + +size_t countUniqueWordsInTrie(TrieNode* node, std::unordered_set& seen) { + if (!node) return 0; + for (auto& word : node->outputs) { + seen.insert(word); + } + for (auto& [c, child] : node->children) { + countUniqueWordsInTrie(child, seen); + } + return seen.size(); +} + +size_t AhoCorasick::numWords() const { + std::unordered_set seen; + return countUniqueWordsInTrie(root, seen); +} + +void AhoCorasick::build(const std::vector& entities) { + for (const auto& word: entities) { + TrieNode* node = root; + for (char c : word) { + if (node->children.find(c) == node->children.end()) { + node->children[c] = new TrieNode(); + } + node = node->children[c]; + } + node->outputs.push_back(word); + } + buildFailureLinks(); +} + +void AhoCorasick::buildFailureLinks() { + std::queue q; + root->failure = root; + + for (auto& kv : root->children) { + kv.second ->failure = root; + q.push(kv.second); + } + + while (!q.empty()) { + TrieNode* current = q.front(); q.pop(); + for (auto& kv : current->children) { + char c = kv.first; + TrieNode* child = kv.second; + + TrieNode* fail = current->failure; + while (fail != root && fail -> children.find(c) == fail->children.end()) { + fail = fail->failure; + } + + if (fail->children.find(c) != fail->children.end()) { + child->failure = fail->children[c]; + } else { + child->failure = root; + } + + child->outputs.insert(child -> outputs.end(), + child->failure->outputs.begin(), + child->failure->outputs.end()); + + q.push(child); + } + } +} + +std::vector> AhoCorasick::search(const std::string& text) const { + std::vector> matches; + TrieNode* node = root; + + for (size_t i = 0; i < text.size(); i++) { + char c = text[i]; + while (node != root && node->children.find(c) == node->children.end()) + node = node->failure; + + if (node->children.find(c) != node->children.end()) + node = node->children.at(c); + + for (const auto& out : node->outputs) + matches.push_back({i - out.size() + 1, out}); // start index + } + + // sort matches by start index + std::sort(matches.begin(), matches.end(), + [](const auto &a, const auto &b){ return a.first < b.first; }); + + std::vector> longest_matches; +std::unordered_map> best_at_start; + +for (auto& m : matches) { + size_t start = m.first; + size_t len = m.second.size(); + // keep only the longest word for this start index + if (best_at_start.find(start) == best_at_start.end() || len > best_at_start[start].first) { + best_at_start[start] = {len, m.second}; + } +} + +// now collect and sort by start index +for (auto& kv : best_at_start) { + longest_matches.push_back({kv.first, kv.second.second}); +} +std::sort(longest_matches.begin(), longest_matches.end(), + [](const auto &a, const auto &b){ return a.first < b.first; }); + + + return longest_matches; +} + + + +void AhoCorasick::save(const std::string& filename) const { + std::ofstream out(filename, std::ios::binary); + std::cout << "Trie saving not implemented yet\n"; +} + +void AhoCorasick::load(const std::string& filename) { + std::ifstream in(filename, std::ios::binary); + std::cout << "Trie loading not implemented yet\n"; +} \ No newline at end of file diff --git a/hnswlib/AhoCorasick.h b/hnswlib/AhoCorasick.h new file mode 100644 index 00000000..aed42373 --- /dev/null +++ b/hnswlib/AhoCorasick.h @@ -0,0 +1,25 @@ +#pragma once +#include "TrieNode.h" +#include +#include + +class AhoCorasick { + +public: + AhoCorasick(); + ~AhoCorasick(); + + size_t numWords() const; + + void build(const std::vector &entities); + + std::vector> search(const std::string& text) const; + + void save(const std::string& filename) const; + void load(const std::string& filename); + +private: + TrieNode* root; + void buildFailureLinks(); + void deleteTrie(TrieNode* node); +}; \ No newline at end of file diff --git a/hnswlib/TrieNode.h b/hnswlib/TrieNode.h new file mode 100644 index 00000000..3e4bab24 --- /dev/null +++ b/hnswlib/TrieNode.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include +#include + +struct TrieNode; + +struct TrieNode { + std::unordered_map children; + TrieNode* failure; + std::vector outputs; + + TrieNode() : failure(nullptr) {} +}; \ No newline at end of file diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 2aba2cd2..6cf2dea6 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -11,6 +11,7 @@ #include #include "ats_dummy.h" +#include "AhoCorasick.h" namespace hnswlib { typedef unsigned int tableint; @@ -743,6 +744,29 @@ class HierarchicalNSW : public AlgorithmInterface { std::ofstream output(location, std::ios::binary); std::streampos position; + // ---- Aho-Corasick small test ---- + { + std::vector patterns = { + "apple", + "banana", + "pie" + }; + + AhoCorasick ac; + ac.build(patterns); + + std::string text = "I like apple pie and banana bread"; + auto matches = ac.search(text); + + std::cout << "=== Aho-Corasick test ===\n"; + for (const auto &m : matches) { + std::cout << "Found \"" << m.second + << "\" at position " << m.first << "\n"; + } + std::cout << "========================\n"; + } + // ---- end test ---- + writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); writeBinaryPOD(output, cur_element_count); diff --git a/hnswlib/test.cpp b/hnswlib/test.cpp new file mode 100644 index 00000000..62f8953f --- /dev/null +++ b/hnswlib/test.cpp @@ -0,0 +1,32 @@ +#include "AhoCorasick.h" +#include +#include + +int main() { + // Example symbolic entities for Tesla Model Y retrieval + std::vector entities = { + "Tesla", "Tesla Model Y", "electric", "autopilot", "red", "blue", + "interior", "battery", "long range", "performance", "range" + }; + + // Build the Trie + AhoCorasick ac; + ac.build(entities); + + // Example user query + std::string query = "Show me a red Tesla Model Y with autopilot and long range battery"; + + // Search for symbolic entities + auto matches = ac.search(query); + + // Display found entities and positions + std::cout << "Detected symbolic entities in query:\n"; + for (auto& m : matches) { + std::cout << " - '" << m.second << "' at position " << m.first << "\n"; + } + + // Total number of entities in the Trie + std::cout << "\nTotal symbolic entities in Trie: " << ac.numWords() << "\n"; + + return 0; +} diff --git a/setup.py b/setup.py index d96aea49..0b1ca1fd 100644 --- a/setup.py +++ b/setup.py @@ -18,12 +18,21 @@ # compatibility when run in python_bindings bindings_dir = 'python_bindings' +this_dir = os.path.abspath(os.path.dirname(__file__)) + if bindings_dir in os.path.basename(os.getcwd()): - source_files = ['./bindings.cpp'] - include_dirs.extend(['../hnswlib/']) + source_files = [ + os.path.join(this_dir, 'python_bindings', 'bindings.cpp'), + os.path.join(this_dir, 'hnswlib', 'AhoCorasick.cpp'), + ] + include_dirs.extend([os.path.join(this_dir, 'hnswlib')]) else: - source_files = ['./python_bindings/bindings.cpp'] - include_dirs.extend(['./hnswlib/']) + source_files = [ + os.path.join(this_dir, 'python_bindings', 'bindings.cpp'), + os.path.join(this_dir, 'hnswlib', 'AhoCorasick.cpp'), + ] + include_dirs.extend([os.path.join(this_dir, 'hnswlib')]) + libraries = [] @@ -62,10 +71,10 @@ def cpp_flag(compiler): """Return the -std=c++[11/14] compiler flag. The c++14 is prefered over c++11 (when it is available). """ - if has_flag(compiler, '-std=c++14'): + if has_flag(compiler, '/std:c++17'): + return '/std:c++17' + elif has_flag(compiler, '-std=c++14'): return '-std=c++14' - elif has_flag(compiler, '-std=c++11'): - return '-std=c++11' else: raise RuntimeError('Unsupported compiler -- at least C++11 support ' 'is needed!') @@ -119,6 +128,7 @@ def build_extensions(self): else: print(f'flag: {BuildExt.compiler_flag_native} is available') elif ct == 'msvc': + opts.append('/std:c++17') opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) for ext in self.extensions: