From 9e9117ffdfd5a0fd72a1d05f36284fe7557c7f57 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 30 Apr 2025 16:04:02 -0700 Subject: [PATCH] Allow Composer to return intermediate representations. --- mart/attack/composer/modular.py | 18 +++++++++++++----- mart/configs/attack/composer/default.yaml | 1 + 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mart/attack/composer/modular.py b/mart/attack/composer/modular.py index 64aa9d79..2c6c42e4 100644 --- a/mart/attack/composer/modular.py +++ b/mart/attack/composer/modular.py @@ -20,7 +20,12 @@ class Composer(torch.nn.Module): def __init__( - self, perturber: Perturber, modules, sequence, visualizer: Callable = None + self, + perturber: Perturber, + modules, + sequence, + visualizer: Callable = None, + return_final_output_only: bool = True, ) -> None: """_summary_ @@ -28,6 +33,7 @@ def __init__( perturber (Perturber): Manage perturbations. functions (dict[str, Function]): A dictionary of functions for composing pertured input. visualizer (Callable): Visualize intermediate results of a composer. + return_final_output_only (bool): Only returns the final output instead of the output dict if True. """ super().__init__() @@ -38,6 +44,7 @@ def __init__( sequence = [sequence[key] for key in sorted(sequence)] self.functions = SequentialDict(modules, {"composer": sequence}) self.visualizer = visualizer + self.return_final_output_only = return_final_output_only def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): return self.perturber.configure_perturbation(input) @@ -84,10 +91,11 @@ def _compose( if self.visualizer: self.visualizer(output) - # SequentialDict returns a dictionary DotDict, - # but we only need the return value of the most recently executed module. - last_added_key = next(reversed(output)) - output = output[last_added_key] + if self.return_final_output_only: + # SequentialDict returns a dictionary DotDict, + # but we only need the return value of the most recently executed module. + last_added_key = next(reversed(output)) + output = output[last_added_key] # Return the composed input. return output diff --git a/mart/configs/attack/composer/default.yaml b/mart/configs/attack/composer/default.yaml index 0c16c028..9d32df05 100644 --- a/mart/configs/attack/composer/default.yaml +++ b/mart/configs/attack/composer/default.yaml @@ -2,6 +2,7 @@ defaults: - perturber: default _target_: mart.attack.Composer +return_final_output_only: true modules: ??? # Example: additive, mask, overlay