Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 76 additions & 32 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):

if current_platform.is_cuda_alike():
logger.info("CUDA device is available.")
torch_dev = torch
self.torch_dev = torch.cuda
dev_name = "cuda"
elif current_platform.device_type == "npu":
logger.info("NPU device is available.")
torch_dev = torch.npu
self.torch_dev = torch.npu
dev_name = "npu"
else:
raise RuntimeError("Unsupported device platform for UCMDirectConnector.")

if self.local_rank >= 0:
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
self.device = torch.device(f"{dev_name}:{self.local_rank}")

self.k_store: UcmKVStoreBaseV1
self.v_store: Optional[UcmKVStoreBaseV1] = None
Expand All @@ -134,7 +134,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
if role == KVConnectorRole.WORKER:
self.group_coordinator = get_tp_group()
self.broadcast_fn = self.group_coordinator.broadcast
self.broadcast_stream = torch.cuda.Stream()
self.broadcast_stream = self.torch_dev.Stream()
self._broadcast_buffer = None
self._broadcast_buffer_size = 0

name = self.connector_configs[0].get("ucm_connector_name")
config = self.connector_configs[0].get("ucm_connector_config") or {}
Expand Down Expand Up @@ -167,12 +169,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
)
self.monitor = ucmmonitor.StatsMonitor.get_instance()

self.synchronize = (
torch.cuda.synchronize
if current_platform.is_cuda_alike()
else torch.npu.synchronize
)

# invlalid block ids due to load errors
self._invalid_block_ids: set[int] = set()

Expand Down Expand Up @@ -461,24 +457,61 @@ def _generate_task(

return block_ids, shard_indexs, total_k_tensors, total_v_tensors

def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
rec_tensor: torch.Tensor = None
with torch.cuda.stream(self.broadcast_stream):
# TODO support broadcast when PP
if self.global_rank == 0:
tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0)
self.broadcast_fn(tensor_to_broadcast, 0)
else:
shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape
# TODO create earlier
rec_tensor = torch.empty(
shape, dtype=self.kv_cache_dtype, device=self.device
def _ensure_buffer(self, total_numel: int):
"""
Initialize or ensure buffer for broadcast;
Typically this buffer length equals to one layer kv cache tensor size.
"""
if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel:
self._broadcast_buffer = torch.empty(
total_numel,
dtype=self.kv_cache_dtype,
device=self.device,
)
self._broadcast_buffer_size = total_numel

def _broadcast(self, dst_tensor_addr: List[torch.Tensor]):
"""
Broadcast tensor list in tp group.
"""
rec_tensor = None
total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel()
group = self.group_coordinator.device_group
if self.global_rank == 0:
tensor_to_broadcast = torch.stack(dst_tensor_addr)
handle = torch.distributed.broadcast(
tensor_to_broadcast, src=0, async_op=True, group=group
)
else:
self._ensure_buffer(total_numel)
shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape
rec_tensor = self._broadcast_buffer[:total_numel].view(shape)
handle = torch.distributed.broadcast(
rec_tensor, src=0, async_op=True, group=group
)
return handle, rec_tensor

def _broadcast_layers(self, dst_tensor_addr: list[torch.Tensor]):
"""
Broadcast kv caches by layer.
"""
num_layers = len(self.kv_caches)
total = len(dst_tensor_addr)
assert num_layers > 0 and total % num_layers == 0, (num_layers, total)
num_tensors_per_layer = total // num_layers

for layer_i in range(num_layers):
start = layer_i * num_tensors_per_layer
handle, rec_tensor = self._broadcast(
dst_tensor_addr[start : start + num_tensors_per_layer]
)
handle.wait()
if self.global_rank != 0 and rec_tensor is not None:
rec_tensor_list = list(torch.unbind(rec_tensor, dim=0))
torch._foreach_copy_(
dst_tensor_addr[start : start + num_tensors_per_layer],
rec_tensor_list,
)
self.broadcast_fn(rec_tensor, 0)
self.broadcast_stream.synchronize()
if self.global_rank != 0 and rec_tensor is not None:
for i, tensor in enumerate(dst_tensor_addr):
tensor.copy_(rec_tensor[i])

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
metadata = self._get_connector_metadata()
Expand Down Expand Up @@ -512,9 +545,13 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
request_to_task[request_id].append(v_task)
else:
request_to_task[request_id] = None
req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [
t for row in v_tensors for t in row
]
if v_tensors and self.v_store:
req_broadcast_addr[request_id] = (
[t for row in k_tensors for t in row],
[t for row in v_tensors for t in row],
)
else:
req_broadcast_addr[request_id] = [t for row in k_tensors for t in row]

for request_id, tasks in request_to_task.items():
# TODO error handling
Expand All @@ -529,7 +566,14 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
metadata.request_meta[request_id].load_block_ids[1]
)
if self.load_only_first_rank:
self._broadcast(req_broadcast_addr[request_id])
if isinstance(req_broadcast_addr[request_id], tuple) and self.v_store:
# In vllm_ascend >= 0.10.0, the MLA model's k cache is separated into (nope_dim, rope_dim)
k_nope, k_rope = req_broadcast_addr[request_id]
self._broadcast_layers(k_nope)
self._broadcast_layers(k_rope)
else:
self._broadcast_layers(req_broadcast_addr[request_id])

load_end_time = time.perf_counter() * 1000
load_speed = (
num_loaded_block
Expand Down Expand Up @@ -569,7 +613,7 @@ def wait_for_save(self) -> None:
if self.metrics_config or current_platform.device_type == "npu":
# When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise
# This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions.
self.synchronize()
self.torch_dev.synchronize()

metadata = self._get_connector_metadata()
assert isinstance(metadata, UCMConnectorMetadata)
Expand Down