Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-compatibility.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.5.0"
version = "3.7.0"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"
requires-python = ">=3.9, <3.14"
license = { text = "BSD-2-Clause" }

dependencies = [
Expand All @@ -14,7 +14,7 @@ dependencies = [
"nltk>=3.9.1",
# Restrict numpy, onnxruntime, pandas, av to be compatible with Python 3.9
"numpy>=2.0.2,<2.1.0",
"onnxruntime>=1.19,<1.20.0",
"onnxruntime>=1.19,<1.20.0; python_version <'3.10'",
"pandas>=2.2.3,<2.3.0",
"av<16.0.0",
"pyannote-audio>=3.3.2,<4.0.0",
Expand Down
777 changes: 721 additions & 56 deletions uv.lock

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions whisperx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,29 @@ def load_audio(*args, **kwargs):
def assign_word_speakers(*args, **kwargs):
diarize = _lazy_import("diarize")
return diarize.assign_word_speakers(*args, **kwargs)


def setup_logging(*args, **kwargs):
"""
Configure logging for WhisperX.

Args:
level: Logging level (debug, info, warning, error, critical). Default: warning
log_file: Optional path to log file. If None, logs only to console.
"""
logging_module = _lazy_import("log_utils")
return logging_module.setup_logging(*args, **kwargs)


def get_logger(*args, **kwargs):
"""
Get a logger instance for the given module.

Args:
name: Logger name (typically __name__ from calling module)

Returns:
Logger instance configured with WhisperX settings
"""
logging_module = _lazy_import("log_utils")
return logging_module.get_logger(*args, **kwargs)
12 changes: 12 additions & 0 deletions whisperx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
optional_int, str2bool)
from whisperx.log_utils import setup_logging


def cli():
Expand All @@ -23,6 +24,7 @@ def cli():
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)")

parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
Expand Down Expand Up @@ -80,6 +82,16 @@ def cli():

args = parser.parse_args().__dict__

log_level = args.get("log_level")
verbose = args.get("verbose")

if log_level is not None:
setup_logging(level=log_level)
elif verbose:
setup_logging(level="info")
else:
setup_logging(level="warning")

from whisperx.transcribe import transcribe_task

transcribe_task(args, parser)
Expand Down
14 changes: 9 additions & 5 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
)
import nltk
from nltk.data import load as nltk_load
from whisperx.log_utils import get_logger

logger = get_logger(__name__)

LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]

Expand Down Expand Up @@ -81,8 +84,9 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
logger.error(f"No default alignment model for language: {language_code}. "
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
f"then pass the model name via --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")

if model_name in torchaudio.pipelines.__all__:
Expand Down Expand Up @@ -223,12 +227,12 @@ def align(

# check we can align
if len(segment_data[sdx]["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original')
aligned_segments.append(aligned_seg)
continue

if t1 >= MAX_DURATION:
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping')
aligned_segments.append(aligned_seg)
continue

Expand Down Expand Up @@ -270,7 +274,7 @@ def align(
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)

if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
aligned_segments.append(aligned_seg)
continue

Expand Down
11 changes: 7 additions & 4 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.schema import SingleSegment, TranscriptionResult
from whisperx.vads import Vad, Silero, Pyannote
from whisperx.log_utils import get_logger

logger = get_logger(__name__)


def find_numeral_symbol_tokens(tokenizer):
Expand Down Expand Up @@ -247,7 +250,7 @@ def data(audio, segments):
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
print(f"Suppressing numeral and symbol tokens")
logger.info("Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
Expand Down Expand Up @@ -285,7 +288,7 @@ def data(audio, segments):

def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
model_n_mels = self.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
Expand All @@ -294,7 +297,7 @@ def detect_language(self, audio: np.ndarray) -> str:
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio")
return language


Expand Down Expand Up @@ -344,7 +347,7 @@ def load_model(
if language is not None:
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
tokenizer = None

default_asr_options = {
Expand Down
4 changes: 4 additions & 0 deletions whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from whisperx.audio import load_audio, SAMPLE_RATE
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult
from whisperx.log_utils import get_logger

logger = get_logger(__name__)


class DiarizationPipeline:
Expand All @@ -18,6 +21,7 @@ def __init__(
if isinstance(device, str):
device = torch.device(device)
model_config = model_name or "pyannote/speaker-diarization-3.1"
logger.info(f"Loading diarization model: {model_config}")
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)

def __call__(
Expand Down
67 changes: 67 additions & 0 deletions whisperx/log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import sys
from typing import Optional

_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"


def setup_logging(
level: str = "info",
log_file: Optional[str] = None,
) -> None:
"""
Configure logging for WhisperX.

Args:
level: Logging level (debug, info, warning, error, critical). Default: info
log_file: Optional path to log file. If None, logs only to console.
"""
logger = logging.getLogger("whisperx")

logger.handlers.clear()

try:
log_level = getattr(logging, level.upper())
except AttributeError:
log_level = logging.WARNING
logger.setLevel(log_level)

formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)

console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)

if log_file:
try:
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
except (OSError) as e:
logger.warning(f"Failed to create log file '{log_file}': {e}")
logger.warning("Continuing with console logging only")

# Don't propagate to root logger to avoid duplicate messages
logger.propagate = False


def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for the given module.

Args:
name: Logger name (typically __name__ from calling module)

Returns:
Logger instance configured with WhisperX settings
"""
whisperx_logger = logging.getLogger("whisperx")
if not whisperx_logger.handlers:
setup_logging()

logger_name = "whisperx" if name == "__main__" else name
return logging.getLogger(logger_name)
17 changes: 10 additions & 7 deletions whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
from whisperx.log_utils import get_logger

logger = get_logger(__name__)


def transcribe_task(args: dict, parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -142,7 +145,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
logger.info("Performing transcription...")
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
Expand Down Expand Up @@ -175,13 +178,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(
logger.info(
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...")
logger.info("Performing alignment...")
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
Expand All @@ -203,12 +206,12 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
# >> Diarize
if diarize:
if hf_token is None:
print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
logger.warning(
"No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model"
)
tmp_results = results
print(">>Performing diarization...")
print(">>Using model:", diarize_model_name)
logger.info("Performing diarization...")
logger.info(f"Using model: {diarize_model_name}")
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
Expand Down
7 changes: 5 additions & 2 deletions whisperx/vads/pyannote.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from whisperx.log_utils import get_logger

logger = get_logger(__name__)


def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
Expand Down Expand Up @@ -232,7 +235,7 @@ def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
class Pyannote(Vad):

def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
logger.info("Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)

Expand All @@ -257,7 +260,7 @@ def merge_chunks(segments,
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))

if len(segments_list) == 0:
print("No active speech found in audio")
logger.warning("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
7 changes: 5 additions & 2 deletions whisperx/vads/silero.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from whisperx.log_utils import get_logger

logger = get_logger(__name__)

AudioFile = Union[Text, Path, IOBase, Mapping]


class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
logger.info("Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])

self.vad_onset = kwargs['vad_onset']
Expand Down Expand Up @@ -60,7 +63,7 @@ def merge_chunks(segments_list,
):
assert chunk_size > 0
if len(segments_list) == 0:
print("No active speech found in audio")
logger.warning("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)