diff --git a/src/classes/assets.py b/src/classes/assets.py index 559f4a609..0ae9a9676 100644 --- a/src/classes/assets.py +++ b/src/classes/assets.py @@ -90,6 +90,12 @@ def get_assets_path(file_path=None, create_paths=True): os.mkdir(asset_clipboard_folder) log.info("New clipboard folder: {}".format(asset_clipboard_folder)) + # Create asset ComfyUI output folder + asset_comfy_output_folder = os.path.join(asset_path, "comfyui-output") + if not os.path.exists(asset_comfy_output_folder): + os.mkdir(asset_comfy_output_folder) + log.info("New ComfyUI output folder: {}".format(asset_comfy_output_folder)) + return asset_path except Exception as ex: diff --git a/src/classes/clip_utils.py b/src/classes/clip_utils.py index 5c2fd3233..12a3dfea3 100644 --- a/src/classes/clip_utils.py +++ b/src/classes/clip_utils.py @@ -25,14 +25,50 @@ """ import logging +import json from fractions import Fraction from typing import Any, Mapping, Optional, Tuple +import openshot + from classes.app import get_app logger = logging.getLogger(__name__) +def apply_file_caption_to_clip(clip_data: Any, file_obj: Any, *, dedupe: bool = True) -> bool: + """Attach a Caption effect to clip_data when file metadata includes caption text.""" + if not isinstance(clip_data, Mapping): + return False + file_data = getattr(file_obj, "data", None) + if not isinstance(file_data, Mapping): + return False + caption_text = str(file_data.get("caption", "") or "").strip() + if not caption_text: + return False + + effects = clip_data.get("effects") + if not isinstance(effects, list): + effects = list(effects) if effects else [] + clip_data["effects"] = effects + + if dedupe: + for effect in effects: + if not isinstance(effect, Mapping): + continue + if str(effect.get("class_name", "")).lower() == "caption": + existing_text = str(effect.get("caption_text", "") or "").strip() + if existing_text == caption_text: + return False + + caption_effect = openshot.EffectInfo().CreateEffect("Caption") + caption_effect.Id(get_app().project.generate_id()) + caption_json = json.loads(caption_effect.Json()) + caption_json["caption_text"] = caption_text + effects.append(caption_json) + return True + + def _as_mapping(candidate: Any) -> Mapping[str, Any]: """Return dict-style metadata for clips, readers, or similar.""" if isinstance(candidate, Mapping): diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py new file mode 100644 index 000000000..7b2ece3d2 --- /dev/null +++ b/src/classes/comfy_client.py @@ -0,0 +1,849 @@ +""" + @file + @brief This file contains a small ComfyUI HTTP/WebSocket client. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . +""" + +import json +import os +import ssl +import base64 +import uuid +from datetime import datetime +import re +import socket +import struct +from urllib.error import HTTPError +from urllib.request import Request, urlopen +from urllib.parse import quote, urlencode +from urllib.parse import urlparse + +from classes import info +from classes.logger import log + + +class ComfyProgressSocket: + """Minimal WebSocket client for ComfyUI /ws progress events.""" + + def __init__(self, base_url, client_id): + self.base_url = str(base_url or "").rstrip("/") + self.client_id = str(client_id or "") + self.sock = None + self._connect() + + def _connect(self): + parsed = urlparse(self.base_url) + scheme = parsed.scheme.lower() + host = parsed.hostname + if not host: + raise RuntimeError("Invalid ComfyUI URL for websocket") + port = parsed.port or (443 if scheme == "https" else 80) + base_path = (parsed.path or "").rstrip("/") + ws_path = "{}/ws".format(base_path) if base_path else "/ws" + path = "{}?clientId={}".format(ws_path, quote(self.client_id)) + + raw = socket.create_connection((host, port), timeout=6.0) + if scheme == "https": + ctx = ssl.create_default_context() + raw = ctx.wrap_socket(raw, server_hostname=host) + # Allow slower remote/proxied websocket handshakes. + raw.settimeout(6.0) + self.sock = raw + + key = base64.b64encode(os.urandom(16)).decode("ascii") + req = ( + "GET {} HTTP/1.1\r\n" + "Host: {}:{}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Origin: {}://{}:{}\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" + "Sec-WebSocket-Key: {}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ).format(path, host, port, scheme, host, port, key) + self.sock.sendall(req.encode("utf-8")) + + response = self._recv_http_headers() + if " 101 " not in response.split("\r\n", 1)[0]: + raise RuntimeError("WebSocket upgrade failed: {}".format(response.split("\r\n", 1)[0])) + # Use short timeout for regular frame polling after successful handshake. + self.sock.settimeout(0.25) + + def close(self): + if self.sock is not None: + try: + self.sock.close() + except OSError: + pass + self.sock = None + + def poll_progress(self, prompt_id=None, max_messages=8): + """Read available frames and return latest progress payload. + + If prompt_id is provided, events are filtered to that prompt id. + If prompt_id is None/empty, events from any prompt on this websocket + are accepted (useful for meta-batch follow-up prompts). + """ + if not self.sock: + return None + latest = None + latest_rank = None + prompt_key = str(prompt_id or "").strip() + for _ in range(max_messages): + frame = self._recv_frame_nonblocking() + if frame is None: + break + opcode, payload = frame + + # Ping -> pong + if opcode == 0x9: + self._send_control_frame(0xA, payload) + continue + if opcode == 0x8: + self.close() + break + if opcode != 0x1: + continue + try: + msg = json.loads(payload.decode("utf-8")) + except Exception: + continue + if not isinstance(msg, dict): + continue + + event_type = msg.get("type") + event_data = msg.get("data", {}) + if event_type == "progress": + if not isinstance(event_data, dict): + continue + event_prompt = str(event_data.get("prompt_id", "")) + if prompt_key and (not event_prompt or event_prompt != prompt_key): + continue + value = float(event_data.get("value", 0.0)) + maximum = float(event_data.get("max", 0.0)) + if maximum > 0: + candidate = { + "percent": int(max(0, min(99, round((value / maximum) * 100.0)))), + "value": value, + "max": maximum, + "node": str(event_data.get("node", "")), + "type": "progress", + "prompt_id": event_prompt, + } + # Prefer unfinished progress, and prefer explicit "progress" events. + unfinished = (value + 1e-6) < maximum + rank = (1 if unfinished else 0, 2, maximum) + if latest is None or rank > latest_rank: + latest = candidate + latest_rank = rank + elif event_type == "progress_state": + # Newer Comfy events: data={prompt_id, nodes={node_id:{value,max}}} + if not isinstance(event_data, dict): + continue + event_prompt = str(event_data.get("prompt_id", "")) + if prompt_key and (not event_prompt or event_prompt != prompt_key): + continue + nodes = event_data.get("nodes", {}) + if not isinstance(nodes, dict): + continue + # Prefer unfinished node progress; only fall back to completed states. + best = None + best_rank = None + for node_id, node_state in nodes.items(): + if not isinstance(node_state, dict): + continue + value = float(node_state.get("value", 0.0)) + maximum = float(node_state.get("max", 0.0)) + if maximum > 0: + candidate = { + "percent": int(max(0, min(99, round((value / maximum) * 100.0)))), + "value": value, + "max": maximum, + "node": str(node_id), + "type": "progress_state", + "prompt_id": event_prompt, + } + unfinished = (value + 1e-6) < maximum + rank = (1 if unfinished else 0, maximum) + if best is None or rank > best_rank: + best = candidate + best_rank = rank + if best is not None: + rank = (best_rank[0], 1, float(best.get("max", 0.0))) + if latest is None or rank > latest_rank: + latest = best + latest_rank = rank + return latest + + def _recv_http_headers(self): + data = b"" + while b"\r\n\r\n" not in data: + chunk = self.sock.recv(4096) + if not chunk: + break + data += chunk + if len(data) > 65536: + break + return data.decode("utf-8", errors="replace") + + def _recv_exact(self, size): + chunks = [] + remaining = size + while remaining > 0: + chunk = self.sock.recv(remaining) + if not chunk: + raise RuntimeError("WebSocket connection closed") + chunks.append(chunk) + remaining -= len(chunk) + return b"".join(chunks) + + def _recv_frame_nonblocking(self): + try: + header = self.sock.recv(2) + if not header: + return None + except socket.timeout: + return None + except OSError: + return None + + if len(header) < 2: + return None + b1, b2 = header[0], header[1] + opcode = b1 & 0x0F + masked = (b2 & 0x80) != 0 + length = b2 & 0x7F + + if length == 126: + length = struct.unpack("!H", self._recv_exact(2))[0] + elif length == 127: + length = struct.unpack("!Q", self._recv_exact(8))[0] + + mask_key = b"" + if masked: + mask_key = self._recv_exact(4) + + payload = self._recv_exact(length) if length > 0 else b"" + if masked and payload: + payload = bytes(payload[i] ^ mask_key[i % 4] for i in range(len(payload))) + + return opcode, payload + + def _send_control_frame(self, opcode, payload=b""): + if self.sock is None: + return + payload = payload or b"" + first = 0x80 | (opcode & 0x0F) + # Client frames must be masked. + mask = os.urandom(4) + length = len(payload) + if length < 126: + header = bytes([first, 0x80 | length]) + elif length < (1 << 16): + header = bytes([first, 0x80 | 126]) + struct.pack("!H", length) + else: + header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length) + masked_payload = bytes(payload[i] ^ mask[i % 4] for i in range(length)) + self.sock.sendall(header + mask + masked_payload) + + +class ComfyClient: + """Minimal ComfyUI client using stdlib HTTP.""" + ERROR_MAX_CHARS = 1800 + + def __init__(self, base_url): + self.base_url = str(base_url or "").rstrip("/") + + @staticmethod + def _write_debug_error(payload): + debug_dir = info.COMFYUI_PATH + try: + os.makedirs(debug_dir, exist_ok=True) + debug_path = os.path.join(debug_dir, "debug_error.json") + with open(debug_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + except Exception: + log.warning("Failed writing Comfy debug error payload", exc_info=True) + + def _write_debug_prompt_payload(self, prompt_graph, client_id): + debug_dir = info.COMFYUI_PATH + try: + os.makedirs(debug_dir, exist_ok=True) + debug_path = os.path.join(debug_dir, "debug.json") + payload = { + "generated_at_utc": datetime.utcnow().isoformat() + "Z", + "comfy_url": self.base_url, + "client_id": str(client_id or ""), + "prompt": prompt_graph, + } + with open(debug_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + except Exception: + log.warning("Failed writing Comfy sent prompt payload", exc_info=True) + + @staticmethod + def open_progress_socket(base_url, client_id): + return ComfyProgressSocket(base_url, client_id) + + def ping(self, timeout=0.5): + with urlopen("{}/system_stats".format(self.base_url), timeout=timeout) as response: + return int(response.status) >= 200 and int(response.status) < 300 + + def queue_prompt(self, prompt_graph, client_id): + prompt_graph = self._rewrite_prompt_local_file_inputs(prompt_graph) + self._write_debug_prompt_payload(prompt_graph, client_id) + payload = json.dumps({"prompt": prompt_graph, "client_id": client_id}).encode("utf-8") + req = Request( + "{}/prompt".format(self.base_url), + data=payload, + method="POST", + headers={"Content-Type": "application/json"}, + ) + try: + with urlopen(req, timeout=10.0) as response: + data = json.loads(response.read().decode("utf-8")) + except HTTPError as ex: + details = "" + try: + error_data = json.loads(ex.read().decode("utf-8")) + ComfyClient._write_debug_error(error_data) + error_obj = error_data.get("error", {}) + if isinstance(error_obj, dict): + details = error_obj.get("type") or error_obj.get("message") or "" + else: + details = str(error_obj or "") + + node_errors = error_data.get("node_errors", {}) + node_error_text = self._format_node_errors(node_errors) + if node_error_text: + details = "{}\n{}".format(details or "prompt validation failed", node_error_text) + elif not details: + details = ComfyClient.summarize_error_text(error_data) + else: + details = "{}\n{}".format(details, ComfyClient.summarize_error_text(error_data)) + except Exception: + details = str(ex) + raise RuntimeError("ComfyUI prompt rejected: {}".format(ComfyClient.summarize_error_text(details))) + return data.get("prompt_id") + + def _rewrite_prompt_local_file_inputs(self, prompt_graph): + """Rewrite local absolute paths for image/video loader nodes to uploaded Comfy input refs.""" + if not isinstance(prompt_graph, dict): + return prompt_graph + rewritten = dict(prompt_graph) + + def _annotated(path_text): + path_text = str(path_text or "").strip() + return path_text.endswith("[input]") or path_text.endswith("[output]") or path_text.endswith("[temp]") + + for node_id, node in rewritten.items(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")) + inputs = node.get("inputs", {}) + if not isinstance(inputs, dict): + continue + + if class_type == "LoadImage": + image_path = str(inputs.get("image", "")).strip() + if image_path and os.path.isabs(image_path) and os.path.exists(image_path) and not _annotated(image_path): + uploaded = self.upload_input_file(image_path) + inputs["image"] = uploaded + node["inputs"] = inputs + rewritten[node_id] = node + log.debug("ComfyClient rewrote LoadImage input node=%s path=%s -> %s", str(node_id), image_path, uploaded) + elif class_type == "LoadVideo": + video_path = str(inputs.get("file", "")).strip() + if video_path and os.path.isabs(video_path) and os.path.exists(video_path) and not _annotated(video_path): + uploaded = self.upload_input_file(video_path) + inputs["file"] = uploaded + node["inputs"] = inputs + rewritten[node_id] = node + log.debug("ComfyClient rewrote LoadVideo input node=%s path=%s -> %s", str(node_id), video_path, uploaded) + elif class_type in ("VHS_LoadVideo", "VHS_LoadVideoPath", "VHS_LoadVideoFFmpegPath"): + video_path = str(inputs.get("video", "")).strip() + if video_path and os.path.isabs(video_path) and os.path.exists(video_path) and not _annotated(video_path): + uploaded = self.upload_input_file(video_path) + # VHS_LoadVideo expects a plain filename from Comfy input options. + # Path-based VHS loaders accept a plain relative path as well. + if uploaded.endswith(" [input]"): + uploaded = uploaded[:-8].strip() + inputs["video"] = uploaded + node["inputs"] = inputs + rewritten[node_id] = node + log.debug( + "ComfyClient rewrote %s input node=%s path=%s -> %s", + class_type, + str(node_id), + video_path, + uploaded, + ) + + return rewritten + + @staticmethod + def _format_node_errors(node_errors): + if not isinstance(node_errors, dict) or not node_errors: + return "" + lines = [] + max_lines = 8 + for node_id, err in node_errors.items(): + if len(lines) >= max_lines: + break + if not isinstance(err, dict): + lines.append("node {}: {}".format(node_id, str(err))) + continue + err_type = str(err.get("type", "")).strip() + message = str(err.get("message", "")).strip() + if not message: + details = err.get("details") + if details: + message = str(details) + if err_type and message: + lines.append("node {} [{}]: {}".format(node_id, err_type, message)) + elif message: + lines.append("node {}: {}".format(node_id, message)) + elif err_type: + lines.append("node {} [{}]".format(node_id, err_type)) + if not lines: + return "" + return "Node validation errors: {}".format(" | ".join(lines)) + + @staticmethod + def summarize_error_text(value, max_chars=None): + """Return a compact Comfy error text safe for UI display.""" + if max_chars is None: + max_chars = ComfyClient.ERROR_MAX_CHARS + + if isinstance(value, (dict, list, tuple)): + value = ComfyClient._limit_error_structure(value) + try: + text = json.dumps(value, ensure_ascii=True) + except Exception: + text = str(value) + else: + text = str(value or "") + + # Remove huge numeric/tensor dumps that make dialogs unreadable. + text = re.sub(r"tensor\(\[[\s\S]{250,}?\]\)", "tensor([])", text) + text = re.sub(r"array\(\[[\s\S]{250,}?\]\)", "array([])", text) + text = re.sub(r"\[[\d\.\-eE,\s]{350,}\]", "[]", text) + text = re.sub(r"\s+", " ", text).strip() + + max_chars = max(300, int(max_chars)) + if len(text) > max_chars: + truncated = len(text) - max_chars + text = "{} ... [truncated {} chars]".format(text[:max_chars], truncated) + return text + + @staticmethod + def _limit_error_structure(value, depth=0, max_depth=4, max_items=10, max_str=260): + if depth >= max_depth: + return "<...>" + if isinstance(value, dict): + out = {} + for index, key in enumerate(value.keys()): + if index >= max_items: + out[""] = len(value) - max_items + break + out[str(key)] = ComfyClient._limit_error_structure( + value.get(key), + depth=depth + 1, + max_depth=max_depth, + max_items=max_items, + max_str=max_str, + ) + return out + if isinstance(value, (list, tuple)): + out = [] + for index, item in enumerate(value): + if index >= max_items: + out.append("".format(len(value) - max_items)) + break + out.append( + ComfyClient._limit_error_structure( + item, + depth=depth + 1, + max_depth=max_depth, + max_items=max_items, + max_str=max_str, + ) + ) + return out + text = str(value) + if len(text) > max_str: + return text[:max_str] + "..." + return text + + def list_checkpoints(self): + """Return available checkpoint names from ComfyUI object info.""" + with urlopen("{}/object_info/CheckpointLoaderSimple".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + # Expected path: + # CheckpointLoaderSimple -> input -> required -> ckpt_name + # Supports multiple schema variants: + # 1) [ [..names..], {...meta...} ] + # 2) ["COMBO", {"options":[..names..], ...}] + node_info = data.get("CheckpointLoaderSimple", {}) + required = node_info.get("input", {}).get("required", {}) + ckpt_input = required.get("ckpt_name", None) + values = self._extract_combo_options(ckpt_input) + return [str(v) for v in values if str(v).strip()] + + def list_upscale_models(self): + """Return available upscaler model names from ComfyUI object info.""" + models = [] + # Primary source: object_info schema for UpscaleModelLoader. + try: + with urlopen("{}/object_info/UpscaleModelLoader".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get("UpscaleModelLoader", {}) + required = node_info.get("input", {}).get("required", {}) + model_input = required.get("model_name", None) + values = self._extract_combo_options(model_input) + if values: + models = [str(v) for v in values if str(v).strip()] + except Exception as ex: + log.debug("ComfyClient list_upscale_models object_info parse failed: %s", ex) + + # Fallback: direct model listing endpoint. + if not models: + try: + with urlopen("{}/models/upscale_models".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + if isinstance(data, list): + models = [str(v) for v in data if str(v).strip()] + except Exception as ex: + log.debug("ComfyClient list_upscale_models /models fallback failed: %s", ex) + + # Dedupe while preserving order. + seen = set() + ordered = [] + for name in models: + if name not in seen: + seen.add(name) + ordered.append(name) + return ordered + + def list_clip_models(self): + """Return available CLIP/text-encoder model names from ComfyUI object info.""" + with urlopen("{}/object_info/CLIPLoader".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get("CLIPLoader", {}) + required = node_info.get("input", {}).get("required", {}) + clip_input = required.get("clip_name", None) + values = self._extract_combo_options(clip_input) + return [str(v) for v in values if str(v).strip()] + + def list_clip_vision_models(self): + """Return available CLIP vision model names from ComfyUI object info.""" + with urlopen("{}/object_info/CLIPVisionLoader".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get("CLIPVisionLoader", {}) + required = node_info.get("input", {}).get("required", {}) + clip_input = required.get("clip_name", None) + values = self._extract_combo_options(clip_input) + return [str(v) for v in values if str(v).strip()] + + def list_rife_vfi_models(self): + """Return available RIFE checkpoint names from ComfyUI object info.""" + node_type = "RIFE VFI" + with urlopen( + "{}/object_info/{}".format(self.base_url, quote(node_type, safe="")), + timeout=8.0, + ) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get(node_type, {}) + required = node_info.get("input", {}).get("required", {}) + ckpt_input = required.get("ckpt_name", None) + values = self._extract_combo_options(ckpt_input) + return [str(v) for v in values if str(v).strip()] + + @staticmethod + def _extract_combo_options(input_config): + """Extract valid options from Comfy object_info input config variants.""" + if input_config is None: + return [] + + # Variant: [ [options...], {meta...} ] + if isinstance(input_config, list) and input_config and isinstance(input_config[0], list): + return [str(v) for v in input_config[0]] + + # Variant: ["COMBO", {"options":[...], ...}] + if ( + isinstance(input_config, list) + and len(input_config) >= 2 + and str(input_config[0]).upper() == "COMBO" + and isinstance(input_config[1], dict) + ): + options = input_config[1].get("options", []) + if isinstance(options, list): + return [str(v) for v in options] + + # Variant: direct list of values + if isinstance(input_config, list): + scalar_values = [] + for item in input_config: + if isinstance(item, (str, int, float)): + scalar_values.append(str(item)) + return scalar_values + + return [] + + def history(self, prompt_id): + with urlopen("{}/history/{}".format(self.base_url, quote(str(prompt_id))), timeout=10.0) as response: + return json.loads(response.read().decode("utf-8")) + + def history_all(self): + with urlopen("{}/history".format(self.base_url), timeout=10.0) as response: + return json.loads(response.read().decode("utf-8")) + + def progress(self): + """Return ComfyUI /progress payload.""" + try: + with urlopen("{}/progress".format(self.base_url), timeout=8.0) as response: + return json.loads(response.read().decode("utf-8")) + except HTTPError as ex: + if int(getattr(ex, "code", 0)) == 404: + # Some ComfyUI versions don't expose /progress. + return None + raise + + def interrupt(self, prompt_id=None): + payload = {} + if prompt_id: + payload["prompt_id"] = str(prompt_id) + log.debug("ComfyClient interrupt request base_url=%s prompt_id=%s", self.base_url, payload.get("prompt_id", "")) + req = Request( + "{}/interrupt".format(self.base_url), + data=json.dumps(payload).encode("utf-8"), + method="POST", + headers={"Content-Type": "application/json"}, + ) + with urlopen(req, timeout=8.0) as response: + log.debug("ComfyClient interrupt response status=%s", int(response.status)) + return int(response.status) >= 200 and int(response.status) < 300 + + def cancel_prompt(self, prompt_id): + """Request ComfyUI to delete/cancel a prompt from the queue.""" + log.debug("ComfyClient cancel_prompt request base_url=%s prompt_id=%s", self.base_url, str(prompt_id)) + payload = json.dumps({"delete": [str(prompt_id)]}).encode("utf-8") + req = Request( + "{}/queue".format(self.base_url), + data=payload, + method="POST", + headers={"Content-Type": "application/json"}, + ) + with urlopen(req, timeout=8.0) as response: + log.debug("ComfyClient cancel_prompt response status=%s", int(response.status)) + return int(response.status) >= 200 and int(response.status) < 300 + + def queue(self): + """Return ComfyUI queue state.""" + with urlopen("{}/queue".format(self.base_url), timeout=10.0) as response: + return json.loads(response.read().decode("utf-8")) + + def upload_input_file(self, local_path): + """Upload a local file into ComfyUI input dir via /upload/image.""" + local_path = str(local_path or "").strip() + if not local_path or not os.path.exists(local_path): + raise RuntimeError("Local file does not exist: {}".format(local_path)) + + boundary = "----OpenShotComfy{}".format(uuid.uuid4().hex) + filename = os.path.basename(local_path) + parts = [] + + def _add_field(name, value): + parts.append("--{}\r\n".format(boundary).encode("utf-8")) + parts.append('Content-Disposition: form-data; name="{}"\r\n\r\n'.format(name).encode("utf-8")) + parts.append(str(value).encode("utf-8")) + parts.append(b"\r\n") + + _add_field("type", "input") + parts.append("--{}\r\n".format(boundary).encode("utf-8")) + parts.append( + ( + 'Content-Disposition: form-data; name="image"; filename="{}"\r\n' + "Content-Type: application/octet-stream\r\n\r\n" + ).format(filename).encode("utf-8") + ) + with open(local_path, "rb") as handle: + parts.append(handle.read()) + parts.append(b"\r\n") + parts.append("--{}--\r\n".format(boundary).encode("utf-8")) + body = b"".join(parts) + + req = Request( + "{}/upload/image".format(self.base_url), + data=body, + method="POST", + headers={"Content-Type": "multipart/form-data; boundary={}".format(boundary)}, + ) + with urlopen(req, timeout=30.0) as response: + data = json.loads(response.read().decode("utf-8")) + + name = str(data.get("name", "")).strip() + subfolder = str(data.get("subfolder", "")).strip() + if not name: + raise RuntimeError("ComfyUI upload failed: invalid response") + rel = "{}/{}".format(subfolder, name) if subfolder else name + return "{} [input]".format(rel) + + @staticmethod + def prompt_in_queue(prompt_id, queue_data): + """Check if prompt_id appears in queue_running/queue_pending payload.""" + pid = str(prompt_id) + if not isinstance(queue_data, dict): + return False + + for key in ("queue_running", "queue_pending"): + entries = queue_data.get(key, []) + if not isinstance(entries, list): + continue + for entry in entries: + # Common format: [number, prompt_id, ...] + if isinstance(entry, list) and len(entry) >= 2 and str(entry[1]) == pid: + return True + # Defensive fallback for dict-like entries + if isinstance(entry, dict): + if str(entry.get("prompt_id", "")) == pid: + return True + return False + + @staticmethod + def extract_file_outputs(history_entry, save_node_ids=None): + """Return a flat list of file refs from image/video/audio history outputs.""" + outputs = [] + if not isinstance(history_entry, dict): + return outputs + node_outputs = history_entry.get("outputs", {}) + if not isinstance(node_outputs, dict): + return outputs + save_node_ids = set(str(node_id) for node_id in (save_node_ids or [])) + + for node_id, node_out in node_outputs.items(): + if save_node_ids and str(node_id) not in save_node_ids: + continue + if isinstance(node_out, dict): + for key in ("images", "videos", "video", "gifs", "audios", "audio", "files", "filenames"): + refs = node_out.get(key, []) + if not isinstance(refs, list): + continue + for ref in refs: + if not isinstance(ref, dict): + continue + if ref.get("filename"): + outputs.append({ + "filename": str(ref.get("filename")), + "subfolder": str(ref.get("subfolder", "")), + "type": str(ref.get("type", "output")), + }) + # Also extract text-like outputs (for custom nodes such as Whisper/SRT pipelines). + for value in node_out.values(): + text_values = ComfyClient._extract_text_outputs(value) + for text_value in text_values: + output_format = "srt" if ComfyClient._looks_like_srt(text_value) else "txt" + outputs.append({ + "text": text_value, + "format": output_format, + "type": "text", + }) + else: + # Some custom nodes emit list/string outputs directly instead of dicts. + text_values = ComfyClient._extract_text_outputs(node_out) + for text_value in text_values: + output_format = "srt" if ComfyClient._looks_like_srt(text_value) else "txt" + outputs.append({ + "text": text_value, + "format": output_format, + "type": "text", + }) + return outputs + + @staticmethod + def extract_image_outputs(history_entry, save_node_ids=None): + return ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) + + @staticmethod + def _extract_text_output(value): + """Extract text payloads from common Comfy output structures.""" + values = ComfyClient._extract_text_outputs(value) + return values[0] if values else "" + + @staticmethod + def _extract_text_outputs(value): + """Extract one or more text payloads from common Comfy output structures.""" + if isinstance(value, str): + text = value.strip() + return [text] if text else [] + if isinstance(value, list): + out = [] + for item in value: + if isinstance(item, str): + text = item.strip() + if text: + out.append(text) + return out + if isinstance(value, dict): + out = [] + for key in ("srt", "text", "value"): + text = value.get(key) + if isinstance(text, str) and text.strip(): + out.append(text.strip()) + return out + return [] + + @staticmethod + def _looks_like_srt(text): + text = str(text or "") + if "-->" not in text: + return False + return bool(re.search(r"\d{2}:\d{2}:\d{2}[,.:]\d{3}\s+-->\s+\d{2}:\d{2}:\d{2}[,.:]\d{3}", text)) + + def download_output_file(self, file_ref, destination_path): + """Download a Comfy output reference to a local file path.""" + params = { + "filename": file_ref.get("filename", ""), + "subfolder": file_ref.get("subfolder", ""), + "type": file_ref.get("type", "output"), + } + url = "{}/view?{}".format(self.base_url, urlencode(params)) + with urlopen(url, timeout=10.0) as response: + data = response.read() + + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + with open(destination_path, "wb") as handle: + handle.write(data) + + def download_image(self, image_ref, destination_path): + self.download_output_file(image_ref, destination_path) diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py new file mode 100644 index 000000000..9c2fd816d --- /dev/null +++ b/src/classes/comfy_pipelines.py @@ -0,0 +1,468 @@ +""" + @file + @brief This file contains built-in ComfyUI pipeline definitions. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . +""" + +import random +import os + + +RASTER_IMAGE_EXTENSIONS = { + ".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff", ".gif", +} + +DEFAULT_SD_CHECKPOINT = "sd_xl_turbo_1.0_fp16.safetensors" +DEFAULT_SD_BASE_CHECKPOINT = "sd_xl_base_1.0.safetensors" +DEFAULT_UPSCALE_MODEL = "RealESRGAN_x4plus.safetensors" +DEFAULT_STABLE_AUDIO_CHECKPOINT = "stable-audio-open-1.0.safetensors" +DEFAULT_STABLE_AUDIO_CLIP = "t5-base.safetensors" +DEFAULT_SVD_CHECKPOINT = "svd_xt.safetensors" +DEFAULT_RIFE_VFI_MODEL = "rife47.pth" + + +def is_supported_img2img_path(path): + path_text = str(path or "").strip() + # Comfy annotated paths can look like: "image.jpg [input]" + if path_text.endswith("]") and " [" in path_text: + path_text = path_text.rsplit(" [", 1)[0].strip() + ext = os.path.splitext(path_text)[1].lower() + return ext in RASTER_IMAGE_EXTENSIONS + + +def _supports_img2img(source_file=None): + if not source_file: + return False + if source_file.data.get("media_type") != "image": + return False + path = source_file.data.get("path", "") + return is_supported_img2img_path(path) + + +def _supports_video_upscale(source_file=None): + if not source_file: + return False + return source_file.data.get("media_type") == "video" + + +def available_pipelines(source_file=None): + pipelines = [ + {"id": "txt2img-basic", "name": "Basic Text to Image"}, + {"id": "txt2video-svd", "name": "Text to Video (txt_to_image_to_video)"}, + {"id": "txt2audio-stable-open", "name": "Text to Audio (Stable Audio Open)"}, + ] + if _supports_img2img(source_file): + pipelines.insert(0, {"id": "img2img-basic", "name": "Basic Image Variation"}) + pipelines.insert(1, {"id": "upscale-realesrgan-x4", "name": "Upscale Image (RealESRGAN x4)"}) + pipelines.insert(2, {"id": "img2video-svd", "name": "Image to Video (WAN 2.2 TI2V)"}) + if _supports_video_upscale(source_file): + pipelines.append({"id": "video-segment-scenes-transnet", "name": "Segment Scenes (TransNetV2)"}) + pipelines.append({"id": "video-frame-interpolation-rife2x", "name": "Frame Interpolation (RIFE 2x FPS)"}) + pipelines.append({"id": "video-upscale-gan", "name": "Upscale Video (GAN x4, first 10s)"}) + pipelines.append({"id": "video2video-basic", "name": "Video + Text to Video (Style Transfer)"}) + pipelines.append({"id": "video-whisper-srt", "name": "Whisper Transcribe to SRT (Caption Effect)"}) + return pipelines + + +def pipeline_requires_checkpoint(pipeline_id): + return str(pipeline_id or "") in ( + "txt2img-basic", + "img2img-basic", + "txt2audio-stable-open", + "txt2video-svd", + "video2video-basic", + ) + + +def pipeline_requires_upscale_model(pipeline_id): + return str(pipeline_id or "") in ("upscale-realesrgan-x4", "video-upscale-gan") + + +def pipeline_requires_stable_audio_clip(pipeline_id): + return str(pipeline_id or "") in ("txt2audio-stable-open",) + + +def pipeline_requires_svd_checkpoint(pipeline_id): + return str(pipeline_id or "") in ("txt2video-svd", "img2video-svd") + + +def pipeline_requires_rife_model(pipeline_id): + return str(pipeline_id or "") in ("video-frame-interpolation-rife2x",) + + +def build_workflow( + pipeline_id, + prompt_text, + source_path, + output_prefix, + checkpoint_name=None, + upscale_model_name=None, + stable_audio_clip_name=None, + svd_checkpoint_name=None, + source_fps=None, + rife_model_name=None, +): + prompt_text = str(prompt_text or "cinematic shot, highly detailed").strip() + if not prompt_text: + prompt_text = "cinematic shot, highly detailed" + output_prefix = str(output_prefix or "openshot_gen").strip() or "openshot_gen" + checkpoint_name = str(checkpoint_name or "").strip() or DEFAULT_SD_CHECKPOINT + upscale_model_name = str(upscale_model_name or "").strip() or DEFAULT_UPSCALE_MODEL + stable_audio_clip_name = str(stable_audio_clip_name or "").strip() or DEFAULT_STABLE_AUDIO_CLIP + svd_checkpoint_name = str(svd_checkpoint_name or "").strip() or DEFAULT_SVD_CHECKPOINT + rife_model_name = str(rife_model_name or "").strip() or DEFAULT_RIFE_VFI_MODEL + try: + source_fps_value = float(source_fps) + except (TypeError, ValueError): + source_fps_value = 30.0 + if source_fps_value <= 0: + source_fps_value = 30.0 + target_fps = round(source_fps_value * 2.0, 6) + seed = random.randint(1, 2**31 - 1) + + if pipeline_id == "img2img-basic": + if not is_supported_img2img_path(source_path): + raise ValueError( + "The selected file is not a supported raster image for this pipeline. " + "Use PNG/JPG/WebP/BMP/TIFF or switch to Text to Image." + ) + return { + "1": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "2": {"inputs": {"text": prompt_text, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "3": {"inputs": {"text": "low quality, blurry", "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"image": str(source_path or ""), "upload": "image"}, "class_type": "LoadImage"}, + "5": {"inputs": {"pixels": ["4", 0], "vae": ["1", 2]}, "class_type": "VAEEncode"}, + "6": { + "inputs": { + "seed": seed, "steps": 20, "cfg": 7.0, "sampler_name": "euler", "scheduler": "normal", + "denoise": 0.65, "model": ["1", 0], "positive": ["2", 0], "negative": ["3", 0], "latent_image": ["5", 0], + }, + "class_type": "KSampler", + }, + "7": {"inputs": {"samples": ["6", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "8": {"inputs": {"filename_prefix": output_prefix, "images": ["7", 0]}, "class_type": "SaveImage"}, + } + + if pipeline_id == "upscale-realesrgan-x4": + if not is_supported_img2img_path(source_path): + raise ValueError( + "The selected file is not a supported raster image for this pipeline. " + "Use PNG/JPG/WebP/BMP/TIFF or switch to Text to Image." + ) + return { + "1": {"inputs": {"image": str(source_path or ""), "upload": "image"}, "class_type": "LoadImage"}, + "2": {"inputs": {"model_name": upscale_model_name}, "class_type": "UpscaleModelLoader"}, + "3": {"inputs": {"upscale_model": ["2", 0], "image": ["1", 0]}, "class_type": "ImageUpscaleWithModel"}, + "4": {"inputs": {"filename_prefix": output_prefix, "images": ["3", 0]}, "class_type": "SaveImage"}, + } + + if pipeline_id == "video-upscale-gan": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": { + "inputs": {"video": ["1", 0], "start_time": 0.0, "duration": 10.0, "strict_duration": False}, + "class_type": "Video Slice", + }, + "3": {"inputs": {"video": ["2", 0]}, "class_type": "GetVideoComponents"}, + "4": {"inputs": {"model_name": upscale_model_name}, "class_type": "UpscaleModelLoader"}, + "5": {"inputs": {"upscale_model": ["4", 0], "image": ["3", 0]}, "class_type": "ImageUpscaleWithModel"}, + "6": {"inputs": {"images": ["5", 0], "audio": ["3", 1], "fps": ["3", 2]}, "class_type": "CreateVideo"}, + "7": {"inputs": {"video": ["6", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + + if pipeline_id == "video-whisper-srt": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": { + "inputs": { + "video": source_path, + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "format": "AnimateDiff", + }, + "class_type": "VHS_LoadVideo", + }, + "2": { + "inputs": { + "model": "medium", + "language": "auto", + "prompt": "", + "audio": ["1", 2], + }, + "class_type": "Apply Whisper", + }, + "3": { + "inputs": { + "name": "{}_segments".format(output_prefix), + "alignment": ["2", 1], + }, + "class_type": "Save SRT", + }, + "4": { + "inputs": { + "preview": "", + "previewMode": None, + "source": ["3", 0], + }, + "class_type": "PreviewAny", + }, + } + + if pipeline_id == "video-frame-interpolation-rife2x": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": {"inputs": {"video": ["1", 0]}, "class_type": "GetVideoComponents"}, + "3": { + "inputs": { + "frames": ["2", 0], + "ckpt_name": rife_model_name, + "clear_cache_after_n_frames": 10, + "multiplier": 2, + "fast_mode": True, + "ensemble": True, + "scale_factor": 1, + }, + "class_type": "RIFE VFI", + "_meta": {"title": "RIFE VFI (recommend rife47 and rife49)"}, + }, + "4": {"inputs": {"images": ["3", 0], "audio": ["2", 1], "fps": target_fps}, "class_type": "CreateVideo"}, + "5": { + "inputs": { + "video": ["4", 0], + "filename_prefix": "video/{}".format(output_prefix), + "format": "auto", + "codec": "auto", + }, + "class_type": "SaveVideo", + }, + } + + if pipeline_id == "video-segment-scenes-transnet": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": { + "inputs": { + "source_video_path": source_path, + "threshold": 0.5, + "min_scene_length_frames": 30, + "device": "auto", + }, + "class_type": "OpenShotTransNetSceneDetect", + "_meta": {"title": "OpenShot TransNet Scene Detect"}, + }, + "9": { + "inputs": { + "preview": "", + "previewMode": None, + "source": ["1", 0], + }, + "class_type": "PreviewAny", + "_meta": {"title": "Preview Any"}, + }, + } + + if pipeline_id == "txt2audio-stable-open": + return { + "3": { + "inputs": { + "seed": seed, + "steps": 50, + "cfg": 5.0, + "sampler_name": "dpmpp_3m_sde_gpu", + "scheduler": "exponential", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["11", 0], + }, + "class_type": "KSampler", + }, + "4": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "6": {"inputs": {"text": prompt_text, "clip": ["10", 0]}, "class_type": "CLIPTextEncode"}, + "7": {"inputs": {"text": "", "clip": ["10", 0]}, "class_type": "CLIPTextEncode"}, + "10": {"inputs": {"clip_name": stable_audio_clip_name, "type": "stable_audio"}, "class_type": "CLIPLoader"}, + "11": {"inputs": {"seconds": 30.0, "batch_size": 1}, "class_type": "EmptyLatentAudio"}, + "12": {"inputs": {"samples": ["3", 0], "vae": ["4", 2]}, "class_type": "VAEDecodeAudio"}, + "13": {"inputs": {"filename_prefix": "audio/{}".format(output_prefix), "audio": ["12", 0]}, "class_type": "SaveAudio"}, + } + + if pipeline_id == "txt2video-svd": + return { + "1": {"inputs": {"ckpt_name": svd_checkpoint_name}, "class_type": "ImageOnlyCheckpointLoader"}, + "2": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "3": {"inputs": {"text": prompt_text, "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"text": "low quality, blurry", "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, + "5": {"inputs": {"width": 512, "height": 288, "batch_size": 1}, "class_type": "EmptyLatentImage"}, + "6": { + "inputs": { + "seed": seed, + "steps": 8, + "cfg": 6.0, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1.0, + "model": ["2", 0], + "positive": ["3", 0], + "negative": ["4", 0], + "latent_image": ["5", 0], + }, + "class_type": "KSampler", + }, + "7": {"inputs": {"samples": ["6", 0], "vae": ["2", 2]}, "class_type": "VAEDecode"}, + "8": { + "inputs": { + "clip_vision": ["1", 1], + "init_image": ["7", 0], + "vae": ["1", 2], + "width": 512, + "height": 288, + "video_frames": 24, + "motion_bucket_id": 127, + "fps": 12, + "augmentation_level": 0.0, + }, + "class_type": "SVD_img2vid_Conditioning", + }, + "9": {"inputs": {"model": ["1", 0], "min_cfg": 1.0}, "class_type": "VideoLinearCFGGuidance"}, + "10": { + "inputs": { + "seed": seed + 1, + "steps": 10, + "cfg": 2.5, + "sampler_name": "euler", + "scheduler": "karras", + "denoise": 1.0, + "model": ["9", 0], + "positive": ["8", 0], + "negative": ["8", 1], + "latent_image": ["8", 2], + }, + "class_type": "KSampler", + }, + "11": {"inputs": {"samples": ["10", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "12": {"inputs": {"images": ["11", 0], "fps": 12}, "class_type": "CreateVideo"}, + "13": {"inputs": {"video": ["12", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + + if pipeline_id == "img2video-svd": + if not is_supported_img2img_path(source_path): + raise ValueError( + "The selected file is not a supported raster image for this pipeline. " + "Use PNG/JPG/WebP/BMP/TIFF or switch to Text to Video." + ) + return { + "1": {"inputs": {"ckpt_name": svd_checkpoint_name}, "class_type": "ImageOnlyCheckpointLoader"}, + "2": {"inputs": {"image": str(source_path or ""), "upload": "image"}, "class_type": "LoadImage"}, + "3": { + "inputs": { + "clip_vision": ["1", 1], + "init_image": ["2", 0], + "vae": ["1", 2], + "width": 1024, + "height": 576, + "video_frames": 25, + "motion_bucket_id": 127, + "fps": 6, + "augmentation_level": 0.0, + }, + "class_type": "SVD_img2vid_Conditioning", + }, + "4": {"inputs": {"model": ["1", 0], "min_cfg": 1.0}, "class_type": "VideoLinearCFGGuidance"}, + "5": { + "inputs": { + "seed": seed + 1, + "steps": 20, + "cfg": 2.5, + "sampler_name": "euler", + "scheduler": "karras", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["3", 0], + "negative": ["3", 1], + "latent_image": ["3", 2], + }, + "class_type": "KSampler", + }, + "6": {"inputs": {"samples": ["5", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "7": {"inputs": {"images": ["6", 0], "fps": 6}, "class_type": "CreateVideo"}, + "8": {"inputs": {"video": ["7", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + + if pipeline_id == "video2video-basic": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": { + "inputs": {"video": ["1", 0], "start_time": 0.0, "duration": 10.0, "strict_duration": False}, + "class_type": "Video Slice", + }, + "3": {"inputs": {"video": ["2", 0]}, "class_type": "GetVideoComponents"}, + "4": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "5": {"inputs": {"text": prompt_text, "clip": ["4", 1]}, "class_type": "CLIPTextEncode"}, + "6": {"inputs": {"text": "low quality, blurry", "clip": ["4", 1]}, "class_type": "CLIPTextEncode"}, + "7": {"inputs": {"pixels": ["3", 0], "vae": ["4", 2]}, "class_type": "VAEEncode"}, + "8": { + "inputs": { + "seed": seed, "steps": 16, "cfg": 6.0, "sampler_name": "euler", "scheduler": "normal", + "denoise": 0.55, "model": ["4", 0], "positive": ["5", 0], "negative": ["6", 0], "latent_image": ["7", 0], + }, + "class_type": "KSampler", + }, + "9": {"inputs": {"samples": ["8", 0], "vae": ["4", 2]}, "class_type": "VAEDecode"}, + "10": {"inputs": {"images": ["9", 0], "audio": ["3", 1], "fps": ["3", 2]}, "class_type": "CreateVideo"}, + "11": {"inputs": {"video": ["10", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + + return { + "1": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "2": {"inputs": {"text": prompt_text, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "3": {"inputs": {"text": "low quality, blurry", "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"width": 1024, "height": 576, "batch_size": 1}, "class_type": "EmptyLatentImage"}, + "5": { + "inputs": { + "seed": seed, "steps": 20, "cfg": 7.0, "sampler_name": "euler", "scheduler": "normal", + "denoise": 1.0, "model": ["1", 0], "positive": ["2", 0], "negative": ["3", 0], "latent_image": ["4", 0], + }, + "class_type": "KSampler", + }, + "6": {"inputs": {"samples": ["5", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "7": {"inputs": {"filename_prefix": output_prefix, "images": ["6", 0]}, "class_type": "SaveImage"}, + } diff --git a/src/classes/comfy_templates.py b/src/classes/comfy_templates.py new file mode 100644 index 000000000..adcc0d345 --- /dev/null +++ b/src/classes/comfy_templates.py @@ -0,0 +1,405 @@ +""" + @file + @brief ComfyUI workflow template discovery and classification helpers. +""" + +import copy +import json +import os + +from classes import info +from classes.logger import log + + +IMAGE_INPUT_TYPES = { + "loadimage", + "load image", +} +VIDEO_INPUT_TYPES = { + "loadvideo", + "load video", + "vhs_loadvideo", +} +AUDIO_INPUT_TYPES = { + "loadaudio", + "load audio", +} + +IMAGE_OUTPUT_TYPES = { + "saveimage", + "save image", +} +VIDEO_OUTPUT_TYPES = { + "savevideo", + "save video", +} +AUDIO_OUTPUT_TYPES = { + "saveaudio", + "save audio", +} + +KNOWN_NODE_TYPES = { + # Input + "checkpointloadersimple", + "unetloader", + "cliptextencode", + "cliploader", + "vaeloader", + "loadimage", + "loadvideo", + "vhs_loadvideo", + "loadaudio", + # Core built-in/OpenShot workflows + "vaeencode", + "vaedecode", + "ksampler", + "upscalemodelloader", + "imageupscalewithmodel", + "videoslice", + "video slice", + "getvideocomponents", + "createvideo", + "saveimage", + "savevideo", + "saveaudio", + "save srt", + "emptylatentimage", + "emptyhunyuanlatentvideo", + "wan22imagetovideolatent", + "imageonlycheckpointloader", + "modelsamplingsd3", + "svd_img2vid_conditioning", + "videolinearcfgguidance", + "emptylatentaudio", + "vaedecodeaudio", + "previewany", + "apply whisper", + "riff vfi", + "rife vfi", + "downloadandloadtransnetmodel", + "transnetv2_run", + "selectvideo", + "stableaudioprojectionmodel", + "stableaudiomodelloader", + "stableaudioemptylatentaudio", + "stableaudioembedding", + "kdiffusionsampler", + "stableaudiovaedecode", + "videocombine", + "imagescaleby", + "imagetosimage", + "imagetoimage", + "imagescaleto", + "imageblur", + "imagecompositemasked", + # Video Helper Suite + "vhs_batchmanager", + "vhs_loadvideo", + "vhs_loadvideopath", + "vhs_loadvideoffmpegpath", + "vhs_videocombine", + "vhs_videoinfo", + "vhs_videoinfoloaded", + "vhs_videoinfosource", + # ComfyUI-segment-anything-2 + "downloadandloadsam2model", + "sam2segmentation", + "sam2autosegmentation", + "sam2videosegmentationaddpoints", + "sam2videosegmentation", + # OpenShot-ComfyUI (custom SAM2) + "openshotdownloadandloadsam2model", + "openshotsam2segmentation", + "openshotsam2videosegmentationaddpoints", + "openshotsam2videosegmentationchunked", + "openshotimageblurmasked", + "openshotimagehighlightmasked", +} + + +class ComfyTemplateRegistry: + """Discovers ComfyUI templates from built-in + user folders.""" + + def __init__(self): + self._cache = None + self._cache_signature = None + + @staticmethod + def _is_ignored_filename(name): + return str(name or "").strip().lower() in ("debug.json", "debug_error.json", "debug_sent.json") + + def _template_roots(self): + return [ + (os.path.join(info.PATH, "comfyui"), False), + (info.COMFYUI_PATH, True), + ] + + def _current_signature(self): + signature = [] + for folder, _is_user in self._template_roots(): + if not os.path.isdir(folder): + continue + for name in sorted(os.listdir(folder)): + if not name.lower().endswith(".json"): + continue + if self._is_ignored_filename(name): + continue + path = os.path.join(folder, name) + try: + stat = os.stat(path) + signature.append((path, stat.st_mtime_ns, stat.st_size)) + except OSError: + continue + return tuple(signature) + + def discover(self, force=False): + signature = self._current_signature() + if not force and self._cache is not None and signature == self._cache_signature: + return self._cache + + templates = [] + existing_ids = set() + for folder, is_user in self._template_roots(): + if not os.path.isdir(folder): + continue + for name in sorted(os.listdir(folder)): + if not name.lower().endswith(".json"): + continue + if self._is_ignored_filename(name): + continue + path = os.path.join(folder, name) + template = self._load_template(path, is_user=is_user, existing_ids=existing_ids) + if template is None: + continue + templates.append(template) + + templates.sort(key=lambda t: (int(t.get("sort_order", 99999)), str(t.get("display_name", "")).lower())) + self._cache = templates + self._cache_signature = signature + return templates + + def _load_template(self, path, is_user, existing_ids): + try: + with open(path, "r", encoding="utf-8") as handle: + payload = json.load(handle) + except Exception as ex: + log.warning("Skipping invalid ComfyUI template JSON %s: %s", path, ex) + return None + if not isinstance(payload, dict): + log.warning("Skipping invalid ComfyUI template JSON %s: root must be an object", path) + return None + + workflow = self._extract_workflow(payload) + if workflow is None: + log.warning("Skipping invalid ComfyUI template JSON %s: no valid workflow graph found", path) + return None + + node_types = [] + input_types = set() + output_types = set() + unknown_node_types = set() + needs_prompt = False + for node in workflow.values(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip() + if not class_type: + continue + inputs = node.get("inputs", {}) + if not isinstance(inputs, dict): + inputs = {} + class_key = class_type.lower().replace("_", "") + class_flat = class_type.lower().strip() + node_types.append(class_type) + + text_value = inputs.get("text", None) + if isinstance(text_value, str): + meta = node.get("_meta", {}) + meta_title = "" + if isinstance(meta, dict): + meta_title = str(meta.get("title", "")).strip().lower() + if "textencode" in class_key or "prompt" in meta_title: + needs_prompt = True + + if class_flat in IMAGE_INPUT_TYPES or class_key in IMAGE_INPUT_TYPES: + input_types.add("image") + if class_flat in VIDEO_INPUT_TYPES or class_key in VIDEO_INPUT_TYPES: + input_types.add("video") + if class_flat in AUDIO_INPUT_TYPES or class_key in AUDIO_INPUT_TYPES: + input_types.add("audio") + + if class_flat in IMAGE_OUTPUT_TYPES or class_key in IMAGE_OUTPUT_TYPES: + output_types.add("image") + if class_flat in VIDEO_OUTPUT_TYPES or class_key in VIDEO_OUTPUT_TYPES: + output_types.add("video") + if class_flat in AUDIO_OUTPUT_TYPES or class_key in AUDIO_OUTPUT_TYPES: + output_types.add("audio") + + if class_flat not in KNOWN_NODE_TYPES and class_key not in KNOWN_NODE_TYPES: + unknown_node_types.add(class_type) + + if unknown_node_types: + log.warning( + "ComfyUI template has unknown node types (%s): %s", + os.path.basename(path), + ", ".join(sorted(unknown_node_types)), + ) + + override_category = str(payload.get("menu_category") or payload.get("category") or "").strip().lower() + override_menu_parent = str(payload.get("menu_parent") or "").strip() + override_output_type = str(payload.get("output_type") or payload.get("media_output") or "").strip().lower() + override_icon = str(payload.get("action_icon") or payload.get("icon") or "").strip() + override_open_dialog = payload.get("open_dialog", None) + + inferred_category = "unknown" + requires_source = bool(input_types) + if output_types: + inferred_category = "enhance" if requires_source else "create" + else: + if override_category not in ("create", "enhance", "unknown"): + log.warning( + "ComfyUI template category unknown (%s): no output nodes detected", + os.path.basename(path), + ) + if override_category in ("create", "enhance", "unknown"): + inferred_category = override_category + + template_id = str(payload.get("template_id") or payload.get("id") or "").strip() + if not template_id: + template_id = os.path.splitext(os.path.basename(path))[0] + + unique_id = template_id + suffix = 2 + while unique_id in existing_ids: + unique_id = "{}__{}".format(template_id, suffix) + suffix += 1 + existing_ids.add(unique_id) + + display_name = self._extract_name(payload, path) + if is_user: + display_name = "(User) {}".format(display_name) + + try: + sort_order = int(payload.get("menu_order", 99999)) + except (TypeError, ValueError): + sort_order = 99999 + primary_output = self._primary_output_type(output_types) + if override_output_type in ("image", "video", "audio", "unknown"): + primary_output = override_output_type + + open_dialog = None + if isinstance(override_open_dialog, bool): + open_dialog = override_open_dialog + + return { + "id": unique_id, + "template_id": template_id, + "display_name": display_name, + "path": path, + "is_user": is_user, + "category": inferred_category, + "input_types": sorted(input_types), + "output_types": sorted(output_types), + "primary_output": primary_output, + "sort_order": sort_order, + "workflow": workflow, + "node_types": node_types, + "needs_prompt": needs_prompt, + "action_icon": override_icon, + "open_dialog": open_dialog, + "menu_parent": override_menu_parent, + } + + def _primary_output_type(self, output_types): + if "video" in output_types: + return "video" + if "image" in output_types: + return "image" + if "audio" in output_types: + return "audio" + return "unknown" + + def _extract_name(self, payload, path): + fields = [ + payload.get("name"), + payload.get("title"), + payload.get("workflow_name"), + ] + metadata = payload.get("metadata") + if isinstance(metadata, dict): + fields.extend([metadata.get("name"), metadata.get("title")]) + + for value in fields: + text = str(value or "").strip() + if text: + return text + + return os.path.splitext(os.path.basename(path))[0] + + def _extract_workflow(self, payload): + if self._looks_like_workflow(payload): + return payload + if isinstance(payload, dict): + workflow = payload.get("workflow") + if self._looks_like_workflow(workflow): + return workflow + return None + + def _looks_like_workflow(self, value): + if not isinstance(value, dict) or not value: + return False + for node in value.values(): + if isinstance(node, dict) and str(node.get("class_type", "")).strip(): + return True + return False + + def templates_for_context(self, source_file=None): + templates = self.discover() + media_type = "" + if source_file: + media_type = str(source_file.data.get("media_type", "")).strip().lower() + + filtered = [] + for template in templates: + category = str(template.get("category", "unknown")) + input_types = set(template.get("input_types", [])) + if source_file: + if category not in ("enhance", "unknown"): + continue + if input_types and media_type not in input_types: + continue + else: + if category not in ("create", "unknown"): + continue + if category == "unknown" and input_types: + continue + filtered.append(template) + return filtered + + def get_template(self, template_id): + template_id = str(template_id or "").strip() + if not template_id: + return None + for template in self.discover(): + if str(template.get("id")) == template_id: + return template + return None + + def get_workflow_copy(self, template_id): + template = self.get_template(template_id) + if not template: + return None + return copy.deepcopy(template.get("workflow") or {}) + + def output_icon_name(self, template): + explicit_icon = str((template or {}).get("action_icon") or "").strip() + if explicit_icon: + return explicit_icon + kind = str((template or {}).get("primary_output") or "unknown") + if kind == "video": + return "ai-action-create-video.svg" + if kind == "audio": + return "ai-action-create-audio.svg" + if kind == "image": + return "ai-action-create-image.svg" + return "tool-generate-sparkle.svg" diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py new file mode 100644 index 000000000..04a985988 --- /dev/null +++ b/src/classes/generation_queue.py @@ -0,0 +1,879 @@ +""" + @file + @brief This file contains a lightweight in-memory generation queue for ComfyUI jobs. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . + """ + +import uuid +from collections import deque +from threading import Event +from time import monotonic + +from PyQt5.QtCore import QObject, QThread, pyqtSignal, pyqtSlot + +from classes.comfy_client import ComfyClient +from classes.logger import log + + +class _GenerationWorker(QObject): + """Background worker that simulates generation progress for queued jobs.""" + + progress_changed = pyqtSignal(str, int) + progress_detail_changed = pyqtSignal(str, str) + progress_sub_changed = pyqtSignal(str, int) + job_finished = pyqtSignal(str, bool, bool, str, object) + + def __init__(self): + super().__init__() + self._cancel_requested = set() + self._job_prompts = {} + + def _is_cancel_requested(self, job_id, cancel_event): + return (job_id in self._cancel_requested) or (cancel_event is not None and cancel_event.is_set()) + + @staticmethod + def _is_unfinished_meta_batch(history_entry): + outputs = history_entry.get("outputs", {}) if isinstance(history_entry, dict) else {} + if not isinstance(outputs, dict): + return False + for node_out in outputs.values(): + if not isinstance(node_out, dict): + continue + unfinished = node_out.get("unfinished_batch", None) + if isinstance(unfinished, list): + if any(bool(v) for v in unfinished): + return True + elif unfinished: + return True + return False + + @staticmethod + def _history_prompt_meta(history_entry): + prompt_payload = history_entry.get("prompt", []) if isinstance(history_entry, dict) else [] + if not isinstance(prompt_payload, list): + return "", 0 + client_payload = prompt_payload[3] if len(prompt_payload) >= 4 else {} + if not isinstance(client_payload, dict): + return "", 0 + client_id = str(client_payload.get("client_id", "")).strip() + create_time = int(client_payload.get("create_time", 0) or 0) + return client_id, create_time + + @staticmethod + def _allow_unfiltered_output_fallback(template_id): + template_id = str(template_id or "").strip().lower() + # Track-object templates intentionally have multiple save nodes + # (mask/debug + final), so we must not relax save-node filtering. + if template_id in ( + "video-blur-anything-sam2", + "video-mask-anything-sam2", + "video-highlight-anything-sam2", + "txt2music-ace-step", + ): + return False + return True + + def _find_related_meta_batch_outputs(self, client, history_entry, save_node_ids, template_id=""): + base_client_id, base_create_time = self._history_prompt_meta(history_entry) + if not base_client_id: + return [] + try: + history_all = client.history_all() or {} + except Exception: + return [] + if not isinstance(history_all, dict): + return [] + + best_create_time = 0 + best_outputs = [] + for entry in history_all.values(): + if not isinstance(entry, dict): + continue + status_obj = entry.get("status", {}) if isinstance(entry, dict) else {} + status_str = str(status_obj.get("status_str", "")).lower() + if status_str not in ("success", "completed", ""): + continue + entry_client_id, entry_create_time = self._history_prompt_meta(entry) + if entry_client_id != base_client_id: + continue + if entry_create_time and base_create_time and entry_create_time < base_create_time: + continue + if self._is_unfinished_meta_batch(entry): + continue + + outputs = ComfyClient.extract_file_outputs(entry, save_node_ids=save_node_ids) + if (not outputs) and save_node_ids and self._allow_unfiltered_output_fallback(template_id): + outputs = ComfyClient.extract_file_outputs(entry, save_node_ids=None) + if not outputs: + continue + + if entry_create_time >= best_create_time: + best_create_time = entry_create_time + best_outputs = outputs + + return best_outputs + + @pyqtSlot(str, object) + def run_job(self, job_id, request): + request = request or {} + cancel_event = request.get("cancel_event") + if request.get("workflow") and request.get("comfy_url"): + self._run_comfy_job(job_id, request) + return + + canceled = False + for step in range(1, 21): + QThread.msleep(250) + if self._is_cancel_requested(job_id, cancel_event): + canceled = True + break + self.progress_changed.emit(job_id, min(step * 5, 99)) + + self._cancel_requested.discard(job_id) + + if canceled: + self.job_finished.emit(job_id, False, True, "", []) + else: + self.progress_changed.emit(job_id, 100) + self.job_finished.emit(job_id, True, False, "", []) + + @pyqtSlot(str) + def cancel_job(self, job_id): + log.debug("GenerationWorker cancel_job received job=%s", str(job_id)) + self._cancel_requested.add(job_id) + + def _run_comfy_job(self, job_id, request): + comfy_url = request.get("comfy_url") + workflow = request.get("workflow") + client_id = request.get("client_id") or "openshot-qt" + timeout_s = int(request.get("timeout_s") or 86400) # default 24 hours safety cap + save_node_ids = list(request.get("save_node_ids") or []) + template_id = str(request.get("template_id") or "") + cancel_event = request.get("cancel_event") + client = ComfyClient(comfy_url) + ws_client = None + + try: + prompt_id = client.queue_prompt(workflow, client_id) + if not prompt_id: + self.job_finished.emit(job_id, False, False, "ComfyUI returned an invalid prompt_id", []) + return + self._job_prompts[job_id] = prompt_id + try: + ws_client = ComfyClient.open_progress_socket(comfy_url, client_id) + log.debug("Comfy progress websocket connected for prompt=%s", str(prompt_id)) + except Exception: + log.debug("Comfy progress websocket unavailable; continuing without live progress", exc_info=True) + + start_time = monotonic() + last_in_queue_time = start_time + last_contact_time = start_time + last_progress_log_time = 0.0 + last_network_error_log_time = 0.0 + progress_endpoint_unavailable = False + accepted_progress_started = False + ws_retry_delay_s = 2.0 + ws_next_retry_at = start_time + ws_last_progress_time = start_time + ws_stale_reconnect_s = 60.0 + ws_stale_reconnect_max_s = 300.0 + prompt_key = str(prompt_id) + last_progress_signature = None + last_progress_detail = "" + + while True: + if self._is_cancel_requested(job_id, cancel_event): + log.debug("Comfy cancel requested for job=%s prompt=%s", job_id, str(prompt_id)) + cancel_ok = False + cancel_errors = [] + + # Retry cancellation a few times and verify prompt no longer appears in Comfy queue. + for attempt in range(1, 181): + try: + cancel_ok = client.cancel_prompt(prompt_id) or cancel_ok + except Exception as ex: + cancel_errors.append("queue: {}".format(ex)) + + try: + cancel_ok = client.interrupt(prompt_id=prompt_id) or cancel_ok + except Exception as ex: + cancel_errors.append("interrupt: {}".format(ex)) + + try: + history = client.history(prompt_id) or {} + prompt_key = str(prompt_id) + history_entry = history.get(prompt_key) or history.get(prompt_id) or None + if isinstance(history_entry, dict): + status_obj = history_entry.get("status", {}) if isinstance(history_entry, dict) else {} + status_str = str(status_obj.get("status_str", "")).lower() + # Comfy commonly marks interrupted runs as failed/error in history. + if status_str in ("error", "failed"): + cancel_ok = True + log.debug( + "Comfy cancel confirmed by history status for job=%s prompt=%s status=%s", + job_id, + prompt_key, + status_str, + ) + break + except Exception as ex: + cancel_errors.append("history-check: {}".format(ex)) + + try: + queue_data = client.queue() or {} + if not ComfyClient.prompt_in_queue(prompt_id, queue_data): + cancel_ok = True + log.debug( + "Comfy cancel confirmed by queue absence for job=%s prompt=%s on attempt=%s", + job_id, + str(prompt_id), + attempt, + ) + break + except Exception as ex: + cancel_errors.append("queue-check: {}".format(ex)) + + if attempt % 10 == 0: + log.debug( + "Comfy cancel still pending job=%s prompt=%s attempt=%s", + job_id, + str(prompt_id), + attempt, + ) + QThread.msleep(500) + + self._cancel_requested.discard(job_id) + self._job_prompts.pop(job_id, None) + if cancel_ok: + self.job_finished.emit(job_id, False, True, "", []) + else: + self.job_finished.emit( + job_id, + False, + False, + "ComfyUI did not accept cancel request ({})".format("; ".join(cancel_errors) or "unknown"), + [], + ) + return + + history_entry = None + try: + history = client.history(prompt_id) or {} + history_entry = history.get(prompt_key) or history.get(prompt_id) or None + last_contact_time = monotonic() + except Exception: + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy history poll temporarily unavailable for job=%s prompt=%s; retrying", + job_id, + prompt_key, + exc_info=True, + ) + last_network_error_log_time = now_log + if history_entry is not None: + status_obj = history_entry.get("status", {}) if isinstance(history_entry, dict) else {} + status_str = str(status_obj.get("status_str", "")).lower() + if status_str in ("error", "failed"): + error_text = "ComfyUI job failed." + messages = status_obj.get("messages", []) + if isinstance(messages, list) and messages: + error_text = ComfyClient.summarize_error_text(messages[-1]) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, False, False, error_text, []) + return + if self._is_unfinished_meta_batch(history_entry): + image_outputs = self._find_related_meta_batch_outputs( + client, + history_entry, + save_node_ids, + template_id=template_id, + ) + if image_outputs: + self.progress_changed.emit(job_id, 100) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, True, False, "", image_outputs) + return + # Meta batch uses follow-up prompts under the same client_id. + # Keep polling progress/queue while waiting for follow-up prompt outputs. + else: + image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) + if (not image_outputs) and save_node_ids and self._allow_unfiltered_output_fallback(template_id): + # Fallback for workflows whose output node ids shift or emit non-standard keys. + image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=None) + self.progress_changed.emit(job_id, 100) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, True, False, "", image_outputs) + return + + # Query ComfyUI's live progress values when available. + try: + ws_progress_emitted = False + now = monotonic() + if ws_client is None and now >= ws_next_retry_at: + try: + ws_client = ComfyClient.open_progress_socket(comfy_url, client_id) + ws_retry_delay_s = 2.0 + log.debug("Comfy progress websocket reconnected for prompt=%s", prompt_key) + except Exception: + ws_next_retry_at = now + ws_retry_delay_s + ws_retry_delay_s = min(60.0, ws_retry_delay_s * 1.5) + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy websocket reconnect failed for job=%s prompt=%s; retrying in %.1fs", + job_id, + prompt_key, + ws_retry_delay_s, + exc_info=True, + ) + last_network_error_log_time = now_log + + if ws_client is not None: + try: + # Accept progress from follow-up prompts as well (meta-batch). + progress_event = ws_client.poll_progress(prompt_id=None) + except Exception: + progress_event = None + try: + ws_client.close() + except Exception: + pass + ws_client = None + ws_next_retry_at = monotonic() + ws_retry_delay_s + ws_retry_delay_s = min(60.0, ws_retry_delay_s * 1.5) + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy websocket progress read failed for job=%s prompt=%s; switching to retry mode", + job_id, + prompt_key, + exc_info=True, + ) + last_network_error_log_time = now_log + + if progress_event is not None: + elapsed = monotonic() - start_time + progress = int(progress_event.get("percent", 0)) + raw_value = float(progress_event.get("value", 0.0)) + raw_max = float(progress_event.get("max", 0.0)) + progress_type = str(progress_event.get("type", "")) + progress_node = str(progress_event.get("node", "")) + # Some workflows emit near-complete progress bursts at startup + # (e.g. tiny setup nodes), then reset to sampler progress. + # Ignore those bootstrap spikes for a short window. + if ( + (not accepted_progress_started) + and progress >= 95 + and elapsed < 20.0 + and raw_max <= 1.0 + ): + log.debug( + "Comfy WS progress setup-node spike ignored job=%s prompt=%s node=%s type=%s value=%s max=%s percent=%s elapsed=%.2fs", + job_id, + prompt_key, + progress_node, + progress_type, + raw_value, + raw_max, + progress, + elapsed, + ) + elif (not accepted_progress_started) and progress >= 95 and elapsed < 20.0: + log.debug( + "Comfy WS progress bootstrap spike ignored job=%s prompt=%s percent=%s elapsed=%.2fs", + job_id, + prompt_key, + progress, + elapsed, + ) + else: + accepted_progress_started = True + progress_signature = ( + progress_type, + progress_node, + int(progress), + round(raw_value, 3), + round(raw_max, 3), + ) + if progress_signature != last_progress_signature: + inferred_progress = int(max(0, min(99, progress))) + detail_text = "" + if progress_node: + detail_text = "node {} {}%".format(progress_node, int(progress)) + + log.debug( + "Comfy WS progress emit job=%s prompt=%s node=%s type=%s value=%s max=%s percent=%s", + job_id, + prompt_key, + progress_node, + progress_type, + raw_value, + raw_max, + inferred_progress, + ) + self.progress_changed.emit(job_id, inferred_progress) + self.progress_sub_changed.emit(job_id, int(max(0, min(99, progress)))) + if detail_text != last_progress_detail: + self.progress_detail_changed.emit(job_id, detail_text) + last_progress_detail = detail_text + last_progress_signature = progress_signature + ws_progress_emitted = True + ws_last_progress_time = monotonic() + ws_stale_reconnect_s = 60.0 + last_contact_time = monotonic() + if ws_client is not None and not ws_progress_emitted: + stale_for = now - ws_last_progress_time + if stale_for >= ws_stale_reconnect_s: + try: + ws_client.close() + except Exception: + pass + ws_client = None + ws_next_retry_at = now + ws_retry_delay_s + ws_retry_delay_s = min(60.0, ws_retry_delay_s * 1.5) + next_stale_reconnect_s = min( + ws_stale_reconnect_max_s, + max(60.0, ws_stale_reconnect_s * 1.5), + ) + log.debug( + "Comfy websocket stalled for job=%s prompt=%s (%.1fs >= %.1fs); forcing reconnect, next stall timeout %.1fs", + job_id, + prompt_key, + stale_for, + ws_stale_reconnect_s, + next_stale_reconnect_s, + ) + ws_stale_reconnect_s = next_stale_reconnect_s + # Use HTTP /progress only when websocket progress is unavailable. + # If websocket is connected but temporarily quiet, keep waiting for WS + # instead of spamming a misleading 404 fallback warning. + if ws_client is None: + progress_data = client.progress() + if progress_data is None: + if not progress_endpoint_unavailable: + log.debug( + "Comfy progress endpoint unavailable (404); waiting for websocket progress for job=%s", + job_id, + ) + progress_endpoint_unavailable = True + progress_data = {} + + progress_block = progress_data.get("progress", progress_data) + if not isinstance(progress_block, dict): + progress_block = {} + + value = float(progress_block.get("value", progress_block.get("current", 0.0))) + maximum = float(progress_block.get("max", progress_block.get("total", 0.0))) + progress_prompt = str( + progress_data.get("prompt_id", progress_block.get("prompt_id", "")) + ) + prompt_matches = (not progress_prompt) or (progress_prompt == prompt_key) + + now_log = monotonic() + if (now_log - last_progress_log_time) > 8.0: + log.debug( + "Comfy progress poll job=%s prompt=%s payload_keys=%s value=%s max=%s progress_prompt=%s prompt_match=%s", + job_id, + prompt_key, + list(progress_data.keys()) if isinstance(progress_data, dict) else type(progress_data), + value, + maximum, + progress_prompt, + prompt_matches, + ) + last_progress_log_time = now_log + + if maximum > 0 and prompt_matches: + progress = int(max(0, min(99, round((value / maximum) * 100.0)))) + progress_signature = ("poll", "", int(progress), round(value, 3), round(maximum, 3)) + if progress_signature != last_progress_signature: + log.debug( + "Comfy progress emit job=%s prompt=%s value=%s max=%s percent=%s", + job_id, + prompt_key, + value, + maximum, + progress, + ) + self.progress_changed.emit(job_id, progress) + self.progress_sub_changed.emit(job_id, int(max(0, min(99, progress)))) + if last_progress_detail: + self.progress_detail_changed.emit(job_id, "") + last_progress_detail = "" + last_progress_signature = progress_signature + last_contact_time = monotonic() + except Exception: + # Keep polling history and queue even if /progress is unavailable. + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug("Comfy progress poll failed for job=%s", job_id, exc_info=True) + last_network_error_log_time = now_log + + # Check queue to avoid timing out long-running but active jobs. + in_queue = False + try: + queue_data = client.queue() or {} + in_queue = ComfyClient.prompt_in_queue(prompt_id, queue_data) + last_contact_time = monotonic() + except Exception: + # If queue check fails, do not penalize the job immediately. + in_queue = True + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug("Comfy queue check temporarily unavailable for job=%s", job_id, exc_info=True) + last_network_error_log_time = now_log + if in_queue: + last_in_queue_time = monotonic() + else: + now_log = monotonic() + if (now_log - last_progress_log_time) > 8.0: + log.debug( + "Comfy queue check: prompt=%s not found in queue_running/queue_pending yet", + prompt_key, + ) + last_progress_log_time = now_log + + now = monotonic() + if (now - start_time) > timeout_s: + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, False, False, "Timed out waiting for ComfyUI history result", []) + return + + if (now - last_contact_time) > 60.0: + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy connection degraded for job=%s prompt=%s (no successful API contact for %.1fs); continuing retries", + job_id, + prompt_key, + now - last_contact_time, + ) + last_network_error_log_time = now_log + + # If prompt vanished from queue for an extended period and still no history, treat as failure. + if (now - last_in_queue_time) > 600: + self._job_prompts.pop(job_id, None) + self.job_finished.emit( + job_id, + False, + False, + "ComfyUI prompt is no longer in queue and has no history result.", + [], + ) + return + QThread.msleep(500) + except Exception as ex: + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, False, False, ComfyClient.summarize_error_text(ex), []) + finally: + if ws_client is not None: + ws_client.close() + + +class GenerationQueueManager(QObject): + """Single-worker, in-memory generation queue with per-file active-job limits.""" + + ACTIVE_STATES = {"queued", "running", "canceling"} + + job_added = pyqtSignal(str, object) + job_updated = pyqtSignal(str, str, int) + job_finished = pyqtSignal(str, str) + job_removed = pyqtSignal(str) + file_job_changed = pyqtSignal(str) + queue_changed = pyqtSignal() + + _run_job = pyqtSignal(str, object) + _cancel_job = pyqtSignal(str) + + def __init__(self, parent=None): + super().__init__(parent) + self.jobs = {} + self._queued = deque() + self._running_job_id = None + self._active_file_jobs = {} + + self._thread = QThread(self) + self._thread.setObjectName("generation_queue_worker") + self._worker = _GenerationWorker() + self._worker.moveToThread(self._thread) + self._run_job.connect(self._worker.run_job) + self._cancel_job.connect(self._worker.cancel_job) + self._worker.progress_changed.connect(self._on_progress_changed) + self._worker.progress_detail_changed.connect(self._on_progress_detail_changed) + self._worker.progress_sub_changed.connect(self._on_progress_sub_changed) + self._worker.job_finished.connect(self._on_job_finished) + self._thread.start() + + def enqueue(self, name, template_id, prompt, source_file_id=None, request=None): + source_file_id = str(source_file_id or "") + if source_file_id and self.get_active_job_for_file(source_file_id): + return None + + job_id = str(uuid.uuid4()) + cancel_event = Event() + job_request = dict(request or {}) + job_request["cancel_event"] = cancel_event + job = { + "id": job_id, + "name": str(name or "").strip(), + "template_id": str(template_id or "").strip(), + "prompt": str(prompt or "").strip(), + "source_file_id": source_file_id, + "status": "queued", + "progress": 0, + "sub_progress": 0, + "progress_detail": "", + "error": "", + "request": job_request, + "cancel_event": cancel_event, + } + self.jobs[job_id] = job + self._queued.append(job_id) + if source_file_id: + self._active_file_jobs[source_file_id] = job_id + + self.job_added.emit(job_id, source_file_id) + self.job_updated.emit(job_id, "queued", 0) + self._emit_file_changed(source_file_id) + self.queue_changed.emit() + self._start_next_if_idle() + return job_id + + def cancel_job(self, job_id): + job = self.jobs.get(job_id) + if not job: + log.debug("GenerationQueue cancel_job ignored; unknown job=%s", str(job_id)) + return False + + log.debug( + "GenerationQueue cancel_job request job=%s status=%s source_file_id=%s", + str(job_id), + str(job.get("status", "")), + str(job.get("source_file_id", "")), + ) + if job["status"] == "queued": + cancel_event = job.get("cancel_event") + if cancel_event is not None: + cancel_event.set() + log.debug("GenerationQueue cancel_event set for queued job=%s", str(job_id)) + job["status"] = "canceled" + self._queued = deque([queued_id for queued_id in self._queued if queued_id != job_id]) + self._release_file_slot(job.get("source_file_id", "")) + self.job_updated.emit(job_id, "canceled", int(job.get("progress", 0))) + self.job_finished.emit(job_id, "canceled") + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + log.debug("GenerationQueue cancel_job completed for queued job=%s", str(job_id)) + return True + + if job["status"] == "running": + cancel_event = job.get("cancel_event") + if cancel_event is not None: + cancel_event.set() + log.debug("GenerationQueue cancel_event set for running job=%s", str(job_id)) + job["status"] = "canceling" + self.job_updated.emit(job_id, "canceling", int(job.get("progress", 0))) + self._cancel_job.emit(job_id) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + log.debug("GenerationQueue cancel_job emitted worker cancel for running job=%s", str(job_id)) + return True + + log.debug("GenerationQueue cancel_job ignored for job=%s with status=%s", str(job_id), str(job.get("status", ""))) + return False + + def cancel_jobs_for_file(self, source_file_id): + source_file_id = str(source_file_id or "") + if not source_file_id: + return + for job in list(self.jobs.values()): + if job.get("source_file_id") == source_file_id and job.get("status") in self.ACTIVE_STATES: + self.cancel_job(job["id"]) + + def remove_job(self, job_id): + job = self.jobs.get(job_id) + if not job: + return False + if job.get("status") in self.ACTIVE_STATES: + return False + + source_file_id = job.get("source_file_id", "") + self.jobs.pop(job_id, None) + self.job_removed.emit(job_id) + self._emit_file_changed(source_file_id) + self.queue_changed.emit() + return True + + def get_job(self, job_id): + return self.jobs.get(job_id) + + def get_active_job_for_file(self, source_file_id): + source_file_id = str(source_file_id or "") + if not source_file_id: + return None + + job_id = self._active_file_jobs.get(source_file_id) + if not job_id: + return None + + job = self.jobs.get(job_id) + if not job or job.get("status") not in self.ACTIVE_STATES: + self._active_file_jobs.pop(source_file_id, None) + return None + return job + + def get_file_badge(self, source_file_id): + job = self.get_active_job_for_file(source_file_id) + if not job: + return None + + status = job.get("status") + progress = int(job.get("progress", 0)) + sub_progress = int(job.get("sub_progress", 0)) + detail = str(job.get("progress_detail", "") or "").strip() + if status == "queued": + label = "Queued" + elif status == "running": + label = "Generating {}%".format(progress) + if detail: + label = "{} ({})".format(label, detail) + elif status == "canceling": + label = "Canceling..." + else: + label = status.capitalize() + + return { + "status": status, + "progress": progress, + "sub_progress": sub_progress, + "label": label, + "job_id": job.get("id"), + } + + def shutdown(self): + if self._thread.isRunning(): + self._thread.quit() + self._thread.wait(2000) + + def _start_next_if_idle(self): + if self._running_job_id is not None: + return + if not self._queued: + return + + next_job_id = self._queued.popleft() + job = self.jobs.get(next_job_id) + if not job: + self._start_next_if_idle() + return + + self._running_job_id = next_job_id + job["status"] = "running" + job["progress"] = int(job.get("progress", 0)) + job["sub_progress"] = int(job.get("sub_progress", 0)) + self.job_updated.emit(next_job_id, "running", int(job["progress"])) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + self._run_job.emit(next_job_id, job.get("request", {})) + + def _release_file_slot(self, source_file_id): + source_file_id = str(source_file_id or "") + if source_file_id: + self._active_file_jobs.pop(source_file_id, None) + + def _emit_file_changed(self, source_file_id): + source_file_id = str(source_file_id or "") + if source_file_id: + self.file_job_changed.emit(source_file_id) + + @pyqtSlot(str, int) + def _on_progress_changed(self, job_id, progress): + job = self.jobs.get(job_id) + if not job: + return + if job.get("status") not in ("running", "canceling"): + return + job["progress"] = int(progress) + self.job_updated.emit(job_id, job.get("status"), int(progress)) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + + @pyqtSlot(str, str) + def _on_progress_detail_changed(self, job_id, detail): + job = self.jobs.get(job_id) + if not job: + return + if job.get("status") not in ("running", "canceling"): + return + detail_text = str(detail or "").strip() + if str(job.get("progress_detail", "") or "") == detail_text: + return + job["progress_detail"] = detail_text + self.job_updated.emit(job_id, job.get("status"), int(job.get("progress", 0))) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + + @pyqtSlot(str, int) + def _on_progress_sub_changed(self, job_id, progress): + job = self.jobs.get(job_id) + if not job: + return + if job.get("status") not in ("running", "canceling"): + return + p = int(max(0, min(99, progress))) + if int(job.get("sub_progress", 0)) == p: + return + job["sub_progress"] = p + self.job_updated.emit(job_id, job.get("status"), int(job.get("progress", 0))) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + + @pyqtSlot(str, bool, bool, str, object) + def _on_job_finished(self, job_id, success, canceled, error, outputs): + job = self.jobs.get(job_id) + if not job: + return + + if canceled: + job["status"] = "canceled" + elif success: + job["status"] = "completed" + job["progress"] = 100 + job["outputs"] = list(outputs or []) + else: + job["status"] = "failed" + job["error"] = str(error or "") + + source_file_id = job.get("source_file_id", "") + self._release_file_slot(source_file_id) + + self.job_updated.emit(job_id, job["status"], int(job.get("progress", 0))) + self.job_finished.emit(job_id, job["status"]) + self._emit_file_changed(source_file_id) + self.queue_changed.emit() + + if self._running_job_id == job_id: + self._running_job_id = None + self._start_next_if_idle() diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py new file mode 100644 index 000000000..31bf272a5 --- /dev/null +++ b/src/classes/generation_service.py @@ -0,0 +1,1586 @@ +""" + @file + @brief This file contains Comfy generation orchestration logic. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . + """ + +import os +import re +import tempfile +import json +import random +from time import time +from urllib.parse import unquote +from fractions import Fraction + +import openshot +from PyQt5.QtWidgets import QMessageBox, QDialog + +from classes import info +from classes import time_parts +from classes.app import get_app +from classes.comfy_client import ComfyClient +from classes.comfy_templates import ComfyTemplateRegistry +from classes.comfy_pipelines import ( + build_workflow, + is_supported_img2img_path, + pipeline_requires_checkpoint, + pipeline_requires_svd_checkpoint, + pipeline_requires_stable_audio_clip, + pipeline_requires_rife_model, + pipeline_requires_upscale_model, + DEFAULT_RIFE_VFI_MODEL, + DEFAULT_SD_CHECKPOINT, + DEFAULT_SD_BASE_CHECKPOINT, + DEFAULT_STABLE_AUDIO_CHECKPOINT, + DEFAULT_STABLE_AUDIO_CLIP, + DEFAULT_SVD_CHECKPOINT, + DEFAULT_UPSCALE_MODEL, +) +from classes.logger import log +from classes.query import File +from windows.generate import GenerateMediaDialog + + +class GenerationService: + """Encapsulates generation-specific UI + workflow behavior.""" + + LEGACY_PIPELINE_IDS = { + "txt2audio-stable-open", + "img2img-basic", + "upscale-realesrgan-x4", + "video-segment-scenes-transnet", + "video-frame-interpolation-rife2x", + "video-upscale-gan", + "video2video-basic", + "video-whisper-srt", + } + SAM2_DEFAULT_TARGET_BATCH_BYTES = 4 * 1024 * 1024 * 1024 # 4 GiB + SAM2_ESTIMATED_BYTES_PER_PIXEL = 24.0 + SAM2_ESTIMATED_BYTES_PER_PIXEL_HIGHLIGHT = 64.0 + SAM2_ESTIMATED_BYTES_PER_PIXEL_BLUR = 40.0 + SAM2_MIN_FRAMES_PER_BATCH = 4 + SAM2_MAX_FRAMES_PER_BATCH = 192 + + def __init__(self, win): + self.win = win + self._generation_temp_files = [] + self._comfy_status_cache = {"checked_at": 0.0, "available": False, "url": ""} + self._last_logged_comfy_state = None + self.template_registry = ComfyTemplateRegistry() + + def cleanup_temp_files(self): + for tmp_path in list(self._generation_temp_files): + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + self._generation_temp_files = [] + + def comfy_ui_url(self): + url = get_app().get_settings().get("comfy-ui-url") or "http://127.0.0.1:8188" + return str(url).strip().rstrip("/") + + def is_comfy_available(self, force=False): + now = time() + if not force and (now - self._comfy_status_cache["checked_at"]) < 2.0: + return self._comfy_status_cache["available"] + + url = self.comfy_ui_url() + available = False + error_text = "" + try: + available = ComfyClient(url).ping(timeout=0.5) + except Exception as ex: + available = False + error_text = str(ex) + + previous_available = bool(self._comfy_status_cache.get("available")) + previous_url = str(self._comfy_status_cache.get("url", "")) + self._comfy_status_cache["checked_at"] = now + self._comfy_status_cache["available"] = available + self._comfy_status_cache["url"] = url + + state = (url, bool(available)) + if force or state != self._last_logged_comfy_state or previous_url != url or previous_available != available: + if available: + log.info("ComfyUI check passed at %s", url) + else: + if error_text: + log.info("ComfyUI check failed at %s (%s)", url, error_text) + else: + log.info("ComfyUI check failed at %s", url) + self._last_logged_comfy_state = state + return available + + def can_open_generate_dialog(self): + return len(self.win.selected_file_ids()) <= 1 + + def _prepare_generation_source_path(self, source_file, template_id): + if not source_file: + return "" + + source_path = source_file.data.get("path", "") + media_type = source_file.data.get("media_type") + if template_id not in ("img2img-basic", "upscale-realesrgan-x4", "img2video-svd") or media_type != "image": + return source_path + + if is_supported_img2img_path(source_path): + return source_path + + tmp_fd, tmp_png = tempfile.mkstemp(prefix="openshot-comfy-", suffix=".png") + os.close(tmp_fd) + try: + clip = openshot.Clip(source_path) + frame = clip.Reader().GetFrame(1) + frame.Save(tmp_png, 1.0) + self._generation_temp_files.append(tmp_png) + return tmp_png + except Exception: + try: + os.remove(tmp_png) + except OSError: + pass + raise + + def _prepare_generation_video_input(self, source_file, client): + if not source_file: + raise ValueError("A source video is required.") + source_path = source_file.data.get("path", "") + if not source_path: + raise ValueError("Source video path is invalid.") + return client.upload_input_file(source_path) + + def _prepare_generation_image_input(self, local_image_path, client): + local_image_path = str(local_image_path or "").strip() + if not local_image_path: + raise ValueError("A source image is required.") + return client.upload_input_file(local_image_path) + + def _get_source_fps(self, source_file): + if not source_file: + return None + fps_data = source_file.data.get("fps") + if isinstance(fps_data, dict): + try: + num = float(fps_data.get("num", 0)) + den = float(fps_data.get("den", 0)) + except (TypeError, ValueError): + num = den = 0.0 + if num > 0 and den > 0: + return num / den + return None + + def _default_generation_name(self, source_file): + default_name = "generation" + if source_file: + path = source_file.data.get("path", "") + if path: + default_name = "{}_gen".format(os.path.splitext(os.path.basename(path))[0]) + return default_name + + def _get_source_dimensions(self, source_file): + if not source_file: + return (0, 0) + data = source_file.data if hasattr(source_file, "data") and isinstance(source_file.data, dict) else {} + try: + width = int(data.get("width", 0) or 0) + except Exception: + width = 0 + try: + height = int(data.get("height", 0) or 0) + except Exception: + height = 0 + return (max(0, width), max(0, height)) + + def _sam2_target_batch_bytes(self): + settings = get_app().get_settings() + raw_bytes = settings.get("comfy-sam2-target-batch-bytes") + if raw_bytes is not None: + try: + value = int(raw_bytes) + if value > 0: + return value + except Exception: + pass + raw_gb = settings.get("comfy-sam2-target-batch-gb") + if raw_gb is not None: + try: + value = float(raw_gb) + if value > 0.0: + return int(value * 1024 * 1024 * 1024) + except Exception: + pass + return int(self.SAM2_DEFAULT_TARGET_BATCH_BYTES) + + def _estimate_sam2_frames_per_batch(self, width, height, bytes_per_pixel=None): + width = int(max(0, width)) + height = int(max(0, height)) + if width <= 0 or height <= 0: + return self.SAM2_MIN_FRAMES_PER_BATCH + target_bytes = self._sam2_target_batch_bytes() + if bytes_per_pixel is None: + bytes_per_pixel = self.SAM2_ESTIMATED_BYTES_PER_PIXEL + try: + bytes_per_pixel = float(bytes_per_pixel) + except Exception: + bytes_per_pixel = float(self.SAM2_ESTIMATED_BYTES_PER_PIXEL) + bytes_per_frame = max( + 1.0, + float(width) * float(height) * bytes_per_pixel, + ) + frames = int(target_bytes / bytes_per_frame) + frames = max(self.SAM2_MIN_FRAMES_PER_BATCH, min(self.SAM2_MAX_FRAMES_PER_BATCH, frames)) + # Keep chunk sizes aligned for more stable batching behavior. + frames = max(self.SAM2_MIN_FRAMES_PER_BATCH, int((frames // 4) * 4)) + return frames + + def _apply_dynamic_sam2_meta_batch(self, workflow, source_file, template_id=None): + template_id = str(template_id or "").strip().lower() + # Only adjust non-legacy SAM2 video tracking templates/workflows. + if template_id and template_id not in ( + "video-blur-anything-sam2", + "video-highlight-anything-sam2", + "video-mask-anything-sam2", + ): + return + if not isinstance(workflow, dict): + return + + width, height = self._get_source_dimensions(source_file) + if width <= 0 or height <= 0: + return + + has_sam2_chunked = False + for node in workflow.values(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip().lower() + if class_type == "openshotsam2videosegmentationchunked": + has_sam2_chunked = True + break + if not has_sam2_chunked: + return + + # Account for downstream per-frame processing memory: + # - Highlight path is the heaviest (multiple full-frame tensor intermediates) + # - Blur path is moderately heavy + # - Mask-only path is closest to baseline SAM2 estimate + estimated_bpp = float(self.SAM2_ESTIMATED_BYTES_PER_PIXEL) + for node in workflow.values(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip().lower() + if class_type == "openshotimagehighlightmasked": + estimated_bpp = max(estimated_bpp, float(self.SAM2_ESTIMATED_BYTES_PER_PIXEL_HIGHLIGHT)) + elif class_type == "openshotimageblurmasked": + estimated_bpp = max(estimated_bpp, float(self.SAM2_ESTIMATED_BYTES_PER_PIXEL_BLUR)) + + dynamic_frames = self._estimate_sam2_frames_per_batch(width, height, bytes_per_pixel=estimated_bpp) + updated_chunk_nodes = 0 + updated_batch_nodes = 0 + for node in workflow.values(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip().lower() + inputs = node.get("inputs", {}) + if not isinstance(inputs, dict): + continue + if class_type == "openshotsam2videosegmentationchunked" and "chunk_size_frames" in inputs: + inputs["chunk_size_frames"] = int(dynamic_frames) + updated_chunk_nodes += 1 + if class_type == "vhs_batchmanager" and "frames_per_batch" in inputs: + inputs["frames_per_batch"] = int(dynamic_frames) + updated_batch_nodes += 1 + if updated_chunk_nodes or updated_batch_nodes: + log.info( + "Dynamic SAM2 batch size: %s frames (source=%sx%s, target_bytes=%s, est_bpp=%s, template=%s, chunk_nodes=%s, batch_nodes=%s)", + dynamic_frames, + width, + height, + self._sam2_target_batch_bytes(), + round(estimated_bpp, 2), + template_id or "unknown", + updated_chunk_nodes, + updated_batch_nodes, + ) + + def templates_for_context(self, source_file=None): + templates = self.template_registry.templates_for_context(source_file=source_file) + return [ + {"id": t.get("id"), "name": t.get("display_name"), "template": t} + for t in templates + ] + + def build_menu_templates(self, source_file=None): + grouped = {"create": [], "enhance": [], "unknown": []} + for template in self.template_registry.templates_for_context(source_file=source_file): + category = str(template.get("category", "unknown")) + if category not in grouped: + category = "unknown" + grouped[category].append(template) + return grouped + + def icon_for_template(self, template): + return self.template_registry.output_icon_name(template) + + def _prepare_nonlegacy_workflow( + self, + template, + payload_name, + prompt_text, + source_file, + source_path, + coordinates_positive_text="", + coordinates_negative_text="", + rectangles_positive_text="", + rectangles_negative_text="", + auto_mode=False, + tracking_selection=None, + highlight_color="", + highlight_opacity=0.0, + border_color="", + border_width=0, + mask_brightness=1.0, + background_brightness=1.0, + ): + workflow = self.template_registry.get_workflow_copy(template.get("id")) + if not workflow: + raise ValueError("Template workflow not found.") + + template_dir = "" + template_path = str((template or {}).get("path") or "").strip() + if template_path: + template_dir = os.path.dirname(template_path) + + def _resolve_template_local_file(path_text): + path_text = str(path_text or "").strip() + if not path_text: + return "" + if os.path.isabs(path_text): + return path_text if os.path.exists(path_text) else "" + if not template_dir: + return "" + candidate = os.path.abspath(os.path.join(template_dir, path_text)) + if os.path.exists(candidate): + return candidate + return "" + + template_id = str((template or {}).get("id") or "").strip().lower() + prompt_text = str(prompt_text or "").strip() + music_prompt_text = prompt_text + music_lyrics_text = "" + if template_id == "txt2music-ace-step" and prompt_text: + # Optional inline format: + #