From 3a7cdaa6b4ae6a14288c5dec988a3af2eb936697 Mon Sep 17 00:00:00 2001 From: qyh111 Date: Thu, 18 Dec 2025 01:03:35 -0800 Subject: [PATCH] add chunk size --- examples/ucm_config_example.yaml | 2 + ucm/integration/vllm/ucm_connector.py | 73 ++++++++++++++++++--------- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index 632a9a38..0d2a1083 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -16,6 +16,8 @@ ucm_connectors: load_only_first_rank: false +chunk_size: 256 + # Enable UCM metrics so they can be monitored online via Grafana and Prometheus. # metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml" diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index faf4051f..c661509e 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -130,7 +130,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.group_coordinator = get_tp_group() self.broadcast_fn = self.group_coordinator.broadcast self.broadcast_stream = torch.cuda.Stream() - self.chunk_size = 1 + + self.chunk_size = self.launch_config.get("chunk_size", self.block_size) + if self.chunk_size % self.block_size != 0: + raise ValueError( + f"chunk_size ({self.chunk_size}) must be a multiple of " + f"block_size ({self.block_size})" + ) + self.blocks_per_chunk = self.chunk_size // self.block_size + logger.info( + f"chunk_size = {self.chunk_size}, blocks_per_chunk = {self.blocks_per_chunk}" + ) if role == KVConnectorRole.SCHEDULER: self.request_hasher = RequestHasher(vllm_config, 0) @@ -258,7 +268,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): chunk_block_size = ( tensor_size * self.num_layers - * self.chunk_size + * self.blocks_per_chunk * (1 if self.is_mla or self.is_dsa else 2) ) self.block_data_size = chunk_block_size @@ -267,7 +277,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): rope_tensor_size = ( sample_kv_layer[1][0].numel() * sample_kv_layer[1][0].element_size() ) - rope_chunk_block_size = rope_tensor_size * self.num_layers * self.chunk_size + rope_chunk_block_size = ( + rope_tensor_size * self.num_layers * self.blocks_per_chunk + ) self.rope_store = self._create_store( rope_tensor_size, rope_chunk_block_size, True ) @@ -279,9 +291,9 @@ def get_num_new_matched_tokens( num_computed_tokens: int, ) -> tuple[int, bool]: assert num_computed_tokens % self.block_size == 0 - hbm_hit_block_num = num_computed_tokens // self.block_size + hbm_hit_block_num = num_computed_tokens // self.chunk_size - ucm_block_ids = self.generate_hash(self.block_size, request) + ucm_block_ids = self.generate_hash(self.chunk_size, request) external_block_ids = ucm_block_ids[hbm_hit_block_num:] if not external_block_ids: @@ -307,12 +319,12 @@ def get_num_new_matched_tokens( total_hit_block_num = hbm_hit_block_num + external_hit_blocks - external_hit_tokens = external_hit_blocks * self.block_size + external_hit_tokens = external_hit_blocks * self.chunk_size # When all the tokens are cached in ssd or hbm, # we need to recompute the last token. This if condition will be removed # once vLLM scheduler provides a better solution in the future. - num_total_hit_tokens = total_hit_block_num * self.block_size + num_total_hit_tokens = total_hit_block_num * self.chunk_size if num_total_hit_tokens == request.num_tokens: external_hit_tokens -= 1 @@ -359,13 +371,19 @@ def _generate_dispatch_meta( dump_ucm_block_ids, dump_vllm_block_ids = [], [] if need_load: load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num] - load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num] + load_vllm_block_ids = vllm_block_ids[ + hbm_hit_block_num + * self.blocks_per_chunk : total_hit_block_num + * self.blocks_per_chunk + ] if req_meta.token_processed < req_meta.num_token_ids: - start_idx = req_meta.token_processed // self.block_size - end_idx = (req_meta.token_processed + new_tokens) // self.block_size + start_idx = req_meta.token_processed // self.chunk_size + end_idx = (req_meta.token_processed + new_tokens) // self.chunk_size dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx] - dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx] + dump_vllm_block_ids = req_meta.vllm_block_ids[ + start_idx * self.blocks_per_chunk : end_idx * self.blocks_per_chunk + ] req_meta.token_processed += new_tokens return RequestDispatchMeta( @@ -456,18 +474,25 @@ def _generate_task( ) -> Tuple[ List[bytes], List[int], List[List[torch.Tensor]], List[List[torch.Tensor]] ]: - block_ids, shard_indexs, total_tensors, rope_tensors = [], [], [], [] - for i, vllm_block_id in enumerate(vllm_block_ids): - k_tensors, v_tensors = self._get_tensors(vllm_block_id) - block_ids.append(ucm_block_ids[i]) - if self.is_dsa: - total_tensors.append(k_tensors) - rope_tensors.append(v_tensors) - else: - total_tensors.append(k_tensors + v_tensors) + block_ids, shard_indexs, total_tensors, total_rope_tensors = [], [], [], [] + for ucm_block_id, i in zip( + ucm_block_ids, range(0, len(vllm_block_ids), self.blocks_per_chunk) + ): + block_ids.append(ucm_block_id) shard_indexs.append(0) - - return block_ids, shard_indexs, total_tensors, rope_tensors + tensors, rope_tensors = [], [] + for j in range(self.blocks_per_chunk): + vllm_block_id = vllm_block_ids[i + j] + k_tensors, v_tensors = self._get_tensors(vllm_block_id) + if self.is_dsa: + tensors.extend(k_tensors) + rope_tensors.extend(v_tensors) + else: + tensors.extend(k_tensors + v_tensors) + total_tensors.append(tensors) + total_rope_tensors.append(rope_tensors) + + return block_ids, shard_indexs, total_tensors, total_rope_tensors def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): rec_tensor: torch.Tensor = None @@ -737,7 +762,7 @@ def get_num_new_matched_tokens( expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens) if hit_tokens <= expect_hit_tokens: return hit_tokens, False - expect_hit_block_num = expect_hit_tokens // self.block_size + expect_hit_block_num = expect_hit_tokens // self.chunk_size request_meta = self.requests_meta[request.request_id] request_meta.total_hit_block_num = expect_hit_block_num request_meta.hbm_hit_block_num = min( @@ -752,7 +777,7 @@ def get_num_new_matched_tokens( f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}" ) - return expect_hit_block_num * self.block_size, False + return expect_hit_block_num * self.chunk_size, False class UCMConnector(KVConnectorBase_V1):