diff --git a/README.md b/README.md index a16ffbf9b..809db9ab3 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ with the latest research ideas. We achieve this goal by: We believe that it is important to provide benchmarks that people can easily replicate. As a result, we have provided full experimental results and -pretrained on models on the following publicly available datasets: +pretrained our models on the following publicly available datasets: 1. *Small-scale*: English-Vietnamese parallel corpus of TED talks (133K sentence pairs) provided by diff --git a/nmt/attention_model.py b/nmt/attention_model.py index af0ee0b7e..d262b8e8e 100644 --- a/nmt/attention_model.py +++ b/nmt/attention_model.py @@ -44,11 +44,14 @@ def __init__(self, reverse_target_vocab_table=None, scope=None, extra_args=None): + self.has_attention = hparams.attention_architecture and hparams.attention + # Set attention_mechanism_fn - if extra_args and extra_args.attention_mechanism_fn: - self.attention_mechanism_fn = extra_args.attention_mechanism_fn - else: - self.attention_mechanism_fn = create_attention_mechanism + if self.has_attention: + if extra_args and extra_args.attention_mechanism_fn: + self.attention_mechanism_fn = extra_args.attention_mechanism_fn + else: + self.attention_mechanism_fn = create_attention_mechanism super(AttentionModel, self).__init__( hparams=hparams, @@ -60,23 +63,32 @@ def __init__(self, scope=scope, extra_args=extra_args) - if self.mode == tf.contrib.learn.ModeKeys.INFER: - self.infer_summary = self._get_infer_summary(hparams) + def _prepare_beam_search_decoder_inputs( + self, beam_width, memory, source_sequence_length, encoder_state): + memory = tf.contrib.seq2seq.tile_batch( + memory, multiplier=beam_width) + source_sequence_length = tf.contrib.seq2seq.tile_batch( + source_sequence_length, multiplier=beam_width) + encoder_state = tf.contrib.seq2seq.tile_batch( + encoder_state, multiplier=beam_width) + batch_size = self.batch_size * beam_width + return memory, source_sequence_length, encoder_state, batch_size def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Build a RNN cell with attention mechanism that can be used by decoder.""" - attention_option = hparams.attention - attention_architecture = hparams.attention_architecture - - if attention_architecture != "standard": + # No Attention + if not self.has_attention: + return super(AttentionModel, self)._build_decoder_cell( + hparams, encoder_outputs, encoder_state, source_sequence_length) + elif hparams.attention_architecture != "standard": raise ValueError( - "Unknown attention architecture %s" % attention_architecture) + "Unknown attention architecture %s" % hparams.attention_architecture) num_units = hparams.num_units num_layers = self.num_decoder_layers num_residual_layers = self.num_decoder_residual_layers - beam_width = hparams.beam_width + infer_mode = hparams.infer_mode dtype = tf.float32 @@ -86,19 +98,18 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, else: memory = encoder_outputs - if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: - memory = tf.contrib.seq2seq.tile_batch( - memory, multiplier=beam_width) - source_sequence_length = tf.contrib.seq2seq.tile_batch( - source_sequence_length, multiplier=beam_width) - encoder_state = tf.contrib.seq2seq.tile_batch( - encoder_state, multiplier=beam_width) - batch_size = self.batch_size * beam_width + if (self.mode == tf.contrib.learn.ModeKeys.INFER and + infer_mode == "beam_search"): + memory, source_sequence_length, encoder_state, batch_size = ( + self._prepare_beam_search_decoder_inputs( + hparams.beam_width, memory, source_sequence_length, + encoder_state)) else: batch_size = self.batch_size + # Attention attention_mechanism = self.attention_mechanism_fn( - attention_option, num_units, memory, source_sequence_length, self.mode) + hparams.attention, num_units, memory, source_sequence_length, self.mode) cell = model_helper.create_rnn_cell( unit_type=hparams.unit_type, @@ -113,7 +124,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, # Only generate alignment in greedy INFER mode. alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and - beam_width == 0) + infer_mode != "beam_search") cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism, @@ -136,7 +147,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, return cell, decoder_initial_state def _get_infer_summary(self, hparams): - if hparams.beam_width > 0: + if not self.has_attention or hparams.infer_mode == "beam_search": return tf.no_op() return _create_attention_images_summary(self.final_context_state) diff --git a/nmt/gnmt_model.py b/nmt/gnmt_model.py index 6f1219ec3..468a5d00c 100644 --- a/nmt/gnmt_model.py +++ b/nmt/gnmt_model.py @@ -20,12 +20,10 @@ import tensorflow as tf -# TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out. -from tensorflow.python.util import nest - from . import attention_model from . import model_helper from .utils import misc_utils as utils +from .utils import vocab_utils __all__ = ["GNMTModel"] @@ -43,6 +41,9 @@ def __init__(self, reverse_target_vocab_table=None, scope=None, extra_args=None): + self.is_gnmt_attention = ( + hparams.attention_architecture in ["gnmt", "gnmt_v2"]) + super(GNMTModel, self).__init__( hparams=hparams, mode=mode, @@ -64,6 +65,7 @@ def _build_encoder(self, hparams): # Build GNMT encoder. num_bi_layers = 1 num_uni_layers = self.num_encoder_layers - num_bi_layers + utils.print_out("# Build a GNMT encoder") utils.print_out(" num_bi_layers = %d" % num_bi_layers) utils.print_out(" num_uni_layers = %d" % num_uni_layers) @@ -75,14 +77,12 @@ def _build_encoder(self, hparams): with tf.variable_scope("encoder") as scope: dtype = scope.dtype - # Look up embedding, emp_inp: [max_time, batch_size, num_units] - # when time_major = True - encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder, - source) + self.encoder_emb_inp = self.encoder_emb_lookup_fn( + self.embedding_encoder, source) # Execute _build_bidirectional_rnn from Model class bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( - inputs=encoder_emb_inp, + inputs=self.encoder_emb_inp, sequence_length=iterator.source_sequence_length, dtype=dtype, hparams=hparams, @@ -90,39 +90,88 @@ def _build_encoder(self, hparams): num_bi_residual_layers=0, # no residual connection ) - uni_cell = model_helper.create_rnn_cell( - unit_type=hparams.unit_type, - num_units=hparams.num_units, - num_layers=num_uni_layers, - num_residual_layers=self.num_encoder_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=self.num_gpus, - base_gpu=1, - mode=self.mode, - single_cell_fn=self.single_cell_fn) - - # encoder_outputs: size [max_time, batch_size, num_units] - # when time_major = True - encoder_outputs, encoder_state = tf.nn.dynamic_rnn( - uni_cell, - bi_encoder_outputs, - dtype=dtype, - sequence_length=iterator.source_sequence_length, - time_major=self.time_major) + # Build unidirectional layers + if self.extract_encoder_layers: + encoder_state, encoder_outputs = self._build_individual_encoder_layers( + bi_encoder_outputs, num_uni_layers, dtype, hparams) + else: + encoder_state, encoder_outputs = self._build_all_encoder_layers( + bi_encoder_outputs, num_uni_layers, dtype, hparams) - # Pass all encoder state except the first bi-directional layer's state to - # decoder. + # Pass all encoder states to the decoder + # except the first bi-directional layer encoder_state = (bi_encoder_state[1],) + ( (encoder_state,) if num_uni_layers == 1 else encoder_state) return encoder_outputs, encoder_state + def _build_all_encoder_layers(self, bi_encoder_outputs, + num_uni_layers, dtype, hparams): + """Build encoder layers all at once.""" + uni_cell = model_helper.create_rnn_cell( + unit_type=hparams.unit_type, + num_units=hparams.num_units, + num_layers=num_uni_layers, + num_residual_layers=self.num_encoder_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + base_gpu=1, + mode=self.mode, + single_cell_fn=self.single_cell_fn) + encoder_outputs, encoder_state = tf.nn.dynamic_rnn( + uni_cell, + bi_encoder_outputs, + dtype=dtype, + sequence_length=self.iterator.source_sequence_length, + time_major=self.time_major) + + # Use the top layer for now + self.encoder_state_list = [encoder_outputs] + + return encoder_state, encoder_outputs + + def _build_individual_encoder_layers(self, bi_encoder_outputs, + num_uni_layers, dtype, hparams): + """Run each of the encoder layer separately, not used in general seq2seq.""" + uni_cell_lists = model_helper._cell_list( + unit_type=hparams.unit_type, + num_units=hparams.num_units, + num_layers=num_uni_layers, + num_residual_layers=self.num_encoder_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + base_gpu=1, + mode=self.mode, + single_cell_fn=self.single_cell_fn) + + encoder_inp = bi_encoder_outputs + encoder_states = [] + self.encoder_state_list = [bi_encoder_outputs[:, :, :hparams.num_units], + bi_encoder_outputs[:, :, hparams.num_units:]] + with tf.variable_scope("rnn/multi_rnn_cell"): + for i, cell in enumerate(uni_cell_lists): + with tf.variable_scope("cell_%d" % i) as scope: + encoder_inp, encoder_state = tf.nn.dynamic_rnn( + cell, + encoder_inp, + dtype=dtype, + sequence_length=self.iterator.source_sequence_length, + time_major=self.time_major, + scope=scope) + encoder_states.append(encoder_state) + self.encoder_state_list.append(encoder_inp) + + encoder_state = tuple(encoder_states) + encoder_outputs = self.encoder_state_list[-1] + return encoder_state, encoder_outputs + def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Build a RNN cell with GNMT attention architecture.""" # Standard attention - if hparams.attention_architecture == "standard": + if not self.is_gnmt_attention: return super(GNMTModel, self)._build_decoder_cell( hparams, encoder_outputs, encoder_state, source_sequence_length) @@ -130,7 +179,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, attention_option = hparams.attention attention_architecture = hparams.attention_architecture num_units = hparams.num_units - beam_width = hparams.beam_width + infer_mode = hparams.infer_mode dtype = tf.float32 @@ -139,14 +188,12 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, else: memory = encoder_outputs - if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: - memory = tf.contrib.seq2seq.tile_batch( - memory, multiplier=beam_width) - source_sequence_length = tf.contrib.seq2seq.tile_batch( - source_sequence_length, multiplier=beam_width) - encoder_state = tf.contrib.seq2seq.tile_batch( - encoder_state, multiplier=beam_width) - batch_size = self.batch_size * beam_width + if (self.mode == tf.contrib.learn.ModeKeys.INFER and + infer_mode == "beam_search"): + memory, source_sequence_length, encoder_state, batch_size = ( + self._prepare_beam_search_decoder_inputs( + hparams.beam_width, memory, source_sequence_length, + encoder_state)) else: batch_size = self.batch_size @@ -171,7 +218,7 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, # Only generate alignment in greedy INFER mode. alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and - beam_width == 0) + infer_mode != "beam_search") attention_cell = tf.contrib.seq2seq.AttentionWrapper( attention_cell, attention_mechanism, @@ -202,15 +249,13 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, return cell, decoder_initial_state def _get_infer_summary(self, hparams): - # Standard attention - if hparams.attention_architecture == "standard": - return super(GNMTModel, self)._get_infer_summary(hparams) - - # GNMT attention - if hparams.beam_width > 0: + if hparams.infer_mode == "beam_search": return tf.no_op() - return attention_model._create_attention_images_summary( - self.final_context_state[0]) + elif self.is_gnmt_attention: + return attention_model._create_attention_images_summary( + self.final_context_state[0]) + else: + return super(GNMTModel, self)._get_infer_summary(hparams) class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): @@ -231,7 +276,7 @@ def __init__(self, attention_cell, cells, use_new_attention=False): def __call__(self, inputs, state, scope=None): """Run the cell with bottom layer's attention copied to all upper layers.""" - if not nest.is_sequence(state): + if not tf.contrib.framework.nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) @@ -277,9 +322,12 @@ def split_input(inp, out): out_dim = out.get_shape().as_list()[-1] inp_dim = inp.get_shape().as_list()[-1] return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) - actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) + actual_inputs, _ = tf.contrib.framework.nest.map_structure( + split_input, inputs, outputs) def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) - nest.assert_same_structure(actual_inputs, outputs) - nest.map_structure(assert_shape_match, actual_inputs, outputs) - return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) + tf.contrib.framework.nest.assert_same_structure(actual_inputs, outputs) + tf.contrib.framework.nest.map_structure( + assert_shape_match, actual_inputs, outputs) + return tf.contrib.framework.nest.map_structure( + lambda inp, out: inp + out, actual_inputs, outputs) diff --git a/nmt/inference.py b/nmt/inference.py index 6f589337a..2cbef07c2 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -80,7 +80,32 @@ def load_data(inference_input_file, hparams=None): return inference_data -def inference(ckpt, +def get_model_creator(hparams): + """Get the right model class depending on configuration.""" + if (hparams.encoder_type == "gnmt" or + hparams.attention_architecture in ["gnmt", "gnmt_v2"]): + model_creator = gnmt_model.GNMTModel + elif hparams.attention_architecture == "standard": + model_creator = attention_model.AttentionModel + elif not hparams.attention: + model_creator = nmt_model.Model + else: + raise ValueError("Unknown attention architecture %s" % + hparams.attention_architecture) + return model_creator + + +def start_sess_and_load_model(infer_model, ckpt_path): + """Start session and load model.""" + sess = tf.Session( + graph=infer_model.graph, config=utils.get_config_proto()) + with infer_model.graph.as_default(): + loaded_infer_model = model_helper.load_model( + infer_model.model, ckpt_path, sess, "infer") + return sess, loaded_infer_model + + +def inference(ckpt_path, inference_input_file, inference_output_file, hparams, @@ -91,36 +116,34 @@ def inference(ckpt, if hparams.inference_indices: assert num_workers == 1 - if not hparams.attention: - model_creator = nmt_model.Model - elif hparams.attention_architecture == "standard": - model_creator = attention_model.AttentionModel - elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: - model_creator = gnmt_model.GNMTModel - else: - raise ValueError("Unknown model architecture") + model_creator = get_model_creator(hparams) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) + sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path) if num_workers == 1: single_worker_inference( + sess, infer_model, - ckpt, + loaded_infer_model, inference_input_file, inference_output_file, hparams) else: multi_worker_inference( + sess, infer_model, - ckpt, + loaded_infer_model, inference_input_file, inference_output_file, hparams, num_workers=num_workers, jobid=jobid) + sess.close() -def single_worker_inference(infer_model, - ckpt, +def single_worker_inference(sess, + infer_model, + loaded_infer_model, inference_input_file, inference_output_file, hparams): @@ -130,10 +153,7 @@ def single_worker_inference(infer_model, # Read data infer_data = load_data(inference_input_file, hparams) - with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: - loaded_infer_model = model_helper.load_model( - infer_model.model, ckpt, sess, "infer") + with infer_model.graph.as_default(): sess.run( infer_model.iterator.initializer, feed_dict={ @@ -162,11 +182,13 @@ def single_worker_inference(infer_model, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, - num_translations_per_input=hparams.num_translations_per_input) + num_translations_per_input=hparams.num_translations_per_input, + infer_mode=hparams.infer_mode) -def multi_worker_inference(infer_model, - ckpt, +def multi_worker_inference(sess, + infer_model, + loaded_infer_model, inference_input_file, inference_output_file, hparams, @@ -189,10 +211,7 @@ def multi_worker_inference(infer_model, end_position = min(start_position + load_per_worker, total_load) infer_data = infer_data[start_position:end_position] - with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: - loaded_infer_model = model_helper.load_model( - infer_model.model, ckpt, sess, "infer") + with infer_model.graph.as_default(): sess.run(infer_model.iterator.initializer, { infer_model.src_placeholder: infer_data, @@ -210,7 +229,8 @@ def multi_worker_inference(infer_model, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, - num_translations_per_input=hparams.num_translations_per_input) + num_translations_per_input=hparams.num_translations_per_input, + infer_mode=hparams.infer_mode) # Change file name to indicate the file writing is completed. tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) @@ -224,7 +244,7 @@ def multi_worker_inference(infer_model, for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) while not tf.gfile.Exists(worker_infer_done): - utils.print_out(" waitting job %d to complete." % worker_id) + utils.print_out(" waiting job %d to complete." % worker_id) time.sleep(10) with codecs.getreader("utf-8")( diff --git a/nmt/inference_test.py b/nmt/inference_test.py index 7c342f9f9..317024b81 100644 --- a/nmt/inference_test.py +++ b/nmt/inference_test.py @@ -23,11 +23,8 @@ import numpy as np import tensorflow as tf -from . import attention_model -from . import model_helper -from . import model as nmt_model -from . import gnmt_model from . import inference +from . import model_helper from .utils import common_test_utils float32 = np.float32 @@ -37,24 +34,26 @@ class InferenceTest(tf.test.TestCase): - def _createTestInferCheckpoint(self, hparams, out_dir): - if not hparams.attention: - model_creator = nmt_model.Model - elif hparams.attention_architecture == "standard": - model_creator = attention_model.AttentionModel - elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: - model_creator = gnmt_model.GNMTModel - else: - raise ValueError("Unknown model architecture") + def _createTestInferCheckpoint(self, hparams, name): + # Prepare + hparams.vocab_prefix = ( + "nmt/testdata/test_infer_vocab") + hparams.src_vocab_file = hparams.vocab_prefix + "." + hparams.src + hparams.tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt + out_dir = os.path.join(tf.test.get_temp_dir(), name) + os.makedirs(out_dir) + hparams.out_dir = out_dir + # Create check point + model_creator = inference.get_model_creator(hparams) infer_model = model_helper.create_infer_model(model_creator, hparams) with self.test_session(graph=infer_model.graph) as sess: loaded_model, global_step = model_helper.create_or_load_model( infer_model.model, out_dir, sess, "infer_name") - ckpt = loaded_model.saver.save( + ckpt_path = loaded_model.saver.save( sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) - return ckpt + return ckpt_path def testBasicModel(self): hparams = common_test_utils.create_test_hparams( @@ -63,17 +62,10 @@ def testBasicModel(self): attention="", attention_architecture="", use_residual=False,) - vocab_prefix = "nmt/testdata/test_infer_vocab" - hparams.src_vocab_file = vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt - + ckpt_path = self._createTestInferCheckpoint(hparams, "basic_infer") infer_file = "nmt/testdata/test_infer_file" - out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer") - hparams.out_dir = out_dir - os.makedirs(out_dir) - output_infer = os.path.join(out_dir, "output_infer") - ckpt = self._createTestInferCheckpoint(hparams, out_dir) - inference.inference(ckpt, infer_file, output_infer, hparams) + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(5, len(list(f))) @@ -87,17 +79,12 @@ def testBasicModelWithMultipleTranslations(self): num_translations_per_input=2, beam_width=2, ) - vocab_prefix = "nmt/testdata/test_infer_vocab" - hparams.src_vocab_file = vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt + hparams.infer_mode = "beam_search" + ckpt_path = self._createTestInferCheckpoint(hparams, "multi_basic_infer") infer_file = "nmt/testdata/test_infer_file" - out_dir = os.path.join(tf.test.get_temp_dir(), "multi_basic_infer") - hparams.out_dir = out_dir - os.makedirs(out_dir) - output_infer = os.path.join(out_dir, "output_infer") - ckpt = self._createTestInferCheckpoint(hparams, out_dir) - inference.inference(ckpt, infer_file, output_infer, hparams) + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(10, len(list(f))) @@ -108,17 +95,10 @@ def testAttentionModel(self): attention="scaled_luong", attention_architecture="standard", use_residual=False,) - vocab_prefix = "nmt/testdata/test_infer_vocab" - hparams.src_vocab_file = vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt - + ckpt_path = self._createTestInferCheckpoint(hparams, "attention_infer") infer_file = "nmt/testdata/test_infer_file" - out_dir = os.path.join(tf.test.get_temp_dir(), "attention_infer") - hparams.out_dir = out_dir - os.makedirs(out_dir) - output_infer = os.path.join(out_dir, "output_infer") - ckpt = self._createTestInferCheckpoint(hparams, out_dir) - inference.inference(ckpt, infer_file, output_infer, hparams) + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(5, len(list(f))) @@ -129,15 +109,6 @@ def testMultiWorkers(self): attention="scaled_luong", attention_architecture="standard", use_residual=False,) - vocab_prefix = "nmt/testdata/test_infer_vocab" - hparams.src_vocab_file = vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt - - infer_file = "nmt/testdata/test_infer_file" - out_dir = os.path.join(tf.test.get_temp_dir(), "multi_worker_infer") - hparams.out_dir = out_dir - os.makedirs(out_dir) - output_infer = os.path.join(out_dir, "output_infer") num_workers = 3 @@ -146,17 +117,19 @@ def testMultiWorkers(self): # cases. hparams.batch_size = 3 - ckpt = self._createTestInferCheckpoint(hparams, out_dir) + ckpt_path = self._createTestInferCheckpoint(hparams, "multi_worker_infer") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") inference.inference( - ckpt, infer_file, output_infer, hparams, num_workers, jobid=1) + ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=1) inference.inference( - ckpt, infer_file, output_infer, hparams, num_workers, jobid=2) + ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=2) # Note: Need to start job 0 at the end; otherwise, it will block the testing # thread. inference.inference( - ckpt, infer_file, output_infer, hparams, num_workers, jobid=0) + ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=0) with open(output_infer) as f: self.assertEqual(5, len(list(f))) @@ -169,17 +142,11 @@ def testBasicModelWithInferIndices(self): attention_architecture="", use_residual=False, inference_indices=[0]) - vocab_prefix = "nmt/testdata/test_infer_vocab" - hparams.src_vocab_file = vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt - + ckpt_path = self._createTestInferCheckpoint(hparams, + "basic_infer_with_indices") infer_file = "nmt/testdata/test_infer_file" - out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer_with_indices") - hparams.out_dir = out_dir - os.makedirs(out_dir) - output_infer = os.path.join(out_dir, "output_infer") - ckpt = self._createTestInferCheckpoint(hparams, out_dir) - inference.inference(ckpt, infer_file, output_infer, hparams) + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(1, len(list(f))) @@ -193,18 +160,11 @@ def testAttentionModelWithInferIndices(self): inference_indices=[1, 2]) # TODO(rzhao): Make infer indices support batch_size > 1. hparams.infer_batch_size = 1 - vocab_prefix = "nmt/testdata/test_infer_vocab" - hparams.src_vocab_file = vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt - + ckpt_path = self._createTestInferCheckpoint(hparams, + "attention_infer_with_indices") infer_file = "nmt/testdata/test_infer_file" - out_dir = os.path.join(tf.test.get_temp_dir(), - "attention_infer_with_indices") - hparams.out_dir = out_dir - os.makedirs(out_dir) - output_infer = os.path.join(out_dir, "output_infer") - ckpt = self._createTestInferCheckpoint(hparams, out_dir) - inference.inference(ckpt, infer_file, output_infer, hparams) + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(2, len(list(f))) self.assertTrue(os.path.exists(output_infer+str(1)+".png")) diff --git a/nmt/model.py b/nmt/model.py index fb789815a..e0c4f4e03 100644 --- a/nmt/model.py +++ b/nmt/model.py @@ -19,20 +19,42 @@ from __future__ import print_function import abc +import collections +import numpy as np import tensorflow as tf -from tensorflow.python.layers import core as layers_core - from . import model_helper from .utils import iterator_utils from .utils import misc_utils as utils +from .utils import vocab_utils utils.check_tensorflow_version() __all__ = ["BaseModel", "Model"] +class TrainOutputTuple(collections.namedtuple( + "TrainOutputTuple", ("train_summary", "train_loss", "predict_count", + "global_step", "word_count", "batch_size", "grad_norm", + "learning_rate"))): + """To allow for flexibily in returing different outputs.""" + pass + + +class EvalOutputTuple(collections.namedtuple( + "EvalOutputTuple", ("eval_loss", "predict_count", "batch_size"))): + """To allow for flexibily in returing different outputs.""" + pass + + +class InferOutputTuple(collections.namedtuple( + "InferOutputTuple", ("infer_logits", "infer_summary", "sample_id", + "sample_words"))): + """To allow for flexibily in returing different outputs.""" + pass + + class BaseModel(object): """Sequence-to-sequence base class. """ @@ -60,6 +82,33 @@ def __init__(self, extra_args: model_helper.ExtraArgs, for passing customizable functions. """ + # Set params + self._set_params_initializer(hparams, mode, iterator, + source_vocab_table, target_vocab_table, + scope, extra_args) + + # Not used in general seq2seq models; when True, ignore decoder & training + self.extract_encoder_layers = (hasattr(hparams, "extract_encoder_layers") + and hparams.extract_encoder_layers) + + # Train graph + res = self.build_graph(hparams, scope=scope) + if not self.extract_encoder_layers: + self._set_train_or_infer(res, reverse_target_vocab_table, hparams) + + # Saver + self.saver = tf.train.Saver( + tf.global_variables(), max_to_keep=hparams.num_keep_ckpts) + + def _set_params_initializer(self, + hparams, + mode, + iterator, + source_vocab_table, + target_vocab_table, + scope, + extra_args=None): + """Set various params for self and initialize.""" assert isinstance(iterator, iterator_utils.BatchedInput) self.iterator = iterator self.mode = mode @@ -71,11 +120,21 @@ def __init__(self, self.num_gpus = hparams.num_gpus self.time_major = hparams.time_major + if hparams.use_char_encode: + assert (not self.time_major), ("Can't use time major for" + " char-level inputs.") + + self.dtype = tf.float32 + self.num_sampled_softmax = hparams.num_sampled_softmax + # extra_args: to make it flexible for adding external customizable code self.single_cell_fn = None if extra_args: self.single_cell_fn = extra_args.single_cell_fn + # Set num units + self.num_units = hparams.num_units + # Set num layers self.num_encoder_layers = hparams.num_encoder_layers self.num_decoder_layers = hparams.num_decoder_layers @@ -90,24 +149,27 @@ def __init__(self, self.num_encoder_residual_layers = hparams.num_encoder_residual_layers self.num_decoder_residual_layers = hparams.num_decoder_residual_layers + # Batch size + self.batch_size = tf.size(self.iterator.source_sequence_length) + + # Global step + self.global_step = tf.Variable(0, trainable=False) + # Initializer + self.random_seed = hparams.random_seed initializer = model_helper.get_initializer( - hparams.init_op, hparams.random_seed, hparams.init_weight) + hparams.init_op, self.random_seed, hparams.init_weight) tf.get_variable_scope().set_initializer(initializer) # Embeddings + if extra_args and extra_args.encoder_emb_lookup_fn: + self.encoder_emb_lookup_fn = extra_args.encoder_emb_lookup_fn + else: + self.encoder_emb_lookup_fn = tf.nn.embedding_lookup self.init_embeddings(hparams, scope) - self.batch_size = tf.size(self.iterator.source_sequence_length) - - # Projection - with tf.variable_scope(scope or "build_network"): - with tf.variable_scope("decoder/output_projection"): - self.output_layer = layers_core.Dense( - hparams.tgt_vocab_size, use_bias=False, name="output_projection") - - ## Train graph - res = self.build_graph(hparams, scope=scope) + def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams): + """Set up training and inference.""" if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.train_loss = res[1] self.word_count = tf.reduce_sum( @@ -125,11 +187,10 @@ def __init__(self, self.predict_count = tf.reduce_sum( self.iterator.target_sequence_length) - self.global_step = tf.Variable(0, trainable=False) params = tf.trainable_variables() # Gradients and SGD update operation for training the model. - # Arrage for the embedding vars to appear at the beginning. + # Arrange for the embedding vars to appear at the beginning. if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.learning_rate = tf.constant(hparams.learning_rate) # warm-up @@ -140,9 +201,10 @@ def __init__(self, # Optimizer if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) - tf.summary.scalar("lr", self.learning_rate) elif hparams.optimizer == "adam": opt = tf.train.AdamOptimizer(self.learning_rate) + else: + raise ValueError("Unknown optimizer type %s" % hparams.optimizer) # Gradients gradients = tf.gradients( @@ -152,26 +214,20 @@ def __init__(self, clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip( gradients, max_gradient_norm=hparams.max_gradient_norm) + self.grad_norm_summary = grad_norm_summary self.grad_norm = grad_norm self.update = opt.apply_gradients( zip(clipped_grads, params), global_step=self.global_step) # Summary - self.train_summary = tf.summary.merge([ - tf.summary.scalar("lr", self.learning_rate), - tf.summary.scalar("train_loss", self.train_loss), - ] + grad_norm_summary) - - if self.mode == tf.contrib.learn.ModeKeys.INFER: + self.train_summary = self._get_train_summary() + elif self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_summary = self._get_infer_summary(hparams) - # Saver - self.saver = tf.train.Saver( - tf.global_variables(), max_to_keep=hparams.num_keep_ckpts) - # Print trainable variables utils.print_out("# Trainable variables") + utils.print_out("Format: , , <(soft) device placement>") for param in params: utils.print_out(" %s, %s, %s" % (param.name, str(param.get_shape()), param.op.device)) @@ -201,8 +257,8 @@ def _get_learning_rate_warmup(self, hparams): lambda: self.learning_rate, name="learning_rate_warump_cond") - def _get_learning_rate_decay(self, hparams): - """Get learning rate decay.""" + def _get_decay_info(self, hparams): + """Return decay info based on decay_scheme.""" if hparams.decay_scheme in ["luong5", "luong10", "luong234"]: decay_factor = 0.5 if hparams.decay_scheme == "luong5": @@ -222,6 +278,11 @@ def _get_learning_rate_decay(self, hparams): decay_factor = 1.0 elif hparams.decay_scheme: raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme) + return start_decay_step, decay_steps, decay_factor + + def _get_learning_rate_decay(self, hparams): + """Get learning rate decay.""" + start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams) utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, " "decay_factor %g" % (hparams.decay_scheme, start_decay_step, @@ -244,32 +305,45 @@ def init_embeddings(self, hparams, scope): share_vocab=hparams.share_vocab, src_vocab_size=self.src_vocab_size, tgt_vocab_size=self.tgt_vocab_size, - src_embed_size=hparams.num_units, - tgt_embed_size=hparams.num_units, - num_partitions=hparams.num_embeddings_partitions, + src_embed_size=self.num_units, + tgt_embed_size=self.num_units, + num_enc_partitions=hparams.num_enc_emb_partitions, + num_dec_partitions=hparams.num_dec_emb_partitions, src_vocab_file=hparams.src_vocab_file, tgt_vocab_file=hparams.tgt_vocab_file, src_embed_file=hparams.src_embed_file, tgt_embed_file=hparams.tgt_embed_file, + use_char_encode=hparams.use_char_encode, scope=scope,)) + def _get_train_summary(self): + """Get train summary.""" + train_summary = tf.summary.merge( + [tf.summary.scalar("lr", self.learning_rate), + tf.summary.scalar("train_loss", self.train_loss)] + + self.grad_norm_summary) + return train_summary + def train(self, sess): + """Execute train graph.""" assert self.mode == tf.contrib.learn.ModeKeys.TRAIN - return sess.run([self.update, - self.train_loss, - self.predict_count, - self.train_summary, - self.global_step, - self.word_count, - self.batch_size, - self.grad_norm, - self.learning_rate]) + output_tuple = TrainOutputTuple(train_summary=self.train_summary, + train_loss=self.train_loss, + predict_count=self.predict_count, + global_step=self.global_step, + word_count=self.word_count, + batch_size=self.batch_size, + grad_norm=self.grad_norm, + learning_rate=self.learning_rate) + return sess.run([self.update, output_tuple]) def eval(self, sess): + """Execute eval graph.""" assert self.mode == tf.contrib.learn.ModeKeys.EVAL - return sess.run([self.eval_loss, - self.predict_count, - self.batch_size]) + output_tuple = EvalOutputTuple(eval_loss=self.eval_loss, + predict_count=self.predict_count, + batch_size=self.batch_size) + return sess.run(output_tuple) def build_graph(self, hparams, scope=None): """Subclass must implement this method. @@ -280,35 +354,51 @@ def build_graph(self, hparams, scope=None): scope: VariableScope for the created subgraph; default "dynamic_seq2seq". Returns: - A tuple of the form (logits, loss, final_context_state), + A tuple of the form (logits, loss_tuple, final_context_state, sample_id), where: logits: float32 Tensor [batch_size x num_decoder_symbols]. - loss: the total loss / batch_size. - final_context_state: The final state of decoder RNN. + loss: loss = the total loss / batch_size. + final_context_state: the final state of decoder RNN. + sample_id: sampling indices. Raises: ValueError: if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau). """ - utils.print_out("# creating %s graph ..." % self.mode) - dtype = tf.float32 + utils.print_out("# Creating %s graph ..." % self.mode) + + # Projection + if not self.extract_encoder_layers: + with tf.variable_scope(scope or "build_network"): + with tf.variable_scope("decoder/output_projection"): + self.output_layer = tf.layers.Dense( + self.tgt_vocab_size, use_bias=False, name="output_projection") - with tf.variable_scope(scope or "dynamic_seq2seq", dtype=dtype): + with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype): # Encoder - encoder_outputs, encoder_state = self._build_encoder(hparams) + if hparams.language_model: # no encoder for language modeling + utils.print_out(" language modeling: no encoder") + self.encoder_outputs = None + encoder_state = None + else: + self.encoder_outputs, encoder_state = self._build_encoder(hparams) + + # Skip decoder if extracting only encoder layers + if self.extract_encoder_layers: + return ## Decoder - logits, sample_id, final_context_state = self._build_decoder( - encoder_outputs, encoder_state, hparams) + logits, decoder_cell_outputs, sample_id, final_context_state = ( + self._build_decoder(self.encoder_outputs, encoder_state, hparams)) ## Loss if self.mode != tf.contrib.learn.ModeKeys.INFER: with tf.device(model_helper.get_device_str(self.num_encoder_layers - 1, self.num_gpus)): - loss = self._compute_loss(logits) + loss = self._compute_loss(logits, decoder_cell_outputs) else: - loss = None + loss = tf.constant(0.0) return logits, loss, final_context_state, sample_id @@ -332,7 +422,7 @@ def _build_encoder_cell(self, hparams, num_layers, num_residual_layers, return model_helper.create_rnn_cell( unit_type=hparams.unit_type, - num_units=hparams.num_units, + num_units=self.num_units, num_layers=num_layers, num_residual_layers=num_residual_layers, forget_bias=hparams.forget_bias, @@ -383,6 +473,11 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): hparams, encoder_outputs, encoder_state, iterator.source_sequence_length) + # Optional ops depends on which mode we are in and which loss function we + # are using. + logits = tf.no_op() + decoder_cell_outputs = None + ## Train or eval if self.mode != tf.contrib.learn.ModeKeys.INFER: # decoder_emp_inp: [max_time, batch_size, num_units] @@ -412,22 +507,40 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): sample_id = outputs.sample_id + if self.num_sampled_softmax > 0: + # Note: this is required when using sampled_softmax_loss. + decoder_cell_outputs = outputs.rnn_output + # Note: there's a subtle difference here between train and inference. # We could have set output_layer when create my_decoder # and shared more code between train and inference. # We chose to apply the output_layer to all timesteps for speed: # 10% improvements for small models & 20% for larger ones. # If memory is a concern, we should apply output_layer per timestep. - logits = self.output_layer(outputs.rnn_output) + num_layers = self.num_decoder_layers + num_gpus = self.num_gpus + device_id = num_layers if num_layers < num_gpus else (num_layers - 1) + # Colocate output layer with the last RNN cell if there is no extra GPU + # available. Otherwise, put last layer on a separate GPU. + with tf.device(model_helper.get_device_str(device_id, num_gpus)): + logits = self.output_layer(outputs.rnn_output) + + if self.num_sampled_softmax > 0: + logits = tf.no_op() # unused when using sampled softmax loss. ## Inference else: - beam_width = hparams.beam_width - length_penalty_weight = hparams.length_penalty_weight + infer_mode = hparams.infer_mode start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id + utils.print_out( + " decoder: infer_mode=%sbeam_width=%d, length_penalty=%f" % ( + infer_mode, hparams.beam_width, hparams.length_penalty_weight)) + + if infer_mode == "beam_search": + beam_width = hparams.beam_width + length_penalty_weight = hparams.length_penalty_weight - if beam_width > 0: my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=cell, embedding=self.embedding_decoder, @@ -437,19 +550,23 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): beam_width=beam_width, output_layer=self.output_layer, length_penalty_weight=length_penalty_weight) - else: + elif infer_mode == "sample": # Helper sampling_temperature = hparams.sampling_temperature - if sampling_temperature > 0.0: - helper = tf.contrib.seq2seq.SampleEmbeddingHelper( - self.embedding_decoder, start_tokens, end_token, - softmax_temperature=sampling_temperature, - seed=hparams.random_seed) - else: - helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( - self.embedding_decoder, start_tokens, end_token) - - # Decoder + assert sampling_temperature > 0.0, ( + "sampling_temperature must greater than 0.0 when using sample" + " decoder.") + helper = tf.contrib.seq2seq.SampleEmbeddingHelper( + self.embedding_decoder, start_tokens, end_token, + softmax_temperature=sampling_temperature, + seed=self.random_seed) + elif infer_mode == "greedy": + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( + self.embedding_decoder, start_tokens, end_token) + else: + raise ValueError("Unknown infer_mode '%s'", infer_mode) + + if infer_mode != "beam_search": my_decoder = tf.contrib.seq2seq.BasicDecoder( cell, helper, @@ -465,14 +582,13 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): swap_memory=True, scope=decoder_scope) - if beam_width > 0: - logits = tf.no_op() + if infer_mode == "beam_search": sample_id = outputs.predicted_ids else: logits = outputs.rnn_output sample_id = outputs.sample_id - return logits, sample_id, final_context_state + return logits, decoder_cell_outputs, sample_id, final_context_state def get_max_time(self, tensor): time_axis = 0 if self.time_major else 1 @@ -490,21 +606,56 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length: sequence length of encoder_outputs. Returns: - A tuple of a multi-layer RNN cell used by decoder - and the intial state of the decoder RNN. + A tuple of a multi-layer RNN cell used by decoder and the intial state of + the decoder RNN. """ pass - def _compute_loss(self, logits): + def _softmax_cross_entropy_loss( + self, logits, decoder_cell_outputs, labels): + """Compute softmax loss or sampled softmax loss.""" + if self.num_sampled_softmax > 0: + + is_sequence = (decoder_cell_outputs.shape.ndims == 3) + + if is_sequence: + labels = tf.reshape(labels, [-1, 1]) + inputs = tf.reshape(decoder_cell_outputs, [-1, self.num_units]) + + crossent = tf.nn.sampled_softmax_loss( + weights=tf.transpose(self.output_layer.kernel), + biases=self.output_layer.bias or tf.zeros([self.tgt_vocab_size]), + labels=labels, + inputs=inputs, + num_sampled=self.num_sampled_softmax, + num_classes=self.tgt_vocab_size, + partition_strategy="div", + seed=self.random_seed) + + if is_sequence: + if self.time_major: + crossent = tf.reshape(crossent, [-1, self.batch_size]) + else: + crossent = tf.reshape(crossent, [self.batch_size, -1]) + + else: + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + + return crossent + + def _compute_loss(self, logits, decoder_cell_outputs): """Compute optimization loss.""" target_output = self.iterator.target_output if self.time_major: target_output = tf.transpose(target_output) max_time = self.get_max_time(target_output) - crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=target_output, logits=logits) + + crossent = self._softmax_cross_entropy_loss( + logits, decoder_cell_outputs, target_output) + target_weights = tf.sequence_mask( - self.iterator.target_sequence_length, max_time, dtype=logits.dtype) + self.iterator.target_sequence_length, max_time, dtype=self.dtype) if self.time_major: target_weights = tf.transpose(target_weights) @@ -513,13 +664,16 @@ def _compute_loss(self, logits): return loss def _get_infer_summary(self, hparams): + del hparams return tf.no_op() def infer(self, sess): assert self.mode == tf.contrib.learn.ModeKeys.INFER - return sess.run([ - self.infer_logits, self.infer_summary, self.sample_id, self.sample_words - ]) + output_tuple = InferOutputTuple(infer_logits=self.infer_logits, + infer_summary=self.infer_summary, + sample_id=self.sample_id, + sample_words=self.sample_words) + return sess.run(output_tuple) def decode(self, sess): """Decode a batch. @@ -531,17 +685,34 @@ def decode(self, sess): A tuple consiting of outputs, infer_summary. outputs: of size [batch_size, time] """ - _, infer_summary, _, sample_words = self.infer(sess) + output_tuple = self.infer(sess) + sample_words = output_tuple.sample_words + infer_summary = output_tuple.infer_summary # make sure outputs is of shape [batch_size, time] or [beam_width, # batch_size, time] when using beam search. if self.time_major: sample_words = sample_words.transpose() - elif sample_words.ndim == 3: # beam search output in [batch_size, - # time, beam_width] shape. + elif sample_words.ndim == 3: + # beam search output in [batch_size, time, beam_width] shape. sample_words = sample_words.transpose([2, 0, 1]) return sample_words, infer_summary + def build_encoder_states(self, include_embeddings=False): + """Stack encoder states and return tensor [batch, length, layer, size].""" + assert self.mode == tf.contrib.learn.ModeKeys.INFER + if include_embeddings: + stack_state_list = tf.stack( + [self.encoder_emb_inp] + self.encoder_state_list, 2) + else: + stack_state_list = tf.stack(self.encoder_state_list, 2) + + # transform from [length, batch, ...] -> [batch, length, ...] + if self.time_major: + stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3]) + + return stack_state_list + class Model(BaseModel): """Sequence-to-sequence dynamic model. @@ -549,35 +720,45 @@ class Model(BaseModel): This class implements a multi-layer recurrent neural network as encoder, and a multi-layer recurrent neural network decoder. """ + def _build_encoder_from_sequence(self, hparams, sequence, sequence_length): + """Build an encoder from a sequence. - def _build_encoder(self, hparams): - """Build an encoder.""" + Args: + hparams: hyperparameters. + sequence: tensor with input sequence data. + sequence_length: tensor with length of the input sequence. + + Returns: + encoder_outputs: RNN encoder outputs. + encoder_state: RNN encoder state. + + Raises: + ValueError: if encoder_type is neither "uni" nor "bi". + """ num_layers = self.num_encoder_layers num_residual_layers = self.num_encoder_residual_layers - iterator = self.iterator - source = iterator.source if self.time_major: - source = tf.transpose(source) + sequence = tf.transpose(sequence) with tf.variable_scope("encoder") as scope: dtype = scope.dtype - # Look up embedding, emp_inp: [max_time, batch_size, num_units] - encoder_emb_inp = tf.nn.embedding_lookup( - self.embedding_encoder, source) + + self.encoder_emb_inp = self.encoder_emb_lookup_fn( + self.embedding_encoder, sequence) # Encoder_outputs: [max_time, batch_size, num_units] if hparams.encoder_type == "uni": utils.print_out(" num_layers = %d, num_residual_layers=%d" % (num_layers, num_residual_layers)) - cell = self._build_encoder_cell( - hparams, num_layers, num_residual_layers) + cell = self._build_encoder_cell(hparams, num_layers, + num_residual_layers) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell, - encoder_emb_inp, + self.encoder_emb_inp, dtype=dtype, - sequence_length=iterator.source_sequence_length, + sequence_length=sequence_length, time_major=self.time_major, swap_memory=True) elif hparams.encoder_type == "bi": @@ -588,8 +769,8 @@ def _build_encoder(self, hparams): encoder_outputs, bi_encoder_state = ( self._build_bidirectional_rnn( - inputs=encoder_emb_inp, - sequence_length=iterator.source_sequence_length, + inputs=self.encoder_emb_inp, + sequence_length=sequence_length, dtype=dtype, hparams=hparams, num_bi_layers=num_bi_layers, @@ -606,8 +787,18 @@ def _build_encoder(self, hparams): encoder_state = tuple(encoder_state) else: raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) + + # Use the top layer for now + self.encoder_state_list = [encoder_outputs] + return encoder_outputs, encoder_state + def _build_encoder(self, hparams): + """Build encoder from source.""" + utils.print_out("# Build a basic encoder") + return self._build_encoder_from_sequence( + hparams, self.iterator.source, self.iterator.source_sequence_length) + def _build_bidirectional_rnn(self, inputs, sequence_length, dtype, hparams, num_bi_layers, @@ -650,7 +841,7 @@ def _build_bidirectional_rnn(self, inputs, sequence_length, return tf.concat(bi_outputs, -1), bi_state def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, - source_sequence_length): + source_sequence_length, base_gpu=0): """Build an RNN cell that can be used by decoder.""" # We only make use of encoder_outputs in attention-based models if hparams.attention: @@ -658,17 +849,26 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, cell = model_helper.create_rnn_cell( unit_type=hparams.unit_type, - num_units=hparams.num_units, + num_units=self.num_units, num_layers=self.num_decoder_layers, num_residual_layers=self.num_decoder_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, num_gpus=self.num_gpus, mode=self.mode, - single_cell_fn=self.single_cell_fn) + single_cell_fn=self.single_cell_fn, + base_gpu=base_gpu + ) + + if hparams.language_model: + encoder_state = cell.zero_state(self.batch_size, self.dtype) + elif not hparams.pass_hidden_state: + raise ValueError("For non-attentional model, " + "pass_hidden_state needs to be set to True") # For beam search, we need to replicate encoder infos beam_width times - if self.mode == tf.contrib.learn.ModeKeys.INFER and hparams.beam_width > 0: + if (self.mode == tf.contrib.learn.ModeKeys.INFER and + hparams.infer_mode == "beam_search"): decoder_initial_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=hparams.beam_width) else: diff --git a/nmt/model_helper.py b/nmt/model_helper.py index efe18c1e8..65e111414 100644 --- a/nmt/model_helper.py +++ b/nmt/model_helper.py @@ -1,21 +1,33 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + """Utility functions for building models.""" from __future__ import print_function import collections -import six import os import time - import numpy as np +import six import tensorflow as tf from tensorflow.python.ops import lookup_ops - from .utils import iterator_utils from .utils import misc_utils as utils from .utils import vocab_utils - __all__ = [ "get_initializer", "get_device_str", "create_train_model", "create_eval_model", "create_infer_model", @@ -54,7 +66,7 @@ def get_device_str(device_id, num_gpus): class ExtraArgs(collections.namedtuple( "ExtraArgs", ("single_cell_fn", "model_device_fn", - "attention_mechanism_fn"))): + "attention_mechanism_fn", "encoder_emb_lookup_fn"))): pass @@ -79,8 +91,8 @@ def create_train_model( src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) - src_dataset = tf.data.TextLineDataset(src_file) - tgt_dataset = tf.data.TextLineDataset(tgt_file) + src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file)) + tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file)) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( @@ -97,7 +109,8 @@ def create_train_model( tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, - shard_index=jobid) + shard_index=jobid, + use_char_encode=hparams.use_char_encode) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. @@ -136,6 +149,9 @@ def create_eval_model(model_creator, hparams, scope=None, extra_args=None): with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) + reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( + tgt_vocab_file, default_value=vocab_utils.UNK) + src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) @@ -151,13 +167,15 @@ def create_eval_model(model_creator, hparams, scope=None, extra_args=None): random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, - tgt_max_len=hparams.tgt_max_len_infer) + tgt_max_len=hparams.tgt_max_len_infer, + use_char_encode=hparams.use_char_encode) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, + reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args) return EvalModel( @@ -197,7 +215,8 @@ def create_infer_model(model_creator, hparams, scope=None, extra_args=None): src_vocab_table, batch_size=batch_size_placeholder, eos=hparams.eos, - src_max_len=hparams.src_max_len_infer) + src_max_len=hparams.src_max_len_infer, + use_char_encode=hparams.use_char_encode) model = model_creator( hparams, iterator=iterator, @@ -274,11 +293,13 @@ def create_emb_for_encoder_and_decoder(share_vocab, src_embed_size, tgt_embed_size, dtype=tf.float32, - num_partitions=0, + num_enc_partitions=0, + num_dec_partitions=0, src_vocab_file=None, tgt_vocab_file=None, src_embed_file=None, tgt_embed_file=None, + use_char_encode=False, scope=None): """Create embedding matrix for both encoder and decoder. @@ -292,7 +313,10 @@ def create_emb_for_encoder_and_decoder(share_vocab, tgt_embed_size: An integer. The embedding dimension for the decoder's embedding. dtype: dtype of the embedding matrix. Default to float32. - num_partitions: number of partitions used for the embedding vars. + num_enc_partitions: number of partitions used for the encoder's embedding + vars. + num_dec_partitions: number of partitions used for the decoder's embedding + vars. scope: VariableScope for the created subgraph. Default to "embedding". Returns: @@ -303,22 +327,36 @@ def create_emb_for_encoder_and_decoder(share_vocab, ValueError: if use share_vocab but source and target have different vocab size. """ + if num_enc_partitions <= 1: + enc_partitioner = None + else: + # Note: num_partitions > 1 is required for distributed training due to + # embedding_lookup tries to colocate single partition-ed embedding variable + # with lookup ops. This may cause embedding variables being placed on worker + # jobs. + enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions) - if num_partitions <= 1: - partitioner = None + if num_dec_partitions <= 1: + dec_partitioner = None else: # Note: num_partitions > 1 is required for distributed training due to # embedding_lookup tries to colocate single partition-ed embedding variable # with lookup ops. This may cause embedding variables being placed on worker # jobs. - partitioner = tf.fixed_size_partitioner(num_partitions) + dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions) - if (src_embed_file or tgt_embed_file) and partitioner: + if src_embed_file and enc_partitioner: raise ValueError( - "Can't set num_partitions > 1 when using pretrained embedding") + "Can't set num_enc_partitions > 1 when using pretrained encoder " + "embedding") + + if tgt_embed_file and dec_partitioner: + raise ValueError( + "Can't set num_dec_partitions > 1 when using pretrained decdoer " + "embedding") with tf.variable_scope( - scope or "embeddings", dtype=dtype, partitioner=partitioner) as scope: + scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope: # Share embedding if share_vocab: if src_vocab_size != tgt_vocab_size: @@ -334,12 +372,15 @@ def create_emb_for_encoder_and_decoder(share_vocab, src_vocab_size, src_embed_size, dtype) embedding_decoder = embedding_encoder else: - with tf.variable_scope("encoder", partitioner=partitioner): - embedding_encoder = _create_or_load_embed( - "embedding_encoder", src_vocab_file, src_embed_file, - src_vocab_size, src_embed_size, dtype) - - with tf.variable_scope("decoder", partitioner=partitioner): + if not use_char_encode: + with tf.variable_scope("encoder", partitioner=enc_partitioner): + embedding_encoder = _create_or_load_embed( + "embedding_encoder", src_vocab_file, src_embed_file, + src_vocab_size, src_embed_size, dtype) + else: + embedding_encoder = None + + with tf.variable_scope("decoder", partitioner=dec_partitioner): embedding_decoder = _create_or_load_embed( "embedding_decoder", tgt_vocab_file, tgt_embed_file, tgt_vocab_size, tgt_embed_size, dtype) @@ -478,13 +519,29 @@ def gradient_clip(gradients, max_gradient_norm): return clipped_gradients, gradient_norm_summary, gradient_norm -def load_model(model, ckpt, session, name): +def print_variables_in_ckpt(ckpt_path): + """Print a list of variables in a checkpoint together with their shapes.""" + utils.print_out("# Variables in ckpt %s" % ckpt_path) + reader = tf.train.NewCheckpointReader(ckpt_path) + variable_map = reader.get_variable_to_shape_map() + for key in sorted(variable_map.keys()): + utils.print_out(" %s: %s" % (key, variable_map[key])) + + +def load_model(model, ckpt_path, session, name): + """Load model from a checkpoint.""" start_time = time.time() - model.saver.restore(session, ckpt) + try: + model.saver.restore(session, ckpt_path) + except tf.errors.NotFoundError as e: + utils.print_out("Can't load checkpoint") + print_variables_in_ckpt(ckpt_path) + utils.print_out("%s" % str(e)) + session.run(tf.tables_initializer()) utils.print_out( " loaded %s model parameters from %s, time %.2fs" % - (name, ckpt, time.time() - start_time)) + (name, ckpt_path, time.time() - start_time)) return model @@ -594,9 +651,9 @@ def compute_perplexity(model, sess, name): while True: try: - loss, predict_count, batch_size = model.eval(sess) - total_loss += loss * batch_size - total_predict_count += predict_count + output_tuple = model.eval(sess) + total_loss += output_tuple.eval_loss * output_tuple.batch_size + total_predict_count += output_tuple.predict_count except tf.errors.OutOfRangeError: break diff --git a/nmt/model_test.py b/nmt/model_test.py index 5af64df7f..168895844 100644 --- a/nmt/model_test.py +++ b/nmt/model_test.py @@ -34,6 +34,9 @@ int32 = np.int32 array = np.array +SOS = '' +EOS = '' + class ModelTest(tf.test.TestCase): @@ -143,6 +146,7 @@ def setUpClass(cls): 'UniEncoderStandardAttentionArchitecture/loss': 8.8519087, 'InitializerGlorotNormal/loss': 8.9779415, 'InitializerGlorotUniform/loss': 8.7643699, + 'SampledSoftmaxLoss/loss': 5.83928, } cls.actual_eval_values = {} @@ -186,15 +190,15 @@ def setUpClass(cls): cls.actual_beam_sentences = {} cls.expected_beam_sentences = { 'BeamSearchAttentionModel: batch 0 of beam 0': '', - 'BeamSearchAttentionModel: batch 0 of beam 1': 'sos a sos a', + 'BeamSearchAttentionModel: batch 0 of beam 1': '%s a %s a' % (SOS, SOS), 'BeamSearchAttentionModel: batch 1 of beam 0': '', 'BeamSearchAttentionModel: batch 1 of beam 1': 'b', 'BeamSearchBasicModel: batch 0 of beam 0': 'b b b b', - 'BeamSearchBasicModel: batch 0 of beam 1': 'b b b sos', + 'BeamSearchBasicModel: batch 0 of beam 1': 'b b b %s' % SOS, 'BeamSearchBasicModel: batch 0 of beam 2': 'b b b c', 'BeamSearchBasicModel: batch 1 of beam 0': 'b b b b', 'BeamSearchBasicModel: batch 1 of beam 1': 'a b b b', - 'BeamSearchBasicModel: batch 1 of beam 2': 'b b b sos', + 'BeamSearchBasicModel: batch 1 of beam 2': 'b b b %s' % SOS, 'BeamSearchGNMTModel: batch 0 of beam 0': '', 'BeamSearchGNMTModel: batch 1 of beam 0': '', } @@ -250,8 +254,8 @@ def _assertModelVariable(self, variable, sess, name): def _assertTrainStepsLoss(self, m, sess, name, num_steps=1): for _ in range(num_steps): - _, loss, _, _, _, _, _, _, _ = m.train(sess) - + _, output_tuple = m.train(sess) + loss = output_tuple.train_loss print('{} {}-th step loss is: '.format(name, num_steps), loss) expected_loss = self.expected_train_values[name + '/loss'] self.actual_train_values[name + '/loss'] = loss @@ -259,8 +263,9 @@ def _assertTrainStepsLoss(self, m, sess, name, num_steps=1): self.assertAllClose(expected_loss, loss) def _assertEvalLossAndPredictCount(self, m, sess, name): - loss, predict_count, _ = m.eval(sess) - + output_tuple = m.eval(sess) + loss = output_tuple.eval_loss + predict_count = output_tuple.predict_count print('{} eval loss is: '.format(name), loss) print('{} predict count is: '.format(name), predict_count) expected_loss = self.expected_eval_values[name + '/loss'] @@ -272,8 +277,8 @@ def _assertEvalLossAndPredictCount(self, m, sess, name): self.assertAllClose(expected_predict_count, predict_count) def _assertInferLogits(self, m, sess, name): - results = m.infer(sess) - logits_sum = np.sum(results[0]) + output_tuple = m.infer(sess) + logits_sum = np.sum(output_tuple.infer_logits) print('{} infer logits sum is: '.format(name), logits_sum) expected_logits_sum = self.expected_infer_values[name + '/logits_sum'] @@ -288,7 +293,7 @@ def _assertBeamSearchOutputs(self, m, sess, assert_top_k_sentence, name): output_words = nmt_outputs[i] for j in range(output_words.shape[0]): sentence = nmt_utils.get_translation( - output_words, j, tgt_eos='eos', subword_option='') + output_words, j, tgt_eos=EOS, subword_option='') sentence_key = ('%s: batch %d of beam %d' % (name, j, i)) self.actual_beam_sentences[sentence_key] = sentence expected_sentence = self.expected_beam_sentences[sentence_key] @@ -939,6 +944,7 @@ def testBeamSearchBasicModel(self): attention_architecture='', use_residual=False,) hparams.beam_width = 3 + hparams.infer_mode = "beam_search" hparams.tgt_max_len_infer = 4 assert_top_k_sentence = 3 @@ -956,6 +962,7 @@ def testBeamSearchAttentionModel(self): num_layers=2, use_residual=False,) hparams.beam_width = 3 + hparams.infer_mode = "beam_search" hparams.tgt_max_len_infer = 4 assert_top_k_sentence = 2 @@ -972,6 +979,7 @@ def testBeamSearchGNMTModel(self): attention='scaled_luong', attention_architecture='gnmt') hparams.beam_width = 3 + hparams.infer_mode = "beam_search" hparams.tgt_max_len_infer = 4 assert_top_k_sentence = 1 @@ -1009,5 +1017,18 @@ def testInitializerGlorotUniform(self): self._assertTrainStepsLoss(train_m, sess, 'InitializerGlorotUniform') + def testSampledSoftmaxLoss(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='gnmt', + num_layers=4, + attention='scaled_luong', + attention_architecture='gnmt') + hparams.num_sampled_softmax = 3 + + with self.test_session() as sess: + train_m = self._createTestTrainModel(gnmt_model.GNMTModel, hparams, sess) + self._assertTrainStepsLoss(train_m, sess, + 'SampledSoftmaxLoss') + if __name__ == '__main__': tf.test.main() diff --git a/nmt/nmt.py b/nmt/nmt.py index aa18acd95..f5823d893 100644 --- a/nmt/nmt.py +++ b/nmt/nmt.py @@ -35,6 +35,11 @@ FLAGS = None +INFERENCE_KEYS = ["src_max_len_infer", "tgt_max_len_infer", "subword_option", + "infer_batch_size", "beam_width", + "length_penalty_weight", "sampling_temperature", + "num_translations_per_input", "infer_mode"] + def add_arguments(parser): """Build ArgumentParser.""" @@ -202,6 +207,9 @@ def add_arguments(parser): help="Limit on the size of training data (0: no limit).") parser.add_argument("--num_buckets", type=int, default=5, help="Put data into similar-length buckets.") + parser.add_argument("--num_sampled_softmax", type=int, default=0, + help=("Use sampled_softmax_loss if > 0." + "Otherwise, use full softmax loss.")) # SPM parser.add_argument("--subword_option", type=str, default="", @@ -210,6 +218,14 @@ def add_arguments(parser): Set to bpe or spm to activate subword desegmentation.\ """) + # Experimental encoding feature. + parser.add_argument("--use_char_encode", type="bool", default=False, + help="""\ + Whether to split each word or bpe into character, and then + generate the word-level representation from the character + reprentation. + """) + # Misc parser.add_argument("--num_gpus", type=int, default=1, help="Number of gpus in each worker.") @@ -240,6 +256,9 @@ def add_arguments(parser): Average the last N checkpoints for external evaluation. N can be controlled by setting --num_keep_ckpts.\ """)) + parser.add_argument("--language_model", type="bool", nargs="?", + const=True, default=False, + help="True to train a language model, ignoring encoder") # Inference parser.add_argument("--ckpt", type=str, default="", @@ -257,6 +276,11 @@ def add_arguments(parser): help=("""\ Reference file to compute evaluation scores (if provided).\ """)) + + # Advanced inference arguments + parser.add_argument("--infer_mode", type=str, default="greedy", + choices=["greedy", "sample", "beam_search"], + help="Which type of decoder to use during inference.") parser.add_argument("--beam_width", type=int, default=0, help=("""\ beam width when using beam search decoder. If 0 (default), use standard @@ -302,7 +326,6 @@ def create_hparams(flags): # Networks num_units=flags.num_units, - num_layers=flags.num_layers, # Compatible num_encoder_layers=(flags.num_encoder_layers or flags.num_layers), num_decoder_layers=(flags.num_decoder_layers or flags.num_layers), dropout=flags.dropout, @@ -330,6 +353,7 @@ def create_hparams(flags): warmup_scheme=flags.warmup_scheme, decay_scheme=flags.decay_scheme, colocate_gradients_with_ops=flags.colocate_gradients_with_ops, + num_sampled_softmax=flags.num_sampled_softmax, # Data constraints num_buckets=flags.num_buckets, @@ -341,6 +365,9 @@ def create_hparams(flags): src_max_len_infer=flags.src_max_len_infer, tgt_max_len_infer=flags.tgt_max_len_infer, infer_batch_size=flags.infer_batch_size, + + # Advanced inference arguments + infer_mode=flags.infer_mode, beam_width=flags.beam_width, length_penalty_weight=flags.length_penalty_weight, sampling_temperature=flags.sampling_temperature, @@ -351,6 +378,7 @@ def create_hparams(flags): eos=flags.eos if flags.eos else vocab_utils.EOS, subword_option=flags.subword_option, check_special_token=flags.check_special_token, + use_char_encode=flags.use_char_encode, # Misc forget_bias=flags.forget_bias, @@ -365,21 +393,23 @@ def create_hparams(flags): override_loaded_hparams=flags.override_loaded_hparams, num_keep_ckpts=flags.num_keep_ckpts, avg_ckpts=flags.avg_ckpts, + language_model=flags.language_model, num_intra_threads=flags.num_intra_threads, num_inter_threads=flags.num_inter_threads, ) -def extend_hparams(hparams): - """Extend training hparams.""" - assert hparams.num_encoder_layers and hparams.num_decoder_layers - if hparams.num_encoder_layers != hparams.num_decoder_layers: - hparams.pass_hidden_state = False - utils.print_out("Num encoder layer %d is different from num decoder layer" - " %d, so set pass_hidden_state to False" % ( - hparams.num_encoder_layers, - hparams.num_decoder_layers)) +def _add_argument(hparams, key, value, update=True): + """Add an argument to hparams; if exists, change the value if update==True.""" + if hasattr(hparams, key): + if update: + setattr(hparams, key, value) + else: + hparams.add_hparam(key, value) + +def extend_hparams(hparams): + """Add new arguments to hparams.""" # Sanity checks if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0: raise ValueError("For bi, num_encoder_layers %d should be even" % @@ -389,6 +419,23 @@ def extend_hparams(hparams): raise ValueError("For gnmt attention architecture, " "num_encoder_layers %d should be >= 2" % hparams.num_encoder_layers) + if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]: + raise ValueError("subword option must be either spm, or bpe") + if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0: + raise ValueError("beam_width must greater than 0 when using beam_search" + "decoder.") + if hparams.infer_mode == "sample" and hparams.sampling_temperature <= 0.0: + raise ValueError("sampling_temperature must greater than 0.0 when using" + "sample decoder.") + + # Different number of encoder / decoder layers + assert hparams.num_encoder_layers and hparams.num_decoder_layers + if hparams.num_encoder_layers != hparams.num_decoder_layers: + hparams.pass_hidden_state = False + utils.print_out("Num encoder layer %d is different from num decoder layer" + " %d, so set pass_hidden_state to False" % ( + hparams.num_encoder_layers, + hparams.num_decoder_layers)) # Set residual layers num_encoder_residual_layers = 0 @@ -408,20 +455,20 @@ def extend_hparams(hparams): # Compatible for GNMT models if hparams.num_encoder_layers == hparams.num_decoder_layers: num_decoder_residual_layers = num_encoder_residual_layers - hparams.add_hparam("num_encoder_residual_layers", num_encoder_residual_layers) - hparams.add_hparam("num_decoder_residual_layers", num_decoder_residual_layers) - - if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]: - raise ValueError("subword option must be either spm, or bpe") - - # Flags - utils.print_out("# hparams:") - utils.print_out(" src=%s" % hparams.src) - utils.print_out(" tgt=%s" % hparams.tgt) - utils.print_out(" train_prefix=%s" % hparams.train_prefix) - utils.print_out(" dev_prefix=%s" % hparams.dev_prefix) - utils.print_out(" test_prefix=%s" % hparams.test_prefix) - utils.print_out(" out_dir=%s" % hparams.out_dir) + _add_argument(hparams, "num_encoder_residual_layers", + num_encoder_residual_layers) + _add_argument(hparams, "num_decoder_residual_layers", + num_decoder_residual_layers) + + # Language modeling + if getattr(hparams, "language_model", None): + hparams.attention = "" + hparams.attention_architecture = "" + hparams.pass_hidden_state = False + hparams.share_vocab = True + hparams.src = hparams.tgt + utils.print_out("For language modeling, we turn off attention and " + "pass_hidden_state; turn on share_vocab; set src to tgt.") ## Vocab # Get vocab file names first @@ -432,10 +479,11 @@ def extend_hparams(hparams): raise ValueError("hparams.vocab_prefix must be provided.") # Source vocab + check_special_token = getattr(hparams, "check_special_token", True) src_vocab_size, src_vocab_file = vocab_utils.check_vocab( src_vocab_file, hparams.out_dir, - check_special_token=hparams.check_special_token, + check_special_token=check_special_token, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) @@ -449,54 +497,75 @@ def extend_hparams(hparams): tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab( tgt_vocab_file, hparams.out_dir, - check_special_token=hparams.check_special_token, + check_special_token=check_special_token, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) - hparams.add_hparam("src_vocab_size", src_vocab_size) - hparams.add_hparam("tgt_vocab_size", tgt_vocab_size) - hparams.add_hparam("src_vocab_file", src_vocab_file) - hparams.add_hparam("tgt_vocab_file", tgt_vocab_file) - - # Pretrained Embeddings: - hparams.add_hparam("src_embed_file", "") - hparams.add_hparam("tgt_embed_file", "") - if hparams.embed_prefix: + _add_argument(hparams, "src_vocab_size", src_vocab_size) + _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size) + _add_argument(hparams, "src_vocab_file", src_vocab_file) + _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file) + + # Num embedding partitions + num_embeddings_partitions = getattr(hparams, "num_embeddings_partitions", 0) + _add_argument(hparams, "num_enc_emb_partitions", num_embeddings_partitions) + _add_argument(hparams, "num_dec_emb_partitions", num_embeddings_partitions) + + # Pretrained Embeddings + _add_argument(hparams, "src_embed_file", "") + _add_argument(hparams, "tgt_embed_file", "") + if getattr(hparams, "embed_prefix", None): src_embed_file = hparams.embed_prefix + "." + hparams.src tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt if tf.gfile.Exists(src_embed_file): + utils.print_out(" src_embed_file %s exist" % src_embed_file) hparams.src_embed_file = src_embed_file + utils.print_out( + "For pretrained embeddings, set num_enc_emb_partitions to 1") + hparams.num_enc_emb_partitions = 1 + else: + utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file) + if tf.gfile.Exists(tgt_embed_file): + utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file) hparams.tgt_embed_file = tgt_embed_file - # Check out_dir - if not tf.gfile.Exists(hparams.out_dir): - utils.print_out("# Creating output directory %s ..." % hparams.out_dir) - tf.gfile.MakeDirs(hparams.out_dir) + utils.print_out( + "For pretrained embeddings, set num_dec_emb_partitions to 1") + hparams.num_dec_emb_partitions = 1 + else: + utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file) # Evaluation for metric in hparams.metrics: - hparams.add_hparam("best_" + metric, 0) # larger is better best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric) - hparams.add_hparam("best_" + metric + "_dir", best_metric_dir) tf.gfile.MakeDirs(best_metric_dir) + _add_argument(hparams, "best_" + metric, 0, update=False) + _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir) - if hparams.avg_ckpts: - hparams.add_hparam("avg_best_" + metric, 0) # larger is better + if getattr(hparams, "avg_ckpts", None): best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric) - hparams.add_hparam("avg_best_" + metric + "_dir", best_metric_dir) tf.gfile.MakeDirs(best_metric_dir) + _add_argument(hparams, "avg_best_" + metric, 0, update=False) + _add_argument(hparams, "avg_best_" + metric + "_dir", best_metric_dir) return hparams -def ensure_compatible_hparams(hparams, default_hparams, hparams_path): +def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""): """Make sure the loaded hparams is compatible with new changes.""" default_hparams = utils.maybe_parse_standard_hparams( default_hparams, hparams_path) + # Set num encoder/decoder layers (for old checkpoints) + if hasattr(hparams, "num_layers"): + if not hasattr(hparams, "num_encoder_layers"): + hparams.add_hparam("num_encoder_layers", hparams.num_layers) + if not hasattr(hparams, "num_decoder_layers"): + hparams.add_hparam("num_decoder_layers", hparams.num_layers) + # For compatible reason, if there are new fields in default_hparams, # we add them to the current hparams default_config = default_hparams.values() @@ -506,13 +575,18 @@ def ensure_compatible_hparams(hparams, default_hparams, hparams_path): hparams.add_hparam(key, default_config[key]) # Update all hparams' keys if override_loaded_hparams=True - if default_hparams.override_loaded_hparams: - for key in default_config: - if getattr(hparams, key) != default_config[key]: - utils.print_out("# Updating hparams.%s: %s -> %s" % - (key, str(getattr(hparams, key)), - str(default_config[key]))) - setattr(hparams, key, default_config[key]) + if getattr(default_hparams, "override_loaded_hparams", None): + overwritten_keys = default_config.keys() + else: + # For inference + overwritten_keys = INFERENCE_KEYS + + for key in overwritten_keys: + if getattr(hparams, key) != default_config[key]: + utils.print_out("# Updating hparams.%s: %s -> %s" % + (key, str(getattr(hparams, key)), + str(default_config[key]))) + setattr(hparams, key, default_config[key]) return hparams @@ -524,9 +598,9 @@ def create_or_load_hparams( hparams = default_hparams hparams = utils.maybe_parse_standard_hparams( hparams, hparams_path) - hparams = extend_hparams(hparams) else: hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path) + hparams = extend_hparams(hparams) # Save HParams if save_hparams: @@ -546,6 +620,10 @@ def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): num_workers = flags.num_workers utils.print_out("# Job id %d" % jobid) + # GPU device + utils.print_out( + "# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices())) + # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: @@ -553,15 +631,36 @@ def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): random.seed(random_seed + jobid) np.random.seed(random_seed + jobid) - ## Train / Decode + # Model output directory out_dir = flags.out_dir - if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir) + if out_dir and not tf.gfile.Exists(out_dir): + utils.print_out("# Creating output directory %s ..." % out_dir) + tf.gfile.MakeDirs(out_dir) # Load hparams. - hparams = create_or_load_hparams( - out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid == 0)) + loaded_hparams = False + if flags.ckpt: # Try to load hparams from the same directory as ckpt + ckpt_dir = os.path.dirname(flags.ckpt) + ckpt_hparams_file = os.path.join(ckpt_dir, "hparams") + if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path: + hparams = create_or_load_hparams( + ckpt_dir, default_hparams, flags.hparams_path, + save_hparams=False) + loaded_hparams = True + if not loaded_hparams: # Try to load from out_dir + assert out_dir + hparams = create_or_load_hparams( + out_dir, default_hparams, flags.hparams_path, + save_hparams=(jobid == 0)) + ## Train / Decode if flags.inference_input_file: + # Inference output directory + trans_file = flags.inference_output_file + assert trans_file + trans_dir = os.path.dirname(trans_file) + if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir) + # Inference indices hparams.inference_indices = None if flags.inference_list: @@ -569,7 +668,6 @@ def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): [int(token) for token in flags.inference_list.split(",")]) # Inference - trans_file = flags.inference_output_file ckpt = flags.ckpt if not ckpt: ckpt = tf.train.latest_checkpoint(out_dir) diff --git a/nmt/standard_hparams/iwslt15.json b/nmt/standard_hparams/iwslt15.json index ff5f46ec6..2b658eca1 100644 --- a/nmt/standard_hparams/iwslt15.json +++ b/nmt/standard_hparams/iwslt15.json @@ -13,7 +13,8 @@ "max_gradient_norm": 5.0, "metrics": ["bleu"], "num_buckets": 5, - "num_layers": 2, + "num_encoder_layers": 2, + "num_decoder_layers": 2, "num_train_steps": 12000, "decay_scheme": "luong234", "num_units": 512, diff --git a/nmt/standard_hparams/wmt16.json b/nmt/standard_hparams/wmt16.json index 8c1cb3fb0..ba57dc5ef 100644 --- a/nmt/standard_hparams/wmt16.json +++ b/nmt/standard_hparams/wmt16.json @@ -13,7 +13,8 @@ "max_gradient_norm": 5.0, "metrics": ["bleu"], "num_buckets": 5, - "num_layers": 4, + "num_encoder_layers": 4, + "num_decoder_layers": 4, "num_train_steps": 340000, "decay_scheme": "luong10", "num_units": 1024, diff --git a/nmt/standard_hparams/wmt16_gnmt_4_layer.json b/nmt/standard_hparams/wmt16_gnmt_4_layer.json index 0031a54e9..8f36e4133 100644 --- a/nmt/standard_hparams/wmt16_gnmt_4_layer.json +++ b/nmt/standard_hparams/wmt16_gnmt_4_layer.json @@ -13,7 +13,8 @@ "max_gradient_norm": 5.0, "metrics": ["bleu"], "num_buckets": 5, - "num_layers": 4, + "num_encoder_layers": 4, + "num_decoder_layers": 4, "num_train_steps": 340000, "decay_scheme": "luong10", "num_units": 1024, @@ -30,6 +31,7 @@ "tgt_max_len_infer": null, "time_major": true, "unit_type": "lstm", + "infer_mode": "beam_search", "beam_width": 10, "length_penalty_weight": 1.0 } diff --git a/nmt/standard_hparams/wmt16_gnmt_8_layer.json b/nmt/standard_hparams/wmt16_gnmt_8_layer.json index 438ddcf55..b96ec8782 100644 --- a/nmt/standard_hparams/wmt16_gnmt_8_layer.json +++ b/nmt/standard_hparams/wmt16_gnmt_8_layer.json @@ -13,7 +13,8 @@ "max_gradient_norm": 5.0, "metrics": ["bleu"], "num_buckets": 5, - "num_layers": 8, + "num_encoder_layers": 8, + "num_decoder_layers": 8, "num_train_steps": 340000, "decay_scheme": "luong10", "num_units": 1024, @@ -22,7 +23,6 @@ "share_vocab": false, "subword_option": "bpe", "sos": "", - "source_reverse": false, "src_max_len": 50, "src_max_len_infer": null, "steps_per_external_eval": null, @@ -31,6 +31,7 @@ "tgt_max_len_infer": null, "time_major": true, "unit_type": "lstm", + "infer_mode": "beam_search", "beam_width": 10, "length_penalty_weight": 1.0 } diff --git a/nmt/testdata/test_infer_vocab.src b/nmt/testdata/test_infer_vocab.src index ccecabbca..0e441f86b 100644 --- a/nmt/testdata/test_infer_vocab.src +++ b/nmt/testdata/test_infer_vocab.src @@ -1,5 +1,5 @@ -unk -eos -sos + + + test1 test2 diff --git a/nmt/testdata/test_infer_vocab.tgt b/nmt/testdata/test_infer_vocab.tgt index 6c60b1194..279be587b 100644 --- a/nmt/testdata/test_infer_vocab.tgt +++ b/nmt/testdata/test_infer_vocab.tgt @@ -1,6 +1,5 @@ -unk -eos -test1 -test2 + + + test3 test4 diff --git a/nmt/train.py b/nmt/train.py index 75978ec44..1f061486b 100644 --- a/nmt/train.py +++ b/nmt/train.py @@ -35,7 +35,8 @@ __all__ = [ "run_sample_decode", "run_internal_eval", "run_external_eval", "run_avg_external_eval", "run_full_eval", "init_stats", "update_stats", - "print_step_info", "process_stats", "train" + "print_step_info", "process_stats", "train", "get_model_creator", + "add_info_summaries", "get_best_results" ] @@ -52,20 +53,49 @@ def run_sample_decode(infer_model, infer_sess, model_dir, hparams, infer_model.batch_size_placeholder, summary_writer) -def run_internal_eval( - eval_model, eval_sess, model_dir, hparams, summary_writer, - use_test_set=True): - """Compute internal evaluation (perplexity) for both dev / test.""" +def run_internal_eval(eval_model, + eval_sess, + model_dir, + hparams, + summary_writer, + use_test_set=True, + dev_eval_iterator_feed_dict=None, + test_eval_iterator_feed_dict=None): + """Compute internal evaluation (perplexity) for both dev / test. + + Computes development and testing perplexities for given model. + + Args: + eval_model: Evaluation model for which to compute perplexities. + eval_sess: Evaluation TensorFlow session. + model_dir: Directory from which to load evaluation model from. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + use_test_set: Computes testing perplexity if true; does not otherwise. + Note that the development perplexity is always computed regardless of + value of this parameter. + dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + development evaluation. + test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + testing evaluation. + Returns: + Pair containing development perplexity and testing perplexity, in this + order. + """ + if dev_eval_iterator_feed_dict is None: + dev_eval_iterator_feed_dict = {} + if test_eval_iterator_feed_dict is None: + test_eval_iterator_feed_dict = {} with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, "eval") dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) - dev_eval_iterator_feed_dict = { - eval_model.src_file_placeholder: dev_src_file, - eval_model.tgt_file_placeholder: dev_tgt_file - } + dev_eval_iterator_feed_dict[eval_model.src_file_placeholder] = dev_src_file + dev_eval_iterator_feed_dict[eval_model.tgt_file_placeholder] = dev_tgt_file dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, dev_eval_iterator_feed_dict, @@ -74,30 +104,64 @@ def run_internal_eval( if use_test_set and hparams.test_prefix: test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) - test_eval_iterator_feed_dict = { - eval_model.src_file_placeholder: test_src_file, - eval_model.tgt_file_placeholder: test_tgt_file - } + test_eval_iterator_feed_dict[ + eval_model.src_file_placeholder] = test_src_file + test_eval_iterator_feed_dict[ + eval_model.tgt_file_placeholder] = test_tgt_file test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, test_eval_iterator_feed_dict, summary_writer, "test") return dev_ppl, test_ppl -def run_external_eval(infer_model, infer_sess, model_dir, hparams, - summary_writer, save_best_dev=True, use_test_set=True, - avg_ckpts=False): - """Compute external evaluation (bleu, rouge, etc.) for both dev / test.""" +def run_external_eval(infer_model, + infer_sess, + model_dir, + hparams, + summary_writer, + save_best_dev=True, + use_test_set=True, + avg_ckpts=False, + dev_infer_iterator_feed_dict=None, + test_infer_iterator_feed_dict=None): + """Compute external evaluation for both dev / test. + + Computes development and testing external evaluation (e.g. bleu, rouge) for + given model. + + Args: + infer_model: Inference model for which to compute perplexities. + infer_sess: Inference TensorFlow session. + model_dir: Directory from which to load inference model from. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + use_test_set: Computes testing external evaluation if true; does not + otherwise. Note that the development external evaluation is always + computed regardless of value of this parameter. + dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + development external evaluation. + test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + testing external evaluation. + Returns: + Triple containing development scores, testing scores and the TensorFlow + Variable for the global step number, in this order. + """ + if dev_infer_iterator_feed_dict is None: + dev_infer_iterator_feed_dict = {} + if test_infer_iterator_feed_dict is None: + test_infer_iterator_feed_dict = {} with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) - dev_infer_iterator_feed_dict = { - infer_model.src_placeholder: inference.load_data(dev_src_file), - infer_model.batch_size_placeholder: hparams.infer_batch_size, - } + dev_infer_iterator_feed_dict[ + infer_model.src_placeholder] = inference.load_data(dev_src_file) + dev_infer_iterator_feed_dict[ + infer_model.batch_size_placeholder] = hparams.infer_batch_size dev_scores = _external_eval( loaded_infer_model, global_step, @@ -115,10 +179,10 @@ def run_external_eval(infer_model, infer_sess, model_dir, hparams, if use_test_set and hparams.test_prefix: test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) - test_infer_iterator_feed_dict = { - infer_model.src_placeholder: inference.load_data(test_src_file), - infer_model.batch_size_placeholder: hparams.infer_batch_size, - } + test_infer_iterator_feed_dict[ + infer_model.src_placeholder] = inference.load_data(test_src_file) + test_infer_iterator_feed_dict[ + infer_model.batch_size_placeholder] = hparams.infer_batch_size test_scores = _external_eval( loaded_infer_model, global_step, @@ -156,16 +220,63 @@ def run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, return avg_dev_scores, avg_test_scores -def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, - hparams, summary_writer, sample_src_data, sample_tgt_data, - avg_ckpts=False): - """Wrapper for running sample_decode, internal_eval and external_eval.""" - run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, - sample_src_data, sample_tgt_data) +def run_internal_and_external_eval(model_dir, + infer_model, + infer_sess, + eval_model, + eval_sess, + hparams, + summary_writer, + avg_ckpts=False, + dev_eval_iterator_feed_dict=None, + test_eval_iterator_feed_dict=None, + dev_infer_iterator_feed_dict=None, + test_infer_iterator_feed_dict=None): + """Compute internal evaluation (perplexity) for both dev / test. + + Computes development and testing perplexities for given model. + + Args: + model_dir: Directory from which to load models from. + infer_model: Inference model for which to compute perplexities. + infer_sess: Inference TensorFlow session. + eval_model: Evaluation model for which to compute perplexities. + eval_sess: Evaluation TensorFlow session. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + avg_ckpts: Whether to compute average external evaluation scores. + dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + internal development evaluation. + test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + internal testing evaluation. + dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + external development evaluation. + test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + external testing evaluation. + Returns: + Triple containing results summary, global step Tensorflow Variable and + metrics in this order. + """ dev_ppl, test_ppl = run_internal_eval( - eval_model, eval_sess, model_dir, hparams, summary_writer) + eval_model, + eval_sess, + model_dir, + hparams, + summary_writer, + dev_eval_iterator_feed_dict=dev_eval_iterator_feed_dict, + test_eval_iterator_feed_dict=test_eval_iterator_feed_dict) dev_scores, test_scores, global_step = run_external_eval( - infer_model, infer_sess, model_dir, hparams, summary_writer) + infer_model, + infer_sess, + model_dir, + hparams, + summary_writer, + dev_infer_iterator_feed_dict=dev_infer_iterator_feed_dict, + test_infer_iterator_feed_dict=test_infer_iterator_feed_dict) metrics = { "dev_ppl": dev_ppl, @@ -196,25 +307,64 @@ def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, return result_summary, global_step, metrics +def run_full_eval(model_dir, + infer_model, + infer_sess, + eval_model, + eval_sess, + hparams, + summary_writer, + sample_src_data, + sample_tgt_data, + avg_ckpts=False): + """Wrapper for running sample_decode, internal_eval and external_eval. + + Args: + model_dir: Directory from which to load models from. + infer_model: Inference model for which to compute perplexities. + infer_sess: Inference TensorFlow session. + eval_model: Evaluation model for which to compute perplexities. + eval_sess: Evaluation TensorFlow session. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + sample_src_data: sample of source data for sample decoding. + sample_tgt_data: sample of target data for sample decoding. + avg_ckpts: Whether to compute average external evaluation scores. + Returns: + Triple containing results summary, global step Tensorflow Variable and + metrics in this order. + """ + run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, + sample_src_data, sample_tgt_data) + return run_internal_and_external_eval(model_dir, infer_model, infer_sess, + eval_model, eval_sess, hparams, + summary_writer, avg_ckpts) + + def init_stats(): """Initialize statistics that we want to accumulate.""" - return {"step_time": 0.0, "loss": 0.0, "predict_count": 0.0, - "total_count": 0.0, "grad_norm": 0.0} + return {"step_time": 0.0, "train_loss": 0.0, + "predict_count": 0.0, # word count on the target side + "word_count": 0.0, # word counts for both source and target + "sequence_count": 0.0, # number of training examples processed + "grad_norm": 0.0} def update_stats(stats, start_time, step_result): """Update stats: write summary and accumulate statistics.""" - (_, step_loss, step_predict_count, step_summary, global_step, - step_word_count, batch_size, grad_norm, learning_rate) = step_result + _, output_tuple = step_result # Update statistics - stats["step_time"] += (time.time() - start_time) - stats["loss"] += (step_loss * batch_size) - stats["predict_count"] += step_predict_count - stats["total_count"] += float(step_word_count) - stats["grad_norm"] += grad_norm + batch_size = output_tuple.batch_size + stats["step_time"] += time.time() - start_time + stats["train_loss"] += output_tuple.train_loss * batch_size + stats["grad_norm"] += output_tuple.grad_norm + stats["predict_count"] += output_tuple.predict_count + stats["word_count"] += output_tuple.word_count + stats["sequence_count"] += batch_size - return global_step, learning_rate, step_summary + return (output_tuple.global_step, output_tuple.learning_rate, + output_tuple.train_summary) def print_step_info(prefix, global_step, info, result_summary, log_f): @@ -227,13 +377,25 @@ def print_step_info(prefix, global_step, info, result_summary, log_f): log_f) +def add_info_summaries(summary_writer, global_step, info): + """Add stuffs in info to summaries.""" + excluded_list = ["learning_rate"] + for key in info: + if key not in excluded_list: + utils.add_summary(summary_writer, global_step, key, info[key]) + + def process_stats(stats, info, global_step, steps_per_stats, log_f): """Update info and check for overflow.""" - # Update info + # Per-step info info["avg_step_time"] = stats["step_time"] / steps_per_stats info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats - info["train_ppl"] = utils.safe_exp(stats["loss"] / stats["predict_count"]) - info["speed"] = stats["total_count"] / (1000 * stats["step_time"]) + info["avg_sequence_count"] = stats["sequence_count"] / steps_per_stats + info["speed"] = stats["word_count"] / (1000 * stats["step_time"]) + + # Per-predict info + info["train_ppl"] = ( + utils.safe_exp(stats["train_loss"] / stats["predict_count"])) # Check for overflow is_overflow = False @@ -250,8 +412,10 @@ def before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f): """Misc tasks to do before training.""" stats = init_stats() - info = {"train_ppl": 0.0, "speed": 0.0, "avg_step_time": 0.0, + info = {"train_ppl": 0.0, "speed": 0.0, + "avg_step_time": 0.0, "avg_grad_norm": 0.0, + "avg_sequence_count": 0.0, "learning_rate": loaded_train_model.learning_rate.eval( session=train_sess)} start_train_time = time.time() @@ -268,6 +432,21 @@ def before_train(loaded_train_model, train_model, train_sess, global_step, return stats, info, start_train_time +def get_model_creator(hparams): + """Get the right model class depending on configuration.""" + if (hparams.encoder_type == "gnmt" or + hparams.attention_architecture in ["gnmt", "gnmt_v2"]): + model_creator = gnmt_model.GNMTModel + elif hparams.attention_architecture == "standard": + model_creator = attention_model.AttentionModel + elif not hparams.attention: + model_creator = nmt_model.Model + else: + raise ValueError("Unknown attention architecture %s" % + hparams.attention_architecture) + return model_creator + + def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement @@ -281,18 +460,8 @@ def train(hparams, scope=None, target_session=""): if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval - if not hparams.attention: - model_creator = nmt_model.Model - else: # Attention - if (hparams.encoder_type == "gnmt" or - hparams.attention_architecture in ["gnmt", "gnmt_v2"]): - model_creator = gnmt_model.GNMTModel - elif hparams.attention_architecture == "standard": - model_creator = attention_model.AttentionModel - else: - raise ValueError("Unknown attention architecture %s" % - hparams.attention_architecture) - + # Create model + model_creator = get_model_creator(hparams) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) @@ -381,7 +550,7 @@ def train(hparams, scope=None, target_session=""): last_stats_step = global_step is_overflow = process_stats( stats, info, global_step, steps_per_stats, log_f) - print_step_info(" ", global_step, info, _get_best_results(hparams), + print_step_info(" ", global_step, info, get_best_results(hparams), log_f) if is_overflow: break @@ -392,8 +561,7 @@ def train(hparams, scope=None, target_session=""): if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) - utils.add_summary(summary_writer, global_step, "train_ppl", - info["train_ppl"]) + add_info_summaries(summary_writer, global_step, info) # Save checkpoint loaded_train_model.saver.save( @@ -482,7 +650,7 @@ def _format_results(name, ppl, scores, metrics): return result_str -def _get_best_results(hparams): +def get_best_results(hparams): """Summary of the current best results.""" tokens = [] for metric in hparams.metrics: @@ -514,7 +682,7 @@ def _sample_decode(model, global_step, sess, hparams, iterator, src_data, nmt_outputs, attention_summary = model.decode(sess) - if hparams.beam_width > 0: + if hparams.infer_mode == "beam_search": # get the top translation. nmt_outputs = nmt_outputs[0] @@ -558,7 +726,8 @@ def _external_eval(model, global_step, sess, hparams, iterator, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, - decode=decode) + decode=decode, + infer_mode=hparams.infer_mode) # Save on best metrics if decode: for metric in hparams.metrics: diff --git a/nmt/utils/common_test_utils.py b/nmt/utils/common_test_utils.py index f76fd3b10..68ff209f9 100644 --- a/nmt/utils/common_test_utils.py +++ b/nmt/utils/common_test_utils.py @@ -47,7 +47,7 @@ def create_test_hparams(unit_type="lstm", standard_hparams = standard_hparams_utils.create_standard_hparams() # Networks - standard_hparams.num_units = 5 + standard_hparams.num_units = 5 standard_hparams.num_encoder_layers = num_layers standard_hparams.num_decoder_layers = num_layers standard_hparams.dropout = 0.5 @@ -73,12 +73,13 @@ def create_test_hparams(unit_type="lstm", # Misc standard_hparams.forget_bias = 0.0 standard_hparams.random_seed = 3 + standard_hparams.language_model = False # Vocab standard_hparams.src_vocab_size = 5 standard_hparams.tgt_vocab_size = 5 - standard_hparams.eos = "eos" - standard_hparams.sos = "sos" + standard_hparams.eos = "" + standard_hparams.sos = "" standard_hparams.src_vocab_file = "" standard_hparams.tgt_vocab_file = "" standard_hparams.src_embed_file = "" diff --git a/nmt/utils/iterator_utils.py b/nmt/utils/iterator_utils.py index 623bf461a..31efb11ff 100644 --- a/nmt/utils/iterator_utils.py +++ b/nmt/utils/iterator_utils.py @@ -19,6 +19,9 @@ import tensorflow as tf +from ..utils import vocab_utils + + __all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"] @@ -35,17 +38,34 @@ def get_infer_iterator(src_dataset, src_vocab_table, batch_size, eos, - src_max_len=None): - src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) + src_max_len=None, + use_char_encode=False): + if use_char_encode: + src_eos_id = vocab_utils.EOS_CHAR_ID + else: + src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values) if src_max_len: src_dataset = src_dataset.map(lambda src: src[:src_max_len]) - # Convert the word strings to ids - src_dataset = src_dataset.map( - lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)) + + if use_char_encode: + # Convert the word strings to character ids + src_dataset = src_dataset.map( + lambda src: tf.reshape(vocab_utils.tokens_to_bytes(src), [-1])) + else: + # Convert the word strings to ids + src_dataset = src_dataset.map( + lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)) + # Add in the word counts. - src_dataset = src_dataset.map(lambda src: (src, tf.size(src))) + if use_char_encode: + src_dataset = src_dataset.map( + lambda src: (src, + tf.to_int32( + tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN))) + else: + src_dataset = src_dataset.map(lambda src: (src, tf.size(src))) def batching_func(x): return x.padded_batch( @@ -91,10 +111,16 @@ def get_iterator(src_dataset, skip_count=None, num_shards=1, shard_index=0, - reshuffle_each_iteration=True): + reshuffle_each_iteration=True, + use_char_encode=False): if not output_buffer_size: output_buffer_size = batch_size * 1000 - src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) + + if use_char_encode: + src_eos_id = vocab_utils.EOS_CHAR_ID + else: + src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) + tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) @@ -124,12 +150,21 @@ def get_iterator(src_dataset, src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) + # Convert the word strings to ids. Word strings that are not in the # vocab get the lookup table's default_value integer. - src_tgt_dataset = src_tgt_dataset.map( - lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), - tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), - num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) + if use_char_encode: + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]), + tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), + num_parallel_calls=num_parallel_calls) + else: + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), + tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), + num_parallel_calls=num_parallel_calls) + + src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) # Create a tgt_input prefixed with and a tgt_output suffixed with . src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, @@ -137,10 +172,20 @@ def get_iterator(src_dataset, tf.concat((tgt, [tgt_eos_id]), 0)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Add in sequence lengths. - src_tgt_dataset = src_tgt_dataset.map( - lambda src, tgt_in, tgt_out: ( - src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), - num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) + if use_char_encode: + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt_in, tgt_out: ( + src, tgt_in, tgt_out, + tf.to_int32(tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN), + tf.size(tgt_in)), + num_parallel_calls=num_parallel_calls) + else: + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt_in, tgt_out: ( + src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), + num_parallel_calls=num_parallel_calls) + + src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) def batching_func(x): diff --git a/nmt/utils/misc_utils.py b/nmt/utils/misc_utils.py index a680a5cf2..63dc5a69c 100644 --- a/nmt/utils/misc_utils.py +++ b/nmt/utils/misc_utils.py @@ -23,6 +23,7 @@ import os import sys import time +from distutils import version import numpy as np import tensorflow as tf @@ -30,7 +31,8 @@ def check_tensorflow_version(): min_tf_version = "1.4.0-dev20171024" - if tf.__version__ < min_tf_version: + if (version.LooseVersion(tf.__version__) < + version.LooseVersion(min_tf_version)): raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version) @@ -100,14 +102,10 @@ def load_hparams(model_dir): def maybe_parse_standard_hparams(hparams, hparams_path): """Override hparams values with existing standard hparams config.""" - if not hparams_path: - return hparams - - if tf.gfile.Exists(hparams_path): + if hparams_path and tf.gfile.Exists(hparams_path): print_out("# Loading standard hparams from %s" % hparams_path) - with tf.gfile.GFile(hparams_path, "r") as f: + with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_path, "rb")) as f: hparams.parse_json(f.read()) - return hparams @@ -116,7 +114,7 @@ def save_hparams(out_dir, hparams): hparams_file = os.path.join(out_dir, "hparams") print_out(" saving hparams to %s" % hparams_file) with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: - f.write(hparams.to_json()) + f.write(hparams.to_json(indent=4, sort_keys=True)) def debug_tensor(s, msg=None, summarize=10): diff --git a/nmt/utils/nmt_utils.py b/nmt/utils/nmt_utils.py index 72f71b5c2..2115de942 100644 --- a/nmt/utils/nmt_utils.py +++ b/nmt/utils/nmt_utils.py @@ -37,11 +37,12 @@ def decode_and_evaluate(name, beam_width, tgt_eos, num_translations_per_input=1, - decode=True): + decode=True, + infer_mode="greedy"): """Decode a test set and compute a score according to the evaluation task.""" # Decode if decode: - utils.print_out(" decoding to output %s." % trans_file) + utils.print_out(" decoding to output %s" % trans_file) start_time = time.time() num_sentences = 0 @@ -49,12 +50,15 @@ def decode_and_evaluate(name, tf.gfile.GFile(trans_file, mode="wb")) as trans_f: trans_f.write("") # Write empty string to ensure file is created. - num_translations_per_input = max( - min(num_translations_per_input, beam_width), 1) + if infer_mode == "greedy": + num_translations_per_input = 1 + elif infer_mode == "beam_search": + num_translations_per_input = min(num_translations_per_input, beam_width) + while True: try: nmt_outputs, _ = model.decode(sess) - if beam_width == 0: + if infer_mode != "beam_search": nmt_outputs = np.expand_dims(nmt_outputs, 0) batch_size = nmt_outputs.shape[1] diff --git a/nmt/utils/standard_hparams_utils.py b/nmt/utils/standard_hparams_utils.py index 15f294a5d..c47a6f6b3 100644 --- a/nmt/utils/standard_hparams_utils.py +++ b/nmt/utils/standard_hparams_utils.py @@ -36,7 +36,6 @@ def create_standard_hparams(): # Networks num_units=512, - num_layers=2, num_encoder_layers=2, num_decoder_layers=2, dropout=0.2, @@ -45,6 +44,8 @@ def create_standard_hparams(): residual=False, time_major=True, num_embeddings_partitions=0, + num_enc_emb_partitions=0, + num_dec_emb_partitions=0, # Attention mechanisms attention="scaled_luong", @@ -64,6 +65,7 @@ def create_standard_hparams(): decay_scheme="luong234", colocate_gradients_with_ops=True, num_train_steps=12000, + num_sampled_softmax=0, # Data constraints num_buckets=5, @@ -77,6 +79,7 @@ def create_standard_hparams(): sos="", eos="", subword_option="", + use_char_encode=False, check_special_token=True, # Misc @@ -101,4 +104,8 @@ def create_standard_hparams(): infer_batch_size=32, sampling_temperature=0.0, num_translations_per_input=1, + infer_mode="greedy", + + # Language model + language_model=False, ) diff --git a/nmt/utils/vocab_utils.py b/nmt/utils/vocab_utils.py index d5de9a11d..5063bf2ef 100644 --- a/nmt/utils/vocab_utils.py +++ b/nmt/utils/vocab_utils.py @@ -27,12 +27,76 @@ from ..utils import misc_utils as utils - +# word level special token UNK = "" SOS = "" EOS = "" UNK_ID = 0 +# char ids 0-255 come from utf-8 encoding bytes +# assign 256-300 to special chars +BOS_CHAR_ID = 256 # +EOS_CHAR_ID = 257 # +BOW_CHAR_ID = 258 # +EOW_CHAR_ID = 259 # +PAD_CHAR_ID = 260 # + +DEFAULT_CHAR_MAXLEN = 50 # max number of chars for each word. + + +def _string_to_bytes(text, max_length): + """Given string and length, convert to byte seq of at most max_length. + + This process mimics docqa/elmo's preprocessing: + https://github.com/allenai/document-qa/blob/master/docqa/elmo/data.py + + Note that we make use of BOS_CHAR_ID and EOS_CHAR_ID in iterator_utils.py & + our usage differs from docqa/elmo. + + Args: + text: tf.string tensor of shape [] + max_length: max number of chars for each word. + + Returns: + A tf.int32 tensor of the byte encoded text. + """ + byte_ids = tf.to_int32(tf.decode_raw(text, tf.uint8)) + byte_ids = byte_ids[:max_length - 2] + padding = tf.fill([max_length - tf.shape(byte_ids)[0] - 2], PAD_CHAR_ID) + byte_ids = tf.concat( + [[BOW_CHAR_ID], byte_ids, [EOW_CHAR_ID], padding], axis=0) + tf.logging.info(byte_ids) + + byte_ids = tf.reshape(byte_ids, [max_length]) + tf.logging.info(byte_ids.get_shape().as_list()) + return byte_ids + 1 + + +def tokens_to_bytes(tokens): + """Given a sequence of strings, map to sequence of bytes. + + Args: + tokens: A tf.string tensor + + Returns: + A tensor of shape words.shape + [bytes_per_word] containing byte versions + of each word. + """ + bytes_per_word = DEFAULT_CHAR_MAXLEN + with tf.device("/cpu:0"): + tf.assert_rank(tokens, 1) + shape = tf.shape(tokens) + tf.logging.info(tokens) + tokens_flat = tf.reshape(tokens, [-1]) + as_bytes_flat = tf.map_fn( + fn=lambda x: _string_to_bytes(x, max_length=bytes_per_word), + elems=tokens_flat, + dtype=tf.int32, + back_prop=False) + tf.logging.info(as_bytes_flat) + as_bytes = tf.reshape(as_bytes_flat, [shape[0], bytes_per_word]) + return as_bytes + def load_vocab(vocab_file): vocab = [] @@ -91,13 +155,15 @@ def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): def load_embed_txt(embed_file): """Load embed_file into a python dictionary. - Note: the embed_file should be a Glove formated txt file. Assuming - embed_size=5, for example: + Note: the embed_file should be a Glove/word2vec formatted txt file. Assuming + Here is an exampe assuming embed_size=5: the -0.071549 0.093459 0.023738 -0.090339 0.056123 to 0.57346 0.5417 -0.23477 -0.3624 0.4037 and 0.20327 0.47348 0.050877 0.002103 0.060547 + For word2vec format, the first line will be: . + Args: embed_file: file path to the embedding file. Returns: @@ -105,14 +171,24 @@ def load_embed_txt(embed_file): """ emb_dict = dict() emb_size = None - with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f: + + is_first_line = True + with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, "rb")) as f: for line in f: - tokens = line.strip().split(" ") + tokens = line.rstrip().split(" ") + if is_first_line: + is_first_line = False + if len(tokens) == 2: # header line + emb_size = int(tokens[1]) + continue word = tokens[0] vec = list(map(float, tokens[1:])) emb_dict[word] = vec if emb_size: - assert emb_size == len(vec), "All embedding size should be same." + if emb_size != len(vec): + utils.print_out( + "Ignoring %s since embeding size is inconsistent." % word) + del emb_dict[word] else: emb_size = len(vec) return emb_dict, emb_size