Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,30 @@ def _(
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
target_coordinate_system: str,
return_array: bool = False,
) -> SpatialImage | MultiscaleSpatialImage | None:
"""Implement bounding box query for SpatialImage.

Notes
-----
See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code,
and for the cases the comments refer to.

Parameters
----------
image
The image to query.
axes
The axes the coordinates are expressed in.
min_coordinate
The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions).
max_coordinate
The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions
target_coordinate_system
The coordinate system the bounding box is defined in.
return_array
If `True`, return the query result as a `numpy.ndarray` and it does not parse it into a `SpatialImage`
or `MultiscaleSpatialImage`.
"""
from spatialdata.transformations import get_transformation, set_transformation

Expand Down Expand Up @@ -563,6 +580,8 @@ def _(
if 0 in query_result.shape:
return None
assert isinstance(query_result, SpatialImage)
if return_array:
return query_result.data
# rechunk the data to avoid irregular chunks
image = image.chunk("auto")
else:
Expand All @@ -580,6 +599,8 @@ def _(
return None
else:
d[k] = xdata
if return_array and k == "scale0":
return xdata.data
# the list of scales may not be contiguous when the data has small shape (for instance with yx = 22 and
# rotations we may end up having scale0 and scale2 but not scale1. Practically this may occur in torch tiler if
# the tiles are request to be too small).
Expand Down Expand Up @@ -820,7 +841,6 @@ def _(
images: bool = True,
labels: bool = True,
) -> SpatialData:

_check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels})
new_elements = {}
for element_type in ["points", "images", "labels", "shapes"]:
Expand Down
65 changes: 59 additions & 6 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class ImageTilesDataset(Dataset):
system; this back-transforms the target tile into the pixel coordinates. If the back-transformed tile is not
aligned with the pixel grid, the returned tile will correspond to the bounding box of the back-transformed tile
(so that the returned tile is axis-aligned to the pixel grid).
return_genes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Two comments:

  1. I would specify that the layers are AnnData layers and the default layer is X.
  2. I would also allow to pass just a list instead of a dict, that would be interpreted as {'X': genes_list}

If not `None`, return the gene expression values from the table. The dictionary should have the following
structure: `{"layer_name": None}` or `{"layer": ["gene_name1", "gene_name2"]}`.
If the value is `None`, all the genes are returned.
return_annotations
If not `None`, one or more values from the table are returned together with the image tile in a tuple.
Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` can be returned.
Expand All @@ -96,6 +100,9 @@ class ImageTilesDataset(Dataset):
It is a `Callable`, with `Any` as return type, that takes as input the (image, table_value) tuple (when
`return_annotations` is not `None`) or a `Callable` that takes as input the `SpatialData` object (when
`return_annotations` is `None`).
return_array
If `True`, the tile is returned as an :class:`dask.array.Array` object; otherwise, it is returned as a
:class:`spatial_image.SpatialImage` object.
rasterize_kwargs
Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is `True`.
This argument can be used in particular to choose the pixel dimension of the produced image tiles; please refer
Expand All @@ -119,9 +126,11 @@ def __init__(
tile_scale: float = 1.0,
tile_dim_in_units: float | None = None,
rasterize: bool = False,
return_genes: Mapping[str, list[str] | None] | None = None,
return_annotations: str | list[str] | None = None,
table_name: str | None = None,
transform: Callable[[Any], Any] | None = None,
return_array: bool = False,
rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
from spatialdata import bounding_box_query
Expand All @@ -144,23 +153,24 @@ def __init__(
**dict(rasterize_kwargs),
)
if rasterize
else bounding_box_query # type: ignore[assignment]
else partial(bounding_box_query, return_array=return_array)
)
self._return = self._get_return(return_annotations, table_name)
self._return = self._get_return(return_annotations, table_name, return_array=return_array)
self.transform = transform

def _validate(
self,
sdata: SpatialData,
regions_to_images: dict[str, str],
regions_to_coordinate_systems: dict[str, str],
return_genes: Mapping[str, list[str] | None] | None,
return_annotations: str | list[str] | None,
table_name: str | None,
) -> None:
"""Validate input parameters."""
self.sdata = sdata
if return_annotations is not None and table_name is None:
raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.")
if (return_annotations is not None) or (return_genes is None) and table_name is None:
raise ValueError("`table_name` must be provided if `return_annotations` or `return_genes` is not `None`.")

# check that the regions specified in the two dicts are the same
assert set(regions_to_images.keys()) == set(
Expand Down Expand Up @@ -260,6 +270,7 @@ def _preprocess(

if table_name is not None:
table_subset = filtered_table[filtered_table.obs[region_key] == region_name]
table_subset.uns["spatialdata_attrs"]["region"] = region_name
circles_sdata = SpatialData.init_from_elements({region_name: circles}, tables=table_subset.copy())
_, table = join_spatialelement_table(
sdata=circles_sdata,
Expand Down Expand Up @@ -298,15 +309,18 @@ def _return_function(
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
return_genes: Mapping[str, list[str] | None],
return_annot: str | list[str] | None,
return_array: bool = False,
) -> tuple[Any, Any] | SpatialData:
tile = ImageTilesDataset._ensure_single_scale(tile)
if not return_array:
tile = ImageTilesDataset._ensure_single_scale(tile)
if return_annot is not None:
# table is always returned as array shape (1, len(return_annot))
# where return_table can be a single column or a list of columns
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
# return tuple of (tile, table)
if np.all([i in dataset_table.obs for i in return_annot]):
if np.all(dataset_table.obs.columns.isin(return_annot)):
return tile, dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1)
if np.all([i in dataset_table.var_names for i in return_annot]):
if issparse(dataset_table.X):
Expand All @@ -330,10 +344,48 @@ def _return_function(
)
return SpatialData(images={dataset_index.iloc[idx][ImageTilesDataset.IMAGE_KEY]: tile})

@staticmethod
def _return_annotations(
idx: int,
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
return_annot: str | list[str],
) -> pd.DataFrame:
# table is always returned as array shape (1, len(return_annot))
# where return_table can be a single column or a list of columns
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
# return tuple of (tile, table)
if np.all(dataset_table.obs.columns.isin(return_annot)):
return dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1)
else:
raise KeyError("Missing some valid annotations in the table.")

@staticmethod
def _return_genes(
idx: int,
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
return_genes: Mapping[str, list[str] | None],
) -> pd.DataFrame:
k, v = next(iter(return_genes.items()))
layer = dataset_table.X if k == "X" else dataset_table.layers[k].X
if v is None:
if issparse(layer):
return layer[idx].X.A
return layer[idx].X
if isinstance(v, list) and np.all(dataset_table.var_names.isin(v)):
if issparse(layer):
return layer[idx, v].X.A
return layer[idx, v].X
raise KeyError("Missing some valid genes in the table.")

def _get_return(
self,
return_annot: str | list[str] | None,
table_name: str | None,
return_array: bool = False,
) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]:
"""Get function to return values from the table of the dataset."""
return partial(
Expand All @@ -342,6 +394,7 @@ def _get_return(
dataset_index=self.dataset_index,
table_name=table_name,
return_annot=return_annot,
return_array=return_array,
)

def __len__(self) -> int:
Expand Down