Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions olive/passes/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ def process_llm_pipeline(
output_dir: Union[str, Path],
decoder_config_extra: Optional[dict[str, Any]] = None,
group_session_options: Optional[dict[str, Any]] = None,
group_run_options: Optional[dict[str, Any]] = None,
) -> CompositeModelHandler:
"""Process an LLM pipeline with the given function.

Expand All @@ -631,6 +632,7 @@ def process_llm_pipeline(
:param output_dir: The directory to save the processed model.
:param decoder_config_extra: Extra configuration for the decoder.
:param group_session_options: Session options for the context and iterator groups.
:param group_run_options: Run options for the context and iterator groups.
:return: The processed composite model handler.
"""
output_dir = Path(output_dir)
Expand Down Expand Up @@ -673,6 +675,7 @@ def process_llm_pipeline(
source_llm_pipeline=llm_pipeline,
decoder_config_extra=decoder_config_extra,
group_session_options=group_session_options,
group_run_options=group_run_options,
)


Expand All @@ -681,13 +684,15 @@ def update_llm_pipeline_genai_config(
source_llm_pipeline: Optional[dict[str, Any]] = None,
decoder_config_extra: Optional[dict[str, Any]] = None,
group_session_options: Optional[dict[str, Any]] = None,
group_run_options: Optional[dict[str, Any]] = None,
) -> CompositeModelHandler:
"""Update the LLM pipeline in the model's genai_config.json file.

:param model: The composite model to update.
:param source_llm_pipeline: The source LLM pipeline to use for the update.
:param decoder_config_extra: Extra configuration for the decoder.
:param group_session_options: Session options for the context and iterator groups.
:param group_run_options: Run options for the context and iterator groups.
:return: The updated composite model.
"""
if not model.model_path or not Path(model.model_path).is_dir():
Expand Down Expand Up @@ -737,6 +742,9 @@ def update_llm_pipeline_genai_config(
group_session_options = group_session_options or decoder_config.get("pipeline", [{}])[0].get(
source_llm_pipeline["context"][0], {}
).get("session_options")
group_run_options = group_run_options or decoder_config.get("pipeline", [{}])[0].get(
source_llm_pipeline["context"][0], {}
).get("run_options")
# update pipeline config
component_models = dict(model.get_model_components())
pipeline_config = {}
Expand All @@ -758,6 +766,8 @@ def update_llm_pipeline_genai_config(
for name in llm_pipeline[group]:
if group_session_options:
pipeline_config[name]["session_options"] = group_session_options
if group_run_options:
pipeline_config[name]["run_options"] = group_run_options
pipeline_config[name][f"run_on_{dont_run_on}"] = False

pipeline_config[llm_pipeline["lm_head"]]["is_lm_head"] = True
Expand Down
7 changes: 7 additions & 0 deletions olive/passes/onnx/context_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=None,
description="Session options for the EP.",
),
"run_options": PassConfigParam(
type_=dict,
default_value=None,
description="Run options for the EP.",
),
"disable_cpu_fallback": PassConfigParam(
type_=bool,
default_value=False,
Expand Down Expand Up @@ -150,13 +155,15 @@ def process_context_iterator(component_models, llm_pipeline, output_dir):
group_session_options["provider_options"] = [
{self.accelerator_spec.execution_provider.lower().replace("executionprovider", ""): provider_options}
]
group_run_options = config.run_options

return process_llm_pipeline(
model,
pipeline,
process_context_iterator,
output_model_path,
group_session_options=group_session_options,
group_run_options=group_run_options,
)

new_component_models = self._generate_composite_binaries(
Expand Down
Loading