Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/ucm_config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ ucm_connectors:

load_only_first_rank: false

chunk_size: 256

# Enable UCM metrics so they can be monitored online via Grafana and Prometheus.
# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml"

Expand Down
73 changes: 49 additions & 24 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.group_coordinator = get_tp_group()
self.broadcast_fn = self.group_coordinator.broadcast
self.broadcast_stream = torch.cuda.Stream()
self.chunk_size = 1

self.chunk_size = self.launch_config.get("chunk_size", self.block_size)
if self.chunk_size % self.block_size != 0:
raise ValueError(
f"chunk_size ({self.chunk_size}) must be a multiple of "
f"block_size ({self.block_size})"
)
self.blocks_per_chunk = self.chunk_size // self.block_size
logger.info(
f"chunk_size = {self.chunk_size}, blocks_per_chunk = {self.blocks_per_chunk}"
)

if role == KVConnectorRole.SCHEDULER:
self.request_hasher = RequestHasher(vllm_config, 0)
Expand Down Expand Up @@ -258,7 +268,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
chunk_block_size = (
tensor_size
* self.num_layers
* self.chunk_size
* self.blocks_per_chunk
* (1 if self.is_mla or self.is_dsa else 2)
)
self.block_data_size = chunk_block_size
Expand All @@ -267,7 +277,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
rope_tensor_size = (
sample_kv_layer[1][0].numel() * sample_kv_layer[1][0].element_size()
)
rope_chunk_block_size = rope_tensor_size * self.num_layers * self.chunk_size
rope_chunk_block_size = (
rope_tensor_size * self.num_layers * self.blocks_per_chunk
)
self.rope_store = self._create_store(
rope_tensor_size, rope_chunk_block_size, True
)
Expand All @@ -279,9 +291,9 @@ def get_num_new_matched_tokens(
num_computed_tokens: int,
) -> tuple[int, bool]:
assert num_computed_tokens % self.block_size == 0
hbm_hit_block_num = num_computed_tokens // self.block_size
hbm_hit_block_num = num_computed_tokens // self.chunk_size

ucm_block_ids = self.generate_hash(self.block_size, request)
ucm_block_ids = self.generate_hash(self.chunk_size, request)

external_block_ids = ucm_block_ids[hbm_hit_block_num:]
if not external_block_ids:
Expand All @@ -307,12 +319,12 @@ def get_num_new_matched_tokens(

total_hit_block_num = hbm_hit_block_num + external_hit_blocks

external_hit_tokens = external_hit_blocks * self.block_size
external_hit_tokens = external_hit_blocks * self.chunk_size

# When all the tokens are cached in ssd or hbm,
# we need to recompute the last token. This if condition will be removed
# once vLLM scheduler provides a better solution in the future.
num_total_hit_tokens = total_hit_block_num * self.block_size
num_total_hit_tokens = total_hit_block_num * self.chunk_size
if num_total_hit_tokens == request.num_tokens:
external_hit_tokens -= 1

Expand Down Expand Up @@ -359,13 +371,19 @@ def _generate_dispatch_meta(
dump_ucm_block_ids, dump_vllm_block_ids = [], []
if need_load:
load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num]
load_vllm_block_ids = vllm_block_ids[
hbm_hit_block_num
* self.blocks_per_chunk : total_hit_block_num
* self.blocks_per_chunk
]

if req_meta.token_processed < req_meta.num_token_ids:
start_idx = req_meta.token_processed // self.block_size
end_idx = (req_meta.token_processed + new_tokens) // self.block_size
start_idx = req_meta.token_processed // self.chunk_size
end_idx = (req_meta.token_processed + new_tokens) // self.chunk_size
dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx]
dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx]
dump_vllm_block_ids = req_meta.vllm_block_ids[
start_idx * self.blocks_per_chunk : end_idx * self.blocks_per_chunk
]
req_meta.token_processed += new_tokens

return RequestDispatchMeta(
Expand Down Expand Up @@ -456,18 +474,25 @@ def _generate_task(
) -> Tuple[
List[bytes], List[int], List[List[torch.Tensor]], List[List[torch.Tensor]]
]:
block_ids, shard_indexs, total_tensors, rope_tensors = [], [], [], []
for i, vllm_block_id in enumerate(vllm_block_ids):
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
block_ids.append(ucm_block_ids[i])
if self.is_dsa:
total_tensors.append(k_tensors)
rope_tensors.append(v_tensors)
else:
total_tensors.append(k_tensors + v_tensors)
block_ids, shard_indexs, total_tensors, total_rope_tensors = [], [], [], []
for ucm_block_id, i in zip(
ucm_block_ids, range(0, len(vllm_block_ids), self.blocks_per_chunk)
):
block_ids.append(ucm_block_id)
shard_indexs.append(0)

return block_ids, shard_indexs, total_tensors, rope_tensors
tensors, rope_tensors = [], []
for j in range(self.blocks_per_chunk):
vllm_block_id = vllm_block_ids[i + j]
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
if self.is_dsa:
tensors.extend(k_tensors)
rope_tensors.extend(v_tensors)
else:
tensors.extend(k_tensors + v_tensors)
total_tensors.append(tensors)
total_rope_tensors.append(rope_tensors)

return block_ids, shard_indexs, total_tensors, total_rope_tensors

def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
rec_tensor: torch.Tensor = None
Expand Down Expand Up @@ -737,7 +762,7 @@ def get_num_new_matched_tokens(
expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens)
if hit_tokens <= expect_hit_tokens:
return hit_tokens, False
expect_hit_block_num = expect_hit_tokens // self.block_size
expect_hit_block_num = expect_hit_tokens // self.chunk_size
request_meta = self.requests_meta[request.request_id]
request_meta.total_hit_block_num = expect_hit_block_num
request_meta.hbm_hit_block_num = min(
Expand All @@ -752,7 +777,7 @@ def get_num_new_matched_tokens(
f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}"
)

return expect_hit_block_num * self.block_size, False
return expect_hit_block_num * self.chunk_size, False


class UCMConnector(KVConnectorBase_V1):
Expand Down
Loading