diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 756d0ba7..81c47566 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -14,7 +14,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from whisperx.audio import SAMPLE_RATE, load_audio -from whisperx.utils import interpolate_nans +from whisperx.utils import interpolate_nans, PUNKT_LANGUAGES from whisperx.schema import ( AlignedTranscriptionResult, SingleSegment, @@ -192,11 +192,13 @@ def align( clean_wdx.append(wdx) + # Use language-specific Punkt model if available otherwise we fallback to English. + punkt_lang = PUNKT_LANGUAGES.get(model_lang, 'english') try: - sentence_splitter = nltk_load('tokenizers/punkt/english.pickle') + sentence_splitter = nltk_load(f'tokenizers/punkt_tab/{punkt_lang}.pickle') except LookupError: nltk.download('punkt_tab', quiet=True) - sentence_splitter = nltk_load('tokenizers/punkt/english.pickle') + sentence_splitter = nltk_load(f'tokenizers/punkt_tab/{punkt_lang}.pickle') sentence_spans = list(sentence_splitter.span_tokenize(text)) segment_data[sdx] = { diff --git a/whisperx/utils.py b/whisperx/utils.py index ada0deb9..8c997ce9 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -126,6 +126,29 @@ LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] +# Mapping of language codes to NLTK Punkt tokenizer model names +PUNKT_LANGUAGES = { + 'cs': 'czech', + 'da': 'danish', + 'de': 'german', + 'el': 'greek', + 'en': 'english', + 'es': 'spanish', + 'et': 'estonian', + 'fi': 'finnish', + 'fr': 'french', + 'it': 'italian', + 'nl': 'dutch', + 'no': 'norwegian', + 'pl': 'polish', + 'pt': 'portuguese', + 'sl': 'slovene', + 'sv': 'swedish', + 'tr': 'turkish', + "ml": "malayalam", + "ru": "russian", +} + system_encoding = sys.getdefaultencoding() if system_encoding != "utf-8":