diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 26f3872..cbb7e8b 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -348,7 +348,7 @@ def _generate_sync(): class PyTorchSTTBackend: """PyTorch-based STT backend using Whisper.""" - def __init__(self, model_size: str = "base"): + def __init__(self, model_size: str = "turbo"): self.model = None self.processor = None self.model_size = model_size @@ -379,7 +379,10 @@ def _is_model_cached(self, model_size: str) -> bool: """ try: from huggingface_hub import constants as hf_constants - model_name = f"openai/whisper-{model_size}" + model_size_to_hf = { + "turbo": "openai/whisper-large-v3-turbo", + } + model_name = model_size_to_hf.get(model_size, f"openai/whisper-{model_size}") repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--")) if not repo_cache.exists(): @@ -457,7 +460,10 @@ def _load_model_sync(self, model_size: str): # Import transformers from transformers import WhisperProcessor, WhisperForConditionalGeneration - model_name = f"openai/whisper-{model_size}" + model_size_to_hf = { + "turbo": "openai/whisper-large-v3-turbo", + } + model_name = model_size_to_hf.get(model_size, f"openai/whisper-{model_size}") print(f"[DEBUG] Model name: {model_name}") print(f"Loading Whisper model {model_size} on {self.device}...") @@ -546,21 +552,20 @@ def _transcribe_sync(): ) inputs = inputs.to(self.device) - # Set language if provided - forced_decoder_ids = None + # Generate transcription + # If language is provided, force it; otherwise let Whisper auto-detect + generate_kwargs = {} if language: - # Support all languages from frontend: en, zh, ja, ko, de, fr, ru, pt, es, it - # Whisper supports these and many more forced_decoder_ids = self.processor.get_decoder_prompt_ids( language=language, task="transcribe", ) + generate_kwargs["forced_decoder_ids"] = forced_decoder_ids - # Generate transcription with torch.no_grad(): predicted_ids = self.model.generate( inputs["input_features"], - forced_decoder_ids=forced_decoder_ids, + **generate_kwargs, ) # Decode diff --git a/backend/main.py b/backend/main.py index 59fb9e1..eeb8969 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,9 +20,20 @@ from pathlib import Path import uuid import asyncio -import signal import os +# Set HSA_OVERRIDE_GFX_VERSION for AMD GPUs that aren't officially listed in ROCm +# (e.g., RX 6600 is gfx1032 which maps to gfx1030 target) +# This must be set BEFORE any torch.cuda calls +if not os.environ.get("HSA_OVERRIDE_GFX_VERSION"): + os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0" + +# Suppress noisy MIOpen workspace warnings on AMD GPUs +if not os.environ.get("MIOPEN_LOG_LEVEL"): + os.environ["MIOPEN_LOG_LEVEL"] = "4" + +import signal + from . import database, models, profiles, history, tts, transcribe, config, export_import, channels, stories, __version__ from .database import get_db, Generation as DBGeneration, VoiceProfile as DBVoiceProfile from .utils.progress import get_progress_manager @@ -816,9 +827,12 @@ async def transcribe_audio( # Transcribe whisper_model = transcribe.get_whisper_model() - # Check if Whisper model is downloaded (uses default size "base") + # Check if Whisper model is downloaded model_size = whisper_model.model_size - model_name = f"openai/whisper-{model_size}" + model_size_to_hf = { + "turbo": "openai/whisper-large-v3-turbo", + } + model_name = model_size_to_hf.get(model_size, f"openai/whisper-{model_size}") # Check if model is cached from huggingface_hub import constants as hf_constants @@ -1248,6 +1262,13 @@ def check_whisper_loaded(model_size: str): "model_size": "large", "check_loaded": lambda: check_whisper_loaded("large"), }, + { + "model_name": "whisper-turbo", + "display_name": "Whisper Turbo", + "hf_repo_id": "openai/whisper-large-v3-turbo", + "model_size": "turbo", + "check_loaded": lambda: check_whisper_loaded("turbo"), + }, ] # Build a mapping of model_name -> hf_repo_id so we can check if shared repos are downloading @@ -1642,7 +1663,12 @@ def _get_gpu_status() -> str: """Get GPU availability status.""" backend_type = get_backend_type() if torch.cuda.is_available(): - return f"CUDA ({torch.cuda.get_device_name(0)})" + device_name = torch.cuda.get_device_name(0) + # Check if this is ROCm (AMD) or CUDA (NVIDIA) + is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None + if is_rocm: + return f"ROCm ({device_name})" + return f"CUDA ({device_name})" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "MPS (Apple Silicon)" elif backend_type == "mlx": diff --git a/package.json b/package.json index c0f3c21..925bd35 100644 --- a/package.json +++ b/package.json @@ -9,7 +9,8 @@ "landing" ], "scripts": { - "dev": "bun run setup:dev && cd tauri && bun run tauri dev", + "dev": "cd tauri && bun run dev", + "dev:tauri": "bun run setup:dev && cd tauri && bun run tauri dev", "dev:web": "cd web && bun run dev", "dev:landing": "cd landing && bun run dev", "dev:server": "uvicorn backend.main:app --reload --port 17493", @@ -41,4 +42,4 @@ "bun": ">=1.0.0" }, "packageManager": "bun@1.3.8" -} +} \ No newline at end of file diff --git a/tauri/src-tauri/src/audio_capture/linux.rs b/tauri/src-tauri/src/audio_capture/linux.rs index 8af26e9..3cae59e 100644 --- a/tauri/src-tauri/src/audio_capture/linux.rs +++ b/tauri/src-tauri/src/audio_capture/linux.rs @@ -1,16 +1,312 @@ use crate::audio_capture::AudioCaptureState; +use base64::{engine::general_purpose, Engine as _}; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use cpal::{SampleFormat, StreamConfig}; +use hound::{WavSpec, WavWriter}; +use std::io::Cursor; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; +/// Start capturing system audio on Linux using PulseAudio monitor sources. +/// +/// PulseAudio exposes "monitor" devices that mirror the output of each sink, +/// allowing us to capture whatever audio is currently playing on the system. +/// We use `cpal` with the default host (which will be PulseAudio or PipeWire +/// on modern Linux) and look for monitor input devices. pub async fn start_capture( state: &AudioCaptureState, max_duration_secs: u32, ) -> Result<(), String> { - todo!("implement Linux audio capture") + // Reset previous samples + state.reset(); + + let samples = state.samples.clone(); + let sample_rate_arc = state.sample_rate.clone(); + let channels_arc = state.channels.clone(); + let stop_tx = state.stop_tx.clone(); + let error_arc = state.error.clone(); + + // Use AtomicBool for stop signal (works across threads) + let stop_flag = Arc::new(AtomicBool::new(false)); + let stop_flag_clone = stop_flag.clone(); + + // Create tokio channel and spawn a task to bridge it to the AtomicBool + let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(1); + *stop_tx.lock().unwrap() = Some(tx); + + tokio::spawn(async move { + rx.recv().await; + stop_flag_clone.store(true, Ordering::Relaxed); + }); + + // Spawn capture on a dedicated thread + thread::spawn(move || { + let host = cpal::default_host(); + + // Try to find a monitor device for system audio capture. + // On PulseAudio/PipeWire, monitor sources have "monitor" in their name. + let device = { + let mut monitor_device = None; + + if let Ok(devices) = host.input_devices() { + for d in devices { + if let Ok(name) = d.name() { + let name_lower = name.to_lowercase(); + if name_lower.contains("monitor") { + eprintln!("Linux audio capture: Found monitor device: {}", name); + monitor_device = Some(d); + break; + } + } + } + } + + match monitor_device { + Some(d) => d, + None => { + // Fallback to default input device (microphone) + eprintln!("Linux audio capture: No monitor device found, falling back to default input"); + match host.default_input_device() { + Some(d) => d, + None => { + let error_msg = "No audio input device available".to_string(); + eprintln!("{}", error_msg); + *error_arc.lock().unwrap() = Some(error_msg); + return; + } + } + } + } + }; + + let device_name = device.name().unwrap_or_else(|_| "unknown".to_string()); + eprintln!("Linux audio capture: Using device: {}", device_name); + + // Get supported config + let config = match device.default_input_config() { + Ok(c) => c, + Err(e) => { + let error_msg = format!("Failed to get default input config: {}", e); + eprintln!("{}", error_msg); + *error_arc.lock().unwrap() = Some(error_msg); + return; + } + }; + + let sample_rate = config.sample_rate().0; + let channels = config.channels(); + let sample_format = config.sample_format(); + + eprintln!( + "Linux audio capture: Config - {}Hz, {} channels, format: {:?}", + sample_rate, channels, sample_format + ); + + *sample_rate_arc.lock().unwrap() = sample_rate; + *channels_arc.lock().unwrap() = channels; + + let stream_config = StreamConfig { + channels, + sample_rate: cpal::SampleRate(sample_rate), + buffer_size: cpal::BufferSize::Default, + }; + + let samples_clone = samples.clone(); + let error_arc_clone = error_arc.clone(); + let stop_flag_for_stream = stop_flag.clone(); + + let err_fn = { + let error_arc = error_arc.clone(); + move |err: cpal::StreamError| { + let error_msg = format!("Stream error: {}", err); + eprintln!("{}", error_msg); + *error_arc.lock().unwrap() = Some(error_msg); + } + }; + + let stream = match sample_format { + SampleFormat::F32 => { + let samples = samples_clone.clone(); + let stop = stop_flag_for_stream.clone(); + device.build_input_stream( + &stream_config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + if stop.load(Ordering::Relaxed) { + return; + } + let mut guard = samples.lock().unwrap(); + guard.extend_from_slice(data); + }, + err_fn, + None, + ) + } + SampleFormat::I16 => { + let samples = samples_clone.clone(); + let stop = stop_flag_for_stream.clone(); + device.build_input_stream( + &stream_config, + move |data: &[i16], _: &cpal::InputCallbackInfo| { + if stop.load(Ordering::Relaxed) { + return; + } + let mut guard = samples.lock().unwrap(); + for &s in data { + guard.push(s as f32 / 32768.0); + } + }, + err_fn, + None, + ) + } + SampleFormat::U16 => { + let samples = samples_clone.clone(); + let stop = stop_flag_for_stream.clone(); + device.build_input_stream( + &stream_config, + move |data: &[u16], _: &cpal::InputCallbackInfo| { + if stop.load(Ordering::Relaxed) { + return; + } + let mut guard = samples.lock().unwrap(); + for &s in data { + guard.push((s as f32 / 32768.0) - 1.0); + } + }, + err_fn, + None, + ) + } + _ => { + let error_msg = format!("Unsupported sample format: {:?}", sample_format); + eprintln!("{}", error_msg); + *error_arc_clone.lock().unwrap() = Some(error_msg); + return; + } + }; + + let stream = match stream { + Ok(s) => s, + Err(e) => { + let error_msg = format!("Failed to build input stream: {}", e); + eprintln!("{}", error_msg); + *error_arc_clone.lock().unwrap() = Some(error_msg); + return; + } + }; + + if let Err(e) = stream.play() { + let error_msg = format!("Failed to start stream: {}", e); + eprintln!("{}", error_msg); + *error_arc_clone.lock().unwrap() = Some(error_msg); + return; + } + + eprintln!("Linux audio capture: Stream started successfully"); + + // Keep thread alive until stop signal + loop { + if stop_flag.load(Ordering::Relaxed) { + break; + } + std::thread::sleep(std::time::Duration::from_millis(100)); + } + + // Stream will be dropped here, stopping capture + eprintln!("Linux audio capture: Stream stopped"); + }); + + // Spawn timeout task + let stop_tx_clone = state.stop_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(max_duration_secs as u64)).await; + let tx = stop_tx_clone.lock().unwrap().take(); + if let Some(tx) = tx { + let _ = tx.send(()).await; + } + }); + + Ok(()) } pub async fn stop_capture(state: &AudioCaptureState) -> Result { - todo!("implement Linux audio capture stop") + // Signal stop + if let Some(tx) = state.stop_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + + // Wait a bit for capture to stop + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Check if there was an error during capture + if let Some(error) = state.error.lock().unwrap().as_ref() { + return Err(error.clone()); + } + + // Get samples + let samples = state.samples.lock().unwrap().clone(); + let sample_rate = *state.sample_rate.lock().unwrap(); + let channels = *state.channels.lock().unwrap(); + + if samples.is_empty() { + return Err( + "No audio samples captured. Make sure audio is playing on your system during recording." + .to_string(), + ); + } + + // Convert to WAV + let wav_data = samples_to_wav(&samples, sample_rate, channels)?; + + // Encode to base64 + let base64_data = general_purpose::STANDARD.encode(&wav_data); + + Ok(base64_data) } pub fn is_supported() -> bool { - false + // Check if we can find a monitor device for system audio capture + let host = cpal::default_host(); + if let Ok(devices) = host.input_devices() { + for d in devices { + if let Ok(name) = d.name() { + if name.to_lowercase().contains("monitor") { + return true; + } + } + } + } + // Even without a monitor, basic input capture is available + host.default_input_device().is_some() +} + +fn samples_to_wav(samples: &[f32], sample_rate: u32, channels: u16) -> Result, String> { + let mut buffer = Vec::new(); + let cursor = Cursor::new(&mut buffer); + + let spec = WavSpec { + channels, + sample_rate, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + + let mut writer = + WavWriter::new(cursor, spec).map_err(|e| format!("Failed to create WAV writer: {}", e))?; + + // Convert f32 samples to i16 + for sample in samples { + let clamped = sample.clamp(-1.0, 1.0); + let i16_sample = (clamped * 32767.0) as i16; + writer + .write_sample(i16_sample) + .map_err(|e| format!("Failed to write sample: {}", e))?; + } + + writer + .finalize() + .map_err(|e| format!("Failed to finalize WAV: {}", e))?; + + Ok(buffer) } diff --git a/tauri/src-tauri/src/main.rs b/tauri/src-tauri/src/main.rs index 255655a..a706f25 100644 --- a/tauri/src-tauri/src/main.rs +++ b/tauri/src-tauri/src/main.rs @@ -675,7 +675,7 @@ pub fn run() { }); // Wait for frontend response or timeout - tokio::spawn(async move { + tauri::async_runtime::spawn(async move { tokio::select! { _ = rx.recv() => { // Frontend responded, close window diff --git a/tauri/src-tauri/tauri.conf.json b/tauri/src-tauri/tauri.conf.json index 53b95d1..278a9a5 100644 --- a/tauri/src-tauri/tauri.conf.json +++ b/tauri/src-tauri/tauri.conf.json @@ -13,7 +13,9 @@ "active": true, "targets": "all", "createUpdaterArtifacts": true, - "externalBin": ["binaries/voicebox-server"], + "externalBin": [ + "binaries/voicebox-server" + ], "icon": [ "icons/32x32.png", "icons/128x128.png", @@ -27,16 +29,14 @@ "infoPlist": "Info.plist", "entitlements": "Entitlements.plist" }, - "resources": { - "gen/Assets.car": "./", - "gen/voicebox.icns": "./", - "gen/partial.plist": "./" - } + "resources": {} }, "app": { "security": { "csp": null, - "capabilities": ["default"] + "capabilities": [ + "default" + ] }, "windows": [ { @@ -60,7 +60,9 @@ }, "updater": { "pubkey": "dW50cnVzdGVkIGNvbW1lbnQ6IG1pbmlzaWduIHB1YmxpYyBrZXk6IEUxRENBQkRBQjdBNTM1OTIKUldTU05hVzMycXZjNGJGcUxmcVVocll2QjdSaTJNdlFxR2M3VDJsMnVvbDdyZGRPMmRlOW9aWTcK", - "endpoints": ["https://github.com/jamiepine/voicebox/releases/latest/download/latest.json"] + "endpoints": [ + "https://github.com/jamiepine/voicebox/releases/latest/download/latest.json" + ] } } -} +} \ No newline at end of file