diff --git a/machine/models/Sender.py b/machine/models/Sender.py new file mode 100644 index 00000000..e73fe7c9 --- /dev/null +++ b/machine/models/Sender.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.distributions.categorical import Categorical + +from .baseRNN import BaseRNN +from machine.util.gumbel import gumbel_softmax + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Sender(BaseRNN): + """ + Applies a rnn cell to an input sequence and uses Gumbel softmax sampling to + generate a output sequence. + + Args: + vocab_size (int): size of the vocabulary + output_len (int): the length of the sequence to be generated + embedding_size (int): the size of the embedding of input variables + hidden_size (int): the size of the hidden dimension of the rnn + sos_id (int): index of the start of sequence symbol + eos_id (int): index of the end of sequence symbol + rnn_cell (str, optional): type of RNN cell (default: gru) + greedy (bool, optional): True if use argmax at prediction time, False if sample (default: False) + + Inputs: + tau (float): Temperature to be used for Gumbel Softmax. + hidden_state (torch.tensor, optional): The hidden state to start the decoding. (default=None) + Shape [batch_size, hidden_size]. If None, batch_size=1. + + Outputs: + output_sequence (torch.tensor): The generated decoded sequences. Shape [batch_size, output_len+1] + E.g. of a sequence at prediction time [sos_id, predicted_1, predicted_2,...., predicted_outputlen] + sequence_lengths (torch.tensor): The lengths of all the sequences in the batch. Shape [batch_size] + + """ + + def __init__(self, vocab_size, output_len, embedding_size, hidden_size, + sos_id, eos_id, rnn_cell='gru', greedy=False): + super().__init__(vocab_size, output_len, hidden_size, + input_dropout_p=0, dropout_p=0, + n_layers=1, rnn_cell=rnn_cell, relaxed=True) + + self.output_len = output_len + self.embedding_size = embedding_size + self.sos_id = sos_id + self.eos_id = eos_id + self.greedy = greedy + + self.embedding = nn.Parameter(torch.empty((self.vocab_size, self.embedding_size), dtype=torch.float32)) + self.rnn = self.rnn_cell(self.embedding_size, self.hidden_size) + self.linear_out = nn.Linear(self.hidden_size, self.vocab_size) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.normal_(self.embedding, 0.0, 0.1) + + nn.init.constant_(self.linear_out.weight, 0) + nn.init.constant_(self.linear_out.bias, 0) + + nn.init.xavier_uniform_(self.rnn.weight_ih) + nn.init.orthogonal_(self.rnn.weight_hh) + nn.init.constant_(self.rnn.bias_ih, val=0) + # # cuDNN bias order: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t + # # add some positive bias for the forget gates [b_i, b_f, b_o, b_g] = [0, 1, 0, 0] + nn.init.constant_(self.rnn.bias_hh, val=0) + nn.init.constant_(self.rnn.bias_hh[self.hidden_size:2 * self.hidden_size], val=1) + + def _init_state(self, hidden_state, rnn_type): + """ + Handles the initialization of the first hidden state of the decoder. + Hidden state + cell state in the case of an LSTM cell or + only hidden state in the case of a GRU cell. + + Args: + hidden_state (torch.tensor): The state to initialize the decoding with. + rnn_type (type): Type of the rnn cell. + + Returns: + state: (h, c) if LSTM cell, h if GRU cell + batch_size: Based on the given hidden_state if not None, 1 otherwise + """ + + # h0 + if hidden_state is None: + batch_size = 1 + h = torch.zeros([batch_size, self.hidden_size], device=device) + else: + batch_size = hidden_state.shape[0] + h = hidden_state # batch_size, hidden_size + + # c0 + if rnn_type is nn.LSTMCell: + c = torch.zeros([batch_size, self.hidden_size], device=device) + + state = (h,c) + else: + state = h + + return state, batch_size + + def _calculate_seq_len(self, seq_lengths, token, initial_length, seq_pos, + n_sos_symbols, is_discrete): + """ + Calculates the lengths of each sequence in the batch in-place. + The length goes from the start of the sequece up until the eos_id is predicted. + If it is not predicted, then the length is output_len + n_sos_symbols. + + Args: + seq_lengths (torch.tensor): To keep track of the sequence lengths. + token (torch.tensor): Batch of predicted tokens at this timestep. + initial_length (int): The max possible sequence length (output_len + n_sos_symbols). + seq_pos (int): The current timestep. + n_sos_symbols (int): Number of sos symbols at the beginning of the sequence. + is_discrete (bool): True if Gumbel Softmax is used, False otherwise. + + """ + + if is_discrete: + mask = token == self.eos_id + else: + max_predicted, vocab_index = torch.max(token, dim=1) + mask = (vocab_index == self.eos_id) * (max_predicted == 1.0) + + mask *= seq_lengths == initial_length + seq_lengths[mask.nonzero()] = seq_pos + n_sos_symbols + + def forward(self, tau, hidden_state=None): + """ + Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use + discrete sampling. + """ + + state, batch_size = self._init_state(hidden_state, type(self.rnn)) + + # Init output + if self.training: + output = [torch.zeros((batch_size, self.vocab_size), dtype=torch.float32, device=device)] + output[0][:, self.sos_id] = 1.0 + else: + output = [torch.full((batch_size, ), fill_value=self.sos_id, dtype=torch.int64, device=device)] + + # Keep track of sequence lengths + n_sos_symbols = 1 + initial_length = self.output_len + n_sos_symbols + seq_lengths = torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length + + for i in range(self.output_len): + if self.training: + emb = torch.matmul(output[-1], self.embedding) + else: + emb = self.embedding[output[-1]] + + state = self.rnn(emb, state) + + if type(self.rnn) is nn.LSTMCell: + h, c = state + else: + h = state + + p = F.softmax(self.linear_out(h), dim=1) + + if self.training: + token = gumbel_softmax(p, tau, hard=True) + else: + if self.greedy: + _, token = torch.max(p, -1) + else: + token = Categorical(p).sample() + + if batch_size == 1: + token = token.unsqueeze(0) + + output.append(token) + + self._calculate_seq_len( + seq_lengths, token, initial_length, seq_pos=i+1, + n_sos_symbols=n_sos_symbols, is_discrete=not self.training) + + return (torch.stack(output, dim=1), seq_lengths) diff --git a/machine/models/baseRNN.py b/machine/models/baseRNN.py index 5c52f93d..e4e113ff 100644 --- a/machine/models/baseRNN.py +++ b/machine/models/baseRNN.py @@ -28,7 +28,8 @@ class BaseRNN(nn.Module): SYM_EOS = "EOS" def __init__(self, vocab_size, max_len, hidden_size, - input_dropout_p, dropout_p, n_layers, rnn_cell): + input_dropout_p, dropout_p, n_layers, rnn_cell, + relaxed=False): super(BaseRNN, self).__init__() self.vocab_size = vocab_size self.max_len = max_len @@ -37,9 +38,15 @@ def __init__(self, vocab_size, max_len, hidden_size, self.input_dropout_p = input_dropout_p self.input_dropout = nn.Dropout(p=input_dropout_p) if rnn_cell.lower() == 'lstm': - self.rnn_cell = nn.LSTM + if relaxed: + self.rnn_cell = nn.LSTMCell + else: + self.rnn_cell = nn.LSTM elif rnn_cell.lower() == 'gru': - self.rnn_cell = nn.GRU + if relaxed: + self.rnn_cell = nn.GRUCell + else: + self.rnn_cell = nn.GRU else: raise ValueError("Unsupported RNN Cell: {0}".format(rnn_cell)) diff --git a/machine/util/__init__.py b/machine/util/__init__.py index 41f5fcc2..4d27d9bc 100644 --- a/machine/util/__init__.py +++ b/machine/util/__init__.py @@ -1,2 +1,3 @@ from .log import Log from .checkpoint import Checkpoint +from .gumbel import gumbel_softmax diff --git a/machine/util/gumbel.py b/machine/util/gumbel.py new file mode 100644 index 00000000..e38007f0 --- /dev/null +++ b/machine/util/gumbel.py @@ -0,0 +1,21 @@ +import torch +from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical + +def gumbel_softmax(probs, tau, hard): + """ Computes samplig from the Gumbel Softmax (GS) distribution + + Args: + probs (torch.tensor): probabilities of shape [batch_size, n_classes] + tau (float): temperature parameter for the GS + hard (bool): discretize if True + """ + + rohc = RelaxedOneHotCategorical(tau, probs) + y = rohc.rsample() + + if hard: + y_hard = torch.zeros_like(y) + y_hard.scatter_(-1, torch.argmax(y, dim=-1, keepdim=True), 1.0) + y = (y_hard - y).detach() + y + + return y \ No newline at end of file diff --git a/test/test_gumbel.py b/test/test_gumbel.py new file mode 100644 index 00000000..17ae2dc0 --- /dev/null +++ b/test/test_gumbel.py @@ -0,0 +1,30 @@ +import unittest +import torch + +from machine.util.gumbel import gumbel_softmax + + +class TestGumbelSoftmax(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.probs = torch.tensor([ + [0.3, 0.4, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.2, 0.3, 0.1], + [0.6, 0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2, 0.2]]) + self.tau = 1.2 + + def test_hard(self): + res = gumbel_softmax(self.probs, self.tau, hard=True) + + self.assertEqual(res.shape, (4, 5)) + for r in res: + self.assertTrue(r.sum(), 1.0) + + def test_soft(self): + res = gumbel_softmax(self.probs, self.tau, hard=False) + + self.assertEqual(res.shape, (4, 5)) + + \ No newline at end of file diff --git a/test/test_sender.py b/test/test_sender.py new file mode 100644 index 00000000..bb25911b --- /dev/null +++ b/test/test_sender.py @@ -0,0 +1,502 @@ +import unittest +import mock + +import torch + +from machine.models.Sender import Sender + +class TestSender(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.vocab_size = 4 + self.max_len = 5 + self.embedding_size = 8 + self.hidden_size = 16 + self.sos_id = 0 + self.eos_id = self.vocab_size - 1 + self.tau = 1.2 + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_lstm_hidden_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + sender.train() + + # One per max_len (2 elements in batch - a and b) + a1 = [1.0, 0.0, 0.0, 0.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + b1 = [0.0, 1.0, 0.0, 0.0] + b2 = [0.0, 0.0, 1.0, 0.0] + b3 = [0.0, 0.0, 0.0, 1.0] + b4 = [0.0, 0.0, 1.0, 0.0] + b5 = [1.0, 0.0, 0.0, 0.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1, b1]), + torch.tensor([a2, b2]), + torch.tensor([a3, b3]), + torch.tensor([a4, b4]), + torch.tensor([a5, b5]) + ]) + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + b1, + b2, + b3, + b4, + b5 + ])))) + + self.assertEqual(seq_lengths[0], 6) + self.assertEqual(seq_lengths[1], 4) + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_gru_hidden_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='gru', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + sender.train() + + # One per max_len (2 elements in batch - a and b) + a1 = [1.0, 0.0, 0.0, 0.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + b1 = [0.0, 1.0, 0.0, 0.0] + b2 = [0.0, 0.0, 1.0, 0.0] + b3 = [0.0, 0.0, 0.0, 1.0] + b4 = [0.0, 0.0, 1.0, 0.0] + b5 = [1.0, 0.0, 0.0, 0.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1, b1]), + torch.tensor([a2, b2]), + torch.tensor([a3, b3]), + torch.tensor([a4, b4]), + torch.tensor([a5, b5]) + ]) + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + b1, + b2, + b3, + b4, + b5 + ])))) + + self.assertEqual(seq_lengths[0], 6) + self.assertEqual(seq_lengths[1], 4) + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_lstm_not_hidden_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 1 + n_sos_tokens = 1 + + sender.train() + + # One per max_len (1 element in batch - a) + a1 = [0.0, 0.0, 0.0, 1.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1]), + torch.tensor([a2]), + torch.tensor([a3]), + torch.tensor([a4]), + torch.tensor([a5]) + ]) + + res, seq_lengths = sender(self.tau) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertEqual(seq_lengths[0], 2) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_hidden_eval_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 1, + 3, + 0, + 0, + 3 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([self.sos_id, + 0, + 1, + 0, + 0, + 1 + ])))) + + self.assertEqual(seq_lengths[0], 3) + self.assertEqual(seq_lengths[1], 6) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_hidden_eval_not_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=False) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_softmax.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertEqual(res[0][0], self.sos_id) + self.assertEqual(res[1][0], self.sos_id) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_not_hidden_eval_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 1 + n_sos_tokens = 1 + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([0.3, 0.2, 0.1, 0.4]), + torch.tensor([0.4, 0.3, 0.1, 0.2]), + torch.tensor([0.2, 0.4, 0.2, 0.2]), + torch.tensor([0.2, 0.5, 0.1, 0.2]), + torch.tensor([0.0, 0.0, 0.8, 0.2]), + ] + + sender.eval() + + res, seq_lengths = sender(self.tau) + + self.assertEqual(mock_softmax.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 3, + 0, + 1, + 1, + 2 + ])))) + + self.assertEqual(seq_lengths[0], 2) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_not_hidden_eval_not_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='lstm', greedy=False) + + batch_size = 1 + n_sos_tokens = 1 + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([0.3, 0.2, 0.1, 0.4]), + torch.tensor([0.4, 0.3, 0.1, 0.2]), + torch.tensor([0.2, 0.4, 0.2, 0.2]), + torch.tensor([0.2, 0.5, 0.1, 0.2]), + torch.tensor([0.0, 0.0, 0.8, 0.2]), + ] + + sender.eval() + + res, seq_lengths = sender(self.tau) + + self.assertEqual(mock_softmax.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertEqual(res[0][0], self.sos_id) + + @mock.patch('machine.models.Sender.F.softmax') + def test_gru_hidden_eval_greedy(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.eos_id, + rnn_cell='gru', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 1, + 3, + 0, + 0, + 3 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([self.sos_id, + 0, + 1, + 0, + 0, + 1 + ])))) + + self.assertEqual(seq_lengths[0], 3) + self.assertEqual(seq_lengths[1], 6) + + @mock.patch('machine.models.Sender.gumbel_softmax') + def test_lstm_bound_sos_train(self, mock_gumbel): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.sos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + sender.train() + + # One per max_len (2 elements in batch - a and b) + a1 = [1.0, 0.0, 0.0, 0.0] + a2 = [1.0, 0.0, 0.0, 0.0] + a3 = [0.0, 1.0, 0.0, 0.0] + a4 = [0.0, 0.0, 1.0, 0.0] + a5 = [0.0, 0.0, 0.0, 1.0] + + b1 = [0.0, 1.0, 0.0, 0.0] + b2 = [0.0, 0.0, 1.0, 0.0] + b3 = [0.0, 0.0, 0.0, 1.0] + b4 = [0.0, 0.0, 1.0, 0.0] + b5 = [1.0, 0.0, 0.0, 0.0] + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_gumbel.side_effect = ([ + torch.tensor([a1, b1]), + torch.tensor([a2, b2]), + torch.tensor([a3, b3]), + torch.tensor([a4, b4]), + torch.tensor([a5, b5]) + ]) + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(mock_gumbel.call_count, self.max_len) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + a1, + a2, + a3, + a4, + a5 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([[1.0, 0.0, 0.0, 0.0], # sos token + b1, + b2, + b3, + b4, + b5 + ])))) + + self.assertEqual(seq_lengths[0], 2) + self.assertEqual(seq_lengths[1], 6) + + @mock.patch('machine.models.Sender.F.softmax') + def test_lstm_bound_sos_eval(self, mock_softmax): + sender = Sender(self.vocab_size, self.max_len, self.embedding_size, + self.hidden_size, self.sos_id, self.sos_id, + rnn_cell='lstm', greedy=True) + + batch_size = 2 + n_sos_tokens = 1 + hidden_state = torch.rand([batch_size, self.hidden_size]) + + # Return max_len times a tensor with shape [batch_size, vocab_size] + mock_softmax.side_effect = [ + torch.tensor([[0.3, 0.4, 0.1, 0.2], [0.4, 0.3, 0.1, 0.2]]), + torch.tensor([[0.2, 0.2, 0.2, 0.4], [0.1, 0.4, 0.2, 0.3]]), + torch.tensor([[0.6, 0.4, 0.0, 0.0], [0.5, 0.3, 0.1, 0.2]]), + torch.tensor([[0.4, 0.3, 0.0, 0.3], [0.8, 0.1, 0.1, 0.0]]), + torch.tensor([[0.1, 0.1, 0.1, 0.7], [0.0, 0.5, 0.2, 0.3]]) + ] + + sender.eval() + + res, seq_lengths = sender(self.tau, hidden_state) + + self.assertEqual(res.shape[0], batch_size) + self.assertEqual(res.shape[1], self.max_len + n_sos_tokens) + + self.assertTrue(torch.all(torch.eq( + res[0], + torch.tensor([self.sos_id, + 1, + 3, + 0, + 0, + 3 + ])))) + + self.assertTrue(torch.all(torch.eq( + res[1], + torch.tensor([self.sos_id, + 0, + 1, + 0, + 0, + 1 + ])))) + + self.assertEqual(seq_lengths[0], 4) + self.assertEqual(seq_lengths[1], 2) \ No newline at end of file