Skip to content
Draft

fix #908

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions modelopt/torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,32 +499,32 @@ def patch_transformers5_params_loading():
https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640
"""
# Skip patching for non-applicable transformers version
if importlib.util.find_spec("transformers.core_model_loading") is None:
return
from transformers import core_model_loading

if not hasattr(core_model_loading, "set_param_for_module"):
return

orig_set_param_for_module = core_model_loading.set_param_for_module

def patched_set_param_for_module(*args, **kwargs):
"""Monkey-patch set_param_for_module to restore original requires_grad."""
model, target_name = args[:2]
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model

# Get original requires_grad value
orig_requires_grad = getattr(module_obj, param_name).requires_grad

# Call set_param_for_module
orig_set_param_for_module(*args, **kwargs)

# Restore original requires_grad value
getattr(module_obj, param_name).requires_grad = orig_requires_grad

should_patch = False
orig_set_param_for_module = None
if importlib.util.find_spec("transformers.core_model_loading") is not None:
from transformers import core_model_loading

if hasattr(core_model_loading, "set_param_for_module"):
should_patch = True
orig_set_param_for_module = core_model_loading.set_param_for_module

def patched_set_param_for_module(*args, **kwargs):
"""Monkey-patch set_param_for_module to restore original requires_grad."""
model, target_name = args[:2]
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
# Get original requires_grad value
orig_requires_grad = getattr(module_obj, param_name).requires_grad
# Call set_param_for_module
orig_set_param_for_module(*args, **kwargs)
# Restore original requires_grad value
getattr(module_obj, param_name).requires_grad = orig_requires_grad

core_model_loading.set_param_for_module = patched_set_param_for_module
try:
core_model_loading.set_param_for_module = patched_set_param_for_module
yield
finally:
core_model_loading.set_param_for_module = orig_set_param_for_module
if should_patch:
from transformers import core_model_loading

core_model_loading.set_param_for_module = orig_set_param_for_module