From fa5a9e6d9f9dbbc2b038ebd54083c91ea35b0335 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 30 Apr 2025 15:57:26 -0700 Subject: [PATCH 1/2] Add an option Uniform.round to round the initial numbers. --- mart/attack/initializer/base.py | 5 ++++- .../attack/composer/perturber/initializer/uniform.yaml | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mart/attack/initializer/base.py b/mart/attack/initializer/base.py index 2d28e964..ba98b684 100644 --- a/mart/attack/initializer/base.py +++ b/mart/attack/initializer/base.py @@ -40,13 +40,16 @@ def initialize_(self, parameter: torch.Tensor) -> None: class Uniform(Initializer): - def __init__(self, min: int | float, max: int | float): + def __init__(self, min: int | float, max: int | float, round: False): self.min = min self.max = max + self.round = round @torch.no_grad() def initialize_(self, parameter: torch.Tensor) -> None: torch.nn.init.uniform_(parameter, self.min, self.max) + if self.round: + parameter.round_() class UniformLp(Initializer): diff --git a/mart/configs/attack/composer/perturber/initializer/uniform.yaml b/mart/configs/attack/composer/perturber/initializer/uniform.yaml index 84df0cc1..404e1d22 100644 --- a/mart/configs/attack/composer/perturber/initializer/uniform.yaml +++ b/mart/configs/attack/composer/perturber/initializer/uniform.yaml @@ -1,3 +1,4 @@ _target_: mart.attack.initializer.Uniform min: ??? max: ??? +round: false From 6c234035bdf61860974a674991afbc1a64be98dc Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 1 May 2025 10:29:45 -0700 Subject: [PATCH 2/2] Fix a typo in the default parameter. --- mart/attack/initializer/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/attack/initializer/base.py b/mart/attack/initializer/base.py index ba98b684..49a5b173 100644 --- a/mart/attack/initializer/base.py +++ b/mart/attack/initializer/base.py @@ -40,7 +40,7 @@ def initialize_(self, parameter: torch.Tensor) -> None: class Uniform(Initializer): - def __init__(self, min: int | float, max: int | float, round: False): + def __init__(self, min: int | float, max: int | float, round: bool = False): self.min = min self.max = max self.round = round