diff --git a/pie/default_settings.json b/pie/default_settings.json
index cc5279f..da1a848 100644
--- a/pie/default_settings.json
+++ b/pie/default_settings.json
@@ -124,10 +124,17 @@
"optimizer": "Adam", // optimizer type
"clip_norm": 5.0, // clip norm of gradient up to this value
"lr": 0.001,
- "lr_factor": 0.75, // lr schedule (decrease lr by this factor after `lr_patience` epochs
- // without improvement on dev-set data)
- "min_lr": 0.000001, // minimum learning rate
- "lr_patience": 2, // patience for lr schedule
+
+ "lr_scheduler": "ReduceLROnPlateau",
+ "lr_scheduler_delay": 0, // Number of steps without using the lr_scheduler
+ "lr_scheduler_params": {
+ // needs to be adapted if lr_scheduler is not ReduceLROnPlateau
+ "mode": "max",
+ "factor": 0.75,
+ "patience": 2,
+ "min_lr": 0.000001
+ },
+
"checks_per_epoch": 1, // check model on dev-set so many times during epoch
// * Model hyperparameters
diff --git a/pie/settings.py b/pie/settings.py
index d2a266a..969d584 100644
--- a/pie/settings.py
+++ b/pie/settings.py
@@ -106,6 +106,12 @@ def check_settings(settings):
if not has_target:
raise ValueError("Needs at least one target task")
+ # backward compatibility
+ if "lr_patience" in settings:
+ settings.lr_scheduler_params['patience'] = settings.lr_patience
+ if "lr_factor" in settings:
+ settings.lr_scheduler_params['factor'] = settings.lr_factor
+
return settings
diff --git a/pie/trainer.py b/pie/trainer.py
index 2dceb7e..4dffcf3 100644
--- a/pie/trainer.py
+++ b/pie/trainer.py
@@ -1,4 +1,5 @@
+import inspect
import os
import uuid
import logging
@@ -7,7 +8,6 @@
import random
import tempfile
-
import tqdm
import torch
@@ -150,19 +150,43 @@ def get_weights(self):
class LRScheduler(object):
- def __init__(self, optimizer, threshold=0.0, **kwargs):
- self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
- optimizer, mode='max', threshold=threshold, **kwargs)
+ def __init__(self, optimizer,
+ lr_scheduler='ReduceLROnPlateau',
+ delay=0, **kwargs):
+
+ self.nb_steps: int = 0
+ self.delay: int = delay
+ lr_scheduler_cls = getattr(optim.lr_scheduler, lr_scheduler)
+ params = inspect.signature(lr_scheduler_cls).parameters
+ self.lr_scheduler = lr_scheduler_cls(
+ optimizer,
+ # pick kwargs that fit the selected scheduler
+ **{kwarg: val for kwarg, val in kwargs.items() if kwarg in params})
def step(self, score):
- self.lr_scheduler.step(score)
+ self.nb_steps += 1
+
+ # apply the step() method of the lr_scheduler when delay is reached
+ if self.delay and self.nb_steps <= self.delay:
+ return
+
+ if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(score)
+ else:
+ self.lr_scheduler.step()
def __repr__(self):
- return '' \
- .format(self.lr_scheduler.optimizer.param_groups[0]['lr'],
- self.lr_scheduler.num_bad_epochs,
- self.lr_scheduler.patience,
- self.lr_scheduler.threshold)
+ res = ''
+ return res
class Trainer(object):
@@ -200,8 +224,10 @@ def __init__(self, settings, model, dataset, num_instances):
self.task_scheduler = TaskScheduler(settings)
self.lr_scheduler = LRScheduler(
- self.optimizer, factor=settings.lr_factor,
- patience=settings.lr_patience, min_lr=settings.min_lr)
+ self.optimizer,
+ lr_scheduler=settings.lr_scheduler,
+ delay=settings.lr_scheduler_delay,
+ **settings.lr_scheduler_params)
if settings.verbose:
print()