diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index 8550f0c9..3843bb9c 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -46,6 +46,7 @@ def build( force_engine_build: bool = False, force_onnx_export: bool = False, force_onnx_optimize: bool = False, + timing_cache: str = None, ): if not force_onnx_export and os.path.exists(onnx_path): print(f"Found cached model: {onnx_path}") @@ -89,6 +90,7 @@ def build( build_dynamic_shape=build_dynamic_shape, build_all_tactics=build_all_tactics, build_enable_refit=build_enable_refit, + timing_cache=timing_cache, ) for file in os.listdir(os.path.dirname(engine_path)): if file.endswith('.engine'): diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index f5cfc1bd..f2ca141e 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -221,6 +221,13 @@ def compile_and_load_engine(self, Moves compilation blocks from wrapper.py lines 1200-1252, 1254-1283, 1285-1313. """ + if 'engine_build_options' not in kwargs: + kwargs['engine_build_options'] = {} + + if 'timing_cache' not in kwargs['engine_build_options']: + timing_cache_path = self.engine_dir / "timing_cache" + kwargs['engine_build_options']['timing_cache'] = str(timing_cache_path) + if not engine_path.exists(): # Get the appropriate compile function for this engine type config = self._configs[engine_type] diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py index fe26a88d..d5d1f7b7 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py @@ -289,7 +289,6 @@ def forward(self, sample, timestep, encoder_hidden_states, *args): return res else: return res[0] - return res def create_controlnet_wrapper(unet: UNet2DConditionModel, diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index b9bfa1c3..1e5be108 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -19,6 +19,7 @@ # import gc +import os from collections import OrderedDict from pathlib import Path from typing import Any, List, Optional, Tuple, Union @@ -249,10 +250,14 @@ def build( if not enable_all_tactics: config_kwargs["tactic_sources"] = [] + load_timing_cache = timing_cache + if isinstance(timing_cache, (str, Path)) and not os.path.exists(timing_cache): + load_timing_cache = None + engine = engine_from_network( network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), config=CreateConfig( - fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=load_timing_cache, **config_kwargs ), save_timing_cache=timing_cache, ) @@ -498,6 +503,7 @@ def build_engine( build_dynamic_shape: bool = False, build_all_tactics: bool = False, build_enable_refit: bool = False, + timing_cache: str = None, ): _, free_mem, _ = cudart.cudaMemGetInfo() GiB = 2**30 @@ -520,6 +526,7 @@ def build_engine( input_profile=input_profile, enable_refit=build_enable_refit, enable_all_tactics=build_all_tactics, + timing_cache=timing_cache, workspace_size=max_workspace_size, ) diff --git a/src/streamdiffusion/tools/compile_raft_tensorrt.py b/src/streamdiffusion/tools/compile_raft_tensorrt.py index 8dec3a76..3a5716b2 100644 --- a/src/streamdiffusion/tools/compile_raft_tensorrt.py +++ b/src/streamdiffusion/tools/compile_raft_tensorrt.py @@ -87,6 +87,7 @@ def export_raft_to_onnx( opset_version=17, export_params=True, dynamic_axes=dynamic_axes, + dynamo=False, ) del raft_model