-
Notifications
You must be signed in to change notification settings - Fork 10
Feature: allow picking whatever pytorch lr_scheduler #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
|
|
||
| import inspect | ||
| import os | ||
| import uuid | ||
| import logging | ||
|
|
@@ -7,7 +8,6 @@ | |
| import random | ||
| import tempfile | ||
|
|
||
|
|
||
| import tqdm | ||
|
|
||
| import torch | ||
|
|
@@ -150,19 +150,43 @@ def get_weights(self): | |
|
|
||
|
|
||
| class LRScheduler(object): | ||
| def __init__(self, optimizer, threshold=0.0, **kwargs): | ||
| self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( | ||
| optimizer, mode='max', threshold=threshold, **kwargs) | ||
| def __init__(self, optimizer, | ||
| lr_scheduler='ReduceLROnPlateau', | ||
| delay=0, **kwargs): | ||
|
|
||
| self.nb_steps: int = 0 | ||
| self.delay: int = delay | ||
| lr_scheduler_cls = getattr(optim.lr_scheduler, lr_scheduler) | ||
| params = inspect.signature(lr_scheduler_cls).parameters | ||
| self.lr_scheduler = lr_scheduler_cls( | ||
| optimizer, | ||
| # pick kwargs that fit the selected scheduler | ||
| **{kwarg: val for kwarg, val in kwargs.items() if kwarg in params}) | ||
|
|
||
| def step(self, score): | ||
| self.lr_scheduler.step(score) | ||
| self.nb_steps += 1 | ||
|
|
||
| # apply the step() method of the lr_scheduler when delay is reached | ||
| if self.delay and self.nb_steps <= self.delay: | ||
| return | ||
|
|
||
| if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau): | ||
| self.lr_scheduler.step(score) | ||
| else: | ||
| self.lr_scheduler.step() | ||
|
|
||
| def __repr__(self): | ||
| return '<LrScheduler lr="{:g}" steps="{}" patience="{}" threshold="{}"/>' \ | ||
| .format(self.lr_scheduler.optimizer.param_groups[0]['lr'], | ||
| self.lr_scheduler.num_bad_epochs, | ||
| self.lr_scheduler.patience, | ||
| self.lr_scheduler.threshold) | ||
| res = '<LrScheduler="{}" lr="{:g}" delay="{}" steps="{}"'.format( | ||
| self.lr_scheduler.__class__.__name__, | ||
| self.lr_scheduler.optimizer.param_groups[0]['lr'], | ||
| self.delay, | ||
| self.nb_steps) | ||
| for key in dir(self.lr_scheduler): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are losing some manual filter on what we want to show here
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but otherwise you have to do it manually per class...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know... Couldn't we curate for specific scheduler ? And automatically show for others ?
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as long as it doesn't introduce too much clutter, feel free! |
||
| val = getattr(self.lr_scheduler, key) | ||
| if (not key.startswith('__')) and isinstance(val, (str, float, int)): | ||
| res += ' {}="{}"'.format(key, val) | ||
| res += '/>' | ||
| return res | ||
|
|
||
|
|
||
| class Trainer(object): | ||
|
|
@@ -200,8 +224,10 @@ def __init__(self, settings, model, dataset, num_instances): | |
|
|
||
| self.task_scheduler = TaskScheduler(settings) | ||
| self.lr_scheduler = LRScheduler( | ||
| self.optimizer, factor=settings.lr_factor, | ||
| patience=settings.lr_patience, min_lr=settings.min_lr) | ||
| self.optimizer, | ||
| lr_scheduler=settings.lr_scheduler, | ||
| delay=settings.lr_scheduler_delay, | ||
| **settings.lr_scheduler_params) | ||
|
|
||
| if settings.verbose: | ||
| print() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I definitely like this part.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But, as I mentioned in my other PR, this basically breaks when initiating something else than a
Adam.Another option would be to have
{ "lr_scheduler_params": { "RecudeLROnPlateau": { "mode": "max", ... } } }and this way provide defaults for things we tested ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But those parameters are not specific from the optimizer but from the scheduler, right? I have made some changes that should address this issue...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am sorry, I meant scheduler. I am a bit tired right now ^^. So I corrected the comment up there.