Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions .github/workflows/pr-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ jobs:
# coordinates must match exactly (deterministic tiling)
gt_coordinates = np.load('/gt/test-wsi.npy')
coordinates = np.load('/output/coordinates/test-wsi.npy')
assert len(gt_coordinates) == len(coordinates), f'Number of coordinates mismatch: {len(coordinates)} vs {len(gt_coordinates)}'
assert_array_equal(coordinates, gt_coordinates), f'Coordinates mismatch: {coordinates} vs {gt_coordinates}'
assert len(gt_coordinates) == len(coordinates), f'Number of coordinates mismatch: {len(coordinates)} vs {len(gt_coordinates)} ❌'
x_gt, y_gt = gt_coordinates['x'], gt_coordinates['y']
x, y = coordinates['x'], coordinates['y']
assert_array_equal(x, x_gt), 'x coordinates mismatch ❌'
assert_array_equal(y, y_gt), 'y coordinates mismatch ❌'
tile_level_gt = gt_coordinates['tile_level']
tile_level = coordinates['tile_level']
assert_array_equal(tile_level, tile_level_gt), 'tile_level mismatch ❌'
tile_size_gt = gt_coordinates['tile_size_resized']
tile_size = coordinates['tile_size_resized']
assert_array_equal(tile_size, tile_size_gt), 'tile_size_resized mismatch ❌'
resize_factor_gt = gt_coordinates['resize_factor']
resize_factor = coordinates['resize_factor']
assert_array_equal(resize_factor, resize_factor_gt), 'resize_factor mismatch ❌'
print("All coordinate checks passed ✅")
PY"
2 changes: 1 addition & 1 deletion hs2p/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def process_slide(


def main(args):

cfg = setup(args)
output_dir = Path(cfg.output_dir)

Expand Down
35 changes: 23 additions & 12 deletions hs2p/wsi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@


class HasEnoughTissue(object):
def __init__(self, contour, contour_holes, tissue_mask, tile_size, scale, pct=0.01):
def __init__(self, contour, contour_holes, tissue_mask, tile_size, tile_spacing, resize_factor, seg_spacing, spacing_at_level_0, pct=0.01):
self.cont = contour
self.holes = contour_holes
self.mask = tissue_mask // 255
self.tile_size = tile_size
self.scale = scale
self.tile_spacing = tile_spacing
self.resize_factor = resize_factor
self.seg_spacing = seg_spacing
self.spacing_at_level_0 = spacing_at_level_0
self.pct = pct

self.downsampled_tile_size = int(round(self.tile_size * 1 / self.scale[0], 0))
# downsample tile size from target_spacing to seg_spacing
# where contour and tissue masks are defined
target_spacing = self.tile_spacing * self.resize_factor
scale = self.seg_spacing / target_spacing
self.downsampled_tile_size = int(round(self.tile_size * 1 / scale, 0))
assert (
self.downsampled_tile_size > 0
), "downsampled tile_size is equal to zero, aborting; please consider using a smaller seg_params.downsample parameter"

# Precompute the combined tissue mask
self.tile_size_resized = int(round(tile_size * resize_factor,0))

# precompute the combined tissue mask
self.precomputed_mask = self._precompute_tissue_mask()

def _precompute_tissue_mask(self):
Expand Down Expand Up @@ -58,8 +67,9 @@ def check_coordinates(self, coords):
- keep_flags is a list of 1s and 0s indicating whether each tile has enough tissue.
- tissue_pcts is a list of tissue percentages for each tile.
"""
# downsample coordinates
downsampled_coords = coords * 1 / self.scale[0]
# downsample coordinates from level 0 to seg_level
scale = self.seg_spacing / self.spacing_at_level_0
downsampled_coords = coords * 1 / scale
downsampled_coords = downsampled_coords.astype(int)

keep_flags = []
Expand Down Expand Up @@ -93,19 +103,20 @@ def get_tile_mask(self, x, y):
Returns:
np.ndarray: The binary mask for the tile (0 or 1).
"""
# downsample coordinates
x_tile = int(x / self.scale[0])
y_tile = int(y / self.scale[0])
# downsample coordinates from level 0 to seg_level
scale = self.seg_spacing / self.spacing_at_level_0
x_tile = int(x / scale)
y_tile = int(y / scale)

# extract the sub-mask for the tile
sub_mask = self._extract_sub_mask(x_tile, y_tile)

# handle edge cases where sub_mask is smaller than expected
if sub_mask.shape[0] != self.downsampled_tile_size or sub_mask.shape[1] != self.downsampled_tile_size:
padded_mask = np.zeros((self.downsampled_tile_size, self.downsampled_tile_size), dtype=sub_mask.dtype)
padded_mask[:sub_mask.shape[0], :sub_mask.shape[1]] = sub_mask
sub_mask = padded_mask

# upsample the mask to the original tile size
mask = cv2.resize(sub_mask, (self.tile_size, self.tile_size), interpolation=cv2.INTER_NEAREST)
mask = cv2.resize(sub_mask, (self.tile_size_resized, self.tile_size_resized), interpolation=cv2.INTER_NEAREST)
return mask
31 changes: 14 additions & 17 deletions hs2p/wsi/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
def get_slide(self, spacing: float):
return self.wsi.get_slide(spacing=spacing)

def get_tile(self, x: int, y: int, width: int, height: int, spacing: float, mask: np.ndarray = None):
def get_tile(self, x: int, y: int, width: int, height: int, spacing: float):
"""
Extracts a tile from a whole slide image at the specified coordinates, size, and spacing.

Expand All @@ -145,7 +145,6 @@ def get_tile(self, x: int, y: int, width: int, height: int, spacing: float, mask
width (int): Tile width.
height (int): Tile height.
spacing (float): The spacing (resolution) at which the tile should be extracted.
mask (np.ndarray, optional): A binary mask to apply to the tile. Defaults to None.

Returns:
numpy.ndarray: The extracted tile as a numpy array.
Expand All @@ -158,13 +157,6 @@ def get_tile(self, x: int, y: int, width: int, height: int, spacing: float, mask
spacing=spacing,
center=False,
)

if mask is not None:
# ensure mask is the same size as the tile
assert mask.shape[:2] == tile.shape[:2], "Mask and tile shapes do not match"
# apply mask
tile = cv2.bitwise_and(tile, tile, mask=mask)

return tile

def get_downsamples(self):
Expand Down Expand Up @@ -638,8 +630,9 @@ def detect_contours(
current_scale = self.level_downsamples[spacing_level]
target_scale = self.level_downsamples[self.seg_level]
scale = tuple(a / b for a, b in zip(target_scale, current_scale))
ref_tile_size = filter_params.ref_tile_size
scaled_ref_tile_area = int(round(ref_tile_size**2 / (scale[0] * scale[1]),0))
ref_tile_size = (filter_params.ref_tile_size, filter_params.ref_tile_size)
ref_tile_size_at_target_scale = tuple(a / b for a, b in zip(ref_tile_size, scale))
scaled_ref_tile_area = int(ref_tile_size_at_target_scale[0] * ref_tile_size_at_target_scale[1])

adjusted_filter_params = FilterParameters(
ref_tile_size=filter_params.ref_tile_size,
Expand Down Expand Up @@ -927,7 +920,7 @@ def process_contour(
int(self.level_downsamples[tile_level][0]),
int(self.level_downsamples[tile_level][1]),
)
ref_tile_size = (
tile_size_at_level_0 = (
tile_size_resized * tile_downsample[0],
tile_size_resized * tile_downsample[1],
)
Expand All @@ -937,20 +930,24 @@ def process_contour(
stop_y = int(start_y + h)
stop_x = int(start_x + w)
else:
stop_y = min(start_y + h, img_h - ref_tile_size[1] + 1)
stop_x = min(start_x + w, img_w - ref_tile_size[0] + 1)
stop_y = min(start_y + h, img_h - tile_size_at_level_0[1] + 1)
stop_x = min(start_x + w, img_w - tile_size_at_level_0[0] + 1)

scale = self.level_downsamples[self.seg_level]
cont = self.scaleContourDim([contour], (1.0 / scale[0], 1.0 / scale[1]))[0]

mask = self.annotation_mask["tissue"] if annotation is None else self.annotation_mask[annotation]
pct = self.annotation_pct["tissue"] if annotation is None else self.annotation_pct[annotation]
seg_spacing = self.get_level_spacing(self.seg_level)
tissue_checker = HasEnoughTissue(
contour=cont,
contour_holes=contour_holes,
tissue_mask=mask,
tile_size=ref_tile_size[0],
scale=scale,
tile_size=tile_size,
tile_spacing=tile_spacing,
resize_factor=resize_factor,
seg_spacing=seg_spacing,
spacing_at_level_0=self.get_level_spacing(0),
pct=pct,
)

Expand All @@ -969,7 +966,7 @@ def process_contour(

if drop_holes:
keep_flags = [
flag and not self.isInHoles(contour_holes, coord, ref_tile_size[0])
flag and not self.isInHoles(contour_holes, coord, tile_size_at_level_0[0])
for flag, coord in zip(keep_flags, coord_candidates)
]

Expand Down
Loading