From a457511c95f3af7139d3e8b7527106b183871bf9 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 6 May 2025 16:24:58 -0700 Subject: [PATCH 1/5] Strip away the LogAndMetrics callback from LitModular. --- mart/callbacks/__init__.py | 1 + mart/callbacks/log_metrics.py | 165 ++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 mart/callbacks/log_metrics.py diff --git a/mart/callbacks/__init__.py b/mart/callbacks/__init__.py index b84e09f2..01b64dcf 100644 --- a/mart/callbacks/__init__.py +++ b/mart/callbacks/__init__.py @@ -2,6 +2,7 @@ from .adversary_connector import * from .eval_mode import * from .gradients import * +from .log_metrics import * from .no_grad_mode import * from .progress_bar import * diff --git a/mart/callbacks/log_metrics.py b/mart/callbacks/log_metrics.py new file mode 100644 index 00000000..d5d12b72 --- /dev/null +++ b/mart/callbacks/log_metrics.py @@ -0,0 +1,165 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import logging +from typing import Sequence + +import torch +from lightning.pytorch.callbacks import Callback +from torchmetrics import Metric + +from ..nn.nn import DotDict + +logger = logging.getLogger(__name__) + + +class LogMetrics(Callback): + """For models returning a dictionary, we can configure the callback to log scalars from the + outputs, calculate and log metrics.""" + + def __init__( + self, + train_step_log: Sequence | dict = None, + val_step_log: Sequence | dict = None, + test_step_log: Sequence | dict = None, + train_metrics: Metric = None, + val_metrics: Metric = None, + test_metrics: Metric = None, + output_preds_key: str = "preds", + output_target_key: str = "target", + # We may display only some of the metrics on the progress bar, if there are too many. + metrics_on_train_prog_bar: bool | Sequence[str] = True, + metrics_on_val_prog_bar: bool | Sequence[str] = True, + metrics_on_test_prog_bar: bool | Sequence[str] = True, + ): + super().__init__() + + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(train_step_log, (list, tuple)): + train_step_log = {item: {"key": item, "prog_bar": True} for item in train_step_log} + train_step_log = train_step_log or {} + + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(val_step_log, (list, tuple)): + val_step_log = {item: {"key": item, "prog_bar": True} for item in val_step_log} + val_step_log = val_step_log or {} + + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(test_step_log, (list, tuple)): + test_step_log = {item: {"key": item, "prog_bar": True} for item in test_step_log} + test_step_log = test_step_log or {} + + self.step_log = { + "train": train_step_log, + "val": val_step_log, + "test": test_step_log, + } + self.metrics = { + "train": train_metrics, + "val": val_metrics, + "test": test_metrics, + } + self.metrics_on_prog_bar = { + "train": metrics_on_train_prog_bar, + "val": metrics_on_val_prog_bar, + "test": metrics_on_test_prog_bar, + } + + self.output_preds_key = output_preds_key + self.output_target_key = output_target_key + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return self.on_batch_end(outputs, prefix="train") + + def on_train_epoch_end(self, trainer, pl_module): + return self.on_epoch_end(pl_module, prefix="train") + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return self.on_batch_end(outputs, prefix="val") + + def on_validation_epoch_end(self, trainer, pl_module): + return self.on_epoch_end(pl_module, prefix="val") + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return self.on_batch_end(outputs, prefix="test") + + def on_test_epoch_end(self, trainer, pl_module): + return self.on_epoch_end(pl_module, prefix="test") + + # + # Utilities + # + def on_batch_end(self, outputs, *, prefix: str): + # Convert to DotDict, so that we can use a dot-connected string as a key to find a value deep in the dictionary. + outputs = DotDict(outputs) + + step_log = self.step_log[prefix] + for log_name, cfg in step_log.items(): + key, prog_bar = cfg["key"], cfg["prog_bar"] + self.log(f"{prefix}/{log_name}", outputs[key], prog_bar=prog_bar) + + metric = self.metrics[prefix] + if metric is not None: + metric(outputs[self.output_preds_key], outputs[self.output_target_key]) + + def on_epoch_end(self, pl_module, *, prefix: str): + metric = self.metrics[prefix] + if metric is not None: + # Some models only return loss in the train mode. + results = metric.compute() + results = self.flatten_metrics(results) + metric.reset() + + self.log_metrics(pl_module, results, prefix=prefix) + + def flatten_metrics(self, metrics): + # torchmetrics==0.6.0 does not flatten group metrics such as mAP (which includes mAP and mAP-50, etc), + # while later versions do. We add this for forward compatibility while we downgrade to 0.6.0. + flat_metrics = {} + + for k, v in metrics.items(): + if isinstance(v, dict): + # recursively flatten metrics + v = self.flatten_metrics(v) + for k2, v2 in v.items(): + if k2 in flat_metrics: + logger.warning(f"{k}/{k2} overrides existing metric!") + + flat_metrics[k2] = v2 + else: + # assume raw metric + if k in flat_metrics: + logger.warning(f"{k} overrides existing metric!") + + flat_metrics[k] = v + + return flat_metrics + + def log_metrics(self, pl_module, metrics, prefix=""): + metrics_dict = {} + + def enumerate_metric(metric, name): + # Metrics can have arbitrary depth. + if isinstance(metric, torch.Tensor): + # Ignore non-scalar results generated by Metrics, such as list of classes from MAP. + if metric.shape == torch.Size([]): + metrics_dict[name] = metric + else: + for k, v in metric.items(): + enumerate_metric(v, f"{name}/{k}") + + enumerate_metric(metrics, prefix) + + # sync_dist is not necessary for torchmetrics: https://torchmetrics.readthedocs.io/en/stable/pages/lightning.html + on_prog_bar = self.metrics_on_prog_bar[prefix] + if isinstance(on_prog_bar, bool): + pl_module.log_dict(metrics_dict, prog_bar=on_prog_bar) + elif isinstance(on_prog_bar, Sequence): + for metric_key in on_prog_bar: + metric_value = metrics_dict.pop(metric_key) + pl_module.log(f"{prefix}/{metric_key}", metric_value, prog_bar=on_prog_bar) + else: + raise ValueError(f"Unknown type: {type(self.metrics_on_prog_bar[prefix])=}") From 50965c6733baa023f4489c3cdd07ebcabc5130df Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 6 May 2025 16:41:40 -0700 Subject: [PATCH 2/5] Fix type annotations. --- mart/callbacks/log_metrics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mart/callbacks/log_metrics.py b/mart/callbacks/log_metrics.py index d5d12b72..969100f4 100644 --- a/mart/callbacks/log_metrics.py +++ b/mart/callbacks/log_metrics.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import logging from typing import Sequence From 4e86f6fe303f2076bce997c73c71649338a48ca1 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 7 May 2025 11:25:44 -0700 Subject: [PATCH 3/5] Split the callback into two. --- mart/callbacks/logging.py | 73 +++++++++++++++++++ mart/callbacks/{log_metrics.py => metrics.py} | 32 +------- 2 files changed, 76 insertions(+), 29 deletions(-) create mode 100644 mart/callbacks/logging.py rename mart/callbacks/{log_metrics.py => metrics.py} (78%) diff --git a/mart/callbacks/logging.py b/mart/callbacks/logging.py new file mode 100644 index 00000000..9bd9d99d --- /dev/null +++ b/mart/callbacks/logging.py @@ -0,0 +1,73 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import logging +from typing import Sequence + +from lightning.pytorch.callbacks import Callback + +from ..nn.nn import DotDict + +logger = logging.getLogger(__name__) + +__all__ = ["Logging"] + + +class Logging(Callback): + """For models returning a dictionary, we can configure the callback to log scalars from the + outputs, calculate and log metrics.""" + + def __init__( + self, + train_step_log: Sequence | dict = None, + val_step_log: Sequence | dict = None, + test_step_log: Sequence | dict = None, + ): + super().__init__() + + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(train_step_log, (list, tuple)): + train_step_log = {item: {"key": item, "prog_bar": True} for item in train_step_log} + train_step_log = train_step_log or {} + + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(val_step_log, (list, tuple)): + val_step_log = {item: {"key": item, "prog_bar": True} for item in val_step_log} + val_step_log = val_step_log or {} + + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(test_step_log, (list, tuple)): + test_step_log = {item: {"key": item, "prog_bar": True} for item in test_step_log} + test_step_log = test_step_log or {} + + self.step_log = { + "train": train_step_log, + "val": val_step_log, + "test": test_step_log, + } + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return self.on_batch_end(outputs, prefix="train") + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return self.on_batch_end(outputs, prefix="val") + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + return self.on_batch_end(outputs, prefix="test") + + # + # Utilities + # + def on_batch_end(self, outputs, *, prefix: str): + # Convert to DotDict, so that we can use a dot-connected string as a key to find a value deep in the dictionary. + outputs = DotDict(outputs) + + step_log = self.step_log[prefix] + for log_name, cfg in step_log.items(): + key, prog_bar = cfg["key"], cfg["prog_bar"] + self.log(f"{prefix}/{log_name}", outputs[key], prog_bar=prog_bar) diff --git a/mart/callbacks/log_metrics.py b/mart/callbacks/metrics.py similarity index 78% rename from mart/callbacks/log_metrics.py rename to mart/callbacks/metrics.py index 969100f4..93d3ece5 100644 --- a/mart/callbacks/log_metrics.py +++ b/mart/callbacks/metrics.py @@ -17,16 +17,15 @@ logger = logging.getLogger(__name__) +__all__ = ["Metrics"] -class LogMetrics(Callback): + +class Metrics(Callback): """For models returning a dictionary, we can configure the callback to log scalars from the outputs, calculate and log metrics.""" def __init__( self, - train_step_log: Sequence | dict = None, - val_step_log: Sequence | dict = None, - test_step_log: Sequence | dict = None, train_metrics: Metric = None, val_metrics: Metric = None, test_metrics: Metric = None, @@ -39,26 +38,6 @@ def __init__( ): super().__init__() - # Be backwards compatible by turning list into dict where each item is its own key-value - if isinstance(train_step_log, (list, tuple)): - train_step_log = {item: {"key": item, "prog_bar": True} for item in train_step_log} - train_step_log = train_step_log or {} - - # Be backwards compatible by turning list into dict where each item is its own key-value - if isinstance(val_step_log, (list, tuple)): - val_step_log = {item: {"key": item, "prog_bar": True} for item in val_step_log} - val_step_log = val_step_log or {} - - # Be backwards compatible by turning list into dict where each item is its own key-value - if isinstance(test_step_log, (list, tuple)): - test_step_log = {item: {"key": item, "prog_bar": True} for item in test_step_log} - test_step_log = test_step_log or {} - - self.step_log = { - "train": train_step_log, - "val": val_step_log, - "test": test_step_log, - } self.metrics = { "train": train_metrics, "val": val_metrics, @@ -98,11 +77,6 @@ def on_batch_end(self, outputs, *, prefix: str): # Convert to DotDict, so that we can use a dot-connected string as a key to find a value deep in the dictionary. outputs = DotDict(outputs) - step_log = self.step_log[prefix] - for log_name, cfg in step_log.items(): - key, prog_bar = cfg["key"], cfg["prog_bar"] - self.log(f"{prefix}/{log_name}", outputs[key], prog_bar=prog_bar) - metric = self.metrics[prefix] if metric is not None: metric(outputs[self.output_preds_key], outputs[self.output_target_key]) From 6de5da46f38f6361429d76933482043da10ef32d Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 7 May 2025 13:53:29 -0700 Subject: [PATCH 4/5] Update type checking --- mart/callbacks/logging.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mart/callbacks/logging.py b/mart/callbacks/logging.py index 9bd9d99d..17febf90 100644 --- a/mart/callbacks/logging.py +++ b/mart/callbacks/logging.py @@ -10,6 +10,7 @@ from typing import Sequence from lightning.pytorch.callbacks import Callback +from torch import Tensor from ..nn.nn import DotDict @@ -31,17 +32,17 @@ def __init__( super().__init__() # Be backwards compatible by turning list into dict where each item is its own key-value - if isinstance(train_step_log, (list, tuple)): + if isinstance(train_step_log, Sequence): train_step_log = {item: {"key": item, "prog_bar": True} for item in train_step_log} train_step_log = train_step_log or {} # Be backwards compatible by turning list into dict where each item is its own key-value - if isinstance(val_step_log, (list, tuple)): + if isinstance(val_step_log, Sequence): val_step_log = {item: {"key": item, "prog_bar": True} for item in val_step_log} val_step_log = val_step_log or {} # Be backwards compatible by turning list into dict where each item is its own key-value - if isinstance(test_step_log, (list, tuple)): + if isinstance(test_step_log, Sequence): test_step_log = {item: {"key": item, "prog_bar": True} for item in test_step_log} test_step_log = test_step_log or {} From 2ee7ad5abded18301b4ea40c60f709bddf448f96 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 7 May 2025 13:58:31 -0700 Subject: [PATCH 5/5] Update imports. --- mart/callbacks/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mart/callbacks/__init__.py b/mart/callbacks/__init__.py index 01b64dcf..5e648370 100644 --- a/mart/callbacks/__init__.py +++ b/mart/callbacks/__init__.py @@ -2,7 +2,8 @@ from .adversary_connector import * from .eval_mode import * from .gradients import * -from .log_metrics import * +from .logging import * +from .metrics import * from .no_grad_mode import * from .progress_bar import *