From 37d56d7708e08c86b7b5b31dca1a19831cd715d5 Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sat, 29 Jun 2024 17:47:13 +0200 Subject: [PATCH 1/9] Adding schedulers --- deeplay/applications/application.py | 16 ++- deeplay/module.py | 7 ++ deeplay/schedulers/__init__.py | 5 + deeplay/schedulers/constant.py | 18 ++++ deeplay/schedulers/linear.py | 42 ++++++++ deeplay/schedulers/loglinear.py | 26 +++++ deeplay/schedulers/scheduler.py | 47 ++++++++ deeplay/schedulers/sequence.py | 36 +++++++ deeplay/tests/schedulers/__init__.py | 0 deeplay/tests/schedulers/test_constant.py | 89 +++++++++++++++ deeplay/tests/schedulers/test_linear.py | 120 +++++++++++++++++++++ deeplay/tests/schedulers/test_scheduler.py | 12 +++ setup.py | 2 +- 13 files changed, 415 insertions(+), 5 deletions(-) create mode 100644 deeplay/schedulers/__init__.py create mode 100644 deeplay/schedulers/constant.py create mode 100644 deeplay/schedulers/linear.py create mode 100644 deeplay/schedulers/loglinear.py create mode 100644 deeplay/schedulers/scheduler.py create mode 100644 deeplay/schedulers/sequence.py create mode 100644 deeplay/tests/schedulers/__init__.py create mode 100644 deeplay/tests/schedulers/test_constant.py create mode 100644 deeplay/tests/schedulers/test_linear.py create mode 100644 deeplay/tests/schedulers/test_scheduler.py diff --git a/deeplay/applications/application.py b/deeplay/applications/application.py index e80f4dc1..754201ec 100644 --- a/deeplay/applications/application.py +++ b/deeplay/applications/application.py @@ -243,12 +243,15 @@ def configure_optimizers(self): ) from e def training_step(self, batch, batch_idx): + x, y = self.train_preprocess(batch) y_hat = self(x) loss = self.compute_loss(y_hat, y) if not isinstance(loss, dict): loss = {"loss": loss} + assert "loss" in loss, "the output of compute_loss should contain a 'loss' key" + for name, v in loss.items(): self.log( f"train_{name}", @@ -263,7 +266,7 @@ def training_step(self, batch, batch_idx): "train", y_hat, y, on_step=True, on_epoch=True, prog_bar=True, logger=True ) - return sum(loss.values()) + return loss["loss"] def validation_step(self, batch, batch_idx): x, y = self.val_preprocess(batch) @@ -290,7 +293,7 @@ def validation_step(self, batch, batch_idx): prog_bar=True, logger=True, ) - return sum(loss.values()) + return loss["loss"] if "loss" in loss else 0 def test_step(self, batch, batch_idx): x, y = self.test_preprocess(batch) @@ -318,7 +321,7 @@ def test_step(self, batch, batch_idx): logger=True, ) - return sum(loss.values()) + return loss["loss"] if "loss" in loss else 0 def predict_step(self, batch, batch_idx, dataloader_idx=None): if isinstance(batch, (list, tuple)): @@ -356,12 +359,17 @@ def trainer(self, trainer): if module is self: continue try: - if hasattr(module, "trainer") and module.trainer is not trainer: + if isinstance(module, L.LightningModule) or hasattr(module, "trainer"): + print("Aattaching trainer to", module) module.trainer = trainer + print("Aattached trainer to", module, module.trainer) + except RuntimeError: # hasattr can raise RuntimeError if the module is not attached to a trainer if isinstance(module, L.LightningModule): + print("Battaching trainer to", module) module.trainer = trainer + print("Battached trainer to", module) @staticmethod def clone_metrics(metrics: T) -> T: diff --git a/deeplay/module.py b/deeplay/module.py index e726c061..e23cf0d0 100644 --- a/deeplay/module.py +++ b/deeplay/module.py @@ -1314,6 +1314,13 @@ def __setattr__(self, name, value): # # ensure that logs are stored in the correct place # value.set_root_module(self.root_module) + def __getattr__(self, name): + from deeplay.schedulers import BaseScheduler + x = super().__getattr__(name) + if self._has_built and isinstance(x, BaseScheduler): + return x.__get__(self, type(self)) + return x + def _select_string(self, structure, selections, select, ellipsis=False): selects = select.split(",") selects = [select.strip() for select in selects] diff --git a/deeplay/schedulers/__init__.py b/deeplay/schedulers/__init__.py new file mode 100644 index 00000000..421de1b2 --- /dev/null +++ b/deeplay/schedulers/__init__.py @@ -0,0 +1,5 @@ +from .scheduler import BaseScheduler +from .linear import LinearScheduler +from .constant import ConstantScheduler +from .loglinear import LogLinearScheduler +from .sequence import SchedulerSequence diff --git a/deeplay/schedulers/constant.py b/deeplay/schedulers/constant.py new file mode 100644 index 00000000..87297929 --- /dev/null +++ b/deeplay/schedulers/constant.py @@ -0,0 +1,18 @@ +from . import BaseScheduler + + +class ConstantScheduler(BaseScheduler): + """Sheduler that returns constant value.""" + + def __init__(self, value, on_epoch=False): + super().__init__(on_epoch) + self.value = value + + def __call__(self, step): + return self.value + + def __repr__(self): + return f"{self.__class__.__name__}({self.value})" + + def __str__(self): + return repr(self) diff --git a/deeplay/schedulers/linear.py b/deeplay/schedulers/linear.py new file mode 100644 index 00000000..0cf65c49 --- /dev/null +++ b/deeplay/schedulers/linear.py @@ -0,0 +1,42 @@ +from .scheduler import BaseScheduler + + +class LinearScheduler(BaseScheduler): + """Scheduler that returns linearly changing value from start_value to end_value. + + For steps beyond n_steps, returns end_value. + For steps before 0, returns start_value. + + Parameters + ---------- + start_value : float + Initial value of the scheduler. + end_value : float + Final value of the scheduler. + n_steps : int + Number of steps to reach end_value. + on_epoch : bool + If True, the step is taken from the epoch counter of the trainer. + Otherwise, the step is taken from the global step counter of the trainer. + """ + + def __init__(self, start_value, end_value, n_steps, on_epoch=False): + super().__init__(on_epoch) + self.start_value = start_value + self.end_value = end_value + self.n_steps = n_steps + + def __call__(self, step): + if step < 0: + return self.start_value + if step >= self.n_steps: + return self.end_value + return ( + self.start_value + (self.end_value - self.start_value) * step / self.n_steps + ) + + def __repr__(self): + return f"{self.__class__.__name__}({self.start_value}, {self.end_value}, {self.n_steps})" + + def __str__(self): + return repr(self) diff --git a/deeplay/schedulers/loglinear.py b/deeplay/schedulers/loglinear.py new file mode 100644 index 00000000..64ef6046 --- /dev/null +++ b/deeplay/schedulers/loglinear.py @@ -0,0 +1,26 @@ +from . import BaseScheduler + + +class LogLinearScheduler(BaseScheduler): + """Scheduler that returns log-linearly changing value from start_value to end_value. + + For steps beyond n_steps, returns end_value.""" + + def __init__(self, start_value, end_value, n_steps, on_epoch=False): + super().__init__(on_epoch) + self.start_value = start_value + self.end_value = end_value + self.n_steps = n_steps + + def __call__(self, step): + if step >= self.n_steps: + return self.end_value + return self.start_value * (self.end_value / self.start_value) ** ( + step / self.n_steps + ) + + def __repr__(self): + return f"{self.__class__.__name__}({self.start_value}, {self.end_value}, {self.n_steps})" + + def __str__(self): + return repr(self) diff --git a/deeplay/schedulers/scheduler.py b/deeplay/schedulers/scheduler.py new file mode 100644 index 00000000..dafce25c --- /dev/null +++ b/deeplay/schedulers/scheduler.py @@ -0,0 +1,47 @@ +import lightning as L + +from deeplay.module import DeeplayModule +from deeplay.trainer import Trainer + + +class BaseScheduler(DeeplayModule, L.LightningModule): + """Base class for annealers.""" + + step: int + + def __init__(self, on_epoch=False): + super().__init__() + self.on_epoch = on_epoch + self._step = 0 + self._x = None + + def set_step(self, step): + self._step = step + self._x = self(step) + + def update(self): + current_step = self._step + + if self._trainer: + updated_step = ( + self.trainer.current_epoch + if self.on_epoch + else self.trainer.global_step + ) + else: + updated_step = self._step + + if updated_step != current_step or self._x is None: + self.set_step(updated_step) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + self.update() + return self._x + + def __set__(self, obj, value): + self._x = value + + def __call__(self, step): + raise NotImplementedError diff --git a/deeplay/schedulers/sequence.py b/deeplay/schedulers/sequence.py new file mode 100644 index 00000000..4a3b25ce --- /dev/null +++ b/deeplay/schedulers/sequence.py @@ -0,0 +1,36 @@ +from .scheduler import BaseScheduler + + +class SchedulerSequence(BaseScheduler): + """Scheduler that returns value from one of the schedulers in the chain. + + The scheduler is chosen based on the current step. + """ + + def __init__(self, on_epoch=False): + super().__init__(on_epoch) + self.schedulers = [] + + def add(self, scheduler, n_steps=None): + if n_steps is None: + assert hasattr( + scheduler, "n_steps" + ), "For a scheduler without n_steps, n_steps must be specified" + n_steps = scheduler.n_steps + + self.schedulers.append((n_steps, scheduler)) + + def __call__(self, step): + for n_steps, scheduler in self.schedulers: + if step < n_steps: + return scheduler(step) + step -= n_steps + + final_step, final_scheduler = self.schedulers[-1] + return final_scheduler(final_step + step) + + def __repr__(self): + return f"{self.__class__.__name__}({self.schedulers})" + + def __str__(self): + return repr(self) diff --git a/deeplay/tests/schedulers/__init__.py b/deeplay/tests/schedulers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deeplay/tests/schedulers/test_constant.py b/deeplay/tests/schedulers/test_constant.py new file mode 100644 index 00000000..d9d3652c --- /dev/null +++ b/deeplay/tests/schedulers/test_constant.py @@ -0,0 +1,89 @@ +import unittest + +import torch +from deeplay.applications.application import Application +from deeplay.components.mlp import MultiLayerPerceptron +from deeplay.external.optimizers.adam import Adam +from deeplay.module import DeeplayModule +from deeplay.schedulers import ConstantScheduler + + +class TestConstantScheduler(unittest.TestCase): + + def test_scheduler_build(self): + scheduler = ConstantScheduler(1.0) + scheduler.build() + + self.assertEqual(scheduler._step, 0) + self.assertIsNone(scheduler._x) + + def test_scheduler_step(self): + scheduler = ConstantScheduler(1.0) + scheduler.build() + + steps = [-999999999, -1, 0, 999, 9999999999999] + for step in steps: + value = scheduler(step) + self.assertEqual(value, 1.0) + + def test_scheduler_attached_to_module(self): + + class Module(DeeplayModule): + def __init__(self): + super().__init__() + self.x = ConstantScheduler(1.0) + + module = Module() + + # Before build, x is the scheduler + self.assertIsInstance(module.x, ConstantScheduler) + + module.build() + + # After build, x is the value of the scheduler + self.assertEqual(module.x, 1.0) + + def test_scheduler_attached_configure(self): + + class Module(DeeplayModule): + def __init__(self): + super().__init__() + self.x = ConstantScheduler(1.0) + + module = Module() + + # Before build, x is the scheduler + self.assertIsInstance(module.x, ConstantScheduler) + + module.x.configure(value=2.0) + + module.build() + + # After build, x is the value of the scheduler + self.assertEqual(module.x, 2.0) + + def test_scheduler_trainer(self): + + class Module(Application): + def __init__(self): + super().__init__(optimizer=Adam(lr=1.0), loss=torch.nn.MSELoss()) + self.x = ConstantScheduler(1.0) + self.net = MultiLayerPerceptron(1, [1], 1) + + def forward(_self, x): + self.assertEqual(_self.x, 1.0) + return _self.net(x) * _self.x + + module = Module() + + module.build() + + x = torch.randn(10, 1) + y = torch.randn(10, 1) + + module.fit((x, y), max_steps=10) + module._has_built = False + + self.assertEqual(module.x.trainer, module.trainer) + self.assertEqual(module.x._step, 9) + self.assertEqual(module.x._x, 1.0) diff --git a/deeplay/tests/schedulers/test_linear.py b/deeplay/tests/schedulers/test_linear.py new file mode 100644 index 00000000..b846185a --- /dev/null +++ b/deeplay/tests/schedulers/test_linear.py @@ -0,0 +1,120 @@ +import unittest + +import torch +from deeplay.applications.application import Application +from deeplay.components.mlp import MultiLayerPerceptron +from deeplay.external.optimizers.adam import Adam +from deeplay.module import DeeplayModule +from deeplay.schedulers import LinearScheduler + + +class TestConstantScheduler(unittest.TestCase): + + def test_scheduler_build(self): + scheduler = LinearScheduler(0.0, 1.0, 10) + scheduler.build() + + self.assertEqual(scheduler._step, 0) + self.assertIsNone(scheduler._x) + + def test_scheduler_step(self): + scheduler = LinearScheduler(0, 1, 10) + scheduler.build() + + steps = [-1, 0, 1, 9, 10, 20] + exp_values = [0.0, 0.0, 0.1, 0.9, 1.0, 1.0] + for step, exp_value in zip(steps, exp_values): + value = scheduler(step) + self.assertEqual(value, exp_value) + + def test_scheduler_step_negative(self): + scheduler = LinearScheduler(0, -1, 10) + scheduler.build() + + steps = [-1, 0, 1, 9, 10, 20] + exp_values = [0.0, 0.0, -0.1, -0.9, -1.0, -1.0] + for step, exp_value in zip(steps, exp_values): + value = scheduler(step) + self.assertEqual(value, exp_value) + + def test_scheduler_step_constant(self): + scheduler = LinearScheduler(0, 0, 10) + scheduler.build() + + steps = [-1, 0, 1, 9, 10, 20] + exp_values = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + for step, exp_value in zip(steps, exp_values): + value = scheduler(step) + self.assertEqual(value, exp_value) + + def test_scheduler_step_zero_steps(self): + scheduler = LinearScheduler(0, 1, 0) + scheduler.build() + + steps = [-1, 0, 1, 9, 10, 20] + exp_values = [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] + for step, exp_value in zip(steps, exp_values): + value = scheduler(step) + self.assertEqual(value, exp_value) + + def test_scheduler_attached_to_module(self): + + class Module(DeeplayModule): + def __init__(self): + super().__init__() + self.x = LinearScheduler(0.0, 1.0, 10) + + module = Module() + + # Before build, x is the scheduler + self.assertIsInstance(module.x, LinearScheduler) + + module.build() + + # After build, x is the value of the scheduler + self.assertEqual(module.x, 0.0) + + def test_scheduler_attached_configure(self): + + class Module(DeeplayModule): + def __init__(self): + super().__init__() + self.x = LinearScheduler(0.0, 1.0, 10) + + module = Module() + + # Before build, x is the scheduler + self.assertIsInstance(module.x, LinearScheduler) + + module.x.configure(start_value=1.0, end_value=2.0) + + module.build() + + # After build, x is the value of the scheduler + self.assertEqual(module.x, 1.0) + + def test_scheduler_trainer(self): + + class Module(Application): + def __init__(self): + super().__init__(optimizer=Adam(lr=1.0), loss=torch.nn.MSELoss()) + self.x = LinearScheduler(0.0, 1.0, 10) + self.net = MultiLayerPerceptron(1, [1], 1) + + def forward(_self, x): + self.assertEqual(_self.x, _self.trainer.global_step / 10) + return _self.net(x) * _self.x + + module = Module() + + module.build() + + x = torch.randn(10, 1) + y = torch.randn(10, 1) + + module.fit((x, y), max_steps=10) + module._has_built = False + + self.assertEqual(module.x.trainer, module.trainer) + self.assertEqual(module.x._step, 9) + self.assertEqual(module.x._x, 0.9) diff --git a/deeplay/tests/schedulers/test_scheduler.py b/deeplay/tests/schedulers/test_scheduler.py new file mode 100644 index 00000000..7439ab1b --- /dev/null +++ b/deeplay/tests/schedulers/test_scheduler.py @@ -0,0 +1,12 @@ +import unittest +from deeplay.schedulers import BaseScheduler + + +class TestBaseScheduler(unittest.TestCase): + + def test_scheduler_defaults(self): + scheduler = BaseScheduler() + scheduler.build() + + self.assertEqual(scheduler._step, 0) + self.assertIsNone(scheduler._x) diff --git a/setup.py b/setup.py index b2536714..2b28dbbd 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="deeplay", - version="0.0.7", + version="0.1.0", license="MIT", packages=find_packages(), author=( From 408b25bdbfe813b884098f8c4ee74693e46901a1 Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sat, 29 Jun 2024 18:43:12 +0200 Subject: [PATCH 2/9] Refactor module.py and loglinear.py for scheduling functionality --- deeplay/module.py | 32 ++++++++++++++++++++++++++++++++ deeplay/schedulers/loglinear.py | 7 +++++++ 2 files changed, 39 insertions(+) diff --git a/deeplay/module.py b/deeplay/module.py index e23cf0d0..4aa35517 100644 --- a/deeplay/module.py +++ b/deeplay/module.py @@ -723,6 +723,37 @@ def replace(self, target: str, replacement: nn.Module): self._modules[target] = replacement + @after_init + def schedule(self, attr: str, scheduler: "BaseScheduler"): + from deeplay.schedulers import BaseScheduler + setattr(self, attr, scheduler) + + @after_init + def schedule_linear( + self, + attr: str, + start_value: float, + end_value: float, + n_steps: int, + on_epoch: bool = False, + ): + from deeplay.schedulers import LinearScheduler + + setattr(self, attr, LinearScheduler(start_value, end_value, n_steps, on_epoch)) + + @after_init + def schedule_loglinear( + self, + attr: str, + start_value: float, + end_value: float, + n_steps: int, + on_epoch: bool = False, + ): + from deeplay.schedulers import LogLinearScheduler + + setattr(self, attr, LogLinearScheduler(start_value, end_value, n_steps, on_epoch)) + @stateful def configure(self, *args: Any, **kwargs: Any): """ @@ -1316,6 +1347,7 @@ def __setattr__(self, name, value): def __getattr__(self, name): from deeplay.schedulers import BaseScheduler + x = super().__getattr__(name) if self._has_built and isinstance(x, BaseScheduler): return x.__get__(self, type(self)) diff --git a/deeplay/schedulers/loglinear.py b/deeplay/schedulers/loglinear.py index 64ef6046..ffc7a362 100644 --- a/deeplay/schedulers/loglinear.py +++ b/deeplay/schedulers/loglinear.py @@ -1,4 +1,5 @@ from . import BaseScheduler +import numpy as np class LogLinearScheduler(BaseScheduler): @@ -8,6 +9,12 @@ class LogLinearScheduler(BaseScheduler): def __init__(self, start_value, end_value, n_steps, on_epoch=False): super().__init__(on_epoch) + assert np.sign(start_value) == np.sign( + end_value + ), "Start and end values must have the same sign" + assert start_value != 0, "Start value must be non-zero" + assert end_value != 0, "End value must be non-zero" + assert n_steps > 0, "Number of steps must be greater than 0" self.start_value = start_value self.end_value = end_value self.n_steps = n_steps From ac1f895ae0c22c6443bc07ea412687730f4cbecb Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sat, 29 Jun 2024 18:43:33 +0200 Subject: [PATCH 3/9] Refactor schedule method in DeeplayModule to remove type hint for scheduler parameter --- deeplay/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplay/module.py b/deeplay/module.py index 4aa35517..72357028 100644 --- a/deeplay/module.py +++ b/deeplay/module.py @@ -724,7 +724,7 @@ def replace(self, target: str, replacement: nn.Module): self._modules[target] = replacement @after_init - def schedule(self, attr: str, scheduler: "BaseScheduler"): + def schedule(self, attr: str, scheduler): from deeplay.schedulers import BaseScheduler setattr(self, attr, scheduler) From 24de6453b98ef6a35955abe8b7ca604d77d2bb45 Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sat, 29 Jun 2024 18:47:56 +0200 Subject: [PATCH 4/9] Refactor schedule method in DeeplayModule to accept multiple schedulers --- deeplay/module.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deeplay/module.py b/deeplay/module.py index 72357028..5fe98ce5 100644 --- a/deeplay/module.py +++ b/deeplay/module.py @@ -724,9 +724,9 @@ def replace(self, target: str, replacement: nn.Module): self._modules[target] = replacement @after_init - def schedule(self, attr: str, scheduler): - from deeplay.schedulers import BaseScheduler - setattr(self, attr, scheduler) + def schedule(self, **schedulers): + for attr, scheduler in schedulers.items(): + setattr(self, attr, scheduler) @after_init def schedule_linear( @@ -752,7 +752,9 @@ def schedule_loglinear( ): from deeplay.schedulers import LogLinearScheduler - setattr(self, attr, LogLinearScheduler(start_value, end_value, n_steps, on_epoch)) + setattr( + self, attr, LogLinearScheduler(start_value, end_value, n_steps, on_epoch) + ) @stateful def configure(self, *args: Any, **kwargs: Any): From 950c162e3b5ad5acf7e596bf64e3eff233714e3d Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sun, 30 Jun 2024 17:09:40 +0200 Subject: [PATCH 5/9] Refactor test_decorators.py to remove unnecessary code --- deeplay/applications/application.py | 3 +-- deeplay/tests/test_decorators.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/deeplay/applications/application.py b/deeplay/applications/application.py index 754201ec..449b199a 100644 --- a/deeplay/applications/application.py +++ b/deeplay/applications/application.py @@ -360,9 +360,8 @@ def trainer(self, trainer): continue try: if isinstance(module, L.LightningModule) or hasattr(module, "trainer"): - print("Aattaching trainer to", module) + module.trainer = trainer - print("Aattached trainer to", module, module.trainer) except RuntimeError: # hasattr can raise RuntimeError if the module is not attached to a trainer diff --git a/deeplay/tests/test_decorators.py b/deeplay/tests/test_decorators.py index 58475a8d..b35a8157 100644 --- a/deeplay/tests/test_decorators.py +++ b/deeplay/tests/test_decorators.py @@ -51,8 +51,6 @@ def __init__(self): # module["encoder"] -# print("after:", module.encoder.p) - class DummyClass: ... From 70eabc2a57d403fdea6927e50872cbd208fd5f28 Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Tue, 13 Aug 2024 12:24:06 +0200 Subject: [PATCH 6/9] Refactor RandomRotation2d method in transforms.py for improved functionality --- .../applications/detection/lodestar/transforms.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/deeplay/applications/detection/lodestar/transforms.py b/deeplay/applications/detection/lodestar/transforms.py index 157eea5c..6e794ceb 100644 --- a/deeplay/applications/detection/lodestar/transforms.py +++ b/deeplay/applications/detection/lodestar/transforms.py @@ -90,6 +90,15 @@ def _backward(x, angle, indices): mat2d[:, indices[1], indices[0]] = -torch.sin(-angle) mat2d[:, indices[0], indices[1]] = torch.sin(-angle) mat2d[:, indices[0], indices[0]] = torch.cos(-angle) - out = torch.matmul(x.unsqueeze(1), mat2d).squeeze(1) - - return out + + if len(x.size()) == 2: + # (B, C) -> (B, 1, C) + x = x.unsqueeze(1) + return torch.matmul(x, mat2d).squeeze(1) + + x_expanded = x.view(x.size(0), 1, x.size(1), -1) + y = torch.einsum("bijm,bjk->bikm", x_expanded, mat2d) + + return y.view(x.size()) + + \ No newline at end of file From 7d0a31818c9494c7f10d8dfd732b51356527f666 Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Fri, 27 Sep 2024 17:44:33 +0200 Subject: [PATCH 7/9] Refactor Application build method to include progress bar and multiple schedulers --- deeplay/applications/application.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deeplay/applications/application.py b/deeplay/applications/application.py index 449b199a..62bc4b5b 100644 --- a/deeplay/applications/application.py +++ b/deeplay/applications/application.py @@ -84,6 +84,7 @@ def fit( val_batch_size=None, val_steps_per_epoch=10, callbacks=[], + enable_progress_bar=True, **kwargs, ) -> LogHistory: """Train the model on the training data. @@ -124,9 +125,11 @@ def fit( ) history = LogHistory() - progressbar = RichProgressBar() + aux_callbacks = [history] + if enable_progress_bar: + aux_callbacks = aux_callbacks + RichProgressBar() - callbacks = callbacks + [history, progressbar] + callbacks = callbacks + aux_callbacks trainer = dl.Trainer(max_epochs=max_epochs, callbacks=callbacks, **kwargs) train_dataloader = torch.utils.data.DataLoader( From cdbb4ab315e071fbbe4ffd83024c66ae11548b6e Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sat, 28 Sep 2024 14:30:47 +0200 Subject: [PATCH 8/9] Refactor Application fit to not add progressbar --- deeplay/applications/application.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deeplay/applications/application.py b/deeplay/applications/application.py index 62bc4b5b..0e27d74d 100644 --- a/deeplay/applications/application.py +++ b/deeplay/applications/application.py @@ -126,8 +126,6 @@ def fit( history = LogHistory() aux_callbacks = [history] - if enable_progress_bar: - aux_callbacks = aux_callbacks + RichProgressBar() callbacks = callbacks + aux_callbacks trainer = dl.Trainer(max_epochs=max_epochs, callbacks=callbacks, **kwargs) From 2e6f2e1022fbfe01bc9f37e1dc044e52c6da09db Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Sat, 28 Sep 2024 16:27:39 +0200 Subject: [PATCH 9/9] Refactor Application fit method to remove enable_progress_bar parameter --- deeplay/applications/application.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deeplay/applications/application.py b/deeplay/applications/application.py index 0e27d74d..f333c624 100644 --- a/deeplay/applications/application.py +++ b/deeplay/applications/application.py @@ -84,7 +84,6 @@ def fit( val_batch_size=None, val_steps_per_epoch=10, callbacks=[], - enable_progress_bar=True, **kwargs, ) -> LogHistory: """Train the model on the training data.