diff --git a/api/generate.py b/api/generate.py index 814da55..964799c 100644 --- a/api/generate.py +++ b/api/generate.py @@ -5,6 +5,7 @@ import json import logging +import os import re import threading import time @@ -21,11 +22,11 @@ def _uppercase_track_in_instruction(instruction): return instruction[: m.start(2)] + m.group(2).upper() + instruction[m.end(2) :] return instruction -from cdmf_paths import get_output_dir, get_user_data_dir, load_config +from cdmf_paths import get_output_dir, get_user_data_dir, get_models_folder, load_config from cdmf_tracks import get_audio_duration, list_lora_adapters, load_track_meta, save_track_meta from cdmf_generation_job import GenerationCancelled import cdmf_state -from generate_ace import register_job_progress_callback +from generate_ace import register_job_progress_callback, _resolve_lm_checkpoint_path bp = Blueprint("api_generate", __name__) @@ -183,12 +184,10 @@ def _run_generation(job_id: str) -> None: negative_prompt_str = (params.get("negativePrompt") or params.get("negative_prompt") or "").strip() try: d = params.get("duration") - duration = float(d if d is not None else 60) + # Keep <=0 as "Auto" and pass through to the model path. + duration = float(d if d is not None else -1) except (TypeError, ValueError): - duration = 60 - # UI may send duration=-1 or 0; for cover we may get duration from source file below - if duration > 0: - duration = max(15, min(240, duration)) + duration = -1 # Guide: 65 steps + CFG 4.0 for best quality; low CFG reduces artifacts (see community guide). try: steps = int(params.get("inferenceSteps") or 65) @@ -405,7 +404,7 @@ def _run_generation(job_id: str) -> None: path = Path(str(wav_path)) filename = path.name audio_url = f"/audio/{filename}" - actual_seconds = float(summary.get("actual_seconds") or duration) + actual_seconds = float(summary.get("actual_seconds") or (duration if duration > 0 else 0)) # Save title, lyrics, style to track metadata so they appear in the library (input params only; model does not return lyrics) try: @@ -750,16 +749,374 @@ def get_debug(task_id: str): return jsonify({"rawResponse": job}) -@bp.route("/format", methods=["POST"]) -def format_input(): - """POST /api/generate/format — stub; return same payload.""" - data = request.get_json(silent=True) or {} - return jsonify({ +def _format_with_lm(data: dict) -> tuple[dict | None, str | None]: + """ + Best-effort input formatting via ACE-Step LM. + Returns (response payload, unavailable_reason). + - payload is not None when formatter executed (including success=False responses from LM). + - unavailable_reason explains why LM formatting could not run at all. + """ + caption = (data.get("caption") or "").strip() + lyrics = (data.get("lyrics") or "").strip() + if not caption and not lyrics: + return None, "Please provide Style or Lyrics input to format." + + LLMHandler = None + format_sample = None + import_errors: list[str] = [] + try: + from acestep.llm_inference import LLMHandler as _LLMHandler + from acestep.inference import format_sample as _format_sample + LLMHandler = _LLMHandler + format_sample = _format_sample + except Exception as e1: + import_errors.append(f"acestep.llm_inference + acestep.inference.format_sample: {e1}") + if LLMHandler is None or format_sample is None: + try: + from acestep.inference import LLMHandler as _LLMHandler # type: ignore[attr-defined] + from acestep.inference import format_sample as _format_sample + LLMHandler = _LLMHandler + format_sample = _format_sample + except Exception as e2: + import_errors.append(f"acestep.inference (LLMHandler, format_sample): {e2}") + if LLMHandler is None or format_sample is None: + reason = ( + "LM modules failed to import. " + "This build likely has a non-1.5 ACE-Step package. " + f"Tried: {' | '.join(import_errors)}" + ) + logging.info("[API format] %s", reason) + return None, reason + + cfg = load_config() or {} + lm_id = str(cfg.get("ace_step_lm") or "1.7B").strip() + if not lm_id or lm_id.lower() == "none": + return None, "LM model is set to 'none' in Settings > Models." + + try: + checkpoints_root = get_models_folder() / "checkpoints" + lm_checkpoint_path = _resolve_lm_checkpoint_path(lm_id, checkpoints_root) + except Exception as path_err: + reason = f"Could not resolve LM checkpoint path: {path_err}" + logging.info("[API format] %s", reason) + return None, reason + if not lm_checkpoint_path: + return None, f"LM checkpoint for '{lm_id}' not found. Download it in Settings > Models." + + user_metadata: dict = {} + try: + bpm = data.get("bpm") + if bpm is not None: + bpm_i = int(float(bpm)) + if bpm_i > 0: + user_metadata["bpm"] = bpm_i + except Exception: + pass + try: + duration = data.get("duration") + if duration is not None: + duration_f = float(duration) + if duration_f > 0: + user_metadata["duration"] = duration_f + except Exception: + pass + key_scale = (data.get("keyScale") or data.get("key_scale") or "").strip() + if key_scale: + user_metadata["keyscale"] = key_scale + time_sig = (data.get("timeSignature") or data.get("time_signature") or "").strip() + if time_sig: + user_metadata["timesignature"] = time_sig + language = (data.get("language") or "").strip() + if language and language.lower() not in ("unknown", "auto"): + user_metadata["language"] = language + + try: + temperature = float(data.get("temperature") or 0.85) + except Exception: + temperature = 0.85 + try: + top_k = int(data.get("topK")) if data.get("topK") not in (None, "") else None + except Exception: + top_k = None + try: + top_p = float(data.get("topP")) if data.get("topP") not in (None, "") else None + except Exception: + top_p = None + + try: + llm = LLMHandler() + device = "cuda" if bool(os.environ.get("CUDA_VISIBLE_DEVICES")) else "cpu" + lm_path = Path(str(lm_checkpoint_path)) + init_errors: list[str] = [] + init_ok = False + init_attempts = [ + # Prefer explicit non-vLLM backends first, especially on CPU/MPS runs. + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "pytorch", "device": device}, + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "transformers", "device": device}, + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "hf", "device": device}, + # ACE-Step 1.5 signature from docs (backend default) + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "device": device}, + # Older/simple initialize signature + {"checkpoint_dir": str(lm_path), "device": device}, + # Some variants accept direct lm_model_path only + {"lm_model_path": str(lm_path), "device": device}, + ] + if device == "cuda": + init_attempts.append({"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "vllm", "device": device}) + for kwargs in init_attempts: + try: + llm.initialize(**kwargs) + init_ok = True + break + except Exception as init_err: + init_errors.append(f"{kwargs}: {init_err}") + if not init_ok: + raise RuntimeError("LLMHandler.initialize failed for all known signatures: " + " | ".join(init_errors)) + def _run_format(): + return format_sample( + llm_handler=llm, + caption=caption, + lyrics=lyrics, + user_metadata=user_metadata or None, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + result = _run_format() + # Some ACE-Step builds return "LLM not initialized" without throwing. + status_msg = str(getattr(result, "status_message", "") or "") + if "llm not initialized" in status_msg.lower(): + retry_attempts = [ + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "pytorch", "device": device}, + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "transformers", "device": device}, + {"checkpoint_dir": str(lm_path.parent), "lm_model_path": lm_path.name, "backend": "hf", "device": device}, + ] + for kwargs in retry_attempts: + try: + llm.initialize(**kwargs) + result = _run_format() + status_msg = str(getattr(result, "status_message", "") or "") + if "llm not initialized" not in status_msg.lower(): + break + except Exception: + continue + if result is None: + return None, "LM formatter returned no result." + except Exception as run_err: + reason = f"LM inference failed: {run_err}" + logging.warning("[API format] %s", reason) + return None, reason + + success = bool(getattr(result, "success", True)) + return { + "success": success, + "caption": getattr(result, "caption", None), + "lyrics": getattr(result, "lyrics", None), + "bpm": getattr(result, "bpm", None), + "duration": getattr(result, "duration", None), + "key_scale": getattr(result, "keyscale", None), + "language": getattr(result, "language", None), + "time_signature": getattr(result, "timesignature", None), + "status_message": getattr(result, "status_message", None), + "error": getattr(result, "error", None), + }, None + + +def _normalize_lyrics_sections(lyrics: str) -> str: + text = (lyrics or "").strip() + if not text: + return text + lines = [ln.rstrip() for ln in text.splitlines()] + tag_re = re.compile(r"^\s*\[\s*([A-Za-z][A-Za-z0-9 _-]*?)\s*\]\s*$") + has_any_tag = any(tag_re.match(ln) for ln in lines) + + def _clean_tag(tag: str) -> str: + low = re.sub(r"\s+", " ", tag.strip().lower()) + mapping = { + "intro": "Intro", + "verse": "Verse", + "chorus": "Chorus", + "pre-chorus": "Pre-Chorus", + "post-chorus": "Post-Chorus", + "bridge": "Bridge", + "outro": "Outro", + "hook": "Hook", + "refrain": "Refrain", + } + for k, v in mapping.items(): + if low == k or low.startswith(k + " "): + suffix = low[len(k):].strip() + return f"[{v}{(' ' + suffix) if suffix else ''}]" + return f"[{tag.strip().title()}]" + + if has_any_tag: + out: list[str] = [] + prev_blank = False + for ln in lines: + m = tag_re.match(ln) + if m: + out.append(_clean_tag(m.group(1))) + prev_blank = False + continue + if not ln.strip(): + if not prev_blank: + out.append("") + prev_blank = True + continue + out.append(ln.strip()) + prev_blank = False + return "\n".join(out).strip() + + # No structure tags: split paragraphs and apply a simple song structure. + paragraphs: list[str] = [] + cur: list[str] = [] + for ln in lines: + if ln.strip(): + cur.append(ln.strip()) + else: + if cur: + paragraphs.append("\n".join(cur)) + cur = [] + if cur: + paragraphs.append("\n".join(cur)) + if not paragraphs: + return text + + labels = ["[Verse 1]", "[Chorus]", "[Verse 2]", "[Chorus]", "[Bridge]", "[Chorus]", "[Outro]"] + out: list[str] = [] + verse_n = 3 + for i, block in enumerate(paragraphs): + if i < len(labels): + tag = labels[i] + else: + tag = f"[Verse {verse_n}]" + verse_n += 1 + out.append(tag) + out.append(block) + out.append("") + return "\n".join(out).strip() + + +def _infer_style_from_lyrics(lyrics: str) -> str: + t = (lyrics or "").lower() + if not t.strip(): + return "emotional vocal song with clear verse-chorus structure" + if any(w in t for w in ("black days", "fate", "fear", "night", "blind", "fall")): + return "dark alternative rock, melancholic grunge vibe, expressive male vocals, dynamic verse-chorus structure" + if any(w in t for w in ("dance", "party", "club", "tonight", "bailando")): + return "upbeat pop dance track, catchy hooks, energetic vocal delivery" + if any(w in t for w in ("love", "heart", "tears", "alone", "broken")): + return "emotional pop rock ballad, introspective lyrics, wide dynamic chorus" + return "vocal alt-pop/rock song, emotional tone, clear verse-chorus form" + + +def _expand_style_prompt(caption: str, lyrics: str) -> str: + cap = re.sub(r"\s+", " ", (caption or "").strip().strip(",")) + lyr = (lyrics or "").lower() + low = f"{cap.lower()} {lyr}".strip() + if not cap: + return _infer_style_from_lyrics(lyrics) + + def has_any(words: tuple[str, ...]) -> bool: + return any(w in low for w in words) + + additions: list[str] = [] + + # Genre/style axis + has_genre = has_any(( + "rock", "grunge", "metal", "pop", "dance", "edm", "house", "hip hop", "rap", + "r&b", "soul", "jazz", "blues", "folk", "country", "acoustic", "orchestral", + "cinematic", "ambient", "electronic", "alt", "alternative", + )) + if not has_genre: + if has_any(("black days", "fear", "fate", "night", "blind", "fall", "dark")): + additions.append("dark alternative rock with subtle grunge texture") + elif has_any(("dance", "party", "club", "tonight", "bailando")): + additions.append("upbeat pop dance production with modern electronic polish") + elif has_any(("love", "heart", "alone", "tears", "broken")): + additions.append("emotional pop-rock ballad character") + else: + additions.append("modern alt-pop/rock character") + + # Mood axis + has_mood = has_any(( + "dark", "melanch", "moody", "sad", "brooding", "uplift", "happy", "energetic", + "aggressive", "tender", "warm", "cinematic", "emotional", "introspective", + )) + if not has_mood: + if has_any(("black days", "fear", "fate", "night", "blind", "fall", "empty")): + additions.append("brooding, introspective mood") + elif has_any(("dance", "party", "club", "celebrate")): + additions.append("high-energy, hook-forward mood") + else: + additions.append("emotionally focused tone") + + # Instrumentation axis + has_instruments = has_any(( + "guitar", "bass", "drum", "synth", "piano", "string", "pad", "808", "perc", + "orchestra", "brass", "keys", + )) + if not has_instruments: + if has_any(("rock", "grunge", "alt", "alternative", "black days")): + additions.append("gritty electric guitars, driving bass, and punchy live drums") + elif has_any(("dance", "edm", "house", "electronic", "club")): + additions.append("tight electronic drums, deep bass, and bright synth hooks") + else: + additions.append("focused rhythm section with melodic lead layers") + + # Arrangement / structure axis + has_structure = has_any(("verse", "chorus", "bridge", "drop", "build", "hook", "arrangement", "structure")) + if not has_structure: + additions.append("clear verse-chorus contrast with a stronger chorus lift") + + # Vocal direction axis + has_vocal = has_any(("vocal", "voice", "sung", "singer", "male vocal", "female vocal", "duet", "harmony")) + if lyrics.strip() and not has_vocal: + additions.append("expressive lead vocals with natural phrasing") + + # Keep repeated clicks mostly idempotent. + deduped = [a for a in additions if a.lower() not in cap.lower()] + if not deduped: + return cap + return f"{cap}, {', '.join(deduped)}" + + +def _heuristic_format_input(data: dict, reason: str | None) -> dict: + mode = str(data.get("mode") or "general").strip().lower() + caption_in = (data.get("caption") or "").strip() + lyrics_in = (data.get("lyrics") or "").strip() + + caption_out = caption_in + lyrics_out = lyrics_in + + if mode in ("style", "general"): + if not caption_out and lyrics_in: + caption_out = _infer_style_from_lyrics(lyrics_in) + elif caption_out: + caption_out = _expand_style_prompt(caption_out, lyrics_in) + + if mode in ("lyrics", "general") and lyrics_in: + lyrics_out = _normalize_lyrics_sections(lyrics_in) + + return { "success": True, - "caption": data.get("caption"), - "lyrics": data.get("lyrics"), + "caption": caption_out, + "lyrics": lyrics_out, "bpm": data.get("bpm"), "duration": data.get("duration"), "key_scale": data.get("keyScale"), + "language": data.get("language"), "time_signature": data.get("timeSignature"), - }) + "status_message": f"LM formatter unavailable. Reason: {reason or 'unknown'}. Applied local heuristic formatting.", + } + + +@bp.route("/format", methods=["POST"]) +def format_input(): + """POST /api/generate/format — format caption/lyrics via ACE-Step LM when available.""" + data = request.get_json(silent=True) or {} + lm_payload, unavailable_reason = _format_with_lm(data) + if lm_payload is not None: + return jsonify(lm_payload) + return jsonify(_heuristic_format_input(data, unavailable_reason)) diff --git a/cdmf_pipeline_ace_step.py b/cdmf_pipeline_ace_step.py index 5780a74..22c52a2 100644 --- a/cdmf_pipeline_ace_step.py +++ b/cdmf_pipeline_ace_step.py @@ -329,37 +329,160 @@ def _refine_prompt_with_lm( ): """ Optionally refine prompt/lyrics using the bundled ACE-Step 5Hz LM (no external LLM). - Returns (refined_prompt, refined_lyrics) or None if LM inference is not available. + Returns metadata dict or None if LM inference is not available. """ try: - # ACE-Step 1.5 may expose LLMHandler and format_sample / create_sample + # ACE-Step 1.5 primary layout from acestep.llm_inference import LLMHandler from acestep.inference import format_sample - except ImportError: - return None + except Exception: + try: + # Fallback layout used by some builds + from acestep.inference import LLMHandler, format_sample # type: ignore[attr-defined] + except Exception: + return None try: llm = LLMHandler() - llm.initialize( - checkpoint_dir=str(lm_checkpoint_path), - device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu", - ) - result = format_sample( - llm, - caption=prompt or "", - lyrics=lyrics or "", - temperature=lm_temperature, - top_k=lm_top_k if lm_top_k else None, - top_p=lm_top_p, - ) + device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" + lm_path = str(lm_checkpoint_path) + lm_parent = os.path.dirname(lm_path) + lm_name = os.path.basename(lm_path) + init_ok = False + init_attempts = [ + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "pytorch", "device": device}, + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "transformers", "device": device}, + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "hf", "device": device}, + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "device": device}, + {"checkpoint_dir": lm_path, "device": device}, + {"lm_model_path": lm_path, "device": device}, + ] + if device == "cuda": + init_attempts.append({"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "vllm", "device": device}) + for kwargs in init_attempts: + try: + llm.initialize(**kwargs) + init_ok = True + break + except Exception: + continue + if not init_ok: + return None + def _run_format(): + return format_sample( + llm, + caption=prompt or "", + lyrics=lyrics or "", + temperature=lm_temperature, + top_k=lm_top_k if lm_top_k else None, + top_p=lm_top_p, + ) + + result = _run_format() + status_msg = str(getattr(result, "status_message", "") or "") + if "llm not initialized" in status_msg.lower(): + retry_attempts = [ + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "pytorch", "device": device}, + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "transformers", "device": device}, + {"checkpoint_dir": lm_parent, "lm_model_path": lm_name, "backend": "hf", "device": device}, + ] + for kwargs in retry_attempts: + try: + llm.initialize(**kwargs) + result = _run_format() + status_msg = str(getattr(result, "status_message", "") or "") + if "llm not initialized" not in status_msg.lower(): + break + except Exception: + continue if result and getattr(result, "caption", None) is not None: new_caption = getattr(result, "caption", None) or prompt new_lyrics = getattr(result, "lyrics", None) if hasattr(result, "lyrics") else lyrics - return (new_caption, new_lyrics or lyrics) + lm_duration = None + try: + if hasattr(result, "duration") and getattr(result, "duration") is not None: + lm_duration = float(getattr(result, "duration")) + elif hasattr(result, "cot_duration") and getattr(result, "cot_duration") is not None: + lm_duration = float(getattr(result, "cot_duration")) + except Exception: + lm_duration = None + return { + "caption": new_caption, + "lyrics": new_lyrics or lyrics, + "duration": lm_duration, + } + except Exception: + pass + return None + + +def _audio_duration_seconds(audio_path): + """ + Best-effort duration probe for local audio paths. + Returns duration in seconds or None when unavailable. + """ + if not audio_path: + return None + try: + from pydub import AudioSegment + seg = AudioSegment.from_file(str(audio_path)) + sec = float(len(seg)) / 1000.0 + return sec if sec > 0 else None + except Exception: + pass + try: + if torchaudio is not None: + wav, sr = torchaudio.load(str(audio_path)) + if sr and getattr(wav, "shape", None) is not None and wav.shape[-1] > 0: + return float(wav.shape[-1]) / float(sr) except Exception: pass return None +def _estimate_duration_from_lyrics(lyrics_text): + """ + Heuristic duration estimate from lyric content when LM metadata is unavailable. + Returns seconds or None. + """ + text = (lyrics_text or "").strip() + if not text: + return None + low = text.lower() + if low in ("[inst]", "[instrumental]", "instrumental"): + return None + + lines = [ln.strip() for ln in text.splitlines()] + non_empty = [ln for ln in lines if ln] + if not non_empty: + return None + + # Ignore section tags like [Verse], [Chorus], etc. for word counting. + section_pat = re.compile(r"^\s*\[[^\]]+\]\s*$") + lyric_lines = [ln for ln in non_empty if not section_pat.match(ln)] + if not lyric_lines: + return None + + words = 0 + for ln in lyric_lines: + words += len(re.findall(r"[A-Za-z0-9']+", ln)) + + # Two complementary estimates: + # 1) word-rate estimate (sung lyrics are often ~95-120 wpm; use ~0.58 s/word) + # 2) line-rate estimate (phrases + rests; ~4.3 s per lyric line) + word_est = words * 0.58 + line_est = len(lyric_lines) * 4.3 + + # Take the larger estimate to reduce premature cutoffs. + est = max(word_est, line_est) + + # Add a small arrangement buffer for intro/outro/transitions. + est += 12.0 + + # Keep in sane generation bounds. + est = max(30.0, min(360.0, est)) + return float(est) + + # class ACEStepPipeline(DiffusionPipeline): class ACEStepPipeline: def __init__( @@ -1961,6 +2084,7 @@ def __call__( start_time = time.time() # LM planner: optional refinement of prompt/lyrics using bundled 5Hz LM (no external LLM) + lm_duration = None if thinking and lm_checkpoint_path: refined = _refine_prompt_with_lm( lm_checkpoint_path=lm_checkpoint_path, @@ -1976,9 +2100,13 @@ def __call__( lm_negative_prompt=lm_negative_prompt, ) if refined is not None: - prompt, lyrics = refined + prompt = refined.get("caption", prompt) + lyrics = refined.get("lyrics", lyrics) + lm_duration = refined.get("duration") if logger: logger.info("LM planner refined caption/lyrics; using for DiT.") + if logger and lm_duration and lm_duration > 0: + logger.info(f"LM planner suggested duration: {lm_duration:.1f}s") elif thinking and logger: logger.info( "Thinking on but no LM planner path (select an LM in Settings → Models). DiT-only." @@ -2045,8 +2173,23 @@ def __call__( ) if audio_duration <= 0: - audio_duration = random.uniform(30.0, 240.0) - logger.info(f"random audio duration: {audio_duration}") + if lm_duration is not None and lm_duration > 0: + audio_duration = lm_duration + logger.info(f"Using LM metadata duration: {audio_duration:.1f}s") + else: + ref_like_path = ref_audio_input if (audio2audio_enable and ref_audio_input) else src_audio_path + ref_duration = _audio_duration_seconds(ref_like_path) + if ref_duration is not None and ref_duration > 0: + audio_duration = ref_duration + logger.info(f"Using source/reference duration: {audio_duration:.1f}s") + else: + lyric_duration = _estimate_duration_from_lyrics(lyrics) + if lyric_duration is not None and lyric_duration > 0: + audio_duration = lyric_duration + logger.info(f"Using lyric-derived duration estimate: {audio_duration:.1f}s") + else: + audio_duration = random.uniform(30.0, 240.0) + logger.info(f"random audio duration: {audio_duration}") end_time = time.time() preprocess_time_cost = end_time - start_time diff --git a/ui/components/CreatePanel.tsx b/ui/components/CreatePanel.tsx index b902906..1f324cd 100644 --- a/ui/components/CreatePanel.tsx +++ b/ui/components/CreatePanel.tsx @@ -286,7 +286,10 @@ export const CreatePanel: React.FC = ({ onGenerate, isGenerati const [isUploadingReference, setIsUploadingReference] = useState(false); const [isUploadingSource, setIsUploadingSource] = useState(false); const [uploadError, setUploadError] = useState(null); - const [isFormatting, setIsFormatting] = useState(false); + const [isFormattingStyle, setIsFormattingStyle] = useState(false); + const [isFormattingLyrics, setIsFormattingLyrics] = useState(false); + const referenceInputRef = useRef(null); + const sourceInputRef = useRef(null); const [showAudioModal, setShowAudioModal] = useState(false); const [audioModalTarget, setAudioModalTarget] = useState<'reference' | 'source' | 'cover_style'>('reference'); const [tempAudioUrl, setTempAudioUrl] = useState(''); @@ -435,14 +438,37 @@ export const CreatePanel: React.FC = ({ onGenerate, isGenerati } }; - // Format handler - uses LLM to enhance style and auto-fill parameters - const handleFormat = async () => { - if (!token || !style.trim()) return; - setIsFormatting(true); + const handleFileSelect = (e: React.ChangeEvent, target: 'reference' | 'source') => { + const file = e.target.files?.[0]; + if (file) { + void uploadAudio(file, target); + } + e.target.value = ''; + }; + + const applyFormatMetadata = (result: { + bpm?: number; + duration?: number; + key_scale?: string; + time_signature?: string; + language?: string; + }) => { + if (result.bpm && result.bpm > 0) setBpm(result.bpm); + if (result.duration && result.duration > 0) setDuration(result.duration); + if (result.key_scale) setKeyScale(result.key_scale); + if (result.time_signature) setTimeSignature(result.time_signature); + if (result.language) setVocalLanguage(result.language); + }; + + // Style formatter: infer/enhance style prompt; if style is empty, it can infer from lyrics. + const handleFormatStyle = async () => { + if (!style.trim() && !lyrics.trim()) return; + setIsFormattingStyle(true); try { const result = await generateApi.formatInput({ + mode: 'style', caption: style, - lyrics: lyrics, + lyrics: lyrics || undefined, bpm: bpm > 0 ? bpm : undefined, duration: duration > 0 ? duration : undefined, keyScale: keyScale || undefined, @@ -450,27 +476,54 @@ export const CreatePanel: React.FC = ({ onGenerate, isGenerati temperature: lmTemperature, topK: lmTopK > 0 ? lmTopK : undefined, topP: lmTopP, - }, token); + }, token || undefined); if (result.success) { - // Update fields with LLM-generated values - if (result.caption) setStyle(result.caption); - if (result.lyrics) setLyrics(result.lyrics); - if (result.bpm && result.bpm > 0) setBpm(result.bpm); - if (result.duration && result.duration > 0) setDuration(result.duration); - if (result.key_scale) setKeyScale(result.key_scale); - if (result.time_signature) setTimeSignature(result.time_signature); - if (result.language) setVocalLanguage(result.language); + if (result.caption?.trim()) setStyle(result.caption.trim()); + applyFormatMetadata(result); setIsFormatCaption(true); } else { - console.error('Format failed:', result.error || result.status_message); - alert(result.error || result.status_message || 'Format failed. Make sure the LLM is initialized.'); + console.error('Style format failed:', result.error || result.status_message); + alert(result.error || result.status_message || 'Style format failed. Make sure the LM is initialized.'); + } + } catch (err) { + console.error('Style format error:', err); + alert('Style format failed. The LM may not be available.'); + } finally { + setIsFormattingStyle(false); + } + }; + + // Lyrics formatter: normalize / structure lyrics ([Verse], [Chorus], etc.). + const handleFormatLyrics = async () => { + if (!lyrics.trim()) return; + setIsFormattingLyrics(true); + try { + const result = await generateApi.formatInput({ + mode: 'lyrics', + caption: style || undefined, + lyrics, + bpm: bpm > 0 ? bpm : undefined, + duration: duration > 0 ? duration : undefined, + keyScale: keyScale || undefined, + timeSignature: timeSignature || undefined, + temperature: lmTemperature, + topK: lmTopK > 0 ? lmTopK : undefined, + topP: lmTopP, + }, token || undefined); + + if (result.success) { + if (result.lyrics?.trim()) setLyrics(result.lyrics.trim()); + applyFormatMetadata(result); + } else { + console.error('Lyrics format failed:', result.error || result.status_message); + alert(result.error || result.status_message || 'Lyrics format failed. Make sure the LM is initialized.'); } } catch (err) { - console.error('Format error:', err); - alert('Format failed. The LLM may not be available.'); + console.error('Lyrics format error:', err); + alert('Lyrics format failed. The LM may not be available.'); } finally { - setIsFormatting(false); + setIsFormattingLyrics(false); } }; @@ -1707,12 +1760,12 @@ export const CreatePanel: React.FC = ({ onGenerate, isGenerati

Genre, mood, instruments, vibe

@@ -1784,12 +1837,12 @@ export const CreatePanel: React.FC = ({ onGenerate, isGenerati {instrumental ? 'Instrumental' : 'Vocal'}