Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ucm/store/pcstore/cc/domain/trans/trans_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ Status TransManager::Wait(const size_t taskId) noexcept
UC_ERROR("Task({}) timeout({}).", task->Str(), timeoutMs_);
failureSet_.Insert(taskId);
waiter->Wait();
failureSet_.Remove(taskId);
return Status::Timeout();
}
auto failure = failureSet_.Contains(taskId);
if (failure) {
Expand Down
61 changes: 55 additions & 6 deletions ucm/store/pcstore/cpy/pcstore.py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,61 @@
#include "pcstore.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "status/status.h"
#include "thread/latch.h"
#include "thread/thread_pool.h"

namespace py = pybind11;

namespace UC {

class PcStorePy : public PcStore {
using LookupCtx = std::pair<std::string, uint8_t*>;
ThreadPool<LookupCtx> lookupService_;
Latch lookupWaiter_;
std::atomic<int32_t> lookupStatus_;
std::mutex lookupMtx_;

public:
void* CCStoreImpl() { return this; }
int32_t SetupPy(const Config& config)
{
auto ret = Setup(config);
if (config.transferEnable || ret != Status::OK().Underlying()) { return ret; }
auto success =
lookupService_.SetNWorker(4)
.SetWorkerFn([this](auto& pair, auto) { OnLookup(pair.first, pair.second); })
.SetWorkerTimeoutFn([this](auto&, auto) { OnLookupTimeouted(); }, 10000)
.Run();
if (!success) {
UC_ERROR("Failed to start lookup service.");
return Status::Error().Underlying();
}
return Status::OK().Underlying();
}
py::list AllocBatch(const py::list& blocks)
{
py::list results;
for (auto& block : blocks) { results.append(this->Alloc(block.cast<std::string>())); }
return results;
}
py::list LookupBatch(const py::list& blocks)
std::vector<uint8_t> LookupBatch(const py::list& blocks)
{
py::list founds;
for (auto& block : blocks) { founds.append(this->Lookup(block.cast<std::string>())); }
return founds;
std::lock_guard lock{lookupMtx_};
const auto number = blocks.size();
const auto ok = Status::OK().Underlying();
lookupStatus_ = ok;
lookupWaiter_.Set(number);
std::vector<uint8_t> founds(number);
size_t idx = 0;
for (auto& block : blocks) {
lookupService_.Push({block.cast<std::string>(), founds.data() + idx});
idx++;
}
lookupWaiter_.Wait();
const auto ret = lookupStatus_.load();
if (ret == ok) { return founds; }
throw std::runtime_error{fmt::format("error({}) when performing LookupBatch", ret)};
}
void CommitBatch(const py::list& blocks, const bool success)
{
Expand Down Expand Up @@ -77,9 +113,22 @@ class PcStorePy : public PcStore {
}
return this->Submit(std::move(task));
}
void OnLookup(const std::string& block, uint8_t* found)
{
const auto ok = Status::OK().Underlying();
if (lookupStatus_ == ok) { *found = Lookup(block); }
lookupWaiter_.Done();
}
void OnLookupTimeouted()
{
auto ok = Status::OK().Underlying();
auto timeout = Status::Timeout().Underlying();
lookupStatus_.compare_exchange_weak(ok, timeout, std::memory_order_acq_rel);
lookupWaiter_.Done();
}
};

} // namespace UC
} // namespace UC

PYBIND11_MODULE(ucmpcstore, module)
{
Expand All @@ -106,7 +155,7 @@ PYBIND11_MODULE(ucmpcstore, module)
&UC::PcStorePy::Config::transferScatterGatherEnable);
store.def(py::init<>());
store.def("CCStoreImpl", &UC::PcStorePy::CCStoreImpl);
store.def("Setup", &UC::PcStorePy::Setup);
store.def("Setup", &UC::PcStorePy::SetupPy);
store.def("Alloc", py::overload_cast<const std::string&>(&UC::PcStorePy::Alloc));
store.def("AllocBatch", &UC::PcStorePy::AllocBatch);
store.def("Lookup", py::overload_cast<const std::string&>(&UC::PcStorePy::Lookup));
Expand Down
32 changes: 20 additions & 12 deletions ucm/store/test/e2e/pcstore_embed_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreB
config = {}
config["storage_backends"] = storage_backends
config["kv_block_size"] = block_size
config["role"] = "worker"
config["role"] = "worker" if device_id != -1 else "scheduler"
config["device"] = device_id
config["io_size"] = io_size
config["unique_id"] = secrets.token_hex(8)
Expand All @@ -63,22 +63,29 @@ def make_buffers(


def embed(
store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]
worker: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]
):
task = store.dump(hashes, [], tensors)
task = worker.dump(hashes, [], tensors)
assert task.task_id > 0
store.wait(task)
worker.wait(task)


def fetch(
store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]
scheduler: UcmKVStoreBaseV1,
worker: UcmKVStoreBaseV1,
hashes: List[bytes],
tensors: List[List[torch.Tensor]],
):
founds = store.lookup(hashes)
number = len(hashes)
tp = time.perf_counter()
founds = scheduler.lookup(hashes)
cost = time.perf_counter() - tp
print(f"Lookup {number} blocks cost {cost * 1e3:.03f}ms.")
for found in founds:
assert found
task = store.load(hashes, [], tensors)
task = worker.load(hashes, [], tensors)
assert task.task_id > 0
store.wait(task)
worker.wait(task)


def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0):
Expand All @@ -104,8 +111,9 @@ def main():
block_layer = 61
io_size = block_dim * block_len * block_elem_size
block_size = io_size * block_layer
batch_size = 64
store = setup_store(storage_backends, block_size, device_id, io_size)
batch_size = 256
worker = setup_store(storage_backends, block_size, device_id, io_size)
scheduler = setup_store(storage_backends, block_size, -1, io_size)
hashes, tensors = make_buffers(
block_number, device_id, batch_size, block_dim, block_len, block_layer
)
Expand All @@ -114,9 +122,9 @@ def main():
start = batch_size * batch
end = min(start + batch_size, block_number)
tensors2 = [[torch.empty_like(t) for t in row] for row in tensors]
embed(store, hashes[start:end], tensors)
embed(worker, hashes[start:end], tensors)
time.sleep(1)
fetch(store, hashes[start:end], tensors2)
fetch(scheduler, worker, hashes[start:end], tensors2)
cmp_and_print_diff(tensors, tensors2)


Expand Down