Skip to content
This repository was archived by the owner on Dec 11, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9ccfde9
Added num_encoder_layers/num_decoder_layers to WMT16 standard hparams.
tiberiuscarlat Mar 12, 2018
8522665
Add TrainOutputTuple, EvalOutputTuple, InferOutputTuple
lmthang Jan 3, 2018
80ff1c9
Wrap try/except for saver.restore and print variables in loaded check…
lmthang Jan 5, 2018
653da78
Improve several behaviors regarding loading hparams for train/inference:
lmthang Jan 6, 2018
64e0436
pretrained our models on
a-googler Jan 7, 2018
cfa4f48
Factor out model creation code in train.py and inference.py.
lmthang Jan 8, 2018
989683d
Standardize vocab in test to use <unk>, <s>, </s>
lmthang Jan 9, 2018
656fb12
Clean up inference_test.py and add more sharing code to _createTestIn…
lmthang Jan 9, 2018
bdbba21
NMT: Improving GPU availability debugging.
eli7 Jan 9, 2018
e73f83e
Add add_info_summaries to automatically add summaries from info dict.
lmthang Jan 9, 2018
ae33cf1
Internal change only
lmthang Jan 10, 2018
4eba2d8
Make compute_encoder_states() in model.py more general
lmthang Jan 12, 2018
9be88ea
Set self.encoder_state_list in build_encoder() for gnmt_model.py
lmthang Jan 18, 2018
438e29a
Add language_model flag to train a language model by ignoring the enc…
lmthang Jan 18, 2018
3bb4930
Add infer_mode option to specify which type of decoder to use during …
oahziur Jan 23, 2018
eeed098
Replace compute_encoder_states with build_encoder_states (no sess.run).
lmthang Jan 24, 2018
17c4272
Add sampled_softmax_loss and minor cleanup.
oahziur Jan 25, 2018
4781773
Update standard hparams as num_layers is no longer a valid hparam.
oahziur Jan 29, 2018
005fef0
Add copyright text for model_helper.py
lmthang Jan 30, 2018
0642c53
Refactoring internal and external eval to allow injection of placehol…
eli7 Jan 30, 2018
bd936dd
Refactoring Model._build_encoder.
eli7 Feb 2, 2018
781201a
Minor clean-ups, update docstring
lmthang Feb 7, 2018
3cd5d33
Update vocab_utils.py to load embedding files under word2vec format.
lmthang Feb 11, 2018
d2f66df
Allow embedding file to have some misformated entry. Simply ignore th…
oahziur Feb 13, 2018
6caa994
Unify reading and writing of hparams file.
a-googler Feb 17, 2018
8f20c69
Colocate output_layer with last LSTM cell to improve training speed.
oahziur Feb 21, 2018
735b8b7
Add char-level embeddings for encoder only.
oahziur Apr 3, 2018
a35b1e3
Refactored *model.py files.
lmthang Apr 5, 2018
6853265
Pretty print hparams when writing file.
a-googler Apr 5, 2018
2411c44
Remove hparams.num_layers
lmthang Apr 5, 2018
9494594
Remove unused standard hparams.
oahziur May 11, 2018
3128dac
Fix typo.
a-googler Jul 4, 2018
eb88c18
[NMT] Adding support for sharded train sets.
eli7 Jul 20, 2018
7932a59
Decouple model loading from inference code
lmthang Aug 2, 2018
471e484
Use LooseVersion for version check
a-googler Aug 6, 2018
1355e32
Minor improvements and fixes
lmthang Aug 11, 2018
b278487
Add INFERENCE_KEYS and make extend_hparams() in nmt.py more robust
lmthang Aug 24, 2018
b333eea
hparams
tiberiu92 Oct 17, 2018
c045042
Added beam_search as default inference mode in hparams.
tiberiu92 Oct 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ with the latest research ideas. We achieve this goal by:

We believe that it is important to provide benchmarks that people can easily
replicate. As a result, we have provided full experimental results and
pretrained on models on the following publicly available datasets:
pretrained our models on the following publicly available datasets:

1. *Small-scale*: English-Vietnamese parallel corpus of TED talks (133K sentence
pairs) provided by
Expand Down
57 changes: 34 additions & 23 deletions nmt/attention_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ def __init__(self,
reverse_target_vocab_table=None,
scope=None,
extra_args=None):
self.has_attention = hparams.attention_architecture and hparams.attention

# Set attention_mechanism_fn
if extra_args and extra_args.attention_mechanism_fn:
self.attention_mechanism_fn = extra_args.attention_mechanism_fn
else:
self.attention_mechanism_fn = create_attention_mechanism
if self.has_attention:
if extra_args and extra_args.attention_mechanism_fn:
self.attention_mechanism_fn = extra_args.attention_mechanism_fn
else:
self.attention_mechanism_fn = create_attention_mechanism

super(AttentionModel, self).__init__(
hparams=hparams,
Expand All @@ -60,23 +63,32 @@ def __init__(self,
scope=scope,
extra_args=extra_args)

if self.mode == tf.contrib.learn.ModeKeys.INFER:
self.infer_summary = self._get_infer_summary(hparams)
def _prepare_beam_search_decoder_inputs(
self, beam_width, memory, source_sequence_length, encoder_state):
memory = tf.contrib.seq2seq.tile_batch(
memory, multiplier=beam_width)
source_sequence_length = tf.contrib.seq2seq.tile_batch(
source_sequence_length, multiplier=beam_width)
encoder_state = tf.contrib.seq2seq.tile_batch(
encoder_state, multiplier=beam_width)
batch_size = self.batch_size * beam_width
return memory, source_sequence_length, encoder_state, batch_size

def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
source_sequence_length):
"""Build a RNN cell with attention mechanism that can be used by decoder."""
attention_option = hparams.attention
attention_architecture = hparams.attention_architecture

if attention_architecture != "standard":
# No Attention
if not self.has_attention:
return super(AttentionModel, self)._build_decoder_cell(
hparams, encoder_outputs, encoder_state, source_sequence_length)
elif hparams.attention_architecture != "standard":
raise ValueError(
"Unknown attention architecture %s" % attention_architecture)
"Unknown attention architecture %s" % hparams.attention_architecture)

num_units = hparams.num_units
num_layers = self.num_decoder_layers
num_residual_layers = self.num_decoder_residual_layers
beam_width = hparams.beam_width
infer_mode = hparams.infer_mode

dtype = tf.float32

Expand All @@ -86,19 +98,18 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
else:
memory = encoder_outputs

if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
memory = tf.contrib.seq2seq.tile_batch(
memory, multiplier=beam_width)
source_sequence_length = tf.contrib.seq2seq.tile_batch(
source_sequence_length, multiplier=beam_width)
encoder_state = tf.contrib.seq2seq.tile_batch(
encoder_state, multiplier=beam_width)
batch_size = self.batch_size * beam_width
if (self.mode == tf.contrib.learn.ModeKeys.INFER and
infer_mode == "beam_search"):
memory, source_sequence_length, encoder_state, batch_size = (
self._prepare_beam_search_decoder_inputs(
hparams.beam_width, memory, source_sequence_length,
encoder_state))
else:
batch_size = self.batch_size

# Attention
attention_mechanism = self.attention_mechanism_fn(
attention_option, num_units, memory, source_sequence_length, self.mode)
hparams.attention, num_units, memory, source_sequence_length, self.mode)

cell = model_helper.create_rnn_cell(
unit_type=hparams.unit_type,
Expand All @@ -113,7 +124,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,

# Only generate alignment in greedy INFER mode.
alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
beam_width == 0)
infer_mode != "beam_search")
cell = tf.contrib.seq2seq.AttentionWrapper(
cell,
attention_mechanism,
Expand All @@ -136,7 +147,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
return cell, decoder_initial_state

def _get_infer_summary(self, hparams):
if hparams.beam_width > 0:
if not self.has_attention or hparams.infer_mode == "beam_search":
return tf.no_op()
return _create_attention_images_summary(self.final_context_state)

Expand Down
156 changes: 102 additions & 54 deletions nmt/gnmt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@

import tensorflow as tf

# TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out.
from tensorflow.python.util import nest

from . import attention_model
from . import model_helper
from .utils import misc_utils as utils
from .utils import vocab_utils

__all__ = ["GNMTModel"]

Expand All @@ -43,6 +41,9 @@ def __init__(self,
reverse_target_vocab_table=None,
scope=None,
extra_args=None):
self.is_gnmt_attention = (
hparams.attention_architecture in ["gnmt", "gnmt_v2"])

super(GNMTModel, self).__init__(
hparams=hparams,
mode=mode,
Expand All @@ -64,6 +65,7 @@ def _build_encoder(self, hparams):
# Build GNMT encoder.
num_bi_layers = 1
num_uni_layers = self.num_encoder_layers - num_bi_layers
utils.print_out("# Build a GNMT encoder")
utils.print_out(" num_bi_layers = %d" % num_bi_layers)
utils.print_out(" num_uni_layers = %d" % num_uni_layers)

Expand All @@ -75,62 +77,109 @@ def _build_encoder(self, hparams):
with tf.variable_scope("encoder") as scope:
dtype = scope.dtype

# Look up embedding, emp_inp: [max_time, batch_size, num_units]
# when time_major = True
encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder,
source)
self.encoder_emb_inp = self.encoder_emb_lookup_fn(
self.embedding_encoder, source)

# Execute _build_bidirectional_rnn from Model class
bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn(
inputs=encoder_emb_inp,
inputs=self.encoder_emb_inp,
sequence_length=iterator.source_sequence_length,
dtype=dtype,
hparams=hparams,
num_bi_layers=num_bi_layers,
num_bi_residual_layers=0, # no residual connection
)

uni_cell = model_helper.create_rnn_cell(
unit_type=hparams.unit_type,
num_units=hparams.num_units,
num_layers=num_uni_layers,
num_residual_layers=self.num_encoder_residual_layers,
forget_bias=hparams.forget_bias,
dropout=hparams.dropout,
num_gpus=self.num_gpus,
base_gpu=1,
mode=self.mode,
single_cell_fn=self.single_cell_fn)

# encoder_outputs: size [max_time, batch_size, num_units]
# when time_major = True
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
uni_cell,
bi_encoder_outputs,
dtype=dtype,
sequence_length=iterator.source_sequence_length,
time_major=self.time_major)
# Build unidirectional layers
if self.extract_encoder_layers:
encoder_state, encoder_outputs = self._build_individual_encoder_layers(
bi_encoder_outputs, num_uni_layers, dtype, hparams)
else:
encoder_state, encoder_outputs = self._build_all_encoder_layers(
bi_encoder_outputs, num_uni_layers, dtype, hparams)

# Pass all encoder state except the first bi-directional layer's state to
# decoder.
# Pass all encoder states to the decoder
# except the first bi-directional layer
encoder_state = (bi_encoder_state[1],) + (
(encoder_state,) if num_uni_layers == 1 else encoder_state)

return encoder_outputs, encoder_state

def _build_all_encoder_layers(self, bi_encoder_outputs,
num_uni_layers, dtype, hparams):
"""Build encoder layers all at once."""
uni_cell = model_helper.create_rnn_cell(
unit_type=hparams.unit_type,
num_units=hparams.num_units,
num_layers=num_uni_layers,
num_residual_layers=self.num_encoder_residual_layers,
forget_bias=hparams.forget_bias,
dropout=hparams.dropout,
num_gpus=self.num_gpus,
base_gpu=1,
mode=self.mode,
single_cell_fn=self.single_cell_fn)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
uni_cell,
bi_encoder_outputs,
dtype=dtype,
sequence_length=self.iterator.source_sequence_length,
time_major=self.time_major)

# Use the top layer for now
self.encoder_state_list = [encoder_outputs]

return encoder_state, encoder_outputs

def _build_individual_encoder_layers(self, bi_encoder_outputs,
num_uni_layers, dtype, hparams):
"""Run each of the encoder layer separately, not used in general seq2seq."""
uni_cell_lists = model_helper._cell_list(
unit_type=hparams.unit_type,
num_units=hparams.num_units,
num_layers=num_uni_layers,
num_residual_layers=self.num_encoder_residual_layers,
forget_bias=hparams.forget_bias,
dropout=hparams.dropout,
num_gpus=self.num_gpus,
base_gpu=1,
mode=self.mode,
single_cell_fn=self.single_cell_fn)

encoder_inp = bi_encoder_outputs
encoder_states = []
self.encoder_state_list = [bi_encoder_outputs[:, :, :hparams.num_units],
bi_encoder_outputs[:, :, hparams.num_units:]]
with tf.variable_scope("rnn/multi_rnn_cell"):
for i, cell in enumerate(uni_cell_lists):
with tf.variable_scope("cell_%d" % i) as scope:
encoder_inp, encoder_state = tf.nn.dynamic_rnn(
cell,
encoder_inp,
dtype=dtype,
sequence_length=self.iterator.source_sequence_length,
time_major=self.time_major,
scope=scope)
encoder_states.append(encoder_state)
self.encoder_state_list.append(encoder_inp)

encoder_state = tuple(encoder_states)
encoder_outputs = self.encoder_state_list[-1]
return encoder_state, encoder_outputs

def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
source_sequence_length):
"""Build a RNN cell with GNMT attention architecture."""
# Standard attention
if hparams.attention_architecture == "standard":
if not self.is_gnmt_attention:
return super(GNMTModel, self)._build_decoder_cell(
hparams, encoder_outputs, encoder_state, source_sequence_length)

# GNMT attention
attention_option = hparams.attention
attention_architecture = hparams.attention_architecture
num_units = hparams.num_units
beam_width = hparams.beam_width
infer_mode = hparams.infer_mode

dtype = tf.float32

Expand All @@ -139,14 +188,12 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
else:
memory = encoder_outputs

if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
memory = tf.contrib.seq2seq.tile_batch(
memory, multiplier=beam_width)
source_sequence_length = tf.contrib.seq2seq.tile_batch(
source_sequence_length, multiplier=beam_width)
encoder_state = tf.contrib.seq2seq.tile_batch(
encoder_state, multiplier=beam_width)
batch_size = self.batch_size * beam_width
if (self.mode == tf.contrib.learn.ModeKeys.INFER and
infer_mode == "beam_search"):
memory, source_sequence_length, encoder_state, batch_size = (
self._prepare_beam_search_decoder_inputs(
hparams.beam_width, memory, source_sequence_length,
encoder_state))
else:
batch_size = self.batch_size

Expand All @@ -171,7 +218,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,

# Only generate alignment in greedy INFER mode.
alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
beam_width == 0)
infer_mode != "beam_search")
attention_cell = tf.contrib.seq2seq.AttentionWrapper(
attention_cell,
attention_mechanism,
Expand Down Expand Up @@ -202,15 +249,13 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
return cell, decoder_initial_state

def _get_infer_summary(self, hparams):
# Standard attention
if hparams.attention_architecture == "standard":
return super(GNMTModel, self)._get_infer_summary(hparams)

# GNMT attention
if hparams.beam_width > 0:
if hparams.infer_mode == "beam_search":
return tf.no_op()
return attention_model._create_attention_images_summary(
self.final_context_state[0])
elif self.is_gnmt_attention:
return attention_model._create_attention_images_summary(
self.final_context_state[0])
else:
return super(GNMTModel, self)._get_infer_summary(hparams)


class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell):
Expand All @@ -231,7 +276,7 @@ def __init__(self, attention_cell, cells, use_new_attention=False):

def __call__(self, inputs, state, scope=None):
"""Run the cell with bottom layer's attention copied to all upper layers."""
if not nest.is_sequence(state):
if not tf.contrib.framework.nest.is_sequence(state):
raise ValueError(
"Expected state to be a tuple of length %d, but received: %s"
% (len(self.state_size), state))
Expand Down Expand Up @@ -277,9 +322,12 @@ def split_input(inp, out):
out_dim = out.get_shape().as_list()[-1]
inp_dim = inp.get_shape().as_list()[-1]
return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)
actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)
actual_inputs, _ = tf.contrib.framework.nest.map_structure(
split_input, inputs, outputs)
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
nest.assert_same_structure(actual_inputs, outputs)
nest.map_structure(assert_shape_match, actual_inputs, outputs)
return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)
tf.contrib.framework.nest.assert_same_structure(actual_inputs, outputs)
tf.contrib.framework.nest.map_structure(
assert_shape_match, actual_inputs, outputs)
return tf.contrib.framework.nest.map_structure(
lambda inp, out: inp + out, actual_inputs, outputs)
Loading