diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py index 8d4a09af..8a41763d 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py @@ -56,9 +56,14 @@ def _apply_ascend_patch() -> None: # ========================= vllm_ascend/attention/attention_v1.py ========================= def _patch_attention_v1() -> None: - """Patch attention_v1.py for vLLM-Ascend.""" + """Patch attention_v1.py for vLLM-Ascend.. + + Key points: + - Skip hooks during compile/fake/meta stage to keep graph stable + - Allow hook begin() to return None (in-place) or (q,k,v,out) tuple + """ try: - from typing import List, Optional + from typing import Optional import torch from vllm.forward_context import ForwardContext, get_forward_context @@ -66,6 +71,21 @@ def _patch_attention_v1() -> None: from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + # ------------------------------------------------------------ + # 1) Graph-safe guards + # ------------------------------------------------------------ + def _should_skip_ucm_hooks(*tensors: torch.Tensor) -> bool: + # Skip during torch.compile / Dynamo capture + if hasattr(torch, "_dynamo") and torch._dynamo.is_compiling(): + return True + # Skip FakeTensor / meta tensors (tracing/fake phase) + for t in tensors: + if isinstance(t, torch.Tensor) and ( + t.is_meta or "Fake" in type(t).__name__ + ): + return True + return False + def maybe_execute_sparse_attention_begin( query: torch.Tensor, key: torch.Tensor, @@ -97,79 +117,104 @@ def maybe_execute_sparse_attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ): if not has_ucm_sparse(): return + ucm_sparse = get_ucm_sparse() + attn_metadata = forward_context.attn_metadata if attn_metadata is None: return + ucm_sparse.attention_finished( - query, key, value, attn_output, layer_name, forward_context + query, key, value, attn_output, layer_name, forward_context, phase ) attention_v1.maybe_execute_sparse_attention_finished = ( maybe_execute_sparse_attention_finished ) - vllm_ops = torch.ops.vllm - orig_unified_ascend_attention_with_output = ( - vllm_ops.unified_ascend_attention_with_output - ) - - def _wrap_op_overload(orig, impl): - class _Wrapper: - def __init__(self, orig): - self._orig = orig - - def __call__(self, *args, **kwargs): - return impl(*args, **kwargs) - - def __getattr__(self, name): - return getattr(self._orig, name) - - return _Wrapper(orig) - - def unified_ascend_attention_with_output_impl( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, - ) -> None: - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - if not self.use_mla: - query, key, value, _ = maybe_execute_sparse_attention_begin( - query, key, value, layer_name, forward_context - ) - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output, - trace_flag=False, + # ------------------------------------------------------------ + # 2) Patch Python attention_v1 functions in-place (dispatcher remains) + # ------------------------------------------------------------ + target = getattr(attention_v1, "unified_ascend_attention_with_output", None) + if target is None: + raise AttributeError( + "vllm_ascend.attention.attention_v1 has no unified_ascend_attention_with_output" ) - if not self.use_mla: - maybe_execute_sparse_attention_finished( - query, key, value, output, layer_name, forward_context - ) - return - vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload( - orig_unified_ascend_attention_with_output, - unified_ascend_attention_with_output_impl, + g = target.__globals__ + g.update( + { + "torch": torch, + "Optional": Optional, + "ForwardContext": ForwardContext, + "get_forward_context": get_forward_context, + "_should_skip_ucm_hooks": _should_skip_ucm_hooks, + "maybe_execute_sparse_attention_begin": maybe_execute_sparse_attention_begin, + "maybe_execute_sparse_attention_finished": maybe_execute_sparse_attention_finished, + } ) - attention_v1.unified_ascend_attention_with_output = ( - unified_ascend_attention_with_output_impl + # NOTE: + # - Keep calling torch.ops.vllm.unified_ascend_attention_with_output inside this function. + # - We are NOT replacing torch.ops itself, only the Python caller. + src = r""" +def __ucm_hooked_unified_ascend_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + + layer = forward_context.no_compile_layers[layer_name] + kv_cache = layer.kv_cache[forward_context.virtual_engine] + + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== + # Graph-safe: skip hooks during compile/fake/meta and MLA path + do_sparse_hooks = (not _should_skip_ucm_hooks(query, key, value, output)) and (not layer.use_mla) + + if do_sparse_hooks: + # begin() may return None (in-place) or tuple; helper already normalizes to tuple + query, key, value, _ = maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context, output=None, phase=None + ) + # ====================== UCM-SPARSE-PATCH-END ======================== + layer.impl.forward( + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output, + trace_flag=False, + ) + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== + if do_sparse_hooks: + maybe_execute_sparse_attention_finished( + query, key, value, output, layer_name, forward_context, phase=None ) + # ====================== UCM-SPARSE-PATCH-END ======================== + return +""" + exec(src, g, g) + repl = g["__ucm_hooked_unified_ascend_attention_with_output"] + + # In-place swap code (0 freevars -> 0 freevars) + target.__code__ = repl.__code__ + target.__defaults__ = repl.__defaults__ + target.__kwdefaults__ = repl.__kwdefaults__ + try: + target.__name__ = "unified_ascend_attention_with_output_ucm_hooked" + except Exception: + pass + except ImportError as e: logger.error(f"Failed to patch attention_v1.py: {e}", exc_info=True) raise @@ -208,7 +253,9 @@ def forward( enable_multistream_mla: bool = False, ckq: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== forward_context: ForwardContext = get_forward_context() + # ====================== UCM-SPARSE-PATCH-END ======================== assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. @@ -389,6 +436,7 @@ def forward( if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== # TODO: use an elegant way to overlap prefill_q, prefill_k_c_normed, prefill_k_pe, _ = ( maybe_execute_sparse_attention_begin( @@ -400,6 +448,7 @@ def forward( phase="prefill", ) ) + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== output_prefill = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata ) @@ -410,6 +459,7 @@ def forward( current_ms_metadata.after_comm_event.record() else: output[num_decode_tokens:] = output_prefill + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== maybe_execute_sparse_attention_finished( prefill_q, prefill_k_c_normed, @@ -419,7 +469,9 @@ def forward( forward_context, "prefill", ) + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== if has_decode: + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== _, decode_ql_nope, decode_q_pe, _ = ( maybe_execute_sparse_attention_begin( torch.cat([decode_ql_nope, decode_q_pe], dim=-1), @@ -430,6 +482,7 @@ def forward( phase="decode", ) ) + # ====================== UCM-SPARSE-PATCH-END-hook3 ======================== if self.running_in_graph: return self._forward_decode( decode_ql_nope, @@ -456,6 +509,7 @@ def forward( current_ms_metadata.after_comm_event.record() else: output[:num_decode_tokens] = output_decode + # ====================== UCM-SPARSE-PATCH-BEGIN-hook4 ====================== maybe_execute_sparse_attention_finished( torch.cat([decode_ql_nope, decode_q_pe], dim=-1), decode_ql_nope, @@ -465,6 +519,7 @@ def forward( forward_context, "decode", ) + # ====================== UCM-SPARSE-PATCH-END-hook4 ======================== return output_padded @@ -532,7 +587,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== self.ucm_sparse_request_finished_in_worker(req_id) + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. @@ -640,14 +697,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the states of the running/resumed requests. req_data = scheduler_output.scheduled_cached_reqs - req_sparsed_slots = scheduler_output.req_sparsed_slots + req_sparsed_slots = ( + scheduler_output.req_sparsed_slots + ) ### UCM-SPARSE-PATCH ### is_last_rank = get_pp_group().is_last_rank for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] - is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + is_sparsed_request = ( + req_sparsed_slots[req_id] != INVALID_SLOT + ) ### UCM-SPARSE-PATCH ### req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: @@ -665,6 +726,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: new_token_ids[-num_new_tokens:] ) # Update the block IDs. + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== if resumed_from_preemption or is_sparsed_request: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -675,6 +737,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.block_ids, new_block_ids ): block_ids.extend(new_ids) + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -689,8 +752,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens ) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== if is_sparsed_request: self.input_batch.block_table.reset_row(req_index) + # ====================== UCM-SPARSE-PATCH-END-hook3 ======================== self.input_batch.block_table.append_row(new_block_ids, req_index) @@ -844,6 +909,7 @@ def _process_reqs( ) seq_lens = self.seq_lens_cpu[:num_reqs] + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== # TODO: improve performance, no `positions_np.copy()` sparsed_positions = positions_np.copy() req_sparsed_slots = scheduler_output.req_sparsed_slots @@ -864,6 +930,8 @@ def _process_reqs( block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = sparsed_positions % self.block_size + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== + np.add( block_numbers * self.block_size, block_offsets, @@ -892,16 +960,20 @@ def _process_reqs( else: attn_state = AscendAttentionState.PrefillCacheHit + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== for req_id in self.input_batch.req_id_to_index: is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT req_index = self.input_batch.req_id_to_index[req_id] if is_sparsed_request: seq_lens[req_index] = req_sparsed_slots[req_id] + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== self.attn_mask = self._make_attention_mask( seq_lens=seq_lens, query_lens=num_scheduled_tokens, - position=torch.tensor(sparsed_positions).npu(), + position=torch.tensor( + sparsed_positions + ).npu(), ### UCM-SPARSE-PATCH ### attn_state=attn_state, ) self.attn_state = attn_state # type: ignore @@ -1053,10 +1125,13 @@ def _process_reqs( maybe_converting_weight_acl_format( self.model, ACL_FORMAT_FRACTAL_ND ) + + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== self.maybe_setup_kv_connector(scheduler_output) self.maybe_execute_ucm_sparse_begin( scheduler_output, attn_metadata ) + # ====================== UCM-SPARSE-PATCH-END-hook3 ======================== hidden_states = self.model( input_ids=input_ids, @@ -1065,8 +1140,13 @@ def _process_reqs( inputs_embeds=inputs_embeds, **model_kwargs, ) + + # ====================== UCM-SPARSE-PATCH-BEGIN-hook4 ====================== self.maybe_wait_for_kv_save() - self.maybe_execute_ucm_sparse_finished() + logits_indices = self.maybe_execute_ucm_sparse_finished( + logits_indices + ) + # ====================== UCM-SPARSE-PATCH-END-hook4 ======================== use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1359,11 +1439,11 @@ def maybe_execute_ucm_sparse_begin( ) ucm_sparse.execute_begin(scheduler_output) - def maybe_execute_ucm_sparse_finished(self): + def maybe_execute_ucm_sparse_finished(self, logits_indices): if not has_ucm_sparse(): - return + return logits_indices ucm_sparse = get_ucm_sparse() - ucm_sparse.execute_finished() + return ucm_sparse.execute_finished(logits_indices) def ucm_sparse_request_finished_in_worker(self, request_id: str | int): if not has_ucm_sparse(): @@ -1390,6 +1470,7 @@ def _patch_worker_v1() -> None: import copy from typing import Optional + from vllm.distributed.kv_transfer import has_kv_transfer_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.sequence import IntermediateTensors @@ -1422,6 +1503,9 @@ def execute_model( output.tensors, all_gather_group=get_tp_group() ) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== + if not has_kv_transfer_group(): + return None kv_connector_output = output.kv_connector_output finished_sending = kv_connector_output.finished_sending finished_recving = kv_connector_output.finished_recving @@ -1432,6 +1516,7 @@ def execute_model( new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) new_output.kv_connector_output = kv_connector_output return new_output + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== assert isinstance(output, ModelRunnerOutput) return output @@ -1444,7 +1529,9 @@ def execute_model( def patched_init_worker_distributed_environment(self) -> None: original_init_worker_distributed_environment(self) + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== ensure_ucm_sparse_initialized(self.vllm_config) + # ====================== UCM-SPARSE-PATCH-END ======================== NPUWorker._init_worker_distributed_environment = ( patched_init_worker_distributed_environment diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py index 6ce7589e..27254c25 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py @@ -40,11 +40,11 @@ def _apply_sparse_adapt() -> None: """Apply sparse adapt patches.""" try: if _enable_sparse(): + _patch_attention_layer() + _patch_mla_common() _patch_block_table() _patch_kv_cache_manager() _patch_shared_storage_connector() - _patch_attention_layer() - _patch_mla_common() _patch_gpu_model_runner() _patch_gpu_worker() _patch_scheduler_output() @@ -120,8 +120,10 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== # modified slots by sparse algorithm req_sparsed_slots: dict[str, int] = None + # ====================== UCM-SPARSE-PATCH-END ======================== # Set module and qualname to make the class pickleable # This ensures pickle can find the class when serializing @@ -136,7 +138,12 @@ class SchedulerOutput: # ==================== vllm/attention/layer.py ==================== def _patch_attention_layer() -> None: - """Patch attention layer & unified_attention_with_output C++ op.""" + """ + Graph-safe patch for vLLM attention: + - DO NOT wrap torch.ops (breaks torch.compile / Dynamo graph capture) + - In-place replace Python functions' __code__ via exec() to keep 0 freevars + - Skip UCM hooks during compile/fake/meta stage + """ try: from typing import Optional @@ -149,82 +156,20 @@ def _patch_attention_layer() -> None: from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse - def attn_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - # For some alternate attention backends like MLA the attention output - # shape does not match the query shape, so we optionally let the model - # definition specify the output tensor shape. - output_shape: Optional[torch.Size] = None, - ) -> torch.Tensor: - """ - The KV cache is stored inside this class and is accessed via - `self.kv_cache`. - - Attention metadata (`attn_metadata`) is set using a context manager in - the model runner's `execute_model` method. It is accessed via forward - context using - `vllm.forward_context.get_forward_context().attn_metadata`. - """ - if self.calculate_kv_scales: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(query, key, value) - if self.use_output: - output_shape = output_shape if output_shape is not None else query.shape - output = torch.zeros( - output_shape, dtype=query.dtype, device=query.device - ) - hidden_size = output_shape[-1] - # We skip reshaping query, key and value tensors for the MLA - # backend since these tensors have different semantics and are - # processed differently. - if not self.use_mla: - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) - if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward( - self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output, - ) - else: - torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name - ) - return output.view(-1, hidden_size) - else: - if self.use_direct_call: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward( - self, query, key, value, self_kv_cache, attn_metadata - ) - else: - return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name - ) + # ------------------------------------------------------------ + # 1) Graph-safe guards + # ------------------------------------------------------------ + def _should_skip_ucm_hooks(*tensors: torch.Tensor) -> bool: + # Skip during torch.compile / Dynamo capture + if hasattr(torch, "_dynamo") and torch._dynamo.is_compiling(): + return True + # Skip FakeTensor / meta tensors (tracing/fake phase) + for t in tensors: + if isinstance(t, torch.Tensor) and ( + t.is_meta or "Fake" in type(t).__name__ + ): + return True + return False def maybe_execute_sparse_attention_begin( query: torch.Tensor, @@ -270,100 +215,161 @@ def maybe_execute_sparse_attention_finished( query, key, value, attn_output, layer_name, forward_context, phase ) - vllm_ops = torch.ops.vllm - orig_unified_attention_with_output = vllm_ops.unified_attention_with_output - orig_unified_attention = vllm_ops.unified_attention - - def _wrap_op_overload(orig, impl): - class _Wrapper: - def __init__(self, orig): - self._orig = orig + # ---------------------------------------------------------------- + # 2) Patch Python layer functions in-place (dispatcher remains) + # ---------------------------------------------------------------- + from vllm.attention import layer as attn_layer - def __call__(self, *args, **kwargs): - return impl(*args, **kwargs) - - def __getattr__(self, name): - return getattr(self._orig, name) - - return _Wrapper(orig) - - def unified_attention_impl( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - layer_name: str, - ) -> torch.Tensor: - wait_for_kv_layer_from_connector(layer_name) + target_unified_attention = getattr(attn_layer, "unified_attention", None) + target_unified_attention_with_output = getattr( + attn_layer, "unified_attention_with_output", None + ) - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - query, key, value, _ = maybe_execute_sparse_attention_begin( - query, key, value, layer_name, forward_context - ) - output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) - maybe_execute_sparse_attention_finished( - query, key, value, output, layer_name, forward_context + if ( + target_unified_attention is None + or target_unified_attention_with_output is None + ): + raise AttributeError( + "vllm.attention.layer missing unified_attention or unified_attention_with_output" ) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - return output - def unified_attention_with_output_impl( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, - output_scale: Optional[torch.Tensor] = None, - ) -> None: - wait_for_kv_layer_from_connector(layer_name) - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - if not self.use_mla: - query, key, value, output = maybe_execute_sparse_attention_begin( - query, key, value, layer_name, forward_context, output - ) - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - ) - if not self.use_mla: - maybe_execute_sparse_attention_finished( - query, key, value, output, layer_name, forward_context - ) + # Inject referenced symbols into the targets' globals so exec() has no freevars. + g = target_unified_attention_with_output.__globals__ + g.update( + { + "torch": torch, + "Optional": Optional, + "ForwardContext": ForwardContext, + "get_forward_context": get_forward_context, + "wait_for_kv_layer_from_connector": wait_for_kv_layer_from_connector, + "maybe_save_kv_layer_to_connector": maybe_save_kv_layer_to_connector, + "_should_skip_ucm_hooks": _should_skip_ucm_hooks, + "maybe_execute_sparse_attention_begin": maybe_execute_sparse_attention_begin, + "maybe_execute_sparse_attention_finished": maybe_execute_sparse_attention_finished, + } + ) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) + # Use exec to define top-level (0 freevars) replacement functions. + src = r""" +def __ucm_hooked_unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + + layer = forward_context.no_compile_layers[layer_name] + kv_cache = layer.kv_cache[forward_context.virtual_engine] + + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== + # Graph-safe skip: compile/fake/meta + do_sparse_hooks = (not _should_skip_ucm_hooks(query, key, value)) and (not layer.use_mla) + + if do_sparse_hooks: + query, key, value, _ = maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context + ) + # ====================== UCM-SPARSE-PATCH-END ======================== + out = layer.impl.forward(layer, query, key, value, kv_cache, attn_metadata) - vllm_ops.unified_attention_with_output = _wrap_op_overload( - orig_unified_attention_with_output, unified_attention_with_output_impl + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== + if do_sparse_hooks: + maybe_execute_sparse_attention_finished( + query, key, value, out, layer_name, forward_context ) - vllm_ops.unified_attention = _wrap_op_overload( - orig_unified_attention, unified_attention_impl + # ====================== UCM-SPARSE-PATCH-END ======================== + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return out + + +def __ucm_hooked_unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + + layer = forward_context.no_compile_layers[layer_name] + kv_cache = layer.kv_cache[forward_context.virtual_engine] + + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== + # Graph-safe skip: compile/fake/meta + do_sparse_hooks = (not _should_skip_ucm_hooks(query, key, value, output)) and (not layer.use_mla) + + if do_sparse_hooks: + query, key, value, output = maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context, output ) - from vllm.attention import layer + # ====================== UCM-SPARSE-PATCH-END ======================== + layer.impl.forward( + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + ) + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== + if do_sparse_hooks: + maybe_execute_sparse_attention_finished( + query, key, value, output, layer_name, forward_context + ) + # ====================== UCM-SPARSE-PATCH-END ======================== + maybe_save_kv_layer_to_connector(layer_name, kv_cache) +""" + exec(src, g, g) + repl_unified_attention = g["__ucm_hooked_unified_attention"] + repl_unified_attention_with_output = g[ + "__ucm_hooked_unified_attention_with_output" + ] + + # In-place code swap (freevars 0 -> 0) + target_unified_attention.__code__ = repl_unified_attention.__code__ + target_unified_attention.__defaults__ = repl_unified_attention.__defaults__ + target_unified_attention.__kwdefaults__ = repl_unified_attention.__kwdefaults__ + try: + target_unified_attention.__name__ = "unified_attention_ucm_hooked" + except Exception: + pass + + target_unified_attention_with_output.__code__ = ( + repl_unified_attention_with_output.__code__ + ) + target_unified_attention_with_output.__defaults__ = ( + repl_unified_attention_with_output.__defaults__ + ) + target_unified_attention_with_output.__kwdefaults__ = ( + repl_unified_attention_with_output.__kwdefaults__ + ) + try: + target_unified_attention_with_output.__name__ = ( + "unified_attention_with_output_ucm_hooked" + ) + except Exception: + pass - layer.maybe_execute_sparse_attention_begin = ( + attn_layer.maybe_execute_sparse_attention_begin = ( maybe_execute_sparse_attention_begin ) - layer.maybe_execute_sparse_attention_finished = ( + attn_layer.maybe_execute_sparse_attention_finished = ( maybe_execute_sparse_attention_finished ) - layer.Attention.forward = attn_forward - layer.unified_attention = unified_attention_impl - layer.unified_attention_with_output = unified_attention_with_output_impl except ImportError: logger.warning( @@ -389,8 +395,10 @@ def _patch_shared_storage_connector() -> None: @dataclass class SharedStorageConnectorMetadata(KVConnectorMetadata): + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== requests: list[ReqMeta] = field(default_factory=list) + # ====================== UCM-SPARSE-PATCH-END ======================== def add_request( self, token_ids: list[int], @@ -441,7 +449,9 @@ def forward( output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== forward_context: ForwardContext = get_forward_context() + # ====================== UCM-SPARSE-PATCH-END ======================== assert output is not None, "Output tensor must be provided." if output_scale is not None: @@ -493,6 +503,7 @@ def forward( ) if has_prefill: + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== prefill_q, prefill_k_c_normed, prefill_k_pe, _ = ( maybe_execute_sparse_attention_begin( prefill_q, @@ -503,9 +514,11 @@ def forward( phase="prefill", ) ) + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata ) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== maybe_execute_sparse_attention_finished( prefill_q, prefill_k_c_normed, @@ -515,6 +528,7 @@ def forward( forward_context, "prefill", ) + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( @@ -526,6 +540,7 @@ def forward( decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== _, decode_ql_nope, decode_q_pe, _ = ( maybe_execute_sparse_attention_begin( torch.cat([decode_ql_nope, decode_q_pe], dim=-1), @@ -536,9 +551,11 @@ def forward( phase="decode", ) ) + # ====================== UCM-SPARSE-PATCH-END-hook3 ======================== output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata ) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook4 ====================== maybe_execute_sparse_attention_finished( torch.cat([decode_ql_nope, decode_q_pe], dim=-1), decode_ql_nope, @@ -548,6 +565,7 @@ def forward( forward_context, "decode", ) + # ====================== UCM-SPARSE-PATCH-END-hook4 ======================== return output_padded MLACommonImpl.forward = forward @@ -582,10 +600,12 @@ def patched_allocate_slots( ) -> Optional[KVCacheBlocks]: if num_new_tokens == 0: raise ValueError("num_new_tokens must be greater than 0") + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== # Only route to UCM sparse path when caller explicitly provided # a valid sparsified slot count. if (num_slots_sparsed is not None) and (num_slots_sparsed != INVALID_SLOT): return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed) + # ====================== UCM-SPARSE-PATCH-END ======================== return original_allocate_slots( self, request, @@ -695,17 +715,21 @@ def patched_schedule(self) -> SchedulerOutput: # First, schedule the RUNNING requests. req_index = 0 + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== req_sparsed_slots: dict[str, int] = {} if not hasattr(self, "ucm_sparse"): init_ucm_sparse(self) + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] num_slots_sparsed = INVALID_SLOT + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== if self.ucm_sparse: num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed( request ) req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== num_new_tokens = ( request.num_tokens_with_spec - request.num_computed_tokens @@ -1285,7 +1309,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== self.ucm_sparse_request_finished_in_worker(req_id) + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. @@ -1390,7 +1416,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== req_sparsed_slots = scheduler_output.req_sparsed_slots + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] @@ -1420,10 +1448,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: ) # Update the block IDs. + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== if resumed_from_preemption or is_sparsed_request: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids + # ====================== UCM-SPARSE-PATCH-END-hook3 ======================== else: # Append the new blocks to the existing block IDs. for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): @@ -1441,8 +1471,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens ) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== if is_sparsed_request: self.input_batch.block_table.reset_row(req_index) + # ====================== UCM-SPARSE-PATCH-END-hook3 ======================== self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu @@ -1532,6 +1564,7 @@ def _prepare_inputs( if self.uses_mrope: self._calc_mrope_positions(scheduler_output) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens @@ -1548,7 +1581,7 @@ def _prepare_inputs( ) # TODO: support MTP if is_sparsed_request: sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 - + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -1583,11 +1616,13 @@ def _prepare_inputs( # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req - + sparsed_positions // block_size + + sparsed_positions // block_size ### UCM-SPARSE-PATCH ### ) block_table_cpu = block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== block_offsets = sparsed_positions % block_size + # ====================== UCM-SPARSE-PATCH-END-hook2 ====================== np.add( block_numbers * block_size, block_offsets, @@ -1598,6 +1633,7 @@ def _prepare_inputs( self.query_start_loc_np[0] = 0 self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + # ====================== UCM-SPARSE-PATCH-BEGIN-hook3 ====================== for req_id in self.input_batch.req_id_to_index: req_index = self.input_batch.req_id_to_index[req_id] is_sparsed_request = ( @@ -1607,6 +1643,7 @@ def _prepare_inputs( self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[ req_id ] + # ====================== UCM-SPARSE-PATCH-END-hook3 ====================== # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -1620,9 +1657,11 @@ def _prepare_inputs( ) else: # Common case (1D positions) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook4 ====================== self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy( positions_np[:total_num_scheduled_tokens] ) + # ====================== UCM-SPARSE-PATCH-END-hook4 ====================== self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True ) @@ -1831,7 +1870,9 @@ def execute_model( skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) + # ====================== UCM-SPARSE-PATCH-BEGIN-hook1 ====================== self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) + # ====================== UCM-SPARSE-PATCH-END-hook1 ======================== model_output = self.model( input_ids=input_ids, @@ -1841,8 +1882,9 @@ def execute_model( ) self.maybe_wait_for_kv_save() + # ====================== UCM-SPARSE-PATCH-BEGIN-hook2 ====================== logits_indices = self.maybe_execute_ucm_sparse_finished(logits_indices) - + # ====================== UCM-SPARSE-PATCH-END-hook2 ======================== finished_sending, finished_recving = self.get_finished_kv_transfers( scheduler_output ) @@ -2075,7 +2117,9 @@ def patched_init_worker_distributed_environment( original_init_worker_distributed_environment( vllm_config, rank, distributed_init_method, local_rank, backend ) + # ====================== UCM-SPARSE-PATCH-BEGIN ====================== ensure_ucm_sparse_initialized(vllm_config) + # ====================== UCM-SPARSE-PATCH-END ======================== gpu_worker.init_worker_distributed_environment = ( patched_init_worker_distributed_environment