Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ cython_debug/

# custom
output/
outputs/
archive/
5 changes: 5 additions & 0 deletions slide2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
__version__ = "2.0.0"

import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "hs2p"))
3 changes: 2 additions & 1 deletion slide2vec/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ tiling:
use_otsu: false # use otsu's method instead of simple binary thresholding
tissue_pixel_value: 1 # value of tissue pixel in pre-computed segmentation masks
filter_params:
ref_tile_size: 16 # reference tile size at spacing tiling.spacing
ref_tile_size: ${tiling.params.tile_size} # reference tile size at spacing tiling.spacing
a_t: 4 # area filter threshold for tissue (positive integer, the minimum size of detected foreground contours to consider, relative to the reference tile size ref_tile_size, e.g. a value 10 means only detected foreground contours of size greater than 10 [ref_tile_size, ref_tile_size] tiles at spacing tiling.spacing will be kept)
a_h: 2 # area filter threshold for holes (positive integer, the minimum size of detected holes/cavities in foreground contours to avoid, once again relative to the reference tile size ref_tile_size)
max_n_holes: 8 # maximum of holes to consider per detected foreground contours (positive integer, higher values lead to more accurate patching but increase computational cost ; keeps the biggest holes)
Expand All @@ -43,6 +43,7 @@ model:
pretrained_weights: # path to the pretrained weights when using a custom model
batch_size: 256
tile_size: ${tiling.params.tile_size}
restrict_to_tissue: false # whether to restrict tile content to tissue pixels only when feeding tile through encoder
patch_size: 256 # if level is "region", size used to unroll the region into patches
save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide"
save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism')
Expand Down
78 changes: 69 additions & 9 deletions slide2vec/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,72 @@
import cv2
import torch
import numpy as np
import wholeslidedata as wsd

from transformers.image_processing_utils import BaseImageProcessor
from PIL import Image
from pathlib import Path
from typing import Callable

from slide2vec.hs2p.hs2p.wsi import WholeSlideImage, SegmentationParameters, SamplingParameters, FilterParameters
from slide2vec.hs2p.hs2p.wsi.utils import HasEnoughTissue


class TileDataset(torch.utils.data.Dataset):
def __init__(self, wsi_path, tile_dir, target_spacing, backend, transforms=None):
def __init__(
self,
wsi_path: Path,
mask_path: Path,
coordinates_dir: Path,
target_spacing: float,
tolerance: float,
backend: str,
segment_params: SegmentationParameters | None = None,
sampling_params: SamplingParameters | None = None,
filter_params: FilterParameters | None = None,
transforms: BaseImageProcessor | Callable | None = None,
restrict_to_tissue: bool = False,
):
self.path = wsi_path
self.mask_path = mask_path
self.target_spacing = target_spacing
self.backend = backend
self.name = wsi_path.stem.replace(" ", "_")
self.load_coordinates(tile_dir)
self.load_coordinates(coordinates_dir)
self.transforms = transforms
self.restrict_to_tissue = restrict_to_tissue

if restrict_to_tissue:
_wsi = WholeSlideImage(
path=self.path,
mask_path=self.mask_path,
backend=self.backend,
segment_params=segment_params,
sampling_params=sampling_params,
)
contours, holes = _wsi.detect_contours(
target_spacing=target_spacing,
tolerance=tolerance,
filter_params=filter_params,
)
scale = _wsi.level_downsamples[_wsi.seg_level]
self.contours = _wsi.scaleContourDim(contours, (1.0 / scale[0], 1.0 / scale[1]))
self.holes = _wsi.scaleHolesDim(holes, (1.0 / scale[0], 1.0 / scale[1]))
self.tissue_mask = _wsi.annotation_mask["tissue"]
self.seg_spacing = _wsi.get_level_spacing(_wsi.seg_level)
self.spacing_at_level_0 = _wsi.get_level_spacing(0)

def load_coordinates(self, tile_dir):
coordinates = np.load(Path(tile_dir, f"{self.name}.npy"), allow_pickle=True)
def load_coordinates(self, coordinates_dir):
coordinates = np.load(Path(coordinates_dir, f"{self.name}.npy"), allow_pickle=True)
self.x = coordinates["x"]
self.y = coordinates["y"]
self.coordinates = (np.array([self.x, self.y]).T).astype(int)
self.scaled_coordinates = self.scale_coordinates()
self.contour_index = coordinates["contour_index"]
self.target_tile_size = coordinates["target_tile_size"]
self.tile_level = coordinates["tile_level"]
self.resize_factor = coordinates["resize_factor"]
self.tile_size_resized = coordinates["tile_size_resized"]
resize_factor = coordinates["resize_factor"]
self.tile_size = np.round(self.tile_size_resized / resize_factor).astype(int)
self.tile_size_lv0 = coordinates["tile_size_lv0"][0]

def scale_coordinates(self):
Expand Down Expand Up @@ -55,11 +96,30 @@ def __getitem__(self, idx):
spacing=tile_spacing,
center=False,
)
if self.restrict_to_tissue:
contour_idx = self.contour_index[idx]
contour = self.contours[contour_idx]
holes = self.holes[contour_idx]
tissue_checker = HasEnoughTissue(
contour=contour,
contour_holes=holes,
tissue_mask=self.tissue_mask,
tile_size=self.target_tile_size[idx],
tile_spacing=tile_spacing,
resize_factor=self.resize_factor[idx],
seg_spacing=self.seg_spacing,
spacing_at_level_0=self.spacing_at_level_0,
)
tissue_mask = tissue_checker.get_tile_mask(self.x[idx], self.y[idx])
# ensure mask is the same size as the tile
assert tissue_mask.shape[:2] == tile_arr.shape[:2], "Mask and tile shapes do not match"
# apply mask
tile_arr = cv2.bitwise_and(tile_arr, tile_arr, mask=tissue_mask)
tile = Image.fromarray(tile_arr).convert("RGB")
if self.tile_size[idx] != self.tile_size_resized[idx]:
tile = tile.resize((self.tile_size[idx], self.tile_size[idx]))
if self.target_tile_size[idx] != self.tile_size_resized[idx]:
tile = tile.resize((self.target_tile_size[idx], self.target_tile_size[idx]))
if self.transforms:
if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`)
if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`)
tile = self.transforms(tile, return_tensors="pt")["pixel_values"].squeeze(0)
else: # general callable such as torchvision transforms
tile = self.transforms(tile)
Expand Down
63 changes: 56 additions & 7 deletions slide2vec/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from slide2vec.utils.config import get_cfg_from_file, setup_distributed
from slide2vec.models import ModelFactory
from slide2vec.data import TileDataset, RegionUnfolding
from slide2vec.hs2p.hs2p.wsi import SamplingParameters

torchvision.disable_beta_transforms_warning()

Expand Down Expand Up @@ -60,13 +61,31 @@ def create_transforms(cfg, model):
raise ValueError(f"Unknown model level: {cfg.model.level}")


def create_dataset(wsi_fp, coordinates_dir, spacing, backend, transforms):
def create_dataset(
wsi_path,
mask_path,
coordinates_dir,
target_spacing,
tolerance,
backend,
segment_params,
sampling_params,
filter_params,
transforms,
restrict_to_tissue: bool,
):
return TileDataset(
wsi_fp,
coordinates_dir,
spacing,
wsi_path=wsi_path,
mask_path=mask_path,
coordinates_dir=coordinates_dir,
target_spacing=target_spacing,
tolerance=tolerance,
backend=backend,
segment_params=segment_params,
sampling_params=sampling_params,
filter_params=filter_params,
transforms=transforms,
restrict_to_tissue=restrict_to_tissue,
)


Expand Down Expand Up @@ -176,12 +195,30 @@ def main(args):
if not run_on_cpu:
torch.distributed.barrier()

pixel_mapping = {k: v for e in cfg.tiling.sampling_params.pixel_mapping for k, v in e.items()}
tissue_percentage = {k: v for e in cfg.tiling.sampling_params.tissue_percentage for k, v in e.items()}
if "tissue" not in tissue_percentage:
tissue_percentage["tissue"] = cfg.tiling.params.min_tissue_percentage
if cfg.tiling.sampling_params.color_mapping is not None:
color_mapping = {k: v for e in cfg.tiling.sampling_params.color_mapping for k, v in e.items()}
else:
color_mapping = None

sampling_params = SamplingParameters(
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
tissue_percentage=tissue_percentage,
)

# select slides that were successfully tiled but not yet processed for feature extraction
tiled_df = process_df[process_df.tiling_status == "success"]
mask = tiled_df["feature_status"] != "success"
process_stack = tiled_df[mask]
total = len(process_stack)

wsi_paths_to_process = [Path(x) for x in process_stack.wsi_path.values.tolist()]
mask_paths_to_process = [Path(x) for x in process_stack.mask_path.values.tolist()]
combined_paths = zip(wsi_paths_to_process, mask_paths_to_process)

features_dir = Path(cfg.output_dir, "features")
if distributed.is_main_process():
Expand All @@ -201,8 +238,8 @@ def main(args):
transforms = create_transforms(cfg, model)
print(f"transforms: {transforms}")

for wsi_fp in tqdm.tqdm(
wsi_paths_to_process,
for wsi_fp, mask_fp in tqdm.tqdm(
combined_paths,
desc="Inference",
unit="slide",
total=total,
Expand All @@ -211,7 +248,19 @@ def main(args):
position=1,
):
try:
dataset = create_dataset(wsi_fp, coordinates_dir, cfg.tiling.params.spacing, cfg.tiling.backend, transforms)
dataset = create_dataset(
wsi_path=wsi_fp,
mask_path=mask_fp,
coordinates_dir=coordinates_dir,
target_spacing=cfg.tiling.params.spacing,
tolerance=cfg.tiling.params.tolerance,
backend=cfg.tiling.backend,
segment_params=cfg.tiling.seg_params,
sampling_params=sampling_params,
filter_params=cfg.tiling.filter_params,
transforms=transforms,
restrict_to_tissue=cfg.model.restrict_to_tissue,
)
if distributed.is_enabled_and_multiple_gpus():
sampler = torch.utils.data.DistributedSampler(
dataset,
Expand Down
Binary file modified test/gt/test-wsi.npy
Binary file not shown.
Loading