From d53ab71c6cd4a5f289ccfb9898b78a58e5dfb639 Mon Sep 17 00:00:00 2001 From: Ning Dong Date: Sun, 27 Oct 2019 16:46:05 -0700 Subject: [PATCH] Flow side support for separate decoders / Fix max_iter at inference Summary: 1. Add flow side support for separate decoder configs 2. Fix #iterations at inference. Will add adaptive decoding support in a separate diff. Reviewed By: kahne Differential Revision: D18165098 fbshipit-source-id: a3692b864aad3e87ab0b7edaf7548e291f7f434d --- pytorch_translate/ensemble_export.py | 31 ++++++++++------------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index 5b0de249..55e36afb 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -2073,7 +2073,7 @@ def finalize_hypos_loop_attns( class IterativeRefinementGenerateAndDecode(torch.jit.ScriptModule): def __init__( - self, checkpoint_files, src_dict_filename, tgt_dict_filename, max_iter=2 + self, checkpoint_files, src_dict_filename, tgt_dict_filename, max_iter=1 ): super().__init__() self.models, _, tgt_dict = load_models_from_checkpoints( @@ -2209,21 +2209,16 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): decoding_format=self.decoding_format, ) - if self.adaptive: - # terminate if looping. - terminated, output_tokens, output_scores, output_attn = is_a_loop( - self.pad, - prev_output_tokens, - decoder_out[0], - decoder_out[1], - decoder_out[2], - ) - decoder_out[0] = output_tokens - decoder_out[1] = output_scores - decoder_out[2] = output_attn - - else: - terminated = torch.zeros_like(decoder_out[0]).bool() + terminated, output_tokens, output_scores, output_attn = is_a_loop( + self.pad, + prev_output_tokens, + decoder_out[0], + decoder_out[1], + decoder_out[2], + ) + decoder_out[0] = output_tokens + decoder_out[1] = output_scores + decoder_out[2] = output_attn terminated = last_step(step, self.max_iter, terminated) # collect finalized sentences @@ -2257,10 +2252,6 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): finalized_attn, ) - # check if all terminated - if terminated.sum() == terminated.size(0): - break - # for next step prev_decoder_out = [ script_skip_tensor(decoder_out[0], ~terminated),