From b2ad3bf5f3ae9cb207280fe1023811c262cdc5e6 Mon Sep 17 00:00:00 2001 From: Lorenz Diener Date: Wed, 28 Oct 2015 10:29:49 +0100 Subject: [PATCH] Allow usage of any function as log printer --- brainstorm/hooks.py | 90 +++++++++++++++++++++------------- brainstorm/training/trainer.py | 35 ++++++++----- 2 files changed, 78 insertions(+), 47 deletions(-) diff --git a/brainstorm/hooks.py b/brainstorm/hooks.py index 6399fa3..7e07f72 100644 --- a/brainstorm/hooks.py +++ b/brainstorm/hooks.py @@ -29,14 +29,16 @@ class Hook(Describable): 'verbose': None } - def __init__(self, name=None, timescale='epoch', interval=1, verbose=None): + def __init__(self, name=None, timescale='epoch', interval=1, verbose=None, + logging_function=print): self.timescale = timescale self.interval = interval self.__name__ = name or self.__class__.__name__ self.priority = 0 self.verbose = verbose self.run_verbosity = None - + self.logging_function = logging_function + def start(self, net, stepper, verbose, named_data_iters): if self.verbose is None: self.run_verbosity = verbose @@ -46,7 +48,7 @@ def start(self, net, stepper, verbose, named_data_iters): def message(self, msg): """Print an output message if :attr:`run_verbosity` is True.""" if self.run_verbosity: - print("{} >> {}".format(self.__name__, msg)) + self.logging_function("{} >> {}".format(self.__name__, msg)) def __call__(self, epoch_nr, update_nr, net, stepper, logs): pass @@ -59,8 +61,10 @@ class SaveNetwork(Hook): the timescale and interval parameters. """ - def __init__(self, filename, name=None, timescale='epoch', interval=1): - super(SaveNetwork, self).__init__(name, timescale, interval) + def __init__(self, filename, name=None, timescale='epoch', interval=1, + logging_function=print): + super(SaveNetwork, self).__init__(name, timescale, interval, + logging_function) self.filename = filename def __call__(self, epoch_nr, update_nr, net, stepper, logs): @@ -79,9 +83,10 @@ class SaveBestNetwork(Hook): __default_values__ = {'filename': None} def __init__(self, log_name, filename=None, name=None, - criterion='max', timescale='epoch', interval=1, verbose=None): - super(SaveBestNetwork, self).__init__(name, timescale, - interval, verbose) + criterion='max', timescale='epoch', interval=1, verbose=None, + logging_function=print): + super(SaveBestNetwork, self).__init__(name, timescale, interval, + verbose, logging_function) self.log_name = log_name self.filename = filename self.parameters = None @@ -121,8 +126,10 @@ def load_parameters(self): class SaveLogs(Hook): - def __init__(self, filename, name=None, timescale='epoch', interval=1): - super(SaveLogs, self).__init__(name, timescale, interval) + def __init__(self, filename, name=None, timescale='epoch', interval=1, + logging_function=print): + super(SaveLogs, self).__init__(name, timescale, interval, + logging_function=logging_function) self.filename = filename def __call__(self, epoch_nr, update_nr, net, stepper, logs): @@ -144,9 +151,11 @@ def _save_recursively(group, logs): class ModifyStepperAttribute(Hook): """Modify an attribute of the training stepper.""" def __init__(self, schedule, attr_name='learning_rate', - timescale='epoch', interval=1, name=None, verbose=None): + timescale='epoch', interval=1, name=None, verbose=None, + logging_function=print): super(ModifyStepperAttribute, self).__init__(name, timescale, - interval, verbose) + interval, verbose, + logging_function) self.schedule = schedule self.attr_name = attr_name @@ -168,11 +177,12 @@ class MonitorLayerParameters(Hook): Monitor some properties of a layer. """ def __init__(self, layer_name, timescale='epoch', - interval=1, name=None, verbose=None): + interval=1, name=None, verbose=None, logging_function=print): if name is None: name = "MonitorParameters_{}".format(layer_name) super(MonitorLayerParameters, self).__init__(name, timescale, - interval, verbose) + interval, verbose, + logging_function) self.layer_name = layer_name def start(self, net, stepper, verbose, named_data_iters): @@ -202,11 +212,12 @@ class MonitorLayerGradients(Hook): Monitor some statistics about all the gradients of a layer. """ def __init__(self, layer_name, timescale='epoch', - interval=1, name=None, verbose=None): + interval=1, name=None, verbose=None, logging_function=print): if name is None: name = "MonitorGradients_{}".format(layer_name) super(MonitorLayerGradients, self).__init__(name, timescale, - interval, verbose) + interval, verbose, + logging_function) self.layer_name = layer_name def start(self, net, stepper, verbose, named_data_iters): @@ -230,7 +241,7 @@ class MonitorLayerDeltas(Hook): Monitor some statistics about all the deltas of a layer. """ def __init__(self, layer_name, timescale='epoch', - interval=1, name=None, verbose=None): + interval=1, name=None, verbose=None, logging_function=print): if name is None: name = "MonitorDeltas_{}".format(layer_name) super(MonitorLayerDeltas, self).__init__(name, timescale, @@ -275,11 +286,12 @@ class MonitorLayerInOuts(Hook): Monitor some statistics about all the inputs and outputs of a layer. """ def __init__(self, layer_name, timescale='epoch', - interval=1, name=None, verbose=None): + interval=1, name=None, verbose=None, logging_function=print): if name is None: name = "MonitorInOuts_{}".format(layer_name) super(MonitorLayerInOuts, self).__init__(name, timescale, - interval, verbose) + interval, verbose, + logging_function) self.layer_name = layer_name def start(self, net, stepper, verbose, named_data_iters): @@ -310,9 +322,10 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs): class StopAfterEpoch(Hook): def __init__(self, max_epochs, timescale='epoch', interval=1, name=None, - verbose=None): + verbose=None, logging_function=print): super(StopAfterEpoch, self).__init__(name, timescale, - interval, verbose) + interval, verbose, + logging_function) self.max_epochs = max_epochs def __call__(self, epoch_nr, update_nr, net, stepper, logs): @@ -325,8 +338,9 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs): class EarlyStopper(Hook): __default_values__ = {'patience': 1} - def __init__(self, log_name, patience=1, name=None): - super(EarlyStopper, self).__init__(name, 'epoch', 1) + def __init__(self, log_name, patience=1, name=None, logging_function=print): + super(EarlyStopper, self).__init__(name, 'epoch', 1, + logging_function=logging_function) self.log_name = log_name self.patience = patience @@ -346,8 +360,9 @@ class StopOnNan(Hook): """ def __init__(self, logs_to_check=(), check_parameters=True, check_training_loss=True, name=None, timescale='epoch', - interval=1): - super(StopOnNan, self).__init__(name, timescale, interval) + interval=1, logging_function=print): + super(StopOnNan, self).__init__(name, timescale, interval, + logging_function=logging_function) self.logs_to_check = ([logs_to_check] if isinstance(logs_to_check, string_types) else logs_to_check) @@ -378,8 +393,10 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs): class InfoUpdater(Hook): """ Save the information from logs to the Sacred custom info dict""" - def __init__(self, run, name=None, timescale='epoch', interval=1): - super(InfoUpdater, self).__init__(name, timescale, interval) + def __init__(self, run, name=None, timescale='epoch', interval=1, + logging_function=print): + super(InfoUpdater, self).__init__(name, timescale, interval, + logging_function=logging_function) self.run = run def __call__(self, epoch_nr, update_nr, net, stepper, logs): @@ -393,8 +410,9 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs): class MonitorLoss(Hook): def __init__(self, iter_name, timescale='epoch', interval=1, name=None, - verbose=None): - super(MonitorLoss, self).__init__(name, timescale, interval, verbose) + verbose=None, logging_function=print): + super(MonitorLoss, self).__init__(name, timescale, interval, verbose, + logging_function=logging_function) self.iter_name = iter_name self.iter = None @@ -439,10 +457,10 @@ class MonitorScores(Hook): """ def __init__(self, iter_name, scorers, timescale='epoch', interval=1, - name=None, verbose=None): + name=None, verbose=None, logging_function=print): super(MonitorScores, self).__init__(name, timescale, interval, - verbose) + verbose, logging_function) self.iter_name = iter_name self.iter = None self.scorers = scorers @@ -466,7 +484,8 @@ class StopOnSigQuit(Hook): """ __undescribed__ = {'quit': False} - def __init__(self, name=None, timescale='epoch', interval=1, verbose=None): + def __init__(self, name=None, timescale='epoch', interval=1, verbose=None, + logging_function=print): super(StopOnSigQuit, self).__init__(name, timescale, interval, verbose=verbose) self.quit = False @@ -527,10 +546,11 @@ class BokehVisualizer(Hook): and acts as a fallback verbosity for the used data iterator. If not set it defaults to the verbosity setting of the trainer. """ - def __init__(self, log_names, filename=None, timescale='epoch', interval=1, - name=None, verbose=None): + def __init__(self, log_names, filename=None, timescale='epoch', + interval=1, name=None, verbose=None, + logging_function=print): super(BokehVisualizer, self).__init__(name, timescale, interval, - verbose) + verbose, logging_function) if isinstance(log_names, basestring): self.log_names = [log_names] diff --git a/brainstorm/training/trainer.py b/brainstorm/training/trainer.py index 8467346..e5bd2e8 100644 --- a/brainstorm/training/trainer.py +++ b/brainstorm/training/trainer.py @@ -26,7 +26,7 @@ class Trainer(Describable): } __default_values__ = {'verbose': True} - def __init__(self, stepper, verbose=True): + def __init__(self, stepper, verbose=True, logging_function=print): """Create a new Trainer. Args: @@ -41,6 +41,7 @@ def __init__(self, stepper, verbose=True): self.current_update_nr = 0 self.logs = {} self.results = {} + self.logging_function = logging_function def add_hook(self, hook): """Add a hook to this trainer. @@ -70,7 +71,11 @@ def train(self, net, training_data_iter, **named_data_iters): iterators. """ if self.verbose: - print('\n\n', 10 * '- ', "Before Training", 10 * ' -') + if self.logging_function == print: + self.logging_function('\n\n' + 10 * '- ' + "Before Training" + + 10 * ' -') + else: + self.logging_function(10 * '- ' + "Before Training" + 10 * ' -') assert set(training_data_iter.data_shapes.keys()) == set( net.buffer.Input.outputs.keys()), \ "The data names provided by the training data iterator {} do not "\ @@ -91,8 +96,12 @@ def train(self, net, training_data_iter, **named_data_iters): train_scores.update({n: [] for n in net.get_loss_values()}) if self.verbose: - print('\n\n', 12 * '- ', "Epoch", self.current_epoch_nr, - 12 * ' -') + if self.logging_function == print: + self.logging_function('\n\n' + 12 * '- ' + "Epoch" + + str(self.current_epoch_nr) + 12 * ' -') + else: + self.logging_function(12 * '- ' + "Epoch" + + str(self.current_epoch_nr) + 12 * ' -') iterator = training_data_iter(handler=net.handler) for _ in run_network(net, iterator): self.current_update_nr += 1 @@ -133,8 +142,9 @@ def _start_hooks(self, net, named_data_iters): hook.start(net, self.stepper, self.verbose, named_data_iters) except Exception: - print('An error occurred while starting the "{}" hook:' - .format(name), file=sys.stderr) + self.logging_function( + 'An error occurred while starting the "{}" hook:' + .format(name), file=sys.stderr) raise def _emit_hooks(self, net, timescale, logs=None): @@ -163,9 +173,10 @@ def _call_hook(self, hook, net): except StopIteration as err: return getattr(err, 'value', None), True except Exception as e: - print('An error occurred while calling the "{}" hook:' - .format(hook.__name__), file=sys.stderr) - print(traceback.format_exc()) + self.logging_function( + 'An error occurred while calling the "{}" hook:' + .format(hook.__name__), file=sys.stderr) + self.logging_function(traceback.format_exc()) raise e def _add_log(self, name, val, verbose=None, logs=None, indent=0): @@ -178,14 +189,14 @@ def _add_log(self, name, val, verbose=None, logs=None, indent=0): if isinstance(val, dict): if verbose: - print(" " * indent + name) + self.logging_function(" " * indent + name) logs[name] = dict() if name not in logs else logs[name] for k, v in val.items(): self._add_log(k, v, verbose, logs[name], indent + 2) else: if verbose: - print(" " * indent + ("{0:%d}: {1}" % (40 - indent)) - .format(name, val)) + self.logging_function(" " * indent + ("{0:%d}: {1}" % + (40 - indent)).format(name, val)) logs[name] = [] if name not in logs else logs[name] logs[name].append(val)