Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions whisperx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def cli():
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")

parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")

Expand Down
1 change: 1 addition & 0 deletions whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
"hotwords": args.pop("hotwords"),
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
"suppress_numerals": args.pop("suppress_numerals"),
}
Expand Down
4 changes: 2 additions & 2 deletions whisperx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def write_result(self, result: dict, file: TextIO, options: dict):

def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
) -> Callable[[dict, str, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
Expand All @@ -425,7 +425,7 @@ def get_writer(
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]

def write_all(result: dict, file: TextIO, options: dict):
def write_all(result: dict, file: str, options: dict):
for writer in all_writers:
writer(result, file, options)

Expand Down