From aca17e31aa04ba5d0036784a94531ddebc55bdb1 Mon Sep 17 00:00:00 2001 From: zhou-haitao <1300182097@qq.com> Date: Tue, 16 Dec 2025 14:41:33 +0800 Subject: [PATCH] Modify the test tool code to adapt to the latest code --- test/test_ucm_connector_save_load.py | 14 ++- ucm/store/test/e2e/nfsstore_embed_fetch.py | 100 ++++++++++----------- 2 files changed, 61 insertions(+), 53 deletions(-) diff --git a/test/test_ucm_connector_save_load.py b/test/test_ucm_connector_save_load.py index c0def663f..99717b2d5 100644 --- a/test/test_ucm_connector_save_load.py +++ b/test/test_ucm_connector_save_load.py @@ -254,7 +254,13 @@ def run_once( load_block_ids=([], []), dump_block_ids=(dump_hashes, dump_vllm_block_ids), ) - connector.connector.kv_caches = kv_caches + + if ( + not hasattr(connector.connector, "k_store") + or connector.connector.k_store is None + ): + connector.connector.register_kv_caches(kv_caches) + connector.bind_connector_metadata(metadata) total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla) @@ -267,7 +273,7 @@ def run_once( write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0 - lookup = connector.connector.store.lookup(dump_hashes) + lookup = connector.connector.k_store.lookup(dump_hashes) if not all(lookup): raise RuntimeError("Found missing cache blocks before load test.") @@ -277,7 +283,7 @@ def run_once( load_block_ids=(dump_hashes, load_vllm_block_ids), dump_block_ids=([], []), ) - connector.connector.kv_caches = kv_caches + connector.bind_connector_metadata(load_metadata) forward_context = build_forward_context(kv_caches, is_mla) @@ -375,6 +381,8 @@ def broadcast(self, tensor, src): mla, ) + connector.connector.register_kv_caches(kv_caches) + w_sizes, w_times, w_bws = [], [], [] r_sizes, r_times, r_bws = [], [], [] diff --git a/ucm/store/test/e2e/nfsstore_embed_fetch.py b/ucm/store/test/e2e/nfsstore_embed_fetch.py index 1132afa50..bd2663271 100644 --- a/ucm/store/test/e2e/nfsstore_embed_fetch.py +++ b/ucm/store/test/e2e/nfsstore_embed_fetch.py @@ -32,7 +32,9 @@ import torch from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore +from ucm.store.pcstore.pcstore_connector_v1 import UcmPcStoreV1 from ucm.store.ucmstore import UcmKVStoreBase +from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1 def setup( @@ -42,7 +44,7 @@ def setup( io_size, transferStreamNumber, transferIoDirect, -) -> UcmKVStoreBase: +) -> UcmKVStoreBaseV1: config = { "storage_backends": storage_backends, "kv_block_size": block_size, @@ -51,8 +53,9 @@ def setup( "io_size": io_size, "transferStreamNumber": transferStreamNumber, "transferIoDirect": transferIoDirect, + "unique_id": secrets.token_hex(8), } - return UcmNfsStore(config) + return UcmPcStoreV1(config) def make_aligned_tensor(shape, dtype, device, alignment=4096): @@ -79,64 +82,59 @@ def make_aligned_tensor(shape, dtype, device, alignment=4096): def make_buffers( block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv ): - hashes = [secrets.token_hex(16) for _ in range(block_number)] - kv_caches = {} - for i in range(block_layer): - kv_caches[i] = make_aligned_tensor( + hashes = [secrets.token_bytes(16) for _ in range(block_number)] + kvcaches = {} + for layer_id in range(block_layer): + kvcaches[layer_id] = make_aligned_tensor( [kv, block_number, block_len, num_head, head_dim], - dtype=torch.float16, + dtype=torch.bfloat16, device=f"cuda:{device_id}", ) - return hashes, kv_caches + kvcaches[layer_id].random_() + return hashes, kvcaches -def store_all_hashes(hashes: List[str]): +def store_all_hashes(hashes: List[bytes]): file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt") with open(file_path, "w", encoding="utf-8") as f: for h in hashes: - f.write(h + "\n") + f.write(h.hex() + "\n") -def load_hashes_from_file() -> List[str]: +def load_hashes_from_file() -> List[bytes]: file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt") if not os.path.exists(file_path): return [] with open(file_path, "r", encoding="utf-8") as f: - return [line.strip() for line in f.readlines()] + return [bytes.fromhex(line.strip()) for line in f.readlines()] def embed( - store: UcmKVStoreBase, - hashes: List[str], + store: UcmKVStoreBaseV1, + hashes: List[bytes], kvcaches: Dict[int, torch.Tensor], mla: bool, ): start_time = time.perf_counter() - total_block_ids, total_offsets, total_tensors = [], [], [] + total_tensors = [] total_size = 0 for i, hash_val in enumerate(hashes): - offset = 0 + tensors = [] for layer_id, kv_layer in kvcaches.items(): - k_tensor = kv_layer[0][i] # kv=1 - total_tensors.append(k_tensor) - total_block_ids.append(hash_val) - total_offsets.append(offset) + k_tensor = kv_layer[0][i].contiguous() + tensors.append(k_tensor) sz = k_tensor.numel() * k_tensor.element_size() - offset += sz total_size += sz if not mla: - v_tensor = kv_layer[1][i] - total_tensors.append(v_tensor) - total_block_ids.append(hash_val) - total_offsets.append(offset) + v_tensor = kv_layer[1][i].contiguous() + tensors.append(v_tensor) sz = v_tensor.numel() * v_tensor.element_size() - offset += sz total_size += sz - - task = store.dump(total_block_ids, total_offsets, total_tensors) + total_tensors.append(tensors) + task = store.dump(hashes, [], total_tensors) store.wait(task) elapsed_time = time.perf_counter() - start_time @@ -151,8 +149,8 @@ def embed( def fetch( - store: UcmKVStoreBase, - hashes: List[str], + store: UcmKVStoreBaseV1, + hashes: List[bytes], kvcaches: Dict[int, torch.Tensor], mla: bool, ): @@ -162,32 +160,33 @@ def fetch( for f in founds: assert f, "Cache block miss detected" - block_ids, offsets, tensors = [], [], [] + totoal_tensors = [] total_size = 0 for i, hash_val in enumerate(hashes): - offset = 0 + tensors = [] for layer_id, kv_layer in kvcaches.items(): - k_tensor = kv_layer[0][i] # kv=1 - block_ids.append(hash_val) - offsets.append(offset) + k_tensor = kv_layer[0][i].contiguous() tensors.append(k_tensor) sz = k_tensor.numel() * k_tensor.element_size() - offset += sz total_size += sz if not mla: - v_tensor = kv_layer[1][i] - block_ids.append(hash_val) - offsets.append(offset) + v_tensor = kv_layer[1][i].contiguous() tensors.append(v_tensor) sz = v_tensor.numel() * v_tensor.element_size() - offset += sz total_size += sz + totoal_tensors.append(tensors) - task = store.load(block_ids, offsets, tensors) - ret = store.wait(task) - assert ret == 0, "Load operation failed" + task = store.load(hashes, [], totoal_tensors) + try: + ret = store.wait(task) + if ret is None: + ret = 0 + except RuntimeError as e: + print(f"Load operation failed with error: {e}") + raise + assert ret == 0, f"Load operation failed with return code: {ret}" elapsed_time = time.perf_counter() - start_time throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0 @@ -226,6 +225,10 @@ def run( block_dim = head_size * num_head io_size = block_dim * block_len * block_elem_size block_size = io_size * block_layer + + if not mla: + block_size = block_size * 2 + batch_size = int(num_tokens / block_len) real_blocks = batch_size + 10 @@ -257,16 +260,13 @@ def run( kv, ) - results = store.create(hashes[:batch_size]) - assert sum(results) == 0, "Create operation failed" - w_size, w_time, w_bw = embed( store, hashes[:batch_size], kvcaches, mla, ) - store.commit(hashes[:batch_size], True) + time.sleep(1) if r == 0: store_all_hashes(hashes[:batch_size]) @@ -349,10 +349,10 @@ def run( try: result = run( storage_backends=".", - device_id=1, - repeat=1, + device_id=6, + repeat=2, num_head=1, - block_len=128, + block_len=64, transferStreamNumber=32, num_tokens=4096, block_layer=61,