Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def log(
):
metrics_ = [
('{}', prefix),
('wer {:.2f}', metrics.avg('wer')),
('cer {:.2f}', metrics.avg('cer')),
('bleu {:.2f}', metrics.avg('bleu')),
('spbleu {:.2f}', metrics.avg('spbleu')),
('len_ratio {:.3f}', metrics.avg('len_ratio')),
Expand Down
2 changes: 2 additions & 0 deletions cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ def log(
('ppl {:.2f}', ppl),
('lines {:.4g}', metrics.sum('num_lines')),
('tokens {:.4g}', metrics.sum('num_tokens')),
('wer {:.2f}', metrics.val('wer')),
('cer {:.2f}', metrics.val('cer')),
('bleu {:.2f}', metrics.val('bleu')),
('spbleu {:.2f}', metrics.val('spbleu')),
('len_ratio {:.3f}', metrics.val('len_ratio')),
Expand Down
1 change: 1 addition & 0 deletions pasero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class EvalConfig(Config):
defaults={
'language_modeling': [],
'translation': ['chrf', 'bleu', 'chrf++', 'spbleu', 'len_ratio'],
'speech_translation': ['chrf', 'bleu', 'chrf++', 'spbleu', 'len_ratio', 'wer', 'cer'],
},
help='evaluation metrics to compute'
)
Expand Down
37 changes: 30 additions & 7 deletions pasero/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,40 @@
import functools
import itertools
import random
import jiwer
from typing import Iterable, Iterator, Sequence, Optional

METRICS = ['chrf', 'bleu', 'langid', 'len_ratio', 'chrf++', 'spbleu', 'wer']
METRICS = ['chrf', 'bleu', 'langid', 'len_ratio', 'chrf++', 'spbleu', 'wer', 'cer']
BLEU_TOKENIZERS = sacrebleu.metrics.METRICS['BLEU'].TOKENIZERS

JIWER_TRANSFORMATIONS = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveMultipleSpaces(),
jiwer.Strip(),
jiwer.RemovePunctuation(),
jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
])
JICER_TRANSFORMATIONS = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveMultipleSpaces(),
jiwer.Strip(),
jiwer.RemovePunctuation(),
jiwer.ReduceToListOfListOfChars(),
])

def lower_is_better(metric: str) -> bool:
return 'loss' in metric or 'ppl' in metric
return 'loss' in metric or 'ppl' in metric or 'wer' in metric or 'cer' in metric

def jiwer_scores(
hyps: list[str],
refs: list[str],
metric: str = 'wer',
):
if metric == 'wer':
return 100*jiwer.wer(refs, hyps, truth_transform=JIWER_TRANSFORMATIONS, hypothesis_transform=JIWER_TRANSFORMATIONS)
elif metric == 'cer':
return 100*jiwer.cer(refs, hyps, truth_transform=JICER_TRANSFORMATIONS, hypothesis_transform=JICER_TRANSFORMATIONS)
else:
raise NotImplementedError(f"metric {metric} not implemented in jiwer_scores")

def langid_py(line: str) -> str:
import langid
Expand Down Expand Up @@ -275,10 +300,8 @@ def score(
chrf_word_order = 2 if metric == 'chrf++' else 0
bleu_tok = 'flores200' if metric == 'spbleu' else bleu_tok

if metric == 'wer':
from jiwer import wer
hyps, refs = zip(*[(hyp, ref) for hyp, ref in zip(hyps, refs) if hyp and ref])
score = 100 * wer(list(refs), list(hyps))
if metric in ('wer', 'cer'):
score = jiwer_scores(hyps, refs, metric)
elif metric in ('bleu', 'spbleu'):
metric_ = sacrebleu.metrics.BLEU(tokenize=bleu_tok, force=True)
out = metric_.corpus_score(hyps, [refs])
Expand Down
1 change: 1 addition & 0 deletions pasero/models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def forward(

batch_size, tgt_len, embed_dim = query.size()
src_len = key.size(1)
return_attn = True

q = self.q_proj(query) # BxTxD
q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim) # BxTxHxD'
Expand Down
2 changes: 1 addition & 1 deletion scripts/plot-logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
})


lower_is_better = 'loss' in args.metric or 'ppl' in args.metric
lower_is_better = 'loss' in args.metric or 'ppl' in args.metric or 'wer' in args.metric or 'cer' in args.metric

if args.sort:
plots.sort(key=lambda d: (min if lower_is_better else max)(d['y']),
Expand Down