Skip to content

Commit 867f98c

Browse files
authored
Hook for reporting and recording variables
* Add a --num_report_steps option to specify reporting frequency. Currently the information for the following * `global_step/sec` and * `examples/sec` gets displayed (and recorded via the summary writer) after every step. This `--num_report_steps=N` option allows the user to specify the frequency (i.e. every N steps) with which such information should be displayed and recorded * Enable printing and recording of throughput + loss on a periodic basis This commit adds the ability to report (i.e. display to stdout) the following information on a periodic basis * step number * throughput * total_loss * mlm_loss * nsp_loss * learning_rate The frequency of the reporting is specified via the `--num_report_steps` option. Currently only the `throughput` and `total_loss` values get recorded (to the trace events file meant for tensorboard consumption). Note that `throughput` is the same as `examples.sec` and `total_loss` is the same as `loss` both of which are already reported and recorded via the `TPUEstimator` implementation. The `LogSessionRunHook` class is based on a similar class in the NVBERT implementation. It can be easily enhanced to report and record other/more variables. * Disable the log messages from being printed twice. Currently all the messages output via `tf.compat.v1.logging.info` get printed twice. For example ``` INFO:tensorflow:**** Trainable Variables **** I0610 12:40:23.553335 139787876316928 run_pretraining.py:256] **** Trainable Variables **** ``` Setting the `propgate` flag in the loggger to `False` will prevent this. For the above example, only one line will be printed ``` INFO:tensorflow:**** Trainable Variables **** ``` This makes the output log file more readable.
1 parent bc6324f commit 867f98c

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

optimization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(self,
153153
"""Constructs a AdamWeightDecayOptimizer."""
154154
super(AdamWeightDecayOptimizer, self).__init__(False, name)
155155

156-
self.learning_rate = learning_rate
156+
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
157157
self.weight_decay_rate = weight_decay_rate
158158
self.beta_1 = beta_1
159159
self.beta_2 = beta_2
@@ -253,7 +253,7 @@ def __init__(self,
253253
"""Constructs a LAMBOptimizer."""
254254
super(LAMBOptimizer, self).__init__(False, name)
255255

256-
self.learning_rate = learning_rate
256+
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
257257
self.weight_decay_rate = weight_decay_rate
258258
self.beta_1 = beta_1
259259
self.beta_2 = beta_2
@@ -369,7 +369,7 @@ def __init__(self,
369369
"""Constructs a NadamWeightDecayOptimizer."""
370370
super(NadamWeightDecayOptimizer, self).__init__(False, name)
371371

372-
self.learning_rate = learning_rate
372+
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
373373
self.weight_decay_rate = weight_decay_rate
374374
self.beta_1 = beta_1
375375
self.beta_2 = beta_2
@@ -481,7 +481,7 @@ def __init__(self,
481481
"""Constructs a NlamOptimizer."""
482482
super(NlambOptimizer, self).__init__(False, name)
483483

484-
self.learning_rate = learning_rate
484+
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
485485
self.weight_decay_rate = weight_decay_rate
486486
self.beta_1 = beta_1
487487
self.beta_2 = beta_2

run_pretraining.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import tensorflow as tf
2525
tf.compat.v1.disable_resource_variables()
2626

27+
import time
28+
from tensorflow.python.training.summary_io import SummaryWriterCache
29+
from tensorflow.core.framework.summary_pb2 import Summary
30+
2731
# Add Horovod to run_pretraining
2832
try:
2933
import horovod.tensorflow as hvd
@@ -116,6 +120,63 @@
116120

117121
flags.DEFINE_string("optimizer_type", "adam", "Optimizer used for training - adam (default), lamb, nadam and nlamb")
118122

123+
flags.DEFINE_integer(
124+
"num_report_steps", 10,
125+
"How frequently should summary information be reported and recorded.")
126+
127+
128+
class LogSessionRunHook(tf.estimator.SessionRunHook):
129+
130+
def __init__(self,
131+
global_batch_size,
132+
num_report_steps=10,
133+
output_dir=None):
134+
self.global_batch_size = global_batch_size
135+
self.num_report_steps = num_report_steps
136+
self.output_dir=output_dir
137+
self.summary_writer=None
138+
139+
def begin(self):
140+
if self.summary_writer is None and self.output_dir:
141+
self.summary_writer = SummaryWriterCache.get(self.output_dir)
142+
143+
def after_create_session(self, session, coord):
144+
self.elapsed_secs = 0.
145+
self.count = 0
146+
147+
def before_run(self, run_context):
148+
self.t0 = time.time()
149+
global_step = tf.compat.v1.train.get_global_step()
150+
fetches = [global_step, 'learning_rate:0', 'total_loss:0', 'mlm_loss:0', 'nsp_loss:0']
151+
return tf.estimator.SessionRunArgs(fetches=fetches)
152+
153+
def _log_and_record(self, global_step, learning_rate, total_loss, mlm_loss, nsp_loss):
154+
time_per_step = self.elapsed_secs / self.count
155+
throughput = self.global_batch_size / time_per_step
156+
log_string = ' '
157+
log_string += 'Step = %6i'%(global_step)
158+
log_string += ', throughput = %6.1f'%(throughput)
159+
log_string += ', total_loss = %6.3f'%(total_loss)
160+
log_string += ', mlm_oss = %6.4e'%(mlm_loss)
161+
log_string += ', nsp_loss = %6.4e'%(nsp_loss)
162+
log_string += ', learning_rate = %6.4e'%(learning_rate)
163+
tf.compat.v1.logging.info(log_string)
164+
165+
if self.summary_writer is not None:
166+
throughput_summary = Summary(value=[Summary.Value(tag='throughput', simple_value=throughput)])
167+
self.summary_writer.add_summary(throughput_summary, global_step)
168+
total_loss_summary = Summary(value=[Summary.Value(tag='total_loss', simple_value=total_loss)])
169+
self.summary_writer.add_summary(total_loss_summary, global_step)
170+
171+
def after_run(self, run_context, run_values):
172+
self.elapsed_secs += time.time() - self.t0
173+
self.count += 1
174+
global_step, learning_rate, total_loss, mlm_loss, nsp_loss = run_values.results[0:5]
175+
if (global_step % self.num_report_steps) == 0:
176+
self._log_and_record(global_step, learning_rate, total_loss, mlm_loss, nsp_loss)
177+
self.elapsed_secs = 0.
178+
self.count = 0
179+
119180

120181
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
121182
num_train_steps, num_warmup_steps, use_tpu,
@@ -158,6 +219,10 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
158219

159220
total_loss = masked_lm_loss + next_sentence_loss
160221

222+
masked_lm_loss = tf.identity(masked_lm_loss, name='mlm_loss')
223+
next_sentence_loss = tf.identity(next_sentence_loss, name='nsp_loss')
224+
total_loss = tf.identity(total_loss, name='total_loss')
225+
161226
tvars = tf.compat.v1.trainable_variables()
162227

163228
initialized_variable_names = {}
@@ -416,7 +481,10 @@ def _decode_record(record, name_to_features):
416481

417482
def main(_):
418483
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
419-
484+
485+
# disable the log messages from being printed twice
486+
tf.compat.v1.get_logger().propagate = False
487+
420488
use_hvd = False
421489
if FLAGS.use_horovod and hvd != None:
422490
use_hvd = True
@@ -468,7 +536,7 @@ def main(_):
468536
iterations_per_loop=FLAGS.iterations_per_loop,
469537
num_shards=FLAGS.num_tpu_cores,
470538
per_host_input_for_training=is_per_host),
471-
log_step_count_steps=25,
539+
log_step_count_steps=FLAGS.num_report_steps * FLAGS.iterations_per_loop,
472540
session_config=config)
473541

474542
model_fn = model_fn_builder(
@@ -499,10 +567,13 @@ def main(_):
499567
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
500568
is_training=True)
501569

502-
hooks = None
570+
hooks = []
571+
if (not use_hvd) or (hvd.rank() == 0):
572+
global_batch_size = FLAGS.train_batch_size if not use_hvd else FLAGS.train_batch_size * hvd.size()
573+
hooks.append(LogSessionRunHook(global_batch_size, FLAGS.num_report_steps, FLAGS.output_dir))
503574
if use_hvd:
504575
# [HVD] Ensure all GPU's start with the same weights.
505-
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
576+
hooks.append(hvd.BroadcastGlobalVariablesHook(0))
506577
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks)
507578

508579
if FLAGS.do_eval:

0 commit comments

Comments
 (0)