From 1951991a11f797592aa7fdcaf6691d8fdc8a974d Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 19 Feb 2026 02:54:25 +0000 Subject: [PATCH] fix Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/utils.py | 52 ++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index e34538665..778587b48 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -499,32 +499,32 @@ def patch_transformers5_params_loading(): https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640 """ # Skip patching for non-applicable transformers version - if importlib.util.find_spec("transformers.core_model_loading") is None: - return - from transformers import core_model_loading - - if not hasattr(core_model_loading, "set_param_for_module"): - return - - orig_set_param_for_module = core_model_loading.set_param_for_module - - def patched_set_param_for_module(*args, **kwargs): - """Monkey-patch set_param_for_module to restore original requires_grad.""" - model, target_name = args[:2] - module_path, _, param_name = target_name.rpartition(".") - module_obj = model.get_submodule(module_path) if module_path else model - - # Get original requires_grad value - orig_requires_grad = getattr(module_obj, param_name).requires_grad - - # Call set_param_for_module - orig_set_param_for_module(*args, **kwargs) - - # Restore original requires_grad value - getattr(module_obj, param_name).requires_grad = orig_requires_grad - + should_patch = False + orig_set_param_for_module = None + if importlib.util.find_spec("transformers.core_model_loading") is not None: + from transformers import core_model_loading + + if hasattr(core_model_loading, "set_param_for_module"): + should_patch = True + orig_set_param_for_module = core_model_loading.set_param_for_module + + def patched_set_param_for_module(*args, **kwargs): + """Monkey-patch set_param_for_module to restore original requires_grad.""" + model, target_name = args[:2] + module_path, _, param_name = target_name.rpartition(".") + module_obj = model.get_submodule(module_path) if module_path else model + # Get original requires_grad value + orig_requires_grad = getattr(module_obj, param_name).requires_grad + # Call set_param_for_module + orig_set_param_for_module(*args, **kwargs) + # Restore original requires_grad value + getattr(module_obj, param_name).requires_grad = orig_requires_grad + + core_model_loading.set_param_for_module = patched_set_param_for_module try: - core_model_loading.set_param_for_module = patched_set_param_for_module yield finally: - core_model_loading.set_param_for_module = orig_set_param_for_module + if should_patch: + from transformers import core_model_loading + + core_model_loading.set_param_for_module = orig_set_param_for_module