diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index ee49f8203..ff323664d 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -474,6 +474,7 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, + return_array: bool = False, ) -> SpatialImage | MultiscaleSpatialImage | None: """Implement bounding box query for SpatialImage. @@ -481,6 +482,22 @@ def _( ----- See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, and for the cases the comments refer to. + + Parameters + ---------- + image + The image to query. + axes + The axes the coordinates are expressed in. + min_coordinate + The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions). + max_coordinate + The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions + target_coordinate_system + The coordinate system the bounding box is defined in. + return_array + If `True`, return the query result as a `numpy.ndarray` and it does not parse it into a `SpatialImage` + or `MultiscaleSpatialImage`. """ from spatialdata.transformations import get_transformation, set_transformation @@ -563,6 +580,8 @@ def _( if 0 in query_result.shape: return None assert isinstance(query_result, SpatialImage) + if return_array: + return query_result.data # rechunk the data to avoid irregular chunks image = image.chunk("auto") else: @@ -580,6 +599,8 @@ def _( return None else: d[k] = xdata + if return_array and k == "scale0": + return xdata.data # the list of scales may not be contiguous when the data has small shape (for instance with yx = 22 and # rotations we may end up having scale0 and scale2 but not scale1. Practically this may occur in torch tiler if # the tiles are request to be too small). @@ -820,7 +841,6 @@ def _( images: bool = True, labels: bool = True, ) -> SpatialData: - _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} for element_type in ["points", "images", "labels", "shapes"]: diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index b19c22379..c59a5d5da 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -81,6 +81,10 @@ class ImageTilesDataset(Dataset): system; this back-transforms the target tile into the pixel coordinates. If the back-transformed tile is not aligned with the pixel grid, the returned tile will correspond to the bounding box of the back-transformed tile (so that the returned tile is axis-aligned to the pixel grid). + return_genes: + If not `None`, return the gene expression values from the table. The dictionary should have the following + structure: `{"layer_name": None}` or `{"layer": ["gene_name1", "gene_name2"]}`. + If the value is `None`, all the genes are returned. return_annotations If not `None`, one or more values from the table are returned together with the image tile in a tuple. Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` can be returned. @@ -96,6 +100,9 @@ class ImageTilesDataset(Dataset): It is a `Callable`, with `Any` as return type, that takes as input the (image, table_value) tuple (when `return_annotations` is not `None`) or a `Callable` that takes as input the `SpatialData` object (when `return_annotations` is `None`). + return_array + If `True`, the tile is returned as an :class:`dask.array.Array` object; otherwise, it is returned as a + :class:`spatial_image.SpatialImage` object. rasterize_kwargs Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is `True`. This argument can be used in particular to choose the pixel dimension of the produced image tiles; please refer @@ -119,9 +126,11 @@ def __init__( tile_scale: float = 1.0, tile_dim_in_units: float | None = None, rasterize: bool = False, + return_genes: Mapping[str, list[str] | None] | None = None, return_annotations: str | list[str] | None = None, table_name: str | None = None, transform: Callable[[Any], Any] | None = None, + return_array: bool = False, rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}), ): from spatialdata import bounding_box_query @@ -144,9 +153,9 @@ def __init__( **dict(rasterize_kwargs), ) if rasterize - else bounding_box_query # type: ignore[assignment] + else partial(bounding_box_query, return_array=return_array) ) - self._return = self._get_return(return_annotations, table_name) + self._return = self._get_return(return_annotations, table_name, return_array=return_array) self.transform = transform def _validate( @@ -154,13 +163,14 @@ def _validate( sdata: SpatialData, regions_to_images: dict[str, str], regions_to_coordinate_systems: dict[str, str], + return_genes: Mapping[str, list[str] | None] | None, return_annotations: str | list[str] | None, table_name: str | None, ) -> None: """Validate input parameters.""" self.sdata = sdata - if return_annotations is not None and table_name is None: - raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.") + if (return_annotations is not None) or (return_genes is None) and table_name is None: + raise ValueError("`table_name` must be provided if `return_annotations` or `return_genes` is not `None`.") # check that the regions specified in the two dicts are the same assert set(regions_to_images.keys()) == set( @@ -260,6 +270,7 @@ def _preprocess( if table_name is not None: table_subset = filtered_table[filtered_table.obs[region_key] == region_name] + table_subset.uns["spatialdata_attrs"]["region"] = region_name circles_sdata = SpatialData.init_from_elements({region_name: circles}, tables=table_subset.copy()) _, table = join_spatialelement_table( sdata=circles_sdata, @@ -298,15 +309,18 @@ def _return_function( dataset_table: AnnData, dataset_index: pd.DataFrame, table_name: str | None, + return_genes: Mapping[str, list[str] | None], return_annot: str | list[str] | None, + return_array: bool = False, ) -> tuple[Any, Any] | SpatialData: - tile = ImageTilesDataset._ensure_single_scale(tile) + if not return_array: + tile = ImageTilesDataset._ensure_single_scale(tile) if return_annot is not None: # table is always returned as array shape (1, len(return_annot)) # where return_table can be a single column or a list of columns return_annot = [return_annot] if isinstance(return_annot, str) else return_annot # return tuple of (tile, table) - if np.all([i in dataset_table.obs for i in return_annot]): + if np.all(dataset_table.obs.columns.isin(return_annot)): return tile, dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1) if np.all([i in dataset_table.var_names for i in return_annot]): if issparse(dataset_table.X): @@ -330,10 +344,48 @@ def _return_function( ) return SpatialData(images={dataset_index.iloc[idx][ImageTilesDataset.IMAGE_KEY]: tile}) + @staticmethod + def _return_annotations( + idx: int, + dataset_table: AnnData, + dataset_index: pd.DataFrame, + table_name: str | None, + return_annot: str | list[str], + ) -> pd.DataFrame: + # table is always returned as array shape (1, len(return_annot)) + # where return_table can be a single column or a list of columns + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + # return tuple of (tile, table) + if np.all(dataset_table.obs.columns.isin(return_annot)): + return dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1) + else: + raise KeyError("Missing some valid annotations in the table.") + + @staticmethod + def _return_genes( + idx: int, + dataset_table: AnnData, + dataset_index: pd.DataFrame, + table_name: str | None, + return_genes: Mapping[str, list[str] | None], + ) -> pd.DataFrame: + k, v = next(iter(return_genes.items())) + layer = dataset_table.X if k == "X" else dataset_table.layers[k].X + if v is None: + if issparse(layer): + return layer[idx].X.A + return layer[idx].X + if isinstance(v, list) and np.all(dataset_table.var_names.isin(v)): + if issparse(layer): + return layer[idx, v].X.A + return layer[idx, v].X + raise KeyError("Missing some valid genes in the table.") + def _get_return( self, return_annot: str | list[str] | None, table_name: str | None, + return_array: bool = False, ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: """Get function to return values from the table of the dataset.""" return partial( @@ -342,6 +394,7 @@ def _get_return( dataset_index=self.dataset_index, table_name=table_name, return_annot=return_annot, + return_array=return_array, ) def __len__(self) -> int: