From 1db7aa4a3d208a77345a764cc590daded1c72052 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Wed, 11 Feb 2026 23:53:52 -0600 Subject: [PATCH 01/27] Initial ComfyUI pipeline and client implementations. - Generate with no input file - Generate with 1 input image - Progress displayed in "Project Files" - Cancel Job menu - Full circle generations and editing proven --- src/classes/comfy_client.py | 383 ++++++++++++ src/classes/comfy_pipelines.py | 81 +++ src/classes/generation_queue.py | 576 ++++++++++++++++++ src/settings/_default.settings | 8 + .../cosmic/images/tool-generate-sparkle.svg | 8 + src/windows/generate_media.py | 218 +++++++ src/windows/main_window.py | 312 +++++++++- src/windows/models/files_model.py | 205 ++++++- src/windows/views/files_listview.py | 119 +++- src/windows/views/files_treeview.py | 131 +++- 10 files changed, 2026 insertions(+), 15 deletions(-) create mode 100644 src/classes/comfy_client.py create mode 100644 src/classes/comfy_pipelines.py create mode 100644 src/classes/generation_queue.py create mode 100644 src/themes/cosmic/images/tool-generate-sparkle.svg create mode 100644 src/windows/generate_media.py diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py new file mode 100644 index 000000000..987433b43 --- /dev/null +++ b/src/classes/comfy_client.py @@ -0,0 +1,383 @@ +""" + @file + @brief Small ComfyUI HTTP client for queue/poll/cancel operations. +""" + +import json +import os +import ssl +import base64 +import socket +import struct +from urllib.error import HTTPError +from urllib.request import Request, urlopen +from urllib.parse import quote, urlencode +from urllib.parse import urlparse + +from classes.logger import log + + +class ComfyProgressSocket: + """Minimal WebSocket client for ComfyUI /ws progress events.""" + + def __init__(self, base_url, client_id): + self.base_url = str(base_url or "").rstrip("/") + self.client_id = str(client_id or "") + self.sock = None + self._connect() + + def _connect(self): + parsed = urlparse(self.base_url) + scheme = parsed.scheme.lower() + host = parsed.hostname + if not host: + raise RuntimeError("Invalid ComfyUI URL for websocket") + port = parsed.port or (443 if scheme == "https" else 80) + path = "/ws?clientId={}".format(self.client_id) + + raw = socket.create_connection((host, port), timeout=4.0) + if scheme == "https": + ctx = ssl.create_default_context() + raw = ctx.wrap_socket(raw, server_hostname=host) + raw.settimeout(0.25) + self.sock = raw + + key = base64.b64encode(os.urandom(16)).decode("ascii") + req = ( + "GET {} HTTP/1.1\r\n" + "Host: {}:{}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: {}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ).format(path, host, port, key) + self.sock.sendall(req.encode("utf-8")) + + response = self._recv_http_headers() + if " 101 " not in response.split("\r\n", 1)[0]: + raise RuntimeError("WebSocket upgrade failed: {}".format(response.split("\r\n", 1)[0])) + + def close(self): + if self.sock is not None: + try: + self.sock.close() + except OSError: + pass + self.sock = None + + def poll_progress(self, prompt_id, max_messages=8): + """Read available frames and return latest progress payload for prompt_id.""" + if not self.sock: + return None + latest = None + prompt_key = str(prompt_id) + for _ in range(max_messages): + frame = self._recv_frame_nonblocking() + if frame is None: + break + opcode, payload = frame + + # Ping -> pong + if opcode == 0x9: + self._send_control_frame(0xA, payload) + continue + if opcode == 0x8: + self.close() + break + if opcode != 0x1: + continue + try: + msg = json.loads(payload.decode("utf-8")) + except Exception: + continue + if not isinstance(msg, dict): + continue + + event_type = msg.get("type") + event_data = msg.get("data", {}) + if event_type == "progress": + if not isinstance(event_data, dict): + continue + event_prompt = str(event_data.get("prompt_id", "")) + if not event_prompt or event_prompt != prompt_key: + continue + value = float(event_data.get("value", 0.0)) + maximum = float(event_data.get("max", 0.0)) + if maximum > 0: + latest = { + "percent": int(max(0, min(99, round((value / maximum) * 100.0)))), + "value": value, + "max": maximum, + "node": str(event_data.get("node", "")), + "type": "progress", + } + elif event_type == "progress_state": + # Newer Comfy events: data={prompt_id, nodes={node_id:{value,max}}} + if not isinstance(event_data, dict): + continue + event_prompt = str(event_data.get("prompt_id", "")) + if not event_prompt or event_prompt != prompt_key: + continue + nodes = event_data.get("nodes", {}) + if not isinstance(nodes, dict): + continue + # Pick node state with the largest max to avoid setup-node 1/1 spikes. + best = None + for node_id, node_state in nodes.items(): + if not isinstance(node_state, dict): + continue + value = float(node_state.get("value", 0.0)) + maximum = float(node_state.get("max", 0.0)) + if maximum > 0: + candidate = { + "percent": int(max(0, min(99, round((value / maximum) * 100.0)))), + "value": value, + "max": maximum, + "node": str(node_id), + "type": "progress_state", + } + if best is None or maximum > float(best.get("max", 0.0)): + best = candidate + if best is not None: + latest = best + return latest + + def _recv_http_headers(self): + data = b"" + while b"\r\n\r\n" not in data: + chunk = self.sock.recv(4096) + if not chunk: + break + data += chunk + if len(data) > 65536: + break + return data.decode("utf-8", errors="replace") + + def _recv_exact(self, size): + chunks = [] + remaining = size + while remaining > 0: + chunk = self.sock.recv(remaining) + if not chunk: + raise RuntimeError("WebSocket connection closed") + chunks.append(chunk) + remaining -= len(chunk) + return b"".join(chunks) + + def _recv_frame_nonblocking(self): + try: + header = self.sock.recv(2) + if not header: + return None + except socket.timeout: + return None + except OSError: + return None + + if len(header) < 2: + return None + b1, b2 = header[0], header[1] + opcode = b1 & 0x0F + masked = (b2 & 0x80) != 0 + length = b2 & 0x7F + + if length == 126: + length = struct.unpack("!H", self._recv_exact(2))[0] + elif length == 127: + length = struct.unpack("!Q", self._recv_exact(8))[0] + + mask_key = b"" + if masked: + mask_key = self._recv_exact(4) + + payload = self._recv_exact(length) if length > 0 else b"" + if masked and payload: + payload = bytes(payload[i] ^ mask_key[i % 4] for i in range(len(payload))) + + return opcode, payload + + def _send_control_frame(self, opcode, payload=b""): + if self.sock is None: + return + payload = payload or b"" + first = 0x80 | (opcode & 0x0F) + # Client frames must be masked. + mask = os.urandom(4) + length = len(payload) + if length < 126: + header = bytes([first, 0x80 | length]) + elif length < (1 << 16): + header = bytes([first, 0x80 | 126]) + struct.pack("!H", length) + else: + header = bytes([first, 0x80 | 127]) + struct.pack("!Q", length) + masked_payload = bytes(payload[i] ^ mask[i % 4] for i in range(length)) + self.sock.sendall(header + mask + masked_payload) + + +class ComfyClient: + """Minimal ComfyUI client using stdlib HTTP.""" + + def __init__(self, base_url): + self.base_url = str(base_url or "").rstrip("/") + + @staticmethod + def open_progress_socket(base_url, client_id): + return ComfyProgressSocket(base_url, client_id) + + def ping(self, timeout=0.5): + with urlopen("{}/system_stats".format(self.base_url), timeout=timeout) as response: + return int(response.status) >= 200 and int(response.status) < 300 + + def queue_prompt(self, prompt_graph, client_id): + payload = json.dumps({"prompt": prompt_graph, "client_id": client_id}).encode("utf-8") + req = Request( + "{}/prompt".format(self.base_url), + data=payload, + method="POST", + headers={"Content-Type": "application/json"}, + ) + try: + with urlopen(req, timeout=5.0) as response: + data = json.loads(response.read().decode("utf-8")) + except HTTPError as ex: + details = "" + try: + error_data = json.loads(ex.read().decode("utf-8")) + details = error_data.get("error", {}).get("type") or error_data.get("error", {}).get("message") or str(error_data) + except Exception: + details = str(ex) + raise RuntimeError("ComfyUI prompt rejected: {}".format(details)) + return data.get("prompt_id") + + def list_checkpoints(self): + """Return available checkpoint names from ComfyUI object info.""" + with urlopen("{}/object_info/CheckpointLoaderSimple".format(self.base_url), timeout=3.0) as response: + data = json.loads(response.read().decode("utf-8")) + + # Expected path: + # CheckpointLoaderSimple -> input -> required -> ckpt_name -> [ [..names..], {...meta...} ] + node_info = data.get("CheckpointLoaderSimple", {}) + required = node_info.get("input", {}).get("required", {}) + ckpt_input = required.get("ckpt_name", []) + if not ckpt_input or not isinstance(ckpt_input, list): + return [] + values = ckpt_input[0] if len(ckpt_input) > 0 else [] + if not isinstance(values, list): + return [] + return [str(v) for v in values if str(v).strip()] + + def history(self, prompt_id): + with urlopen("{}/history/{}".format(self.base_url, quote(str(prompt_id))), timeout=3.0) as response: + return json.loads(response.read().decode("utf-8")) + + def progress(self): + """Return ComfyUI /progress payload.""" + try: + with urlopen("{}/progress".format(self.base_url), timeout=3.0) as response: + return json.loads(response.read().decode("utf-8")) + except HTTPError as ex: + if int(getattr(ex, "code", 0)) == 404: + # Some ComfyUI versions don't expose /progress. + return None + raise + + def interrupt(self, prompt_id=None): + payload = {} + if prompt_id: + payload["prompt_id"] = str(prompt_id) + log.debug("ComfyClient interrupt request base_url=%s prompt_id=%s", self.base_url, payload.get("prompt_id", "")) + req = Request( + "{}/interrupt".format(self.base_url), + data=json.dumps(payload).encode("utf-8"), + method="POST", + headers={"Content-Type": "application/json"}, + ) + with urlopen(req, timeout=3.0) as response: + log.debug("ComfyClient interrupt response status=%s", int(response.status)) + return int(response.status) >= 200 and int(response.status) < 300 + + def cancel_prompt(self, prompt_id): + """Request ComfyUI to delete/cancel a prompt from the queue.""" + log.debug("ComfyClient cancel_prompt request base_url=%s prompt_id=%s", self.base_url, str(prompt_id)) + payload = json.dumps({"delete": [str(prompt_id)]}).encode("utf-8") + req = Request( + "{}/queue".format(self.base_url), + data=payload, + method="POST", + headers={"Content-Type": "application/json"}, + ) + with urlopen(req, timeout=3.0) as response: + log.debug("ComfyClient cancel_prompt response status=%s", int(response.status)) + return int(response.status) >= 200 and int(response.status) < 300 + + def queue(self): + """Return ComfyUI queue state.""" + with urlopen("{}/queue".format(self.base_url), timeout=3.0) as response: + return json.loads(response.read().decode("utf-8")) + + @staticmethod + def prompt_in_queue(prompt_id, queue_data): + """Check if prompt_id appears in queue_running/queue_pending payload.""" + pid = str(prompt_id) + if not isinstance(queue_data, dict): + return False + + for key in ("queue_running", "queue_pending"): + entries = queue_data.get(key, []) + if not isinstance(entries, list): + continue + for entry in entries: + # Common format: [number, prompt_id, ...] + if isinstance(entry, list) and len(entry) >= 2 and str(entry[1]) == pid: + return True + # Defensive fallback for dict-like entries + if isinstance(entry, dict): + if str(entry.get("prompt_id", "")) == pid: + return True + return False + + @staticmethod + def extract_image_outputs(history_entry, save_node_ids=None): + """Return a flat list of image refs from a history entry.""" + outputs = [] + if not isinstance(history_entry, dict): + return outputs + node_outputs = history_entry.get("outputs", {}) + if not isinstance(node_outputs, dict): + return outputs + save_node_ids = set(str(node_id) for node_id in (save_node_ids or [])) + + for node_id, node_out in node_outputs.items(): + if save_node_ids and str(node_id) not in save_node_ids: + continue + if not isinstance(node_out, dict): + continue + images = node_out.get("images", []) + if not isinstance(images, list): + continue + for img in images: + if not isinstance(img, dict): + continue + if img.get("filename"): + outputs.append({ + "filename": str(img.get("filename")), + "subfolder": str(img.get("subfolder", "")), + "type": str(img.get("type", "output")), + }) + return outputs + + def download_image(self, image_ref, destination_path): + """Download a Comfy image reference to a local file path.""" + params = { + "filename": image_ref.get("filename", ""), + "subfolder": image_ref.get("subfolder", ""), + "type": image_ref.get("type", "output"), + } + url = "{}/view?{}".format(self.base_url, urlencode(params)) + with urlopen(url, timeout=10.0) as response: + data = response.read() + + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + with open(destination_path, "wb") as handle: + handle.write(data) diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py new file mode 100644 index 000000000..d5aec84cb --- /dev/null +++ b/src/classes/comfy_pipelines.py @@ -0,0 +1,81 @@ +""" + @file + @brief Basic built-in ComfyUI pipeline definitions. +""" + +import random +import os + + +RASTER_IMAGE_EXTENSIONS = { + ".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff", ".gif", +} + + +def is_supported_img2img_path(path): + ext = os.path.splitext(str(path or ""))[1].lower() + return ext in RASTER_IMAGE_EXTENSIONS + + +def _supports_img2img(source_file=None): + if not source_file: + return False + if source_file.data.get("media_type") != "image": + return False + path = source_file.data.get("path", "") + return is_supported_img2img_path(path) + + +def available_pipelines(source_file=None): + pipelines = [{"id": "txt2img-basic", "name": "Basic Text to Image"}] + if _supports_img2img(source_file): + pipelines.insert(0, {"id": "img2img-basic", "name": "Basic Image Variation"}) + return pipelines + + +def build_workflow(pipeline_id, prompt_text, source_path, output_prefix, checkpoint_name=None): + prompt_text = str(prompt_text or "cinematic shot, highly detailed").strip() + if not prompt_text: + prompt_text = "cinematic shot, highly detailed" + output_prefix = str(output_prefix or "openshot_gen").strip() or "openshot_gen" + checkpoint_name = str(checkpoint_name or "").strip() or "v1-5-pruned-emaonly.safetensors" + seed = random.randint(1, 2**31 - 1) + + if pipeline_id == "img2img-basic": + if not is_supported_img2img_path(source_path): + raise ValueError( + "The selected file is not a supported raster image for this pipeline. " + "Use PNG/JPG/WebP/BMP/TIFF or switch to Text to Image." + ) + return { + "1": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "2": {"inputs": {"text": prompt_text, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "3": {"inputs": {"text": "low quality, blurry", "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"image": str(source_path or ""), "upload": "image"}, "class_type": "LoadImage"}, + "5": {"inputs": {"pixels": ["4", 0], "vae": ["1", 2]}, "class_type": "VAEEncode"}, + "6": { + "inputs": { + "seed": seed, "steps": 20, "cfg": 7.0, "sampler_name": "euler", "scheduler": "normal", + "denoise": 0.65, "model": ["1", 0], "positive": ["2", 0], "negative": ["3", 0], "latent_image": ["5", 0], + }, + "class_type": "KSampler", + }, + "7": {"inputs": {"samples": ["6", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "8": {"inputs": {"filename_prefix": output_prefix, "images": ["7", 0]}, "class_type": "SaveImage"}, + } + + return { + "1": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "2": {"inputs": {"text": prompt_text, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "3": {"inputs": {"text": "low quality, blurry", "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"width": 1024, "height": 576, "batch_size": 1}, "class_type": "EmptyLatentImage"}, + "5": { + "inputs": { + "seed": seed, "steps": 20, "cfg": 7.0, "sampler_name": "euler", "scheduler": "normal", + "denoise": 1.0, "model": ["1", 0], "positive": ["2", 0], "negative": ["3", 0], "latent_image": ["4", 0], + }, + "class_type": "KSampler", + }, + "6": {"inputs": {"samples": ["5", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "7": {"inputs": {"filename_prefix": output_prefix, "images": ["6", 0]}, "class_type": "SaveImage"}, + } diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py new file mode 100644 index 000000000..4dfed87bb --- /dev/null +++ b/src/classes/generation_queue.py @@ -0,0 +1,576 @@ +""" + @file + @brief Lightweight in-memory generation queue for ComfyUI-backed jobs. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + """ + +import uuid +from collections import deque +from threading import Event +from time import monotonic + +from PyQt5.QtCore import QObject, QThread, pyqtSignal, pyqtSlot + +from classes.comfy_client import ComfyClient +from classes.logger import log + + +class _GenerationWorker(QObject): + """Background worker that simulates generation progress for queued jobs.""" + + progress_changed = pyqtSignal(str, int) + job_finished = pyqtSignal(str, bool, bool, str, object) + + def __init__(self): + super().__init__() + self._cancel_requested = set() + self._job_prompts = {} + + def _is_cancel_requested(self, job_id, cancel_event): + return (job_id in self._cancel_requested) or (cancel_event is not None and cancel_event.is_set()) + + @pyqtSlot(str, object) + def run_job(self, job_id, request): + request = request or {} + cancel_event = request.get("cancel_event") + if request.get("workflow") and request.get("comfy_url"): + self._run_comfy_job(job_id, request) + return + + canceled = False + for step in range(1, 21): + QThread.msleep(250) + if self._is_cancel_requested(job_id, cancel_event): + canceled = True + break + self.progress_changed.emit(job_id, min(step * 5, 99)) + + self._cancel_requested.discard(job_id) + + if canceled: + self.job_finished.emit(job_id, False, True, "", []) + else: + self.progress_changed.emit(job_id, 100) + self.job_finished.emit(job_id, True, False, "", []) + + @pyqtSlot(str) + def cancel_job(self, job_id): + log.debug("GenerationWorker cancel_job received job=%s", str(job_id)) + self._cancel_requested.add(job_id) + + def _run_comfy_job(self, job_id, request): + comfy_url = request.get("comfy_url") + workflow = request.get("workflow") + client_id = request.get("client_id") or "openshot-qt" + timeout_s = int(request.get("timeout_s") or 86400) # default 24 hours safety cap + save_node_ids = list(request.get("save_node_ids") or []) + cancel_event = request.get("cancel_event") + client = ComfyClient(comfy_url) + ws_client = None + + try: + prompt_id = client.queue_prompt(workflow, client_id) + if not prompt_id: + self.job_finished.emit(job_id, False, False, "ComfyUI returned an invalid prompt_id", []) + return + self._job_prompts[job_id] = prompt_id + try: + ws_client = ComfyClient.open_progress_socket(comfy_url, client_id) + log.debug("Comfy progress websocket connected for prompt=%s", str(prompt_id)) + except Exception: + log.debug("Comfy progress websocket unavailable; continuing without live progress", exc_info=True) + + start_time = monotonic() + last_in_queue_time = start_time + last_progress_log_time = 0.0 + progress_endpoint_unavailable = False + accepted_progress_started = False + + while True: + if self._is_cancel_requested(job_id, cancel_event): + log.debug("Comfy cancel requested for job=%s prompt=%s", job_id, str(prompt_id)) + cancel_ok = False + cancel_errors = [] + + # Retry cancellation a few times and verify prompt no longer appears in Comfy queue. + for attempt in range(1, 181): + try: + cancel_ok = client.cancel_prompt(prompt_id) or cancel_ok + except Exception as ex: + cancel_errors.append("queue: {}".format(ex)) + + try: + cancel_ok = client.interrupt(prompt_id=prompt_id) or cancel_ok + except Exception as ex: + cancel_errors.append("interrupt: {}".format(ex)) + + try: + history = client.history(prompt_id) or {} + prompt_key = str(prompt_id) + history_entry = history.get(prompt_key) or history.get(prompt_id) or None + if isinstance(history_entry, dict): + status_obj = history_entry.get("status", {}) if isinstance(history_entry, dict) else {} + status_str = str(status_obj.get("status_str", "")).lower() + # Comfy commonly marks interrupted runs as failed/error in history. + if status_str in ("error", "failed"): + cancel_ok = True + log.debug( + "Comfy cancel confirmed by history status for job=%s prompt=%s status=%s", + job_id, + prompt_key, + status_str, + ) + break + except Exception as ex: + cancel_errors.append("history-check: {}".format(ex)) + + try: + queue_data = client.queue() or {} + if not ComfyClient.prompt_in_queue(prompt_id, queue_data): + cancel_ok = True + log.debug( + "Comfy cancel confirmed by queue absence for job=%s prompt=%s on attempt=%s", + job_id, + str(prompt_id), + attempt, + ) + break + except Exception as ex: + cancel_errors.append("queue-check: {}".format(ex)) + + if attempt % 10 == 0: + log.debug( + "Comfy cancel still pending job=%s prompt=%s attempt=%s", + job_id, + str(prompt_id), + attempt, + ) + QThread.msleep(500) + + self._cancel_requested.discard(job_id) + self._job_prompts.pop(job_id, None) + if cancel_ok: + self.job_finished.emit(job_id, False, True, "", []) + else: + self.job_finished.emit( + job_id, + False, + False, + "ComfyUI did not accept cancel request ({})".format("; ".join(cancel_errors) or "unknown"), + [], + ) + return + + history = client.history(prompt_id) or {} + prompt_key = str(prompt_id) + history_entry = history.get(prompt_key) or history.get(prompt_id) or None + if history_entry is not None: + status_obj = history_entry.get("status", {}) if isinstance(history_entry, dict) else {} + status_str = str(status_obj.get("status_str", "")).lower() + if status_str in ("error", "failed"): + error_text = "ComfyUI job failed." + messages = status_obj.get("messages", []) + if isinstance(messages, list) and messages: + error_text = str(messages[-1]) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, False, False, error_text, []) + return + image_outputs = ComfyClient.extract_image_outputs(history_entry, save_node_ids=save_node_ids) + self.progress_changed.emit(job_id, 100) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, True, False, "", image_outputs) + return + + # Query ComfyUI's live progress values when available. + try: + if ws_client is not None: + progress_event = ws_client.poll_progress(prompt_id) + if progress_event is not None: + elapsed = monotonic() - start_time + progress = int(progress_event.get("percent", 0)) + raw_value = float(progress_event.get("value", 0.0)) + raw_max = float(progress_event.get("max", 0.0)) + progress_type = str(progress_event.get("type", "")) + progress_node = str(progress_event.get("node", "")) + # Some workflows emit near-complete progress bursts at startup + # (e.g. tiny setup nodes), then reset to sampler progress. + # Ignore those bootstrap spikes for a short window. + if ( + (not accepted_progress_started) + and progress >= 95 + and elapsed < 20.0 + and raw_max <= 1.0 + ): + log.debug( + "Comfy WS progress setup-node spike ignored job=%s prompt=%s node=%s type=%s value=%s max=%s percent=%s elapsed=%.2fs", + job_id, + prompt_key, + progress_node, + progress_type, + raw_value, + raw_max, + progress, + elapsed, + ) + elif (not accepted_progress_started) and progress >= 95 and elapsed < 20.0: + log.debug( + "Comfy WS progress bootstrap spike ignored job=%s prompt=%s percent=%s elapsed=%.2fs", + job_id, + prompt_key, + progress, + elapsed, + ) + else: + accepted_progress_started = True + log.debug( + "Comfy WS progress emit job=%s prompt=%s node=%s type=%s value=%s max=%s percent=%s", + job_id, + prompt_key, + progress_node, + progress_type, + raw_value, + raw_max, + progress, + ) + self.progress_changed.emit(job_id, progress) + else: + progress_data = client.progress() or {} + if progress_data is None: + if not progress_endpoint_unavailable: + log.debug( + "Comfy progress endpoint unavailable (404). Progress bar updates disabled for this job=%s", + job_id, + ) + progress_endpoint_unavailable = True + progress_data = {} + + progress_block = progress_data.get("progress", progress_data) + if not isinstance(progress_block, dict): + progress_block = {} + + value = float(progress_block.get("value", progress_block.get("current", 0.0))) + maximum = float(progress_block.get("max", progress_block.get("total", 0.0))) + progress_prompt = str( + progress_data.get("prompt_id", progress_block.get("prompt_id", "")) + ) + prompt_matches = (not progress_prompt) or (progress_prompt == prompt_key) + + now_log = monotonic() + if (now_log - last_progress_log_time) > 8.0: + log.debug( + "Comfy progress poll job=%s prompt=%s payload_keys=%s value=%s max=%s progress_prompt=%s prompt_match=%s", + job_id, + prompt_key, + list(progress_data.keys()) if isinstance(progress_data, dict) else type(progress_data), + value, + maximum, + progress_prompt, + prompt_matches, + ) + last_progress_log_time = now_log + + if maximum > 0 and prompt_matches: + progress = int(max(0, min(99, round((value / maximum) * 100.0)))) + log.debug( + "Comfy progress emit job=%s prompt=%s value=%s max=%s percent=%s", + job_id, + prompt_key, + value, + maximum, + progress, + ) + self.progress_changed.emit(job_id, progress) + except Exception: + # Keep polling history and queue even if /progress is unavailable. + if not progress_endpoint_unavailable: + log.debug("Comfy progress poll failed for job=%s", job_id, exc_info=True) + + # Check queue to avoid timing out long-running but active jobs. + in_queue = False + try: + queue_data = client.queue() or {} + in_queue = ComfyClient.prompt_in_queue(prompt_id, queue_data) + except Exception: + # If queue check fails, do not penalize the job immediately. + in_queue = True + if in_queue: + last_in_queue_time = monotonic() + else: + now_log = monotonic() + if (now_log - last_progress_log_time) > 8.0: + log.debug( + "Comfy queue check: prompt=%s not found in queue_running/queue_pending yet", + prompt_key, + ) + last_progress_log_time = now_log + + now = monotonic() + if (now - start_time) > timeout_s: + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, False, False, "Timed out waiting for ComfyUI history result", []) + return + + # If prompt vanished from queue for an extended period and still no history, treat as failure. + if (now - last_in_queue_time) > 600: + self._job_prompts.pop(job_id, None) + self.job_finished.emit( + job_id, + False, + False, + "ComfyUI prompt is no longer in queue and has no history result.", + [], + ) + return + QThread.msleep(500) + except Exception as ex: + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, False, False, str(ex), []) + finally: + if ws_client is not None: + ws_client.close() + + +class GenerationQueueManager(QObject): + """Single-worker, in-memory generation queue with per-file active-job limits.""" + + ACTIVE_STATES = {"queued", "running", "canceling"} + + job_added = pyqtSignal(str, object) + job_updated = pyqtSignal(str, str, int) + job_finished = pyqtSignal(str, str) + job_removed = pyqtSignal(str) + file_job_changed = pyqtSignal(str) + queue_changed = pyqtSignal() + + _run_job = pyqtSignal(str, object) + _cancel_job = pyqtSignal(str) + + def __init__(self, parent=None): + super().__init__(parent) + self.jobs = {} + self._queued = deque() + self._running_job_id = None + self._active_file_jobs = {} + + self._thread = QThread(self) + self._thread.setObjectName("generation_queue_worker") + self._worker = _GenerationWorker() + self._worker.moveToThread(self._thread) + self._run_job.connect(self._worker.run_job) + self._cancel_job.connect(self._worker.cancel_job) + self._worker.progress_changed.connect(self._on_progress_changed) + self._worker.job_finished.connect(self._on_job_finished) + self._thread.start() + + def enqueue(self, name, template_id, prompt, source_file_id=None, request=None): + source_file_id = str(source_file_id or "") + if source_file_id and self.get_active_job_for_file(source_file_id): + return None + + job_id = str(uuid.uuid4()) + cancel_event = Event() + job_request = dict(request or {}) + job_request["cancel_event"] = cancel_event + job = { + "id": job_id, + "name": str(name or "").strip(), + "template_id": str(template_id or "").strip(), + "prompt": str(prompt or "").strip(), + "source_file_id": source_file_id, + "status": "queued", + "progress": 0, + "error": "", + "request": job_request, + "cancel_event": cancel_event, + } + self.jobs[job_id] = job + self._queued.append(job_id) + if source_file_id: + self._active_file_jobs[source_file_id] = job_id + + self.job_added.emit(job_id, source_file_id) + self.job_updated.emit(job_id, "queued", 0) + self._emit_file_changed(source_file_id) + self.queue_changed.emit() + self._start_next_if_idle() + return job_id + + def cancel_job(self, job_id): + job = self.jobs.get(job_id) + if not job: + log.debug("GenerationQueue cancel_job ignored; unknown job=%s", str(job_id)) + return False + + log.debug( + "GenerationQueue cancel_job request job=%s status=%s source_file_id=%s", + str(job_id), + str(job.get("status", "")), + str(job.get("source_file_id", "")), + ) + if job["status"] == "queued": + cancel_event = job.get("cancel_event") + if cancel_event is not None: + cancel_event.set() + log.debug("GenerationQueue cancel_event set for queued job=%s", str(job_id)) + job["status"] = "canceled" + self._queued = deque([queued_id for queued_id in self._queued if queued_id != job_id]) + self._release_file_slot(job.get("source_file_id", "")) + self.job_updated.emit(job_id, "canceled", int(job.get("progress", 0))) + self.job_finished.emit(job_id, "canceled") + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + log.debug("GenerationQueue cancel_job completed for queued job=%s", str(job_id)) + return True + + if job["status"] == "running": + cancel_event = job.get("cancel_event") + if cancel_event is not None: + cancel_event.set() + log.debug("GenerationQueue cancel_event set for running job=%s", str(job_id)) + job["status"] = "canceling" + self.job_updated.emit(job_id, "canceling", int(job.get("progress", 0))) + self._cancel_job.emit(job_id) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + log.debug("GenerationQueue cancel_job emitted worker cancel for running job=%s", str(job_id)) + return True + + log.debug("GenerationQueue cancel_job ignored for job=%s with status=%s", str(job_id), str(job.get("status", ""))) + return False + + def cancel_jobs_for_file(self, source_file_id): + source_file_id = str(source_file_id or "") + if not source_file_id: + return + for job in list(self.jobs.values()): + if job.get("source_file_id") == source_file_id and job.get("status") in self.ACTIVE_STATES: + self.cancel_job(job["id"]) + + def remove_job(self, job_id): + job = self.jobs.get(job_id) + if not job: + return False + if job.get("status") in self.ACTIVE_STATES: + return False + + source_file_id = job.get("source_file_id", "") + self.jobs.pop(job_id, None) + self.job_removed.emit(job_id) + self._emit_file_changed(source_file_id) + self.queue_changed.emit() + return True + + def get_job(self, job_id): + return self.jobs.get(job_id) + + def get_active_job_for_file(self, source_file_id): + source_file_id = str(source_file_id or "") + if not source_file_id: + return None + + job_id = self._active_file_jobs.get(source_file_id) + if not job_id: + return None + + job = self.jobs.get(job_id) + if not job or job.get("status") not in self.ACTIVE_STATES: + self._active_file_jobs.pop(source_file_id, None) + return None + return job + + def get_file_badge(self, source_file_id): + job = self.get_active_job_for_file(source_file_id) + if not job: + return None + + status = job.get("status") + progress = int(job.get("progress", 0)) + if status == "queued": + label = "Queued" + elif status == "running": + label = "Generating {}%".format(progress) + elif status == "canceling": + label = "Canceling..." + else: + label = status.capitalize() + + return {"status": status, "progress": progress, "label": label, "job_id": job.get("id")} + + def shutdown(self): + if self._thread.isRunning(): + self._thread.quit() + self._thread.wait(2000) + + def _start_next_if_idle(self): + if self._running_job_id is not None: + return + if not self._queued: + return + + next_job_id = self._queued.popleft() + job = self.jobs.get(next_job_id) + if not job: + self._start_next_if_idle() + return + + self._running_job_id = next_job_id + job["status"] = "running" + job["progress"] = int(job.get("progress", 0)) + self.job_updated.emit(next_job_id, "running", int(job["progress"])) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + self._run_job.emit(next_job_id, job.get("request", {})) + + def _release_file_slot(self, source_file_id): + source_file_id = str(source_file_id or "") + if source_file_id: + self._active_file_jobs.pop(source_file_id, None) + + def _emit_file_changed(self, source_file_id): + source_file_id = str(source_file_id or "") + if source_file_id: + self.file_job_changed.emit(source_file_id) + + @pyqtSlot(str, int) + def _on_progress_changed(self, job_id, progress): + job = self.jobs.get(job_id) + if not job: + return + if job.get("status") not in ("running", "canceling"): + return + job["progress"] = int(progress) + self.job_updated.emit(job_id, job.get("status"), int(progress)) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + + @pyqtSlot(str, bool, bool, str, object) + def _on_job_finished(self, job_id, success, canceled, error, outputs): + job = self.jobs.get(job_id) + if not job: + return + + if canceled: + job["status"] = "canceled" + elif success: + job["status"] = "completed" + job["progress"] = 100 + job["outputs"] = list(outputs or []) + else: + job["status"] = "failed" + job["error"] = str(error or "") + + source_file_id = job.get("source_file_id", "") + self._release_file_slot(source_file_id) + + self.job_updated.emit(job_id, job["status"], int(job.get("progress", 0))) + self.job_finished.emit(job_id, job["status"]) + self._emit_file_changed(source_file_id) + self.queue_changed.emit() + + if self._running_job_id == job_id: + self._running_job_id = None + self._start_next_if_idle() diff --git a/src/settings/_default.settings b/src/settings/_default.settings index 82034c8c7..563a68c29 100644 --- a/src/settings/_default.settings +++ b/src/settings/_default.settings @@ -446,6 +446,14 @@ "category": "Experimental", "setting": "qwidget-based-timeline" }, + { + "value": "http://127.0.0.1:8188", + "title": "Comfy UI URL", + "type": "text", + "restart": false, + "category": "Experimental", + "setting": "comfy-ui-url" + }, { "value": "start", "title": "Thumbnail Style", diff --git a/src/themes/cosmic/images/tool-generate-sparkle.svg b/src/themes/cosmic/images/tool-generate-sparkle.svg new file mode 100644 index 000000000..de7946858 --- /dev/null +++ b/src/themes/cosmic/images/tool-generate-sparkle.svg @@ -0,0 +1,8 @@ + + + + + diff --git a/src/windows/generate_media.py b/src/windows/generate_media.py new file mode 100644 index 000000000..5d27a6cc9 --- /dev/null +++ b/src/windows/generate_media.py @@ -0,0 +1,218 @@ +""" + @file + @brief Simple Generate dialog for ComfyUI pipeline jobs. +""" + +import os + +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QIcon, QPixmap +from PyQt5.QtWidgets import ( + QDialog, QVBoxLayout, QHBoxLayout, QFormLayout, QLabel, QLineEdit, + QComboBox, QTextEdit, QTabWidget, QWidget, QPushButton +) + +from classes import info +from classes.thumbnail import GetThumbPath + + +class GenerateMediaDialog(QDialog): + """Minimal generate dialog with a simple default-first layout.""" + + PREVIEW_WIDTH = 180 + PREVIEW_HEIGHT = 128 + + def __init__(self, source_file=None, templates=None, parent=None): + super().__init__(parent) + self.source_file = source_file + self.templates = templates or [] + self.setObjectName("generateDialog") + self.setWindowTitle("Generate") + self.setMinimumWidth(620) + self.setMinimumHeight(460) + + root = QVBoxLayout(self) + root.setContentsMargins(14, 14, 14, 14) + root.setSpacing(10) + + root.addLayout(self._build_top_block()) + + self.tabs = QTabWidget(self) + self.tabs.setObjectName("generateTabs") + self.tabs.addTab(self._build_prompt_tab(), "Prompt") + self.tabs.addTab(self._build_mask_tab(), "Mask") + self.tabs.addTab(self._build_advanced_tab(), "Advanced") + root.addWidget(self.tabs, 1) + + button_row = QHBoxLayout() + button_row.addStretch(1) + self.cancel_button = QPushButton("Cancel") + self.generate_button = QPushButton("Generate") + self.generate_button.setIcon(QIcon(":/icons/Humanity/actions/16/star.svg")) + self.cancel_button.clicked.connect(self.reject) + self.generate_button.clicked.connect(self._on_generate_clicked) + button_row.addWidget(self.cancel_button) + button_row.addWidget(self.generate_button) + root.addLayout(button_row) + self._apply_dialog_theme() + + def get_payload(self): + return { + "name": self.name_edit.text().strip(), + "template_id": self.template_combo.currentData() or self.template_combo.currentText(), + "prompt": self.prompt_edit.toPlainText().strip(), + } + + def _build_top_block(self): + block = QHBoxLayout() + block.setSpacing(12) + + if self.source_file: + self.thumbnail_label = QLabel() + self.thumbnail_label.setFixedSize(self.PREVIEW_WIDTH, self.PREVIEW_HEIGHT) + self.thumbnail_label.setAlignment(Qt.AlignCenter) + self.thumbnail_label.setStyleSheet("border: 1px solid palette(mid);") + self._load_thumbnail() + block.addWidget(self.thumbnail_label, 0) + + setup_form = QFormLayout() + setup_form.setContentsMargins(0, 0, 0, 0) + setup_form.setVerticalSpacing(8) + + default_name = "generation" + if self.source_file: + path = self.source_file.data.get("path", "") + if path: + default_name = "{}_gen".format(os.path.splitext(os.path.basename(path))[0]) + + self.name_edit = QLineEdit() + self.name_edit.setPlaceholderText("Output file name") + self.name_edit.setText(default_name) + setup_form.addRow("Name", self.name_edit) + + self.template_combo = QComboBox() + if self.templates: + for template in self.templates: + self.template_combo.addItem(template.get("name", ""), template.get("id", "")) + else: + self.template_combo.addItem("Basic Text to Image", "txt2img-basic") + setup_form.addRow("Template", self.template_combo) + + if self.source_file: + source_path = self.source_file.data.get("path", "") + source_label = QLabel(os.path.basename(source_path)) + source_label.setToolTip(source_path) + setup_form.addRow("Source", source_label) + + right_container = QWidget(self) + right_container.setLayout(setup_form) + block.addWidget(right_container, 1) + return block + + def _build_prompt_tab(self): + tab = QWidget(self) + tab.setObjectName("pagePrompt") + layout = QVBoxLayout(tab) + layout.setContentsMargins(8, 8, 8, 8) + self.prompt_edit = QTextEdit() + self.prompt_edit.setPlaceholderText("Describe what to generate...") + self.prompt_edit.setMinimumHeight(140) + layout.addWidget(self.prompt_edit) + return tab + + def _build_mask_tab(self): + tab = QWidget(self) + tab.setObjectName("pageMask") + layout = QVBoxLayout(tab) + layout.setContentsMargins(8, 8, 8, 8) + label = QLabel("Mask tools will appear for templates that support drawing.") + label.setWordWrap(True) + layout.addWidget(label) + layout.addStretch(1) + return tab + + def _build_advanced_tab(self): + tab = QWidget(self) + tab.setObjectName("pageAdvanced") + layout = QVBoxLayout(tab) + layout.setContentsMargins(8, 8, 8, 8) + label = QLabel("Advanced controls are template-driven and will appear here.") + label.setWordWrap(True) + layout.addWidget(label) + layout.addStretch(1) + return tab + + def _load_thumbnail(self): + path = "" + media_type = self.source_file.data.get("media_type") + if media_type in ["video", "image"]: + path = GetThumbPath(self.source_file.id, 1) + elif media_type == "audio": + path = os.path.join(info.PATH, "images", "AudioThumbnail.svg") + + pix = QPixmap(path) if path else QPixmap() + if not pix.isNull(): + pix = pix.scaled( + self.PREVIEW_WIDTH - 2, + self.PREVIEW_HEIGHT - 2, + Qt.KeepAspectRatio, + Qt.SmoothTransformation, + ) + self.thumbnail_label.setPixmap(pix) + else: + self.thumbnail_label.setText("No Preview") + + def _on_generate_clicked(self): + if not self.name_edit.text().strip(): + self.name_edit.setFocus(Qt.TabFocusReason) + return + self.accept() + + def _apply_dialog_theme(self): + self.setStyleSheet(""" +QDialog#generateDialog { + background-color: #192332; + color: #91C3FF; +} +QDialog#generateDialog QTabWidget#generateTabs QWidget#pagePrompt, +QDialog#generateDialog QTabWidget#generateTabs QWidget#pageMask, +QDialog#generateDialog QTabWidget#generateTabs QWidget#pageAdvanced { + background-color: #141923; + border: none; +} +QDialog#generateDialog QTabWidget#generateTabs QTabBar::tab { + margin-left: 14px; + margin-top: 10px; + padding: 6px 2px; + color: rgba(145, 195, 255, 0.5); +} +QDialog#generateDialog QTabWidget#generateTabs QTabBar::tab:selected { + color: rgba(145, 195, 255, 1.0); + border-bottom: 1.2px solid #53a0ed; +} +QDialog#generateDialog QLineEdit, +QDialog#generateDialog QTextEdit, +QDialog#generateDialog QComboBox { + background-color: #141923; + color: #91C3FF; + border: 1px solid rgba(145, 195, 255, 0.20); + border-radius: 4px; + padding: 6px 8px; +} +QDialog#generateDialog QPushButton { + background-color: #283241; + color: #91C3FF; + border: 1px solid rgba(145, 195, 255, 0.20); + border-radius: 4px; + padding: 6px 10px; +} +QDialog#generateDialog QPushButton:hover { + background-color: #323C50; +} +QDialog#generateDialog QPushButton:focus, +QDialog#generateDialog QLineEdit:focus, +QDialog#generateDialog QTextEdit:focus, +QDialog#generateDialog QComboBox:focus { + border: 1px solid #53a0ed; +} +""") diff --git a/src/windows/main_window.py b/src/windows/main_window.py index 536d90054..a06a77ff2 100644 --- a/src/windows/main_window.py +++ b/src/windows/main_window.py @@ -34,11 +34,14 @@ import shutil import uuid import webbrowser +import tempfile from time import sleep, time from datetime import datetime from uuid import uuid4 import zipfile import threading +from urllib.parse import urlparse +from urllib.request import urlopen import openshot # Python module for libopenshot (required video editing module installed separately) from PyQt5.QtCore import ( @@ -63,6 +66,9 @@ from classes.logger import log from classes.metrics import track_metric_session, track_metric_screen from classes.query import File, Clip, Transition, Marker, Track, Effect +from classes.generation_queue import GenerationQueueManager +from classes.comfy_pipelines import available_pipelines, build_workflow, is_supported_img2img_path +from classes.comfy_client import ComfyClient from classes.thumbnail import httpThumbnailServerThread, httpThumbnailException from classes.time_parts import secondsToTimecode from classes.timeline import TimelineSync @@ -86,6 +92,7 @@ from windows.views.transitions_listview import TransitionsListView from windows.views.transitions_treeview import TransitionsTreeView from windows.views.tutorial import TutorialManager +from windows.generate_media import GenerateMediaDialog class MainWindow(updates.UpdateWatcher, QMainWindow): @@ -209,6 +216,18 @@ def closeEvent(self, event): if self.http_server_thread: self.http_server_thread.kill() + # Stop generation queue worker thread (if any) + if getattr(self, "generation_queue", None): + self.generation_queue.shutdown() + + # Cleanup temporary generation source files + for tmp_path in getattr(self, "_generation_temp_files", []): + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + # Stop ZMQ polling thread (if any) if app.logger_libopenshot: app.logger_libopenshot.kill() @@ -1980,6 +1999,277 @@ def actionSplitFile_trigger(self): else: log.info('Cutting Cancelled') + def comfy_ui_url(self): + url = get_app().get_settings().get("comfy-ui-url") or "http://127.0.0.1:8188" + return str(url).strip().rstrip("/") + + def _prepare_generation_source_path(self, source_file, template_id): + """Return a source path suitable for the selected template.""" + if not source_file: + return "" + + source_path = source_file.data.get("path", "") + media_type = source_file.data.get("media_type") + if template_id != "img2img-basic" or media_type != "image": + return source_path + + if is_supported_img2img_path(source_path): + return source_path + + # Convert unsupported image formats (such as SVG) into a temporary PNG. + tmp_fd, tmp_png = tempfile.mkstemp(prefix="openshot-comfy-", suffix=".png") + os.close(tmp_fd) + try: + clip = openshot.Clip(source_path) + frame = clip.Reader().GetFrame(1) + frame.Save(tmp_png, 1.0) + self._generation_temp_files.append(tmp_png) + return tmp_png + except Exception: + try: + os.remove(tmp_png) + except OSError: + pass + raise + + def is_comfy_available(self, force=False): + now = time() + if not force and (now - self._comfy_status_cache["checked_at"]) < 2.0: + return self._comfy_status_cache["available"] + + url = self.comfy_ui_url() + parsed = urlparse(url) + available = False + if parsed.scheme in ("http", "https") and parsed.netloc: + try: + with urlopen("{}/system_stats".format(url), timeout=0.5) as response: + available = int(response.status) >= 200 and int(response.status) < 300 + except Exception: + available = False + + self._comfy_status_cache["checked_at"] = now + self._comfy_status_cache["available"] = available + return available + + def can_open_generate_dialog(self): + # Keep action clickable for valid selection counts. + # Comfy availability is validated when the action is triggered. + return len(self.selected_file_ids()) <= 1 + + def active_generation_job_for_file(self, file_id): + if not getattr(self, "generation_queue", None): + return None + return self.generation_queue.get_active_job_for_file(file_id) + + def cancel_generation_job(self, job_id): + if not job_id: + log.debug("MainWindow cancel_generation_job ignored; empty job_id") + return + log.debug("MainWindow cancel_generation_job requested job=%s", str(job_id)) + if self.generation_queue.cancel_job(job_id): + log.debug("MainWindow cancel_generation_job accepted job=%s", str(job_id)) + self.statusBar.showMessage("Generation canceled", 3000) + else: + log.debug("MainWindow cancel_generation_job rejected job=%s", str(job_id)) + + def actionCancelGenerationJob_trigger(self, checked=True): + file_id = self.current_file_id() + if not file_id: + return + active_job = self.active_generation_job_for_file(file_id) + if active_job: + self.cancel_generation_job(active_job.get("id")) + + def actionGenerate_trigger(self, checked=True): + selected_files = self.selected_files() + if len(selected_files) > 1: + return + + if not self.is_comfy_available(force=True): + msg = QMessageBox(self) + msg.setWindowTitle("ComfyUI Unavailable") + msg.setText( + "OpenShot could not connect to ComfyUI at:\n{}\n\n" + "Start ComfyUI or update the URL in Preferences > Experimental.".format(self.comfy_ui_url()) + ) + msg.exec_() + return + + source_file = selected_files[0] if selected_files else None + templates = available_pipelines(source_file=source_file) + win = GenerateMediaDialog(source_file=source_file, templates=templates, parent=self) + if win.exec_() != QDialog.Accepted: + return + + payload = win.get_payload() + payload_name = self._next_generation_name(payload.get("name")) + source_file_id = source_file.id if source_file else None + try: + source_path = self._prepare_generation_source_path(source_file, payload.get("template_id")) + except Exception as ex: + QMessageBox.warning( + self, + "Source Conversion Failed", + "OpenShot could not convert this image into PNG for ComfyUI.\n\n{}".format(ex), + ) + return + checkpoint_name = None + try: + checkpoint_names = ComfyClient(self.comfy_ui_url()).list_checkpoints() + if checkpoint_names: + checkpoint_name = checkpoint_names[0] + except Exception as ex: + log.warning("Failed to query ComfyUI checkpoints: %s", ex) + + if not checkpoint_name: + QMessageBox.information( + self, + "No Checkpoints Found", + "ComfyUI has no checkpoints available for CheckpointLoaderSimple.\n" + "Add a model to ComfyUI/models/checkpoints and try again.", + ) + return + + try: + workflow = build_workflow( + payload.get("template_id"), + payload.get("prompt"), + source_path, + payload_name, + checkpoint_name=checkpoint_name, + ) + except Exception as ex: + QMessageBox.information(self, "Invalid Input", str(ex)) + return + request = { + "comfy_url": self.comfy_ui_url(), + "workflow": workflow, + "client_id": "openshot-qt", + "timeout_s": 21600, + "save_node_ids": [str(node_id) for node_id, node in workflow.items() if node.get("class_type") == "SaveImage"], + } + job_id = self.generation_queue.enqueue( + payload_name, + payload.get("template_id"), + payload.get("prompt"), + source_file_id=source_file_id, + request=request, + ) + if not job_id: + QMessageBox.information( + self, + "Generation Already Active", + "Only one active generation is allowed per source file.", + ) + return + + self.statusBar.showMessage("Queued generation job", 3000) + + def _on_generation_job_finished(self, job_id, status): + job = self.generation_queue.get_job(job_id) if getattr(self, "generation_queue", None) else None + if not job: + return + + if status == "completed": + imported = self._import_generation_outputs(job) + if imported > 0: + self.statusBar.showMessage("Generation completed and imported {} file(s)".format(imported), 5000) + else: + self.statusBar.showMessage("Generation completed (no output files found)", 5000) + return + + if status == "canceled": + self.statusBar.showMessage("Generation canceled", 3000) + return + + if status == "failed": + error_text = str(job.get("error") or "ComfyUI generation failed.") + self.statusBar.showMessage("Generation failed", 5000) + QMessageBox.warning(self, "Generation Failed", error_text) + return + + def _import_generation_outputs(self, job): + outputs = list(job.get("outputs", []) or []) + if not outputs: + return 0 + + request = job.get("request", {}) or {} + comfy_url = str(request.get("comfy_url") or self.comfy_ui_url()) + client = ComfyClient(comfy_url) + output_dir = os.path.join(info.USER_PATH, "comfy_outputs") + os.makedirs(output_dir, exist_ok=True) + + name_raw = str(job.get("name") or "generation") + safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name_raw).strip("._") + if not safe_name: + safe_name = "generation" + + saved_paths = [] + for index, image_ref in enumerate(outputs, start=1): + original_name = str(image_ref.get("filename", "output.png")) + ext = os.path.splitext(original_name)[1] or ".png" + local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ext) + local_path = self._next_available_path(os.path.join(output_dir, local_name)) + try: + client.download_image(image_ref, local_path) + saved_paths.append(local_path) + except Exception as ex: + log.warning("Failed to download Comfy output %s: %s", image_ref, ex) + + if not saved_paths: + return 0 + + self.files_model.add_files( + saved_paths, + quiet=True, + prevent_image_seq=True, + prevent_recent_folder=True, + ) + return len(saved_paths) + + def _next_generation_name(self, requested_name): + """Return a unique generation-friendly name for project files/jobs.""" + base = re.sub(r"[^A-Za-z0-9._-]+", "_", str(requested_name or "").strip()).strip("._") + if not base: + base = "generation" + + existing_names = set() + for file_obj in File.filter(): + if not file_obj: + continue + display_name = str(file_obj.data.get("name") or os.path.basename(file_obj.data.get("path", "")) or "") + if display_name: + stem = os.path.splitext(display_name)[0] + existing_names.add(stem.lower()) + + if base.lower() not in existing_names: + return base + + # If the requested name already exists, append/increment _genN. + name_root = base + m = re.match(r"^(.*?)(?:_gen(\d+))?$", base, re.IGNORECASE) + if m: + name_root = (m.group(1) or base).rstrip("_") or "generation" + n = 1 + while True: + candidate = "{}_gen{}".format(name_root, n) + if candidate.lower() not in existing_names: + return candidate + n += 1 + + def _next_available_path(self, path): + """Return a non-colliding file path by appending _N when needed.""" + if not os.path.exists(path): + return path + folder = os.path.dirname(path) + stem, ext = os.path.splitext(os.path.basename(path)) + n = 2 + while True: + candidate = os.path.join(folder, "{}_{}{}".format(stem, n, ext)) + if not os.path.exists(candidate): + return candidate + n += 1 + def actionRemove_from_Project_trigger(self): log.debug("actionRemove_from_Project_trigger") @@ -1991,6 +2281,10 @@ def actionRemove_from_Project_trigger(self): if not f: continue + # Cancel queued/running generation jobs tied to this file + if getattr(self, "generation_queue", None): + self.generation_queue.cancel_jobs_for_file(f.data.get("id")) + # Find matching clips (if any) clips = Clip.filter(file_id=f.data.get("id")) for c in clips: @@ -3379,7 +3673,7 @@ def initModels(self): s = get_app().get_settings() # Setup files tree and list view (both share a model) - self.files_model = FilesModel() + self.files_model = FilesModel(generation_queue=self.generation_queue) self.filesTreeView = FilesTreeView(self.files_model) self.filesListView = FilesListView(self.files_model) self.files_model.update_model() @@ -3441,6 +3735,17 @@ def initModels(self): self.emojiListView = EmojisListView(self.emojis_model) self.tabEmojis.layout().addWidget(self.emojiListView) + def _init_generation_actions(self): + self.actionGenerate = QAction("Generate...", self) + self.actionGenerate.setObjectName("actionGenerate") + sparkle_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "tool-generate-sparkle.svg") + self.actionGenerate.setIcon(QIcon(sparkle_icon_path)) + self.actionGenerate.triggered.connect(self.actionGenerate_trigger) + + self.actionCancelGenerationJob = QAction("Cancel Job", self) + self.actionCancelGenerationJob.setObjectName("actionCancelGenerationJob") + self.actionCancelGenerationJob.triggered.connect(self.actionCancelGenerationJob_trigger) + def actionInsertKeyframe(self): log.debug("actionInsertKeyframe") if self.selected_clips or self.selected_transitions: @@ -4010,6 +4315,7 @@ def __init__(self, *args): # Load UI from designer self.selected_items = [] + self._generation_temp_files = [] ui_util.load_ui(self, self.ui_path) # Init UI @@ -4017,6 +4323,10 @@ def __init__(self, *args): # Create dock toolbars, set initial state of items, etc self.setup_toolbars() + self._comfy_status_cache = {"checked_at": 0.0, "available": False} + self.generation_queue = GenerationQueueManager(self) + self.generation_queue.job_finished.connect(self._on_generation_job_finished) + self._init_generation_actions() # Add window as watcher to receive undo/redo status updates app.updates.add_watcher(self) diff --git a/src/windows/models/files_model.py b/src/windows/models/files_model.py index bf4597a33..5a76f0686 100644 --- a/src/windows/models/files_model.py +++ b/src/windows/models/files_model.py @@ -144,6 +144,7 @@ def __init__(self, **kwargs): class FilesModel(QObject, updates.UpdateInterface): ModelRefreshed = pyqtSignal() + PLACEHOLDER_PREFIX = "__genjob__:" # This method is invoked by the UpdateManager each time a change happens (i.e UpdateInterface) def changed(self, action): @@ -303,6 +304,7 @@ def update_model(self, clear=True, delete_file_id=None, update_file_id=None): # Emit signal when model is updated self.ModelRefreshed.emit() + self._rebuild_generation_placeholders() def add_files(self, files, image_seq_details=None, quiet=False, prevent_image_seq=False, prevent_recent_folder=False): @@ -606,8 +608,13 @@ def selected_file_ids(self): """ Get a list of file IDs for all selected files """ # Get the indexes for column 5 of all selected rows selected = self.selection_model.selectedRows(5) - - return [idx.data() for idx in selected] + ids = [] + for idx in selected: + file_id = idx.data() + if not file_id or self._is_generation_placeholder(file_id): + continue + ids.append(file_id) + return ids def selected_files(self): """ Get a list of File objects representing the current selection """ @@ -624,7 +631,10 @@ def current_file_id(self): cur = self.selection_model.selectedIndexes()[0] if cur and cur.isValid(): - return cur.sibling(cur.row(), 5).data() + file_id = cur.sibling(cur.row(), 5).data() + if self._is_generation_placeholder(file_id): + return None + return file_id def current_file(self): """ Get the File object for the current files-view item, or the first selection """ @@ -683,7 +693,8 @@ def _sync_list_to_tree_selection(self, selected, deselected): finally: self._syncing_selection = False - def __init__(self, *args): + def __init__(self, generation_queue=None, *args): + self.generation_queue = generation_queue # Add self as listener to project data updates # (undo/redo, as well as normal actions handled within this class all update the model) @@ -725,6 +736,13 @@ def __init__(self, *args): app.window.FileUpdated.connect(self.update_file_thumbnail) app.window.refreshFilesSignal.connect( functools.partial(self.update_model, clear=False)) + if self.generation_queue: + self.generation_queue.file_job_changed.connect(self._refresh_file_generation_display) + self.generation_queue.queue_changed.connect(self._refresh_all_generation_displays) + self.generation_queue.job_added.connect(self._on_generation_job_added) + self.generation_queue.job_updated.connect(self._on_generation_job_updated) + self.generation_queue.job_finished.connect(self._on_generation_job_finished) + self.generation_queue.job_removed.connect(self._on_generation_job_removed) # Call init for superclass QObject super(QObject, FilesModel).__init__(self, *args) @@ -744,3 +762,182 @@ def __init__(self, *args): log.info("Enabled {} model tests for emoji data".format(len(self.model_tests))) except ImportError: pass + + def _is_generation_placeholder(self, file_id): + return str(file_id or "").startswith(self.PLACEHOLDER_PREFIX) + + def _placeholder_id_for_job(self, job_id): + return "{}{}".format(self.PLACEHOLDER_PREFIX, str(job_id or "")) + + def _job_id_from_placeholder(self, file_id): + file_id = str(file_id or "") + if not self._is_generation_placeholder(file_id): + return None + return file_id[len(self.PLACEHOLDER_PREFIX):] + + def _placeholder_row_for_job(self, job_id): + placeholder_id = self._placeholder_id_for_job(job_id) + if placeholder_id not in self.model_ids: + return None + id_index = self.model_ids[placeholder_id] + if not id_index.isValid(): + return None + return id_index.row() + + def _add_generation_placeholder(self, job_id): + job = self.generation_queue.get_job(job_id) if self.generation_queue else None + if not job: + return + if job.get("source_file_id"): + return + + placeholder_id = self._placeholder_id_for_job(job_id) + if placeholder_id in self.model_ids and self.model_ids[placeholder_id].isValid(): + self._update_generation_placeholder(job_id) + return + + name = str(job.get("name") or "generation") + status = str(job.get("status") or "queued") + progress = int(job.get("progress", 0)) + label = name + if status == "running": + label = "{} ({}%)".format(name, progress) + elif status == "queued": + label = "{} (Queued)".format(name) + elif status == "canceling": + label = "{} (Canceling...)".format(name) + + row = [] + generate_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "tool-generate-sparkle.svg") + emoji_icon_path = os.path.join(info.PATH, "emojis", "color", "svg", "2728.svg") + if os.path.exists(generate_icon_path): + icon = QIcon(generate_icon_path) + elif os.path.exists(emoji_icon_path): + icon = QIcon(emoji_icon_path) + else: + icon = QIcon(":/icons/Humanity/actions/16/media-record.svg") + flags = Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemNeverHasChildren + + col = QStandardItem(icon, label) + col.setFlags(flags) + row.append(col) + + col = QStandardItem(label) + col.setFlags(flags) + row.append(col) + + col = QStandardItem("generation") + col.setFlags(flags) + row.append(col) + + col = QStandardItem("generation_job") + col.setFlags(flags) + row.append(col) + + col = QStandardItem("") + col.setFlags(flags) + row.append(col) + + col = QStandardItem(placeholder_id) + col.setFlags(flags) + row.append(col) + + self.model.appendRow(row) + self.model_ids[placeholder_id] = QPersistentModelIndex(row[5].index()) + self.ModelRefreshed.emit() + + def _update_generation_placeholder(self, job_id): + row = self._placeholder_row_for_job(job_id) + if row is None: + self._add_generation_placeholder(job_id) + return + job = self.generation_queue.get_job(job_id) if self.generation_queue else None + if not job: + return + + name = str(job.get("name") or "generation") + status = str(job.get("status") or "queued") + progress = int(job.get("progress", 0)) + label = name + if status == "running": + label = "{} ({}%)".format(name, progress) + elif status == "queued": + label = "{} (Queued)".format(name) + elif status == "canceling": + label = "{} (Canceling...)".format(name) + + self.model.item(row, 0).setText(label) + self.model.item(row, 1).setText(label) + left = self.model.index(row, 0) + right = self.model.index(row, 1) + self.model.dataChanged.emit(left, right, [Qt.DisplayRole, Qt.AccessibleTextRole]) + self.ModelRefreshed.emit() + + def _remove_generation_placeholder(self, job_id): + placeholder_id = self._placeholder_id_for_job(job_id) + if placeholder_id not in self.model_ids: + return + id_index = self.model_ids.get(placeholder_id) + if not id_index or not id_index.isValid(): + self.model_ids.pop(placeholder_id, None) + return + row = id_index.row() + self.model.removeRows(row, 1, id_index.parent()) + self.model.submit() + self.model_ids.pop(placeholder_id, None) + self.ModelRefreshed.emit() + + def _rebuild_generation_placeholders(self): + if not self.generation_queue: + return + for job in list(self.generation_queue.jobs.values()): + if job.get("source_file_id"): + continue + if job.get("status") in ("completed", "failed", "canceled"): + self._remove_generation_placeholder(job.get("id")) + else: + self._add_generation_placeholder(job.get("id")) + + def _on_generation_job_added(self, job_id, source_file_id): + if source_file_id: + return + self._add_generation_placeholder(job_id) + + def _on_generation_job_updated(self, job_id, status, progress): + job = self.generation_queue.get_job(job_id) if self.generation_queue else None + if not job or job.get("source_file_id"): + return + if status in ("completed", "failed", "canceled"): + self._remove_generation_placeholder(job_id) + else: + self._update_generation_placeholder(job_id) + + def _on_generation_job_finished(self, job_id, status): + job = self.generation_queue.get_job(job_id) if self.generation_queue else None + if not job or job.get("source_file_id"): + return + self._remove_generation_placeholder(job_id) + + def _on_generation_job_removed(self, job_id): + self._remove_generation_placeholder(job_id) + + def _refresh_file_generation_display(self, file_id): + if file_id not in self.model_ids: + return + id_index = self.model_ids[file_id] + if not id_index.isValid(): + return + + row = id_index.row() + left = self.model.index(row, 0) + right = self.model.index(row, 0) + self.model.dataChanged.emit(left, right, [Qt.DisplayRole, Qt.AccessibleTextRole]) + self.ModelRefreshed.emit() + + def _refresh_all_generation_displays(self): + if self.model.rowCount() < 1: + return + left = self.model.index(0, 0) + right = self.model.index(self.model.rowCount() - 1, 0) + self.model.dataChanged.emit(left, right, [Qt.DisplayRole, Qt.AccessibleTextRole]) + self.ModelRefreshed.emit() diff --git a/src/windows/views/files_listview.py b/src/windows/views/files_listview.py index 02a2e4a4a..e6ab76bfd 100644 --- a/src/windows/views/files_listview.py +++ b/src/windows/views/files_listview.py @@ -29,8 +29,8 @@ import uuid from PyQt5.QtCore import QSize, Qt, QPoint, QRegExp -from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon -from PyQt5.QtWidgets import QListView, QAbstractItemView +from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor +from PyQt5.QtWidgets import QListView, QAbstractItemView, QStyledItemDelegate, QStyleOptionViewItem, QStyle from classes import info from classes.app import get_app @@ -39,6 +39,87 @@ from .menu import StyledContextMenu +def _is_generation_placeholder(file_id): + return str(file_id or "").startswith("__genjob__:") + + +def _job_id_from_placeholder(file_id): + file_id = str(file_id or "") + if not _is_generation_placeholder(file_id): + return None + return file_id.split(":", 1)[1] + + +class FilesListProgressDelegate(QStyledItemDelegate): + """Paint a thin progress line over list-view thumbnails.""" + + def __init__(self, view): + super().__init__(view) + self.view = view + + def paint(self, painter, option, index): + super().paint(painter, option, index) + + # list_proxy_model index -> proxy_model index -> source model index + proxy_index = self.view.files_model.list_proxy_model.mapToSource(index) + if not proxy_index or not proxy_index.isValid(): + return + source_index = self.view.files_model.proxy_model.mapToSource(proxy_index) + if not source_index or not source_index.isValid(): + return + + file_id = source_index.sibling(source_index.row(), 5).data(Qt.DisplayRole) + queue = getattr(self.view.win, "generation_queue", None) + if not file_id or not queue: + return + badge = queue.get_file_badge(file_id) + if not badge and _is_generation_placeholder(file_id): + job = queue.get_job(_job_id_from_placeholder(file_id)) + if job and job.get("status") in ("queued", "running", "canceling"): + label = "Queued" if job.get("status") == "queued" else "Generating" + badge = { + "status": job.get("status"), + "progress": int(job.get("progress", 0)), + "label": label, + "job_id": job.get("id"), + } + if not badge: + return + + progress = int(badge.get("progress", 0)) + status = badge.get("status", "") + if status == "queued": + progress = max(progress, 2) + if progress <= 0: + return + + opt = QStyleOptionViewItem(option) + self.initStyleOption(opt, index) + style = opt.widget.style() if opt.widget else self.view.style() + deco_rect = style.subElementRect(QStyle.SE_ItemViewItemDecoration, opt, opt.widget) + if not deco_rect.isValid(): + return + + bar_height = 3 + bar_margin = 2 + full_rect = deco_rect.adjusted(1, 0, -1, 0) + full_rect.setTop(deco_rect.bottom() - bar_height - bar_margin + 1) + full_rect.setHeight(bar_height) + if full_rect.width() <= 2: + return + + fill_width = max(1, int((full_rect.width() * min(progress, 100)) / 100.0)) + fill_rect = full_rect.adjusted(0, 0, -(full_rect.width() - fill_width), 0) + + painter.save() + painter.setPen(Qt.NoPen) + painter.setBrush(QColor("#283241")) + painter.drawRect(full_rect) + painter.setBrush(QColor("#53A0ED")) + painter.drawRect(fill_rect) + painter.restore() + + class FilesListView(QListView): """ A ListView QWidget used on the main window """ drag_item_size = QSize(48, 48) @@ -58,6 +139,30 @@ def contextMenuEvent(self, event): menu = StyledContextMenu(parent=self) menu.addAction(self.win.actionImportFiles) + self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) + menu.addAction(self.win.actionGenerate) + + active_job = None + file_id = None + if index.isValid(): + model = self.model() + source_index = model.mapToSource(index) + id_index = source_index.sibling(source_index.row(), 5) + file_id = model.sourceModel().data(id_index, Qt.DisplayRole) + if _is_generation_placeholder(file_id): + job_id = _job_id_from_placeholder(file_id) + queue = getattr(self.win, "generation_queue", None) + active_job = queue.get_job(job_id) if queue else None + if active_job and active_job.get("status") not in ("queued", "running", "canceling"): + active_job = None + else: + active_job = self.win.active_generation_job_for_file(file_id) + if active_job: + cancel_action = menu.addAction(_("Cancel Job")) + cancel_action.triggered.connect( + lambda checked=False, job_id=active_job.get("id"): self.win.cancel_generation_job(job_id) + ) + menu.addSeparator() menu.addAction(self.win.actionDetailsView) if index.isValid(): @@ -74,6 +179,9 @@ def contextMenuEvent(self, event): # Add edit title option (if svg file) file = File.get(id=file_id) + if not file: + menu.popup(event.globalPos()) + return if file and file.data.get("path").endswith(".svg"): menu.addAction(self.win.actionEditTitle) menu.addAction(self.win.actionDuplicate) @@ -133,6 +241,12 @@ def startDrag(self, supportedActions): # Get first column indexes for all selected rows selected = self.selectionModel().selectedRows(0) + selected = [ + idx for idx in selected + if not _is_generation_placeholder( + self.model().sourceModel().data(self.model().mapToSource(idx).sibling(self.model().mapToSource(idx).row(), 5), Qt.DisplayRole) + ) + ] # Check if there are any selected items if not selected: @@ -249,6 +363,7 @@ def __init__(self, model, *args): self.setSelectionMode(QAbstractItemView.ExtendedSelection) self.setSelectionBehavior(QAbstractItemView.SelectRows) self.setSelectionModel(self.files_model.list_selection_model) + self.setItemDelegate(FilesListProgressDelegate(self)) # Keep track of mouse press start position to determine when to start drag self.setAcceptDrops(True) diff --git a/src/windows/views/files_treeview.py b/src/windows/views/files_treeview.py index 07dbddc62..7d0e6231a 100644 --- a/src/windows/views/files_treeview.py +++ b/src/windows/views/files_treeview.py @@ -31,8 +31,8 @@ import uuid from PyQt5.QtCore import QSize, Qt, QPoint -from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon -from PyQt5.QtWidgets import QTreeView, QAbstractItemView, QSizePolicy, QHeaderView +from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor +from PyQt5.QtWidgets import QTreeView, QAbstractItemView, QSizePolicy, QHeaderView, QStyledItemDelegate, QStyleOptionViewItem, QStyle from classes import info from classes.app import get_app @@ -41,6 +41,82 @@ from .menu import StyledContextMenu +def _is_generation_placeholder(file_id): + return str(file_id or "").startswith("__genjob__:") + + +def _job_id_from_placeholder(file_id): + file_id = str(file_id or "") + if not _is_generation_placeholder(file_id): + return None + return file_id.split(":", 1)[1] + + +class FilesTreeProgressDelegate(QStyledItemDelegate): + """Paint a thin progress line over thumbnail cells.""" + + def __init__(self, view): + super().__init__(view) + self.view = view + + def paint(self, painter, option, index): + super().paint(painter, option, index) + + if index.column() != 0: + return + + file_id = index.sibling(index.row(), 5).data(Qt.DisplayRole) + queue = getattr(self.view.win, "generation_queue", None) + if not file_id or not queue: + return + badge = queue.get_file_badge(file_id) + if not badge and _is_generation_placeholder(file_id): + job = queue.get_job(_job_id_from_placeholder(file_id)) + if job and job.get("status") in ("queued", "running", "canceling"): + label = "Queued" if job.get("status") == "queued" else "Generating" + badge = { + "status": job.get("status"), + "progress": int(job.get("progress", 0)), + "label": label, + "job_id": job.get("id"), + } + if not badge: + return + + progress = int(badge.get("progress", 0)) + status = badge.get("status", "") + if status == "queued": + progress = max(progress, 2) + if progress <= 0: + return + + opt = QStyleOptionViewItem(option) + self.initStyleOption(opt, index) + style = opt.widget.style() if opt.widget else self.view.style() + deco_rect = style.subElementRect(QStyle.SE_ItemViewItemDecoration, opt, opt.widget) + if not deco_rect.isValid(): + return + + bar_height = 3 + bar_margin = 2 + full_rect = deco_rect.adjusted(1, 0, -1, 0) + full_rect.setTop(deco_rect.bottom() - bar_height - bar_margin + 1) + full_rect.setHeight(bar_height) + if full_rect.width() <= 2: + return + + fill_width = max(1, int((full_rect.width() * min(progress, 100)) / 100.0)) + fill_rect = full_rect.adjusted(0, 0, -(full_rect.width() - fill_width), 0) + + painter.save() + painter.setPen(Qt.NoPen) + painter.setBrush(QColor("#283241")) + painter.drawRect(full_rect) + painter.setBrush(QColor("#53A0ED")) + painter.drawRect(fill_rect) + painter.restore() + + class FilesTreeView(QTreeView): """ A TreeView QWidget used on the main window """ drag_item_size = QSize(48, 48) @@ -60,6 +136,28 @@ def contextMenuEvent(self, event): menu = StyledContextMenu(parent=self) menu.addAction(self.win.actionImportFiles) + self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) + menu.addAction(self.win.actionGenerate) + + active_job = None + file_id = None + if index.isValid(): + id_index = index.sibling(index.row(), 5) + file_id = index.model().data(id_index, Qt.DisplayRole) + if _is_generation_placeholder(file_id): + job_id = _job_id_from_placeholder(file_id) + queue = getattr(self.win, "generation_queue", None) + active_job = queue.get_job(job_id) if queue else None + if active_job and active_job.get("status") not in ("queued", "running", "canceling"): + active_job = None + else: + active_job = self.win.active_generation_job_for_file(file_id) + if active_job: + cancel_action = menu.addAction(_("Cancel Job")) + cancel_action.triggered.connect( + lambda checked=False, job_id=active_job.get("id"): self.win.cancel_generation_job(job_id) + ) + menu.addSeparator() menu.addAction(self.win.actionThumbnailView) if index.isValid(): @@ -75,6 +173,9 @@ def contextMenuEvent(self, event): # Add edit title option (if svg file) file = File.get(id=file_id) + if not file: + menu.popup(event.globalPos()) + return if file and file.data.get("path").endswith(".svg"): menu.addAction(self.win.actionEditTitle) menu.addAction(self.win.actionDuplicate) @@ -136,6 +237,7 @@ def startDrag(self, supportedActions): # Get first column indexes for all selected rows selected = self.selectionModel().selectedRows(0) + selected = [idx for idx in selected if not _is_generation_placeholder(idx.sibling(idx.row(), 5).data(Qt.DisplayRole))] # Check if there are any selected items if not selected: @@ -248,17 +350,29 @@ def value_updated(self, item): """ Name or tags updated """ if self.files_model.ignore_updates: return - - # Get translation method - _ = get_app()._tr + if item is None: + return + if item.column() not in (1, 2): + return # Determine what was changed - file_id = self.files_model.model.item(item.row(), 5).text() - name = self.files_model.model.item(item.row(), 1).text() - tags = self.files_model.model.item(item.row(), 2).text() + file_id_item = self.files_model.model.item(item.row(), 5) + name_item = self.files_model.model.item(item.row(), 1) + tags_item = self.files_model.model.item(item.row(), 2) + if not file_id_item or not name_item or not tags_item: + return + + file_id = file_id_item.text() + if _is_generation_placeholder(file_id): + return + + name = name_item.text() + tags = tags_item.text() # Get file object and update friendly name and tags attribute f = File.get(id=file_id) + if not f: + return f.data.update({"name": name or os.path.basename(f.data.get("path"))}) if "tags" in f.data or tags: f.data.update({"tags": tags}) @@ -286,6 +400,7 @@ def __init__(self, model, *args): self.setSelectionBehavior(QAbstractItemView.SelectRows) self.setSelectionModel(self.files_model.selection_model) self.setSortingEnabled(True) + self.setItemDelegate(FilesTreeProgressDelegate(self)) self.setAcceptDrops(True) self.setDragEnabled(True) From 32af63af1981e9f1d2c3af52e6e6a0740229fe76 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Thu, 12 Feb 2026 16:44:30 -0600 Subject: [PATCH 02/27] Refactoring Comfy UI service code to it's own file, improving code to work with remote servers (i.e. LoadFile -> Upload files). Updating file headers with correct header and copyright. Updating some base/default models for image generation and upscaling. Much more stable now at reconnecting, losing connection, and not giving up. --- src/classes/comfy_client.py | 262 +++++++++++- src/classes/comfy_pipelines.py | 92 ++++- src/classes/generation_queue.py | 110 +++++- src/classes/generation_service.py | 374 ++++++++++++++++++ .../{generate_media.py => generate.py} | 24 +- src/windows/main_window.py | 262 +----------- src/windows/views/files_listview.py | 11 +- src/windows/views/files_treeview.py | 10 +- 8 files changed, 857 insertions(+), 288 deletions(-) create mode 100644 src/classes/generation_service.py rename src/windows/{generate_media.py => generate.py} (87%) diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index 987433b43..93b38d18a 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -1,12 +1,35 @@ """ @file - @brief Small ComfyUI HTTP client for queue/poll/cancel operations. + @brief This file contains a small ComfyUI HTTP/WebSocket client. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . """ import json import os import ssl import base64 +import uuid import socket import struct from urllib.error import HTTPError @@ -33,13 +56,16 @@ def _connect(self): if not host: raise RuntimeError("Invalid ComfyUI URL for websocket") port = parsed.port or (443 if scheme == "https" else 80) - path = "/ws?clientId={}".format(self.client_id) + base_path = (parsed.path or "").rstrip("/") + ws_path = "{}/ws".format(base_path) if base_path else "/ws" + path = "{}?clientId={}".format(ws_path, quote(self.client_id)) - raw = socket.create_connection((host, port), timeout=4.0) + raw = socket.create_connection((host, port), timeout=6.0) if scheme == "https": ctx = ssl.create_default_context() raw = ctx.wrap_socket(raw, server_hostname=host) - raw.settimeout(0.25) + # Allow slower remote/proxied websocket handshakes. + raw.settimeout(6.0) self.sock = raw key = base64.b64encode(os.urandom(16)).decode("ascii") @@ -48,15 +74,20 @@ def _connect(self): "Host: {}:{}\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" + "Origin: {}://{}:{}\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" "Sec-WebSocket-Key: {}\r\n" "Sec-WebSocket-Version: 13\r\n" "\r\n" - ).format(path, host, port, key) + ).format(path, host, port, scheme, host, port, key) self.sock.sendall(req.encode("utf-8")) response = self._recv_http_headers() if " 101 " not in response.split("\r\n", 1)[0]: raise RuntimeError("WebSocket upgrade failed: {}".format(response.split("\r\n", 1)[0])) + # Use short timeout for regular frame polling after successful handshake. + self.sock.settimeout(0.25) def close(self): if self.sock is not None: @@ -230,6 +261,7 @@ def ping(self, timeout=0.5): return int(response.status) >= 200 and int(response.status) < 300 def queue_prompt(self, prompt_graph, client_id): + prompt_graph = self._rewrite_prompt_local_file_inputs(prompt_graph) payload = json.dumps({"prompt": prompt_graph, "client_id": client_id}).encode("utf-8") req = Request( "{}/prompt".format(self.base_url), @@ -238,43 +270,187 @@ def queue_prompt(self, prompt_graph, client_id): headers={"Content-Type": "application/json"}, ) try: - with urlopen(req, timeout=5.0) as response: + with urlopen(req, timeout=10.0) as response: data = json.loads(response.read().decode("utf-8")) except HTTPError as ex: details = "" try: error_data = json.loads(ex.read().decode("utf-8")) - details = error_data.get("error", {}).get("type") or error_data.get("error", {}).get("message") or str(error_data) + error_obj = error_data.get("error", {}) + if isinstance(error_obj, dict): + details = error_obj.get("type") or error_obj.get("message") or "" + else: + details = str(error_obj or "") + + node_errors = error_data.get("node_errors", {}) + node_error_text = self._format_node_errors(node_errors) + if node_error_text: + details = "{}\n{}".format(details or "prompt validation failed", node_error_text) + elif not details: + details = json.dumps(error_data, ensure_ascii=True) + else: + details = "{}\n{}".format(details, json.dumps(error_data, ensure_ascii=True)) except Exception: details = str(ex) raise RuntimeError("ComfyUI prompt rejected: {}".format(details)) return data.get("prompt_id") + def _rewrite_prompt_local_file_inputs(self, prompt_graph): + """Rewrite local absolute paths for LoadImage/LoadVideo nodes to uploaded [input] refs.""" + if not isinstance(prompt_graph, dict): + return prompt_graph + rewritten = dict(prompt_graph) + + def _annotated(path_text): + path_text = str(path_text or "").strip() + return path_text.endswith("[input]") or path_text.endswith("[output]") or path_text.endswith("[temp]") + + for node_id, node in rewritten.items(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")) + inputs = node.get("inputs", {}) + if not isinstance(inputs, dict): + continue + + if class_type == "LoadImage": + image_path = str(inputs.get("image", "")).strip() + if image_path and os.path.isabs(image_path) and os.path.exists(image_path) and not _annotated(image_path): + uploaded = self.upload_input_file(image_path) + inputs["image"] = uploaded + node["inputs"] = inputs + rewritten[node_id] = node + log.debug("ComfyClient rewrote LoadImage input node=%s path=%s -> %s", str(node_id), image_path, uploaded) + elif class_type == "LoadVideo": + video_path = str(inputs.get("file", "")).strip() + if video_path and os.path.isabs(video_path) and os.path.exists(video_path) and not _annotated(video_path): + uploaded = self.upload_input_file(video_path) + inputs["file"] = uploaded + node["inputs"] = inputs + rewritten[node_id] = node + log.debug("ComfyClient rewrote LoadVideo input node=%s path=%s -> %s", str(node_id), video_path, uploaded) + + return rewritten + + @staticmethod + def _format_node_errors(node_errors): + if not isinstance(node_errors, dict) or not node_errors: + return "" + lines = [] + max_lines = 8 + for node_id, err in node_errors.items(): + if len(lines) >= max_lines: + break + if not isinstance(err, dict): + lines.append("node {}: {}".format(node_id, str(err))) + continue + err_type = str(err.get("type", "")).strip() + message = str(err.get("message", "")).strip() + if not message: + details = err.get("details") + if details: + message = str(details) + if err_type and message: + lines.append("node {} [{}]: {}".format(node_id, err_type, message)) + elif message: + lines.append("node {}: {}".format(node_id, message)) + elif err_type: + lines.append("node {} [{}]".format(node_id, err_type)) + if not lines: + return "" + return "Node validation errors: {}".format(" | ".join(lines)) + def list_checkpoints(self): """Return available checkpoint names from ComfyUI object info.""" - with urlopen("{}/object_info/CheckpointLoaderSimple".format(self.base_url), timeout=3.0) as response: + with urlopen("{}/object_info/CheckpointLoaderSimple".format(self.base_url), timeout=8.0) as response: data = json.loads(response.read().decode("utf-8")) # Expected path: - # CheckpointLoaderSimple -> input -> required -> ckpt_name -> [ [..names..], {...meta...} ] + # CheckpointLoaderSimple -> input -> required -> ckpt_name + # Supports multiple schema variants: + # 1) [ [..names..], {...meta...} ] + # 2) ["COMBO", {"options":[..names..], ...}] node_info = data.get("CheckpointLoaderSimple", {}) required = node_info.get("input", {}).get("required", {}) - ckpt_input = required.get("ckpt_name", []) - if not ckpt_input or not isinstance(ckpt_input, list): - return [] - values = ckpt_input[0] if len(ckpt_input) > 0 else [] - if not isinstance(values, list): - return [] + ckpt_input = required.get("ckpt_name", None) + values = self._extract_combo_options(ckpt_input) return [str(v) for v in values if str(v).strip()] + def list_upscale_models(self): + """Return available upscaler model names from ComfyUI object info.""" + models = [] + # Primary source: object_info schema for UpscaleModelLoader. + try: + with urlopen("{}/object_info/UpscaleModelLoader".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get("UpscaleModelLoader", {}) + required = node_info.get("input", {}).get("required", {}) + model_input = required.get("model_name", None) + values = self._extract_combo_options(model_input) + if values: + models = [str(v) for v in values if str(v).strip()] + except Exception as ex: + log.debug("ComfyClient list_upscale_models object_info parse failed: %s", ex) + + # Fallback: direct model listing endpoint. + if not models: + try: + with urlopen("{}/models/upscale_models".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + if isinstance(data, list): + models = [str(v) for v in data if str(v).strip()] + except Exception as ex: + log.debug("ComfyClient list_upscale_models /models fallback failed: %s", ex) + + # Dedupe while preserving order. + seen = set() + ordered = [] + for name in models: + if name not in seen: + seen.add(name) + ordered.append(name) + return ordered + + @staticmethod + def _extract_combo_options(input_config): + """Extract valid options from Comfy object_info input config variants.""" + if input_config is None: + return [] + + # Variant: [ [options...], {meta...} ] + if isinstance(input_config, list) and input_config and isinstance(input_config[0], list): + return [str(v) for v in input_config[0]] + + # Variant: ["COMBO", {"options":[...], ...}] + if ( + isinstance(input_config, list) + and len(input_config) >= 2 + and str(input_config[0]).upper() == "COMBO" + and isinstance(input_config[1], dict) + ): + options = input_config[1].get("options", []) + if isinstance(options, list): + return [str(v) for v in options] + + # Variant: direct list of values + if isinstance(input_config, list): + scalar_values = [] + for item in input_config: + if isinstance(item, (str, int, float)): + scalar_values.append(str(item)) + return scalar_values + + return [] + def history(self, prompt_id): - with urlopen("{}/history/{}".format(self.base_url, quote(str(prompt_id))), timeout=3.0) as response: + with urlopen("{}/history/{}".format(self.base_url, quote(str(prompt_id))), timeout=10.0) as response: return json.loads(response.read().decode("utf-8")) def progress(self): """Return ComfyUI /progress payload.""" try: - with urlopen("{}/progress".format(self.base_url), timeout=3.0) as response: + with urlopen("{}/progress".format(self.base_url), timeout=8.0) as response: return json.loads(response.read().decode("utf-8")) except HTTPError as ex: if int(getattr(ex, "code", 0)) == 404: @@ -293,7 +469,7 @@ def interrupt(self, prompt_id=None): method="POST", headers={"Content-Type": "application/json"}, ) - with urlopen(req, timeout=3.0) as response: + with urlopen(req, timeout=8.0) as response: log.debug("ComfyClient interrupt response status=%s", int(response.status)) return int(response.status) >= 200 and int(response.status) < 300 @@ -307,15 +483,61 @@ def cancel_prompt(self, prompt_id): method="POST", headers={"Content-Type": "application/json"}, ) - with urlopen(req, timeout=3.0) as response: + with urlopen(req, timeout=8.0) as response: log.debug("ComfyClient cancel_prompt response status=%s", int(response.status)) return int(response.status) >= 200 and int(response.status) < 300 def queue(self): """Return ComfyUI queue state.""" - with urlopen("{}/queue".format(self.base_url), timeout=3.0) as response: + with urlopen("{}/queue".format(self.base_url), timeout=10.0) as response: return json.loads(response.read().decode("utf-8")) + def upload_input_file(self, local_path): + """Upload a local file into ComfyUI input dir via /upload/image.""" + local_path = str(local_path or "").strip() + if not local_path or not os.path.exists(local_path): + raise RuntimeError("Local file does not exist: {}".format(local_path)) + + boundary = "----OpenShotComfy{}".format(uuid.uuid4().hex) + filename = os.path.basename(local_path) + parts = [] + + def _add_field(name, value): + parts.append("--{}\r\n".format(boundary).encode("utf-8")) + parts.append('Content-Disposition: form-data; name="{}"\r\n\r\n'.format(name).encode("utf-8")) + parts.append(str(value).encode("utf-8")) + parts.append(b"\r\n") + + _add_field("type", "input") + parts.append("--{}\r\n".format(boundary).encode("utf-8")) + parts.append( + ( + 'Content-Disposition: form-data; name="image"; filename="{}"\r\n' + "Content-Type: application/octet-stream\r\n\r\n" + ).format(filename).encode("utf-8") + ) + with open(local_path, "rb") as handle: + parts.append(handle.read()) + parts.append(b"\r\n") + parts.append("--{}--\r\n".format(boundary).encode("utf-8")) + body = b"".join(parts) + + req = Request( + "{}/upload/image".format(self.base_url), + data=body, + method="POST", + headers={"Content-Type": "multipart/form-data; boundary={}".format(boundary)}, + ) + with urlopen(req, timeout=30.0) as response: + data = json.loads(response.read().decode("utf-8")) + + name = str(data.get("name", "")).strip() + subfolder = str(data.get("subfolder", "")).strip() + if not name: + raise RuntimeError("ComfyUI upload failed: invalid response") + rel = "{}/{}".format(subfolder, name) if subfolder else name + return "{} [input]".format(rel) + @staticmethod def prompt_in_queue(prompt_id, queue_data): """Check if prompt_id appears in queue_running/queue_pending payload.""" diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index d5aec84cb..cab132644 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -1,6 +1,28 @@ """ @file - @brief Basic built-in ComfyUI pipeline definitions. + @brief This file contains built-in ComfyUI pipeline definitions. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . """ import random @@ -11,9 +33,16 @@ ".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff", ".gif", } +DEFAULT_SD_CHECKPOINT = "sd_xl_turbo_1.0_fp16.safetensors" +DEFAULT_UPSCALE_MODEL = "RealESRGAN_x4plus.safetensors" + def is_supported_img2img_path(path): - ext = os.path.splitext(str(path or ""))[1].lower() + path_text = str(path or "").strip() + # Comfy annotated paths can look like: "image.jpg [input]" + if path_text.endswith("]") and " [" in path_text: + path_text = path_text.rsplit(" [", 1)[0].strip() + ext = os.path.splitext(path_text)[1].lower() return ext in RASTER_IMAGE_EXTENSIONS @@ -26,19 +55,44 @@ def _supports_img2img(source_file=None): return is_supported_img2img_path(path) +def _supports_video_upscale(source_file=None): + if not source_file: + return False + return source_file.data.get("media_type") == "video" + + def available_pipelines(source_file=None): pipelines = [{"id": "txt2img-basic", "name": "Basic Text to Image"}] if _supports_img2img(source_file): pipelines.insert(0, {"id": "img2img-basic", "name": "Basic Image Variation"}) + pipelines.insert(1, {"id": "upscale-realesrgan-x4", "name": "Upscale Image (RealESRGAN x4)"}) + if _supports_video_upscale(source_file): + pipelines.append({"id": "video-upscale-gan", "name": "Upscale Video (GAN x4, first 10s)"}) return pipelines -def build_workflow(pipeline_id, prompt_text, source_path, output_prefix, checkpoint_name=None): +def pipeline_requires_checkpoint(pipeline_id): + return str(pipeline_id or "") in ("txt2img-basic", "img2img-basic") + + +def pipeline_requires_upscale_model(pipeline_id): + return str(pipeline_id or "") in ("upscale-realesrgan-x4", "video-upscale-gan") + + +def build_workflow( + pipeline_id, + prompt_text, + source_path, + output_prefix, + checkpoint_name=None, + upscale_model_name=None, +): prompt_text = str(prompt_text or "cinematic shot, highly detailed").strip() if not prompt_text: prompt_text = "cinematic shot, highly detailed" output_prefix = str(output_prefix or "openshot_gen").strip() or "openshot_gen" - checkpoint_name = str(checkpoint_name or "").strip() or "v1-5-pruned-emaonly.safetensors" + checkpoint_name = str(checkpoint_name or "").strip() or DEFAULT_SD_CHECKPOINT + upscale_model_name = str(upscale_model_name or "").strip() or DEFAULT_UPSCALE_MODEL seed = random.randint(1, 2**31 - 1) if pipeline_id == "img2img-basic": @@ -64,6 +118,36 @@ def build_workflow(pipeline_id, prompt_text, source_path, output_prefix, checkpo "8": {"inputs": {"filename_prefix": output_prefix, "images": ["7", 0]}, "class_type": "SaveImage"}, } + if pipeline_id == "upscale-realesrgan-x4": + if not is_supported_img2img_path(source_path): + raise ValueError( + "The selected file is not a supported raster image for this pipeline. " + "Use PNG/JPG/WebP/BMP/TIFF or switch to Text to Image." + ) + return { + "1": {"inputs": {"image": str(source_path or ""), "upload": "image"}, "class_type": "LoadImage"}, + "2": {"inputs": {"model_name": upscale_model_name}, "class_type": "UpscaleModelLoader"}, + "3": {"inputs": {"upscale_model": ["2", 0], "image": ["1", 0]}, "class_type": "ImageUpscaleWithModel"}, + "4": {"inputs": {"filename_prefix": output_prefix, "images": ["3", 0]}, "class_type": "SaveImage"}, + } + + if pipeline_id == "video-upscale-gan": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": { + "inputs": {"video": ["1", 0], "start_time": 0.0, "duration": 10.0, "strict_duration": False}, + "class_type": "Video Slice", + }, + "3": {"inputs": {"video": ["2", 0]}, "class_type": "GetVideoComponents"}, + "4": {"inputs": {"model_name": upscale_model_name}, "class_type": "UpscaleModelLoader"}, + "5": {"inputs": {"upscale_model": ["4", 0], "image": ["3", 0]}, "class_type": "ImageUpscaleWithModel"}, + "6": {"inputs": {"images": ["5", 0], "audio": ["3", 1], "fps": ["3", 2]}, "class_type": "CreateVideo"}, + "7": {"inputs": {"video": ["6", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + return { "1": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, "2": {"inputs": {"text": prompt_text, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index 4dfed87bb..dea2f3c9f 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -1,11 +1,28 @@ """ @file - @brief Lightweight in-memory generation queue for ComfyUI-backed jobs. + @brief This file contains a lightweight in-memory generation queue for ComfyUI jobs. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. OpenShot Video Editor is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . """ import uuid @@ -86,9 +103,14 @@ def _run_comfy_job(self, job_id, request): start_time = monotonic() last_in_queue_time = start_time + last_contact_time = start_time last_progress_log_time = 0.0 + last_network_error_log_time = 0.0 progress_endpoint_unavailable = False accepted_progress_started = False + ws_retry_delay_s = 2.0 + ws_next_retry_at = start_time + prompt_key = str(prompt_id) while True: if self._is_cancel_requested(job_id, cancel_event): @@ -165,9 +187,21 @@ def _run_comfy_job(self, job_id, request): ) return - history = client.history(prompt_id) or {} - prompt_key = str(prompt_id) - history_entry = history.get(prompt_key) or history.get(prompt_id) or None + history_entry = None + try: + history = client.history(prompt_id) or {} + history_entry = history.get(prompt_key) or history.get(prompt_id) or None + last_contact_time = monotonic() + except Exception: + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy history poll temporarily unavailable for job=%s prompt=%s; retrying", + job_id, + prompt_key, + exc_info=True, + ) + last_network_error_log_time = now_log if history_entry is not None: status_obj = history_entry.get("status", {}) if isinstance(history_entry, dict) else {} status_str = str(status_obj.get("status_str", "")).lower() @@ -187,8 +221,48 @@ def _run_comfy_job(self, job_id, request): # Query ComfyUI's live progress values when available. try: + now = monotonic() + if ws_client is None and now >= ws_next_retry_at: + try: + ws_client = ComfyClient.open_progress_socket(comfy_url, client_id) + ws_retry_delay_s = 2.0 + log.debug("Comfy progress websocket reconnected for prompt=%s", prompt_key) + except Exception: + ws_next_retry_at = now + ws_retry_delay_s + ws_retry_delay_s = min(60.0, ws_retry_delay_s * 1.5) + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy websocket reconnect failed for job=%s prompt=%s; retrying in %.1fs", + job_id, + prompt_key, + ws_retry_delay_s, + exc_info=True, + ) + last_network_error_log_time = now_log + if ws_client is not None: - progress_event = ws_client.poll_progress(prompt_id) + try: + progress_event = ws_client.poll_progress(prompt_id) + except Exception: + progress_event = None + try: + ws_client.close() + except Exception: + pass + ws_client = None + ws_next_retry_at = monotonic() + ws_retry_delay_s + ws_retry_delay_s = min(60.0, ws_retry_delay_s * 1.5) + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy websocket progress read failed for job=%s prompt=%s; switching to retry mode", + job_id, + prompt_key, + exc_info=True, + ) + last_network_error_log_time = now_log + if progress_event is not None: elapsed = monotonic() - start_time progress = int(progress_event.get("percent", 0)) @@ -237,8 +311,9 @@ def _run_comfy_job(self, job_id, request): progress, ) self.progress_changed.emit(job_id, progress) - else: - progress_data = client.progress() or {} + last_contact_time = monotonic() + if ws_client is None: + progress_data = client.progress() if progress_data is None: if not progress_endpoint_unavailable: log.debug( @@ -284,19 +359,27 @@ def _run_comfy_job(self, job_id, request): progress, ) self.progress_changed.emit(job_id, progress) + last_contact_time = monotonic() except Exception: # Keep polling history and queue even if /progress is unavailable. - if not progress_endpoint_unavailable: + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: log.debug("Comfy progress poll failed for job=%s", job_id, exc_info=True) + last_network_error_log_time = now_log # Check queue to avoid timing out long-running but active jobs. in_queue = False try: queue_data = client.queue() or {} in_queue = ComfyClient.prompt_in_queue(prompt_id, queue_data) + last_contact_time = monotonic() except Exception: # If queue check fails, do not penalize the job immediately. in_queue = True + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug("Comfy queue check temporarily unavailable for job=%s", job_id, exc_info=True) + last_network_error_log_time = now_log if in_queue: last_in_queue_time = monotonic() else: @@ -314,6 +397,17 @@ def _run_comfy_job(self, job_id, request): self.job_finished.emit(job_id, False, False, "Timed out waiting for ComfyUI history result", []) return + if (now - last_contact_time) > 60.0: + now_log = monotonic() + if (now_log - last_network_error_log_time) > 8.0: + log.debug( + "Comfy connection degraded for job=%s prompt=%s (no successful API contact for %.1fs); continuing retries", + job_id, + prompt_key, + now - last_contact_time, + ) + last_network_error_log_time = now_log + # If prompt vanished from queue for an extended period and still no history, treat as failure. if (now - last_in_queue_time) > 600: self._job_prompts.pop(job_id, None) diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py new file mode 100644 index 000000000..b756327df --- /dev/null +++ b/src/classes/generation_service.py @@ -0,0 +1,374 @@ +""" + @file + @brief This file contains Comfy generation orchestration logic. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . + """ + +import os +import re +import tempfile +from time import time + +import openshot +from PyQt5.QtWidgets import QMessageBox, QDialog + +from classes import info +from classes.app import get_app +from classes.comfy_client import ComfyClient +from classes.comfy_pipelines import ( + available_pipelines, + build_workflow, + is_supported_img2img_path, + pipeline_requires_checkpoint, + pipeline_requires_upscale_model, + DEFAULT_SD_CHECKPOINT, + DEFAULT_UPSCALE_MODEL, +) +from classes.logger import log +from classes.query import File +from windows.generate import GenerateMediaDialog + + +class GenerationService: + """Encapsulates generation-specific UI + workflow behavior.""" + + def __init__(self, win): + self.win = win + self._generation_temp_files = [] + self._comfy_status_cache = {"checked_at": 0.0, "available": False} + + def cleanup_temp_files(self): + for tmp_path in list(self._generation_temp_files): + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + self._generation_temp_files = [] + + def comfy_ui_url(self): + url = get_app().get_settings().get("comfy-ui-url") or "http://127.0.0.1:8188" + return str(url).strip().rstrip("/") + + def is_comfy_available(self, force=False): + now = time() + if not force and (now - self._comfy_status_cache["checked_at"]) < 2.0: + return self._comfy_status_cache["available"] + + url = self.comfy_ui_url() + available = False + try: + available = ComfyClient(url).ping(timeout=0.5) + except Exception: + available = False + + self._comfy_status_cache["checked_at"] = now + self._comfy_status_cache["available"] = available + return available + + def can_open_generate_dialog(self): + return len(self.win.selected_file_ids()) <= 1 + + def _prepare_generation_source_path(self, source_file, template_id): + if not source_file: + return "" + + source_path = source_file.data.get("path", "") + media_type = source_file.data.get("media_type") + if template_id not in ("img2img-basic", "upscale-realesrgan-x4") or media_type != "image": + return source_path + + if is_supported_img2img_path(source_path): + return source_path + + tmp_fd, tmp_png = tempfile.mkstemp(prefix="openshot-comfy-", suffix=".png") + os.close(tmp_fd) + try: + clip = openshot.Clip(source_path) + frame = clip.Reader().GetFrame(1) + frame.Save(tmp_png, 1.0) + self._generation_temp_files.append(tmp_png) + return tmp_png + except Exception: + try: + os.remove(tmp_png) + except OSError: + pass + raise + + def _prepare_generation_video_input(self, source_file, client): + if not source_file: + raise ValueError("A source video is required.") + source_path = source_file.data.get("path", "") + if not source_path: + raise ValueError("Source video path is invalid.") + return client.upload_input_file(source_path) + + def _prepare_generation_image_input(self, local_image_path, client): + local_image_path = str(local_image_path or "").strip() + if not local_image_path: + raise ValueError("A source image is required.") + return client.upload_input_file(local_image_path) + + def action_generate_trigger(self, checked=True): + selected_files = self.win.selected_files() + if len(selected_files) > 1: + return + + if not self.is_comfy_available(force=True): + msg = QMessageBox(self.win) + msg.setWindowTitle("ComfyUI Unavailable") + msg.setText( + "OpenShot could not connect to ComfyUI at:\n{}\n\n" + "Start ComfyUI or update the URL in Preferences > Experimental.".format(self.comfy_ui_url()) + ) + msg.exec_() + return + + source_file = selected_files[0] if selected_files else None + templates = available_pipelines(source_file=source_file) + win = GenerateMediaDialog(source_file=source_file, templates=templates, parent=self.win) + if win.exec_() != QDialog.Accepted: + return + + payload = win.get_payload() + payload_name = self._next_generation_name(payload.get("name")) + source_file_id = source_file.id if source_file else None + try: + source_path = self._prepare_generation_source_path(source_file, payload.get("template_id")) + except Exception as ex: + QMessageBox.warning( + self.win, + "Source Conversion Failed", + "OpenShot could not convert this image into PNG for ComfyUI.\n\n{}".format(ex), + ) + return + pipeline_id = payload.get("template_id") + checkpoint_name = None + upscale_model_name = None + client = ComfyClient(self.comfy_ui_url()) + workflow_source = source_path + + if pipeline_id == "video-upscale-gan": + if not source_file or source_file.data.get("media_type") != "video": + QMessageBox.information(self.win, "Invalid Input", "This pipeline requires a source video file.") + return + try: + workflow_source = self._prepare_generation_video_input(source_file, client) + except Exception as ex: + QMessageBox.warning( + self.win, + "Video Upload Failed", + "OpenShot could not upload the source video into ComfyUI input.\n\n{}".format(ex), + ) + return + elif pipeline_id in ("img2img-basic", "upscale-realesrgan-x4"): + try: + workflow_source = self._prepare_generation_image_input(source_path, client) + except Exception as ex: + QMessageBox.warning( + self.win, + "Image Upload Failed", + "OpenShot could not upload the source image into ComfyUI input.\n\n{}".format(ex), + ) + return + + try: + if pipeline_requires_checkpoint(pipeline_id): + checkpoint_names = client.list_checkpoints() + if checkpoint_names: + checkpoint_name = ( + DEFAULT_SD_CHECKPOINT if DEFAULT_SD_CHECKPOINT in checkpoint_names else checkpoint_names[0] + ) + except Exception as ex: + log.warning("Failed to query ComfyUI checkpoints: %s", ex) + + if pipeline_requires_checkpoint(pipeline_id) and not checkpoint_name: + QMessageBox.information( + self.win, + "No Checkpoints Found", + "ComfyUI has no checkpoints available for CheckpointLoaderSimple.\n" + "Add a model to ComfyUI/models/checkpoints and try again.", + ) + return + + try: + if pipeline_requires_upscale_model(pipeline_id): + upscale_models = client.list_upscale_models() + if upscale_models: + upscale_model_name = ( + DEFAULT_UPSCALE_MODEL if DEFAULT_UPSCALE_MODEL in upscale_models else upscale_models[0] + ) + except Exception as ex: + log.warning("Failed to query ComfyUI upscale models: %s", ex) + + if pipeline_requires_upscale_model(pipeline_id) and not upscale_model_name: + QMessageBox.information( + self.win, + "No Upscale Models Found", + "ComfyUI has no upscaler models available for UpscaleModelLoader.\n" + "Add a model such as RealESRGAN_x4plus.safetensors to ComfyUI/models/upscale_models and try again.", + ) + return + + try: + workflow = build_workflow( + pipeline_id, + payload.get("prompt"), + workflow_source, + payload_name, + checkpoint_name=checkpoint_name, + upscale_model_name=upscale_model_name, + ) + except Exception as ex: + QMessageBox.information(self.win, "Invalid Input", str(ex)) + return + request = { + "comfy_url": self.comfy_ui_url(), + "workflow": workflow, + "client_id": "openshot-qt", + "timeout_s": 21600, + "save_node_ids": [ + str(node_id) + for node_id, node in workflow.items() + if node.get("class_type") in ("SaveImage", "SaveVideo") + ], + } + job_id = self.win.generation_queue.enqueue( + payload_name, + payload.get("template_id"), + payload.get("prompt"), + source_file_id=source_file_id, + request=request, + ) + if not job_id: + QMessageBox.information( + self.win, + "Generation Already Active", + "Only one active generation is allowed per source file.", + ) + return + + self.win.statusBar.showMessage("Queued generation job", 3000) + + def on_generation_job_finished(self, job_id, status): + job = self.win.generation_queue.get_job(job_id) if getattr(self.win, "generation_queue", None) else None + if not job: + return + + if status == "completed": + imported = self._import_generation_outputs(job) + if imported > 0: + self.win.statusBar.showMessage("Generation completed and imported {} file(s)".format(imported), 5000) + else: + self.win.statusBar.showMessage("Generation completed (no output files found)", 5000) + return + + if status == "canceled": + self.win.statusBar.showMessage("Generation canceled", 3000) + return + + if status == "failed": + error_text = str(job.get("error") or "ComfyUI generation failed.") + self.win.statusBar.showMessage("Generation failed", 5000) + QMessageBox.warning(self.win, "Generation Failed", error_text) + + def _import_generation_outputs(self, job): + outputs = list(job.get("outputs", []) or []) + if not outputs: + return 0 + + request = job.get("request", {}) or {} + comfy_url = str(request.get("comfy_url") or self.comfy_ui_url()) + client = ComfyClient(comfy_url) + output_dir = os.path.join(info.USER_PATH, "comfy_outputs") + os.makedirs(output_dir, exist_ok=True) + + name_raw = str(job.get("name") or "generation") + safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name_raw).strip("._") + if not safe_name: + safe_name = "generation" + + saved_paths = [] + for index, image_ref in enumerate(outputs, start=1): + original_name = str(image_ref.get("filename", "output.png")) + ext = os.path.splitext(original_name)[1] or ".png" + local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ext) + local_path = self._next_available_path(os.path.join(output_dir, local_name)) + try: + client.download_image(image_ref, local_path) + saved_paths.append(local_path) + except Exception as ex: + log.warning("Failed to download Comfy output %s: %s", image_ref, ex) + + if not saved_paths: + return 0 + + self.win.files_model.add_files( + saved_paths, + quiet=True, + prevent_image_seq=True, + prevent_recent_folder=True, + ) + return len(saved_paths) + + def _next_generation_name(self, requested_name): + base = re.sub(r"[^A-Za-z0-9._-]+", "_", str(requested_name or "").strip()).strip("._") + if not base: + base = "generation" + + existing_names = set() + for file_obj in File.filter(): + if not file_obj: + continue + display_name = str(file_obj.data.get("name") or os.path.basename(file_obj.data.get("path", "")) or "") + if display_name: + stem = os.path.splitext(display_name)[0] + existing_names.add(stem.lower()) + + if base.lower() not in existing_names: + return base + + name_root = base + m = re.match(r"^(.*?)(?:_gen(\d+))?$", base, re.IGNORECASE) + if m: + name_root = (m.group(1) or base).rstrip("_") or "generation" + n = 1 + while True: + candidate = "{}_gen{}".format(name_root, n) + if candidate.lower() not in existing_names: + return candidate + n += 1 + + def _next_available_path(self, path): + if not os.path.exists(path): + return path + folder = os.path.dirname(path) + stem, ext = os.path.splitext(os.path.basename(path)) + n = 2 + while True: + candidate = os.path.join(folder, "{}_{}{}".format(stem, n, ext)) + if not os.path.exists(candidate): + return candidate + n += 1 diff --git a/src/windows/generate_media.py b/src/windows/generate.py similarity index 87% rename from src/windows/generate_media.py rename to src/windows/generate.py index 5d27a6cc9..a4d6c8591 100644 --- a/src/windows/generate_media.py +++ b/src/windows/generate.py @@ -1,6 +1,28 @@ """ @file - @brief Simple Generate dialog for ComfyUI pipeline jobs. + @brief This file contains the Generate media dialog. + @author Jonathan Thomas + + @section LICENSE + + Copyright (c) 2008-2026 OpenShot Studios, LLC + (http://www.openshotstudios.com). This file is part of + OpenShot Video Editor (http://www.openshot.org), an open-source project + dedicated to delivering high quality video editing and animation solutions + to the world. + + OpenShot Video Editor is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + OpenShot Video Editor is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with OpenShot Library. If not, see . """ import os diff --git a/src/windows/main_window.py b/src/windows/main_window.py index a06a77ff2..eae5bcd57 100644 --- a/src/windows/main_window.py +++ b/src/windows/main_window.py @@ -34,14 +34,11 @@ import shutil import uuid import webbrowser -import tempfile from time import sleep, time from datetime import datetime from uuid import uuid4 import zipfile import threading -from urllib.parse import urlparse -from urllib.request import urlopen import openshot # Python module for libopenshot (required video editing module installed separately) from PyQt5.QtCore import ( @@ -67,8 +64,7 @@ from classes.metrics import track_metric_session, track_metric_screen from classes.query import File, Clip, Transition, Marker, Track, Effect from classes.generation_queue import GenerationQueueManager -from classes.comfy_pipelines import available_pipelines, build_workflow, is_supported_img2img_path -from classes.comfy_client import ComfyClient +from classes.generation_service import GenerationService from classes.thumbnail import httpThumbnailServerThread, httpThumbnailException from classes.time_parts import secondsToTimecode from classes.timeline import TimelineSync @@ -92,7 +88,6 @@ from windows.views.transitions_listview import TransitionsListView from windows.views.transitions_treeview import TransitionsTreeView from windows.views.tutorial import TutorialManager -from windows.generate_media import GenerateMediaDialog class MainWindow(updates.UpdateWatcher, QMainWindow): @@ -221,12 +216,8 @@ def closeEvent(self, event): self.generation_queue.shutdown() # Cleanup temporary generation source files - for tmp_path in getattr(self, "_generation_temp_files", []): - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - except OSError: - pass + if getattr(self, "generation_service", None): + self.generation_service.cleanup_temp_files() # Stop ZMQ polling thread (if any) if app.logger_libopenshot: @@ -2000,61 +1991,13 @@ def actionSplitFile_trigger(self): log.info('Cutting Cancelled') def comfy_ui_url(self): - url = get_app().get_settings().get("comfy-ui-url") or "http://127.0.0.1:8188" - return str(url).strip().rstrip("/") - - def _prepare_generation_source_path(self, source_file, template_id): - """Return a source path suitable for the selected template.""" - if not source_file: - return "" - - source_path = source_file.data.get("path", "") - media_type = source_file.data.get("media_type") - if template_id != "img2img-basic" or media_type != "image": - return source_path - - if is_supported_img2img_path(source_path): - return source_path - - # Convert unsupported image formats (such as SVG) into a temporary PNG. - tmp_fd, tmp_png = tempfile.mkstemp(prefix="openshot-comfy-", suffix=".png") - os.close(tmp_fd) - try: - clip = openshot.Clip(source_path) - frame = clip.Reader().GetFrame(1) - frame.Save(tmp_png, 1.0) - self._generation_temp_files.append(tmp_png) - return tmp_png - except Exception: - try: - os.remove(tmp_png) - except OSError: - pass - raise + return self.generation_service.comfy_ui_url() def is_comfy_available(self, force=False): - now = time() - if not force and (now - self._comfy_status_cache["checked_at"]) < 2.0: - return self._comfy_status_cache["available"] - - url = self.comfy_ui_url() - parsed = urlparse(url) - available = False - if parsed.scheme in ("http", "https") and parsed.netloc: - try: - with urlopen("{}/system_stats".format(url), timeout=0.5) as response: - available = int(response.status) >= 200 and int(response.status) < 300 - except Exception: - available = False - - self._comfy_status_cache["checked_at"] = now - self._comfy_status_cache["available"] = available - return available + return self.generation_service.is_comfy_available(force=force) def can_open_generate_dialog(self): - # Keep action clickable for valid selection counts. - # Comfy availability is validated when the action is triggered. - return len(self.selected_file_ids()) <= 1 + return self.generation_service.can_open_generate_dialog() def active_generation_job_for_file(self, file_id): if not getattr(self, "generation_queue", None): @@ -2081,194 +2024,10 @@ def actionCancelGenerationJob_trigger(self, checked=True): self.cancel_generation_job(active_job.get("id")) def actionGenerate_trigger(self, checked=True): - selected_files = self.selected_files() - if len(selected_files) > 1: - return - - if not self.is_comfy_available(force=True): - msg = QMessageBox(self) - msg.setWindowTitle("ComfyUI Unavailable") - msg.setText( - "OpenShot could not connect to ComfyUI at:\n{}\n\n" - "Start ComfyUI or update the URL in Preferences > Experimental.".format(self.comfy_ui_url()) - ) - msg.exec_() - return - - source_file = selected_files[0] if selected_files else None - templates = available_pipelines(source_file=source_file) - win = GenerateMediaDialog(source_file=source_file, templates=templates, parent=self) - if win.exec_() != QDialog.Accepted: - return - - payload = win.get_payload() - payload_name = self._next_generation_name(payload.get("name")) - source_file_id = source_file.id if source_file else None - try: - source_path = self._prepare_generation_source_path(source_file, payload.get("template_id")) - except Exception as ex: - QMessageBox.warning( - self, - "Source Conversion Failed", - "OpenShot could not convert this image into PNG for ComfyUI.\n\n{}".format(ex), - ) - return - checkpoint_name = None - try: - checkpoint_names = ComfyClient(self.comfy_ui_url()).list_checkpoints() - if checkpoint_names: - checkpoint_name = checkpoint_names[0] - except Exception as ex: - log.warning("Failed to query ComfyUI checkpoints: %s", ex) - - if not checkpoint_name: - QMessageBox.information( - self, - "No Checkpoints Found", - "ComfyUI has no checkpoints available for CheckpointLoaderSimple.\n" - "Add a model to ComfyUI/models/checkpoints and try again.", - ) - return - - try: - workflow = build_workflow( - payload.get("template_id"), - payload.get("prompt"), - source_path, - payload_name, - checkpoint_name=checkpoint_name, - ) - except Exception as ex: - QMessageBox.information(self, "Invalid Input", str(ex)) - return - request = { - "comfy_url": self.comfy_ui_url(), - "workflow": workflow, - "client_id": "openshot-qt", - "timeout_s": 21600, - "save_node_ids": [str(node_id) for node_id, node in workflow.items() if node.get("class_type") == "SaveImage"], - } - job_id = self.generation_queue.enqueue( - payload_name, - payload.get("template_id"), - payload.get("prompt"), - source_file_id=source_file_id, - request=request, - ) - if not job_id: - QMessageBox.information( - self, - "Generation Already Active", - "Only one active generation is allowed per source file.", - ) - return - - self.statusBar.showMessage("Queued generation job", 3000) + self.generation_service.action_generate_trigger(checked=checked) def _on_generation_job_finished(self, job_id, status): - job = self.generation_queue.get_job(job_id) if getattr(self, "generation_queue", None) else None - if not job: - return - - if status == "completed": - imported = self._import_generation_outputs(job) - if imported > 0: - self.statusBar.showMessage("Generation completed and imported {} file(s)".format(imported), 5000) - else: - self.statusBar.showMessage("Generation completed (no output files found)", 5000) - return - - if status == "canceled": - self.statusBar.showMessage("Generation canceled", 3000) - return - - if status == "failed": - error_text = str(job.get("error") or "ComfyUI generation failed.") - self.statusBar.showMessage("Generation failed", 5000) - QMessageBox.warning(self, "Generation Failed", error_text) - return - - def _import_generation_outputs(self, job): - outputs = list(job.get("outputs", []) or []) - if not outputs: - return 0 - - request = job.get("request", {}) or {} - comfy_url = str(request.get("comfy_url") or self.comfy_ui_url()) - client = ComfyClient(comfy_url) - output_dir = os.path.join(info.USER_PATH, "comfy_outputs") - os.makedirs(output_dir, exist_ok=True) - - name_raw = str(job.get("name") or "generation") - safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name_raw).strip("._") - if not safe_name: - safe_name = "generation" - - saved_paths = [] - for index, image_ref in enumerate(outputs, start=1): - original_name = str(image_ref.get("filename", "output.png")) - ext = os.path.splitext(original_name)[1] or ".png" - local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ext) - local_path = self._next_available_path(os.path.join(output_dir, local_name)) - try: - client.download_image(image_ref, local_path) - saved_paths.append(local_path) - except Exception as ex: - log.warning("Failed to download Comfy output %s: %s", image_ref, ex) - - if not saved_paths: - return 0 - - self.files_model.add_files( - saved_paths, - quiet=True, - prevent_image_seq=True, - prevent_recent_folder=True, - ) - return len(saved_paths) - - def _next_generation_name(self, requested_name): - """Return a unique generation-friendly name for project files/jobs.""" - base = re.sub(r"[^A-Za-z0-9._-]+", "_", str(requested_name or "").strip()).strip("._") - if not base: - base = "generation" - - existing_names = set() - for file_obj in File.filter(): - if not file_obj: - continue - display_name = str(file_obj.data.get("name") or os.path.basename(file_obj.data.get("path", "")) or "") - if display_name: - stem = os.path.splitext(display_name)[0] - existing_names.add(stem.lower()) - - if base.lower() not in existing_names: - return base - - # If the requested name already exists, append/increment _genN. - name_root = base - m = re.match(r"^(.*?)(?:_gen(\d+))?$", base, re.IGNORECASE) - if m: - name_root = (m.group(1) or base).rstrip("_") or "generation" - n = 1 - while True: - candidate = "{}_gen{}".format(name_root, n) - if candidate.lower() not in existing_names: - return candidate - n += 1 - - def _next_available_path(self, path): - """Return a non-colliding file path by appending _N when needed.""" - if not os.path.exists(path): - return path - folder = os.path.dirname(path) - stem, ext = os.path.splitext(os.path.basename(path)) - n = 2 - while True: - candidate = os.path.join(folder, "{}_{}{}".format(stem, n, ext)) - if not os.path.exists(candidate): - return candidate - n += 1 + self.generation_service.on_generation_job_finished(job_id, status) def actionRemove_from_Project_trigger(self): log.debug("actionRemove_from_Project_trigger") @@ -3740,6 +3499,8 @@ def _init_generation_actions(self): self.actionGenerate.setObjectName("actionGenerate") sparkle_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "tool-generate-sparkle.svg") self.actionGenerate.setIcon(QIcon(sparkle_icon_path)) + self.actionGenerate.setShortcut(QKeySequence("Ctrl+G")) + self.actionGenerate.setShortcutContext(Qt.ApplicationShortcut) self.actionGenerate.triggered.connect(self.actionGenerate_trigger) self.actionCancelGenerationJob = QAction("Cancel Job", self) @@ -4315,7 +4076,6 @@ def __init__(self, *args): # Load UI from designer self.selected_items = [] - self._generation_temp_files = [] ui_util.load_ui(self, self.ui_path) # Init UI @@ -4323,7 +4083,7 @@ def __init__(self, *args): # Create dock toolbars, set initial state of items, etc self.setup_toolbars() - self._comfy_status_cache = {"checked_at": 0.0, "available": False} + self.generation_service = GenerationService(self) self.generation_queue = GenerationQueueManager(self) self.generation_queue.job_finished.connect(self._on_generation_job_finished) self._init_generation_actions() diff --git a/src/windows/views/files_listview.py b/src/windows/views/files_listview.py index e6ab76bfd..c345b98e0 100644 --- a/src/windows/views/files_listview.py +++ b/src/windows/views/files_listview.py @@ -26,6 +26,7 @@ along with OpenShot Library. If not, see . """ +import os import uuid from PyQt5.QtCore import QSize, Qt, QPoint, QRegExp @@ -139,8 +140,6 @@ def contextMenuEvent(self, event): menu = StyledContextMenu(parent=self) menu.addAction(self.win.actionImportFiles) - self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) - menu.addAction(self.win.actionGenerate) active_job = None file_id = None @@ -157,8 +156,16 @@ def contextMenuEvent(self, event): active_job = None else: active_job = self.win.active_generation_job_for_file(file_id) + if not active_job: + self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) + menu.addAction(self.win.actionGenerate) if active_job: cancel_action = menu.addAction(_("Cancel Job")) + delete_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "track-delete-enabled.svg") + if os.path.exists(delete_icon_path): + cancel_action.setIcon(QIcon(delete_icon_path)) + else: + cancel_action.setIcon(self.win.actionRemove_from_Project.icon()) cancel_action.triggered.connect( lambda checked=False, job_id=active_job.get("id"): self.win.cancel_generation_job(job_id) ) diff --git a/src/windows/views/files_treeview.py b/src/windows/views/files_treeview.py index 7d0e6231a..2dd51cecf 100644 --- a/src/windows/views/files_treeview.py +++ b/src/windows/views/files_treeview.py @@ -136,8 +136,6 @@ def contextMenuEvent(self, event): menu = StyledContextMenu(parent=self) menu.addAction(self.win.actionImportFiles) - self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) - menu.addAction(self.win.actionGenerate) active_job = None file_id = None @@ -152,8 +150,16 @@ def contextMenuEvent(self, event): active_job = None else: active_job = self.win.active_generation_job_for_file(file_id) + if not active_job: + self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) + menu.addAction(self.win.actionGenerate) if active_job: cancel_action = menu.addAction(_("Cancel Job")) + delete_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "track-delete-enabled.svg") + if os.path.exists(delete_icon_path): + cancel_action.setIcon(QIcon(delete_icon_path)) + else: + cancel_action.setIcon(self.win.actionRemove_from_Project.icon()) cancel_action.triggered.connect( lambda checked=False, job_id=active_job.get("id"): self.win.cancel_generation_job(job_id) ) From 559aa8cd27315a25ed1b510a94252630c0d36541 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Thu, 12 Feb 2026 20:42:29 -0600 Subject: [PATCH 03/27] Add Comfy audio/video generation pipelines and generalized output import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - add txt2audio-stable-open, txt2video-svd, and video2video-basic - support SaveAudio plus image/video/audio output download/import - add model default selection for installed Comfy checkpoints/encoders - fix Stable Audio EmptyLatentAudio.batch_size validation - align txt→img→video flow with Comfy example and tune low-VRAM test defaults --- src/classes/comfy_client.py | 66 ++++++++++---- src/classes/comfy_pipelines.py | 142 +++++++++++++++++++++++++++++- src/classes/generation_queue.py | 2 +- src/classes/generation_service.py | 64 ++++++++++++-- 4 files changed, 246 insertions(+), 28 deletions(-) diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index 93b38d18a..89411ec59 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -412,6 +412,28 @@ def list_upscale_models(self): ordered.append(name) return ordered + def list_clip_models(self): + """Return available CLIP/text-encoder model names from ComfyUI object info.""" + with urlopen("{}/object_info/CLIPLoader".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get("CLIPLoader", {}) + required = node_info.get("input", {}).get("required", {}) + clip_input = required.get("clip_name", None) + values = self._extract_combo_options(clip_input) + return [str(v) for v in values if str(v).strip()] + + def list_clip_vision_models(self): + """Return available CLIP vision model names from ComfyUI object info.""" + with urlopen("{}/object_info/CLIPVisionLoader".format(self.base_url), timeout=8.0) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get("CLIPVisionLoader", {}) + required = node_info.get("input", {}).get("required", {}) + clip_input = required.get("clip_name", None) + values = self._extract_combo_options(clip_input) + return [str(v) for v in values if str(v).strip()] + @staticmethod def _extract_combo_options(input_config): """Extract valid options from Comfy object_info input config variants.""" @@ -560,8 +582,8 @@ def prompt_in_queue(prompt_id, queue_data): return False @staticmethod - def extract_image_outputs(history_entry, save_node_ids=None): - """Return a flat list of image refs from a history entry.""" + def extract_file_outputs(history_entry, save_node_ids=None): + """Return a flat list of file refs from image/video/audio history outputs.""" outputs = [] if not isinstance(history_entry, dict): return outputs @@ -575,26 +597,31 @@ def extract_image_outputs(history_entry, save_node_ids=None): continue if not isinstance(node_out, dict): continue - images = node_out.get("images", []) - if not isinstance(images, list): - continue - for img in images: - if not isinstance(img, dict): + for key in ("images", "videos", "video", "audios", "audio", "files"): + refs = node_out.get(key, []) + if not isinstance(refs, list): continue - if img.get("filename"): - outputs.append({ - "filename": str(img.get("filename")), - "subfolder": str(img.get("subfolder", "")), - "type": str(img.get("type", "output")), - }) + for ref in refs: + if not isinstance(ref, dict): + continue + if ref.get("filename"): + outputs.append({ + "filename": str(ref.get("filename")), + "subfolder": str(ref.get("subfolder", "")), + "type": str(ref.get("type", "output")), + }) return outputs - def download_image(self, image_ref, destination_path): - """Download a Comfy image reference to a local file path.""" + @staticmethod + def extract_image_outputs(history_entry, save_node_ids=None): + return ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) + + def download_output_file(self, file_ref, destination_path): + """Download a Comfy output reference to a local file path.""" params = { - "filename": image_ref.get("filename", ""), - "subfolder": image_ref.get("subfolder", ""), - "type": image_ref.get("type", "output"), + "filename": file_ref.get("filename", ""), + "subfolder": file_ref.get("subfolder", ""), + "type": file_ref.get("type", "output"), } url = "{}/view?{}".format(self.base_url, urlencode(params)) with urlopen(url, timeout=10.0) as response: @@ -603,3 +630,6 @@ def download_image(self, image_ref, destination_path): os.makedirs(os.path.dirname(destination_path), exist_ok=True) with open(destination_path, "wb") as handle: handle.write(data) + + def download_image(self, image_ref, destination_path): + self.download_output_file(image_ref, destination_path) diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index cab132644..39638ed89 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -34,7 +34,11 @@ } DEFAULT_SD_CHECKPOINT = "sd_xl_turbo_1.0_fp16.safetensors" +DEFAULT_SD_BASE_CHECKPOINT = "sd_xl_base_1.0.safetensors" DEFAULT_UPSCALE_MODEL = "RealESRGAN_x4plus.safetensors" +DEFAULT_STABLE_AUDIO_CHECKPOINT = "stable-audio-open-1.0.safetensors" +DEFAULT_STABLE_AUDIO_CLIP = "t5-base.safetensors" +DEFAULT_SVD_CHECKPOINT = "svd_xt.safetensors" def is_supported_img2img_path(path): @@ -62,23 +66,42 @@ def _supports_video_upscale(source_file=None): def available_pipelines(source_file=None): - pipelines = [{"id": "txt2img-basic", "name": "Basic Text to Image"}] + pipelines = [ + {"id": "txt2img-basic", "name": "Basic Text to Image"}, + {"id": "txt2video-svd", "name": "Text to Video (txt_to_image_to_video)"}, + {"id": "txt2audio-stable-open", "name": "Text to Audio (Stable Audio Open)"}, + ] if _supports_img2img(source_file): pipelines.insert(0, {"id": "img2img-basic", "name": "Basic Image Variation"}) pipelines.insert(1, {"id": "upscale-realesrgan-x4", "name": "Upscale Image (RealESRGAN x4)"}) if _supports_video_upscale(source_file): pipelines.append({"id": "video-upscale-gan", "name": "Upscale Video (GAN x4, first 10s)"}) + pipelines.append({"id": "video2video-basic", "name": "Video + Text to Video (Style Transfer)"}) return pipelines def pipeline_requires_checkpoint(pipeline_id): - return str(pipeline_id or "") in ("txt2img-basic", "img2img-basic") + return str(pipeline_id or "") in ( + "txt2img-basic", + "img2img-basic", + "txt2audio-stable-open", + "txt2video-svd", + "video2video-basic", + ) def pipeline_requires_upscale_model(pipeline_id): return str(pipeline_id or "") in ("upscale-realesrgan-x4", "video-upscale-gan") +def pipeline_requires_stable_audio_clip(pipeline_id): + return str(pipeline_id or "") in ("txt2audio-stable-open",) + + +def pipeline_requires_svd_checkpoint(pipeline_id): + return str(pipeline_id or "") in ("txt2video-svd",) + + def build_workflow( pipeline_id, prompt_text, @@ -86,6 +109,8 @@ def build_workflow( output_prefix, checkpoint_name=None, upscale_model_name=None, + stable_audio_clip_name=None, + svd_checkpoint_name=None, ): prompt_text = str(prompt_text or "cinematic shot, highly detailed").strip() if not prompt_text: @@ -93,6 +118,8 @@ def build_workflow( output_prefix = str(output_prefix or "openshot_gen").strip() or "openshot_gen" checkpoint_name = str(checkpoint_name or "").strip() or DEFAULT_SD_CHECKPOINT upscale_model_name = str(upscale_model_name or "").strip() or DEFAULT_UPSCALE_MODEL + stable_audio_clip_name = str(stable_audio_clip_name or "").strip() or DEFAULT_STABLE_AUDIO_CLIP + svd_checkpoint_name = str(svd_checkpoint_name or "").strip() or DEFAULT_SVD_CHECKPOINT seed = random.randint(1, 2**31 - 1) if pipeline_id == "img2img-basic": @@ -148,6 +175,117 @@ def build_workflow( "7": {"inputs": {"video": ["6", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, } + if pipeline_id == "txt2audio-stable-open": + return { + "3": { + "inputs": { + "seed": seed, + "steps": 50, + "cfg": 5.0, + "sampler_name": "dpmpp_3m_sde_gpu", + "scheduler": "exponential", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["11", 0], + }, + "class_type": "KSampler", + }, + "4": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "6": {"inputs": {"text": prompt_text, "clip": ["10", 0]}, "class_type": "CLIPTextEncode"}, + "7": {"inputs": {"text": "", "clip": ["10", 0]}, "class_type": "CLIPTextEncode"}, + "10": {"inputs": {"clip_name": stable_audio_clip_name, "type": "stable_audio"}, "class_type": "CLIPLoader"}, + "11": {"inputs": {"seconds": 30.0, "batch_size": 1}, "class_type": "EmptyLatentAudio"}, + "12": {"inputs": {"samples": ["3", 0], "vae": ["4", 2]}, "class_type": "VAEDecodeAudio"}, + "13": {"inputs": {"filename_prefix": "audio/{}".format(output_prefix), "audio": ["12", 0]}, "class_type": "SaveAudio"}, + } + + if pipeline_id == "txt2video-svd": + return { + "1": {"inputs": {"ckpt_name": svd_checkpoint_name}, "class_type": "ImageOnlyCheckpointLoader"}, + "2": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "3": {"inputs": {"text": prompt_text, "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"text": "text, watermark", "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, + "5": {"inputs": {"width": 512, "height": 288, "batch_size": 1}, "class_type": "EmptyLatentImage"}, + "6": { + "inputs": { + "seed": seed, + "steps": 15, + "cfg": 8.0, + "sampler_name": "uni_pc_bh2", + "scheduler": "normal", + "denoise": 1.0, + "model": ["2", 0], + "positive": ["3", 0], + "negative": ["4", 0], + "latent_image": ["5", 0], + }, + "class_type": "KSampler", + }, + "7": {"inputs": {"samples": ["6", 0], "vae": ["2", 2]}, "class_type": "VAEDecode"}, + "8": { + "inputs": { + "clip_vision": ["1", 1], + "init_image": ["7", 0], + "vae": ["1", 2], + "width": 512, + "height": 288, + "video_frames": 24, + "motion_bucket_id": 127, + "fps": 24, + "augmentation_level": 0.0, + }, + "class_type": "SVD_img2vid_Conditioning", + }, + "9": {"inputs": {"model": ["1", 0], "min_cfg": 1.0}, "class_type": "VideoLinearCFGGuidance"}, + "10": { + "inputs": { + "seed": seed + 1, + "steps": 20, + "cfg": 2.5, + "sampler_name": "euler", + "scheduler": "karras", + "denoise": 1.0, + "model": ["9", 0], + "positive": ["8", 0], + "negative": ["8", 1], + "latent_image": ["8", 2], + }, + "class_type": "KSampler", + }, + "11": {"inputs": {"samples": ["10", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "12": {"inputs": {"images": ["11", 0], "fps": 24}, "class_type": "CreateVideo"}, + "13": {"inputs": {"video": ["12", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + + if pipeline_id == "video2video-basic": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": { + "inputs": {"video": ["1", 0], "start_time": 0.0, "duration": 10.0, "strict_duration": False}, + "class_type": "Video Slice", + }, + "3": {"inputs": {"video": ["2", 0]}, "class_type": "GetVideoComponents"}, + "4": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, + "5": {"inputs": {"text": prompt_text, "clip": ["4", 1]}, "class_type": "CLIPTextEncode"}, + "6": {"inputs": {"text": "low quality, blurry", "clip": ["4", 1]}, "class_type": "CLIPTextEncode"}, + "7": {"inputs": {"pixels": ["3", 0], "vae": ["4", 2]}, "class_type": "VAEEncode"}, + "8": { + "inputs": { + "seed": seed, "steps": 16, "cfg": 6.0, "sampler_name": "euler", "scheduler": "normal", + "denoise": 0.55, "model": ["4", 0], "positive": ["5", 0], "negative": ["6", 0], "latent_image": ["7", 0], + }, + "class_type": "KSampler", + }, + "9": {"inputs": {"samples": ["8", 0], "vae": ["4", 2]}, "class_type": "VAEDecode"}, + "10": {"inputs": {"images": ["9", 0], "audio": ["3", 1], "fps": ["3", 2]}, "class_type": "CreateVideo"}, + "11": {"inputs": {"video": ["10", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + return { "1": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, "2": {"inputs": {"text": prompt_text, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"}, diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index dea2f3c9f..12aeceb5d 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -213,7 +213,7 @@ def _run_comfy_job(self, job_id, request): self._job_prompts.pop(job_id, None) self.job_finished.emit(job_id, False, False, error_text, []) return - image_outputs = ComfyClient.extract_image_outputs(history_entry, save_node_ids=save_node_ids) + image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) self.progress_changed.emit(job_id, 100) self._job_prompts.pop(job_id, None) self.job_finished.emit(job_id, True, False, "", image_outputs) diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index b756327df..c91e61284 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -41,8 +41,14 @@ build_workflow, is_supported_img2img_path, pipeline_requires_checkpoint, + pipeline_requires_svd_checkpoint, + pipeline_requires_stable_audio_clip, pipeline_requires_upscale_model, DEFAULT_SD_CHECKPOINT, + DEFAULT_SD_BASE_CHECKPOINT, + DEFAULT_STABLE_AUDIO_CHECKPOINT, + DEFAULT_STABLE_AUDIO_CLIP, + DEFAULT_SVD_CHECKPOINT, DEFAULT_UPSCALE_MODEL, ) from classes.logger import log @@ -167,10 +173,12 @@ def action_generate_trigger(self, checked=True): pipeline_id = payload.get("template_id") checkpoint_name = None upscale_model_name = None + stable_audio_clip_name = None + svd_checkpoint_name = None client = ComfyClient(self.comfy_ui_url()) workflow_source = source_path - if pipeline_id == "video-upscale-gan": + if pipeline_id in ("video-upscale-gan", "video2video-basic"): if not source_file or source_file.data.get("media_type") != "video": QMessageBox.information(self.win, "Invalid Input", "This pipeline requires a source video file.") return @@ -198,8 +206,17 @@ def action_generate_trigger(self, checked=True): if pipeline_requires_checkpoint(pipeline_id): checkpoint_names = client.list_checkpoints() if checkpoint_names: + preferred_checkpoint = DEFAULT_SD_CHECKPOINT + if pipeline_id == "txt2audio-stable-open": + preferred_checkpoint = DEFAULT_STABLE_AUDIO_CHECKPOINT + elif pipeline_id in ("txt2video-svd", "video2video-basic"): + preferred_checkpoint = DEFAULT_SD_BASE_CHECKPOINT checkpoint_name = ( - DEFAULT_SD_CHECKPOINT if DEFAULT_SD_CHECKPOINT in checkpoint_names else checkpoint_names[0] + preferred_checkpoint if preferred_checkpoint in checkpoint_names else checkpoint_names[0] + ) + if pipeline_requires_svd_checkpoint(pipeline_id): + svd_checkpoint_name = ( + DEFAULT_SVD_CHECKPOINT if DEFAULT_SVD_CHECKPOINT in checkpoint_names else None ) except Exception as ex: log.warning("Failed to query ComfyUI checkpoints: %s", ex) @@ -213,6 +230,15 @@ def action_generate_trigger(self, checked=True): ) return + if pipeline_requires_svd_checkpoint(pipeline_id) and not svd_checkpoint_name: + QMessageBox.information( + self.win, + "No SVD Checkpoint Found", + "ComfyUI could not find the SVD checkpoint required for txt_to_image_to_video.\n" + "Add {} to ComfyUI/models/checkpoints and try again.".format(DEFAULT_SVD_CHECKPOINT), + ) + return + try: if pipeline_requires_upscale_model(pipeline_id): upscale_models = client.list_upscale_models() @@ -232,6 +258,28 @@ def action_generate_trigger(self, checked=True): ) return + try: + if pipeline_requires_stable_audio_clip(pipeline_id): + clip_names = client.list_clip_models() + if clip_names: + for preferred in (DEFAULT_STABLE_AUDIO_CLIP, "t5_base.safetensors"): + if preferred in clip_names: + stable_audio_clip_name = preferred + break + if not stable_audio_clip_name: + stable_audio_clip_name = clip_names[0] + except Exception as ex: + log.warning("Failed to query ComfyUI CLIP models: %s", ex) + + if pipeline_requires_stable_audio_clip(pipeline_id) and not stable_audio_clip_name: + QMessageBox.information( + self.win, + "No Text Encoders Found", + "ComfyUI has no CLIP/text-encoder models available for CLIPLoader.\n" + "Add a text encoder such as t5-base.safetensors and try again.", + ) + return + try: workflow = build_workflow( pipeline_id, @@ -240,6 +288,8 @@ def action_generate_trigger(self, checked=True): payload_name, checkpoint_name=checkpoint_name, upscale_model_name=upscale_model_name, + stable_audio_clip_name=stable_audio_clip_name, + svd_checkpoint_name=svd_checkpoint_name, ) except Exception as ex: QMessageBox.information(self.win, "Invalid Input", str(ex)) @@ -252,7 +302,7 @@ def action_generate_trigger(self, checked=True): "save_node_ids": [ str(node_id) for node_id, node in workflow.items() - if node.get("class_type") in ("SaveImage", "SaveVideo") + if node.get("class_type") in ("SaveImage", "SaveVideo", "SaveAudio") ], } job_id = self.win.generation_queue.enqueue( @@ -311,16 +361,16 @@ def _import_generation_outputs(self, job): safe_name = "generation" saved_paths = [] - for index, image_ref in enumerate(outputs, start=1): - original_name = str(image_ref.get("filename", "output.png")) + for index, output_ref in enumerate(outputs, start=1): + original_name = str(output_ref.get("filename", "output.png")) ext = os.path.splitext(original_name)[1] or ".png" local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ext) local_path = self._next_available_path(os.path.join(output_dir, local_name)) try: - client.download_image(image_ref, local_path) + client.download_output_file(output_ref, local_path) saved_paths.append(local_path) except Exception as ex: - log.warning("Failed to download Comfy output %s: %s", image_ref, ex) + log.warning("Failed to download Comfy output %s: %s", output_ref, ex) if not saved_paths: return 0 From 263dd13019cc58a03897085678b312f0eb78716a Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 14 Feb 2026 14:49:35 -0600 Subject: [PATCH 04/27] Adding support for Whisper/SRT subtitles to the Generate dialog. Also, auto Caption effect when a file has caption data already. --- src/classes/clip_utils.py | 36 ++++++++ src/classes/comfy_client.py | 37 ++++++++ src/classes/comfy_pipelines.py | 45 ++++++++++ src/classes/generation_service.py | 143 ++++++++++++++++++++++++++++-- src/windows/add_to_timeline.py | 4 + src/windows/views/timeline.py | 5 +- 6 files changed, 262 insertions(+), 8 deletions(-) diff --git a/src/classes/clip_utils.py b/src/classes/clip_utils.py index 5c2fd3233..12a3dfea3 100644 --- a/src/classes/clip_utils.py +++ b/src/classes/clip_utils.py @@ -25,14 +25,50 @@ """ import logging +import json from fractions import Fraction from typing import Any, Mapping, Optional, Tuple +import openshot + from classes.app import get_app logger = logging.getLogger(__name__) +def apply_file_caption_to_clip(clip_data: Any, file_obj: Any, *, dedupe: bool = True) -> bool: + """Attach a Caption effect to clip_data when file metadata includes caption text.""" + if not isinstance(clip_data, Mapping): + return False + file_data = getattr(file_obj, "data", None) + if not isinstance(file_data, Mapping): + return False + caption_text = str(file_data.get("caption", "") or "").strip() + if not caption_text: + return False + + effects = clip_data.get("effects") + if not isinstance(effects, list): + effects = list(effects) if effects else [] + clip_data["effects"] = effects + + if dedupe: + for effect in effects: + if not isinstance(effect, Mapping): + continue + if str(effect.get("class_name", "")).lower() == "caption": + existing_text = str(effect.get("caption_text", "") or "").strip() + if existing_text == caption_text: + return False + + caption_effect = openshot.EffectInfo().CreateEffect("Caption") + caption_effect.Id(get_app().project.generate_id()) + caption_json = json.loads(caption_effect.Json()) + caption_json["caption_text"] = caption_text + effects.append(caption_json) + return True + + def _as_mapping(candidate: Any) -> Mapping[str, Any]: """Return dict-style metadata for clips, readers, or similar.""" if isinstance(candidate, Mapping): diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index 89411ec59..ff6929dff 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -30,6 +30,7 @@ import ssl import base64 import uuid +import re import socket import struct from urllib.error import HTTPError @@ -610,12 +611,48 @@ def extract_file_outputs(history_entry, save_node_ids=None): "subfolder": str(ref.get("subfolder", "")), "type": str(ref.get("type", "output")), }) + # Also extract text-like outputs (for custom nodes such as Whisper/SRT pipelines). + for value in node_out.values(): + text_value = ComfyClient._extract_text_output(value) + if not text_value: + continue + output_format = "srt" if ComfyClient._looks_like_srt(text_value) else "txt" + outputs.append({ + "text": text_value, + "format": output_format, + "type": "text", + }) return outputs @staticmethod def extract_image_outputs(history_entry, save_node_ids=None): return ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) + @staticmethod + def _extract_text_output(value): + """Extract text payloads from common Comfy output structures.""" + if isinstance(value, str): + text = value.strip() + return text if text else "" + if isinstance(value, list): + if len(value) == 1 and isinstance(value[0], str): + text = value[0].strip() + return text if text else "" + return "" + if isinstance(value, dict): + for key in ("srt", "text", "value"): + text = value.get(key) + if isinstance(text, str) and text.strip(): + return text.strip() + return "" + + @staticmethod + def _looks_like_srt(text): + text = str(text or "") + if "-->" not in text: + return False + return bool(re.search(r"\d{2}:\d{2}:\d{2}[,.:]\d{3}\s+-->\s+\d{2}:\d{2}:\d{2}[,.:]\d{3}", text)) + def download_output_file(self, file_ref, destination_path): """Download a Comfy output reference to a local file path.""" params = { diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index 39638ed89..45fd90458 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -77,6 +77,7 @@ def available_pipelines(source_file=None): if _supports_video_upscale(source_file): pipelines.append({"id": "video-upscale-gan", "name": "Upscale Video (GAN x4, first 10s)"}) pipelines.append({"id": "video2video-basic", "name": "Video + Text to Video (Style Transfer)"}) + pipelines.append({"id": "video-whisper-srt", "name": "Whisper Transcribe to SRT (Caption Effect)"}) return pipelines @@ -175,6 +176,50 @@ def build_workflow( "7": {"inputs": {"video": ["6", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, } + if pipeline_id == "video-whisper-srt": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": { + "inputs": { + "video": source_path, + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "format": "AnimateDiff", + }, + "class_type": "VHS_LoadVideo", + }, + "2": { + "inputs": { + "model": "medium", + "language": "auto", + "prompt": "", + "audio": ["1", 2], + }, + "class_type": "Apply Whisper", + }, + "3": { + "inputs": { + "name": "{}_segments".format(output_prefix), + "alignment": ["2", 1], + }, + "class_type": "Save SRT", + }, + "4": { + "inputs": { + "preview": "", + "previewMode": None, + "source": ["3", 0], + }, + "class_type": "PreviewAny", + }, + } + if pipeline_id == "txt2audio-stable-open": return { "3": { diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index c91e61284..8cf7a4a13 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -29,6 +29,7 @@ import re import tempfile from time import time +from urllib.parse import unquote import openshot from PyQt5.QtWidgets import QMessageBox, QDialog @@ -178,7 +179,7 @@ def action_generate_trigger(self, checked=True): client = ComfyClient(self.comfy_ui_url()) workflow_source = source_path - if pipeline_id in ("video-upscale-gan", "video2video-basic"): + if pipeline_id in ("video-upscale-gan", "video2video-basic", "video-whisper-srt"): if not source_file or source_file.data.get("media_type") != "video": QMessageBox.information(self.win, "Invalid Input", "This pipeline requires a source video file.") return @@ -302,7 +303,7 @@ def action_generate_trigger(self, checked=True): "save_node_ids": [ str(node_id) for node_id, node in workflow.items() - if node.get("class_type") in ("SaveImage", "SaveVideo", "SaveAudio") + if node.get("class_type") in ("SaveImage", "SaveVideo", "SaveAudio", "Save SRT", "PreviewAny") ], } job_id = self.win.generation_queue.enqueue( @@ -328,9 +329,18 @@ def on_generation_job_finished(self, job_id, status): return if status == "completed": - imported = self._import_generation_outputs(job) - if imported > 0: + result = self._import_generation_outputs(job) + imported = int(result.get("imported", 0)) + caption_saved = bool(result.get("caption_saved", False)) + if imported > 0 and caption_saved: + self.win.statusBar.showMessage( + "Generation completed, imported {} file(s), and saved file caption data".format(imported), + 5000, + ) + elif imported > 0: self.win.statusBar.showMessage("Generation completed and imported {} file(s)".format(imported), 5000) + elif caption_saved: + self.win.statusBar.showMessage("Generation completed and saved file caption data", 5000) else: self.win.statusBar.showMessage("Generation completed (no output files found)", 5000) return @@ -347,7 +357,7 @@ def on_generation_job_finished(self, job_id, status): def _import_generation_outputs(self, job): outputs = list(job.get("outputs", []) or []) if not outputs: - return 0 + return {"imported": 0, "caption_saved": False} request = job.get("request", {}) or {} comfy_url = str(request.get("comfy_url") or self.comfy_ui_url()) @@ -361,7 +371,40 @@ def _import_generation_outputs(self, job): safe_name = "generation" saved_paths = [] + text_outputs = [] for index, output_ref in enumerate(outputs, start=1): + text_payload = str(output_ref.get("text", "")).strip() + if text_payload: + # Some Save SRT node variants return the output file path as text. + # Convert that path to a downloadable Comfy output ref when possible. + if text_payload.lower().endswith(".srt"): + srt_ref = self._comfy_output_ref_from_path(text_payload) + if srt_ref: + local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ".srt") + local_path = self._next_available_path(os.path.join(output_dir, local_name)) + try: + client.download_output_file(srt_ref, local_path) + with open(local_path, "r", encoding="utf-8") as handle: + srt_text = handle.read().strip() + if srt_text: + saved_paths.append(local_path) + text_outputs.append(srt_text) + continue + except Exception as ex: + log.warning("Failed to download/read SRT from Comfy path output %s: %s", text_payload, ex) + + ext = ".srt" if str(output_ref.get("format", "")).lower() == "srt" else ".txt" + local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ext) + local_path = self._next_available_path(os.path.join(output_dir, local_name)) + try: + with open(local_path, "w", encoding="utf-8") as handle: + handle.write(text_payload) + saved_paths.append(local_path) + text_outputs.append(text_payload) + except Exception as ex: + log.warning("Failed to write Comfy text output to %s: %s", local_path, ex) + continue + original_name = str(output_ref.get("filename", "output.png")) ext = os.path.splitext(original_name)[1] or ".png" local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), ext) @@ -373,7 +416,7 @@ def _import_generation_outputs(self, job): log.warning("Failed to download Comfy output %s: %s", output_ref, ex) if not saved_paths: - return 0 + return {"imported": 0, "caption_saved": False} self.win.files_model.add_files( saved_paths, @@ -381,7 +424,93 @@ def _import_generation_outputs(self, job): prevent_image_seq=True, prevent_recent_folder=True, ) - return len(saved_paths) + + caption_saved = False + if str(job.get("template_id") or "") == "video-whisper-srt": + caption_text = self._resolve_caption_text(saved_paths, text_outputs) + caption_saved = self._store_caption_on_file( + source_file_id=job.get("source_file_id"), + caption_text=caption_text, + ) + return {"imported": len(saved_paths), "caption_saved": caption_saved} + + def _resolve_caption_text(self, saved_paths, text_outputs): + srt_path = "" + for path in saved_paths: + if str(path).lower().endswith(".srt"): + srt_path = path + break + if srt_path: + try: + with open(srt_path, "r", encoding="utf-8") as handle: + text = handle.read().strip() + if text: + return text + except Exception as ex: + log.warning("Failed reading SRT file for file caption metadata: %s", ex) + + for value in text_outputs: + text = str(value or "").strip() + if "-->" in text: + return text + + for value in text_outputs: + text = str(value or "").strip() + if text: + return text + + return "" + + def _store_caption_on_file(self, source_file_id, caption_text): + caption_text = str(caption_text or "").strip() + if not caption_text: + return False + + source_file_value = source_file_id + file_obj = File.get(id=source_file_value) + if file_obj is None: + file_obj = File.get(id=str(source_file_value or "")) + if file_obj is None: + log.info("No source file found for caption metadata update (file_id=%s)", source_file_value) + return False + + if not isinstance(file_obj.data, dict): + file_obj.data = {} + file_obj.data["caption"] = caption_text + file_obj.save() + self.win.FileUpdated.emit(str(file_obj.id)) + return True + + def _comfy_output_ref_from_path(self, path_text): + """Convert a Comfy output absolute/relative path into a /view-compatible output ref.""" + path_text = unquote(str(path_text or "").strip()) + if not path_text: + return None + normalized = path_text.replace("\\", "/") + filename = os.path.basename(normalized) + if not filename: + return None + + subfolder = "" + marker = "/output/" + if marker in normalized: + rel = normalized.split(marker, 1)[1].lstrip("/") + rel_dir = os.path.dirname(rel).strip("/") + subfolder = rel_dir + elif normalized.startswith("output/"): + rel = normalized[len("output/"):] + rel_dir = os.path.dirname(rel).strip("/") + subfolder = rel_dir + else: + rel_dir = os.path.dirname(normalized).strip("/") + if rel_dir and rel_dir != ".": + subfolder = rel_dir + + return { + "filename": filename, + "subfolder": subfolder, + "type": "output", + } def _next_generation_name(self, requested_name): base = re.sub(r"[^A-Za-z0-9._-]+", "_", str(requested_name or "").strip()).strip("._") diff --git a/src/windows/add_to_timeline.py b/src/windows/add_to_timeline.py index f11ea3e0c..810cdf9ea 100644 --- a/src/windows/add_to_timeline.py +++ b/src/windows/add_to_timeline.py @@ -40,6 +40,7 @@ from classes.query import Clip, Transition from classes.app import get_app from classes.metrics import track_metric_screen +from classes.clip_utils import apply_file_caption_to_clip from windows.views.add_to_timeline_treeview import TimelineTreeView import openshot @@ -215,6 +216,9 @@ def accept(self): if not new_clip.get("reader"): continue # Skip to next file + # If the source file has stored caption text, attach a Caption effect to this new clip. + apply_file_caption_to_clip(new_clip, file) + # Check for optional start and end attributes start_time = 0 end_time = new_clip["reader"]["duration"] diff --git a/src/windows/views/timeline.py b/src/windows/views/timeline.py index 840f02942..d7131ef9f 100644 --- a/src/windows/views/timeline.py +++ b/src/windows/views/timeline.py @@ -57,7 +57,7 @@ from .timeline_backend.qwidget import TimelineWidget from .timeline_backend.colors import effect_color_hex from .menu import StyledContextMenu -from classes.clip_utils import clamp_timing_to_media +from classes.clip_utils import clamp_timing_to_media, apply_file_caption_to_clip from .retime import retime_clip from .repeat import apply_repeat, reset_repeat, RepeatDialog @@ -4023,6 +4023,9 @@ def addClip(self, file_id, position, track, ignore_refresh=False, call_manual_mo if not new_clip.get("reader"): return # Skip this clip + # If the source file has stored caption text, attach a Caption effect to this new clip. + apply_file_caption_to_clip(new_clip, file) + # Determine start, duration, and end using file metadata media_type = (file.data or {}).get("media_type") start_value = file.data.get("start", new_clip.get("start", 0.0)) From a20d997a1377338e1c92b284369559173026c51b Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 14 Feb 2026 21:21:43 -0600 Subject: [PATCH 05/27] - Add Comfy templates for img2video-svd and RIFE 2x frame interpolation, with full workflow wiring. - Use remote Comfy node/model discovery (RIFE + flexible SVD checkpoint selection). - Show immediate generation progress in Project Files (including queued state) without mutating file names. --- src/classes/comfy_client.py | 15 +++++ src/classes/comfy_pipelines.py | 95 ++++++++++++++++++++++++++++- src/classes/generation_service.py | 69 ++++++++++++++++++--- src/windows/models/files_model.py | 4 +- src/windows/views/files_listview.py | 9 ++- src/windows/views/files_treeview.py | 9 ++- 6 files changed, 186 insertions(+), 15 deletions(-) diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index ff6929dff..2eded71fb 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -435,6 +435,21 @@ def list_clip_vision_models(self): values = self._extract_combo_options(clip_input) return [str(v) for v in values if str(v).strip()] + def list_rife_vfi_models(self): + """Return available RIFE checkpoint names from ComfyUI object info.""" + node_type = "RIFE VFI" + with urlopen( + "{}/object_info/{}".format(self.base_url, quote(node_type, safe="")), + timeout=8.0, + ) as response: + data = json.loads(response.read().decode("utf-8")) + + node_info = data.get(node_type, {}) + required = node_info.get("input", {}).get("required", {}) + ckpt_input = required.get("ckpt_name", None) + values = self._extract_combo_options(ckpt_input) + return [str(v) for v in values if str(v).strip()] + @staticmethod def _extract_combo_options(input_config): """Extract valid options from Comfy object_info input config variants.""" diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index 45fd90458..69d3ffadf 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -39,6 +39,7 @@ DEFAULT_STABLE_AUDIO_CHECKPOINT = "stable-audio-open-1.0.safetensors" DEFAULT_STABLE_AUDIO_CLIP = "t5-base.safetensors" DEFAULT_SVD_CHECKPOINT = "svd_xt.safetensors" +DEFAULT_RIFE_VFI_MODEL = "rife47.pth" def is_supported_img2img_path(path): @@ -74,7 +75,9 @@ def available_pipelines(source_file=None): if _supports_img2img(source_file): pipelines.insert(0, {"id": "img2img-basic", "name": "Basic Image Variation"}) pipelines.insert(1, {"id": "upscale-realesrgan-x4", "name": "Upscale Image (RealESRGAN x4)"}) + pipelines.insert(2, {"id": "img2video-svd", "name": "Image to Video (img_to_video)"}) if _supports_video_upscale(source_file): + pipelines.append({"id": "video-frame-interpolation-rife2x", "name": "Frame Interpolation (RIFE 2x FPS)"}) pipelines.append({"id": "video-upscale-gan", "name": "Upscale Video (GAN x4, first 10s)"}) pipelines.append({"id": "video2video-basic", "name": "Video + Text to Video (Style Transfer)"}) pipelines.append({"id": "video-whisper-srt", "name": "Whisper Transcribe to SRT (Caption Effect)"}) @@ -100,7 +103,11 @@ def pipeline_requires_stable_audio_clip(pipeline_id): def pipeline_requires_svd_checkpoint(pipeline_id): - return str(pipeline_id or "") in ("txt2video-svd",) + return str(pipeline_id or "") in ("txt2video-svd", "img2video-svd") + + +def pipeline_requires_rife_model(pipeline_id): + return str(pipeline_id or "") in ("video-frame-interpolation-rife2x",) def build_workflow( @@ -112,6 +119,8 @@ def build_workflow( upscale_model_name=None, stable_audio_clip_name=None, svd_checkpoint_name=None, + source_fps=None, + rife_model_name=None, ): prompt_text = str(prompt_text or "cinematic shot, highly detailed").strip() if not prompt_text: @@ -121,6 +130,14 @@ def build_workflow( upscale_model_name = str(upscale_model_name or "").strip() or DEFAULT_UPSCALE_MODEL stable_audio_clip_name = str(stable_audio_clip_name or "").strip() or DEFAULT_STABLE_AUDIO_CLIP svd_checkpoint_name = str(svd_checkpoint_name or "").strip() or DEFAULT_SVD_CHECKPOINT + rife_model_name = str(rife_model_name or "").strip() or DEFAULT_RIFE_VFI_MODEL + try: + source_fps_value = float(source_fps) + except (TypeError, ValueError): + source_fps_value = 30.0 + if source_fps_value <= 0: + source_fps_value = 30.0 + target_fps = round(source_fps_value * 2.0, 6) seed = random.randint(1, 2**31 - 1) if pipeline_id == "img2img-basic": @@ -220,6 +237,38 @@ def build_workflow( }, } + if pipeline_id == "video-frame-interpolation-rife2x": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "1": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": {"inputs": {"video": ["1", 0]}, "class_type": "GetVideoComponents"}, + "3": { + "inputs": { + "frames": ["2", 0], + "ckpt_name": rife_model_name, + "clear_cache_after_n_frames": 10, + "multiplier": 2, + "fast_mode": True, + "ensemble": True, + "scale_factor": 1, + }, + "class_type": "RIFE VFI", + "_meta": {"title": "RIFE VFI (recommend rife47 and rife49)"}, + }, + "4": {"inputs": {"images": ["3", 0], "audio": ["2", 1], "fps": target_fps}, "class_type": "CreateVideo"}, + "5": { + "inputs": { + "video": ["4", 0], + "filename_prefix": "video/{}".format(output_prefix), + "format": "auto", + "codec": "auto", + }, + "class_type": "SaveVideo", + }, + } + if pipeline_id == "txt2audio-stable-open": return { "3": { @@ -304,6 +353,50 @@ def build_workflow( "13": {"inputs": {"video": ["12", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, } + if pipeline_id == "img2video-svd": + if not is_supported_img2img_path(source_path): + raise ValueError( + "The selected file is not a supported raster image for this pipeline. " + "Use PNG/JPG/WebP/BMP/TIFF or switch to Text to Video." + ) + return { + "1": {"inputs": {"ckpt_name": svd_checkpoint_name}, "class_type": "ImageOnlyCheckpointLoader"}, + "2": {"inputs": {"image": str(source_path or ""), "upload": "image"}, "class_type": "LoadImage"}, + "3": { + "inputs": { + "clip_vision": ["1", 1], + "init_image": ["2", 0], + "vae": ["1", 2], + "width": 512, + "height": 288, + "video_frames": 24, + "motion_bucket_id": 127, + "fps": 24, + "augmentation_level": 0.0, + }, + "class_type": "SVD_img2vid_Conditioning", + }, + "4": {"inputs": {"model": ["1", 0], "min_cfg": 1.0}, "class_type": "VideoLinearCFGGuidance"}, + "5": { + "inputs": { + "seed": seed + 1, + "steps": 20, + "cfg": 2.5, + "sampler_name": "euler", + "scheduler": "karras", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["3", 0], + "negative": ["3", 1], + "latent_image": ["3", 2], + }, + "class_type": "KSampler", + }, + "6": {"inputs": {"samples": ["5", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, + "7": {"inputs": {"images": ["6", 0], "fps": 24}, "class_type": "CreateVideo"}, + "8": {"inputs": {"video": ["7", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, + } + if pipeline_id == "video2video-basic": source_path = str(source_path or "").strip() if not source_path: diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 8cf7a4a13..f4dd6a147 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -44,7 +44,9 @@ pipeline_requires_checkpoint, pipeline_requires_svd_checkpoint, pipeline_requires_stable_audio_clip, + pipeline_requires_rife_model, pipeline_requires_upscale_model, + DEFAULT_RIFE_VFI_MODEL, DEFAULT_SD_CHECKPOINT, DEFAULT_SD_BASE_CHECKPOINT, DEFAULT_STABLE_AUDIO_CHECKPOINT, @@ -103,7 +105,7 @@ def _prepare_generation_source_path(self, source_file, template_id): source_path = source_file.data.get("path", "") media_type = source_file.data.get("media_type") - if template_id not in ("img2img-basic", "upscale-realesrgan-x4") or media_type != "image": + if template_id not in ("img2img-basic", "upscale-realesrgan-x4", "img2video-svd") or media_type != "image": return source_path if is_supported_img2img_path(source_path): @@ -138,6 +140,20 @@ def _prepare_generation_image_input(self, local_image_path, client): raise ValueError("A source image is required.") return client.upload_input_file(local_image_path) + def _get_source_fps(self, source_file): + if not source_file: + return None + fps_data = source_file.data.get("fps") + if isinstance(fps_data, dict): + try: + num = float(fps_data.get("num", 0)) + den = float(fps_data.get("den", 0)) + except (TypeError, ValueError): + num = den = 0.0 + if num > 0 and den > 0: + return num / den + return None + def action_generate_trigger(self, checked=True): selected_files = self.win.selected_files() if len(selected_files) > 1: @@ -176,10 +192,16 @@ def action_generate_trigger(self, checked=True): upscale_model_name = None stable_audio_clip_name = None svd_checkpoint_name = None + rife_model_name = None client = ComfyClient(self.comfy_ui_url()) workflow_source = source_path - if pipeline_id in ("video-upscale-gan", "video2video-basic", "video-whisper-srt"): + if pipeline_id in ( + "video-upscale-gan", + "video2video-basic", + "video-whisper-srt", + "video-frame-interpolation-rife2x", + ): if not source_file or source_file.data.get("media_type") != "video": QMessageBox.information(self.win, "Invalid Input", "This pipeline requires a source video file.") return @@ -192,7 +214,7 @@ def action_generate_trigger(self, checked=True): "OpenShot could not upload the source video into ComfyUI input.\n\n{}".format(ex), ) return - elif pipeline_id in ("img2img-basic", "upscale-realesrgan-x4"): + elif pipeline_id in ("img2img-basic", "upscale-realesrgan-x4", "img2video-svd"): try: workflow_source = self._prepare_generation_image_input(source_path, client) except Exception as ex: @@ -204,7 +226,8 @@ def action_generate_trigger(self, checked=True): return try: - if pipeline_requires_checkpoint(pipeline_id): + checkpoint_names = [] + if pipeline_requires_checkpoint(pipeline_id) or pipeline_requires_svd_checkpoint(pipeline_id): checkpoint_names = client.list_checkpoints() if checkpoint_names: preferred_checkpoint = DEFAULT_SD_CHECKPOINT @@ -216,9 +239,13 @@ def action_generate_trigger(self, checked=True): preferred_checkpoint if preferred_checkpoint in checkpoint_names else checkpoint_names[0] ) if pipeline_requires_svd_checkpoint(pipeline_id): - svd_checkpoint_name = ( - DEFAULT_SVD_CHECKPOINT if DEFAULT_SVD_CHECKPOINT in checkpoint_names else None - ) + if DEFAULT_SVD_CHECKPOINT in checkpoint_names: + svd_checkpoint_name = DEFAULT_SVD_CHECKPOINT + else: + # Prefer any checkpoint that appears to be an SVD model. + svd_candidates = [name for name in checkpoint_names if "svd" in str(name).lower()] + if svd_candidates: + svd_checkpoint_name = svd_candidates[0] except Exception as ex: log.warning("Failed to query ComfyUI checkpoints: %s", ex) @@ -235,8 +262,8 @@ def action_generate_trigger(self, checked=True): QMessageBox.information( self.win, "No SVD Checkpoint Found", - "ComfyUI could not find the SVD checkpoint required for txt_to_image_to_video.\n" - "Add {} to ComfyUI/models/checkpoints and try again.".format(DEFAULT_SVD_CHECKPOINT), + "ComfyUI could not find the SVD checkpoint required for the selected video generation template.\n" + "Add an SVD checkpoint (for example {}) to ComfyUI/models/checkpoints and try again.".format(DEFAULT_SVD_CHECKPOINT), ) return @@ -281,6 +308,28 @@ def action_generate_trigger(self, checked=True): ) return + try: + if pipeline_requires_rife_model(pipeline_id): + rife_models = client.list_rife_vfi_models() + if rife_models: + for preferred in (DEFAULT_RIFE_VFI_MODEL, "rife49.pth"): + if preferred in rife_models: + rife_model_name = preferred + break + if not rife_model_name: + rife_model_name = rife_models[0] + except Exception as ex: + log.warning("Failed to query ComfyUI RIFE VFI models: %s", ex) + + if pipeline_requires_rife_model(pipeline_id) and not rife_model_name: + QMessageBox.information( + self.win, + "RIFE VFI Not Available", + "ComfyUI could not find the RIFE VFI node/models required for frame interpolation.\n" + "Install ComfyUI-Frame-Interpolation and add models such as rife47.pth.", + ) + return + try: workflow = build_workflow( pipeline_id, @@ -291,6 +340,8 @@ def action_generate_trigger(self, checked=True): upscale_model_name=upscale_model_name, stable_audio_clip_name=stable_audio_clip_name, svd_checkpoint_name=svd_checkpoint_name, + source_fps=self._get_source_fps(source_file), + rife_model_name=rife_model_name, ) except Exception as ex: QMessageBox.information(self.win, "Invalid Input", str(ex)) diff --git a/src/windows/models/files_model.py b/src/windows/models/files_model.py index 55e7a809a..792db5b89 100644 --- a/src/windows/models/files_model.py +++ b/src/windows/models/files_model.py @@ -923,12 +923,14 @@ def _on_generation_job_removed(self, job_id): self._remove_generation_placeholder(job_id) def _refresh_file_generation_display(self, file_id): + file_id = str(file_id or "") + if not file_id: + return if file_id not in self.model_ids: return id_index = self.model_ids[file_id] if not id_index.isValid(): return - row = id_index.row() left = self.model.index(row, 0) right = self.model.index(row, 0) diff --git a/src/windows/views/files_listview.py b/src/windows/views/files_listview.py index c345b98e0..048fae1d0 100644 --- a/src/windows/views/files_listview.py +++ b/src/windows/views/files_listview.py @@ -88,8 +88,9 @@ def paint(self, painter, option, index): return progress = int(badge.get("progress", 0)) - status = badge.get("status", "") - if status == "queued": + status = str(badge.get("status", "")).strip().lower() + if status in ("queued", "running", "canceling"): + # Keep active jobs visible even before numeric progress starts. progress = max(progress, 2) if progress <= 0: return @@ -118,6 +119,10 @@ def paint(self, painter, option, index): painter.drawRect(full_rect) painter.setBrush(QColor("#53A0ED")) painter.drawRect(fill_rect) + if status == "queued": + text_rect = full_rect.adjusted(0, -14, 0, -4) + painter.setPen(QColor("#9EC8F7")) + painter.drawText(text_rect, Qt.AlignLeft | Qt.AlignVCenter, "Queued") painter.restore() diff --git a/src/windows/views/files_treeview.py b/src/windows/views/files_treeview.py index 2dd51cecf..1f16c51f3 100644 --- a/src/windows/views/files_treeview.py +++ b/src/windows/views/files_treeview.py @@ -84,8 +84,9 @@ def paint(self, painter, option, index): return progress = int(badge.get("progress", 0)) - status = badge.get("status", "") - if status == "queued": + status = str(badge.get("status", "")).strip().lower() + if status in ("queued", "running", "canceling"): + # Keep active jobs visible even before numeric progress starts. progress = max(progress, 2) if progress <= 0: return @@ -114,6 +115,10 @@ def paint(self, painter, option, index): painter.drawRect(full_rect) painter.setBrush(QColor("#53A0ED")) painter.drawRect(fill_rect) + if status == "queued": + text_rect = full_rect.adjusted(0, -14, 0, -4) + painter.setPen(QColor("#9EC8F7")) + painter.drawText(text_rect, Qt.AlignLeft | Qt.AlignVCenter, "Queued") painter.restore() From 2eca52951eaa6c13af3596a4e99b3f55a86c407d Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 14 Feb 2026 21:52:55 -0600 Subject: [PATCH 06/27] Add TransNet scene segmentation template and robust segment import handling - Add video-segment-scenes-transnet Comfy workflow (LoadVideo + TransNetV2 + PreviewAny output) - Import all generated segment paths into Project Files, apply split-style scene naming, and add scene tags - Improve Comfy output extraction to handle list/string node outputs from custom nodes --- src/classes/comfy_client.py | 79 +++++++++------ src/classes/comfy_pipelines.py | 45 +++++++++ src/classes/generation_service.py | 158 +++++++++++++++++++++++++++++- 3 files changed, 250 insertions(+), 32 deletions(-) diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index 2eded71fb..da1542760 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -611,32 +611,40 @@ def extract_file_outputs(history_entry, save_node_ids=None): for node_id, node_out in node_outputs.items(): if save_node_ids and str(node_id) not in save_node_ids: continue - if not isinstance(node_out, dict): - continue - for key in ("images", "videos", "video", "audios", "audio", "files"): - refs = node_out.get(key, []) - if not isinstance(refs, list): - continue - for ref in refs: - if not isinstance(ref, dict): + if isinstance(node_out, dict): + for key in ("images", "videos", "video", "audios", "audio", "files"): + refs = node_out.get(key, []) + if not isinstance(refs, list): continue - if ref.get("filename"): + for ref in refs: + if not isinstance(ref, dict): + continue + if ref.get("filename"): + outputs.append({ + "filename": str(ref.get("filename")), + "subfolder": str(ref.get("subfolder", "")), + "type": str(ref.get("type", "output")), + }) + # Also extract text-like outputs (for custom nodes such as Whisper/SRT pipelines). + for value in node_out.values(): + text_values = ComfyClient._extract_text_outputs(value) + for text_value in text_values: + output_format = "srt" if ComfyClient._looks_like_srt(text_value) else "txt" outputs.append({ - "filename": str(ref.get("filename")), - "subfolder": str(ref.get("subfolder", "")), - "type": str(ref.get("type", "output")), + "text": text_value, + "format": output_format, + "type": "text", }) - # Also extract text-like outputs (for custom nodes such as Whisper/SRT pipelines). - for value in node_out.values(): - text_value = ComfyClient._extract_text_output(value) - if not text_value: - continue - output_format = "srt" if ComfyClient._looks_like_srt(text_value) else "txt" - outputs.append({ - "text": text_value, - "format": output_format, - "type": "text", - }) + else: + # Some custom nodes emit list/string outputs directly instead of dicts. + text_values = ComfyClient._extract_text_outputs(node_out) + for text_value in text_values: + output_format = "srt" if ComfyClient._looks_like_srt(text_value) else "txt" + outputs.append({ + "text": text_value, + "format": output_format, + "type": "text", + }) return outputs @staticmethod @@ -646,20 +654,31 @@ def extract_image_outputs(history_entry, save_node_ids=None): @staticmethod def _extract_text_output(value): """Extract text payloads from common Comfy output structures.""" + values = ComfyClient._extract_text_outputs(value) + return values[0] if values else "" + + @staticmethod + def _extract_text_outputs(value): + """Extract one or more text payloads from common Comfy output structures.""" if isinstance(value, str): text = value.strip() - return text if text else "" + return [text] if text else [] if isinstance(value, list): - if len(value) == 1 and isinstance(value[0], str): - text = value[0].strip() - return text if text else "" - return "" + out = [] + for item in value: + if isinstance(item, str): + text = item.strip() + if text: + out.append(text) + return out if isinstance(value, dict): + out = [] for key in ("srt", "text", "value"): text = value.get(key) if isinstance(text, str) and text.strip(): - return text.strip() - return "" + out.append(text.strip()) + return out + return [] @staticmethod def _looks_like_srt(text): diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index 69d3ffadf..5e5cee1e1 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -77,6 +77,7 @@ def available_pipelines(source_file=None): pipelines.insert(1, {"id": "upscale-realesrgan-x4", "name": "Upscale Image (RealESRGAN x4)"}) pipelines.insert(2, {"id": "img2video-svd", "name": "Image to Video (img_to_video)"}) if _supports_video_upscale(source_file): + pipelines.append({"id": "video-segment-scenes-transnet", "name": "Segment Scenes (TransNetV2)"}) pipelines.append({"id": "video-frame-interpolation-rife2x", "name": "Frame Interpolation (RIFE 2x FPS)"}) pipelines.append({"id": "video-upscale-gan", "name": "Upscale Video (GAN x4, first 10s)"}) pipelines.append({"id": "video2video-basic", "name": "Video + Text to Video (Style Transfer)"}) @@ -269,6 +270,50 @@ def build_workflow( }, } + if pipeline_id == "video-segment-scenes-transnet": + source_path = str(source_path or "").strip() + if not source_path: + raise ValueError("A source video is required for this pipeline.") + return { + "7": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, + "2": { + "inputs": { + "model": "transnetv2-pytorch-weights", + "device": "auto", + }, + "class_type": "DownloadAndLoadTransNetModel", + "_meta": {"title": "MiaoshouAI Load TransNet Model"}, + }, + "1": { + "inputs": { + "threshold": 0.5, + "min_scene_length": 30, + "output_dir": "output", + "TransNet_model": ["2", 0], + "video": ["7", 0], + }, + "class_type": "TransNetV2_Run", + "_meta": {"title": "MiaoshouAI Segment Video"}, + }, + "8": { + "inputs": { + "index": 0, + "segment_paths": ["1", 0], + }, + "class_type": "SelectVideo", + "_meta": {"title": "MiaoshouAI Select Video"}, + }, + "9": { + "inputs": { + "preview": "", + "previewMode": None, + "source": ["1", 0], + }, + "class_type": "PreviewAny", + "_meta": {"title": "Preview Any"}, + }, + } + if pipeline_id == "txt2audio-stable-open": return { "3": { diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index f4dd6a147..1014870fe 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -30,11 +30,13 @@ import tempfile from time import time from urllib.parse import unquote +from fractions import Fraction import openshot from PyQt5.QtWidgets import QMessageBox, QDialog from classes import info +from classes import time_parts from classes.app import get_app from classes.comfy_client import ComfyClient from classes.comfy_pipelines import ( @@ -201,6 +203,7 @@ def action_generate_trigger(self, checked=True): "video2video-basic", "video-whisper-srt", "video-frame-interpolation-rife2x", + "video-segment-scenes-transnet", ): if not source_file or source_file.data.get("media_type") != "video": QMessageBox.information(self.win, "Invalid Input", "This pipeline requires a source video file.") @@ -354,7 +357,14 @@ def action_generate_trigger(self, checked=True): "save_node_ids": [ str(node_id) for node_id, node in workflow.items() - if node.get("class_type") in ("SaveImage", "SaveVideo", "SaveAudio", "Save SRT", "PreviewAny") + if node.get("class_type") in ( + "SaveImage", + "SaveVideo", + "SaveAudio", + "Save SRT", + "PreviewAny", + "TransNetV2_Run", + ) ], } job_id = self.win.generation_queue.enqueue( @@ -383,11 +393,19 @@ def on_generation_job_finished(self, job_id, status): result = self._import_generation_outputs(job) imported = int(result.get("imported", 0)) caption_saved = bool(result.get("caption_saved", False)) + scenes_labeled = int(result.get("scenes_labeled", 0)) if imported > 0 and caption_saved: self.win.statusBar.showMessage( "Generation completed, imported {} file(s), and saved file caption data".format(imported), 5000, ) + elif imported > 0 and scenes_labeled > 0: + self.win.statusBar.showMessage( + "Generation completed, imported {} file(s), and labeled {} scene segment(s)".format( + imported, scenes_labeled + ), + 5000, + ) elif imported > 0: self.win.statusBar.showMessage("Generation completed and imported {} file(s)".format(imported), 5000) elif caption_saved: @@ -423,9 +441,42 @@ def _import_generation_outputs(self, job): saved_paths = [] text_outputs = [] + video_path_exts = {".mp4", ".mov", ".mkv", ".webm", ".avi", ".m4v"} + seen_video_payload_paths = set() for index, output_ref in enumerate(outputs, start=1): text_payload = str(output_ref.get("text", "")).strip() if text_payload: + payload_video_paths = self._extract_video_paths_from_text(text_payload) + if not payload_video_paths: + payload_ext = os.path.splitext(text_payload)[1].lower() + if payload_ext in video_path_exts: + payload_video_paths = [text_payload] + downloaded_any_video = False + for raw_video_path in payload_video_paths: + norm_video_path = str(raw_video_path).strip().replace("\\", "/") + if not norm_video_path or norm_video_path in seen_video_payload_paths: + continue + seen_video_payload_paths.add(norm_video_path) + + payload_ext = os.path.splitext(norm_video_path)[1].lower() or ".mp4" + video_ref = self._comfy_output_ref_from_path(norm_video_path) + if not video_ref: + continue + local_name = "{}_{}{}".format(safe_name, str(index).zfill(3), payload_ext) + local_path = self._next_available_path(os.path.join(output_dir, local_name)) + try: + client.download_output_file(video_ref, local_path) + saved_paths.append(local_path) + downloaded_any_video = True + except Exception as ex: + log.warning( + "Failed to download segmented video from Comfy path output %s: %s", + raw_video_path, + ex, + ) + if downloaded_any_video: + continue + # Some Save SRT node variants return the output file path as text. # Convert that path to a downloadable Comfy output ref when possible. if text_payload.lower().endswith(".srt"): @@ -477,13 +528,30 @@ def _import_generation_outputs(self, job): ) caption_saved = False + scenes_labeled = 0 if str(job.get("template_id") or "") == "video-whisper-srt": caption_text = self._resolve_caption_text(saved_paths, text_outputs) caption_saved = self._store_caption_on_file( source_file_id=job.get("source_file_id"), caption_text=caption_text, ) - return {"imported": len(saved_paths), "caption_saved": caption_saved} + if str(job.get("template_id") or "") == "video-segment-scenes-transnet": + scenes_labeled = self._apply_scene_segment_metadata( + source_file_id=job.get("source_file_id"), + saved_paths=saved_paths, + ) + return {"imported": len(saved_paths), "caption_saved": caption_saved, "scenes_labeled": scenes_labeled} + + def _extract_video_paths_from_text(self, text_payload): + """Extract absolute video file paths from log/text payloads.""" + text_payload = str(text_payload or "") + if not text_payload: + return [] + pattern = re.compile( + r"([A-Za-z]:[\\/][^\r\n]+?\.(?:mp4|mov|mkv|webm|avi|m4v)|/[^\r\n]+?\.(?:mp4|mov|mkv|webm|avi|m4v))", + re.IGNORECASE, + ) + return [match.strip() for match in pattern.findall(text_payload) if match.strip()] def _resolve_caption_text(self, saved_paths, text_outputs): srt_path = "" @@ -532,6 +600,85 @@ def _store_caption_on_file(self, source_file_id, caption_text): self.win.FileUpdated.emit(str(file_obj.id)) return True + def _seconds_to_compact_timecode(self, seconds_value, fps_fraction, include_hours=False, include_minutes=False): + fps_fraction = fps_fraction if isinstance(fps_fraction, Fraction) and fps_fraction > 0 else Fraction(30, 1) + fps_float = float(fps_fraction) + frame_number = int(round(max(0.0, float(seconds_value or 0.0)) * fps_float)) + 1 + t = time_parts.secondsToTime((frame_number - 1) / fps_float, fps_fraction.numerator, fps_fraction.denominator) + hours = int(t.get("hour", 0)) + minutes = int(t.get("min", 0)) + secs = int(t.get("sec", 0)) + frames = int(t.get("frame", 0)) + if include_hours: + return "{:02d}:{:02d}:{:02d};{:02d}".format(hours, minutes, secs, frames) + if include_minutes: + return "{:02d}:{:02d};{:02d}".format(minutes, secs, frames) + return "{:02d};{:02d}".format(secs, frames) + + def _append_scene_tag(self, file_obj): + tags_raw = str(file_obj.data.get("tags", "") or "").strip() + if not tags_raw: + file_obj.data["tags"] = "scene" + return + tags = [part.strip() for part in tags_raw.split(",") if part.strip()] + if any(part.lower() == "scene" for part in tags): + return + tags.append("scene") + file_obj.data["tags"] = ", ".join(tags) + + def _apply_scene_segment_metadata(self, source_file_id, saved_paths): + source_file = File.get(id=source_file_id) if source_file_id else None + base_name = "scene" + fps_fraction = Fraction(30, 1) + if source_file: + source_path = str(source_file.data.get("path", "") or "") + if source_path: + base_name = os.path.splitext(os.path.basename(source_path))[0] or base_name + fps_data = source_file.data.get("fps", {}) + try: + num = int(fps_data.get("num", 30)) + den = int(fps_data.get("den", 1) or 1) + if num > 0 and den > 0: + fps_fraction = Fraction(num, den) + except (TypeError, ValueError, ZeroDivisionError): + fps_fraction = Fraction(30, 1) + + imported_files = [] + for path in saved_paths: + file_obj = File.get(path=path) + if file_obj and str(file_obj.data.get("media_type", "")) == "video": + imported_files.append(file_obj) + if not imported_files: + return 0 + + running_start = 0.0 + updated = 0 + for file_obj in imported_files: + duration = float(file_obj.data.get("duration") or 0.0) + if duration <= 0: + start_trim = float(file_obj.data.get("start") or 0.0) + end_trim = float(file_obj.data.get("end") or 0.0) + duration = max(0.0, end_trim - start_trim) + running_end = running_start + max(0.0, duration) + + include_hours = int(running_end // 3600) > 0 + include_minutes = include_hours or int((running_end % 3600) // 60) > 0 + start_tc = self._seconds_to_compact_timecode( + running_start, fps_fraction, include_hours=include_hours, include_minutes=include_minutes + ) + end_tc = self._seconds_to_compact_timecode( + running_end, fps_fraction, include_hours=include_hours, include_minutes=include_minutes + ) + file_obj.data["name"] = "{} ({} to {})".format(base_name, start_tc, end_tc) + self._append_scene_tag(file_obj) + file_obj.save() + self.win.FileUpdated.emit(str(file_obj.id)) + + running_start = running_end + updated += 1 + + return updated + def _comfy_output_ref_from_path(self, path_text): """Convert a Comfy output absolute/relative path into a /view-compatible output ref.""" path_text = unquote(str(path_text or "").strip()) @@ -553,6 +700,13 @@ def _comfy_output_ref_from_path(self, path_text): rel_dir = os.path.dirname(rel).strip("/") subfolder = rel_dir else: + if os.path.isabs(normalized): + # Unknown absolute location outside Comfy output tree; fallback to basename only. + return { + "filename": filename, + "subfolder": "", + "type": "output", + } rel_dir = os.path.dirname(normalized).strip("/") if rel_dir and rel_dir != ".": subfolder = rel_dir From af05981aa2f005af1be9124d78280cafa12e53af Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sun, 15 Feb 2026 14:28:29 -0600 Subject: [PATCH 07/27] Lowering quality and fps of text to video (for now) --- src/classes/comfy_pipelines.py | 24 ++++++++++++------------ src/classes/generation_service.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index 5e5cee1e1..fca53a48b 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -345,14 +345,14 @@ def build_workflow( "1": {"inputs": {"ckpt_name": svd_checkpoint_name}, "class_type": "ImageOnlyCheckpointLoader"}, "2": {"inputs": {"ckpt_name": checkpoint_name}, "class_type": "CheckpointLoaderSimple"}, "3": {"inputs": {"text": prompt_text, "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, - "4": {"inputs": {"text": "text, watermark", "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, + "4": {"inputs": {"text": "low quality, blurry", "clip": ["2", 1]}, "class_type": "CLIPTextEncode"}, "5": {"inputs": {"width": 512, "height": 288, "batch_size": 1}, "class_type": "EmptyLatentImage"}, "6": { "inputs": { "seed": seed, - "steps": 15, - "cfg": 8.0, - "sampler_name": "uni_pc_bh2", + "steps": 8, + "cfg": 6.0, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, "model": ["2", 0], @@ -372,7 +372,7 @@ def build_workflow( "height": 288, "video_frames": 24, "motion_bucket_id": 127, - "fps": 24, + "fps": 12, "augmentation_level": 0.0, }, "class_type": "SVD_img2vid_Conditioning", @@ -381,7 +381,7 @@ def build_workflow( "10": { "inputs": { "seed": seed + 1, - "steps": 20, + "steps": 10, "cfg": 2.5, "sampler_name": "euler", "scheduler": "karras", @@ -394,7 +394,7 @@ def build_workflow( "class_type": "KSampler", }, "11": {"inputs": {"samples": ["10", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, - "12": {"inputs": {"images": ["11", 0], "fps": 24}, "class_type": "CreateVideo"}, + "12": {"inputs": {"images": ["11", 0], "fps": 12}, "class_type": "CreateVideo"}, "13": {"inputs": {"video": ["12", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, } @@ -412,11 +412,11 @@ def build_workflow( "clip_vision": ["1", 1], "init_image": ["2", 0], "vae": ["1", 2], - "width": 512, - "height": 288, - "video_frames": 24, + "width": 1024, + "height": 576, + "video_frames": 25, "motion_bucket_id": 127, - "fps": 24, + "fps": 6, "augmentation_level": 0.0, }, "class_type": "SVD_img2vid_Conditioning", @@ -438,7 +438,7 @@ def build_workflow( "class_type": "KSampler", }, "6": {"inputs": {"samples": ["5", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"}, - "7": {"inputs": {"images": ["6", 0], "fps": 24}, "class_type": "CreateVideo"}, + "7": {"inputs": {"images": ["6", 0], "fps": 6}, "class_type": "CreateVideo"}, "8": {"inputs": {"video": ["7", 0], "filename_prefix": "video/{}".format(output_prefix), "format": "auto", "codec": "auto"}, "class_type": "SaveVideo"}, } diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 1014870fe..5f13c085b 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -236,7 +236,7 @@ def action_generate_trigger(self, checked=True): preferred_checkpoint = DEFAULT_SD_CHECKPOINT if pipeline_id == "txt2audio-stable-open": preferred_checkpoint = DEFAULT_STABLE_AUDIO_CHECKPOINT - elif pipeline_id in ("txt2video-svd", "video2video-basic"): + elif pipeline_id == "video2video-basic": preferred_checkpoint = DEFAULT_SD_BASE_CHECKPOINT checkpoint_name = ( preferred_checkpoint if preferred_checkpoint in checkpoint_names else checkpoint_names[0] From 71951d9325c52b0c8ea9e44b0eeb4da243b1554d Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sun, 15 Feb 2026 18:37:51 -0600 Subject: [PATCH 08/27] Large refactor of ComfyUI context menus, trying to simplify and improve integration with many of the AI tasks. Also, new icons. --- src/classes/generation_service.py | 45 ++++++-- .../cosmic/images/ai-action-captions.svg | 6 ++ .../cosmic/images/ai-action-create-audio.svg | 5 + .../cosmic/images/ai-action-create-image.svg | 5 + .../cosmic/images/ai-action-create-video.svg | 5 + .../cosmic/images/ai-action-restyle.svg | 4 + src/themes/cosmic/images/ai-action-scenes.svg | 6 ++ src/themes/cosmic/images/ai-action-smooth.svg | 5 + .../cosmic/images/ai-action-upscale.svg | 5 + .../cosmic/images/ai-category-create.svg | 5 + src/windows/generate.py | 16 ++- src/windows/main_window.py | 7 +- src/windows/models/files_model.py | 33 ++++-- src/windows/views/ai_tools_menu.py | 101 ++++++++++++++++++ src/windows/views/files_listview.py | 42 ++++++-- src/windows/views/files_treeview.py | 26 ++++- src/windows/views/timeline.py | 8 +- 17 files changed, 294 insertions(+), 30 deletions(-) create mode 100644 src/themes/cosmic/images/ai-action-captions.svg create mode 100644 src/themes/cosmic/images/ai-action-create-audio.svg create mode 100644 src/themes/cosmic/images/ai-action-create-image.svg create mode 100644 src/themes/cosmic/images/ai-action-create-video.svg create mode 100644 src/themes/cosmic/images/ai-action-restyle.svg create mode 100644 src/themes/cosmic/images/ai-action-scenes.svg create mode 100644 src/themes/cosmic/images/ai-action-smooth.svg create mode 100644 src/themes/cosmic/images/ai-action-upscale.svg create mode 100644 src/themes/cosmic/images/ai-category-create.svg create mode 100644 src/windows/views/ai_tools_menu.py diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 5f13c085b..3ff6c5909 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -156,8 +156,16 @@ def _get_source_fps(self, source_file): return num / den return None - def action_generate_trigger(self, checked=True): - selected_files = self.win.selected_files() + def _default_generation_name(self, source_file): + default_name = "generation" + if source_file: + path = source_file.data.get("path", "") + if path: + default_name = "{}_gen".format(os.path.splitext(os.path.basename(path))[0]) + return default_name + + def action_generate_trigger(self, checked=True, source_file=None, template_id=None, open_dialog=True): + selected_files = [source_file] if source_file else self.win.selected_files() if len(selected_files) > 1: return @@ -173,11 +181,36 @@ def action_generate_trigger(self, checked=True): source_file = selected_files[0] if selected_files else None templates = available_pipelines(source_file=source_file) - win = GenerateMediaDialog(source_file=source_file, templates=templates, parent=self.win) - if win.exec_() != QDialog.Accepted: - return + available_template_ids = {str(t.get("id", "")).strip() for t in templates} + if open_dialog: + dialog_title = "Enhance with AI" if source_file else "Create with AI" + win = GenerateMediaDialog( + source_file=source_file, + templates=templates, + preselected_template_id=template_id, + dialog_title=dialog_title, + parent=self.win, + ) + if win.exec_() != QDialog.Accepted: + return + payload = win.get_payload() + else: + selected_template_id = str(template_id or "").strip() + if not selected_template_id: + return + if selected_template_id not in available_template_ids: + QMessageBox.information( + self.win, + "Invalid Input", + "The selected AI action is not available for this source type.", + ) + return + payload = { + "name": self._default_generation_name(source_file), + "template_id": selected_template_id, + "prompt": "", + } - payload = win.get_payload() payload_name = self._next_generation_name(payload.get("name")) source_file_id = source_file.id if source_file else None try: diff --git a/src/themes/cosmic/images/ai-action-captions.svg b/src/themes/cosmic/images/ai-action-captions.svg new file mode 100644 index 000000000..f434475da --- /dev/null +++ b/src/themes/cosmic/images/ai-action-captions.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/src/themes/cosmic/images/ai-action-create-audio.svg b/src/themes/cosmic/images/ai-action-create-audio.svg new file mode 100644 index 000000000..1f72fc6a3 --- /dev/null +++ b/src/themes/cosmic/images/ai-action-create-audio.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/themes/cosmic/images/ai-action-create-image.svg b/src/themes/cosmic/images/ai-action-create-image.svg new file mode 100644 index 000000000..63743d258 --- /dev/null +++ b/src/themes/cosmic/images/ai-action-create-image.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/themes/cosmic/images/ai-action-create-video.svg b/src/themes/cosmic/images/ai-action-create-video.svg new file mode 100644 index 000000000..f820a94b0 --- /dev/null +++ b/src/themes/cosmic/images/ai-action-create-video.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/themes/cosmic/images/ai-action-restyle.svg b/src/themes/cosmic/images/ai-action-restyle.svg new file mode 100644 index 000000000..ed62370bf --- /dev/null +++ b/src/themes/cosmic/images/ai-action-restyle.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/themes/cosmic/images/ai-action-scenes.svg b/src/themes/cosmic/images/ai-action-scenes.svg new file mode 100644 index 000000000..010b9a6cd --- /dev/null +++ b/src/themes/cosmic/images/ai-action-scenes.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/src/themes/cosmic/images/ai-action-smooth.svg b/src/themes/cosmic/images/ai-action-smooth.svg new file mode 100644 index 000000000..369860f53 --- /dev/null +++ b/src/themes/cosmic/images/ai-action-smooth.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/themes/cosmic/images/ai-action-upscale.svg b/src/themes/cosmic/images/ai-action-upscale.svg new file mode 100644 index 000000000..523d3f638 --- /dev/null +++ b/src/themes/cosmic/images/ai-action-upscale.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/themes/cosmic/images/ai-category-create.svg b/src/themes/cosmic/images/ai-category-create.svg new file mode 100644 index 000000000..ac9a400ba --- /dev/null +++ b/src/themes/cosmic/images/ai-category-create.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/windows/generate.py b/src/windows/generate.py index a4d6c8591..684df46cb 100644 --- a/src/windows/generate.py +++ b/src/windows/generate.py @@ -44,12 +44,20 @@ class GenerateMediaDialog(QDialog): PREVIEW_WIDTH = 180 PREVIEW_HEIGHT = 128 - def __init__(self, source_file=None, templates=None, parent=None): + def __init__( + self, + source_file=None, + templates=None, + preselected_template_id=None, + dialog_title=None, + parent=None, + ): super().__init__(parent) self.source_file = source_file self.templates = templates or [] + self.preselected_template_id = str(preselected_template_id or "").strip() self.setObjectName("generateDialog") - self.setWindowTitle("Generate") + self.setWindowTitle(str(dialog_title or "AI Tools")) self.setMinimumWidth(620) self.setMinimumHeight(460) @@ -118,6 +126,10 @@ def _build_top_block(self): self.template_combo.addItem(template.get("name", ""), template.get("id", "")) else: self.template_combo.addItem("Basic Text to Image", "txt2img-basic") + if self.preselected_template_id: + index = self.template_combo.findData(self.preselected_template_id) + if index >= 0: + self.template_combo.setCurrentIndex(index) setup_form.addRow("Template", self.template_combo) if self.source_file: diff --git a/src/windows/main_window.py b/src/windows/main_window.py index 1f76e9465..0c91c4355 100644 --- a/src/windows/main_window.py +++ b/src/windows/main_window.py @@ -1146,8 +1146,13 @@ def actionPreview_File_trigger(self, checked=True): """ Preview the selected media file """ log.info('actionPreview_File_trigger') - # Loop through selected files (set 1 selected file if more than 1) + # Prefer current file, but fall back to selected real files when a generation + # placeholder row has focus. f = self.files_model.current_file() + if not f: + selected_files = self.files_model.selected_files() + if selected_files: + f = selected_files[0] # Bail out if no file selected if not f: diff --git a/src/windows/models/files_model.py b/src/windows/models/files_model.py index 792db5b89..61a9be834 100644 --- a/src/windows/models/files_model.py +++ b/src/windows/models/files_model.py @@ -626,16 +626,25 @@ def selected_files(self): def current_file_id(self): """ Get the file ID of the current files-view item, or the first selection """ - cur = self.selection_model.currentIndex() - - if not cur or not cur.isValid() and self.selection_model.hasSelection(): - cur = self.selection_model.selectedIndexes()[0] + # Prefer selected rows first, since currentIndex can become stale when + # switching between details/list views with separate selection models. + selected_rows = self.selection_model.selectedRows(5) + if selected_rows: + current = self.selection_model.currentIndex() + if current and current.isValid(): + current_id = current.sibling(current.row(), 5).data() + if current_id and not self._is_generation_placeholder(current_id): + return current_id + for row_index in selected_rows: + file_id = row_index.data() + if file_id and not self._is_generation_placeholder(file_id): + return file_id + cur = self.selection_model.currentIndex() if cur and cur.isValid(): file_id = cur.sibling(cur.row(), 5).data() - if self._is_generation_placeholder(file_id): - return None - return file_id + if file_id and not self._is_generation_placeholder(file_id): + return file_id def current_file(self): """ Get the File object for the current files-view item, or the first selection """ @@ -664,14 +673,19 @@ def _sync_tree_to_list_selection(self, selected, deselected): try: # Map selected indexes from proxy_model to list_proxy_model list_selection = QItemSelection() + first_list_index = QModelIndex() for index in self.selection_model.selectedRows(0): list_index = self.list_proxy_model.mapFromSource(index) if list_index.isValid(): list_selection.select(list_index, list_index) + if not first_list_index.isValid(): + first_list_index = list_index self.list_selection_model.select( list_selection, QItemSelectionModel.ClearAndSelect | QItemSelectionModel.Rows ) + if first_list_index.isValid(): + self.list_selection_model.setCurrentIndex(first_list_index, QItemSelectionModel.NoUpdate) finally: self._syncing_selection = False @@ -683,14 +697,19 @@ def _sync_list_to_tree_selection(self, selected, deselected): try: # Map selected indexes from list_proxy_model to proxy_model tree_selection = QItemSelection() + first_tree_index = QModelIndex() for index in self.list_selection_model.selectedRows(0): tree_index = self.list_proxy_model.mapToSource(index) if tree_index.isValid(): tree_selection.select(tree_index, tree_index) + if not first_tree_index.isValid(): + first_tree_index = tree_index self.selection_model.select( tree_selection, QItemSelectionModel.ClearAndSelect | QItemSelectionModel.Rows ) + if first_tree_index.isValid(): + self.selection_model.setCurrentIndex(first_tree_index, QItemSelectionModel.NoUpdate) finally: self._syncing_selection = False diff --git a/src/windows/views/ai_tools_menu.py b/src/windows/views/ai_tools_menu.py new file mode 100644 index 000000000..50045aa5c --- /dev/null +++ b/src/windows/views/ai_tools_menu.py @@ -0,0 +1,101 @@ +""" + @file + @brief Shared AI Tools context-menu builder for project files and timeline. +""" + +import os +from functools import partial + +from PyQt5.QtGui import QIcon + +from classes.app import get_app +from classes import info +from .menu import StyledContextMenu + + +def _trigger_generation(win, template_id, source_file=None, open_dialog=False): + win.generation_service.action_generate_trigger( + source_file=source_file, + template_id=template_id, + open_dialog=open_dialog, + ) + + +def _icon(name): + icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", name) + if os.path.exists(icon_path): + return QIcon(icon_path) + return QIcon() + + +def add_ai_tools_menu(win, parent_menu, source_file=None): + _ = get_app()._tr + media_type = str(source_file.data.get("media_type", "")) if source_file else "" + + if source_file: + ai_menu = StyledContextMenu(title=_("Enhance with AI"), parent=parent_menu) + ai_menu.setIcon(_icon("tool-generate-sparkle.svg")) + + if media_type == "image": + action = ai_menu.addAction(_("Increase Resolution (4x)")) + action.setIcon(_icon("ai-action-upscale.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "upscale-realesrgan-x4", source_file, False) + ) + ai_menu.addSeparator() + action = ai_menu.addAction(_("Change Image Style...")) + action.setIcon(_icon("ai-action-restyle.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "img2img-basic", source_file, True) + ) + parent_menu.addMenu(ai_menu) + return ai_menu + + elif media_type == "video": + action = ai_menu.addAction(_("Increase Resolution (4x)")) + action.setIcon(_icon("ai-action-upscale.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "video-upscale-gan", source_file, False) + ) + action = ai_menu.addAction(_("Smooth Motion (2x Frame Rate)")) + action.setIcon(_icon("ai-action-smooth.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "video-frame-interpolation-rife2x", source_file, False) + ) + action = ai_menu.addAction(_("Split into Scenes")) + action.setIcon(_icon("ai-action-scenes.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "video-segment-scenes-transnet", source_file, False) + ) + action = ai_menu.addAction(_("Add Captions from Speech")) + action.setIcon(_icon("ai-action-captions.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "video-whisper-srt", source_file, False) + ) + ai_menu.addSeparator() + action = ai_menu.addAction(_("Change Video Style...")) + action.setIcon(_icon("ai-action-restyle.svg")) + action.triggered.connect( + partial(_trigger_generation, win, "video2video-basic", source_file, True) + ) + else: + action = ai_menu.addAction(_("No AI enhancement actions available yet.")) + action.setEnabled(False) + + parent_menu.addMenu(ai_menu) + return ai_menu + + ai_menu = StyledContextMenu(title=_("Create with AI"), parent=parent_menu) + ai_menu.setIcon(_icon("tool-generate-sparkle.svg")) + action = ai_menu.addAction(_("Image...")) + action.setIcon(_icon("ai-action-create-image.svg")) + action.triggered.connect(partial(_trigger_generation, win, "txt2img-basic", source_file, True)) + action = ai_menu.addAction(_("Video...")) + action.setIcon(_icon("ai-action-create-video.svg")) + action.triggered.connect(partial(_trigger_generation, win, "txt2video-svd", source_file, True)) + action = ai_menu.addAction(_("Audio...")) + action.setIcon(_icon("ai-action-create-audio.svg")) + action.triggered.connect(partial(_trigger_generation, win, "txt2audio-stable-open", source_file, True)) + + parent_menu.addMenu(ai_menu) + return ai_menu diff --git a/src/windows/views/files_listview.py b/src/windows/views/files_listview.py index 048fae1d0..49c60e9d1 100644 --- a/src/windows/views/files_listview.py +++ b/src/windows/views/files_listview.py @@ -29,14 +29,15 @@ import os import uuid -from PyQt5.QtCore import QSize, Qt, QPoint, QRegExp -from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor +from PyQt5.QtCore import QSize, Qt, QPoint, QRegExp, QItemSelectionModel +from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor, QFontMetrics from PyQt5.QtWidgets import QListView, QAbstractItemView, QStyledItemDelegate, QStyleOptionViewItem, QStyle from classes import info from classes.app import get_app from classes.logger import log from classes.query import File +from .ai_tools_menu import add_ai_tools_menu from .menu import StyledContextMenu @@ -120,9 +121,21 @@ def paint(self, painter, option, index): painter.setBrush(QColor("#53A0ED")) painter.drawRect(fill_rect) if status == "queued": - text_rect = full_rect.adjusted(0, -14, 0, -4) - painter.setPen(QColor("#9EC8F7")) - painter.drawText(text_rect, Qt.AlignLeft | Qt.AlignVCenter, "Queued") + label = "Queued" + fm = QFontMetrics(painter.font()) + text_w = fm.horizontalAdvance(label) + text_h = fm.height() + pad_x = 5 + pad_y = 2 + badge_w = text_w + (pad_x * 2) + badge_h = text_h + (pad_y * 2) + badge_rect = deco_rect.adjusted(3, 3, 0, 0) + badge_rect.setWidth(badge_w) + badge_rect.setHeight(badge_h) + painter.setBrush(QColor(18, 22, 30, 220)) + painter.drawRoundedRect(badge_rect, 4, 4) + painter.setPen(QColor("#EAF5FF")) + painter.drawText(badge_rect, Qt.AlignCenter, label) painter.restore() @@ -140,12 +153,20 @@ def contextMenuEvent(self, event): app.context_menu_object = "files" index = self.indexAt(event.pos()) + if index.isValid(): + self.setCurrentIndex(index) + self.selectionModel().select( + index, + QItemSelectionModel.ClearAndSelect, + ) # Build menu menu = StyledContextMenu(parent=self) menu.addAction(self.win.actionImportFiles) + source_file = None + active_job = None file_id = None if index.isValid(): @@ -161,9 +182,11 @@ def contextMenuEvent(self, event): active_job = None else: active_job = self.win.active_generation_job_for_file(file_id) + source_file = File.get(id=file_id) + add_ai_tools_menu(self.win, menu, source_file=source_file) + if not active_job: self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) - menu.addAction(self.win.actionGenerate) if active_job: cancel_action = menu.addAction(_("Cancel Job")) delete_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "track-delete-enabled.svg") @@ -232,6 +255,13 @@ def contextMenuEvent(self, event): def mouseDoubleClickEvent(self, event): super(FilesListView, self).mouseDoubleClickEvent(event) + index = self.indexAt(event.pos()) + if index.isValid(): + self.setCurrentIndex(index) + self.selectionModel().select( + index, + QItemSelectionModel.ClearAndSelect, + ) # Preview File, File Properties, or Split File (depending on Shift/Ctrl) if int(get_app().keyboardModifiers() & Qt.ShiftModifier) > 0: get_app().window.actionSplitFile.trigger() diff --git a/src/windows/views/files_treeview.py b/src/windows/views/files_treeview.py index 1f16c51f3..54a8e465e 100644 --- a/src/windows/views/files_treeview.py +++ b/src/windows/views/files_treeview.py @@ -31,13 +31,14 @@ import uuid from PyQt5.QtCore import QSize, Qt, QPoint -from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor +from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor, QFontMetrics from PyQt5.QtWidgets import QTreeView, QAbstractItemView, QSizePolicy, QHeaderView, QStyledItemDelegate, QStyleOptionViewItem, QStyle from classes import info from classes.app import get_app from classes.logger import log from classes.query import File +from .ai_tools_menu import add_ai_tools_menu from .menu import StyledContextMenu @@ -116,9 +117,21 @@ def paint(self, painter, option, index): painter.setBrush(QColor("#53A0ED")) painter.drawRect(fill_rect) if status == "queued": - text_rect = full_rect.adjusted(0, -14, 0, -4) - painter.setPen(QColor("#9EC8F7")) - painter.drawText(text_rect, Qt.AlignLeft | Qt.AlignVCenter, "Queued") + label = "Queued" + fm = QFontMetrics(painter.font()) + text_w = fm.horizontalAdvance(label) + text_h = fm.height() + pad_x = 5 + pad_y = 2 + badge_w = text_w + (pad_x * 2) + badge_h = text_h + (pad_y * 2) + badge_rect = deco_rect.adjusted(3, 3, 0, 0) + badge_rect.setWidth(badge_w) + badge_rect.setHeight(badge_h) + painter.setBrush(QColor(18, 22, 30, 220)) + painter.drawRoundedRect(badge_rect, 4, 4) + painter.setPen(QColor("#EAF5FF")) + painter.drawText(badge_rect, Qt.AlignCenter, label) painter.restore() @@ -142,6 +155,7 @@ def contextMenuEvent(self, event): menu.addAction(self.win.actionImportFiles) + source_file = None active_job = None file_id = None if index.isValid(): @@ -155,9 +169,11 @@ def contextMenuEvent(self, event): active_job = None else: active_job = self.win.active_generation_job_for_file(file_id) + source_file = File.get(id=file_id) + add_ai_tools_menu(self.win, menu, source_file=source_file) + if not active_job: self.win.actionGenerate.setEnabled(self.win.can_open_generate_dialog()) - menu.addAction(self.win.actionGenerate) if active_job: cancel_action = menu.addAction(_("Cancel Job")) delete_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "track-delete-enabled.svg") diff --git a/src/windows/views/timeline.py b/src/windows/views/timeline.py index d7131ef9f..f3f8982aa 100644 --- a/src/windows/views/timeline.py +++ b/src/windows/views/timeline.py @@ -1057,7 +1057,7 @@ def ShowTimelineMenu(self, position, layer_number): if not has_clipboard and not found_gap: return - # Get track object (ignore locked tracks) + # Get track object (ignore locked tracks for edit operations) track = Track.get(number=layer_number) if not track: return @@ -1068,6 +1068,8 @@ def ShowTimelineMenu(self, position, layer_number): # New context menu menu = StyledContextMenu(parent=self) + has_edit_actions = False + if found_gap: # Add 'Remove Gap' Menu menu.addAction(self.window.actionRemoveGap) @@ -1079,8 +1081,7 @@ def ShowTimelineMenu(self, position, layer_number): self.window.actionRemoveGap.triggered.connect( partial(self.RemoveGap_Triggered, found_start, found_end, int(layer_number)) ) - if has_clipboard and found_gap: - menu.addSeparator() + has_edit_actions = True if has_clipboard: # Add 'Paste' Menu Paste_Clip = menu.addAction(_("Paste")) @@ -1088,6 +1089,7 @@ def ShowTimelineMenu(self, position, layer_number): Paste_Clip.triggered.connect( partial(self.Paste_Triggered, MenuCopy.PASTE, [], []) ) + has_edit_actions = True # Show context menu self.context_menu_cursor_position = QCursor.pos() From e08b0a39745965ff5124c3f92fe643758d66aa4a Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Mon, 16 Feb 2026 00:33:11 -0600 Subject: [PATCH 09/27] Fixing file multi-selections in this branch (so right click doesn't not clear them) --- src/windows/views/files_listview.py | 6 +----- src/windows/views/files_treeview.py | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/windows/views/files_listview.py b/src/windows/views/files_listview.py index dad07a7a3..c43c7e410 100644 --- a/src/windows/views/files_listview.py +++ b/src/windows/views/files_listview.py @@ -156,11 +156,7 @@ def contextMenuEvent(self, event): if not index.isValid(): self.clearSelection() else: - self.setCurrentIndex(index) - self.selectionModel().select( - index, - QItemSelectionModel.ClearAndSelect, - ) + self.selectionModel().setCurrentIndex(index, QItemSelectionModel.NoUpdate) # Build menu menu = StyledContextMenu(parent=self) diff --git a/src/windows/views/files_treeview.py b/src/windows/views/files_treeview.py index 93ad4742a..2067dadfd 100644 --- a/src/windows/views/files_treeview.py +++ b/src/windows/views/files_treeview.py @@ -30,7 +30,7 @@ import os import uuid -from PyQt5.QtCore import QSize, Qt, QPoint +from PyQt5.QtCore import QSize, Qt, QPoint, QItemSelectionModel from PyQt5.QtGui import QDrag, QCursor, QPixmap, QPainter, QIcon, QColor, QFontMetrics from PyQt5.QtWidgets import QTreeView, QAbstractItemView, QSizePolicy, QHeaderView, QStyledItemDelegate, QStyleOptionViewItem, QStyle @@ -151,6 +151,8 @@ def contextMenuEvent(self, event): index = self.indexAt(event.pos()) if not index.isValid(): self.clearSelection() + else: + self.selectionModel().setCurrentIndex(index, QItemSelectionModel.NoUpdate) # Build menu menu = StyledContextMenu(parent=self) From ab7f0e801c33ec5584ae6d3af03275ca8321b00c Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Mon, 16 Feb 2026 18:49:45 -0600 Subject: [PATCH 10/27] Large refactor to support user-defined ComfyUI JSON files in .openshot_qt/comfyui/ folder, and write a debug.json file (what we send to ComfyUI for debugging purposes). Also split out all our built-in workflows as simple JSON files, with a few overrides to help menus/categories. Also now use the category icon when generating new items (audio icon, image icon, etc...). Menu is now dynamic, and will load user-defined JSON workflows as well. --- src/classes/assets.py | 6 + src/classes/comfy_client.py | 33 ++ src/classes/comfy_templates.py | 374 ++++++++++++++ src/classes/generation_service.py | 468 +++++++++++++----- src/classes/info.py | 2 + src/classes/project_data.py | 31 +- src/comfyui/img2img-basic.json | 107 ++++ src/comfyui/txt2audio-stable-open.json | 101 ++++ src/comfyui/txt2img-basic.json | 95 ++++ src/comfyui/txt2video-svd.json | 186 +++++++ src/comfyui/upscale-realesrgan-x4.json | 47 ++ .../video-frame-interpolation-rife2x.json | 70 +++ .../video-segment-scenes-transnet.json | 73 +++ src/comfyui/video-upscale-gan.json | 86 ++++ src/comfyui/video-whisper-srt.json | 57 +++ src/comfyui/video2video-basic.json | 146 ++++++ src/windows/main_window.py | 1 + src/windows/models/files_model.py | 35 +- src/windows/views/ai_tools_menu.py | 97 ++-- 19 files changed, 1818 insertions(+), 197 deletions(-) create mode 100644 src/classes/comfy_templates.py create mode 100644 src/comfyui/img2img-basic.json create mode 100644 src/comfyui/txt2audio-stable-open.json create mode 100644 src/comfyui/txt2img-basic.json create mode 100644 src/comfyui/txt2video-svd.json create mode 100644 src/comfyui/upscale-realesrgan-x4.json create mode 100644 src/comfyui/video-frame-interpolation-rife2x.json create mode 100644 src/comfyui/video-segment-scenes-transnet.json create mode 100644 src/comfyui/video-upscale-gan.json create mode 100644 src/comfyui/video-whisper-srt.json create mode 100644 src/comfyui/video2video-basic.json diff --git a/src/classes/assets.py b/src/classes/assets.py index 559f4a609..0ae9a9676 100644 --- a/src/classes/assets.py +++ b/src/classes/assets.py @@ -90,6 +90,12 @@ def get_assets_path(file_path=None, create_paths=True): os.mkdir(asset_clipboard_folder) log.info("New clipboard folder: {}".format(asset_clipboard_folder)) + # Create asset ComfyUI output folder + asset_comfy_output_folder = os.path.join(asset_path, "comfyui-output") + if not os.path.exists(asset_comfy_output_folder): + os.mkdir(asset_comfy_output_folder) + log.info("New ComfyUI output folder: {}".format(asset_comfy_output_folder)) + return asset_path except Exception as ex: diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index da1542760..87aa611e6 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -30,6 +30,7 @@ import ssl import base64 import uuid +from datetime import datetime import re import socket import struct @@ -38,6 +39,7 @@ from urllib.parse import quote, urlencode from urllib.parse import urlparse +from classes import info from classes.logger import log @@ -253,6 +255,35 @@ class ComfyClient: def __init__(self, base_url): self.base_url = str(base_url or "").rstrip("/") + @staticmethod + def _write_debug_error(payload): + debug_dir = info.COMFYUI_PATH + try: + os.makedirs(debug_dir, exist_ok=True) + debug_path = os.path.join(debug_dir, "debug_error.json") + with open(debug_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + except Exception: + log.warning("Failed writing Comfy debug error payload", exc_info=True) + + def _write_debug_prompt_payload(self, prompt_graph, client_id): + debug_dir = info.COMFYUI_PATH + try: + os.makedirs(debug_dir, exist_ok=True) + debug_path = os.path.join(debug_dir, "debug.json") + payload = { + "generated_at_utc": datetime.utcnow().isoformat() + "Z", + "comfy_url": self.base_url, + "client_id": str(client_id or ""), + "prompt": prompt_graph, + } + with open(debug_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + except Exception: + log.warning("Failed writing Comfy sent prompt payload", exc_info=True) + @staticmethod def open_progress_socket(base_url, client_id): return ComfyProgressSocket(base_url, client_id) @@ -263,6 +294,7 @@ def ping(self, timeout=0.5): def queue_prompt(self, prompt_graph, client_id): prompt_graph = self._rewrite_prompt_local_file_inputs(prompt_graph) + self._write_debug_prompt_payload(prompt_graph, client_id) payload = json.dumps({"prompt": prompt_graph, "client_id": client_id}).encode("utf-8") req = Request( "{}/prompt".format(self.base_url), @@ -277,6 +309,7 @@ def queue_prompt(self, prompt_graph, client_id): details = "" try: error_data = json.loads(ex.read().decode("utf-8")) + ComfyClient._write_debug_error(error_data) error_obj = error_data.get("error", {}) if isinstance(error_obj, dict): details = error_obj.get("type") or error_obj.get("message") or "" diff --git a/src/classes/comfy_templates.py b/src/classes/comfy_templates.py new file mode 100644 index 000000000..19d40f2ea --- /dev/null +++ b/src/classes/comfy_templates.py @@ -0,0 +1,374 @@ +""" + @file + @brief ComfyUI workflow template discovery and classification helpers. +""" + +import copy +import json +import os + +from classes import info +from classes.logger import log + + +IMAGE_INPUT_TYPES = { + "loadimage", + "load image", +} +VIDEO_INPUT_TYPES = { + "loadvideo", + "load video", + "vhs_loadvideo", +} +AUDIO_INPUT_TYPES = { + "loadaudio", + "load audio", +} + +IMAGE_OUTPUT_TYPES = { + "saveimage", + "save image", +} +VIDEO_OUTPUT_TYPES = { + "savevideo", + "save video", +} +AUDIO_OUTPUT_TYPES = { + "saveaudio", + "save audio", +} + +KNOWN_NODE_TYPES = { + # Input + "checkpointloadersimple", + "cliptextencode", + "cliploader", + "loadimage", + "loadvideo", + "vhs_loadvideo", + "loadaudio", + # Core built-in/OpenShot workflows + "vaeencode", + "vaedecode", + "ksampler", + "upscalemodelloader", + "imageupscalewithmodel", + "videoslice", + "video slice", + "getvideocomponents", + "createvideo", + "saveimage", + "savevideo", + "saveaudio", + "save srt", + "emptylatentimage", + "imageonlycheckpointloader", + "svd_img2vid_conditioning", + "videolinearcfgguidance", + "emptylatentaudio", + "vaedecodeaudio", + "previewany", + "apply whisper", + "riff vfi", + "rife vfi", + "downloadandloadtransnetmodel", + "transnetv2_run", + "selectvideo", + "stableaudioprojectionmodel", + "stableaudiomodelloader", + "stableaudioemptylatentaudio", + "stableaudioembedding", + "kdiffusionsampler", + "stableaudiovaedecode", + "videocombine", + "imagescaleby", + "imagetosimage", + "imagetoimage", + "imagescaleto", +} + + +class ComfyTemplateRegistry: + """Discovers ComfyUI templates from built-in + user folders.""" + + def __init__(self): + self._cache = None + self._cache_signature = None + + @staticmethod + def _is_ignored_filename(name): + return str(name or "").strip().lower() in ("debug.json", "debug_error.json", "debug_sent.json") + + def _template_roots(self): + return [ + (os.path.join(info.PATH, "comfyui"), False), + (info.COMFYUI_PATH, True), + ] + + def _current_signature(self): + signature = [] + for folder, _is_user in self._template_roots(): + if not os.path.isdir(folder): + continue + for name in sorted(os.listdir(folder)): + if not name.lower().endswith(".json"): + continue + if self._is_ignored_filename(name): + continue + path = os.path.join(folder, name) + try: + stat = os.stat(path) + signature.append((path, stat.st_mtime_ns, stat.st_size)) + except OSError: + continue + return tuple(signature) + + def discover(self, force=False): + signature = self._current_signature() + if not force and self._cache is not None and signature == self._cache_signature: + return self._cache + + templates = [] + existing_ids = set() + for folder, is_user in self._template_roots(): + if not os.path.isdir(folder): + continue + for name in sorted(os.listdir(folder)): + if not name.lower().endswith(".json"): + continue + if self._is_ignored_filename(name): + continue + path = os.path.join(folder, name) + template = self._load_template(path, is_user=is_user, existing_ids=existing_ids) + if template is None: + continue + templates.append(template) + + templates.sort(key=lambda t: (int(t.get("sort_order", 99999)), str(t.get("display_name", "")).lower())) + self._cache = templates + self._cache_signature = signature + return templates + + def _load_template(self, path, is_user, existing_ids): + try: + with open(path, "r", encoding="utf-8") as handle: + payload = json.load(handle) + except Exception as ex: + log.warning("Skipping invalid ComfyUI template JSON %s: %s", path, ex) + return None + if not isinstance(payload, dict): + log.warning("Skipping invalid ComfyUI template JSON %s: root must be an object", path) + return None + + workflow = self._extract_workflow(payload) + if workflow is None: + log.warning("Skipping invalid ComfyUI template JSON %s: no valid workflow graph found", path) + return None + + node_types = [] + input_types = set() + output_types = set() + unknown_node_types = set() + needs_prompt = False + for node in workflow.values(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip() + if not class_type: + continue + inputs = node.get("inputs", {}) + if not isinstance(inputs, dict): + inputs = {} + class_key = class_type.lower().replace("_", "") + class_flat = class_type.lower().strip() + node_types.append(class_type) + + text_value = inputs.get("text", None) + if isinstance(text_value, str): + meta = node.get("_meta", {}) + meta_title = "" + if isinstance(meta, dict): + meta_title = str(meta.get("title", "")).strip().lower() + if "textencode" in class_key or "prompt" in meta_title: + needs_prompt = True + + if class_flat in IMAGE_INPUT_TYPES or class_key in IMAGE_INPUT_TYPES: + input_types.add("image") + if class_flat in VIDEO_INPUT_TYPES or class_key in VIDEO_INPUT_TYPES: + input_types.add("video") + if class_flat in AUDIO_INPUT_TYPES or class_key in AUDIO_INPUT_TYPES: + input_types.add("audio") + + if class_flat in IMAGE_OUTPUT_TYPES or class_key in IMAGE_OUTPUT_TYPES: + output_types.add("image") + if class_flat in VIDEO_OUTPUT_TYPES or class_key in VIDEO_OUTPUT_TYPES: + output_types.add("video") + if class_flat in AUDIO_OUTPUT_TYPES or class_key in AUDIO_OUTPUT_TYPES: + output_types.add("audio") + + if class_flat not in KNOWN_NODE_TYPES and class_key not in KNOWN_NODE_TYPES: + unknown_node_types.add(class_type) + + if unknown_node_types: + log.warning( + "ComfyUI template has unknown node types (%s): %s", + os.path.basename(path), + ", ".join(sorted(unknown_node_types)), + ) + + override_category = str(payload.get("menu_category") or payload.get("category") or "").strip().lower() + override_output_type = str(payload.get("output_type") or payload.get("media_output") or "").strip().lower() + override_icon = str(payload.get("action_icon") or payload.get("icon") or "").strip() + override_open_dialog = payload.get("open_dialog", None) + + inferred_category = "unknown" + requires_source = bool(input_types) + if output_types: + inferred_category = "enhance" if requires_source else "create" + else: + if override_category not in ("create", "enhance", "unknown"): + log.warning( + "ComfyUI template category unknown (%s): no output nodes detected", + os.path.basename(path), + ) + if override_category in ("create", "enhance", "unknown"): + inferred_category = override_category + + template_id = str(payload.get("template_id") or payload.get("id") or "").strip() + if not template_id: + template_id = os.path.splitext(os.path.basename(path))[0] + + unique_id = template_id + suffix = 2 + while unique_id in existing_ids: + unique_id = "{}__{}".format(template_id, suffix) + suffix += 1 + existing_ids.add(unique_id) + + display_name = self._extract_name(payload, path) + if is_user: + display_name = "(User) {}".format(display_name) + + try: + sort_order = int(payload.get("menu_order", 99999)) + except (TypeError, ValueError): + sort_order = 99999 + primary_output = self._primary_output_type(output_types) + if override_output_type in ("image", "video", "audio", "unknown"): + primary_output = override_output_type + + open_dialog = None + if isinstance(override_open_dialog, bool): + open_dialog = override_open_dialog + + return { + "id": unique_id, + "template_id": template_id, + "display_name": display_name, + "path": path, + "is_user": is_user, + "category": inferred_category, + "input_types": sorted(input_types), + "output_types": sorted(output_types), + "primary_output": primary_output, + "sort_order": sort_order, + "workflow": workflow, + "node_types": node_types, + "needs_prompt": needs_prompt, + "action_icon": override_icon, + "open_dialog": open_dialog, + } + + def _primary_output_type(self, output_types): + if "video" in output_types: + return "video" + if "image" in output_types: + return "image" + if "audio" in output_types: + return "audio" + return "unknown" + + def _extract_name(self, payload, path): + fields = [ + payload.get("name"), + payload.get("title"), + payload.get("workflow_name"), + ] + metadata = payload.get("metadata") + if isinstance(metadata, dict): + fields.extend([metadata.get("name"), metadata.get("title")]) + + for value in fields: + text = str(value or "").strip() + if text: + return text + + return os.path.splitext(os.path.basename(path))[0] + + def _extract_workflow(self, payload): + if self._looks_like_workflow(payload): + return payload + if isinstance(payload, dict): + workflow = payload.get("workflow") + if self._looks_like_workflow(workflow): + return workflow + return None + + def _looks_like_workflow(self, value): + if not isinstance(value, dict) or not value: + return False + for node in value.values(): + if isinstance(node, dict) and str(node.get("class_type", "")).strip(): + return True + return False + + def templates_for_context(self, source_file=None): + templates = self.discover() + media_type = "" + if source_file: + media_type = str(source_file.data.get("media_type", "")).strip().lower() + + filtered = [] + for template in templates: + category = str(template.get("category", "unknown")) + input_types = set(template.get("input_types", [])) + if source_file: + if category not in ("enhance", "unknown"): + continue + if input_types and media_type not in input_types: + continue + else: + if category not in ("create", "unknown"): + continue + if category == "unknown" and input_types: + continue + filtered.append(template) + return filtered + + def get_template(self, template_id): + template_id = str(template_id or "").strip() + if not template_id: + return None + for template in self.discover(): + if str(template.get("id")) == template_id: + return template + return None + + def get_workflow_copy(self, template_id): + template = self.get_template(template_id) + if not template: + return None + return copy.deepcopy(template.get("workflow") or {}) + + def output_icon_name(self, template): + explicit_icon = str((template or {}).get("action_icon") or "").strip() + if explicit_icon: + return explicit_icon + kind = str((template or {}).get("primary_output") or "unknown") + if kind == "video": + return "ai-action-create-video.svg" + if kind == "audio": + return "ai-action-create-audio.svg" + if kind == "image": + return "ai-action-create-image.svg" + return "tool-generate-sparkle.svg" diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 3ff6c5909..890684206 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -39,8 +39,8 @@ from classes import time_parts from classes.app import get_app from classes.comfy_client import ComfyClient +from classes.comfy_templates import ComfyTemplateRegistry from classes.comfy_pipelines import ( - available_pipelines, build_workflow, is_supported_img2img_path, pipeline_requires_checkpoint, @@ -64,10 +64,25 @@ class GenerationService: """Encapsulates generation-specific UI + workflow behavior.""" + LEGACY_PIPELINE_IDS = { + "txt2img-basic", + "txt2video-svd", + "txt2audio-stable-open", + "img2img-basic", + "upscale-realesrgan-x4", + "img2video-svd", + "video-segment-scenes-transnet", + "video-frame-interpolation-rife2x", + "video-upscale-gan", + "video2video-basic", + "video-whisper-srt", + } + def __init__(self, win): self.win = win self._generation_temp_files = [] self._comfy_status_cache = {"checked_at": 0.0, "available": False} + self.template_registry = ComfyTemplateRegistry() def cleanup_temp_files(self): for tmp_path in list(self._generation_temp_files): @@ -164,6 +179,200 @@ def _default_generation_name(self, source_file): default_name = "{}_gen".format(os.path.splitext(os.path.basename(path))[0]) return default_name + def templates_for_context(self, source_file=None): + templates = self.template_registry.templates_for_context(source_file=source_file) + return [ + {"id": t.get("id"), "name": t.get("display_name"), "template": t} + for t in templates + ] + + def build_menu_templates(self, source_file=None): + grouped = {"create": [], "enhance": [], "unknown": []} + for template in self.template_registry.templates_for_context(source_file=source_file): + category = str(template.get("category", "unknown")) + if category not in grouped: + category = "unknown" + grouped[category].append(template) + return grouped + + def icon_for_template(self, template): + return self.template_registry.output_icon_name(template) + + def _prepare_nonlegacy_workflow(self, template, payload_name, prompt_text, source_file, source_path): + workflow = self.template_registry.get_workflow_copy(template.get("id")) + if not workflow: + raise ValueError("Template workflow not found.") + + template_dir = "" + template_path = str((template or {}).get("path") or "").strip() + if template_path: + template_dir = os.path.dirname(template_path) + + def _resolve_template_local_file(path_text): + path_text = str(path_text or "").strip() + if not path_text: + return "" + if os.path.isabs(path_text): + return path_text if os.path.exists(path_text) else "" + if not template_dir: + return "" + candidate = os.path.abspath(os.path.join(template_dir, path_text)) + if os.path.exists(candidate): + return candidate + return "" + + prompt_text = str(prompt_text or "").strip() + media_type = str(source_file.data.get("media_type", "")).strip().lower() if source_file else "" + applied_prompt = False + loadimage_node_ids = [] + loadvideo_node_ids = [] + loadaudio_node_ids = [] + + for node_id, node in workflow.items(): + if not isinstance(node, dict): + continue + class_flat = str(node.get("class_type", "")).strip().lower() + if class_flat == "loadimage": + loadimage_node_ids.append(str(node_id)) + elif class_flat in ("loadvideo", "load video", "vhs_loadvideo"): + loadvideo_node_ids.append(str(node_id)) + elif class_flat in ("loadaudio", "load audio"): + loadaudio_node_ids.append(str(node_id)) + + def _is_placeholder_value(path_text): + path_text = str(path_text or "").strip().lower() + return path_text in ("__openshot_input__", "{{openshot_input}}", "$openshot_input") + + def _select_bind_nodes(node_ids, path_key, preferred_upload=None): + explicit = [] + candidates = [] + for node_id in node_ids: + node = workflow.get(node_id, {}) + inputs = node.get("inputs", {}) if isinstance(node, dict) else {} + if not isinstance(inputs, dict): + continue + path_value = str(inputs.get(path_key, "")).strip() + upload_value = str(inputs.get("upload", "")).strip().lower() + if _is_placeholder_value(path_value): + explicit.append(node_id) + continue + if preferred_upload and upload_value == preferred_upload: + explicit.append(node_id) + continue + if not path_value: + candidates.append(node_id) + if explicit: + return set(explicit) + if candidates: + return {candidates[0]} + if node_ids: + return {node_ids[0]} + return set() + + image_bind_nodes = _select_bind_nodes(loadimage_node_ids, "image", preferred_upload="image") + video_bind_nodes = _select_bind_nodes(loadvideo_node_ids, "file") + audio_bind_nodes = _select_bind_nodes(loadaudio_node_ids, "audio") + + for node_id, node in workflow.items(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip() + inputs = node.get("inputs", {}) + if not isinstance(inputs, dict): + continue + + class_flat = class_type.lower().strip() + + if "filename_prefix" in inputs: + prefix_value = str(inputs.get("filename_prefix", "")).strip() + if "/" in prefix_value: + head = prefix_value.rsplit("/", 1)[0] + inputs["filename_prefix"] = "{}/{}".format(head, payload_name) + else: + inputs["filename_prefix"] = payload_name + + if prompt_text and class_flat == "cliptextencode" and "text" in inputs and not applied_prompt: + inputs["text"] = prompt_text + applied_prompt = True + + if not source_path: + continue + + node_id = str(node_id) + + if class_flat == "loadimage" and media_type == "image" and node_id in image_bind_nodes: + if "image" in inputs: + inputs["image"] = source_path + if "upload" in inputs: + inputs["upload"] = "image" + elif class_flat == "loadimage": + # Resolve template-local reference images (relative filenames) to absolute paths, + # so ComfyClient can upload and rewrite them to [input] automatically. + configured_image = str(inputs.get("image", "")).strip() + local_image = _resolve_template_local_file(configured_image) + if local_image: + inputs["image"] = local_image + if "upload" in inputs: + inputs["upload"] = "image" + elif media_type == "image" and source_path: + # If additional reference images are missing, gracefully fallback to the selected source image. + # This keeps common exported workflows usable without hand-editing filenames. + missing_relative = configured_image and (not os.path.isabs(configured_image)) + missing_absolute = os.path.isabs(configured_image) and (not os.path.exists(configured_image)) + if missing_relative or missing_absolute: + inputs["image"] = source_path + if "upload" in inputs: + inputs["upload"] = "image" + log.warning( + "Comfy template missing LoadImage asset (%s). Falling back to selected source image.", + configured_image, + ) + elif class_flat in ("loadvideo", "load video") and media_type == "video" and node_id in video_bind_nodes: + if "file" in inputs: + inputs["file"] = source_path + elif "video" in inputs: + inputs["video"] = source_path + elif class_flat in ("loadvideo", "load video"): + local_video = _resolve_template_local_file(inputs.get("file", "")) + if local_video and "file" in inputs: + inputs["file"] = local_video + elif class_flat == "vhs_loadvideo" and media_type == "video" and node_id in video_bind_nodes: + if "video" in inputs: + inputs["video"] = source_path + elif class_flat in ("loadaudio", "load audio") and media_type == "audio" and node_id in audio_bind_nodes: + if "audio" in inputs: + inputs["audio"] = source_path + elif "file" in inputs: + inputs["file"] = source_path + elif class_flat in ("loadaudio", "load audio"): + local_audio = _resolve_template_local_file(inputs.get("audio", "") or inputs.get("file", "")) + if local_audio: + if "audio" in inputs: + inputs["audio"] = local_audio + elif "file" in inputs: + inputs["file"] = local_audio + + if media_type == "image" and image_bind_nodes: + log.debug("Comfy template image input binding nodes=%s source=%s", sorted(image_bind_nodes), source_path) + if media_type == "video" and video_bind_nodes: + log.debug("Comfy template video input binding nodes=%s source=%s", sorted(video_bind_nodes), source_path) + if media_type == "audio" and audio_bind_nodes: + log.debug("Comfy template audio input binding nodes=%s source=%s", sorted(audio_bind_nodes), source_path) + + return workflow + + def _save_nodes_for_workflow(self, workflow): + save_nodes = [] + for node_id, node in workflow.items(): + if not isinstance(node, dict): + continue + class_type = str(node.get("class_type", "")).strip().lower() + if not class_type: + continue + if class_type.startswith("save") or class_type in ("previewany", "transnetv2_run"): + save_nodes.append(str(node_id)) + return save_nodes + def action_generate_trigger(self, checked=True, source_file=None, template_id=None, open_dialog=True): selected_files = [source_file] if source_file else self.win.selected_files() if len(selected_files) > 1: @@ -180,7 +389,7 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No return source_file = selected_files[0] if selected_files else None - templates = available_pipelines(source_file=source_file) + templates = self.templates_for_context(source_file=source_file) available_template_ids = {str(t.get("id", "")).strip() for t in templates} if open_dialog: dialog_title = "Enhance with AI" if source_file else "Create with AI" @@ -213,6 +422,10 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No payload_name = self._next_generation_name(payload.get("name")) source_file_id = source_file.id if source_file else None + template_meta = self.template_registry.get_template(payload.get("template_id")) + if not template_meta: + QMessageBox.information(self.win, "Invalid Input", "The selected AI template was not found.") + return try: source_path = self._prepare_generation_source_path(source_file, payload.get("template_id")) except Exception as ex: @@ -261,144 +474,145 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No ) return - try: - checkpoint_names = [] - if pipeline_requires_checkpoint(pipeline_id) or pipeline_requires_svd_checkpoint(pipeline_id): - checkpoint_names = client.list_checkpoints() - if checkpoint_names: - preferred_checkpoint = DEFAULT_SD_CHECKPOINT - if pipeline_id == "txt2audio-stable-open": - preferred_checkpoint = DEFAULT_STABLE_AUDIO_CHECKPOINT - elif pipeline_id == "video2video-basic": - preferred_checkpoint = DEFAULT_SD_BASE_CHECKPOINT - checkpoint_name = ( - preferred_checkpoint if preferred_checkpoint in checkpoint_names else checkpoint_names[0] - ) - if pipeline_requires_svd_checkpoint(pipeline_id): - if DEFAULT_SVD_CHECKPOINT in checkpoint_names: - svd_checkpoint_name = DEFAULT_SVD_CHECKPOINT - else: - # Prefer any checkpoint that appears to be an SVD model. - svd_candidates = [name for name in checkpoint_names if "svd" in str(name).lower()] - if svd_candidates: - svd_checkpoint_name = svd_candidates[0] - except Exception as ex: - log.warning("Failed to query ComfyUI checkpoints: %s", ex) + if pipeline_id in self.LEGACY_PIPELINE_IDS: + try: + checkpoint_names = [] + if pipeline_requires_checkpoint(pipeline_id) or pipeline_requires_svd_checkpoint(pipeline_id): + checkpoint_names = client.list_checkpoints() + if checkpoint_names: + preferred_checkpoint = DEFAULT_SD_CHECKPOINT + if pipeline_id == "txt2audio-stable-open": + preferred_checkpoint = DEFAULT_STABLE_AUDIO_CHECKPOINT + elif pipeline_id == "video2video-basic": + preferred_checkpoint = DEFAULT_SD_BASE_CHECKPOINT + checkpoint_name = ( + preferred_checkpoint if preferred_checkpoint in checkpoint_names else checkpoint_names[0] + ) + if pipeline_requires_svd_checkpoint(pipeline_id): + if DEFAULT_SVD_CHECKPOINT in checkpoint_names: + svd_checkpoint_name = DEFAULT_SVD_CHECKPOINT + else: + svd_candidates = [name for name in checkpoint_names if "svd" in str(name).lower()] + if svd_candidates: + svd_checkpoint_name = svd_candidates[0] + except Exception as ex: + log.warning("Failed to query ComfyUI checkpoints: %s", ex) - if pipeline_requires_checkpoint(pipeline_id) and not checkpoint_name: - QMessageBox.information( - self.win, - "No Checkpoints Found", - "ComfyUI has no checkpoints available for CheckpointLoaderSimple.\n" - "Add a model to ComfyUI/models/checkpoints and try again.", - ) - return + if pipeline_requires_checkpoint(pipeline_id) and not checkpoint_name: + QMessageBox.information( + self.win, + "No Checkpoints Found", + "ComfyUI has no checkpoints available for CheckpointLoaderSimple.\n" + "Add a model to ComfyUI/models/checkpoints and try again.", + ) + return - if pipeline_requires_svd_checkpoint(pipeline_id) and not svd_checkpoint_name: - QMessageBox.information( - self.win, - "No SVD Checkpoint Found", - "ComfyUI could not find the SVD checkpoint required for the selected video generation template.\n" - "Add an SVD checkpoint (for example {}) to ComfyUI/models/checkpoints and try again.".format(DEFAULT_SVD_CHECKPOINT), - ) - return + if pipeline_requires_svd_checkpoint(pipeline_id) and not svd_checkpoint_name: + QMessageBox.information( + self.win, + "No SVD Checkpoint Found", + "ComfyUI could not find the SVD checkpoint required for the selected video generation template.\n" + "Add an SVD checkpoint (for example {}) to ComfyUI/models/checkpoints and try again.".format(DEFAULT_SVD_CHECKPOINT), + ) + return - try: - if pipeline_requires_upscale_model(pipeline_id): - upscale_models = client.list_upscale_models() - if upscale_models: - upscale_model_name = ( - DEFAULT_UPSCALE_MODEL if DEFAULT_UPSCALE_MODEL in upscale_models else upscale_models[0] - ) - except Exception as ex: - log.warning("Failed to query ComfyUI upscale models: %s", ex) + try: + if pipeline_requires_upscale_model(pipeline_id): + upscale_models = client.list_upscale_models() + if upscale_models: + upscale_model_name = ( + DEFAULT_UPSCALE_MODEL if DEFAULT_UPSCALE_MODEL in upscale_models else upscale_models[0] + ) + except Exception as ex: + log.warning("Failed to query ComfyUI upscale models: %s", ex) - if pipeline_requires_upscale_model(pipeline_id) and not upscale_model_name: - QMessageBox.information( - self.win, - "No Upscale Models Found", - "ComfyUI has no upscaler models available for UpscaleModelLoader.\n" - "Add a model such as RealESRGAN_x4plus.safetensors to ComfyUI/models/upscale_models and try again.", - ) - return + if pipeline_requires_upscale_model(pipeline_id) and not upscale_model_name: + QMessageBox.information( + self.win, + "No Upscale Models Found", + "ComfyUI has no upscaler models available for UpscaleModelLoader.\n" + "Add a model such as RealESRGAN_x4plus.safetensors to ComfyUI/models/upscale_models and try again.", + ) + return - try: - if pipeline_requires_stable_audio_clip(pipeline_id): - clip_names = client.list_clip_models() - if clip_names: - for preferred in (DEFAULT_STABLE_AUDIO_CLIP, "t5_base.safetensors"): - if preferred in clip_names: - stable_audio_clip_name = preferred - break - if not stable_audio_clip_name: - stable_audio_clip_name = clip_names[0] - except Exception as ex: - log.warning("Failed to query ComfyUI CLIP models: %s", ex) + try: + if pipeline_requires_stable_audio_clip(pipeline_id): + clip_names = client.list_clip_models() + if clip_names: + for preferred in (DEFAULT_STABLE_AUDIO_CLIP, "t5_base.safetensors"): + if preferred in clip_names: + stable_audio_clip_name = preferred + break + if not stable_audio_clip_name: + stable_audio_clip_name = clip_names[0] + except Exception as ex: + log.warning("Failed to query ComfyUI CLIP models: %s", ex) - if pipeline_requires_stable_audio_clip(pipeline_id) and not stable_audio_clip_name: - QMessageBox.information( - self.win, - "No Text Encoders Found", - "ComfyUI has no CLIP/text-encoder models available for CLIPLoader.\n" - "Add a text encoder such as t5-base.safetensors and try again.", - ) - return + if pipeline_requires_stable_audio_clip(pipeline_id) and not stable_audio_clip_name: + QMessageBox.information( + self.win, + "No Text Encoders Found", + "ComfyUI has no CLIP/text-encoder models available for CLIPLoader.\n" + "Add a text encoder such as t5-base.safetensors and try again.", + ) + return - try: - if pipeline_requires_rife_model(pipeline_id): - rife_models = client.list_rife_vfi_models() - if rife_models: - for preferred in (DEFAULT_RIFE_VFI_MODEL, "rife49.pth"): - if preferred in rife_models: - rife_model_name = preferred - break - if not rife_model_name: - rife_model_name = rife_models[0] - except Exception as ex: - log.warning("Failed to query ComfyUI RIFE VFI models: %s", ex) + try: + if pipeline_requires_rife_model(pipeline_id): + rife_models = client.list_rife_vfi_models() + if rife_models: + for preferred in (DEFAULT_RIFE_VFI_MODEL, "rife49.pth"): + if preferred in rife_models: + rife_model_name = preferred + break + if not rife_model_name: + rife_model_name = rife_models[0] + except Exception as ex: + log.warning("Failed to query ComfyUI RIFE VFI models: %s", ex) - if pipeline_requires_rife_model(pipeline_id) and not rife_model_name: - QMessageBox.information( - self.win, - "RIFE VFI Not Available", - "ComfyUI could not find the RIFE VFI node/models required for frame interpolation.\n" - "Install ComfyUI-Frame-Interpolation and add models such as rife47.pth.", - ) - return + if pipeline_requires_rife_model(pipeline_id) and not rife_model_name: + QMessageBox.information( + self.win, + "RIFE VFI Not Available", + "ComfyUI could not find the RIFE VFI node/models required for frame interpolation.\n" + "Install ComfyUI-Frame-Interpolation and add models such as rife47.pth.", + ) + return - try: - workflow = build_workflow( - pipeline_id, - payload.get("prompt"), - workflow_source, - payload_name, - checkpoint_name=checkpoint_name, - upscale_model_name=upscale_model_name, - stable_audio_clip_name=stable_audio_clip_name, - svd_checkpoint_name=svd_checkpoint_name, - source_fps=self._get_source_fps(source_file), - rife_model_name=rife_model_name, - ) - except Exception as ex: - QMessageBox.information(self.win, "Invalid Input", str(ex)) - return + try: + workflow = build_workflow( + pipeline_id, + payload.get("prompt"), + workflow_source, + payload_name, + checkpoint_name=checkpoint_name, + upscale_model_name=upscale_model_name, + stable_audio_clip_name=stable_audio_clip_name, + svd_checkpoint_name=svd_checkpoint_name, + source_fps=self._get_source_fps(source_file), + rife_model_name=rife_model_name, + ) + except Exception as ex: + QMessageBox.information(self.win, "Invalid Input", str(ex)) + return + else: + try: + workflow = self._prepare_nonlegacy_workflow( + template_meta, + payload_name=payload_name, + prompt_text=payload.get("prompt"), + source_file=source_file, + source_path=source_path, + ) + except Exception as ex: + QMessageBox.information(self.win, "Invalid Input", str(ex)) + return request = { "comfy_url": self.comfy_ui_url(), "workflow": workflow, "client_id": "openshot-qt", "timeout_s": 21600, - "save_node_ids": [ - str(node_id) - for node_id, node in workflow.items() - if node.get("class_type") in ( - "SaveImage", - "SaveVideo", - "SaveAudio", - "Save SRT", - "PreviewAny", - "TransNetV2_Run", - ) - ], + "save_node_ids": self._save_nodes_for_workflow(workflow), } job_id = self.win.generation_queue.enqueue( payload_name, @@ -464,7 +678,7 @@ def _import_generation_outputs(self, job): request = job.get("request", {}) or {} comfy_url = str(request.get("comfy_url") or self.comfy_ui_url()) client = ComfyClient(comfy_url) - output_dir = os.path.join(info.USER_PATH, "comfy_outputs") + output_dir = info.COMFYUI_OUTPUT_PATH os.makedirs(output_dir, exist_ok=True) name_raw = str(job.get("name") or "generation") diff --git a/src/classes/info.py b/src/classes/info.py index 0cb26d304..07e746dfd 100644 --- a/src/classes/info.py +++ b/src/classes/info.py @@ -63,6 +63,8 @@ USER_PROFILES_PATH = os.path.join(USER_PATH, "profiles") USER_PRESETS_PATH = os.path.join(USER_PATH, "presets") USER_TITLES_PATH = os.path.join(USER_PATH, "title_templates") +COMFYUI_PATH = os.path.join(USER_PATH, "comfyui") +COMFYUI_OUTPUT_PATH = os.path.join(USER_PATH, "comfyui-output") USER_COLORS_PATH = os.path.join(USER_PATH, "colors") PROTOBUF_DATA_PATH = os.path.join(USER_PATH, "protobuf_data") YOLO_PATH = os.path.join(USER_PATH, "yolo") diff --git a/src/classes/project_data.py b/src/classes/project_data.py index 15cda3ac1..9b0f20559 100644 --- a/src/classes/project_data.py +++ b/src/classes/project_data.py @@ -420,6 +420,7 @@ def load(self, file_path, clear_thumbnails=True): info.BLENDER_PATH = os.path.join(get_assets_path(self.current_filepath), "blender") info.PROTOBUF_DATA_PATH = os.path.join(get_assets_path(self.current_filepath), "protobuf_data") info.CLIPBOARD_PATH = os.path.join(get_assets_path(self.current_filepath), "clipboard") + info.COMFYUI_OUTPUT_PATH = os.path.join(get_assets_path(self.current_filepath), "comfyui-output") # Clear needs save flag self.has_unsaved_changes = False @@ -897,6 +898,7 @@ def save(self, file_path, backup_only=False): info.TITLE_PATH = os.path.join(get_assets_path(self.current_filepath), "title") info.BLENDER_PATH = os.path.join(get_assets_path(self.current_filepath), "blender") info.CLIPBOARD_PATH = os.path.join(get_assets_path(self.current_filepath), "clipboard") + info.COMFYUI_OUTPUT_PATH = os.path.join(get_assets_path(self.current_filepath), "comfyui-output") self.add_to_recent_files(file_path) self.has_unsaved_changes = False @@ -911,12 +913,13 @@ def move_temp_paths_to_project_folder(self, file_path, previous_path=None): target_blender_path = os.path.join(asset_path, "blender") target_protobuf_path = os.path.join(asset_path, "protobuf_data") target_clipboard_path = os.path.join(asset_path, "clipboard") + target_comfy_output_path = os.path.join(asset_path, "comfyui-output") # Create any missing target paths try: for target_dir in [asset_path, target_thumb_path, target_title_path, target_blender_path, target_protobuf_path, - target_clipboard_path]: + target_clipboard_path, target_comfy_output_path]: if not os.path.exists(target_dir): os.mkdir(target_dir) except OSError: @@ -931,12 +934,14 @@ def move_temp_paths_to_project_folder(self, file_path, previous_path=None): info.BLENDER_PATH = os.path.join(previous_asset_path, "blender") info.PROTOBUF_DATA_PATH = os.path.join(previous_asset_path, "protobuf_data") info.CLIPBOARD_PATH = os.path.join(previous_asset_path, "clipboard") + info.COMFYUI_OUTPUT_PATH = os.path.join(previous_asset_path, "comfyui-output") # Track assets we copy/update copied_assets = { "blender": set(), "title": set(), "clipboard": set(), + "comfyui_output": set(), } reader_paths = {} @@ -974,6 +979,20 @@ def move_temp_paths_to_project_folder(self, file_path, previous_path=None): if not os.path.exists(target_clipboard_filepath): shutil.copy2(working_clipboard_path, target_clipboard_filepath) + # Copy all ComfyUI output files/folders (fully) to assets folder + if os.path.exists(info.COMFYUI_OUTPUT_PATH) and ( + os.path.abspath(info.COMFYUI_OUTPUT_PATH) != os.path.abspath(target_comfy_output_path) + ): + for output_name in os.listdir(info.COMFYUI_OUTPUT_PATH): + working_output_path = os.path.join(info.COMFYUI_OUTPUT_PATH, output_name) + target_output_path = os.path.join(target_comfy_output_path, output_name) + if os.path.isdir(working_output_path): + if os.path.exists(target_output_path): + shutil.rmtree(target_output_path, True) + shutil.copytree(working_output_path, target_output_path) + else: + shutil.copy2(working_output_path, target_output_path) + # Copy all protobuf files (if not found in target asset folder) if os.path.abspath(info.PROTOBUF_DATA_PATH) != os.path.abspath(target_protobuf_path): for protobuf_path in os.listdir(info.PROTOBUF_DATA_PATH): @@ -1021,6 +1040,16 @@ def move_temp_paths_to_project_folder(self, file_path, previous_path=None): log.info("Copied clipboard %s to %s", asset_name, target_clipboard_path) new_asset_path = os.path.join(target_clipboard_path, asset_name) + comfy_output_abs = os.path.abspath(info.COMFYUI_OUTPUT_PATH) + path_abs = os.path.abspath(path) + if path_abs.startswith(comfy_output_abs + os.sep): + if os.path.abspath(os.path.dirname(path)) != os.path.abspath(target_comfy_output_path): + relative_output_path = os.path.relpath(path_abs, comfy_output_abs) + if relative_output_path not in copied_assets["comfyui_output"]: + copied_assets["comfyui_output"].add(relative_output_path) + log.info("Copied ComfyUI output %s to %s", relative_output_path, target_comfy_output_path) + new_asset_path = os.path.join(target_comfy_output_path, relative_output_path) + # Update path in File object to new location if new_asset_path: file["path"] = new_asset_path diff --git a/src/comfyui/img2img-basic.json b/src/comfyui/img2img-basic.json new file mode 100644 index 000000000..3a0f26ff3 --- /dev/null +++ b/src/comfyui/img2img-basic.json @@ -0,0 +1,107 @@ +{ + "action_icon": "ai-action-restyle.svg", + "menu_category": "enhance", + "menu_order": 20, + "name": "Change Image Style...", + "open_dialog": true, + "output_type": "image", + "template_id": "img2img-basic", + "workflow": { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + } + }, + "2": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "1", + 1 + ], + "text": "cinematic shot, highly detailed" + } + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "1", + 1 + ], + "text": "low quality, blurry" + } + }, + "4": { + "class_type": "LoadImage", + "inputs": { + "image": "/tmp/input.png", + "upload": "image" + } + }, + "5": { + "class_type": "VAEEncode", + "inputs": { + "pixels": [ + "4", + 0 + ], + "vae": [ + "1", + 2 + ] + } + }, + "6": { + "class_type": "KSampler", + "inputs": { + "cfg": 7.0, + "denoise": 0.65, + "latent_image": [ + "5", + 0 + ], + "model": [ + "1", + 0 + ], + "negative": [ + "3", + 0 + ], + "positive": [ + "2", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 1293041938, + "steps": 20 + } + }, + "7": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "6", + 0 + ], + "vae": [ + "1", + 2 + ] + } + }, + "8": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": "openshot_gen", + "images": [ + "7", + 0 + ] + } + } + } +} diff --git a/src/comfyui/txt2audio-stable-open.json b/src/comfyui/txt2audio-stable-open.json new file mode 100644 index 000000000..7c9cbe299 --- /dev/null +++ b/src/comfyui/txt2audio-stable-open.json @@ -0,0 +1,101 @@ +{ + "action_icon": "ai-action-create-audio.svg", + "menu_category": "create", + "menu_order": 30, + "name": "Audio...", + "open_dialog": true, + "output_type": "audio", + "template_id": "txt2audio-stable-open", + "workflow": { + "10": { + "class_type": "CLIPLoader", + "inputs": { + "clip_name": "t5-base.safetensors", + "type": "stable_audio" + } + }, + "11": { + "class_type": "EmptyLatentAudio", + "inputs": { + "batch_size": 1, + "seconds": 30.0 + } + }, + "12": { + "class_type": "VAEDecodeAudio", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "13": { + "class_type": "SaveAudio", + "inputs": { + "audio": [ + "12", + 0 + ], + "filename_prefix": "audio/openshot_gen" + } + }, + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 5.0, + "denoise": 1.0, + "latent_image": [ + "11", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "dpmpp_3m_sde_gpu", + "scheduler": "exponential", + "seed": 806751699, + "steps": 50 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "10", + 0 + ], + "text": "lofi ambient beat, soft texture, 90 bpm" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "10", + 0 + ], + "text": "" + } + } + } +} diff --git a/src/comfyui/txt2img-basic.json b/src/comfyui/txt2img-basic.json new file mode 100644 index 000000000..5a656854b --- /dev/null +++ b/src/comfyui/txt2img-basic.json @@ -0,0 +1,95 @@ +{ + "action_icon": "ai-action-create-image.svg", + "menu_category": "create", + "menu_order": 10, + "name": "Image...", + "open_dialog": true, + "output_type": "image", + "template_id": "txt2img-basic", + "workflow": { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + } + }, + "2": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "1", + 1 + ], + "text": "cinematic shot, highly detailed" + } + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "1", + 1 + ], + "text": "low quality, blurry" + } + }, + "4": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 576, + "width": 1024 + } + }, + "5": { + "class_type": "KSampler", + "inputs": { + "cfg": 7.0, + "denoise": 1.0, + "latent_image": [ + "4", + 0 + ], + "model": [ + "1", + 0 + ], + "negative": [ + "3", + 0 + ], + "positive": [ + "2", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 687962524, + "steps": 20 + } + }, + "6": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "5", + 0 + ], + "vae": [ + "1", + 2 + ] + } + }, + "7": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": "openshot_gen", + "images": [ + "6", + 0 + ] + } + } + } +} diff --git a/src/comfyui/txt2video-svd.json b/src/comfyui/txt2video-svd.json new file mode 100644 index 000000000..73deac8e1 --- /dev/null +++ b/src/comfyui/txt2video-svd.json @@ -0,0 +1,186 @@ +{ + "action_icon": "ai-action-create-video.svg", + "menu_category": "create", + "menu_order": 20, + "name": "Video...", + "open_dialog": true, + "output_type": "video", + "template_id": "txt2video-svd", + "workflow": { + "1": { + "class_type": "ImageOnlyCheckpointLoader", + "inputs": { + "ckpt_name": "svd_xt.safetensors" + } + }, + "10": { + "class_type": "KSampler", + "inputs": { + "cfg": 2.5, + "denoise": 1.0, + "latent_image": [ + "8", + 2 + ], + "model": [ + "9", + 0 + ], + "negative": [ + "8", + 1 + ], + "positive": [ + "8", + 0 + ], + "sampler_name": "euler", + "scheduler": "karras", + "seed": 1825708738, + "steps": 10 + } + }, + "11": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "10", + 0 + ], + "vae": [ + "1", + 2 + ] + } + }, + "12": { + "class_type": "CreateVideo", + "inputs": { + "fps": 12, + "images": [ + "11", + 0 + ] + } + }, + "13": { + "class_type": "SaveVideo", + "inputs": { + "codec": "auto", + "filename_prefix": "video/openshot_gen", + "format": "auto", + "video": [ + "12", + 0 + ] + } + }, + "2": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + } + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "2", + 1 + ], + "text": "cinematic shot, highly detailed" + } + }, + "4": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "2", + 1 + ], + "text": "low quality, blurry" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 288, + "width": 512 + } + }, + "6": { + "class_type": "KSampler", + "inputs": { + "cfg": 6.0, + "denoise": 1.0, + "latent_image": [ + "5", + 0 + ], + "model": [ + "2", + 0 + ], + "negative": [ + "4", + 0 + ], + "positive": [ + "3", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 1825708737, + "steps": 8 + } + }, + "7": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "6", + 0 + ], + "vae": [ + "2", + 2 + ] + } + }, + "8": { + "class_type": "SVD_img2vid_Conditioning", + "inputs": { + "augmentation_level": 0.0, + "clip_vision": [ + "1", + 1 + ], + "fps": 12, + "height": 288, + "init_image": [ + "7", + 0 + ], + "motion_bucket_id": 127, + "vae": [ + "1", + 2 + ], + "video_frames": 24, + "width": 512 + } + }, + "9": { + "class_type": "VideoLinearCFGGuidance", + "inputs": { + "min_cfg": 1.0, + "model": [ + "1", + 0 + ] + } + } + } +} diff --git a/src/comfyui/upscale-realesrgan-x4.json b/src/comfyui/upscale-realesrgan-x4.json new file mode 100644 index 000000000..b25cfd30b --- /dev/null +++ b/src/comfyui/upscale-realesrgan-x4.json @@ -0,0 +1,47 @@ +{ + "action_icon": "ai-action-upscale.svg", + "menu_category": "enhance", + "menu_order": 10, + "name": "Increase Resolution (4x)", + "open_dialog": false, + "output_type": "image", + "template_id": "upscale-realesrgan-x4", + "workflow": { + "1": { + "class_type": "LoadImage", + "inputs": { + "image": "/tmp/input.png", + "upload": "image" + } + }, + "2": { + "class_type": "UpscaleModelLoader", + "inputs": { + "model_name": "RealESRGAN_x4plus.safetensors" + } + }, + "3": { + "class_type": "ImageUpscaleWithModel", + "inputs": { + "image": [ + "1", + 0 + ], + "upscale_model": [ + "2", + 0 + ] + } + }, + "4": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": "openshot_gen", + "images": [ + "3", + 0 + ] + } + } + } +} diff --git a/src/comfyui/video-frame-interpolation-rife2x.json b/src/comfyui/video-frame-interpolation-rife2x.json new file mode 100644 index 000000000..c9554d6ad --- /dev/null +++ b/src/comfyui/video-frame-interpolation-rife2x.json @@ -0,0 +1,70 @@ +{ + "action_icon": "ai-action-smooth.svg", + "menu_category": "enhance", + "menu_order": 20, + "name": "Smooth Motion (2x Frame Rate)", + "open_dialog": false, + "output_type": "video", + "template_id": "video-frame-interpolation-rife2x", + "workflow": { + "1": { + "class_type": "LoadVideo", + "inputs": { + "file": "/tmp/input.mp4" + } + }, + "2": { + "class_type": "GetVideoComponents", + "inputs": { + "video": [ + "1", + 0 + ] + } + }, + "3": { + "_meta": { + "title": "RIFE VFI (recommend rife47 and rife49)" + }, + "class_type": "RIFE VFI", + "inputs": { + "ckpt_name": "rife47.pth", + "clear_cache_after_n_frames": 10, + "ensemble": true, + "fast_mode": true, + "frames": [ + "2", + 0 + ], + "multiplier": 2, + "scale_factor": 1 + } + }, + "4": { + "class_type": "CreateVideo", + "inputs": { + "audio": [ + "2", + 1 + ], + "fps": 60.0, + "images": [ + "3", + 0 + ] + } + }, + "5": { + "class_type": "SaveVideo", + "inputs": { + "codec": "auto", + "filename_prefix": "video/openshot_gen", + "format": "auto", + "video": [ + "4", + 0 + ] + } + } + } +} diff --git a/src/comfyui/video-segment-scenes-transnet.json b/src/comfyui/video-segment-scenes-transnet.json new file mode 100644 index 000000000..40a8c1970 --- /dev/null +++ b/src/comfyui/video-segment-scenes-transnet.json @@ -0,0 +1,73 @@ +{ + "action_icon": "ai-action-scenes.svg", + "menu_category": "enhance", + "menu_order": 30, + "name": "Split into Scenes", + "open_dialog": false, + "output_type": "video", + "template_id": "video-segment-scenes-transnet", + "workflow": { + "1": { + "_meta": { + "title": "MiaoshouAI Segment Video" + }, + "class_type": "TransNetV2_Run", + "inputs": { + "TransNet_model": [ + "2", + 0 + ], + "min_scene_length": 30, + "output_dir": "output", + "threshold": 0.5, + "video": [ + "7", + 0 + ] + } + }, + "2": { + "_meta": { + "title": "MiaoshouAI Load TransNet Model" + }, + "class_type": "DownloadAndLoadTransNetModel", + "inputs": { + "device": "auto", + "model": "transnetv2-pytorch-weights" + } + }, + "7": { + "class_type": "LoadVideo", + "inputs": { + "file": "/tmp/input.mp4" + } + }, + "8": { + "_meta": { + "title": "MiaoshouAI Select Video" + }, + "class_type": "SelectVideo", + "inputs": { + "index": 0, + "segment_paths": [ + "1", + 0 + ] + } + }, + "9": { + "_meta": { + "title": "Preview Any" + }, + "class_type": "PreviewAny", + "inputs": { + "preview": "", + "previewMode": null, + "source": [ + "1", + 0 + ] + } + } + } +} diff --git a/src/comfyui/video-upscale-gan.json b/src/comfyui/video-upscale-gan.json new file mode 100644 index 000000000..9ffb0dabf --- /dev/null +++ b/src/comfyui/video-upscale-gan.json @@ -0,0 +1,86 @@ +{ + "action_icon": "ai-action-upscale.svg", + "menu_category": "enhance", + "menu_order": 10, + "name": "Increase Resolution (4x)", + "open_dialog": false, + "output_type": "video", + "template_id": "video-upscale-gan", + "workflow": { + "1": { + "class_type": "LoadVideo", + "inputs": { + "file": "/tmp/input.mp4" + } + }, + "2": { + "class_type": "Video Slice", + "inputs": { + "duration": 10.0, + "start_time": 0.0, + "strict_duration": false, + "video": [ + "1", + 0 + ] + } + }, + "3": { + "class_type": "GetVideoComponents", + "inputs": { + "video": [ + "2", + 0 + ] + } + }, + "4": { + "class_type": "UpscaleModelLoader", + "inputs": { + "model_name": "RealESRGAN_x4plus.safetensors" + } + }, + "5": { + "class_type": "ImageUpscaleWithModel", + "inputs": { + "image": [ + "3", + 0 + ], + "upscale_model": [ + "4", + 0 + ] + } + }, + "6": { + "class_type": "CreateVideo", + "inputs": { + "audio": [ + "3", + 1 + ], + "fps": [ + "3", + 2 + ], + "images": [ + "5", + 0 + ] + } + }, + "7": { + "class_type": "SaveVideo", + "inputs": { + "codec": "auto", + "filename_prefix": "video/openshot_gen", + "format": "auto", + "video": [ + "6", + 0 + ] + } + } + } +} diff --git a/src/comfyui/video-whisper-srt.json b/src/comfyui/video-whisper-srt.json new file mode 100644 index 000000000..59edbdb1d --- /dev/null +++ b/src/comfyui/video-whisper-srt.json @@ -0,0 +1,57 @@ +{ + "action_icon": "ai-action-captions.svg", + "menu_category": "enhance", + "menu_order": 40, + "name": "Add Captions from Speech", + "open_dialog": false, + "output_type": "video", + "template_id": "video-whisper-srt", + "workflow": { + "1": { + "class_type": "VHS_LoadVideo", + "inputs": { + "custom_height": 0, + "custom_width": 0, + "force_rate": 0, + "format": "AnimateDiff", + "frame_load_cap": 0, + "select_every_nth": 1, + "skip_first_frames": 0, + "video": "/tmp/input.mp4" + } + }, + "2": { + "class_type": "Apply Whisper", + "inputs": { + "audio": [ + "1", + 2 + ], + "language": "auto", + "model": "medium", + "prompt": "" + } + }, + "3": { + "class_type": "Save SRT", + "inputs": { + "alignment": [ + "2", + 1 + ], + "name": "openshot_gen_segments" + } + }, + "4": { + "class_type": "PreviewAny", + "inputs": { + "preview": "", + "previewMode": null, + "source": [ + "3", + 0 + ] + } + } + } +} diff --git a/src/comfyui/video2video-basic.json b/src/comfyui/video2video-basic.json new file mode 100644 index 000000000..bc86fd212 --- /dev/null +++ b/src/comfyui/video2video-basic.json @@ -0,0 +1,146 @@ +{ + "action_icon": "ai-action-restyle.svg", + "menu_category": "enhance", + "menu_order": 50, + "name": "Change Video Style...", + "open_dialog": true, + "output_type": "video", + "template_id": "video2video-basic", + "workflow": { + "1": { + "class_type": "LoadVideo", + "inputs": { + "file": "/tmp/input.mp4" + } + }, + "10": { + "class_type": "CreateVideo", + "inputs": { + "audio": [ + "3", + 1 + ], + "fps": [ + "3", + 2 + ], + "images": [ + "9", + 0 + ] + } + }, + "11": { + "class_type": "SaveVideo", + "inputs": { + "codec": "auto", + "filename_prefix": "video/openshot_gen", + "format": "auto", + "video": [ + "10", + 0 + ] + } + }, + "2": { + "class_type": "Video Slice", + "inputs": { + "duration": 10.0, + "start_time": 0.0, + "strict_duration": false, + "video": [ + "1", + 0 + ] + } + }, + "3": { + "class_type": "GetVideoComponents", + "inputs": { + "video": [ + "2", + 0 + ] + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + } + }, + "5": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "cinematic shot, highly detailed" + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "low quality, blurry" + } + }, + "7": { + "class_type": "VAEEncode", + "inputs": { + "pixels": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "8": { + "class_type": "KSampler", + "inputs": { + "cfg": 6.0, + "denoise": 0.55, + "latent_image": [ + "7", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "6", + 0 + ], + "positive": [ + "5", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 147747579, + "steps": 16 + } + }, + "9": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "8", + 0 + ], + "vae": [ + "4", + 2 + ] + } + } + } +} diff --git a/src/windows/main_window.py b/src/windows/main_window.py index 70a7c9101..c14e3290b 100644 --- a/src/windows/main_window.py +++ b/src/windows/main_window.py @@ -709,6 +709,7 @@ def clear_temporary_files(self): info.get_default_path("BLENDER_PATH"), info.get_default_path("TITLE_PATH"), info.get_default_path("CLIPBOARD_PATH"), + info.get_default_path("COMFYUI_OUTPUT_PATH"), ]: try: if os.path.exists(temp_dir): diff --git a/src/windows/models/files_model.py b/src/windows/models/files_model.py index 61a9be834..c4ea4c3ac 100644 --- a/src/windows/models/files_model.py +++ b/src/windows/models/files_model.py @@ -804,6 +804,31 @@ def _placeholder_row_for_job(self, job_id): return None return id_index.row() + def _generation_icon_for_job(self, job): + icon_name = "tool-generate-sparkle.svg" + try: + app = get_app() + window = getattr(app, "window", None) + generation_service = getattr(window, "generation_service", None) + if generation_service and isinstance(job, dict): + template_id = str(job.get("template_id") or "").strip() + template = generation_service.template_registry.get_template(template_id) + if template: + resolved_icon = generation_service.icon_for_template(template) + if resolved_icon: + icon_name = resolved_icon + except Exception: + pass + + icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", icon_name) + if os.path.exists(icon_path): + return QIcon(icon_path) + + emoji_icon_path = os.path.join(info.PATH, "emojis", "color", "svg", "2728.svg") + if os.path.exists(emoji_icon_path): + return QIcon(emoji_icon_path) + return QIcon(":/icons/Humanity/actions/16/media-record.svg") + def _add_generation_placeholder(self, job_id): job = self.generation_queue.get_job(job_id) if self.generation_queue else None if not job: @@ -828,14 +853,7 @@ def _add_generation_placeholder(self, job_id): label = "{} (Canceling...)".format(name) row = [] - generate_icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "tool-generate-sparkle.svg") - emoji_icon_path = os.path.join(info.PATH, "emojis", "color", "svg", "2728.svg") - if os.path.exists(generate_icon_path): - icon = QIcon(generate_icon_path) - elif os.path.exists(emoji_icon_path): - icon = QIcon(emoji_icon_path) - else: - icon = QIcon(":/icons/Humanity/actions/16/media-record.svg") + icon = self._generation_icon_for_job(job) flags = Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemNeverHasChildren col = QStandardItem(icon, label) @@ -886,6 +904,7 @@ def _update_generation_placeholder(self, job_id): elif status == "canceling": label = "{} (Canceling...)".format(name) + self.model.item(row, 0).setIcon(self._generation_icon_for_job(job)) self.model.item(row, 0).setText(label) self.model.item(row, 1).setText(label) left = self.model.index(row, 0) diff --git a/src/windows/views/ai_tools_menu.py b/src/windows/views/ai_tools_menu.py index 50045aa5c..04070e89c 100644 --- a/src/windows/views/ai_tools_menu.py +++ b/src/windows/views/ai_tools_menu.py @@ -30,72 +30,47 @@ def _icon(name): def add_ai_tools_menu(win, parent_menu, source_file=None): _ = get_app()._tr - media_type = str(source_file.data.get("media_type", "")) if source_file else "" - + grouped = win.generation_service.build_menu_templates(source_file=source_file) + menu_defs = [] if source_file: - ai_menu = StyledContextMenu(title=_("Enhance with AI"), parent=parent_menu) - ai_menu.setIcon(_icon("tool-generate-sparkle.svg")) + menu_defs = [("enhance", _("Enhance with AI")), ("unknown", _("Unknown AI"))] + else: + menu_defs = [("create", _("Create with AI")), ("unknown", _("Unknown AI"))] - if media_type == "image": - action = ai_menu.addAction(_("Increase Resolution (4x)")) - action.setIcon(_icon("ai-action-upscale.svg")) - action.triggered.connect( - partial(_trigger_generation, win, "upscale-realesrgan-x4", source_file, False) - ) - ai_menu.addSeparator() - action = ai_menu.addAction(_("Change Image Style...")) - action.setIcon(_icon("ai-action-restyle.svg")) - action.triggered.connect( - partial(_trigger_generation, win, "img2img-basic", source_file, True) - ) - parent_menu.addMenu(ai_menu) - return ai_menu + created_menus = [] + for key, title in menu_defs: + templates = list(grouped.get(key, []) or []) + if not templates: + continue + ai_menu = StyledContextMenu(title=title, parent=parent_menu) + ai_menu.setIcon(_icon("tool-generate-sparkle.svg")) - elif media_type == "video": - action = ai_menu.addAction(_("Increase Resolution (4x)")) - action.setIcon(_icon("ai-action-upscale.svg")) - action.triggered.connect( - partial(_trigger_generation, win, "video-upscale-gan", source_file, False) - ) - action = ai_menu.addAction(_("Smooth Motion (2x Frame Rate)")) - action.setIcon(_icon("ai-action-smooth.svg")) - action.triggered.connect( - partial(_trigger_generation, win, "video-frame-interpolation-rife2x", source_file, False) - ) - action = ai_menu.addAction(_("Split into Scenes")) - action.setIcon(_icon("ai-action-scenes.svg")) - action.triggered.connect( - partial(_trigger_generation, win, "video-segment-scenes-transnet", source_file, False) - ) - action = ai_menu.addAction(_("Add Captions from Speech")) - action.setIcon(_icon("ai-action-captions.svg")) - action.triggered.connect( - partial(_trigger_generation, win, "video-whisper-srt", source_file, False) - ) - ai_menu.addSeparator() - action = ai_menu.addAction(_("Change Video Style...")) - action.setIcon(_icon("ai-action-restyle.svg")) + inserted_style_separator = False + for template in templates: + template_key = str(template.get("template_id") or template.get("id") or "") + if ( + key == "enhance" + and not inserted_style_separator + and template_key in ("img2img-basic", "video2video-basic") + ): + ai_menu.addSeparator() + inserted_style_separator = True + open_dialog = template.get("open_dialog") + if not isinstance(open_dialog, bool): + open_dialog = (source_file is None) or bool(template.get("needs_prompt", False)) + action = ai_menu.addAction(_(str(template.get("display_name", "")))) + action.setIcon(_icon(win.generation_service.icon_for_template(template))) action.triggered.connect( - partial(_trigger_generation, win, "video2video-basic", source_file, True) + partial( + _trigger_generation, + win, + template.get("id"), + source_file, + open_dialog, + ) ) - else: - action = ai_menu.addAction(_("No AI enhancement actions available yet.")) - action.setEnabled(False) parent_menu.addMenu(ai_menu) - return ai_menu - - ai_menu = StyledContextMenu(title=_("Create with AI"), parent=parent_menu) - ai_menu.setIcon(_icon("tool-generate-sparkle.svg")) - action = ai_menu.addAction(_("Image...")) - action.setIcon(_icon("ai-action-create-image.svg")) - action.triggered.connect(partial(_trigger_generation, win, "txt2img-basic", source_file, True)) - action = ai_menu.addAction(_("Video...")) - action.setIcon(_icon("ai-action-create-video.svg")) - action.triggered.connect(partial(_trigger_generation, win, "txt2video-svd", source_file, True)) - action = ai_menu.addAction(_("Audio...")) - action.setIcon(_icon("ai-action-create-audio.svg")) - action.triggered.connect(partial(_trigger_generation, win, "txt2audio-stable-open", source_file, True)) + created_menus.append(ai_menu) - parent_menu.addMenu(ai_menu) - return ai_menu + return created_menus[0] if created_menus else None From 5c9f81b756e239f7a4f533fe55d502d1adb260ff Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Tue, 17 Feb 2026 17:14:49 -0600 Subject: [PATCH 11/27] =?UTF-8?q?-=20Added=20new=20SAM2=20Blur=20Anything?= =?UTF-8?q?=20templates:=20=20=20=20=20=20=20-=20src/comfyui/image-blur-an?= =?UTF-8?q?ything-sam2.json=20=20=20=20=20=20=20-=20src/comfyui/video-blur?= =?UTF-8?q?-anything-sam2.json=20-=20Added=20Comfy=20availability=20gating?= =?UTF-8?q?=20+=20URL=20validation:=20=20=20=20=20=20=20-=20Hide=20AI=20me?= =?UTF-8?q?nus=20when=20Comfy=20is=20unreachable=20=20=20=20=20=20=20-=20P?= =?UTF-8?q?references=20now=20has=20Check=20for=20comfy-ui-url=20-=20Added?= =?UTF-8?q?=20Comfy=20error=20truncation/sanitizing=20so=20failures=20don?= =?UTF-8?q?=E2=80=99t=20dump=20huge=20payloads=20in=20UI=20dialogs.=20-=20?= =?UTF-8?q?Enabled=20proper=20maximize/min-max=20behavior=20for=20Region?= =?UTF-8?q?=20and=20Split=20dialogs.=20-=20Moved=20Queued=20badge=20to=20t?= =?UTF-8?q?he=20bottom=20of=20file=20thumbnails=20(list=20+=20tree=20views?= =?UTF-8?q?).?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/classes/comfy_client.py | 99 +++++++++++- src/classes/comfy_templates.py | 17 ++ src/classes/generation_queue.py | 95 ++++++++++- src/classes/generation_service.py | 150 ++++++++++++++++-- src/comfyui/image-blur-anything-sam2.json | 84 ++++++++++ src/comfyui/video-blur-anything-sam2.json | 124 +++++++++++++++ src/windows/cutting.py | 7 + src/windows/generate.py | 183 +++++++++++++++++++--- src/windows/preferences.py | 54 +++++++ src/windows/region.py | 88 +++++++++-- src/windows/video_widget.py | 80 +++++++++- src/windows/views/ai_tools_menu.py | 3 + src/windows/views/files_listview.py | 4 +- src/windows/views/files_treeview.py | 4 +- 14 files changed, 926 insertions(+), 66 deletions(-) create mode 100644 src/comfyui/image-blur-anything-sam2.json create mode 100644 src/comfyui/video-blur-anything-sam2.json diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index 87aa611e6..cc7a7a3cc 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -251,6 +251,7 @@ def _send_control_frame(self, opcode, payload=b""): class ComfyClient: """Minimal ComfyUI client using stdlib HTTP.""" + ERROR_MAX_CHARS = 1800 def __init__(self, base_url): self.base_url = str(base_url or "").rstrip("/") @@ -321,16 +322,16 @@ def queue_prompt(self, prompt_graph, client_id): if node_error_text: details = "{}\n{}".format(details or "prompt validation failed", node_error_text) elif not details: - details = json.dumps(error_data, ensure_ascii=True) + details = ComfyClient.summarize_error_text(error_data) else: - details = "{}\n{}".format(details, json.dumps(error_data, ensure_ascii=True)) + details = "{}\n{}".format(details, ComfyClient.summarize_error_text(error_data)) except Exception: details = str(ex) - raise RuntimeError("ComfyUI prompt rejected: {}".format(details)) + raise RuntimeError("ComfyUI prompt rejected: {}".format(ComfyClient.summarize_error_text(details))) return data.get("prompt_id") def _rewrite_prompt_local_file_inputs(self, prompt_graph): - """Rewrite local absolute paths for LoadImage/LoadVideo nodes to uploaded [input] refs.""" + """Rewrite local absolute paths for image/video loader nodes to uploaded Comfy input refs.""" if not isinstance(prompt_graph, dict): return prompt_graph rewritten = dict(prompt_graph) @@ -363,6 +364,24 @@ def _annotated(path_text): node["inputs"] = inputs rewritten[node_id] = node log.debug("ComfyClient rewrote LoadVideo input node=%s path=%s -> %s", str(node_id), video_path, uploaded) + elif class_type in ("VHS_LoadVideo", "VHS_LoadVideoPath", "VHS_LoadVideoFFmpegPath"): + video_path = str(inputs.get("video", "")).strip() + if video_path and os.path.isabs(video_path) and os.path.exists(video_path) and not _annotated(video_path): + uploaded = self.upload_input_file(video_path) + # VHS_LoadVideo expects a plain filename from Comfy input options. + # Path-based VHS loaders accept a plain relative path as well. + if uploaded.endswith(" [input]"): + uploaded = uploaded[:-8].strip() + inputs["video"] = uploaded + node["inputs"] = inputs + rewritten[node_id] = node + log.debug( + "ComfyClient rewrote %s input node=%s path=%s -> %s", + class_type, + str(node_id), + video_path, + uploaded, + ) return rewritten @@ -394,6 +413,72 @@ def _format_node_errors(node_errors): return "" return "Node validation errors: {}".format(" | ".join(lines)) + @staticmethod + def summarize_error_text(value, max_chars=None): + """Return a compact Comfy error text safe for UI display.""" + if max_chars is None: + max_chars = ComfyClient.ERROR_MAX_CHARS + + if isinstance(value, (dict, list, tuple)): + value = ComfyClient._limit_error_structure(value) + try: + text = json.dumps(value, ensure_ascii=True) + except Exception: + text = str(value) + else: + text = str(value or "") + + # Remove huge numeric/tensor dumps that make dialogs unreadable. + text = re.sub(r"tensor\(\[[\s\S]{250,}?\]\)", "tensor([])", text) + text = re.sub(r"array\(\[[\s\S]{250,}?\]\)", "array([])", text) + text = re.sub(r"\[[\d\.\-eE,\s]{350,}\]", "[]", text) + text = re.sub(r"\s+", " ", text).strip() + + max_chars = max(300, int(max_chars)) + if len(text) > max_chars: + truncated = len(text) - max_chars + text = "{} ... [truncated {} chars]".format(text[:max_chars], truncated) + return text + + @staticmethod + def _limit_error_structure(value, depth=0, max_depth=4, max_items=10, max_str=260): + if depth >= max_depth: + return "<...>" + if isinstance(value, dict): + out = {} + for index, key in enumerate(value.keys()): + if index >= max_items: + out[""] = len(value) - max_items + break + out[str(key)] = ComfyClient._limit_error_structure( + value.get(key), + depth=depth + 1, + max_depth=max_depth, + max_items=max_items, + max_str=max_str, + ) + return out + if isinstance(value, (list, tuple)): + out = [] + for index, item in enumerate(value): + if index >= max_items: + out.append("".format(len(value) - max_items)) + break + out.append( + ComfyClient._limit_error_structure( + item, + depth=depth + 1, + max_depth=max_depth, + max_items=max_items, + max_str=max_str, + ) + ) + return out + text = str(value) + if len(text) > max_str: + return text[:max_str] + "..." + return text + def list_checkpoints(self): """Return available checkpoint names from ComfyUI object info.""" with urlopen("{}/object_info/CheckpointLoaderSimple".format(self.base_url), timeout=8.0) as response: @@ -518,6 +603,10 @@ def history(self, prompt_id): with urlopen("{}/history/{}".format(self.base_url, quote(str(prompt_id))), timeout=10.0) as response: return json.loads(response.read().decode("utf-8")) + def history_all(self): + with urlopen("{}/history".format(self.base_url), timeout=10.0) as response: + return json.loads(response.read().decode("utf-8")) + def progress(self): """Return ComfyUI /progress payload.""" try: @@ -645,7 +734,7 @@ def extract_file_outputs(history_entry, save_node_ids=None): if save_node_ids and str(node_id) not in save_node_ids: continue if isinstance(node_out, dict): - for key in ("images", "videos", "video", "audios", "audio", "files"): + for key in ("images", "videos", "video", "gifs", "audios", "audio", "files", "filenames"): refs = node_out.get(key, []) if not isinstance(refs, list): continue diff --git a/src/classes/comfy_templates.py b/src/classes/comfy_templates.py index 19d40f2ea..9aa8b7f28 100644 --- a/src/classes/comfy_templates.py +++ b/src/classes/comfy_templates.py @@ -85,6 +85,23 @@ "imagetosimage", "imagetoimage", "imagescaleto", + "imageblur", + "imagecompositemasked", + # Video Helper Suite + "vhs_batchmanager", + "vhs_loadvideo", + "vhs_loadvideopath", + "vhs_loadvideoffmpegpath", + "vhs_videocombine", + "vhs_videoinfo", + "vhs_videoinfoloaded", + "vhs_videoinfosource", + # ComfyUI-segment-anything-2 + "downloadandloadsam2model", + "sam2segmentation", + "sam2autosegmentation", + "sam2videosegmentationaddpoints", + "sam2videosegmentation", } diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index 12aeceb5d..96857c7b7 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -50,6 +50,74 @@ def __init__(self): def _is_cancel_requested(self, job_id, cancel_event): return (job_id in self._cancel_requested) or (cancel_event is not None and cancel_event.is_set()) + @staticmethod + def _is_unfinished_meta_batch(history_entry): + outputs = history_entry.get("outputs", {}) if isinstance(history_entry, dict) else {} + if not isinstance(outputs, dict): + return False + for node_out in outputs.values(): + if not isinstance(node_out, dict): + continue + unfinished = node_out.get("unfinished_batch", None) + if isinstance(unfinished, list): + if any(bool(v) for v in unfinished): + return True + elif unfinished: + return True + return False + + @staticmethod + def _history_prompt_meta(history_entry): + prompt_payload = history_entry.get("prompt", []) if isinstance(history_entry, dict) else [] + if not isinstance(prompt_payload, list): + return "", 0 + client_payload = prompt_payload[3] if len(prompt_payload) >= 4 else {} + if not isinstance(client_payload, dict): + return "", 0 + client_id = str(client_payload.get("client_id", "")).strip() + create_time = int(client_payload.get("create_time", 0) or 0) + return client_id, create_time + + def _find_related_meta_batch_outputs(self, client, history_entry, save_node_ids): + base_client_id, base_create_time = self._history_prompt_meta(history_entry) + if not base_client_id: + return [] + try: + history_all = client.history_all() or {} + except Exception: + return [] + if not isinstance(history_all, dict): + return [] + + best_create_time = 0 + best_outputs = [] + for entry in history_all.values(): + if not isinstance(entry, dict): + continue + status_obj = entry.get("status", {}) if isinstance(entry, dict) else {} + status_str = str(status_obj.get("status_str", "")).lower() + if status_str not in ("success", "completed", ""): + continue + entry_client_id, entry_create_time = self._history_prompt_meta(entry) + if entry_client_id != base_client_id: + continue + if entry_create_time and base_create_time and entry_create_time < base_create_time: + continue + if self._is_unfinished_meta_batch(entry): + continue + + outputs = ComfyClient.extract_file_outputs(entry, save_node_ids=save_node_ids) + if not outputs and save_node_ids: + outputs = ComfyClient.extract_file_outputs(entry, save_node_ids=None) + if not outputs: + continue + + if entry_create_time >= best_create_time: + best_create_time = entry_create_time + best_outputs = outputs + + return best_outputs + @pyqtSlot(str, object) def run_job(self, job_id, request): request = request or {} @@ -209,15 +277,28 @@ def _run_comfy_job(self, job_id, request): error_text = "ComfyUI job failed." messages = status_obj.get("messages", []) if isinstance(messages, list) and messages: - error_text = str(messages[-1]) + error_text = ComfyClient.summarize_error_text(messages[-1]) self._job_prompts.pop(job_id, None) self.job_finished.emit(job_id, False, False, error_text, []) return - image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) - self.progress_changed.emit(job_id, 100) - self._job_prompts.pop(job_id, None) - self.job_finished.emit(job_id, True, False, "", image_outputs) - return + if self._is_unfinished_meta_batch(history_entry): + image_outputs = self._find_related_meta_batch_outputs(client, history_entry, save_node_ids) + if image_outputs: + self.progress_changed.emit(job_id, 100) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, True, False, "", image_outputs) + return + # Meta batch uses follow-up prompts under the same client_id. + # Keep polling progress/queue while waiting for follow-up prompt outputs. + else: + image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) + if not image_outputs and save_node_ids: + # Fallback for workflows whose output node ids shift or emit non-standard keys. + image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=None) + self.progress_changed.emit(job_id, 100) + self._job_prompts.pop(job_id, None) + self.job_finished.emit(job_id, True, False, "", image_outputs) + return # Query ComfyUI's live progress values when available. try: @@ -422,7 +503,7 @@ def _run_comfy_job(self, job_id, request): QThread.msleep(500) except Exception as ex: self._job_prompts.pop(job_id, None) - self.job_finished.emit(job_id, False, False, str(ex), []) + self.job_finished.emit(job_id, False, False, ComfyClient.summarize_error_text(ex), []) finally: if ws_client is not None: ws_client.close() diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 890684206..db186cb08 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -28,6 +28,7 @@ import os import re import tempfile +import json from time import time from urllib.parse import unquote from fractions import Fraction @@ -81,7 +82,8 @@ class GenerationService: def __init__(self, win): self.win = win self._generation_temp_files = [] - self._comfy_status_cache = {"checked_at": 0.0, "available": False} + self._comfy_status_cache = {"checked_at": 0.0, "available": False, "url": ""} + self._last_logged_comfy_state = None self.template_registry = ComfyTemplateRegistry() def cleanup_temp_files(self): @@ -104,13 +106,29 @@ def is_comfy_available(self, force=False): url = self.comfy_ui_url() available = False + error_text = "" try: available = ComfyClient(url).ping(timeout=0.5) - except Exception: + except Exception as ex: available = False + error_text = str(ex) + previous_available = bool(self._comfy_status_cache.get("available")) + previous_url = str(self._comfy_status_cache.get("url", "")) self._comfy_status_cache["checked_at"] = now self._comfy_status_cache["available"] = available + self._comfy_status_cache["url"] = url + + state = (url, bool(available)) + if force or state != self._last_logged_comfy_state or previous_url != url or previous_available != available: + if available: + log.info("ComfyUI check passed at %s", url) + else: + if error_text: + log.info("ComfyUI check failed at %s (%s)", url, error_text) + else: + log.info("ComfyUI check failed at %s", url) + self._last_logged_comfy_state = state return available def can_open_generate_dialog(self): @@ -198,7 +216,16 @@ def build_menu_templates(self, source_file=None): def icon_for_template(self, template): return self.template_registry.output_icon_name(template) - def _prepare_nonlegacy_workflow(self, template, payload_name, prompt_text, source_file, source_path): + def _prepare_nonlegacy_workflow( + self, + template, + payload_name, + prompt_text, + source_file, + source_path, + coordinates_positive_text="", + coordinates_negative_text="", + ): workflow = self.template_registry.get_workflow_copy(template.get("id")) if not workflow: raise ValueError("Template workflow not found.") @@ -221,7 +248,10 @@ def _resolve_template_local_file(path_text): return candidate return "" + template_id = str((template or {}).get("id") or "").strip().lower() prompt_text = str(prompt_text or "").strip() + coordinates_positive_text = str(coordinates_positive_text or "").strip() + coordinates_negative_text = str(coordinates_negative_text or "").strip() media_type = str(source_file.data.get("media_type", "")).strip().lower() if source_file else "" applied_prompt = False loadimage_node_ids = [] @@ -234,7 +264,7 @@ def _resolve_template_local_file(path_text): class_flat = str(node.get("class_type", "")).strip().lower() if class_flat == "loadimage": loadimage_node_ids.append(str(node_id)) - elif class_flat in ("loadvideo", "load video", "vhs_loadvideo"): + elif class_flat in ("loadvideo", "load video", "vhs_loadvideo", "vhs_loadvideopath", "vhs_loadvideoffmpegpath"): loadvideo_node_ids.append(str(node_id)) elif class_flat in ("loadaudio", "load audio"): loadaudio_node_ids.append(str(node_id)) @@ -243,6 +273,60 @@ def _is_placeholder_value(path_text): path_text = str(path_text or "").strip().lower() return path_text in ("__openshot_input__", "{{openshot_input}}", "$openshot_input") + def _is_prompt_placeholder_value(text_value): + text_value = str(text_value or "").strip().lower() + return text_value in ("__openshot_prompt__", "{{openshot_prompt}}", "$openshot_prompt") + + def _normalize_sam2_coords_input(text_value, fallback_value): + text_value = str(text_value or "").strip() + fallback_value = str(fallback_value or "").strip() + if not text_value: + return fallback_value + # Accept raw JSON list format expected by Sam2VideoSegmentationAddPoints. + if text_value.startswith("[") and "x" in text_value and "y" in text_value: + return text_value + + # Also accept "x,y; x,y" shorthand and convert to JSON list. + points = [] + for chunk in text_value.split(";"): + chunk = chunk.strip() + if not chunk: + continue + parts = [p.strip() for p in chunk.split(",")] + if len(parts) != 2: + return fallback_value + try: + x_val = float(parts[0]) + y_val = float(parts[1]) + except (TypeError, ValueError): + return fallback_value + points.append({"x": x_val, "y": y_val}) + if not points: + return fallback_value + return str(points).replace("'", "\"") + + def _parse_sam2_points(coords_text): + coords_text = str(coords_text or "").strip() + if not coords_text: + return [] + try: + parsed = json.loads(coords_text.replace("'", "\"")) + except Exception: + return [] + if not isinstance(parsed, list): + return [] + points = [] + for item in parsed: + if not isinstance(item, dict): + continue + if "x" not in item or "y" not in item: + continue + try: + points.append({"x": float(item["x"]), "y": float(item["y"])}) + except Exception: + continue + return points + def _select_bind_nodes(node_ids, path_key, preferred_upload=None): explicit = [] candidates = [] @@ -291,9 +375,51 @@ def _select_bind_nodes(node_ids, path_key, preferred_upload=None): else: inputs["filename_prefix"] = payload_name - if prompt_text and class_flat == "cliptextencode" and "text" in inputs and not applied_prompt: - inputs["text"] = prompt_text - applied_prompt = True + if prompt_text: + text_value = inputs.get("text", None) + prompt_value = inputs.get("prompt", None) + + if isinstance(text_value, str) and _is_prompt_placeholder_value(text_value): + inputs["text"] = prompt_text + applied_prompt = True + elif isinstance(prompt_value, str) and _is_prompt_placeholder_value(prompt_value): + inputs["prompt"] = prompt_text + applied_prompt = True + elif class_flat == "cliptextencode" and "text" in inputs and not applied_prompt: + inputs["text"] = prompt_text + applied_prompt = True + elif "prompt" in inputs and isinstance(prompt_value, str) and not prompt_value.strip() and not applied_prompt: + # Support prompt-driven custom nodes (e.g. GroundingDINO/SAM2) that expose a plain string prompt input. + inputs["prompt"] = prompt_text + applied_prompt = True + + coords_value = inputs.get("coordinates_positive", None) + if ( + class_flat in ("sam2videosegmentationaddpoints", "sam2segmentation") + and "coordinates_positive" in inputs + and isinstance(coords_value, str) + ): + coords_text = coordinates_positive_text or prompt_text + points = _parse_sam2_points(coords_text) + if ("blur-anything-sam2" in template_id) and (not points): + raise ValueError("No SAM2 points were provided. Use Mask > Pick Point(s) on Source.") + inputs["coordinates_positive"] = _normalize_sam2_coords_input(coords_text, coords_value) + if class_flat == "sam2segmentation" and "individual_objects" in inputs: + # For Blur Anything, treat points as a single combined prompt. + # This is more stable with mixed positive/negative points and avoids + # per-object mask selection quirks in the current SAM2 single-image node. + if "blur-anything-sam2" in template_id: + inputs["individual_objects"] = False + else: + # Non-Blur-Anything templates keep multi-object behavior. + inputs["individual_objects"] = bool(len(points) > 1) + if coordinates_negative_text: + neg_value = inputs.get("coordinates_negative", "") + if isinstance(neg_value, str) or "coordinates_negative" not in inputs: + inputs["coordinates_negative"] = _normalize_sam2_coords_input( + coordinates_negative_text, + str(neg_value or ""), + ) if not source_path: continue @@ -327,12 +453,12 @@ def _select_bind_nodes(node_ids, path_key, preferred_upload=None): "Comfy template missing LoadImage asset (%s). Falling back to selected source image.", configured_image, ) - elif class_flat in ("loadvideo", "load video") and media_type == "video" and node_id in video_bind_nodes: + elif class_flat in ("loadvideo", "load video", "vhs_loadvideopath", "vhs_loadvideoffmpegpath") and media_type == "video" and node_id in video_bind_nodes: if "file" in inputs: inputs["file"] = source_path elif "video" in inputs: inputs["video"] = source_path - elif class_flat in ("loadvideo", "load video"): + elif class_flat in ("loadvideo", "load video", "vhs_loadvideopath", "vhs_loadvideoffmpegpath"): local_video = _resolve_template_local_file(inputs.get("file", "")) if local_video and "file" in inputs: inputs["file"] = local_video @@ -369,7 +495,7 @@ def _save_nodes_for_workflow(self, workflow): class_type = str(node.get("class_type", "")).strip().lower() if not class_type: continue - if class_type.startswith("save") or class_type in ("previewany", "transnetv2_run"): + if class_type.startswith("save") or class_type in ("previewany", "transnetv2_run", "vhs_videocombine"): save_nodes.append(str(node_id)) return save_nodes @@ -603,6 +729,8 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No prompt_text=payload.get("prompt"), source_file=source_file, source_path=source_path, + coordinates_positive_text=payload.get("coordinates_positive"), + coordinates_negative_text=payload.get("coordinates_negative"), ) except Exception as ex: QMessageBox.information(self.win, "Invalid Input", str(ex)) @@ -666,7 +794,7 @@ def on_generation_job_finished(self, job_id, status): return if status == "failed": - error_text = str(job.get("error") or "ComfyUI generation failed.") + error_text = ComfyClient.summarize_error_text(job.get("error") or "ComfyUI generation failed.") self.win.statusBar.showMessage("Generation failed", 5000) QMessageBox.warning(self.win, "Generation Failed", error_text) diff --git a/src/comfyui/image-blur-anything-sam2.json b/src/comfyui/image-blur-anything-sam2.json new file mode 100644 index 000000000..624df0b08 --- /dev/null +++ b/src/comfyui/image-blur-anything-sam2.json @@ -0,0 +1,84 @@ +{ + "action_icon": "ai-action-smooth.svg", + "menu_category": "enhance", + "menu_order": 61, + "name": "Blur Anything (Image)...", + "open_dialog": true, + "output_type": "image", + "template_id": "image-blur-anything-sam2", + "workflow": { + "1": { + "class_type": "LoadImage", + "inputs": { + "image": "__openshot_input__", + "upload": "image" + } + }, + "2": { + "class_type": "DownloadAndLoadSAM2Model", + "inputs": { + "model": "sam2.1_hiera_small.safetensors", + "segmentor": "single_image", + "device": "cuda", + "precision": "fp16" + } + }, + "3": { + "class_type": "Sam2Segmentation", + "inputs": { + "sam2_model": [ + "2", + 0 + ], + "image": [ + "1", + 0 + ], + "coordinates_positive": "", + "individual_objects": true, + "keep_model_loaded": false + } + }, + "4": { + "class_type": "ImageBlur", + "inputs": { + "image": [ + "1", + 0 + ], + "blur_radius": 12, + "sigma": 4.0 + } + }, + "5": { + "class_type": "ImageCompositeMasked", + "inputs": { + "destination": [ + "1", + 0 + ], + "source": [ + "4", + 0 + ], + "mask": [ + "3", + 0 + ], + "x": 0, + "y": 0, + "resize_source": false + } + }, + "6": { + "class_type": "SaveImage", + "inputs": { + "images": [ + "5", + 0 + ], + "filename_prefix": "image/openshot_gen" + } + } + } +} diff --git a/src/comfyui/video-blur-anything-sam2.json b/src/comfyui/video-blur-anything-sam2.json new file mode 100644 index 000000000..e62362df7 --- /dev/null +++ b/src/comfyui/video-blur-anything-sam2.json @@ -0,0 +1,124 @@ +{ + "action_icon": "ai-action-smooth.svg", + "menu_category": "enhance", + "menu_order": 60, + "name": "Blur Anything...", + "open_dialog": true, + "output_type": "video", + "template_id": "video-blur-anything-sam2", + "workflow": { + "1": { + "class_type": "VHS_LoadVideo", + "inputs": { + "video": "/tmp/input.mp4", + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1 + } + }, + "2": { + "class_type": "VHS_VideoInfoLoaded", + "inputs": { + "video_info": [ + "1", + 3 + ] + } + }, + "3": { + "class_type": "DownloadAndLoadSAM2Model", + "inputs": { + "model": "sam2.1_hiera_small.safetensors", + "segmentor": "video", + "device": "cuda", + "precision": "fp16" + } + }, + "4": { + "class_type": "Sam2VideoSegmentationAddPoints", + "inputs": { + "sam2_model": [ + "3", + 0 + ], + "image": [ + "1", + 0 + ], + "coordinates_positive": "", + "frame_index": 0, + "object_index": 0 + } + }, + "5": { + "class_type": "Sam2VideoSegmentation", + "inputs": { + "sam2_model": [ + "4", + 0 + ], + "inference_state": [ + "4", + 1 + ], + "keep_model_loaded": false + } + }, + "6": { + "class_type": "ImageBlur", + "inputs": { + "image": [ + "1", + 0 + ], + "blur_radius": 12, + "sigma": 4.0 + } + }, + "7": { + "class_type": "ImageCompositeMasked", + "inputs": { + "destination": [ + "1", + 0 + ], + "source": [ + "6", + 0 + ], + "mask": [ + "5", + 0 + ], + "x": 0, + "y": 0, + "resize_source": false + } + }, + "8": { + "class_type": "VHS_VideoCombine", + "inputs": { + "images": [ + "7", + 0 + ], + "frame_rate": [ + "2", + 0 + ], + "loop_count": 0, + "filename_prefix": "video/openshot_gen", + "format": "video/h264-mp4", + "pingpong": false, + "save_output": true, + "audio": [ + "1", + 2 + ] + } + } + } +} diff --git a/src/windows/cutting.py b/src/windows/cutting.py index 52dee73c6..d6009721b 100644 --- a/src/windows/cutting.py +++ b/src/windows/cutting.py @@ -75,6 +75,13 @@ def __init__(self, file=None, preview=False): # Init UI ui_util.init_ui(self) + self.setWindowFlags( + (self.windowFlags() & ~Qt.Dialog) + | Qt.Window + | Qt.WindowMinMaxButtonsHint + | Qt.WindowMaximizeButtonHint + ) + self.setSizeGripEnabled(True) # Track metrics track_metric_screen("cutting-screen") diff --git a/src/windows/generate.py b/src/windows/generate.py index 684df46cb..6b5d632c3 100644 --- a/src/windows/generate.py +++ b/src/windows/generate.py @@ -26,16 +26,19 @@ """ import os +import json from PyQt5.QtCore import Qt from PyQt5.QtGui import QIcon, QPixmap from PyQt5.QtWidgets import ( QDialog, QVBoxLayout, QHBoxLayout, QFormLayout, QLabel, QLineEdit, - QComboBox, QTextEdit, QTabWidget, QWidget, QPushButton + QComboBox, QTextEdit, QTabWidget, QWidget, QPushButton, QMessageBox ) from classes import info +from classes.logger import log from classes.thumbnail import GetThumbPath +from windows.region import SelectRegion class GenerateMediaDialog(QDialog): @@ -56,6 +59,8 @@ def __init__( self.source_file = source_file self.templates = templates or [] self.preselected_template_id = str(preselected_template_id or "").strip() + self._coordinates_positive_text = "" + self._coordinates_negative_text = "" self.setObjectName("generateDialog") self.setWindowTitle(str(dialog_title or "AI Tools")) self.setMinimumWidth(620) @@ -69,9 +74,10 @@ def __init__( self.tabs = QTabWidget(self) self.tabs.setObjectName("generateTabs") - self.tabs.addTab(self._build_prompt_tab(), "Prompt") - self.tabs.addTab(self._build_mask_tab(), "Mask") - self.tabs.addTab(self._build_advanced_tab(), "Advanced") + self.page_prompt = self._build_prompt_tab() + self.page_points = self._build_points_tab() + self.prompt_tab_index = self.tabs.addTab(self.page_prompt, "Prompt") + self.points_tab_index = self.tabs.addTab(self.page_points, "Points") root.addWidget(self.tabs, 1) button_row = QHBoxLayout() @@ -86,11 +92,32 @@ def __init__( root.addLayout(button_row) self._apply_dialog_theme() + def _current_coordinates_text(self): + coordinates_positive = str(self._coordinates_positive_text or "").strip() + coordinates_negative = str(self._coordinates_negative_text or "").strip() + if not coordinates_positive and hasattr(self, "points_preview"): + preview_text = self.points_preview.toPlainText().strip() + if preview_text.startswith("{"): + try: + payload = json.loads(preview_text.replace("'", "\"")) + coordinates_positive = str(payload.get("positive", "")).strip() or coordinates_positive + coordinates_negative = str(payload.get("negative", "")).strip() or coordinates_negative + except Exception: + pass + prompt_text = self.prompt_edit.toPlainText().strip() + # Backward-compatible fallback: if prompt itself contains point JSON, treat it as coordinates. + if (not coordinates_positive) and prompt_text.startswith("[") and ("\"x\"" in prompt_text or "'x'" in prompt_text): + coordinates_positive = prompt_text + return coordinates_positive, coordinates_negative, prompt_text + def get_payload(self): + coordinates_positive, coordinates_negative, prompt_text = self._current_coordinates_text() return { "name": self.name_edit.text().strip(), "template_id": self.template_combo.currentData() or self.template_combo.currentText(), - "prompt": self.prompt_edit.toPlainText().strip(), + "prompt": prompt_text, + "coordinates_positive": coordinates_positive, + "coordinates_negative": coordinates_negative, } def _build_top_block(self): @@ -130,6 +157,7 @@ def _build_top_block(self): index = self.template_combo.findData(self.preselected_template_id) if index >= 0: self.template_combo.setCurrentIndex(index) + self.template_combo.currentIndexChanged.connect(self._on_template_changed) setup_form.addRow("Template", self.template_combo) if self.source_file: @@ -154,25 +182,31 @@ def _build_prompt_tab(self): layout.addWidget(self.prompt_edit) return tab - def _build_mask_tab(self): + def _build_points_tab(self): tab = QWidget(self) - tab.setObjectName("pageMask") + tab.setObjectName("pagePoints") layout = QVBoxLayout(tab) layout.setContentsMargins(8, 8, 8, 8) - label = QLabel("Mask tools will appear for templates that support drawing.") - label.setWordWrap(True) - layout.addWidget(label) - layout.addStretch(1) - return tab - - def _build_advanced_tab(self): - tab = QWidget(self) - tab.setObjectName("pageAdvanced") - layout = QVBoxLayout(tab) - layout.setContentsMargins(8, 8, 8, 8) - label = QLabel("Advanced controls are template-driven and will appear here.") - label.setWordWrap(True) - layout.addWidget(label) + self.mask_hint = QLabel( + "Select one or more tracking points on the source frame." + ) + self.mask_hint.setWordWrap(True) + layout.addWidget(self.mask_hint) + + controls = QHBoxLayout() + self.pick_points_button = QPushButton("Pick Point(s) on Source") + self.clear_points_button = QPushButton("Clear") + self.pick_points_button.clicked.connect(self._pick_points_clicked) + self.clear_points_button.clicked.connect(self._clear_points_clicked) + controls.addWidget(self.pick_points_button) + controls.addWidget(self.clear_points_button) + controls.addStretch(1) + layout.addLayout(controls) + + self.points_preview = QTextEdit() + self.points_preview.setReadOnly(True) + self.points_preview.setMinimumHeight(90) + layout.addWidget(self.points_preview) layout.addStretch(1) return tab @@ -200,8 +234,111 @@ def _on_generate_clicked(self): if not self.name_edit.text().strip(): self.name_edit.setFocus(Qt.TabFocusReason) return + if self._is_sam2_point_template(): + coordinates_positive, _coordinates_negative, _prompt_text = self._current_coordinates_text() + if not coordinates_positive: + QMessageBox.warning( + self, + "Missing Points", + "No SAM2 points were provided. Use the Points tab and click Pick Point(s) on Source.", + ) + self.tabs.setCurrentWidget(self.page_points) + return self.accept() + def _is_sam2_point_template(self): + template_id = str(self.template_combo.currentData() or "").strip().lower() + return "sam2" in template_id and "blur-anything" in template_id + + def _on_template_changed(self, index): + _ = index + is_point_template = self._is_sam2_point_template() + self._set_tab_visible(self.prompt_tab_index, not is_point_template) + self._set_tab_visible(self.points_tab_index, is_point_template) + self.pick_points_button.setEnabled(bool(self.source_file) and is_point_template) + self.clear_points_button.setEnabled(is_point_template) + if is_point_template: + self.mask_hint.setText( + "Select one or more tracking points on the source frame." + ) + self.tabs.setCurrentWidget(self.page_points) + else: + self.mask_hint.setText( + "Point selection is available for SAM2 Blur Anything templates." + ) + self.tabs.setCurrentWidget(self.page_prompt) + + def _pick_points_clicked(self): + if not self.source_file: + return + + win = SelectRegion(file=self.source_file, clip=None, selection_mode="point") + if win.exec_() != QDialog.Accepted: + return + + raw_points_pos = win.selected_points() + raw_points_neg = win.selected_points_negative() + log.info( + "Generate dialog captured raw SAM2 points positive=%s negative=%s", + len(raw_points_pos or []), + len(raw_points_neg or []), + ) + points_pos = [] + points_neg = [] + frame_size = win.videoPreview.curr_frame_size + if not frame_size: + frame_w = float(max(win.viewport_rect.width(), 1)) + frame_h = float(max(win.viewport_rect.height(), 1)) + else: + frame_w = float(max(frame_size.width(), 1)) + frame_h = float(max(frame_size.height(), 1)) + for point in raw_points_pos: + x_norm = max(min(float(point["x"]), float(max(frame_w - 1.0, 0.0))), 0.0) + y_norm = max(min(float(point["y"]), float(max(frame_h - 1.0, 0.0))), 0.0) + x_abs = int(round((x_norm / frame_w) * float(win.width))) + y_abs = int(round((y_norm / frame_h) * float(win.height))) + points_pos.append({"x": x_abs, "y": y_abs}) + for point in raw_points_neg: + x_norm = max(min(float(point["x"]), float(max(frame_w - 1.0, 0.0))), 0.0) + y_norm = max(min(float(point["y"]), float(max(frame_h - 1.0, 0.0))), 0.0) + x_abs = int(round((x_norm / frame_w) * float(win.width))) + y_abs = int(round((y_norm / frame_h) * float(win.height))) + points_neg.append({"x": x_abs, "y": y_abs}) + + if not points_pos: + QMessageBox.warning( + self, + "No Points Found", + "No positive points were captured. Use Shift+Click to add positive points.", + ) + return + + points_pos_text = json.dumps(points_pos) + points_neg_text = json.dumps(points_neg) if points_neg else "" + log.info( + "Generate dialog normalized SAM2 points positive=%s negative=%s", + len(points_pos), + len(points_neg), + ) + self._coordinates_positive_text = points_pos_text + self._coordinates_negative_text = points_neg_text + self.points_preview.setPlainText( + json.dumps({"positive": points_pos_text, "negative": points_neg_text}, indent=2) + ) + self.tabs.setCurrentWidget(self.page_points) + + def _clear_points_clicked(self): + self._coordinates_positive_text = "" + self._coordinates_negative_text = "" + self.points_preview.clear() + + def _set_tab_visible(self, index, visible): + bar = self.tabs.tabBar() + if hasattr(bar, "setTabVisible"): + bar.setTabVisible(index, bool(visible)) + else: + self.tabs.setTabEnabled(index, bool(visible)) + def _apply_dialog_theme(self): self.setStyleSheet(""" QDialog#generateDialog { @@ -209,8 +346,7 @@ def _apply_dialog_theme(self): color: #91C3FF; } QDialog#generateDialog QTabWidget#generateTabs QWidget#pagePrompt, -QDialog#generateDialog QTabWidget#generateTabs QWidget#pageMask, -QDialog#generateDialog QTabWidget#generateTabs QWidget#pageAdvanced { +QDialog#generateDialog QTabWidget#generateTabs QWidget#pagePoints { background-color: #141923; border: none; } @@ -250,3 +386,4 @@ def _apply_dialog_theme(self): border: 1px solid #53a0ed; } """) + self._on_template_changed(self.template_combo.currentIndex()) diff --git a/src/windows/preferences.py b/src/windows/preferences.py index 5de53ed1c..c6f278483 100644 --- a/src/windows/preferences.py +++ b/src/windows/preferences.py @@ -43,6 +43,7 @@ from classes import info, ui_util, tabstops from classes import openshot_rc # noqa from classes.app import get_app +from classes.comfy_client import ComfyClient from classes.language import get_all_languages from classes.logger import log from classes.metrics import track_metric_screen @@ -284,6 +285,10 @@ def Populate(self, filter=""): # Add filesystem browser button extraWidget = QPushButton(_("Browse...")) extraWidget.clicked.connect(functools.partial(self.selectExecutable, widget, param)) + elif param.get("setting") == "comfy-ui-url": + # Add an explicit connectivity check for ComfyUI URL. + extraWidget = QPushButton(_("Check")) + extraWidget.clicked.connect(functools.partial(self.check_comfy_ui_url, widget, param)) elif param["type"] == "bool": # create spinner @@ -673,6 +678,55 @@ def text_value_changed(self, widget, param, value=None): # Check for restart self.check_for_restart(param) + def check_comfy_ui_url(self, widget, param): + _ = get_app()._tr + url = str(widget.text() or "").strip().rstrip("/") + if not url: + log.info("ComfyUI URL check failed: empty URL") + QMessageBox.warning(self, _("Comfy UI URL"), _("Comfy UI URL is empty.")) + return + + # Persist normalized URL before validation. + self.s.set(param["setting"], url) + widget.setText(url) + + available = False + error_text = "" + try: + available = ComfyClient(url).ping(timeout=2.0) + except Exception as ex: + error_text = str(ex) + + # Refresh cached availability so context menus update immediately. + try: + if getattr(get_app(), "window", None): + get_app().window.is_comfy_available(force=True) + except Exception: + log.debug("ComfyUI availability cache refresh failed", exc_info=1) + + if available: + log.info("ComfyUI URL check succeeded at %s", url) + QMessageBox.information( + self, + _("Comfy UI URL"), + _("Connection successful. AI menus are enabled."), + ) + else: + if error_text: + log.info("ComfyUI URL check failed at %s (%s)", url, error_text) + message = _("Connection failed: {}").format(error_text) + else: + log.info("ComfyUI URL check failed at %s", url) + message = _("Connection failed.") + QMessageBox.warning( + self, + _("Comfy UI URL"), + "{}\n{}".format( + message, + _("AI menus are disabled until ComfyUI is reachable."), + ), + ) + def dropdown_index_changed(self, widget, param, index): # Save setting value = widget.itemData(index) diff --git a/src/windows/region.py b/src/windows/region.py index f9bb068dd..84d5b26b5 100644 --- a/src/windows/region.py +++ b/src/windows/region.py @@ -60,7 +60,7 @@ class SelectRegion(QDialog): SpeedSignal = pyqtSignal(float) StopSignal = pyqtSignal() - def __init__(self, file=None, clip=None): + def __init__(self, file=None, clip=None, selection_mode="rect"): _ = get_app()._tr # Create dialog class @@ -71,10 +71,23 @@ def __init__(self, file=None, clip=None): # Init UI ui_util.init_ui(self) + self.setWindowFlags( + (self.windowFlags() & ~Qt.Dialog) + | Qt.Window + | Qt.WindowMinMaxButtonsHint + | Qt.WindowMaximizeButtonHint + ) + self.setSizeGripEnabled(True) # Track metrics track_metric_screen("cutting-screen") + self.selection_mode = str(selection_mode or "rect").strip().lower() + if self.selection_mode not in ("rect", "point"): + self.selection_mode = "rect" + self._selected_points = [] + self._selected_points_negative = [] + self.start_frame = 1 self.start_image = None self.end_frame = 1 @@ -82,19 +95,31 @@ def __init__(self, file=None, clip=None): self.current_frame = 1 # Create region clip with Reader - self.clip = openshot.Clip(clip.Reader()) - self.clip.Open() - - # Set region clip start and end - self.clip.Start(clip.Start()) - self.clip.End(clip.End()) - self.clip.Id( get_app().project.generate_id() ) + if clip: + self.clip = openshot.Clip(clip.Reader()) + self.clip.Open() + # Set region clip start and end + self.clip.Start(clip.Start()) + self.clip.End(clip.End()) + else: + source_path = "" + if file: + if hasattr(file, "absolute_path"): + source_path = file.absolute_path() + else: + source_path = str(getattr(file, "data", {}).get("path", "")) + self.clip = openshot.Clip(source_path) + self.clip.Open() + self.clip.Id(get_app().project.generate_id()) # Keep track of file object self.file = file - self.file_path = file.absolute_path() + if file and hasattr(file, "absolute_path"): + self.file_path = file.absolute_path() + else: + self.file_path = str(getattr(file, "data", {}).get("path", "")) - c_info = clip.Reader().info + c_info = self.clip.Reader().info self.fps = c_info.fps.ToInt() self.fps_num = c_info.fps.num self.fps_den = c_info.fps.den @@ -106,16 +131,26 @@ def __init__(self, file=None, clip=None): self.video_length = int(self.clip.Duration() * self.fps) + 1 # Apply effects to region frames - for effect in clip.Effects(): - self.clip.AddEffect(effect) + if clip: + for effect in clip.Effects(): + self.clip.AddEffect(effect) # Open video file with Reader log.info(self.clip.Reader()) + # Set instruction text first so it remains above the preview widget. + if self.selection_mode == "point": + self.lblInstructions.setText( + _("Click to add tracking point (SHIFT+Click for additional points, CTRL+Click for negative point)") + ) + else: + self.lblInstructions.setText(_("Draw a rectangle to select a region of the video frame.")) + # Add Video Widget self.videoPreview = VideoWidget() self.videoPreview.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) - self.verticalLayout.insertWidget(0, self.videoPreview) + self.videoPreview.region_selection_mode = self.selection_mode + self.verticalLayout.insertWidget(1, self.videoPreview) # Set aspect ratio to match source content aspect_ratio = openshot.Fraction(self.width, self.height) @@ -172,7 +207,8 @@ def __init__(self, file=None, clip=None): # Add buttons self.cancel_button = QPushButton(_('Cancel')) - self.process_button = QPushButton(_('Select Region')) + process_label = _('Select Region') if self.selection_mode == "rect" else _('Select Point(s)') + self.process_button = QPushButton(process_label) self.buttonBox.addButton(self.process_button, QDialogButtonBox.AcceptRole) self.buttonBox.addButton(self.cancel_button, QDialogButtonBox.RejectRole) @@ -182,7 +218,7 @@ def __init__(self, file=None, clip=None): self.sliderVideo.valueChanged.connect(self.sliderVideo_valueChanged) self.initialized = True - get_app().window.SelectRegionSignal.emit(clip.Id()) + get_app().window.SelectRegionSignal.emit(self.clip.Id()) def actionPlay_Triggered(self): # Trigger play button (This action is invoked from the preview thread, so it must exist here) @@ -253,7 +289,13 @@ def accept(self): self.sliderVideo.setValue(self.sliderVideo.minimum()) return + if self.selection_mode == "point" and not self.videoPreview.region_points_positive: + QMessageBox.warning(self, _("Invalid Selection"), _("Please select at least one point.")) + return + # Continue with the rest of the accept method + self._selected_points = self.selected_points() + self._selected_points_negative = self.selected_points_negative() self.shutdownPlayer() get_app().window.SelectRegionSignal.emit("") super(SelectRegion, self).accept() @@ -279,4 +321,18 @@ def reject(self): get_app().window.SelectRegionSignal.emit("") super(SelectRegion, self).reject() - + def selected_points(self): + if self._selected_points: + return list(self._selected_points) + points = [] + for point in getattr(self.videoPreview, "region_points_positive", []) or []: + points.append({"x": float(point.x()), "y": float(point.y())}) + return points + + def selected_points_negative(self): + if self._selected_points_negative: + return list(self._selected_points_negative) + points = [] + for point in getattr(self.videoPreview, "region_points_negative", []) or []: + points.append({"x": float(point.x()), "y": float(point.y())}) + return points diff --git a/src/windows/video_widget.py b/src/windows/video_widget.py index 3bf4ff826..7e27d8080 100644 --- a/src/windows/video_widget.py +++ b/src/windows/video_widget.py @@ -660,7 +660,27 @@ def paintEvent(self, event, *args): painter.setTransform(self.region_transform) cs = self.cs - if self.regionTopLeftHandle and self.regionBottomRightHandle: + if self.region_selection_mode == "point": + point_radius = max(2.0, (cs * 0.4) / max(self.zoom, 0.001)) + if self.region_points_positive: + pos_color = QColor("#53a0ed") + pos_color.setAlphaF(self.handle_opacity) + pos_pen = QPen(QBrush(pos_color), 1.5) + pos_pen.setCosmetic(True) + painter.setPen(pos_pen) + painter.setBrush(QBrush(pos_color)) + for pt in self.region_points_positive: + painter.drawEllipse(pt, point_radius, point_radius) + if self.region_points_negative: + neg_color = QColor("#e05757") + neg_color.setAlphaF(self.handle_opacity) + neg_pen = QPen(QBrush(neg_color), 1.5) + neg_pen.setCosmetic(True) + painter.setPen(neg_pen) + painter.setBrush(QBrush(neg_color)) + for pt in self.region_points_negative: + painter.drawEllipse(pt, point_radius, point_radius) + elif self.regionTopLeftHandle and self.regionBottomRightHandle: color = QColor("#53a0ed") color.setAlphaF(self.handle_opacity) pen = QPen(QBrush(color), 1.5) @@ -737,6 +757,21 @@ def mousePressEvent(self, event): self.rotation_drag_value = None self.setCursor(self.hover_cursor) + if self.region_enabled and self.region_selection_mode == "point" and event.button() == Qt.LeftButton: + self._ensure_region_transform() + point = self.region_transform_inverted.map(event.pos()) + point = self._clamp_region_point(point) + mods = int(QCoreApplication.instance().keyboardModifiers()) + if mods & Qt.ControlModifier: + self.region_points_negative.append(point) + elif mods & Qt.ShiftModifier: + self.region_points_positive.append(point) + else: + # Default click resets to a single positive point. + self.region_points_positive = [point] + self.region_points_negative = [] + self.update() + # Ignore undo/redo history temporarily (to avoid a huge pile of undo/redo history) get_app().updates.ignore_history = True @@ -765,7 +800,7 @@ def mouseReleaseEvent(self, event): # Save region image data (as QImage) # This can be used other widgets to display the selected region - if self.region_enabled: + if self.region_enabled and self.region_selection_mode != "point": # Get region coordinates region_rect = QRectF( self.regionTopLeftHandle.x(), @@ -1189,6 +1224,12 @@ def mouseMoveEvent(self, event): self.update() if self.region_enabled: + if self.region_selection_mode == "point": + self.setCursor(Qt.CrossCursor) + self.mouse_position = event.pos() + self.mutex.unlock() + return + # Modify region selection (x, y, width, height) # Corner size cs = self.cs @@ -1579,6 +1620,30 @@ def updateClipProperty(self, clip_id, frame_number, property_key, new_value, ref if refresh: get_app().window.refreshFrameSignal.emit() + def _ensure_region_transform(self): + if self.region_transform: + return + viewport = self.centeredViewport(self.width(), self.height()) + self.region_transform = QTransform() + rx = viewport.x() + ry = viewport.y() + if rx or ry: + self.region_transform.translate(rx, ry) + if self.zoom: + self.region_transform.scale(self.zoom, self.zoom) + self.region_transform_inverted = self.region_transform.inverted()[0] + + def _clamp_region_point(self, point): + max_w = float(self.curr_frame_size.width()) if self.curr_frame_size else 0.0 + max_h = float(self.curr_frame_size.height()) if self.curr_frame_size else 0.0 + if max_w <= 0.0 or max_h <= 0.0: + viewport = self.centeredViewport(self.width(), self.height()) + max_w = float(viewport.width()) / max(self.zoom, 0.001) + max_h = float(viewport.height()) / max(self.zoom, 0.001) + x = min(max(float(point.x()), 0.0), max(max_w - 1.0, 0.0)) + y = min(max(float(point.y()), 0.0), max(max_h - 1.0, 0.0)) + return QPointF(x, y) + def updateEffectProperty(self, effect_id, frame_number, obj_id, property_key, new_value, refresh=True): """Update a keyframe property to a new value, adding or updating keyframes as needed""" found_point = False @@ -1919,6 +1984,12 @@ def regionTriggered(self, clip_id): """Handle the 'select region' signal when it's emitted""" # Clear transform self.region_enabled = bool(clip_id) + if not self.region_enabled: + self.region_points = [] + self.region_points_positive = [] + self.region_points_negative = [] + self.regionTopLeftHandle = None + self.regionBottomRightHandle = None get_app().window.refreshFrameSignal.emit() self.update_title() @@ -2029,7 +2100,12 @@ def __init__(self, watch_project=True, *args): self.original_effect_data = None self.region_qimage = None self.region_transform = None + self.region_transform_inverted = None self.region_enabled = False + self.region_selection_mode = "rect" + self.region_points = [] + self.region_points_positive = [] + self.region_points_negative = [] self.region_mode = None self.regionTopLeftHandle = None self.regionBottomRightHandle = None diff --git a/src/windows/views/ai_tools_menu.py b/src/windows/views/ai_tools_menu.py index 04070e89c..1cbb17409 100644 --- a/src/windows/views/ai_tools_menu.py +++ b/src/windows/views/ai_tools_menu.py @@ -30,6 +30,9 @@ def _icon(name): def add_ai_tools_menu(win, parent_menu, source_file=None): _ = get_app()._tr + if not win.is_comfy_available(force=False): + return None + grouped = win.generation_service.build_menu_templates(source_file=source_file) menu_defs = [] if source_file: diff --git a/src/windows/views/files_listview.py b/src/windows/views/files_listview.py index c43c7e410..40074d246 100644 --- a/src/windows/views/files_listview.py +++ b/src/windows/views/files_listview.py @@ -129,7 +129,9 @@ def paint(self, painter, option, index): pad_y = 2 badge_w = text_w + (pad_x * 2) badge_h = text_h + (pad_y * 2) - badge_rect = deco_rect.adjusted(3, 3, 0, 0) + badge_bottom = full_rect.top() - 3 + badge_top = max(deco_rect.top() + 3, badge_bottom - badge_h + 1) + badge_rect = deco_rect.adjusted(3, badge_top - deco_rect.top(), 0, 0) badge_rect.setWidth(badge_w) badge_rect.setHeight(badge_h) painter.setBrush(QColor(18, 22, 30, 220)) diff --git a/src/windows/views/files_treeview.py b/src/windows/views/files_treeview.py index 2067dadfd..ed7a2d596 100644 --- a/src/windows/views/files_treeview.py +++ b/src/windows/views/files_treeview.py @@ -125,7 +125,9 @@ def paint(self, painter, option, index): pad_y = 2 badge_w = text_w + (pad_x * 2) badge_h = text_h + (pad_y * 2) - badge_rect = deco_rect.adjusted(3, 3, 0, 0) + badge_bottom = full_rect.top() - 3 + badge_top = max(deco_rect.top() + 3, badge_bottom - badge_h + 1) + badge_rect = deco_rect.adjusted(3, badge_top - deco_rect.top(), 0, 0) badge_rect.setWidth(badge_w) badge_rect.setHeight(badge_h) painter.setBrush(QColor(18, 22, 30, 220)) From 84b606307a2c12c365c4ffb4661f50b070a7502b Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Wed, 18 Feb 2026 11:55:53 -0600 Subject: [PATCH 12/27] - Harden Comfy progress updates after sleep/wake: detect stale WS, reconnect with backoff (60s -> 300s), and fallback to /progress when WS is quiet. - Fix template input rewriting for multi-loader workflows by matching both file/video (and audio/file) keys. - Tune Blur Anything defaults: switch to sam2.1_hiera_tiny, keep model loaded, reduce batch size to 32. --- src/classes/generation_queue.py | 32 ++++++- src/classes/generation_service.py | 17 +++- src/comfyui/video-blur-anything-sam2.json | 112 ++++++++++++++++++++-- 3 files changed, 145 insertions(+), 16 deletions(-) diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index 96857c7b7..5a51c445c 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -178,6 +178,9 @@ def _run_comfy_job(self, job_id, request): accepted_progress_started = False ws_retry_delay_s = 2.0 ws_next_retry_at = start_time + ws_last_progress_time = start_time + ws_stale_reconnect_s = 60.0 + ws_stale_reconnect_max_s = 300.0 prompt_key = str(prompt_id) while True: @@ -302,6 +305,7 @@ def _run_comfy_job(self, job_id, request): # Query ComfyUI's live progress values when available. try: + ws_progress_emitted = False now = monotonic() if ws_client is None and now >= ws_next_retry_at: try: @@ -392,8 +396,34 @@ def _run_comfy_job(self, job_id, request): progress, ) self.progress_changed.emit(job_id, progress) + ws_progress_emitted = True + ws_last_progress_time = monotonic() + ws_stale_reconnect_s = 60.0 last_contact_time = monotonic() - if ws_client is None: + if ws_client is not None and not ws_progress_emitted: + stale_for = now - ws_last_progress_time + if stale_for >= ws_stale_reconnect_s: + try: + ws_client.close() + except Exception: + pass + ws_client = None + ws_next_retry_at = now + ws_retry_delay_s + ws_retry_delay_s = min(60.0, ws_retry_delay_s * 1.5) + next_stale_reconnect_s = min( + ws_stale_reconnect_max_s, + max(60.0, ws_stale_reconnect_s * 1.5), + ) + log.debug( + "Comfy websocket stalled for job=%s prompt=%s (%.1fs >= %.1fs); forcing reconnect, next stall timeout %.1fs", + job_id, + prompt_key, + stale_for, + ws_stale_reconnect_s, + next_stale_reconnect_s, + ) + ws_stale_reconnect_s = next_stale_reconnect_s + if ws_client is None or not ws_progress_emitted: progress_data = client.progress() if progress_data is None: if not progress_endpoint_unavailable: diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index db186cb08..5f2532cc5 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -327,7 +327,9 @@ def _parse_sam2_points(coords_text): continue return points - def _select_bind_nodes(node_ids, path_key, preferred_upload=None): + def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): + if isinstance(path_keys, str): + path_keys = [path_keys] explicit = [] candidates = [] for node_id in node_ids: @@ -335,7 +337,12 @@ def _select_bind_nodes(node_ids, path_key, preferred_upload=None): inputs = node.get("inputs", {}) if isinstance(node, dict) else {} if not isinstance(inputs, dict): continue - path_value = str(inputs.get(path_key, "")).strip() + path_value = "" + for path_key in path_keys: + candidate_value = str(inputs.get(path_key, "")).strip() + if candidate_value: + path_value = candidate_value + break upload_value = str(inputs.get("upload", "")).strip().lower() if _is_placeholder_value(path_value): explicit.append(node_id) @@ -353,9 +360,9 @@ def _select_bind_nodes(node_ids, path_key, preferred_upload=None): return {node_ids[0]} return set() - image_bind_nodes = _select_bind_nodes(loadimage_node_ids, "image", preferred_upload="image") - video_bind_nodes = _select_bind_nodes(loadvideo_node_ids, "file") - audio_bind_nodes = _select_bind_nodes(loadaudio_node_ids, "audio") + image_bind_nodes = _select_bind_nodes(loadimage_node_ids, ["image"], preferred_upload="image") + video_bind_nodes = _select_bind_nodes(loadvideo_node_ids, ["file", "video"]) + audio_bind_nodes = _select_bind_nodes(loadaudio_node_ids, ["audio", "file"]) for node_id, node in workflow.items(): if not isinstance(node, dict): diff --git a/src/comfyui/video-blur-anything-sam2.json b/src/comfyui/video-blur-anything-sam2.json index e62362df7..ef58f3c0a 100644 --- a/src/comfyui/video-blur-anything-sam2.json +++ b/src/comfyui/video-blur-anything-sam2.json @@ -10,7 +10,7 @@ "1": { "class_type": "VHS_LoadVideo", "inputs": { - "video": "/tmp/input.mp4", + "video": "__openshot_input__", "force_rate": 0, "custom_width": 0, "custom_height": 0, @@ -31,7 +31,7 @@ "3": { "class_type": "DownloadAndLoadSAM2Model", "inputs": { - "model": "sam2.1_hiera_small.safetensors", + "model": "sam2.1_hiera_tiny.safetensors", "segmentor": "video", "device": "cuda", "precision": "fp16" @@ -64,33 +64,121 @@ "4", 1 ], - "keep_model_loaded": false + "keep_model_loaded": true } }, "6": { + "class_type": "MaskToImage", + "inputs": { + "mask": [ + "5", + 0 + ] + } + }, + "7": { + "class_type": "VHS_VideoCombine", + "inputs": { + "images": [ + "6", + 0 + ], + "frame_rate": [ + "2", + 0 + ], + "loop_count": 0, + "filename_prefix": "video/openshot_mask", + "format": "video/ffv1-mkv", + "pingpong": false, + "save_output": true + } + }, + "8": { + "class_type": "VHS_SelectFilename", + "inputs": { + "filenames": [ + "7", + 0 + ], + "index": -1 + } + }, + "9": { + "class_type": "VHS_BatchManager", + "inputs": { + "frames_per_batch": 32 + } + }, + "10": { + "class_type": "VHS_LoadVideo", + "inputs": { + "video": "__openshot_input__", + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "meta_batch": [ + "9", + 0 + ] + } + }, + "11": { + "class_type": "VHS_LoadVideoPath", + "inputs": { + "video": [ + "8", + 0 + ], + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "meta_batch": [ + "9", + 0 + ] + } + }, + "12": { + "class_type": "ImageToMask", + "inputs": { + "image": [ + "11", + 0 + ], + "channel": "red" + } + }, + "13": { "class_type": "ImageBlur", "inputs": { "image": [ - "1", + "10", 0 ], "blur_radius": 12, "sigma": 4.0 } }, - "7": { + "14": { "class_type": "ImageCompositeMasked", "inputs": { "destination": [ - "1", + "10", 0 ], "source": [ - "6", + "13", 0 ], "mask": [ - "5", + "12", 0 ], "x": 0, @@ -98,11 +186,11 @@ "resize_source": false } }, - "8": { + "15": { "class_type": "VHS_VideoCombine", "inputs": { "images": [ - "7", + "14", 0 ], "frame_rate": [ @@ -117,6 +205,10 @@ "audio": [ "1", 2 + ], + "meta_batch": [ + "9", + 0 ] } } From 0c5eefc4c49dd9c329eeef70a634edd9cc4fd6f4 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Fri, 20 Feb 2026 14:29:21 -0600 Subject: [PATCH 13/27] - Fix Comfy progress tracking for meta-batch follow-up prompts and cleaner WS event selection. - Add queue progress metadata (progress_detail, sub_progress) and improve fallback behavior when /progress is unavailable. - Switch blur-anything SAM2 workflow to OpenShot windowed/chunked nodes with masked blur and 96-frame batching. --- src/classes/comfy_client.py | 39 +++++-- src/classes/comfy_templates.py | 6 ++ src/classes/generation_queue.py | 122 ++++++++++++++++++---- src/classes/generation_service.py | 16 ++- src/comfyui/image-blur-anything-sam2.json | 5 +- src/comfyui/video-blur-anything-sam2.json | 47 ++++++--- src/windows/models/files_model.py | 6 ++ 7 files changed, 190 insertions(+), 51 deletions(-) diff --git a/src/classes/comfy_client.py b/src/classes/comfy_client.py index cc7a7a3cc..7b2ece3d2 100644 --- a/src/classes/comfy_client.py +++ b/src/classes/comfy_client.py @@ -100,12 +100,18 @@ def close(self): pass self.sock = None - def poll_progress(self, prompt_id, max_messages=8): - """Read available frames and return latest progress payload for prompt_id.""" + def poll_progress(self, prompt_id=None, max_messages=8): + """Read available frames and return latest progress payload. + + If prompt_id is provided, events are filtered to that prompt id. + If prompt_id is None/empty, events from any prompt on this websocket + are accepted (useful for meta-batch follow-up prompts). + """ if not self.sock: return None latest = None - prompt_key = str(prompt_id) + latest_rank = None + prompt_key = str(prompt_id or "").strip() for _ in range(max_messages): frame = self._recv_frame_nonblocking() if frame is None: @@ -134,30 +140,38 @@ def poll_progress(self, prompt_id, max_messages=8): if not isinstance(event_data, dict): continue event_prompt = str(event_data.get("prompt_id", "")) - if not event_prompt or event_prompt != prompt_key: + if prompt_key and (not event_prompt or event_prompt != prompt_key): continue value = float(event_data.get("value", 0.0)) maximum = float(event_data.get("max", 0.0)) if maximum > 0: - latest = { + candidate = { "percent": int(max(0, min(99, round((value / maximum) * 100.0)))), "value": value, "max": maximum, "node": str(event_data.get("node", "")), "type": "progress", + "prompt_id": event_prompt, } + # Prefer unfinished progress, and prefer explicit "progress" events. + unfinished = (value + 1e-6) < maximum + rank = (1 if unfinished else 0, 2, maximum) + if latest is None or rank > latest_rank: + latest = candidate + latest_rank = rank elif event_type == "progress_state": # Newer Comfy events: data={prompt_id, nodes={node_id:{value,max}}} if not isinstance(event_data, dict): continue event_prompt = str(event_data.get("prompt_id", "")) - if not event_prompt or event_prompt != prompt_key: + if prompt_key and (not event_prompt or event_prompt != prompt_key): continue nodes = event_data.get("nodes", {}) if not isinstance(nodes, dict): continue - # Pick node state with the largest max to avoid setup-node 1/1 spikes. + # Prefer unfinished node progress; only fall back to completed states. best = None + best_rank = None for node_id, node_state in nodes.items(): if not isinstance(node_state, dict): continue @@ -170,11 +184,18 @@ def poll_progress(self, prompt_id, max_messages=8): "max": maximum, "node": str(node_id), "type": "progress_state", + "prompt_id": event_prompt, } - if best is None or maximum > float(best.get("max", 0.0)): + unfinished = (value + 1e-6) < maximum + rank = (1 if unfinished else 0, maximum) + if best is None or rank > best_rank: best = candidate + best_rank = rank if best is not None: - latest = best + rank = (best_rank[0], 1, float(best.get("max", 0.0))) + if latest is None or rank > latest_rank: + latest = best + latest_rank = rank return latest def _recv_http_headers(self): diff --git a/src/classes/comfy_templates.py b/src/classes/comfy_templates.py index 9aa8b7f28..48df519c6 100644 --- a/src/classes/comfy_templates.py +++ b/src/classes/comfy_templates.py @@ -102,6 +102,12 @@ "sam2autosegmentation", "sam2videosegmentationaddpoints", "sam2videosegmentation", + # OpenShot-ComfyUI (custom SAM2) + "openshotdownloadandloadsam2model", + "openshotsam2segmentation", + "openshotsam2videosegmentationaddpoints", + "openshotsam2videosegmentationchunked", + "openshotimageblurmasked", } diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index 5a51c445c..d9702425e 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -40,6 +40,8 @@ class _GenerationWorker(QObject): """Background worker that simulates generation progress for queued jobs.""" progress_changed = pyqtSignal(str, int) + progress_detail_changed = pyqtSignal(str, str) + progress_sub_changed = pyqtSignal(str, int) job_finished = pyqtSignal(str, bool, bool, str, object) def __init__(self): @@ -182,6 +184,8 @@ def _run_comfy_job(self, job_id, request): ws_stale_reconnect_s = 60.0 ws_stale_reconnect_max_s = 300.0 prompt_key = str(prompt_id) + last_progress_signature = None + last_progress_detail = "" while True: if self._is_cancel_requested(job_id, cancel_event): @@ -328,7 +332,8 @@ def _run_comfy_job(self, job_id, request): if ws_client is not None: try: - progress_event = ws_client.poll_progress(prompt_id) + # Accept progress from follow-up prompts as well (meta-batch). + progress_event = ws_client.poll_progress(prompt_id=None) except Exception: progress_event = None try: @@ -385,17 +390,35 @@ def _run_comfy_job(self, job_id, request): ) else: accepted_progress_started = True - log.debug( - "Comfy WS progress emit job=%s prompt=%s node=%s type=%s value=%s max=%s percent=%s", - job_id, - prompt_key, - progress_node, + progress_signature = ( progress_type, - raw_value, - raw_max, - progress, + progress_node, + int(progress), + round(raw_value, 3), + round(raw_max, 3), ) - self.progress_changed.emit(job_id, progress) + if progress_signature != last_progress_signature: + inferred_progress = int(max(0, min(99, progress))) + detail_text = "" + if progress_node: + detail_text = "node {} {}%".format(progress_node, int(progress)) + + log.debug( + "Comfy WS progress emit job=%s prompt=%s node=%s type=%s value=%s max=%s percent=%s", + job_id, + prompt_key, + progress_node, + progress_type, + raw_value, + raw_max, + inferred_progress, + ) + self.progress_changed.emit(job_id, inferred_progress) + self.progress_sub_changed.emit(job_id, int(max(0, min(99, progress)))) + if detail_text != last_progress_detail: + self.progress_detail_changed.emit(job_id, detail_text) + last_progress_detail = detail_text + last_progress_signature = progress_signature ws_progress_emitted = True ws_last_progress_time = monotonic() ws_stale_reconnect_s = 60.0 @@ -423,12 +446,15 @@ def _run_comfy_job(self, job_id, request): next_stale_reconnect_s, ) ws_stale_reconnect_s = next_stale_reconnect_s - if ws_client is None or not ws_progress_emitted: + # Use HTTP /progress only when websocket progress is unavailable. + # If websocket is connected but temporarily quiet, keep waiting for WS + # instead of spamming a misleading 404 fallback warning. + if ws_client is None: progress_data = client.progress() if progress_data is None: if not progress_endpoint_unavailable: log.debug( - "Comfy progress endpoint unavailable (404). Progress bar updates disabled for this job=%s", + "Comfy progress endpoint unavailable (404); waiting for websocket progress for job=%s", job_id, ) progress_endpoint_unavailable = True @@ -461,15 +487,22 @@ def _run_comfy_job(self, job_id, request): if maximum > 0 and prompt_matches: progress = int(max(0, min(99, round((value / maximum) * 100.0)))) - log.debug( - "Comfy progress emit job=%s prompt=%s value=%s max=%s percent=%s", - job_id, - prompt_key, - value, - maximum, - progress, - ) - self.progress_changed.emit(job_id, progress) + progress_signature = ("poll", "", int(progress), round(value, 3), round(maximum, 3)) + if progress_signature != last_progress_signature: + log.debug( + "Comfy progress emit job=%s prompt=%s value=%s max=%s percent=%s", + job_id, + prompt_key, + value, + maximum, + progress, + ) + self.progress_changed.emit(job_id, progress) + self.progress_sub_changed.emit(job_id, int(max(0, min(99, progress)))) + if last_progress_detail: + self.progress_detail_changed.emit(job_id, "") + last_progress_detail = "" + last_progress_signature = progress_signature last_contact_time = monotonic() except Exception: # Keep polling history and queue even if /progress is unavailable. @@ -568,6 +601,8 @@ def __init__(self, parent=None): self._run_job.connect(self._worker.run_job) self._cancel_job.connect(self._worker.cancel_job) self._worker.progress_changed.connect(self._on_progress_changed) + self._worker.progress_detail_changed.connect(self._on_progress_detail_changed) + self._worker.progress_sub_changed.connect(self._on_progress_sub_changed) self._worker.job_finished.connect(self._on_job_finished) self._thread.start() @@ -588,6 +623,8 @@ def enqueue(self, name, template_id, prompt, source_file_id=None, request=None): "source_file_id": source_file_id, "status": "queued", "progress": 0, + "sub_progress": 0, + "progress_detail": "", "error": "", "request": job_request, "cancel_event": cancel_event, @@ -694,16 +731,26 @@ def get_file_badge(self, source_file_id): status = job.get("status") progress = int(job.get("progress", 0)) + sub_progress = int(job.get("sub_progress", 0)) + detail = str(job.get("progress_detail", "") or "").strip() if status == "queued": label = "Queued" elif status == "running": label = "Generating {}%".format(progress) + if detail: + label = "{} ({})".format(label, detail) elif status == "canceling": label = "Canceling..." else: label = status.capitalize() - return {"status": status, "progress": progress, "label": label, "job_id": job.get("id")} + return { + "status": status, + "progress": progress, + "sub_progress": sub_progress, + "label": label, + "job_id": job.get("id"), + } def shutdown(self): if self._thread.isRunning(): @@ -725,6 +772,7 @@ def _start_next_if_idle(self): self._running_job_id = next_job_id job["status"] = "running" job["progress"] = int(job.get("progress", 0)) + job["sub_progress"] = int(job.get("sub_progress", 0)) self.job_updated.emit(next_job_id, "running", int(job["progress"])) self._emit_file_changed(job.get("source_file_id", "")) self.queue_changed.emit() @@ -752,6 +800,36 @@ def _on_progress_changed(self, job_id, progress): self._emit_file_changed(job.get("source_file_id", "")) self.queue_changed.emit() + @pyqtSlot(str, str) + def _on_progress_detail_changed(self, job_id, detail): + job = self.jobs.get(job_id) + if not job: + return + if job.get("status") not in ("running", "canceling"): + return + detail_text = str(detail or "").strip() + if str(job.get("progress_detail", "") or "") == detail_text: + return + job["progress_detail"] = detail_text + self.job_updated.emit(job_id, job.get("status"), int(job.get("progress", 0))) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + + @pyqtSlot(str, int) + def _on_progress_sub_changed(self, job_id, progress): + job = self.jobs.get(job_id) + if not job: + return + if job.get("status") not in ("running", "canceling"): + return + p = int(max(0, min(99, progress))) + if int(job.get("sub_progress", 0)) == p: + return + job["sub_progress"] = p + self.job_updated.emit(job_id, job.get("status"), int(job.get("progress", 0))) + self._emit_file_changed(job.get("source_file_id", "")) + self.queue_changed.emit() + @pyqtSlot(str, bool, bool, str, object) def _on_job_finished(self, job_id, success, canceled, error, outputs): job = self.jobs.get(job_id) diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 5f2532cc5..e4b8385d0 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -374,6 +374,13 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): class_flat = class_type.lower().strip() + # Resolve generic OpenShot source placeholders in any string input + # (custom nodes may use keys like `video_path` instead of `video`/`file`). + if source_path: + for input_key, input_value in list(inputs.items()): + if isinstance(input_value, str) and _is_placeholder_value(input_value): + inputs[input_key] = source_path + if "filename_prefix" in inputs: prefix_value = str(inputs.get("filename_prefix", "")).strip() if "/" in prefix_value: @@ -402,7 +409,12 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): coords_value = inputs.get("coordinates_positive", None) if ( - class_flat in ("sam2videosegmentationaddpoints", "sam2segmentation") + class_flat in ( + "sam2videosegmentationaddpoints", + "sam2segmentation", + "openshotsam2videosegmentationaddpoints", + "openshotsam2segmentation", + ) and "coordinates_positive" in inputs and isinstance(coords_value, str) ): @@ -411,7 +423,7 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): if ("blur-anything-sam2" in template_id) and (not points): raise ValueError("No SAM2 points were provided. Use Mask > Pick Point(s) on Source.") inputs["coordinates_positive"] = _normalize_sam2_coords_input(coords_text, coords_value) - if class_flat == "sam2segmentation" and "individual_objects" in inputs: + if class_flat in ("sam2segmentation", "openshotsam2segmentation") and "individual_objects" in inputs: # For Blur Anything, treat points as a single combined prompt. # This is more stable with mixed positive/negative points and avoids # per-object mask selection quirks in the current SAM2 single-image node. diff --git a/src/comfyui/image-blur-anything-sam2.json b/src/comfyui/image-blur-anything-sam2.json index 624df0b08..39829377b 100644 --- a/src/comfyui/image-blur-anything-sam2.json +++ b/src/comfyui/image-blur-anything-sam2.json @@ -15,7 +15,7 @@ } }, "2": { - "class_type": "DownloadAndLoadSAM2Model", + "class_type": "OpenShotDownloadAndLoadSAM2Model", "inputs": { "model": "sam2.1_hiera_small.safetensors", "segmentor": "single_image", @@ -24,7 +24,7 @@ } }, "3": { - "class_type": "Sam2Segmentation", + "class_type": "OpenShotSam2Segmentation", "inputs": { "sam2_model": [ "2", @@ -35,7 +35,6 @@ 0 ], "coordinates_positive": "", - "individual_objects": true, "keep_model_loaded": false } }, diff --git a/src/comfyui/video-blur-anything-sam2.json b/src/comfyui/video-blur-anything-sam2.json index ef58f3c0a..f4c1e0f69 100644 --- a/src/comfyui/video-blur-anything-sam2.json +++ b/src/comfyui/video-blur-anything-sam2.json @@ -29,32 +29,31 @@ } }, "3": { - "class_type": "DownloadAndLoadSAM2Model", + "class_type": "OpenShotDownloadAndLoadSAM2Model", "inputs": { - "model": "sam2.1_hiera_tiny.safetensors", + "model": "sam2.1_hiera_base_plus.safetensors", "segmentor": "video", "device": "cuda", "precision": "fp16" } }, "4": { - "class_type": "Sam2VideoSegmentationAddPoints", + "class_type": "OpenShotSam2VideoSegmentationAddPoints", "inputs": { "sam2_model": [ "3", 0 ], - "image": [ - "1", - 0 - ], "coordinates_positive": "", "frame_index": 0, - "object_index": 0 + "object_index": 0, + "windowed_mode": true, + "offload_video_to_cpu": false, + "offload_state_to_cpu": false } }, "5": { - "class_type": "Sam2VideoSegmentation", + "class_type": "OpenShotSam2VideoSegmentationChunked", "inputs": { "sam2_model": [ "4", @@ -64,7 +63,17 @@ "4", 1 ], - "keep_model_loaded": true + "image": [ + "10", + 0 + ], + "start_frame": 0, + "chunk_size_frames": 96, + "keep_model_loaded": true, + "meta_batch": [ + "9", + 0 + ] } }, "6": { @@ -91,7 +100,11 @@ "filename_prefix": "video/openshot_mask", "format": "video/ffv1-mkv", "pingpong": false, - "save_output": true + "save_output": true, + "meta_batch": [ + "9", + 0 + ] } }, "8": { @@ -107,7 +120,7 @@ "9": { "class_type": "VHS_BatchManager", "inputs": { - "frames_per_batch": 32 + "frames_per_batch": 96 } }, "10": { @@ -156,12 +169,16 @@ } }, "13": { - "class_type": "ImageBlur", + "class_type": "OpenShotImageBlurMasked", "inputs": { "image": [ "10", 0 ], + "mask": [ + "5", + 0 + ], "blur_radius": 12, "sigma": 4.0 } @@ -178,7 +195,7 @@ 0 ], "mask": [ - "12", + "5", 0 ], "x": 0, @@ -190,7 +207,7 @@ "class_type": "VHS_VideoCombine", "inputs": { "images": [ - "14", + "13", 0 ], "frame_rate": [ diff --git a/src/windows/models/files_model.py b/src/windows/models/files_model.py index 56a2b5349..b78b7a057 100644 --- a/src/windows/models/files_model.py +++ b/src/windows/models/files_model.py @@ -852,9 +852,12 @@ def _add_generation_placeholder(self, job_id): name = str(job.get("name") or "generation") status = str(job.get("status") or "queued") progress = int(job.get("progress", 0)) + progress_detail = str(job.get("progress_detail") or "").strip() label = name if status == "running": label = "{} ({}%)".format(name, progress) + if progress_detail: + label = "{} [{}]".format(label, progress_detail) elif status == "queued": label = "{} (Queued)".format(name) elif status == "canceling": @@ -904,9 +907,12 @@ def _update_generation_placeholder(self, job_id): name = str(job.get("name") or "generation") status = str(job.get("status") or "queued") progress = int(job.get("progress", 0)) + progress_detail = str(job.get("progress_detail") or "").strip() label = name if status == "running": label = "{} ({}%)".format(name, progress) + if progress_detail: + label = "{} [{}]".format(label, progress_detail) elif status == "queued": label = "{} (Queued)".format(name) elif status == "canceling": From 8e7fa85517088a4acc1f1082c08acf666e14ad39 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Fri, 20 Feb 2026 19:52:22 -0600 Subject: [PATCH 14/27] - Replaced built-in txt2img-basic and txt2video-svd workflows with WAN-based graphs (UNET/CLIP/VAE loaders, ModelSamplingSD3, WAN model paths), and set video defaults to 832x480, 16 fps, length: 64 (~4s). - Switched these template IDs (and img2video-svd) off legacy Python workflow generation so OpenShot now uses JSON templates directly at runtime. - Added WAN node types to template classification, renamed the image-to-video menu label to Image to Video (WAN 2.2 TI2V), and added/used a new img2video-svd WAN 2.2 TI2V template in the Enhance flow. --- src/classes/comfy_pipelines.py | 2 +- src/classes/comfy_templates.py | 5 + src/classes/generation_service.py | 3 - src/comfyui/img2video-svd.json | 148 +++++++++++++++++++++++ src/comfyui/txt2img-basic.json | 127 ++++++++++++-------- src/comfyui/txt2video-svd.json | 193 +++++++++++------------------- 6 files changed, 300 insertions(+), 178 deletions(-) create mode 100644 src/comfyui/img2video-svd.json diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index fca53a48b..93ca0661e 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -75,7 +75,7 @@ def available_pipelines(source_file=None): if _supports_img2img(source_file): pipelines.insert(0, {"id": "img2img-basic", "name": "Basic Image Variation"}) pipelines.insert(1, {"id": "upscale-realesrgan-x4", "name": "Upscale Image (RealESRGAN x4)"}) - pipelines.insert(2, {"id": "img2video-svd", "name": "Image to Video (img_to_video)"}) + pipelines.insert(2, {"id": "img2video-svd", "name": "Image to Video (WAN 2.2 TI2V)"}) if _supports_video_upscale(source_file): pipelines.append({"id": "video-segment-scenes-transnet", "name": "Segment Scenes (TransNetV2)"}) pipelines.append({"id": "video-frame-interpolation-rife2x", "name": "Frame Interpolation (RIFE 2x FPS)"}) diff --git a/src/classes/comfy_templates.py b/src/classes/comfy_templates.py index 48df519c6..285b18384 100644 --- a/src/classes/comfy_templates.py +++ b/src/classes/comfy_templates.py @@ -41,8 +41,10 @@ KNOWN_NODE_TYPES = { # Input "checkpointloadersimple", + "unetloader", "cliptextencode", "cliploader", + "vaeloader", "loadimage", "loadvideo", "vhs_loadvideo", @@ -62,7 +64,10 @@ "saveaudio", "save srt", "emptylatentimage", + "emptyhunyuanlatentvideo", + "wan22imagetovideolatent", "imageonlycheckpointloader", + "modelsamplingsd3", "svd_img2vid_conditioning", "videolinearcfgguidance", "emptylatentaudio", diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index e4b8385d0..92376f199 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -66,12 +66,9 @@ class GenerationService: """Encapsulates generation-specific UI + workflow behavior.""" LEGACY_PIPELINE_IDS = { - "txt2img-basic", - "txt2video-svd", "txt2audio-stable-open", "img2img-basic", "upscale-realesrgan-x4", - "img2video-svd", "video-segment-scenes-transnet", "video-frame-interpolation-rife2x", "video-upscale-gan", diff --git a/src/comfyui/img2video-svd.json b/src/comfyui/img2video-svd.json new file mode 100644 index 000000000..de8acc653 --- /dev/null +++ b/src/comfyui/img2video-svd.json @@ -0,0 +1,148 @@ +{ + "action_icon": "ai-action-create-video.svg", + "menu_category": "enhance", + "menu_order": 21, + "name": "Image to Video...", + "open_dialog": true, + "output_type": "video", + "template_id": "img2video-svd", + "workflow": { + "2": { + "class_type": "LoadImage", + "inputs": { + "image": "__openshot_input__", + "upload": "image" + } + }, + "3": { + "class_type": "KSampler", + "inputs": { + "seed": 82628696717253, + "steps": 30, + "cfg": 6, + "sampler_name": "uni_pc", + "scheduler": "simple", + "denoise": 1, + "model": [ + "48", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "40", + 0 + ] + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "text": "__openshot_prompt__", + "clip": [ + "38", + 0 + ] + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "text": "low quality, blurry, overexposed, static scene, washed out, text subtitles, watermark, logo, distorted hands, deformed face, bad anatomy", + "clip": [ + "38", + 0 + ] + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "39", + 0 + ] + } + }, + "37": { + "class_type": "UNETLoader", + "inputs": { + "unet_name": "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors", + "weight_dtype": "default" + } + }, + "38": { + "class_type": "CLIPLoader", + "inputs": { + "clip_name": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + } + }, + "39": { + "class_type": "VAELoader", + "inputs": { + "vae_name": "wan2.2_vae.safetensors" + } + }, + "40": { + "class_type": "Wan22ImageToVideoLatent", + "inputs": { + "start_image": [ + "2", + 0 + ], + "vae": [ + "39", + 0 + ], + "width": 832, + "height": 480, + "length": 64, + "batch_size": 1 + } + }, + "48": { + "class_type": "ModelSamplingSD3", + "inputs": { + "shift": 8, + "model": [ + "37", + 0 + ] + } + }, + "49": { + "class_type": "CreateVideo", + "inputs": { + "fps": 16, + "images": [ + "8", + 0 + ] + } + }, + "50": { + "class_type": "SaveVideo", + "inputs": { + "filename_prefix": "video/openshot_gen", + "format": "auto", + "codec": "auto", + "video": [ + "49", + 0 + ] + } + } + } +} diff --git a/src/comfyui/txt2img-basic.json b/src/comfyui/txt2img-basic.json index 5a656854b..7803472b3 100644 --- a/src/comfyui/txt2img-basic.json +++ b/src/comfyui/txt2img-basic.json @@ -7,86 +7,111 @@ "output_type": "image", "template_id": "txt2img-basic", "workflow": { - "1": { - "class_type": "CheckpointLoaderSimple", + "3": { + "class_type": "KSampler", "inputs": { - "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + "seed": 82628696717253, + "steps": 30, + "cfg": 6, + "sampler_name": "uni_pc", + "scheduler": "simple", + "denoise": 1, + "model": [ + "48", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "40", + 0 + ] } }, - "2": { + "6": { "class_type": "CLIPTextEncode", "inputs": { + "text": "__openshot_prompt__", "clip": [ - "1", - 1 - ], - "text": "cinematic shot, highly detailed" + "38", + 0 + ] } }, - "3": { + "7": { "class_type": "CLIPTextEncode", "inputs": { + "text": "low quality, blurry, overexposed, static scene, washed out, text subtitles, watermark, logo, distorted hands, deformed face, bad anatomy", "clip": [ - "1", - 1 - ], - "text": "low quality, blurry" - } - }, - "4": { - "class_type": "EmptyLatentImage", - "inputs": { - "batch_size": 1, - "height": 576, - "width": 1024 + "38", + 0 + ] } }, - "5": { - "class_type": "KSampler", + "8": { + "class_type": "VAEDecode", "inputs": { - "cfg": 7.0, - "denoise": 1.0, - "latent_image": [ - "4", - 0 - ], - "model": [ - "1", - 0 - ], - "negative": [ + "samples": [ "3", 0 ], - "positive": [ - "2", + "vae": [ + "39", 0 - ], - "sampler_name": "euler", - "scheduler": "normal", - "seed": 687962524, - "steps": 20 + ] } }, - "6": { - "class_type": "VAEDecode", + "37": { + "class_type": "UNETLoader", "inputs": { - "samples": [ - "5", + "unet_name": "split_files/diffusion_models/wan2.1_t2v_14B_fp8_scaled.safetensors", + "weight_dtype": "default" + } + }, + "38": { + "class_type": "CLIPLoader", + "inputs": { + "clip_name": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" + } + }, + "39": { + "class_type": "VAELoader", + "inputs": { + "vae_name": "wan_2.1_vae.safetensors" + } + }, + "40": { + "class_type": "EmptyLatentImage", + "inputs": { + "width": 832, + "height": 480, + "batch_size": 1 + } + }, + "48": { + "class_type": "ModelSamplingSD3", + "inputs": { + "shift": 8, + "model": [ + "37", 0 - ], - "vae": [ - "1", - 2 ] } }, - "7": { + "50": { "class_type": "SaveImage", "inputs": { "filename_prefix": "openshot_gen", "images": [ - "6", + "8", 0 ] } diff --git a/src/comfyui/txt2video-svd.json b/src/comfyui/txt2video-svd.json index 73deac8e1..73af62674 100644 --- a/src/comfyui/txt2video-svd.json +++ b/src/comfyui/txt2video-svd.json @@ -7,177 +7,124 @@ "output_type": "video", "template_id": "txt2video-svd", "workflow": { - "1": { - "class_type": "ImageOnlyCheckpointLoader", - "inputs": { - "ckpt_name": "svd_xt.safetensors" - } - }, - "10": { + "3": { "class_type": "KSampler", "inputs": { - "cfg": 2.5, - "denoise": 1.0, - "latent_image": [ - "8", - 2 - ], + "seed": 82628696717253, + "steps": 30, + "cfg": 6, + "sampler_name": "uni_pc", + "scheduler": "simple", + "denoise": 1, "model": [ - "9", + "48", 0 ], - "negative": [ - "8", - 1 - ], "positive": [ - "8", + "6", 0 ], - "sampler_name": "euler", - "scheduler": "karras", - "seed": 1825708738, - "steps": 10 - } - }, - "11": { - "class_type": "VAEDecode", - "inputs": { - "samples": [ - "10", + "negative": [ + "7", 0 ], - "vae": [ - "1", - 2 + "latent_image": [ + "40", + 0 ] } }, - "12": { - "class_type": "CreateVideo", + "6": { + "class_type": "CLIPTextEncode", "inputs": { - "fps": 12, - "images": [ - "11", + "text": "__openshot_prompt__", + "clip": [ + "38", 0 ] } }, - "13": { - "class_type": "SaveVideo", + "7": { + "class_type": "CLIPTextEncode", "inputs": { - "codec": "auto", - "filename_prefix": "video/openshot_gen", - "format": "auto", - "video": [ - "12", + "text": "low quality, blurry, overexposed, static scene, washed out, text subtitles, watermark, logo, distorted hands, deformed face, bad anatomy", + "clip": [ + "38", 0 ] } }, - "2": { - "class_type": "CheckpointLoaderSimple", + "8": { + "class_type": "VAEDecode", "inputs": { - "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" + "samples": [ + "3", + 0 + ], + "vae": [ + "39", + 0 + ] } }, - "3": { - "class_type": "CLIPTextEncode", + "37": { + "class_type": "UNETLoader", "inputs": { - "clip": [ - "2", - 1 - ], - "text": "cinematic shot, highly detailed" + "unet_name": "split_files/diffusion_models/wan2.1_t2v_1.3B_fp16.safetensors", + "weight_dtype": "default" } }, - "4": { - "class_type": "CLIPTextEncode", + "38": { + "class_type": "CLIPLoader", "inputs": { - "clip": [ - "2", - 1 - ], - "text": "low quality, blurry" + "clip_name": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", + "type": "wan", + "device": "default" } }, - "5": { - "class_type": "EmptyLatentImage", + "39": { + "class_type": "VAELoader", "inputs": { - "batch_size": 1, - "height": 288, - "width": 512 + "vae_name": "wan_2.1_vae.safetensors" } }, - "6": { - "class_type": "KSampler", + "40": { + "class_type": "EmptyHunyuanLatentVideo", "inputs": { - "cfg": 6.0, - "denoise": 1.0, - "latent_image": [ - "5", - 0 - ], - "model": [ - "2", - 0 - ], - "negative": [ - "4", - 0 - ], - "positive": [ - "3", - 0 - ], - "sampler_name": "euler", - "scheduler": "normal", - "seed": 1825708737, - "steps": 8 + "width": 832, + "height": 480, + "length": 64, + "batch_size": 1 } }, - "7": { - "class_type": "VAEDecode", + "48": { + "class_type": "ModelSamplingSD3", "inputs": { - "samples": [ - "6", + "shift": 8, + "model": [ + "37", 0 - ], - "vae": [ - "2", - 2 ] } }, - "8": { - "class_type": "SVD_img2vid_Conditioning", + "49": { + "class_type": "CreateVideo", "inputs": { - "augmentation_level": 0.0, - "clip_vision": [ - "1", - 1 - ], - "fps": 12, - "height": 288, - "init_image": [ - "7", + "fps": 16, + "images": [ + "8", 0 - ], - "motion_bucket_id": 127, - "vae": [ - "1", - 2 - ], - "video_frames": 24, - "width": 512 + ] } }, - "9": { - "class_type": "VideoLinearCFGGuidance", + "50": { + "class_type": "SaveVideo", "inputs": { - "min_cfg": 1.0, - "model": [ - "1", + "filename_prefix": "video/openshot_gen", + "format": "auto", + "codec": "auto", + "video": [ + "49", 0 ] } From 6803d8cf05cfa62dc6fcb587adbf3e7ee0af4ded Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Fri, 20 Feb 2026 20:17:06 -0600 Subject: [PATCH 15/27] Replaces the WAN-based image workflow with a standard SDXL graph (CheckpointLoaderSimple, CLIPTextEncode, KSampler, VAEDecode, SaveImage). - The image model/settings are now quality-oriented SDXL Base: ckpt_name = sd_xl_base_1.0.safetensors, steps = 28, and cfg = 6.5. --- src/comfyui/txt2img-basic.json | 127 +++++++++++++-------------------- 1 file changed, 51 insertions(+), 76 deletions(-) diff --git a/src/comfyui/txt2img-basic.json b/src/comfyui/txt2img-basic.json index 7803472b3..5302b4108 100644 --- a/src/comfyui/txt2img-basic.json +++ b/src/comfyui/txt2img-basic.json @@ -7,111 +7,86 @@ "output_type": "image", "template_id": "txt2img-basic", "workflow": { - "3": { - "class_type": "KSampler", + "1": { + "class_type": "CheckpointLoaderSimple", "inputs": { - "seed": 82628696717253, - "steps": 30, - "cfg": 6, - "sampler_name": "uni_pc", - "scheduler": "simple", - "denoise": 1, - "model": [ - "48", - 0 - ], - "positive": [ - "6", - 0 - ], - "negative": [ - "7", - 0 - ], - "latent_image": [ - "40", - 0 - ] + "ckpt_name": "sd_xl_base_1.0.safetensors" } }, - "6": { + "2": { "class_type": "CLIPTextEncode", "inputs": { - "text": "__openshot_prompt__", "clip": [ - "38", - 0 - ] + "1", + 1 + ], + "text": "cinematic shot, highly detailed" } }, - "7": { + "3": { "class_type": "CLIPTextEncode", "inputs": { - "text": "low quality, blurry, overexposed, static scene, washed out, text subtitles, watermark, logo, distorted hands, deformed face, bad anatomy", "clip": [ - "38", - 0 - ] - } - }, - "8": { - "class_type": "VAEDecode", - "inputs": { - "samples": [ - "3", - 0 + "1", + 1 ], - "vae": [ - "39", - 0 - ] - } - }, - "37": { - "class_type": "UNETLoader", - "inputs": { - "unet_name": "split_files/diffusion_models/wan2.1_t2v_14B_fp8_scaled.safetensors", - "weight_dtype": "default" + "text": "low quality, blurry" } }, - "38": { - "class_type": "CLIPLoader", - "inputs": { - "clip_name": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", - "type": "wan", - "device": "default" - } - }, - "39": { - "class_type": "VAELoader", + "4": { + "class_type": "EmptyLatentImage", "inputs": { - "vae_name": "wan_2.1_vae.safetensors" + "batch_size": 1, + "height": 576, + "width": 1024 } }, - "40": { - "class_type": "EmptyLatentImage", + "5": { + "class_type": "KSampler", "inputs": { - "width": 832, - "height": 480, - "batch_size": 1 + "cfg": 6.5, + "denoise": 1.0, + "latent_image": [ + "4", + 0 + ], + "model": [ + "1", + 0 + ], + "negative": [ + "3", + 0 + ], + "positive": [ + "2", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 687962524, + "steps": 28 } }, - "48": { - "class_type": "ModelSamplingSD3", + "6": { + "class_type": "VAEDecode", "inputs": { - "shift": 8, - "model": [ - "37", + "samples": [ + "5", 0 + ], + "vae": [ + "1", + 2 ] } }, - "50": { + "7": { "class_type": "SaveImage", "inputs": { "filename_prefix": "openshot_gen", "images": [ - "8", + "6", 0 ] } From 6aa3abd20a2ef37b7cdac83c2738d47b9be65f79 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Fri, 20 Feb 2026 20:33:56 -0600 Subject: [PATCH 16/27] - Comfy scene splitting now returns range metadata and OpenShot creates split-style file entries from it, reusing the original source path with start/end trims (no duplicate segment MP4 imports). - Scene split file names now match Split File dialog format exactly (base (start to end)), with only start and end attributes applied. --- src/classes/comfy_pipelines.py | 8 +- src/classes/generation_service.py | 158 +++++++++++++++++++++++++++--- 2 files changed, 150 insertions(+), 16 deletions(-) diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index 93ca0661e..02ce8c405 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -297,17 +297,17 @@ def build_workflow( }, "8": { "inputs": { - "index": 0, "segment_paths": ["1", 0], + "source_video_path": source_path, }, - "class_type": "SelectVideo", - "_meta": {"title": "MiaoshouAI Select Video"}, + "class_type": "OpenShotSceneRangesFromSegments", + "_meta": {"title": "OpenShot Build Scene Ranges"}, }, "9": { "inputs": { "preview": "", "previewMode": None, - "source": ["1", 0], + "source": ["8", 0], }, "class_type": "PreviewAny", "_meta": {"title": "Preview Any"}, diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 92376f199..4b626d913 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -785,11 +785,17 @@ def on_generation_job_finished(self, job_id, status): imported = int(result.get("imported", 0)) caption_saved = bool(result.get("caption_saved", False)) scenes_labeled = int(result.get("scenes_labeled", 0)) + scene_splits_created = int(result.get("scene_splits_created", 0)) if imported > 0 and caption_saved: self.win.statusBar.showMessage( "Generation completed, imported {} file(s), and saved file caption data".format(imported), 5000, ) + elif scene_splits_created > 0: + self.win.statusBar.showMessage( + "Generation completed and created {} scene split file(s)".format(scene_splits_created), + 5000, + ) elif imported > 0 and scenes_labeled > 0: self.win.statusBar.showMessage( "Generation completed, imported {} file(s), and labeled {} scene segment(s)".format( @@ -817,13 +823,14 @@ def on_generation_job_finished(self, job_id, status): def _import_generation_outputs(self, job): outputs = list(job.get("outputs", []) or []) if not outputs: - return {"imported": 0, "caption_saved": False} + return {"imported": 0, "caption_saved": False, "scene_splits_created": 0} request = job.get("request", {}) or {} comfy_url = str(request.get("comfy_url") or self.comfy_ui_url()) client = ComfyClient(comfy_url) output_dir = info.COMFYUI_OUTPUT_PATH os.makedirs(output_dir, exist_ok=True) + template_id = str(job.get("template_id") or "").strip().lower() name_raw = str(job.get("name") or "generation") safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name_raw).strip("._") @@ -832,11 +839,22 @@ def _import_generation_outputs(self, job): saved_paths = [] text_outputs = [] + scene_splits_created = 0 video_path_exts = {".mp4", ".mov", ".mkv", ".webm", ".avi", ".m4v"} seen_video_payload_paths = set() for index, output_ref in enumerate(outputs, start=1): text_payload = str(output_ref.get("text", "")).strip() if text_payload: + if template_id == "video-segment-scenes-transnet": + scene_ranges = self._extract_scene_ranges_from_text(text_payload) + if scene_ranges: + created = self._create_scene_split_files( + source_file_id=job.get("source_file_id"), + scene_ranges=scene_ranges, + ) + scene_splits_created += created + if created > 0: + continue payload_video_paths = self._extract_video_paths_from_text(text_payload) if not payload_video_paths: payload_ext = os.path.splitext(text_payload)[1].lower() @@ -908,30 +926,36 @@ def _import_generation_outputs(self, job): except Exception as ex: log.warning("Failed to download Comfy output %s: %s", output_ref, ex) - if not saved_paths: - return {"imported": 0, "caption_saved": False} + if saved_paths: + self.win.files_model.add_files( + saved_paths, + quiet=True, + prevent_image_seq=True, + prevent_recent_folder=True, + ) - self.win.files_model.add_files( - saved_paths, - quiet=True, - prevent_image_seq=True, - prevent_recent_folder=True, - ) + if not saved_paths and scene_splits_created <= 0: + return {"imported": 0, "caption_saved": False, "scene_splits_created": 0} caption_saved = False scenes_labeled = 0 - if str(job.get("template_id") or "") == "video-whisper-srt": + if template_id == "video-whisper-srt": caption_text = self._resolve_caption_text(saved_paths, text_outputs) caption_saved = self._store_caption_on_file( source_file_id=job.get("source_file_id"), caption_text=caption_text, ) - if str(job.get("template_id") or "") == "video-segment-scenes-transnet": + if template_id == "video-segment-scenes-transnet" and saved_paths: scenes_labeled = self._apply_scene_segment_metadata( source_file_id=job.get("source_file_id"), saved_paths=saved_paths, ) - return {"imported": len(saved_paths), "caption_saved": caption_saved, "scenes_labeled": scenes_labeled} + return { + "imported": len(saved_paths), + "caption_saved": caption_saved, + "scenes_labeled": scenes_labeled, + "scene_splits_created": scene_splits_created, + } def _extract_video_paths_from_text(self, text_payload): """Extract absolute video file paths from log/text payloads.""" @@ -944,6 +968,116 @@ def _extract_video_paths_from_text(self, text_payload): ) return [match.strip() for match in pattern.findall(text_payload) if match.strip()] + def _extract_scene_ranges_from_text(self, text_payload): + """Parse scene range metadata JSON from text output payloads.""" + text_payload = str(text_payload or "").strip() + if not text_payload: + return [] + if not text_payload.startswith("{"): + first = text_payload.find("{") + last = text_payload.rfind("}") + if first >= 0 and last > first: + text_payload = text_payload[first:last + 1] + else: + return [] + try: + payload = json.loads(text_payload) + except Exception: + return [] + + segment_entries = payload.get("segments") if isinstance(payload, dict) else None + if not isinstance(segment_entries, list): + return [] + + scene_ranges = [] + for segment in segment_entries: + if not isinstance(segment, dict): + continue + start_seconds = segment.get("start_seconds", segment.get("start")) + end_seconds = segment.get("end_seconds", segment.get("end")) + try: + start_value = float(start_seconds) + end_value = float(end_seconds) + except (TypeError, ValueError): + continue + if end_value <= start_value: + continue + scene_ranges.append((max(0.0, start_value), max(0.0, end_value))) + return scene_ranges + + def _create_scene_split_files(self, source_file_id, scene_ranges): + """Create split-style file entries pointing to the same source media path.""" + source_file = File.get(id=source_file_id) if source_file_id else None + if source_file is None: + source_file = File.get(id=str(source_file_id or "")) + if source_file is None: + return 0 + + source_data = source_file.data if isinstance(source_file.data, dict) else {} + source_path = str(source_data.get("path", "") or "") + if not source_path: + return 0 + if str(source_data.get("media_type", "")).lower() != "video": + return 0 + + fps_data = source_data.get("fps", {}) + fps_fraction = Fraction(30, 1) + try: + fps_num = int(fps_data.get("num", 30)) + fps_den = int(fps_data.get("den", 1) or 1) + if fps_num > 0 and fps_den > 0: + fps_fraction = Fraction(fps_num, fps_den) + except (TypeError, ValueError, ZeroDivisionError): + fps_fraction = Fraction(30, 1) + + source_duration = float(source_data.get("duration") or 0.0) + base_name = os.path.splitext(os.path.basename(source_path))[0] or "scene" + + created = 0 + for start_seconds, end_seconds in scene_ranges: + start_seconds = max(0.0, float(start_seconds or 0.0)) + end_seconds = max(start_seconds, float(end_seconds or start_seconds)) + if source_duration > 0.0: + start_seconds = min(start_seconds, source_duration) + end_seconds = min(end_seconds, source_duration) + if end_seconds <= start_seconds: + continue + + include_hours = int(end_seconds // 3600) > 0 + include_minutes = include_hours or int((end_seconds % 3600) // 60) > 0 + start_tc = self._seconds_to_compact_timecode( + start_seconds, + fps_fraction, + include_hours=include_hours, + include_minutes=include_minutes, + ) + end_tc = self._seconds_to_compact_timecode( + end_seconds, + fps_fraction, + include_hours=include_hours, + include_minutes=include_minutes, + ) + + split_data = json.loads(json.dumps(source_data)) + split_data.pop("id", None) + split_data["start"] = start_seconds + split_data["end"] = end_seconds + split_data["name"] = "{} ({} to {})".format( + base_name, + start_tc, + end_tc, + ) + + split_file = File() + split_file.id = None + split_file.key = None + split_file.type = "insert" + split_file.data = split_data + self._append_scene_tag(split_file) + split_file.save() + created += 1 + return created + def _resolve_caption_text(self, saved_paths, text_outputs): srt_path = "" for path in saved_paths: From e959c0767cc0165939c509473417190c3b03ac3d Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 21 Feb 2026 11:53:50 -0600 Subject: [PATCH 17/27] - Region UI now supports multi-tool annotations (pos/neg points + rects), frame markers, and clear-all, plus new toolbar icons. - Generate + SAM2 workflows now pass full tracking JSON (seed frame + per-frame points/rects/auto) instead of seed-only inputs. - Tracker preprocess flow was fixed (dialog parenting + guards/imports) to avoid region selection crashes/regressions. --- src/classes/comfy_pipelines.py | 30 +- src/classes/generation_service.py | 98 +++- src/comfyui/image-blur-anything-sam2.json | 6 +- src/comfyui/video-blur-anything-sam2.json | 9 +- .../cosmic/images/ai-track-point-negative.svg | 4 + .../cosmic/images/ai-track-point-positive.svg | 5 + .../cosmic/images/ai-track-rect-negative.svg | 4 + .../cosmic/images/ai-track-rect-positive.svg | 5 + src/windows/generate.py | 166 ++++-- src/windows/process_effect.py | 54 +- src/windows/region.py | 488 +++++++++++++++++- src/windows/video_widget.py | 109 +++- 12 files changed, 860 insertions(+), 118 deletions(-) create mode 100644 src/themes/cosmic/images/ai-track-point-negative.svg create mode 100644 src/themes/cosmic/images/ai-track-point-positive.svg create mode 100644 src/themes/cosmic/images/ai-track-rect-negative.svg create mode 100644 src/themes/cosmic/images/ai-track-rect-positive.svg diff --git a/src/classes/comfy_pipelines.py b/src/classes/comfy_pipelines.py index 02ce8c405..9c2fd816d 100644 --- a/src/classes/comfy_pipelines.py +++ b/src/classes/comfy_pipelines.py @@ -275,39 +275,21 @@ def build_workflow( if not source_path: raise ValueError("A source video is required for this pipeline.") return { - "7": {"inputs": {"file": source_path}, "class_type": "LoadVideo"}, - "2": { - "inputs": { - "model": "transnetv2-pytorch-weights", - "device": "auto", - }, - "class_type": "DownloadAndLoadTransNetModel", - "_meta": {"title": "MiaoshouAI Load TransNet Model"}, - }, "1": { "inputs": { - "threshold": 0.5, - "min_scene_length": 30, - "output_dir": "output", - "TransNet_model": ["2", 0], - "video": ["7", 0], - }, - "class_type": "TransNetV2_Run", - "_meta": {"title": "MiaoshouAI Segment Video"}, - }, - "8": { - "inputs": { - "segment_paths": ["1", 0], "source_video_path": source_path, + "threshold": 0.5, + "min_scene_length_frames": 30, + "device": "auto", }, - "class_type": "OpenShotSceneRangesFromSegments", - "_meta": {"title": "OpenShot Build Scene Ranges"}, + "class_type": "OpenShotTransNetSceneDetect", + "_meta": {"title": "OpenShot TransNet Scene Detect"}, }, "9": { "inputs": { "preview": "", "previewMode": None, - "source": ["8", 0], + "source": ["1", 0], }, "class_type": "PreviewAny", "_meta": {"title": "Preview Any"}, diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 4b626d913..c4b5802b6 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -222,6 +222,10 @@ def _prepare_nonlegacy_workflow( source_path, coordinates_positive_text="", coordinates_negative_text="", + rectangles_positive_text="", + rectangles_negative_text="", + auto_mode=False, + tracking_selection=None, ): workflow = self.template_registry.get_workflow_copy(template.get("id")) if not workflow: @@ -249,6 +253,10 @@ def _resolve_template_local_file(path_text): prompt_text = str(prompt_text or "").strip() coordinates_positive_text = str(coordinates_positive_text or "").strip() coordinates_negative_text = str(coordinates_negative_text or "").strip() + rectangles_positive_text = str(rectangles_positive_text or "").strip() + rectangles_negative_text = str(rectangles_negative_text or "").strip() + auto_mode = bool(auto_mode) + tracking_selection = tracking_selection if isinstance(tracking_selection, dict) else {} media_type = str(source_file.data.get("media_type", "")).strip().lower() if source_file else "" applied_prompt = False loadimage_node_ids = [] @@ -404,22 +412,52 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): inputs["prompt"] = prompt_text applied_prompt = True - coords_value = inputs.get("coordinates_positive", None) - if ( - class_flat in ( - "sam2videosegmentationaddpoints", - "sam2segmentation", - "openshotsam2videosegmentationaddpoints", - "openshotsam2segmentation", - ) - and "coordinates_positive" in inputs - and isinstance(coords_value, str) + if class_flat in ( + "sam2videosegmentationaddpoints", + "sam2segmentation", + "openshotsam2videosegmentationaddpoints", + "openshotsam2segmentation", ): + seed_frame_idx = 0 + if isinstance(tracking_selection, dict): + try: + seed_frame_idx = max(0, int(tracking_selection.get("seed_frame", 1)) - 1) + except Exception: + seed_frame_idx = 0 + if "frame_index" in inputs: + try: + inputs["frame_index"] = int(seed_frame_idx) + except Exception: + pass + if "tracking_selection_json" in inputs: + try: + inputs["tracking_selection_json"] = json.dumps(tracking_selection or {}) + except Exception: + inputs["tracking_selection_json"] = "{}" + coords_text = coordinates_positive_text or prompt_text points = _parse_sam2_points(coords_text) - if ("blur-anything-sam2" in template_id) and (not points): - raise ValueError("No SAM2 points were provided. Use Mask > Pick Point(s) on Source.") - inputs["coordinates_positive"] = _normalize_sam2_coords_input(coords_text, coords_value) + has_positive_rects = bool(rectangles_positive_text) + + auto_enabled = bool(inputs.get("auto_mode", False)) or auto_mode + if "auto_mode" in inputs: + inputs["auto_mode"] = bool(auto_enabled) + if ("blur-anything-sam2" in template_id) and (not points) and (not has_positive_rects) and (not auto_enabled): + raise ValueError("No SAM2 seed was provided. Use Points, Rectangle, or Auto mode.") + + # New OpenShot node contract. + if "positive_points_json" in inputs and isinstance(inputs.get("positive_points_json", None), str): + inputs["positive_points_json"] = _normalize_sam2_coords_input( + coords_text, + str(inputs.get("positive_points_json", "")), + ) + # Backward compatibility for third-party node variants. + elif "coordinates_positive" in inputs and isinstance(inputs.get("coordinates_positive", None), str): + inputs["coordinates_positive"] = _normalize_sam2_coords_input( + coords_text, + str(inputs.get("coordinates_positive", "")), + ) + if class_flat in ("sam2segmentation", "openshotsam2segmentation") and "individual_objects" in inputs: # For Blur Anything, treat points as a single combined prompt. # This is more stable with mixed positive/negative points and avoids @@ -429,13 +467,28 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): else: # Non-Blur-Anything templates keep multi-object behavior. inputs["individual_objects"] = bool(len(points) > 1) + if coordinates_negative_text: - neg_value = inputs.get("coordinates_negative", "") - if isinstance(neg_value, str) or "coordinates_negative" not in inputs: - inputs["coordinates_negative"] = _normalize_sam2_coords_input( + if "negative_points_json" in inputs and isinstance(inputs.get("negative_points_json", None), str): + inputs["negative_points_json"] = _normalize_sam2_coords_input( coordinates_negative_text, - str(neg_value or ""), + str(inputs.get("negative_points_json", "")), ) + elif "coordinates_negative" in inputs: + neg_value = inputs.get("coordinates_negative", "") + if isinstance(neg_value, str) or "coordinates_negative" not in inputs: + inputs["coordinates_negative"] = _normalize_sam2_coords_input( + coordinates_negative_text, + str(neg_value or ""), + ) + if rectangles_positive_text and ("positive_rects_json" in inputs) and isinstance( + inputs.get("positive_rects_json", None), str + ): + inputs["positive_rects_json"] = rectangles_positive_text + if rectangles_negative_text and ("negative_rects_json" in inputs) and isinstance( + inputs.get("negative_rects_json", None), str + ): + inputs["negative_rects_json"] = rectangles_negative_text if not source_path: continue @@ -511,7 +564,12 @@ def _save_nodes_for_workflow(self, workflow): class_type = str(node.get("class_type", "")).strip().lower() if not class_type: continue - if class_type.startswith("save") or class_type in ("previewany", "transnetv2_run", "vhs_videocombine"): + if class_type.startswith("save") or class_type in ( + "previewany", + "transnetv2_run", + "openshottransnetscenedetect", + "vhs_videocombine", + ): save_nodes.append(str(node_id)) return save_nodes @@ -747,6 +805,10 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No source_path=source_path, coordinates_positive_text=payload.get("coordinates_positive"), coordinates_negative_text=payload.get("coordinates_negative"), + rectangles_positive_text=payload.get("rectangles_positive"), + rectangles_negative_text=payload.get("rectangles_negative"), + auto_mode=payload.get("auto_mode"), + tracking_selection=payload.get("tracking_selection"), ) except Exception as ex: QMessageBox.information(self.win, "Invalid Input", str(ex)) diff --git a/src/comfyui/image-blur-anything-sam2.json b/src/comfyui/image-blur-anything-sam2.json index 39829377b..3258fd639 100644 --- a/src/comfyui/image-blur-anything-sam2.json +++ b/src/comfyui/image-blur-anything-sam2.json @@ -34,7 +34,11 @@ "1", 0 ], - "coordinates_positive": "", + "auto_mode": false, + "positive_points_json": "", + "negative_points_json": "", + "positive_rects_json": "", + "negative_rects_json": "", "keep_model_loaded": false } }, diff --git a/src/comfyui/video-blur-anything-sam2.json b/src/comfyui/video-blur-anything-sam2.json index f4c1e0f69..b0cce2892 100644 --- a/src/comfyui/video-blur-anything-sam2.json +++ b/src/comfyui/video-blur-anything-sam2.json @@ -44,12 +44,17 @@ "3", 0 ], - "coordinates_positive": "", "frame_index": 0, "object_index": 0, "windowed_mode": true, "offload_video_to_cpu": false, - "offload_state_to_cpu": false + "offload_state_to_cpu": false, + "auto_mode": false, + "positive_points_json": "", + "negative_points_json": "", + "positive_rects_json": "", + "negative_rects_json": "", + "tracking_selection_json": "{}" } }, "5": { diff --git a/src/themes/cosmic/images/ai-track-point-negative.svg b/src/themes/cosmic/images/ai-track-point-negative.svg new file mode 100644 index 000000000..78befa0f4 --- /dev/null +++ b/src/themes/cosmic/images/ai-track-point-negative.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/themes/cosmic/images/ai-track-point-positive.svg b/src/themes/cosmic/images/ai-track-point-positive.svg new file mode 100644 index 000000000..8f885486c --- /dev/null +++ b/src/themes/cosmic/images/ai-track-point-positive.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/themes/cosmic/images/ai-track-rect-negative.svg b/src/themes/cosmic/images/ai-track-rect-negative.svg new file mode 100644 index 000000000..2ab0de07d --- /dev/null +++ b/src/themes/cosmic/images/ai-track-rect-negative.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/themes/cosmic/images/ai-track-rect-positive.svg b/src/themes/cosmic/images/ai-track-rect-positive.svg new file mode 100644 index 000000000..17382fe16 --- /dev/null +++ b/src/themes/cosmic/images/ai-track-rect-positive.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/windows/generate.py b/src/windows/generate.py index 6b5d632c3..2303240a3 100644 --- a/src/windows/generate.py +++ b/src/windows/generate.py @@ -61,6 +61,10 @@ def __init__( self.preselected_template_id = str(preselected_template_id or "").strip() self._coordinates_positive_text = "" self._coordinates_negative_text = "" + self._rectangles_positive_text = "" + self._rectangles_negative_text = "" + self._auto_mode = False + self._tracking_selection_payload = {} self.setObjectName("generateDialog") self.setWindowTitle(str(dialog_title or "AI Tools")) self.setMinimumWidth(620) @@ -95,6 +99,10 @@ def __init__( def _current_coordinates_text(self): coordinates_positive = str(self._coordinates_positive_text or "").strip() coordinates_negative = str(self._coordinates_negative_text or "").strip() + rects_positive = str(self._rectangles_positive_text or "").strip() + rects_negative = str(self._rectangles_negative_text or "").strip() + auto_mode = bool(self._auto_mode) + tracking_payload = dict(self._tracking_selection_payload or {}) if not coordinates_positive and hasattr(self, "points_preview"): preview_text = self.points_preview.toPlainText().strip() if preview_text.startswith("{"): @@ -102,22 +110,31 @@ def _current_coordinates_text(self): payload = json.loads(preview_text.replace("'", "\"")) coordinates_positive = str(payload.get("positive", "")).strip() or coordinates_positive coordinates_negative = str(payload.get("negative", "")).strip() or coordinates_negative + rects_positive = str(payload.get("positive_rects", "")).strip() or rects_positive + rects_negative = str(payload.get("negative_rects", "")).strip() or rects_negative + auto_mode = bool(payload.get("auto_mode", auto_mode)) + if isinstance(payload.get("tracking_selection"), dict): + tracking_payload = payload.get("tracking_selection") except Exception: pass prompt_text = self.prompt_edit.toPlainText().strip() # Backward-compatible fallback: if prompt itself contains point JSON, treat it as coordinates. if (not coordinates_positive) and prompt_text.startswith("[") and ("\"x\"" in prompt_text or "'x'" in prompt_text): coordinates_positive = prompt_text - return coordinates_positive, coordinates_negative, prompt_text + return coordinates_positive, coordinates_negative, rects_positive, rects_negative, auto_mode, tracking_payload, prompt_text def get_payload(self): - coordinates_positive, coordinates_negative, prompt_text = self._current_coordinates_text() + coordinates_positive, coordinates_negative, rects_positive, rects_negative, auto_mode, tracking_payload, prompt_text = self._current_coordinates_text() return { "name": self.name_edit.text().strip(), "template_id": self.template_combo.currentData() or self.template_combo.currentText(), "prompt": prompt_text, "coordinates_positive": coordinates_positive, "coordinates_negative": coordinates_negative, + "rectangles_positive": rects_positive, + "rectangles_negative": rects_negative, + "auto_mode": bool(auto_mode), + "tracking_selection": tracking_payload, } def _build_top_block(self): @@ -188,15 +205,15 @@ def _build_points_tab(self): layout = QVBoxLayout(tab) layout.setContentsMargins(8, 8, 8, 8) self.mask_hint = QLabel( - "Select one or more tracking points on the source frame." + "Open tracking selection tools to choose object regions across frames." ) self.mask_hint.setWordWrap(True) layout.addWidget(self.mask_hint) controls = QHBoxLayout() - self.pick_points_button = QPushButton("Pick Point(s) on Source") + self.pick_points_button = QPushButton("Choose object(s) for tracking") self.clear_points_button = QPushButton("Clear") - self.pick_points_button.clicked.connect(self._pick_points_clicked) + self.pick_points_button.clicked.connect(self._choose_tracking_clicked) self.clear_points_button.clicked.connect(self._clear_points_clicked) controls.addWidget(self.pick_points_button) controls.addWidget(self.clear_points_button) @@ -235,12 +252,12 @@ def _on_generate_clicked(self): self.name_edit.setFocus(Qt.TabFocusReason) return if self._is_sam2_point_template(): - coordinates_positive, _coordinates_negative, _prompt_text = self._current_coordinates_text() - if not coordinates_positive: + coordinates_positive, _coordinates_negative, rects_positive, _rects_negative, auto_mode, _tracking_payload, _prompt_text = self._current_coordinates_text() + if (not auto_mode) and (not coordinates_positive) and (not rects_positive): QMessageBox.warning( self, - "Missing Points", - "No SAM2 points were provided. Use the Points tab and click Pick Point(s) on Source.", + "Missing Selection", + "No SAM2 seed was provided. Click 'Choose object(s) for tracking' in the Points tab.", ) self.tabs.setCurrentWidget(self.page_points) return @@ -259,8 +276,9 @@ def _on_template_changed(self, index): self.clear_points_button.setEnabled(is_point_template) if is_point_template: self.mask_hint.setText( - "Select one or more tracking points on the source frame." + "Use tracking tools to choose positive/negative points or rectangles on any frame." ) + self.pick_points_button.setText("Choose object(s) for tracking") self.tabs.setCurrentWidget(self.page_points) else: self.mask_hint.setText( @@ -268,23 +286,15 @@ def _on_template_changed(self, index): ) self.tabs.setCurrentWidget(self.page_prompt) - def _pick_points_clicked(self): + def _choose_tracking_clicked(self): if not self.source_file: return - win = SelectRegion(file=self.source_file, clip=None, selection_mode="point") + win = SelectRegion(file=self.source_file, clip=None, selection_mode="annotate") if win.exec_() != QDialog.Accepted: return - raw_points_pos = win.selected_points() - raw_points_neg = win.selected_points_negative() - log.info( - "Generate dialog captured raw SAM2 points positive=%s negative=%s", - len(raw_points_pos or []), - len(raw_points_neg or []), - ) - points_pos = [] - points_neg = [] + selection_payload = win.selection_payload() frame_size = win.videoPreview.curr_frame_size if not frame_size: frame_w = float(max(win.viewport_rect.width(), 1)) @@ -292,44 +302,122 @@ def _pick_points_clicked(self): else: frame_w = float(max(frame_size.width(), 1)) frame_h = float(max(frame_size.height(), 1)) - for point in raw_points_pos: - x_norm = max(min(float(point["x"]), float(max(frame_w - 1.0, 0.0))), 0.0) - y_norm = max(min(float(point["y"]), float(max(frame_h - 1.0, 0.0))), 0.0) - x_abs = int(round((x_norm / frame_w) * float(win.width))) - y_abs = int(round((y_norm / frame_h) * float(win.height))) - points_pos.append({"x": x_abs, "y": y_abs}) - for point in raw_points_neg: - x_norm = max(min(float(point["x"]), float(max(frame_w - 1.0, 0.0))), 0.0) - y_norm = max(min(float(point["y"]), float(max(frame_h - 1.0, 0.0))), 0.0) - x_abs = int(round((x_norm / frame_w) * float(win.width))) - y_abs = int(round((y_norm / frame_h) * float(win.height))) - points_neg.append({"x": x_abs, "y": y_abs}) - - if not points_pos: + src_w = float(max(getattr(win, "width", 1), 1)) + src_h = float(max(getattr(win, "height", 1), 1)) + + def _scale_point_dict(p): + if not isinstance(p, dict): + return None + try: + x_in = float(p.get("x", 0.0)) + y_in = float(p.get("y", 0.0)) + except Exception: + return None + x_norm = max(min(x_in, float(max(frame_w - 1.0, 0.0))), 0.0) + y_norm = max(min(y_in, float(max(frame_h - 1.0, 0.0))), 0.0) + x_abs = int(round((x_norm / frame_w) * src_w)) + y_abs = int(round((y_norm / frame_h) * src_h)) + return {"x": x_abs, "y": y_abs} + + def _scale_rect_dict(r): + if not isinstance(r, dict): + return None + try: + x1_in = float(r.get("x1", 0.0)) + y1_in = float(r.get("y1", 0.0)) + x2_in = float(r.get("x2", 0.0)) + y2_in = float(r.get("y2", 0.0)) + except Exception: + return None + x1 = max(min(x1_in, float(max(frame_w - 1.0, 0.0))), 0.0) + y1 = max(min(y1_in, float(max(frame_h - 1.0, 0.0))), 0.0) + x2 = max(min(x2_in, float(max(frame_w - 1.0, 0.0))), 0.0) + y2 = max(min(y2_in, float(max(frame_h - 1.0, 0.0))), 0.0) + if x2 < x1: + x1, x2 = x2, x1 + if y2 < y1: + y1, y2 = y2, y1 + sx1 = int(round((x1 / frame_w) * src_w)) + sy1 = int(round((y1 / frame_h) * src_h)) + sx2 = int(round((x2 / frame_w) * src_w)) + sy2 = int(round((y2 / frame_h) * src_h)) + return {"x1": sx1, "y1": sy1, "x2": sx2, "y2": sy2} + + # Normalize all frame annotations to source frame coordinates. + if isinstance(selection_payload, dict) and isinstance(selection_payload.get("frames"), dict): + normalized_frames = {} + for frame_key, frame_data in selection_payload.get("frames", {}).items(): + if not isinstance(frame_data, dict): + continue + pos_pts = [_scale_point_dict(p) for p in (frame_data.get("positive_points") or [])] + neg_pts = [_scale_point_dict(p) for p in (frame_data.get("negative_points") or [])] + pos_rects = [_scale_rect_dict(r) for r in (frame_data.get("positive_rects") or [])] + neg_rects = [_scale_rect_dict(r) for r in (frame_data.get("negative_rects") or [])] + normalized_frames[str(frame_key)] = { + "positive_points": [p for p in pos_pts if p is not None], + "negative_points": [p for p in neg_pts if p is not None], + "positive_rects": [r for r in pos_rects if r is not None], + "negative_rects": [r for r in neg_rects if r is not None], + } + selection_payload["frames"] = normalized_frames + + frames = selection_payload.get("frames", {}) if isinstance(selection_payload, dict) else {} + seed_frame = int(selection_payload.get("seed_frame", 1)) if isinstance(selection_payload, dict) else 1 + seed_data = frames.get(str(seed_frame), {}) if isinstance(frames, dict) else {} + points_pos = list(seed_data.get("positive_points", []) or []) + points_neg = list(seed_data.get("negative_points", []) or []) + rects_pos = list(seed_data.get("positive_rects", []) or []) + rects_neg = list(seed_data.get("negative_rects", []) or []) + + if (not points_pos) and (not rects_pos): QMessageBox.warning( self, - "No Points Found", - "No positive points were captured. Use Shift+Click to add positive points.", + "No Selections Found", + "No positive points or rectangles were captured.", ) return points_pos_text = json.dumps(points_pos) points_neg_text = json.dumps(points_neg) if points_neg else "" + rects_pos_text = json.dumps(rects_pos) if rects_pos else "" + rects_neg_text = json.dumps(rects_neg) if rects_neg else "" log.info( - "Generate dialog normalized SAM2 points positive=%s negative=%s", + "Generate dialog captured SAM2 seed frame=%s points_pos=%s points_neg=%s rects_pos=%s rects_neg=%s", + seed_frame, len(points_pos), len(points_neg), + len(rects_pos), + len(rects_neg), ) self._coordinates_positive_text = points_pos_text self._coordinates_negative_text = points_neg_text + self._rectangles_positive_text = rects_pos_text + self._rectangles_negative_text = rects_neg_text + self._auto_mode = False + self._tracking_selection_payload = selection_payload if isinstance(selection_payload, dict) else {} self.points_preview.setPlainText( - json.dumps({"positive": points_pos_text, "negative": points_neg_text}, indent=2) + json.dumps( + { + "seed_frame": seed_frame, + "auto_mode": False, + "positive": points_pos_text, + "negative": points_neg_text, + "positive_rects": rects_pos_text, + "negative_rects": rects_neg_text, + "tracking_selection": self._tracking_selection_payload, + }, + indent=2, + ) ) self.tabs.setCurrentWidget(self.page_points) def _clear_points_clicked(self): self._coordinates_positive_text = "" self._coordinates_negative_text = "" + self._rectangles_positive_text = "" + self._rectangles_negative_text = "" + self._auto_mode = False + self._tracking_selection_payload = {} self.points_preview.clear() def _set_tab_visible(self, index, visible): diff --git a/src/windows/process_effect.py b/src/windows/process_effect.py index 162c67f9a..b79a1603b 100644 --- a/src/windows/process_effect.py +++ b/src/windows/process_effect.py @@ -33,7 +33,7 @@ from PyQt5.QtCore import Qt, pyqtSignal, QCoreApplication from PyQt5.QtGui import QPainter -from PyQt5.QtWidgets import QPushButton, QDialog, QLabel, QDoubleSpinBox, QSpinBox, QLineEdit, QCheckBox, QComboBox, QDialogButtonBox, QSizePolicy +from PyQt5.QtWidgets import QPushButton, QDialog, QLabel, QDoubleSpinBox, QSpinBox, QLineEdit, QCheckBox, QComboBox, QDialogButtonBox, QSizePolicy, QMessageBox import openshot # Python module for libopenshot (required video editing module installed separately) from classes import info @@ -279,6 +279,7 @@ def text_value_changed(self, widget, param, value=None): def rect_select_clicked(self, widget, param): """Rect select button clicked""" + _ = get_app()._tr self.context[param["setting"]].update({"button-clicked": True}) # show dialog @@ -289,25 +290,41 @@ def rect_select_clicked(self, widget, param): reader_path = c.data.get('reader', {}).get('path','') f = File.get(path=reader_path) if f: - win = SelectRegion(f, self.clip_instance) + win = SelectRegion(f, self.clip_instance, parent=self) # Run the dialog event loop - blocking interaction on this window during that time result = win.exec_() if result == QDialog.Accepted: # self.first_frame = win.current_frame # Region selected (get coordinates if any) - topLeft = win.videoPreview.regionTopLeftHandle - bottomRight = win.videoPreview.regionBottomRightHandle - viewPortSize = win.viewport_rect - curr_frame_size = win.videoPreview.curr_frame_size - - x1 = topLeft.x() / curr_frame_size.width() - y1 = topLeft.y() / curr_frame_size.height() - x2 = bottomRight.x() / curr_frame_size.width() - y2 = bottomRight.y() / curr_frame_size.height() + selected_rect = win.selected_rect_normalized() if hasattr(win, "selected_rect_normalized") else None + if selected_rect: + x1 = float(selected_rect.get("normalized_x", 0.0)) + y1 = float(selected_rect.get("normalized_y", 0.0)) + xw = float(selected_rect.get("normalized_width", 0.0)) + yh = float(selected_rect.get("normalized_height", 0.0)) + else: + topLeft = win.videoPreview.regionTopLeftHandle + bottomRight = win.videoPreview.regionBottomRightHandle + curr_frame_size = win.videoPreview.curr_frame_size + if not topLeft or not bottomRight or not curr_frame_size: + QMessageBox.warning( + self, + _("Invalid Region"), + _("Please draw a rectangle region before clicking Select Region."), + ) + return + x1 = topLeft.x() / curr_frame_size.width() + y1 = topLeft.y() / curr_frame_size.height() + x2 = bottomRight.x() / curr_frame_size.width() + y2 = bottomRight.y() / curr_frame_size.height() + xw = x2 - x1 + yh = y2 - y1 # Get QImage of region - if win.videoPreview.region_qimage: + region_qimage = win.selected_region_qimage() if hasattr(win, "selected_region_qimage") else None + if region_qimage is None and win.videoPreview.region_qimage: region_qimage = win.videoPreview.region_qimage + if region_qimage: # Resize QImage to match button size resized_qimage = region_qimage.scaled(widget.size(), Qt.IgnoreAspectRatio, Qt.SmoothTransformation) @@ -317,13 +334,12 @@ def rect_select_clicked(self, widget, param): widget.setText("") # If data found, add to context - if topLeft and bottomRight: - self.context[param["setting"]].update({"normalized_x": x1, "normalized_y": y1, - "normalized_width": x2-x1, - "normalized_height": y2-y1, - "first-frame": win.current_frame, - }) - log.info(self.context) + self.context[param["setting"]].update({"normalized_x": x1, "normalized_y": y1, + "normalized_width": xw, + "normalized_height": yh, + "first-frame": win.current_frame, + }) + log.info(self.context) else: log.error('No file found with path: %s' % reader_path) diff --git a/src/windows/region.py b/src/windows/region.py index 84d5b26b5..0e0d0bfae 100644 --- a/src/windows/region.py +++ b/src/windows/region.py @@ -31,6 +31,7 @@ import math from PyQt5.QtCore import * +from PyQt5.QtGui import QIcon, QPainter, QColor, QPen, QBrush from PyQt5.QtWidgets import * import openshot # Python module for libopenshot (required video editing module installed separately) @@ -43,6 +44,126 @@ import json + +class RegionAnnotatedSlider(QSlider): + frameClicked = pyqtSignal(int) + + def __init__(self, orientation, parent=None): + super().__init__(orientation, parent) + self.setMouseTracking(True) + self._total_frames = 1 + self._current_frame = 1 + self._markers = [] # list of (frame, kind) + self._marker_positions = [] # list of (x, y, frame) + + def set_frames(self, total_frames, current_frame, markers): + self._total_frames = int(max(1, total_frames or 1)) + self._current_frame = int(max(1, current_frame or 1)) + self._markers = sorted(list(markers or []), key=lambda item: int(item[0])) + self.update() + + def _groove_rect(self): + opt = QStyleOptionSlider() + self.initStyleOption(opt) + return self.style().subControlRect(QStyle.CC_Slider, opt, QStyle.SC_SliderGroove, self) + + def _handle_rect(self): + opt = QStyleOptionSlider() + self.initStyleOption(opt) + return self.style().subControlRect(QStyle.CC_Slider, opt, QStyle.SC_SliderHandle, self) + + def _x_for_frame(self, frame): + groove = self._groove_rect() + left = float(groove.left()) + right = float(groove.right()) + span = max(1.0, right - left) + if self._total_frames <= 1: + return int(round(left)) + ratio = float(max(1, min(self._total_frames, int(frame))) - 1) / float(max(1, self._total_frames - 1)) + return int(round(left + ratio * span)) + + def mousePressEvent(self, event): + if event.button() != Qt.LeftButton: + return super().mousePressEvent(event) + handle = self._handle_rect().adjusted(-4, -4, 4, 4) + if handle.contains(event.pos()): + return super().mousePressEvent(event) + if not self._marker_positions: + return super().mousePressEvent(event) + click_x = int(event.pos().x()) + click_y = int(event.pos().y()) + nearest = min( + self._marker_positions, + key=lambda item: (int(item[0]) - click_x) * (int(item[0]) - click_x) + (int(item[1]) - click_y) * (int(item[1]) - click_y), + ) + dx = abs(int(nearest[0]) - click_x) + dy = abs(int(nearest[1]) - click_y) + if dx <= 8 and dy <= 8: + self.setValue(int(nearest[2])) + self.frameClicked.emit(int(nearest[2])) + event.accept() + return + return super().mousePressEvent(event) + + def mouseMoveEvent(self, event): + handle = self._handle_rect().adjusted(-4, -4, 4, 4) + if handle.contains(event.pos()): + self.unsetCursor() + return super().mouseMoveEvent(event) + if self._marker_positions: + x = int(event.pos().x()) + y = int(event.pos().y()) + nearest = min( + self._marker_positions, + key=lambda item: (int(item[0]) - x) * (int(item[0]) - x) + (int(item[1]) - y) * (int(item[1]) - y), + ) + if abs(int(nearest[0]) - x) <= 8 and abs(int(nearest[1]) - y) <= 8: + self.setCursor(Qt.PointingHandCursor) + else: + self.unsetCursor() + else: + self.unsetCursor() + return super().mouseMoveEvent(event) + + def leaveEvent(self, event): + self.unsetCursor() + return super().leaveEvent(event) + + def paintEvent(self, event): + super().paintEvent(event) + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing, True) + + groove = self._groove_rect() + # Center markers on the slider groove line. + mid_y = int(groove.center().y()) + + self._marker_positions = [] + for frame, kind in self._markers: + x = self._x_for_frame(frame) + self._marker_positions.append((x, mid_y, int(frame))) + if kind == "both": + painter.setPen(Qt.NoPen) + painter.setBrush(QBrush(QColor("#53A0ED"))) + painter.drawEllipse(x - 4, mid_y - 4, 8, 8) + painter.setBrush(QBrush(QColor("#E05757"))) + painter.drawEllipse(x, mid_y, 8, 8) + elif kind == "negative": + painter.setPen(Qt.NoPen) + painter.setBrush(QBrush(QColor("#E05757"))) + painter.drawEllipse(x - 4, mid_y - 4, 8, 8) + else: + painter.setPen(Qt.NoPen) + painter.setBrush(QBrush(QColor("#53A0ED"))) + painter.drawEllipse(x - 4, mid_y - 4, 8, 8) + + # Current-frame indicator + cx = self._x_for_frame(self._current_frame) + cur_pen = QPen(QColor("#EAF5FF"), 1) + painter.setPen(cur_pen) + painter.drawLine(cx, max(0, groove.top() - 5), cx, min(self.height() - 1, groove.bottom() + 5)) + painter.end() + class SelectRegion(QDialog): """ SelectRegion Dialog """ @@ -60,33 +181,56 @@ class SelectRegion(QDialog): SpeedSignal = pyqtSignal(float) StopSignal = pyqtSignal() - def __init__(self, file=None, clip=None, selection_mode="rect"): + def __init__(self, file=None, clip=None, selection_mode="rect", parent=None): _ = get_app()._tr # Create dialog class - QDialog.__init__(self) + QDialog.__init__(self, parent) # Load UI from designer ui_util.load_ui(self, self.ui_path) # Init UI ui_util.init_ui(self) - self.setWindowFlags( - (self.windowFlags() & ~Qt.Dialog) - | Qt.Window - | Qt.WindowMinMaxButtonsHint - | Qt.WindowMaximizeButtonHint - ) + if parent is None: + self.setWindowFlags( + (self.windowFlags() & ~Qt.Dialog) + | Qt.Window + | Qt.WindowMinMaxButtonsHint + | Qt.WindowMaximizeButtonHint + ) + else: + self.setWindowModality(Qt.WindowModal) self.setSizeGripEnabled(True) # Track metrics track_metric_screen("cutting-screen") self.selection_mode = str(selection_mode or "rect").strip().lower() - if self.selection_mode not in ("rect", "point"): + if self.selection_mode not in ("rect", "point", "annotate"): self.selection_mode = "rect" + if self.selection_mode == "annotate": + # Replace stock UI slider with custom-painted annotation slider + # so marker dots are perfectly aligned to the slider groove. + original_slider = self.sliderVideo + custom_slider = RegionAnnotatedSlider(Qt.Horizontal, original_slider.parent()) + custom_slider.setObjectName(original_slider.objectName()) + custom_slider.setTracking(original_slider.hasTracking()) + if hasattr(self, "horizontalLayout_3"): + self.horizontalLayout_3.replaceWidget(original_slider, custom_slider) + original_slider.hide() + original_slider.deleteLater() + self.sliderVideo = custom_slider + self.sliderVideo.frameClicked.connect(self._on_marker_frame_clicked) self._selected_points = [] self._selected_points_negative = [] + self._selected_payload = {} + self._selected_rect_normalized = None + self._selected_region_qimage = None + self.frame_annotations = {} + self._last_annotation_frame = 1 + self._frame_has_local_keyframe = False + self._frame_edited = False self.start_frame = 1 self.start_image = None @@ -143,6 +287,10 @@ def __init__(self, file=None, clip=None, selection_mode="rect"): self.lblInstructions.setText( _("Click to add tracking point (SHIFT+Click for additional points, CTRL+Click for negative point)") ) + elif self.selection_mode == "annotate": + self.lblInstructions.setText( + _("Choose a tool and mark positive/negative points or rectangles. Scrub to edit selections by frame.") + ) else: self.lblInstructions.setText(_("Draw a rectangle to select a region of the video frame.")) @@ -150,7 +298,10 @@ def __init__(self, file=None, clip=None, selection_mode="rect"): self.videoPreview = VideoWidget() self.videoPreview.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) self.videoPreview.region_selection_mode = self.selection_mode + self.videoPreview.regionAnnotationChanged.connect(self._on_video_annotation_changed) self.verticalLayout.insertWidget(1, self.videoPreview) + if self.selection_mode == "annotate": + self._build_annotation_toolbar() # Set aspect ratio to match source content aspect_ratio = openshot.Fraction(self.width, self.height) @@ -207,7 +358,12 @@ def __init__(self, file=None, clip=None, selection_mode="rect"): # Add buttons self.cancel_button = QPushButton(_('Cancel')) - process_label = _('Select Region') if self.selection_mode == "rect" else _('Select Point(s)') + if self.selection_mode == "rect": + process_label = _('Select Region') + elif self.selection_mode == "annotate": + process_label = _('Apply Selections') + else: + process_label = _('Select Point(s)') self.process_button = QPushButton(process_label) self.buttonBox.addButton(self.process_button, QDialogButtonBox.AcceptRole) self.buttonBox.addButton(self.cancel_button, QDialogButtonBox.RejectRole) @@ -217,6 +373,10 @@ def __init__(self, file=None, clip=None, selection_mode="rect"): self.btnPlay.clicked.connect(self.btnPlay_clicked) self.sliderVideo.valueChanged.connect(self.sliderVideo_valueChanged) self.initialized = True + if self.selection_mode == "annotate": + self._load_frame_annotations(1) + self._update_defined_frames_label() + self._refresh_marker_bar() get_app().window.SelectRegionSignal.emit(self.clip.Id()) @@ -224,9 +384,263 @@ def actionPlay_Triggered(self): # Trigger play button (This action is invoked from the preview thread, so it must exist here) self.btnPlay.click() + def _icon_path(self, name): + for theme_name in ("cosmic", "cosmic-dusk"): + path = os.path.join(info.PATH, "themes", theme_name, "images", name) + if os.path.exists(path): + return path + return "" + + def _build_annotation_toolbar(self): + _ = get_app()._tr + self.annotation_toolbar = QHBoxLayout() + self.annotation_toolbar.setContentsMargins(0, 0, 0, 0) + self.annotation_toolbar.setSpacing(6) + + self.annotation_tool_group = QButtonGroup(self) + self.annotation_tool_group.setExclusive(True) + tool_style = ( + "QToolButton {" + " background-color: rgba(20, 25, 35, 0.95);" + " color: #91C3FF;" + " border: 1px solid rgba(145, 195, 255, 0.22);" + " border-radius: 4px;" + " padding: 4px;" + "}" + "QToolButton:hover {" + " border: 1px solid rgba(145, 195, 255, 0.55);" + "}" + "QToolButton:checked {" + " background-color: #1F3952;" + " border: 2px solid #53A0ED;" + "}" + ) + + tool_defs = [ + ("positive_point", _("Positive Point"), "ai-track-point-positive.svg"), + ("negative_point", _("Negative Point"), "ai-track-point-negative.svg"), + ("positive_rect", _("Positive Rectangle"), "ai-track-rect-positive.svg"), + ("negative_rect", _("Negative Rectangle"), "ai-track-rect-negative.svg"), + ] + self.annotation_tool_buttons = {} + for tool_id, tooltip, icon_name in tool_defs: + btn = QToolButton(self) + btn.setCheckable(True) + btn.setToolTip(tooltip) + btn.setIconSize(QSize(18, 18)) + btn.setMinimumSize(QSize(28, 28)) + icon_path = self._icon_path(icon_name) + if icon_path: + btn.setIcon(QIcon(icon_path)) + else: + btn.setText(tooltip) + btn.setStyleSheet(tool_style) + btn.clicked.connect(lambda checked=False, t=tool_id: self._on_annotation_tool_changed(t)) + self.annotation_tool_group.addButton(btn) + self.annotation_tool_buttons[tool_id] = btn + self.annotation_toolbar.addWidget(btn) + + self.btnClearAnnotation = QToolButton(self) + self.btnClearAnnotation.setToolTip(_("Clear All Selections")) + self.btnClearAnnotation.setIconSize(QSize(18, 18)) + self.btnClearAnnotation.setMinimumSize(QSize(28, 28)) + trash_icon = self._icon_path("track-delete-enabled.svg") + if trash_icon: + self.btnClearAnnotation.setIcon(QIcon(trash_icon)) + else: + self.btnClearAnnotation.setText(_("Reset")) + self.btnClearAnnotation.setStyleSheet(tool_style) + self.btnClearAnnotation.clicked.connect(self._clear_current_frame_annotations) + self.annotation_toolbar.addSpacing(8) + self.annotation_toolbar.addWidget(self.btnClearAnnotation) + self.annotation_toolbar.addStretch(1) + + self.lblDefinedFrames = QLabel("") + self.annotation_toolbar.addWidget(self.lblDefinedFrames) + self.verticalLayout.insertLayout(1, self.annotation_toolbar) + + # Default tool + default_btn = self.annotation_tool_buttons.get("positive_point") + if default_btn: + default_btn.setChecked(True) + self._on_annotation_tool_changed("positive_point") + + def _on_annotation_tool_changed(self, tool_id): + if hasattr(self, "videoPreview") and self.videoPreview is not None: + self.videoPreview.region_annotation_tool = str(tool_id or "positive_point") + + def _capture_current_annotation(self): + def _points_to_payload(items): + payload = [] + for p in items or []: + try: + payload.append({"x": float(p.x()), "y": float(p.y())}) + except Exception: + continue + return payload + + def _rects_to_payload(items): + payload = [] + for r in items or []: + if not isinstance(r, QRectF): + continue + n = r.normalized() + payload.append({ + "x1": float(n.left()), + "y1": float(n.top()), + "x2": float(n.right()), + "y2": float(n.bottom()), + }) + return payload + + return { + "positive_points": _points_to_payload(self.videoPreview.region_points_positive), + "negative_points": _points_to_payload(self.videoPreview.region_points_negative), + "positive_rects": _rects_to_payload(self.videoPreview.region_rects_positive), + "negative_rects": _rects_to_payload(self.videoPreview.region_rects_negative), + } + + def _has_any_annotation(self, payload): + return bool( + (payload.get("positive_points") or []) + or (payload.get("negative_points") or []) + or (payload.get("positive_rects") or []) + or (payload.get("negative_rects") or []) + ) + + def _save_current_frame_annotations(self, force=False): + if self.selection_mode != "annotate": + return + frame = int(max(1, self.current_frame)) + if (not force) and (not self._frame_has_local_keyframe) and (not self._frame_edited): + return + payload = self._capture_current_annotation() + if self._has_any_annotation(payload): + self.frame_annotations[frame] = payload + elif frame in self.frame_annotations: + self.frame_annotations.pop(frame, None) + self._frame_has_local_keyframe = frame in self.frame_annotations + self._frame_edited = False + self._update_defined_frames_label() + + def _load_frame_annotations(self, frame): + if self.selection_mode != "annotate": + return + frame = int(max(1, frame)) + payload = {} + inherited = False + if frame in self.frame_annotations: + payload = dict(self.frame_annotations.get(frame, {})) + else: + prior_frames = [f for f in self.frame_annotations.keys() if int(f) <= frame] + if prior_frames: + nearest = int(sorted(prior_frames)[-1]) + payload = dict(self.frame_annotations.get(nearest, {})) + inherited = True + self.videoPreview.region_points_positive = [ + QPointF(float(p.get("x", 0.0)), float(p.get("y", 0.0))) + for p in (payload.get("positive_points") or []) + if isinstance(p, dict) + ] + self.videoPreview.region_points_negative = [ + QPointF(float(p.get("x", 0.0)), float(p.get("y", 0.0))) + for p in (payload.get("negative_points") or []) + if isinstance(p, dict) + ] + self.videoPreview.region_rects_positive = [ + QRectF( + QPointF(float(r.get("x1", 0.0)), float(r.get("y1", 0.0))), + QPointF(float(r.get("x2", 0.0)), float(r.get("y2", 0.0))), + ).normalized() + for r in (payload.get("positive_rects") or []) + if isinstance(r, dict) + ] + self.videoPreview.region_rects_negative = [ + QRectF( + QPointF(float(r.get("x1", 0.0)), float(r.get("y1", 0.0))), + QPointF(float(r.get("x2", 0.0)), float(r.get("y2", 0.0))), + ).normalized() + for r in (payload.get("negative_rects") or []) + if isinstance(r, dict) + ] + self.videoPreview.region_rect_drag_start = None + self.videoPreview.region_rect_drag_current = None + self.videoPreview.region_annotation_inherited = bool(inherited) + self.videoPreview.update() + self._frame_has_local_keyframe = frame in self.frame_annotations + self._frame_edited = False + + def _clear_current_frame_annotations(self): + if self.selection_mode != "annotate": + return + self.frame_annotations = {} + self._frame_has_local_keyframe = False + self._frame_edited = False + self.videoPreview.region_points_positive = [] + self.videoPreview.region_points_negative = [] + self.videoPreview.region_rects_positive = [] + self.videoPreview.region_rects_negative = [] + self.videoPreview.region_rect_drag_start = None + self.videoPreview.region_rect_drag_current = None + self.videoPreview.region_annotation_inherited = False + self.videoPreview.update() + self._update_defined_frames_label() + self._refresh_marker_bar() + + def _on_video_annotation_changed(self): + if self.selection_mode != "annotate": + return + self._frame_edited = True + self._save_current_frame_annotations(force=True) + self._refresh_marker_bar() + + def _update_defined_frames_label(self): + if self.selection_mode != "annotate" or not hasattr(self, "lblDefinedFrames"): + return + frames = sorted(self.frame_annotations.keys()) + if not frames: + self.lblDefinedFrames.setText("") + return + preview = ", ".join(str(f) for f in frames[:10]) + if len(frames) > 10: + preview = "{} ...".format(preview) + self.lblDefinedFrames.setText(get_app()._tr("Frames: {}").format(preview)) + self._refresh_marker_bar() + + def _refresh_marker_bar(self): + if self.selection_mode != "annotate" or not hasattr(self, "sliderVideo"): + return + if not hasattr(self.sliderVideo, "set_frames"): + return + markers = [] + for frame in sorted(self.frame_annotations.keys()): + payload = self.frame_annotations.get(frame, {}) or {} + has_pos = bool((payload.get("positive_points") or []) or (payload.get("positive_rects") or [])) + has_neg = bool((payload.get("negative_points") or []) or (payload.get("negative_rects") or [])) + kind = "both" if (has_pos and has_neg) else ("negative" if has_neg else "positive") + markers.append((int(frame), kind)) + self.sliderVideo.set_frames(self.video_length, self.current_frame, markers) + + def _on_marker_frame_clicked(self, frame_number): + frame_number = int(max(1, min(int(frame_number), int(self.video_length)))) + self.sliderVideo.setValue(frame_number) + + def selection_payload(self): + return dict(self._selected_payload or {}) + + def selected_rect_normalized(self): + if isinstance(self._selected_rect_normalized, dict): + return dict(self._selected_rect_normalized) + return None + + def selected_region_qimage(self): + return self._selected_region_qimage + def movePlayhead(self, frame_number): """Update the playhead position""" + if self.selection_mode == "annotate" and int(frame_number) != int(self.current_frame): + self._save_current_frame_annotations() self.current_frame = frame_number # Move slider to correct frame position self.sliderIgnoreSignal = True @@ -242,6 +656,9 @@ def movePlayhead(self, frame_number): # Update label self.lblVideoTime.setText(timestamp) + if self.selection_mode == "annotate": + self._load_frame_annotations(frame_number) + self._refresh_marker_bar() def btnPlay_clicked(self, force=None): log.info("btnPlay_clicked") @@ -266,6 +683,11 @@ def btnPlay_clicked(self, force=None): def sliderVideo_valueChanged(self, new_frame): if self.preview_thread and not self.sliderIgnoreSignal: log.info('sliderVideo_valueChanged: %s' % new_frame) + if self.selection_mode == "annotate": + self._save_current_frame_annotations() + self.current_frame = int(new_frame) + self._load_frame_annotations(new_frame) + self._refresh_marker_bar() # Pause video self.btnPlay_clicked(force="pause") @@ -279,8 +701,8 @@ def accept(self): app = get_app() _ = app._tr - # Check if the sliderVideo is not at its minimum value - if self.sliderVideo.value() != self.sliderVideo.minimum(): + # Legacy behavior for rect/point modes: require frame 1 selection. + if self.selection_mode in ("rect", "point") and self.sliderVideo.value() != self.sliderVideo.minimum(): # Show a warning message box to the user QMessageBox.warning(self, _("Invalid Region"), _("Please choose a region at the beginning of the clip")) @@ -292,6 +714,48 @@ def accept(self): if self.selection_mode == "point" and not self.videoPreview.region_points_positive: QMessageBox.warning(self, _("Invalid Selection"), _("Please select at least one point.")) return + if self.selection_mode == "rect": + top_left = getattr(self.videoPreview, "regionTopLeftHandle", None) + bottom_right = getattr(self.videoPreview, "regionBottomRightHandle", None) + if top_left is None or bottom_right is None: + QMessageBox.warning(self, _("Invalid Selection"), _("Please draw a rectangle region.")) + return + curr_frame_size = getattr(self.videoPreview, "curr_frame_size", None) + if curr_frame_size and curr_frame_size.width() > 0 and curr_frame_size.height() > 0: + x1 = float(top_left.x()) / float(curr_frame_size.width()) + y1 = float(top_left.y()) / float(curr_frame_size.height()) + x2 = float(bottom_right.x()) / float(curr_frame_size.width()) + y2 = float(bottom_right.y()) / float(curr_frame_size.height()) + left = min(x1, x2) + top = min(y1, y2) + right = max(x1, x2) + bottom = max(y1, y2) + self._selected_rect_normalized = { + "normalized_x": left, + "normalized_y": top, + "normalized_width": max(0.0, right - left), + "normalized_height": max(0.0, bottom - top), + } + region_qimage = getattr(self.videoPreview, "region_qimage", None) + if region_qimage: + self._selected_region_qimage = region_qimage.copy() + if self.selection_mode == "annotate": + self._save_current_frame_annotations() + if not self.frame_annotations: + QMessageBox.warning(self, _("Invalid Selection"), _("Please select at least one point or rectangle.")) + return + sorted_frames = sorted(self.frame_annotations.keys()) + seed_frame = int(sorted_frames[0]) if sorted_frames else int(self.current_frame) + self._selected_payload = { + "version": 1, + "seed_frame": seed_frame, + "frames": { + str(frame): dict(self.frame_annotations.get(frame, {})) + for frame in sorted_frames + }, + } + else: + self._selected_payload = {} # Continue with the rest of the accept method self._selected_points = self.selected_points() diff --git a/src/windows/video_widget.py b/src/windows/video_widget.py index 7e27d8080..c671ea655 100644 --- a/src/windows/video_widget.py +++ b/src/windows/video_widget.py @@ -31,7 +31,7 @@ import uuid from PyQt5.QtCore import ( - Qt, QCoreApplication, QMutex, QTimer, + Qt, QCoreApplication, QMutex, QTimer, pyqtSignal, QPoint, QPointF, QSize, QSizeF, QRect, QRectF, ) from PyQt5.QtGui import ( @@ -50,6 +50,7 @@ class VideoWidget(QWidget, updates.UpdateInterface): """ A QWidget used on the video display widget """ + regionAnnotationChanged = pyqtSignal() def _snap_angle(self, angle_degrees, step_degrees=15.0): """Snap an angle to the nearest increment (degrees).""" @@ -660,7 +661,7 @@ def paintEvent(self, event, *args): painter.setTransform(self.region_transform) cs = self.cs - if self.region_selection_mode == "point": + if self.region_selection_mode in ("point", "annotate"): point_radius = max(2.0, (cs * 0.4) / max(self.zoom, 0.001)) if self.region_points_positive: pos_color = QColor("#53a0ed") @@ -680,6 +681,42 @@ def paintEvent(self, event, *args): painter.setBrush(QBrush(neg_color)) for pt in self.region_points_negative: painter.drawEllipse(pt, point_radius, point_radius) + # Draw positive rectangles + if self.region_rects_positive: + rect_pos_color = QColor("#53a0ed") + rect_pos_color.setAlphaF(self.handle_opacity) + rect_pos_pen = QPen(QBrush(rect_pos_color), 1.5) + rect_pos_pen.setCosmetic(True) + painter.setPen(rect_pos_pen) + painter.setBrush(Qt.NoBrush) + for rect in self.region_rects_positive: + if isinstance(rect, QRectF): + painter.drawRect(rect.normalized()) + + # Draw negative rectangles + if self.region_rects_negative: + rect_neg_color = QColor("#e05757") + rect_neg_color.setAlphaF(self.handle_opacity) + rect_neg_pen = QPen(QBrush(rect_neg_color), 1.5) + rect_neg_pen.setCosmetic(True) + painter.setPen(rect_neg_pen) + painter.setBrush(Qt.NoBrush) + for rect in self.region_rects_negative: + if isinstance(rect, QRectF): + painter.drawRect(rect.normalized()) + + # Draw current dragging rectangle preview + if self.region_rect_drag_start is not None and self.region_rect_drag_current is not None: + drag_color = QColor("#53a0ed") + if str(self.region_annotation_tool or "").endswith("negative_rect"): + drag_color = QColor("#e05757") + drag_color.setAlphaF(self.handle_opacity) + drag_pen = QPen(QBrush(drag_color), 1.5, Qt.DashLine) + drag_pen.setCosmetic(True) + painter.setPen(drag_pen) + painter.setBrush(Qt.NoBrush) + painter.drawRect(QRectF(self.region_rect_drag_start, self.region_rect_drag_current).normalized()) + elif self.regionTopLeftHandle and self.regionBottomRightHandle: color = QColor("#53a0ed") color.setAlphaF(self.handle_opacity) @@ -771,6 +808,32 @@ def mousePressEvent(self, event): self.region_points_positive = [point] self.region_points_negative = [] self.update() + elif self.region_enabled and self.region_selection_mode == "annotate" and event.button() == Qt.LeftButton: + self._ensure_region_transform() + point = self.region_transform_inverted.map(event.pos()) + point = self._clamp_region_point(point) + if bool(self.region_annotation_inherited): + # First edit on a carried frame should replace inherited selections. + self.region_points_positive = [] + self.region_points_negative = [] + self.region_rects_positive = [] + self.region_rects_negative = [] + self.region_rect_drag_start = None + self.region_rect_drag_current = None + self.region_annotation_inherited = False + tool = str(self.region_annotation_tool or "positive_point") + if tool == "positive_point": + self.region_points_positive.append(point) + self.update() + self.regionAnnotationChanged.emit() + elif tool == "negative_point": + self.region_points_negative.append(point) + self.update() + self.regionAnnotationChanged.emit() + elif tool in ("positive_rect", "negative_rect"): + self.region_rect_drag_start = QPointF(point) + self.region_rect_drag_current = QPointF(point) + self.update() # Ignore undo/redo history temporarily (to avoid a huge pile of undo/redo history) get_app().updates.ignore_history = True @@ -798,9 +861,28 @@ def mouseReleaseEvent(self, event): self.rotation_drag_value = None self.region_mode = None + if self.region_enabled and self.region_selection_mode == "annotate": + if self.region_rect_drag_start is not None and self.region_rect_drag_current is not None: + rect = QRectF(self.region_rect_drag_start, self.region_rect_drag_current).normalized() + if rect.width() >= 2.0 and rect.height() >= 2.0: + tool = str(self.region_annotation_tool or "positive_rect") + if tool == "negative_rect": + self.region_rects_negative.append(rect) + else: + self.region_rects_positive.append(rect) + self.region_rect_drag_start = None + self.region_rect_drag_current = None + self.update() + self.regionAnnotationChanged.emit() + # Save region image data (as QImage) # This can be used other widgets to display the selected region - if self.region_enabled and self.region_selection_mode != "point": + if ( + self.region_enabled + and self.region_selection_mode not in ("point", "annotate") + and self.regionTopLeftHandle is not None + and self.regionBottomRightHandle is not None + ): # Get region coordinates region_rect = QRectF( self.regionTopLeftHandle.x(), @@ -1224,6 +1306,17 @@ def mouseMoveEvent(self, event): self.update() if self.region_enabled: + if self.region_selection_mode == "annotate": + self.setCursor(Qt.CrossCursor) + self._ensure_region_transform() + if self.region_rect_drag_start is not None and self.mouse_pressed: + current = self.region_transform_inverted.map(event.pos()) + self.region_rect_drag_current = self._clamp_region_point(current) + self.update() + self.mouse_position = event.pos() + self.mutex.unlock() + return + if self.region_selection_mode == "point": self.setCursor(Qt.CrossCursor) self.mouse_position = event.pos() @@ -1988,6 +2081,10 @@ def regionTriggered(self, clip_id): self.region_points = [] self.region_points_positive = [] self.region_points_negative = [] + self.region_rects_positive = [] + self.region_rects_negative = [] + self.region_rect_drag_start = None + self.region_rect_drag_current = None self.regionTopLeftHandle = None self.regionBottomRightHandle = None get_app().window.refreshFrameSignal.emit() @@ -2103,9 +2200,15 @@ def __init__(self, watch_project=True, *args): self.region_transform_inverted = None self.region_enabled = False self.region_selection_mode = "rect" + self.region_annotation_tool = "positive_point" self.region_points = [] self.region_points_positive = [] self.region_points_negative = [] + self.region_rects_positive = [] + self.region_rects_negative = [] + self.region_rect_drag_start = None + self.region_rect_drag_current = None + self.region_annotation_inherited = False self.region_mode = None self.regionTopLeftHandle = None self.regionBottomRightHandle = None From 1b6aa176867f23b0ce539e0f225d538ddd071ec8 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 21 Feb 2026 14:54:39 -0600 Subject: [PATCH 18/27] - Fixed Track Object output routing: Blur/Highlight now only import final video, Mask only imports mask; blocked fallback paths that were re-adding extra outputs. - Fixed filename prefix handling so openshot_mask/openshot_gen suffixes are preserved (with payload appended), enabling reliable output filtering. - Expanded Generate/Highlight UX: simplified Prompt + Tracking tabs, added Qt color pickers, plus new Mask Brightness and Background Brightness controls wired end-to-end. --- src/classes/comfy_templates.py | 3 + src/classes/generation_queue.py | 27 ++- src/classes/generation_service.py | 78 +++++++- src/comfyui/video-blur-anything-sam2.json | 12 +- .../video-highlight-anything-sam2.json | 158 +++++++++++++++ src/comfyui/video-mask-anything-sam2.json | 144 ++++++++++++++ src/windows/generate.py | 181 ++++++++++++++---- src/windows/views/ai_tools_menu.py | 20 +- 8 files changed, 575 insertions(+), 48 deletions(-) create mode 100644 src/comfyui/video-highlight-anything-sam2.json create mode 100644 src/comfyui/video-mask-anything-sam2.json diff --git a/src/classes/comfy_templates.py b/src/classes/comfy_templates.py index 285b18384..adcc0d345 100644 --- a/src/classes/comfy_templates.py +++ b/src/classes/comfy_templates.py @@ -113,6 +113,7 @@ "openshotsam2videosegmentationaddpoints", "openshotsam2videosegmentationchunked", "openshotimageblurmasked", + "openshotimagehighlightmasked", } @@ -245,6 +246,7 @@ def _load_template(self, path, is_user, existing_ids): ) override_category = str(payload.get("menu_category") or payload.get("category") or "").strip().lower() + override_menu_parent = str(payload.get("menu_parent") or "").strip() override_output_type = str(payload.get("output_type") or payload.get("media_output") or "").strip().lower() override_icon = str(payload.get("action_icon") or payload.get("icon") or "").strip() override_open_dialog = payload.get("open_dialog", None) @@ -305,6 +307,7 @@ def _load_template(self, path, is_user, existing_ids): "needs_prompt": needs_prompt, "action_icon": override_icon, "open_dialog": open_dialog, + "menu_parent": override_menu_parent, } def _primary_output_type(self, output_types): diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index d9702425e..685ea51c9 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -80,7 +80,20 @@ def _history_prompt_meta(history_entry): create_time = int(client_payload.get("create_time", 0) or 0) return client_id, create_time - def _find_related_meta_batch_outputs(self, client, history_entry, save_node_ids): + @staticmethod + def _allow_unfiltered_output_fallback(template_id): + template_id = str(template_id or "").strip().lower() + # Track-object templates intentionally have multiple save nodes + # (mask/debug + final), so we must not relax save-node filtering. + if template_id in ( + "video-blur-anything-sam2", + "video-mask-anything-sam2", + "video-highlight-anything-sam2", + ): + return False + return True + + def _find_related_meta_batch_outputs(self, client, history_entry, save_node_ids, template_id=""): base_client_id, base_create_time = self._history_prompt_meta(history_entry) if not base_client_id: return [] @@ -109,7 +122,7 @@ def _find_related_meta_batch_outputs(self, client, history_entry, save_node_ids) continue outputs = ComfyClient.extract_file_outputs(entry, save_node_ids=save_node_ids) - if not outputs and save_node_ids: + if (not outputs) and save_node_ids and self._allow_unfiltered_output_fallback(template_id): outputs = ComfyClient.extract_file_outputs(entry, save_node_ids=None) if not outputs: continue @@ -155,6 +168,7 @@ def _run_comfy_job(self, job_id, request): client_id = request.get("client_id") or "openshot-qt" timeout_s = int(request.get("timeout_s") or 86400) # default 24 hours safety cap save_node_ids = list(request.get("save_node_ids") or []) + template_id = str(request.get("template_id") or "") cancel_event = request.get("cancel_event") client = ComfyClient(comfy_url) ws_client = None @@ -289,7 +303,12 @@ def _run_comfy_job(self, job_id, request): self.job_finished.emit(job_id, False, False, error_text, []) return if self._is_unfinished_meta_batch(history_entry): - image_outputs = self._find_related_meta_batch_outputs(client, history_entry, save_node_ids) + image_outputs = self._find_related_meta_batch_outputs( + client, + history_entry, + save_node_ids, + template_id=template_id, + ) if image_outputs: self.progress_changed.emit(job_id, 100) self._job_prompts.pop(job_id, None) @@ -299,7 +318,7 @@ def _run_comfy_job(self, job_id, request): # Keep polling progress/queue while waiting for follow-up prompt outputs. else: image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=save_node_ids) - if not image_outputs and save_node_ids: + if (not image_outputs) and save_node_ids and self._allow_unfiltered_output_fallback(template_id): # Fallback for workflows whose output node ids shift or emit non-standard keys. image_outputs = ComfyClient.extract_file_outputs(history_entry, save_node_ids=None) self.progress_changed.emit(job_id, 100) diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index c4b5802b6..5fc2388ab 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -226,6 +226,12 @@ def _prepare_nonlegacy_workflow( rectangles_negative_text="", auto_mode=False, tracking_selection=None, + highlight_color="", + highlight_opacity=0.0, + border_color="", + border_width=0, + mask_brightness=1.0, + background_brightness=1.0, ): workflow = self.template_registry.get_workflow_copy(template.get("id")) if not workflow: @@ -257,6 +263,24 @@ def _resolve_template_local_file(path_text): rectangles_negative_text = str(rectangles_negative_text or "").strip() auto_mode = bool(auto_mode) tracking_selection = tracking_selection if isinstance(tracking_selection, dict) else {} + highlight_color = str(highlight_color or "").strip() + border_color = str(border_color or "").strip() + try: + highlight_opacity = float(highlight_opacity) + except (TypeError, ValueError): + highlight_opacity = 0.0 + try: + border_width = int(border_width) + except (TypeError, ValueError): + border_width = 0 + try: + mask_brightness = float(mask_brightness) + except (TypeError, ValueError): + mask_brightness = 1.0 + try: + background_brightness = float(background_brightness) + except (TypeError, ValueError): + background_brightness = 1.0 media_type = str(source_file.data.get("media_type", "")).strip().lower() if source_file else "" applied_prompt = False loadimage_node_ids = [] @@ -389,8 +413,9 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): if "filename_prefix" in inputs: prefix_value = str(inputs.get("filename_prefix", "")).strip() if "/" in prefix_value: - head = prefix_value.rsplit("/", 1)[0] - inputs["filename_prefix"] = "{}/{}".format(head, payload_name) + head, tail = prefix_value.rsplit("/", 1) + tail = str(tail or "output").strip() + inputs["filename_prefix"] = "{}/{}_{}".format(head, tail, payload_name) else: inputs["filename_prefix"] = payload_name @@ -434,15 +459,18 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): inputs["tracking_selection_json"] = json.dumps(tracking_selection or {}) except Exception: inputs["tracking_selection_json"] = "{}" + if "dino_prompt" in inputs: + inputs["dino_prompt"] = str(prompt_text or "") - coords_text = coordinates_positive_text or prompt_text + coords_text = coordinates_positive_text points = _parse_sam2_points(coords_text) has_positive_rects = bool(rectangles_positive_text) + has_dino_prompt = bool(str(prompt_text or "").strip()) and ("dino_prompt" in inputs) auto_enabled = bool(inputs.get("auto_mode", False)) or auto_mode if "auto_mode" in inputs: inputs["auto_mode"] = bool(auto_enabled) - if ("blur-anything-sam2" in template_id) and (not points) and (not has_positive_rects) and (not auto_enabled): + if ("anything-sam2" in template_id) and (not points) and (not has_positive_rects) and (not auto_enabled) and (not has_dino_prompt): raise ValueError("No SAM2 seed was provided. Use Points, Rectangle, or Auto mode.") # New OpenShot node contract. @@ -490,6 +518,20 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): ): inputs["negative_rects_json"] = rectangles_negative_text + if class_flat in ("openshotimagehighlightmasked",) or class_type.strip() == "OpenShotImageHighlightMasked": + if "highlight_color" in inputs and highlight_color: + inputs["highlight_color"] = highlight_color + if "highlight_opacity" in inputs: + inputs["highlight_opacity"] = float(max(0.0, min(1.0, highlight_opacity))) + if "border_color" in inputs and border_color: + inputs["border_color"] = border_color + if "border_width" in inputs: + inputs["border_width"] = int(max(0, border_width)) + if "mask_brightness" in inputs: + inputs["mask_brightness"] = float(max(0.0, min(3.0, mask_brightness))) + if "background_brightness" in inputs: + inputs["background_brightness"] = float(max(0.0, min(3.0, background_brightness))) + if not source_path: continue @@ -556,7 +598,8 @@ def _select_bind_nodes(node_ids, path_keys, preferred_upload=None): return workflow - def _save_nodes_for_workflow(self, workflow): + def _save_nodes_for_workflow(self, workflow, template_id=None): + template_id = str(template_id or "").strip().lower() save_nodes = [] for node_id, node in workflow.items(): if not isinstance(node, dict): @@ -570,6 +613,22 @@ def _save_nodes_for_workflow(self, workflow): "openshottransnetscenedetect", "vhs_videocombine", ): + if class_type == "vhs_videocombine": + inputs = node.get("inputs", {}) if isinstance(node.get("inputs", {}), dict) else {} + prefix = str(inputs.get("filename_prefix", "")).strip().lower() + is_mask_output = ("openshot_mask" in prefix) + is_track_template = template_id in ( + "video-blur-anything-sam2", + "video-highlight-anything-sam2", + "video-mask-anything-sam2", + ) + if is_track_template: + if template_id == "video-mask-anything-sam2": + if not is_mask_output: + continue + else: + if is_mask_output: + continue save_nodes.append(str(node_id)) return save_nodes @@ -809,6 +868,12 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No rectangles_negative_text=payload.get("rectangles_negative"), auto_mode=payload.get("auto_mode"), tracking_selection=payload.get("tracking_selection"), + highlight_color=payload.get("highlight_color"), + highlight_opacity=payload.get("highlight_opacity"), + border_color=payload.get("border_color"), + border_width=payload.get("border_width"), + mask_brightness=payload.get("mask_brightness"), + background_brightness=payload.get("background_brightness"), ) except Exception as ex: QMessageBox.information(self.win, "Invalid Input", str(ex)) @@ -818,7 +883,8 @@ def action_generate_trigger(self, checked=True, source_file=None, template_id=No "workflow": workflow, "client_id": "openshot-qt", "timeout_s": 21600, - "save_node_ids": self._save_nodes_for_workflow(workflow), + "save_node_ids": self._save_nodes_for_workflow(workflow, template_id=payload.get("template_id")), + "template_id": str(payload.get("template_id") or ""), } job_id = self.win.generation_queue.enqueue( payload_name, diff --git a/src/comfyui/video-blur-anything-sam2.json b/src/comfyui/video-blur-anything-sam2.json index b0cce2892..51c129794 100644 --- a/src/comfyui/video-blur-anything-sam2.json +++ b/src/comfyui/video-blur-anything-sam2.json @@ -2,7 +2,8 @@ "action_icon": "ai-action-smooth.svg", "menu_category": "enhance", "menu_order": 60, - "name": "Blur Anything...", + "name": "Blur...", + "menu_parent": "track_object", "open_dialog": true, "output_type": "video", "template_id": "video-blur-anything-sam2", @@ -54,7 +55,12 @@ "negative_points_json": "", "positive_rects_json": "", "negative_rects_json": "", - "tracking_selection_json": "{}" + "tracking_selection_json": "{}", + "dino_prompt": "", + "dino_model_id": "IDEA-Research/grounding-dino-tiny", + "dino_box_threshold": 0.35, + "dino_text_threshold": 0.25, + "dino_device": "auto" } }, "5": { @@ -235,4 +241,4 @@ } } } -} +} \ No newline at end of file diff --git a/src/comfyui/video-highlight-anything-sam2.json b/src/comfyui/video-highlight-anything-sam2.json new file mode 100644 index 000000000..e6082b9d5 --- /dev/null +++ b/src/comfyui/video-highlight-anything-sam2.json @@ -0,0 +1,158 @@ +{ + "action_icon": "ai-action-smooth.svg", + "menu_category": "enhance", + "menu_parent": "track_object", + "menu_order": 62, + "name": "Highlight...", + "open_dialog": true, + "output_type": "video", + "template_id": "video-highlight-anything-sam2", + "workflow": { + "1": { + "class_type": "VHS_LoadVideo", + "inputs": { + "video": "__openshot_input__", + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1 + } + }, + "2": { + "class_type": "VHS_VideoInfoLoaded", + "inputs": { + "video_info": [ + "1", + 3 + ] + } + }, + "3": { + "class_type": "OpenShotDownloadAndLoadSAM2Model", + "inputs": { + "model": "sam2.1_hiera_base_plus.safetensors", + "segmentor": "video", + "device": "cuda", + "precision": "fp16" + } + }, + "4": { + "class_type": "OpenShotSam2VideoSegmentationAddPoints", + "inputs": { + "sam2_model": [ + "3", + 0 + ], + "frame_index": 0, + "object_index": 0, + "windowed_mode": true, + "offload_video_to_cpu": false, + "offload_state_to_cpu": false, + "auto_mode": false, + "positive_points_json": "", + "negative_points_json": "", + "positive_rects_json": "", + "negative_rects_json": "", + "tracking_selection_json": "{}", + "dino_prompt": "", + "dino_model_id": "IDEA-Research/grounding-dino-tiny", + "dino_box_threshold": 0.35, + "dino_text_threshold": 0.25, + "dino_device": "auto" + } + }, + "5": { + "class_type": "OpenShotSam2VideoSegmentationChunked", + "inputs": { + "sam2_model": [ + "4", + 0 + ], + "inference_state": [ + "4", + 1 + ], + "image": [ + "10", + 0 + ], + "start_frame": 0, + "chunk_size_frames": 96, + "keep_model_loaded": true, + "meta_batch": [ + "9", + 0 + ] + } + }, + "9": { + "class_type": "VHS_BatchManager", + "inputs": { + "frames_per_batch": 96 + } + }, + "10": { + "class_type": "VHS_LoadVideo", + "inputs": { + "video": "__openshot_input__", + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "meta_batch": [ + "9", + 0 + ] + } + }, + "13": { + "class_type": "OpenShotImageHighlightMasked", + "inputs": { + "image": [ + "10", + 0 + ], + "mask": [ + "5", + 0 + ], + "highlight_color": "#F5D742", + "highlight_opacity": 0.35, + "border_color": "#000000", + "border_width": 0, + "mask_brightness": 1.15, + "background_brightness": 0.75 + } + }, + "15": { + "class_type": "VHS_VideoCombine", + "inputs": { + "images": [ + "13", + 0 + ], + "frame_rate": [ + "2", + 0 + ], + "loop_count": 0, + "filename_prefix": "video/openshot_gen", + "format": "video/h264-mp4", + "pingpong": false, + "save_output": true, + "audio": [ + "1", + 2 + ], + "meta_batch": [ + "9", + 0 + ] + } + } + } +} diff --git a/src/comfyui/video-mask-anything-sam2.json b/src/comfyui/video-mask-anything-sam2.json new file mode 100644 index 000000000..86431e73a --- /dev/null +++ b/src/comfyui/video-mask-anything-sam2.json @@ -0,0 +1,144 @@ +{ + "action_icon": "ai-action-smooth.svg", + "menu_category": "enhance", + "menu_parent": "track_object", + "menu_order": 61, + "name": "Mask...", + "open_dialog": true, + "output_type": "video", + "template_id": "video-mask-anything-sam2", + "workflow": { + "1": { + "class_type": "VHS_LoadVideo", + "inputs": { + "video": "__openshot_input__", + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1 + } + }, + "2": { + "class_type": "VHS_VideoInfoLoaded", + "inputs": { + "video_info": [ + "1", + 3 + ] + } + }, + "3": { + "class_type": "OpenShotDownloadAndLoadSAM2Model", + "inputs": { + "model": "sam2.1_hiera_base_plus.safetensors", + "segmentor": "video", + "device": "cuda", + "precision": "fp16" + } + }, + "4": { + "class_type": "OpenShotSam2VideoSegmentationAddPoints", + "inputs": { + "sam2_model": [ + "3", + 0 + ], + "frame_index": 0, + "object_index": 0, + "windowed_mode": true, + "offload_video_to_cpu": false, + "offload_state_to_cpu": false, + "auto_mode": false, + "positive_points_json": "", + "negative_points_json": "", + "positive_rects_json": "", + "negative_rects_json": "", + "tracking_selection_json": "{}", + "dino_prompt": "", + "dino_model_id": "IDEA-Research/grounding-dino-tiny", + "dino_box_threshold": 0.35, + "dino_text_threshold": 0.25, + "dino_device": "auto" + } + }, + "5": { + "class_type": "OpenShotSam2VideoSegmentationChunked", + "inputs": { + "sam2_model": [ + "4", + 0 + ], + "inference_state": [ + "4", + 1 + ], + "image": [ + "10", + 0 + ], + "start_frame": 0, + "chunk_size_frames": 96, + "keep_model_loaded": true, + "meta_batch": [ + "9", + 0 + ] + } + }, + "6": { + "class_type": "MaskToImage", + "inputs": { + "mask": [ + "5", + 0 + ] + } + }, + "7": { + "class_type": "VHS_VideoCombine", + "inputs": { + "images": [ + "6", + 0 + ], + "frame_rate": [ + "2", + 0 + ], + "loop_count": 0, + "filename_prefix": "video/openshot_mask", + "format": "video/h264-mp4", + "pingpong": false, + "save_output": true, + "meta_batch": [ + "9", + 0 + ] + } + }, + "9": { + "class_type": "VHS_BatchManager", + "inputs": { + "frames_per_batch": 96 + } + }, + "10": { + "class_type": "VHS_LoadVideo", + "inputs": { + "video": "__openshot_input__", + "force_rate": 0, + "custom_width": 0, + "custom_height": 0, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "meta_batch": [ + "9", + 0 + ] + } + } + } +} \ No newline at end of file diff --git a/src/windows/generate.py b/src/windows/generate.py index 2303240a3..a314d11c6 100644 --- a/src/windows/generate.py +++ b/src/windows/generate.py @@ -27,18 +27,21 @@ import os import json +import functools from PyQt5.QtCore import Qt -from PyQt5.QtGui import QIcon, QPixmap +from PyQt5.QtGui import QIcon, QPixmap, QColor from PyQt5.QtWidgets import ( QDialog, QVBoxLayout, QHBoxLayout, QFormLayout, QLabel, QLineEdit, - QComboBox, QTextEdit, QTabWidget, QWidget, QPushButton, QMessageBox + QComboBox, QTextEdit, QTabWidget, QWidget, QPushButton, QMessageBox, + QDoubleSpinBox, QSpinBox ) from classes import info from classes.logger import log from classes.thumbnail import GetThumbPath from windows.region import SelectRegion +from windows.color_picker import ColorPicker class GenerateMediaDialog(QDialog): @@ -80,8 +83,10 @@ def __init__( self.tabs.setObjectName("generateTabs") self.page_prompt = self._build_prompt_tab() self.page_points = self._build_points_tab() + self.page_highlight = self._build_highlight_tab() self.prompt_tab_index = self.tabs.addTab(self.page_prompt, "Prompt") - self.points_tab_index = self.tabs.addTab(self.page_points, "Points") + self.points_tab_index = self.tabs.addTab(self.page_points, "Tracking") + self.highlight_tab_index = self.tabs.addTab(self.page_highlight, "Highlight") root.addWidget(self.tabs, 1) button_row = QHBoxLayout() @@ -118,13 +123,16 @@ def _current_coordinates_text(self): except Exception: pass prompt_text = self.prompt_edit.toPlainText().strip() - # Backward-compatible fallback: if prompt itself contains point JSON, treat it as coordinates. - if (not coordinates_positive) and prompt_text.startswith("[") and ("\"x\"" in prompt_text or "'x'" in prompt_text): - coordinates_positive = prompt_text return coordinates_positive, coordinates_negative, rects_positive, rects_negative, auto_mode, tracking_payload, prompt_text def get_payload(self): coordinates_positive, coordinates_negative, rects_positive, rects_negative, auto_mode, tracking_payload, prompt_text = self._current_coordinates_text() + highlight_color = self.highlight_color.name(QColor.HexArgb) if hasattr(self, "highlight_color") else "" + border_color = self.border_color.name(QColor.HexArgb) if hasattr(self, "border_color") else "" + border_width = int(self.border_width_spin.value()) if hasattr(self, "border_width_spin") else 0 + highlight_opacity = float(self.highlight_opacity_spin.value()) if hasattr(self, "highlight_opacity_spin") else 0.0 + mask_brightness = float(self.mask_brightness_spin.value()) if hasattr(self, "mask_brightness_spin") else 1.0 + background_brightness = float(self.background_brightness_spin.value()) if hasattr(self, "background_brightness_spin") else 1.0 return { "name": self.name_edit.text().strip(), "template_id": self.template_combo.currentData() or self.template_combo.currentText(), @@ -135,6 +143,12 @@ def get_payload(self): "rectangles_negative": rects_negative, "auto_mode": bool(auto_mode), "tracking_selection": tracking_payload, + "highlight_color": highlight_color, + "highlight_opacity": highlight_opacity, + "border_color": border_color, + "border_width": border_width, + "mask_brightness": mask_brightness, + "background_brightness": background_brightness, } def _build_top_block(self): @@ -194,7 +208,7 @@ def _build_prompt_tab(self): layout = QVBoxLayout(tab) layout.setContentsMargins(8, 8, 8, 8) self.prompt_edit = QTextEdit() - self.prompt_edit.setPlaceholderText("Describe what to generate...") + self.prompt_edit.setPlaceholderText("Prompt (optional)") self.prompt_edit.setMinimumHeight(140) layout.addWidget(self.prompt_edit) return tab @@ -204,14 +218,8 @@ def _build_points_tab(self): tab.setObjectName("pagePoints") layout = QVBoxLayout(tab) layout.setContentsMargins(8, 8, 8, 8) - self.mask_hint = QLabel( - "Open tracking selection tools to choose object regions across frames." - ) - self.mask_hint.setWordWrap(True) - layout.addWidget(self.mask_hint) - controls = QHBoxLayout() - self.pick_points_button = QPushButton("Choose object(s) for tracking") + self.pick_points_button = QPushButton("Select objects for tracking") self.clear_points_button = QPushButton("Clear") self.pick_points_button.clicked.connect(self._choose_tracking_clicked) self.clear_points_button.clicked.connect(self._clear_points_clicked) @@ -227,6 +235,103 @@ def _build_points_tab(self): layout.addStretch(1) return tab + def _build_highlight_tab(self): + tab = QWidget(self) + tab.setObjectName("pageHighlight") + layout = QFormLayout(tab) + layout.setContentsMargins(8, 8, 8, 8) + layout.setVerticalSpacing(8) + self.highlight_color = QColor("#2EA6FF") + self.highlight_color.setAlphaF(0.70) + self.border_color = QColor("#FFFFFF") + self.border_color.setAlphaF(1.0) + self.highlight_color_button = QPushButton("Choose Color") + self.highlight_color_button.clicked.connect(self._pick_highlight_color) + self.border_color_button = QPushButton("Choose Color") + self.border_color_button.clicked.connect(self._pick_border_color) + self.highlight_opacity_spin = QDoubleSpinBox() + self.highlight_opacity_spin.setRange(0.0, 1.0) + self.highlight_opacity_spin.setSingleStep(0.05) + self.highlight_opacity_spin.setValue(0.28) + self.border_width_spin = QSpinBox() + self.border_width_spin.setRange(0, 64) + self.border_width_spin.setValue(2) + self.mask_brightness_spin = QDoubleSpinBox() + self.mask_brightness_spin.setRange(0.0, 3.0) + self.mask_brightness_spin.setSingleStep(0.05) + self.mask_brightness_spin.setValue(1.15) + self.background_brightness_spin = QDoubleSpinBox() + self.background_brightness_spin.setRange(0.0, 3.0) + self.background_brightness_spin.setSingleStep(0.05) + self.background_brightness_spin.setValue(0.75) + layout.addRow("Background Color", self.highlight_color_button) + layout.addRow("Background Opacity", self.highlight_opacity_spin) + layout.addRow("Border Color", self.border_color_button) + layout.addRow("Border Width", self.border_width_spin) + layout.addRow("Mask Brightness", self.mask_brightness_spin) + layout.addRow("Background Brightness", self.background_brightness_spin) + self._update_highlight_color_button() + self._update_border_color_button() + return tab + + @staticmethod + def _best_contrast(bg): + colrgb = bg.getRgbF() + lum = (0.299 * colrgb[0] + 0.587 * colrgb[1] + 0.114 * colrgb[2]) + return QColor(Qt.white) if lum < 0.5 else QColor(Qt.black) + + def _color_callback(self, setter_fn, refresh_fn, color): + if not color or not color.isValid(): + return + setter_fn(color) + refresh_fn() + + def _pick_highlight_color(self): + callback = functools.partial( + self._color_callback, + self._set_highlight_color, + self._update_highlight_color_button, + ) + ColorPicker( + self.highlight_color, + parent=self, + title="Select a Color", + callback=callback, + ) + + def _pick_border_color(self): + callback = functools.partial( + self._color_callback, + self._set_border_color, + self._update_border_color_button, + ) + ColorPicker( + self.border_color, + parent=self, + title="Select a Color", + callback=callback, + ) + + def _set_highlight_color(self, color): + self.highlight_color = QColor(color) + + def _set_border_color(self, color): + self.border_color = QColor(color) + + def _update_highlight_color_button(self): + fg = self._best_contrast(self.highlight_color) + self.highlight_color_button.setStyleSheet( + "QPushButton{background-color:%s;color:%s;}" % + (self.highlight_color.name(QColor.HexArgb), fg.name()) + ) + + def _update_border_color_button(self): + fg = self._best_contrast(self.border_color) + self.border_color_button.setStyleSheet( + "QPushButton{background-color:%s;color:%s;}" % + (self.border_color.name(QColor.HexArgb), fg.name()) + ) + def _load_thumbnail(self): path = "" media_type = self.source_file.data.get("media_type") @@ -251,39 +356,46 @@ def _on_generate_clicked(self): if not self.name_edit.text().strip(): self.name_edit.setFocus(Qt.TabFocusReason) return - if self._is_sam2_point_template(): - coordinates_positive, _coordinates_negative, rects_positive, _rects_negative, auto_mode, _tracking_payload, _prompt_text = self._current_coordinates_text() - if (not auto_mode) and (not coordinates_positive) and (not rects_positive): + if self._is_track_object_template(): + coordinates_positive, _coordinates_negative, rects_positive, _rects_negative, auto_mode, _tracking_payload, prompt_text = self._current_coordinates_text() + if (not auto_mode) and (not coordinates_positive) and (not rects_positive) and (not str(prompt_text or "").strip()): QMessageBox.warning( self, "Missing Selection", - "No SAM2 seed was provided. Click 'Choose object(s) for tracking' in the Points tab.", + "No SAM2 seed was provided. Add tracking points/rectangles or enter a prompt.", ) self.tabs.setCurrentWidget(self.page_points) return self.accept() - def _is_sam2_point_template(self): + def _is_track_object_template(self): + template_id = str(self.template_combo.currentData() or "").strip().lower() + return template_id in ( + "video-blur-anything-sam2", + "video-mask-anything-sam2", + "video-highlight-anything-sam2", + ) + + def _is_highlight_template(self): template_id = str(self.template_combo.currentData() or "").strip().lower() - return "sam2" in template_id and "blur-anything" in template_id + return template_id == "video-highlight-anything-sam2" def _on_template_changed(self, index): _ = index - is_point_template = self._is_sam2_point_template() - self._set_tab_visible(self.prompt_tab_index, not is_point_template) - self._set_tab_visible(self.points_tab_index, is_point_template) - self.pick_points_button.setEnabled(bool(self.source_file) and is_point_template) - self.clear_points_button.setEnabled(is_point_template) - if is_point_template: - self.mask_hint.setText( - "Use tracking tools to choose positive/negative points or rectangles on any frame." - ) - self.pick_points_button.setText("Choose object(s) for tracking") + is_track_template = self._is_track_object_template() + is_highlight_template = self._is_highlight_template() + self._set_tab_visible(self.prompt_tab_index, is_track_template) + self._set_tab_visible(self.points_tab_index, is_track_template) + self._set_tab_visible(self.highlight_tab_index, is_track_template and is_highlight_template) + self.pick_points_button.setEnabled(bool(self.source_file) and is_track_template) + self.clear_points_button.setEnabled(is_track_template) + if is_track_template: + self.pick_points_button.setText("Select objects for tracking") self.tabs.setCurrentWidget(self.page_points) else: - self.mask_hint.setText( - "Point selection is available for SAM2 Blur Anything templates." - ) + self._set_tab_visible(self.prompt_tab_index, True) + self._set_tab_visible(self.points_tab_index, False) + self._set_tab_visible(self.highlight_tab_index, False) self.tabs.setCurrentWidget(self.page_prompt) def _choose_tracking_clicked(self): @@ -434,7 +546,8 @@ def _apply_dialog_theme(self): color: #91C3FF; } QDialog#generateDialog QTabWidget#generateTabs QWidget#pagePrompt, -QDialog#generateDialog QTabWidget#generateTabs QWidget#pagePoints { +QDialog#generateDialog QTabWidget#generateTabs QWidget#pagePoints, +QDialog#generateDialog QTabWidget#generateTabs QWidget#pageHighlight { background-color: #141923; border: none; } diff --git a/src/windows/views/ai_tools_menu.py b/src/windows/views/ai_tools_menu.py index 1cbb17409..e70a0d8d4 100644 --- a/src/windows/views/ai_tools_menu.py +++ b/src/windows/views/ai_tools_menu.py @@ -48,6 +48,11 @@ def add_ai_tools_menu(win, parent_menu, source_file=None): ai_menu = StyledContextMenu(title=title, parent=parent_menu) ai_menu.setIcon(_icon("tool-generate-sparkle.svg")) + parent_labels = { + "track_object": _("Track an Object"), + } + parent_menus = {} + inserted_style_separator = False for template in templates: template_key = str(template.get("template_id") or template.get("id") or "") @@ -58,10 +63,23 @@ def add_ai_tools_menu(win, parent_menu, source_file=None): ): ai_menu.addSeparator() inserted_style_separator = True + open_dialog = template.get("open_dialog") if not isinstance(open_dialog, bool): open_dialog = (source_file is None) or bool(template.get("needs_prompt", False)) - action = ai_menu.addAction(_(str(template.get("display_name", "")))) + + template_parent = str(template.get("menu_parent") or "").strip().lower() + target_menu = ai_menu + if key == "enhance" and template_parent: + if template_parent not in parent_menus: + submenu_title = parent_labels.get(template_parent, template_parent.replace("_", " ").title()) + submenu = StyledContextMenu(title=submenu_title, parent=ai_menu) + submenu.setIcon(_icon("tool-generate-sparkle.svg")) + ai_menu.addMenu(submenu) + parent_menus[template_parent] = submenu + target_menu = parent_menus[template_parent] + + action = target_menu.addAction(_(str(template.get("display_name", "")))) action.setIcon(_icon(win.generation_service.icon_for_template(template))) action.triggered.connect( partial( From 5da1f6a87cb39f7341e2f9c4bf65274080692d39 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 21 Feb 2026 15:21:35 -0600 Subject: [PATCH 19/27] Added a loop icon on video preview dialog, and SPACE now controls play/pause, and the play button resets correctly when end of video is reached. --- .../cosmic/images/tool-media-repeat.svg | 6 ++ src/windows/cutting.py | 63 ++++++++++++++++++- src/windows/preview_thread.py | 25 +++++--- src/windows/region.py | 1 + 4 files changed, 85 insertions(+), 10 deletions(-) create mode 100644 src/themes/cosmic/images/tool-media-repeat.svg diff --git a/src/themes/cosmic/images/tool-media-repeat.svg b/src/themes/cosmic/images/tool-media-repeat.svg new file mode 100644 index 000000000..637a484c1 --- /dev/null +++ b/src/themes/cosmic/images/tool-media-repeat.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/src/windows/cutting.py b/src/windows/cutting.py index d6009721b..e9b24001e 100644 --- a/src/windows/cutting.py +++ b/src/windows/cutting.py @@ -30,7 +30,8 @@ import json from PyQt5.QtCore import pyqtSignal, QTimer -from PyQt5.QtWidgets import QDialog, QMessageBox, QSizePolicy, QSlider +from PyQt5.QtGui import QIcon +from PyQt5.QtWidgets import QDialog, QMessageBox, QSizePolicy, QSlider, QToolButton, QLineEdit from PyQt5.QtCore import Qt, QEvent import openshot # Python module for libopenshot (required video editing module installed separately) @@ -66,6 +67,7 @@ def __init__(self, file=None, preview=False): self._preview_autoplay_attempts = 0 self._shutdown_in_progress = False self._close_after_shutdown = False + self.loop_playback = bool(preview) # Create dialog class QDialog.__init__(self) @@ -190,6 +192,8 @@ def __init__(self, file=None, preview=False): self.sliderVideo.setMaximum(self.video_length) self.sliderVideo.setSingleStep(1) self.sliderVideo.setPageStep(24) + if self.is_preview_mode: + self._build_preview_repeat_button() # Initialize first frame display. # For cutting mode, preserve the legacy two-step seek refresh. @@ -221,6 +225,42 @@ def __init__(self, file=None, preview=False): self.slider_timer.timeout.connect(self.sliderVideo_timeout) self.initialized = True + def _build_preview_repeat_button(self): + _ = get_app()._tr + self.btnRepeat = QToolButton(self) + self.btnRepeat.setObjectName("btnRepeat") + self.btnRepeat.setCheckable(True) + self.btnRepeat.setChecked(True) + self.btnRepeat.setAutoRaise(True) + self.btnRepeat.setFixedSize(24, 24) + self.btnRepeat.setToolTip(_("Repeat")) + self.btnRepeat.setStyleSheet( + "QToolButton#btnRepeat { border-radius: 4px; }" + "QToolButton#btnRepeat:checked { background-color: rgba(83,160,237,80); }" + ) + self.btnRepeat.toggled.connect(self._on_repeat_toggled) + self.horizontalLayout_3.insertWidget(2, self.btnRepeat) + + icon = ui_util.get_icon("media-playlist-repeat") + if not icon: + icon_path = os.path.join(info.PATH, "themes", "cosmic", "images", "tool-media-repeat.svg") + icon = QIcon(icon_path) + self.btnRepeat.setIcon(icon) + + def _on_repeat_toggled(self, checked): + self.loop_playback = bool(checked) + + def keyPressEvent(self, event): + if event and event.key() == Qt.Key_Space: + focused = self.focusWidget() + if focused and isinstance(focused, QLineEdit): + return super(Cutting, self).keyPressEvent(event) + if hasattr(self, "btnPlay") and self.btnPlay is not None: + self.btnPlay.click() + event.accept() + return + return super(Cutting, self).keyPressEvent(event) + def eventFilter(self, obj, event): if event.type() == event.KeyPress and obj is self.txtName: # Handle ENTER key to create new clip @@ -294,6 +334,14 @@ def btnPlay_clicked(self, force=None): if self.btnPlay.isChecked(): log.info('play (icon to pause)') ui_util.setup_icon(self, self.btnPlay, "actionPlay", "media-playback-pause") + # In non-loop mode, replay from the beginning when currently at end. + if not self.loop_playback: + try: + current_pos = int(self.preview_thread.player.Position()) + except Exception: + current_pos = 1 + if current_pos >= int(self.video_length): + self.SeekSignal.emit(1) self.PlaySignal.emit() else: log.info('pause (icon to play)') @@ -339,9 +387,20 @@ def _preview_ready(self): QTimer.singleShot(0, self._start_preview_autoplay) def _preview_mode_changed(self, mode): + play_mode = getattr(openshot, "PLAYBACK_PLAY", None) + paused_mode = getattr(openshot, "PLAYBACK_PAUSED", getattr(openshot, "PLAYBACK_PAUSE", None)) + stop_mode = getattr(openshot, "PLAYBACK_STOPPED", getattr(openshot, "PLAYBACK_STOP", None)) + + # Keep the play button state visually in sync with current playback mode. + if mode == play_mode and not self.btnPlay.isChecked(): + self.btnPlay.setChecked(True) + ui_util.setup_icon(self, self.btnPlay, "actionPlay", "media-playback-pause") + elif mode in (paused_mode, stop_mode) and self.btnPlay.isChecked(): + self.btnPlay.setChecked(False) + ui_util.setup_icon(self, self.btnPlay, "actionPlay", "media-playback-start") + if not self.is_preview_mode or not self._preview_autoplay_active: return - paused_mode = getattr(openshot, "PLAYBACK_PAUSED", getattr(openshot, "PLAYBACK_PAUSE", None)) if paused_mode is not None and mode == paused_mode: QTimer.singleShot(0, self._start_preview_autoplay) diff --git a/src/windows/preview_thread.py b/src/windows/preview_thread.py index ef668eccd..024286011 100644 --- a/src/windows/preview_thread.py +++ b/src/windows/preview_thread.py @@ -62,16 +62,25 @@ def onPositionChanged(self, current_frame): # Check if we are at the end of the timeline if self.worker.player.Mode() == openshot.PLAYBACK_PLAY: + loop_preview = bool(getattr(self.parent, "loop_playback", False)) if self.worker.player.Speed() > 0.0 and current_frame >= self.timeline_max_length: - # Yes, pause the video - self.parent.PauseSignal.emit() - # If the player got past the end of the project, go back. - self.worker.Seek(self.timeline_max_length) + if loop_preview: + # Loop preview playback back to the beginning. + self.worker.Seek(1) + else: + # Yes, pause the video + self.parent.PauseSignal.emit() + # If the player got past the end of the project, go back. + self.worker.Seek(self.timeline_max_length) if self.worker.player.Speed() < 0.0 and current_frame <= 1: - # If rewinding, and the player got past the first frame, - # pause and go to frame 1 - self.parent.PauseSignal.emit() - self.worker.Seek(1) + if loop_preview: + # Loop rewind to the end frame. + self.worker.Seek(self.timeline_max_length) + else: + # If rewinding, and the player got past the first frame, + # pause and go to frame 1 + self.parent.PauseSignal.emit() + self.worker.Seek(1) # Signal when the playback mode changes in the preview player (i.e PLAY, PAUSE, STOP) def onModeChanged(self, current_mode): diff --git a/src/windows/region.py b/src/windows/region.py index 0e0d0bfae..41949f14f 100644 --- a/src/windows/region.py +++ b/src/windows/region.py @@ -227,6 +227,7 @@ def __init__(self, file=None, clip=None, selection_mode="rect", parent=None): self._selected_payload = {} self._selected_rect_normalized = None self._selected_region_qimage = None + self.loop_playback = False self.frame_annotations = {} self._last_annotation_frame = 1 self._frame_has_local_keyframe = False From 58afbf2771217512013104a71416c88eddf3ae54 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 21 Feb 2026 17:25:36 -0600 Subject: [PATCH 20/27] - Prevented cross-job state contamination with a per-run cache key for windowed SAM2 state. - Improved chunk-boundary stability using richer carries (point+bbox) plus 4-frame boundary replay. - Added detailed DINO/SAM2 debug logging (currently always on). --- src/comfyui/video-blur-anything-sam2.json | 8 +++++++- src/comfyui/video-highlight-anything-sam2.json | 6 ++++++ src/comfyui/video-mask-anything-sam2.json | 8 +++++++- src/windows/generate.py | 7 ++++--- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/comfyui/video-blur-anything-sam2.json b/src/comfyui/video-blur-anything-sam2.json index 51c129794..5c2674b60 100644 --- a/src/comfyui/video-blur-anything-sam2.json +++ b/src/comfyui/video-blur-anything-sam2.json @@ -47,6 +47,7 @@ ], "frame_index": 0, "object_index": 0, + "video_path": "__openshot_input__", "windowed_mode": true, "offload_video_to_cpu": false, "offload_state_to_cpu": false, @@ -61,6 +62,11 @@ "dino_box_threshold": 0.35, "dino_text_threshold": 0.25, "dino_device": "auto" + , + "meta_batch": [ + "9", + 0 + ] } }, "5": { @@ -241,4 +247,4 @@ } } } -} \ No newline at end of file +} diff --git a/src/comfyui/video-highlight-anything-sam2.json b/src/comfyui/video-highlight-anything-sam2.json index e6082b9d5..596f67c9d 100644 --- a/src/comfyui/video-highlight-anything-sam2.json +++ b/src/comfyui/video-highlight-anything-sam2.json @@ -47,6 +47,7 @@ ], "frame_index": 0, "object_index": 0, + "video_path": "__openshot_input__", "windowed_mode": true, "offload_video_to_cpu": false, "offload_state_to_cpu": false, @@ -61,6 +62,11 @@ "dino_box_threshold": 0.35, "dino_text_threshold": 0.25, "dino_device": "auto" + , + "meta_batch": [ + "9", + 0 + ] } }, "5": { diff --git a/src/comfyui/video-mask-anything-sam2.json b/src/comfyui/video-mask-anything-sam2.json index 86431e73a..60df8e421 100644 --- a/src/comfyui/video-mask-anything-sam2.json +++ b/src/comfyui/video-mask-anything-sam2.json @@ -47,6 +47,7 @@ ], "frame_index": 0, "object_index": 0, + "video_path": "__openshot_input__", "windowed_mode": true, "offload_video_to_cpu": false, "offload_state_to_cpu": false, @@ -61,6 +62,11 @@ "dino_box_threshold": 0.35, "dino_text_threshold": 0.25, "dino_device": "auto" + , + "meta_batch": [ + "9", + 0 + ] } }, "5": { @@ -141,4 +147,4 @@ } } } -} \ No newline at end of file +} diff --git a/src/windows/generate.py b/src/windows/generate.py index a314d11c6..56a1d95d7 100644 --- a/src/windows/generate.py +++ b/src/windows/generate.py @@ -481,15 +481,16 @@ def _scale_rect_dict(r): rects_pos = list(seed_data.get("positive_rects", []) or []) rects_neg = list(seed_data.get("negative_rects", []) or []) - if (not points_pos) and (not rects_pos): + has_any_selection = bool(points_pos or points_neg or rects_pos or rects_neg) + if not has_any_selection: QMessageBox.warning( self, "No Selections Found", - "No positive points or rectangles were captured.", + "No points or rectangles were captured.", ) return - points_pos_text = json.dumps(points_pos) + points_pos_text = json.dumps(points_pos) if points_pos else "" points_neg_text = json.dumps(points_neg) if points_neg else "" rects_pos_text = json.dumps(rects_pos) if rects_pos else "" rects_neg_text = json.dumps(rects_neg) if rects_neg else "" From c41c2223e430cf886ab532031af2d7a2c96a8d1f Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sun, 22 Feb 2026 14:34:02 -0600 Subject: [PATCH 21/27] - Added Music... ACE 1.5 template + music icon. - Renamed Audio... to Sound.... - Fixed music runs: random seeds, no stale output reuse, and prompt-to-tags/lyrics handling. --- src/classes/generation_queue.py | 1 + src/classes/generation_service.py | 40 ++++++ src/comfyui/txt2audio-stable-open.json | 2 +- src/comfyui/txt2music-ace-step.json | 118 ++++++++++++++++++ .../cosmic/images/ai-action-create-music.svg | 9 ++ 5 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 src/comfyui/txt2music-ace-step.json create mode 100644 src/themes/cosmic/images/ai-action-create-music.svg diff --git a/src/classes/generation_queue.py b/src/classes/generation_queue.py index 685ea51c9..04a985988 100644 --- a/src/classes/generation_queue.py +++ b/src/classes/generation_queue.py @@ -89,6 +89,7 @@ def _allow_unfiltered_output_fallback(template_id): "video-blur-anything-sam2", "video-mask-anything-sam2", "video-highlight-anything-sam2", + "txt2music-ace-step", ): return False return True diff --git a/src/classes/generation_service.py b/src/classes/generation_service.py index 5fc2388ab..17892aa9c 100644 --- a/src/classes/generation_service.py +++ b/src/classes/generation_service.py @@ -29,6 +29,7 @@ import re import tempfile import json +import random from time import time from urllib.parse import unquote from fractions import Fraction @@ -257,6 +258,19 @@ def _resolve_template_local_file(path_text): template_id = str((template or {}).get("id") or "").strip().lower() prompt_text = str(prompt_text or "").strip() + music_prompt_text = prompt_text + music_lyrics_text = "" + if template_id == "txt2music-ace-step" and prompt_text: + # Optional inline format: + #