diff --git a/app/src/components/Generation/FloatingGenerateBox.tsx b/app/src/components/Generation/FloatingGenerateBox.tsx index a8d556a..b343663 100644 --- a/app/src/components/Generation/FloatingGenerateBox.tsx +++ b/app/src/components/Generation/FloatingGenerateBox.tsx @@ -13,7 +13,7 @@ import { } from '@/components/ui/select'; import { Textarea } from '@/components/ui/textarea'; import { useToast } from '@/components/ui/use-toast'; -import { LANGUAGE_OPTIONS } from '@/lib/constants/languages'; +import { LANGUAGE_OPTIONS, type LanguageCode } from '@/lib/constants/languages'; import { useGenerationForm } from '@/lib/hooks/useGenerationForm'; import { useProfile, useProfiles } from '@/lib/hooks/useProfiles'; import { useAddStoryItem, useStory } from '@/lib/hooks/useStories'; @@ -112,6 +112,13 @@ export function FloatingGenerateBox({ } }, [selectedProfileId, profiles, setSelectedProfileId]); + // Sync generation form language with selected profile's language + useEffect(() => { + if (selectedProfile?.language) { + form.setValue('language', selectedProfile.language as LanguageCode); + } + }, [selectedProfile, form]); + // Auto-resize textarea based on content (only when expanded) useEffect(() => { if (!isExpanded) { diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index c4ecc09..3de564e 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -14,6 +14,12 @@ from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback from ..utils.tasks import get_task_manager +LANGUAGE_CODE_TO_NAME = { + "zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean", + "de": "german", "fr": "french", "ru": "russian", "pt": "portuguese", + "es": "spanish", "it": "italian", +} + class MLXTTSBackend: """MLX-based TTS backend using mlx-audio.""" @@ -316,24 +322,25 @@ def _generate_sync(): # MLX generate() returns a generator yielding GenerationResult objects audio_chunks = [] sample_rate = 24000 - + lang = LANGUAGE_CODE_TO_NAME.get(language, "auto") + # Set seed if provided (MLX uses numpy random) if seed is not None: import mlx.core as mx np.random.seed(seed) mx.random.seed(seed) - + # Extract voice prompt info ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path") ref_text = voice_prompt.get("ref_text", "") - + # Validate that the audio file exists if ref_audio and not Path(ref_audio).exists(): print(f"Warning: Audio file not found: {ref_audio}") print("This may be due to a cached voice prompt referencing a deleted temp file.") print("Regenerating without voice prompt.") ref_audio = None - + # Check if model supports voice cloning via generate method # MLX API may support ref_audio parameter directly try: @@ -344,23 +351,23 @@ def _generate_sync(): sig = inspect.signature(self.model.generate) if "ref_audio" in sig.parameters: # Generate with voice cloning - for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text): + for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate else: # Fallback: generate without voice cloning - for result in self.model.generate(text): + for result in self.model.generate(text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate else: # No voice prompt, generate normally - for result in self.model.generate(text): + for result in self.model.generate(text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate except Exception as e: # If voice cloning fails, try without it print(f"Warning: Voice cloning failed, generating without voice prompt: {e}") - for result in self.model.generate(text): + for result in self.model.generate(text, lang_code=lang): audio_chunks.append(np.array(result.audio)) sample_rate = result.sample_rate diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 26f3872..1adeb22 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -15,6 +15,12 @@ from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback from ..utils.tasks import get_task_manager +LANGUAGE_CODE_TO_NAME = { + "zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean", + "de": "german", "fr": "french", "ru": "russian", "pt": "portuguese", + "es": "spanish", "it": "italian", +} + class PyTorchTTSBackend: """PyTorch-based TTS backend using Qwen3-TTS.""" @@ -335,6 +341,7 @@ def _generate_sync(): wavs, sample_rate = self.model.generate_voice_clone( text=text, voice_clone_prompt=voice_prompt, + language=LANGUAGE_CODE_TO_NAME.get(language, "auto"), instruct=instruct, ) return wavs[0], sample_rate