diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index 8df1a53..39424f9 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -446,19 +446,21 @@ def __init__(self, class Adafactor(OptaxOptimizer): """Adafactor optimizer.""" - def __init__(self, - learning_rate: float, - min_dim_size_to_factor: int = 128, - decay_rate: float = 0.8, - decay_offset: int = 0, - multiply_by_parameter_scale: float = True, - clipping_threshold: Optional[float] = 1.0, - momentum: Optional[float] = None, - dtype_momentum: Any = jnp.float32, - weight_decay_rate: Optional[float] = None, - eps: float = 1e-30, - factored: bool = True, - weight_decay_mask=None): + def __init__( + self, + learning_rate: float, + min_dim_size_to_factor: int = 128, + decay_rate: float = 0.8, + decay_offset: int = 0, + multiply_by_parameter_scale: bool = True, + clipping_threshold: Optional[float] = 1.0, + momentum: Optional[float] = None, + dtype_momentum: Any = jnp.float32, + weight_decay_rate: Optional[float] = None, + eps: float = 1e-30, + factored: bool = True, + weight_decay_mask=None, + ): opt = optax.adafactor( learning_rate=learning_rate,