diff --git a/fedjax/core/optimizers.py b/fedjax/core/optimizers.py index b4d3d12..46204f3 100644 --- a/fedjax/core/optimizers.py +++ b/fedjax/core/optimizers.py @@ -286,7 +286,7 @@ def adafactor( min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, - multiply_by_parameter_scale: float = True, + multiply_by_parameter_scale: bool = True, clipping_threshold: Optional[float] = 1.0, momentum: Optional[float] = None, dtype_momentum: Any = jnp.float32,