diff --git a/torchrec/metrics/cpu_comms_metric_module.py b/torchrec/metrics/cpu_comms_metric_module.py index 2eae617cc..ca87729ac 100644 --- a/torchrec/metrics/cpu_comms_metric_module.py +++ b/torchrec/metrics/cpu_comms_metric_module.py @@ -106,9 +106,6 @@ def _load_metric_states( Uses aggregated states. """ - # All update() calls were done prior. Clear previous computed state. - # Otherwise, we get warnings that compute() was called before - # update() which is not the case. computation = cast(RecMetricComputation, computation) set_update_called(computation) computation._computed = None @@ -157,8 +154,9 @@ def _clone_rec_metrics(self) -> RecMetricList: def set_update_called(computation: RecMetricComputation) -> None: """ - Set _update_called to True for RecMetricComputation. - This is a workaround for torchmetrics 1.0.3+. + All update() calls were done prior. Clear previous computed state. + Otherwise, we get warnings that compute() was called before + update() which is not the case. """ try: computation._update_called = True diff --git a/torchrec/metrics/cpu_offloaded_metric_module.py b/torchrec/metrics/cpu_offloaded_metric_module.py index 83a9c6e48..bb866c2cc 100644 --- a/torchrec/metrics/cpu_offloaded_metric_module.py +++ b/torchrec/metrics/cpu_offloaded_metric_module.py @@ -22,7 +22,7 @@ MetricUpdateJob, SynchronizationMarker, ) -from torchrec.metrics.metric_module import MetricValue, RecMetricModule +from torchrec.metrics.metric_module import MetricsFuture, MetricsResult, RecMetricModule from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot from torchrec.metrics.model_utils import parse_task_model_outputs from torchrec.metrics.rec_metric import RecMetricException @@ -62,6 +62,7 @@ class CPUOffloadedRecMetricModule(RecMetricModule): def __init__( self, + device: torch.device, update_queue_size: int = 100, compute_queue_size: int = 100, *args: Any, @@ -69,12 +70,19 @@ def __init__( ) -> None: """ Args: - All arguments are the same as RecMetricModule except for - - update_queue_size: Maximum size of the update queue. Default is 100. - - compute_queue_size: Maximum size of the update queue. Default is 100. + batch_size: batch size used by this trainer. + world_size: the number of trainers. + device: the device where the model is located (used to determine whether to perform GPU to CPU transfers). + update_queue_size: Maximum size of the update queue. Default is 100. + compute_queue_size: Maximum size of the update queue. Default is 100. + *args: Additional positional arguments passed to RecMetricModule. + **kwargs: Additional keyword arguments passed to RecMetricModule. """ super().__init__(*args, **kwargs) - self._shutdown_event = threading.Event() + self._device = device + self._shutdown_event: threading.Event = threading.Event() + self._captured_exception_event: threading.Event = threading.Event() + self._captured_exception: Optional[Exception] = None self.update_queue: queue.Queue[ Union[MetricUpdateJob, SynchronizationMarker] @@ -132,8 +140,16 @@ def _update_rec_metrics( if self._shutdown_event.is_set(): raise RecMetricException("metric processor thread is shut down.") + if self._captured_exception_event.is_set(): + assert self._captured_exception is not None + raise self._captured_exception + try: - cpu_model_out, transfer_completed_event = self._transfer_to_cpu(model_out) + cpu_model_out, transfer_completed_event = ( + self._transfer_to_cpu(model_out) + if self._device == torch.device("cuda") + else (model_out, None) + ) self.update_queue.put_nowait( MetricUpdateJob( model_out=cpu_model_out, @@ -191,31 +207,25 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None """ with record_function("## CPUOffloadedRecMetricModule:update ##"): - try: + if metric_update_job.transfer_completed_event is not None: metric_update_job.transfer_completed_event.synchronize() - labels, predictions, weights, required_inputs = ( - parse_task_model_outputs( - self.rec_tasks, - metric_update_job.model_out, - self.get_required_inputs(), - ) - ) - if required_inputs: - metric_update_job.kwargs["required_inputs"] = required_inputs - - self.rec_metrics.update( - predictions=predictions, - labels=labels, - weights=weights, - **metric_update_job.kwargs, - ) - - if self.throughput_metric: - self.throughput_metric.update() + labels, predictions, weights, required_inputs = parse_task_model_outputs( + self.rec_tasks, + metric_update_job.model_out, + self.get_required_inputs(), + ) + if required_inputs: + metric_update_job.kwargs["required_inputs"] = required_inputs + + self.rec_metrics.update( + predictions=predictions, + labels=labels, + weights=weights, + **metric_update_job.kwargs, + ) - except Exception as e: - logger.exception("Error processing metric update: %s", e) - raise e + if self.throughput_metric: + self.throughput_metric.update() @override def shutdown(self) -> None: @@ -248,30 +258,34 @@ def shutdown(self) -> None: logger.info("CPUOffloadedRecMetricModule has been successfully shutdown.") @override - def compute(self) -> Dict[str, MetricValue]: + def compute(self) -> MetricsResult: raise RecMetricException( - "compute() is not supported in CPUOffloadedRecMetricModule. Use async_compute() instead." + "CPUOffloadedRecMetricModule does not support compute(). Use async_compute() instead." ) @override - def async_compute( - self, future: concurrent.futures.Future[Dict[str, MetricValue]] - ) -> None: + def async_compute(self) -> MetricsFuture: """ Entry point for asynchronous metric compute. It enqueues a synchronization marker to the update queue. - Args: + Returns: future: Pre-created future where the computed metrics will be set. """ + metrics_future = concurrent.futures.Future() if self._shutdown_event.is_set(): - future.set_exception( + metrics_future.set_exception( RecMetricException("metric processor thread is shut down.") ) - return + return metrics_future + + if self._captured_exception_event.is_set(): + assert self._captured_exception is not None + raise self._captured_exception - self.update_queue.put_nowait(SynchronizationMarker(future)) + self.update_queue.put_nowait(SynchronizationMarker(metrics_future)) self.update_queue_size_logger.add(self.update_queue.qsize()) + return metrics_future def _process_synchronization_marker( self, synchronization_marker: SynchronizationMarker @@ -304,7 +318,7 @@ def _process_synchronization_marker( def _process_metric_compute_job( self, metric_compute_job: MetricComputeJob - ) -> Dict[str, MetricValue]: + ) -> MetricsResult: """ Process a metric compute job: 1. Comms module performs all gather @@ -355,6 +369,8 @@ def _update_loop(self) -> None: self._do_work(self.update_queue) except Exception as e: logger.exception(f"Exception in update loop: {e}") + self._captured_exception_event.set() + self._captured_exception = e raise e remaining = self._flush_remaining_work(self.update_queue) @@ -372,6 +388,8 @@ def _compute_loop(self) -> None: self._do_work(self.compute_queue) except Exception as e: logger.exception(f"Exception in compute loop: {e}") + self._captured_exception_event.set() + self._captured_exception = e raise e remaining = self._flush_remaining_work(self.compute_queue) diff --git a/torchrec/metrics/metric_job_types.py b/torchrec/metrics/metric_job_types.py index 5c0ef58f8..2e785cb4d 100644 --- a/torchrec/metrics/metric_job_types.py +++ b/torchrec/metrics/metric_job_types.py @@ -8,7 +8,7 @@ # pyre-strict import concurrent -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from torchrec.metrics.metric_module import MetricValue @@ -26,7 +26,7 @@ class MetricUpdateJob: def __init__( self, model_out: Dict[str, torch.Tensor], - transfer_completed_event: torch.cuda.Event, + transfer_completed_event: Optional[torch.cuda.Event], kwargs: Dict[str, Any], ) -> None: """ @@ -37,7 +37,9 @@ def __init__( """ self.model_out: Dict[str, torch.Tensor] = model_out - self.transfer_completed_event: torch.cuda.Event = transfer_completed_event + self.transfer_completed_event: Optional[torch.cuda.Event] = ( + transfer_completed_event + ) self.kwargs: Dict[str, Any] = kwargs diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 06ce6e3a3..a76223fc5 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -21,8 +21,20 @@ import torch.nn as nn from torch.distributed.tensor import DeviceMesh from torch.profiler import record_function +from torchmetrics.utilities.data import ( + dim_zero_cat, + dim_zero_max, + dim_zero_mean, + dim_zero_min, + dim_zero_sum, +) from torchrec.metrics.accuracy import AccuracyMetric -from torchrec.metrics.auc import AUCMetric + +from torchrec.metrics.auc import ( + _grouping_keys_state_reduction, + _state_reduction, + AUCMetric, +) from torchrec.metrics.auprc import AUPRCMetric from torchrec.metrics.cali_free_ne import CaliFreeNEMetric from torchrec.metrics.calibration import CalibrationMetric @@ -74,6 +86,65 @@ logger: logging.Logger = logging.getLogger(__name__) +# TorchRec-specific custom reduction functions. +# These work correctly with local+global reduction pattern. +# Requirements: Associative AND (Commutative OR post-processing makes result order-invariant) +SAFE_CALLABLE_REDUCTIONS: frozenset[Any] = frozenset( + { + _state_reduction, # Concatenation + AUC sorts data, making final result order-invariant + _grouping_keys_state_reduction, # Concatenation along dim=0 + sorting makes result order-invariant + } +) + +# torchmetrics.Metric built-in reduction functions. +# All dim_zero_* functions are both associative and commutative (dim_zero_cat is not commutative +# but torchmetrics.Metric also reduce before sync_dist to reduce number of collectives). +TORCHMETRICS_REDUCTIONS: frozenset[Any] = frozenset( + { + dim_zero_sum, + dim_zero_mean, + dim_zero_max, + dim_zero_min, + dim_zero_cat, + } +) + + +def _validate_reduction_function( + reduction_fn: Union[str, Any, None], + state_name: str, + metric_namespace: str, +) -> None: + """ + Validate that a reduction function is safe for local+global reduction pattern. + + Only validates custom reduction functions. TorchMetrics built-in functions + (dim_zero_*) are skipped as they're safe by construction (all are associative & commutative). + + Mathematical Requirements: + 1. **Associativity**: f([f([a,b]), f([c,d])]) = f([a,b,c,d]) + - Required so local reduction + global reduction = direct reduction + + 2. **Commutativity**: f([a, b]) = f([b, a]) + - Required so rank ordering doesn't affect the result + - OR the metric's computation must make the final result order-invariant + (e.g., AUC concatenates in rank order but sorts before computing, making final result order-invariant) + """ + # Skip validation for None and torchmetrics.Metric built-in functions (safe by construction) + if reduction_fn is None or reduction_fn in TORCHMETRICS_REDUCTIONS: + return + + # Validate custom callable reductions + if callable(reduction_fn): + if reduction_fn not in SAFE_CALLABLE_REDUCTIONS: + raise RecMetricException( + f"Unknown custom reduction '{reduction_fn}' for state '{state_name}' in '{metric_namespace}'. " + f"Must be associative: f([f([a,b]), f([c,d])]) == f([a,b,c,d]) " + f"AND commutative: f([a,b]) == f([b,a]) (or metric makes result order-invariant). " + f"Known safe custom reductions: {[f for f in SAFE_CALLABLE_REDUCTIONS if f not in TORCHMETRICS_REDUCTIONS]}. " + f"Add to SAFE_CALLABLE_REDUCTIONS if verified safe." + ) + REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = { RecMetricEnum.NE: NEMetric, @@ -117,6 +188,9 @@ MetricValue = Union[torch.Tensor, float] +MetricsResult = Dict[str, MetricValue] +MetricsFuture = concurrent.futures.Future[MetricsResult] +MetricsOutput = Union[MetricsResult, MetricsFuture] class StateMetric(abc.ABC): @@ -125,7 +199,7 @@ class StateMetric(abc.ABC): """ @abc.abstractmethod - def get_metrics(self) -> Dict[str, MetricValue]: + def get_metrics(self) -> MetricsResult: pass @@ -208,6 +282,8 @@ def __init__( self.oom_count = 0 self.compute_count = 0 + self._validate_all_reduction_functions() + self.compute_interval_steps = compute_interval_steps self.min_compute_interval = min_compute_interval self.max_compute_interval = max_compute_interval @@ -230,6 +306,20 @@ def __init__( self._register_load_state_dict_pre_hook(self.load_state_dict_hook) + def _validate_all_reduction_functions(self) -> None: + """ + Validate all reduction functions in rec_metrics during initialization. + This ensures that all reduction functions are safe for the local+global reduction pattern. + """ + for metric in self.rec_metrics.rec_metrics: + for computation in metric._metrics_computations: # pyre-ignore[16] + for state_name, reduction_fn in computation._reductions.items(): # pyre-ignore[16] + _validate_reduction_function( + reduction_fn, + state_name, + metric._namespace.value, # pyre-ignore[16] + ) + def load_state_dict_hook( self, state_dict: OrderedDict[str, torch.Tensor], @@ -335,12 +425,12 @@ def _adjust_compute_interval(self) -> None: def should_compute(self) -> bool: return self.trained_batches % self.compute_interval_steps == 0 - def compute(self) -> Dict[str, MetricValue]: + def compute(self) -> MetricsResult: r"""compute() is called when the global metrics are required, usually right before logging the metrics results to the data sink. """ self.compute_count += 1 - ret: Dict[str, MetricValue] = {} + ret: MetricsResult = {} with record_function("## RecMetricModule:compute ##"): if self.rec_metrics: self._adjust_compute_interval() @@ -357,11 +447,11 @@ def compute(self) -> Dict[str, MetricValue]: ) return ret - def local_compute(self) -> Dict[str, MetricValue]: + def local_compute(self) -> MetricsResult: r"""local_compute() is called when per-trainer metrics are required. It's can be used for debugging. Currently only rec_metrics is supported. """ - ret: Dict[str, MetricValue] = {} + ret: MetricsResult = {} if self.rec_metrics: ret.update(self.rec_metrics.local_compute()) return ret @@ -398,22 +488,24 @@ def _get_metric_states( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `items`. for state_name, reduction_fn in computation._reductions.items(): - tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr( - computation, state_name - ) - - if isinstance(tensor_or_list, list): - gathered = _all_gather_tensor_list( - tensor_or_list, world_size, process_group - ) - else: - gathered = torch.stack( - _all_gather_tensor(tensor_or_list, world_size, process_group) + with record_function(f"## RecMetricModule: {state_name} all gather ##"): + tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr( + computation, state_name ) - reduced = ( - reduction_fn(gathered) if reduction_fn is not None else gathered - ) - result[task.name][state_name] = reduced + + if isinstance(tensor_or_list, list): + local_reduced = reduction_fn(tensor_or_list) + gathered = _all_gather_tensor_list( + local_reduced, world_size, process_group + ) + else: + gathered = torch.stack( + _all_gather_tensor( + tensor_or_list, world_size, process_group + ) + ) + global_reduced = reduction_fn(gathered) + result[task.name][state_name] = global_reduced return result @@ -462,7 +554,8 @@ def get_pre_compute_states( # throughput metric requires special handling, since it's not a RecMetric throughput_metric = self.throughput_metric if throughput_metric is not None: - aggregated_states[throughput_metric._namespace.value] = ( + # Merge in case there are rec metric namespaces that overlap with throughput metric namespace + aggregated_states.setdefault(throughput_metric._namespace.value, {}).update( self._get_throughput_metric_states(throughput_metric) ) @@ -512,9 +605,7 @@ def load_pre_compute_states( def shutdown(self) -> None: logger.info("Initiating graceful shutdown...") - def async_compute( - self, future: concurrent.futures.Future[Dict[str, MetricValue]] - ) -> None: + def async_compute(self) -> MetricsFuture: raise RecMetricException("async_compute is not supported in RecMetricModule") @@ -658,8 +749,23 @@ def _all_gather_tensor_list( world_size: int, pg: Union[dist.ProcessGroup, DeviceMesh], ) -> List[torch.Tensor]: - """All-gather every tensor in a list and flatten the result.""" - gathered: List[torch.Tensor] = [] # pragma: no cover + """ + All-gather every tensor in a list and flatten the result. + + Note: In the current implementation with local reduction in _get_metric_states, + this function should only receive a list with at most 1 tensor after local reduction. + """ + if not tensors: + return [] + + # After local reduction in _get_metric_states, tensors should contain at most 1 element + if len(tensors) > 1: + raise ValueError( + f"_all_gather_tensor_list expected at most 1 tensor after local reduction, " + f"but received {len(tensors)} tensors. This indicates a bug in _get_metric_states." + ) + + gathered: List[torch.Tensor] = [] for t in tensors: gathered.extend(_all_gather_tensor(t, world_size, pg)) return gathered diff --git a/torchrec/metrics/metrics_output_util.py b/torchrec/metrics/metrics_output_util.py new file mode 100644 index 000000000..6129fe38a --- /dev/null +++ b/torchrec/metrics/metrics_output_util.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Utility functions for handling MetricsOutput (Union[MetricsResult, MetricsFuture]) from +- RecMetricModule.compute() +- CPUOffloadedRecMetricModule.async_compute() +""" + +import concurrent +import logging +from typing import Callable, TypeVar + +from torchrec.metrics.metric_module import MetricsFuture, MetricsOutput, MetricsResult + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +def get_metrics_async( + metrics_output: MetricsOutput, + callback: Callable[[MetricsResult], T], + *, + on_error: Callable[[Exception], None] | None = None, +) -> T | None: + """ + Register a callback to execute when metrics are ready. + + Preserves CPUOffloadedRecMetricModule's async benefits by executing callbacks when Future resolves, + without blocking the critical training path. + + Args: + metrics_output: Either metrics dict (sync from RecMetricModule) or Future (async from CPUOffloadedRecMetricModule) + callback: Function to execute with resolved metrics + on_error: Optional error handler for exceptions + + Returns: + Result of callback if metrics are immediately available (Dict[str, MetricValue]), + None if async (Future) - callback will be invoked later + """ + + # Asynchronous path + if isinstance(metrics_output, concurrent.futures.Future): + + def on_complete(future: MetricsFuture) -> None: + try: + result = future.result() + callback(result) + except Exception as e: + if on_error: + on_error(e) + else: + logger.exception("Error in metrics callback") + raise + + metrics_output.add_done_callback(on_complete) + return None + else: + # Synchronous path + return callback(metrics_output) + + +def get_metrics_sync( + metrics_output: MetricsOutput, + timeout: float | None = None, +) -> MetricsResult: + """ + Synchronously resolve MetricsOutput to MetricsResult. + + Use this when you need the actual metrics dict immediately (e.g., to modify it). + For async handling, use get_metrics_async() instead. + + Args: + metrics_output: Either metrics dict (sync) or Future (async) + timeout: Optional timeout in seconds for Future resolution + + Returns: + Resolved metrics dict + + Raises: + TimeoutError: If Future doesn't resolve within timeout (if specified) + Exception: Any exception from Future computation + + Example: + >>> metrics_output = self.metrics.compute() + >>> metrics_result = resolve_metrics(metrics_output) # wait until metrics are ready + >>> publish_metrics(metrics_result) + """ + if isinstance(metrics_output, concurrent.futures.Future): + return metrics_output.result(timeout=timeout) + else: + return metrics_output diff --git a/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py b/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py index 3c3d557e3..320b59383 100644 --- a/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py +++ b/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py @@ -7,13 +7,12 @@ # pyre-strict -import concurrent.futures import os import queue import threading import time import unittest -from typing import Callable, cast, Dict +from typing import Callable, cast from unittest.mock import patch import torch @@ -23,7 +22,7 @@ CPUOffloadedRecMetricModule, MetricUpdateJob, ) -from torchrec.metrics.metric_module import MetricValue, RecMetricModule +from torchrec.metrics.metric_module import RecMetricModule from torchrec.metrics.rec_metric import RecMetricException, RecMetricList from torchrec.metrics.test_utils import gen_test_tasks from torchrec.metrics.test_utils.mock_metrics import ( @@ -75,6 +74,7 @@ def setUp(self) -> None: self.cpu_module: CPUOffloadedRecMetricModule = CPUOffloadedRecMetricModule( batch_size=self.batch_size, world_size=self.world_size, + device=torch.device("cpu"), rec_tasks=self.tasks, rec_metrics=self.rec_metrics, throughput_metric=ThroughputMetric( @@ -85,7 +85,8 @@ def setUp(self) -> None: ) def tearDown(self) -> None: - dist.destroy_process_group() + if dist.is_initialized(): + dist.destroy_process_group() if hasattr(self, "cpu_module"): try: self.cpu_module.shutdown() @@ -155,6 +156,7 @@ def test_update_rec_metrics_queue_full(self) -> None: cpu_module = CPUOffloadedRecMetricModule( batch_size=self.batch_size, world_size=self.world_size, + device=torch.device("cuda"), rec_tasks=self.tasks, rec_metrics=self.rec_metrics, update_queue_size=1, # Small queue size @@ -191,7 +193,7 @@ def controlled_process_job(_: MetricUpdateJob) -> None: def test_sync_compute_raises_exception(self) -> None: self.assertRaisesRegex( RecMetricException, - "compute\\(\\) is not supported in CPUOffloadedRecMetricModule.", + "CPUOffloadedRecMetricModule does not support compute\\(\\). Use async_compute\\(\\) instead.", self.cpu_module.compute, ) @@ -207,10 +209,6 @@ def test_async_compute_synchronization_marker(self) -> None: Note that the comms module's metrics are actually the ones that are computed. """ - future: concurrent.futures.Future[Dict[str, MetricValue]] = ( - concurrent.futures.Future() - ) - model_out = { "task1-prediction": torch.tensor([0.5]), "task1-label": torch.tensor([0.7]), @@ -220,7 +218,7 @@ def test_async_compute_synchronization_marker(self) -> None: for _ in range(10): self.cpu_module.update(model_out) - self.cpu_module.async_compute(future) + self.cpu_module.async_compute() comms_mock_metric = cast( MockRecMetric, self.cpu_module.comms_module.rec_metrics.rec_metrics[0] @@ -234,10 +232,7 @@ def test_async_compute_synchronization_marker(self) -> None: def test_async_compute_after_shutdown(self) -> None: self.cpu_module.shutdown() - future: concurrent.futures.Future[Dict[str, MetricValue]] = ( - concurrent.futures.Future() - ) - self.cpu_module.async_compute(future) + future = self.cpu_module.async_compute() self.assertRaisesRegex( RecMetricException, "metric processor thread is shut down.", future.result @@ -275,7 +270,7 @@ def test_wait_until_queue_is_empty(self) -> None: "task1-weight": torch.tensor([1.0]), } self.cpu_module.update(model_out) - self.cpu_module.async_compute(concurrent.futures.Future()) + self.cpu_module.async_compute() self.cpu_module.wait_until_queue_is_empty(self.cpu_module.update_queue) self.cpu_module.wait_until_queue_is_empty(self.cpu_module.compute_queue) @@ -283,6 +278,84 @@ def test_wait_until_queue_is_empty(self) -> None: self.assertTrue(self.cpu_module.update_queue.empty()) self.assertTrue(self.cpu_module.compute_queue.empty()) + def test_update_thread_exception_captured(self) -> None: + """ + Test that exceptions in update thread are: + 1. Captured in _captured_exception + 2. Cause the update thread to terminate + 3. Main thread raises it on the next update() call + """ + test_exception = RuntimeError("Test exception from update thread") + + with patch.object( + self.cpu_module, + "_process_metric_update_job", + side_effect=test_exception, + ): + model_out = { + "task1-prediction": torch.tensor([0.5]), + "task1-label": torch.tensor([0.7]), + "task1-weight": torch.tensor([1.0]), + } + + self.cpu_module.update(model_out) + + # Wait for exception to be captured + captured = self.cpu_module._captured_exception_event.wait(timeout=5.0) + + self.assertTrue(captured, "Exception event should be set") + self.assertIsNotNone(self.cpu_module._captured_exception) + self.assertIsInstance(self.cpu_module._captured_exception, RuntimeError) + self.assertEqual( + str(self.cpu_module._captured_exception), + "Test exception from update thread", + ) + + self.cpu_module.update_thread.join(timeout=5.0) + self.assertFalse( + self.cpu_module.update_thread.is_alive(), + "Update thread should have terminated after exception", + ) + + with self.assertRaises(RuntimeError): + self.cpu_module.update(model_out) + + def test_compute_thread_exception_captured(self) -> None: + """ + Test that exceptions in compute thread are: + 1. Captured in _captured_exception + 2. Cause the compute thread to terminate + 3. Main thread raises it on the next compute() call + """ + test_exception = RuntimeError("Test exception from compute thread") + + with patch.object( + self.cpu_module, + "_process_metric_compute_job", + side_effect=test_exception, + ): + self.cpu_module.async_compute() + + # Wait for exception to be captured + captured = self.cpu_module._captured_exception_event.wait(timeout=5.0) + + self.assertTrue(captured, "Exception event should be set") + self.assertIsNotNone(self.cpu_module._captured_exception) + self.assertIsInstance(self.cpu_module._captured_exception, RuntimeError) + self.assertEqual( + str(self.cpu_module._captured_exception), + "Test exception from compute thread", + ) + + self.cpu_module.compute_thread.join(timeout=5.0) + self.assertFalse( + self.cpu_module.compute_thread.is_alive(), + "compute thread should have terminated after exception", + ) + + with self.assertRaises(RuntimeError): + self.cpu_module.async_compute() + # pyre-ignore[56] @unittest.skipIf( torch.cuda.device_count() < 1, @@ -313,6 +386,7 @@ def test_state_dict_save_load(self) -> None: offloaded_module = CPUOffloadedRecMetricModule( batch_size=self.batch_size, world_size=self.world_size, + device=torch.device("cuda"), rec_tasks=self.tasks, rec_metrics=RecMetricList([offloaded_metric]), ) @@ -379,6 +453,7 @@ def test_sync(self) -> None: offloaded_module = CPUOffloadedRecMetricModule( batch_size=self.batch_size, world_size=self.world_size, + device=torch.device("cuda"), rec_tasks=self.tasks, rec_metrics=RecMetricList([offloaded_metric]), ) @@ -533,10 +608,10 @@ def _compare_metric_results_worker( rec_metrics=RecMetricList([standard_metric]), ).to(device) - # Create CPUOffloadedRecMetricModule (automatically stays on CPU) cpu_offloaded_module = CPUOffloadedRecMetricModule( batch_size=batch_size, world_size=world_size, + device=torch.device("cuda"), rec_tasks=tasks, rec_metrics=RecMetricList([offloaded_metric]), ).to(device) @@ -576,10 +651,7 @@ def _compare_metric_results_worker( standard_results = standard_module.compute() - future: concurrent.futures.Future[Dict[str, MetricValue]] = ( - concurrent.futures.Future() - ) - cpu_offloaded_module.async_compute(future) + future = cpu_offloaded_module.async_compute() # Wait for async compute to finish. Compare the input to each update() offloaded_results = future.result(timeout=10.0) diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index 74e549aa7..31ac3d764 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -7,7 +7,6 @@ # pyre-strict -import concurrent import copy import dataclasses import logging @@ -24,10 +23,10 @@ MultiProcessContext, MultiProcessTestBase, ) -from torchrec.metrics.auc import AUCMetric +from torchrec.metrics.auc import _state_reduction, AUCMetric from torchrec.metrics.metric_module import ( generate_metric_module, - MetricValue, + MetricsResult, RecMetricModule, StateMetric, StateMetricEnum, @@ -44,7 +43,8 @@ ) from torchrec.metrics.model_utils import parse_task_model_outputs from torchrec.metrics.rec_metric import RecMetricException, RecMetricList, RecTaskInfo -from torchrec.metrics.test_utils import gen_test_batch +from torchrec.metrics.test_utils import gen_test_batch, gen_test_tasks +from torchrec.metrics.test_utils.mock_metrics import MockRecMetric from torchrec.metrics.throughput import ThroughputMetric from torchrec.test_utils import get_free_port, seed_and_log, skip_if_asan_class @@ -55,7 +55,7 @@ class MockOptimizer(StateMetric): def __init__(self) -> None: self.get_metrics_call = 0 - def get_metrics(self) -> Dict[str, MetricValue]: + def get_metrics(self) -> MetricsResult: self.get_metrics_call += 1 return {"learning_rate": torch.tensor(1.0)} @@ -662,7 +662,7 @@ def test_async_compute_raises_exception(self) -> None: RecMetricException, "async_compute is not supported in RecMetricModule", ): - metric_module.async_compute(concurrent.futures.Future()) + metric_module.async_compute() def test_load_state_dict_with_trained_batches_key(self) -> None: metric_module = generate_metric_module( @@ -710,6 +710,53 @@ def test_load_state_dict_without_trained_batches_key(self) -> None: self.assertIsInstance(result, dict) self.assertTrue(len(result) > 0) + def test_invalid_reduction_function(self) -> None: + """Test that invalid reduction functions raise RecMetricException.""" + + def invalid_reduction_fn(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """Custom reduction that's not in whitelist.""" + return [torch.cat(tensors, dim=0)] + + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=1, + my_rank=0, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=invalid_reduction_fn, + initial_states={"predictions": []}, + ) + + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]])} + ) + + with self.assertRaisesRegex(RecMetricException, "Unknown custom reduction"): + RecMetricModule( + batch_size=10, + world_size=1, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + def test_all_gather_tensor_list_validates_single_tensor(self) -> None: + """Test that _all_gather_tensor_list raises ValueError when given >1 tensor after local reduction.""" + from torchrec.metrics.metric_module import _all_gather_tensor_list + + # Create multiple tensors (simulates bug where local reduction didn't happen) + tensors = [ + torch.tensor([[1.0, 2.0]]), + torch.tensor([[3.0, 4.0]]), + ] + + with self.assertRaisesRegex(ValueError, "expected at most 1 tensor"): + _all_gather_tensor_list( + tensors=tensors, + world_size=1, + pg=dist.group.WORLD, + ) + def metric_module_gather_state( rank: int, @@ -815,7 +862,7 @@ def test_post_init_raises_when_rec_tasks_is_none(self) -> None: # Execute & Assert: should raise ValueError about rec_tasks being None with self.assertRaises(ValueError) as context: - config = MetricsConfig( + _ = MetricsConfig( rec_tasks=None, # pyre-ignore[6]: Intentionally passing None for testing rec_metrics={ RecMetricEnum.AUC: RecMetricDef(rec_task_indices=[0]), @@ -835,7 +882,7 @@ def test_post_init_raises_when_rec_task_index_out_of_range(self) -> None: # Execute & Assert: should raise ValueError about index out of range with self.assertRaises(ValueError) as context: - config = MetricsConfig( + _ = MetricsConfig( rec_tasks=rec_tasks, rec_metrics={ RecMetricEnum.NE: RecMetricDef( @@ -889,3 +936,230 @@ def test_metric_module_gather_state(self) -> None: batch_size=batch_size, config=metrics_config, ) + + def test_get_metric_states_list_reduction(self) -> None: + """ + Test _get_metric_states with list states and concatenation reduction. + Validates the double-reduction optimization (local + global) for AUC-like metrics. + """ + world_size = 2 + backend = "nccl" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_list_reduction, + world_size=world_size, + backend=backend, + ) + + def test_get_metric_states_tensor_reduction(self) -> None: + """ + Test _get_metric_states with tensor states and sum reduction. + Validates standard reduction for NE-like metrics. + """ + world_size = 2 + backend = "nccl" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_tensor_reduction, + world_size=world_size, + backend=backend, + ) + + def test_get_metric_states_single_tensor(self) -> None: + """ + Test _get_metric_states with a single tensor in the list. + Edge case validation. + """ + world_size = 2 + backend = "nccl" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_single_tensor, + world_size=world_size, + backend=backend, + ) + + +def _test_get_metric_states_with_list_reduction( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with list states and concatenation reduction (AUC-like).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + # Create mock metric with list state using concatenation reduction + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=_state_reduction, + initial_states={"predictions": []}, + ) + + # Each rank appends different local tensors to simulate batch updates + # Rank 0: [[1, 2], [3, 4]] -> after local concat: [[1, 2, 3, 4]] + # Rank 1: [[5, 6], [7, 8]] -> after local concat: [[5, 6, 7, 8]] + # After global gather+concat: [[1, 2, 3, 4, 5, 6, 7, 8]] + if rank == 0: + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[3.0, 4.0]], device=ctx.device)} + ) + else: # rank == 1 + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[5.0, 6.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[7.0, 8.0]], device=ctx.device)} + ) + + # Execute: Call _get_metric_states + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: Verify result matches expected + # Expected: After local reduction + gather + global reduction + # Order: [rank0_local_concat, rank1_local_concat] concatenated + # = [[1, 2, 3, 4], [5, 6, 7, 8]] concatenated = [[1, 2, 3, 4, 5, 6, 7, 8]] + expected = [ + torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]], device=ctx.device) + ] + + actual = result["task1"]["predictions"] + + assert len(actual) == len( + expected + ), f"Expected {len(expected)} tensors, got {len(actual)}" + torch.testing.assert_close( + actual[0], + expected[0], + msg="Mismatch in gathered predictions", + ) + + +def _test_get_metric_states_with_tensor_reduction( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with tensor states and sum reduction (NE-like).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + # Create mock metric with tensor state using sum reduction + tasks = gen_test_tasks(["task1"]) + initial_value = torch.tensor( + [float(rank + 1)], device=ctx.device + ) # Rank 0: [1.0], Rank 1: [2.0] + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=False, + reduction_fn="sum", + initial_states={"state1": initial_value}, + ) + + # Execute: Call _get_metric_states + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: Verify result matches expected + # Expected: sum([rank0_value, rank1_value]) = sum([1.0, 2.0]) = 3.0 + expected = torch.tensor([3.0], device=ctx.device) + + actual = result["task1"]["state1"] + + torch.testing.assert_close( + actual, + expected, + msg="Mismatch in summed state", + ) + + +def _test_get_metric_states_with_single_tensor( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with a single tensor in the list (edge case).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + + # Create mock metric with list state containing a single tensor + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=_state_reduction, + initial_states={"predictions": []}, + ) + + # Each rank has a single tensor + # Rank 0: [[1, 2]] + # Rank 1: [[3, 4]] + # After local reduction (no-op since single tensor): [[1, 2]] and [[3, 4]] + # After global gather+concat: [[1, 2, 3, 4]] + if rank == 0: + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]], device=ctx.device)} + ) + else: # rank == 1 + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[3.0, 4.0]], device=ctx.device)} + ) + + # Execute: Call _get_metric_states + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: Verify result matches expected + # Expected: [[1, 2, 3, 4]] + expected = [torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=ctx.device)] + + actual = result["task1"]["predictions"] + + assert len(actual) == len( + expected + ), f"Expected {len(expected)} tensors, got {len(actual)}" + torch.testing.assert_close( + actual[0], + expected[0], + msg="Mismatch in gathered predictions for single tensor case", + ) diff --git a/torchrec/metrics/tests/test_metrics_output_util.py b/torchrec/metrics/tests/test_metrics_output_util.py new file mode 100644 index 000000000..9989ffb76 --- /dev/null +++ b/torchrec/metrics/tests/test_metrics_output_util.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Tests for metrics_output_util - utilities for handling MetricsOutput (dict or Future). + +These tests focus on our utility functions' behavior, not on testing Python's +concurrent.futures module which is already well-tested. +""" + +import unittest +from concurrent.futures import Future +from typing import Callable + +import torch +from torchrec.metrics.metric_module import MetricsFuture, MetricsResult +from torchrec.metrics.metrics_output_util import get_metrics_async, get_metrics_sync + + +class OnMetricsReadyTest(unittest.TestCase): + """Tests that get_metrics_async correctly dispatches between sync/async paths.""" + + def setUp(self) -> None: + self.metrics: MetricsResult = {"loss": torch.tensor(0.5)} + self.received_metrics: MetricsResult | None = None + self.received_error: Exception | None = None + + def _callback(self, metrics: MetricsResult) -> None: + self.received_metrics = metrics + + def _error_handler(self, error: Exception) -> None: + self.received_error = error + + def test_synchronous_dict_path(self) -> None: + result = get_metrics_async(self.metrics, self._callback) + + self.assertIsNone(result) + self.assertIs(self.received_metrics, self.metrics) + + def test_synchronous_dict_returns_callback_value(self) -> None: + result = get_metrics_async(self.metrics, lambda metrics: "success") + self.assertEqual(result, "success") + + def test_asynchronous_future_path(self) -> None: + future: MetricsFuture = Future() + + get_metrics_async(future, self._callback) + self.assertIsNone(self.received_metrics) + + future.set_result(self.metrics) + self.assertEqual(self.received_metrics, self.metrics) + + def test_error_handler_receives_callback_exceptions(self) -> None: + """Error handler receives exceptions raised by callbacks.""" + future: MetricsFuture = Future() + + def failing_callback(metrics: MetricsResult) -> None: + raise ValueError("callback failed") + + get_metrics_async(future, failing_callback, on_error=self._error_handler) + future.set_result(self.metrics) + + self.assertIsInstance(self.received_error, ValueError) + + def test_error_handler_receives_future_exceptions(self) -> None: + """Error handler receives exceptions from Future resolution.""" + future: MetricsFuture = Future() + + get_metrics_async(future, self._callback, on_error=self._error_handler) + future.set_exception(RuntimeError("computation failed")) + + self.assertIsInstance(self.received_error, RuntimeError) + + +class GetMetricsSyncTest(unittest.TestCase): + """Tests that get_metrics_sync correctly handles both dict and Future inputs.""" + + def setUp(self) -> None: + self.metrics: MetricsResult = {"loss": torch.tensor(0.5)} + + def test_dict_returns_immediately(self) -> None: + result = get_metrics_sync(self.metrics) + self.assertEqual(result, self.metrics) + + def test_future_blocks_until_resolved(self) -> None: + future: MetricsFuture = Future() + future.set_result(self.metrics) + + result = get_metrics_sync(future) + self.assertEqual(result, self.metrics) + + def test_future_exception_propagates(self) -> None: + """Exceptions from Future are propagated to caller.""" + future: MetricsFuture = Future() + future.set_exception(RuntimeError("failed")) + + with self.assertRaises(RuntimeError): + get_metrics_sync(future) + + def test_timeout_raises_timeout_error(self) -> None: + """Timeout on unresolved Future raises TimeoutError.""" + future: MetricsFuture = Future() + + with self.assertRaises(TimeoutError): + get_metrics_sync(future, timeout=0.001) + + +class MultipleCallbacksTest(unittest.TestCase): + """Tests that multiple callbacks can be attached and all execute correctly.""" + + def setUp(self) -> None: + self.sample_data: MetricsResult = { + "metric_a": torch.tensor(1.0), + "metric_b": torch.tensor(2.0), + } + self.callback_executions: dict[str, bool] = { + "callback_1": False, + "callback_2": False, + "callback_3": False, + } + self.extracted_value: float | None = None + + def _make_tracking_callback(self, name: str) -> Callable[[MetricsResult], None]: + """Factory for callbacks that track execution.""" + + def callback(metrics: MetricsResult) -> None: + self.callback_executions[name] = True + + return callback + + def _value_extraction_callback(self, metrics: MetricsResult) -> None: + """Callback that extracts a value from metrics.""" + metric = metrics.get("metric_a") + if isinstance(metric, torch.Tensor): + self.extracted_value = metric.item() + + def test_multiple_callbacks_on_future(self) -> None: + """Multiple callbacks attached to same Future all execute when resolved.""" + future: Future[MetricsResult] = Future() + + # Attach multiple callbacks to the same Future + get_metrics_async(future, self._make_tracking_callback("callback_1")) + get_metrics_async(future, self._make_tracking_callback("callback_2")) + get_metrics_async(future, self._make_tracking_callback("callback_3")) + get_metrics_async(future, self._value_extraction_callback) + + # Verify no callbacks executed yet + self.assertEqual( + self.callback_executions, + {"callback_1": False, "callback_2": False, "callback_3": False}, + ) + self.assertIsNone(self.extracted_value) + + # Resolve the Future + future.set_result(self.sample_data) + + # Verify all callbacks executed + self.assertEqual( + self.callback_executions, + {"callback_1": True, "callback_2": True, "callback_3": True}, + ) + self.assertIsNotNone(self.extracted_value) + self.assertAlmostEqual(self.extracted_value, 1.0, places=5) + + def test_multiple_callbacks_on_dict(self) -> None: + """Multiple callbacks with dict input all execute immediately.""" + get_metrics_async(self.sample_data, self._make_tracking_callback("callback_1")) + get_metrics_async(self.sample_data, self._make_tracking_callback("callback_2")) + get_metrics_async(self.sample_data, self._value_extraction_callback) + + # Verify all callbacks executed immediately + self.assertEqual( + self.callback_executions, + {"callback_1": True, "callback_2": True, "callback_3": False}, + ) + self.assertIsNotNone(self.extracted_value) + self.assertAlmostEqual(self.extracted_value, 1.0, places=5)