From 406c7b1c5b321969a7b6062c5daefafe73792639 Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Fri, 6 Feb 2026 14:25:21 +0000 Subject: [PATCH] Distribute PathoROB feature extraction across GPUs PathoROB was running feature extraction on a single GPU while N-1 GPUs sat idle. Switch to the same distributed pattern as StandardProbePlugin: DistributedSampler + extract_multiple_features with all_gather. Metric computation (RI, APD, clustering) stays main-process-only. Co-Authored-By: Claude Opus 4.6 --- dino/eval/plugins/pathorob.py | 81 +++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/dino/eval/plugins/pathorob.py b/dino/eval/plugins/pathorob.py index b4c8f29..00980d2 100644 --- a/dino/eval/plugins/pathorob.py +++ b/dino/eval/plugins/pathorob.py @@ -9,13 +9,14 @@ import numpy as np import pandas as pd import torch +import torch.distributed as dist import torch.nn as nn from omegaconf import DictConfig from torchvision import transforms -from dino.distributed import is_main_process +from dino.distributed import is_main_process, is_enabled_and_multiple_gpus from dino.eval.dataset import EvalDataset -from dino.eval.features import extract_features_single_process +from dino.eval.features import extract_multiple_features from dino.eval.pathorob.apd import compute_apd from dino.eval.pathorob.clustering import compute_clustering_score from dino.eval.pathorob.datasets import load_manifest @@ -36,6 +37,7 @@ def __init__(self, cfg: DictConfig, device: torch.device, output_dir: Path): self.device = device self.name = "pathorob" self.output_dir = Path(output_dir) + self.distributed = is_enabled_and_multiple_gpus() self.base_dir = self.output_dir / "pathorob" self.splits_dir = self.base_dir / "splits" @@ -93,22 +95,31 @@ def _extract_features( image_col="image_path", label_col="label", ) + + if self.distributed: + sampler = torch.utils.data.DistributedSampler(dataset, shuffle=False) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + loader = torch.utils.data.DataLoader( dataset, - sampler=torch.utils.data.SequentialSampler(dataset), + sampler=sampler, batch_size=int(self.cfg.batch_size_per_gpu), num_workers=int(self.cfg.num_workers), pin_memory=True, drop_last=False, ) - feats, _ = extract_features_single_process( + feats, _ = extract_multiple_features( student_backbone, teacher_backbone, loader, self.device, ) + if not is_main_process(): + return {} + out: Dict[str, np.ndarray] = {} for model_name in ["student", "teacher"]: tensor = feats[model_name] @@ -234,6 +245,10 @@ def _run_dataset( teacher_backbone, manifest_df, ) + + if not is_main_process(): + return [], {} + sample_to_idx = {sid: i for i, sid in enumerate(manifest_df["sample_id"].tolist())} metric_rows: List[Dict[str, Any]] = [] @@ -606,9 +621,6 @@ def _build_workspace(self, wandb, ws, wr, metric_keys: List[str]) -> None: @torch.no_grad() def run(self, student: nn.Module, teacher: nn.Module, epoch: int) -> PluginResult: - if not is_main_process(): - return PluginResult(name=self.name) - student_backbone = self._get_backbone(student) teacher_backbone = self._get_backbone(teacher) student_backbone.eval() @@ -617,12 +629,13 @@ def run(self, student: nn.Module, teacher: nn.Module, epoch: int) -> PluginResul all_rows: List[Dict[str, Any]] = [] all_logs: Dict[str, float] = {} - # add an assert to make sure that at least one dataset is enabled - any_enabled = any( - bool(dataset_cfg.enable) for dataset_cfg in self.cfg.datasets.values() - ) - if not any_enabled: - logging.warning("[PathoROB] No datasets are enabled for evaluation.") + if is_main_process(): + any_enabled = any( + bool(dataset_cfg.enable) for dataset_cfg in self.cfg.datasets.values() + ) + if not any_enabled: + logging.warning("[PathoROB] No datasets are enabled for evaluation.") + for dataset_name, dataset_cfg in self.cfg.datasets.items(): if not bool(dataset_cfg.enable): continue @@ -637,27 +650,31 @@ def run(self, student: nn.Module, teacher: nn.Module, epoch: int) -> PluginResul all_rows.extend(rows) all_logs.update(logs) except Exception as exc: - err_file = self.metrics_dir / f"epoch_{epoch+1:04d}_{dataset_name}_error.txt" - err_file.write_text(traceback.format_exc()) + if is_main_process(): + err_file = self.metrics_dir / f"epoch_{epoch+1:04d}_{dataset_name}_error.txt" + err_file.write_text(traceback.format_exc()) logger.error(f"[PathoROB] {dataset_name} failed at epoch {epoch+1}: {exc}") - # Create wandb workspace with native LinePlot panels (once) - if all_logs: - self._setup_wandb_workspace(list(all_logs.keys())) - - if all_rows: - df = pd.DataFrame(all_rows) - out_csv = self.metrics_dir / f"epoch_{epoch+1:04d}.csv" - out_json = self.metrics_dir / f"epoch_{epoch+1:04d}.json" - df.to_csv(out_csv, index=False) - out_json.write_text(json.dumps(all_rows, indent=2)) - - roll_path = self.metrics_dir / "all_metrics.csv" - if roll_path.exists(): - old = pd.read_csv(roll_path) - pd.concat([old, df], axis=0).reset_index(drop=True).to_csv(roll_path, index=False) - else: - df.to_csv(roll_path, index=False) + if self.distributed: + dist.barrier() + + if is_main_process(): + if all_logs: + self._setup_wandb_workspace(list(all_logs.keys())) + + if all_rows: + df = pd.DataFrame(all_rows) + out_csv = self.metrics_dir / f"epoch_{epoch+1:04d}.csv" + out_json = self.metrics_dir / f"epoch_{epoch+1:04d}.json" + df.to_csv(out_csv, index=False) + out_json.write_text(json.dumps(all_rows, indent=2)) + + roll_path = self.metrics_dir / "all_metrics.csv" + if roll_path.exists(): + old = pd.read_csv(roll_path) + pd.concat([old, df], axis=0).reset_index(drop=True).to_csv(roll_path, index=False) + else: + df.to_csv(roll_path, index=False) return PluginResult( name=self.name,