From 624a1eaaf8e011eb6b565ab323392303b7ff7502 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Sun, 6 Apr 2025 21:42:41 -0400 Subject: [PATCH 1/2] Add particle detection features to copick-torch This commit adds several key features to copick-torch: 1. MONAI-based particle detector: A detector using MONAI's RetinaNet for 3D particle detection in cryoET data 2. Difference of Gaussian (DoG) detector: A classic blob detector reimplemented for particle picking 3. Evaluation metrics: Tools for analyzing detector performance with ground truth 4. CryoET Data Portal dataloader: A dataloader that automatically rescales tomograms to a target resolution Each component has a comprehensive test suite and example usage scripts. --- copick_torch/__init__.py | 12 +- copick_torch/dataloaders/__init__.py | 527 +++++++++++++++++++++++ copick_torch/detectors/__init__.py | 8 + copick_torch/detectors/dog_detector.py | 310 +++++++++++++ copick_torch/detectors/monai_detector.py | 476 ++++++++++++++++++++ copick_torch/metrics/__init__.py | 227 ++++++++++ examples/dog_detector_example.py | 200 +++++++++ examples/monai_detector_example.py | 417 ++++++++++++++++++ tests/test_dog_detector.py | 121 ++++++ tests/test_metrics.py | 168 ++++++++ tests/test_monai_detector.py | 203 +++++++++ 11 files changed, 2668 insertions(+), 1 deletion(-) create mode 100644 copick_torch/dataloaders/__init__.py create mode 100644 copick_torch/detectors/__init__.py create mode 100644 copick_torch/detectors/dog_detector.py create mode 100644 copick_torch/detectors/monai_detector.py create mode 100644 copick_torch/metrics/__init__.py create mode 100644 examples/dog_detector_example.py create mode 100644 examples/monai_detector_example.py create mode 100644 tests/test_dog_detector.py create mode 100644 tests/test_metrics.py create mode 100644 tests/test_monai_detector.py diff --git a/copick_torch/__init__.py b/copick_torch/__init__.py index 14b810d..3167fa3 100644 --- a/copick_torch/__init__.py +++ b/copick_torch/__init__.py @@ -1,4 +1,14 @@ from copick_torch.copick import CopickDataset from copick_torch.logging import setup_logging +from copick_torch.detectors.monai_detector import MONAIParticleDetector +from copick_torch.detectors.dog_detector import DoGParticleDetector +from copick_torch.dataloaders import CryoETDataPortalDataset, CryoETParticleDataset -__all__ = ["CopickDataset", "setup_logging"] \ No newline at end of file +__all__ = [ + "CopickDataset", + "setup_logging", + "MONAIParticleDetector", + "DoGParticleDetector", + "CryoETDataPortalDataset", + "CryoETParticleDataset" +] diff --git a/copick_torch/dataloaders/__init__.py b/copick_torch/dataloaders/__init__.py new file mode 100644 index 0000000..baa78c3 --- /dev/null +++ b/copick_torch/dataloaders/__init__.py @@ -0,0 +1,527 @@ +""" +CryoET Data Portal dataloader for copick. + +This module provides dataloaders specifically designed for loading data from the CryoET Data Portal, +with automatic rescaling to a target resolution. +""" + +import os +import logging +import warnings +from pathlib import Path +from typing import Dict, List, Tuple, Union, Optional, Any, Sequence + +import numpy as np +import torch +import zarr +from torch.utils.data import Dataset, DataLoader +from scipy.ndimage import zoom + +import copick +from monai.transforms import Compose, ScaleIntensity, EnsureChannelFirst, ToTensor + + +class CryoETDataPortalDataset(Dataset): + """ + A PyTorch dataset for working with CryoET data from the CZ CryoET Data Portal. + + This dataset automatically rescales tomograms to the target resolution using scipy.ndimage.zoom. + + Args: + dataset_ids: list of dataset IDs from the CryoET Data Portal + overlay_root: root URL for the overlay storage + boxsize: size of extracted boxes (z, y, x) + voxel_spacing: target voxel spacing in Ångstroms + transform: transforms to apply to the tomogram + cache_dir: directory to cache rescaled tomograms + use_cache: whether to use cached tomograms if available + batch_size: batch size for returning data + shuffle: whether to shuffle the order of tomograms + num_workers: number of worker processes for data loading + """ + + def __init__( + self, + dataset_ids: List[int], + overlay_root: str, + boxsize: Tuple[int, int, int] = (32, 32, 32), + voxel_spacing: float = 10.0, + transform: Optional[Any] = None, + cache_dir: Optional[str] = None, + use_cache: bool = True, + batch_size: int = 1, + shuffle: bool = False, + num_workers: int = 0 + ): + self.dataset_ids = dataset_ids + self.overlay_root = overlay_root + self.boxsize = boxsize + self.voxel_spacing = voxel_spacing + self.transform = transform if transform is not None else Compose([ + EnsureChannelFirst(), + ScaleIntensity(), + ToTensor() + ]) + self.cache_dir = cache_dir + self.use_cache = use_cache + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + + self.logger = logging.getLogger(__name__) + + # Create cache directory if necessary + if self.cache_dir is not None: + os.makedirs(self.cache_dir, exist_ok=True) + + # Initialize copick project + self.logger.info(f"Initializing copick project for dataset ids: {dataset_ids}") + self.root = copick.from_czcdp_datasets( + dataset_ids=dataset_ids, + overlay_root=overlay_root + ) + + # Collect all runs and tomograms + self.runs = self.root.runs + self.tomograms = [] + self.tomogram_metadata = [] + + for run in self.runs: + # Get the closest voxel spacing available + available_spacings = [vs.voxel_size for vs in run.voxel_spacings] + closest_spacing = min(available_spacings, key=lambda x: abs(x - voxel_spacing)) + + vs = run.get_voxel_spacing(closest_spacing) + if vs is None: + continue + + # Get all tomograms for this voxel spacing + for tomo in vs.tomograms: + self.tomograms.append(tomo) + self.tomogram_metadata.append({ + 'run': run, + 'voxel_spacing': vs, + 'original_spacing': closest_spacing, + 'target_spacing': voxel_spacing + }) + + self.logger.info(f"Found {len(self.tomograms)} tomograms in {len(self.runs)} runs") + + def __len__(self): + return len(self.tomograms) + + def __getitem__(self, idx): + tomogram = self.tomograms[idx] + metadata = self.tomogram_metadata[idx] + + # Check if rescaled tomogram is cached + cache_path = None + if self.cache_dir is not None: + run_name = metadata['run'].name + vs_value = metadata['original_spacing'] + target_vs = metadata['target_spacing'] + tomo_type = tomogram.tomo_type + cache_filename = f"{run_name}_{vs_value:.2f}_to_{target_vs:.2f}_{tomo_type}.npy" + cache_path = Path(self.cache_dir) / cache_filename + + # Load from cache if available and requested + if cache_path is not None and cache_path.exists() and self.use_cache: + self.logger.info(f"Loading rescaled tomogram from cache: {cache_path}") + try: + rescaled_tomo = np.load(cache_path) + except Exception as e: + self.logger.error(f"Failed to load from cache: {e}") + rescaled_tomo = self._load_and_rescale_tomogram(tomogram, metadata) + else: + # Load and rescale tomogram + rescaled_tomo = self._load_and_rescale_tomogram(tomogram, metadata) + + # Cache the rescaled tomogram if requested + if cache_path is not None and self.use_cache: + self.logger.info(f"Caching rescaled tomogram: {cache_path}") + try: + np.save(cache_path, rescaled_tomo) + except Exception as e: + self.logger.error(f"Failed to save to cache: {e}") + + # Apply transforms + if self.transform: + rescaled_tomo = self.transform(rescaled_tomo) + + return rescaled_tomo, {'idx': idx, 'metadata': metadata} + + def _load_and_rescale_tomogram(self, tomogram, metadata): + """ + Load and rescale a tomogram to the target voxel spacing. + + Args: + tomogram: tomogram object from copick + metadata: metadata dictionary + + Returns: + rescaled_tomo: numpy array containing the rescaled tomogram + """ + self.logger.info(f"Loading tomogram for run {metadata['run'].name}") + + # Load the tomogram + tomo_array = tomogram.numpy() + + # Calculate zoom factors for rescaling + original_spacing = metadata['original_spacing'] + target_spacing = metadata['target_spacing'] + zoom_factors = [original_spacing / target_spacing] * 3 + + # Skip rescaling if spacing is already very close + if abs(original_spacing - target_spacing) < 0.01: + self.logger.info(f"Skipping rescaling as spacing is already close: {original_spacing:.2f}") + return tomo_array + + # Rescale the tomogram + self.logger.info(f"Rescaling tomogram from {original_spacing:.2f}Å to {target_spacing:.2f}Å") + self.logger.info(f"Tomogram shape before rescaling: {tomo_array.shape}") + + # Use scipy.ndimage.zoom for rescaling + try: + rescaled_tomo = zoom(tomo_array, zoom_factors, order=1, mode='constant') + self.logger.info(f"Tomogram shape after rescaling: {rescaled_tomo.shape}") + return rescaled_tomo + except Exception as e: + self.logger.error(f"Failed to rescale tomogram: {e}") + # Return original if rescaling fails + return tomo_array + + def get_dataloader(self): + """ + Get a DataLoader for this dataset. + + Returns: + DataLoader: PyTorch DataLoader + """ + return DataLoader( + self, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers + ) + + def get_picks_for_tomogram(self, idx): + """ + Get particle picks for a specific tomogram. + + Args: + idx: index of the tomogram + + Returns: + list of particle coordinates (rescaled to target resolution) + """ + if idx < 0 or idx >= len(self.tomograms): + raise IndexError(f"Index {idx} out of range for tomogram list of length {len(self.tomograms)}") + + metadata = self.tomogram_metadata[idx] + run = metadata['run'] + + # Get all picks for this run + picks = run.get_picks() + + # Initialize list to hold all coordinates + all_coords = [] + + # Process each pick set + for pick in picks: + try: + # Convert picks to numpy coordinates + points, _ = pick.numpy() + + # Rescale coordinates from original to target resolution + original_spacing = metadata['original_spacing'] + target_spacing = metadata['target_spacing'] + scale_factor = original_spacing / target_spacing + + # Scale coordinates + points = points * scale_factor + + # Add to list + all_coords.append(points) + except Exception as e: + self.logger.error(f"Error processing picks: {e}") + + # Combine all coordinates + if all_coords: + return np.vstack(all_coords) + else: + return np.zeros((0, 3)) + + def get_all_picks(self): + """ + Get all particle picks for all tomograms. + + Returns: + dict: mapping from tomogram idx to particle coordinates + """ + all_picks = {} + for idx in range(len(self.tomograms)): + all_picks[idx] = self.get_picks_for_tomogram(idx) + return all_picks + + +class CryoETParticleDataset(Dataset): + """ + A PyTorch dataset for particle picking in CryoET data from the CZ CryoET Data Portal. + + This dataset extracts subvolumes around particle coordinates for training detection models. + + Args: + dataset_ids: list of dataset IDs from the CryoET Data Portal + overlay_root: root URL for the overlay storage + boxsize: size of extracted boxes (z, y, x) + voxel_spacing: target voxel spacing in Ångstroms + include_background: whether to include background (non-particle) samples + background_ratio: ratio of background samples to particle samples + min_background_distance: minimum distance from particles for background samples + transform: transforms to apply to the extracted subvolumes + cache_dir: directory to cache rescaled tomograms + use_cache: whether to use cached tomograms if available + """ + + def __init__( + self, + dataset_ids: List[int], + overlay_root: str, + boxsize: Tuple[int, int, int] = (32, 32, 32), + voxel_spacing: float = 10.0, + include_background: bool = True, + background_ratio: float = 0.5, + min_background_distance: float = 20.0, + transform: Optional[Any] = None, + cache_dir: Optional[str] = None, + use_cache: bool = True + ): + self.dataset_ids = dataset_ids + self.overlay_root = overlay_root + self.boxsize = boxsize + self.voxel_spacing = voxel_spacing + self.include_background = include_background + self.background_ratio = background_ratio + self.min_background_distance = min_background_distance + self.transform = transform if transform is not None else Compose([ + EnsureChannelFirst(), + ScaleIntensity(), + ToTensor() + ]) + self.cache_dir = cache_dir + self.use_cache = use_cache + + self.logger = logging.getLogger(__name__) + + # Create the base dataset to load and rescale tomograms + self.base_dataset = CryoETDataPortalDataset( + dataset_ids=dataset_ids, + overlay_root=overlay_root, + boxsize=boxsize, + voxel_spacing=voxel_spacing, + transform=None, # We'll apply transforms later + cache_dir=cache_dir, + use_cache=use_cache + ) + + # Load all tomograms and picks + self.tomograms = [] + self.particle_coords = [] + self.background_coords = [] + + # Flag to track if dataset is fully initialized + self._initialized = False + + def initialize(self): + """ + Fully initialize the dataset by loading all tomograms and extracting particles. + This is separated from __init__ to allow lazy loading. + """ + if self._initialized: + return + + # Load all tomograms + for idx in range(len(self.base_dataset)): + # Load tomogram + tomo, metadata = self.base_dataset[idx] + self.tomograms.append(tomo) + + # Get particle picks for this tomogram + particle_coords = self.base_dataset.get_picks_for_tomogram(idx) + self.particle_coords.append(particle_coords) + + # Generate background samples + if self.include_background: + bg_coords = self._generate_background_coords(tomo, particle_coords) + self.background_coords.append(bg_coords) + else: + self.background_coords.append(np.zeros((0, 3))) + + # Set up indices for accessing particles and background + self.particle_indices = [] + self.background_indices = [] + + for tomo_idx, (particles, backgrounds) in enumerate(zip(self.particle_coords, self.background_coords)): + for particle_idx in range(len(particles)): + self.particle_indices.append((tomo_idx, particle_idx, True)) # True = particle + + for bg_idx in range(len(backgrounds)): + self.background_indices.append((tomo_idx, bg_idx, False)) # False = background + + self._initialized = True + + self.logger.info(f"Dataset initialized with {len(self.particle_indices)} particles and {len(self.background_indices)} background samples") + + def _generate_background_coords(self, tomogram, particle_coords): + """ + Generate random background coordinates away from particles. + + Args: + tomogram: tomogram as numpy array + particle_coords: array of particle coordinates + + Returns: + background_coords: array of background coordinates + """ + if len(particle_coords) == 0: + return np.zeros((0, 3)) + + # Calculate number of background samples to generate + num_particles = len(particle_coords) + num_backgrounds = int(num_particles * self.background_ratio) + + # Generate random coordinates + background_coords = [] + max_attempts = num_backgrounds * 10 # Limit attempts to avoid infinite loop + + # Get tomogram shape + z_max, y_max, x_max = tomogram.shape + half_box = np.array(self.boxsize) // 2 + + attempts = 0 + while len(background_coords) < num_backgrounds and attempts < max_attempts: + # Generate random coordinates within valid range + z = np.random.randint(half_box[0], z_max - half_box[0]) + y = np.random.randint(half_box[1], y_max - half_box[1]) + x = np.random.randint(half_box[2], x_max - half_box[2]) + + coord = np.array([z, y, x]) + + # Check distance to all particles + if len(particle_coords) > 0: + distances = np.sqrt(np.sum((particle_coords - coord)**2, axis=1)) + min_distance = np.min(distances) + + # Only accept if far enough from all particles + if min_distance >= self.min_background_distance: + background_coords.append(coord) + else: + # If no particles, just accept the coordinate + background_coords.append(coord) + + attempts += 1 + + return np.array(background_coords) + + def __len__(self): + # Make sure dataset is initialized + if not self._initialized: + self.initialize() + + return len(self.particle_indices) + len(self.background_indices) + + def __getitem__(self, idx): + # Make sure dataset is initialized + if not self._initialized: + self.initialize() + + # Determine if this is a particle or background sample + if idx < len(self.particle_indices): + tomo_idx, coord_idx, is_particle = self.particle_indices[idx] + coords = self.particle_coords[tomo_idx][coord_idx] + else: + # Adjust index for background samples + bg_idx = idx - len(self.particle_indices) + tomo_idx, coord_idx, is_particle = self.background_indices[bg_idx] + coords = self.background_coords[tomo_idx][coord_idx] + + # Get tomogram + tomogram = self.tomograms[tomo_idx] + + # Extract subvolume centered at coordinates + subvolume = self._extract_subvolume(tomogram, coords) + + # Create target label (1 for particle, 0 for background) + label = 1 if is_particle else 0 + + # Apply transforms + if self.transform: + subvolume = self.transform(subvolume) + + return subvolume, label + + def _extract_subvolume(self, tomogram, coords): + """ + Extract a subvolume from the tomogram centered at the given coordinates. + + Args: + tomogram: tomogram as numpy array + coords: coordinates (z, y, x) of the center point + + Returns: + subvolume: extracted subvolume + """ + # Get tomogram shape + z_max, y_max, x_max = tomogram.shape + + # Calculate half box size + half_box = np.array(self.boxsize) // 2 + + # Calculate extraction ranges + z_start = max(0, int(coords[0] - half_box[0])) + z_end = min(z_max, int(coords[0] + half_box[0])) + y_start = max(0, int(coords[1] - half_box[1])) + y_end = min(y_max, int(coords[1] + half_box[1])) + x_start = max(0, int(coords[2] - half_box[2])) + x_end = min(x_max, int(coords[2] + half_box[2])) + + # Extract subvolume + subvolume = tomogram[z_start:z_end, y_start:y_end, x_start:x_end] + + # Handle case where extracted subvolume is smaller than desired size + if subvolume.shape != self.boxsize: + # Create padded subvolume + padded = np.zeros(self.boxsize, dtype=subvolume.dtype) + + # Calculate padding + z_pad = half_box[0] - int(coords[0]) if coords[0] < half_box[0] else 0 + y_pad = half_box[1] - int(coords[1]) if coords[1] < half_box[1] else 0 + x_pad = half_box[2] - int(coords[2]) if coords[2] < half_box[2] else 0 + + # Calculate actual dimensions to copy + z_size, y_size, x_size = subvolume.shape + + # Copy data to padded volume + padded[z_pad:z_pad+z_size, y_pad:y_pad+y_size, x_pad:x_pad+x_size] = subvolume + return padded + + return subvolume + + def get_dataloader(self, batch_size=8, shuffle=True, num_workers=0): + """ + Get a DataLoader for this dataset. + + Args: + batch_size: batch size for the dataloader + shuffle: whether to shuffle the dataset + num_workers: number of worker processes + + Returns: + DataLoader: PyTorch DataLoader + """ + return DataLoader( + self, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers + ) diff --git a/copick_torch/detectors/__init__.py b/copick_torch/detectors/__init__.py new file mode 100644 index 0000000..8500b45 --- /dev/null +++ b/copick_torch/detectors/__init__.py @@ -0,0 +1,8 @@ +""" +Particle detectors for CryoET data. +""" + +from copick_torch.detectors.monai_detector import MONAIParticleDetector +from copick_torch.detectors.dog_detector import DoGParticleDetector + +__all__ = ["MONAIParticleDetector", "DoGParticleDetector"] diff --git a/copick_torch/detectors/dog_detector.py b/copick_torch/detectors/dog_detector.py new file mode 100644 index 0000000..88cc63b --- /dev/null +++ b/copick_torch/detectors/dog_detector.py @@ -0,0 +1,310 @@ +""" +Difference of Gaussian (DoG) particle detector for CryoET data. + +This module implements a simple but effective particle detector based on the Difference of Gaussian (DoG) +method, which is widely used in particle picking for cryo-EM and cryo-ET data. +""" + +import logging +import numpy as np +from scipy import ndimage +from skimage.feature import peak_local_max +from typing import List, Tuple, Union, Optional, Dict, Any, Sequence + +class DoGParticleDetector: + """ + Difference of Gaussian (DoG) particle detector for CryoET data. + + This detector applies Gaussian filters with two different sigma values to the input volume, + then subtracts the more blurred volume from the less blurred one to enhance particle-like features. + Local maxima in the resulting volume are identified as potential particle locations. + + Args: + sigma1: sigma value for the first Gaussian filter (smaller) + sigma2: sigma value for the second Gaussian filter (larger) + threshold_abs: absolute threshold for peak detection + min_distance: minimum distance between peaks (in voxels) + exclude_border: exclude border region of this size (in voxels) + normalize: whether to normalize the input volume before processing + invert: whether to invert the contrast (for dark particles on light background) + prefilter: optional filter to apply before DoG (e.g., 'median', 'gaussian') + prefilter_size: size parameter for prefilter + """ + + def __init__( + self, + sigma1: float = 1.0, + sigma2: float = 3.0, + threshold_abs: float = 0.1, + min_distance: int = 5, + exclude_border: int = 2, + normalize: bool = True, + invert: bool = False, + prefilter: Optional[str] = None, + prefilter_size: float = 1.0 + ): + self.sigma1 = sigma1 + self.sigma2 = sigma2 + self.threshold_abs = threshold_abs + self.min_distance = min_distance + self.exclude_border = exclude_border + self.normalize = normalize + self.invert = invert + self.prefilter = prefilter + self.prefilter_size = prefilter_size + + self.logger = logging.getLogger(__name__) + + def detect(self, volume: np.ndarray, return_scores: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + Detect particles in a 3D volume using Difference of Gaussian method. + + Args: + volume: input 3D volume + return_scores: whether to return the peak values at each detected location + + Returns: + np.ndarray: particle coordinates (N, 3) + np.ndarray (optional): peak values at each location (N,) + """ + # Check input dimensions + if volume.ndim != 3: + raise ValueError(f"Expected 3D volume, got shape {volume.shape}") + + # Make a copy to avoid modifying the input + vol = volume.copy() + + # Invert contrast if needed (for dark particles on light background) + if self.invert: + vol = -vol + + # Normalize if requested + if self.normalize: + vol = (vol - np.mean(vol)) / np.std(vol) + + # Apply prefilter if requested + if self.prefilter == 'median': + self.logger.info(f"Applying median filter with size {self.prefilter_size}") + vol = ndimage.median_filter(vol, size=self.prefilter_size) + elif self.prefilter == 'gaussian': + self.logger.info(f"Applying Gaussian filter with sigma {self.prefilter_size}") + vol = ndimage.gaussian_filter(vol, sigma=self.prefilter_size) + + # Apply Gaussian filters + self.logger.info(f"Applying DoG with sigma1={self.sigma1}, sigma2={self.sigma2}") + vol_smooth1 = ndimage.gaussian_filter(vol, sigma=self.sigma1) + vol_smooth2 = ndimage.gaussian_filter(vol, sigma=self.sigma2) + + # Calculate Difference of Gaussian + dog_vol = vol_smooth1 - vol_smooth2 + + # Find local maxima + self.logger.info(f"Finding peaks with min_distance={self.min_distance}, threshold={self.threshold_abs}") + peaks = peak_local_max( + dog_vol, + min_distance=self.min_distance, + threshold_abs=self.threshold_abs, + exclude_border=self.exclude_border, + indices=True + ) + + # Return peak coordinates and peak values if requested + if return_scores: + peak_values = np.array([dog_vol[tuple(peak)] for peak in peaks]) + return peaks, peak_values + + return peaks + + def detect_multiscale( + self, + volume: np.ndarray, + sigma_pairs: List[Tuple[float, float]], + return_scores: bool = False + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + Detect particles at multiple scales by applying DoG with different sigma pairs. + + Args: + volume: input 3D volume + sigma_pairs: list of (sigma1, sigma2) pairs to try + return_scores: whether to return the peak values at each detected location + + Returns: + np.ndarray: particle coordinates (N, 3) + np.ndarray (optional): peak values at each location (N,) + """ + all_peaks = [] + all_values = [] + + # Keep the original settings + orig_sigma1 = self.sigma1 + orig_sigma2 = self.sigma2 + + # Try each sigma pair + for sigma1, sigma2 in sigma_pairs: + self.sigma1 = sigma1 + self.sigma2 = sigma2 + + # Detect particles with this sigma pair + if return_scores: + peaks, values = self.detect(volume, return_scores=True) + all_peaks.append(peaks) + all_values.append(values) + else: + peaks = self.detect(volume, return_scores=False) + all_peaks.append(peaks) + + # Restore original settings + self.sigma1 = orig_sigma1 + self.sigma2 = orig_sigma2 + + # Combine results + if all_peaks: + combined_peaks = np.vstack(all_peaks) + + if return_scores: + combined_values = np.hstack(all_values) + return combined_peaks, combined_values + return combined_peaks + + # Return empty array if no peaks found + if return_scores: + return np.zeros((0, 3)), np.zeros(0) + return np.zeros((0, 3)) + + def optimize_parameters( + self, + volume: np.ndarray, + ground_truth: np.ndarray, + sigma1_range: Tuple[float, float, float] = (0.5, 3.0, 0.5), + sigma2_range: Tuple[float, float, float] = (1.0, 5.0, 0.5), + threshold_range: Tuple[float, float, float] = (0.05, 0.5, 0.05), + min_distance_range: Tuple[int, int, int] = (3, 10, 1), + tolerance: float = 5.0 + ) -> Dict[str, Any]: + """ + Optimize detector parameters by grid search against ground truth particles. + + Args: + volume: input 3D volume + ground_truth: array of ground truth particle coordinates (N, 3) + sigma1_range: (start, stop, step) for sigma1 values to try + sigma2_range: (start, stop, step) for sigma2 values to try + threshold_range: (start, stop, step) for threshold values to try + min_distance_range: (start, stop, step) for min_distance values to try + tolerance: maximum distance for a detected particle to be considered a match + + Returns: + dict: optimal parameters found + """ + best_params = {} + best_f1 = 0.0 + + # Create parameter grid + sigma1_values = np.arange(sigma1_range[0], sigma1_range[1] + 1e-5, sigma1_range[2]) + sigma2_values = np.arange(sigma2_range[0], sigma2_range[1] + 1e-5, sigma2_range[2]) + threshold_values = np.arange(threshold_range[0], threshold_range[1] + 1e-5, threshold_range[2]) + min_distance_values = np.arange( + min_distance_range[0], min_distance_range[1] + 1, min_distance_range[2], dtype=int + ) + + total_combinations = ( + len(sigma1_values) * len(sigma2_values) * + len(threshold_values) * len(min_distance_values) + ) + self.logger.info(f"Grid search over {total_combinations} parameter combinations") + + # Iterate over all parameter combinations + for sigma1 in sigma1_values: + for sigma2 in sigma2_values: + # Skip invalid combinations + if sigma2 <= sigma1: + continue + + for threshold in threshold_values: + for min_distance in min_distance_values: + # Update detector parameters + self.sigma1 = sigma1 + self.sigma2 = sigma2 + self.threshold_abs = threshold + self.min_distance = min_distance + + # Detect particles with current parameters + detected_peaks = self.detect(volume) + + # Calculate metrics + precision, recall, f1 = self._calculate_metrics(detected_peaks, ground_truth, tolerance) + + # Update best parameters if F1 score improved + if f1 > best_f1: + best_f1 = f1 + best_params = { + 'sigma1': sigma1, + 'sigma2': sigma2, + 'threshold_abs': threshold, + 'min_distance': min_distance, + 'precision': precision, + 'recall': recall, + 'f1': f1 + } + + self.logger.info( + f"New best: F1={f1:.3f}, P={precision:.3f}, R={recall:.3f} with " + f"sigma1={sigma1}, sigma2={sigma2}, thresh={threshold}, min_dist={min_distance}" + ) + + # Set detector to best parameters + self.sigma1 = best_params['sigma1'] + self.sigma2 = best_params['sigma2'] + self.threshold_abs = best_params['threshold_abs'] + self.min_distance = best_params['min_distance'] + + return best_params + + def _calculate_metrics( + self, + detected: np.ndarray, + ground_truth: np.ndarray, + tolerance: float + ) -> Tuple[float, float, float]: + """ + Calculate precision, recall, and F1 score for particle detection. + + Args: + detected: detected particle coordinates + ground_truth: ground truth particle coordinates + tolerance: maximum distance to consider a detection as correct + + Returns: + tuple: (precision, recall, f1_score) + """ + if len(detected) == 0: + return 0, 0, 0 + + if len(ground_truth) == 0: + return 0, 0, 0 + + # Calculate distances between all detections and ground truth particles + true_positives = 0 + matched_gt = set() + + # For each detected particle, find the closest ground truth + for det in detected: + # Calculate Euclidean distances to all ground truth particles + distances = np.sqrt(np.sum((ground_truth - det)**2, axis=1)) + + # Find the closest ground truth particle + min_idx = np.argmin(distances) + min_dist = distances[min_idx] + + # Consider it a match if the distance is within tolerance + if min_dist <= tolerance and min_idx not in matched_gt: + true_positives += 1 + matched_gt.add(min_idx) + + # Calculate metrics + precision = true_positives / len(detected) if len(detected) > 0 else 0 + recall = true_positives / len(ground_truth) if len(ground_truth) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + + return precision, recall, f1 diff --git a/copick_torch/detectors/monai_detector.py b/copick_torch/detectors/monai_detector.py new file mode 100644 index 0000000..fc2744a --- /dev/null +++ b/copick_torch/detectors/monai_detector.py @@ -0,0 +1,476 @@ +""" +MONAI-based particle detector for CryoET data. +This detector is based on MONAI's RetinaNet implementation. +""" + +import os +import logging +import warnings +from typing import Dict, List, Tuple, Union, Optional, Any, Sequence + +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from pathlib import Path + +import monai +from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector +from monai.apps.detection.networks.retinanet_network import RetinaNet, resnet_fpn_feature_extractor +from monai.apps.detection.utils.anchor_utils import AnchorGenerator, AnchorGeneratorWithAnchorShape +from monai.apps.detection.utils.predict_utils import ensure_dict_value_to_list_ +from monai.networks.nets import resnet +from monai.data.box_utils import box_iou +from monai.transforms import Compose, ScaleIntensity, EnsureChannelFirst, ToTensor, SpatialPad, Resize +from monai.inferers import SlidingWindowInferer +from monai.utils import ensure_tuple_rep, BlendMode, PytorchPadMode + +class MONAIParticleDetector: + """ + MONAI-based particle detector for cryoET data based on RetinaNet. + + This detector uses MONAI's RetinaNet implementation for 3D object detection + and adapts it for particle picking in cryoET data. + + Args: + spatial_dims: number of spatial dimensions (2 or 3) + num_classes: number of output classes (particle types) to detect + feature_size: size of features maps (32, 64, 128, etc.) + anchor_sizes: list of anchor sizes in voxels (e.g., [(8,8,8), (16,16,16)]) + pretrained: whether to use pretrained backbone weights + device: device to run the detector on ('cuda', 'cpu') + sliding_window_size: size of sliding window for inference (must be divisible by 16) + sliding_window_batch_size: batch size for sliding window inference + sliding_window_overlap: overlap of sliding windows during inference + detection_threshold: confidence threshold for detections + nms_threshold: non-maximum suppression threshold for overlapping detections + max_detections_per_volume: maximum number of detections per volume + """ + + def __init__( + self, + spatial_dims: int = 3, + num_classes: int = 1, + feature_size: int = 32, + anchor_sizes: Sequence[Sequence[int]] = None, + pretrained: bool = False, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + sliding_window_size: Sequence[int] = None, + sliding_window_batch_size: int = 4, + sliding_window_overlap: float = 0.25, + detection_threshold: float = 0.3, + nms_threshold: float = 0.1, + max_detections_per_volume: int = 2000, + ): + self.spatial_dims = spatial_dims + self.num_classes = num_classes + self.device = device + + # Set default anchor sizes if not provided + if anchor_sizes is None: + # Default sizes appropriate for cryoET particles + if spatial_dims == 3: + self.anchor_sizes = [(8, 8, 8), (16, 16, 16)] + else: + self.anchor_sizes = [(8, 8), (16, 16)] + else: + self.anchor_sizes = anchor_sizes + + # Set default sliding window size if not provided + if sliding_window_size is None: + # Default sliding window size (must be divisible by 16) + if spatial_dims == 3: + self.sliding_window_size = (64, 64, 64) + else: + self.sliding_window_size = (128, 128) + else: + self.sliding_window_size = sliding_window_size + + self.sliding_window_batch_size = sliding_window_batch_size + self.sliding_window_overlap = sliding_window_overlap + + # Detection parameters + self.detection_threshold = detection_threshold + self.nms_threshold = nms_threshold + self.max_detections_per_volume = max_detections_per_volume + + # Create the detector + self.detector = self._create_detector( + spatial_dims=spatial_dims, + num_classes=num_classes, + feature_size=feature_size, + pretrained=pretrained + ) + + # Configure detector for inference + self.detector.set_box_selector_parameters( + score_thresh=self.detection_threshold, + nms_thresh=self.nms_threshold, + detections_per_img=self.max_detections_per_volume, + apply_sigmoid=True + ) + + # Configure sliding window inference + self.detector.set_sliding_window_inferer( + roi_size=self.sliding_window_size, + sw_batch_size=self.sliding_window_batch_size, + overlap=self.sliding_window_overlap, + mode=BlendMode.CONSTANT, + padding_mode=PytorchPadMode.CONSTANT, + cval=0.0, + sw_device=self.device, + device=self.device, + progress=True + ) + + # Input pre-processing + self.transform = Compose([ + EnsureChannelFirst(), + ScaleIntensity(), + ToTensor(), + ]) + + # Move model to device + self.detector.to(self.device) + self.detector.eval() # Set to eval mode + + def _create_detector(self, spatial_dims, num_classes, feature_size, pretrained): + """ + Create the MONAI RetinaNet detector. + + Args: + spatial_dims: number of spatial dimensions + num_classes: number of classes to predict + feature_size: size of feature maps + pretrained: whether to use pretrained weights + + Returns: + RetinaNetDetector: detector model + """ + # Create ResNet backbone + if spatial_dims == 3: + conv1_t_stride = (2, 2, 2) + conv1_t_size = (7, 7, 7) + else: + conv1_t_stride = (2, 2) + conv1_t_size = (7, 7) + + # Use ResNet18 as the backbone for faster inference + backbone = resnet.ResNet( + spatial_dims=spatial_dims, + block=resnet.ResNetBasicBlock, + layers=[2, 2, 2, 2], # ResNet18 architecture + block_inplanes=resnet.get_inplanes(), + n_input_channels=1, # Single channel for tomogram data + conv1_t_stride=conv1_t_stride, + conv1_t_size=conv1_t_size, + num_classes=num_classes + ) + + # Define which layers to return from the FPN + returned_layers = [1, 2, 3] + + # Create feature extractor with FPN + feature_extractor = resnet_fpn_feature_extractor( + backbone=backbone, + spatial_dims=spatial_dims, + pretrained_backbone=pretrained, + trainable_backbone_layers=None, + returned_layers=returned_layers, + ) + + # Calculate required divisibility for network input + size_divisible = tuple(2 * s * 2**max(returned_layers) for s in conv1_t_stride) + + # Create anchor generator with custom anchor sizes for cryoET particles + anchor_generator = AnchorGeneratorWithAnchorShape( + feature_map_scales=(1, 2, 4), # Scales for different feature map levels + base_anchor_shapes=self.anchor_sizes # Use custom anchor sizes + ) + + # Create RetinaNet network + network = RetinaNet( + spatial_dims=spatial_dims, + num_classes=num_classes, + num_anchors=anchor_generator.num_anchors_per_location()[0], + feature_extractor=feature_extractor, + size_divisible=size_divisible, + ) + + # Create RetinaNet detector + detector = RetinaNetDetector(network, anchor_generator) + + return detector + + def load_weights(self, weights_path): + """ + Load trained weights for the detector. + + Args: + weights_path: path to the weights file + """ + if os.path.exists(weights_path): + print(f"Loading weights from {weights_path}") + state_dict = torch.load(weights_path, map_location=self.device) + self.detector.network.load_state_dict(state_dict) + else: + warnings.warn(f"Weights file {weights_path} not found.") + + def save_weights(self, weights_path): + """ + Save the detector weights. + + Args: + weights_path: path to save the weights file + """ + Path(weights_path).parent.mkdir(parents=True, exist_ok=True) + torch.save(self.detector.network.state_dict(), weights_path) + print(f"Saved weights to {weights_path}") + + def detect(self, volume, return_scores=False, use_inferer=True): + """ + Detect particles in a 3D volume. + + Args: + volume: 3D numpy array or torch tensor + return_scores: whether to return detection scores + use_inferer: whether to use sliding window inference + + Returns: + numpy array of particle coordinates [N, spatial_dims] + (optional) scores for each detection [N] + """ + # Ensure detector is in eval mode + self.detector.eval() + + # Ensure input is preprocessed correctly + if isinstance(volume, np.ndarray): + # Apply transforms to numpy array + input_tensor = self.transform(volume) + elif isinstance(volume, torch.Tensor): + # Ensure tensor is in correct format + if volume.ndim == self.spatial_dims: + # Add channel dimension if missing + input_tensor = volume.unsqueeze(0) + else: + input_tensor = volume + + # Move to device + input_tensor = input_tensor.to(self.device) + else: + raise ValueError(f"Unsupported input type: {type(volume)}") + + # Ensure input has batch dimension + if input_tensor.dim() == self.spatial_dims + 1: # [C, D, H, W] or [C, H, W] + input_tensor = input_tensor.unsqueeze(0) # Add batch dimension [B, C, D, H, W] or [B, C, H, W] + + with torch.no_grad(): + # Run inference + detections = self.detector.forward(input_tensor, use_inferer=use_inferer) + + # Extract coordinates + coordinates_list = [] + scores_list = [] + + for det in detections: + boxes = det["boxes"].cpu().numpy() + scores = det["labels_scores"].cpu().numpy() + + # Convert boxes [xmin, ymin, zmin, xmax, ymax, zmax] to coordinates [x, y, z] + if boxes.shape[0] > 0: + if self.spatial_dims == 3: + coordinates = np.zeros((boxes.shape[0], 3)) + coordinates[:, 0] = (boxes[:, 0] + boxes[:, 3]) / 2 # x center + coordinates[:, 1] = (boxes[:, 1] + boxes[:, 4]) / 2 # y center + coordinates[:, 2] = (boxes[:, 2] + boxes[:, 5]) / 2 # z center + else: + coordinates = np.zeros((boxes.shape[0], 2)) + coordinates[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 # x center + coordinates[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 # y center + + coordinates_list.append(coordinates) + scores_list.append(scores) + + # Concatenate results from all batches + if coordinates_list: + coordinates = np.concatenate(coordinates_list, axis=0) + scores = np.concatenate(scores_list, axis=0) + else: + # No detections + coordinates = np.zeros((0, self.spatial_dims)) + scores = np.zeros(0) + + if return_scores: + return coordinates, scores + return coordinates + + def train(self, + train_dataloader, + val_dataloader=None, + num_epochs=10, + learning_rate=1e-4, + weight_decay=1e-5, + save_path=None, + best_metric_name="val_loss"): + """ + Train the detector. + + Args: + train_dataloader: dataloader for training data + val_dataloader: dataloader for validation data + num_epochs: number of epochs to train + learning_rate: learning rate for optimizer + weight_decay: weight decay for optimizer + save_path: path to save the best model weights + best_metric_name: metric name to determine the best model + + Returns: + dict: training metrics + """ + # Set detector to training mode + self.detector.train() + + # Define optimizer + optimizer = torch.optim.AdamW( + self.detector.network.parameters(), + lr=learning_rate, + weight_decay=weight_decay + ) + + # Define learning rate scheduler + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=2, verbose=True + ) + + # Training metrics + best_metric = float("inf") if "loss" in best_metric_name else 0.0 + metrics = {"train_loss": [], "val_loss": []} + + for epoch in range(num_epochs): + # Training loop + self.detector.train() + epoch_loss = 0 + + for batch_idx, batch_data in enumerate(train_dataloader): + # Extract data and target + images, targets = batch_data + images = images.to(self.device) + + # Prepare targets for RetinaNet format + formatted_targets = [] + for idx, target in enumerate(targets): + boxes = target.get("boxes", None) + labels = target.get("labels", None) + + if boxes is None or labels is None: + continue + + formatted_targets.append({ + "boxes": boxes.to(self.device), + "labels": labels.to(self.device) + }) + + # Skip batch if no valid targets + if not formatted_targets: + continue + + # Zero gradients + optimizer.zero_grad() + + # Forward pass + losses = self.detector(images, formatted_targets) + loss = losses["classification"] + losses["box_regression"] + + # Backward pass + loss.backward() + + # Update weights + optimizer.step() + + # Update metrics + epoch_loss += loss.item() + + # Print progress + if (batch_idx + 1) % 10 == 0: + print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item():.4f}") + + # Compute average loss for the epoch + avg_train_loss = epoch_loss / len(train_dataloader) + metrics["train_loss"].append(avg_train_loss) + + # Validation loop + if val_dataloader is not None: + val_loss = self.validate(val_dataloader) + metrics["val_loss"].append(val_loss) + + # Update learning rate + scheduler.step(val_loss) + + # Save best model + if save_path is not None: + if ("loss" in best_metric_name and val_loss < best_metric) or \ + ("loss" not in best_metric_name and val_loss > best_metric): + best_metric = val_loss + self.save_weights(save_path) + print(f"Saved best model with {best_metric_name} = {best_metric:.4f}") + + # Print epoch summary + print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}" + + (f", Val Loss: {val_loss:.4f}" if val_dataloader else "")) + + # Set back to eval mode + self.detector.eval() + + return metrics + + def validate(self, val_dataloader): + """ + Validate the detector on validation data. + + Args: + val_dataloader: dataloader for validation data + + Returns: + float: validation loss + """ + # Set detector to evaluation mode + self.detector.eval() + + # Validation metrics + val_loss = 0 + + with torch.no_grad(): + for batch_data in val_dataloader: + # Extract data and target + images, targets = batch_data + images = images.to(self.device) + + # Prepare targets for RetinaNet format + formatted_targets = [] + for idx, target in enumerate(targets): + boxes = target.get("boxes", None) + labels = target.get("labels", None) + + if boxes is None or labels is None: + continue + + formatted_targets.append({ + "boxes": boxes.to(self.device), + "labels": labels.to(self.device) + }) + + # Skip batch if no valid targets + if not formatted_targets: + continue + + # Forward pass + losses = self.detector(images, formatted_targets) + loss = losses["classification"] + losses["box_regression"] + + # Update metrics + val_loss += loss.item() + + # Compute average loss for the validation set + avg_val_loss = val_loss / len(val_dataloader) + + return avg_val_loss diff --git a/copick_torch/metrics/__init__.py b/copick_torch/metrics/__init__.py new file mode 100644 index 0000000..d43c0e2 --- /dev/null +++ b/copick_torch/metrics/__init__.py @@ -0,0 +1,227 @@ +""" +Evaluation metrics for particle detection performance. + +This module provides various metrics for evaluating particle detection performance, +especially for cryoET particle picking. +""" + +import numpy as np +from typing import Dict, List, Tuple, Union, Optional +from sklearn.metrics import precision_recall_curve, average_precision_score + + +def calculate_distances(detected: np.ndarray, ground_truth: np.ndarray) -> np.ndarray: + """ + Calculate distances between detected particles and ground truth particles. + + Args: + detected: array of detected particle coordinates (N, 3) or (N, 2) + ground_truth: array of ground truth particle coordinates (M, 3) or (M, 2) + + Returns: + distances: 2D array of distances between all detected and ground truth particles (N, M) + """ + # Check inputs + if detected.shape[0] == 0 or ground_truth.shape[0] == 0: + return np.zeros((detected.shape[0], ground_truth.shape[0])) + + # Ensure both arrays have the same number of dimensions + if detected.shape[1] != ground_truth.shape[1]: + raise ValueError(f"Dimension mismatch: detected has shape {detected.shape}, ground_truth has shape {ground_truth.shape}") + + # Calculate distances between all pairs + distances = np.zeros((detected.shape[0], ground_truth.shape[0])) + for i, det in enumerate(detected): + for j, gt in enumerate(ground_truth): + distances[i, j] = np.sqrt(np.sum((det - gt) ** 2)) + + return distances + + +def calculate_precision_recall_f1( + detected: np.ndarray, + ground_truth: np.ndarray, + tolerance: float = 10.0 +) -> Tuple[float, float, float]: + """ + Calculate precision, recall, and F1 score for particle detection. + + Args: + detected: array of detected particle coordinates (N, 3) or (N, 2) + ground_truth: array of ground truth particle coordinates (M, 3) or (M, 2) + tolerance: maximum distance (in pixels/voxels) for a detection to be considered correct + + Returns: + tuple: (precision, recall, F1 score) + """ + if detected.shape[0] == 0: + return 0.0, 0.0, 0.0 + + if ground_truth.shape[0] == 0: + return 0.0, 0.0, 0.0 + + # Calculate distances between all detections and ground truth + distances = calculate_distances(detected, ground_truth) + + # Find matches using greedy assignment + true_positives = 0 + matched_gt = set() + + # For each detection, find the closest unmatched ground truth within tolerance + for i in range(distances.shape[0]): + min_idx = np.argmin(distances[i]) + min_dist = distances[i, min_idx] + + if min_dist <= tolerance and min_idx not in matched_gt: + true_positives += 1 + matched_gt.add(min_idx) + + # Calculate metrics + precision = true_positives / detected.shape[0] if detected.shape[0] > 0 else 0.0 + recall = true_positives / ground_truth.shape[0] if ground_truth.shape[0] > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + + return precision, recall, f1 + + +def calculate_average_precision( + detected: np.ndarray, + scores: np.ndarray, + ground_truth: np.ndarray, + tolerance: float = 10.0 +) -> Tuple[float, List[float], List[float], List[float]]: + """ + Calculate average precision (AP) for particle detection. + + Args: + detected: array of detected particle coordinates (N, 3) or (N, 2) + scores: confidence scores for each detection (N,) + ground_truth: array of ground truth particle coordinates (M, 3) or (M, 2) + tolerance: maximum distance (in pixels/voxels) for a detection to be considered correct + + Returns: + tuple: (average precision, precision values, recall values, thresholds) + """ + if detected.shape[0] == 0 or ground_truth.shape[0] == 0: + return 0.0, [], [], [] + + # Calculate distances between all detections and ground truth + distances = calculate_distances(detected, ground_truth) + + # For each detection, check if it's a true positive at minimum distance + y_true = np.zeros(detected.shape[0], dtype=bool) + matched_gt = set() + + # Sort detections by score in descending order + sort_indices = np.argsort(-scores) + sorted_dists = distances[sort_indices] + + # Find matches for sorted detections + for i in range(sorted_dists.shape[0]): + min_idx = np.argmin(sorted_dists[i]) + min_dist = sorted_dists[i, min_idx] + + if min_dist <= tolerance and min_idx not in matched_gt: + y_true[sort_indices[i]] = True + matched_gt.add(min_idx) + + # Calculate precision-recall curve + precision, recall, thresholds = precision_recall_curve(y_true, scores) + + # Calculate average precision + ap = average_precision_score(y_true, scores) + + return ap, precision.tolist(), recall.tolist(), thresholds.tolist() + + +def calculate_detector_metrics( + detected: np.ndarray, + ground_truth: np.ndarray, + scores: Optional[np.ndarray] = None, + tolerance: float = 10.0 +) -> Dict[str, Union[float, List[float]]]: + """ + Calculate a comprehensive set of metrics for particle detection. + + Args: + detected: array of detected particle coordinates (N, 3) or (N, 2) + ground_truth: array of ground truth particle coordinates (M, 3) or (M, 2) + scores: confidence scores for each detection (N,) (optional) + tolerance: maximum distance (in pixels/voxels) for a detection to be considered correct + + Returns: + dict: dictionary of metrics including precision, recall, F1, and AP + """ + metrics = {} + + # Basic metrics + precision, recall, f1 = calculate_precision_recall_f1(detected, ground_truth, tolerance) + metrics["precision"] = precision + metrics["recall"] = recall + metrics["f1_score"] = f1 + + # Absolute counts + metrics["num_detections"] = detected.shape[0] + metrics["num_ground_truth"] = ground_truth.shape[0] + + # Calculate true positives, false positives, and false negatives + if detected.shape[0] > 0 and ground_truth.shape[0] > 0: + distances = calculate_distances(detected, ground_truth) + matched_gt = set() + true_positives = 0 + + for i in range(distances.shape[0]): + min_idx = np.argmin(distances[i]) + min_dist = distances[i, min_idx] + + if min_dist <= tolerance and min_idx not in matched_gt: + true_positives += 1 + matched_gt.add(min_idx) + + false_positives = detected.shape[0] - true_positives + false_negatives = ground_truth.shape[0] - true_positives + + metrics["true_positives"] = true_positives + metrics["false_positives"] = false_positives + metrics["false_negatives"] = false_negatives + + # Calculate average precision if scores are provided + if scores is not None and scores.shape[0] == detected.shape[0]: + ap, precision_values, recall_values, thresholds = calculate_average_precision( + detected, scores, ground_truth, tolerance + ) + metrics["average_precision"] = ap + metrics["precision_values"] = precision_values + metrics["recall_values"] = recall_values + metrics["thresholds"] = thresholds + + return metrics + + +def calculate_dog_detector_metrics( + volume: np.ndarray, + ground_truth: np.ndarray, + detector, + tolerance: float = 10.0 +) -> Dict[str, Union[float, List[float]]]: + """ + Calculate metrics for DoG detector by running the detector on the volume. + + Args: + volume: 3D volume to detect particles in + ground_truth: array of ground truth particle coordinates (M, 3) + detector: DoG detector instance + tolerance: maximum distance (in pixels/voxels) for a detection to be considered correct + + Returns: + dict: dictionary of metrics including precision, recall, F1 + """ + # Run detector on volume + detected, scores = detector.detect(volume, return_scores=True) + + # Calculate metrics + metrics = calculate_detector_metrics( + detected, ground_truth, scores=scores, tolerance=tolerance + ) + + return metrics diff --git a/examples/dog_detector_example.py b/examples/dog_detector_example.py new file mode 100644 index 0000000..a9915cb --- /dev/null +++ b/examples/dog_detector_example.py @@ -0,0 +1,200 @@ +""" +Example script demonstrating the Difference of Gaussian (DoG) particle detector. + +This script shows how to use the DoG detector for particle picking in cryoET tomograms. +""" + +import os +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +import logging + +from copick_torch.detectors.dog_detector import DoGParticleDetector +from copick_torch.metrics import calculate_detector_metrics +import copick + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def plot_results(tomogram, picks, ground_truth=None, slice_idx=None): + """ + Plot detection results on a tomogram slice. + + Args: + tomogram: 3D numpy array of the tomogram + picks: detected particle coordinates (N, 3) + ground_truth: optional ground truth particle coordinates (M, 3) + slice_idx: optional slice index to plot, if None will use middle slice + """ + if slice_idx is None: + # Use middle slice + slice_idx = tomogram.shape[0] // 2 + + # Plot the slice + plt.figure(figsize=(10, 8)) + plt.imshow(tomogram[slice_idx], cmap='gray') + + # Find particles in this slice (within +/- 3 slices) + z_min, z_max = slice_idx - 3, slice_idx + 3 + + # Plot detected particles + if picks is not None and len(picks) > 0: + picks_in_slice = picks[(picks[:, 0] >= z_min) & (picks[:, 0] <= z_max)] + if len(picks_in_slice) > 0: + plt.scatter(picks_in_slice[:, 2], picks_in_slice[:, 1], + s=30, c='red', marker='o', alpha=0.7, label=f'Detected ({len(picks_in_slice)})') + + # Plot ground truth if provided + if ground_truth is not None and len(ground_truth) > 0: + gt_in_slice = ground_truth[(ground_truth[:, 0] >= z_min) & (ground_truth[:, 0] <= z_max)] + if len(gt_in_slice) > 0: + plt.scatter(gt_in_slice[:, 2], gt_in_slice[:, 1], + s=30, c='blue', marker='x', alpha=0.7, label=f'Ground Truth ({len(gt_in_slice)})') + + plt.title(f'Tomogram Slice {slice_idx}') + plt.colorbar(label='Intensity') + if (picks is not None and len(picks) > 0) or (ground_truth is not None and len(ground_truth) > 0): + plt.legend() + plt.tight_layout() + +def main(): + """Main function demonstrating the DoG detector.""" + # Check if cache directory exists, if not create it + cache_dir = Path('cache') + cache_dir.mkdir(exist_ok=True) + + # Define path to the copick project configuration + config_path = './examples/czii_object_detection_training.json' + + # Check if the config exists, if not, print instructions + if not os.path.exists(config_path): + logger.error(f"Config file {config_path} not found!") + logger.info("Please create a copick configuration file or specify an existing one.") + return + + logger.info(f"Loading copick project from {config_path}") + root = copick.from_file(config_path) + + # Get the first run + if not root.runs: + logger.error("No runs found in the copick project!") + return + + run = root.runs[0] + logger.info(f"Using run: {run.name}") + + # Get the first voxel spacing + if not run.voxel_spacings: + logger.error("No voxel spacings found in the run!") + return + + voxel_spacing = run.voxel_spacings[0] + logger.info(f"Using voxel spacing: {voxel_spacing.voxel_size}") + + # Get the first tomogram + if not voxel_spacing.tomograms: + logger.error("No tomograms found for this voxel spacing!") + return + + tomogram = voxel_spacing.tomograms[0] + logger.info(f"Using tomogram: {tomogram.tomo_type}") + + # Load the tomogram + logger.info("Loading tomogram data...") + tomogram_array = tomogram.numpy() + + # Create the DoG detector + logger.info("Creating DoG detector...") + detector = DoGParticleDetector( + sigma1=1.0, + sigma2=3.0, + threshold_abs=0.1, + min_distance=5, + normalize=True, + prefilter="median", + prefilter_size=1.0 + ) + + # Detect particles + logger.info("Detecting particles...") + picks, scores = detector.detect(tomogram_array, return_scores=True) + logger.info(f"Detected {len(picks)} particles") + + # Get ground truth picks if available + picks_sets = run.get_picks() + ground_truth = None + + if picks_sets: + logger.info(f"Found {len(picks_sets)} pick sets") + # Combine all pick sets + all_picks = [] + for pick_set in picks_sets: + try: + points, _ = pick_set.numpy() + all_picks.append(points) + except Exception as e: + logger.error(f"Error loading picks: {e}") + + if all_picks: + ground_truth = np.vstack(all_picks) + logger.info(f"Loaded {len(ground_truth)} ground truth picks") + + # Calculate metrics + logger.info("Calculating metrics...") + metrics = calculate_detector_metrics( + picks, ground_truth, scores, tolerance=10.0 + ) + + logger.info(f"Precision: {metrics['precision']:.3f}") + logger.info(f"Recall: {metrics['recall']:.3f}") + logger.info(f"F1 Score: {metrics['f1_score']:.3f}") + logger.info(f"Average Precision: {metrics.get('average_precision', 'N/A')}") + + # Plot results + logger.info("Plotting results...") + plot_results(tomogram_array, picks, ground_truth) + + # Save figure + plt.savefig('dog_detector_results.png') + logger.info("Saved figure to dog_detector_results.png") + + # Optimize detector parameters if ground truth is available + if ground_truth is not None and len(ground_truth) > 0: + logger.info("Optimizing detector parameters...") + best_params = detector.optimize_parameters( + tomogram_array, + ground_truth, + sigma1_range=(0.5, 2.0, 0.5), + sigma2_range=(1.5, 4.0, 0.5), + threshold_range=(0.05, 0.2, 0.05), + min_distance_range=(3, 7, 2), + tolerance=10.0 + ) + + logger.info("Best parameters:") + for key, value in best_params.items(): + logger.info(f" {key}: {value}") + + # Update detector with best parameters + detector.sigma1 = best_params['sigma1'] + detector.sigma2 = best_params['sigma2'] + detector.threshold_abs = best_params['threshold_abs'] + detector.min_distance = best_params['min_distance'] + + # Detect particles with optimized parameters + logger.info("Detecting particles with optimized parameters...") + picks_opt, scores_opt = detector.detect(tomogram_array, return_scores=True) + logger.info(f"Detected {len(picks_opt)} particles") + + # Plot results with optimized parameters + logger.info("Plotting results with optimized parameters...") + plot_results(tomogram_array, picks_opt, ground_truth) + + # Save figure + plt.savefig('dog_detector_optimized_results.png') + logger.info("Saved figure to dog_detector_optimized_results.png") + +if __name__ == "__main__": + main() diff --git a/examples/monai_detector_example.py b/examples/monai_detector_example.py new file mode 100644 index 0000000..7ab317c --- /dev/null +++ b/examples/monai_detector_example.py @@ -0,0 +1,417 @@ +""" +Example script demonstrating the MONAI-based particle detector. + +This script shows how to use the MONAI-based detector for particle picking in cryoET tomograms. +""" + +import os +import numpy as np +import matplotlib.pyplot as plt +import torch +from pathlib import Path +import logging +import argparse + +from copick_torch.detectors.monai_detector import MONAIParticleDetector +from copick_torch.metrics import calculate_detector_metrics +from copick_torch.dataloaders import CryoETDataPortalDataset +import copick + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def plot_results(tomogram, picks, ground_truth=None, slice_idx=None): + """ + Plot detection results on a tomogram slice. + + Args: + tomogram: 3D numpy array of the tomogram + picks: detected particle coordinates (N, 3) + ground_truth: optional ground truth particle coordinates (M, 3) + slice_idx: optional slice index to plot, if None will use middle slice + """ + if slice_idx is None: + # Use middle slice + slice_idx = tomogram.shape[0] // 2 + + # Plot the slice + plt.figure(figsize=(10, 8)) + plt.imshow(tomogram[slice_idx], cmap='gray') + + # Find particles in this slice (within +/- 3 slices) + z_min, z_max = slice_idx - 3, slice_idx + 3 + + # Plot detected particles + if picks is not None and len(picks) > 0: + picks_in_slice = picks[(picks[:, 0] >= z_min) & (picks[:, 0] <= z_max)] + if len(picks_in_slice) > 0: + plt.scatter(picks_in_slice[:, 2], picks_in_slice[:, 1], + s=30, c='red', marker='o', alpha=0.7, label=f'Detected ({len(picks_in_slice)})') + + # Plot ground truth if provided + if ground_truth is not None and len(ground_truth) > 0: + gt_in_slice = ground_truth[(ground_truth[:, 0] >= z_min) & (ground_truth[:, 0] <= z_max)] + if len(gt_in_slice) > 0: + plt.scatter(gt_in_slice[:, 2], gt_in_slice[:, 1], + s=30, c='blue', marker='x', alpha=0.7, label=f'Ground Truth ({len(gt_in_slice)})') + + plt.title(f'Tomogram Slice {slice_idx}') + plt.colorbar(label='Intensity') + if (picks is not None and len(picks) > 0) or (ground_truth is not None and len(ground_truth) > 0): + plt.legend() + plt.tight_layout() + +def train_on_dataset(dataset_ids, overlay_root, output_dir, num_epochs=10, batch_size=4): + """ + Train a particle detector on the specified datasets. + + Args: + dataset_ids: list of dataset IDs from the CryoET Data Portal + overlay_root: root URL for the overlay storage + output_dir: directory to save model weights and outputs + num_epochs: number of epochs to train for + batch_size: batch size for training + """ + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Create dataset + logger.info(f"Creating dataset for IDs: {dataset_ids}") + dataset = CryoETParticleDataset( + dataset_ids=dataset_ids, + overlay_root=overlay_root, + boxsize=(64, 64, 64), + voxel_spacing=10.0, + include_background=True, + background_ratio=0.5, + min_background_distance=20.0, + cache_dir=os.path.join(output_dir, 'cache') + ) + + # Initialize dataset (this will load data) + logger.info("Initializing dataset...") + dataset.initialize() + + # Create dataloader + logger.info("Creating dataloader...") + dataloader = dataset.get_dataloader(batch_size=batch_size, shuffle=True) + + # Create detector + logger.info("Creating MONAI detector...") + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + feature_size=32, + anchor_sizes=[(8, 8, 8), (16, 16, 16), (32, 32, 32)], + device="cuda" if torch.cuda.is_available() else "cpu", + sliding_window_size=(64, 64, 64), + sliding_window_batch_size=batch_size, + sliding_window_overlap=0.25, + detection_threshold=0.3, + nms_threshold=0.1, + max_detections_per_volume=1000 + ) + + # Train the detector + logger.info("Training detector...") + detector.train( + train_dataloader=dataloader, + num_epochs=num_epochs, + learning_rate=1e-4, + weight_decay=1e-5, + save_path=os.path.join(output_dir, 'model_weights.pt') + ) + + logger.info(f"Training complete. Model saved to {os.path.join(output_dir, 'model_weights.pt')}") + +def inference_on_dataset(dataset_ids, overlay_root, weights_path, output_dir): + """ + Run inference on the specified datasets. + + Args: + dataset_ids: list of dataset IDs from the CryoET Data Portal + overlay_root: root URL for the overlay storage + weights_path: path to trained model weights + output_dir: directory to save outputs + """ + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Create dataset + logger.info(f"Creating dataset for IDs: {dataset_ids}") + dataset = CryoETDataPortalDataset( + dataset_ids=dataset_ids, + overlay_root=overlay_root, + voxel_spacing=10.0, + cache_dir=os.path.join(output_dir, 'cache') + ) + + # Create detector + logger.info("Creating MONAI detector...") + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + feature_size=32, + anchor_sizes=[(8, 8, 8), (16, 16, 16), (32, 32, 32)], + device="cuda" if torch.cuda.is_available() else "cpu", + sliding_window_size=(64, 64, 64), + sliding_window_batch_size=4, + sliding_window_overlap=0.25, + detection_threshold=0.3, + nms_threshold=0.1, + max_detections_per_volume=1000 + ) + + # Load weights + if os.path.exists(weights_path): + logger.info(f"Loading weights from {weights_path}") + detector.load_weights(weights_path) + else: + logger.warning(f"Weights file {weights_path} not found, using untrained model") + + # Process each tomogram in the dataset + all_metrics = [] + + for idx in range(len(dataset)): + # Load tomogram + logger.info(f"Processing tomogram {idx+1}/{len(dataset)}") + tomogram, metadata = dataset[idx] + + # Get ground truth picks + ground_truth = dataset.get_picks_for_tomogram(idx) + + # Run inference + logger.info("Running inference...") + if isinstance(tomogram, torch.Tensor): + tomogram_np = tomogram.numpy() + if tomogram_np.ndim == 4: # Remove channel dimension if present + tomogram_np = tomogram_np[0] + else: + tomogram_np = tomogram + + picks, scores = detector.detect(tomogram_np, return_scores=True, use_inferer=True) + logger.info(f"Detected {len(picks)} particles") + + # Calculate metrics if ground truth is available + if ground_truth is not None and len(ground_truth) > 0: + logger.info(f"Calculating metrics against {len(ground_truth)} ground truth particles") + metrics = calculate_detector_metrics( + picks, ground_truth, scores, tolerance=10.0 + ) + + logger.info(f"Precision: {metrics['precision']:.3f}") + logger.info(f"Recall: {metrics['recall']:.3f}") + logger.info(f"F1 Score: {metrics['f1_score']:.3f}") + logger.info(f"Average Precision: {metrics.get('average_precision', 'N/A')}") + + all_metrics.append(metrics) + + # Plot and save results + logger.info("Plotting results...") + plot_results(tomogram_np, picks, ground_truth) + plt.savefig(os.path.join(output_dir, f'tomo_{idx}_results.png')) + + # Save overall metrics if available + if all_metrics: + # Calculate average metrics + avg_metrics = { + 'precision': np.mean([m['precision'] for m in all_metrics]), + 'recall': np.mean([m['recall'] for m in all_metrics]), + 'f1_score': np.mean([m['f1_score'] for m in all_metrics]), + } + + if 'average_precision' in all_metrics[0]: + avg_metrics['average_precision'] = np.mean([m['average_precision'] for m in all_metrics]) + + # Save to file + with open(os.path.join(output_dir, 'metrics.txt'), 'w') as f: + f.write(f"Average Metrics across {len(all_metrics)} tomograms:\n") + f.write(f"Precision: {avg_metrics['precision']:.3f}\n") + f.write(f"Recall: {avg_metrics['recall']:.3f}\n") + f.write(f"F1 Score: {avg_metrics['f1_score']:.3f}\n") + if 'average_precision' in avg_metrics: + f.write(f"Average Precision: {avg_metrics['average_precision']:.3f}\n") + +def inference_on_file(config_path, weights_path, output_dir): + """ + Run inference on a local copick configuration. + + Args: + config_path: path to the copick configuration file + weights_path: path to trained model weights + output_dir: directory to save outputs + """ + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Load copick project + logger.info(f"Loading copick project from {config_path}") + root = copick.from_file(config_path) + + # Create detector + logger.info("Creating MONAI detector...") + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + feature_size=32, + anchor_sizes=[(8, 8, 8), (16, 16, 16), (32, 32, 32)], + device="cuda" if torch.cuda.is_available() else "cpu", + sliding_window_size=(64, 64, 64), + sliding_window_batch_size=4, + sliding_window_overlap=0.25, + detection_threshold=0.3, + nms_threshold=0.1, + max_detections_per_volume=1000 + ) + + # Load weights + if os.path.exists(weights_path): + logger.info(f"Loading weights from {weights_path}") + detector.load_weights(weights_path) + else: + logger.warning(f"Weights file {weights_path} not found, using untrained model") + + # Process each run in the project + for run_idx, run in enumerate(root.runs): + logger.info(f"Processing run {run_idx+1}/{len(root.runs)}: {run.name}") + + # Get all voxel spacings + for vs_idx, vs in enumerate(run.voxel_spacings): + logger.info(f"Processing voxel spacing {vs_idx+1}/{len(run.voxel_spacings)}: {vs.voxel_size}") + + # Get all tomograms + for tomo_idx, tomo in enumerate(vs.tomograms): + logger.info(f"Processing tomogram {tomo_idx+1}/{len(vs.tomograms)}: {tomo.tomo_type}") + + # Load the tomogram + try: + logger.info("Loading tomogram data...") + tomogram_array = tomo.numpy() + except Exception as e: + logger.error(f"Error loading tomogram: {e}") + continue + + # Run inference + logger.info("Running inference...") + picks, scores = detector.detect(tomogram_array, return_scores=True, use_inferer=True) + logger.info(f"Detected {len(picks)} particles") + + # Get ground truth if available + ground_truth = None + picks_sets = run.get_picks() + + if picks_sets: + # Combine all pick sets + all_picks = [] + for pick_set in picks_sets: + try: + points, _ = pick_set.numpy() + all_picks.append(points) + except Exception as e: + logger.error(f"Error loading picks: {e}") + + if all_picks: + ground_truth = np.vstack(all_picks) + logger.info(f"Loaded {len(ground_truth)} ground truth picks") + + # Calculate metrics + logger.info("Calculating metrics...") + metrics = calculate_detector_metrics( + picks, ground_truth, scores, tolerance=10.0 + ) + + logger.info(f"Precision: {metrics['precision']:.3f}") + logger.info(f"Recall: {metrics['recall']:.3f}") + logger.info(f"F1 Score: {metrics['f1_score']:.3f}") + logger.info(f"Average Precision: {metrics.get('average_precision', 'N/A')}") + + # Save metrics + metrics_path = os.path.join(output_dir, f'run{run_idx}_vs{vs_idx}_tomo{tomo_idx}_metrics.txt') + with open(metrics_path, 'w') as f: + for key, value in metrics.items(): + if not isinstance(value, list): + f.write(f"{key}: {value}\n") + + # Plot and save results + logger.info("Plotting results...") + plot_results(tomogram_array, picks, ground_truth) + plt.savefig(os.path.join(output_dir, f'run{run_idx}_vs{vs_idx}_tomo{tomo_idx}_results.png')) + + # Save the picks + picks_path = os.path.join(output_dir, f'run{run_idx}_vs{vs_idx}_tomo{tomo_idx}_picks.npy') + np.save(picks_path, picks) + logger.info(f"Saved picks to {picks_path}") + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="MONAI Particle Detector Example") + + # Create subparsers for different modes + subparsers = parser.add_subparsers(dest='mode', help='Mode to run') + + # Train mode + train_parser = subparsers.add_parser('train', help='Train a model on a dataset') + train_parser.add_argument('--dataset-ids', type=int, nargs='+', required=True, + help='Dataset IDs from the CryoET Data Portal') + train_parser.add_argument('--overlay-root', type=str, required=True, + help='Root URL for the overlay storage') + train_parser.add_argument('--output-dir', type=str, default='./output', + help='Directory to save outputs') + train_parser.add_argument('--epochs', type=int, default=10, + help='Number of epochs to train for') + train_parser.add_argument('--batch-size', type=int, default=4, + help='Batch size for training') + + # Inference on dataset mode + inference_dataset_parser = subparsers.add_parser('inference-dataset', + help='Run inference on a dataset') + inference_dataset_parser.add_argument('--dataset-ids', type=int, nargs='+', required=True, + help='Dataset IDs from the CryoET Data Portal') + inference_dataset_parser.add_argument('--overlay-root', type=str, required=True, + help='Root URL for the overlay storage') + inference_dataset_parser.add_argument('--weights', type=str, required=True, + help='Path to trained model weights') + inference_dataset_parser.add_argument('--output-dir', type=str, default='./output', + help='Directory to save outputs') + + # Inference on file mode + inference_file_parser = subparsers.add_parser('inference-file', + help='Run inference on a local copick configuration') + inference_file_parser.add_argument('--config', type=str, required=True, + help='Path to the copick configuration file') + inference_file_parser.add_argument('--weights', type=str, required=True, + help='Path to trained model weights') + inference_file_parser.add_argument('--output-dir', type=str, default='./output', + help='Directory to save outputs') + + # Parse arguments + args = parser.parse_args() + + # Run the specified mode + if args.mode == 'train': + train_on_dataset( + dataset_ids=args.dataset_ids, + overlay_root=args.overlay_root, + output_dir=args.output_dir, + num_epochs=args.epochs, + batch_size=args.batch_size + ) + elif args.mode == 'inference-dataset': + inference_on_dataset( + dataset_ids=args.dataset_ids, + overlay_root=args.overlay_root, + weights_path=args.weights, + output_dir=args.output_dir + ) + elif args.mode == 'inference-file': + inference_on_file( + config_path=args.config, + weights_path=args.weights, + output_dir=args.output_dir + ) + else: + parser.print_help() + +if __name__ == "__main__": + main() diff --git a/tests/test_dog_detector.py b/tests/test_dog_detector.py new file mode 100644 index 0000000..394c3e8 --- /dev/null +++ b/tests/test_dog_detector.py @@ -0,0 +1,121 @@ +import os +import unittest +import numpy as np +from unittest.mock import patch, MagicMock + +from copick_torch.detectors.dog_detector import DoGParticleDetector + +class TestDoGDetector(unittest.TestCase): + def setUp(self): + # Create a simple 3D test volume + self.volume = np.zeros((64, 64, 64), dtype=np.float32) + + # Add some particle-like features + self.particle_positions = [ + (20, 20, 20), + (40, 40, 40), + (20, 40, 20), + (40, 20, 40) + ] + + # Create Gaussian-like particles at specified positions + for pos in self.particle_positions: + z, y, x = pos + # Create a small Gaussian blob + z_grid, y_grid, x_grid = np.mgrid[z-5:z+5, y-5:y+5, x-5:x+5] + dist_sq = (z_grid - z)**2 + (y_grid - y)**2 + (x_grid - x)**2 + # Add Gaussian blob to the volume + self.volume[z-5:z+5, y-5:y+5, x-5:x+5] += np.exp(-dist_sq / 8.0) + + # Initialize the detector + self.detector = DoGParticleDetector( + sigma1=1.0, + sigma2=3.0, + threshold_abs=0.1, + min_distance=5 + ) + + def test_init(self): + """Test detector initialization.""" + self.assertEqual(self.detector.sigma1, 1.0) + self.assertEqual(self.detector.sigma2, 3.0) + self.assertEqual(self.detector.threshold_abs, 0.1) + self.assertEqual(self.detector.min_distance, 5) + + def test_detect(self): + """Test basic particle detection.""" + # Detect particles + peaks = self.detector.detect(self.volume) + + # Check if we get some peaks + self.assertGreater(len(peaks), 0) + + # Check that peaks array has correct shape + self.assertEqual(peaks.shape[1], 3) # Each peak should have (z, y, x) coordinates + + def test_detect_with_scores(self): + """Test particle detection with scores.""" + # Detect particles and get scores + peaks, scores = self.detector.detect(self.volume, return_scores=True) + + # Check if we get some peaks + self.assertGreater(len(peaks), 0) + + # Check that peaks and scores have same length + self.assertEqual(len(peaks), len(scores)) + + def test_detect_multiscale(self): + """Test multiscale particle detection.""" + # Define multiple scales + sigma_pairs = [(1.0, 2.0), (2.0, 4.0)] + + # Detect particles at multiple scales + peaks = self.detector.detect_multiscale(self.volume, sigma_pairs) + + # Check if we get some peaks + self.assertGreater(len(peaks), 0) + + def test_optimize_parameters(self): + """Test parameter optimization.""" + # Create a mock optimize function to avoid long computation + with patch.object(self.detector, '_calculate_metrics', return_value=(0.8, 0.7, 0.75)): + # Run optimization with reduced parameter space + result = self.detector.optimize_parameters( + self.volume, + np.array(self.particle_positions), + sigma1_range=(1.0, 1.5, 0.5), + sigma2_range=(2.0, 2.5, 0.5), + threshold_range=(0.1, 0.2, 0.1), + min_distance_range=(5, 6, 1) + ) + + # Check if we get a result + self.assertIsInstance(result, dict) + self.assertIn('f1', result) + + def test_metrics_calculation(self): + """Test metrics calculation.""" + # Create detected peaks close to ground truth + detected = np.array([ + [21, 21, 21], # Close to first particle + [41, 41, 41], # Close to second particle + [30, 30, 30] # False positive + ]) + + ground_truth = np.array(self.particle_positions) + + # Calculate metrics + precision, recall, f1 = self.detector._calculate_metrics( + detected, ground_truth, tolerance=5.0 + ) + + # Check reasonable metrics + self.assertGreaterEqual(precision, 0.0) + self.assertLessEqual(precision, 1.0) + self.assertGreaterEqual(recall, 0.0) + self.assertLessEqual(recall, 1.0) + self.assertGreaterEqual(f1, 0.0) + self.assertLessEqual(f1, 1.0) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..8e048d5 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,168 @@ +import unittest +import numpy as np + +from copick_torch.metrics import ( + calculate_distances, + calculate_precision_recall_f1, + calculate_average_precision, + calculate_detector_metrics +) + +class TestMetrics(unittest.TestCase): + def setUp(self): + # Create test data + self.detected = np.array([ + [10, 10, 10], + [30, 30, 30], + [50, 50, 50], + [70, 70, 70] + ]) + + self.ground_truth = np.array([ + [12, 12, 12], # Close to first detection + [32, 32, 32], # Close to second detection + [90, 90, 90] # No matching detection + ]) + + self.scores = np.array([0.9, 0.8, 0.7, 0.6]) + + def test_calculate_distances(self): + """Test distance calculation between detected and ground truth particles.""" + distances = calculate_distances(self.detected, self.ground_truth) + + # Check shape + self.assertEqual(distances.shape, (4, 3)) + + # Check a few specific distances + # Distance between [10,10,10] and [12,12,12] should be sqrt(12) + self.assertAlmostEqual(distances[0, 0], np.sqrt(12), delta=1e-6) + + # Distance between [30,30,30] and [32,32,32] should be sqrt(12) + self.assertAlmostEqual(distances[1, 1], np.sqrt(12), delta=1e-6) + + # Distance between [70,70,70] and [90,90,90] should be sqrt(1200) + self.assertAlmostEqual(distances[3, 2], np.sqrt(1200), delta=1e-6) + + def test_precision_recall_f1(self): + """Test precision, recall, and F1 calculation.""" + # With tolerance 5.0, only the first two detections should match + precision, recall, f1 = calculate_precision_recall_f1( + self.detected, self.ground_truth, tolerance=5.0 + ) + + # Expected results + # Precision = 2/4 = 0.5 (2 true positives out of 4 detections) + # Recall = 2/3 = 0.667 (2 true positives out of 3 ground truth) + # F1 = 2 * 0.5 * 0.667 / (0.5 + 0.667) = 0.572 + self.assertAlmostEqual(precision, 0.5, delta=1e-6) + self.assertAlmostEqual(recall, 2/3, delta=1e-6) + self.assertAlmostEqual(f1, 2 * 0.5 * (2/3) / (0.5 + 2/3), delta=1e-6) + + # With tolerance 50.0, three detections should match (leaving one ground truth unmatched) + precision, recall, f1 = calculate_precision_recall_f1( + self.detected, self.ground_truth, tolerance=50.0 + ) + + # Expected results + # Precision = 3/4 = 0.75 (3 true positives out of 4 detections) + # Recall = 3/3 = 1.0 (3 true positives out of 3 ground truth) + # F1 = 2 * 0.75 * 1.0 / (0.75 + 1.0) = 0.857 + self.assertAlmostEqual(precision, 0.75, delta=1e-6) + self.assertAlmostEqual(recall, 1.0, delta=1e-6) + self.assertAlmostEqual(f1, 2 * 0.75 * 1.0 / (0.75 + 1.0), delta=1e-6) + + def test_average_precision(self): + """Test average precision calculation.""" + # With tolerance 5.0, only the first two detections should match + ap, precision_values, recall_values, thresholds = calculate_average_precision( + self.detected, self.scores, self.ground_truth, tolerance=5.0 + ) + + # Check that values are returned + self.assertTrue(isinstance(ap, float)) + self.assertTrue(isinstance(precision_values, list)) + self.assertTrue(isinstance(recall_values, list)) + self.assertTrue(isinstance(thresholds, list)) + + # Check that values are in correct range + self.assertGreaterEqual(ap, 0.0) + self.assertLessEqual(ap, 1.0) + for p in precision_values: + self.assertGreaterEqual(p, 0.0) + self.assertLessEqual(p, 1.0) + for r in recall_values: + self.assertGreaterEqual(r, 0.0) + self.assertLessEqual(r, 1.0) + + def test_detector_metrics(self): + """Test comprehensive detector metrics calculation.""" + # Calculate metrics + metrics = calculate_detector_metrics( + self.detected, self.ground_truth, self.scores, tolerance=5.0 + ) + + # Check that expected keys are present + expected_keys = [ + "precision", "recall", "f1_score", + "num_detections", "num_ground_truth", + "true_positives", "false_positives", "false_negatives", + "average_precision", "precision_values", "recall_values", "thresholds" + ] + for key in expected_keys: + self.assertIn(key, metrics) + + # Check specific values + self.assertEqual(metrics["num_detections"], 4) + self.assertEqual(metrics["num_ground_truth"], 3) + self.assertEqual(metrics["true_positives"], 2) + self.assertEqual(metrics["false_positives"], 2) + self.assertEqual(metrics["false_negatives"], 1) + + def test_empty_detections(self): + """Test metrics with empty detections.""" + empty_detected = np.zeros((0, 3)) + + # Calculate metrics + precision, recall, f1 = calculate_precision_recall_f1( + empty_detected, self.ground_truth, tolerance=5.0 + ) + + # With no detections, precision is undefined (set to 0), recall is 0 + self.assertEqual(precision, 0.0) + self.assertEqual(recall, 0.0) + self.assertEqual(f1, 0.0) + + def test_empty_ground_truth(self): + """Test metrics with empty ground truth.""" + empty_ground_truth = np.zeros((0, 3)) + + # Calculate metrics + precision, recall, f1 = calculate_precision_recall_f1( + self.detected, empty_ground_truth, tolerance=5.0 + ) + + # With no ground truth, precision is 0, recall is undefined (set to 0) + self.assertEqual(precision, 0.0) + self.assertEqual(recall, 0.0) + self.assertEqual(f1, 0.0) + + def test_detector_metrics_no_scores(self): + """Test metrics calculation without confidence scores.""" + # Calculate metrics without scores + metrics = calculate_detector_metrics( + self.detected, self.ground_truth, tolerance=5.0 + ) + + # Check that AP-related keys are not present + self.assertNotIn("average_precision", metrics) + self.assertNotIn("precision_values", metrics) + self.assertNotIn("recall_values", metrics) + self.assertNotIn("thresholds", metrics) + + # Check that other metrics are still calculated + self.assertIn("precision", metrics) + self.assertIn("recall", metrics) + self.assertIn("f1_score", metrics) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_monai_detector.py b/tests/test_monai_detector.py new file mode 100644 index 0000000..0aca852 --- /dev/null +++ b/tests/test_monai_detector.py @@ -0,0 +1,203 @@ +import os +import unittest +import numpy as np +import torch +from unittest.mock import patch, MagicMock + +from copick_torch.detectors.monai_detector import MONAIParticleDetector + +class TestMONAIDetector(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Skip if CUDA is not available + if not torch.cuda.is_available(): + cls.device = "cpu" + else: + cls.device = "cuda" + + def setUp(self): + # Create a simple 3D test volume + self.volume = np.zeros((64, 64, 64), dtype=np.float32) + + # Add some particle-like features + self.particle_positions = [ + (20, 20, 20), + (40, 40, 40), + (20, 40, 20), + (40, 20, 40) + ] + + # Create Gaussian-like particles at specified positions + for pos in self.particle_positions: + z, y, x = pos + # Create a small Gaussian blob + z_grid, y_grid, x_grid = np.mgrid[z-5:z+5, y-5:y+5, x-5:x+5] + dist_sq = (z_grid - z)**2 + (y_grid - y)**2 + (x_grid - x)**2 + # Add Gaussian blob to the volume + self.volume[z-5:z+5, y-5:y+5, x-5:x+5] += np.exp(-dist_sq / 8.0) + + def test_init(self): + """Test detector initialization with minimal configuration.""" + try: + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + device="cpu" # Use CPU for testing + ) + self.assertIsNotNone(detector) + except Exception as e: + self.fail(f"Detector initialization failed with exception: {e}") + + def test_detector_attributes(self): + """Test that detector has the expected attributes.""" + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + device="cpu" + ) + + # Check attributes + self.assertEqual(detector.spatial_dims, 3) + self.assertEqual(detector.num_classes, 1) + self.assertEqual(detector.device, "cpu") + + # Check that MONAI detector is created + self.assertIsNotNone(detector.detector) + + @unittest.skipIf(torch.cuda.is_available() == False, "CUDA not available") + def test_cuda_detection(self): + """Test detection with CUDA if available.""" + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + device="cuda", + detection_threshold=0.1 # Lower threshold for test + ) + + # Mock the detector forward method to return predefined detections + mock_detections = [{ + "boxes": torch.tensor([[15, 15, 15, 25, 25, 25], + [35, 35, 35, 45, 45, 45]]).cuda(), + "labels": torch.tensor([0, 0]).cuda(), + "labels_scores": torch.tensor([0.9, 0.8]).cuda() + }] + with patch.object(detector.detector, 'forward', return_value=mock_detections): + # Convert volume to tensor + volume_tensor = torch.from_numpy(self.volume).unsqueeze(0).cuda() # Add channel dimension + + # Detect particles + coords, scores = detector.detect(volume_tensor, return_scores=True) + + # Check shape and content + self.assertEqual(coords.shape, (2, 3)) # 2 particles, 3D coordinates + self.assertEqual(scores.shape, (2,)) # 2 confidence scores + + # Check coordinate calculation + np.testing.assert_allclose(coords[0], [20, 20, 20], atol=1.0) # Center of first box + np.testing.assert_allclose(coords[1], [40, 40, 40], atol=1.0) # Center of second box + + def test_cpu_detection(self): + """Test detection on CPU.""" + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + device="cpu", + detection_threshold=0.1 # Lower threshold for test + ) + + # Mock the detector forward method to return predefined detections + mock_detections = [{ + "boxes": torch.tensor([[15, 15, 15, 25, 25, 25], + [35, 35, 35, 45, 45, 45]]), + "labels": torch.tensor([0, 0]), + "labels_scores": torch.tensor([0.9, 0.8]) + }] + with patch.object(detector.detector, 'forward', return_value=mock_detections): + # Detect particles + coords = detector.detect(self.volume) + + # Check shape + self.assertEqual(coords.shape, (2, 3)) # 2 particles, 3D coordinates + + # Check coordinate calculation + np.testing.assert_allclose(coords[0], [20, 20, 20], atol=1.0) # Center of first box + np.testing.assert_allclose(coords[1], [40, 40, 40], atol=1.0) # Center of second box + + def test_2d_detection(self): + """Test 2D detection.""" + # Create a 2D test image + image = np.zeros((64, 64), dtype=np.float32) + + # Add some particle-like features + particle_positions_2d = [(20, 20), (40, 40)] + + for pos in particle_positions_2d: + y, x = pos + # Create a small Gaussian blob + y_grid, x_grid = np.mgrid[y-5:y+5, x-5:x+5] + dist_sq = (y_grid - y)**2 + (x_grid - x)**2 + # Add Gaussian blob to the image + image[y-5:y+5, x-5:x+5] += np.exp(-dist_sq / 8.0) + + # Create 2D detector + detector = MONAIParticleDetector( + spatial_dims=2, + num_classes=1, + device="cpu", + detection_threshold=0.1 # Lower threshold for test + ) + + # Mock the detector forward method to return predefined detections + mock_detections = [{ + "boxes": torch.tensor([[15, 15, 25, 25], + [35, 35, 45, 45]]), + "labels": torch.tensor([0, 0]), + "labels_scores": torch.tensor([0.9, 0.8]) + }] + with patch.object(detector.detector, 'forward', return_value=mock_detections): + # Detect particles + coords = detector.detect(image) + + # Check shape + self.assertEqual(coords.shape, (2, 2)) # 2 particles, 2D coordinates + + # Check coordinate calculation + np.testing.assert_allclose(coords[0], [20, 20], atol=1.0) # Center of first box + np.testing.assert_allclose(coords[1], [40, 40], atol=1.0) # Center of second box + + def test_save_load_weights(self): + """Test saving and loading weights.""" + import tempfile + + # Create temporary file + with tempfile.NamedTemporaryFile(suffix='.pt') as tmp: + # Create detector + detector = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + device="cpu" + ) + + # Save weights + detector.save_weights(tmp.name) + + # Check that file exists and has content + self.assertTrue(os.path.exists(tmp.name)) + self.assertGreater(os.path.getsize(tmp.name), 0) + + # Create a new detector + detector2 = MONAIParticleDetector( + spatial_dims=3, + num_classes=1, + device="cpu" + ) + + # Load weights + detector2.load_weights(tmp.name) + + # Check that weights are loaded (this is a basic check, not comprehensive) + # Here we just check that loading doesn't raise an exception + self.assertTrue(True) + +if __name__ == '__main__': + unittest.main() From dc00d83ce7ce8d81e419523218df07a2f816d815 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Sun, 6 Apr 2025 21:52:34 -0400 Subject: [PATCH 2/2] Fix CI test failures - Fix ResNetBasicBlock -> ResNetBlock in MONAI detector - Add compatibility for older versions of scikit-image in DoG detector by handling peak_local_max without the indices parameter --- copick_torch/detectors/dog_detector.py | 25 +++++++++++++++++------- copick_torch/detectors/monai_detector.py | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/copick_torch/detectors/dog_detector.py b/copick_torch/detectors/dog_detector.py index 88cc63b..1a41ea3 100644 --- a/copick_torch/detectors/dog_detector.py +++ b/copick_torch/detectors/dog_detector.py @@ -100,13 +100,24 @@ def detect(self, volume: np.ndarray, return_scores: bool = False) -> Union[np.nd # Find local maxima self.logger.info(f"Finding peaks with min_distance={self.min_distance}, threshold={self.threshold_abs}") - peaks = peak_local_max( - dog_vol, - min_distance=self.min_distance, - threshold_abs=self.threshold_abs, - exclude_border=self.exclude_border, - indices=True - ) + # Handle peak_local_max with compatibility for different scikit-image versions + try: + # Newer versions of scikit-image + peaks = peak_local_max( + dog_vol, + min_distance=self.min_distance, + threshold_abs=self.threshold_abs, + exclude_border=self.exclude_border, + indices=True + ) + except TypeError: + # Older versions of scikit-image don't have the indices parameter (it's always True) + peaks = peak_local_max( + dog_vol, + min_distance=self.min_distance, + threshold_abs=self.threshold_abs, + exclude_border=self.exclude_border + ) # Return peak coordinates and peak values if requested if return_scores: diff --git a/copick_torch/detectors/monai_detector.py b/copick_torch/detectors/monai_detector.py index fc2744a..6e9bd55 100644 --- a/copick_torch/detectors/monai_detector.py +++ b/copick_torch/detectors/monai_detector.py @@ -158,7 +158,7 @@ def _create_detector(self, spatial_dims, num_classes, feature_size, pretrained): # Use ResNet18 as the backbone for faster inference backbone = resnet.ResNet( spatial_dims=spatial_dims, - block=resnet.ResNetBasicBlock, + block=resnet.ResNetBlock, layers=[2, 2, 2, 2], # ResNet18 architecture block_inplanes=resnet.get_inplanes(), n_input_channels=1, # Single channel for tomogram data