diff --git a/trax/supervised/decoding.py b/trax/supervised/decoding.py index de54e4588..8f57a93a1 100644 --- a/trax/supervised/decoding.py +++ b/trax/supervised/decoding.py @@ -154,26 +154,26 @@ def autoregressive_sample(model, inputs=None, a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than `max_length`. """ - result = [] - eos_seen = [] - counter = 0 - for sample in autoregressive_sample_stream( - model, inputs, batch_size=batch_size, temperature=temperature, - start_id=start_id, accelerate=accelerate, eval_mode=eval_mode, - eval_min_length=eval_min_length): - sample = sample[:, None] - result.append(sample) - counter += 1 - if counter >= max_length: - return np.concatenate(result, axis=1) - # Check at which batch positions have we already encountered EOS. - for j in range(batch_size): - if int(sample[j, 0]) == eos_id: - eos_seen.append(j) - # If EOS has been seen on all positions, stop. - if all([j in eos_seen for j in range(batch_size)]): - return np.concatenate(result, axis=1) - return np.concatenate(result, axis=1) + saved_state = model.state + try: + for sample in autoregressive_sample_stream( + model, inputs, batch_size=batch_size, temperature=temperature, + start_id=start_id, accelerate=accelerate): + sample = sample[:, None] + result.append(sample) + counter += 1 + if counter >= max_length: + return np.concatenate(result, axis=1) + # Check at which batch positions have we already encountered EOS. + for j in range(batch_size): + if int(sample[j, 0]) == eos_id: + eos_seen.append(j) + # If EOS has been seen on all positions, stop. + if all([j in eos_seen for j in range(batch_size)]): + return np.concatenate(result, axis=1) + return np.concatenate(result, axis=1) + finally: + model.state = saved_state def beam_search(model, inputs=None, batch_size=1, n_beams=2, start_id=0,