Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 147 additions & 12 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from olive.passes.onnx.common import get_external_data_config, ir_model_to_olive_model
from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config

# pylint: disable=W0212

logger = logging.getLogger(__name__)


Expand All @@ -57,6 +59,129 @@ def forward(self, *input_data, **input_dict):
return self.model(*input_data, **input_dict)


def _register_dynamic_cache_export_support():
"""Utilities for `DynamicCache` <> torch.export support."""
from transformers.cache_utils import DynamicCache, DynamicLayer, DynamicSlidingWindowLayer

def _get_cache_dict(cache: DynamicCache):
"""Convert cache to dictionary format for pytree operations."""
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")

return {
"cache": [(layer.keys, layer.values) for layer in cache.layers if layer.keys is not None],
}

try:
torch.utils._pytree.register_pytree_node(
DynamicCache,
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
_get_cache_dict(dynamic_cache)
),
)
# TODO (team): This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(
DynamicCache,
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
)
# Catching this in case there are multiple runs for some test runs
except ValueError as e:
if "already registered as pytree node" not in str(e):
raise


def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
from transformers.cache_utils import DynamicCache

dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
# Reconstruct layers from keys and values lists
cache_list = dictionary.get("cache", [])
for i, (key, value) in enumerate(cache_list):
cache.update(key, value, i)
return cache


def _patch_dynamic_layer_for_export():
"""Patch DynamicLayer.lazy_initialization for torch.export compatibility (transformers >= 5.0).

The original uses torch.tensor([]) which creates a 1D empty tensor (shape [0]).
torch.export needs consistent tensor ranks, so we use torch.narrow + torch.empty_like
to preserve the full shape (e.g. [batch, heads, 0, head_dim]).
"""
from transformers.cache_utils import DynamicLayer

if not hasattr(DynamicLayer, "lazy_initialization"):
return

def patched_lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor = None):
self.dtype, self.device = key_states.dtype, key_states.device
like = torch.narrow(key_states, dim=-2, start=0, length=0)
if hasattr(key_states, "fake_mode"):
with key_states.fake_mode:
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
else:
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.is_initialized = True

DynamicLayer.lazy_initialization = patched_lazy_initialization
logger.debug("Patched DynamicLayer.lazy_initialization for torch.export compatibility.")


def _convert_past_key_values_to_dynamic_cache(dummy_kwargs: dict, config=None) -> dict:
"""Convert legacy list-format past_key_values to DynamicCache (transformers >= 5.0).

Transformers 5.0 models expect DynamicCache objects, not lists of (key, value) tensors.
When config is provided, the DynamicCache will create correct layer types (e.g.
DynamicSlidingWindowLayer for models using sliding window attention).
"""
pkv = dummy_kwargs.get("past_key_values")
if pkv is None or not isinstance(pkv, (list, tuple)):
return dummy_kwargs

# Check if it's legacy format: list of [key, value] pairs (each with exactly 2 elements)
if not pkv or not isinstance(pkv[0], (list, tuple)) or len(pkv[0]) != 2:
return dummy_kwargs

from transformers.cache_utils import DynamicCache

dc = DynamicCache(config=config)
for layer_idx, kv in enumerate(pkv):
dc.update(kv[0], kv[1], layer_idx=layer_idx)
dummy_kwargs["past_key_values"] = dc
logger.debug("Converted past_key_values from legacy list format to DynamicCache.")
return dummy_kwargs


def _convert_dynamic_shapes_for_dynamic_cache(dynamic_shapes: dict) -> dict:
"""Convert dynamic_shapes for past_key_values from nested list to DynamicCache pytree format.

The old format is: [[key_shape, val_shape], ...] (one pair per layer)
The DynamicCache pytree expects a flat list: [key0, val0, key1, val1, ...]
matching the flattened order from _register_dynamic_cache_export_support().
"""
pkv_shapes = dynamic_shapes.get("past_key_values")
if pkv_shapes is None or not isinstance(pkv_shapes, (list, tuple)):
return dynamic_shapes

if not pkv_shapes or not isinstance(pkv_shapes[0], (list, tuple)) or len(pkv_shapes[0]) != 2:
return dynamic_shapes

# Convert [[key0, val0], [key1, val1], ...] -> [[key0, key1, ...], [val0, val1, ...]]
# matching DynamicCache pytree: _dict_flatten({"key_cache": [...], "value_cache": [...]})
dynamic_shapes["past_key_values"] = [
[layer[0] for layer in pkv_shapes],
[layer[1] for layer in pkv_shapes],
]
logger.debug("Converted dynamic_shapes for past_key_values to DynamicCache pytree format.")
return dynamic_shapes


def _patch_model_if_necessary(pytorch_model: torch.nn.Module):
if not isinstance(pytorch_model, PreTrainedModel):
return
Expand Down Expand Up @@ -179,9 +304,6 @@ def _export_pytorch_model(
if torch_dtype:
pytorch_model = pytorch_model.to(torch_dtype)

# Apply any necessary patches
_patch_model_if_necessary(pytorch_model)

# get input and output names, and dynamic axes
assert io_config is not None, "Cannot get io_config for the model."
io_config = validate_config(io_config, IoConfig)
Expand All @@ -194,8 +316,6 @@ def _export_pytorch_model(
# is taken, the old export always writes a model to the disk. When that happens we need to
# load the model back into IR and load all the external tensor to memory
with tempfile.TemporaryDirectory(prefix="olive_tmp") as tmp_dir:
tmp_model_path = resolve_onnx_path(tmp_dir)

if dynamo:
# Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0
if _torch_is_older_than("2.7.0") and (
Expand All @@ -212,24 +332,39 @@ def _export_pytorch_model(
"Please upgrade PyTorch to 2.6.0 or above."
)

# Register DynamicCache export support
from transformers.integrations.executorch import register_dynamic_cache_export_support

register_dynamic_cache_export_support()

if isinstance(dummy_inputs, dict):
dummy_kwargs = dummy_inputs
dummy_inputs = ()
else:
dummy_kwargs = {}
dummy_inputs = tuple(dummy_inputs)

# Apply patches for DynamicCache / past_key_values compatibility
if version.parse(transformers.__version__) >= version.parse("5.0"):
# transformers >= 5.0: DynamicCache refactored to use DynamicLayer

_register_dynamic_cache_export_support()
_patch_dynamic_layer_for_export()
model_config = getattr(pytorch_model, "config", None)
dummy_kwargs = _convert_past_key_values_to_dynamic_cache(dummy_kwargs, config=model_config)
if io_config.dynamic_shapes:
io_config.dynamic_shapes = _convert_dynamic_shapes_for_dynamic_cache(io_config.dynamic_shapes)
else:
# transformers < 5.0: patch forward to convert list <-> DynamicCache
_patch_model_if_necessary(pytorch_model)

# NOTE: Usually validation is done in io_config.py, but because
# dynamic_shapes has nested complexity, and it can't be validated multiple
# times like others, we validate it here.
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes(
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs, pytorch_model
)
# torch.export requires strict type match between inputs and dynamic_shapes;
# _validate_dynamic_shapes may return OrderedDict, so convert back to plain dict
if isinstance(io_config.dynamic_shapes, collections.OrderedDict):
io_config.dynamic_shapes = dict(io_config.dynamic_shapes)
if isinstance(dummy_kwargs, collections.OrderedDict):
dummy_kwargs = dict(dummy_kwargs)

# When dynamo=True, PyTorch prefers dynamic_shapes over dynamic_axes.
# If dynamic_shapes is None and fallback is enabled, don't pass dynamic_axes
Expand All @@ -239,15 +374,13 @@ def _export_pytorch_model(
onnx_program = torch.onnx.export( # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
pytorch_model,
dummy_inputs,
tmp_model_path, # needed for fallback=True
kwargs=dummy_kwargs,
opset_version=config.target_opset,
input_names=io_config.input_names,
output_names=io_config.output_names,
dynamic_axes=dynamic_axes_for_export,
dynamic_shapes=io_config.dynamic_shapes,
dynamo=True,
fallback=False,
optimize=config.optimize,
report=logger.isEnabledFor(logging.DEBUG),
)
Expand All @@ -264,6 +397,8 @@ def _export_pytorch_model(
# default is True in 2.9.0 and later
dynamo_args["dynamo"] = False

tmp_model_path = resolve_onnx_path(tmp_dir)

torch.onnx.export(
pytorch_model,
dummy_inputs,
Expand Down
5 changes: 5 additions & 0 deletions olive/passes/pytorch/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def create_training_args(self) -> transformers.TrainingArguments:
if version.parse(transformers_version) < version.parse("4.41") and "eval_strategy" in args:
args["evaluation_strategy"] = args.pop("eval_strategy")
extra_args = args.pop("extra_args")
# Filter out fields that are not valid TrainingArguments parameters (e.g. overwrite_output_dir
# was removed in transformers 5.0 but is still used by Olive's own logic) and None values
# so that transformers uses its own defaults
training_args_fields = {f.name for f in dataclasses.fields(transformers.TrainingArguments) if f.init}
args = {k: v for k, v in args.items() if k in training_args_fields and v is not None}
return transformers.TrainingArguments(**args, **extra_args)


Expand Down
30 changes: 11 additions & 19 deletions test/model/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,16 @@ def setup(self):
self.local_path = huggingface_hub.snapshot_download(self.model_name, revision=self.revision)

@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize("trust_remote_code", [True, False])
def test_load_model(self, local, trust_remote_code):
def test_load_model(self, local):
olive_model = HfModelHandler(
model_path=self.local_path if local else self.model_name,
task=self.task,
load_kwargs={"trust_remote_code": trust_remote_code, "revision": self.revision},
load_kwargs={"revision": self.revision},
)

pytorch_model = olive_model.load_model()
actual_class_path = f"{pytorch_model.__module__}.{pytorch_model.__class__.__name__}"
if trust_remote_code:
# When using remote code, the model is loaded from transformers_modules
assert actual_class_path.startswith("transformers_modules.")
assert actual_class_path.endswith(".modeling_phi3.Phi3ForCausalLM")
else:
# When not using remote code, the model is loaded from transformers
assert actual_class_path == "transformers.models.phi3.modeling_phi3.Phi3ForCausalLM"
assert actual_class_path == "transformers.models.phi3.modeling_phi3.Phi3ForCausalLM"

@pytest.mark.parametrize("local", [True, False])
def test_load_model_with_kwargs(self, local):
Expand Down Expand Up @@ -73,19 +66,18 @@ def test_save_metadata(self, local, trust_remote_code, tokenizer_exists, tmp_pat
if tokenizer_exists:
olive_model.get_hf_tokenizer().save_pretrained(tmp_path)
saved_filepaths = olive_model.save_metadata(tmp_path)
# transformers>=4.53.x
assert len(saved_filepaths) == (4 if tokenizer_exists else 10)
# transformers>=5.0.0
assert len(saved_filepaths) == (4 if tokenizer_exists else 7)
assert all(Path(fp).exists() for fp in saved_filepaths)
assert isinstance(transformers.AutoConfig.from_pretrained(tmp_path), transformers.Phi3Config)
assert isinstance(transformers.AutoTokenizer.from_pretrained(tmp_path), transformers.LlamaTokenizerFast)
assert isinstance(transformers.AutoTokenizer.from_pretrained(tmp_path), transformers.PreTrainedTokenizerBase)

@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize("trust_remote_code", [True, False])
def test_save_pretrained_metadata(self, local, trust_remote_code, tmp_path):
def test_save_pretrained_metadata(self, local, tmp_path):
olive_model = HfModelHandler(
model_path=self.local_path if local else self.model_name,
task=self.task,
load_kwargs={"trust_remote_code": trust_remote_code, "revision": self.revision},
load_kwargs={"revision": self.revision},
)

# modify the config and save the model
Expand All @@ -94,8 +86,8 @@ def test_save_pretrained_metadata(self, local, trust_remote_code, tmp_path):
loaded_model.save_pretrained(tmp_path)

saved_filepaths = olive_model.save_metadata(tmp_path)
# generation config is also saved, transformers>=4.53.x
assert len(saved_filepaths) == 9
# generation config is also saved, transformers>=5.0.0
assert len(saved_filepaths) == 6

with open(tmp_path / "config.json") as f:
config = json.load(f)
Expand Down Expand Up @@ -126,7 +118,7 @@ def test_save_metadata_with_module_files(trust_remote_code, tmp_path):
assert f"{config.__module__}.{config.__class__.__name__}" == expected_class_name
assert isinstance(
transformers.AutoTokenizer.from_pretrained(tmp_path, **load_kwargs),
transformers.LlamaTokenizerFast,
transformers.PreTrainedTokenizerBase,
)


Expand Down
3 changes: 2 additions & 1 deletion test/passes/pytorch/test_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def common_test_rotate(rotate_pass, tmp_path, model_path, rotate_mode, atol, **c
with torch.no_grad():
original_output = original_model(i)
rotated_output = rotated_model(i)
assert torch.allclose(original_output.logits, rotated_output.logits, atol=atol)
# Cast to same dtype before comparison since rotated model may be saved/loaded in a different dtype
assert torch.allclose(original_output.logits.float(), rotated_output.logits.float(), atol=atol)


@pytest.mark.parametrize("model_path", ["tiny-phi3", "tiny-llama"])
Expand Down
2 changes: 0 additions & 2 deletions test/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,3 @@ sentencepiece
soundfile
tabulate
torchvision
# Remove version pin when the tests are fixed
transformers<5.0.0
Loading