From 72f0124fa26c9a8caf4b4f96dc56d371e03cfe8e Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Wed, 31 Dec 2025 12:38:48 +0000 Subject: [PATCH] improve tissue checker logic --- .github/workflows/pr-test.yaml | 17 +++++++++++++++-- hs2p/tiling.py | 2 +- hs2p/wsi/utils.py | 35 ++++++++++++++++++++++------------ hs2p/wsi/wsi.py | 31 ++++++++++++++---------------- 4 files changed, 53 insertions(+), 32 deletions(-) diff --git a/.github/workflows/pr-test.yaml b/.github/workflows/pr-test.yaml index a13ceb0..45c8bb2 100644 --- a/.github/workflows/pr-test.yaml +++ b/.github/workflows/pr-test.yaml @@ -78,6 +78,19 @@ jobs: # coordinates must match exactly (deterministic tiling) gt_coordinates = np.load('/gt/test-wsi.npy') coordinates = np.load('/output/coordinates/test-wsi.npy') - assert len(gt_coordinates) == len(coordinates), f'Number of coordinates mismatch: {len(coordinates)} vs {len(gt_coordinates)}' - assert_array_equal(coordinates, gt_coordinates), f'Coordinates mismatch: {coordinates} vs {gt_coordinates}' + assert len(gt_coordinates) == len(coordinates), f'Number of coordinates mismatch: {len(coordinates)} vs {len(gt_coordinates)} ❌' + x_gt, y_gt = gt_coordinates['x'], gt_coordinates['y'] + x, y = coordinates['x'], coordinates['y'] + assert_array_equal(x, x_gt), 'x coordinates mismatch ❌' + assert_array_equal(y, y_gt), 'y coordinates mismatch ❌' + tile_level_gt = gt_coordinates['tile_level'] + tile_level = coordinates['tile_level'] + assert_array_equal(tile_level, tile_level_gt), 'tile_level mismatch ❌' + tile_size_gt = gt_coordinates['tile_size_resized'] + tile_size = coordinates['tile_size_resized'] + assert_array_equal(tile_size, tile_size_gt), 'tile_size_resized mismatch ❌' + resize_factor_gt = gt_coordinates['resize_factor'] + resize_factor = coordinates['resize_factor'] + assert_array_equal(resize_factor, resize_factor_gt), 'resize_factor mismatch ❌' + print("All coordinate checks passed ✅") PY" diff --git a/hs2p/tiling.py b/hs2p/tiling.py index e47786a..816b2ca 100644 --- a/hs2p/tiling.py +++ b/hs2p/tiling.py @@ -103,7 +103,7 @@ def process_slide( def main(args): - + cfg = setup(args) output_dir = Path(cfg.output_dir) diff --git a/hs2p/wsi/utils.py b/hs2p/wsi/utils.py index 7418c4c..c1b95be 100644 --- a/hs2p/wsi/utils.py +++ b/hs2p/wsi/utils.py @@ -3,20 +3,29 @@ class HasEnoughTissue(object): - def __init__(self, contour, contour_holes, tissue_mask, tile_size, scale, pct=0.01): + def __init__(self, contour, contour_holes, tissue_mask, tile_size, tile_spacing, resize_factor, seg_spacing, spacing_at_level_0, pct=0.01): self.cont = contour self.holes = contour_holes self.mask = tissue_mask // 255 self.tile_size = tile_size - self.scale = scale + self.tile_spacing = tile_spacing + self.resize_factor = resize_factor + self.seg_spacing = seg_spacing + self.spacing_at_level_0 = spacing_at_level_0 self.pct = pct - self.downsampled_tile_size = int(round(self.tile_size * 1 / self.scale[0], 0)) + # downsample tile size from target_spacing to seg_spacing + # where contour and tissue masks are defined + target_spacing = self.tile_spacing * self.resize_factor + scale = self.seg_spacing / target_spacing + self.downsampled_tile_size = int(round(self.tile_size * 1 / scale, 0)) assert ( self.downsampled_tile_size > 0 ), "downsampled tile_size is equal to zero, aborting; please consider using a smaller seg_params.downsample parameter" - # Precompute the combined tissue mask + self.tile_size_resized = int(round(tile_size * resize_factor,0)) + + # precompute the combined tissue mask self.precomputed_mask = self._precompute_tissue_mask() def _precompute_tissue_mask(self): @@ -58,8 +67,9 @@ def check_coordinates(self, coords): - keep_flags is a list of 1s and 0s indicating whether each tile has enough tissue. - tissue_pcts is a list of tissue percentages for each tile. """ - # downsample coordinates - downsampled_coords = coords * 1 / self.scale[0] + # downsample coordinates from level 0 to seg_level + scale = self.seg_spacing / self.spacing_at_level_0 + downsampled_coords = coords * 1 / scale downsampled_coords = downsampled_coords.astype(int) keep_flags = [] @@ -93,19 +103,20 @@ def get_tile_mask(self, x, y): Returns: np.ndarray: The binary mask for the tile (0 or 1). """ - # downsample coordinates - x_tile = int(x / self.scale[0]) - y_tile = int(y / self.scale[0]) + # downsample coordinates from level 0 to seg_level + scale = self.seg_spacing / self.spacing_at_level_0 + x_tile = int(x / scale) + y_tile = int(y / scale) # extract the sub-mask for the tile sub_mask = self._extract_sub_mask(x_tile, y_tile) - + # handle edge cases where sub_mask is smaller than expected if sub_mask.shape[0] != self.downsampled_tile_size or sub_mask.shape[1] != self.downsampled_tile_size: padded_mask = np.zeros((self.downsampled_tile_size, self.downsampled_tile_size), dtype=sub_mask.dtype) padded_mask[:sub_mask.shape[0], :sub_mask.shape[1]] = sub_mask sub_mask = padded_mask - + # upsample the mask to the original tile size - mask = cv2.resize(sub_mask, (self.tile_size, self.tile_size), interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(sub_mask, (self.tile_size_resized, self.tile_size_resized), interpolation=cv2.INTER_NEAREST) return mask diff --git a/hs2p/wsi/wsi.py b/hs2p/wsi/wsi.py index f4c5dd8..0b7e4bc 100644 --- a/hs2p/wsi/wsi.py +++ b/hs2p/wsi/wsi.py @@ -135,7 +135,7 @@ def __init__( def get_slide(self, spacing: float): return self.wsi.get_slide(spacing=spacing) - def get_tile(self, x: int, y: int, width: int, height: int, spacing: float, mask: np.ndarray = None): + def get_tile(self, x: int, y: int, width: int, height: int, spacing: float): """ Extracts a tile from a whole slide image at the specified coordinates, size, and spacing. @@ -145,7 +145,6 @@ def get_tile(self, x: int, y: int, width: int, height: int, spacing: float, mask width (int): Tile width. height (int): Tile height. spacing (float): The spacing (resolution) at which the tile should be extracted. - mask (np.ndarray, optional): A binary mask to apply to the tile. Defaults to None. Returns: numpy.ndarray: The extracted tile as a numpy array. @@ -158,13 +157,6 @@ def get_tile(self, x: int, y: int, width: int, height: int, spacing: float, mask spacing=spacing, center=False, ) - - if mask is not None: - # ensure mask is the same size as the tile - assert mask.shape[:2] == tile.shape[:2], "Mask and tile shapes do not match" - # apply mask - tile = cv2.bitwise_and(tile, tile, mask=mask) - return tile def get_downsamples(self): @@ -638,8 +630,9 @@ def detect_contours( current_scale = self.level_downsamples[spacing_level] target_scale = self.level_downsamples[self.seg_level] scale = tuple(a / b for a, b in zip(target_scale, current_scale)) - ref_tile_size = filter_params.ref_tile_size - scaled_ref_tile_area = int(round(ref_tile_size**2 / (scale[0] * scale[1]),0)) + ref_tile_size = (filter_params.ref_tile_size, filter_params.ref_tile_size) + ref_tile_size_at_target_scale = tuple(a / b for a, b in zip(ref_tile_size, scale)) + scaled_ref_tile_area = int(ref_tile_size_at_target_scale[0] * ref_tile_size_at_target_scale[1]) adjusted_filter_params = FilterParameters( ref_tile_size=filter_params.ref_tile_size, @@ -927,7 +920,7 @@ def process_contour( int(self.level_downsamples[tile_level][0]), int(self.level_downsamples[tile_level][1]), ) - ref_tile_size = ( + tile_size_at_level_0 = ( tile_size_resized * tile_downsample[0], tile_size_resized * tile_downsample[1], ) @@ -937,20 +930,24 @@ def process_contour( stop_y = int(start_y + h) stop_x = int(start_x + w) else: - stop_y = min(start_y + h, img_h - ref_tile_size[1] + 1) - stop_x = min(start_x + w, img_w - ref_tile_size[0] + 1) + stop_y = min(start_y + h, img_h - tile_size_at_level_0[1] + 1) + stop_x = min(start_x + w, img_w - tile_size_at_level_0[0] + 1) scale = self.level_downsamples[self.seg_level] cont = self.scaleContourDim([contour], (1.0 / scale[0], 1.0 / scale[1]))[0] mask = self.annotation_mask["tissue"] if annotation is None else self.annotation_mask[annotation] pct = self.annotation_pct["tissue"] if annotation is None else self.annotation_pct[annotation] + seg_spacing = self.get_level_spacing(self.seg_level) tissue_checker = HasEnoughTissue( contour=cont, contour_holes=contour_holes, tissue_mask=mask, - tile_size=ref_tile_size[0], - scale=scale, + tile_size=tile_size, + tile_spacing=tile_spacing, + resize_factor=resize_factor, + seg_spacing=seg_spacing, + spacing_at_level_0=self.get_level_spacing(0), pct=pct, ) @@ -969,7 +966,7 @@ def process_contour( if drop_holes: keep_flags = [ - flag and not self.isInHoles(contour_holes, coord, ref_tile_size[0]) + flag and not self.isInHoles(contour_holes, coord, tile_size_at_level_0[0]) for flag, coord in zip(keep_flags, coord_candidates) ]