From 99acf8264838dfe5188b1f0a78b5f09dc1c72835 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 10 Jun 2025 11:10:02 +0200 Subject: [PATCH 1/8] replace asserts by proper errors --- src/spatialdata/_io/io_raster.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 541be3ead..9241c938a 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -126,17 +126,18 @@ def _write_raster( label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: - assert raster_type in ["image", "labels"] + if raster_type not in ["image", "labels"]: + raise TypeError(f"Writing raster data is only supported for 'image' and 'labels'. Got: {raster_type}") # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in # write_multiscale_ngff() when writing metadata, but name is not used in write_image_ngff(). Maybe this is bug of # ome-zarr-py. In any case, we don't need that metadata and we use the argument name so that when we write labels # the correct group is created by the ome-zarr-py APIs. For images we do it manually in the function # _get_group_for_writing_data() - if raster_type == "image": - assert label_metadata is None - else: - metadata["name"] = name - metadata["label_metadata"] = label_metadata + if raster_type == "image" and label_metadata is not None: + raise ValueError("If the rastertype is 'image', 'label_metadata' should be None.") + + metadata["name"] = name + metadata["label_metadata"] = label_metadata write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff From 11d5249c12e06a304a69dced89ee653601fb61c1 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 10 Jun 2025 13:47:10 +0200 Subject: [PATCH 2/8] correct for labels --- src/spatialdata/_io/io_raster.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 9241c938a..ddad6a575 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -135,9 +135,9 @@ def _write_raster( # _get_group_for_writing_data() if raster_type == "image" and label_metadata is not None: raise ValueError("If the rastertype is 'image', 'label_metadata' should be None.") - - metadata["name"] = name - metadata["label_metadata"] = label_metadata + if raster_type == "labels": + metadata["name"] = name + metadata["label_metadata"] = label_metadata write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff From acda400323226c25c48d39886fc105e2c76e52fa Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 11 Jun 2025 00:02:00 +0200 Subject: [PATCH 3/8] add compressor --- src/spatialdata/_core/spatialdata.py | 18 +++++++++++++++++- src/spatialdata/_io/_utils.py | 19 +++++++++++++++++++ src/spatialdata/_io/io_raster.py | 13 +++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 3cb9f8d52..6df499763 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -29,6 +29,7 @@ raise_validation_errors, validate_table_attr_keys, ) +from spatialdata._io._utils import _validate_compressor_args from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import ( @@ -1179,6 +1180,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1204,7 +1206,13 @@ def write( By default, the latest format is used for all elements, i.e. :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. + compressor + A dictionary with as key the type of compression to use for images and labels and as value the compression + level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not + specified, the compression will be `lz4` with compression level 5. """ + _validate_compressor_args(compressor) + if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) @@ -1223,6 +1231,7 @@ def write( element_name=element_name, overwrite=False, format=format, + compressor=compressor, ) if self.path != file_path: @@ -1241,6 +1250,7 @@ def _write_element( element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1260,7 +1270,13 @@ def _write_element( parsed = _parse_formats(formats=format) if element_type == "images": - write_image(image=element, group=element_type_group, name=element_name, format=parsed["raster"]) + write_image( + image=element, + group=element_type_group, + name=element_name, + format=parsed["raster"], + compressor=compressor, + ) elif element_type == "labels": write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"]) elif element_type == "points": diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index bf70b3cc4..0d6b82570 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -442,3 +442,22 @@ def handle_read_errors( else: # on_bad_files == BadFileHandleMethod.ERROR # Let it raise exceptions yield + + +def _validate_compressor_args(compressor_dict: dict[Literal["lz4", "zstd"], int] | None) -> None: + if compressor_dict: + if not isinstance(compressor_dict, dict): + raise TypeError( + f"Expected a dictionary with as key the type of compression to use for images and labels and " + f"as value the compression level which should be inclusive between 1 and 9. " + f"Got type: {compressor_dict}" + ) + if len(compressor_dict) != 1: + raise ValueError( + "Expected a dictionary with a single key indicating the type of compression, either 'lz4' or " + "'zstd' and an `int` inclusive between 1 and 9 as value representing the compression level." + ) + if compression := list(compressor_dict.keys())[0] not in ["lz4", "zstd"]: + raise ValueError(f"Compression must either be `lz4` or `zstd`, got: {compression}.") + if not isinstance(value := list(compressor_dict.values())[0], int) or 0 <= value <= 9: + raise ValueError(f"The compression level must be an integer inclusive between 0 and 9. Got: {value}") diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index ddad6a575..ebf59283e 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -123,9 +123,12 @@ def _write_raster( name: str, format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: + from zarr.codecs import Blosc + if raster_type not in ["image", "labels"]: raise TypeError(f"Writing raster data is only supported for 'image' and 'labels'. Got: {raster_type}") # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in @@ -167,6 +170,10 @@ def _get_group_for_writing_transformations() -> zarr.Group: storage_options["chunks"] = chunks else: storage_options = {"chunks": chunks} + + if compressor and isinstance(storage_options, dict): + ((compression, compression_level),) = compressor.items() + storage_options["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. @@ -200,6 +207,10 @@ def _get_group_for_writing_transformations() -> zarr.Group: # coords = iterate_pyramid_levels(raster_data, "coords") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) storage_options = [{"chunks": chunk} for chunk in chunks] + if compressor: + ((compression, compression_level),) = compressor.items() + for option in storage_options: + option["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) dask_delayed = write_multi_scale_ngff( pyramid=data, group=group_data, @@ -238,6 +249,7 @@ def write_image( name: str, format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: _write_raster( @@ -247,6 +259,7 @@ def write_image( name=name, format=format, storage_options=storage_options, + compressor=compressor, **metadata, ) From 977d0a952e71f161aa55f16eec046b7b295d6d28 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 11 Jun 2025 00:08:26 +0200 Subject: [PATCH 4/8] add comrpessor to write_element --- src/spatialdata/_core/spatialdata.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 6df499763..757b25571 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1209,7 +1209,8 @@ def write( compressor A dictionary with as key the type of compression to use for images and labels and as value the compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not - specified, the compression will be `lz4` with compression level 5. + specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more + efficient compression. """ _validate_compressor_args(compressor) @@ -1278,7 +1279,9 @@ def _write_element( compressor=compressor, ) elif element_type == "labels": - write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"]) + write_labels( + labels=element, group=root_group, name=element_name, format=parsed["raster"], compressor=compressor + ) elif element_type == "points": write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) elif element_type == "shapes": @@ -1293,6 +1296,7 @@ def write_element( element_name: str | list[str], overwrite: bool = False, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1308,6 +1312,11 @@ def write_element( format It is recommended to leave this parameter equal to `None`. See more details in the documentation of `SpatialData.write()`. + compressor + A dictionary with as key the type of compression to use for images and labels and as value the compression + level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not + specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more + efficient compression. Notes ----- @@ -1317,7 +1326,7 @@ def write_element( if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) - self.write_element(name, overwrite=overwrite) + self.write_element(name, overwrite=overwrite, compressor=compressor) return check_valid_name(element_name) @@ -1351,6 +1360,7 @@ def write_element( element_name=element_name, overwrite=overwrite, format=format, + compressor=compressor, ) def delete_element_from_disk(self, element_name: str | list[str]) -> None: From 44761f1058d012569e6607e37bd5781136b3914e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 11 Jun 2025 00:18:04 +0200 Subject: [PATCH 5/8] minor corrections --- src/spatialdata/_core/spatialdata.py | 3 ++- src/spatialdata/_io/_utils.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 757b25571..094959718 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -29,7 +29,6 @@ raise_validation_errors, validate_table_attr_keys, ) -from spatialdata._io._utils import _validate_compressor_args from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import ( @@ -1212,6 +1211,8 @@ def write( specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more efficient compression. """ + from spatialdata._io._utils import _validate_compressor_args + _validate_compressor_args(compressor) if isinstance(file_path, str): diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 0d6b82570..ffe050999 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -450,14 +450,14 @@ def _validate_compressor_args(compressor_dict: dict[Literal["lz4", "zstd"], int] raise TypeError( f"Expected a dictionary with as key the type of compression to use for images and labels and " f"as value the compression level which should be inclusive between 1 and 9. " - f"Got type: {compressor_dict}" + f"Got type: {type(compressor_dict)}" ) if len(compressor_dict) != 1: raise ValueError( "Expected a dictionary with a single key indicating the type of compression, either 'lz4' or " "'zstd' and an `int` inclusive between 1 and 9 as value representing the compression level." ) - if compression := list(compressor_dict.keys())[0] not in ["lz4", "zstd"]: + if (compression := list(compressor_dict.keys())[0]) not in ["lz4", "zstd"]: raise ValueError(f"Compression must either be `lz4` or `zstd`, got: {compression}.") - if not isinstance(value := list(compressor_dict.values())[0], int) or 0 <= value <= 9: + if not isinstance(value := list(compressor_dict.values())[0], int) or not (0 <= value <= 9): raise ValueError(f"The compression level must be an integer inclusive between 0 and 9. Got: {value}") From 25388b0be1c8de292e4effe0b6209a81aac7c93c Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 11 Jun 2025 00:52:36 +0200 Subject: [PATCH 6/8] add tests --- tests/io/test_readwrite.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index ad8c66b4c..0744ca79e 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -7,6 +7,7 @@ import dask.dataframe as dd import numpy as np import pytest +import zarr from anndata import AnnData from numpy.random import default_rng @@ -82,6 +83,23 @@ def test_multiple_tables(self, tmp_path: str, tables: list[AnnData]) -> None: sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) self._test_table(tmp_path, sdata_tables) + def test_compression_roundtrip(self, tmp_path: str, full_sdata: SpatialData): + tmpdir = Path(tmp_path) / "tmp.zarr" + with pytest.raises(TypeError, match="Expected a dictionary with as"): + full_sdata.write(tmpdir, compressor="faulty") + with pytest.raises(ValueError, match="Expected a dictionary with a single"): + full_sdata.write(tmpdir, compressor={"zstd": 8, "other_item": 4}) + with pytest.raises(ValueError, match="Compression must either"): + full_sdata.write(tmpdir, compressor={"faulty": 8}) + with pytest.raises(ValueError, match="Compression must either"): + full_sdata.write(tmpdir, compressor={"The compression level": 10}) + + full_sdata.write(tmpdir, compressor={"zstd": 8}) + + compressor = zarr.open_group(tmpdir / "images", mode="r")["image2d"]["0"].compressor + assert compressor.cname == "zstd" + assert compressor.clevel == 8 + def test_roundtrip( self, tmp_path: str, @@ -252,6 +270,25 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name) + def test_write_element_compression(self, tmp_path: str, full_sdata: SpatialData): + tmpdir = Path(tmp_path) / "compression.zarr" + sdata = SpatialData() + sdata.write(tmpdir) + + sdata["image_lz4"] = full_sdata["image2d"] + sdata["image_zstd"] = full_sdata["image2d"] + + sdata.write_element("image_lz4", compressor={"lz4": 3}) + sdata.write_element("image_zstd", compressor={"zstd": 7}) + + compressor = zarr.open_group(tmpdir / "images", mode="r")["image_lz4"]["0"].compressor + assert compressor.cname == "lz4" + assert compressor.clevel == 3 + + compressor = zarr.open_group(tmpdir / "images", mode="r")["image_zstd"]["0"].compressor + assert compressor.cname == "zstd" + assert compressor.clevel == 7 + def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) -> None: s = table_single_annotation t = s["table"][:10, :].copy() From eac6f3d5db669c56ab7a607e207930443b68c919 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 11 Jun 2025 01:14:54 +0200 Subject: [PATCH 7/8] add tests --- tests/io/test_readwrite.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 0744ca79e..1e13d7704 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -2,7 +2,7 @@ import tempfile from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Literal import dask.dataframe as dd import numpy as np @@ -96,9 +96,16 @@ def test_compression_roundtrip(self, tmp_path: str, full_sdata: SpatialData): full_sdata.write(tmpdir, compressor={"zstd": 8}) - compressor = zarr.open_group(tmpdir / "images", mode="r")["image2d"]["0"].compressor - assert compressor.cname == "zstd" - assert compressor.clevel == 8 + # sourcery skip: no-loop-in-tests + for element in ["image2d", "image2d_multiscale"]: + compressor = zarr.open_group(tmpdir / "images", mode="r")[element]["0"].compressor + assert compressor.cname == "zstd" + assert compressor.clevel == 8 + + for element in ["labels2d", "labels2d_multiscale"]: + compressor = zarr.open_group(tmpdir / "labels", mode="r")[element]["0"].compressor + assert compressor.cname == "zstd" + assert compressor.clevel == 8 def test_roundtrip( self, @@ -270,24 +277,21 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name) - def test_write_element_compression(self, tmp_path: str, full_sdata: SpatialData): + @pytest.mark.parametrize("compressor", [{"lz4": 3}, {"zstd": 7}]) + @pytest.mark.parametrize("element", [("images", "image2d"), ("labels", "labels2d")]) + def test_write_element_compression( + self, tmp_path: str, full_sdata: SpatialData, compressor: dict[Literal["lz4", "zstd"], int], element: str + ): tmpdir = Path(tmp_path) / "compression.zarr" sdata = SpatialData() sdata.write(tmpdir) - sdata["image_lz4"] = full_sdata["image2d"] - sdata["image_zstd"] = full_sdata["image2d"] - - sdata.write_element("image_lz4", compressor={"lz4": 3}) - sdata.write_element("image_zstd", compressor={"zstd": 7}) - - compressor = zarr.open_group(tmpdir / "images", mode="r")["image_lz4"]["0"].compressor - assert compressor.cname == "lz4" - assert compressor.clevel == 3 + sdata["element"] = full_sdata[element[1]] + sdata.write_element("element", compressor=compressor) - compressor = zarr.open_group(tmpdir / "images", mode="r")["image_zstd"]["0"].compressor - assert compressor.cname == "zstd" - assert compressor.clevel == 7 + compression = zarr.open_group(tmpdir / element[0], mode="r")["element"]["0"].compressor + assert compression.cname == list(compressor.keys())[0] + assert compression.clevel == list(compressor.values())[0] def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) -> None: s = table_single_annotation From b9be5d29bd6f40718d609687f883b425201f3ff0 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 11 Jun 2025 14:10:17 +0200 Subject: [PATCH 8/8] refactor of write_raster --- src/spatialdata/_io/io_raster.py | 339 ++++++++++++++++++++++++------- 1 file changed, 264 insertions(+), 75 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index ebf59283e..fa5b16508 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from pathlib import Path from typing import Any, Literal @@ -116,6 +117,213 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) return compute_coordinates(si) +def _get_group_for_writing_transformations( + raster_type: Literal["image", "labels"], group: zarr.Group, name: str +) -> zarr.Group: + """Get the appropriate zarr group for writing transformations. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + group + Parent zarr group + name + Name of the element + + Returns + ------- + The zarr group where transformations should be written + """ + if raster_type == "image": + return group.require_group(name) + return group["labels"][name] + + +def _apply_compression( + storage_options: JSONDict | list[JSONDict], compressor: dict[Literal["lz4", "zstd"], int] | None +) -> JSONDict | list[JSONDict]: + """Apply compression settings to storage options. + + Parameters + ---------- + storage_options + Storage options for zarr arrays + compressor + Compression settings as a dictionary with a single key-value pair + + Returns + ------- + Updated storage options with compression settings + """ + from zarr.codecs import Blosc + + if not compressor: + return storage_options + + ((compression, compression_level),) = compressor.items() + + if isinstance(storage_options, dict): + storage_options["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) + elif isinstance(storage_options, list): + for option in storage_options: + option["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) + + return storage_options + + +def _write_data_array( + raster_type: Literal["image", "labels"], + raster_data: DataArray, + group_data: zarr.Group, + format: Format, + storage_options: JSONDict | None, + compressor: dict[Literal["lz4", "zstd"], int] | None, + metadata: dict[str, Any], + get_transformations_group: Callable[[], zarr.Group], +) -> None: + """Write a DataArray to a zarr group. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + raster_data + The DataArray to write + group_data + The zarr group to write to + format + The spatialdata raster format to use for writing + storage_options + Storage options for zarr arrays (to be passed to ome-zarr) + compressor + Compression settings as a dictionary with a single key-value (compression, compression level) pair + metadata + Additional metadata + get_transformations_group + Function that returns the group for writing transformations + """ + data = raster_data.data + transformations = _get_transformations(raster_data) + input_axes: tuple[str, ...] = tuple(raster_data.dims) + chunks = raster_data.chunks + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) + + # Set up storage options with chunks + if storage_options is not None: + if "chunks" not in storage_options and isinstance(storage_options, dict): + storage_options["chunks"] = chunks + else: + storage_options = {"chunks": chunks} + + # Apply compression if specified + storage_options = _apply_compression(storage_options, compressor) + + # Determine which write function to use + write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff + + # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. + # We need this because the argument of write_image_ngff is called image while the argument of + # write_labels_ngff is called label. + metadata[raster_type] = data + + # Write the data + write_single_scale_ngff( + group=group_data, + scaler=None, + fmt=format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + ) + + # Write transformations + assert transformations is not None + overwrite_coordinate_transformations_raster( + group=get_transformations_group(), transformations=transformations, axes=input_axes + ) + + +def _write_data_tree( + raster_type: Literal["image", "labels"], + raster_data: DataTree, + group_data: zarr.Group, + format: Format, + storage_options: JSONDict | list[JSONDict] | None, + compressor: dict[Literal["lz4", "zstd"], int] | None, + metadata: dict[str, Any], + get_transformations_group: Callable[[], zarr.Group], +) -> None: + """Write a DataTree to a zarr group. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + raster_data + The DataTree to write + group_data + The zarr group to write to + format + The SpatialData raster format to use for writing + storage_options + Storage options for zarr arrays (to be passed to ome-zarr) + compressor + Compression settings as a dictionary with a single key-value (compression, compression level) pair + metadata + Additional metadata + get_transformations_group + Function that returns the group for writing transformations + """ + data = get_pyramid_levels(raster_data, attr="data") + list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") + assert len(set(list_of_input_axes)) == 1 + input_axes = list_of_input_axes[0] + + # Saving only the transformations of the first scale + d = dict(raster_data["scale0"]) + assert len(d) == 1 + xdata = d.values().__iter__().__next__() + transformations = _get_transformations_xarray(xdata) + assert transformations is not None + assert len(transformations) > 0 + + chunks = get_pyramid_levels(raster_data, "chunks") + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) + + # Set up storage options with chunks + if storage_options is None: + storage_options = [{"chunks": chunk} for chunk in chunks] + + # Apply compression if specified + storage_options = _apply_compression(storage_options, compressor) + + # Determine which write function to use + write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff + + # Write the data + dask_delayed = write_multi_scale_ngff( + pyramid=data, + group=group_data, + fmt=format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + compute=False, + ) + + # Compute all pyramid levels at once to allow Dask to optimize the computational graph. + da.compute(*dask_delayed) + + # Write transformations + assert transformations is not None + overwrite_coordinate_transformations_raster( + group=get_transformations_group(), transformations=transformations, axes=tuple(input_axes) + ) + + def _write_raster( raster_type: Literal["image", "labels"], raster_data: DataArray | DataTree, @@ -127,119 +335,98 @@ def _write_raster( label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: - from zarr.codecs import Blosc + """Write raster data to a zarr group. + This function handles writing both image and label data, in both single-scale (DataArray) + and multi-scale (DataTree) formats. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + raster_data + The data to write, either a DataArray (single-scale) or DataTree (multi-scale) + group + The zarr group to write to + name + Name of the element + format + The raster format to use for writing + storage_options + Storage options for zarr arrays (to be passed to ome-zarr) + compressor + Compression settings as a dictionary with a single key-value (compression, compression level) pair + label_metadata + Metadata specific to labels + **metadata + Additional metadata + """ + # Validate inputs if raster_type not in ["image", "labels"]: raise TypeError(f"Writing raster data is only supported for 'image' and 'labels'. Got: {raster_type}") - # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in + + # The argument "name" and "label_metadata" are only used for labels (to be precise, name is used in # write_multiscale_ngff() when writing metadata, but name is not used in write_image_ngff(). Maybe this is bug of # ome-zarr-py. In any case, we don't need that metadata and we use the argument name so that when we write labels # the correct group is created by the ome-zarr-py APIs. For images we do it manually in the function # _get_group_for_writing_data() if raster_type == "image" and label_metadata is not None: raise ValueError("If the rastertype is 'image', 'label_metadata' should be None.") + if raster_type == "labels": metadata["name"] = name metadata["label_metadata"] = label_metadata - write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff - write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff - + # Prepare the group for writing data group_data = group.require_group(name) if raster_type == "image" else group - def _get_group_for_writing_transformations() -> zarr.Group: - if raster_type == "image": - return group.require_group(name) - return group["labels"][name] + # Create a function to get the transformations group + get_transformations_group = lambda: _get_group_for_writing_transformations(raster_type, group, name) - # convert channel names to channel metadata in omero + # Convert channel names to channel metadata in omero for images if raster_type == "image": metadata["metadata"] = {"omero": {"channels": []}} channels = get_channel_names(raster_data) for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + # Write the data based on its type if isinstance(raster_data, DataArray): - data = raster_data.data - transformations = _get_transformations(raster_data) - input_axes: tuple[str, ...] = tuple(raster_data.dims) - chunks = raster_data.chunks - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - if storage_options is not None: - if "chunks" not in storage_options and isinstance(storage_options, dict): - storage_options["chunks"] = chunks - else: - storage_options = {"chunks": chunks} - - if compressor and isinstance(storage_options, dict): - ((compression, compression_level),) = compressor.items() - storage_options["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of - # write_labels_ngff is called label. - metadata[raster_type] = data - write_single_scale_ngff( - group=group_data, - scaler=None, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, + _write_data_array( + raster_type=raster_type, + raster_data=raster_data, + group_data=group_data, + format=format, storage_options=storage_options, - **metadata, - ) - assert transformations is not None - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes + compressor=compressor, + metadata=metadata, + get_transformations_group=get_transformations_group, ) elif isinstance(raster_data, DataTree): - data = get_pyramid_levels(raster_data, attr="data") - list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") - assert len(set(list_of_input_axes)) == 1 - input_axes = list_of_input_axes[0] - # saving only the transformations of the first scale - d = dict(raster_data["scale0"]) - assert len(d) == 1 - xdata = d.values().__iter__().__next__() - transformations = _get_transformations_xarray(xdata) - assert transformations is not None - assert len(transformations) > 0 - chunks = get_pyramid_levels(raster_data, "chunks") - # coords = iterate_pyramid_levels(raster_data, "coords") - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - storage_options = [{"chunks": chunk} for chunk in chunks] - if compressor: - ((compression, compression_level),) = compressor.items() - for option in storage_options: - option["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) - dask_delayed = write_multi_scale_ngff( - pyramid=data, - group=group_data, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, + _write_data_tree( + raster_type=raster_type, + raster_data=raster_data, + group_data=group_data, + format=format, storage_options=storage_options, - **metadata, - compute=False, - ) - # Compute all pyramid levels at once to allow Dask to optimize the computational graph. - da.compute(*dask_delayed) - assert transformations is not None - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=tuple(input_axes) + compressor=compressor, + metadata=metadata, + get_transformations_group=get_transformations_group, ) else: raise ValueError("Not a valid labels object") - # as explained in a comment in format.py, since coordinate transformations are not part of NGFF yet, we need to have + # Write format version metadata + # As explained in a comment in format.py, since coordinate transformations are not part of NGFF yet, we need to have # our spatialdata extension also for raster type (eventually it will be dropped in favor of pure NGFF). Until then, # saving the NGFF version (i.e. 0.4) is not enough, and we need to also record which version of the spatialdata # format we are using for raster types - group = _get_group_for_writing_transformations() + group = get_transformations_group() if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} attrs = group.attrs[ATTRS_KEY] attrs["version"] = format.spatialdata_format_version - # triggers the write operation + # Triggers the write operation group.attrs[ATTRS_KEY] = attrs @@ -271,6 +458,7 @@ def write_labels( format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, **metadata: JSONDict, ) -> None: _write_raster( @@ -280,6 +468,7 @@ def write_labels( name=name, format=format, storage_options=storage_options, + compressor=compressor, label_metadata=label_metadata, **metadata, )