From 34446abc8099dba3b043645c409c297b7e830652 Mon Sep 17 00:00:00 2001 From: pghosh Date: Wed, 23 Jul 2025 10:24:03 +0200 Subject: [PATCH 1/2] Overriding the newdefault of 'sdpa' --- generate_multitalk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generate_multitalk.py b/generate_multitalk.py index 9a3b525..10ed931 100644 --- a/generate_multitalk.py +++ b/generate_multitalk.py @@ -252,9 +252,9 @@ def _parse_args(): return args def custom_init(device, wav2vec): - audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device) + audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True, attn_implementation="eager").to(device) audio_encoder.feature_extractor._freeze_parameters() - wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True) + wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True, attn_implementation="eager") return wav2vec_feature_extractor, audio_encoder def loudness_norm(audio_array, sr=16000, lufs=-23): From 61ed0074205bf4fed71799bf918c43d519f843c0 Mon Sep 17 00:00:00 2001 From: Partha Ghosh Date: Mon, 19 Jan 2026 10:13:54 +0100 Subject: [PATCH 2/2] Building CLI interface so it can be triggered from hubspot --- cli.py | 267 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ config.py | 6 ++ 2 files changed, 273 insertions(+) create mode 100644 cli.py create mode 100644 config.py diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..899bf26 --- /dev/null +++ b/cli.py @@ -0,0 +1,267 @@ +""" +@file cli.py +@brief Backend-facing CLI wrapper for the MultiTalk generator. +""" + +import argparse +import json +import os +import shutil +import subprocess +import sys +import uuid +from typing import Any, Dict, Tuple + +import config + + +def _run_command_streaming(command: list[str], cwd: str) -> None: + """ + @brief Run a subprocess while streaming stdout/stderr to the current process. + @param command Command list to execute. + @param cwd Working directory for the subprocess. + @throws RuntimeError when the command exits non-zero. + """ + + proc = subprocess.Popen( + command, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + tail: list[str] = [] + assert proc.stdout is not None + for line in proc.stdout: + sys.stdout.write(line) + sys.stdout.flush() + tail.append(line) + if len(tail) > 200: + tail.pop(0) + + rc = proc.wait() + if rc != 0: + tail_text = "".join(tail).strip() + raise RuntimeError( + f"multitalk generation failed with exit code {rc}. Last output:\n{tail_text}" + ) + + +def _resolve_path(base_dir: str, path_value: str) -> str: + """ + @brief Resolve a path relative to the multitalk repo if needed. + @param base_dir Base directory for relative resolution. + @param path_value Path to resolve. + @return Absolute path for the input value. + """ + + if os.path.isabs(path_value): + return path_value + return os.path.abspath(os.path.join(base_dir, path_value)) + + +def _ensure_kokoro_weights(repo_dir: str) -> None: + """ + @brief Ensure the Kokoro weights are reachable via `weights/Kokoro-82M` from repo_dir. + @details MultiTalk's `generate_multitalk.py` uses a relative repo_id (`weights/Kokoro-82M`), + so we provide a symlink to the absolute directory configured in `config.KOKORO_DIR`. + @param repo_dir MultiTalk repo directory. + """ + + kokoro_dir = getattr(config, "KOKORO_DIR", "") + if not kokoro_dir: + return + + weights_dir = os.path.join(repo_dir, "weights") + os.makedirs(weights_dir, exist_ok=True) + link_path = os.path.join(weights_dir, "Kokoro-82M") + + if os.path.exists(link_path): + return + + try: + os.symlink(kokoro_dir, link_path) + except Exception: + # If symlinks are not permitted, fall back to doing nothing; the generator will error clearly. + pass + + +def _select_avatar_assets(avatar_dir: str) -> Tuple[str, str]: + """ + @brief Select the avatar JSON and image file from the avatar directory. + @param avatar_dir Directory containing avatar assets. + @return Tuple of (json_path, image_path). + @throws RuntimeError if required files are missing. + """ + + if not os.path.isdir(avatar_dir): + raise RuntimeError(f"Avatar directory not found: {avatar_dir}") + + entries = sorted(os.listdir(avatar_dir)) + json_path = "" + image_path = "" + for name in entries: + path = os.path.join(avatar_dir, name) + if os.path.isdir(path): + continue + lower = name.lower() + if lower.endswith(".json") and not json_path: + json_path = path + elif lower.endswith((".png", ".jpg", ".jpeg", ".webp")) and not image_path: + image_path = path + + if not json_path or not image_path: + raise RuntimeError( + f"Avatar directory must contain one json and one image file: {avatar_dir}" + ) + + return json_path, image_path + + +def _build_input_payload( + data: Dict[str, Any], base_dir: str, avatar_json: str, avatar_image: str +) -> Dict[str, Any]: + """ + @brief Build the input payload expected by the multitalk generator. + @param data Raw job data from the backend JSON file. + @param base_dir Base directory for resolving paths. + @param avatar_json Path to the base JSON file. + @param avatar_image Path to the avatar image file. + @return Payload dictionary ready for multitalk input_json. + @throws RuntimeError when required fields are missing. + """ + + payload = _load_json(avatar_json) + payload["cond_image"] = _resolve_path(base_dir, avatar_image) + if "cond_audio" not in payload: + payload["cond_audio"] = {} + + tts_audio: Dict[str, Any] = data.get("tts_audio", {}) + speech_text = data.get("speech_text") + if not speech_text: + raise RuntimeError("speech_text is required for multitalk TTS mode") + + tts_audio["text"] = speech_text + + if tts_audio: + if "text" not in tts_audio: + raise RuntimeError("tts_audio provided but missing 'text'") + if "human1_voice" not in tts_audio: + tts_audio["human1_voice"] = config.TTS_VOICE + if "human2_voice" in tts_audio and tts_audio["human2_voice"]: + tts_audio["human2_voice"] = _resolve_path( + base_dir, tts_audio["human2_voice"] + ) + tts_audio["human1_voice"] = _resolve_path(base_dir, tts_audio["human1_voice"]) + payload["tts_audio"] = tts_audio + + return payload + + +def _load_json(path: str) -> Dict[str, Any]: + """ + @brief Load JSON content from disk. + @param path File path to read. + @return Parsed JSON dictionary. + @throws RuntimeError if JSON cannot be read. + """ + + try: + with open(path, "r", encoding="utf-8") as handle: + return json.load(handle) + except Exception as exc: + raise RuntimeError(f"Failed to read JSON from {path}: {exc}") from exc + + +def _write_json(path: str, payload: Dict[str, Any]) -> None: + """ + @brief Write JSON content to disk. + @param path File path to write. + @param payload JSON-serializable dictionary. + """ + + with open(path, "w", encoding="utf-8") as handle: + json.dump(payload, handle) + + +def main() -> None: + """ + @brief CLI entrypoint for backend-triggered multitalk generation. + @throws RuntimeError on invalid input or generation failure. + """ + + parser = argparse.ArgumentParser( + description="Backend wrapper for MultiTalk generation", + ) + parser.add_argument("--job-id", required=True) + parser.add_argument("--output", required=True) + parser.add_argument("--data", required=True) + args = parser.parse_args() + + repo_dir = os.path.dirname(os.path.abspath(__file__)) + work_dir = os.path.join(repo_dir, "backend_runs", args.job_id) + os.makedirs(work_dir, exist_ok=True) + + input_json_path = os.path.join(work_dir, f"{uuid.uuid4().hex}.json") + audio_save_dir = os.path.join(work_dir, "audio") + + data = _load_json(args.data) + avatar_json, avatar_image = _select_avatar_assets(config.AVATAR_DIR) + payload = _build_input_payload( + data=data, + base_dir=repo_dir, + avatar_json=avatar_json, + avatar_image=avatar_image, + ) + _write_json(input_json_path, payload) + + ckpt_dir = config.CKPT_DIR + wav2vec_dir = config.WAV2VEC_DIR + if not ckpt_dir or not wav2vec_dir: + raise RuntimeError("Missing CKPT_DIR or WAV2VEC_DIR in config.py") + + ckpt_dir = _resolve_path(repo_dir, ckpt_dir) + wav2vec_dir = _resolve_path(repo_dir, wav2vec_dir) + + _ensure_kokoro_weights(repo_dir) + + command = [ + sys.executable, + os.path.join(repo_dir, "generate_multitalk.py"), + "--ckpt_dir", + ckpt_dir, + "--wav2vec_dir", + wav2vec_dir, + "--input_json", + input_json_path, + "--sample_steps", + str(data.get("sample_steps", getattr(config, "SAMPLE_STEPS", 40))), + "--mode", + str(data.get("mode", "streaming")), + "--num_persistent_param_in_dit", + str(data.get("num_persistent_param_in_dit", 0)), + "--audio_mode", + "tts", + "--audio_save_dir", + audio_save_dir, + "--save_file", + os.path.splitext(args.output)[0], + ] + + if data.get("use_teacache", True): + command.append("--use_teacache") + + try: + _run_command_streaming(command, cwd=repo_dir) + finally: + try: + shutil.rmtree(work_dir) + except Exception: + pass + + +if __name__ == "__main__": + main() + diff --git a/config.py b/config.py new file mode 100644 index 0000000..a96405c --- /dev/null +++ b/config.py @@ -0,0 +1,6 @@ +AVATAR_DIR = '/home/web/partha/dev/MultiTalk/Input_outputs/input_files/sales_executive' +CKPT_DIR = '/home/web/partha/dev/MultiTalk/weights/Wan2.1-I2V-14B-480P' +WAV2VEC_DIR = '/home/web/partha/dev/MultiTalk/weights/chinese-wav2vec2-base' +KOKORO_DIR = '/home/web/partha/dev/MultiTalk/weights/Kokoro-82M' +TTS_VOICE = '/home/web/partha/dev/MultiTalk/weights/Kokoro-82M/voices/af_heart.pt' +SAMPLE_STEPS = 1 \ No newline at end of file