Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions brainstorm/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ class Hook(Describable):
'verbose': None
}

def __init__(self, name=None, timescale='epoch', interval=1, verbose=None):
def __init__(self, name=None, timescale='epoch', interval=1, verbose=None,
logging_function=print):
self.timescale = timescale
self.interval = interval
self.__name__ = name or self.__class__.__name__
self.priority = 0
self.verbose = verbose
self.run_verbosity = None

self.logging_function = logging_function

def start(self, net, stepper, verbose, named_data_iters):
if self.verbose is None:
self.run_verbosity = verbose
Expand All @@ -46,7 +48,7 @@ def start(self, net, stepper, verbose, named_data_iters):
def message(self, msg):
"""Print an output message if :attr:`run_verbosity` is True."""
if self.run_verbosity:
print("{} >> {}".format(self.__name__, msg))
self.logging_function("{} >> {}".format(self.__name__, msg))

def __call__(self, epoch_nr, update_nr, net, stepper, logs):
pass
Expand All @@ -59,8 +61,10 @@ class SaveNetwork(Hook):
the timescale and interval parameters.
"""

def __init__(self, filename, name=None, timescale='epoch', interval=1):
super(SaveNetwork, self).__init__(name, timescale, interval)
def __init__(self, filename, name=None, timescale='epoch', interval=1,
logging_function=print):
super(SaveNetwork, self).__init__(name, timescale, interval,
logging_function)
self.filename = filename

def __call__(self, epoch_nr, update_nr, net, stepper, logs):
Expand All @@ -79,9 +83,10 @@ class SaveBestNetwork(Hook):
__default_values__ = {'filename': None}

def __init__(self, log_name, filename=None, name=None,
criterion='max', timescale='epoch', interval=1, verbose=None):
super(SaveBestNetwork, self).__init__(name, timescale,
interval, verbose)
criterion='max', timescale='epoch', interval=1, verbose=None,
logging_function=print):
super(SaveBestNetwork, self).__init__(name, timescale, interval,
verbose, logging_function)
self.log_name = log_name
self.filename = filename
self.parameters = None
Expand Down Expand Up @@ -121,8 +126,10 @@ def load_parameters(self):


class SaveLogs(Hook):
def __init__(self, filename, name=None, timescale='epoch', interval=1):
super(SaveLogs, self).__init__(name, timescale, interval)
def __init__(self, filename, name=None, timescale='epoch', interval=1,
logging_function=print):
super(SaveLogs, self).__init__(name, timescale, interval,
logging_function=logging_function)
self.filename = filename

def __call__(self, epoch_nr, update_nr, net, stepper, logs):
Expand All @@ -144,9 +151,11 @@ def _save_recursively(group, logs):
class ModifyStepperAttribute(Hook):
"""Modify an attribute of the training stepper."""
def __init__(self, schedule, attr_name='learning_rate',
timescale='epoch', interval=1, name=None, verbose=None):
timescale='epoch', interval=1, name=None, verbose=None,
logging_function=print):
super(ModifyStepperAttribute, self).__init__(name, timescale,
interval, verbose)
interval, verbose,
logging_function)
self.schedule = schedule
self.attr_name = attr_name

Expand All @@ -168,11 +177,12 @@ class MonitorLayerParameters(Hook):
Monitor some properties of a layer.
"""
def __init__(self, layer_name, timescale='epoch',
interval=1, name=None, verbose=None):
interval=1, name=None, verbose=None, logging_function=print):
if name is None:
name = "MonitorParameters_{}".format(layer_name)
super(MonitorLayerParameters, self).__init__(name, timescale,
interval, verbose)
interval, verbose,
logging_function)
self.layer_name = layer_name

def start(self, net, stepper, verbose, named_data_iters):
Expand Down Expand Up @@ -202,11 +212,12 @@ class MonitorLayerGradients(Hook):
Monitor some statistics about all the gradients of a layer.
"""
def __init__(self, layer_name, timescale='epoch',
interval=1, name=None, verbose=None):
interval=1, name=None, verbose=None, logging_function=print):
if name is None:
name = "MonitorGradients_{}".format(layer_name)
super(MonitorLayerGradients, self).__init__(name, timescale,
interval, verbose)
interval, verbose,
logging_function)
self.layer_name = layer_name

def start(self, net, stepper, verbose, named_data_iters):
Expand All @@ -230,7 +241,7 @@ class MonitorLayerDeltas(Hook):
Monitor some statistics about all the deltas of a layer.
"""
def __init__(self, layer_name, timescale='epoch',
interval=1, name=None, verbose=None):
interval=1, name=None, verbose=None, logging_function=print):
if name is None:
name = "MonitorDeltas_{}".format(layer_name)
super(MonitorLayerDeltas, self).__init__(name, timescale,
Expand Down Expand Up @@ -275,11 +286,12 @@ class MonitorLayerInOuts(Hook):
Monitor some statistics about all the inputs and outputs of a layer.
"""
def __init__(self, layer_name, timescale='epoch',
interval=1, name=None, verbose=None):
interval=1, name=None, verbose=None, logging_function=print):
if name is None:
name = "MonitorInOuts_{}".format(layer_name)
super(MonitorLayerInOuts, self).__init__(name, timescale,
interval, verbose)
interval, verbose,
logging_function)
self.layer_name = layer_name

def start(self, net, stepper, verbose, named_data_iters):
Expand Down Expand Up @@ -310,9 +322,10 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs):

class StopAfterEpoch(Hook):
def __init__(self, max_epochs, timescale='epoch', interval=1, name=None,
verbose=None):
verbose=None, logging_function=print):
super(StopAfterEpoch, self).__init__(name, timescale,
interval, verbose)
interval, verbose,
logging_function)
self.max_epochs = max_epochs

def __call__(self, epoch_nr, update_nr, net, stepper, logs):
Expand All @@ -325,8 +338,9 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs):
class EarlyStopper(Hook):
__default_values__ = {'patience': 1}

def __init__(self, log_name, patience=1, name=None):
super(EarlyStopper, self).__init__(name, 'epoch', 1)
def __init__(self, log_name, patience=1, name=None, logging_function=print):
super(EarlyStopper, self).__init__(name, 'epoch', 1,
logging_function=logging_function)
self.log_name = log_name
self.patience = patience

Expand All @@ -346,8 +360,9 @@ class StopOnNan(Hook):
"""
def __init__(self, logs_to_check=(), check_parameters=True,
check_training_loss=True, name=None, timescale='epoch',
interval=1):
super(StopOnNan, self).__init__(name, timescale, interval)
interval=1, logging_function=print):
super(StopOnNan, self).__init__(name, timescale, interval,
logging_function=logging_function)
self.logs_to_check = ([logs_to_check] if isinstance(logs_to_check,
string_types)
else logs_to_check)
Expand Down Expand Up @@ -378,8 +393,10 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs):

class InfoUpdater(Hook):
""" Save the information from logs to the Sacred custom info dict"""
def __init__(self, run, name=None, timescale='epoch', interval=1):
super(InfoUpdater, self).__init__(name, timescale, interval)
def __init__(self, run, name=None, timescale='epoch', interval=1,
logging_function=print):
super(InfoUpdater, self).__init__(name, timescale, interval,
logging_function=logging_function)
self.run = run

def __call__(self, epoch_nr, update_nr, net, stepper, logs):
Expand All @@ -393,8 +410,9 @@ def __call__(self, epoch_nr, update_nr, net, stepper, logs):

class MonitorLoss(Hook):
def __init__(self, iter_name, timescale='epoch', interval=1, name=None,
verbose=None):
super(MonitorLoss, self).__init__(name, timescale, interval, verbose)
verbose=None, logging_function=print):
super(MonitorLoss, self).__init__(name, timescale, interval, verbose,
logging_function=logging_function)
self.iter_name = iter_name
self.iter = None

Expand Down Expand Up @@ -439,10 +457,10 @@ class MonitorScores(Hook):

"""
def __init__(self, iter_name, scorers, timescale='epoch', interval=1,
name=None, verbose=None):
name=None, verbose=None, logging_function=print):

super(MonitorScores, self).__init__(name, timescale, interval,
verbose)
verbose, logging_function)
self.iter_name = iter_name
self.iter = None
self.scorers = scorers
Expand All @@ -466,7 +484,8 @@ class StopOnSigQuit(Hook):
"""
__undescribed__ = {'quit': False}

def __init__(self, name=None, timescale='epoch', interval=1, verbose=None):
def __init__(self, name=None, timescale='epoch', interval=1, verbose=None,
logging_function=print):
super(StopOnSigQuit, self).__init__(name, timescale, interval,
verbose=verbose)
self.quit = False
Expand Down Expand Up @@ -527,10 +546,11 @@ class BokehVisualizer(Hook):
and acts as a fallback verbosity for the used data iterator.
If not set it defaults to the verbosity setting of the trainer.
"""
def __init__(self, log_names, filename=None, timescale='epoch', interval=1,
name=None, verbose=None):
def __init__(self, log_names, filename=None, timescale='epoch',
interval=1, name=None, verbose=None,
logging_function=print):
super(BokehVisualizer, self).__init__(name, timescale, interval,
verbose)
verbose, logging_function)

if isinstance(log_names, basestring):
self.log_names = [log_names]
Expand Down
35 changes: 23 additions & 12 deletions brainstorm/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Trainer(Describable):
}
__default_values__ = {'verbose': True}

def __init__(self, stepper, verbose=True):
def __init__(self, stepper, verbose=True, logging_function=print):
"""Create a new Trainer.

Args:
Expand All @@ -41,6 +41,7 @@ def __init__(self, stepper, verbose=True):
self.current_update_nr = 0
self.logs = {}
self.results = {}
self.logging_function = logging_function

def add_hook(self, hook):
"""Add a hook to this trainer.
Expand Down Expand Up @@ -70,7 +71,11 @@ def train(self, net, training_data_iter, **named_data_iters):
iterators.
"""
if self.verbose:
print('\n\n', 10 * '- ', "Before Training", 10 * ' -')
if self.logging_function == print:
self.logging_function('\n\n' + 10 * '- ' + "Before Training" +
10 * ' -')
else:
self.logging_function(10 * '- ' + "Before Training" + 10 * ' -')
assert set(training_data_iter.data_shapes.keys()) == set(
net.buffer.Input.outputs.keys()), \
"The data names provided by the training data iterator {} do not "\
Expand All @@ -91,8 +96,12 @@ def train(self, net, training_data_iter, **named_data_iters):
train_scores.update({n: [] for n in net.get_loss_values()})

if self.verbose:
print('\n\n', 12 * '- ', "Epoch", self.current_epoch_nr,
12 * ' -')
if self.logging_function == print:
self.logging_function('\n\n' + 12 * '- ' + "Epoch" +
str(self.current_epoch_nr) + 12 * ' -')
else:
self.logging_function(12 * '- ' + "Epoch" +
str(self.current_epoch_nr) + 12 * ' -')
iterator = training_data_iter(handler=net.handler)
for _ in run_network(net, iterator):
self.current_update_nr += 1
Expand Down Expand Up @@ -133,8 +142,9 @@ def _start_hooks(self, net, named_data_iters):
hook.start(net, self.stepper, self.verbose,
named_data_iters)
except Exception:
print('An error occurred while starting the "{}" hook:'
.format(name), file=sys.stderr)
self.logging_function(
'An error occurred while starting the "{}" hook:'
.format(name), file=sys.stderr)
raise

def _emit_hooks(self, net, timescale, logs=None):
Expand Down Expand Up @@ -163,9 +173,10 @@ def _call_hook(self, hook, net):
except StopIteration as err:
return getattr(err, 'value', None), True
except Exception as e:
print('An error occurred while calling the "{}" hook:'
.format(hook.__name__), file=sys.stderr)
print(traceback.format_exc())
self.logging_function(
'An error occurred while calling the "{}" hook:'
.format(hook.__name__), file=sys.stderr)
self.logging_function(traceback.format_exc())
raise e

def _add_log(self, name, val, verbose=None, logs=None, indent=0):
Expand All @@ -178,14 +189,14 @@ def _add_log(self, name, val, verbose=None, logs=None, indent=0):

if isinstance(val, dict):
if verbose:
print(" " * indent + name)
self.logging_function(" " * indent + name)
logs[name] = dict() if name not in logs else logs[name]

for k, v in val.items():
self._add_log(k, v, verbose, logs[name], indent + 2)
else:
if verbose:
print(" " * indent + ("{0:%d}: {1}" % (40 - indent))
.format(name, val))
self.logging_function(" " * indent + ("{0:%d}: {1}" %
(40 - indent)).format(name, val))
logs[name] = [] if name not in logs else logs[name]
logs[name].append(val)