diff --git a/.gitignore b/.gitignore index e56499d..3dba468 100755 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,6 @@ tutorials/dataset-PixelPandemonium/* #_*.py dicom_select examples +slicerio_data +*.nrrd +nnUNet_results \ No newline at end of file diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index d669844..1085cab 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -662,6 +662,10 @@ def parent(self): def bids_format(self): return self.format + @property + def mod(self): + return self.mod + def get_parent(self, file_type=None): return self.get_path_decomposed(file_type)[1] @@ -1213,7 +1217,7 @@ def filter_format(self, filter_fun: list[str] | str | typing.Callable[[str | obj return self.filter_format(lambda x: x in filter_fun) return self.filter("format", filter_fun=filter_fun, required=True) - def filter_filetype(self, filter_fun: str | typing.Callable[[str | object], bool], required=True): + def filter_filetype(self, filter_fun: list[str] | str | typing.Callable[[str | object], bool], required=True): return self.filter("filetype", filter_fun=filter_fun, required=required) def filter_non_existence( diff --git a/TPTBox/core/internal/slicer_nrrd.py b/TPTBox/core/internal/slicer_nrrd.py new file mode 100644 index 0000000..a17e9f7 --- /dev/null +++ b/TPTBox/core/internal/slicer_nrrd.py @@ -0,0 +1,775 @@ +from __future__ import annotations + +import re +from collections import OrderedDict +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from TPTBox.core.vert_constants import log, logging +from TPTBox.logger.log_constants import Log_Type + +if TYPE_CHECKING: + from TPTBox.core.nii_wrapper import NII + + +def _read(filename, skip_voxels=False, verbos=True): + """Read segmentation metadata from a .seg.nrrd file or NIFTI file and store it in a dict. + + Example header: + + NRRD0004 + # Complete NRRD file format specification at: + # http://teem.sourceforge.net/nrrd/format.html + type: unsigned char + dimension: 3 + space: left-posterior-superior + sizes: 128 128 34 + space directions: (-3.04687595367432,0,0) (0,-3.04687595367432,0) (0,0,9.9999999999999964) + kinds: domain domain domain + encoding: gzip + space origin: (193.09599304199222,216.39599609374994,-340.24999999999994) + Segment0_Color:=0.992157 0.909804 0.619608 + Segment0_ColorAutoGenerated:=1 + Segment0_Extent:=0 124 0 127 0 33 + Segment0_ID:=Segment_1 + Segment0_LabelValue:=1 + Segment0_Layer:=0 + Segment0_Name:=ribs + Segment0_NameAutoGenerated:=1 + Segment0_Tags:=Segmentation.Status:inprogress|TerminologyEntry:Segmentation category and type - 3D Slicer General Anatomy list~SCT^123037004^Anatomical Structure~SCT^113197003^Rib~^^~Anatomic codes - DICOM master list~^^~^^| + Segment1_Color:=1 1 0.811765 + Segment1_ColorAutoGenerated:=1 + Segment1_Extent:=0 124 0 127 0 33 + Segment1_ID:=Segment_2 + Segment1_LabelValue:=2 + Segment1_Layer:=0 + Segment1_Name:=cervical vertebral column + Segment1_NameAutoGenerated:=1 + Segment1_Tags:=Segmentation.Status:inprogress|TerminologyEntry:Segmentation category and type - 3D Slicer General Anatomy list~SCT^123037004^Anatomical Structure~SCT^122494005^Cervical spine~^^~Anatomic codes - DICOM master list~^^~^^| + Segment2_Color:=0.886275 0.792157 0.52549 + Segment2_ColorAutoGenerated:=1 + Segment2_Extent:=0 124 0 127 0 33 + Segment2_ID:=Segment_3 + Segment2_LabelValue:=3 + Segment2_Layer:=0 + Segment2_Name:=thoracic vertebral column + Segment2_NameAutoGenerated:=1 + Segment2_Tags:=Some field:some value|Segmentation.Status:inprogress|TerminologyEntry:Segmentation category and type - 3D Slicer General Anatomy list~SCT^123037004^Anatomical Structure~SCT^122495006^Thoracic spine~^^~Anatomic codes - DICOM master list~^^~^^| + Segmentation_ContainedRepresentationNames:=Binary labelmap|Closed surface| + Segmentation_ConversionParameters:=Collapse labelmaps|1|Merge the labelmaps into as few shared labelmaps as possible 1 = created labelmaps will be shared if possible without overwriting each other.&Compute surface normals|1|Compute surface normals. 1 (default) = surface normals are computed. 0 = surface normals are not computed (slightly faster but produces less smooth surface display).&Crop to reference image geometry|0|Crop the model to the extent of reference geometry. 0 (default) = created labelmap will contain the entire model. 1 = created labelmap extent will be within reference image extent.&Decimation factor|0.0|Desired reduction in the total number of polygons. Range: 0.0 (no decimation) to 1.0 (as much simplification as possible). Value of 0.8 typically reduces data set size by 80% without losing too much details.&Default slice thickness|0.0|Default thickness for contours if slice spacing cannot be calculated.&End capping|1|Create end cap to close surface inside contours on the top and bottom of the structure.\n0 = leave contours open on surface exterior.\n1 (default) = close surface by generating smooth end caps.\n2 = close surface by generating straight end caps.&Fractional labelmap oversampling factor|1|Determines the oversampling of the reference image geometry. All segments are oversampled with the same value (value of 1 means no oversampling).&Joint smoothing|0|Perform joint smoothing.&Oversampling factor|1|Determines the oversampling of the reference image geometry. If it's a number, then all segments are oversampled with the same value (value of 1 means no oversampling). If it has the value "A", then automatic oversampling is calculated.&Reference image geometry|3.0468759536743195;0;0;-193.0959930419922;0;3.0468759536743195;0;-216.39599609374994;0;0;9.999999999999998;-340.24999999999994;0;0;0;1;0;127;0;127;0;33;|Image geometry description string determining the geometry of the labelmap that is created in course of conversion. Can be copied from a volume, using the button.&Smoothing factor|-0.5|Smoothing factor. Range: 0.0 (no smoothing) to 1.0 (strong smoothing).&Threshold fraction|0.5|Determines the threshold that the closed surface is created at as a fractional value between 0 and 1.& + Segmentation_MasterRepresentation:=Binary labelmap + Segmentation_ReferenceImageExtentOffset:=0 0 0 + + Example header in case of overlapping segments: + + NRRD0004 + # Complete NRRD file format specification at: + # http://teem.sourceforge.net/nrrd/format.html + type: unsigned char + dimension: 4 + space: left-posterior-superior + sizes: 5 256 256 130 + space directions: none (0,1,0) (0,0,-1) (-1.2999954223632812,0,0) + kinds: list domain domain domain + encoding: gzip + space origin: (86.644897460937486,-133.92860412597656,116.78569793701172) + Segment0_... + + Returned segmentation object: + + { + "voxels": (numpy array of voxel values), + "ijkToLPS": [[ -3.04687595, 0. , 0. , 193.09599304], + [ 0. , -3.04687595, 0. , 216.39599609], + [ 0. , 0. , 10. , -340.25 ], + [ 0. , 0. , 0. , 1. ]], + "encoding": "gzip", + "containedRepresentationNames": ["Binary labelmap", "Closed surface"], + "conversionParameters": [ + {"name": "Collapse labelmaps", "value": "1", "description": "Merge the labelmaps into as few shared labelmaps as possible 1 = created labelmaps will be shared if possible without overwriting each other."}, + {"name": "Compute surface normals", "value": "1", "description": "Compute surface normals. 1 (default) = surface normals are computed. 0 = surface normals are not computed (slightly faster but produces less smooth surface display)."}, + {"name": "Crop to reference image geometry", "value": "0", "description": "Crop the model to the extent of reference geometry. 0 (default) = created labelmap will contain the entire model. 1 = created labelmap extent will be within reference image extent."}, + {"name": "Decimation factor", "value": "0.0", "description": "Desired reduction in the total number of polygons. Range: 0.0 (no decimation) to 1.0 (as much simplification as possible). Value of 0.8 typically reduces data set size by 80% without losing too much details."}, + {"name": "Default slice thickness", "value": "0.0", "description": "Default thickness for contours if slice spacing cannot be calculated."}, + {"name": "End capping", "value": "1", "description": "Create end cap to close surface inside contours on the top and bottom of the structure.\n0 = leave contours open on surface exterior.\n1 (default) = close surface by generating smooth end caps.\n2 = close surface by generating straight end caps."}, + {"name": "Fractional labelmap oversampling factor", "value": "1", "description": "Determines the oversampling of the reference image geometry. All segments are oversampled with the same value (value of 1 means no oversampling)."}, + {"name": "Joint smoothing", "value": "0", "description": "Perform joint smoothing."}, + {"name": "Oversampling factor", "value": "1", "description": "Determines the oversampling of the reference image geometry. If it's a number, then all segments are oversampled with the same value (value of 1 means no oversampling). If it has the value \"A\", then automatic oversampling is calculated."}, + {"name": "Reference image geometry", "value": "3.0468759536743195;0;0;-193.0959930419922;0;3.0468759536743195;0;-216.39599609374994;0;0;9.999999999999998;-340.24999999999994;0;0;0;1;0;127;0;127;0;33;", "description": "Image geometry description string determining the geometry of the labelmap that is created in course of conversion. Can be copied from a volume, using the button."}, + {"name": "Smoothing factor", "value": "-0.5", "description": "Smoothing factor. Range: 0.0 (no smoothing) to 1.0 (strong smoothing)."}, + {"name": "Threshold fraction", "value": "0.5", "description": "Determines the threshold that the closed surface is created at as a fractional value between 0 and 1."} + ], + "masterRepresentation": "Binary labelmap", + "referenceImageExtentOffset": [0, 0, 0] + "segments": [ + { + "color": [0.992157, 0.909804, 0.619608], + "colorAutoGenerated": true, + "extent": [0, 124, 0, 127, 0, 33], + "id": "Segment_1", + "labelValue": 1, + "layer": 0, + "name": "ribs", + "nameAutoGenerated": true, + "status": "inprogress", + "terminology": { + "contextName": "Segmentation category and type - 3D Slicer General Anatomy list", + "category": ["SCT", "123037004", "Anatomical Structure"], + "type": ["SCT", "113197003", "Rib"] } + }, + { + "color": [1.0, 1.0, 0.811765], + "colorAutoGenerated": true, + "extent": [0, 124, 0, 127, 0, 33], + "id": "Segment_2", + "labelValue": 2, + "layer": 0, + "name": "cervical vertebral column", + "nameAutoGenerated": true, + "status": "inprogress", + "terminology": { + "contextName": "Segmentation category and type - 3D Slicer General Anatomy list", + "category": ["SCT", "123037004", "Anatomical Structure"], + "type": ["SCT", "122494005", "Cervical spine"] }, + "tags": { + "Some field": "some value" } + } + ] + } + """ + try: + import nrrd + except ModuleNotFoundError: + raise ImportError("The `pynrrd` package is required but not installed. Install it with `pip install pynrrd`.") from None + + if skip_voxels: + header = nrrd.read_header(filename) + voxels = None + else: + voxels, header = nrrd.read(filename) + + segmentation = OrderedDict() + + segments_fields = {} # map from segment index to key:value map + + spaceToLps = np.eye(4) + ijkToSpace = np.eye(4) + + # Store header fields + for header_key in header: + if header_key in ["type", "endian", "dimension", "sizes"]: + # these are stored in the voxel array, it would be redundant to store in metadata + continue + + if header_key == "space": + if header[header_key] == "left-posterior-superior": + spaceToLps = np.eye(4) + elif header[header_key] == "right-anterior-superior": + spaceToLps = np.diag([-1.0, -1.0, 1.0, 1.0]) + else: + # LPS and RAS are the most commonly used image orientations, for now we only support these + raise OSError("space field must be 'left-posterior-superior' or 'right-anterior-superior'") + continue + elif header_key == "kinds": + if header[header_key] == ["domain", "domain", "domain"]: + # multiple_layers = False + pass + elif header[header_key] == ["list", "domain", "domain", "domain"]: + # multiple_layers = True + pass + else: + raise OSError("kinds field must be 'domain domain domain' or 'list domain domain domain'") + continue + elif header_key == "space origin": + ijkToSpace[0:3, 3] = header[header_key] + continue + elif header_key == "space directions": + space_directions = header[header_key] + if space_directions.shape[0] == 4: + # 4D segmentation, skip first (nan) row + ijkToSpace[0:3, 0:3] = header[header_key][1:4, 0:3].T + else: + ijkToSpace[0:3, 0:3] = header[header_key].T + continue + elif header_key == "Segmentation_ContainedRepresentationNames": + # Segmentation_ContainedRepresentationNames:=Binary labelmap|Closed surface| + representations = header[header_key].split("|") + representations[:] = [item for item in representations if item != ""] # Remove empty elements + segmentation["containedRepresentationNames"] = representations + continue + elif header_key == "Segmentation_ConversionParameters": + parameters = [] + # Segmentation_ConversionParameters:=Collapse labelmaps|1|Merge the labelmaps into as few...&Compute surface normals|1|Compute...&Crop to reference image geometry|0|Crop the model...& + parameters_str = header[header_key].split("&") + for parameter_str in parameters_str: + if not parameter_str.strip(): + # empty parameter description is ignored + continue + parameter_info = parameter_str.split("|") + if len(parameter_info) != 3: + raise OSError("Segmentation_ConversionParameters field value is invalid (each parameter must be defined by 3 strings)") + parameters.append({"name": parameter_info[0], "value": parameter_info[1], "description": parameter_info[2]}) + if parameters: + segmentation["conversionParameters"] = parameters + continue + elif header_key == "Segmentation_MasterRepresentation": + # Segmentation_MasterRepresentation:=Binary labelmap + segmentation["masterRepresentation"] = header[header_key] + continue + elif header_key == "Segmentation_ReferenceImageExtentOffset": + # Segmentation_ReferenceImageExtentOffset:=0 0 0 + segmentation["referenceImageExtentOffset"] = [int(i) for i in header[header_key].split(" ")] + continue + + segment_match = re.match("^Segment([0-9]+)_(.+)", header_key) + if segment_match: + # Store in segment_fields (segmentation field) + segment_index = int(segment_match.groups()[0]) + segment_key = segment_match.groups()[1] + if segment_index not in segments_fields: + segments_fields[segment_index] = {} + segments_fields[segment_index][segment_key] = header[header_key] + continue + + segmentation[header_key] = header[header_key] + + # Compute voxel to physical transformation matrix + ijkToLps = np.dot(spaceToLps, ijkToSpace) + segmentation["ijkToLPS"] = ijkToLps + + segmentation["voxels"] = voxels + + # Process segment_fields to build segment_info + + # Get all used segment IDs (necessary for creating unique segment IDs) + segment_ids = set() + for segments_field in segments_fields.values(): + if "ID" in segments_field: + segment_ids.add(segments_field["ID"]) + + # Store segment metadata in segments_info + segments_info = [] + for segment_index in sorted(segments_fields.keys()): + segment_fields = segments_fields[segment_index] + if "ID" in segment_fields: # Segment0_ID:=Segment_1 + segment_id = segment_fields["ID"] + else: + segment_id = _generate_unique_segment_id(segment_ids) + segment_ids.add(segment_id) + log.on_fail(f"Segment ID was not found for index {segment_index}, use automatically generated ID: {segment_id}", verbose=verbos) + + segment_info = {} + segment_info["id"] = segment_id + if "Color" in segment_fields: + segment_info["color"] = [float(i) for i in segment_fields["Color"].split(" ")] # Segment0_Color:=0.501961 0.682353 0.501961 + if "ColorAutoGenerated" in segment_fields: + segment_info["colorAutoGenerated"] = int(segment_fields["ColorAutoGenerated"]) != 0 # Segment0_ColorAutoGenerated:=1 + if "Extent" in segment_fields: + segment_info["extent"] = [int(i) for i in segment_fields["Extent"].split(" ")] # Segment0_Extent:=68 203 53 211 24 118 + if "LabelValue" in segment_fields: + segment_info["labelValue"] = int(segment_fields["LabelValue"]) # Segment0_LabelValue:=1 + if "Layer" in segment_fields: + segment_info["layer"] = int(segment_fields["Layer"]) # Segment0_Layer:=0 + if "Name" in segment_fields: + segment_info["name"] = segment_fields["Name"] # Segment0_Name:=Segment_1 + if "NameAutoGenerated" in segment_fields: + segment_info["nameAutoGenerated"] = int(segment_fields["NameAutoGenerated"]) != 0 # Segment0_NameAutoGenerated:=1 + # Segment0_Tags:=Segmentation.Status:inprogress|TerminologyEntry:Segmentation category and type - 3D Slicer General Anatomy list + # ~SCT^85756007^Tissue~SCT^85756007^Tissue~^^~Anatomic codes - DICOM master list~^^~^^| + if "Tags" in segment_fields: + tags = {} + tags_str = segment_fields["Tags"].split("|") + for tag_str in tags_str: + tag_str = tag_str.strip() # noqa: PLW2901 + if not tag_str: + continue + key, value = tag_str.split(":", maxsplit=1) + # Process known tags: TerminologyEntry and Segmentation.Status, store all other tags as they are + if key == "TerminologyEntry": + segment_info["terminology"] = _terminology_entry_from_string(value) + elif key == "Segmentation.Status": + segment_info["status"] = value + else: + tags[key] = value + if tags: + segment_info["tags"] = tags + segments_info.append(segment_info) + + segmentation["segments"] = segments_info + + return segmentation + + +def _write_segmentation(file, segmentation, compression_level=9, index_order=None): + """ + Extracts segments from a segmentation volume and header. + :param segmentation: segmentation metadata and voxels + """ + try: + import nrrd + except ModuleNotFoundError: + raise ImportError("The `pynrrd` package is required but not installed. Install it with `pip install pynrrd`.") from None + + voxels = segmentation["voxels"] + if voxels is None: + raise ValueError("Segmentation does not contain voxels") + + # Copy non-segmentation fields to the extracted header + output_header = {} + ijkToLPS = None + for key in segmentation: + if key == "voxels": + # written separately + continue + if key == "segments": + # written later + continue + elif key == "ijkToLPS": + # image geometry will be set later in space directions, space origin fields + ijkToLPS = segmentation[key] + continue + elif key == "containedRepresentationNames": + # Segmentation_ContainedRepresentationNames:=Binary labelmap|Closed surface| + # An extra closing "|" is added as this is requires by some older Slicer versions. + representations = "|".join(segmentation[key]) + "|" + output_header["Segmentation_ContainedRepresentationNames"] = representations + elif key == "conversionParameters": + # Segmentation_ConversionParameters:=Collapse labelmaps|1|Merge the labelmaps into as few...&Compute surface normals|1|Compute...&Crop to reference image geometry|0|Crop the model...& + parameters_str = "" + parameters = segmentation[key] + for parameter in parameters: + if parameters_str != "": + parameters_str += "&" + parameters_str += f"{parameter['name']}|{parameter['value']}|{parameter['description']}" + output_header["Segmentation_ConversionParameters"] = parameters_str + elif key == "masterRepresentation": + # Segmentation_MasterRepresentation:=Binary labelmap + output_header["Segmentation_MasterRepresentation"] = segmentation[key] + elif key == "referenceImageExtentOffset": + # Segmentation_ReferenceImageExtentOffset:=0 0 0 + offset = segmentation[key] + output_header["Segmentation_ReferenceImageExtentOffset"] = " ".join([str(i) for i in offset]) + else: + output_header[key] = segmentation[key] + + # Add kinds, space directions, space origin to the header + # kinds: list domain domain domain + kinds = ["domain", "domain", "domain"] + + # space directions: (0,1,0) (0,0,-1) (-1.2999954223632812,0,0) + # 'space directions', array([ + # [ 0. , 1. , 0. ], + # [ 0. , 0. , -1. ], + # [-1.29999542, 0. , 0. ]])) + assert ijkToLPS is not None, "no spical directions" + space_directions = np.array(ijkToLPS)[0:3, 0:3].T + + # Add 4th dimension metadata if array is 4-dimensional (there are overlapping segments) + dims = len(voxels.shape) + if dims == 4: + # kinds: list domain domain domain + # ('kinds', ['list', 'domain', 'domain', 'domain']) + kinds = ["list", *kinds] + # space directions: none (0,1,0) (0,0,-1) (-1.2999954223632812,0,0) + # 'space directions', array([ + # [ nan, nan, nan], + # [ 0. , 1. , 0. ], + # [ 0. , 0. , -1. ], + # [-1.29999542, 0. , 0. ]])) + space_directions = np.vstack(([np.nan, np.nan, np.nan], space_directions)) + elif dims != 3: + raise ValueError("Unsupported number of dimensions: " + str(dims)) + + output_header["kinds"] = kinds + output_header["space directions"] = space_directions + output_header["space origin"] = np.array(ijkToLPS)[0:3, 3] + output_header["space"] = "left-posterior-superior" # DICOM uses LPS coordinate system + + # Set defaults + if "encoding" not in segmentation: + output_header["encoding"] = "gzip" + if "referenceImageExtentOffset" not in segmentation: + output_header["Segmentation_ReferenceImageExtentOffset"] = "0 0 0" + if "masterRepresentation" not in segmentation: + output_header["Segmentation_MasterRepresentation"] = "Binary labelmap" + + # Add segments fields to the header + + # Get list of segment IDs (needed if we need to generate new ID) + segment_ids = set() + for _, segment in enumerate(segmentation["segments"]): + if "id" in segment: + segment_ids.add(segment["id"]) + + for output_segment_index, segment in enumerate(segmentation["segments"]): + # Copy all segment fields corresponding to this segment + output_tags = [] + for segment_key in segment: + if segment_key == "labelValue": + # Segment0_LabelValue:=1 + field_name = "LabelValue" + value = str(segment[segment_key]) + elif segment_key == "layer": + # Segment0_Layer:=0 + field_name = "Layer" + value = str(segment[segment_key]) + elif segment_key == "name": + # Segment0_Name:=Segment_1 + field_name = "Name" + value = segment[segment_key] + elif segment_key == "id": + # Segment0_ID:=Segment_1 + field_name = "ID" + value = segment[segment_key] + elif segment_key == "color": + # Segment0_Color:=0.501961 0.682353 0.501961 + field_name = "Color" + value = " ".join([str(i) for i in segment[segment_key]]) + elif segment_key == "nameAutoGenerated": + # Segment0_NameAutoGenerated:=1 + field_name = "NameAutoGenerated" + value = 1 if segment[segment_key] else 0 + elif segment_key == "colorAutoGenerated": + # Segment0_ColorAutoGenerated:=1 + field_name = "ColorAutoGenerated" + value = 1 if segment[segment_key] else 0 + # Process information stored in tags, for example: + # Segment0_Tags:=Segmentation.Status:inprogress|TerminologyEntry:Segmentation category and type - 3D Slicer General Anatomy list + # ~SCT^85756007^Tissue~SCT^85756007^Tissue~^^~Anatomic codes - DICOM master list~^^~^^| + elif segment_key == "terminology": + # Terminology is stored in a tag + terminology_str = _terminology_entry_to_string(segment[segment_key]) + output_tags.append(f"TerminologyEntry:{terminology_str}") + # Add tags later + continue + elif segment_key == "status": + # Segmentation status is stored in a tag + output_tags.append(f"Segmentation.Status:{segment[segment_key]}") + # Add tags later + continue + elif segment_key == "tags": + # Other tags + tags = segment[segment_key] + for tag_key in tags: + output_tags.append(f"{tag_key}:{tags[tag_key]}") # noqa: PERF401 + # Add tags later + continue + elif segment_key == "extent": + # Segment0_Extent:=68 203 53 211 24 118 + field_name = "Extent" + value = " ".join([str(i) for i in segment[segment_key]]) + else: + field_name = segment_key + value = segment[segment_key] + + output_header[f"Segment{output_segment_index}_{field_name}"] = value + + if "id" not in segment: + # If user has not specified ID, generate a unique one + new_segment_id = _generate_unique_segment_id(segment_ids) + output_header[f"Segment{output_segment_index}_ID"] = new_segment_id + segment_ids.add(new_segment_id) + + if "layer" not in segment: + # If user has not specified layer, set it to 0 + output_header[f"Segment{output_segment_index}_Layer"] = "0" + + if "extent" not in segment: + # If user has not specified extent, set it to the full extent + output_shape = voxels.shape[-3:] + output_header[f"Segment{output_segment_index}_Extent"] = ( + f"0 {output_shape[0] - 1} 0 {output_shape[1] - 1} 0 {output_shape[2] - 1}" + ) + + # Add tags + # Need to end with "|" as earlier Slicer versions require this + output_header[f"Segment{output_segment_index}_Tags"] = "|".join(output_tags) + "|" + + # Write segmentation to file + if index_order is None: + index_order = "F" + import nrrd + + nrrd.write(file, voxels, output_header, compression_level=compression_level, index_order=index_order) + + +def _terminology_entry_from_string(terminology_str): + """Converts a terminology string to a dict. + + Example terminology string: + + Segmentation category and type - 3D Slicer General Anatomy list + ~SCT^49755003^Morphologically Altered Structure + ~SCT^4147007^Mass + ~^^ + ~Anatomic codes - DICOM master list + ~SCT^23451007^Adrenal gland + ~SCT^24028007^Right + + Resulting dict: + + { + 'contextName': 'Segmentation category and type - 3D Slicer General Anatomy list', + 'category': ['SCT', '49755003', 'Morphologically Altered Structure'], + 'type': ['SCT', '4147007', 'Mass'], + 'anatomicContextName': 'Anatomic codes - DICOM master list', + 'anatomicRegion': ['SCT', '23451007', 'Adrenal gland'], + 'anatomicRegionModifier': ['SCT', '24028007', 'Right'] + } + + Specification of terminology entry string is available at + https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag + """ + + terminology_items = terminology_str.split("~") + + terminology = {} + terminology["contextName"] = terminology_items[0] + + terminology["category"] = terminology_items[1].split("^") + terminology["type"] = terminology_items[2].split("^") + typeModifier = terminology_items[3].split("^") + if any(item != "" for item in typeModifier): + terminology["typeModifier"] = typeModifier + + anatomicContextName = terminology_items[4] + if anatomicContextName: + terminology["anatomicContextName"] = anatomicContextName + anatomicRegion = terminology_items[5].split("^") + if any(item != "" for item in anatomicRegion): + terminology["anatomicRegion"] = anatomicRegion + anatomicRegionModifier = terminology_items[6].split("^") + if any(item != "" for item in anatomicRegionModifier): + terminology["anatomicRegionModifier"] = anatomicRegionModifier + + return terminology + + +def _terminology_entry_to_string(terminology): + """Converts a terminology dict to string.""" + terminology_str = "" + + if "contextName" in terminology: + terminology_str += terminology["contextName"] + else: + terminology_str += "" + terminology_str += "~" + "^".join(terminology["category"]) + terminology_str += "~" + "^".join(terminology["type"]) + typeModifier = terminology.get("typeModifier", ["", "", ""]) + terminology_str += "~" + "^".join(typeModifier) + + if "anatomicContextName" in terminology: + terminology_str += "~" + terminology["anatomicContextName"] + else: + terminology_str += "~" + anatomic_region = terminology.get("anatomicRegion", ["", "", ""]) + terminology_str += "~" + "^".join(anatomic_region) + anatomic_region_modifier = terminology.get("anatomicRegionModifier", ["", "", ""]) + terminology_str += "~" + "^".join(anatomic_region_modifier) + + return terminology_str + + +def _generate_unique_segment_id(existing_segment_ids): + """Generate a unique segment ID, i.e., an ID that is not among existing_segment_ids. + It follows DICOM convention to allow using this ID in DICOM Segmentation objects.""" + import uuid + + while True: + segment_id = f"2.25.{uuid.uuid4().int}" + if segment_id not in existing_segment_ids: + return segment_id + + +def remove_not_supported_values(nrrd_dict: dict): + nrrd_dict.pop("conversionParameters", None) + + if "segments" in nrrd_dict: + for i in nrrd_dict["segments"]: + # print(i) + i.pop("extent", None) + + +def load_slicer_nrrd(filename, seg, skip_voxels=False, verbos=True) -> NII: + """ + Load a 3D/4D Slicer NRRD segmentation and return a NII object (wrapper around NIfTI). + + :param filename: path to .seg.nrrd file + :param skip_voxels: if True, only metadata is read + :param logging: logger + :return: NII object with segmentation data + """ + import nibabel as nib + import numpy as np + + from TPTBox import NII + + # Read segmentation + nrrd_dict = _read(filename, skip_voxels=skip_voxels, verbos=verbos) + + # Voxel array + arr = nrrd_dict.pop("voxels") + if arr is None: + raise ValueError("Segmentation file does not contain voxel data") + remove_not_supported_values(nrrd_dict) + + # Affine: convert ijkToLPS to nifti RAS affine + ijkToLPS = nrrd_dict.pop("ijkToLPS") + + # Convert LPS -> RAS for NIfTI + # NIfTI uses RAS; LPS = diag(-1,-1,1) * RAS + lps2ras = np.diag([-1, -1, 1, 1]) + affine = lps2ras @ ijkToLPS + + # Background value (usually 0) + c_val = 0 + + # Create NIfTI object + nib_obj = nib.nifti1.Nifti1Image(arr, affine) + + # Wrap in NII + return NII(nib_obj, seg=seg, c_val=c_val, desc="", info=nrrd_dict) + + +def save_slicer_nrrd(nii: NII, file: str | Path, make_parents=True, verbose: logging = True, compression_level=9, index_order=None): + """ + Save a NII object (segmentation) to a Slicer-compatible NRRD file. + + :param nii: NII object with seg=True + :param filename: path to save .seg.nrrd + :param compression_level: gzip compression level + :param index_order: NRRD index order, e.g., "F" for Fortran + """ + + if not str(file).endswith("nrrd"): + file = str(file) + (".seg.nrrd" if nii.seg else ".nrrd") + if make_parents: + Path(file).parent.mkdir(0o771, exist_ok=True, parents=True) + # _header = {} + # ori = "left-posterior-superior" + # data = self.get_array() + # affine = self.affine.copy() + # affine[0] *= -1 + # affine[1] *= -1 + ## Extract header fields from the affine matrix + # n = affine.shape[0] - 1 + # space_directions = affine[:n, :n] + # space_origin = affine[:n, n] + # _header["kinds"] = ["domain"] * n if "kinds" not in self.info else self.info["kinds"] + # header = { + # "type": str(data.dtype), + # "dimension": n, + # "space": ori, + # "sizes": data.shape, # (data.shape[1],data.shape[0],data.shape[2]), + # "space directions": space_directions.tolist(), + # "space origin": space_origin, + # "endian": "little", + # "encoding": "gzip", + # **_header, + # **self.info, + # } + ## Save NRRD file + # + # nrrd.write(str(file), data=data, header=header, **args) # nrrd only acepts strings... + + # Get voxel array + arr = nii.get_array() + + # Get affine matrix + affine = nii.affine + + # Convert RAS → LPS for Slicer + ras2lps = np.diag([-1, -1, 1, 1]) + ijkToLPS = ras2lps @ affine + info = getattr(nii, "info", {}) + # Build minimal segmentation dict + info.pop("conversionParameters", None) + + segmentation = { + "voxels": arr, + "ijkToLPS": ijkToLPS, + "encoding": info.get("encoding", "gzip"), + **info, + } + if nii.seg: + segmentation["containedRepresentationNames"] = info.get("containedRepresentationNames", ["Binary labelmap"]) + segmentation["masterRepresentation"] = info.get("masterRepresentation", "Binary labelmap") + segmentation["referenceImageExtentOffset"] = info.get("referenceImageExtentOffset", [0, 0, 0]) + remove_not_supported_values(segmentation) + # Write using previously defined write_segmentation + log.print(f"Saveing {file}", verbose=verbose, ltype=Log_Type.SAVE, end="\r") + _write_segmentation(str(file), segmentation, compression_level=compression_level, index_order=index_order) + log.print(f"Save {file} ", verbose=verbose, ltype=Log_Type.SAVE) + + +if __name__ == "__main__": + # { + # "voxels": np.ndarray or None, + # "ijkToLPS": 4x4 np.ndarray, + # "encoding": "gzip" or other, + # "containedRepresentationNames": [...], + # "conversionParameters": [...], + # "masterRepresentation": "Binary labelmap", + # "referenceImageExtentOffset": [...], + # "segments": [ {...}, {...}, ... ], + # ... (any additional non-segment NRRD header fields) + # } + import os + + import numpy as np + import requests + + slicerio_data = Path(__file__).parent / "slicerio_data" + + def download_file(url: str, out_path: str): + """Download a URL to a local file.""" + resp = requests.get(url, stream=True) + resp.raise_for_status() + with open(out_path, "wb") as f: + f.writelines(resp.iter_content(chunk_size=8192)) + print(f"Downloaded {url} → {out_path}") + + # Create data directory + os.makedirs(slicerio_data, exist_ok=True) + base = "https://raw.githubusercontent.com/lassoan/slicerio/main/slicerio/data" + + files = { + "CT": f"{base}/CTChest4.nrrd", + "Seg": f"{base}/Segmentation.seg.nrrd", + "SegOverlap": f"{base}/SegmentationOverlapping.seg.nrrd", + } + # Download + for url in files.values(): + out_local = os.path.join(slicerio_data, url.split("/")[-1]) + if not os.path.exists(out_local): + download_file(url, out_local) + + # Test loading/saving + def test_file(path: str): + print(f"\nTesting: {path}") + seg = "seg." in path + # Load + nii = load_slicer_nrrd(path, seg) + arr = nii + print("Shape:", arr, "Unique labels:", np.unique(arr)) + + # Save to new file + out_seg = path.replace(".seg.nrrd", ".roundtrip.seg.nrrd") + save_slicer_nrrd(nii, out_seg) + + # Reload saved file + nii2 = load_slicer_nrrd(out_seg, seg) + arr2 = nii2.get_array() + + # Compare + same = np.array_equal(arr, arr2) + print(f"Round trip match: {same}") + # Path(out_seg).unlink(missing_ok=True) + + for file in os.listdir(slicerio_data): + filepath = os.path.join(slicerio_data, file) + if "roundtrip.seg" in filepath: + continue + test_file(filepath) diff --git a/TPTBox/core/nii_poi_abstract.py b/TPTBox/core/nii_poi_abstract.py index 0ca2632..13f3b64 100755 --- a/TPTBox/core/nii_poi_abstract.py +++ b/TPTBox/core/nii_poi_abstract.py @@ -125,12 +125,45 @@ def change_affine( inplace=False, ): """ - Apply a transformation (translation, rotation, scaling) to the affine matrix. - - Parameters: - translation: (n,) array-like in mm in (R, A, S) - rotation_degrees: (n,) array-like (pitch, yaw, roll) in degrees - scaling: (n,) array-like scaling factors along x, y, z + Apply a transformation (scaling, rotation, translation) to the affine matrix. + + Assumptions + ----------- + - `self.affine` is a square homogeneous affine matrix of shape (n, n), + where the spatial dimensionality is n-1 (typically n=4 for 3D). + - The affine follows the convention: + x_world = A @ x_homogeneous + - Transformations are applied in the following order (right-multiplied): + 1. Scaling + 2. Rotation + 3. Translation + i.e. the final update is: + self.affine = (T @ R @ S) @ self.affine + - Rotation is specified as Euler angles in the "xyz" convention + (pitch, yaw, roll) using scipy.spatial.transform.Rotation. + - Translation is specified in world units (e.g. mm) in (x, y, z) + corresponding to the affine axes. + - Scaling is applied along the affine axes, not object-local axes. + - If `inplace=False`, a copy of the object is returned. + If `inplace=True`, the object is modified in place. + + Parameters + ---------- + translation : (n-1,) array-like, optional + Translation vector in world coordinates. + rotation_degrees : (n-1,) array-like, optional + Euler angles (x, y, z) in degrees by default. + scaling : (n-1,) array-like, optional + Scaling factors along each axis. + degrees : bool, default=True + Whether rotation angles are given in degrees. + inplace : bool, default=False + Whether to modify the object in place. + + Returns + ------- + self or copy of self + Object with updated affine. """ # warnings.warn("change_affine is untested", stacklevel=2) n = self.affine.shape[0] @@ -312,21 +345,6 @@ def get_axis(self, direction: DIRECTIONS = "S"): direction = _same_direction[direction] return self.orientation.index(direction) - def get_empty_POI(self, points: dict | None = None): - warnings.warn("get_empty_POI id deprecated use make_empty_POI instead", stacklevel=5) # TODO remove in version 1.0 - - from TPTBox import POI - - p = {} if points is None else points - return POI( - p, - orientation=self.orientation, - zoom=self.zoom, - shape=self.shape, - rotation=self.rotation, - origin=self.origin, - ) - def make_empty_POI(self, points: dict | None = None): from TPTBox import POI diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 93c62fa..8bb0b89 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -248,7 +248,7 @@ def load(cls, path: Image_Reference, seg, c_val=None)-> Self: return nii @classmethod - def load_nrrd(cls, path: str | Path, seg: bool): + def load_nrrd(cls, path: str | Path, seg: bool,verbos=False): """ Load an NRRD file and convert it into a Nifti1Image object. @@ -272,76 +272,8 @@ def load_nrrd(cls, path: str | Path, seg: bool): import nrrd # pip install pynrrd, if pynrrd is not already installed except ModuleNotFoundError: raise ImportError("The `pynrrd` package is required but not installed. Install it with `pip install pynrrd`.") from None - _nrrd = nrrd.read(path) - data = _nrrd[0] - - header = dict(_nrrd[1]) - #print(data.shape, header) - #print(header) - # Example print out: OrderedDict([ - # ('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), - # ('sizes', array([512, 512, 1637])), - # ('space directions', array([[0.9765625, 0. , 0. ], - # [0. , 0.9765625, 0. ], - # [0. , 0. , 0.6997555]])), - # ('kinds', ['domain', 'domain', 'domain']), ('endian', 'little'), - # ('encoding', 'gzip'), - # ('space origin', array([-249.51171875, -392.51171875, 119.7]))]) - - # Construct the affine transformation matrix - #print(header) - try: - #print(header['space directions']) - #print(header['space origin']) - space_directions = np.array(header['space directions']) - space_origin = np.array(header['space origin']) - #space_directions = space_directions[~np.isnan(space_directions).any(axis=1)] #Filter NAN - n = header['dimension'] - #print(data.shape) - if space_directions.shape != (n, n): - space_directions = space_directions[~np.isnan(space_directions).all(axis=1)] - m = len(space_directions[0]) - if m != n: - n=m - data = data.sum(axis=0) - space_directions = space_directions.T - if space_directions.shape != (n, n): - raise ValueError(f"Expected 'space directions' to be a nxn matrix. n = {n} is not {space_directions.shape}",space_directions) - if space_origin.shape != (n,): - raise ValueError("Expected 'space origin' to be a n-element vector. n = ", n, "is not",space_origin.shape ) - space = header.get("space","left-posterior-superior") - affine = np.eye(n+1) # Initialize 4x4 identity matrix - affine[:n, :n] = space_directions # Set rotation and scaling - affine[:n, n] = space_origin # Set translation - #print(affine,space) - if space =="left-posterior-superior": #LPS (SITK-space) - affine[0] *=-1 - affine[1] *=-1 - elif space == "right-posterior-superior": #RPS - affine[0] *=-1 - elif space == "left-anterior-superior": #LAS - affine[1] *=-1 - elif space == "right-anterior-superior": #RAS - pass - else: - raise ValueError(space) - #print(affine) - - except KeyError as e: - raise KeyError(f"Missing expected header field: {e}") from None - if len(data.shape) != n: - raise ValueError(f"{len(data.shape)=} diffrent from n = ", n) - ref_orientation = header.get("ref_orientation") - for i in ["ref_orientation","dimension","space directions","space origin""space","type","endian"]: - header.pop(i, None) - for key in list(header.keys()): - if "_Extent" in key: - del header[key] - nii = NII((data,affine,None),seg=seg,info = header) - if ref_orientation is not None: - nii.reorient_(ref_orientation) - return nii - + from TPTBox.core.internal.slicer_nrrd import load_slicer_nrrd + return load_slicer_nrrd(path,seg,verbos=verbos) @classmethod def load_bids(cls, nii_bids: bids_files.BIDS_FILE): nifty = None @@ -1355,7 +1287,7 @@ def fill_holes(self, labels: LABEL_REFERENCE = None, slice_wise_dim: int|str | N filled = np_fill_holes(seg_arr, label_ref=labels, slice_wise_dim=slice_wise_dim, use_crop=use_crop) return self.set_array(filled,inplace=inplace) - def fill_holes_(self, labels: LABEL_REFERENCE = None, slice_wise_dim: int | None = None, verbose:logging=True,use_crop=True): + def fill_holes_(self, labels: LABEL_REFERENCE = None, slice_wise_dim: int |str| None = None, verbose:logging=True,use_crop=True): return self.fill_holes(labels, slice_wise_dim, verbose, inplace=True,use_crop=use_crop) def calc_convex_hull( @@ -1376,7 +1308,7 @@ def calc_convex_hull( return self.set_array_(convex_hull_arr) return self.set_array(convex_hull_arr) - def calc_convex_hull_(self, axis: DIRECTIONS="S", verbose: bool = False,): + def calc_convex_hull_(self, axis: None|DIRECTIONS="S", verbose: bool = False,): return self.calc_convex_hull(axis=axis, inplace=True, verbose=verbose) @@ -1662,67 +1594,6 @@ def truncate_labels_beyond_reference( ): return self.truncate_labels_beyond_reference_(idx,not_beyond,fill,axis,inclusion) - def infect_conv(self: NII, reference_mask: NII, max_iters=100,inplace=False): - """ - Expands labels from self_mask into regions of reference_mask == 1 via breadth-first diffusion. - - Args: - self_mask (ndarray): (H, W) or (D, H, W) integer-labeled array. - reference_mask (ndarray): Binary array of same shape as self_mask. - max_iters (int): Maximum number of propagation steps. - - Returns: - ndarray: Updated label mask. - """ - from scipy.ndimage import convolve - crop = reference_mask.compute_crop(0,1) - self.assert_affine(reference_mask) - self_mask = self.apply_crop(crop).get_seg_array().copy() - ref_mask = np.clip(reference_mask.apply_crop(crop).get_seg_array(), 0, 1) - - ndim = len(self_mask.shape) - - # Define neighborhood kernel - if ndim == 2: - kernel = np.array([[0, 1, 0], - [1, 0, 1], - [0, 1, 0]], dtype=np.uint8) - elif ndim == 3: - kernel = np.zeros((3, 3, 3), dtype=np.uint8) - kernel[1, 1, 0] = kernel[1, 1, 2] = 1 - kernel[1, 0, 1] = kernel[1, 2, 1] = 1 - kernel[0, 1, 1] = kernel[2, 1, 1] = 1 - else: - raise NotImplementedError("Only 2D or 3D masks are supported.") - try: - from tqdm import tqdm - r = tqdm(range(max_iters),desc="infect") - except Exception: - r = range(max_iters) - for _ in r: - unlabeled = (self_mask == 0) & (ref_mask == 1) - updated = False - - for label in np_unique(self_mask): - if label == 0: - continue # skip background - - binary_label_mask = (self_mask == label).astype(np.uint8) - neighbor_count = convolve(binary_label_mask, kernel, mode="constant", cval=0) - - # Find unlabeled voxels adjacent to current label - new_voxels = (neighbor_count > 0) & unlabeled - - if np.any(new_voxels): - self_mask[new_voxels] = label - updated = True - - if not updated: - break - org = self.get_seg_array() - org[crop] = self_mask - return self.set_array(org,inplace=inplace) - def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|str|None=None): """ Expands labels from self_mask into regions of reference_mask == 1 via breadth-first diffusion. @@ -1846,7 +1717,7 @@ def flip(self, axis:int|str,keep_global_coords=True,inplace=False): axis = self.get_axis(axis) if not isinstance(axis,int) else axis if keep_global_coords: orient = list(self.orientation) - orient[axis] = _same_direction[orient[axis] ] + orient[axis] = _same_direction[orient[axis]] return self.reorient(tuple(orient),inplace=inplace) else: return self.set_array(np.flip(self.get_array(),axis),inplace=inplace) @@ -1856,8 +1727,7 @@ def clone(self): @secure_save def save(self,file:str|Path,make_parents=True,verbose:logging=True, dtype = None): if make_parents: - Path(file).parent.mkdir(exist_ok=True,parents=True) - + Path(file).parent.mkdir(0o777,exist_ok=True,parents=True) arr = self.get_array() if not self.seg else self.get_seg_array() if isinstance(arr,np.floating) and self.seg: self.set_dtype_("smallest_uint") @@ -1894,43 +1764,10 @@ def save_nrrd(self:Self, file: str | Path|bids_files.BIDS_FILE,make_parents=True raise ImportError("The `pynrrd` package is required but not installed. Install it with `pip install pynrrd`." ) from None if isinstance(file, bids_files.BIDS_FILE): file = file.file['nrrd'] - if not str(file).endswith(".nrrd"): - file = str(file)+".nrrd" - if make_parents: - Path(file).parent.mkdir(exist_ok=True,parents=True) - _header = {} - #if self.orientation not in [("L","P","S")]: #,("R","P","S"),("R","A","S"),("L","A","S") - # _header = {"ref_orientation": "".join(self.orientation)} - # self = self.reorient(("P","L","S")) # Convert to LAS-SimpleITK # noqa: PLW0642 - # Slicer only allows LPS and flip of L and P axis - ori = "left-posterior-superior"# "-".join([_dirction_name_itksnap_dict[i] for i in self.orientation]) - data = self.get_array() - affine = self.affine.copy() - affine[0] *=-1 - affine[1] *=-1 - # Extract header fields from the affine matrix - n = affine.shape[0] - 1 - space_directions = affine[:n, :n] - space_origin = affine[:n, n] - _header["kinds"]= ['domain'] * n if "kinds" not in self.info else self.info["kinds"] - header = { - 'type': str(data.dtype), - 'dimension': n, - 'space': ori, - 'sizes': data.shape,#(data.shape[1],data.shape[0],data.shape[2]), - 'space directions': space_directions.tolist(), - 'space origin': space_origin, - 'endian': 'little', - 'encoding': 'gzip', - **_header,**self.info - } - header.pop("Segmentation_ConversionParameters", None) - # Save NRRD file + from TPTBox.core.internal.slicer_nrrd import save_slicer_nrrd + save_slicer_nrrd(self,file,make_parents=make_parents,verbose=verbose,**args) - log.print(f"Saveing {file}",verbose=verbose,ltype=Log_Type.SAVE,end='\r') - nrrd.write(str(file), data=data, header=header,**args) # nrrd only acepts strings... - log.print(f"Save {file} as {header['type']}",verbose=verbose,ltype=Log_Type.SAVE) def __str__(self) -> str: return f"{super().__str__()}, seg={self.seg}" # type: ignore @@ -1970,6 +1807,8 @@ def __getitem__(self, key)-> Any: return self.get_array()[key.get_array()==1] elif isinstance(key,np.ndarray): return self.get_array()[key] + elif isinstance(key,slice): + self.__getitem__((key,Ellipsis,Ellipsis)) else: raise TypeError("Invalid argument type:", type(key)) def __setitem__(self, key,value): @@ -2086,9 +1925,11 @@ def voxel_volume(self): product = math.prod(self.spacing) return product - def volumes(self, include_zero: bool = False, in_mm3=False) -> dict[int, float]|dict[int, int]: + def volumes(self, include_zero: bool = False, in_mm3=False,sort=False) -> dict[int, float]|dict[int, int]: '''Returns a dict stating how many pixels are present for each label''' dic = np_volume(self.get_seg_array(), include_zero=include_zero) + if sort: + dic = dict(sorted(dic.items())) if in_mm3: voxel_size = self.voxel_volume() dic = {k:v*voxel_size for k,v in dic.items()} diff --git a/TPTBox/core/nii_wrapper_math.py b/TPTBox/core/nii_wrapper_math.py index 8d12fb9..4f6cf2b 100755 --- a/TPTBox/core/nii_wrapper_math.py +++ b/TPTBox/core/nii_wrapper_math.py @@ -56,8 +56,8 @@ def _binary_opt(self, other:C, opt,inplace = False)-> Self: if isinstance(other,NII_Math): other = other.get_array() return self.set_array(opt(self.get_array(),other),inplace=inplace,verbose=False) - def _uni_opt(self, opt,inplace = False)-> Self: - return self.set_array(opt(self.get_array()),inplace=inplace,verbose=False) + def _uni_opt(self, opt,inplace = False,**args)-> Self: + return self.set_array(opt(self.get_array(),**args),inplace=inplace,verbose=False) def __add__(self,p2): return self._binary_opt(p2,operator.add) def __radd__(self,p2): @@ -81,12 +81,20 @@ def __lshift__(self,p2): def __rshift__(self,p2): return self._binary_opt(p2,operator.rshift) def __and__(self,p2): + if not np.issubdtype(self.get_array().dtype, np.integer): + raise TypeError("Bitwise operations require integer arrays") return self._binary_opt(p2,operator.and_) def __or__(self,p2): + if not np.issubdtype(self.get_array().dtype, np.integer): + raise TypeError("Bitwise operations require integer arrays") return self._binary_opt(p2,operator.or_) def __xor__(self,p2): + if not np.issubdtype(self.get_array().dtype, np.integer): + raise TypeError("Bitwise operations require integer arrays") return self._binary_opt(p2,operator.xor) def __invert__(self): + if not np.issubdtype(self.get_array().dtype, np.integer): + raise TypeError("Bitwise operations require integer arrays") return self._uni_opt(operator.invert) def __lt__(self,p2): @@ -125,7 +133,7 @@ def __abs__(self): return self._uni_opt(operator.abs) def __round__(self, decimals=0): - return self._uni_opt(np.round) + return self._uni_opt(np.round,decimals=decimals) def round(self,decimals): return self.__round__(decimals=decimals) def __floor__(self): diff --git a/TPTBox/mesh3D/snapshot3D.py b/TPTBox/mesh3D/snapshot3D.py index d0729e1..12f985e 100644 --- a/TPTBox/mesh3D/snapshot3D.py +++ b/TPTBox/mesh3D/snapshot3D.py @@ -120,7 +120,7 @@ def make_snapshot3D( scene.projection(proj_type="parallel") scene.reset_camera_tight(margin_factor=1.02) window.record( - scene, + scene=scene, size=window_size, out_path=output_path, reset_camera=False, diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index fac4ec6..f25f5c5 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -18,16 +18,24 @@ model_path = out_base / "nnUNet_results" -def get_ds_info(idx) -> dict: +def get_ds_info(idx, _model_path: str | Path | None = None, exit_one_fail=True) -> dict: + if _model_path is not None: + _model_path = Path(_model_path) + model_path = _model_path / "nnUNet_results" + assert model_path.exists(), model_path + try: nnunet_path = next(next(iter(model_path.glob(f"*{idx}*"))).glob("*__nnUNetPlans*")) except StopIteration: try: nnunet_path = next(next(iter(model_path.glob(f"*{idx}*"))).glob("*__nnUNet*ResEnc*")) except StopIteration: - Print_Logger().print(f"Please add Dataset {idx} to {model_path}", Log_Type.FAIL) - model_path.mkdir(exist_ok=True, parents=True) - sys.exit() + if exit_one_fail: + Print_Logger().print(f"Please add Dataset {idx} to {model_path}", Log_Type.FAIL) + model_path.mkdir(exist_ok=True, parents=True) + sys.exit() + else: + return None with open(Path(nnunet_path, "dataset.json")) as f: ds_info = json.load(f) return ds_info @@ -129,15 +137,15 @@ def run_inference_on_file( og_nii = input_nii[0].copy() try: - zoom_old = ds_info.get("spacing") - if idx not in [527] and zoom_old is not None: - zoom_old = zoom_old[::-1] - - zoom_old = ds_info.get("resolution_range", zoom_old) - if zoom_old is None: - zoom = plans_info["configurations"]["3d_fullres"]["spacing"] - if all(zoom[0] == z for z in zoom): - zoom_old = zoom + zoom = ds_info.get("spacing") + if idx not in [527] and zoom is not None: + zoom = zoom[::-1] + + zoom = ds_info.get("resolution_range", zoom) + if zoom is None: + zoom_ = plans_info["configurations"]["3d_fullres"]["spacing"] + if all(zoom[0] == z for z in zoom_): + zoom = zoom_ # order = plans_info["transpose_backward"] ## order2 = plans_info["transpose_forward"] # zoom = [zoom[order[0]], zoom[order[1]], zoom[order[2]]][::-1] @@ -150,7 +158,7 @@ def run_inference_on_file( # zoom_old = zoom_old[::-1] - zoom_old = [float(z) for z in zoom_old] + zoom = [float(z) for z in zoom] except Exception: pass assert len(ds_info["channel_names"]) == len(input_nii), ( @@ -163,9 +171,9 @@ def run_inference_on_file( print("orientation", orientation, f"{orientation_ref=}") if verbose else None input_nii = [i.reorient(orientation) for i in input_nii] - if zoom_old is not None: - print("rescale", zoom, f"{zoom_old=}") if verbose else None - input_nii = [i.rescale_(zoom_old, mode=mode) for i in input_nii] + if zoom is not None: + print("rescale", input_nii[0].orientation, f"{zoom=}") if verbose else None + input_nii = [i.rescale_(zoom, mode=mode) for i in input_nii] print(input_nii) print("squash to float16") if verbose else None input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii] @@ -199,8 +207,8 @@ def run_inference_on_file( def run_VibeSeg( - img: Path | str | list[Path] | list[NII], - out_path: Path, + img: Path | str | list[Path] | list[NII] | Image_Reference, + out_path: str | Path | None, override=False, dataset_id=None, gpu: int | None = None, @@ -215,7 +223,9 @@ def run_VibeSeg( **_kargs, ): global model_path # noqa: PLW0603 - if out_path.exists() and not override: + if isinstance(out_path, str): + out_path = Path(out_path) + if out_path is not None and out_path.exists() and not override: logger.print(out_path, "already exists. SKIP!", Log_Type.OK) return out_path @@ -243,13 +253,14 @@ def run_VibeSeg( gpu = "auto" # type: ignore logger.print("run", f"{dataset_id=}, {gpu=}", Log_Type.STAGE) ds_info = get_ds_info(dataset_id) - orientation = ds_info["orientation"] - if not isinstance(img, Sequence): + orientation = ds_info.get("orientation", ("R", "A", "S")) + if not isinstance(img, Sequence) or isinstance(img, str): img = [img] + if "roi" in ds_info: raise NotImplementedError("roi") else: - in_niis = [to_nii(i) for i in img] + in_niis = [to_nii(i) for i in img] # type: ignore in_niis = [i.resample_from_to_(in_niis[0]) if i.shape != in_niis[0].shape else i for i in in_niis] if (in_niis[0].affine == np.eye(4)).all(): warn( diff --git a/TPTBox/segmentation/_deface.py b/TPTBox/segmentation/_deface.py new file mode 100644 index 0000000..95ff3a7 --- /dev/null +++ b/TPTBox/segmentation/_deface.py @@ -0,0 +1,75 @@ +from pathlib import Path + +from TPTBox import BIDS_FILE, BIDS_Global_info, Image_Reference, to_nii +from TPTBox.segmentation.VibeSeg.inference_nnunet import run_inference_on_file, run_VibeSeg + + +def compute_deface_mask_cta(ct_img: Image_Reference, outpath: str | Path | None = None, override=False, gpu=None, **args): + """ + Mahmutoglu, M.A., Rastogi, A., Schell, M. et al. Deep learning-based defacing tool for CT angiography: CTA-DEFACE. Eur Radiol Exp 8, 111 (2024). https://doi.org/10.1186/s41747-024-00510-9 + + """ + if isinstance(outpath, str): + outpath = Path(outpath) + if isinstance(ct_img, BIDS_FILE) and outpath is None: + outpath = ct_img.get_changed_path("nii.gz", "msk", parent="derivatives-defacing", info={"seg": "defacting", "mod": ct_img.format}) + if outpath is not None and not override and outpath.exists(): + return outpath + return run_VibeSeg(ct_img, out_path=outpath, dataset_id=1, keep_size=False, override=override, gpu=gpu, **args) + + +if __name__ == "__main__": + # bgi = BIDS_Global_info("/DATA/NAS/datasets_processed/CT_spine/dataset-myelom") + # snps = [] + # msk = [] + # for sub, subj in bgi.iter_subjects(): + # q = subj.new_query(flatten=True) + # q.filter_format(lambda x: "ct" in str(x)) + # q.filter_filetype(["nii.gz", "nii", "nrrd", "mrk"]) + # for ct_img in q.loop_list(): + # compute_deface_mask_cta(ct_img, gpu=5) + # outpath = ct_img.get_changed_path( + # "nii.gz", "msk", parent="derivatives-defacing", info={"seg": "defacting", "mod": ct_img.format} + # ) + # snp = ( + # ct_img.dataset + # / "derivatives-defacing" + # / "snapshots" + # / ct_img.get_changed_path("jpg", "msk", parent="derivatives-defacing", info={"seg": "defacting", "mod": ct_img.format}).name + # ) + # msk.append(outpath) + # snps.append(snp) + # from TPTBox.mesh3D.snapshot3D import make_snapshot3D_parallel + # make_snapshot3D_parallel(msk, snps, ["A", "R", "S"]) + snps = [] + msk = [] + + bgi = BIDS_Global_info( + "/DATA/NAS/ongoing_projects/robert/datasets/Carotis-CoW-Projekt/Carotis-CoW-Projekt/CT_Datensatz_TUM_20250827/", + parents=["CT_CAROTIS"], + ) + for _, subj in bgi.iter_subjects(): + q = subj.new_query(flatten=True) + q.filter_format(lambda x: "ct" in str(x)) + q.filter_filetype(["nii.gz", "nii", "nrrd", "mrk"]) + for ct_img in q.loop_list(): + try: + compute_deface_mask_cta(ct_img, gpu=5) + outpath = ct_img.get_changed_path( + "nii.gz", "msk", parent="derivatives-defacing", info={"seg": "defacting", "mod": ct_img.format} + ) + snp = ( + ct_img.dataset + / "derivatives-defacing" + / "snapshots" + / ct_img.get_changed_path( + "jpg", "msk", parent="derivatives-defacing", info={"seg": "defacting", "mod": ct_img.format} + ).name + ) + msk.append(outpath) + snps.append(snp) + except Exception: + pass + from TPTBox.mesh3D.snapshot3D import make_snapshot3D_parallel + + make_snapshot3D_parallel(msk, snps, ["A", "R", "S"]) diff --git a/TPTBox/spine/snapshot2D/snapshot_modular.py b/TPTBox/spine/snapshot2D/snapshot_modular.py index 99549c8..ff86d95 100755 --- a/TPTBox/spine/snapshot2D/snapshot_modular.py +++ b/TPTBox/spine/snapshot2D/snapshot_modular.py @@ -30,7 +30,7 @@ v_idx2name, v_idx_order, ) -from TPTBox.mesh3D.mesh_colors import _color_map_in_row # vert_color_map +from TPTBox.mesh3D.mesh_colors import _color_map_in_row, get_color_by_label NII.suppress_dtype_change_printout_in_set_array(True) """ @@ -430,6 +430,8 @@ def make_isotropic2dpluscolor(arr3d, zms2d, msk=False): def get_contrasting_stroke_color(rgb): # Convert RGBA to RGB if necessary + if isinstance(rgb, int): + rgb = list(get_color_by_label(rgb).rgb / 255.0) if len(rgb) == 4: rgb = rgb[:3] luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2] @@ -505,7 +507,15 @@ def plot_sag_centroids( elif len(x) == 4: v = (np.array(ctd[x[0], x[1]]) + np.array(ctd[x[2], x[3]])) / 2 - axs.add_patch(FancyArrow(v[1] * zms[1], v[0] * zms[0], c, d, color=cmap(color - 1 % LABEL_MAX % cmap.N))) + axs.add_patch( + FancyArrow( + v[1] * zms[1], + v[0] * zms[0], + c, + d, + color=cmap(color - 1 % LABEL_MAX % cmap.N), + ) + ) if "text_sag" in ctd.info: for color, x in ctd.info["text_sag"]: backgroundcolor = get_contrasting_stroke_color(color) @@ -578,15 +588,29 @@ def plot_cor_centroids( ) except Exception: pass + if "line_segments_cor" in ctd.info: for color, x, (c, d) in ctd.info["line_segments_cor"]: + # if isinstance(color, int): + # color = list(get_color_by_label(color).rgb / 255.0) + if len(x) == 2: v = ctd[x] elif len(x) == 4: v = (np.array(ctd[x[0], x[1]]) + np.array(ctd[x[2], x[3]])) / 2 - axs.add_patch(FancyArrow(v[2] * zms[2], v[0] * zms[0], c, d, color=cmap(color - 1 % LABEL_MAX % cmap.N))) + axs.add_patch( + FancyArrow( + v[2] * zms[2], + v[0] * zms[0], + c, + d, + color=cmap(color - 1 % LABEL_MAX % cmap.N), + ) + ) if "text_cor" in ctd.info: for color, x in ctd.info["text_cor"]: + if isinstance(color, int): + color = list(get_color_by_label(color).rgb / 255.0) # noqa: PLW2901 backgroundcolor = get_contrasting_stroke_color(color) if isinstance(color, Sequence) and len(color) == 2: color, curve_location = color # noqa: PLW2901 @@ -1038,9 +1062,21 @@ def create_snapshot( # noqa: C901 ) fig, axs = create_figure(dpi, img_list, has_title=frame.title is None) - for ax, (img, msk, ctd, wdw, is_sag, alpha, cmap, zms, curve_location, poi_labelmap, hide_centroid_labels, title, frame) in zip( - axs, frame_list - ): + for ax, ( + img, + msk, + ctd, + wdw, + is_sag, + alpha, + cmap, + zms, + curve_location, + poi_labelmap, + hide_centroid_labels, + title, + frame, + ) in zip(axs, frame_list): if title is not None: ax.set_title(title, fontdict={"fontsize": 18, "color": "black"}, loc="center") if img.ndim == 3: diff --git a/TPTBox/spine/spinestats/distances.py b/TPTBox/spine/spinestats/distances.py index 821eb79..2c1a56a 100644 --- a/TPTBox/spine/spinestats/distances.py +++ b/TPTBox/spine/spinestats/distances.py @@ -31,7 +31,7 @@ def _compute_distance( all_pois_computed = True if not all_pois_computed: poi = calc_poi_from_subreg_vert(vert, subreg, extend_to=poi, subreg_id=[l1, l2]) - poi.info[key] = poi.calculate_distances_poi_two_locations(l1, l2, keep_zoom=False) + poi.info[key] = poi.calculate_distances_poi_across_regions(l1, l2, keep_zoom=False) return poi diff --git a/TPTBox/tests/test_utils.py b/TPTBox/tests/test_utils.py index 9342432..2b64ad8 100644 --- a/TPTBox/tests/test_utils.py +++ b/TPTBox/tests/test_utils.py @@ -26,6 +26,15 @@ def get_tests_dir(): return Path(__file__).parent +def get_nii_paths_ct() -> tuple[Path, Path, Path]: + tests_path = get_tests_dir() + ct_path = tests_path.joinpath("sample_ct") + ct = ct_path.joinpath("sub-ct_label-22_ct.nii.gz") + subreg = ct_path.joinpath("sub-ct_seg-subreg_label-22_msk.nii.gz") + vert = ct_path.joinpath("sub-ct_seg-vert_label-22_msk.nii.gz") + return ct, subreg, vert + + def get_test_ct() -> tuple[NII, NII, NII, int]: tests_path = get_tests_dir() ct_path = tests_path.joinpath("sample_ct") diff --git a/pyproject.toml b/pyproject.toml index d9e86f5..d01af0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,10 @@ tqdm = "*" joblib = "*" scikit-learn = "*" antspyx = "0.4.2" +pynrrd = "*" #hf-deepali = "*" -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] pytest = ">=8.1.1" vtk = "*" pre-commit = "*" @@ -39,7 +40,7 @@ pyvista = "^0.43.2" coverage = ">=7.0.1" pytest-mock = "^3.6.0" exceptiongroup = { version = "^1.2", python = "<3.11" } - +tomli = {version = "*", python = "<3.11" } [build-system] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] @@ -175,7 +176,23 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # Flag errors (`C901`) whenever the complexity level exceeds 5. max-complexity = 20 - +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if __name__ == .__main__.:", + "def __repr__", + "except", + "warnings.warn", + "print\\(f?\"\\[!\\].*", + "def __str__(self):", + "def __hash__(self) -> int:", + "assert ", + "if TYPE_CHECKING: ", + "raise " +] +omit = [ + "TPTBox/test/speedtest/*", +] [tool.ruff.format] # Like Black, use double quotes for strings. quote-style = "double" diff --git a/unit_tests/test_auto_segmentation.py b/unit_tests/test_auto_segmentation.py index 41ed643..5b31b2b 100644 --- a/unit_tests/test_auto_segmentation.py +++ b/unit_tests/test_auto_segmentation.py @@ -31,6 +31,14 @@ has_spineps = False +try: + import torch + + has_torch = True +except ModuleNotFoundError: + has_torch = False + + class Test_test_samples(unittest.TestCase): # def test_load_ct(self): # ct_nii, subreg_nii, vert_nii, label = get_test_ct() @@ -103,3 +111,15 @@ def test_VIBESeg_ct(self): assert isinstance(out, (NII, Path)) assert seg_out_path.exists() seg_out_path.unlink(missing_ok=True) + + @unittest.skipIf(not has_torch, "requires torch to be installed") + def test_get_device(self): + import torch + + from TPTBox.core.internal.deep_learning_utils import get_device + + assert get_device("cpu", 0) == torch.device("cpu") + assert get_device("cuda", 0) == torch.device("cuda:0") + assert get_device("cuda", 1) == torch.device("cuda:1") + assert get_device("cuda", 1) != torch.device("cuda:0") + assert get_device("mps", 0) == torch.device("mps") diff --git a/unit_tests/test_nii.py b/unit_tests/test_nii.py index 0677604..f8d1d3f 100755 --- a/unit_tests/test_nii.py +++ b/unit_tests/test_nii.py @@ -4,6 +4,7 @@ # coverage html from __future__ import annotations +import operator import random import sys import unittest @@ -11,6 +12,7 @@ import nibabel as nib import numpy as np +import pytest from TPTBox import NII, v_idx2name from TPTBox.core import np_utils @@ -32,6 +34,126 @@ def get_all_corner_points(affine, shape) -> np.ndarray: return a +class TestNII_MathOperators(unittest.TestCase): + """Tests that NII_Math operators match explicit NumPy array operations. + + Operator domain assumptions: + - Arithmetic & comparison ops operate on float arrays + - Bitwise ops (and, or, xor, invert) require integer arrays + """ + + @staticmethod + def make_nii(shape=(8, 9, 10), seed=0, dtype=float): + rng = np.random.default_rng(seed) + arr = rng.normal(size=shape) if dtype is float else rng.integers(0, 8, size=shape, dtype=dtype) + import nibabel as nib + + nii = NII((arr, np.eye(4), nib.nifti1.Nifti1Header())) + return nii + + def test_binary_operator_equivalence_float(self): + """Binary operators valid for float arrays.""" + binary_ops = [ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + operator.pow, + operator.lt, + operator.le, + operator.eq, + operator.ne, + operator.gt, + operator.ge, + ] + + for op in binary_ops: + with self.subTest(op=op): + nii1 = self.make_nii(seed=1, dtype=float) + nii2 = self.make_nii(seed=2, dtype=float) + + out_op = op(nii1, nii2) + expected = op(nii1.get_array(), nii2.get_array()) + + out_manual = nii1.set_array(expected, inplace=False) + + self.assertTrue(np.allclose(out_op.get_array(), out_manual.get_array(), equal_nan=True)) + + def test_binary_operator_equivalence_bitwise(self): + """Bitwise binary operators require integer arrays.""" + bitwise_ops = [ + operator.and_, + operator.or_, + operator.xor, + ] + + for op in bitwise_ops: + with self.subTest(op=op): + nii1 = self.make_nii(seed=3, dtype=np.int32) + nii2 = self.make_nii(seed=4, dtype=np.int32) + + out_op = op(nii1, nii2) + expected = op(nii1.get_array(), nii2.get_array()) + + out_manual = nii1.set_array(expected, inplace=False) + + np.testing.assert_array_equal(out_op.get_array(), out_manual.get_array()) + + def test_unary_operator_equivalence_float(self): + """Unary operators valid for float arrays.""" + unary_ops = [operator.neg, operator.pos, operator.abs, np.floor, np.ceil] + + for op in unary_ops: + with self.subTest(op=op): + nii = self.make_nii(seed=5, dtype=float) + + out_op = op(nii) + expected = op(nii.get_array()) + + out_manual = nii.set_array(expected, inplace=False) + + self.assertTrue(np.allclose(out_op.get_array(), out_manual.get_array())) + + def test_unary_operator_invert_integer(self): + """Bitwise invert requires integer arrays.""" + nii = self.make_nii(seed=6, dtype=np.int32) + + out_op = ~nii + expected = ~nii.get_array() + + out_manual = nii.set_array(expected, inplace=False) + + np.testing.assert_array_equal(out_op.get_array(), out_manual.get_array()) + + def test_inplace_binary_operator(self): + nii1 = self.make_nii(seed=7, dtype=float) + nii2 = self.make_nii(seed=8, dtype=float) + + arr_before = nii1.get_array().copy() + + nii1 += nii2 + expected = arr_before + nii2.get_array() + + self.assertTrue(np.allclose(nii1.get_array(), expected)) + + def test_inplace_unary_operator(self): + nii = self.make_nii(seed=9, dtype=float) + arr_before = nii.get_array().copy() + + nii *= 2 + self.assertTrue(np.allclose(nii.get_array(), arr_before * 2)) + + def test_round_equivalence(self): + nii = self.make_nii(seed=4) + out_op = round(nii, 2) + expected = np.round(nii.get_array(), 2) + out_manual = nii.set_array(expected, inplace=False) + print((out_op.get_array()[0], out_manual.get_array()[0])) + self.assertTrue(np.allclose(out_op.get_array(), out_manual.get_array())) + + class Test_bids_file(unittest.TestCase): def test_rescale_corners(self): for _ in range(repeats // 4): diff --git a/unit_tests/test_nrrd.py b/unit_tests/test_nrrd.py new file mode 100644 index 0000000..6c2c4a1 --- /dev/null +++ b/unit_tests/test_nrrd.py @@ -0,0 +1,144 @@ +import os +import unittest +from pathlib import Path + +import numpy as np +import requests + +# Import your functions +from TPTBox.core.internal.slicer_nrrd import load_slicer_nrrd, save_slicer_nrrd +from TPTBox.core.nii_wrapper import NII +from TPTBox.tests.test_utils import get_nii_paths_ct + +try: + import ants + + has_ants = True +except Exception: + has_ants = False + + +class TestAnts(unittest.TestCase): + @unittest.skipIf(not has_ants, "requires spineps to be installed") + def test_segmentation_CT(self): + """Test round-trip for Segmentation.seg.nrrd.""" + ct, subreg, vert = get_nii_paths_ct() + from TPTBox import NII, to_nii + from TPTBox.core.internal import ants_to_nifti, nifti_to_ants + + nii = to_nii(ct) + nii2 = ants_to_nifti(nifti_to_ants(nii.nii), nii.header) + nii2 = NII(nii2) + assert nii.orientation == nii2.orientation + assert np.isclose(nii.affine, nii2.affine).all() + assert np.isclose(nii.get_array(), nii.get_array()).all() + + # @unittest.skipIf(not has_ants, "requires spineps to be installed") + # def test_raf_ants(): + # ct, subreg, vert = get_nii_paths_ct() + # from TPTBox.core.internal import get_ras_affine_from_ants + + +class TestSlicerSegmentationIO(unittest.TestCase): + slicerio_data = Path(__file__).parent / "slicerio_data" + base_url = "https://raw.githubusercontent.com/lassoan/slicerio/main/slicerio/data" + + files = { # noqa: RUF012 + "CT": "CTChest4.nrrd", + "Seg": "Segmentation.seg.nrrd", + "SegOverlap": "SegmentationOverlapping.seg.nrrd", + } + + @classmethod + def setUpClass(cls): + """Ensure the data directory exists and download all test files.""" + os.makedirs(cls.slicerio_data, exist_ok=True) + for filename in cls.files.values(): + url = f"{cls.base_url}/{filename}" + out_local = cls.slicerio_data / filename + if not out_local.exists(): + cls.download_file(url, out_local) + + @staticmethod + def download_file(url: str, out_path: Path): + """Download a file from a URL.""" + resp = requests.get(url, stream=True) + resp.raise_for_status() + with open(out_path, "wb") as f: + f.writelines(resp.iter_content(chunk_size=8192)) + print(f"Downloaded {url} → {out_path}") + + def roundtrip_test(self, filename: str): + """Helper function: load, save, reload, and compare arrays.""" + path = self.slicerio_data / filename + seg = "seg." in filename + nii = load_slicer_nrrd(path, seg) + arr = nii.get_array() + + # Save to roundtrip file + out_seg = path.with_name(path.stem + ".roundtrip.seg.nrrd") + save_slicer_nrrd(nii, out_seg) + + # Reload saved file + nii2 = load_slicer_nrrd(out_seg, seg) + arr2 = nii2.get_array() + + # Compare arrays + self.assertTrue(np.array_equal(arr, arr2), f"Round-trip arrays differ for {filename}") + + # Optional: remove roundtrip file + out_seg.unlink(missing_ok=True) + + def roundtrip_test2(self, filename: str): + """Helper function: load, save, reload, and compare arrays.""" + path = self.slicerio_data / filename + seg = "seg." in filename + nii = NII.load(path, seg) + arr = nii.get_array() + + # Save to roundtrip file + out_seg = path.with_name(path.stem + ".roundtrip.seg.nrrd") + nii.save_nrrd(out_seg) + + # Reload saved file + nii2 = NII.load(out_seg, seg) + arr2 = nii2.get_array() + + # Compare arrays + self.assertTrue(np.array_equal(arr, arr2), f"Round-trip arrays differ for {filename}") + + # Optional: remove roundtrip file + out_seg.unlink(missing_ok=True) + + def test_segmentation(self): + """Test round-trip for Segmentation.seg.nrrd.""" + self.roundtrip_test(self.files["Seg"]) + + def test_segmentation_CT(self): + """Test round-trip for Segmentation.seg.nrrd.""" + self.roundtrip_test(self.files["CT"]) + + def test_segmentation_overlapping(self): + """Test round-trip for SegmentationOverlapping.seg.nrrd.""" + self.roundtrip_test(self.files["SegOverlap"]) + + def test_ct_file_exists(self): + """Just check that CT file exists.""" + path = self.slicerio_data / self.files["CT"] + self.assertTrue(path.exists(), f"{path} should exist") + + def test_segmentation2(self): + """Test round-trip for Segmentation.seg.nrrd.""" + self.roundtrip_test2(self.files["Seg"]) + + def test_segmentation_CT2(self): + """Test round-trip for Segmentation.seg.nrrd.""" + self.roundtrip_test(self.files["CT"]) + + def test_segmentation_overlapping2(self): + """Test round-trip for SegmentationOverlapping.seg.nrrd.""" + self.roundtrip_test2(self.files["SegOverlap"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_poi_autogen.py b/unit_tests/test_poi_autogen.py index beebb91..6fc4435 100755 --- a/unit_tests/test_poi_autogen.py +++ b/unit_tests/test_poi_autogen.py @@ -2,6 +2,7 @@ from __future__ import annotations import os +import random import tempfile import unittest from pathlib import Path @@ -10,6 +11,7 @@ from TPTBox.core.poi import POI from TPTBox.core.poi_fun.poi_global import POI_Global +from TPTBox.tests.test_utils import get_random_ax_code class TestPOI(unittest.TestCase): @@ -77,22 +79,143 @@ def test_local_to_global_valid_coordinates(self): # Check if the returned global coordinates match the expected coordinates assert global_coords == expected_coords - # Test that the 'affine' property returns the expected value when all required attributes are set. def test_affine_property(self): - # Create a POI object with all required attributes set + # Random but valid parameters + zoom = tuple(np.random.uniform(0.1, 5.0, size=3)) + rotation = np.linalg.qr(np.random.randn(3, 3))[0] # orthonormal rotation + origin = tuple(np.random.uniform(-100.0, 100.0, size=3)) + shape = tuple(np.random.randint(1, 512, size=3)) + + # Create a POI object poi = POI() - poi.zoom = (1.0, 1.0, 1.0) - poi.rotation = np.eye(3) - poi.origin = (0.0, 0.0, 0.0) - poi.shape = (10, 10, 10) + poi.zoom = zoom + poi.rotation = rotation + poi.origin = origin + poi.shape = shape - # Calculate the expected affine matrix + # Manually construct expected affine expected_affine = np.eye(4) - expected_affine[:3, :3] = np.eye(3) - expected_affine[:3, 3] = (0.0, 0.0, 0.0) + expected_affine[:3, :3] = rotation @ np.diag(zoom) + expected_affine[:3, 3] = origin # Check that the 'affine' property returns the expected value - assert np.array_equal(poi.affine, expected_affine) + assert np.allclose(poi.affine, expected_affine) + + def test_affine_property_2(self): + for _ in range(10): + # Random but valid parameters + zoom = tuple(np.random.uniform(0.1, 5.0, size=3)) + rotation = np.linalg.qr(np.random.randn(3, 3))[0] # random orthonormal rotation + origin = tuple(np.random.uniform(-100.0, 100.0, size=3)) + shape = tuple(np.random.randint(1, 512, size=3)) + + # Create first POI + poi = POI({(1, 1): (0, 0, 0)}) + poi.zoom = zoom + poi.rotation = rotation + poi.origin = origin + poi.shape = shape + poi.reorient(get_random_ax_code()) + + # Create second POI from affine + poi2 = POI() + poi2.affine = poi.affine + + # Zoom must be preserved + assert np.allclose(poi.zoom, poi2.zoom, rtol=1e-6, atol=1e-8) + + # Rotation must be preserved + assert np.allclose(poi.rotation, poi2.rotation, rtol=1e-5, atol=1e-6) + + # Origin (translation) must be preserved + assert np.allclose(poi.origin, poi2.origin, rtol=1e-6, atol=1e-8) + + # Full affine round-trip consistency + assert np.allclose(poi.affine, poi2.affine, rtol=1e-6, atol=1e-8) + + def test_change_affine_translation_only(self): + poi = POI() + poi.affine = np.eye(4) + + translation = np.array([10.0, -5.0, 2.5]) + + out = poi.change_affine(translation=translation) + + expected = np.eye(4) + expected[:3, 3] = translation + + assert np.allclose(out.affine, expected, rtol=1e-6, atol=1e-8) + # not inplace + assert not np.allclose(poi.affine, expected, rtol=1e-6, atol=1e-8) + + def test_change_affine_scaling_only(self): + poi = POI() + poi.affine = np.eye(4) + + scaling = np.array([2.0, 0.5, 3.0]) + + out = poi.change_affine(scaling=scaling) + + expected = np.eye(4) + expected[:3, :3] = np.diag(scaling) + + assert np.allclose(out.affine, expected, rtol=1e-6, atol=1e-8) + + def test_change_affine_rotation_only(self): + from scipy.spatial.transform import Rotation + + poi = POI() + poi.affine = np.eye(4) + + rotation = np.array([90.0, 0.0, 0.0]) + + out = poi.change_affine(rotation_degrees=rotation) + + R = Rotation.from_euler("xyz", rotation, degrees=True).as_matrix() + + expected = np.eye(4) + expected[:3, :3] = R + + assert np.allclose(out.affine, expected, rtol=1e-6, atol=1e-8) + + def test_change_affine_combined_order(self): + from scipy.spatial.transform import Rotation + + poi = POI() + poi.affine = np.eye(4) + + scaling = np.array([2.0, 2.0, 2.0]) + rotation = np.array([0.0, 0.0, 90.0]) + translation = np.array([1.0, 2.0, 3.0]) + + out = poi.change_affine(scaling=scaling, rotation_degrees=rotation, translation=translation) + + S = np.eye(4) + S[:3, :3] = np.diag(scaling) + + R = np.eye(4) + R[:3, :3] = Rotation.from_euler("xyz", rotation, degrees=True).as_matrix() + + T = np.eye(4) + T[:3, 3] = translation + + expected = T @ R @ S + + assert np.allclose(out.affine, expected, rtol=1e-6, atol=1e-8) + + def test_change_affine_inplace(self): + poi = POI() + poi.affine = np.eye(4) + + translation = np.array([1.0, 2.0, 3.0]) + + out = poi.change_affine_(translation=translation) + + expected = np.eye(4) + expected[:3, 3] = translation + + assert out is poi + assert np.allclose(poi.affine, expected) # Test that the 'POI' instance can be saved to a file. def test_save_poi(self):