diff --git a/requirements.txt b/requirements.txt index c392a4f..7d74a5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ scikit-learn>=0.24.0 pandas>=1.3.0 opencv-python>=4.5.0 Pillow>=8.3.0 +numba==0.61.0 # Deep learning torch>=1.9.0 @@ -18,6 +19,7 @@ albumentations>=1.0.0 scikit-learn-extra>=0.2.0 connected-components-3d>=3.0.0 SimpleITK>=2.1.0 +cellpose ==3.1.1.1 # Development and utilities jupyterlab>=3.0.0 @@ -26,9 +28,9 @@ black>=21.6b0 isort>=5.9.0 flake8>=3.9.0 tqdm>=4.64.0 +pyyaml==6.0.2 # Logging and monitoring loguru>=0.7.0 # Optional GPU acceleration (install with: pip install cupy-cuda12x) # cupy-cuda12x>=13.0.0 # Uncomment if you have CUDA 12.x GPU support -cellpose \ No newline at end of file diff --git a/scripts/do_segmentation.py b/scripts/do_segmentation.py index ecf149b..f50b146 100644 --- a/scripts/do_segmentation.py +++ b/scripts/do_segmentation.py @@ -94,18 +94,17 @@ def main(): else: pool = multiprocessing.Pool(processes=args.workers) """ - # Completed: IF slide_id is provided, use the slide_id and data_dir to load the slides - # TODO: Multiprocessing data loader log.logger.debug("Loading slides...") slides = data_loader.load_slides(args.data_dir) - # TODO: A non offset based composite creation should be implemented in the data loader log.logger.debug("Creating composites...") - composite_images = data_loader.get_composites(slides, config.SLIDE_INDEX_OFFSET) # creaing composites should be preprocessing which is in segmentor + composite_images = cellposeSegmentor.preprocess(slides) log.logger.debug("Running Segmentation...") binary_masks = cellposeSegmentor.segment(composite_images) + image_crops, mask_crops, centers = cellposeSegmentor.postprocess() + log.logger.debug("Saving masks...") cellposeSegmentor.save_masks(binary_masks) diff --git a/src/deep_learning/base.py b/src/deep_learning/base.py index 237b8d8..35555fd 100644 --- a/src/deep_learning/base.py +++ b/src/deep_learning/base.py @@ -2,7 +2,7 @@ Base class for all traditional segmentation algorithms. """ from abc import ABC, abstractmethod -from scipy import ndimage as ndi +import numpy as np class BaseSegmenter(ABC): """Base class that all traditional segmentation algorithms should inherit from.""" @@ -17,7 +17,7 @@ def __init__(self, config=None): self.config = config or {} @abstractmethod - def segment(self, images): + def segment(self, images) -> np.ndarray: """ Segment the input images. @@ -29,7 +29,8 @@ def segment(self, images): """ pass - def preprocess(self, images_dir): # get_composites shouuld be here + @abstractmethod + def preprocess(self, images) -> np.ndarray: # get_composites shouuld be here """ Preprocess the input image before segmentation. @@ -41,15 +42,16 @@ def preprocess(self, images_dir): # get_composites shouuld be here """ pass - def postprocess(self, mask): + @abstractmethod + def postprocess(self, masks=None, images=None) -> list[np.ndarray]: """ - Postprocess the segmentation mask. - - Args: - mask (numpy.ndarray): Segmentation mask to postprocess. - + Postprocess the segmentation mask. Extracts cropped cell images using the segmented masks. + + Arguments: + masks (np.ndarray): Array of segmented masks with shape (N, C, H, W). + images (np.ndarray): Array of original images with shape (N, C, H, W). Returns: - numpy.ndarray: Postprocessed mask. + List[np.ndarray]: List of cropped cell images. """ pass \ No newline at end of file diff --git a/src/deep_learning/cellpose.py b/src/deep_learning/cellpose.py index a8a50fe..66ef54f 100644 --- a/src/deep_learning/cellpose.py +++ b/src/deep_learning/cellpose.py @@ -3,48 +3,187 @@ from pathlib import Path import numpy as np import cv2 -import matplotlib.pyplot as plt # use in debug console from src.deep_learning.base import BaseSegmenter +import os +import numpy as np +import cv2 +import multiprocessing +from .base import BaseSegmenter +from .utils.config import Config +from .utils.loader import load_img +from .utils.image import compute_composite +from .utils.crop import crop_single_image +from .utils.mask import binary_masks import loguru as log class CellposeSegmentor(BaseSegmenter): - def __init__(self, config): + def __init__(self, config: Config): """ Initialize the Cellpose segmentor. This class is a wrapper around the Cellpose deep learning model for image segmentation. It inherits from BaseSegmenter and implements the segment method. + + This class provides functionality for: + - Loading grayscale microscopy images from a directory + - Combining multi-channel scans into composite images + - Running segmentation using Cellpose + - Saving mask outputs + - Extracting cropped cell images from masks + + Attributes: + model (cellpose.models.CellposeModel): The loaded Cellpose model. + config (Config): Configuration object containing paths and settings. + + Methods: + load_images(image_dir): Loads images from a directory using multiprocessing. + combine_images(images): Combines 4-channel scans into RGB composites. + segment_frames(frames): Runs Cellpose segmentation on image frames. + save_masks(masks): Saves the predicted masks to disk. + get_cell_crops(masks, images): Extracts cropped cell images and their masks. + run(image_dir): Main workflow to segment images from a directory. """ self.config = config if core.use_gpu() == False: raise ImportError("No GPU access") - if not Path(self.config.DEEP_LEARNING_MODELS_DIR).exists(): - log.logger.warning("Pretrained model path does not exist, using default model.") - self.config.DEEP_LEARNING_CONFIG["model"]["name"] = "cpsam" # Default model if not specified # not sure why syntax is so cursed - - if self.config.MODEL == 'cellpose': - self.model = models.CellposeModel(gpu = True, - pretrained_model=str(Path(self.config.DEEP_LEARNING_MODELS_DIR, self.config.DEEP_LEARNING_CONFIG["model"]["name"])), - device=torch.device(self.config.DEEP_LEARNING_CONFIG["device"])) + if not self.config.data_dir.exists(): + raise FileNotFoundError(f"Data directory {self.config.data_dir} does not exist") - else: - pass # For future addition of models + if self.config.pretrained_model is None: + raise ValueError("Pretrained model must be specified") + + self.model = models.CellposeModel(gpu = True, + pretrained_model=str(self.config.pretrained_model), # ignore Pylance error, this code is correct + device=torch.device(self.config.device)) + self.image_data = np.empty(1) + self.composite_data = np.empty(1) + self.masks = np.empty(1) + self.stacked_scans_data = [] log.logger.debug("Cellpose Segmentor initialized.") - def save_masks(self, masks): - if not Path(self.config.PROCESSED_DATA_DIR).exists(): - self.config.PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True) + def segment(self, images=None) -> np.ndarray: + """ + Segment the input images. + + Args: + List of images (numpy.ndarray with shape NUM IMAGES * HEIGHT * WIDTH * 3): Input images to segment. + + Returns: + numpy.ndarray: Insance mask where each cell gets its own ID + """ + if not images: + images = self.composite_data + + self.masks, _, _ = self.model.eval(self.composite_data, diameter=15, channels=[0, 0]) + return self.masks + + def preprocess(self, images=None) -> np.ndarray: + """ + Preprocess the loaded input images before segmentation by combining different scan types into a BRG image understood by the segmentation module. + + Args: + image (numpy.ndarray): Input image to preprocess. + + Returns: + numpy.ndarray: Preprocessed image. + """ + if not images: + images = self.image_data + + frames=[] + offset = int(len(images)/4) + for i in range(offset): + image0 = images[i] + image1 = images[i+offset] + image2 = images[i+2*offset] + # skip Bright Field scan + image3 = images[i+3*offset] + stacked = np.stack([image0, image1, image2, image3], axis=-1) + self.stacked_scans_data.append(stacked) + frames.append(compute_composite(image0, image1, image2, image3)) - for i, mask in enumerate(masks): - mask_path = Path(self.config.PROCESSED_DATA_DIR, f"mask_{i}.png") - cv2.imwrite(mask_path, mask) + self.stacked_scans_data = np.stack(self.stacked_scans_data[1:], axis=0) # remove the first empty array + self.composite_data = np.ndarray(frames) + return np.ndarray(frames) - def segment(self, images): - masks, _, _ = self.model.eval(images,diameter=15,channels=[0, 0]) # test if pasing all the frames at once or one at a time is faster + def postprocess(self, masks=None, images=None) -> list[np.ndarray]: + """ + Postprocess the segmentation mask. Extracts cropped cell images using the segmented masks. + + Arguments: + masks (np.ndarray): Array of segmented masks with shape (N, C, H, W). + images (np.ndarray): Array of original images with shape (N, C, H, W). + Returns: + List[np.ndarray]: List of cropped cell images. + """ + if not masks: + masks = self.masks + if not images: + images = self.stacked_scans_data + + args = [ + ( + masks[j], images[j], + ) + for j in range(len(images)) + ] + + with multiprocessing.Pool(processes=max(1, multiprocessing.cpu_count() - 2)) as pool: + results = pool.map(crop_single_image, args) + + # Flatten results + image_crops, mask_crops, centers = [], [], [] + for img_crops, msk_crops, ctrs in results: + image_crops.extend(img_crops) + mask_crops.extend(msk_crops) + centers.extend(ctrs) + + del self.image_data + del self.composite_data + del self.masks + del self.stacked_scans_data + + return ( + [ np.transpose(np.stack(image_crops, axis = 0), (0,3,1,2)), # Convert to (N, C, H, W,) because thats what the current extration model expects, + # it is probaly worth a look at why that choice was made and if it can be undone + binary_masks(np.stack((mask_crops), axis=0)), + np.stack(centers, axis=0)] + ) + + def load_data(self, image_dir) -> np.ndarray: + """ + Load images from the specified directory, and return a list of images as numpy arrays. + The returned value is optional to use and self.image_data is what the segment wants to use unless overwritten - # return np.array(masks).astype(bool).astype(np.uint8)*255 # binarize the masks for visual check + Args: + image_dir(Path): os-valid path (Use pathlib.Path) for the folder with slide data - return masks \ No newline at end of file + """ + image_files = sorted(os.listdir(image_dir)) # list index must match the order of scans + + with multiprocessing.Pool(multiprocessing.cpu_count() - 2) as p: # save one core for the system and one more for good luck + args = [(image_dir, f) for f in image_files] + frames = p.map(load_img, args) + + self.image_data = np.array(frames, dtype=np.uint16) + return self.image_data + + def save_masks(self, masks) -> None: + if not self.config.mask_output_dir.exists(): + self.config.mask_output_dir.mkdir(parents=True, exist_ok=True) + + for i, mask in enumerate(masks): + mask_path = self.config.mask_output_dir / f"mask_{i}.png" + + # Handle mask format + if mask.dtype == bool or mask.max() <= 1: + mask_to_save = (mask * 255).astype(np.uint8) + else: + mask_to_save = mask.astype(np.uint8) + + success = cv2.imwrite(str(mask_path), mask_to_save) + if not success: + print(f"Warning: Failed to save mask {i} to {mask_path}") \ No newline at end of file diff --git a/src/deep_learning/utils/config.py b/src/deep_learning/utils/config.py new file mode 100644 index 0000000..7591977 --- /dev/null +++ b/src/deep_learning/utils/config.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from pathlib import Path + +@dataclass +class Config: + pretrained_model: Path + device: str + data_dir: Path + image_extension: str + mask_output_dir: Path + offset: int = 10 + diff --git a/src/deep_learning/utils/crop.py b/src/deep_learning/utils/crop.py new file mode 100644 index 0000000..3727ab7 --- /dev/null +++ b/src/deep_learning/utils/crop.py @@ -0,0 +1,91 @@ +import numpy as np + +def crop_single_image(args): + """ + Util function for each worker in a multiprocessing tool. Takes in an image and returns an cropped cells, cropped masks, and (y, x) center of + each cell + """ + mask, image = args + image_crops = [] + mask_crops = [] + centers = [] + for i in range(1, np.max(mask)): + center = find_center(mask, i) + if (center[0] < 38 or center[1] < 38 or + center[0] > image.shape[0] - 38 or + center[1] > image.shape[1] - 38): + continue + centers.append(center) + crop = crop_img_from_center(center, image) + crop = multiplex_mask_on_crop(crop, mask, i, center) # 75 * 75 * 4 + image_crops.append(crop) + mask_crops.append(crop_mask_from_center(center, mask)) + return image_crops, mask_crops, centers + +def crop_img_from_center(center, image): + left = 0 # slighly assymetric, the left gets 38 pixels while the right gets 37 pixels + right = 75 + bottom = 75 + top = 0 + if(center[0]>38): # Make sure h is not out of range + if(center[0]38): # Make sure w is not out of range + if(center[1]38): # Make sure h is not out of range + if(center[0]38): # Make sure w is not out of range + if(center[1] max_val] = max_val # Clips overflow + + rgb = rgb.astype(dtype) + return rgb diff --git a/src/deep_learning/utils/loader.py b/src/deep_learning/utils/loader.py new file mode 100644 index 0000000..5ab4965 --- /dev/null +++ b/src/deep_learning/utils/loader.py @@ -0,0 +1,10 @@ +import cv2 +import os + +def load_img(args): + """ + Util function for each worker in a pool for loading in raw data images + """ + folder, filename = args + full_path = os.path.join(folder,filename) + return cv2.imread(full_path, cv2.IMREAD_GRAYSCALE) \ No newline at end of file diff --git a/src/deep_learning/utils/mask.py b/src/deep_learning/utils/mask.py new file mode 100644 index 0000000..0979cae --- /dev/null +++ b/src/deep_learning/utils/mask.py @@ -0,0 +1,15 @@ +import numpy as np + +def binary_masks(masks): + """ + Takes an np.array (shape N, 75, 75) with instance values and converts it to a np.array (shape N, 1, 75, 75) + with 1 or 0 in the second dimension to indicate mask or no mask. + Any nonzero value in the original mask is set to 1. + """ + # Ensure input is a numpy array + masks = np.asarray(masks) + # Create binary masks: 1 where mask > 0, else 0 + binary = (masks > 0).astype(np.uint8) + # Add a channel dimension (axis=1) + binary = binary[:, np.newaxis, :, :] + return binary diff --git a/src/utils/data_loader.py b/src/utils/data_loader.py index f465372..26678c9 100644 --- a/src/utils/data_loader.py +++ b/src/utils/data_loader.py @@ -9,6 +9,7 @@ from typing import List, Dict, Tuple, Optional, Union, Callable from dataclasses import dataclass from pathlib import Path +import multiprocessing @dataclass class SegmentationSample: @@ -19,6 +20,14 @@ class SegmentationSample: mask: Optional[np.ndarray] = None metadata: Optional[Dict] = None +def load_img(args): + """ + Util function for each worker in a pool for loading in raw data images + """ + folder, filename = args + full_path = os.path.join(folder,filename) + return cv2.imread(full_path, cv2.IMREAD_GRAYSCALE) + class SegmentationDataLoader: """ @@ -35,8 +44,8 @@ def __init__( self, image_dir: str, mask_dir: Optional[str] = None, - image_ext: Union[str, List[str]] = ('.png', '.jpg', '.jpeg', '.tif', '.tiff'), - mask_ext: Union[str, List[str]] = ('.png', '.jpg', '.jpeg', '.tif', '.tiff'), + image_ext: Union[str, List[str]] = ['.png', '.jpg', '.jpeg', '.tif', '.tiff'], + mask_ext: Union[str, List[str]] = ['.png', '.jpg', '.jpeg', '.tif', '.tiff'], image_preprocessing: Optional[Callable] = None, mask_preprocessing: Optional[Callable] = None, recursive: bool = False @@ -92,6 +101,9 @@ def _pair_images_with_masks(self) -> Dict[str, str]: Returns: Dictionary mapping image paths to mask paths """ + if self.mask_dir == None: + return {} # never going to hit this line but removes errors below + pairs = {} mask_files = self._find_files(self.mask_dir, self.mask_ext) @@ -184,37 +196,10 @@ def get_sample_with_mask(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: raise ValueError(f"No mask available for image at index {idx}") return sample.image, sample.mask - def load_slides(self, data_dir, slide_id=None): - """ - Load images from the specified directory, and return a list of images as numpy arrays. - - Args: - slides_path (str): Path to the directory containing the slide images. - - Returns: - np.ndarray (16 bit): Array of images loaded from the directory, each image is a numpy array. - """ - if slide_id is None: - slide_path = Path(data_dir) - else: - slide_path = Path(data_dir, slide_id) - if not slide_path.exists(): - raise ValueError(f"Slide path {slide_path} does not exist") - - - image_files = sorted(os.listdir(slide_path)) # list index must match the order of scans - - frames = [] - for image_file in image_files: - image = cv2.imread(Path(slide_path, image_file), cv2.IMREAD_GRAYSCALE) - frames.append(image) - - return np.array(frames, dtype=np.uint16) - def compute_composite(self, dapi, ck, cd45, fitc): """ - COmbine DAPI, CK, CD45, and FITC channels into a single RGB composite image. Used by CellposeSegmentor. + Combine DAPI, CK, CD45, and FITC channels into a single RGB composite image. Used by CellposeSegmentor. Args: dapi (np.ndarray): DAPI channel image. @@ -253,10 +238,37 @@ def get_composites(self, slides, offset, save_composites=False): # skip Bright Field scan image3 = slides[i+3*offset] frames.append(self.compute_composite(image0, image1, image2, image3)) + if self.mask_dir == None: + return {} # never going to hit this line but removes errors below if save_composites: composite_path = Path(self.mask_dir, f"composite_{i}.png") cv2.imwrite(str(composite_path), frames[-1]) return frames - \ No newline at end of file + + # km + def load_slides(self, data_dir, slide_id=None): + """ + Load images from the specified directory, and return a list of images as numpy arrays. + + Args: + slides_path (str): Path to the directory containing the slide images. + + Returns: + np.ndarray (16 bit): Array of images loaded from the directory, each image is a numpy array. + """ + if slide_id is None: + slide_path = Path(data_dir) + else: + slide_path = Path(data_dir, slide_id) + if not slide_path.exists(): + raise ValueError(f"Slide path {slide_path} does not exist") + + image_files = sorted(os.listdir(slide_path)) # list index must match the order of scans + + with multiprocessing.Pool(multiprocessing.cpu_count() - 2) as p: # save one core for the system and one more for good luck + args = [(slide_path, f) for f in image_files] + frames = p.map(load_img, args) + + return np.array(frames, dtype=np.uint16) \ No newline at end of file