From fd31f2bb09e4892f8c5f94ba25ec54f3c39c5880 Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 11:59:17 +0100 Subject: [PATCH 01/10] Add gumbel softmax - soft and hard --- machine/util/__init__.py | 1 + machine/util/gumbel.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 machine/util/gumbel.py diff --git a/machine/util/__init__.py b/machine/util/__init__.py index 41f5fcc2..4d27d9bc 100644 --- a/machine/util/__init__.py +++ b/machine/util/__init__.py @@ -1,2 +1,3 @@ from .log import Log from .checkpoint import Checkpoint +from .gumbel import gumbel_softmax diff --git a/machine/util/gumbel.py b/machine/util/gumbel.py new file mode 100644 index 00000000..e38007f0 --- /dev/null +++ b/machine/util/gumbel.py @@ -0,0 +1,21 @@ +import torch +from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical + +def gumbel_softmax(probs, tau, hard): + """ Computes samplig from the Gumbel Softmax (GS) distribution + + Args: + probs (torch.tensor): probabilities of shape [batch_size, n_classes] + tau (float): temperature parameter for the GS + hard (bool): discretize if True + """ + + rohc = RelaxedOneHotCategorical(tau, probs) + y = rohc.rsample() + + if hard: + y_hard = torch.zeros_like(y) + y_hard.scatter_(-1, torch.argmax(y, dim=-1, keepdim=True), 1.0) + y = (y_hard - y).detach() + y + + return y \ No newline at end of file From bac0e11ed23e77ee9126d7a57d460ef659131515 Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 11:59:36 +0100 Subject: [PATCH 02/10] Add UTs for Gumbel Softmax --- test/test_gumbel.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 test/test_gumbel.py diff --git a/test/test_gumbel.py b/test/test_gumbel.py new file mode 100644 index 00000000..17ae2dc0 --- /dev/null +++ b/test/test_gumbel.py @@ -0,0 +1,30 @@ +import unittest +import torch + +from machine.util.gumbel import gumbel_softmax + + +class TestGumbelSoftmax(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.probs = torch.tensor([ + [0.3, 0.4, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.2, 0.3, 0.1], + [0.6, 0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2, 0.2]]) + self.tau = 1.2 + + def test_hard(self): + res = gumbel_softmax(self.probs, self.tau, hard=True) + + self.assertEqual(res.shape, (4, 5)) + for r in res: + self.assertTrue(r.sum(), 1.0) + + def test_soft(self): + res = gumbel_softmax(self.probs, self.tau, hard=False) + + self.assertEqual(res.shape, (4, 5)) + + \ No newline at end of file From 52169e9963d6a9147f0baaf583251ec40f8227d1 Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 12:05:09 +0100 Subject: [PATCH 03/10] Export sender --- machine/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/machine/models/__init__.py b/machine/models/__init__.py index 38e08d7c..d3730142 100644 --- a/machine/models/__init__.py +++ b/machine/models/__init__.py @@ -4,3 +4,4 @@ from .seq2seq import Seq2seq from .baseModel import BaseModel from .LanguageModel import LanguageModel +from .Sender import Sender From c119fdec8cb6bc24b6c2f5e23a41e16a04cfd5aa Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 12:26:27 +0100 Subject: [PATCH 04/10] Add inputs and outputs --- machine/models/Sender.py | 178 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 machine/models/Sender.py diff --git a/machine/models/Sender.py b/machine/models/Sender.py new file mode 100644 index 00000000..c6374f8a --- /dev/null +++ b/machine/models/Sender.py @@ -0,0 +1,178 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.distributions.categorical import Categorical + +from .baseRNN import BaseRNN +from machine.util.gumbel import gumbel_softmax + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Sender(BaseRNN): + """ + Applies a rnn cell to an input sequence and uses Gumbel softmax sampling to + generate a output sequence. + + Args: + vocab_size (int): size of the vocabulary + output_len (int): the length of the sequence to be generated + embedding_size (int): the size of the embedding of input variables + hidden_size (int): the size of the hidden dimension of the rnn + sos_id (int): index of the start of sequence symbol + eos_id (int): index of the end of sequence symbol + rnn_cell (str, optional): type of RNN cell (default: gru) + greedy (bool, optional): True if use argmax at prediction time, False if sample (default: False) + + Inputs: + tau (float): Temperature to be used for Gumbel Softmax. + hidden_state (torch.tensor, optional): The hidden state to start the decoding. (default=None) + Shape [batch_size, hidden_size]. If None, batch_size=1. + + Outputs: + output_sequence (torch.tensor): The generated decoded sequences. Shape [batch_size, output_len+1] + E.g. of a sequence at prediction time [sos_id, predicted_1, predicted_2,...., predicted_outputlen] + sequence_lengths (torch.tensor): The lengths of all the sequences in the batch. Shape [batch_size] + + """ + + def __init__(self, vocab_size, output_len, embedding_size, + hidden_size, sos_id, eos_id, rnn_cell='gru', greedy=False): + super().__init__(vocab_size, output_len, hidden_size, + input_dropout_p=0, dropout_p=0, + n_layers=1, rnn_cell=rnn_cell, relaxed=True) + + self.output_len = output_len + self.embedding_size = embedding_size + self.sos_id = sos_id + self.eos_id = eos_id + self.greedy = greedy + + self.embedding = nn.Parameter(torch.empty((self.vocab_size, self.embedding_size), dtype=torch.float32)) + self.rnn = self.rnn_cell(self.embedding_size, self.hidden_size) + self.linear_out = nn.Linear(self.hidden_size, self.vocab_size) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.normal_(self.embedding, 0.0, 0.1) + + nn.init.constant_(self.linear_out.weight, 0) + nn.init.constant_(self.linear_out.bias, 0) + + nn.init.xavier_uniform_(self.rnn.weight_ih) + nn.init.orthogonal_(self.rnn.weight_hh) + nn.init.constant_(self.rnn.bias_ih, val=0) + # # cuDNN bias order: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t + # # add some positive bias for the forget gates [b_i, b_f, b_o, b_g] = [0, 1, 0, 0] + nn.init.constant_(self.rnn.bias_hh, val=0) + nn.init.constant_(self.rnn.bias_hh[self.hidden_size:2 * self.hidden_size], val=1) + + def _init_state(self, hidden_state, rnn_type): + """ + Handles the initialization of the first hidden state of the decoder. + Hidden state + cell state in the case of an LSTM cell or + only hidden state in the case of a GRU cell. + + Args: + hidden_state (torch.tensor): The state to initialize the decoding with. + rnn_type (type): Type of the rnn cell. + + Returns: + state: (h, c) if LSTM cell, h if GRU cell + batch_size: Based on the given hidden_state if not None, 1 otherwise + """ + + # h0 + if hidden_state is None: + batch_size = 1 + h = torch.zeros([batch_size, self.hidden_size], device=device) + else: + batch_size = hidden_state.shape[0] + h = hidden_state # batch_size, hidden_size + + # c0 + if rnn_type is nn.LSTMCell: + c = torch.zeros([batch_size, self.hidden_size], device=device) + + state = (h,c) + else: + state = h + + return state, batch_size + + def _calculate_seq_len(self, seq_lengths, token, initial_length, seq_pos, + n_sos_symbols, is_discrete): + """ + Calculates the lengths of each sequence in the batch in-place. + The length goes from the start of the sequece up until the eos_id is predicted. + If it is not predicted, then the length is output_len + n_sos_symbols. + + Args: + seq_lengths (torch.tensor): To keep track of the sequence lengths. + token (torch.tensor): Batch of predicted tokens at this timestep. + initial_length (int): The max possible sequence length (output_len + n_sos_symbols). + seq_pos (int): The current timestep. + n_sos_symbols (int): Number of sos symbols at the beginning of the sequence. + is_discrete (bool): True if Gumbel Softmax is used, False otherwise. + + """ + + for idx, elem in enumerate(token): + if seq_lengths[idx] == initial_length: + if (is_discrete and elem == self.eos_id) or (not is_discrete and elem[self.eos_id] == 1.0): + seq_lengths[idx] = seq_pos + n_sos_symbols + + def forward(self, tau, hidden_state=None): + """ + Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use + discrete sampling. + """ + + state, batch_size = self._init_state(hidden_state, type(self.rnn)) + + # Init output + if self.training: + output = [torch.zeros((batch_size, self.vocab_size), dtype=torch.float32, device=device)] + output[0][:, self.sos_id] = 1.0 + else: + output = [torch.full((batch_size, ), fill_value=self.sos_id, dtype=torch.int64, device=device)] + + # Keep track of sequence lengths + n_sos_symbols = 1 + initial_length = self.output_len + n_sos_symbols + seq_lengths = torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length + + for i in range(self.output_len): + if self.training: + emb = torch.matmul(output[-1], self.embedding) + else: + emb = self.embedding[output[-1]] + + state = self.rnn(emb, state) + + if type(self.rnn) is nn.LSTMCell: + h, c = state + else: + h = state + + p = F.softmax(self.linear_out(h), dim=1) + + if self.training: + token = gumbel_softmax(p, tau, hard=True) + else: + if self.greedy: + _, token = torch.max(p, -1) + else: + token = Categorical(p).sample() + + if batch_size == 1: + token = token.unsqueeze(0) + + output.append(token) + + self._calculate_seq_len( + seq_lengths, token, initial_length, seq_pos=i+1, + n_sos_symbols=n_sos_symbols, is_discrete=not self.training) + + return (torch.stack(output, dim=1), seq_lengths) \ No newline at end of file From 74101a51a09f94b006a250e2d60803c63f4ad822 Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 12:35:05 +0100 Subject: [PATCH 05/10] Revert --- machine/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/machine/models/__init__.py b/machine/models/__init__.py index d3730142..38e08d7c 100644 --- a/machine/models/__init__.py +++ b/machine/models/__init__.py @@ -4,4 +4,3 @@ from .seq2seq import Seq2seq from .baseModel import BaseModel from .LanguageModel import LanguageModel -from .Sender import Sender From ffaff75eac6c938a0920f17e828680cbba4a39d7 Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 12:35:48 +0100 Subject: [PATCH 06/10] Add relaxed flag --- machine/models/baseRNN.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/machine/models/baseRNN.py b/machine/models/baseRNN.py index 5c52f93d..e4e113ff 100644 --- a/machine/models/baseRNN.py +++ b/machine/models/baseRNN.py @@ -28,7 +28,8 @@ class BaseRNN(nn.Module): SYM_EOS = "EOS" def __init__(self, vocab_size, max_len, hidden_size, - input_dropout_p, dropout_p, n_layers, rnn_cell): + input_dropout_p, dropout_p, n_layers, rnn_cell, + relaxed=False): super(BaseRNN, self).__init__() self.vocab_size = vocab_size self.max_len = max_len @@ -37,9 +38,15 @@ def __init__(self, vocab_size, max_len, hidden_size, self.input_dropout_p = input_dropout_p self.input_dropout = nn.Dropout(p=input_dropout_p) if rnn_cell.lower() == 'lstm': - self.rnn_cell = nn.LSTM + if relaxed: + self.rnn_cell = nn.LSTMCell + else: + self.rnn_cell = nn.LSTM elif rnn_cell.lower() == 'gru': - self.rnn_cell = nn.GRU + if relaxed: + self.rnn_cell = nn.GRUCell + else: + self.rnn_cell = nn.GRU else: raise ValueError("Unsupported RNN Cell: {0}".format(rnn_cell)) From 2f116fe953f26121044e11b4fc08b8f8c2d428bb Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 12:36:09 +0100 Subject: [PATCH 07/10] Add UTs for Sender --- test/test_sender.py | 389 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100644 test/test_sender.py diff --git a/test/test_sender.py b/test/test_sender.py new file mode 100644 index 00000000..22084ec2 --- /dev/null +++ b/test/test_sender.py @@ -0,0 +1,389 @@ +import unittest +import mock + +import torch + +from machine.models.Sender import Sender + +class TestSender(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.vocab_size = 4 + self.max_len = 5 + self.embedding_size = 8 + self.hidden_size = 16 + self.sos_id = 0 + self.eos_id = self.vocab_size - 1 + self.tau = 1.2 + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_lstm_hidden_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + sender.train() + + # One per max_len (2 elements in batch - a and b) + a1 = [1.0, 0.0, 0.0, 0.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + b1 = [0.0, 1.0, 0.0, 0.0] + b2 = [0.0, 0.0, 1.0, 0.0] + b3 = [0.0, 0.0, 0.0, 1.0] + b4 = [0.0, 0.0, 1.0, 0.0] + b5 = [1.0, 0.0, 0.0, 0.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1, b1]), + torch.tensor([a2, b2]), + torch.tensor([a3, b3]), + torch.tensor([a4, b4]), + torch.tensor([a5, b5]) + ]) + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + b1, + b2, + b3, + b4, + b5 + ])))) + + self.assertEqual(seq_lengths[0], 6) + self.assertEqual(seq_lengths[1], 4) + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_gru_hidden_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='gru', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + sender.train() + + # One per max_len (2 elements in batch - a and b) + a1 = [1.0, 0.0, 0.0, 0.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + b1 = [0.0, 1.0, 0.0, 0.0] + b2 = [0.0, 0.0, 1.0, 0.0] + b3 = [0.0, 0.0, 0.0, 1.0] + b4 = [0.0, 0.0, 1.0, 0.0] + b5 = [1.0, 0.0, 0.0, 0.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1, b1]), + torch.tensor([a2, b2]), + torch.tensor([a3, b3]), + torch.tensor([a4, b4]), + torch.tensor([a5, b5]) + ]) + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + b1, + b2, + b3, + b4, + b5 + ])))) + + self.assertEqual(seq_lengths[0], 6) + self.assertEqual(seq_lengths[1], 4) + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_lstm_not_hidden_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 1 + n_sos_tokens = 1 + + sender.train() + + # One per max_len (1 element in batch - a) + a1 = [0.0, 0.0, 0.0, 1.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1]), + torch.tensor([a2]), + torch.tensor([a3]), + torch.tensor([a4]), + torch.tensor([a5]) + ]) + + res, seq_lengths = sender(self.tau) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertEqual(seq_lengths[0], 2) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_hidden_eval_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 1, + 3, + 0, + 0, + 3 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([self.sos_id, + 0, + 1, + 0, + 0, + 1 + ])))) + + self.assertEqual(seq_lengths[0], 3) + self.assertEqual(seq_lengths[1], 6) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_hidden_eval_not_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=False) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_softmax.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertEqual(res[0][0], self.sos_id) + self.assertEqual(res[1][0], self.sos_id) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_not_hidden_eval_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 1 + n_sos_tokens = 1 + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([0.3, 0.2, 0.1, 0.4]), + torch.tensor([0.4, 0.3, 0.1, 0.2]), + torch.tensor([0.2, 0.4, 0.2, 0.2]), + torch.tensor([0.2, 0.5, 0.1, 0.2]), + torch.tensor([0.0, 0.0, 0.8, 0.2]), + ] + + sender.eval() + + res, seq_lengths = sender(self.tau) + + self.assertEqual(mock_softmax.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 3, + 0, + 1, + 1, + 2 + ])))) + + self.assertEqual(seq_lengths[0], 2) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_not_hidden_eval_not_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=False) + + batch_size = 1 + n_sos_tokens = 1 + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([0.3, 0.2, 0.1, 0.4]), + torch.tensor([0.4, 0.3, 0.1, 0.2]), + torch.tensor([0.2, 0.4, 0.2, 0.2]), + torch.tensor([0.2, 0.5, 0.1, 0.2]), + torch.tensor([0.0, 0.0, 0.8, 0.2]), + ] + + sender.eval() + + res, seq_lengths = sender(self.tau) + + self.assertEqual(mock_softmax.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertEqual(res[0][0], self.sos_id) + + @mock.patch('machine.models.Sender.F.softmax') + def test_gru_hidden_eval_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='gru', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 1, + 3, + 0, + 0, + 3 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([self.sos_id, + 0, + 1, + 0, + 0, + 1 + ])))) + + self.assertEqual(seq_lengths[0], 3) + self.assertEqual(seq_lengths[1], 6) From 7d563c5ccb7ce6fa5cdfd0986ec2873799e1d4ca Mon Sep 17 00:00:00 2001 From: diana Date: Mon, 18 Mar 2019 20:24:18 +0100 Subject: [PATCH 08/10] Fix indentation --- machine/models/Sender.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/machine/models/Sender.py b/machine/models/Sender.py index c6374f8a..5aa42cae 100644 --- a/machine/models/Sender.py +++ b/machine/models/Sender.py @@ -20,21 +20,21 @@ class Sender(BaseRNN): embedding_size (int): the size of the embedding of input variables hidden_size (int): the size of the hidden dimension of the rnn sos_id (int): index of the start of sequence symbol - eos_id (int): index of the end of sequence symbol + eos_id (int): index of the end of sequence symbol rnn_cell (str, optional): type of RNN cell (default: gru) - greedy (bool, optional): True if use argmax at prediction time, False if sample (default: False) + greedy (bool, optional): True if use argmax at prediction time, False if sample (default: False) Inputs: - tau (float): Temperature to be used for Gumbel Softmax. - hidden_state (torch.tensor, optional): The hidden state to start the decoding. (default=None) - Shape [batch_size, hidden_size]. If None, batch_size=1. + tau (float): Temperature to be used for Gumbel Softmax. + hidden_state (torch.tensor, optional): The hidden state to start the decoding. (default=None) + Shape [batch_size, hidden_size]. If None, batch_size=1. Outputs: - output_sequence (torch.tensor): The generated decoded sequences. Shape [batch_size, output_len+1] - E.g. of a sequence at prediction time [sos_id, predicted_1, predicted_2,...., predicted_outputlen] - sequence_lengths (torch.tensor): The lengths of all the sequences in the batch. Shape [batch_size] + output_sequence (torch.tensor): The generated decoded sequences. Shape [batch_size, output_len+1] + E.g. of a sequence at prediction time [sos_id, predicted_1, predicted_2,...., predicted_outputlen] + sequence_lengths (torch.tensor): The lengths of all the sequences in the batch. Shape [batch_size] - """ + """ def __init__(self, vocab_size, output_len, embedding_size, hidden_size, sos_id, eos_id, rnn_cell='gru', greedy=False): From bd2c5986cf11eb830e0a5d3ced52386b132f47b3 Mon Sep 17 00:00:00 2001 From: diana Date: Tue, 19 Mar 2019 14:05:03 +0100 Subject: [PATCH 09/10] Add optional for return seq lengths --- machine/models/Sender.py | 30 +++++++++++++++++++++--------- test/test_sender.py | 29 +++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/machine/models/Sender.py b/machine/models/Sender.py index 5aa42cae..b8cf460a 100644 --- a/machine/models/Sender.py +++ b/machine/models/Sender.py @@ -23,6 +23,8 @@ class Sender(BaseRNN): eos_id (int): index of the end of sequence symbol rnn_cell (str, optional): type of RNN cell (default: gru) greedy (bool, optional): True if use argmax at prediction time, False if sample (default: False) + compute_lengths (bool, optional): True if the length of each sequence in the batch is to be computed + by looking for eos tokens. Inputs: tau (float): Temperature to be used for Gumbel Softmax. @@ -33,11 +35,12 @@ class Sender(BaseRNN): output_sequence (torch.tensor): The generated decoded sequences. Shape [batch_size, output_len+1] E.g. of a sequence at prediction time [sos_id, predicted_1, predicted_2,...., predicted_outputlen] sequence_lengths (torch.tensor): The lengths of all the sequences in the batch. Shape [batch_size] + Only returned if compute_lenghts=True """ - def __init__(self, vocab_size, output_len, embedding_size, - hidden_size, sos_id, eos_id, rnn_cell='gru', greedy=False): + def __init__(self, vocab_size, output_len, embedding_size, hidden_size, + sos_id, eos_id, rnn_cell='gru', greedy=False, compute_lengths=False): super().__init__(vocab_size, output_len, hidden_size, input_dropout_p=0, dropout_p=0, n_layers=1, rnn_cell=rnn_cell, relaxed=True) @@ -47,6 +50,7 @@ def __init__(self, vocab_size, output_len, embedding_size, self.sos_id = sos_id self.eos_id = eos_id self.greedy = greedy + self.compute_lengths = compute_lengths self.embedding = nn.Parameter(torch.empty((self.vocab_size, self.embedding_size), dtype=torch.float32)) self.rnn = self.rnn_cell(self.embedding_size, self.hidden_size) @@ -139,9 +143,10 @@ def forward(self, tau, hidden_state=None): output = [torch.full((batch_size, ), fill_value=self.sos_id, dtype=torch.int64, device=device)] # Keep track of sequence lengths - n_sos_symbols = 1 - initial_length = self.output_len + n_sos_symbols - seq_lengths = torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length + if self.compute_lengths: + n_sos_symbols = 1 + initial_length = self.output_len + n_sos_symbols + seq_lengths = torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length for i in range(self.output_len): if self.training: @@ -171,8 +176,15 @@ def forward(self, tau, hidden_state=None): output.append(token) - self._calculate_seq_len( - seq_lengths, token, initial_length, seq_pos=i+1, - n_sos_symbols=n_sos_symbols, is_discrete=not self.training) + if self.compute_lengths: + self._calculate_seq_len( + seq_lengths, token, initial_length, seq_pos=i+1, + n_sos_symbols=n_sos_symbols, is_discrete=not self.training) - return (torch.stack(output, dim=1), seq_lengths) \ No newline at end of file + + outputs = torch.stack(output, dim=1) + + if self.compute_lengths: + return (outputs, seq_lengths) + else: + return outputs \ No newline at end of file diff --git a/test/test_sender.py b/test/test_sender.py index 22084ec2..e0934cac 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -21,7 +21,7 @@ def setUpClass(self): def test_lstm_hidden_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True) + rnn_cell='lstm', greedy=True, compute_lengths=True) batch_size = 2 n_sos_tokens = 1 @@ -85,7 +85,7 @@ def test_lstm_hidden_train(self, mock_gumbel): def test_gru_hidden_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='gru', greedy=True) + rnn_cell='gru', greedy=True, compute_lengths=True) batch_size = 2 n_sos_tokens = 1 @@ -149,7 +149,7 @@ def test_gru_hidden_train(self, mock_gumbel): def test_lstm_not_hidden_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True) + rnn_cell='lstm', greedy=True, compute_lengths=True) batch_size = 1 n_sos_tokens = 1 @@ -195,7 +195,7 @@ def test_lstm_not_hidden_train(self, mock_gumbel): def test_lstm_hidden_eval_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True) + rnn_cell='lstm', greedy=True, compute_lengths=True) batch_size = 2 n_sos_tokens = 1 @@ -244,7 +244,7 @@ def test_lstm_hidden_eval_greedy(self, mock_softmax): def test_lstm_hidden_eval_not_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=False) + rnn_cell='lstm', greedy=False, compute_lengths=True) batch_size = 2 n_sos_tokens = 1 @@ -275,7 +275,7 @@ def test_lstm_hidden_eval_not_greedy(self, mock_softmax): def test_lstm_not_hidden_eval_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True) + rnn_cell='lstm', greedy=True, compute_lengths=True) batch_size = 1 n_sos_tokens = 1 @@ -314,7 +314,7 @@ def test_lstm_not_hidden_eval_greedy(self, mock_softmax): def test_lstm_not_hidden_eval_not_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=False) + rnn_cell='lstm', greedy=False, compute_lengths=True) batch_size = 1 n_sos_tokens = 1 @@ -343,7 +343,7 @@ def test_lstm_not_hidden_eval_not_greedy(self, mock_softmax): def test_gru_hidden_eval_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='gru', greedy=True) + rnn_cell='gru', greedy=True, compute_lengths=True) batch_size = 2 n_sos_tokens = 1 @@ -387,3 +387,16 @@ def test_gru_hidden_eval_greedy(self, mock_softmax): self.assertEqual(seq_lengths[0], 3) self.assertEqual(seq_lengths[1], 6) + + def test_not_compute_lengths(self): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=False, compute_lengths=False) + + batch_size = 1 + n_sos_tokens = 1 + + res = sender(self.tau) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) \ No newline at end of file From e3ed730b79123be2ff8a9217ba263f79a114a794 Mon Sep 17 00:00:00 2001 From: diana Date: Tue, 19 Mar 2019 16:01:29 +0100 Subject: [PATCH 10/10] Change computing seq lengths code to Gautier's version, which is much faster --- machine/models/Sender.py | 42 ++++++------- test/test_sender.py | 128 ++++++++++++++++++++++++++++++++++----- 2 files changed, 131 insertions(+), 39 deletions(-) diff --git a/machine/models/Sender.py b/machine/models/Sender.py index b8cf460a..e73fe7c9 100644 --- a/machine/models/Sender.py +++ b/machine/models/Sender.py @@ -23,8 +23,6 @@ class Sender(BaseRNN): eos_id (int): index of the end of sequence symbol rnn_cell (str, optional): type of RNN cell (default: gru) greedy (bool, optional): True if use argmax at prediction time, False if sample (default: False) - compute_lengths (bool, optional): True if the length of each sequence in the batch is to be computed - by looking for eos tokens. Inputs: tau (float): Temperature to be used for Gumbel Softmax. @@ -35,12 +33,11 @@ class Sender(BaseRNN): output_sequence (torch.tensor): The generated decoded sequences. Shape [batch_size, output_len+1] E.g. of a sequence at prediction time [sos_id, predicted_1, predicted_2,...., predicted_outputlen] sequence_lengths (torch.tensor): The lengths of all the sequences in the batch. Shape [batch_size] - Only returned if compute_lenghts=True """ def __init__(self, vocab_size, output_len, embedding_size, hidden_size, - sos_id, eos_id, rnn_cell='gru', greedy=False, compute_lengths=False): + sos_id, eos_id, rnn_cell='gru', greedy=False): super().__init__(vocab_size, output_len, hidden_size, input_dropout_p=0, dropout_p=0, n_layers=1, rnn_cell=rnn_cell, relaxed=True) @@ -50,7 +47,6 @@ def __init__(self, vocab_size, output_len, embedding_size, hidden_size, self.sos_id = sos_id self.eos_id = eos_id self.greedy = greedy - self.compute_lengths = compute_lengths self.embedding = nn.Parameter(torch.empty((self.vocab_size, self.embedding_size), dtype=torch.float32)) self.rnn = self.rnn_cell(self.embedding_size, self.hidden_size) @@ -121,11 +117,15 @@ def _calculate_seq_len(self, seq_lengths, token, initial_length, seq_pos, is_discrete (bool): True if Gumbel Softmax is used, False otherwise. """ - - for idx, elem in enumerate(token): - if seq_lengths[idx] == initial_length: - if (is_discrete and elem == self.eos_id) or (not is_discrete and elem[self.eos_id] == 1.0): - seq_lengths[idx] = seq_pos + n_sos_symbols + + if is_discrete: + mask = token == self.eos_id + else: + max_predicted, vocab_index = torch.max(token, dim=1) + mask = (vocab_index == self.eos_id) * (max_predicted == 1.0) + + mask *= seq_lengths == initial_length + seq_lengths[mask.nonzero()] = seq_pos + n_sos_symbols def forward(self, tau, hidden_state=None): """ @@ -143,10 +143,9 @@ def forward(self, tau, hidden_state=None): output = [torch.full((batch_size, ), fill_value=self.sos_id, dtype=torch.int64, device=device)] # Keep track of sequence lengths - if self.compute_lengths: - n_sos_symbols = 1 - initial_length = self.output_len + n_sos_symbols - seq_lengths = torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length + n_sos_symbols = 1 + initial_length = self.output_len + n_sos_symbols + seq_lengths = torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length for i in range(self.output_len): if self.training: @@ -176,15 +175,8 @@ def forward(self, tau, hidden_state=None): output.append(token) - if self.compute_lengths: - self._calculate_seq_len( - seq_lengths, token, initial_length, seq_pos=i+1, - n_sos_symbols=n_sos_symbols, is_discrete=not self.training) + self._calculate_seq_len( + seq_lengths, token, initial_length, seq_pos=i+1, + n_sos_symbols=n_sos_symbols, is_discrete=not self.training) - - outputs = torch.stack(output, dim=1) - - if self.compute_lengths: - return (outputs, seq_lengths) - else: - return outputs \ No newline at end of file + return (torch.stack(output, dim=1), seq_lengths) diff --git a/test/test_sender.py b/test/test_sender.py index e0934cac..bb25911b 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -21,7 +21,7 @@ def setUpClass(self): def test_lstm_hidden_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True, compute_lengths=True) + rnn_cell='lstm', greedy=True) batch_size = 2 n_sos_tokens = 1 @@ -85,7 +85,7 @@ def test_lstm_hidden_train(self, mock_gumbel): def test_gru_hidden_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='gru', greedy=True, compute_lengths=True) + rnn_cell='gru', greedy=True) batch_size = 2 n_sos_tokens = 1 @@ -149,7 +149,7 @@ def test_gru_hidden_train(self, mock_gumbel): def test_lstm_not_hidden_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True, compute_lengths=True) + rnn_cell='lstm', greedy=True) batch_size = 1 n_sos_tokens = 1 @@ -195,7 +195,7 @@ def test_lstm_not_hidden_train(self, mock_gumbel): def test_lstm_hidden_eval_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True, compute_lengths=True) + rnn_cell='lstm', greedy=True) batch_size = 2 n_sos_tokens = 1 @@ -244,7 +244,7 @@ def test_lstm_hidden_eval_greedy(self, mock_softmax): def test_lstm_hidden_eval_not_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=False, compute_lengths=True) + rnn_cell='lstm', greedy=False) batch_size = 2 n_sos_tokens = 1 @@ -275,7 +275,7 @@ def test_lstm_hidden_eval_not_greedy(self, mock_softmax): def test_lstm_not_hidden_eval_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=True, compute_lengths=True) + rnn_cell='lstm', greedy=True) batch_size = 1 n_sos_tokens = 1 @@ -314,7 +314,7 @@ def test_lstm_not_hidden_eval_greedy(self, mock_softmax): def test_lstm_not_hidden_eval_not_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=False, compute_lengths=True) + rnn_cell='lstm', greedy=False) batch_size = 1 n_sos_tokens = 1 @@ -343,7 +343,7 @@ def test_lstm_not_hidden_eval_not_greedy(self, mock_softmax): def test_gru_hidden_eval_greedy(self, mock_softmax): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='gru', greedy=True, compute_lengths=True) + rnn_cell='gru', greedy=True) batch_size = 2 n_sos_tokens = 1 @@ -388,15 +388,115 @@ def test_gru_hidden_eval_greedy(self, mock_softmax): self.assertEqual(seq_lengths[0], 3) self.assertEqual(seq_lengths[1], 6) - def test_not_compute_lengths(self): + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_lstm_bound_sos_train(self, mock_gumbel): sender = Sender(self.vocab_size, self.max_len, self.embedding_size, - self.hidden_size, self.sos_id, self.eos_id, - rnn_cell='lstm', greedy=False, compute_lengths=False) + self.hidden_size, self.sos_id, self.sos_id, + rnn_cell='lstm', greedy=True) - batch_size = 1 + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + sender.train() + + # One per max_len (2 elements in batch - a and b) + a1 = [1.0, 0.0, 0.0, 0.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + b1 = [0.0, 1.0, 0.0, 0.0] + b2 = [0.0, 0.0, 1.0, 0.0] + b3 = [0.0, 0.0, 0.0, 1.0] + b4 = [0.0, 0.0, 1.0, 0.0] + b5 = [1.0, 0.0, 0.0, 0.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1, b1]), + torch.tensor([a2, b2]), + torch.tensor([a3, b3]), + torch.tensor([a4, b4]), + torch.tensor([a5, b5]) + ]) + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + b1, + b2, + b3, + b4, + b5 + ])))) + + self.assertEqual(seq_lengths[0], 2) + self.assertEqual(seq_lengths[1], 6) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_bound_sos_eval(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.sos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] - res = sender(self.tau) + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) self.assertEqual(res.shape[0], batch_size) - self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) \ No newline at end of file + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 1, + 3, + 0, + 0, + 3 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([self.sos_id, + 0, + 1, + 0, + 0, + 1 + ])))) + + self.assertEqual(seq_lengths[0], 4) + self.assertEqual(seq_lengths[1], 2) \ No newline at end of file