Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions mart/attack/composer/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@

class Composer(torch.nn.Module):
def __init__(
self, perturber: Perturber, modules, sequence, visualizer: Callable = None
self,
perturber: Perturber,
modules,
sequence,
visualizer: Callable = None,
return_final_output_only: bool = True,
) -> None:
"""_summary_

Args:
perturber (Perturber): Manage perturbations.
functions (dict[str, Function]): A dictionary of functions for composing pertured input.
visualizer (Callable): Visualize intermediate results of a composer.
return_final_output_only (bool): Only returns the final output instead of the output dict if True.
"""
super().__init__()

Expand All @@ -38,6 +44,7 @@ def __init__(
sequence = [sequence[key] for key in sorted(sequence)]
self.functions = SequentialDict(modules, {"composer": sequence})
self.visualizer = visualizer
self.return_final_output_only = return_final_output_only

def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]):
return self.perturber.configure_perturbation(input)
Expand Down Expand Up @@ -84,10 +91,11 @@ def _compose(
if self.visualizer:
self.visualizer(output)

# SequentialDict returns a dictionary DotDict,
# but we only need the return value of the most recently executed module.
last_added_key = next(reversed(output))
output = output[last_added_key]
if self.return_final_output_only:
# SequentialDict returns a dictionary DotDict,
# but we only need the return value of the most recently executed module.
last_added_key = next(reversed(output))
output = output[last_added_key]

# Return the composed input.
return output
Expand Down
1 change: 1 addition & 0 deletions mart/configs/attack/composer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ defaults:
- perturber: default

_target_: mart.attack.Composer
return_final_output_only: true
modules:
???
# Example: additive, mask, overlay
Expand Down
Loading