From 213642cdd94c6093ecaa31b6f2270261d601dc88 Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Sun, 14 Dec 2025 18:52:48 -0800 Subject: [PATCH 1/2] [Feat] Support broadcast on ucm store v1 --- ucm/integration/vllm/ucm_connector.py | 96 ++++++++++++++++++--------- 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index ee8558e8..f1ef9743 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -103,17 +103,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if current_platform.is_cuda_alike(): logger.info("CUDA device is available.") - torch_dev = torch + self.torch_dev = torch.cuda dev_name = "cuda" elif current_platform.device_type == "npu": logger.info("NPU device is available.") - torch_dev = torch.npu + self.torch_dev = torch.npu dev_name = "npu" else: raise RuntimeError("Unsupported device platform for UCMDirectConnector.") if self.local_rank >= 0: - self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") + self.device = torch.device(f"{dev_name}:{self.local_rank}") self.k_store: UcmKVStoreBaseV1 self.v_store: Optional[UcmKVStoreBaseV1] = None @@ -134,7 +134,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if role == KVConnectorRole.WORKER: self.group_coordinator = get_tp_group() self.broadcast_fn = self.group_coordinator.broadcast - self.broadcast_stream = torch.cuda.Stream() + self.broadcast_stream = self.torch_dev.Stream() + self._broadcast_buffer = None + self._broadcast_buffer_size = 0 name = self.connector_configs[0].get("ucm_connector_name") config = self.connector_configs[0].get("ucm_connector_config") or {} @@ -167,12 +169,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self.monitor = ucmmonitor.StatsMonitor.get_instance() - self.synchronize = ( - torch.cuda.synchronize - if current_platform.is_cuda_alike() - else torch.npu.synchronize - ) - # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() @@ -461,24 +457,51 @@ def _generate_task( return block_ids, shard_indexs, total_k_tensors, total_v_tensors - def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): - rec_tensor: torch.Tensor = None - with torch.cuda.stream(self.broadcast_stream): - # TODO support broadcast when PP - if self.global_rank == 0: - tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0) - self.broadcast_fn(tensor_to_broadcast, 0) - else: - shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape - # TODO create earlier - rec_tensor = torch.empty( - shape, dtype=self.kv_cache_dtype, device=self.device + def _ensure_buffer(self, total_numel: int): + if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel: + self._broadcast_buffer = torch.empty( + total_numel, + dtype=self.kv_cache_dtype, + device=self.device, + ) + self._broadcast_buffer_size = total_numel + + def _broadcast(self, dst_tensor_addr: List[torch.Tensor]): + rec_tensor = None + total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel() + group = self.group_coordinator.device_group + if self.global_rank == 0: + tensor_to_broadcast = torch.stack(dst_tensor_addr) + handle = torch.distributed.broadcast( + tensor_to_broadcast, src=0, async_op=True, group=group + ) + else: + self._ensure_buffer(total_numel) + shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape + rec_tensor = self._broadcast_buffer[:total_numel].view(shape) + handle = torch.distributed.broadcast( + rec_tensor, src=0, async_op=True, group=group + ) + return handle, rec_tensor + + def _broadcast_layers(self, dst_tensor_addr: list[torch.Tensor]): + num_layers = len(self.kv_caches) + total = total = len(dst_tensor_addr) + assert num_layers > 0 and total % num_layers == 0, (num_layers, total) + num_tensors_per_layer = total // num_layers + + for layer_i in range(num_layers): + start = layer_i * num_tensors_per_layer + handle, rec_tensor = self._broadcast( + dst_tensor_addr[start : start + num_tensors_per_layer] + ) + handle.wait() + if self.global_rank != 0 and rec_tensor is not None: + rec_tensor_list = list(torch.unbind(rec_tensor, dim=0)) + torch._foreach_copy_( + dst_tensor_addr[start : start + num_tensors_per_layer], + rec_tensor_list, ) - self.broadcast_fn(rec_tensor, 0) - self.broadcast_stream.synchronize() - if self.global_rank != 0 and rec_tensor is not None: - for i, tensor in enumerate(dst_tensor_addr): - tensor.copy_(rec_tensor[i]) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() @@ -512,9 +535,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: request_to_task[request_id].append(v_task) else: request_to_task[request_id] = None - req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [ - t for row in v_tensors for t in row - ] + if v_tensors and self.v_store: + req_broadcast_addr[request_id] = ([t for row in k_tensors for t in row], [ + t for row in v_tensors for t in row + ]) + else: + req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] for request_id, tasks in request_to_task.items(): # TODO error handling @@ -529,7 +555,13 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata.request_meta[request_id].load_block_ids[1] ) if self.load_only_first_rank: - self._broadcast(req_broadcast_addr[request_id]) + if isinstance(req_broadcast_addr[request_id], tuple) and self.v_store: + k_nope, k_rope = req_broadcast_addr[request_id] + self._broadcast_layers(k_nope) + self._broadcast_layers(k_rope) + else: + self._broadcast_layers(req_broadcast_addr[request_id]) + load_end_time = time.perf_counter() * 1000 load_speed = ( num_loaded_block @@ -569,7 +601,7 @@ def wait_for_save(self) -> None: if self.metrics_config or current_platform.device_type == "npu": # When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise # This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions. - self.synchronize() + self.torch_dev.synchronize() metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) From 2da0897fae12dfcfd55d0d224dfacd84f9bfe1d4 Mon Sep 17 00:00:00 2001 From: harrisonyhq Date: Sun, 14 Dec 2025 22:37:22 -0800 Subject: [PATCH 2/2] [Style] Format code and add docstring --- ucm/integration/vllm/ucm_connector.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index f1ef9743..c79f5438 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -458,6 +458,10 @@ def _generate_task( return block_ids, shard_indexs, total_k_tensors, total_v_tensors def _ensure_buffer(self, total_numel: int): + """ + Initialize or ensure buffer for broadcast; + Typically this buffer length equals to one layer kv cache tensor size. + """ if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel: self._broadcast_buffer = torch.empty( total_numel, @@ -467,6 +471,9 @@ def _ensure_buffer(self, total_numel: int): self._broadcast_buffer_size = total_numel def _broadcast(self, dst_tensor_addr: List[torch.Tensor]): + """ + Broadcast tensor list in tp group. + """ rec_tensor = None total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel() group = self.group_coordinator.device_group @@ -485,8 +492,11 @@ def _broadcast(self, dst_tensor_addr: List[torch.Tensor]): return handle, rec_tensor def _broadcast_layers(self, dst_tensor_addr: list[torch.Tensor]): + """ + Broadcast kv caches by layer. + """ num_layers = len(self.kv_caches) - total = total = len(dst_tensor_addr) + total = len(dst_tensor_addr) assert num_layers > 0 and total % num_layers == 0, (num_layers, total) num_tensors_per_layer = total // num_layers @@ -536,9 +546,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: else: request_to_task[request_id] = None if v_tensors and self.v_store: - req_broadcast_addr[request_id] = ([t for row in k_tensors for t in row], [ - t for row in v_tensors for t in row - ]) + req_broadcast_addr[request_id] = ( + [t for row in k_tensors for t in row], + [t for row in v_tensors for t in row], + ) else: req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] @@ -556,12 +567,13 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: ) if self.load_only_first_rank: if isinstance(req_broadcast_addr[request_id], tuple) and self.v_store: + # In vllm_ascend >= 0.10.0, the MLA model's k cache is separated into (nope_dim, rope_dim) k_nope, k_rope = req_broadcast_addr[request_id] self._broadcast_layers(k_nope) self._broadcast_layers(k_rope) else: self._broadcast_layers(req_broadcast_addr[request_id]) - + load_end_time = time.perf_counter() * 1000 load_speed = ( num_loaded_block