From 6663ce6779332d26f6817f3c37b3fd762a64bdec Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 1 May 2025 09:49:49 -0700 Subject: [PATCH 1/2] Add support to the dictionary batch in LitModular. --- mart/models/modular.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index afab122e..27658348 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -125,9 +125,13 @@ def attack_step(self, batch, batch_idx): # Training # def training_step(self, batch, batch_idx): - # FIXME: Would be much nicer if batch was a dict! - input, target = batch - output = self(input=input, target=target, model=self.model, step="training") + # FIXME: Would be much nicer if batch is always a dictionary! + # We are going to feed the raw batch of a dictionary to self.model, but also making it backward-compatible with the tuple batch format. + if isinstance(batch, dict): + input = target = None + else: + input, target = batch + output = self(input=input, target=target, batch=batch, model=self.model, step="training") for log_name, output_key in self.training_step_log.items(): self.log(f"training/{log_name}", output[output_key]) @@ -160,8 +164,11 @@ def on_train_epoch_end(self): # def validation_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch was a dict! - input, target = batch - output = self(input=input, target=target, model=self.model, step="validation") + if isinstance(batch, dict): + input = target = None + else: + input, target = batch + output = self(input=input, target=target, batch=batch, model=self.model, step="validation") for log_name, output_key in self.validation_step_log.items(): self.log(f"validation/{log_name}", output[output_key]) @@ -182,8 +189,11 @@ def on_validation_epoch_end(self): # def test_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch was a dict! - input, target = batch - output = self(input=input, target=target, model=self.model, step="test") + if isinstance(batch, dict): + input = target = None + else: + input, target = batch + output = self(input=input, target=target, batch=batch, model=self.model, step="test") for log_name, output_key in self.test_step_log.items(): self.log(f"test/{log_name}", output[output_key]) From 5454d0a4bf5ad7c48fe4374fcb55c3db75d60718 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 1 May 2025 10:50:56 -0700 Subject: [PATCH 2/2] First assume the batch is a dictionary. --- mart/models/modular.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index 27658348..f534e634 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -6,6 +6,7 @@ import logging from operator import attrgetter +from typing import Sequence import torch from lightning.pytorch import LightningModule @@ -127,9 +128,8 @@ def attack_step(self, batch, batch_idx): def training_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch is always a dictionary! # We are going to feed the raw batch of a dictionary to self.model, but also making it backward-compatible with the tuple batch format. - if isinstance(batch, dict): - input = target = None - else: + input = target = None + if isinstance(batch, Sequence) and len(batch) == 2: input, target = batch output = self(input=input, target=target, batch=batch, model=self.model, step="training") @@ -164,9 +164,8 @@ def on_train_epoch_end(self): # def validation_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch was a dict! - if isinstance(batch, dict): - input = target = None - else: + input = target = None + if isinstance(batch, Sequence) and len(batch) == 2: input, target = batch output = self(input=input, target=target, batch=batch, model=self.model, step="validation") @@ -189,9 +188,8 @@ def on_validation_epoch_end(self): # def test_step(self, batch, batch_idx): # FIXME: Would be much nicer if batch was a dict! - if isinstance(batch, dict): - input = target = None - else: + input = target = None + if isinstance(batch, Sequence) and len(batch) == 2: input, target = batch output = self(input=input, target=target, batch=batch, model=self.model, step="test")