diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index ee8558e8..c79f5438 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -103,17 +103,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if current_platform.is_cuda_alike(): logger.info("CUDA device is available.") - torch_dev = torch + self.torch_dev = torch.cuda dev_name = "cuda" elif current_platform.device_type == "npu": logger.info("NPU device is available.") - torch_dev = torch.npu + self.torch_dev = torch.npu dev_name = "npu" else: raise RuntimeError("Unsupported device platform for UCMDirectConnector.") if self.local_rank >= 0: - self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") + self.device = torch.device(f"{dev_name}:{self.local_rank}") self.k_store: UcmKVStoreBaseV1 self.v_store: Optional[UcmKVStoreBaseV1] = None @@ -134,7 +134,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if role == KVConnectorRole.WORKER: self.group_coordinator = get_tp_group() self.broadcast_fn = self.group_coordinator.broadcast - self.broadcast_stream = torch.cuda.Stream() + self.broadcast_stream = self.torch_dev.Stream() + self._broadcast_buffer = None + self._broadcast_buffer_size = 0 name = self.connector_configs[0].get("ucm_connector_name") config = self.connector_configs[0].get("ucm_connector_config") or {} @@ -167,12 +169,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self.monitor = ucmmonitor.StatsMonitor.get_instance() - self.synchronize = ( - torch.cuda.synchronize - if current_platform.is_cuda_alike() - else torch.npu.synchronize - ) - # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() @@ -461,24 +457,61 @@ def _generate_task( return block_ids, shard_indexs, total_k_tensors, total_v_tensors - def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): - rec_tensor: torch.Tensor = None - with torch.cuda.stream(self.broadcast_stream): - # TODO support broadcast when PP - if self.global_rank == 0: - tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0) - self.broadcast_fn(tensor_to_broadcast, 0) - else: - shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape - # TODO create earlier - rec_tensor = torch.empty( - shape, dtype=self.kv_cache_dtype, device=self.device + def _ensure_buffer(self, total_numel: int): + """ + Initialize or ensure buffer for broadcast; + Typically this buffer length equals to one layer kv cache tensor size. + """ + if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel: + self._broadcast_buffer = torch.empty( + total_numel, + dtype=self.kv_cache_dtype, + device=self.device, + ) + self._broadcast_buffer_size = total_numel + + def _broadcast(self, dst_tensor_addr: List[torch.Tensor]): + """ + Broadcast tensor list in tp group. + """ + rec_tensor = None + total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel() + group = self.group_coordinator.device_group + if self.global_rank == 0: + tensor_to_broadcast = torch.stack(dst_tensor_addr) + handle = torch.distributed.broadcast( + tensor_to_broadcast, src=0, async_op=True, group=group + ) + else: + self._ensure_buffer(total_numel) + shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape + rec_tensor = self._broadcast_buffer[:total_numel].view(shape) + handle = torch.distributed.broadcast( + rec_tensor, src=0, async_op=True, group=group + ) + return handle, rec_tensor + + def _broadcast_layers(self, dst_tensor_addr: list[torch.Tensor]): + """ + Broadcast kv caches by layer. + """ + num_layers = len(self.kv_caches) + total = len(dst_tensor_addr) + assert num_layers > 0 and total % num_layers == 0, (num_layers, total) + num_tensors_per_layer = total // num_layers + + for layer_i in range(num_layers): + start = layer_i * num_tensors_per_layer + handle, rec_tensor = self._broadcast( + dst_tensor_addr[start : start + num_tensors_per_layer] + ) + handle.wait() + if self.global_rank != 0 and rec_tensor is not None: + rec_tensor_list = list(torch.unbind(rec_tensor, dim=0)) + torch._foreach_copy_( + dst_tensor_addr[start : start + num_tensors_per_layer], + rec_tensor_list, ) - self.broadcast_fn(rec_tensor, 0) - self.broadcast_stream.synchronize() - if self.global_rank != 0 and rec_tensor is not None: - for i, tensor in enumerate(dst_tensor_addr): - tensor.copy_(rec_tensor[i]) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() @@ -512,9 +545,13 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: request_to_task[request_id].append(v_task) else: request_to_task[request_id] = None - req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [ - t for row in v_tensors for t in row - ] + if v_tensors and self.v_store: + req_broadcast_addr[request_id] = ( + [t for row in k_tensors for t in row], + [t for row in v_tensors for t in row], + ) + else: + req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] for request_id, tasks in request_to_task.items(): # TODO error handling @@ -529,7 +566,14 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata.request_meta[request_id].load_block_ids[1] ) if self.load_only_first_rank: - self._broadcast(req_broadcast_addr[request_id]) + if isinstance(req_broadcast_addr[request_id], tuple) and self.v_store: + # In vllm_ascend >= 0.10.0, the MLA model's k cache is separated into (nope_dim, rope_dim) + k_nope, k_rope = req_broadcast_addr[request_id] + self._broadcast_layers(k_nope) + self._broadcast_layers(k_rope) + else: + self._broadcast_layers(req_broadcast_addr[request_id]) + load_end_time = time.perf_counter() * 1000 load_speed = ( num_loaded_block @@ -569,7 +613,7 @@ def wait_for_save(self) -> None: if self.metrics_config or current_platform.device_type == "npu": # When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise # This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions. - self.synchronize() + self.torch_dev.synchronize() metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata)