From 40326e956fce950e0ce7ee4b52db470d44c7e1cc Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 21 Aug 2024 20:39:35 +0200 Subject: [PATCH 01/20] implement selection --- src/spatialdata/dataloader/datasets.py | 47 ++++++++++++++++---------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index c65bc4f5a..7370b2a6a 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -128,7 +128,8 @@ def __init__( from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize as rasterize_fn - self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) + self.sdata = sdata + self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name) if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False: @@ -145,21 +146,19 @@ def __init__( **dict(rasterize_kwargs), ) if rasterize - else bounding_box_query # type: ignore[assignment] + else partial(bounding_box_query, return_request_only=True) # type: ignore[assignment] ) self._return = self._get_return(return_annotations, table_name) self.transform = transform def _validate( self, - sdata: SpatialData, regions_to_images: dict[str, str], regions_to_coordinate_systems: dict[str, str], return_annotations: str | list[str] | None, table_name: str | None, ) -> None: """Validate input parameters.""" - self.sdata = sdata if return_annotations is not None and table_name is None: raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.") @@ -174,8 +173,8 @@ def _validate( image_name = regions_to_images[region_name] # get elements - region_elem = sdata[region_name] - image_elem = sdata[image_name] + region_elem = self.sdata[region_name] + image_elem = self.sdata[image_name] # check that the elements are supported if get_model(region_elem) == PointsModel: @@ -200,13 +199,13 @@ def _validate( ) if table_name is not None: - _, region_key, instance_key = get_table_keys(sdata.tables[table_name]) + _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) if get_model(region_elem) in [Labels2DModel, Labels3DModel]: indices = get_element_instances(region_elem).tolist() else: indices = region_elem.index.tolist() - table = sdata.tables[table_name] - if not isinstance(sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): + table = self.sdata.tables[table_name] + if not isinstance(self.sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): raise TypeError( f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. " f"Please convert it." @@ -229,8 +228,10 @@ def _preprocess( table_name: str | None, ) -> None: """Preprocess the dataset.""" + from spatialdata import bounding_box_query + if table_name is not None: - _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) + _, region_key, _ = get_table_keys(self.sdata.tables[table_name]) filtered_table = self.sdata.tables[table_name][ self.sdata.tables[table_name].obs[region_key].isin(self.regions) ] # filtered table for the data loader @@ -250,6 +251,17 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) + tile_coords["selection"] = tile_coords.apply( + lambda row: bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=row[["minx", "miny"]].values, + max_coordinate=row[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, + ), + axis=1, + ) tile_coords_df.append(tile_coords) inst = circles.index.values @@ -359,13 +371,14 @@ def __getitem__(self, idx: int) -> Any | SpatialData: t_coords = self.tiles_coords.iloc[idx] image = self.sdata[row["image"]] - tile = self._crop_image( - image, - axes=tuple(self.dims), - min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, - max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, - target_coordinate_system=row["cs"], - ) + # tile = self._crop_image( + # image, + # axes=tuple(self.dims), + # min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + # max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + # target_coordinate_system=row["cs"], + # ) + tile = image.sel(t_coords["selection"]) if self.transform is not None: out = self._return(idx, tile) return self.transform(out) From aa339aa719c808afdce662bd8a6e1305e06f3654 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 21 Aug 2024 20:47:50 +0200 Subject: [PATCH 02/20] update --- src/spatialdata/dataloader/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 7370b2a6a..6b7ec7ae9 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -252,7 +252,7 @@ def _preprocess( tile_dim_in_units=tile_dim_in_units, ) tile_coords["selection"] = tile_coords.apply( - lambda row: bounding_box_query( + lambda row, cs=cs, image_name=image_name: bounding_box_query( self.sdata[image_name], ("x", "y"), min_coordinate=row[["minx", "miny"]].values, From 92d578fea77735b2279ee624aa9a316132690bca Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 11:58:25 -0700 Subject: [PATCH 03/20] vectorize adjust_bounding_box_to_real_axes --- src/spatialdata/_core/query/spatial_query.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dea2280a5..5a312149f 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -230,6 +230,7 @@ def _adjust_bounding_box_to_real_axes( The bounding box is defined by the user and its axes may not coincide with the axes of the transformation. """ + axis = min_coordinate.ndim - 1 if set(axes_bb) != set(axes_out_without_c): axes_only_in_bb = set(axes_bb) - set(axes_out_without_c) axes_only_in_output = set(axes_out_without_c) - set(axes_bb) @@ -246,8 +247,8 @@ def _adjust_bounding_box_to_real_axes( for ax in axes_only_in_output: axes_bb = axes_bb + (ax,) M = np.finfo(np.float32).max - 1 - min_coordinate = np.append(min_coordinate, -M) - max_coordinate = np.append(max_coordinate, M) + min_coordinate = np.append(min_coordinate, -M, axis=axis) + max_coordinate = np.append(max_coordinate, M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] min_coordinate = min_coordinate[np.array(indices)] From 2bb5c35e34cee512bad60011a0c2b7ae3bcfb5a8 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 12:00:58 -0700 Subject: [PATCH 04/20] update --- src/spatialdata/_core/query/spatial_query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 5a312149f..fbfbe9c89 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -239,8 +239,8 @@ def _adjust_bounding_box_to_real_axes( # 3D bounding box) indices_to_remove_from_bb = [axes_bb.index(ax) for ax in axes_only_in_bb] axes_bb = tuple(ax for ax in axes_bb if ax not in axes_only_in_bb) - min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb) - max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb) + min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb, axis=axis) + max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb, axis=axis) # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) @@ -251,8 +251,8 @@ def _adjust_bounding_box_to_real_axes( max_coordinate = np.append(max_coordinate, M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] - min_coordinate = min_coordinate[np.array(indices)] - max_coordinate = max_coordinate[np.array(indices)] + min_coordinate = np.take(min_coordinate, indices, axis=axis) + max_coordinate = np.take(max_coordinate, indices, axis=axis) axes_bb = axes_out_without_c return axes_bb, min_coordinate, max_coordinate From c89dcdf62c2a704df00874e81b5a3ee076aacf4d Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 13:59:35 -0700 Subject: [PATCH 05/20] replace append with insert --- src/spatialdata/_core/query/spatial_query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index fbfbe9c89..6c12b792c 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -244,11 +244,11 @@ def _adjust_bounding_box_to_real_axes( # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) + M = np.finfo(np.float32).max - 1 for ax in axes_only_in_output: axes_bb = axes_bb + (ax,) - M = np.finfo(np.float32).max - 1 - min_coordinate = np.append(min_coordinate, -M, axis=axis) - max_coordinate = np.append(max_coordinate, M, axis=axis) + min_coordinate = np.insert(min_coordinate, min_coordinate.shape[axis], -M, axis=axis) + max_coordinate = np.insert(max_coordinate, max_coordinate.shape[axis], M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] min_coordinate = np.take(min_coordinate, indices, axis=axis) From 5bf0b43e1756e8db4f273ae2ef2c7e5917853d49 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 14:10:07 -0700 Subject: [PATCH 06/20] add comment --- src/spatialdata/_core/query/spatial_query.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 6c12b792c..00abbba3f 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -230,6 +230,7 @@ def _adjust_bounding_box_to_real_axes( The bounding box is defined by the user and its axes may not coincide with the axes of the transformation. """ + # axis for slicing, if axis > 0, then the min_/max_coordinate multiple bounding boxes along axis 0 axis = min_coordinate.ndim - 1 if set(axes_bb) != set(axes_out_without_c): axes_only_in_bb = set(axes_bb) - set(axes_out_without_c) From a60bf6f3ee1f645fb3bef81d4f4ff4751633b0f0 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 15:23:26 -0700 Subject: [PATCH 07/20] vectorize --- src/spatialdata/_core/query/_utils.py | 67 +++++++++++++++++---------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 3b63470ed..c79e31d73 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -2,6 +2,7 @@ from typing import Any +import numpy as np from anndata import AnnData from xarray import DataArray @@ -36,37 +37,55 @@ def get_bounding_box_corners( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - if len(min_coordinate) not in (2, 3): + if min_coordinate.ndim == 1: + min_coordinate = min_coordinate[np.newaxis, :] + max_coordinate = max_coordinate[np.newaxis, :] + + if min_coordinate.shape[1] not in (2, 3): raise ValueError("bounding box must be 2D or 3D") - if len(min_coordinate) == 2: + num_boxes = min_coordinate.shape[0] + num_dims = min_coordinate.shape[1] + + if num_dims == 2: # 2D bounding box assert len(axes) == 2 - return DataArray( + corners = np.array( [ - [min_coordinate[0], min_coordinate[1]], - [min_coordinate[0], max_coordinate[1]], - [max_coordinate[0], max_coordinate[1]], - [max_coordinate[0], min_coordinate[1]], - ], - coords={"corner": range(4), "axis": list(axes)}, + [min_coordinate[:, 0], min_coordinate[:, 1]], + [min_coordinate[:, 0], max_coordinate[:, 1]], + [max_coordinate[:, 0], max_coordinate[:, 1]], + [max_coordinate[:, 0], min_coordinate[:, 1]], + ] ) - - # 3D bounding cube - assert len(axes) == 3 - return DataArray( - [ - [min_coordinate[0], min_coordinate[1], min_coordinate[2]], - [min_coordinate[0], min_coordinate[1], max_coordinate[2]], - [min_coordinate[0], max_coordinate[1], max_coordinate[2]], - [min_coordinate[0], max_coordinate[1], min_coordinate[2]], - [max_coordinate[0], min_coordinate[1], min_coordinate[2]], - [max_coordinate[0], min_coordinate[1], max_coordinate[2]], - [max_coordinate[0], max_coordinate[1], max_coordinate[2]], - [max_coordinate[0], max_coordinate[1], min_coordinate[2]], - ], - coords={"corner": range(8), "axis": list(axes)}, + corners = np.transpose(corners, (2, 0, 1)) + else: + # 3D bounding cube + assert len(axes) == 3 + corners = np.array( + [ + [min_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], + [min_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], + [min_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], + [min_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], + [max_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], + [max_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], + [max_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], + [max_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], + ] + ) + corners = np.transpose(corners, (2, 0, 1)) + output = DataArray( + corners, + coords={ + "box": range(num_boxes), + "corner": range(corners.shape[1]), + "axis": list(axes), + }, ) + if num_boxes > 1: + return output + return output.squeeze().drop_vars("box") def _get_filtered_or_unfiltered_tables( From 017967b57e384974f1b3a5e0ff549f8f81039ddf Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 16:19:32 -0700 Subject: [PATCH 08/20] update to handle multiple boxes --- src/spatialdata/_core/query/spatial_query.py | 67 +++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 00abbba3f..b575d67a8 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -120,10 +120,18 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( intrinsic_bounding_box_corners = bounding_box_corners.data @ rotation_matrix.T + translation + if bounding_box_corners.ndim > 2: # multiple boxes + coords = { + "box": range(len(bounding_box_corners)), + "corner": range(len(bounding_box_corners)), + "axis": list(inverse.output_axes), + } + else: + coords = {"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)} return ( DataArray( intrinsic_bounding_box_corners, - coords={"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)}, + coords=coords, ), input_axes_without_c, ) @@ -534,22 +542,47 @@ def _( # build the request: now that we have the bounding box corners in the intrinsic coordinate system, we can use them # to build the request to query the raster data using the xarray APIs - selection = {} - translation_vector = [] - for axis_name in axes: - # get the min value along the axis - min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() - - # get max value, slices are open half interval - max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() - - # add the - selection[axis_name] = slice(min_value, max_value) - - if min_value > 0: - translation_vector.append(np.ceil(min_value).item()) - else: - translation_vector.append(0) + # selection = {} + # translation_vector = [] + # for axis_name in axes: + # # get the min value along the axis + # min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() + + # # get max value, slices are open half interval + # max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() + + # # add the + # selection[axis_name] = slice(min_value, max_value) + + # if min_value > 0: + # translation_vector.append(np.ceil(min_value).item()) + # else: + # translation_vector.append(0) + + min_values = intrinsic_bounding_box_corners.min(dim="corner") + max_values = intrinsic_bounding_box_corners.max(dim="corner") + + # Convert to numpy arrays for faster operations + min_values_np = min_values.values + max_values_np = max_values.values + + if min_values.ndim == 2: # Multiple boxes + slices = np.array( + [ + [slice(min_val, max_val) for min_val, max_val in zip(box_min, box_max)] + for box_min, box_max in zip(min_values_np, max_values_np) + ] + ) + translation_vectors = np.ceil(np.maximum(min_values_np, 0)) + selection: list[dict[str, Any]] | dict[str, Any] = [ + {axis: slices[box_idx, axis_idx] for axis_idx, axis in enumerate(axes)} + for box_idx in range(len(min_values_np)) + ] + translation_vector = translation_vectors.tolist() + else: # Single box + slices = np.array([slice(min_val, max_val) for min_val, max_val in zip(min_values_np, max_values_np)]) + translation_vector = np.ceil(np.maximum(min_values_np, 0)).tolist() + selection = {axis: slices[axis_idx] for axis_idx, axis in enumerate(axes)} if return_request_only: return selection From ab774b7912cb3eb7f5cfb93de41df8fe36c90485 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 16:35:46 -0700 Subject: [PATCH 09/20] vectorize with numba --- src/spatialdata/_core/query/spatial_query.py | 67 ++++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index b575d67a8..dd0e8ffa7 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -9,6 +9,7 @@ import dask.array as da import dask.dataframe as dd +import numba as nb import numpy as np from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree @@ -44,6 +45,24 @@ ) +@nb.njit(parallel=False, nopython=True) +def create_slices_and_translation( + min_values: nb.types.Array[nb.float64, nb.float64], + max_values: nb.types.Array[nb.float64, nb.float64], +) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: + n_boxes, n_dims = min_values.shape + slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) + translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) + + for i in range(n_boxes): + for j in range(n_dims): + slices[i, j, 0] = min_values[i, j] + slices[i, j, 1] = max_values[i, j] + translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) + + return slices, translation_vectors + + def _get_bounding_box_corners_in_intrinsic_coordinates( element: SpatialElement, axes: tuple[str, ...], @@ -540,49 +559,30 @@ def _( if TYPE_CHECKING: assert isinstance(intrinsic_bounding_box_corners, DataArray) - # build the request: now that we have the bounding box corners in the intrinsic coordinate system, we can use them - # to build the request to query the raster data using the xarray APIs - # selection = {} - # translation_vector = [] - # for axis_name in axes: - # # get the min value along the axis - # min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() - - # # get max value, slices are open half interval - # max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() - - # # add the - # selection[axis_name] = slice(min_value, max_value) - - # if min_value > 0: - # translation_vector.append(np.ceil(min_value).item()) - # else: - # translation_vector.append(0) - min_values = intrinsic_bounding_box_corners.min(dim="corner") max_values = intrinsic_bounding_box_corners.max(dim="corner") - # Convert to numpy arrays for faster operations - min_values_np = min_values.values - max_values_np = max_values.values + min_values_np = min_values.data + max_values_np = max_values.data + + if min_values_np.ndim == 1: + min_values_np = min_values_np[np.newaxis, :] + max_values_np = max_values_np[np.newaxis, :] + + slices, translation_vectors = create_slices_and_translation(min_values_np, max_values_np) if min_values.ndim == 2: # Multiple boxes - slices = np.array( - [ - [slice(min_val, max_val) for min_val, max_val in zip(box_min, box_max)] - for box_min, box_max in zip(min_values_np, max_values_np) - ] - ) - translation_vectors = np.ceil(np.maximum(min_values_np, 0)) selection: list[dict[str, Any]] | dict[str, Any] = [ - {axis: slices[box_idx, axis_idx] for axis_idx, axis in enumerate(axes)} + { + axis: slice(slices[box_idx, axis_idx, 0], slices[box_idx, axis_idx, 1]) + for axis_idx, axis in enumerate(axes) + } for box_idx in range(len(min_values_np)) ] translation_vector = translation_vectors.tolist() else: # Single box - slices = np.array([slice(min_val, max_val) for min_val, max_val in zip(min_values_np, max_values_np)]) - translation_vector = np.ceil(np.maximum(min_values_np, 0)).tolist() - selection = {axis: slices[axis_idx] for axis_idx, axis in enumerate(axes)} + selection = {axis: slice(slices[0, axis_idx, 0], slices[0, axis_idx, 1]) for axis_idx, axis in enumerate(axes)} + translation_vector = translation_vectors[0].tolist() if return_request_only: return selection @@ -858,7 +858,6 @@ def _( images: bool = True, labels: bool = True, ) -> SpatialData: - _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} for element_type in ["points", "images", "labels", "shapes"]: From 38dba2528a1f591bc9f52836b2c9cec49664d3df Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 16:56:59 -0700 Subject: [PATCH 10/20] fix corner len --- src/spatialdata/_core/query/spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dd0e8ffa7..792248f32 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -142,7 +142,7 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( if bounding_box_corners.ndim > 2: # multiple boxes coords = { "box": range(len(bounding_box_corners)), - "corner": range(len(bounding_box_corners)), + "corner": range(bounding_box_corners.shape[1]), "axis": list(inverse.output_axes), } else: From b27607e91962123c01b1d91df7014fad05010721 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 17:02:37 -0700 Subject: [PATCH 11/20] update --- src/spatialdata/_core/query/spatial_query.py | 12 ++++----- src/spatialdata/dataloader/datasets.py | 28 +++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 792248f32..0107b49b4 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -546,12 +546,12 @@ def _( max_coordinate = _parse_list_into_array(max_coordinate) # for triggering validation - _ = BoundingBoxRequest( - target_coordinate_system=target_coordinate_system, - axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - ) + # _ = BoundingBoxRequest( + # target_coordinate_system=target_coordinate_system, + # axes=axes, + # min_coordinate=min_coordinate, + # max_coordinate=max_coordinate, + # ) intrinsic_bounding_box_corners, axes = _get_bounding_box_corners_in_intrinsic_coordinates( image, axes, min_coordinate, max_coordinate, target_coordinate_system diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 6b7ec7ae9..f6687d6b9 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -251,17 +251,25 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) - tile_coords["selection"] = tile_coords.apply( - lambda row, cs=cs, image_name=image_name: bounding_box_query( - self.sdata[image_name], - ("x", "y"), - min_coordinate=row[["minx", "miny"]].values, - max_coordinate=row[["maxx", "maxy"]].values, - target_coordinate_system=cs, - return_request_only=True, - ), - axis=1, + tile_coords["selection"] = bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=tile_coords[["minx", "miny"]].values, + max_coordinate=tile_coords[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, ) + # tile_coords["selection"] = tile_coords.apply( + # lambda row, cs=cs, image_name=image_name: bounding_box_query( + # self.sdata[image_name], + # ("x", "y"), + # min_coordinate=row[["minx", "miny"]].values, + # max_coordinate=row[["maxx", "maxy"]].values, + # target_coordinate_system=cs, + # return_request_only=True, + # ), + # axis=1, + # ) tile_coords_df.append(tile_coords) inst = circles.index.values From a934e21a0421c925580a4319148e49f45f2e8bfd Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 17:16:30 -0700 Subject: [PATCH 12/20] fix validation --- src/spatialdata/_core/query/spatial_query.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 792248f32..fb363121d 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -380,7 +380,12 @@ def __post_init__(self) -> None: raise ValueError(f"Non-spatial axes specified: {non_spatial_axes}") # validate the axes - if len(self.axes) != len(self.min_coordinate) or len(self.axes) != len(self.max_coordinate): + if self.min_coordinate.shape != self.max_coordinate.shape: + raise ValueError("The `min_coordinate` and `max_coordinate` must have the same shape.") + + n_axes_coordinate = len(self.min_coordinate) if self.min_coordinate.ndim == 1 else self.min_coordinate.shape[1] + + if len(self.axes) != n_axes_coordinate: raise ValueError("The number of axes must match the number of coordinates.") # validate the coordinates From 77f73f471fbd61eac0c00cf96652d9ac76f0af95 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 14:26:32 -0700 Subject: [PATCH 13/20] refactor --- src/spatialdata/_core/query/spatial_query.py | 68 +++++++++++--------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index fb363121d..260b22896 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -528,6 +528,40 @@ def _( return SpatialData(**new_elements, tables=tables) +def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: + d = {} + for k, data_tree in query_result.items(): + v = data_tree.values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + if 0 in xdata.shape: + if k == "scale0": + return None + else: + d[k] = xdata + + # Remove scales after finding a missing scale + scales_to_keep = [] + for i, scale_name in enumerate(d.keys()): + if scale_name == f"scale{i}": + scales_to_keep.append(scale_name) + else: + break + + # Case in which scale0 is not present but other scales are + if len(scales_to_keep) == 0: + return None + + d = {k: d[k] for k in scales_to_keep} + result = DataTree.from_dict(d) + + # Rechunk the data to avoid irregular chunks + for scale in result: + result[scale]["image"] = result[scale]["image"].chunk("auto") + + return result + + @bounding_box_query.register(DataArray) @bounding_box_query.register(DataTree) def _( @@ -593,46 +627,20 @@ def _( return selection # query the data - query_result = image.sel(selection) + query_result = image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] if isinstance(image, DataArray): if 0 in query_result.shape: return None assert isinstance(query_result, DataArray) # rechunk the data to avoid irregular chunks - image = image.chunk("auto") + query_result = query_result.chunk("auto") else: assert isinstance(image, DataTree) assert isinstance(query_result, DataTree) - - d = {} - for k, data_tree in query_result.items(): - v = data_tree.values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - if 0 in xdata.shape: - if k == "scale0": - return None - else: - d[k] = xdata - # the list of scales may not be contiguous when the data has small shape (for instance with yx = 22 and - # rotations we may end up having scale0 and scale2 but not scale1. Practically this may occur in torch tiler if - # the tiles are request to be too small). - # Here we remove scales after we found a scale missing - scales_to_keep = [] - for i, scale_name in enumerate(d.keys()): - if scale_name == f"scale{i}": - scales_to_keep.append(scale_name) - else: - break - # case in which scale0 is not present but other scales are - if len(scales_to_keep) == 0: + query_result = _process_data_tree_query_result(query_result) + if query_result is None: return None - d = {k: d[k] for k in scales_to_keep} - query_result = DataTree.from_dict(d) - # rechunk the data to avoid irregular chunks - for scale in query_result: - query_result[scale]["image"] = query_result[scale]["image"].chunk("auto") query_result = compute_coordinates(query_result) # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these From 3adfea8655d5a80f55d34b092db40dc3cddcfd3b Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 14:42:53 -0700 Subject: [PATCH 14/20] refactor --- src/spatialdata/_core/query/_utils.py | 99 +++++++++++++++++ src/spatialdata/_core/query/spatial_query.py | 108 +++---------------- 2 files changed, 113 insertions(+), 94 deletions(-) diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index c79e31d73..e45f19b44 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -2,14 +2,22 @@ from typing import Any +import numba as nb import numpy as np from anndata import AnnData +from datatree import DataTree from xarray import DataArray from spatialdata._core._elements import Tables from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array +from spatialdata.transformations._utils import compute_coordinates +from spatialdata.transformations.transformations import ( + BaseTransformation, + Sequence, + Translation, +) def get_bounding_box_corners( @@ -88,6 +96,97 @@ def get_bounding_box_corners( return output.squeeze().drop_vars("box") +@nb.njit(parallel=False, nopython=True) +def _create_slices_and_translation( + min_values: nb.types.Array[nb.float64, nb.float64], + max_values: nb.types.Array[nb.float64, nb.float64], +) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: + n_boxes, n_dims = min_values.shape + slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) + translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) + + for i in range(n_boxes): + for j in range(n_dims): + slices[i, j, 0] = min_values[i, j] + slices[i, j, 1] = max_values[i, j] + translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) + + return slices, translation_vectors + + +def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: + d = {} + for k, data_tree in query_result.items(): + v = data_tree.values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + if 0 in xdata.shape: + if k == "scale0": + return None + else: + d[k] = xdata + + # Remove scales after finding a missing scale + scales_to_keep = [] + for i, scale_name in enumerate(d.keys()): + if scale_name == f"scale{i}": + scales_to_keep.append(scale_name) + else: + break + + # Case in which scale0 is not present but other scales are + if len(scales_to_keep) == 0: + return None + + d = {k: d[k] for k in scales_to_keep} + result = DataTree.from_dict(d) + + # Rechunk the data to avoid irregular chunks + for scale in result: + result[scale]["image"] = result[scale]["image"].chunk("auto") + + return result + + +def _process_query_result( + result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...] +) -> DataArray | DataTree | None: + from spatialdata.transformations import get_transformation, set_transformation + + if isinstance(result, DataArray): + if 0 in result.shape: + return None + # rechunk the data to avoid irregular chunks + result = result.chunk("auto") + elif isinstance(result, DataTree): + result = _process_data_tree_query_result(result) + if result is None: + return None + + result = compute_coordinates(result) + + if not np.allclose(np.array(translation_vector), 0): + translation_transform = Translation(translation=translation_vector, axes=axes) + + transformations = get_transformation(result, get_all=True) + assert isinstance(transformations, dict) + + new_transformations = {} + for coordinate_system, initial_transform in transformations.items(): + new_transformation: BaseTransformation = Sequence( + [translation_transform, initial_transform], + ) + new_transformations[coordinate_system] = new_transformation + set_transformation(result, new_transformations, set_all=True) + + # let's make a copy of the transformations so that we don't modify the original object + t = get_transformation(result, get_all=True) + assert isinstance(t, dict) + set_transformation(result, t.copy(), set_all=True) + + return result + + def _get_filtered_or_unfiltered_tables( filter_table: bool, elements: dict[str, Any], sdata: SpatialData ) -> dict[str, AnnData] | Tables: diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 260b22896..77c4a934b 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -9,7 +9,6 @@ import dask.array as da import dask.dataframe as dd -import numba as nb import numpy as np from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree @@ -34,35 +33,14 @@ points_geopandas_to_dask_dataframe, ) from spatialdata.models._utils import ValidAxis_t, get_spatial_axes -from spatialdata.transformations._utils import compute_coordinates from spatialdata.transformations.operations import set_transformation from spatialdata.transformations.transformations import ( Affine, BaseTransformation, - Sequence, - Translation, _get_affine_for_element, ) -@nb.njit(parallel=False, nopython=True) -def create_slices_and_translation( - min_values: nb.types.Array[nb.float64, nb.float64], - max_values: nb.types.Array[nb.float64, nb.float64], -) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: - n_boxes, n_dims = min_values.shape - slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) - translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) - - for i in range(n_boxes): - for j in range(n_dims): - slices[i, j, 0] = min_values[i, j] - slices[i, j, 1] = max_values[i, j] - translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) - - return slices, translation_vectors - - def _get_bounding_box_corners_in_intrinsic_coordinates( element: SpatialElement, axes: tuple[str, ...], @@ -528,40 +506,6 @@ def _( return SpatialData(**new_elements, tables=tables) -def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: - d = {} - for k, data_tree in query_result.items(): - v = data_tree.values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - if 0 in xdata.shape: - if k == "scale0": - return None - else: - d[k] = xdata - - # Remove scales after finding a missing scale - scales_to_keep = [] - for i, scale_name in enumerate(d.keys()): - if scale_name == f"scale{i}": - scales_to_keep.append(scale_name) - else: - break - - # Case in which scale0 is not present but other scales are - if len(scales_to_keep) == 0: - return None - - d = {k: d[k] for k in scales_to_keep} - result = DataTree.from_dict(d) - - # Rechunk the data to avoid irregular chunks - for scale in result: - result[scale]["image"] = result[scale]["image"].chunk("auto") - - return result - - @bounding_box_query.register(DataArray) @bounding_box_query.register(DataTree) def _( @@ -579,7 +523,7 @@ def _( See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, and for the cases the comments refer to. """ - from spatialdata.transformations import get_transformation, set_transformation + from spatialdata._core.query._utils import _create_slices_and_translation, _process_query_result min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) @@ -608,7 +552,7 @@ def _( min_values_np = min_values_np[np.newaxis, :] max_values_np = max_values_np[np.newaxis, :] - slices, translation_vectors = create_slices_and_translation(min_values_np, max_values_np) + slices, translation_vectors = _create_slices_and_translation(min_values_np, max_values_np) if min_values.ndim == 2: # Multiple boxes selection: list[dict[str, Any]] | dict[str, Any] = [ @@ -627,43 +571,19 @@ def _( return selection # query the data - query_result = image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] - if isinstance(image, DataArray): - if 0 in query_result.shape: - return None - assert isinstance(query_result, DataArray) - # rechunk the data to avoid irregular chunks - query_result = query_result.chunk("auto") + query_result: DataArray | DataTree | list[DataArray | DataTree] = ( + image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] + ) + + if isinstance(query_result, list): + processed_results = [] + for result in query_result: + processed_result = _process_query_result(result, translation_vector, axes) + if processed_result is not None: + processed_results.append(processed_result) + query_result = processed_results if processed_results else None else: - assert isinstance(image, DataTree) - assert isinstance(query_result, DataTree) - query_result = _process_data_tree_query_result(query_result) - if query_result is None: - return None - - query_result = compute_coordinates(query_result) - - # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these - # points is likely starting away from the origin (this is described by translation_vector), so we need to prepend - # this translation to every transformation in the new queries elements (unless the translation_vector is zero, - # in that case the translation is not needed) - if not np.allclose(np.array(translation_vector), 0): - translation_transform = Translation(translation=translation_vector, axes=axes) - - transformations = get_transformation(query_result, get_all=True) - assert isinstance(transformations, dict) - - new_transformations = {} - for coordinate_system, initial_transform in transformations.items(): - new_transformation: BaseTransformation = Sequence( - [translation_transform, initial_transform], - ) - new_transformations[coordinate_system] = new_transformation - set_transformation(query_result, new_transformations, set_all=True) - # let's make a copy of the transformations so that we don't modify the original object - t = get_transformation(query_result, get_all=True) - assert isinstance(t, dict) - set_transformation(query_result, t.copy(), set_all=True) + query_result = _process_query_result(query_result, translation_vector, axes) return query_result From dfdfdbfef7e7ae47a507d9c87ca49e2907c0af42 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 15:03:26 -0700 Subject: [PATCH 15/20] add test for query with multiple bounding boxes --- src/spatialdata/_core/query/spatial_query.py | 8 +- tests/core/query/test_spatial_query.py | 82 ++++++++++++++------ 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 77c4a934b..36728adfb 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -562,10 +562,10 @@ def _( } for box_idx in range(len(min_values_np)) ] - translation_vector = translation_vectors.tolist() + translation_vectors = translation_vectors.tolist() else: # Single box selection = {axis: slice(slices[0, axis_idx, 0], slices[0, axis_idx, 1]) for axis_idx, axis in enumerate(axes)} - translation_vector = translation_vectors[0].tolist() + translation_vectors = translation_vectors[0].tolist() if return_request_only: return selection @@ -577,13 +577,13 @@ def _( if isinstance(query_result, list): processed_results = [] - for result in query_result: + for result, translation_vector in zip(query_result, translation_vectors): processed_result = _process_query_result(result, translation_vector, axes) if processed_result is not None: processed_results.append(processed_result) query_result = processed_results if processed_results else None else: - query_result = _process_query_result(query_result, translation_vector, axes) + query_result = _process_query_result(query_result, translation_vectors, axes) return query_result diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 9444d8e9e..5e4482a81 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -192,8 +192,15 @@ def test_query_points_no_points(): @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) @pytest.mark.parametrize("return_request_only", [True, False]) +@pytest.mark.parametrize("multiple_boxes", [True, False]) def test_query_raster( - n_channels: int, is_labels: bool, is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, return_request_only: bool + n_channels: int, + is_labels: bool, + is_3d: bool, + is_bb_3d: bool, + with_polygon_query: bool, + return_request_only: bool, + multiple_boxes: bool, ): """Apply a bounding box to a raster element.""" if is_labels and n_channels > 1: @@ -232,16 +239,16 @@ def test_query_raster( for image in images: if is_bb_3d: - _min_coordinate = np.array([2, 5, 0]) - _max_coordinate = np.array([7, 10, 5]) + _min_coordinate = np.array([[2, 5, 0], [1, 4, 0]]) if multiple_boxes else np.array([2, 5, 0]) + _max_coordinate = np.array([[7, 10, 5], [6, 9, 4]]) if multiple_boxes else np.array([7, 10, 5]) _axes = ("z", "y", "x") else: - _min_coordinate = np.array([5, 0]) - _max_coordinate = np.array([10, 5]) + _min_coordinate = np.array([[5, 0], [4, 0]]) if multiple_boxes else np.array([5, 0]) + _max_coordinate = np.array([[10, 5], [9, 4]]) if multiple_boxes else np.array([10, 5]) _axes = ("y", "x") if with_polygon_query: - if is_bb_3d: + if is_bb_3d or multiple_boxes: return # make a triangle whose bounding box is the same as the bounding box specified with the query polygon = Polygon([(0, 5), (5, 5), (5, 10)]) @@ -258,29 +265,58 @@ def test_query_raster( return_request_only=return_request_only, ) - slices = {"y": slice(5, 10), "x": slice(0, 5)} - if is_bb_3d and is_3d: - slices["z"] = slice(2, 7) + if multiple_boxes: + slices = [{"y": slice(5, 10), "x": slice(0, 5)}, {"y": slice(4, 9), "x": slice(0, 4)}] + if is_bb_3d and is_3d: + slices[0]["z"] = slice(2, 7) + slices[1]["z"] = slice(1, 6) + else: + slices = {"y": slice(5, 10), "x": slice(0, 5)} + if is_bb_3d and is_3d: + slices["z"] = slice(2, 7) + if return_request_only: - assert isinstance(image_result, dict) - if not (is_bb_3d and is_3d) and ("z" in image_result): - image_result.pop("z") # remove z from slices if `polygon_query` - for k, v in image_result.items(): - assert isinstance(v, slice) - assert image_result[k] == slices[k] + assert isinstance(image_result, (dict, list)) + if multiple_boxes: + for i, result in enumerate(image_result): + if not (is_bb_3d and is_3d) and ("z" in result): + result.pop("z") # remove z from slices if `polygon_query` + for k, v in result.items(): + assert isinstance(v, slice) + assert result[k] == slices[i][k] + else: + if not (is_bb_3d and is_3d) and ("z" in image_result): + image_result.pop("z") # remove z from slices if `polygon_query` + for k, v in image_result.items(): + assert isinstance(v, slice) + assert image_result[k] == slices[k] return - expected_image = ximage.sel(**slices) + if multiple_boxes: + expected_images = [ximage.sel(**s) for s in slices] + else: + expected_image = ximage.sel(**slices) if isinstance(image, DataArray): - assert isinstance(image, DataArray) - np.testing.assert_allclose(image_result, expected_image) + assert isinstance(image_result, (DataArray, list)) + if multiple_boxes: + for result, expected in zip(image_result, expected_images): + np.testing.assert_allclose(result, expected) + else: + np.testing.assert_allclose(image_result, expected_image) elif isinstance(image, DataTree): - assert isinstance(image_result, DataTree) - v = image_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) + assert isinstance(image_result, (DataTree, list)) + if multiple_boxes: + for result, expected in zip(image_result, expected_images): + v = result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected) + else: + v = image_result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected_image) else: raise ValueError("Unexpected type") From 5c5560d517471ff6f67eb9b01dda2d58f02aeaa4 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 15:08:50 -0700 Subject: [PATCH 16/20] fix typing --- src/spatialdata/_core/query/spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 36728adfb..dc85556a2 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -515,7 +515,7 @@ def _( max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, return_request_only: bool = False, -) -> DataArray | DataTree | Mapping[str, slice] | None: +) -> DataArray | DataTree | Mapping[str, slice] | list[DataArray | DataTree] | None: """Implement bounding box query for Spatialdata supported DataArray. Notes From dd2c573d61932b06110dec07bc447f1496310c6a Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 16:09:09 -0700 Subject: [PATCH 17/20] vectorize bounding box query on polygons --- src/spatialdata/_core/query/spatial_query.py | 34 ++++++++++++++------ tests/core/query/test_spatial_query.py | 21 ++++++++---- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dc85556a2..9cb78b64d 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -673,7 +673,7 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, -) -> GeoDataFrame | None: +) -> GeoDataFrame | list[GeoDataFrame] | None: from spatialdata.transformations import get_transformation min_coordinate = _parse_list_into_array(min_coordinate) @@ -695,16 +695,32 @@ def _( max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - intrinsic_bounding_box_corners = intrinsic_bounding_box_corners.data - bounding_box_non_axes_aligned = Polygon(intrinsic_bounding_box_corners) - indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) - queried = polygons[indices] - if len(queried) == 0: - return None + + # Create a list of Polygons for each bounding box old_transformations = get_transformation(polygons, get_all=True) assert isinstance(old_transformations, dict) - del queried.attrs[ShapesModel.TRANSFORM_KEY] - return ShapesModel.parse(queried, transformations=old_transformations.copy()) + + queried_polygons = [] + intrinsic_bounding_box_corners = ( + intrinsic_bounding_box_corners.expand_dims(dim="box") + if "box" not in intrinsic_bounding_box_corners.dims + else intrinsic_bounding_box_corners + ) + for box_corners in intrinsic_bounding_box_corners: + bounding_box_non_axes_aligned = Polygon(box_corners.data) + indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) + queried = polygons[indices] + if len(queried) == 0: + queried_polygon = None + else: + del queried.attrs[ShapesModel.TRANSFORM_KEY] + queried_polygon = ShapesModel.parse(queried, transformations=old_transformations.copy()) + queried_polygons.append(queried_polygon) + if len(queried_polygons) == 0: + return None + if len(queried_polygons) == 1: + return queried_polygons[0] + return queried_polygons # TODO: we can replace the manually triggered deprecation warning heres with the decorator from Wouter diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 5e4482a81..c80270171 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -323,7 +323,8 @@ def test_query_raster( @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) -def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): +@pytest.mark.parametrize("multiple_boxes", [True, False]) +def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): centroids = np.array([[10, 10], [10, 80], [80, 20], [70, 60]]) half_widths = [6] * 4 sd_polygons = _make_squares(centroid_coordinates=centroids, half_widths=half_widths) @@ -339,12 +340,12 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): ) else: if is_bb_3d: - _min_coordinate = np.array([2, 40, 40]) - _max_coordinate = np.array([7, 100, 100]) + _min_coordinate = np.array([[2, 40, 40], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) + _max_coordinate = np.array([[7, 100, 100], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) _axes = ("z", "y", "x") else: - _min_coordinate = np.array([40, 40]) - _max_coordinate = np.array([100, 100]) + _min_coordinate = np.array([[40, 40], [50, 50]]) if multiple_boxes else np.array([40, 40]) + _max_coordinate = np.array([[100, 100], [110, 110]]) if multiple_boxes else np.array([100, 100]) _axes = ("y", "x") polygons_result = bounding_box_query( @@ -355,8 +356,14 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): max_coordinate=_max_coordinate, ) - assert len(polygons_result) == 1 - assert polygons_result.index[0] == 3 + if multiple_boxes and not with_polygon_query: + assert isinstance(polygons_result, list) + assert len(polygons_result) == 2 + assert polygons_result[0].index[0] == 3 + assert len(polygons_result[1]) == 1 + else: + assert len(polygons_result) == 1 + assert polygons_result.index[0] == 3 @pytest.mark.parametrize("is_bb_3d", [True, False]) From be9535834ea0ae1d89597fef0acefb669b6f7b5e Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 16:31:31 -0700 Subject: [PATCH 18/20] add test to cover no polygon overlap (None) --- tests/core/query/test_spatial_query.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index c80270171..e58e3424e 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -324,7 +324,8 @@ def test_query_raster( @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) @pytest.mark.parametrize("multiple_boxes", [True, False]) -def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): +@pytest.mark.parametrize("box_outside_polygon", [True, False]) +def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool, box_outside_polygon: bool): centroids = np.array([[10, 10], [10, 80], [80, 20], [70, 60]]) half_widths = [6] * 4 sd_polygons = _make_squares(centroid_coordinates=centroids, half_widths=half_widths) @@ -342,10 +343,18 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes if is_bb_3d: _min_coordinate = np.array([[2, 40, 40], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) _max_coordinate = np.array([[7, 100, 100], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) + if box_outside_polygon: + _min_coordinate = np.array([[2, 100, 100], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) + _max_coordinate = ( + np.array([[7, 110, 110], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) + ) _axes = ("z", "y", "x") else: _min_coordinate = np.array([[40, 40], [50, 50]]) if multiple_boxes else np.array([40, 40]) _max_coordinate = np.array([[100, 100], [110, 110]]) if multiple_boxes else np.array([100, 100]) + if box_outside_polygon: + _min_coordinate = np.array([[100, 100], [50, 50]]) if multiple_boxes else np.array([40, 40]) + _max_coordinate = np.array([[110, 110], [110, 110]]) if multiple_boxes else np.array([100, 100]) _axes = ("y", "x") polygons_result = bounding_box_query( @@ -359,8 +368,13 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes if multiple_boxes and not with_polygon_query: assert isinstance(polygons_result, list) assert len(polygons_result) == 2 - assert polygons_result[0].index[0] == 3 - assert len(polygons_result[1]) == 1 + if box_outside_polygon: + + assert polygons_result[0] is None + assert polygons_result[1].index[0] == 3 + else: + assert polygons_result[0].index[0] == 3 + assert len(polygons_result[1]) == 1 else: assert len(polygons_result) == 1 assert polygons_result.index[0] == 3 From fad9b1aa2dc72e1c16985efea8ee51f846cfef0a Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 12:52:35 -0700 Subject: [PATCH 19/20] vectorize bounding box query on points and tests --- src/spatialdata/_core/query/spatial_query.py | 123 +++++++++++++------ tests/core/query/test_spatial_query.py | 72 ++++++++--- 2 files changed, 138 insertions(+), 57 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 9cb78b64d..31e75864c 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -385,7 +385,7 @@ def _bounding_box_mask_points( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, ) -> da.Array: - """Compute a mask that is true for the points inside an axis-aligned bounding box. + """Compute a mask that is true for the points inside axis-aligned bounding boxes. Parameters ---------- @@ -394,30 +394,42 @@ def _bounding_box_mask_points( axes The axes that min_coordinate and max_coordinate refer to. min_coordinate - The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions). + The upper left hand corners of the bounding boxes (i.e., minimum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. max_coordinate - The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions). + The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. Returns ------- - The mask for the points inside the bounding box. + The masks for the points inside the bounding boxes. """ element_axes = get_axes_names(points) + min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) + + # Ensure min_coordinate and max_coordinate are 2D arrays + min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate + max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + + n_boxes = min_coordinate.shape[0] in_bounding_box_masks = [] - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - min_value = min_coordinate[axis_index] - in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True)) - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - max_value = max_coordinate[axis_index] - in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True)) - in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1) - return da.all(in_bounding_box_masks, axis=1) + + for box in range(n_boxes): + box_masks = [] + for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue + min_value = min_coordinate[box, axis_index] + max_value = max_coordinate[box, axis_index] + box_masks.append( + points[axis_name].gt(min_value).to_dask_array(lengths=True) + & points[axis_name].lt(max_value).to_dask_array(lengths=True) + ) + bounding_box_mask = da.stack(box_masks, axis=-1) + in_bounding_box_masks.append(da.all(bounding_box_mask, axis=1)) + return in_bounding_box_masks def _dict_query_dispatcher( @@ -601,6 +613,10 @@ def _( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) + # Ensure min_coordinate and max_coordinate are 2D arrays + min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate + max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + # for triggering validation _ = BoundingBoxRequest( target_coordinate_system=target_coordinate_system, @@ -617,9 +633,11 @@ def _( max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - intrinsic_bounding_box_corners = intrinsic_bounding_box_corners.data - min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(axis=0) - max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(axis=0) + min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner") + max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner") + + min_coordinate_intrinsic = min_coordinate_intrinsic.data + max_coordinate_intrinsic = max_coordinate_intrinsic.data # get the points in the intrinsic coordinate bounding box in_intrinsic_bounding_box = _bounding_box_mask_points( @@ -628,10 +646,20 @@ def _( min_coordinate=min_coordinate_intrinsic, max_coordinate=max_coordinate_intrinsic, ) - # if there aren't any points, just return - if in_intrinsic_bounding_box.sum() == 0: + + # assert that the number of bounding boxes is correct + assert len(in_intrinsic_bounding_box) == len(min_coordinate) + points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = [] + for mask in in_intrinsic_bounding_box: + if mask.sum() == 0: + points_in_intrinsic_bounding_box.append(None) + else: + points_in_intrinsic_bounding_box.append(points.loc[mask]) + if len(points_in_intrinsic_bounding_box) == 0: return None - points_in_intrinsic_bounding_box = points.loc[in_intrinsic_bounding_box] + + # assert that the number of queried points is correct + assert len(points_in_intrinsic_bounding_box) == len(min_coordinate) # # we have to reset the index since we have subset # # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask @@ -645,25 +673,42 @@ def _( # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"]) # transform the element to the query coordinate system - points_query_coordinate_system = transform( - points_in_intrinsic_bounding_box, to_coordinate_system=target_coordinate_system, maintain_positioning=False - ) # type: ignore[union-attr] + output: list[DaskDataFrame | None] = [] + for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate): + if p is None: + output.append(None) + else: + points_query_coordinate_system = transform( + p, to_coordinate_system=target_coordinate_system, maintain_positioning=False + ) - # get a mask for the points in the bounding box - bounding_box_mask = _bounding_box_mask_points( - points=points_query_coordinate_system, - axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - ) - bounding_box_indices = np.where(bounding_box_mask.compute())[0] - if len(bounding_box_indices) == 0: + # get a mask for the points in the bounding box + bounding_box_mask = _bounding_box_mask_points( + points=points_query_coordinate_system, + axes=axes, + min_coordinate=min_c, + max_coordinate=max_c, + ) + if len(bounding_box_mask) == 1: + bounding_box_mask = bounding_box_mask[0] + bounding_box_indices = np.where(bounding_box_mask.compute())[0] + + if len(bounding_box_indices) == 0: + output.append(None) + else: + points_df = p.compute().iloc[bounding_box_indices] + old_transformations = get_transformation(p, get_all=True) + assert isinstance(old_transformations, dict) + output.append( + PointsModel.parse( + dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy() + ) + ) + if len(output) == 0: return None - points_df = points_in_intrinsic_bounding_box.compute().iloc[bounding_box_indices] - old_transformations = get_transformation(points, get_all=True) - assert isinstance(old_transformations, dict) - # an alternative approach is to query for each partition in parallel - return PointsModel.parse(dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy()) + if len(output) == 1: + return output[0] + return output @bounding_box_query.register(GeoDataFrame) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index e58e3424e..496bd3e6e 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -108,11 +108,12 @@ def test_bounding_box_request_wrong_coordinate_order(): @pytest.mark.parametrize("is_3d", [True, False]) @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) -def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): +@pytest.mark.parametrize("multiple_boxes", [True, False]) +def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): """test the points bounding box_query""" - data_x = np.array([10, 20, 20, 20]) - data_y = np.array([10, 20, 30, 30]) - data_z = np.array([100, 200, 200, 300]) + data_x = np.array([10, 20, 20, 20, 40]) + data_y = np.array([10, 20, 30, 30, 50]) + data_z = np.array([100, 200, 200, 300, 500]) data = np.stack((data_x, data_y), axis=1) if is_3d: @@ -125,16 +126,24 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): original_z = points_element["z"] if is_bb_3d: - _min_coordinate = np.array([18, 25, 250]) - _max_coordinate = np.array([22, 35, 350]) + if multiple_boxes: + _min_coordinate = np.array([[18, 25, 250], [35, 45, 450], [100, 110, 1100]]) + _max_coordinate = np.array([[22, 35, 350], [45, 55, 550], [110, 120, 1200]]) + else: + _min_coordinate = np.array([18, 25, 250]) + _max_coordinate = np.array([22, 35, 350]) _axes = ("x", "y", "z") else: - _min_coordinate = np.array([18, 25]) - _max_coordinate = np.array([22, 35]) + if multiple_boxes: + _min_coordinate = np.array([[18, 25], [35, 45], [100, 110]]) + _max_coordinate = np.array([[22, 35], [45, 55], [110, 120]]) + else: + _min_coordinate = np.array([18, 25]) + _max_coordinate = np.array([22, 35]) _axes = ("x", "y") if with_polygon_query: - if is_bb_3d: + if is_bb_3d or multiple_boxes: return polygon = Polygon([(18, 25), (18, 35), (22, 35), (22, 25)]) points_result = polygon_query(points_element, polygon=polygon, target_coordinate_system="global") @@ -147,22 +156,49 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): target_coordinate_system="global", ) - # Check that the correct point was selected + # Check that the correct points were selected if is_3d: if is_bb_3d: - np.testing.assert_allclose(points_result["x"].compute(), [20]) - np.testing.assert_allclose(points_result["y"].compute(), [30]) - np.testing.assert_allclose(points_result["z"].compute(), [300]) + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30]) + np.testing.assert_allclose(points_result[0]["z"].compute(), [300]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + np.testing.assert_allclose(points_result[1]["z"].compute(), [500]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20]) + np.testing.assert_allclose(points_result["y"].compute(), [30]) + np.testing.assert_allclose(points_result["z"].compute(), [300]) + else: + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result[0]["z"].compute(), [200, 300]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + np.testing.assert_allclose(points_result[1]["z"].compute(), [500]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) + else: + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + assert points_result[2] is None else: np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) - np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) - else: - np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) - np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) # result should be valid points element - PointsModel.validate(points_result) + if multiple_boxes: + for result in points_result: + if result is None: + continue + PointsModel.validate(result) # original element should be unchanged np.testing.assert_allclose(points_element["x"].compute(), original_x) From 9b977d64b43704426406c230663ea7d8294a3ef2 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 12:53:52 -0700 Subject: [PATCH 20/20] fix type --- src/spatialdata/_core/query/spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 31e75864c..b08e56be1 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -606,7 +606,7 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, -) -> DaskDataFrame | None: +) -> DaskDataFrame | list[DaskDataFrame] | None: from spatialdata import transform from spatialdata.transformations import get_transformation