diff --git a/cli/decode.py b/cli/decode.py index 9bc4439..a9749b4 100755 --- a/cli/decode.py +++ b/cli/decode.py @@ -22,6 +22,8 @@ def log( ): metrics_ = [ ('{}', prefix), + ('wer {:.2f}', metrics.avg('wer')), + ('cer {:.2f}', metrics.avg('cer')), ('bleu {:.2f}', metrics.avg('bleu')), ('spbleu {:.2f}', metrics.avg('spbleu')), ('len_ratio {:.3f}', metrics.avg('len_ratio')), diff --git a/cli/train.py b/cli/train.py index fe1dff9..a8d349a 100755 --- a/cli/train.py +++ b/cli/train.py @@ -561,6 +561,8 @@ def log( ('ppl {:.2f}', ppl), ('lines {:.4g}', metrics.sum('num_lines')), ('tokens {:.4g}', metrics.sum('num_tokens')), + ('wer {:.2f}', metrics.val('wer')), + ('cer {:.2f}', metrics.val('cer')), ('bleu {:.2f}', metrics.val('bleu')), ('spbleu {:.2f}', metrics.val('spbleu')), ('len_ratio {:.3f}', metrics.val('len_ratio')), diff --git a/pasero/config.py b/pasero/config.py index a71e0d5..6c26b6d 100644 --- a/pasero/config.py +++ b/pasero/config.py @@ -491,6 +491,7 @@ class EvalConfig(Config): defaults={ 'language_modeling': [], 'translation': ['chrf', 'bleu', 'chrf++', 'spbleu', 'len_ratio'], + 'speech_translation': ['chrf', 'bleu', 'chrf++', 'spbleu', 'len_ratio', 'wer', 'cer'], }, help='evaluation metrics to compute' ) diff --git a/pasero/evaluation.py b/pasero/evaluation.py index 8bc10c1..f501bfa 100755 --- a/pasero/evaluation.py +++ b/pasero/evaluation.py @@ -12,15 +12,40 @@ import functools import itertools import random +import jiwer from typing import Iterable, Iterator, Sequence, Optional -METRICS = ['chrf', 'bleu', 'langid', 'len_ratio', 'chrf++', 'spbleu', 'wer'] +METRICS = ['chrf', 'bleu', 'langid', 'len_ratio', 'chrf++', 'spbleu', 'wer', 'cer'] BLEU_TOKENIZERS = sacrebleu.metrics.METRICS['BLEU'].TOKENIZERS - +JIWER_TRANSFORMATIONS = jiwer.Compose([ + jiwer.ToLowerCase(), + jiwer.RemoveMultipleSpaces(), + jiwer.Strip(), + jiwer.RemovePunctuation(), + jiwer.ReduceToListOfListOfWords(word_delimiter=" "), +]) +JICER_TRANSFORMATIONS = jiwer.Compose([ + jiwer.ToLowerCase(), + jiwer.RemoveMultipleSpaces(), + jiwer.Strip(), + jiwer.RemovePunctuation(), + jiwer.ReduceToListOfListOfChars(), +]) def lower_is_better(metric: str) -> bool: - return 'loss' in metric or 'ppl' in metric + return 'loss' in metric or 'ppl' in metric or 'wer' in metric or 'cer' in metric +def jiwer_scores( + hyps: list[str], + refs: list[str], + metric: str = 'wer', +): + if metric == 'wer': + return 100*jiwer.wer(refs, hyps, truth_transform=JIWER_TRANSFORMATIONS, hypothesis_transform=JIWER_TRANSFORMATIONS) + elif metric == 'cer': + return 100*jiwer.cer(refs, hyps, truth_transform=JICER_TRANSFORMATIONS, hypothesis_transform=JICER_TRANSFORMATIONS) + else: + raise NotImplementedError(f"metric {metric} not implemented in jiwer_scores") def langid_py(line: str) -> str: import langid @@ -275,10 +300,8 @@ def score( chrf_word_order = 2 if metric == 'chrf++' else 0 bleu_tok = 'flores200' if metric == 'spbleu' else bleu_tok - if metric == 'wer': - from jiwer import wer - hyps, refs = zip(*[(hyp, ref) for hyp, ref in zip(hyps, refs) if hyp and ref]) - score = 100 * wer(list(refs), list(hyps)) + if metric in ('wer', 'cer'): + score = jiwer_scores(hyps, refs, metric) elif metric in ('bleu', 'spbleu'): metric_ = sacrebleu.metrics.BLEU(tokenize=bleu_tok, force=True) out = metric_.corpus_score(hyps, [refs]) diff --git a/pasero/models/modules.py b/pasero/models/modules.py index 194c4f0..3aab42e 100644 --- a/pasero/models/modules.py +++ b/pasero/models/modules.py @@ -596,6 +596,7 @@ def forward( batch_size, tgt_len, embed_dim = query.size() src_len = key.size(1) + return_attn = True q = self.q_proj(query) # BxTxD q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim) # BxTxHxD' diff --git a/scripts/plot-logs.py b/scripts/plot-logs.py index 3e8e47b..01d792a 100755 --- a/scripts/plot-logs.py +++ b/scripts/plot-logs.py @@ -199,7 +199,7 @@ }) -lower_is_better = 'loss' in args.metric or 'ppl' in args.metric +lower_is_better = 'loss' in args.metric or 'ppl' in args.metric or 'wer' in args.metric or 'cer' in args.metric if args.sort: plots.sort(key=lambda d: (min if lower_is_better else max)(d['y']),