diff --git a/mart/attack/initializer/base.py b/mart/attack/initializer/base.py index 2d28e964..49a5b173 100644 --- a/mart/attack/initializer/base.py +++ b/mart/attack/initializer/base.py @@ -40,13 +40,16 @@ def initialize_(self, parameter: torch.Tensor) -> None: class Uniform(Initializer): - def __init__(self, min: int | float, max: int | float): + def __init__(self, min: int | float, max: int | float, round: bool = False): self.min = min self.max = max + self.round = round @torch.no_grad() def initialize_(self, parameter: torch.Tensor) -> None: torch.nn.init.uniform_(parameter, self.min, self.max) + if self.round: + parameter.round_() class UniformLp(Initializer): diff --git a/mart/configs/attack/composer/perturber/initializer/uniform.yaml b/mart/configs/attack/composer/perturber/initializer/uniform.yaml index 84df0cc1..404e1d22 100644 --- a/mart/configs/attack/composer/perturber/initializer/uniform.yaml +++ b/mart/configs/attack/composer/perturber/initializer/uniform.yaml @@ -1,3 +1,4 @@ _target_: mart.attack.initializer.Uniform min: ??? max: ??? +round: false