Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion learned_optimization/learned_optimizers/nn_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class NNAdamState:
"""
params: Any
state: Any
iteration: int
iteration: jax.Array
rolling_features: MeanAndMeanSquareAccumulator
per_layer_lr: Any
per_layer_beta1: Any
Expand Down
12 changes: 7 additions & 5 deletions learned_optimization/optimizers/optax_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,13 @@ def __init__(self,

# SM3 doesn't support scalars, so we have to reshape the params and grads.

def init(self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: chex.PRNGKey = None) -> SM3OptState:
def init(
self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: Optional[chex.PRNGKey] = None,
) -> SM3OptState:
should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test
params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape)
out = super().init(params, model_state, num_steps, key)
Expand Down
12 changes: 6 additions & 6 deletions learned_optimization/outer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def metrics_and_info_from_gradients(
gathered_grads: Sequence[GradientsFromWorker],
steps: Sequence[int],
current_step: int,
) -> Tuple[Mapping[str, float], Sequence[int], int]:
) -> Tuple[Mapping[str, float], Sequence[Union[int, jax.Array]], int]:
"""Perform one outer-iteration on a batch of gradients from workers.

Args:
Expand All @@ -222,17 +222,17 @@ def metrics_and_info_from_gradients(
applied_inner_steps: number if inner steps performed this outer step.
"""

worker_ids = jnp.asarray([t.worker_id for t in gathered_grads])
inner_steps = onp.asarray([t.total_inner_steps for t in gathered_grads])
worker_ids = [t.worker_id for t in gathered_grads]
inner_steps = [t.total_inner_steps for t in gathered_grads]

applied_inner_steps = onp.sum(inner_steps)
applied_inner_steps = int(onp.sum(inner_steps))
metrics = {}
metrics["unique_worker"] = float(len(onp.unique(worker_ids)))

avg_stale = current_step - onp.mean(steps)
avg_stale = current_step - float(onp.mean(steps))
metrics["avg_staleness"] = avg_stale

max_stale = current_step - onp.min(steps)
max_stale = current_step - float(onp.min(steps))
metrics["max_staleness"] = max_stale

return metrics, worker_ids, applied_inner_steps
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/outer_trainers/full_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def last_recompute_antithetic_es(
recompute_samples: int,
clip_loss_diff: Optional[float] = None,
sign_delta_loss_scalar: Optional[float] = None,
) -> Tuple[float, MetaParams]:
) -> Tuple[jax.Array, MetaParams]:
"""Compute an ES gradient estimate by recomputing the loss on both unrolls.

Args:
Expand Down
4 changes: 3 additions & 1 deletion learned_optimization/outer_trainers/truncated_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def compute_es_grad(
vec_pos: MetaParams,
std: float,
sign_delta_loss_scalar: Optional[float] = None,
) -> Tuple[float, MetaParams, truncated_step_mod.TruncatedUnrollOut, float]:
) -> Tuple[
jax.Array, MetaParams, truncated_step_mod.TruncatedUnrollOut, jax.Array
]:
"""Compute the ES gradient estimate from the outputs of many unrolls.

Args:
Expand Down
9 changes: 7 additions & 2 deletions learned_optimization/outer_trainers/truncated_pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ def compute_pes_grad(
vec_pos: MetaParams,
std: float,
sign_delta_loss_scalar: Optional[float] = None,
) -> Tuple[float, MetaParams, MetaParams, truncated_step_mod.TruncatedUnrollOut,
float]:
) -> Tuple[
jax.Array,
MetaParams,
MetaParams,
truncated_step_mod.TruncatedUnrollOut,
jax.Array,
]:
"""Compute the PES gradient estimate from the outputs of many unrolls.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def compute_gradient_estimate(
rng = hk.PRNGSequence(key)

# train on the circular buffer's data.
assert isinstance(self.circular_buffer, circular_buffer.CircularBuffer)
reordered_data, unused_mask = self.circular_buffer.stack_reorder(
buffer_state)

Expand Down