diff --git a/ranger21/ranger21.py b/ranger21/ranger21.py index 777ac18..7299f1a 100644 --- a/ranger21/ranger21.py +++ b/ranger21/ranger21.py @@ -143,6 +143,7 @@ def __init__( warmup_type="linear", warmup_pct_default=0.22, logging_active=True, + verbose=True ): # todo - checks on incoming params @@ -153,6 +154,7 @@ def __init__( # core self.logging = logging_active + self.verbose = verbose # engine self.use_madgrad = use_madgrad @@ -294,8 +296,8 @@ def __init__( engine = "AdamW" if not self.use_madgrad else "MadGrad" # print out initial settings to make usage easier - - self.show_settings() + if self.verbose: + self.show_settings() def __setstate__(self, state): super().__setstate__(state) @@ -360,8 +362,8 @@ def show_settings(self): # lookahead functions def clear_cache(self): """clears the lookahead cached params """ - - print(f"clearing lookahead cache...") + if self.verbose: + print(f"clearing lookahead cache...") for group in self.param_groups: for p in group["params"]: param_state = self.state[p] @@ -373,7 +375,8 @@ def clear_cache(self): if len(la_params): param_state["lookahead_params"] = torch.zeros_like(p.data) - print(f"lookahead cache cleared") + if self.verbose: + print(f"lookahead cache cleared") def clear_and_load_backup(self): for group in self.param_groups: @@ -449,7 +452,8 @@ def warmup_dampening(self, lr, step): ) self.warmup_complete = True - print(f"\n** Ranger21 update = Warmup complete - lr set to {lr}\n") + if self.verbose: + print(f"\n** Ranger21 update = Warmup complete - lr set to {lr}\n") return lr if style == "linear": @@ -472,9 +476,10 @@ def get_warm_down(self, lr, iteration): if iteration > self.start_warm_down - 1: # print when starting if not self.warmdown_displayed: - print( + if self.verbose: + print( f"\n** Ranger21 update: Warmdown starting now. Current iteration = {iteration}....\n" - ) + ) self.warmdown_displayed = True warmdown_iteration = ( @@ -697,9 +702,10 @@ def step(self, closure=None): # we will run this first epoch only and then memoize if not self.param_size: self.param_size = param_size - print(f"params size saved") - print(f"total param groups = {i+1}") - print(f"total params in groups = {j+1}") + if self.verbose: + print(f"params size saved") + print(f"total param groups = {i+1}") + print(f"total params in groups = {j+1}") if not self.param_size: raise ValueError("failed to set param size")