diff --git a/ucm/store/pcstore/cc/domain/trans/trans_manager.cc b/ucm/store/pcstore/cc/domain/trans/trans_manager.cc index aeb30543..63ee887b 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_manager.cc +++ b/ucm/store/pcstore/cc/domain/trans/trans_manager.cc @@ -94,6 +94,8 @@ Status TransManager::Wait(const size_t taskId) noexcept UC_ERROR("Task({}) timeout({}).", task->Str(), timeoutMs_); failureSet_.Insert(taskId); waiter->Wait(); + failureSet_.Remove(taskId); + return Status::Timeout(); } auto failure = failureSet_.Contains(taskId); if (failure) { diff --git a/ucm/store/pcstore/cpy/pcstore.py.cc b/ucm/store/pcstore/cpy/pcstore.py.cc index cadb1ebe..834f08a7 100644 --- a/ucm/store/pcstore/cpy/pcstore.py.cc +++ b/ucm/store/pcstore/cpy/pcstore.py.cc @@ -24,25 +24,61 @@ #include "pcstore.h" #include #include +#include "status/status.h" +#include "thread/latch.h" +#include "thread/thread_pool.h" namespace py = pybind11; namespace UC { class PcStorePy : public PcStore { + using LookupCtx = std::pair; + ThreadPool lookupService_; + Latch lookupWaiter_; + std::atomic lookupStatus_; + std::mutex lookupMtx_; + public: void* CCStoreImpl() { return this; } + int32_t SetupPy(const Config& config) + { + auto ret = Setup(config); + if (config.transferEnable || ret != Status::OK().Underlying()) { return ret; } + auto success = + lookupService_.SetNWorker(4) + .SetWorkerFn([this](auto& pair, auto) { OnLookup(pair.first, pair.second); }) + .SetWorkerTimeoutFn([this](auto&, auto) { OnLookupTimeouted(); }, 10000) + .Run(); + if (!success) { + UC_ERROR("Failed to start lookup service."); + return Status::Error().Underlying(); + } + return Status::OK().Underlying(); + } py::list AllocBatch(const py::list& blocks) { py::list results; for (auto& block : blocks) { results.append(this->Alloc(block.cast())); } return results; } - py::list LookupBatch(const py::list& blocks) + std::vector LookupBatch(const py::list& blocks) { - py::list founds; - for (auto& block : blocks) { founds.append(this->Lookup(block.cast())); } - return founds; + std::lock_guard lock{lookupMtx_}; + const auto number = blocks.size(); + const auto ok = Status::OK().Underlying(); + lookupStatus_ = ok; + lookupWaiter_.Set(number); + std::vector founds(number); + size_t idx = 0; + for (auto& block : blocks) { + lookupService_.Push({block.cast(), founds.data() + idx}); + idx++; + } + lookupWaiter_.Wait(); + const auto ret = lookupStatus_.load(); + if (ret == ok) { return founds; } + throw std::runtime_error{fmt::format("error({}) when performing LookupBatch", ret)}; } void CommitBatch(const py::list& blocks, const bool success) { @@ -77,9 +113,22 @@ class PcStorePy : public PcStore { } return this->Submit(std::move(task)); } + void OnLookup(const std::string& block, uint8_t* found) + { + const auto ok = Status::OK().Underlying(); + if (lookupStatus_ == ok) { *found = Lookup(block); } + lookupWaiter_.Done(); + } + void OnLookupTimeouted() + { + auto ok = Status::OK().Underlying(); + auto timeout = Status::Timeout().Underlying(); + lookupStatus_.compare_exchange_weak(ok, timeout, std::memory_order_acq_rel); + lookupWaiter_.Done(); + } }; -} // namespace UC +} // namespace UC PYBIND11_MODULE(ucmpcstore, module) { @@ -106,7 +155,7 @@ PYBIND11_MODULE(ucmpcstore, module) &UC::PcStorePy::Config::transferScatterGatherEnable); store.def(py::init<>()); store.def("CCStoreImpl", &UC::PcStorePy::CCStoreImpl); - store.def("Setup", &UC::PcStorePy::Setup); + store.def("Setup", &UC::PcStorePy::SetupPy); store.def("Alloc", py::overload_cast(&UC::PcStorePy::Alloc)); store.def("AllocBatch", &UC::PcStorePy::AllocBatch); store.def("Lookup", py::overload_cast(&UC::PcStorePy::Lookup)); diff --git a/ucm/store/test/e2e/pcstore_embed_v1.py b/ucm/store/test/e2e/pcstore_embed_v1.py index 68e5ffdc..edd4512d 100644 --- a/ucm/store/test/e2e/pcstore_embed_v1.py +++ b/ucm/store/test/e2e/pcstore_embed_v1.py @@ -37,7 +37,7 @@ def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreB config = {} config["storage_backends"] = storage_backends config["kv_block_size"] = block_size - config["role"] = "worker" + config["role"] = "worker" if device_id != -1 else "scheduler" config["device"] = device_id config["io_size"] = io_size config["unique_id"] = secrets.token_hex(8) @@ -63,22 +63,29 @@ def make_buffers( def embed( - store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]] + worker: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]] ): - task = store.dump(hashes, [], tensors) + task = worker.dump(hashes, [], tensors) assert task.task_id > 0 - store.wait(task) + worker.wait(task) def fetch( - store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]] + scheduler: UcmKVStoreBaseV1, + worker: UcmKVStoreBaseV1, + hashes: List[bytes], + tensors: List[List[torch.Tensor]], ): - founds = store.lookup(hashes) + number = len(hashes) + tp = time.perf_counter() + founds = scheduler.lookup(hashes) + cost = time.perf_counter() - tp + print(f"Lookup {number} blocks cost {cost * 1e3:.03f}ms.") for found in founds: assert found - task = store.load(hashes, [], tensors) + task = worker.load(hashes, [], tensors) assert task.task_id > 0 - store.wait(task) + worker.wait(task) def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0): @@ -104,8 +111,9 @@ def main(): block_layer = 61 io_size = block_dim * block_len * block_elem_size block_size = io_size * block_layer - batch_size = 64 - store = setup_store(storage_backends, block_size, device_id, io_size) + batch_size = 256 + worker = setup_store(storage_backends, block_size, device_id, io_size) + scheduler = setup_store(storage_backends, block_size, -1, io_size) hashes, tensors = make_buffers( block_number, device_id, batch_size, block_dim, block_len, block_layer ) @@ -114,9 +122,9 @@ def main(): start = batch_size * batch end = min(start + batch_size, block_number) tensors2 = [[torch.empty_like(t) for t in row] for row in tensors] - embed(store, hashes[start:end], tensors) + embed(worker, hashes[start:end], tensors) time.sleep(1) - fetch(store, hashes[start:end], tensors2) + fetch(scheduler, worker, hashes[start:end], tensors2) cmp_and_print_diff(tensors, tensors2)