From 208923b57e1111526b1cc47e13193361a63601cc Mon Sep 17 00:00:00 2001 From: Luke Metz Date: Thu, 18 Aug 2022 16:56:10 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 468580163 --- learned_optimization/eval_training.py | 47 ++++++++++++++++++++++----- learned_optimization/summary.py | 28 +++++++++++++++- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/learned_optimization/eval_training.py b/learned_optimization/eval_training.py index f7cc981..c91da65 100644 --- a/learned_optimization/eval_training.py +++ b/learned_optimization/eval_training.py @@ -73,7 +73,9 @@ def fn(opt_state, key, data): key, summary_key = jax.random.split(key) (next_opt_state, loss, key), metrics = summary.with_summary_output_reduced(fn)( - opt_state, key, data, summary_sample_rng_key=summary_key) + opt_state, key, data, sample_rng_key=summary_key) + key, key1 = jax.random.split(key) + metrics = summary.aggregate_metric_list([metrics], use_jnp=True, key=key1) else: next_opt_state, loss, key = fn(opt_state, key, data) metrics = {} @@ -142,6 +144,7 @@ def single_task_training_curves( last_eval_batches: int = 20, eval_task: Optional[tasks_base.Task] = None, device: Optional[jax.lib.xla_client.Device] = None, + metrics_every: Optional[int] = None, summary_writer: Optional[summary.SummaryWriterBase] = None, ) -> Mapping[str, jnp.ndarray]: """Compute training curves.""" @@ -160,11 +163,14 @@ def single_task_training_curves( opt.init, static_argnames=("num_steps",))( p, model_state=s, num_steps=num_steps) - losses = [] - eval_auxs = [] - use_data = task.datasets is not None - train_xs = [] - eval_xs = [] + losses = [] + eval_auxs = [] + use_data = task.datasets is not None + train_xs = [] + eval_xs = [] + metrics = [] + metrics_xs = [] + for i in tqdm.trange(num_steps + 1, position=0): with profile.Profile("eval"): m = {} @@ -196,16 +202,41 @@ def single_task_training_curves( batch = jax.device_put(batch, device=device) with profile.Profile("next_state"): - opt_state, l, key, _ = _next_state( - task, opt, opt_state, batch, key, with_metrics=False) + with_metrics = False if ( + metrics_every is None) else i % metrics_every == 0 + opt_state, l, key, m = _next_state( + task, opt, opt_state, batch, key, with_metrics=with_metrics) losses.append(l) train_xs.append(i) + if summary_writer: + summary_writer.scalar("train/loss", l, step=i) + + if metrics_every: + if summary_writer: + for k, v in m.items(): + agg, k = k.split("||") + if agg in ["mean", "sample"]: + summary_writer.scalar(k, v, step=i) + elif agg == "tensor": + summary_writer.tensor(k, v, step=i) + else: + logging.warning(f"Not supported aggregation type {agg}." # pylint: disable=logging-fstring-interpolation + f"Dropping data for key {k}.") + metrics.append(m) + metrics_xs.append(i) + ret = { "train/xs": onp.asarray(train_xs), "train/loss": onp.asarray(losses), } + if metrics_every: + stacked_metrics = tree_utils.tree_zip_onp(metrics) + metric_dict = {f"train/metrics/{k}": v for k, v in stacked_metrics.items()} + ret = {**ret, **metric_dict} + ret["train/metrics/xs"] = onp.asarray(metrics_xs) + if eval_batches: stacked_metrics = tree_utils.tree_zip_onp(eval_auxs) ret["eval/xs"] = onp.asarray(eval_xs) diff --git a/learned_optimization/summary.py b/learned_optimization/summary.py index 8925c0b..87bd341 100644 --- a/learned_optimization/summary.py +++ b/learned_optimization/summary.py @@ -101,6 +101,7 @@ class AggregationType(str, enum.Enum): sample = "sample" # pylint: disable=invalid-name collect = "collect" # pylint: disable=invalid-name none = "none" # pylint: disable=invalid-name + tensor = "tensor" # pylint: disable=invalid-name def summary( @@ -138,8 +139,12 @@ def summary( oryx_name = aggregation + "||" + name + mode = "append" + if aggregation == AggregationType.tensor: + mode = "strict" + if ORYX_LOGGING: - val = oryx.core.sow(val, tag=_SOW_TAG, name=oryx_name, mode="append") + val = oryx.core.sow(val, tag=_SOW_TAG, name=oryx_name, mode=mode) return val @@ -203,6 +208,9 @@ def aggregate_metric(k: str, elif agg == AggregationType.collect: # This might be multi dim if vmap is used, so ravel first. return xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) + elif agg == AggregationType.tensor: + assert len(vs) == 1 + return vs[0] elif agg == AggregationType.none: if len(vs) != 1: raise ValueError("when using no aggregation one must ensure only scalar " @@ -279,6 +287,8 @@ def out_fn(unused_in, *args): to_sample.append((k, v)) elif agg == AggregationType.collect: new_metrics[k] = v.ravel() + elif agg == AggregationType.tensor: + new_metrics[k] = v else: raise ValueError(f"unsupported aggregation {agg}") @@ -387,6 +397,9 @@ def scalar(self, name, value, step): def histogram(self, name, value, step): raise NotImplementedError() + def tensor(self, name, value, step): + raise NotImplementedError() + def flush(self): raise NotImplementedError() @@ -423,6 +436,10 @@ def histogram(self, name, value, step): if self.filter_fn(name): print(f"{step}] {name}={value}") + def tensor(self, name, value, step): + if self.filter_fn(name): + print(f"{step}] {name}=Tensor: {value.shape}") + def flush(self): pass @@ -439,6 +456,9 @@ def scalar(self, name, value, step): def flush(self): _ = [w.flush() for w in self.writers] + def tensor(self, name, value, step): + _ = [w.tensor(name, value, step) for w in self.writers] + def histogram(self, name, value, step): _ = [w.histogram(name, value, step) for w in self.writers] @@ -528,6 +548,12 @@ def text(self, name, textdata, step): tf.summary.text(name=name, data=tf.constant(textdata), step=step) + def tensor(self, name, tensor, step): + """Write a tensor summary.""" + self._ensure_default() + tf.summary.write(tag=name, tensor=tensor, step=step, name=name) + + JaxboardWriter = TensorboardWriter