Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions dino/eval/plugins/pathorob.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn as nn
from omegaconf import DictConfig
from torchvision import transforms

from dino.distributed import is_main_process
from dino.distributed import is_main_process, is_enabled_and_multiple_gpus
from dino.eval.dataset import EvalDataset
from dino.eval.features import extract_features_single_process
from dino.eval.features import extract_multiple_features
from dino.eval.pathorob.apd import compute_apd
from dino.eval.pathorob.clustering import compute_clustering_score
from dino.eval.pathorob.datasets import load_manifest
Expand All @@ -36,6 +37,7 @@ def __init__(self, cfg: DictConfig, device: torch.device, output_dir: Path):
self.device = device
self.name = "pathorob"
self.output_dir = Path(output_dir)
self.distributed = is_enabled_and_multiple_gpus()

self.base_dir = self.output_dir / "pathorob"
self.splits_dir = self.base_dir / "splits"
Expand Down Expand Up @@ -93,22 +95,31 @@ def _extract_features(
image_col="image_path",
label_col="label",
)

if self.distributed:
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=False)
else:
sampler = torch.utils.data.SequentialSampler(dataset)

loader = torch.utils.data.DataLoader(
dataset,
sampler=torch.utils.data.SequentialSampler(dataset),
sampler=sampler,
batch_size=int(self.cfg.batch_size_per_gpu),
num_workers=int(self.cfg.num_workers),
pin_memory=True,
drop_last=False,
)

feats, _ = extract_features_single_process(
feats, _ = extract_multiple_features(
student_backbone,
teacher_backbone,
loader,
self.device,
)

if not is_main_process():
return {}

out: Dict[str, np.ndarray] = {}
for model_name in ["student", "teacher"]:
tensor = feats[model_name]
Expand Down Expand Up @@ -234,6 +245,10 @@ def _run_dataset(
teacher_backbone,
manifest_df,
)

if not is_main_process():
return [], {}

sample_to_idx = {sid: i for i, sid in enumerate(manifest_df["sample_id"].tolist())}

metric_rows: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -606,9 +621,6 @@ def _build_workspace(self, wandb, ws, wr, metric_keys: List[str]) -> None:

@torch.no_grad()
def run(self, student: nn.Module, teacher: nn.Module, epoch: int) -> PluginResult:
if not is_main_process():
return PluginResult(name=self.name)

student_backbone = self._get_backbone(student)
teacher_backbone = self._get_backbone(teacher)
student_backbone.eval()
Expand All @@ -617,12 +629,13 @@ def run(self, student: nn.Module, teacher: nn.Module, epoch: int) -> PluginResul
all_rows: List[Dict[str, Any]] = []
all_logs: Dict[str, float] = {}

# add an assert to make sure that at least one dataset is enabled
any_enabled = any(
bool(dataset_cfg.enable) for dataset_cfg in self.cfg.datasets.values()
)
if not any_enabled:
logging.warning("[PathoROB] No datasets are enabled for evaluation.")
if is_main_process():
any_enabled = any(
bool(dataset_cfg.enable) for dataset_cfg in self.cfg.datasets.values()
)
if not any_enabled:
logging.warning("[PathoROB] No datasets are enabled for evaluation.")

for dataset_name, dataset_cfg in self.cfg.datasets.items():
if not bool(dataset_cfg.enable):
continue
Expand All @@ -637,27 +650,31 @@ def run(self, student: nn.Module, teacher: nn.Module, epoch: int) -> PluginResul
all_rows.extend(rows)
all_logs.update(logs)
except Exception as exc:
err_file = self.metrics_dir / f"epoch_{epoch+1:04d}_{dataset_name}_error.txt"
err_file.write_text(traceback.format_exc())
if is_main_process():
err_file = self.metrics_dir / f"epoch_{epoch+1:04d}_{dataset_name}_error.txt"
err_file.write_text(traceback.format_exc())
logger.error(f"[PathoROB] {dataset_name} failed at epoch {epoch+1}: {exc}")

# Create wandb workspace with native LinePlot panels (once)
if all_logs:
self._setup_wandb_workspace(list(all_logs.keys()))

if all_rows:
df = pd.DataFrame(all_rows)
out_csv = self.metrics_dir / f"epoch_{epoch+1:04d}.csv"
out_json = self.metrics_dir / f"epoch_{epoch+1:04d}.json"
df.to_csv(out_csv, index=False)
out_json.write_text(json.dumps(all_rows, indent=2))

roll_path = self.metrics_dir / "all_metrics.csv"
if roll_path.exists():
old = pd.read_csv(roll_path)
pd.concat([old, df], axis=0).reset_index(drop=True).to_csv(roll_path, index=False)
else:
df.to_csv(roll_path, index=False)
if self.distributed:
dist.barrier()

if is_main_process():
if all_logs:
self._setup_wandb_workspace(list(all_logs.keys()))

if all_rows:
df = pd.DataFrame(all_rows)
out_csv = self.metrics_dir / f"epoch_{epoch+1:04d}.csv"
out_json = self.metrics_dir / f"epoch_{epoch+1:04d}.json"
df.to_csv(out_csv, index=False)
out_json.write_text(json.dumps(all_rows, indent=2))

roll_path = self.metrics_dir / "all_metrics.csv"
if roll_path.exists():
old = pd.read_csv(roll_path)
pd.concat([old, df], axis=0).reset_index(drop=True).to_csv(roll_path, index=False)
else:
df.to_csv(roll_path, index=False)

return PluginResult(
name=self.name,
Expand Down