diff --git a/mart/models/modular.py b/mart/models/modular.py index 0c8ebb2c..605b237c 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -6,6 +6,7 @@ import logging from operator import attrgetter +from typing import Sequence import torch from lightning.pytorch import LightningModule @@ -137,9 +138,12 @@ def attack_step(self, batch, batch_idx): # Training # def training_step(self, batch, batch_idx): - # FIXME: Would be much nicer if batch was a dict! - input, target = batch - output = self(input=input, target=target, model=self.model, step="training") + # FIXME: Would be much nicer if batch is always a dictionary! + # We are going to feed the raw batch of a dictionary to self.model, but also making it backward-compatible with the tuple batch format. + input = target = None + if isinstance(batch, Sequence) and len(batch) == 2: + input, target = batch + output = self(input=input, target=target, batch=batch, model=self.model, step="training") for log_name, output_key in self.training_step_log.items(): self.log(f"training/{log_name}", output[output_key]) @@ -172,8 +176,10 @@ def on_train_epoch_end(self): # def validation_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch was a dict! - input, target = batch - output = self(input=input, target=target, model=self.model, step="validation") + input = target = None + if isinstance(batch, Sequence) and len(batch) == 2: + input, target = batch + output = self(input=input, target=target, batch=batch, model=self.model, step="validation") for log_name, output_key in self.validation_step_log.items(): self.log(f"validation/{log_name}", output[output_key]) @@ -194,8 +200,10 @@ def on_validation_epoch_end(self): # def test_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch was a dict! - input, target = batch - output = self(input=input, target=target, model=self.model, step="test") + input = target = None + if isinstance(batch, Sequence) and len(batch) == 2: + input, target = batch + output = self(input=input, target=target, batch=batch, model=self.model, step="test") for log_name, output_key in self.test_step_log.items(): self.log(f"test/{log_name}", output[output_key])