diff --git a/whisperx/__main__.py b/whisperx/__main__.py index 5102bc0a..e7cc62f0 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -58,6 +58,7 @@ def cli(): parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") + parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 11110c64..04c2ab36 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -106,6 +106,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): "no_speech_threshold": args.pop("no_speech_threshold"), "condition_on_previous_text": False, "initial_prompt": args.pop("initial_prompt"), + "hotwords": args.pop("hotwords"), "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], "suppress_numerals": args.pop("suppress_numerals"), } diff --git a/whisperx/utils.py b/whisperx/utils.py index dfe3cf2e..ada0deb9 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -410,7 +410,7 @@ def write_result(self, result: dict, file: TextIO, options: dict): def get_writer( output_format: str, output_dir: str -) -> Callable[[dict, TextIO, dict], None]: +) -> Callable[[dict, str, dict], None]: writers = { "txt": WriteTXT, "vtt": WriteVTT, @@ -425,7 +425,7 @@ def get_writer( if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO, options: dict): + def write_all(result: dict, file: str, options: dict): for writer in all_writers: writer(result, file, options)