diff --git a/learned_optimization/learned_optimizers/nn_adam.py b/learned_optimization/learned_optimizers/nn_adam.py index d374d08..0e62762 100644 --- a/learned_optimization/learned_optimizers/nn_adam.py +++ b/learned_optimization/learned_optimizers/nn_adam.py @@ -82,7 +82,7 @@ class NNAdamState: """ params: Any state: Any - iteration: int + iteration: jax.Array rolling_features: MeanAndMeanSquareAccumulator per_layer_lr: Any per_layer_beta1: Any diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index e8c1edd..8df1a53 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -384,11 +384,13 @@ def __init__(self, # SM3 doesn't support scalars, so we have to reshape the params and grads. - def init(self, - params: Any, - model_state: Optional[Any] = None, - num_steps: Optional[int] = None, - key: chex.PRNGKey = None) -> SM3OptState: + def init( + self, + params: Any, + model_state: Optional[Any] = None, + num_steps: Optional[int] = None, + key: Optional[chex.PRNGKey] = None, + ) -> SM3OptState: should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape) out = super().init(params, model_state, num_steps, key) diff --git a/learned_optimization/outer_train.py b/learned_optimization/outer_train.py index 119f402..31b79d1 100644 --- a/learned_optimization/outer_train.py +++ b/learned_optimization/outer_train.py @@ -207,7 +207,7 @@ def metrics_and_info_from_gradients( gathered_grads: Sequence[GradientsFromWorker], steps: Sequence[int], current_step: int, -) -> Tuple[Mapping[str, float], Sequence[int], int]: +) -> Tuple[Mapping[str, float], Sequence[Union[int, jax.Array]], int]: """Perform one outer-iteration on a batch of gradients from workers. Args: @@ -222,17 +222,17 @@ def metrics_and_info_from_gradients( applied_inner_steps: number if inner steps performed this outer step. """ - worker_ids = jnp.asarray([t.worker_id for t in gathered_grads]) - inner_steps = onp.asarray([t.total_inner_steps for t in gathered_grads]) + worker_ids = [t.worker_id for t in gathered_grads] + inner_steps = [t.total_inner_steps for t in gathered_grads] - applied_inner_steps = onp.sum(inner_steps) + applied_inner_steps = int(onp.sum(inner_steps)) metrics = {} metrics["unique_worker"] = float(len(onp.unique(worker_ids))) - avg_stale = current_step - onp.mean(steps) + avg_stale = current_step - float(onp.mean(steps)) metrics["avg_staleness"] = avg_stale - max_stale = current_step - onp.min(steps) + max_stale = current_step - float(onp.min(steps)) metrics["max_staleness"] = max_stale return metrics, worker_ids, applied_inner_steps diff --git a/learned_optimization/outer_trainers/full_es.py b/learned_optimization/outer_trainers/full_es.py index bd0b64d..dcb111b 100644 --- a/learned_optimization/outer_trainers/full_es.py +++ b/learned_optimization/outer_trainers/full_es.py @@ -249,7 +249,7 @@ def last_recompute_antithetic_es( recompute_samples: int, clip_loss_diff: Optional[float] = None, sign_delta_loss_scalar: Optional[float] = None, -) -> Tuple[float, MetaParams]: +) -> Tuple[jax.Array, MetaParams]: """Compute an ES gradient estimate by recomputing the loss on both unrolls. Args: diff --git a/learned_optimization/outer_trainers/truncated_es.py b/learned_optimization/outer_trainers/truncated_es.py index 8154472..5015cea 100644 --- a/learned_optimization/outer_trainers/truncated_es.py +++ b/learned_optimization/outer_trainers/truncated_es.py @@ -46,7 +46,9 @@ def compute_es_grad( vec_pos: MetaParams, std: float, sign_delta_loss_scalar: Optional[float] = None, -) -> Tuple[float, MetaParams, truncated_step_mod.TruncatedUnrollOut, float]: +) -> Tuple[ + jax.Array, MetaParams, truncated_step_mod.TruncatedUnrollOut, jax.Array +]: """Compute the ES gradient estimate from the outputs of many unrolls. Args: diff --git a/learned_optimization/outer_trainers/truncated_pes.py b/learned_optimization/outer_trainers/truncated_pes.py index 1572ce9..2c3c052 100644 --- a/learned_optimization/outer_trainers/truncated_pes.py +++ b/learned_optimization/outer_trainers/truncated_pes.py @@ -53,8 +53,13 @@ def compute_pes_grad( vec_pos: MetaParams, std: float, sign_delta_loss_scalar: Optional[float] = None, -) -> Tuple[float, MetaParams, MetaParams, truncated_step_mod.TruncatedUnrollOut, - float]: +) -> Tuple[ + jax.Array, + MetaParams, + MetaParams, + truncated_step_mod.TruncatedUnrollOut, + jax.Array, +]: """Compute the PES gradient estimate from the outputs of many unrolls. Args: diff --git a/learned_optimization/research/hysteresis/truncated_pes_no_hyst_same_data.py b/learned_optimization/research/hysteresis/truncated_pes_no_hyst_same_data.py index bee573b..cec8514 100644 --- a/learned_optimization/research/hysteresis/truncated_pes_no_hyst_same_data.py +++ b/learned_optimization/research/hysteresis/truncated_pes_no_hyst_same_data.py @@ -130,6 +130,7 @@ def compute_gradient_estimate( rng = hk.PRNGSequence(key) # train on the circular buffer's data. + assert isinstance(self.circular_buffer, circular_buffer.CircularBuffer) reordered_data, unused_mask = self.circular_buffer.stack_reorder( buffer_state)