Skip to content

Commit 9c6d44d

Browse files
authored
feat(search): Simple multi-reader multi-writer mutex for hnsw index (#6156)
Implementation of multi-reader multi-writer mutex that supports concurrent writes or reads but not a mix of reads and writes in same time. Unit test for testing MRMWMutex class. Signed-off-by: mkaruza <mario@dragonflydb.io>
1 parent e32a4ed commit 9c6d44d

File tree

4 files changed

+321
-2
lines changed

4 files changed

+321
-2
lines changed

src/core/search/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cxx_test(range_tree_test dfly_search_core absl::random_random LABELS DFLY)
2727
cxx_test(rax_tree_test redis_test_lib LABELS DFLY)
2828
cxx_test(search_parser_test dfly_search_core LABELS DFLY)
2929
cxx_test(search_test redis_test_lib dfly_search_core LABELS DFLY)
30+
cxx_test(mrmw_mutex_test redis_test_lib dfly_search_core fibers2 LABELS DFLY)
3031

3132
if(WITH_SIMSIMD)
3233
target_link_libraries(search_test TRDP::simsimd)

src/core/search/hnsw_index.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
#include "core/search/hnsw_index.h"
66

77
#include <absl/strings/match.h>
8-
#include <absl/synchronization/mutex.h>
98
#include <hnswlib/hnswlib.h>
109
#include <hnswlib/space_ip.h>
1110
#include <hnswlib/space_l2.h>
1211

1312
#include "base/logging.h"
1413
#include "core/search/hnsw_alg.h"
14+
#include "core/search/mrmw_mutex.h"
1515
#include "core/search/vector_utils.h"
1616

1717
namespace dfly::search {
@@ -70,7 +70,8 @@ struct HnswlibAdapter {
7070
void Add(const float* data, GlobalDocId id) {
7171
while (true) {
7272
try {
73-
absl::ReaderMutexLock lock(&resize_mutex_);
73+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
74+
absl::ReaderMutexLock resize_lock(&resize_mutex_);
7475
world_.addPoint(data, id);
7576
return;
7677
} catch (const std::exception& e) {
@@ -86,6 +87,7 @@ struct HnswlibAdapter {
8687

8788
void Remove(GlobalDocId id) {
8889
try {
90+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
8991
world_.markDelete(id);
9092
} catch (const std::exception& e) {
9193
LOG(WARNING) << "HnswlibAdapter::Remove exception: " << e.what();
@@ -94,6 +96,7 @@ struct HnswlibAdapter {
9496

9597
vector<pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
9698
world_.setEf(ef.value_or(kDefaultEfRuntime));
99+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
97100
return QueueToVec(world_.searchKnn(target, k));
98101
}
99102

@@ -111,6 +114,7 @@ struct HnswlibAdapter {
111114

112115
world_.setEf(ef.value_or(kDefaultEfRuntime));
113116
BinsearchFilter filter{&allowed};
117+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
114118
return QueueToVec(world_.searchKnn(target, k, &filter));
115119
}
116120

@@ -153,6 +157,7 @@ struct HnswlibAdapter {
153157
HnswSpace space_;
154158
HierarchicalNSW<float> world_;
155159
absl::Mutex resize_mutex_;
160+
mutable MRMWMutex mrmw_mutex_;
156161
};
157162

158163
HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)

src/core/search/mrmw_mutex.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2025, DragonflyDB authors. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
//
4+
5+
#include "base/logging.h"
6+
#include "util/fibers/synchronization.h"
7+
8+
namespace dfly::search {
9+
10+
// Simple implementation of multi-Reader multi-Writer Mutex
11+
// MRMWMutex supports concurrent reads or concurrent writes but not a mix of
12+
// concurrent reads and writes at the same time.
13+
14+
class MRMWMutex {
15+
public:
16+
enum class LockMode : uint8_t { kReadLock, kWriteLock };
17+
18+
MRMWMutex() : lock_mode_(LockMode::kReadLock) {
19+
}
20+
21+
void Lock(LockMode mode) {
22+
std::unique_lock lk(mutex_);
23+
24+
// If we have any active_runners we need to check lock mode
25+
if (active_runners_) {
26+
auto& waiters = GetWaiters(mode);
27+
waiters++;
28+
GetCondVar(mode).wait(lk, [&] { return lock_mode_ == mode; });
29+
waiters--;
30+
} else {
31+
// No active runners so just update to requested lock mode
32+
lock_mode_ = mode;
33+
}
34+
active_runners_++;
35+
}
36+
37+
void Unlock(LockMode mode) {
38+
std::unique_lock lk(mutex_);
39+
LockMode inverse_mode = GetInverseMode(mode);
40+
active_runners_--;
41+
// If this was last runner and there are waiters on inverse mode
42+
if (!active_runners_ && GetWaiters(inverse_mode) > 0) {
43+
lock_mode_ = inverse_mode;
44+
GetCondVar(inverse_mode).notify_all();
45+
}
46+
}
47+
48+
private:
49+
inline size_t& GetWaiters(LockMode target_mode) {
50+
return target_mode == LockMode::kReadLock ? reader_waiters_ : writer_waiters_;
51+
};
52+
53+
inline util::fb2::CondVar& GetCondVar(LockMode target_mode) {
54+
return target_mode == LockMode::kReadLock ? reader_cond_var_ : writer_cond_var_;
55+
};
56+
57+
static inline LockMode GetInverseMode(LockMode mode) {
58+
return mode == LockMode::kReadLock ? LockMode::kWriteLock : LockMode::kReadLock;
59+
}
60+
61+
util::fb2::Mutex mutex_;
62+
util::fb2::CondVar reader_cond_var_, writer_cond_var_;
63+
size_t writer_waiters_ = 0, reader_waiters_ = 0;
64+
size_t active_runners_ = 0;
65+
LockMode lock_mode_;
66+
};
67+
68+
class MRMWMutexLock {
69+
public:
70+
explicit MRMWMutexLock(MRMWMutex* mutex, MRMWMutex::LockMode mode)
71+
: mutex_(mutex), lock_mode_(mode) {
72+
mutex->Lock(lock_mode_);
73+
}
74+
75+
~MRMWMutexLock() {
76+
mutex_->Unlock(lock_mode_);
77+
}
78+
79+
MRMWMutexLock(const MRMWMutexLock&) = delete;
80+
MRMWMutexLock(MRMWMutexLock&&) = delete;
81+
MRMWMutexLock& operator=(const MRMWMutexLock&) = delete;
82+
MRMWMutexLock& operator=(MRMWMutexLock&&) = delete;
83+
84+
private:
85+
MRMWMutex* const mutex_;
86+
MRMWMutex::LockMode lock_mode_;
87+
};
88+
89+
} // namespace dfly::search

src/core/search/mrmw_mutex_test.cc

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
// Copyright 2025, DragonflyDB authors. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
//
4+
5+
#include "core/search/mrmw_mutex.h"
6+
7+
#include <random>
8+
#include <thread>
9+
10+
#include "absl/flags/flag.h"
11+
#include "base/gtest.h"
12+
#include "base/logging.h"
13+
#include "util/fibers/pool.h"
14+
15+
ABSL_FLAG(bool, force_epoll, false, "If true, uses epoll api instead iouring to run tests");
16+
17+
namespace dfly::search {
18+
19+
namespace {
20+
21+
// Helper function to simulate reading operation
22+
void ReadTask(MRMWMutex* mutex, std::atomic<size_t>& read_count, size_t sleep_time) {
23+
read_count.fetch_add(1, std::memory_order_relaxed);
24+
MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kReadLock);
25+
util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time));
26+
read_count.fetch_sub(1, std::memory_order_relaxed);
27+
}
28+
29+
// Helper function to simulate writing operation
30+
void WriteTask(MRMWMutex* mutex, std::atomic<size_t>& write_count, size_t sleep_time) {
31+
write_count.fetch_add(1, std::memory_order_relaxed);
32+
MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kWriteLock);
33+
util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time));
34+
write_count.fetch_sub(1, std::memory_order_relaxed);
35+
}
36+
37+
constexpr size_t kReadTaskSleepTime = 50;
38+
constexpr size_t kWriteTaskSleepTime = 100;
39+
40+
} // namespace
41+
42+
class MRMWMutexTest : public ::testing::Test {
43+
protected:
44+
MRMWMutex mutex_;
45+
std::mt19937 generator_;
46+
void SetUp() override {
47+
#ifdef __linux__
48+
if (absl::GetFlag(FLAGS_force_epoll)) {
49+
pp_.reset(util::fb2::Pool::Epoll(2));
50+
} else {
51+
pp_.reset(util::fb2::Pool::IOUring(16, 2));
52+
}
53+
#else
54+
pp_.reset(fb2::Pool::Epoll(2));
55+
#endif
56+
pp_->Run();
57+
}
58+
void TearDown() override {
59+
pp_->Stop();
60+
pp_.reset();
61+
}
62+
std::unique_ptr<util::ProactorPool> pp_;
63+
};
64+
65+
// Test 1: Multiple readers can lock concurrently
66+
TEST_F(MRMWMutexTest, MultipleReadersConcurrently) {
67+
std::atomic<size_t> read_count(0);
68+
const int num_readers = 5;
69+
70+
std::vector<util::fb2::Fiber> readers;
71+
readers.reserve(num_readers);
72+
73+
for (int i = 0; i < num_readers; ++i) {
74+
readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] {
75+
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
76+
}));
77+
}
78+
79+
// Wait for all reader threads to finish
80+
for (auto& t : readers) {
81+
t.Join();
82+
}
83+
84+
// All readers should have been able to lock the mutex concurrently
85+
EXPECT_EQ(read_count.load(), 0);
86+
}
87+
88+
// Test 2: Writer blocks readers and writer should get the lock exclusively
89+
TEST_F(MRMWMutexTest, ReadersBlockWriters) {
90+
std::atomic<size_t> read_count(0);
91+
std::atomic<size_t> write_count(0);
92+
93+
const int num_readers = 10;
94+
95+
// Start multiple readers
96+
std::vector<util::fb2::Fiber> readers;
97+
readers.reserve(num_readers);
98+
99+
for (int i = 0; i < num_readers; ++i) {
100+
readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] {
101+
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
102+
}));
103+
}
104+
105+
// Give readers time to acquire the lock
106+
util::ThisFiber::SleepFor(std::chrono::milliseconds(10));
107+
108+
pp_->at(1)
109+
->LaunchFiber(util::fb2::Launch::post,
110+
[&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); })
111+
.Join();
112+
113+
// Wait for all reader threads to finish
114+
for (auto& t : readers) {
115+
t.Join();
116+
}
117+
118+
EXPECT_EQ(read_count.load(), 0);
119+
EXPECT_EQ(write_count.load(), 0);
120+
}
121+
122+
// Test 3: Unlock transitions correctly and wakes up waiting threads
123+
TEST_F(MRMWMutexTest, ReaderAfterWriter) {
124+
std::atomic<size_t> write_count(0);
125+
std::atomic<size_t> read_count(0);
126+
127+
// Start a writer thread
128+
auto writer = pp_->at(1)->LaunchFiber(util::fb2::Launch::post, [&] {
129+
WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime);
130+
});
131+
132+
// Give writer time to acquire the lock
133+
util::ThisFiber::SleepFor(std::chrono::milliseconds(10));
134+
135+
// Now start a reader task that will block until the writer is done
136+
pp_->at(0)
137+
->LaunchFiber(util::fb2::Launch::post,
138+
[&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); })
139+
.Join();
140+
141+
// Ensure that writer has completed
142+
writer.Join();
143+
144+
EXPECT_EQ(read_count.load(), 0);
145+
EXPECT_EQ(write_count.load(), 0);
146+
}
147+
148+
// Test 4: Ensure writer gets the lock after readers finish
149+
TEST_F(MRMWMutexTest, WriterAfterReaders) {
150+
std::atomic<size_t> read_count(0);
151+
std::atomic<size_t> write_count(0);
152+
153+
// Start multiple readers
154+
const int num_readers = 10;
155+
std::vector<util::fb2::Fiber> readers;
156+
readers.reserve(num_readers);
157+
158+
for (int i = 0; i < num_readers; ++i) {
159+
readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] {
160+
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
161+
}));
162+
}
163+
164+
// Wait for all readers to acquire and release the lock
165+
for (auto& t : readers) {
166+
t.Join();
167+
}
168+
169+
// Start the writer after all readers are done
170+
pp_->at(1)
171+
->LaunchFiber(util::fb2::Launch::post,
172+
[&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); })
173+
.Join();
174+
175+
EXPECT_EQ(read_count.load(), 0);
176+
EXPECT_EQ(write_count.load(), 0);
177+
}
178+
179+
// Test 5: Mix of readers and writes
180+
TEST_F(MRMWMutexTest, MixWritersReaders) {
181+
std::atomic<size_t> read_count(0);
182+
std::atomic<size_t> write_count(0);
183+
184+
// Start multiple readers and writers
185+
const int num_threads = 100;
186+
std::vector<util::fb2::Fiber> threads;
187+
threads.reserve(num_threads + 1);
188+
189+
// Add long read task that will block all write tasks
190+
threads.emplace_back(
191+
pp_->at(0)->LaunchFiber([&] { ReadTask(&mutex_, std::ref(read_count), 2000); }));
192+
193+
// Give long writer time to acquire the lock
194+
util::ThisFiber::SleepFor(std::chrono::milliseconds(100));
195+
196+
size_t write_threads = 0;
197+
for (int i = 0; i < num_threads; ++i) {
198+
size_t fiber_id = rand() % 2;
199+
if (rand() % 3) {
200+
threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] {
201+
ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime);
202+
}));
203+
} else {
204+
write_threads++;
205+
threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] {
206+
WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime);
207+
}));
208+
}
209+
}
210+
211+
// All shorter threads should be done and only long one remains
212+
util::ThisFiber::SleepFor(std::chrono::milliseconds(500));
213+
214+
EXPECT_EQ(read_count.load(), 1);
215+
216+
EXPECT_EQ(write_count.load(), write_threads);
217+
218+
// Wait for all readers to acquire and release the lock
219+
for (auto& t : threads) {
220+
t.Join();
221+
}
222+
}
223+
224+
} // namespace dfly::search

0 commit comments

Comments
 (0)