diff --git a/configs/sd15_multicontrol.yaml.example b/configs/sd15_multicontrol.yaml.example index 7505d157..236375e5 100644 --- a/configs/sd15_multicontrol.yaml.example +++ b/configs/sd15_multicontrol.yaml.example @@ -32,11 +32,19 @@ seed: 789 frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: true +# LoRA configuration - use lora_dict to load LCM LoRA and other LoRAs +lora_dict: + "latent-consistency/lcm-lora-sdv1-5": 1.0 # LCM LoRA for faster inference + # Add other LoRAs here: + # "your_custom_lora": 0.7 + use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + # Engine directory for TensorRT (engines will be built here if not found) engine_dir: "./engines/sd15" diff --git a/configs/sdturbo_multicontrol.yaml.example b/configs/sdturbo_multicontrol.yaml.example index 5f7b8561..a0fe0ce0 100644 --- a/configs/sdturbo_multicontrol.yaml.example +++ b/configs/sdturbo_multicontrol.yaml.example @@ -22,11 +22,19 @@ seed: 789 frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: true # SD-Turbo benefits from LCM LoRA +# LoRA configuration - SD-Turbo can benefit from LCM LoRA +lora_dict: + "latent-consistency/lcm-lora-sdv1-5": 1.0 # LCM LoRA for faster inference + # Add other LoRAs here: + # "your_custom_lora": 0.7 + use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + # Engine directory for TensorRT engine_dir: "./engines/sdturbo" diff --git a/configs/sdxl_multicontrol.yaml.example b/configs/sdxl_multicontrol.yaml.example index f32f4dab..59116ffd 100644 --- a/configs/sdxl_multicontrol.yaml.example +++ b/configs/sdxl_multicontrol.yaml.example @@ -31,11 +31,20 @@ seed: 42 # Base seed (used with seed_blending above) frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: false # SDXL has built-in optimizations +# LoRA configuration - SDXL can use LCM LoRA for faster inference +# lora_dict: +# "latent-consistency/lcm-lora-sdxl": 1.0 # Uncomment to enable LCM LoRA for SDXL +# # Add other LoRAs here: +# # "your_custom_lora": 0.7 + use_taesd: true # Use Tiny AutoEncoder for SDXL use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" + +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + safety_checker: false # Engine directory for TensorRT diff --git a/demo/realtime-img2img/config.py b/demo/realtime-img2img/config.py index 8d74eda4..56d77404 100644 --- a/demo/realtime-img2img/config.py +++ b/demo/realtime-img2img/config.py @@ -20,6 +20,7 @@ class Args(NamedTuple): controlnet_config: str api_only: bool log_level: str + quiet: bool def pretty_print(self): print("\n") @@ -34,6 +35,7 @@ def pretty_print(self): ENGINE_DIR = os.environ.get("ENGINE_DIR", "engines") ACCELERATION = os.environ.get("ACCELERATION", "xformers") LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") +QUIET = os.environ.get("QUIET", "False").lower() in ("true", "1", "yes", "on") default_host = os.getenv("HOST", "0.0.0.0") default_port = int(os.getenv("PORT", "7860")) @@ -129,5 +131,12 @@ def pretty_print(self): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) +parser.add_argument( + "--quiet", + dest="quiet", + action="store_true", + default=QUIET, + help="Suppress uvicorn INFO messages (server access logs, etc.)", +) config = Args(**vars(parser.parse_args())) config.pretty_print() diff --git a/demo/realtime-img2img/frontend/package-lock.json b/demo/realtime-img2img/frontend/package-lock.json index aef6d66b..89eb9c49 100644 --- a/demo/realtime-img2img/frontend/package-lock.json +++ b/demo/realtime-img2img/frontend/package-lock.json @@ -3842,9 +3842,9 @@ } }, "node_modules/svelte-check/node_modules/picomatch": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", - "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", "optional": true, @@ -4238,20 +4238,6 @@ "node": ">=18" } }, - "node_modules/yaml": { - "version": "2.8.0", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.0.tgz", - "integrity": "sha512-4lLa/EcQCB0cJkyts+FpIRx5G/llPxfP6VQU5KByHEhLxY3IJCH0f0Hy1MHI8sClTvsIb8qwRJ6R/ZdlDJ/leQ==", - "license": "ISC", - "optional": true, - "peer": true, - "bin": { - "yaml": "bin.mjs" - }, - "engines": { - "node": ">= 14.6" - } - }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", diff --git a/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte b/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte index 830666d7..0c59002b 100644 --- a/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte +++ b/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte @@ -40,7 +40,6 @@ use_denoising_batch: true, delta: 0.7, frame_buffer_size: 1, - use_lcm_lora: true, use_tiny_vae: true, acceleration: "xformers", cfg_type: "self", diff --git a/demo/realtime-img2img/frontend/src/routes/+page.svelte b/demo/realtime-img2img/frontend/src/routes/+page.svelte index 11442df2..4eb886cb 100644 --- a/demo/realtime-img2img/frontend/src/routes/+page.svelte +++ b/demo/realtime-img2img/frontend/src/routes/+page.svelte @@ -1026,7 +1026,7 @@ diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index 932808d7..bbcac9d6 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -59,6 +59,13 @@ def setup_logging(log_level: str = "INFO"): # Initialize logger logger = setup_logging(config.log_level) +# Suppress uvicorn INFO messages +if config.quiet: + uvicorn_logger = logging.getLogger('uvicorn') + uvicorn_logger.setLevel(logging.WARNING) + uvicorn_access_logger = logging.getLogger('uvicorn.access') + uvicorn_access_logger.setLevel(logging.WARNING) + class AppState: """Centralized application state management - SINGLE SOURCE OF TRUTH""" diff --git a/demo/realtime-img2img/requirements.txt b/demo/realtime-img2img/requirements.txt index a379a58e..dd200a25 100644 --- a/demo/realtime-img2img/requirements.txt +++ b/demo/realtime-img2img/requirements.txt @@ -1,11 +1,11 @@ diffusers==0.35.0 -transformers==4.56.0 -peft==0.18.0 +transformers==4.55.4 +peft==0.17.1 accelerate==1.10.0 -huggingface_hub==0.35.0 +huggingface_hub==0.34.4 fastapi==0.115.0 uvicorn[standard]==0.32.0 -Pillow==10.5.0 +Pillow==10.4.0 compel==2.0.2 controlnet-aux==0.0.7 xformers; sys_platform != 'darwin' or platform_machine != 'arm64' diff --git a/demo/realtime-txt2img/config.py b/demo/realtime-txt2img/config.py index c0a14ba4..35494148 100644 --- a/demo/realtime-txt2img/config.py +++ b/demo/realtime-txt2img/config.py @@ -29,8 +29,7 @@ class Config: model_id_or_path: str = os.environ.get("MODEL", "KBlueLeaf/kohaku-v2.1") # LoRA dictionary write like field(default_factory=lambda: {'E:/stable-diffusion-webui/models/Lora_1.safetensors' : 1.0 , 'E:/stable-diffusion-webui/models/Lora_2.safetensors' : 0.2}) lora_dict: dict = None - # LCM-LORA model - lcm_lora_id: str = os.environ.get("LORA", "latent-consistency/lcm-lora-sdv1-5") + # LCM-LORA model (use lora_dict instead of lcm_lora_id) # TinyVAE model vae_id: str = os.environ.get("VAE", "madebyollin/taesd") # Device to use diff --git a/demo/realtime-txt2img/main.py b/demo/realtime-txt2img/main.py index 88967c4c..18931f9e 100644 --- a/demo/realtime-txt2img/main.py +++ b/demo/realtime-txt2img/main.py @@ -63,7 +63,6 @@ def __init__(self, config: Config) -> None: mode=config.mode, model_id_or_path=config.model_id_or_path, lora_dict=config.lora_dict, - lcm_lora_id=config.lcm_lora_id, vae_id=config.vae_id, device=config.device, dtype=config.dtype, diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py index ac2c2a53..791d88b1 100644 --- a/examples/optimal-performance/multi.py +++ b/examples/optimal-performance/multi.py @@ -74,7 +74,6 @@ def image_generation_process( frame_buffer_size=batch_size, warmup=10, acceleration=acceleration, - use_lcm_lora=False, mode="txt2img", cfg_type="none", use_denoising_batch=True, diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py index 4bc08b3f..a8020bb8 100644 --- a/examples/optimal-performance/single.py +++ b/examples/optimal-performance/single.py @@ -40,7 +40,6 @@ def image_generation_process( frame_buffer_size=1, warmup=10, acceleration=acceleration, - use_lcm_lora=False, mode="txt2img", cfg_type="none", use_denoising_batch=True, diff --git a/multi_test/README.md b/multi_test/README.md new file mode 100644 index 00000000..cbc75590 --- /dev/null +++ b/multi_test/README.md @@ -0,0 +1,716 @@ +# StreamDiffusion Multi-Config Test Suite + +This testing suite allows you to benchmark multiple StreamDiffusion configurations against multiple video files, providing comprehensive performance analysis and reports. + +## Features + +- **Multi-Config Testing**: Test multiple YAML configuration files against multiple video files +- **Resume Functionality**: Continue processing from where you left off after interruptions +- **Individual Prompt Processing**: Process each prompt individually and merge results into combined videos +- **RAM-Based Frame Processing**: Loads all frames into RAM for maximum speed (similar to main.py) +- **RAM-Based Video Creation**: Creates MP4 videos directly from frames in memory (no disk I/O) +- **Automatic Frame Extraction**: Uses ffmpeg to extract frames from videos for processing +- **Framerate Matching**: Output videos maintain the same framerate and timing as input videos +- **Full ControlNet Support**: Automatically loads and configures ControlNets from YAML configs +- **Performance Metrics**: Measures FPS, frame processing times, and success rates +- **Comprehensive Reporting**: Generates multiple report formats (TXT, CSV, JSON) +- **Error Handling**: Gracefully handles failures and continues with remaining tests +- **Resource Management**: Automatically cleans up temporary files and RAM cache +- **Video Merging**: Combines output from multiple prompts into single merged videos +- **JSON Metadata Generation**: Creates detailed JSON files alongside output videos containing configuration details, performance metrics, and processing information for comprehensive analysis and resume support +- **Enhanced Video Wall**: Automatically generates a video wall with rich metadata overlays (config, model, FPS, etc.) using ffmpeg, supporting flipped layouts (videos as rows, configs as columns) with fallback to basic wall +- **Retry Failed Combinations**: Supports retrying previously failed config+video pairs during resume without reprocessing successful ones, preserving all prior results +- **Advanced Performance Metrics**: Includes coefficient of variation (CV) for FPS stability, segment-level FPS analysis, and detailed rankings/recommendations in reports for better optimization insights + +## Performance Optimizations + +The test suite is optimized for maximum speed by: + +1. **RAM-Based Processing**: All video frames are loaded into RAM once and reused across multiple prompts/configs +2. **RAM-Based Video Creation**: Creates MP4 videos directly from frames in memory (no temporary files) +3. **Frame Caching**: Frames are cached in memory to avoid reloading from disk +4. **Minimal Disk I/O**: Processing happens entirely in memory, with disk writes only for final output +5. **Efficient Memory Management**: Automatic cleanup of frame cache to prevent memory issues +6. **Batch Processing**: Multiple prompts are processed against the same frames without reloading +7. **Framerate Preservation**: Output videos maintain input video timing for seamless playback +8. **Dual Video Encoding**: Uses imageio (primary) + OpenCV (fallback) for maximum compatibility + +This approach makes the test suite run significantly faster than disk-based alternatives, similar to the real-time processing in `main.py`. + +## Requirements + +- Python 3.7+ +- StreamDiffusion installed +- ffmpeg-python (for enhanced video wall creation with metadata overlays; install with `pip install ffmpeg-python`) +- ffmpeg (for video frame extraction only) +- PyYAML +- PIL/Pillow +- **Video Creation Dependencies**: + - `imageio` (primary video creation) + - `imageio-ffmpeg` (for H.264 encoding) + - `opencv-python` (fallback video creation) +- Sufficient RAM to hold all video frames (typically 2-4GB per video depending on resolution) + +## Installation + +1. Install the core dependencies: +```bash +pip install pyyaml pillow +``` + +2. Install ffmpeg (for frame extraction only): + - **Windows**: Download from https://ffmpeg.org/download.html + - **macOS**: `brew install ffmpeg` + - **Linux**: `sudo apt install ffmpeg` or equivalent + +3. Make sure StreamDiffusion is properly installed and accessible + +## Usage + +### Basic Usage + +```bash +# Test with config prompts (original behavior) +python multi_test.py --configs ./myconfigdir --videos ./myinputvideos + +# Test with individual prompts from file +python multi_test.py --configs ./myconfigdir --videos ./myinputvideos --prompts ./my_prompts.txt +``` + +### Command Line Options + +- `--configs`: Directory containing YAML configuration files +- `--videos`: Directory containing video files +- `--output`: Output directory for results (default: `./output-test`) +- `--prompts`: Text file containing individual prompts (one per line, optional) +- `--timeout_seconds`: Maximum time to spend processing each video (default: 300) +- `--resume`: Resume from existing output directory (full path to directory) +- `--retry_failed`: Retry previously failed combinations during resume (default: false) + +### Memory Management Options + +For videos that cause CUDA out-of-memory errors, use these options: + +```bash +# Process fewer frames per batch to reduce memory usage +python multi_test.py --configs ./configs --videos ./videos --batch-size 5 + +# Lower memory threshold for more aggressive cleanup +python multi_test.py --configs ./configs --videos ./videos --memory-threshold 1.0 + +# Process every 2nd frame for very long videos (reduces processing time and memory) +python multi_test.py --configs ./configs --videos ./videos --frame-skip 2 + +# Combine all memory management options +python multi_test.py --configs ./configs --videos ./videos --batch-size 5 --memory-threshold 1.0 --frame-skip 2 +``` + +### Example + +```bash +# Test all configs in ./configs against all videos in ./videos +python multi_test.py --configs ./configs --videos ./videos --output ./benchmark_results + +# Test with individual prompts +python multi_test.py --configs ./configs --videos ./videos --prompts ./prompts.txt --output ./prompt_results + +# Test with custom output directory +python multi_test.py --configs ./my_configs --videos ./my_videos --output ./my_results +``` + +## Resume Functionality + +The test suite now supports resuming interrupted runs, allowing you to continue processing from where you left off without losing previous work. + +### How Resume Works + +1. **Automatic Detection**: Scans existing output directory for completed videos and JSON metadata files +2. **Smart Parsing**: Extracts config+video combinations from existing filenames and JSON metadata +3. **CSV Integration**: Loads existing results from CSV files, including both successful and failed tests +4. **JSON Enrichment**: Loads detailed performance data from JSON metadata files for enhanced analysis +5. **Skip Completed**: Only processes remaining config+video combinations, with option to retry failed ones +6. **Seamless Integration**: Updates existing reports and maintains all output files + +### Resume Usage + +```bash +# Start a new test run +python multi_test.py --configs ./configs --videos ./videos --output ./output-multi + +# Resume from existing directory (if interrupted) +python multi_test.py --configs ./configs --videos ./videos --resume "C:\sd\StreamDiffusion\multi_test\20250903_192109" + +# Resume with different prompts (will use same output directory) +python multi_test.py --configs ./configs --videos ./videos --prompts ./prompts.txt --resume "./output-multi/20250903_192109" + +# Resume with different timeout for remaining work +python multi_test.py --configs ./configs --videos ./videos --timeout_seconds 600 --resume "./output-multi/20250903_192109" +``` + +### Resume Benefits + +- **Time Saving**: No need to reprocess completed combinations +- **Memory Efficient**: Continues with existing memory management +- **Progress Preservation**: Maintains all existing output files and reports +- **Flexible**: Can change prompts or timeout for remaining work +- **Robust**: Handles various filename formats and edge cases + +### Resume Output + +When resuming, the test suite will show: + +``` +๐Ÿ”„ Resuming from existing directory: C:\sd\StreamDiffusion\multi_test\20250903_192109 + +๐Ÿ” Scanning for completed work in: C:\sd\StreamDiffusion\multi_test\20250903_192109 +๐Ÿ“Š Loading existing results from CSV: detailed_results.csv +โœ… Loaded 8 successful results from CSV +๐ŸŽฌ Scanning video files in directory... + ๐Ÿ“น Analyzing: sdxl_depth_trt_ta86_cn_lcm_20250903-2317-28.5538735_ta86AllrounderXL_sdxlV1_merged_7prompts + โœ… Found completed: sdxl_depth_trt_ta86_cn_lcm + 20250903-2317-28.5538735 + +๐Ÿ“‹ Resume Summary: + Found 8 completed config+video combinations + Found 8 results with performance data + Completed combinations: + โœ… sdxl_depth_trt_ta86_cn_lcm + 20250903-2317-28.5538735 + โœ… sdxl_depth_trt_ta86_cn_lcm + 20250903-2319-02.8209091 + ... + +๐Ÿ“Š Work Summary: + Total combinations: 12 + Already completed: 8 + Remaining to process: 4 + โญ๏ธ Skipping 8 completed combinations +๐Ÿš€ Starting processing of 4 remaining combinations... +``` + +### Important Notes + +- **Directory Path**: Resume directory must exist and contain previous results +- **Config/Video Consistency**: Use the same config and video directories as the original run +- **Flexible Parameters**: Prompts file and timeout can be different for remaining work +- **Safety Checks**: Double-checks combinations to prevent duplicate processing +- **Progress Tracking**: Shows clear distinction between resumed and new work + +## Prompt Processing Modes + +### Mode 1: Config Prompts (Default) +When no `--prompts` file is provided, the test suite uses the prompts defined in your YAML config files: +```yaml +prompt: "A beautiful landscape" +negative_prompt: "low quality, bad quality, blurry" +``` + +### Mode 2: Individual Prompts (Temporal Splitting) +When `--prompts` file is provided, the test suite: +1. **Ignores** the `prompt` field in your YAML configs +2. **Splits the video temporally** across prompts (e.g., 30s video with 3 prompts = 10s each) +3. **Processes each prompt against its time segment** (much more efficient than full video processing) +4. **Merges all prompt outputs** into a single combined video +5. **Reports performance** for each prompt separately +6. **No pipeline restart** - uses StreamDiffusion's dynamic prompt updating + +## Prompts File Format + +Create a text file with one prompt per line: + +```txt +A hyperrealistic close-up of a man in a crimson silk, windswept auburn hair framing a freckled face, standing on a sun-drenched beach; fine sand clinging to her bare feet, the vast ocean a turquoise expanse behind her, conveying a sense of serene solitude. +A cinematic portrait of a woman with flowing golden hair, wearing an elegant emerald dress, standing in a moonlit garden surrounded by blooming roses and twinkling fairy lights. +A dramatic close-up of a warrior with battle-scarred armor, steely blue eyes reflecting determination, standing against a stormy sky with lightning illuminating ancient castle ruins in the background. +``` + +## Temporal Prompt Splitting + +The test suite now uses **temporal splitting** for maximum efficiency when processing multiple prompts: + +### How It Works + +1. **Video Segmentation**: The input video is divided into equal time segments based on the number of prompts +2. **Frame Distribution**: Each prompt processes only its assigned frames (e.g., frames 1-100 for prompt 1, frames 101-200 for prompt 2) +3. **Dynamic Prompt Updates**: Uses `stream.update_prompt()` to change prompts without restarting the pipeline +4. **Efficient Processing**: Each frame is processed only once with its corresponding prompt + +### Example: 30-Second Video with 3 Prompts + +- **Total Frames**: 900 frames (30fps ร— 30 seconds) +- **Prompt 1**: Frames 1-300 (0-10 seconds) โ†’ "Stained glass style..." +- **Prompt 2**: Frames 301-600 (10-20 seconds) โ†’ "Cinematic portrait..." +- **Prompt 3**: Frames 601-900 (20-30 seconds) โ†’ "Dramatic warrior..." + +### Benefits + +- **3x Faster**: Video processed once instead of three times +- **Memory Efficient**: No duplicate frame storage +- **Seamless Transitions**: Smooth prompt changes between segments +- **Professional Quality**: Each time segment gets dedicated prompt processing +- **Pipeline Optimization**: Leverages StreamDiffusion's dynamic prompt updating + +## Directory Structure + +``` +project/ +โ”œโ”€โ”€ configs/ # Your YAML config files +โ”‚ โ”œโ”€โ”€ config1.yaml +โ”‚ โ”œโ”€โ”€ config2.yaml +โ”‚ โ””โ”€โ”€ ... +โ”œโ”€โ”€ videos/ # Your video files +โ”‚ โ”œโ”€โ”€ video1.mp4 +โ”‚ โ”œโ”€โ”€ video2.avi +โ”‚ โ””โ”€โ”€ ... +โ”œโ”€โ”€ prompts.txt # Individual prompts (optional) +โ”œโ”€โ”€ test_results/ # Output directory (created automatically) +โ”‚ โ”œโ”€โ”€ test_summary.txt +โ”‚ โ”œโ”€โ”€ detailed_results.csv +โ”‚ โ”œโ”€โ”€ performance_comparison.txt +โ”‚ โ”œโ”€โ”€ config1_video1_merged.mp4 # Merged video (when using prompts) +โ”‚ โ””โ”€โ”€ individual_results/ +โ””โ”€โ”€ multi_test.py # The test suite script +``` + +## Configuration File Format + +Your YAML config files should follow the StreamDiffusion format. When using `--prompts`, the `prompt` field is ignored: + +```yaml +model_id: "runwayml/stable-diffusion-v1-5" +width: 512 +height: 512 +t_index_list: [32, 40, 45] +acceleration: "xformers" +guidance_scale: 1.2 +num_inference_steps: 50 +# prompt: "This is ignored when using --prompts" +negative_prompt: "low quality, bad quality, blurry" +use_denoising_batch: true +cfg_type: "self" +seed: 42 +``` + +### Required Fields + +- `model_id`: Path to the model checkpoint +- `width`: Image width (must be multiple of 64) +- `height`: Image height (must be multiple of 64) + +### Optional Fields + +- `t_index_list`: Denoising timesteps (default: [32, 40, 45]) +- `acceleration`: Acceleration method (default: "xformers") +- `guidance_scale`: CFG scale (default: 1.2) +- `num_inference_steps`: Number of inference steps (default: 50) +- `negative_prompt`: Negative prompt (default: "low quality, bad quality, blurry") +- `use_denoising_batch`: Use denoising batch (default: true) +- `cfg_type`: CFG type (default: "self") +- `seed`: Random seed (default: 42) + +**Note**: When using `--prompts`, the `prompt` field in your config is ignored. + +## ControlNet Support + +The test suite automatically detects and configures ControlNets from your YAML configuration files: + +### ControlNet Configuration Format + +```yaml +model_id: "runwayml/stable-diffusion-v1-5" +width: 512 +height: 512 +acceleration: "xformers" + +# ControlNet configurations +controlnets: + - model_id: "lllyasviel/control_v11p_sd15_canny" + preprocessor: "canny" + conditioning_scale: 1.0 + enabled: true + preprocessor_params: + low_threshold: 100 + high_threshold: 200 + + - model_id: "lllyasviel/control_v11p_sd15_depth" + preprocessor: "depth" + conditioning_scale: 0.8 + enabled: true + preprocessor_params: + depth_estimator: "dpt_large" +``` + +### Supported Preprocessors + +- **canny**: Edge detection with configurable thresholds +- **depth**: Depth estimation using various models +- **openpose**: Human pose estimation +- **scribble**: Free-form drawing input +- **segmentation**: Semantic segmentation +- **passthrough**: Direct image input without preprocessing + +### ControlNet Integration + +- **Automatic Loading**: ControlNets are loaded when the pipeline is created +- **Preprocessor Setup**: Preprocessors are automatically configured with your parameters +- **Performance Impact**: ControlNet processing is included in FPS measurements +- **Memory Management**: ControlNet models are properly managed alongside the main pipeline + +## Video File Support + +The test suite supports common video formats: +- MP4 (.mp4) +- AVI (.avi) +- MOV (.mov) +- MKV (.mkv) +- WebM (.webm) +- FLV (.flv) + +Videos are automatically converted to frames at 30 FPS for processing. + +## Framerate Matching + +The test suite automatically detects and preserves the input video's framerate in all output videos: + +### How It Works + +1. **Automatic Detection**: Uses `ffprobe` to extract the exact framerate from input videos +2. **Timing Preservation**: Output videos maintain the same frame timing as input videos +3. **Frame Distribution**: Generated frames are distributed to match input video timing + +### Example Scenarios + +**Scenario 1: Input 30fps, Processing 30fps** +- Input: 100 frames at 30fps (3.33 seconds) +- Processing: Generates 100 frames +- Output: 100 frames at 30fps (3.33 seconds) - Perfect match + +**Scenario 2: Input 30fps, Processing 15fps** +- Input: 100 frames at 30fps (3.33 seconds) +- Processing: Generates 50 frames +- Output: 50 frames at 30fps (3.33 seconds) - Each frame displayed for 2 input frame durations + +**Scenario 3: Input 30fps, Processing 60fps** +- Input: 100 frames at 30fps (3.33 seconds) +- Processing: Generates 200 frames +- Output: 200 frames at 30fps (3.33 seconds) - Each input frame duration shows 2 generated frames + +### Benefits + +- **Seamless Playback**: Output videos can be played alongside input videos +- **Consistent Timing**: All output videos maintain original video timing +- **Professional Quality**: Suitable for video editing and compositing workflows +- **Frame Accuracy**: Precise frame duration calculations using ffmpeg + +## Memory Management + +The test suite uses intelligent memory management: + +1. **Frame Loading**: Frames are loaded into RAM once per video +2. **Caching**: Frames are cached and reused across multiple configs/prompts +3. **Automatic Cleanup**: Frame cache is cleared after processing to free memory +4. **Memory Estimation**: Each frame typically uses 2-4MB depending on resolution + +**Memory Requirements**: Ensure you have sufficient RAM to hold all frames from your longest video. For a 1000-frame 512x512 video, expect ~2-4GB RAM usage. + +## Output Files + +After running the test suite, you'll get several output files: + +### 1. Test Summary (`test_summary.txt`) +- Overall test statistics +- Results grouped by configuration +- Results grouped by video +- Top 5 performing configurations +- Prompt processing information (when using `--prompts`) + +### 2. Detailed Results (`detailed_results.csv`) +- CSV format with all test results +- Individual frame processing times +- Success/failure status +- Error messages for failed tests +- Prompt processing details (when using `--prompts`) + +### 3. Performance Comparison (`performance_comparison.txt`) +- Performance comparison between configurations +- Average FPS for each config +- Sorted results by performance +- Individual prompt performance (when using `--prompts`) + +### 4. Individual Results (`*_result.json`) +- JSON files for each config-video combination +- Detailed metrics and configuration parameters +- Frame-by-frame timing data +- Prompt-by-prompt results (when using `--prompts`) + +### 4.1. Video Metadata (`*_metadata.json`) +- Comprehensive JSON files generated alongside each output video +- Contains structured data for resume functionality and analysis +- Structure: + - **video_info**: Config filename, video filename, output filename, total frames, prompts used, processing date + - **config_details**: Model ID, resolution (width/height), inference steps, guidance scale, negative prompt + - **performance_metrics**: Overall FPS, min/max/avg FPS, standard deviation, CV percentage, segment FPS list, total processing time + - **technical_details**: Timeout seconds, start/end times, success status +- Used for enhanced video wall overlays and detailed performance tracking + +### 5. Merged Videos (when using `--prompts`) +- `{config}_{video}_merged.mp4`: Combined video from all prompts +- Each frame sequence from a prompt is concatenated into the final video +- Maintains input video framerate and timing + +### 6. Single Prompt Videos (when not using `--prompts`) +- `{config}_{video}_output.mp4`: Output video for single prompt processing +- Maintains input video framerate and timing +- Suitable for direct comparison with input videos + +### 7. Video Timing Information +All output videos automatically: +- Match the input video's framerate (e.g., 30fps, 24fps, 60fps) +- Preserve the original video's timing and duration +- Use precise frame duration calculations for professional quality + +## Example Output + +### With Config Prompts +``` +StreamDiffusion Multi-Config Test Suite Results +============================================================ + +Overall Results: + Total tests: 6 + Successful: 6 + Failed: 0 + Success rate: 100.0% + +Quick Performance Summary: +---------------------------------------------------------------------------------------------------- +Config Video Resolution Overall FPS Avg FPS Min FPS Max FPS Frames +---------------------------------------------------------------------------------------------------- +config1 video1.mp4 512x512 15.23 15.23 14.89 15.67 300 +config1 video2.mp4 512x512 14.89 14.89 14.50 15.28 300 +config2 video1.mp4 512x512 18.45 18.45 18.10 18.80 300 +---------------------------------------------------------------------------------------------------- + +Results by Config: + config1: + Tests: 2/2 successful + Model: runwayml/stable-diffusion-v1-5 + Resolution: 512x512 + โœ… video1.mp4 (300 frames) - Overall FPS: 15.23, Min FPS: 14.89, Max FPS: 15.67, Avg FPS: 15.23, CV: 2.5% + โœ… video2.mp4 (300 frames) - Overall FPS: 14.89, Min FPS: 14.50, Max FPS: 15.28, Avg FPS: 14.89, CV: 2.7% + +Overall FPS Rankings (Higher is Better): + 1. config2 - 18.45 FPS (Avg: 18.45, Range: 18.10-18.80) + 2. config1 - 15.23 FPS (Avg: 15.23, Range: 14.89-15.67) + +Performance Statistics: + Overall FPS - Best: 18.45, Worst: 14.89, Mean: 16.19 + Average FPS - Best: 18.45, Worst: 14.89, Mean: 16.19 + Min FPS - Best: 18.10, Worst: 14.50, Mean: 15.83 + Max FPS - Best: 18.80, Worst: 15.28, Mean: 16.58 + +Recommendations: +๐Ÿ† Best Overall Performance: config2 + - Highest sustained FPS: 18.45 + - Best for: Maximum throughput scenarios + +๐Ÿ“Š Most Consistent Performance: config1 + - Lowest variance: 2.5% CV + - Best for: Real-time applications requiring stable frame rates +``` + +### With Individual Prompts +``` +StreamDiffusion Multi-Config Test Suite Results +Using 3 individual prompts from prompts.txt +============================================================ + +Overall Results: + Total tests: 6 + Successful: 6 + Failed: 0 + Success rate: 100.0% + +Quick Performance Summary: +---------------------------------------------------------------------------------------------------- +Config Video Resolution Overall FPS Avg FPS Min FPS Max FPS Frames +---------------------------------------------------------------------------------------------------- +config1 video1.mp4 512x512 15.23 15.23 14.89 15.67 900 +config1 video2.mp4 512x512 14.89 14.89 14.50 15.28 900 +config2 video1.mp4 512x512 18.45 18.45 18.10 18.80 900 +---------------------------------------------------------------------------------------------------- + +Results by Config: + config1: + Tests: 2/2 successful + Model: runwayml/stable-diffusion-v1-5 + Resolution: 512x512 + Prompts processed: 3/3 successful + โœ… video1.mp4 (900 frames, 3 prompts) - Overall FPS: 15.23, Min FPS: 14.89, Max FPS: 15.67, Avg FPS: 15.23, CV: 2.5% + โœ… video2.mp4 (900 frames, 3 prompts) - Overall FPS: 14.89, Min FPS: 14.50, Max FPS: 15.28, Avg FPS: 14.89, CV: 2.7% + +Performance Consistency Analysis: +Configs ranked by FPS stability (lower variance = more stable): + 1. config1 - CV: 2.5% (Std: 0.38, Range: 0.78) + Mean: 15.23 FPS, Min: 14.89, Max: 15.67 + 2. config2 - CV: 2.1% (Std: 0.35, Range: 0.70) + Mean: 18.45 FPS, Min: 18.10, Max: 18.80 + +Best Config per Video (Overall FPS): +------------------------------------------------------------ +video1.mp4 -> config2 (18.45 FPS, Avg: 18.45) +video2.mp4 -> config1 (14.89 FPS, Avg: 14.89) + +Performance Improvement Analysis: +Best Overall Config: config2 (18.45 FPS) + +Performance vs Best (Overall FPS): + config1 - 15.23 FPS (+21.1% vs best) + +Recommendations: +โš–๏ธ Best Balanced (Performance + Consistency): config2 + - Balanced score: 0.850 + - Performance: 18.45 FPS, Consistency: 2.1% CV + - Best for: Production environments requiring both speed and reliability +``` + +## Performance Tips + +1. **Use Temporal Splitting**: With `--prompts`, videos are processed once instead of multiple times (3x faster) +2. **Use TensorRT**: Set `acceleration: "tensorrt"` in your configs for best performance +3. **Optimize t_index_list**: Lower values (e.g., [10, 15]) for faster processing, higher values for better quality +4. **Batch Processing**: Enable `use_denoising_batch: true` for better throughput +5. **Resolution**: Lower resolutions process faster but may reduce quality +6. **Model Selection**: Smaller models (SD1.5 vs SDXL) generally process faster +7. **Prompt Length**: Shorter prompts generally process faster than very long, detailed ones +8. **RAM Optimization**: The suite automatically caches frames in RAM for maximum speed +9. **Framerate Optimization**: Output videos automatically match input timing for professional workflows +10. **Dynamic Prompt Updates**: Leverages StreamDiffusion's built-in prompt switching without pipeline restarts +11. **ControlNet Optimization**: Use fewer ControlNets and lower conditioning scales for faster processing +12. **Preprocessor Selection**: Choose efficient preprocessors (e.g., passthrough > canny > depth > openpose) + +## Troubleshooting + +### Common Issues + +1. **ffmpeg not found**: Install ffmpeg and ensure it's in your PATH +2. **CUDA out of memory**: Use memory management options (see below) +3. **Config validation errors**: Check that required fields are present in your YAML files +4. **Model loading failures**: Verify model paths and ensure models are accessible +5. **Video merging fails**: Ensure ffmpeg supports the concat demuxer +6. **Out of memory**: Reduce video resolution or frame count, or close other applications +7. **Framerate detection fails**: Ensure ffprobe is available and input videos are valid + +### CUDA Memory Issues + +If you encounter `CUDA out of memory` errors, the test suite now includes several memory management features: + +#### 1. Batch Processing +Process frames in smaller batches to reduce memory usage: +```bash +# Default: 10 frames per batch +python multi_test.py --configs ./configs --videos ./videos + +# Reduce to 5 frames per batch for lower memory usage +python multi_test.py --configs ./configs --videos ./videos --batch-size 5 + +# Very conservative: 3 frames per batch +python multi_test.py --configs ./configs --videos ./videos --batch-size 3 +``` + +#### 2. Memory Threshold Management +Set when automatic memory cleanup should occur: +```bash +# Default: Cleanup when less than 2GB free +python multi_test.py --configs ./configs --videos ./videos + +# More aggressive: Cleanup when less than 1GB free +python multi_test.py --configs ./configs --videos ./videos --memory-threshold 1.0 + +# Very aggressive: Cleanup when less than 0.5GB free +python multi_test.py --configs ./configs --videos ./videos --memory-threshold 0.5 +``` + +#### 3. Frame Skipping +For very long videos, process every Nth frame to reduce memory and time: +```bash +# Process every frame (default) +python multi_test.py --configs ./configs --videos ./videos + +# Process every 2nd frame (2x faster, 2x less memory) +python multi_test.py --configs ./configs --videos ./videos --frame-skip 2 + +# Process every 3rd frame (3x faster, 3x less memory) +python multi_test.py --configs ./configs --videos ./videos --frame-skip 3 +``` + +#### 4. Combined Memory Management +Use all options together for maximum memory efficiency: +```bash +python multi_test.py --configs ./configs --videos ./videos \ + --batch-size 3 \ + --memory-threshold 0.5 \ + --frame-skip 2 +``` + +#### 5. Automatic Memory Recovery +The test suite now automatically: +- Monitors GPU memory usage in real-time +- Cleans up memory after each batch and prompt +- Retries failed frames after memory cleanup +- Provides detailed memory status information +- Gracefully handles out-of-memory errors without crashing + +### Debug Mode + +For detailed logging, you can modify the script to add more verbose output or check the individual result JSON files for specific error details. + +## Advanced Usage + +### Custom Frame Extraction + +You can modify the `extract_frames_from_video` method to customize frame extraction parameters (FPS, format, etc.). + +### Custom Metrics + +Extend the `TestResult` dataclass to include additional metrics like memory usage, GPU utilization, etc. + +### Parallel Processing + +For faster testing, you could modify the suite to process multiple configs in parallel (requires careful resource management). + +### Custom Video Merging + +Modify the `merge_videos_from_prompts` method to customize how videos are combined (different frame rates, transitions, etc.). + +### Framerate Customization + +You can modify the `get_video_framerate` method to implement custom framerate detection logic or override framerates for specific use cases. + +## Example Workflow + +### With Individual Prompts + +1. **Setup**: Create directories and add your configs, videos, and prompts file +2. **Run Tests**: Execute the test suite with `--prompts prompts.txt` +3. **Analyze Results**: Review the generated reports and merged videos +4. **Optimize**: Use results to tune your configurations and prompts +5. **Iterate**: Run tests again with optimized configs + +### Without Individual Prompts + +1. **Setup**: Create directories and add your configs and videos +2. **Run Tests**: Execute the test suite (uses config prompts) +3. **Analyze Results**: Review the generated reports and output videos +4. **Optimize**: Use results to tune your configurations +5. **Iterate**: Run tests again with optimized configs + +## Contributing + +Feel free to extend the test suite with additional features: +- Memory usage tracking +- GPU utilization monitoring +- Quality metrics (PSNR, SSIM) +- Automated optimization suggestions +- Integration with CI/CD pipelines +- Custom video effects and transitions +- Prompt performance analysis and optimization +- Advanced memory management strategies +- Custom framerate handling and video processing \ No newline at end of file diff --git a/multi_test/enhanced_video_wall.py b/multi_test/enhanced_video_wall.py new file mode 100644 index 00000000..42c09910 --- /dev/null +++ b/multi_test/enhanced_video_wall.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 +""" +Enhanced Video Wall Creator using JSON metadata + +This module creates video walls with rich metadata information from JSON files +stored alongside each processed video, providing better data for resume functionality. +""" + +import os +import json +from pathlib import Path +from typing import Dict, List, Optional + +try: + import ffmpeg +except ImportError: + print("Warning: ffmpeg-python library not found. Enhanced video wall creation will be disabled.") + ffmpeg = None + +def load_video_metadata(results_dir: str) -> Dict[str, Dict]: + """ + Load all JSON metadata files from results directory. + + Parameters + ---------- + results_dir : str + Directory containing video results and JSON metadata + + Returns + ------- + Dict[str, Dict] + Dictionary mapping video filenames to their metadata + """ + metadata_dict = {} + + if not os.path.exists(results_dir): + return metadata_dict + + print(f"๐Ÿ“‹ Loading video metadata from: {results_dir}") + + try: + json_files = [f for f in os.listdir(results_dir) if f.endswith('_metadata.json')] + + for json_file in json_files: + json_path = os.path.join(results_dir, json_file) + try: + with open(json_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Extract video filename from metadata + video_info = metadata.get('video_info', {}) + output_filename = video_info.get('output_filename', '') + + if output_filename: + metadata_dict[output_filename] = metadata + + except Exception as e: + print(f" โš ๏ธ Warning: Could not load {json_file}: {e}") + + print(f" โœ… Loaded metadata for {len(metadata_dict)} videos") + + except Exception as e: + print(f" โŒ Error loading metadata: {e}") + + return metadata_dict + +def create_enhanced_video_with_metadata( + input_path: str, + output_path: str, + metadata: Dict, + width: int, + height: int +) -> bool: + """ + Create a scaled video with enhanced metadata overlay using ffmpeg-python + + Args: + input_path: Path to input video + output_path: Path to output video + metadata: Video metadata dictionary + width: Target width + height: Target height + + Returns: + True if successful, False otherwise + """ + if ffmpeg is None: + return False + + try: + # Extract key information from metadata + video_info = metadata.get('video_info', {}) + config_details = metadata.get('config_details', {}) + performance = metadata.get('performance_metrics', {}) + + config_name = video_info.get('config_filename', 'Unknown') + model_name = config_details.get('model_id', 'Unknown').split('/')[-1] + resolution = f"{config_details.get('width', '?')}x{config_details.get('height', '?')}" + overall_fps = performance.get('overall_fps', 0) + avg_fps = performance.get('avg_fps', 0) + total_frames = video_info.get('total_frames', 0) + processing_time = performance.get('total_processing_time', 0) + + # Create multi-line text overlay with rich information + text_lines = [ + f"Config: {config_name}", + f"Model: {model_name}", + f"Resolution: {resolution}", + f"Frames: {total_frames}", + f"Overall FPS: {overall_fps:.1f}", + f"Avg FPS: {avg_fps:.1f}", + f"Time: {processing_time:.1f}s" + ] + + # Create the processing pipeline + stream = ffmpeg.input(input_path) + + # Scale video to target size with padding + scaled = ffmpeg.filter( + stream, + 'scale', + width, height, + force_original_aspect_ratio='decrease' + ) + + padded = ffmpeg.filter( + scaled, + 'pad', + width, height, + '(ow-iw)/2', '(oh-ih)/2' + ) + + # Add multi-line text overlay + font_path = r'C:/Windows/Fonts/arial.ttf' + + # Start with the padded video + current_stream = padded + + # Add each line of text + for i, line in enumerate(text_lines): + y_position = 10 + (i * 25) # 25 pixels between lines + + if os.path.exists(font_path): + current_stream = ffmpeg.drawtext( + current_stream, + text=line, + fontfile=font_path, + fontcolor='white', + fontsize=16, + box=1, + boxcolor='black@0.8', + boxborderw=3, + x=10, + y=y_position + ) + else: + # Fallback without fontfile + current_stream = ffmpeg.drawtext( + current_stream, + text=line, + fontcolor='white', + fontsize=16, + box=1, + boxcolor='black@0.8', + boxborderw=3, + x=10, + y=y_position + ) + + # Output with encoding settings + output = ffmpeg.output( + current_stream, + output_path, + vcodec='libx264', + crf=23, + preset='medium' + ) + + # Run the pipeline + ffmpeg.run(output, overwrite_output=True, quiet=True) + return True + + except Exception as e: + print(f" Error processing enhanced video: {e}") + return False + +def create_enhanced_video_wall( + results_dir: str, + output_path: str, + grid_width: int = 3, # Kept for backward compatibility, but not used in new layout + video_width: int = 512, + video_height: int = 512 +) -> Optional[str]: + """ + Create an enhanced video wall using JSON metadata for rich information display. + + The layout is automatically determined: each row represents one original video, + and columns represent different configurations (including original). + + Parameters + ---------- + results_dir : str + Directory containing video results and JSON metadata + output_path : str + Path for the output video wall + grid_width : int, optional + Legacy parameter kept for backward compatibility (not used in new layout) + video_width : int, optional + Width of each video in the wall, by default 512 + video_height : int, optional + Height of each video in the wall, by default 512 + + Returns + ------- + Optional[str] + Path to the created video wall, or None if failed + """ + + print("\n๐ŸŽฌ Creating enhanced video wall with JSON metadata...") + + if ffmpeg is None: + print(" โŒ Skipping - ffmpeg-python not available") + return None + + # Load metadata for all videos + metadata_dict = load_video_metadata(results_dir) + + if not metadata_dict: + print(" โŒ No video metadata found") + return None + + # Find corresponding video files + video_files = [] + enhanced_metadata = [] + + for filename, metadata in metadata_dict.items(): + video_path = os.path.join(results_dir, filename) + if os.path.exists(video_path): + video_files.append(video_path) + enhanced_metadata.append(metadata) + else: + print(f" โš ๏ธ Warning: Video file not found: {filename}") + + if not video_files: + print(" โŒ No video files found") + return None + + print(f" ๐Ÿ“น Processing {len(video_files)} videos for enhanced wall") + + # Create temporary directory for processing + temp_dir = os.path.join(os.path.dirname(output_path), "temp_enhanced_wall") + os.makedirs(temp_dir, exist_ok=True) + + try: + # Process each video with enhanced metadata overlay + processed_videos = [] + min_duration = float('inf') + + print(" ๐Ÿ”„ Processing videos with enhanced metadata...") + for i, (video_path, metadata) in enumerate(zip(video_files, enhanced_metadata)): + print(f" Processing {i+1}/{len(video_files)}: {os.path.basename(video_path)}") + + # Get video duration + try: + probe = ffmpeg.probe(video_path) + duration = float(probe['format']['duration']) + min_duration = min(min_duration, duration) + except Exception as e: + print(f" Warning: Could not get duration: {e}") + duration = 10 + min_duration = min(min_duration, duration) + + # Create enhanced video with metadata overlay + enhanced_path = os.path.join(temp_dir, f"enhanced_{i:03d}.mp4") + + success = create_enhanced_video_with_metadata( + video_path, + enhanced_path, + metadata, + video_width, + video_height + ) + + if success: + processed_videos.append(enhanced_path) + print(f" โœ… Enhanced video created") + else: + print(f" โŒ Failed to enhance video") + + if not processed_videos: + print(" โŒ No videos were successfully processed") + return None + + if min_duration == float('inf'): + min_duration = 10 + + print(f" ๐ŸŽฌ Creating video wall with flipped layout...") + + # Create video wall grid + input_streams = [] + + # Load all processed videos as input streams + for video_path in processed_videos: + stream = ffmpeg.input(video_path) + input_streams.append(stream) + + # Pad with blank videos if needed to fill grid + total_videos = len(input_streams) + videos_per_row = grid_width + rows_needed = (total_videos + videos_per_row - 1) // videos_per_row + total_slots = rows_needed * videos_per_row + + # Create blank videos for empty slots + for i in range(total_videos, total_slots): + blank_path = os.path.join(temp_dir, f"blank_{i}.mp4") + + # Create a blank video + blank_input = ffmpeg.input( + f'color=c=gray:s={video_width}x{video_height}:d={min_duration}', + f='lavfi' + ) + + blank_with_text = ffmpeg.drawtext( + blank_input, + text='No Video', + fontcolor='white', + fontsize=24, + x='(w-text_w)/2', + y='(h-text_h)/2' + ) + + blank_output = ffmpeg.output(blank_with_text, blank_path, vcodec='libx264', crf=23) + ffmpeg.run(blank_output, overwrite_output=True, quiet=True) + + input_streams.append(ffmpeg.input(blank_path)) + + # Create rows with flipped layout: each row = one original video + all its config outputs + # This assumes videos are ordered as: original_video1, config1_video1, config2_video1, ..., original_video2, config1_video2, etc. + + # First, determine how many configs (including original) we have + # We need to figure this out from the metadata + if processed_videos: + # Get unique config names from metadata + config_names = set() + for metadata in enhanced_metadata: + config_names.add(metadata.get('video_info', {}).get('config_filename', 'Unknown')) + config_names = ['original'] + sorted(list(config_names)) + + # Get unique video names + video_names = set() + for metadata in enhanced_metadata: + # Extract original video name from the output filename + output_filename = metadata.get('video_info', {}).get('output_filename', '') + if output_filename: + # Try to extract video name - this depends on naming convention + # Assuming format like: configName_videoName_merged_5prompts.mp4 + parts = output_filename.replace('.mp4', '').split('_') + if len(parts) >= 2: + video_name = parts[1] # Second part should be video name + video_names.add(video_name) + video_names = sorted(list(video_names)) + + print(f" Detected {len(video_names)} videos and {len(config_names)} configs (including original)") + print(f" Flipped layout: {len(video_names)} rows x {len(config_names)} columns") + + # Reorder streams for flipped layout + reordered_streams = [] + for video_name in video_names: + for config_name in config_names: + # Find the stream for this video+config combination + found = False + for i, metadata in enumerate(enhanced_metadata): + video_info = metadata.get('video_info', {}) + if (video_info.get('config_filename') == config_name or + (config_name == 'original' and 'original' in video_info.get('output_filename', ''))): + # Check if this is the right video + output_filename = video_info.get('output_filename', '') + if video_name in output_filename: + reordered_streams.append(input_streams[i]) + found = True + break + + if not found: + # Create placeholder for missing combination + placeholder_path = os.path.join(temp_dir, f"placeholder_{video_name}_{config_name}.mp4") + placeholder_text = f"MISSING_{config_name}_{video_name}".replace(' ', '_') + + try: + blank_input = ffmpeg.input( + f'color=c=gray:s={video_width}x{video_height}:d={min_duration}', + f='lavfi' + ) + + blank_with_text = ffmpeg.drawtext( + blank_input, + text=placeholder_text, + fontcolor='white', + fontsize=24, + x='(w-text_w)/2', + y='(h-text_h)/2' + ) + + blank_output = ffmpeg.output(blank_with_text, placeholder_path, vcodec='libx264', crf=23) + ffmpeg.run(blank_output, overwrite_output=True, quiet=True) + + reordered_streams.append(ffmpeg.input(placeholder_path)) + except Exception as e: + print(f" Failed to create placeholder: {e}") + return None + + # Create rows from reordered streams + rows = [] + for row_idx in range(len(video_names)): + start_idx = row_idx * len(config_names) + end_idx = start_idx + len(config_names) + row_streams = reordered_streams[start_idx:end_idx] + + if len(row_streams) > 1: + row_combined = ffmpeg.filter(row_streams, 'hstack', inputs=len(row_streams)) + else: + row_combined = row_streams[0] + + rows.append(row_combined) + + # Combine rows vertically + if len(rows) > 1: + final_grid = ffmpeg.filter(rows, 'vstack', inputs=len(rows)) + else: + final_grid = rows[0] + else: + # Fallback to original logic if metadata parsing fails + print(" Warning: Could not parse metadata for flipped layout, using original grid") + rows = [] + for row_idx in range(rows_needed): + start_idx = row_idx * videos_per_row + end_idx = min(start_idx + videos_per_row, len(input_streams)) + row_streams = input_streams[start_idx:end_idx] + + if len(row_streams) > 1: + row_combined = ffmpeg.filter(row_streams, 'hstack', inputs=len(row_streams)) + else: + row_combined = row_streams[0] + + rows.append(row_combined) + + # Combine rows vertically + if len(rows) > 1: + final_grid = ffmpeg.filter(rows, 'vstack', inputs=len(rows)) + else: + final_grid = rows[0] + + # Trim to minimum duration and output + trimmed = ffmpeg.filter(final_grid, 'trim', duration=min_duration) + final_output = ffmpeg.output( + trimmed, + output_path, + vcodec='libx264', + crf=20, + preset='medium' + ) + + print(" ๐ŸŽฌ Rendering final enhanced video wall...") + ffmpeg.run(final_output, overwrite_output=True, quiet=True) + + print(f" โœ… Enhanced video wall created: {output_path}") + + # Clean up temporary files + print(" ๐Ÿงน Cleaning up temporary files...") + try: + import shutil + shutil.rmtree(temp_dir) + except Exception as e: + print(f" Warning: Could not clean up temp directory: {e}") + + return output_path + + except Exception as e: + print(f" โŒ Error creating enhanced video wall: {e}") + import traceback + traceback.print_exc() + return None + +def main(): + """Example usage of enhanced video wall creation.""" + import argparse + + parser = argparse.ArgumentParser(description="Create enhanced video wall with JSON metadata") + parser.add_argument("--results_dir", required=True, help="Directory containing video results and JSON metadata") + parser.add_argument("--output", required=True, help="Output path for video wall") + parser.add_argument("--grid_width", type=int, default=3, help="Number of videos per row") + parser.add_argument("--video_width", type=int, default=512, help="Width of each video") + parser.add_argument("--video_height", type=int, default=512, help="Height of each video") + + args = parser.parse_args() + + result = create_enhanced_video_wall( + args.results_dir, + args.output, + args.grid_width, + args.video_width, + args.video_height + ) + + if result: + print(f"\nโœ… Enhanced video wall created successfully: {result}") + return 0 + else: + print(f"\nโŒ Failed to create enhanced video wall") + return 1 + +if __name__ == "__main__": + exit(main()) + + + diff --git a/multi_test/multi_test.py b/multi_test/multi_test.py new file mode 100644 index 00000000..e736a374 --- /dev/null +++ b/multi_test/multi_test.py @@ -0,0 +1,2076 @@ +#!/usr/bin/env python3 +""" +StreamDiffusion Multi-Config Test Suite + +This script processes multiple videos with multiple YAML configurations, +similar to main.py but for batch testing. It can use individual prompts +from a text file or config prompts. + +Key Features: +- Memory-efficient processing with automatic cleanup between configs +- One merged video output per config (combining all prompt segments) +- Real-time memory monitoring and cleanup +- Pipeline reset between configs to prevent memory issues + +Usage: + python multi_test.py --configs ./configs --videos ./videos --output ./results + python multi_test.py --configs ./configs --videos ./videos --prompts ./prompts.txt --output ./results + python multi_test.py --configs ./configs --videos ./videos --output ./results --timeout_seconds 600 + +Based on the StreamDiffusion framework and main.py architecture. +""" + +import os +import datetime +import sys +import time +import yaml +import argparse +import signal +import atexit +import subprocess +import csv +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Set + +try: + import fire +except ImportError: + print("Error: 'fire' package not found. Please install it with: pip install fire") + sys.exit(1) + +try: + import ffmpeg +except ImportError: + print("Warning: ffmpeg-python library not found. Video wall creation will be disabled.") + print("To enable video wall creation, install it with: pip install ffmpeg-python") + ffmpeg = None + +# Import enhanced video wall functions if available +try: + from enhanced_video_wall import create_enhanced_video_with_metadata +except ImportError: + print("Warning: Enhanced video wall module not found. Will use fallback video processing.") + create_enhanced_video_with_metadata = None + +import torch +from torchvision.io import read_video, write_video +from torchvision.transforms import functional as F +from tqdm import tqdm + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from streamdiffusion import StreamDiffusionWrapper, load_config, create_wrapper_from_config + +# Global cleanup flag +_cleanup_completed = False + +def signal_handler(signum, frame): + """Handle system signals to ensure cleanup before exit.""" + print(f"\nReceived signal {signum}, cleaning up...") + cleanup_and_exit() + sys.exit(1) + +def cleanup_and_exit(): + """Ensure cleanup is performed before exit.""" + global _cleanup_completed + if not _cleanup_completed: + print("Performing final cleanup...") + try: + # Multiple rounds of cleanup to ensure everything is freed + for cleanup_round in range(3): + cleanup_gpu_memory() + + # Force garbage collection + import gc + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Small delay between cleanup rounds + import time + time.sleep(0.1) + + except Exception as e: + print(f"Warning: Final cleanup failed: {e}") + _cleanup_completed = True + print("Cleanup completed.") + +# Register signal handlers and exit handler +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) +atexit.register(cleanup_and_exit) + +def cleanup_gpu_memory(): + """Thorough GPU memory cleanup.""" + try: + if torch.cuda.is_available(): + # Clear PyTorch cache + torch.cuda.empty_cache() + + # Synchronize to ensure all operations are complete + torch.cuda.synchronize() + + # Force garbage collection + import gc + gc.collect() + + # Log memory after cleanup + memory_info = get_memory_info() + if memory_info: + print(f" Memory after cleanup: GPU allocated: {memory_info['gpu_allocated']:.2f} GB, " + f"reserved: {memory_info['gpu_reserved']:.2f} GB, free: {memory_info['gpu_free']:.2f} GB") + except Exception as e: + print(f" Warning: Memory cleanup failed: {e}") + pass + +def cleanup_pipeline(pipeline): + """Properly cleanup a pipeline and free VRAM using StreamDiffusion's built-in cleanup""" + if pipeline is None: + return + + try: + print(" Starting pipeline cleanup...") + + # Use StreamDiffusion's built-in cleanup method which properly handles: + # - TensorRT engine cleanup + # - ControlNet engine cleanup + # - Multiple garbage collection cycles + # - CUDA cache clearing + # - Memory tracking + if hasattr(pipeline, 'stream') and pipeline.stream and hasattr(pipeline.stream, 'cleanup_gpu_memory'): + pipeline.stream.cleanup_gpu_memory() + print(" Pipeline cleanup completed using StreamDiffusion cleanup") + elif hasattr(pipeline, 'cleanup_gpu_memory') and callable(getattr(pipeline, 'cleanup_gpu_memory')): + pipeline.cleanup_gpu_memory() + print(" Pipeline cleanup completed using pipeline cleanup method") + elif hasattr(pipeline, 'cleanup') and callable(getattr(pipeline, 'cleanup')): + pipeline.cleanup() + print(" Pipeline cleanup completed using generic cleanup method") + else: + # Fallback cleanup if the method doesn't exist + print(" StreamDiffusion cleanup method not found, using fallback cleanup") + if hasattr(pipeline, 'stream') and pipeline.stream: + del pipeline.stream + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + except Exception as e: + print(f" Error during pipeline cleanup: {e}") + # Still try to clear CUDA cache even if cleanup fails + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +def preprocess_video(input_video_path: str, target_width: int, target_height: int) -> torch.Tensor: + """Memory-efficient video preprocessing to target resolution, maintaining aspect ratio.""" + print(f"Preprocessing video: {input_video_path}") + print(f" Target resolution: {target_width}x{target_height}") + + # Load video metadata first to check size + video_data, _, info = read_video(input_video_path, pts_unit='sec') + original_fps = info["video_fps"] + num_frames = video_data.shape[0] + print(f" Original FPS: {original_fps}") + print(f" Loaded video shape: {video_data.shape}") + + # Calculate memory usage and warn if large + estimated_memory_gb = (num_frames * target_height * target_width * 3 * 4) / (1024**3) # 4 bytes per float32 + print(f" Estimated memory usage: {estimated_memory_gb:.2f} GB") + + if estimated_memory_gb > 4.0: + print(f" โš ๏ธ WARNING: Large video detected! Consider using batch processing for videos > 4GB") + + # Calculate resize parameters once + original_height, original_width = video_data.shape[1], video_data.shape[2] + original_aspect = original_width / original_height + target_aspect = target_width / target_height + + if original_aspect > target_aspect: + scale_height = target_height + scale_width = int(scale_height * original_aspect) + else: + scale_width = target_width + scale_height = int(scale_width / original_aspect) + + print(f" Resizing and cropping frames...") + + # Pre-allocate output tensor to avoid memory fragmentation + resized_video = torch.zeros(num_frames, target_height, target_width, 3, dtype=torch.float32) + + # Process frames in smaller batches to reduce peak memory usage + batch_size = min(50, num_frames) # Process 50 frames at a time + + for batch_start in tqdm(range(0, num_frames, batch_size), desc=" Processing batches"): + batch_end = min(batch_start + batch_size, num_frames) + + # Process batch of frames + for i in range(batch_start, batch_end): + # Convert to float and normalize (in-place to save memory) + frame = video_data[i].float() / 255.0 # Shape: (H, W, C) + frame_chw = frame.permute(2, 0, 1) + + # Resize maintaining aspect ratio + resized_frame_chw = F.resize(frame_chw, [scale_height, scale_width], antialias=True) + cropped_frame_chw = F.center_crop(resized_frame_chw, [target_height, target_width]) + final_frame = cropped_frame_chw.permute(1, 2, 0) + + # Store directly in pre-allocated tensor + resized_video[i] = final_frame + + # Clean up intermediate tensors + del frame, frame_chw, resized_frame_chw, cropped_frame_chw, final_frame + + # Force garbage collection after each batch + import gc + gc.collect() + + # Clean up original video data + del video_data + gc.collect() + + print(f" Final processed video shape: {resized_video.shape}") + print(f" Memory cleanup completed") + return resized_video + +def load_prompts(prompts_file: str) -> List[str]: + """Load prompts from text file.""" + with open(prompts_file, 'r', encoding='utf-8') as f: + prompts = [line.strip() for line in f.readlines() if line.strip()] + print(f"Loaded {len(prompts)} prompts from {prompts_file}") + return prompts + +def scan_completed_work(resume_dir: str) -> List[Dict]: + """ + Load existing results from CSV and JSON metadata if available. + + Parameters + ---------- + resume_dir : str + Path to existing output directory to resume from + + Returns + ------- + List[Dict] + List of existing results (both successful and failed from CSV if available) + """ + print(f"\n๐Ÿ” Scanning for completed work in: {resume_dir}") + + if not os.path.exists(resume_dir): + print(f"โŒ Resume directory does not exist: {resume_dir}") + return [] + + existing_results = [] + json_metadata = {} # Store JSON metadata by video filename + + # First, scan for JSON metadata files + print(f"๐Ÿ“‹ Scanning for JSON metadata files...") + try: + import json + json_files = [f for f in os.listdir(resume_dir) if f.endswith('_metadata.json')] + for json_file in json_files: + json_path = os.path.join(resume_dir, json_file) + try: + with open(json_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + # Use output filename as key for easy lookup + output_filename = metadata.get('video_info', {}).get('output_filename', '') + if output_filename: + json_metadata[output_filename] = metadata + except Exception as e: + print(f" โš ๏ธ Warning: Could not load JSON metadata {json_file}: {e}") + + print(f" Found {len(json_metadata)} JSON metadata files") + except Exception as e: + print(f" โš ๏ธ Warning: Error scanning JSON files: {e}") + + # Try to load existing results from CSV + csv_path = os.path.join(resume_dir, "detailed_results.csv") + if os.path.exists(csv_path): + print(f"๐Ÿ“Š Loading existing results from CSV: {csv_path}") + try: + import csv + with open(csv_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row in reader: + # Load both successful AND failed results to preserve all data + is_successful = row['Success'] == 'Yes' + + if is_successful: + # Reconstruct successful result dict + result_dict = { + 'config': row['Config'], + 'video': row['Video'], + 'model_id': row['Model ID'], + 'resolution': row['Resolution'], + 'total_frames': int(row['Total Frames']) if row['Total Frames'].isdigit() else 0, + 'prompts_used': int(row['Prompts Used']) if row['Prompts Used'].isdigit() else 1, + 'success': True, + 'output_file': row['Output File'], + 'fps_metrics': { + 'overall_fps': float(row['Overall FPS']) if row['Overall FPS'] != 'N/A' else 0, + 'min_fps': float(row['Min FPS']) if row['Min FPS'] != 'N/A' else 0, + 'max_fps': float(row['Max FPS']) if row['Max FPS'] != 'N/A' else 0, + 'avg_fps': float(row['Avg FPS']) if row['Avg FPS'] != 'N/A' else 0, + 'std_dev_fps': float(row['Std Dev FPS']) if row['Std Dev FPS'] != 'N/A' else 0, + 'cv_percent': float(row['CV %']) if row['CV %'] != 'N/A' else 0 + } + } + + # Enhance with JSON metadata if available + output_file = row['Output File'] + if output_file in json_metadata: + result_dict['json_metadata'] = json_metadata[output_file] + print(f" โœ… Enhanced {output_file} with JSON metadata") + + existing_results.append(result_dict) + else: + # Reconstruct failed result dict + existing_results.append({ + 'config': row['Config'], + 'video': row['Video'], + 'model_id': row['Model ID'], + 'resolution': row['Resolution'], + 'total_frames': int(row['Total Frames']) if row['Total Frames'].isdigit() else 0, + 'prompts_used': int(row['Prompts Used']) if row['Prompts Used'].isdigit() else 1, + 'success': False, + 'error': row['Error Message'] if row['Error Message'] else 'Unknown error' + }) + + successful_count = sum(1 for r in existing_results if r['success']) + failed_count = len(existing_results) - successful_count + print(f"โœ… Loaded {len(existing_results)} total results from CSV:") + print(f" - {successful_count} successful") + print(f" - {failed_count} failed") + except Exception as e: + print(f"โš ๏ธ Warning: Could not load CSV results: {e}") + else: + print(f"๐Ÿ“Š No existing CSV found - will create new results") + + return existing_results + +def check_output_exists(output_dir: str, config_filename: str, video_filename: str, config: Dict, prompts: Optional[List[str]] = None, existing_results: Optional[List[Dict]] = None, retry_failed: bool = False) -> bool: + """ + Check if output file already exists for this config+video combination. + + Parameters + ---------- + output_dir : str + Output directory + config_filename : str + Config filename (without extension) + video_filename : str + Video filename (without extension) + config : Dict + Configuration dictionary + prompts : Optional[List[str]] + List of prompts (to determine filename format) + existing_results : Optional[List[Dict]] + List of existing results to check against (for resume functionality) + retry_failed : bool, optional + Whether to retry previously failed combinations, by default False + + Returns + ------- + bool + True if output file already exists or combination was already processed + """ + # First check if this combination was already processed (from loaded CSV data) + if existing_results: + for result in existing_results: + if result['config'] == config_filename and result['video'] == video_filename: + if result['success']: + print(f" โœ… Combination already completed successfully: {config_filename} + {video_filename}") + return True + else: + if retry_failed: + print(f" ๐Ÿ”„ Retrying previously failed combination: {config_filename} + {video_filename} (Previous error: {result.get('error', 'Unknown')})") + return False # Allow retry + else: + print(f" โš ๏ธ Combination previously failed: {config_filename} + {video_filename} (Error: {result.get('error', 'Unknown')})") + return True # Skip retry + + # Then check if output file exists on disk + # Generate the expected output filename using the same logic as process_video_with_config + config_name = config.get('model_id', 'unknown') + # Clean up the config name to make it filesystem-safe + if '/' in config_name: + config_name = config_name.split('/')[-1] + if '\\' in config_name: + config_name = config_name.split('\\')[-1] + # Remove file extensions + config_name = config_name.replace('.safetensors', '').replace('.ckpt', '').replace('.pth', '') + + # Create expected filename + num_prompts = len(prompts) if prompts else 1 + expected_filename = f"{config_filename}_{video_filename}_{config_name}_merged_{num_prompts}prompts.mp4" + expected_path = os.path.join(output_dir, expected_filename) + + exists = os.path.exists(expected_path) + if exists: + print(f" โœ… Output file already exists: {expected_filename}") + + return exists + +def process_video_with_config( + video: torch.Tensor, + config: Dict, + prompts: Optional[List[str]] = None, + output_dir: str = "./output", + config_filename: str = "unknown_config", + video_filename: str = "unknown_video", + timeout_seconds: int = 600 # 10 minutes timeout per video +) -> Optional[Dict]: + """Process a video with a config, optionally using custom prompts with temporal splitting. + + Parameters + ---------- + video : torch.Tensor + Input video tensor + config : Dict + Configuration dictionary + prompts : Optional[List[str]], optional + List of prompts for temporal splitting, by default None + output_dir : str, optional + Output directory for results, by default "./output" + config_filename : str, optional + Name of the config file (without extension) for output filename, by default "unknown_config" + video_filename : str, optional + Name of the video file (without extension) for output filename, by default "unknown_video" + timeout_seconds : int, optional + Maximum time to spend processing this video, by default 600 (10 minutes) + """ + + print(f"\nProcessing with config: {config.get('model_id', 'Unknown')}") + print(f" Timeout set to {timeout_seconds} seconds") + + # Track start time for timeout + start_time = time.time() + + # Clean GPU state before building pipeline + cleanup_gpu_memory() + log_memory_usage("before pipeline creation") + + stream = None + try: + # Check timeout before starting + if time.time() - start_time > timeout_seconds: + raise TimeoutError(f"Timeout exceeded before starting processing") + + # Create wrapper using config system + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch_dtype = torch.float16 + + overrides = { + 'device': device, + 'dtype': torch_dtype, + 'output_type': 'pt', + } + + print(" Creating pipeline...") + stream = create_wrapper_from_config(config, **overrides) + log_memory_usage("after pipeline creation") + + if stream is None: + raise RuntimeError("Failed to create pipeline - stream is None") + + # Check timeout after pipeline creation + if time.time() - start_time > timeout_seconds: + raise TimeoutError(f"Timeout exceeded after pipeline creation") + + # Debug ControlNet setup + print(f" Stream created successfully") + if hasattr(stream, 'preprocessors'): + print(f" ControlNet preprocessors found: {len(stream.preprocessors)}") + for idx, preproc in enumerate(stream.preprocessors): + if preproc: + print(f" Preprocessor {idx}: {preproc.__class__.__name__}") + if hasattr(preproc, 'params'): + print(f" Params: {preproc.params}") + else: + print(f" Preprocessor {idx}: None") + else: + print(f" No ControlNet preprocessors found on stream") + + # Check if ControlNet images are available + if hasattr(stream, 'controlnet_images'): + print(f" ControlNet images available: {len(stream.controlnet_images)}") + for idx, img in enumerate(stream.controlnet_images): + if img is not None: + print(f" ControlNet {idx} image shape: {img.shape if hasattr(img, 'shape') else 'Unknown'}") + else: + print(f" ControlNet {idx} image: None") + + # Check what ControlNet methods are available + controlnet_methods = [] + if hasattr(stream, 'update_control_image'): + controlnet_methods.append('update_control_image') + if hasattr(stream, 'update_control_image_efficient'): + controlnet_methods.append('update_control_image_efficient') + if hasattr(stream, 'stream') and hasattr(stream.stream, 'update_control_image'): + controlnet_methods.append('stream.update_control_image') + if hasattr(stream, 'stream') and hasattr(stream.stream, 'update_control_image_efficient'): + controlnet_methods.append('stream.update_control_image_efficient') + + print(f" Available ControlNet methods: {controlnet_methods}") + + # Check if we have a nested stream structure + if hasattr(stream, 'stream'): + print(f" Stream has nested stream object") + if hasattr(stream.stream, 'preprocessors'): + print(f" Nested stream has {len(stream.stream.preprocessors)} preprocessors") + else: + print(f" Stream is direct (no nested structure)") + + # Get base prompt from config if no custom prompts + if not prompts: + prompt_config = config.get('prompt_blending', {}) + if isinstance(prompt_config, dict) and 'prompt_list' in prompt_config: + first_prompt = prompt_config['prompt_list'][0][0] if prompt_config['prompt_list'] else "a beautiful landscape" + else: + first_prompt = config.get('prompt', 'a beautiful landscape') + prompts = [first_prompt] + + # Calculate frames per prompt for temporal splitting + total_frames = video.shape[0] + frames_per_prompt = total_frames // len(prompts) + remaining_frames = total_frames % len(prompts) + + print(f" Total frames: {total_frames}, Frames per prompt: {frames_per_prompt}") + print(f" Remaining frames: {remaining_frames} (will be distributed to first prompts)") + + # Process each prompt against its time segment and accumulate results + fps_metrics = [] # Track FPS for each segment + segment_times = [] # Track actual processing time for each segment + all_output_frames = [] # Accumulate all segments for final merged video + + for i, prompt in enumerate(prompts): + # Check timeout before processing each prompt + if time.time() - start_time > timeout_seconds: + raise TimeoutError(f"Timeout exceeded while processing prompt {i+1}") + + print(f" Processing prompt {i+1}/{len(prompts)}: '{prompt[:50]}...'") + + try: + # Calculate frame range for this prompt + start_frame = i * frames_per_prompt + end_frame = start_frame + frames_per_prompt + + # Distribute remaining frames to first prompts + if i < remaining_frames: + end_frame += 1 + + # Get frames for this time segment + segment_frames = video[start_frame:end_frame] + print(f" Processing frames {start_frame+1}-{end_frame} ({len(segment_frames)} frames)") + + # Update stream with new prompt (no pipeline restart needed) + stream.update_prompt(prompt) + + # Prepare the stream if this is the first prompt + if i == 0: + stream.prepare( + prompt=prompt, + negative_prompt=config.get('negative_prompt', ''), + num_inference_steps=config.get('num_inference_steps', 35), + guidance_scale=config.get('guidance_scale', 1.5), + ) + + # Process frames for this time segment + print(" Processing frames...") + segment_start_time = time.time() + + # Create output tensor for this segment + height, width = segment_frames.shape[1], segment_frames.shape[2] + segment_result = torch.zeros(len(segment_frames), height, width, 3, dtype=torch.float32) + + # Warmup on first frame if this is the first prompt + if i == 0: + print(" Warming up...") + try: + for _ in range(min(stream.batch_size, 3)): # Limit warmup to prevent memory issues + warmup_result = stream(image=segment_frames[0].permute(2, 0, 1)) + if warmup_result is None: + print(" Warning: Warmup returned None") + except Exception as e: + print(f" Warning: Warmup failed: {e}") + + # Process frames for this time segment + for j in tqdm(range(len(segment_frames)), desc=" Processing frames"): + # Check timeout periodically during frame processing + if j % 10 == 0 and time.time() - start_time > timeout_seconds: + raise TimeoutError(f"Timeout exceeded while processing frame {j}") + + try: + # Get the input frame + input_frame = segment_frames[j].permute(2, 0, 1) + + # Apply ControlNet preprocessing if available + if hasattr(stream, 'preprocessors') and stream.preprocessors: + # Convert frame to PIL Image for ControlNet preprocessing + import torchvision.transforms.functional as F + frame_pil = F.to_pil_image(input_frame) + + # Update control image for each ControlNet - call directly on the wrapper + for cn_idx in range(len(stream.preprocessors)): + if stream.preprocessors[cn_idx]: + try: + stream.update_control_image(index=cn_idx, image=frame_pil) + except Exception as e: + print(f" Warning: ControlNet {cn_idx} update failed: {e}") + elif hasattr(stream, 'stream') and hasattr(stream.stream, 'preprocessors') and stream.stream.preprocessors: + # Handle nested stream structure - still call update_control_image on the wrapper + import torchvision.transforms.functional as F + frame_pil = F.to_pil_image(input_frame) + + # Update control image for each nested ControlNet - call on wrapper, not nested stream + for cn_idx in range(len(stream.stream.preprocessors)): + if stream.stream.preprocessors[cn_idx]: + try: + stream.update_control_image(index=cn_idx, image=frame_pil) + except Exception as e: + print(f" Warning: Nested ControlNet {cn_idx} update failed: {e}") + + # Process frame through the stream - ControlNet preprocessing has been applied above + output_image = stream(image=input_frame) + + if output_image is None: + print(f" Warning: Frame {j} returned None, skipping") + continue + + # Handle batch dimension if present + if output_image.dim() == 4: + segment_result[j] = output_image.squeeze(0).permute(1, 2, 0).clamp(0, 1) + elif output_image.dim() == 3: + segment_result[j] = output_image.permute(1, 2, 0).clamp(0, 1) + else: + print(f" Warning: unexpected tensor dimensions: {output_image.shape}") + continue + + except Exception as e: + print(f" Error processing frame {j}: {e}") + # Continue with next frame instead of failing completely + continue + + processing_time = time.time() - segment_start_time + effective_fps = len(segment_frames) / processing_time + fps_metrics.append(effective_fps) + segment_times.append(processing_time) # Store actual processing time + print(f" Processed {len(segment_frames)} frames in {processing_time:.2f}s ({effective_fps:.2f} FPS)") + + # Add segment frames to overall result for final merged video + all_output_frames.append(segment_result) + + # Clean up segment processing memory + del segment_result + import gc + gc.collect() + + # Clean up GPU memory after each segment + cleanup_gpu_memory() + log_memory_usage(f"after segment {i+1} completion") + + except Exception as e: + print(f" ERROR processing prompt {i+1}: {e}") + import traceback + traceback.print_exc() + # Continue with next prompt instead of failing completely + continue + + if not all_output_frames: + raise RuntimeError("No segments were processed successfully") + + # Combine all segments into final merged video + print(" Combining all prompt segments...") + final_video = torch.cat(all_output_frames, dim=0) + + # Save final merged video with unique name per config + config_name = config.get('model_id', 'unknown') + # Clean up the config name to make it filesystem-safe + if '/' in config_name: + config_name = config_name.split('/')[-1] + if '\\' in config_name: + config_name = config_name.split('\\')[-1] + # Remove file extensions + config_name = config_name.replace('.safetensors', '').replace('.ckpt', '').replace('.pth', '') + + # Create unique filename for this config and video (merged from all prompts) + # Include config filename, model_id, and video name for clear identification + output_filename = f"{config_filename}_{video_filename}_{config_name}_merged_{len(prompts)}prompts.mp4" + + # Clean filename to ensure it's filesystem-safe + import re + output_filename = re.sub(r'[<>:"/\\|?*]', '_', output_filename) # Replace invalid chars + output_filename = output_filename[:200] + '.mp4' if len(output_filename) > 200 else output_filename # Limit length + + output_video_path = os.path.join(output_dir, output_filename) + + # Ensure output directory exists before writing video + os.makedirs(output_dir, exist_ok=True) + print(f" Saving video to: {output_video_path}") + print(f" Output directory: {output_dir}") + print(f" Directory exists: {os.path.exists(output_dir)}") + + # Convert to uint8 and save + final_video_uint8 = (final_video * 255).clamp(0, 255).to(torch.uint8) + + try: + write_video(output_video_path, final_video_uint8, fps=30) + print(f" โœ… Saved merged video: {output_video_path}") + except Exception as video_error: + print(f" โŒ Failed to save video: {video_error}") + print(f" Output path: {output_video_path}") + print(f" Path length: {len(output_video_path)}") + print(f" Parent dir exists: {os.path.exists(os.path.dirname(output_video_path))}") + print(f" Video shape: {final_video_uint8.shape}") + raise video_error + finally: + # CRITICAL: Clean up large video tensors immediately after saving + print(" Cleaning up video tensors from system RAM...") + try: + del final_video_uint8 + del final_video + del all_output_frames + # Force immediate garbage collection + import gc + gc.collect() + print(" โœ… Video tensors cleaned from system RAM") + except Exception as cleanup_err: + print(f" โš ๏ธ Warning: Video tensor cleanup failed: {cleanup_err}") + + # Calculate overall FPS metrics CORRECTLY + total_processing_time = sum(segment_times) # Sum of actual processing times + overall_fps = total_frames / total_processing_time if total_processing_time > 0 else 0 + min_fps = min(fps_metrics) if fps_metrics else 0 + max_fps = max(fps_metrics) if fps_metrics else 0 + avg_fps = sum(fps_metrics) / len(fps_metrics) if fps_metrics else 0 + + # Calculate consistency metrics + if len(fps_metrics) > 1: + variance = sum((fps - avg_fps) ** 2 for fps in fps_metrics) / len(fps_metrics) + std_dev_fps = variance ** 0.5 + cv_percent = (std_dev_fps / avg_fps) * 100 if avg_fps > 0 else 0 + else: + std_dev_fps = 0 + cv_percent = 0 + + print(f" Overall Performance:") + print(f" Total processing time: {total_processing_time:.2f}s") + print(f" Overall FPS: {overall_fps:.2f}") + print(f" FPS range: {min_fps:.2f} - {max_fps:.2f}") + print(f" Average FPS: {avg_fps:.2f}") + print(f" Standard Deviation: {std_dev_fps:.2f}") + print(f" Coefficient of Variation: {cv_percent:.1f}%") + + # Create comprehensive metadata for JSON storage + video_metadata = { + 'video_info': { + 'config_filename': config_filename, + 'video_filename': video_filename, + 'config_name': config_name, + 'output_filename': output_filename, + 'output_path': output_video_path, + 'total_frames': total_frames, + 'prompts_used': len(prompts), + 'prompts': prompts, + 'processing_date': datetime.datetime.now().isoformat(), + }, + 'config_details': { + 'model_id': config.get('model_id', 'Unknown'), + 'width': config.get('width', 'Unknown'), + 'height': config.get('height', 'Unknown'), + 'num_inference_steps': config.get('num_inference_steps', 'Unknown'), + 'guidance_scale': config.get('guidance_scale', 'Unknown'), + 'negative_prompt': config.get('negative_prompt', ''), + }, + 'performance_metrics': { + 'overall_fps': overall_fps, + 'min_fps': min_fps, + 'max_fps': max_fps, + 'avg_fps': avg_fps, + 'std_dev_fps': std_dev_fps, + 'cv_percent': cv_percent, + 'segment_fps': fps_metrics, + 'segment_times': segment_times, + 'total_processing_time': total_processing_time, + 'segments_processed': len(fps_metrics) + }, + 'technical_details': { + 'timeout_seconds': timeout_seconds, + 'start_time': start_time, + 'end_time': time.time(), + 'success': True + } + } + + # Save metadata as JSON file alongside video + json_filename = output_filename.replace('.mp4', '_metadata.json') + json_path = os.path.join(output_dir, json_filename) + + try: + import json + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(video_metadata, f, indent=2, ensure_ascii=False) + print(f" โœ… Saved metadata: {json_filename}") + except Exception as json_error: + print(f" โš ๏ธ Warning: Failed to save metadata JSON: {json_error}") + + # Return result with FPS metrics and output file + return { + 'output_file': output_filename, # Just the filename, not full path + 'metadata_file': json_filename, # JSON metadata filename + 'fps_metrics': { + 'overall_fps': overall_fps, + 'min_fps': min_fps, + 'max_fps': max_fps, + 'avg_fps': avg_fps, + 'std_dev_fps': std_dev_fps, + 'cv_percent': cv_percent, + 'segment_fps': fps_metrics, + 'segment_times': segment_times, # Add segment times for debugging + 'total_processing_time': total_processing_time + } + } + + except TimeoutError as e: + print(f" TIMEOUT ERROR: {e}") + return None + except Exception as e: + print(f" ERROR processing: {e}") + import traceback + traceback.print_exc() + return None + + finally: + # Always cleanup, even if there was an error + print(" Cleaning up pipeline...") + try: + if stream is not None: + # Use the dedicated cleanup function + cleanup_pipeline(stream) + stream = None + + except Exception as cleanup_error: + print(f" Warning: Cleanup failed: {cleanup_error}") + finally: + # Force cleanup regardless of any errors + cleanup_gpu_memory() + print(" GPU memory cleanup completed") + +def get_memory_info() -> Dict[str, float]: + """Get current GPU and system memory information.""" + memory_info = {} + + # GPU memory + if torch.cuda.is_available(): + memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / (1024**3) # GB + memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / (1024**3) # GB + memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / (1024**3) # GB + + # System memory + try: + import psutil + memory_info['system_ram_used'] = psutil.virtual_memory().used / (1024**3) # GB + memory_info['system_ram_available'] = psutil.virtual_memory().available / (1024**3) # GB + memory_info['system_ram_percent'] = psutil.virtual_memory().percent + except ImportError: + # psutil not available, use basic info + import os + if hasattr(os, 'sysconf'): + try: + memory_info['system_ram_used'] = 'N/A (psutil not available)' + except: + pass + + return memory_info + +def log_memory_usage(stage: str): + """Log current memory usage for debugging.""" + memory_info = get_memory_info() + if memory_info: + gpu_info = f"GPU allocated: {memory_info['gpu_allocated']:.2f} GB, reserved: {memory_info['gpu_reserved']:.2f} GB, free: {memory_info['gpu_free']:.2f} GB" + + if 'system_ram_used' in memory_info and memory_info['system_ram_used'] != 'N/A (psutil not available)': + ram_info = f"RAM used: {memory_info['system_ram_used']:.2f} GB, available: {memory_info['system_ram_available']:.2f} GB ({memory_info['system_ram_percent']:.1f}%)" + print(f" Memory usage at {stage}: {gpu_info}, {ram_info}") + else: + print(f" Memory usage at {stage}: {gpu_info}") + +def create_video_with_text(input_path: str, output_path: str, text: str, + width: int, height: int, fontcolor: str = 'white') -> bool: + """ + Create a scaled video with text overlay using ffmpeg-python + + Args: + input_path: Path to input video + output_path: Path to output video + text: Text to overlay + width: Target width + height: Target height + fontcolor: Color of the text + + Returns: + True if successful, False otherwise + """ + if ffmpeg is None: + return False + + try: + # Create the processing pipeline + stream = ffmpeg.input(input_path) + + # Scale video to target size with padding + scaled = ffmpeg.filter( + stream, + 'scale', + width, height, + force_original_aspect_ratio='decrease' + ) + + padded = ffmpeg.filter( + scaled, + 'pad', + width, height, + '(ow-iw)/2', '(oh-ih)/2' + ) + + # Add text overlay - specify font file to avoid fontconfig issues on Windows + font_path = r'C:/Windows/Fonts/arial.ttf' + if os.path.exists(font_path): + with_text = ffmpeg.drawtext( + padded, + text=text, + fontfile=font_path, + fontcolor=fontcolor, + fontsize=20, + box=1, + boxcolor='black@0.8', + boxborderw=5, + x=10, + y='h-th-10' + ) + else: + # Fallback: try without fontfile (may use system default) + with_text = ffmpeg.drawtext( + padded, + text=text, + fontcolor=fontcolor, + fontsize=20, + box=1, + boxcolor='black@0.8', + boxborderw=5, + x=10, + y='h-th-10' + ) + + # Output with encoding settings + output = ffmpeg.output( + with_text, + output_path, + vcodec='libx264', + crf=23, + preset='medium' + ) + + # Run the pipeline with verbose error output + ffmpeg.run(output, overwrite_output=True, quiet=True) + return True + + except Exception as e: + print(f" Error processing video: {e}") + return False + +def create_placeholder_video(output_path: str, text: str, width: int, height: int, duration: float) -> bool: + """ + Create a placeholder video with text using ffmpeg-python + + Args: + output_path: Path to output video + text: Text to display + width: Video width + height: Video height + duration: Video duration in seconds + + Returns: + True if successful, False otherwise + """ + if ffmpeg is None: + return False + + try: + # Create gray color source + color_source = ffmpeg.input( + 'color=c=gray:s={}x{}:d={}'.format(width, height, duration), + f='lavfi' + ) + + # Add centered text - specify font file to avoid fontconfig issues on Windows + font_path = r'C:/Windows/Fonts/arial.ttf' + if os.path.exists(font_path): + with_text = ffmpeg.drawtext( + color_source, + text=text, + fontfile=font_path, + fontcolor='white', + fontsize=16, + box=1, + boxcolor='black@0.8', + boxborderw=5, + x='(w-text_w)/2', + y='(h-text_h)/2' + ) + else: + # Fallback: try without fontfile (may use system default) + with_text = ffmpeg.drawtext( + color_source, + text=text, + fontcolor='white', + fontsize=16, + box=1, + boxcolor='black@0.8', + boxborderw=5, + x='(w-text_w)/2', + y='(h-text_h)/2' + ) + + # Output + output = ffmpeg.output( + with_text, + output_path, + vcodec='libx264', + crf=23, + preset='medium' + ) + + ffmpeg.run(output, overwrite_output=True, quiet=True) + return True + + except Exception as e: + print(f" Error creating placeholder: {e}") + return False + +def create_video_wall( + results: List[Dict], + video_files: List[Path], + config_files: List[Path], + output_dir: str +) -> Optional[str]: + """ + Create a video wall showing original videos and processed results in a grid layout. + + Layout: + - Top row: Original videos + - Subsequent rows: Processed videos for each config + + Parameters + ---------- + results : List[Dict] + List of processing results + video_files : List[Path] + List of original video files + config_files : List[Path] + List of config files used + output_dir : str + Output directory for the video wall + + Returns + ------- + Optional[str] + Path to the created video wall, or None if failed + """ + + print("\nCreating video wall...") + + if ffmpeg is None: + print(" Skipping video wall creation - ffmpeg-python not available") + return None + + # Filter successful results only + successful_results = [r for r in results if r.get('success', False)] + + if not successful_results: + print(" No successful results found for video wall") + return None + + # Extract unique video and config names from successful results + video_names = sorted(list(set([r['video'] for r in successful_results]))) + config_names = sorted(list(set([r['config'] for r in successful_results]))) + + print(f" Creating grid: {len(config_names)+1} rows x {len(video_names)} columns") + print(f" Videos: {video_names}") + print(f" Configs: {config_names}") + + # Create video wall output path + wall_output = os.path.join(output_dir, "video_wall.mp4") + + # Create temporary directory for processing + temp_dir = os.path.join(output_dir, "temp_video_wall") + os.makedirs(temp_dir, exist_ok=True) + + # Standard resolution for all videos in the wall + wall_video_width = 512 + wall_video_height = 512 + + try: + # Step 1: Process all videos + processed_videos = {} + min_duration = float('inf') + + # Get minimum duration first + print(" Getting video durations...") + all_video_paths = [] + + # Collect original video paths + for video_name in video_names: + for video_file in video_files: + if video_file.stem == video_name: + all_video_paths.append(str(video_file)) + break + + # Collect result video paths + for result in successful_results: + if 'output_file' in result and result['output_file']: + output_file_path = os.path.join(output_dir, result['output_file']) + if os.path.exists(output_file_path): + all_video_paths.append(output_file_path) + + # Get minimum duration + for video_path in all_video_paths: + try: + probe = ffmpeg.probe(video_path) + duration = float(probe['format']['duration']) + min_duration = min(min_duration, duration) + except Exception as e: + print(f" Warning: Could not get duration for {video_path}: {e}") + + if min_duration == float('inf') or min_duration < 1: + min_duration = 10 + + print(f" Using duration: {min_duration:.2f} seconds") + + # Process original videos + print(" Processing original videos...") + for video_name in video_names: + # Find the original video file + original_video_path = None + for video_file in video_files: + if video_file.stem == video_name: + original_video_path = str(video_file) + break + + if not original_video_path: + print(f" Warning: Original video not found for {video_name}") + continue + + scaled_path = os.path.join(temp_dir, f"scaled_original_{video_name.replace(' ', '_')}.mp4") + text_content = f"ORIGINAL_{video_name.replace(' ', '_')}" + + success = create_video_with_text( + original_video_path, + scaled_path, + text_content, + wall_video_width, + wall_video_height, + 'white' + ) + + if success: + processed_videos[('original', video_name)] = scaled_path + print(f" Processed original {video_name}") + else: + print(f" Failed to process original {video_name}") + + # Process result videos with enhanced metadata + print(" Processing result videos with enhanced metadata...") + for result in successful_results: + config_name = result['config'] + video_name = result['video'] + + # Find the output file + if 'output_file' not in result or not result['output_file']: + print(f" Warning: No output file for {config_name}_{video_name}") + continue + + output_file_path = os.path.join(output_dir, result['output_file']) + if not os.path.exists(output_file_path): + print(f" Warning: Output file not found: {output_file_path}") + continue + + scaled_path = os.path.join(temp_dir, f"scaled_{config_name}_{video_name.replace(' ', '_')}.mp4") + + # Create enhanced metadata for this result + metadata = { + 'video_info': { + 'config_filename': config_name, + 'output_filename': result['output_file'], + 'total_frames': result.get('total_frames', 0) + }, + 'config_details': { + 'model_id': result.get('model_id', 'Unknown'), + 'width': result.get('resolution', 'Unknown').split('x')[0] if 'x' in str(result.get('resolution', '')) else 'Unknown', + 'height': result.get('resolution', 'Unknown').split('x')[1] if 'x' in str(result.get('resolution', '')) else 'Unknown' + }, + 'performance_metrics': { + 'overall_fps': result.get('fps_metrics', {}).get('overall_fps', 0), + 'avg_fps': result.get('fps_metrics', {}).get('avg_fps', 0), + 'total_processing_time': result.get('fps_metrics', {}).get('total_processing_time', 0) + } + } + + if create_enhanced_video_with_metadata: + success = create_enhanced_video_with_metadata( + output_file_path, + scaled_path, + metadata, + wall_video_width, + wall_video_height + ) + else: + # Fallback to regular text overlay if enhanced function not available + fps_metrics = result.get('fps_metrics', {}) + avg_fps = fps_metrics.get('avg_fps', 0) + text_content = f"{config_name}_{avg_fps:.1f}_FPS" + success = create_video_with_text( + output_file_path, + scaled_path, + text_content, + wall_video_width, + wall_video_height, + 'yellow' + ) + + if success: + processed_videos[(config_name, video_name)] = scaled_path + fps_metrics = result.get('fps_metrics', {}) + avg_fps = fps_metrics.get('avg_fps', 0) + print(f" Processed {config_name}_{video_name} (FPS: {avg_fps:.1f})") + else: + print(f" Failed to process {config_name}_{video_name}") + + # Step 2: Create the video wall grid with flipped layout + print(" Assembling video wall with flipped layout...") + print(f" New layout: {len(video_names)} rows x {len(config_names) + 1} columns") + print(f" Each row = one original video + all its config outputs") + print(f" Each column = one config (including original)") + + # Collect all input streams for the grid + input_streams = [] + + # Build grid row by row - each row represents one original video + for row_idx, video_name in enumerate(video_names): + row_streams = [] + + # For each column (config), add the corresponding video + for col_idx, config_name in enumerate(['original'] + config_names): + if (config_name, video_name) in processed_videos: + # Use existing processed video + stream = ffmpeg.input(processed_videos[(config_name, video_name)]) + row_streams.append(stream) + else: + # Create placeholder + placeholder_path = os.path.join(temp_dir, f"placeholder_{row_idx}_{col_idx}.mp4") + placeholder_text = f"MISSING_{config_name}_{video_name}".replace(' ', '_') + + success = create_placeholder_video( + placeholder_path, + placeholder_text, + wall_video_width, + wall_video_height, + min_duration + ) + + if success: + stream = ffmpeg.input(placeholder_path) + row_streams.append(stream) + else: + print(f" Failed to create placeholder for {config_name}_{video_name}") + return None + + # Horizontally stack this row (configs for one video) + if len(row_streams) > 1: + row_combined = ffmpeg.filter(row_streams, 'hstack', inputs=len(row_streams)) + else: + row_combined = row_streams[0] + + input_streams.append(row_combined) + + # Vertically stack all rows (different videos) + if len(input_streams) > 1: + final_grid = ffmpeg.filter(input_streams, 'vstack', inputs=len(input_streams)) + else: + final_grid = input_streams[0] + + # Trim to minimum duration and output + trimmed = ffmpeg.filter(final_grid, 'trim', duration=min_duration) + final_output = ffmpeg.output( + trimmed, + wall_output, + vcodec='libx264', + crf=20, + preset='medium' + ) + + print(" Running ffmpeg to create final video wall...") + ffmpeg.run(final_output, overwrite_output=True, quiet=True) + + print(f" โœ… Video wall created: {wall_output}") + + # Clean up temporary files + print(" Cleaning up temporary files...") + try: + import shutil + shutil.rmtree(temp_dir) + except Exception as e: + print(f" Warning: Could not clean up temp directory: {e}") + + return wall_output + + except Exception as e: + print(f" Error creating video wall: {e}") + import traceback + traceback.print_exc() + return None + +def main( + configs: str, + videos: str, + output: str = "./output-test", + prompts: Optional[str] = None, + timeout_seconds: int = 300, # 5 minutes timeout per video + resume: Optional[str] = None, # Resume from existing output directory + retry_failed: bool = False # Whether to retry previously failed combinations +): + """ + Test multiple configs against multiple videos. + + Parameters + ---------- + configs : str + Directory containing YAML configuration files + videos : str + Directory containing video files + output : str, optional + Output directory for results, by default "./output-test" + prompts : str, optional + Text file containing individual prompts (one per line) + timeout_seconds : int, optional + Maximum time to spend processing each video, by default 300 (5 minutes) + resume : str, optional + Resume from existing output directory (full path to directory) + retry_failed : bool, optional + Whether to retry previously failed combinations, by default False + """ + + # Handle resume vs new run + if resume: + if not os.path.exists(resume): + print(f"โŒ Error: Resume directory does not exist: {resume}") + return + output_dir = resume + print(f"๐Ÿ”„ Resuming from existing directory: {output_dir}") + else: + # Create timestamped output directory + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = f"{output}/{timestamp}" + os.makedirs(output_dir, exist_ok=True) + print(f"๐Ÿ†• Starting new run in directory: {output_dir}") + + print("StreamDiffusion Multi-Config Test Suite") + print("=" * 50) + print(f"Configs directory: {configs}") + print(f"Videos directory: {videos}") + print(f"Output directory: {output_dir}") + if prompts: + print(f"Prompts file: {prompts}") + if resume: + print(f"Resume mode: โœ… Enabled") + if retry_failed: + print(f"Retry failed: โœ… Enabled (will retry previously failed combinations)") + else: + print(f"Retry failed: โŒ Disabled (will skip previously failed combinations)") + print("=" * 50) + + # Load prompts if provided + prompt_list = None + if prompts: + if not os.path.exists(prompts): + print(f"Error: Prompts file not found: {prompts}") + return + prompt_list = load_prompts(prompts) + + # Scan for completed work if resuming + existing_results = [] + if resume: + existing_results = scan_completed_work(output_dir) + + # Get config files + config_dir = Path(configs) + config_files = list(config_dir.glob("*.yaml")) + list(config_dir.glob("*.yml")) + if not config_files: + print(f"Error: No YAML config files found in {configs}") + return + + # Get video files + video_dir = Path(videos) + video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv'] + video_files = [] + for ext in video_extensions: + video_files.extend(video_dir.glob(f"*{ext}")) + + if not video_files: + print(f"Error: No video files found in {videos}") + return + + print(f"\nFound {len(config_files)} configs and {len(video_files)} videos") + + # Calculate total work + total_combinations = len(config_files) * len(video_files) + + print(f"\n๐Ÿ“Š Work Summary:") + print(f" Total combinations: {total_combinations}") + if resume and len(existing_results) > 0: + print(f" Previously completed: {len(existing_results)}") + print(f" Will check each combination for existing output files...") + + # Store results for performance summary (start with existing results) + results = existing_results.copy() + + # Process each config against each video + for config_path in config_files: + print(f"\n{'='*60}") + print(f"Processing config: {config_path.stem}") + print(f"{'='*60}") + + # Aggressive cleanup before starting new config to ensure clean slate + print(f"Pre-config cleanup for {config_path.stem}...") + for cleanup_round in range(2): + cleanup_gpu_memory() + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + import time + time.sleep(0.1) + log_memory_usage(f"before config {config_path.stem}") + + try: + config = load_config(config_path) + print(f"Config loaded: {config.get('model_id', 'Unknown')}") + print(f"Resolution: {config.get('width', 'Unknown')}x{config.get('height', 'Unknown')}") + except Exception as e: + print(f"Error loading config {config_path}: {e}") + continue + + for video_path in video_files: + print(f"\nProcessing video: {video_path.name}") + + # Check if output already exists (pass existing_results for resume functionality) + if check_output_exists(output_dir, config_path.stem, video_path.stem, config, prompt_list, existing_results, retry_failed): + print(f" โญ๏ธ Skipping - already processed") + continue + + try: + print(f" Starting video preprocessing...") + # Preprocess video + video = preprocess_video( + str(video_path), + config.get('width', 512), + config.get('height', 512) + ) + print(f" Video preprocessing completed, shape: {video.shape}") + + # Force CPU memory cleanup after video preprocessing + import gc + gc.collect() + + print(f" Starting video processing with config...") + # Process with config and get performance data + result = process_video_with_config( + video=video, + config=config, + prompts=prompt_list, + output_dir=output_dir, + config_filename=config_path.stem, + video_filename=video_path.stem, + timeout_seconds=timeout_seconds + ) + + print(f" Video processing completed, result: {'Success' if result else 'Failed'}") + + # Store video information before cleanup + video_frames = video.shape[0] + + # Clean up after each video to prevent memory accumulation + print(f" Cleaning up after video {video_path.name}...") + + # Clean up video tensor from CPU memory + del video + import gc + gc.collect() + + # Clean up GPU memory + cleanup_gpu_memory() + log_memory_usage(f"after video {video_path.name} completion") + + # If retrying, remove the old failed result to avoid duplicates + if retry_failed: + results = [r for r in results if not (r['config'] == config_path.stem and r['video'] == video_path.stem)] + + # Store result for summary + if result: + results.append({ + 'config': config_path.stem, + 'video': video_path.stem, + 'model_id': config.get('model_id', 'Unknown'), + 'resolution': f"{config.get('width', 'Unknown')}x{config.get('height', 'Unknown')}", + 'total_frames': video_frames, + 'prompts_used': len(prompt_list) if prompt_list else 1, + 'success': True, + 'output_file': result['output_file'], # Store merged video file + 'fps_metrics': result['fps_metrics'] + }) + print(f" โœ… Successfully processed {video_path.name}") + else: + results.append({ + 'config': config_path.stem, + 'video': video_path.stem, + 'model_id': config.get('model_id', 'Unknown'), + 'resolution': f"{config.get('width', 'Unknown')}x{config.get('height', 'Unknown')}", + 'total_frames': video_frames, + 'prompts_used': len(prompt_list) if prompt_list else 1, + 'success': False, + 'error': 'Processing failed' + }) + print(f" โŒ Failed to process {video_path.name}") + + except Exception as e: + print(f" Failed to process {video_path.name}: {e}") + import traceback + traceback.print_exc() + + # Store video frames count if video was successfully loaded + video_frames = video.shape[0] if 'video' in locals() else 0 + + # Clean up video tensor even on failure + try: + del video + import gc + gc.collect() + cleanup_gpu_memory() + except: + pass # video might not be defined if error occurred during preprocessing + + # If retrying, remove the old failed result to avoid duplicates + if retry_failed: + results = [r for r in results if not (r['config'] == config_path.stem and r['video'] == video_path.stem)] + + results.append({ + 'config': config_path.stem, + 'video': video_path.stem, + 'model_id': config.get('model_id', 'Unknown'), + 'resolution': f"{config.get('width', 'Unknown')}x{config.get('height', 'Unknown')}", + 'total_frames': 0, + 'prompts_used': len(prompt_list) if prompt_list else 1, + 'success': False, + 'error': str(e) + }) + continue + + # Force cleanup between configs to ensure memory is cleared + print(f"\nCleaning up after config {config_path.stem}...") + try: + # Multiple rounds of cleanup to ensure everything is freed + for cleanup_round in range(3): # Multiple cleanup rounds like in main.py + cleanup_gpu_memory() + + # Additional cleanup to ensure no lingering references + import gc + gc.collect() + + # Force CUDA synchronization to ensure all operations are complete + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Small delay between cleanup rounds + import time + time.sleep(0.1) + + log_memory_usage(f"after config {config_path.stem} completion") + + except Exception as cleanup_error: + print(f" Warning: Config cleanup failed: {cleanup_error}") + + print(f" Config {config_path.stem} cleanup completed") + + # Update progress tracking + total_processed = len([r for r in results if r['success']]) + total_failed = len([r for r in results if not r['success']]) + print(f" Progress: {total_processed + total_failed}/{total_combinations} total tests") + print(f" Successful: {total_processed}, Failed: {total_failed}") + + # Generate performance summary + generate_performance_summary(results, output_dir, prompt_list) + + # Create video wall if we have successful results + # Try enhanced video wall first (with JSON metadata), fallback to regular wall + wall_path = None + try: + from enhanced_video_wall import create_enhanced_video_wall + print("\n๐ŸŽฌ Attempting to create enhanced video wall with JSON metadata...") + wall_path = create_enhanced_video_wall(output_dir, os.path.join(output_dir, "enhanced_video_wall.mp4")) + if wall_path: + print(f"โœ… Enhanced video wall created: {wall_path}") + else: + print("โš ๏ธ Enhanced video wall creation failed, falling back to regular wall") + except ImportError: + print("โš ๏ธ Enhanced video wall module not available, using regular wall") + except Exception as e: + print(f"โš ๏ธ Enhanced video wall creation failed: {e}, falling back to regular wall") + + # Fallback to regular video wall if enhanced version failed + if not wall_path: + wall_path = create_video_wall(results, video_files, config_files, output_dir) + + # Final summary + total_successful = len([r for r in results if r['success']]) + total_failed = len([r for r in results if not r['success']]) + + print(f"\n๐ŸŽฏ Final Summary:") + print(f" Total combinations: {total_combinations}") + if resume: + print(f" Previously completed: {len(existing_results)}") + print(f" Newly processed: {total_successful + total_failed - len(existing_results)}") + print(f" Total successful: {total_successful}") + print(f" Total failed: {total_failed}") + print(f" Success rate: {(total_successful/total_combinations*100):.1f}%") + + print(f"\n๐Ÿ“ Results saved to: {output_dir}") + print(f"๐Ÿ“Š Performance summary: {output_dir}/performance_summary.txt") + print(f"๐Ÿ“‹ Detailed CSV: {output_dir}/detailed_results.csv") + + if wall_path: + print(f"๐ŸŽฌ Video wall created: {wall_path}") + else: + print("๐ŸŽฌ Video wall creation skipped or failed") + + if resume: + print(f"\n๐Ÿ’ก To resume again later, use: --resume \"{output_dir}\"") + +def generate_performance_summary(results: List[Dict], output_dir: str, prompts: Optional[List[str]] = None): + """Generate a performance summary comparing all configs.""" + + if not results: + print("No results to summarize") + return + + # Define successful_results early to avoid UnboundLocalError + successful_results = [r for r in results if r['success']] + + summary_file = os.path.join(output_dir, "performance_summary.txt") + + with open(summary_file, 'w', encoding='utf-8') as f: + f.write("StreamDiffusion Multi-Config Performance Summary\n") + f.write("=" * 60 + "\n\n") + + if prompts: + f.write(f"Using {len(prompts)} individual prompts with temporal splitting\n\n") + + # Overall statistics + total_tests = len(results) + successful_tests = sum(1 for r in results if r['success']) + failed_tests = total_tests - successful_tests + + f.write(f"Overall Results:\n") + f.write(f" Total tests: {total_tests}\n") + f.write(f" Successful: {successful_tests}\n") + f.write(f" Failed: {failed_tests}\n") + f.write(f" Success rate: {successful_tests/total_tests*100:.1f}%\n\n") + + # Quick Performance Summary Table + if successful_results: + f.write("Quick Performance Summary:\n") + f.write("-" * 120 + "\n") + f.write(f"{'Config':<25} {'Video':<15} {'Resolution':<12} {'Overall FPS':<12} {'Avg FPS':<10} {'Min FPS':<10} {'Max FPS':<10} {'Frames':<8}\n") + f.write("-" * 120 + "\n") + + for result in successful_results: + fps = result['fps_metrics'] + f.write(f"{result['config']:<25} {result['video']:<15} {result['resolution']:<12} " + f"{fps['overall_fps']:<12.2f} {fps['avg_fps']:<10.2f} {fps['min_fps']:<10.2f} " + f"{fps['max_fps']:<10.2f} {result['total_frames']:<8}\n") + f.write("-" * 120 + "\n\n") + + # Results by config + configs = set(r['config'] for r in results) + f.write("Results by Config:\n") + f.write("-" * 40 + "\n") + + for config in sorted(configs): + config_results = [r for r in results if r['config'] == config] + config_success = sum(1 for r in config_results if r['success']) + + f.write(f"\n{config}:\n") + f.write(f" Tests: {len(config_results)}/{config_success} successful\n") + f.write(f" Model: {config_results[0]['model_id']}\n") + f.write(f" Resolution: {config_results[0]['resolution']}\n") + + # List videos processed + for result in config_results: + status = "โœ…" if result['success'] else "โŒ" + f.write(f" {status} {result['video']}") + if result['success']: + f.write(f" ({result['total_frames']} frames, {result['prompts_used']} prompts)") + f.write(f" - Overall FPS: {result['fps_metrics']['overall_fps']:.2f}") + f.write(f", Min FPS: {result['fps_metrics']['min_fps']:.2f}") + f.write(f", Max FPS: {result['fps_metrics']['max_fps']:.2f}") + f.write(f", Avg FPS: {result['fps_metrics']['avg_fps']:.2f}") + else: + f.write(f" - {result.get('error', 'Unknown error')}") + f.write("\n") + + # Results by video + f.write(f"\nResults by Video:\n") + f.write("-" * 40 + "\n") + + videos = set(r['video'] for r in results) + for video in sorted(videos): + video_results = [r for r in results if r['video'] == video] + video_success = sum(1 for r in video_results if r['success']) + + f.write(f"\n{video}:\n") + f.write(f" Tests: {len(video_results)}/{video_success} successful\n") + + for result in video_results: + status = "โœ…" if result['success'] else "โŒ" + f.write(f" {status} {result['config']} ({result['resolution']})") + if result['success']: + f.write(f" - {result['total_frames']} frames") + f.write(f" - Overall FPS: {result['fps_metrics']['overall_fps']:.2f}") + f.write(f", Min FPS: {result['fps_metrics']['min_fps']:.2f}") + f.write(f", Max FPS: {result['fps_metrics']['max_fps']:.2f}") + f.write(f", Avg FPS: {result['fps_metrics']['avg_fps']:.2f}") + f.write("\n") + + # Summary of successful outputs + if successful_results: + f.write(f"\nGenerated Outputs:\n") + f.write("-" * 40 + "\n") + + for result in successful_results: + if 'output_file' in result and result['output_file']: + # Extract just the filename from the full path for display + output_filename = os.path.basename(result['output_file']) + f.write(f"โœ… {output_filename}\n") + else: + f.write(f"โœ… {result['config']}_{result['video']}: No output files generated\n") + + # Performance Analysis and Rankings + if successful_results: + f.write(f"\nPerformance Analysis:\n") + f.write("-" * 40 + "\n") + + # Overall FPS Rankings + f.write(f"\nOverall FPS Rankings (Higher is Better):\n") + fps_rankings = sorted(successful_results, key=lambda x: x['fps_metrics']['overall_fps'], reverse=True) + for i, result in enumerate(fps_rankings): + f.write(f" {i+1:2d}. {result['config']:30s} - {result['fps_metrics']['overall_fps']:6.2f} FPS") + f.write(f" (Avg: {result['fps_metrics']['avg_fps']:5.2f}, Range: {result['fps_metrics']['min_fps']:5.2f}-{result['fps_metrics']['max_fps']:5.2f})\n") + + # Average FPS Rankings + f.write(f"\nAverage FPS Rankings (Higher is Better):\n") + avg_fps_rankings = sorted(successful_results, key=lambda x: x['fps_metrics']['avg_fps'], reverse=True) + for i, result in enumerate(avg_fps_rankings): + f.write(f" {i+1:2d}. {result['config']:30s} - {result['fps_metrics']['avg_fps']:6.2f} FPS") + f.write(f" (Overall: {result['fps_metrics']['overall_fps']:5.2f}, Range: {result['fps_metrics']['min_fps']:5.2f}-{result['fps_metrics']['max_fps']:5.2f})\n") + + # Performance Statistics + f.write(f"\nPerformance Statistics:\n") + overall_fps_values = [r['fps_metrics']['overall_fps'] for r in successful_results] + avg_fps_values = [r['fps_metrics']['avg_fps'] for r in successful_results] + min_fps_values = [r['fps_metrics']['min_fps'] for r in successful_results] + max_fps_values = [r['fps_metrics']['max_fps'] for r in successful_results] + + f.write(f" Overall FPS - Best: {max(overall_fps_values):.2f}, Worst: {min(overall_fps_values):.2f}, Mean: {sum(overall_fps_values)/len(overall_fps_values):.2f}\n") + f.write(f" Average FPS - Best: {max(avg_fps_values):.2f}, Worst: {min(avg_fps_values):.2f}, Mean: {sum(avg_fps_values)/len(avg_fps_values):.2f}\n") + f.write(f" Min FPS - Best: {max(min_fps_values):.2f}, Worst: {min(min_fps_values):.2f}, Mean: {sum(min_fps_values)/len(min_fps_values):.2f}\n") + f.write(f" Max FPS - Best: {max(max_fps_values):.2f}, Worst: {min(max_fps_values):.2f}, Mean: {sum(max_fps_values)/len(max_fps_values):.2f}\n") + + # Performance by Resolution + f.write(f"\nPerformance by Resolution:\n") + resolutions = set(r['resolution'] for r in successful_results) + for resolution in sorted(resolutions): + res_results = [r for r in successful_results if r['resolution'] == resolution] + res_overall_fps = [r['fps_metrics']['overall_fps'] for r in res_results] + res_avg_fps = [r['fps_metrics']['avg_fps'] for r in res_results] + + f.write(f" {resolution}:\n") + f.write(f" Configs tested: {len(res_results)}\n") + f.write(f" Best Overall FPS: {max(res_overall_fps):.2f} ({[r['config'] for r in res_results if r['fps_metrics']['overall_fps'] == max(res_overall_fps)][0]})\n") + f.write(f" Best Average FPS: {max(res_avg_fps):.2f} ({[r['config'] for r in res_results if r['fps_metrics']['avg_fps'] == max(res_avg_fps)][0]})\n") + f.write(f" Mean Overall FPS: {sum(res_overall_fps)/len(res_overall_fps):.2f}\n") + f.write(f" Mean Average FPS: {sum(res_avg_fps)/len(res_avg_fps):.2f}\n") + + # Performance by Video + f.write(f"\nPerformance by Video:\n") + videos = set(r['video'] for r in successful_results) + for video in sorted(videos): + vid_results = [r for r in successful_results if r['video'] == video] + vid_overall_fps = [r['fps_metrics']['overall_fps'] for r in vid_results] + vid_avg_fps = [r['fps_metrics']['avg_fps'] for r in vid_results] + + f.write(f" {video}:\n") + f.write(f" Configs tested: {len(vid_results)}\n") + f.write(f" Best Overall FPS: {max(vid_overall_fps):.2f} ({[r['config'] for r in vid_results if r['fps_metrics']['overall_fps'] == max(vid_overall_fps)][0]})\n") + f.write(f" Best Average FPS: {max(vid_avg_fps):.2f} ({[r['config'] for r in vid_results if r['fps_metrics']['avg_fps'] == max(vid_avg_fps)][0]})\n") + f.write(f" Mean Overall FPS: {sum(vid_overall_fps)/len(vid_overall_fps):.2f}\n") + f.write(f" Mean Average FPS: {sum(vid_avg_fps)/len(vid_avg_fps):.2f}\n") + + # Best Config per Video Summary + f.write(f"\nBest Config per Video (Overall FPS):\n") + f.write("-" * 60 + "\n") + for video in sorted(videos): + vid_results = [r for r in successful_results if r['video'] == video] + best_config = max(vid_results, key=lambda x: x['fps_metrics']['overall_fps']) + fps = best_config['fps_metrics'] + f.write(f" {video:<20} -> {best_config['config']:<25} ({fps['overall_fps']:6.2f} FPS, Avg: {fps['avg_fps']:5.2f})\n") + + f.write(f"\nBest Config per Video (Average FPS):\n") + f.write("-" * 60 + "\n") + for video in sorted(videos): + vid_results = [r for r in successful_results if r['video'] == video] + best_config = max(vid_results, key=lambda x: x['fps_metrics']['avg_fps']) + fps = best_config['fps_metrics'] + f.write(f" {video:<20} -> {best_config['config']:<25} ({fps['avg_fps']:6.2f} FPS, Overall: {fps['overall_fps']:5.2f})\n") + + # Performance Improvement Analysis + f.write(f"\nPerformance Improvement Analysis:\n") + f.write("-" * 60 + "\n") + + # Find the best overall config + best_overall_config = max(successful_results, key=lambda x: x['fps_metrics']['overall_fps']) + best_overall_fps = best_overall_config['fps_metrics']['overall_fps'] + + f.write(f"Best Overall Config: {best_overall_config['config']} ({best_overall_fps:.2f} FPS)\n\n") + f.write(f"Performance vs Best (Overall FPS):\n") + + for result in sorted(successful_results, key=lambda x: x['fps_metrics']['overall_fps'], reverse=True): + if result['config'] != best_overall_config['config']: + improvement = ((best_overall_fps - result['fps_metrics']['overall_fps']) / result['fps_metrics']['overall_fps']) * 100 + f.write(f" {result['config']:<30s} - {result['fps_metrics']['overall_fps']:6.2f} FPS") + f.write(f" ({improvement:+.1f}% vs best)\n") + + # Performance vs Average + avg_overall_fps = sum(r['fps_metrics']['overall_fps'] for r in successful_results) / len(successful_results) + f.write(f"\nPerformance vs Average ({avg_overall_fps:.2f} FPS):\n") + + for result in sorted(successful_results, key=lambda x: x['fps_metrics']['overall_fps'], reverse=True): + vs_avg = ((result['fps_metrics']['overall_fps'] - avg_overall_fps) / avg_overall_fps) * 100 + f.write(f" {result['config']:<30s} - {result['fps_metrics']['overall_fps']:6.2f} FPS") + f.write(f" ({vs_avg:+.1f}% vs avg)\n") + + # Performance Consistency Analysis + f.write(f"\nPerformance Consistency Analysis:\n") + f.write("-" * 60 + "\n") + f.write("Configs ranked by FPS stability (lower variance = more stable):\n") + + # Calculate FPS variance for each config + consistency_data = [] + for result in successful_results: + segment_fps = result['fps_metrics']['segment_fps'] + if len(segment_fps) > 1: + mean_fps = sum(segment_fps) / len(segment_fps) + variance = sum((fps - mean_fps) ** 2 for fps in segment_fps) / len(segment_fps) + std_dev = variance ** 0.5 + cv = (std_dev / mean_fps) * 100 # Coefficient of variation + else: + variance = 0 + std_dev = 0 + cv = 0 + + consistency_data.append({ + 'config': result['config'], + 'mean_fps': result['fps_metrics']['avg_fps'], + 'std_dev': std_dev, + 'cv': cv, + 'min_fps': result['fps_metrics']['min_fps'], + 'max_fps': result['fps_metrics']['max_fps'], + 'fps_range': result['fps_metrics']['max_fps'] - result['fps_metrics']['min_fps'] + }) + + # Sort by coefficient of variation (lower = more stable) + consistency_data.sort(key=lambda x: x['cv']) + + for i, data in enumerate(consistency_data): + f.write(f" {i+1:2d}. {data['config']:<30s} - CV: {data['cv']:5.1f}%") + f.write(f" (Std: {data['std_dev']:5.2f}, Range: {data['fps_range']:5.2f})\n") + f.write(f" Mean: {data['mean_fps']:6.2f} FPS, Min: {data['min_fps']:5.2f}, Max: {data['max_fps']:5.2f}\n") + + # Recommendations + f.write(f"\nRecommendations:\n") + f.write("-" * 60 + "\n") + + # Best overall performance + f.write(f"๐Ÿ† Best Overall Performance: {best_overall_config['config']}\n") + f.write(f" - Highest sustained FPS: {best_overall_config['fps_metrics']['overall_fps']:.2f}\n") + f.write(f" - Best for: Maximum throughput scenarios\n\n") + + # Most consistent performance + most_consistent = consistency_data[0] + f.write(f"๐Ÿ“Š Most Consistent Performance: {most_consistent['config']}\n") + f.write(f" - Lowest variance: {most_consistent['cv']:.1f}% CV\n") + f.write(f" - Best for: Real-time applications requiring stable frame rates\n\n") + + # Best value (good performance + consistency) + # Find config with good balance of performance and consistency + balanced_configs = [] + for data in consistency_data: + # Normalize both metrics (0-1 scale) + perf_score = data['mean_fps'] / max(d['mean_fps'] for d in consistency_data) + consistency_score = 1 - (data['cv'] / max(d['cv'] for d in consistency_data)) + balanced_score = (perf_score + consistency_score) / 2 + balanced_configs.append((data['config'], balanced_score, data['mean_fps'], data['cv'])) + + balanced_configs.sort(key=lambda x: x[1], reverse=True) + best_balanced = balanced_configs[0] + f.write(f"โš–๏ธ Best Balanced (Performance + Consistency): {best_balanced[0]}\n") + f.write(f" - Balanced score: {best_balanced[1]:.3f}\n") + f.write(f" - Performance: {best_balanced[2]:.2f} FPS, Consistency: {best_balanced[3]:.1f}% CV\n") + f.write(f" - Best for: Production environments requiring both speed and reliability\n\n") + + # Performance tiers + f.write(f"๐Ÿ“ˆ Performance Tiers:\n") + fps_values = [r['fps_metrics']['overall_fps'] for r in successful_results] + fps_values.sort(reverse=True) + + if len(fps_values) >= 3: + top_tier = fps_values[:len(fps_values)//3] + mid_tier = fps_values[len(fps_values)//3:2*len(fps_values)//3] + bottom_tier = fps_values[2*len(fps_values)//3:] + + f.write(f" ๐Ÿฅ‡ Top Tier (โ‰ฅ{min(top_tier):.2f} FPS): {len(top_tier)} configs\n") + f.write(f" ๐Ÿฅˆ Mid Tier ({min(mid_tier):.2f}-{max(mid_tier):.2f} FPS): {len(mid_tier)} configs\n") + f.write(f" ๐Ÿฅ‰ Bottom Tier (<{max(bottom_tier):.2f} FPS): {len(bottom_tier)} configs\n") + + f.write(f"\n๐Ÿ’ก Usage Tips:\n") + f.write(f" - For maximum speed: Use {best_overall_config['config']}\n") + f.write(f" - For stable real-time: Use {most_consistent['config']}\n") + f.write(f" - For production: Use {best_balanced[0]}\n") + f.write(f" - Consider resolution impact: Higher resolutions generally reduce FPS\n") + f.write(f" - Monitor VRAM usage: Some configs may be more memory-efficient\n") + + # Best Configs by Use Case + f.write(f"\n๐ŸŽฏ Best Configs by Use Case:\n") + f.write("-" * 60 + "\n") + + # Speed-focused use cases + f.write(f"๐Ÿš€ Speed-Focused Use Cases:\n") + speed_configs = sorted(successful_results, key=lambda x: x['fps_metrics']['overall_fps'], reverse=True)[:3] + for i, result in enumerate(speed_configs): + fps = result['fps_metrics'] + f.write(f" {i+1}. {result['config']:<25} - {fps['overall_fps']:6.2f} FPS") + f.write(f" (Avg: {fps['avg_fps']:5.2f}, CV: {fps['cv_percent']:4.1f}%)\n") + + # Consistency-focused use cases + f.write(f"\n๐Ÿ“Š Consistency-Focused Use Cases:\n") + consistency_configs = sorted(successful_results, key=lambda x: x['fps_metrics']['cv_percent'])[:3] + for i, result in enumerate(consistency_configs): + fps = result['fps_metrics'] + f.write(f" {i+1}. {result['config']:<25} - CV: {fps['cv_percent']:4.1f}%") + f.write(f" (Avg: {fps['avg_fps']:5.2f} FPS, Overall: {fps['overall_fps']:5.2f})\n") + + # Balanced use cases + f.write(f"\nโš–๏ธ Balanced Use Cases (Speed + Consistency):\n") + for i, (config, score, mean_fps, cv) in enumerate(balanced_configs[:3]): + f.write(f" {i+1}. {config:<25} - Score: {score:.3f}") + f.write(f" (Avg: {mean_fps:5.2f} FPS, CV: {cv:4.1f}%)\n") + + # Resolution-specific recommendations + f.write(f"\n๐Ÿ–ผ๏ธ Resolution-Specific Recommendations:\n") + for resolution in sorted(resolutions): + res_results = [r for r in successful_results if r['resolution'] == resolution] + best_speed = max(res_results, key=lambda x: x['fps_metrics']['overall_fps']) + best_consistency = min(res_results, key=lambda x: x['fps_metrics']['cv_percent']) + + f.write(f" {resolution}:\n") + f.write(f" - Best Speed: {best_speed['config']} ({best_speed['fps_metrics']['overall_fps']:.2f} FPS)\n") + f.write(f" - Best Consistency: {best_consistency['config']} (CV: {best_consistency['fps_metrics']['cv_percent']:.1f}%)\n") + + print(f"Performance summary saved to: {summary_file}") + + # Also save as CSV for easy analysis + csv_file = os.path.join(output_dir, "detailed_results.csv") + import csv + + with open(csv_file, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + + # Header + writer.writerow([ + 'Config', 'Video', 'Model ID', 'Resolution', 'Total Frames', + 'Prompts Used', 'Success', 'Output File', 'Error Message', + 'Overall FPS', 'Min FPS', 'Max FPS', 'Avg FPS', 'Std Dev FPS', 'CV %' + ]) + + # Data rows + for result in results: + if result['success']: + fps_metrics = result['fps_metrics'] + # Format output file path - just the filename for clarity + output_file_str = os.path.basename(result.get('output_file', '')) + + writer.writerow([ + result['config'], + result['video'], + result['model_id'], + result['resolution'], + result['total_frames'], + result['prompts_used'], + "Yes" if result['success'] else "No", + output_file_str, + result.get('error', ''), + f"{fps_metrics['overall_fps']:.2f}", + f"{fps_metrics['min_fps']:.2f}", + f"{fps_metrics['max_fps']:.2f}", + f"{fps_metrics['avg_fps']:.2f}", + f"{fps_metrics['std_dev_fps']:.2f}", + f"{fps_metrics['cv_percent']:.1f}" + ]) + else: + writer.writerow([ + result['config'], + result['video'], + result['model_id'], + result['resolution'], + result['total_frames'], + result['prompts_used'], + "Yes" if result['success'] else "No", + "", + result.get('error', ''), + "N/A", + "N/A", + "N/A", + "N/A", + "N/A", + "N/A" + ]) + + print(f"Detailed results saved to: {csv_file}") + +if __name__ == "__main__": + try: + fire.Fire(main) + except KeyboardInterrupt: + print("\nInterrupted by user") + cleanup_and_exit() + except Exception as e: + print(f"\nUnexpected error in main: {e}") + import traceback + traceback.print_exc() + cleanup_and_exit() + sys.exit(1) + finally: + cleanup_and_exit() diff --git a/multi_test/prompts.txt b/multi_test/prompts.txt new file mode 100644 index 00000000..92ed8a60 --- /dev/null +++ b/multi_test/prompts.txt @@ -0,0 +1,6 @@ +naruto anime +marble statue, high detail, roman empire, stones, garden +playstation, graphics, cutscene, ps2, shenmu +stained glass dream fantasy +Disney Aladdin, cartoon, pixar cg +1930s pinup girl \ No newline at end of file diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index 85123c59..4f49bd43 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -1,3 +1,5 @@ + +import hashlib import logging from enum import Enum from pathlib import Path @@ -75,15 +77,30 @@ def __init__(self, engine_dir: str): 'loader': lambda path, cuda_stream, **kwargs: str(path) } } - + + def _lora_signature(self, lora_dict: Dict[str, float]) -> str: + """Create a short, stable signature for a set of LoRAs. + + Uses sorted basenames and weights, hashed to a short hex to avoid + long/invalid paths while keeping cache keys stable across runs. + """ + # Build canonical string of basename:weight pairs + parts = [] + for path, weight in sorted(lora_dict.items(), key=lambda x: str(x[0])): + base = Path(str(path)).name # basename only + parts.append(f"{base}:{weight}") + canon = "|".join(parts) + h = hashlib.sha1(canon.encode("utf-8")).hexdigest()[:10] + return f"{len(lora_dict)}-{h}" + def get_engine_path(self, engine_type: EngineType, model_id_or_path: str, max_batch_size: int, min_batch_size: int, mode: str, - use_lcm_lora: bool, use_tiny_vae: bool, + lora_dict: Optional[Dict[str, float]] = None, ipadapter_scale: Optional[float] = None, ipadapter_tokens: Optional[int] = None, controlnet_model_id: Optional[str] = None, @@ -114,7 +131,7 @@ def get_engine_path(self, base_name = maybe_path.stem if maybe_path.exists() else model_id_or_path # Create prefix (from wrapper.py lines 1005-1013) - prefix = f"{base_name}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" + prefix = f"{base_name}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" # IP-Adapter differentiation: add type and (optionally) tokens # Keep scale out of identity for runtime control, but include a type flag to separate caches @@ -122,6 +139,10 @@ def get_engine_path(self, prefix += f"--fid" if ipadapter_tokens is not None: prefix += f"--tokens{ipadapter_tokens}" + + # Fused Loras - use concise hashed signature to avoid long/invalid paths + if lora_dict is not None and len(lora_dict) > 0: + prefix += f"--lora-{self._lora_signature(lora_dict)}" prefix += f"--mode-{mode}" @@ -287,7 +308,6 @@ def get_or_load_controlnet_engine(self, max_batch_size=max_batch_size, min_batch_size=min_batch_size, mode="", # Not used for ControlNet - use_lcm_lora=False, # Not used for ControlNet use_tiny_vae=False, # Not used for ControlNet controlnet_model_id=model_id ) diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index ce1124df..3f0a3f69 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -360,6 +360,29 @@ def reset_cuda_graph(self): self.graph = None def infer(self, feed_dict, stream, use_cuda_graph=False): + # Filter inputs to only those the engine actually exposes to avoid binding errors + try: + allowed_inputs = set() + for idx in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(idx) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + allowed_inputs.add(name) + + # Drop any extra keys (e.g., text_embeds/time_ids) that the engine was not built to accept + if allowed_inputs: + filtered_feed_dict = {k: v for k, v in feed_dict.items() if k in allowed_inputs} + if len(filtered_feed_dict) != len(feed_dict): + missing = [k for k in feed_dict.keys() if k not in allowed_inputs] + if missing: + logger.debug( + "TensorRT Engine: filtering unsupported inputs %s (allowed=%s)", + missing, sorted(list(allowed_inputs)) + ) + feed_dict = filtered_feed_dict + except Exception: + # Be permissive if engine query fails; proceed with original dict + pass + for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 3ff5cb90..ac8b6f20 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -100,7 +100,6 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'lora_dict': config.get('lora_dict'), 'mode': config.get('mode', 'img2img'), 'output_type': config.get('output_type', 'pil'), - 'lcm_lora_id': config.get('lcm_lora_id'), 'vae_id': config.get('vae_id'), 'device': config.get('device', 'cuda'), 'dtype': _parse_dtype(config.get('dtype', 'float16')), @@ -111,7 +110,7 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'acceleration': config.get('acceleration', 'tensorrt'), 'do_add_noise': config.get('do_add_noise', True), 'device_ids': config.get('device_ids'), - 'use_lcm_lora': config.get('use_lcm_lora', True), + 'use_lcm_lora': config.get('use_lcm_lora'), # Backwards compatibility 'use_tiny_vae': config.get('use_tiny_vae', True), 'enable_similar_image_filter': config.get('enable_similar_image_filter', False), 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98), @@ -124,6 +123,8 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'engine_dir': config.get('engine_dir', 'engines'), 'normalize_prompt_weights': config.get('normalize_prompt_weights', True), 'normalize_seed_weights': config.get('normalize_seed_weights', True), + 'scheduler': config.get('scheduler', 'lcm'), + 'sampler': config.get('sampler', 'normal'), 'compile_engines_only': config.get('compile_engines_only', False), } if 'controlnets' in config and config['controlnets']: diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 1bca0bc2..0c372533 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -4,7 +4,7 @@ import numpy as np import PIL.Image import torch -from diffusers import LCMScheduler, StableDiffusionPipeline +from diffusers import LCMScheduler, TCDScheduler, StableDiffusionPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, @@ -36,8 +36,11 @@ def __init__( use_denoising_batch: bool = True, frame_buffer_size: int = 1, cfg_type: Literal["none", "full", "self", "initialize"] = "self", + lora_dict: Optional[Dict[str, float]] = None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + scheduler: Literal["lcm", "tcd"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", ) -> None: self.device = torch.device(device) self.dtype = torch_dtype @@ -53,6 +56,8 @@ def __init__( self.denoising_steps_num = len(t_index_list) self.cfg_type = cfg_type + self.scheduler_type = scheduler + self.sampler_type = sampler # Detect model type detection_result = detect_model(pipe.unet, pipe) @@ -61,7 +66,16 @@ def __init__( self.is_turbo = detection_result['is_turbo'] self.detection_confidence = detection_result['confidence'] - if use_denoising_batch: + # TCD scheduler is incompatible with denoising batch optimization due to Strategic Stochastic Sampling + # Force sequential processing for TCD + if scheduler == "tcd": + logger.info("TCD scheduler detected: Disabling denoising batch optimization for compatibility") + logger.info("TCD now supports ControlNet through proper hook processing") + self.use_denoising_batch = False + self.batch_size = frame_buffer_size + self.trt_unet_batch_size = frame_buffer_size + elif use_denoising_batch: + self.use_denoising_batch = True self.batch_size = self.denoising_steps_num * frame_buffer_size if self.cfg_type == "initialize": self.trt_unet_batch_size = ( @@ -74,13 +88,12 @@ def __init__( else: self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size else: + self.use_denoising_batch = False self.trt_unet_batch_size = self.frame_bff_size self.batch_size = frame_buffer_size self.t_list = t_index_list - self.do_add_noise = do_add_noise - self.use_denoising_batch = use_denoising_batch self.similar_image_filter = False self.similar_filter = SimilarImageFilter() @@ -89,8 +102,8 @@ def __init__( self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) - - self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.scheduler = self._initialize_scheduler(scheduler, sampler, pipe.scheduler.config) + self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.vae = pipe.vae @@ -126,7 +139,31 @@ def __init__( self._cached_batch_size: Optional[int] = None self._cached_cfg_type: Optional[str] = None self._cached_guidance_scale: Optional[float] = None + + def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): + """Initialize scheduler based on type and sampler configuration.""" + + # TODO: More testing and validation required on samplers. + # Map sampler types to configuration parameters + sampler_config = { + "simple": {"timestep_spacing": "linspace"}, + "sgm uniform": {"timestep_spacing": "trailing"}, + "normal": {}, # Default configuration + "ddim": {"timestep_spacing": "leading"}, + "beta": {"beta_schedule": "scaled_linear"}, + "karras": {}, # Karras sigmas handled per scheduler + } + + # Get sampler-specific configuration + sampler_params = sampler_config.get(sampler_type, {}) + if scheduler_type == "lcm": + return LCMScheduler.from_config(config, **sampler_params) + elif scheduler_type == "tcd": + return TCDScheduler.from_config(config, **sampler_params) + else: + logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM") + return LCMScheduler.from_config(config, **sampler_params) def _check_unet_tensorrt(self) -> bool: """Cache TensorRT detection to avoid repeated hasattr calls""" @@ -191,21 +228,8 @@ def _build_sdxl_conditioning(self, batch_size: int) -> Dict[str, torch.Tensor]: 'time_ids': add_time_ids } - def load_lcm_lora( - self, - pretrained_model_name_or_path_or_dict: Union[ - str, Dict[str, torch.Tensor] - ] = "latent-consistency/lcm-lora-sdv1-5", - adapter_name: Optional[Any] = None, - **kwargs, - ) -> None: - # Check for SDXL compatibility - if self.is_sdxl: - return - - self._load_lora_with_offline_fallback( - pretrained_model_name_or_path_or_dict, adapter_name, **kwargs - ) + + def load_lora( self, @@ -445,12 +469,11 @@ def prepare( self.stock_noise = torch.zeros_like(self.init_noise) + # Handle scheduler-specific scaling calculations c_skip_list = [] c_out_list = [] for timestep in self.sub_timesteps: - c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( - timestep - ) + c_skip, c_out = self._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) @@ -495,7 +518,9 @@ def prepare( #NOTE: this is a hack. Pipeline needs a major refactor along with stream parameter updater. self.update_prompt(prompt) - if not self.use_denoising_batch: + # Only collapse tensors to scalars for LCM non-batched mode + # TCD needs to keep tensor dimensions for iteration + if not self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): self.sub_timesteps_tensor = self.sub_timesteps_tensor[0] self.alpha_prod_t_sqrt = self.alpha_prod_t_sqrt[0] self.beta_prod_t_sqrt = self.beta_prod_t_sqrt[0] @@ -504,6 +529,19 @@ def prepare( self.c_skip = self.c_skip.to(self.device) self.c_out = self.c_out.to(self.device) + def _get_scheduler_scalings(self, timestep): + """Get LCM/TCD-specific scaling factors for boundary conditions.""" + if isinstance(self.scheduler, LCMScheduler): + c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + return c_skip, c_out + else: + # TCD and other schedulers don't use boundary condition scaling like LCM + # They handle scaling internally in their step() method + # Return tensors that are compatible with torch.stack() + c_skip = torch.tensor(1.0, device=self.device, dtype=self.dtype) + c_out = torch.tensor(1.0, device=self.device, dtype=self.dtype) + return c_skip, c_out + @torch.no_grad() def update_prompt(self, prompt: str) -> None: self._param_updater.update_stream_params( @@ -525,6 +563,33 @@ def get_normalize_seed_weights(self) -> bool: + + + def set_scheduler(self, scheduler: Literal["lcm", "tcd"] = None, sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None) -> None: + """ + Change the scheduler and/or sampler at runtime. + + Parameters + ---------- + scheduler : str, optional + The scheduler type to use ("lcm" or "tcd"). If None, keeps current scheduler. + sampler : str, optional + The sampler type to use. If None, keeps current sampler. + """ + if scheduler is not None: + self.scheduler_type = scheduler + if sampler is not None: + self.sampler_type = sampler + + self.scheduler = self._initialize_scheduler(self.scheduler_type, self.sampler_type, self.pipe.scheduler.config) + logger.info(f"Scheduler changed to {self.scheduler_type} with {self.sampler_type} sampler") + + def _uses_lcm_logic(self) -> bool: + """Return True if scheduler uses LCM-style consistency boundary-condition math.""" + return isinstance(self.scheduler, LCMScheduler) + + + def add_noise( self, original_samples: torch.Tensor, @@ -543,7 +608,6 @@ def scheduler_step_batch( x_t_latent_batch: torch.Tensor, idx: Optional[int] = None, ) -> torch.Tensor: - # TODO: use t_list to select beta_prod_t_sqrt if idx is None: F_theta = ( x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch @@ -556,7 +620,6 @@ def scheduler_step_batch( denoised_batch = ( self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch ) - return denoised_batch def unet_step( @@ -565,7 +628,6 @@ def unet_step( t_list: Union[torch.Tensor, list[int]], idx: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) t_list = torch.concat([t_list[0:1], t_list], dim=0) @@ -783,6 +845,113 @@ def unet_step( return denoised_batch, model_pred + def _call_unet( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + """Call the UNet, handling SDXL kwargs and TensorRT engine calling convention.""" + added_cond_kwargs = added_cond_kwargs or {} + if self.is_sdxl: + try: + # Detect TensorRT engine vs PyTorch UNet + is_tensorrt_engine = hasattr(self.unet, 'engine') and hasattr(self.unet, 'stream') + if is_tensorrt_engine: + out = self.unet( + sample, + timestep, + encoder_hidden_states, + **added_cond_kwargs, + )[0] + else: + out = self.unet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + except Exception as e: + logger.error(f"[PIPELINE] _call_unet: SDXL UNet call failed: {e}") + import traceback + traceback.print_exc() + raise + else: + out = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + )[0] + return out + + def _unet_predict_noise_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + cfg_mode: Literal["none", "full", "self", "initialize"], + ) -> torch.Tensor: + """ + Compute noise prediction from UNet with classifier-free guidance applied. + This function does not apply any scheduler math; it only returns the guided noise. + """ + # Build latent batch for CFG + if self.guidance_scale > 1.0 and cfg_mode == "full": + latent_with_uc = torch.cat([latent_model_input, latent_model_input], dim=0) + elif self.guidance_scale > 1.0 and cfg_mode == "initialize": + latent_with_uc = torch.cat([latent_model_input[0:1], latent_model_input], dim=0) + else: + latent_with_uc = latent_model_input + + # SDXL added conditioning replication to match batch + added_cond_kwargs: Dict[str, torch.Tensor] = {} + if self.is_sdxl and hasattr(self, 'add_text_embeds') and hasattr(self, 'add_time_ids'): + if self.add_text_embeds is not None and self.add_time_ids is not None: + batch_size = latent_with_uc.shape[0] + if self.guidance_scale > 1.0 and cfg_mode == "initialize": + add_text_embeds = torch.cat([ + self.add_text_embeds[0:1], + self.add_text_embeds[1:2].repeat(batch_size - 1, 1), + ], dim=0) + add_time_ids = torch.cat([ + self.add_time_ids[0:1], + self.add_time_ids[1:2].repeat(batch_size - 1, 1), + ], dim=0) + elif self.guidance_scale > 1.0 and cfg_mode == "full": + repeat_factor = batch_size // 2 + add_text_embeds = self.add_text_embeds.repeat(repeat_factor, 1) + add_time_ids = self.add_time_ids.repeat(repeat_factor, 1) + else: + add_text_embeds = ( + self.add_text_embeds[1:2].repeat(batch_size, 1) + if self.add_text_embeds.shape[0] > 1 + else self.add_text_embeds.repeat(batch_size, 1) + ) + add_time_ids = ( + self.add_time_ids[1:2].repeat(batch_size, 1) + if self.add_time_ids.shape[0] > 1 + else self.add_time_ids.repeat(batch_size, 1) + ) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # Call UNet + model_pred = self._call_unet( + sample=latent_with_uc, + timestep=timestep, + encoder_hidden_states=self.prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ) + + # Apply CFG + if self.guidance_scale > 1.0 and cfg_mode == "full": + noise_pred_uncond, noise_pred_text = model_pred.chunk(2) + guided = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + return guided + else: + return model_pred + def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: image_tensors = image_tensors.to( device=self.device, @@ -808,23 +977,19 @@ def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: prev_latent_batch = self.x_t_latent_buffer - - - if self.use_denoising_batch: + # LCM supports our denoising-batch trick. TCD must use standard scheduler.step() sequentially + # but now properly processes ControlNet hooks through unet_step() + if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): t_list = self.sub_timesteps_tensor - if self.denoising_steps_num > 1: x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) - self.stock_noise = torch.cat( (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 ) - x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) if self.denoising_steps_num > 1: x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) - if self.do_add_noise: self.x_t_latent_buffer = ( self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] @@ -838,25 +1003,43 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: x_0_pred_out = x_0_pred_batch self.x_t_latent_buffer = None else: - self.init_noise = x_t_latent - for idx, t in enumerate(self.sub_timesteps_tensor): - t = t.view(1,).repeat(self.frame_bff_size,) - - x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) - - if idx < len(self.sub_timesteps_tensor) - 1: - if self.do_add_noise: - x_t_latent = self.alpha_prod_t_sqrt[ - idx + 1 - ] * x_0_pred + self.beta_prod_t_sqrt[ - idx + 1 - ] * torch.randn_like( - x_0_pred, device=self.device, dtype=self.dtype - ) + # Standard scheduler loop for TCD and non-batched LCM + sample = x_t_latent + for idx, timestep in enumerate(self.sub_timesteps_tensor): + # Ensure timestep tensor on device with correct dtype + if not isinstance(timestep, torch.Tensor): + t = torch.tensor(timestep, device=self.device, dtype=torch.long) + else: + t = timestep.to(self.device) + + # For TCD, use the same UNet calling logic as LCM to ensure ControlNet hooks are processed + if isinstance(self.scheduler, TCDScheduler): + # Use unet_step to process ControlNet hooks and get proper noise prediction + t_expanded = t.view(1,).repeat(self.frame_bff_size,) + x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) + + # Apply TCD scheduler step to the guided noise prediction + step_out = self.scheduler.step(model_pred, t, sample) + sample = getattr(step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out) + else: + # Original LCM logic for non-batched mode + t = t.view(1,).repeat(self.frame_bff_size,) + x_0_pred, model_pred = self.unet_step(sample, t, idx) + if idx < len(self.sub_timesteps_tensor) - 1: + if self.do_add_noise: + sample = self.alpha_prod_t_sqrt[ + idx + 1 + ] * x_0_pred + self.beta_prod_t_sqrt[ + idx + 1 + ] * torch.randn_like( + x_0_pred, device=self.device, dtype=self.dtype + ) + else: + sample = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred else: - x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred - x_0_pred_out = x_0_pred + sample = x_0_pred + x_0_pred_out = sample return x_0_pred_out @torch.no_grad() diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index cc901606..8c8b055e 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -674,6 +674,20 @@ def _update_seed(self, seed: int) -> None: # Reset stock_noise to match the new init_noise self.stream.stock_noise = torch.zeros_like(self.stream.init_noise) + def _get_scheduler_scalings(self, timestep): + """Get LCM/TCD-specific scaling factors for boundary conditions.""" + from diffusers import LCMScheduler + if isinstance(self.stream.scheduler, LCMScheduler): + c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + return c_skip, c_out + else: + # TCD and other schedulers don't use boundary condition scaling like LCM + # They handle scaling internally in their step() method + # Return tensors that are compatible with torch.stack() + c_skip = torch.tensor(1.0, device=self.stream.device, dtype=self.stream.dtype) + c_out = torch.tensor(1.0, device=self.stream.device, dtype=self.stream.dtype) + return c_skip, c_out + def _update_timestep_calculations(self) -> None: """Update timestep-dependent calculations based on current t_list.""" self.stream.sub_timesteps = [] @@ -692,7 +706,7 @@ def _update_timestep_calculations(self) -> None: c_skip_list = [] c_out_list = [] for timestep in self.stream.sub_timesteps: - c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + c_skip, c_out = self._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 6a997b12..a0b7c59c 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -75,7 +75,6 @@ def __init__( lora_dict: Optional[Dict[str, float]] = None, mode: Literal["img2img", "txt2img"] = "img2img", output_type: Literal["pil", "pt", "np", "latent"] = "pil", - lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, device: Literal["cpu", "cuda"] = "cuda", dtype: torch.dtype = torch.float16, @@ -86,7 +85,7 @@ def __init__( acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, device_ids: Optional[List[int]] = None, - use_lcm_lora: bool = True, + use_lcm_lora: Optional[bool] = None, # DEPRECATED: Backwards compatibility parameter use_tiny_vae: bool = True, enable_similar_image_filter: bool = False, similar_image_filter_threshold: float = 0.98, @@ -101,6 +100,9 @@ def __init__( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + # Scheduler and sampler options + scheduler: Literal["lcm", "tcd"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", # ControlNet options use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, @@ -124,6 +126,10 @@ def __init__( The model id or path to load. t_index_list : List[int] The t_index_list to use for inference. + min_batch_size : int, optional + The minimum batch size for inference, by default 1. + max_batch_size : int, optional + The maximum batch size for inference, by default 4. lora_dict : Optional[Dict[str, float]], optional The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. @@ -132,16 +138,14 @@ def __init__( txt2img or img2img, by default "img2img". output_type : Literal["pil", "pt", "np", "latent"], optional The output type of image, by default "pil". - lcm_lora_id : Optional[str], optional - The lcm_lora_id to load, by default None. - If None, the default LCM-LoRA - ("latent-consistency/lcm-lora-sdv1-5") will be used. vae_id : Optional[str], optional The vae_id to load, by default None. If None, the default TinyVAE ("madebyollin/taesd") will be used. device : Literal["cpu", "cuda"], optional The device to use for inference, by default "cuda". + device_ids : Optional[List[int]], optional + The device ids to use for DataParallel, by default None. dtype : torch.dtype, optional The dtype for inference, by default torch.float16. frame_buffer_size : int, optional @@ -159,8 +163,11 @@ def __init__( by default True. device_ids : Optional[List[int]], optional The device ids to use for DataParallel, by default None. - use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. + use_lcm_lora : Optional[bool], optional + DEPRECATED: Use lora_dict instead. For backwards compatibility only. + If True, automatically adds appropriate LCM LoRA to lora_dict based on model type. + SDXL models get "latent-consistency/lcm-lora-sdxl", others get "latent-consistency/lcm-lora-sdv1-5". + By default None (ignored). use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. enable_similar_image_filter : bool, optional @@ -179,19 +186,42 @@ def __init__( seed : int, optional The seed, by default 2. use_safety_checker : bool, optional - Whether to use safety checker or not, by default False. Only supported for TensorRT acceleration. + Whether to use safety checker or not, by default False. + skip_diffusion : bool, optional + Whether to skip diffusion and apply only preprocessing/postprocessing hooks, by default False. + engine_dir : Optional[Union[str, Path]], optional + Directory path for storing/loading TensorRT engines, by default "engines". + build_engines_if_missing : bool, optional + Whether to build TensorRT engines if they don't exist, by default True. normalize_prompt_weights : bool, optional Whether to normalize prompt weights in blending to sum to 1, by default True. When False, weights > 1 will amplify embeddings. normalize_seed_weights : bool, optional Whether to normalize seed weights in blending to sum to 1, by default True. When False, weights > 1 will amplify noise. + scheduler : Literal["lcm", "tcd"], optional + The scheduler type to use for denoising, by default "lcm". + sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional Whether to enable ControlNet support, by default False. controlnet_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional ControlNet configuration(s), by default None. Can be a single config dict or list of config dicts for multiple ControlNets. Each config should contain: model_id, preprocessor (optional), conditioning_scale, etc. + use_ipadapter : bool, optional + Whether to enable IPAdapter support, by default False. + ipadapter_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional + IPAdapter configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple IPAdapters. + image_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image preprocessing hooks, by default None. + image_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image postprocessing hooks, by default None. + latent_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent preprocessing hooks, by default None. + latent_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent postprocessing hooks, by default None. safety_checker_fallback_type : Literal["blank", "previous"], optional Whether to use a blank image or the previous image as a fallback, by default "previous". safety_checker_threshold: float, optional @@ -201,7 +231,10 @@ def __init__( """ if compile_engines_only: logger.info("compile_engines_only is True, will only compile engines and not load the model") - + + # Store use_lcm_lora for backwards compatibility processing in _load_model + self.use_lcm_lora = use_lcm_lora + self.sd_turbo = "turbo" in model_id_or_path self.use_controlnet = use_controlnet self.use_ipadapter = use_ipadapter @@ -255,18 +288,19 @@ def __init__( self.stream: StreamDiffusion = self._load_model( model_id_or_path=model_id_or_path, lora_dict=lora_dict, - lcm_lora_id=lcm_lora_id, vae_id=vae_id, t_index_list=t_index_list, acceleration=acceleration, do_add_noise=do_add_noise, - use_lcm_lora=use_lcm_lora, + use_lcm_lora=use_lcm_lora, # Deprecated:Backwards compatibility use_tiny_vae=use_tiny_vae, cfg_type=cfg_type, engine_dir=engine_dir, build_engines_if_missing=build_engines_if_missing, normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, + scheduler=scheduler, + sampler=sampler, use_controlnet=use_controlnet, controlnet_config=controlnet_config, use_ipadapter=use_ipadapter, @@ -802,33 +836,57 @@ def postprocess_image( def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: """ - Denormalize image tensor on GPU for efficiency - + Denormalize image tensor on GPU for efficiency. - Args: - image_tensor: Input tensor on GPU + Converts image tensor from diffusion range [-1, 1] to standard image range [0, 1]. + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in diffusion range [-1, 1], expected to be on GPU. - Returns: - Denormalized tensor on GPU, clamped to [0,1] + Returns + ------- + torch.Tensor + Denormalized tensor in range [0, 1], clamped and on GPU. """ return (image_tensor / 2 + 0.5).clamp(0, 1) def _normalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: - """Convert tensor from [0,1] (processor range) back to [-1,1] (diffusion range)""" + """ + Normalize tensor from processor range to diffusion range. + + Converts image tensor from standard image range [0, 1] to diffusion range [-1, 1]. + + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in standard image range [0, 1], expected to be on GPU. + + Returns + ------- + torch.Tensor + Normalized tensor in diffusion range [-1, 1], clamped and on GPU. + """ return (image_tensor * 2 - 1).clamp(-1, 1) def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Image]: """ - Optimized tensor to PIL conversion with minimal CPU transfers + Optimized tensor to PIL conversion with minimal CPU transfers. + Efficiently converts a batch of GPU tensors to PIL Images with minimal + CPU-GPU transfers and memory allocations. - Args: - image_tensor: Input tensor on GPU - + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in diffusion range [-1, 1], expected to be on GPU. + Shape should be (batch_size, channels, height, width). - Returns: - List of PIL Images + Returns + ------- + List[Image.Image] + List of PIL RGB images, one for each item in the batch. """ # Denormalize on GPU first denormalized = self._denormalize_on_gpu(image_tensor) @@ -866,6 +924,23 @@ def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Ima return pil_images def set_nsfw_fallback_img(self, height: int, width: int) -> None: + """ + Set the NSFW fallback image used when safety checker blocks content. + + Creates a black RGB image of the specified dimensions that will be returned + when the safety checker determines content should be blocked. + + Parameters + ---------- + height : int + Height of the fallback image in pixels. + width : int + Width of the fallback image in pixels. + + Returns + ------- + None + """ self.nsfw_fallback_img = Image.new("RGB", (height, width), (0, 0, 0)) if self.output_type == "pt": self.nsfw_fallback_img = torch.from_numpy(np.array(self.nsfw_fallback_img)).unsqueeze(0) @@ -877,7 +952,6 @@ def _load_model( model_id_or_path: str, t_index_list: List[int], lora_dict: Optional[Dict[str, float]] = None, - lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, @@ -888,6 +962,8 @@ def _load_model( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + scheduler: Literal["lcm", "tcd"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, use_ipadapter: bool = False, @@ -906,7 +982,7 @@ def _load_model( This method does the following: 1. Loads the model from the model_id_or_path. - 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. + 2. Loads and fuses LoRA models from lora_dict if provided. 3. Loads the VAE model from the vae_id if needed. 4. Enables acceleration if needed. 5. Prepares the model for inference. @@ -916,42 +992,72 @@ def _load_model( Parameters ---------- model_id_or_path : str - The model id or path to load. + The model id or path to load. Can be a Hugging Face model ID, local path to + safetensors/ckpt file, or directory containing model files. t_index_list : List[int] - The t_index_list to use for inference. + The t_index_list to use for inference. Specifies which denoising timesteps + to use from the diffusion schedule. lora_dict : Optional[Dict[str, float]], optional The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} - lcm_lora_id : Optional[str], optional - The lcm_lora_id to load, by default None. + Use this to load LCM LoRA: {'latent-consistency/lcm-lora-sdv1-5': 1.0} vae_id : Optional[str], optional - The vae_id to load, by default None. - acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional - The acceleration method, by default "tensorrt". - warmup : int, optional - The number of warmup steps to perform, by default 10. + The vae_id to load, by default None. If None, uses default TinyVAE + ("madebyollin/taesd" for SD1.5, "madebyollin/taesdxl" for SDXL). + acceleration : Literal["none", "xformers", "tensorrt"], optional + The acceleration method, by default "tensorrt". Note: docstring shows + "xfomers" and "sfast" but code uses "xformers". do_add_noise : bool, optional Whether to add noise for following denoising steps or not, by default True. use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. + DEPRECATED: Use lora_dict instead. For backwards compatibility only. + If True, automatically adds appropriate LCM LoRA to lora_dict based on model type. + SDXL models get "latent-consistency/lcm-lora-sdxl", others get "latent-consistency/lcm-lora-sdv1-5". + By default None (ignored). use_tiny_vae : bool, optional - Whether to use TinyVAE or not, by default True. - cfg_type : Literal["none", "full", "self", "initialize"], - optional + Whether to use TinyVAE or not, by default True. TinyVAE is a distilled, + smaller VAE model that provides faster encoding/decoding with minimal quality loss. + cfg_type : Literal["none", "full", "self", "initialize"], optional The cfg_type for img2img mode, by default "self". You cannot use anything other than "none" for txt2img mode. - seed : int, optional - The seed, by default 2. + engine_dir : Optional[Union[str, Path]], optional + Directory path for storing/loading TensorRT engines, by default "engines". + build_engines_if_missing : bool, optional + Whether to build TensorRT engines if they don't exist, by default True. + normalize_prompt_weights : bool, optional + Whether to normalize prompt weights in blending to sum to 1, by default True. + When False, weights > 1 will amplify embeddings. + normalize_seed_weights : bool, optional + Whether to normalize seed weights in blending to sum to 1, by default True. + When False, weights > 1 will amplify noise. + scheduler : Literal["lcm", "tcd"], optional + The scheduler type to use for denoising, by default "lcm". + sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional - Whether to apply ControlNet patch, by default False. + Whether to enable ControlNet support, by default False. controlnet_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional - ControlNet configuration(s), by default None. + ControlNet configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple ControlNets. use_ipadapter : bool, optional - Whether to apply IPAdapter patch, by default False. + Whether to enable IPAdapter support, by default False. ipadapter_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional - IPAdapter configuration(s), by default None. + IPAdapter configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple IPAdapters. + image_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image preprocessing hooks, by default None. + image_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image postprocessing hooks, by default None. + latent_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent preprocessing hooks, by default None. + latent_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent postprocessing hooks, by default None. + safety_checker_model_id : Optional[str], optional + Model ID for the safety checker, by default "Freepik/nsfw_image_detector". + compile_engines_only : bool, optional + Whether to only compile engines and not load the model, by default False. Returns ------- @@ -1045,6 +1151,11 @@ def _load_model( pipe.text_encoder = pipe.text_encoder.to(device=self.device) if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: pipe.text_encoder_2 = pipe.text_encoder_2.to(device=self.device) + # Move main pipeline components to device, but skip UNet for TensorRT + if hasattr(pipe, "unet") and pipe.unet is not None and acceleration != "tensorrt": + pipe.unet = pipe.unet.to(device=self.device) + if hasattr(pipe, "vae") and pipe.vae is not None and acceleration != "tensorrt": + pipe.vae = pipe.vae.to(device=self.device) # If we get here, the model loaded successfully - break out of retry loop logger.info(f"Model loading succeeded") @@ -1063,6 +1174,26 @@ def _load_model( self._is_sdxl = is_sdxl logger.info(f"_load_model: Detected model type: {model_type} (confidence: {confidence:.2f})") + + # DEPRECATED: THIS WILL LOAD LCM_LORA IF USE_LCM_LORA IS TRUE + # Validate backwards compatibility LCM LoRA selection using proper model detection + if hasattr(self, 'use_lcm_lora') and self.use_lcm_lora is not None: + if self.use_lcm_lora and not self.sd_turbo and lora_dict is not None: + # Determine correct LCM LoRA based on actual model detection + lcm_lora = "latent-consistency/lcm-lora-sdxl" if is_sdxl else "latent-consistency/lcm-lora-sdv1-5" + + # Add to lora_dict if not already present + if lcm_lora not in lora_dict: + lora_dict[lcm_lora] = 1.0 + logger.info(f"Added {lcm_lora} with scale 1.0 to lora_dict") + else: + logger.info(f"LCM LoRA {lcm_lora} already present in lora_dict with scale {lora_dict[lcm_lora]}") + else: + logger.info(f"LCM LoRA will not be loaded because use_lcm_lora is {self.use_lcm_lora} and sd_turbo is {self.sd_turbo}") + + # Remove use_lcm_lora from self + self.use_lcm_lora = None + logger.info(f"use_lcm_lora has been removed from self") stream = StreamDiffusion( pipe=pipe, @@ -1075,32 +1206,67 @@ def _load_model( frame_buffer_size=self.frame_buffer_size, use_denoising_batch=self.use_denoising_batch, cfg_type=cfg_type, + lora_dict=lora_dict, # We pass this to include loras in engine path names normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, + scheduler=scheduler, + sampler=sampler, ) - if not self.sd_turbo: - if use_lcm_lora: - if lcm_lora_id is not None: - stream.load_lcm_lora( - pretrained_model_name_or_path_or_dict=lcm_lora_id - ) - else: - stream.load_lcm_lora() - stream.fuse_lora() - if lora_dict is not None: - for lora_name, lora_scale in lora_dict.items(): - stream.load_lora(lora_name) - stream.fuse_lora(lora_scale=lora_scale) + + # Load and properly merge LoRA weights using the standard diffusers approach + lora_adapters_to_merge = [] + lora_scales_to_merge = [] + + # Collect all LoRA adapters and their scales from lora_dict + if lora_dict is not None: + for i, (lora_name, lora_scale) in enumerate(lora_dict.items()): + adapter_name = f"custom_lora_{i}" + logger.info(f"_load_model: Loading LoRA '{lora_name}' with scale {lora_scale}") + + try: + # Load LoRA weights with unique adapter name + stream.pipe.load_lora_weights(lora_name, adapter_name=adapter_name) + lora_adapters_to_merge.append(adapter_name) + lora_scales_to_merge.append(lora_scale) + logger.info(f"Successfully loaded LoRA adapter: {adapter_name}") + except Exception as e: + logger.error(f"Failed to load LoRA {lora_name}: {e}") + # Continue with other LoRAs even if one fails + continue + + # Merge all LoRA adapters using the proper diffusers method + if lora_adapters_to_merge: + try: + for adapter_name, scale in zip(lora_adapters_to_merge, lora_scales_to_merge): + logger.info(f"Merging individual LoRA: {adapter_name} with scale {scale}") + stream.pipe.fuse_lora(lora_scale=scale, adapter_names=[adapter_name]) + + # Clean up after individual merging + stream.pipe.unload_lora_weights() + logger.info("Successfully merged LoRAs individually") + + except Exception as fallback_error: + logger.error(f"LoRA merging fallback also failed: {fallback_error}") + logger.warning("Continuing without LoRA merging - LoRAs may not be applied correctly") + + # Clean up any partial state + try: + stream.pipe.unload_lora_weights() + except: + pass if use_tiny_vae: if vae_id is not None: - stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(dtype=pipe.dtype) + stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(dtype=pipe.dtype, device=self.device) else: # Use TAESD XL for SDXL models, regular TAESD for SD 1.5 taesd_model = "madebyollin/taesdxl" if is_sdxl else "madebyollin/taesd" - stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype) - + stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype, device=self.device) + elif acceleration != "tensorrt": + # For non-TensorRT acceleration, ensure VAE is on device if it wasn't moved earlier + if hasattr(pipe, "vae") and pipe.vae is not None: + pipe.vae = pipe.vae.to(device=self.device) try: if acceleration == "xformers": @@ -1229,8 +1395,8 @@ def _load_model( max_batch_size=self.max_batch_size, min_batch_size=self.min_batch_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1241,8 +1407,8 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1253,8 +1419,8 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1306,10 +1472,15 @@ def _load_model( except Exception: pass - # If using TensorRT with IP-Adapter, ensure processors and weights are installed BEFORE export - if use_ipadapter_trt and has_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): + # Note: LoRA weights have already been merged permanently during model loading + + # CRITICAL: Install IPAdapter module BEFORE TensorRT compilation to ensure processors are baked into engines + if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): try: from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig, IPAdapterType + logger.info("Installing IPAdapter module before TensorRT compilation...") + + # Use first config if list provided cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config ip_cfg = IPAdapterConfig( style_image_key=cfg.get('style_image_key') or 'ipadapter_main', @@ -1321,17 +1492,28 @@ def _load_model( type=IPAdapterType(cfg.get('type', "regular")), insightface_model_name=cfg.get('insightface_model_name'), ) - ip_module_for_export = IPAdapterModule(ip_cfg) - ip_module_for_export.install(stream) - setattr(stream, '_ipadapter_module', ip_module_for_export) - try: - logger.info("Installed IP-Adapter processors prior to TensorRT export") - except Exception: - pass + ip_module = IPAdapterModule(ip_cfg) + ip_module.install(stream) + # Expose for later updates + stream._ipadapter_module = ip_module + logger.info("IPAdapter module installed successfully before TensorRT compilation") + + # Cleanup after IPAdapter installation + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + except torch.cuda.OutOfMemoryError as oom_error: + logger.error(f"CUDA Out of Memory during early IPAdapter installation: {oom_error}") + logger.error("Try reducing batch size, using smaller models, or increasing GPU memory") + raise RuntimeError("Insufficient VRAM for IPAdapter installation. Consider using a GPU with more memory or reducing model complexity.") + except Exception: import traceback traceback.print_exc() - logger.error("Failed to pre-install IP-Adapter prior to TensorRT export") + logger.error("Failed to install IPAdapterModule before TensorRT compilation") + raise # NOTE: When IPAdapter is enabled, we must pass num_ip_layers. We cannot know it until after # installing processors in the export wrapper. We construct the wrapper first to discover it, @@ -1555,13 +1737,13 @@ def _load_model( logger.error(f"TensorRT VAE engine loading failed (non-OOM): {e}") raise e + # Safety checker engine (TensorRT-specific) safety_checker_path = engine_manager.get_engine_path( EngineType.SAFETY_CHECKER, model_id_or_path=safety_checker_model_id, max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, ) safety_checker_engine_exists = os.path.exists(safety_checker_path) @@ -1594,7 +1776,7 @@ def _load_model( cuda_stream, use_cuda_graph=True, ) - + if acceleration == "sfast": from streamdiffusion.acceleration.sfast import ( accelerate_with_stable_fast, @@ -1678,6 +1860,8 @@ def _load_model( logger.error("Failed to install ControlNetModule") raise + # IPAdapter module installation has been moved to before TensorRT compilation (see lines 1307-1345) + # This ensures processors are properly baked into the TensorRT engines if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): try: from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig, IPAdapterType @@ -1707,6 +1891,8 @@ def _load_model( logger.error("Failed to install IPAdapterModule") raise + # Note: LoRA weights have already been merged permanently during model loading + # Install pipeline hook modules (Phase 4: Configuration Integration) if image_preprocessing_config and image_preprocessing_config.get('enabled', True): try: @@ -1778,7 +1964,6 @@ def update_control_image(self, index: int, image: Union[str, Image.Image, torch. else: logger.debug("update_control_image: Skipping ControlNet update in skip diffusion mode") - def update_style_image(self, image: Union[str, Image.Image, torch.Tensor], is_stream: bool = False, style_key = "ipadapter_main") -> None: """Update IPAdapter style image""" if not self.use_ipadapter: @@ -1791,7 +1976,6 @@ def update_style_image(self, image: Union[str, Image.Image, torch.Tensor], is_st - def clear_caches(self) -> None: """Clear all cached prompt embeddings and seed noise tensors.""" self.stream._param_updater.clear_caches() @@ -2114,8 +2298,3 @@ def cleanup_engines_and_rebuild(self, reduce_batch_size: bool = True, reduce_res logger.info(f" Reduced resolution: {old_width}x{old_height} -> {self.width}x{self.height}") logger.info(" Next model load will rebuild engines with these smaller settings") - - - - -