From 87279903fd6bb83963af321c3be1a91204609072 Mon Sep 17 00:00:00 2001 From: Vincent Prins Date: Fri, 27 Oct 2023 11:26:54 +0200 Subject: [PATCH] Fixed AttentionalDecoder truncating to 20 chars Instead, take the length of the longest known token --- pie/models/decoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pie/models/decoder.py b/pie/models/decoder.py index 55b6cb1..a08c359 100644 --- a/pie/models/decoder.py +++ b/pie/models/decoder.py @@ -315,8 +315,7 @@ def loss(self, logits, targets): return loss def predict_max(self, enc_outs, lengths, - max_seq_len=20, bos=None, eos=None, - context=None): + bos=None, eos=None, context=None): """ Decoding routine for inference with step-wise argmax procedure @@ -331,6 +330,8 @@ def predict_max(self, enc_outs, lengths, mask = torch.ones(batch, dtype=torch.int64, device=device) inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos hyps, scores = [], 0 + # Take the longest known token as the maximum sequence length. + max_seq_len = len(max(self.label_encoder.known_tokens, key=len)) for _ in range(max_seq_len): if mask.sum().item() == 0: