diff --git a/pie/models/decoder.py b/pie/models/decoder.py index 55b6cb1..a08c359 100644 --- a/pie/models/decoder.py +++ b/pie/models/decoder.py @@ -315,8 +315,7 @@ def loss(self, logits, targets): return loss def predict_max(self, enc_outs, lengths, - max_seq_len=20, bos=None, eos=None, - context=None): + bos=None, eos=None, context=None): """ Decoding routine for inference with step-wise argmax procedure @@ -331,6 +330,8 @@ def predict_max(self, enc_outs, lengths, mask = torch.ones(batch, dtype=torch.int64, device=device) inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos hyps, scores = [], 0 + # Take the longest known token as the maximum sequence length. + max_seq_len = len(max(self.label_encoder.known_tokens, key=len)) for _ in range(max_seq_len): if mask.sum().item() == 0: