diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 13c37c8..3d88eb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -17,7 +17,7 @@ jobs: name: Test on ${{ matrix.os }} with Python ${{ matrix.python-version }} runs-on: ${{ matrix.os }} strategy: - fail-fast: false + fail-fast: true matrix: python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] os: [ubuntu-latest, windows-latest, macos-latest] diff --git a/pyproject.toml b/pyproject.toml index fe880fa..52ae7bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,27 +13,39 @@ license = "MIT" license-files = ["LICENSE"] dependencies = [ - "click", "tqdm", "numpy", + "rich-argparse", "nibabel", "pandas"] [project.optional-dependencies] -extra = [] +show = [ + "textual-image", # Used to show images in terminal + "pillow", # Used to handle image data + "matplotlib" # Used for colormaps +] +napari = [ + "napari[all]", # Used for interactive visualization +] test = [ "pytest", "pytest-cov", + ] pypi = [ "build" # Used for building wheels and uploading to pypi ] -all = ["mri-toolkit[extra,test,pypi]"] +all = ["mri-toolkit[extra,test,pypi,show,napari]"] [project.urls] Homepage = "https://github.com/scientificcomputing/mri-toolkit.git" + +[project.scripts] +mritk = "mritk.cli:main" + [tool.mypy] ignore_missing_imports = true # Does not show errors when importing untyped libraries files = [ # Folder to which files that should be checked by mypy diff --git a/src/mritk/__init__.py b/src/mritk/__init__.py index cc1e6bc..24ab37f 100644 --- a/src/mritk/__init__.py +++ b/src/mritk/__init__.py @@ -12,7 +12,7 @@ meta = metadata("mri-toolkit") __version__ = meta["Version"] __author__ = meta["Author-email"] -__license__ = meta["License"] +__license__ = meta["license-expression"] __email__ = meta["Author-email"] __program_name__ = meta["Name"] diff --git a/src/mritk/cli.py b/src/mritk/cli.py index d74f0a2..1e126f4 100644 --- a/src/mritk/cli.py +++ b/src/mritk/cli.py @@ -1,25 +1,91 @@ import logging +from importlib.metadata import metadata from pathlib import Path import argparse from typing import Sequence, Optional -from . import download_data +from rich_argparse import RichHelpFormatter + +from . import download_data, info, statistics, show, napari + + +def version_info(): + from rich.console import Console + from rich.table import Table + from rich import box + import sys + import nibabel as nib + import numpy as np + + console = Console() + + meta = metadata("mri-toolkit") + toolkit_version = meta["Version"] + python_version = sys.version.split()[0] + + table = Table( + title="MRI Toolkit Environment", + box=box.ROUNDED, # Nice rounded corners + show_lines=True, # Separator lines between rows + header_style="bold magenta", + ) + + table.add_column("Package", style="cyan", no_wrap=True) + table.add_column("Version", style="green", justify="right") + + table.add_row("mri-toolkit", toolkit_version) + table.add_row("Python", python_version) + table.add_row("Nibabel", nib.__version__) + table.add_row("Numpy", np.__version__) + + console.print(table) def setup_parser(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser(formatter_class=RichHelpFormatter) + parser.add_argument("--version", action="store_true") subparsers = parser.add_subparsers(dest="command") # Download test data parser - download_parser = subparsers.add_parser("download-test-data", help="Download test data") + download_parser = subparsers.add_parser( + "download-test-data", help="Download test data", formatter_class=parser.formatter_class + ) download_parser.add_argument("outdir", type=Path, help="Output directory to download test data") + info_parser = subparsers.add_parser( + "info", help="Display information about a file", formatter_class=parser.formatter_class + ) + info_parser.add_argument("file", type=Path, help="File to display information about") + + info_parser.add_argument( + "--json", action="store_true", help="Output information in JSON format" + ) + + stats_parser = subparsers.add_parser( + "stats", help="Compute MRI statistics", formatter_class=parser.formatter_class + ) + statistics.cli.add_arguments(stats_parser) + + show_parser = subparsers.add_parser( + "show", help="Show MRI data in a terminal", formatter_class=parser.formatter_class + ) + show.add_arguments(show_parser) + + napari_parser = subparsers.add_parser( + "napari", help="Show MRI data using napari", formatter_class=parser.formatter_class + ) + napari.add_arguments(napari_parser) + return parser def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = None) -> int: args = vars(parser.parse_args(argv)) + + if args.pop("version"): + version_info() + return 0 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") command = args.pop("command") logger = logging.getLogger(__name__) @@ -27,6 +93,15 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No if command == "download-test-data": outdir = args.pop("outdir") download_data.download_test_data(outdir) + elif command == "info": + file = args.pop("file") + info.nifty_info(file, json_output=args.pop("json")) + elif command == "stats": + statistics.cli.dispatch(args) + elif command == "show": + show.dispatch(args) + elif command == "napari": + napari.dispatch(args) else: logger.error(f"Unknown command {command}") parser.print_help() diff --git a/src/mritk/data/io.py b/src/mritk/data/io.py index c678a11..95d6c6f 100644 --- a/src/mritk/data/io.py +++ b/src/mritk/data/io.py @@ -18,7 +18,7 @@ def load_mri_data( path: Path | str, - dtype: type, + dtype: type = np.float64, orient: bool = True, ) -> MRIData: suffix_regex = re.compile(r".+(?P(\.nii(\.gz|)|\.mg(z|h)))") diff --git a/src/mritk/info.py b/src/mritk/info.py new file mode 100644 index 0000000..8ee9ba5 --- /dev/null +++ b/src/mritk/info.py @@ -0,0 +1,91 @@ +import json +import typing +from pathlib import Path +import numpy as np +import nibabel as nib +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich import box + + +def custom_json(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif np.isscalar(obj): + return float(obj) + else: + return str(obj) + + +def nifty_info(filename: Path, json_output: bool = False) -> dict[str, typing.Any]: + console = Console() + + # 1. Load the NIfTI file + + img = nib.load(str(filename)) + header = img.header + affine = img.affine + + # --- Part A: Extracting Dimensions & Resolution --- + img_shape = img.shape + zooms = header.get_zooms() + data_type = header.get_data_dtype() + + data = { + "filename": str(filename), + "shape": img_shape, + "voxel_size_mm": zooms, + "data_type": data_type, + "affine": affine, + } + + if json_output: + print(json.dumps(data, default=custom_json, indent=4)) + return data + + # Create a nice header panel + console.print( + Panel( + f"[bold blue]NIfTI File Analysis[/bold blue]\n[green]{filename}[/green]", expand=False + ) + ) + + # Create a table for Basic Info + info_table = Table( + title="Basic Information", + box=box.SIMPLE_HEAVY, + show_header=True, + header_style="bold magenta", + ) + info_table.add_column("Property", style="cyan") + info_table.add_column("Value", style="white") + + # Format the tuples/lists as strings for the table + shape_str = ", ".join(map(str, img_shape)) + zoom_str = ", ".join([f"{z:.2f}" for z in zooms]) + + info_table.add_row("Shape (x, y, z)", f"({shape_str})") + info_table.add_row("Voxel Size (mm)", f"({zoom_str})") + info_table.add_row("Data Type", str(data_type)) + + console.print(info_table) + + # --- Part B: The Affine Matrix --- + + console.print("\n[bold]Affine Transformation Matrix[/bold] (Voxel → World)", style="yellow") + + # Create a specific table for the matrix to align numbers nicely + matrix_table = Table(show_header=False, box=box.ROUNDED, border_style="dim") + + # Add 4 columns for the 4x4 matrix + for _ in range(4): + matrix_table.add_column(justify="right", style="green") + + for row in affine: + # Format numbers to 4 decimal places for cleanliness + matrix_table.add_row(*[f"{val: .4f}" for val in row]) + + console.print(matrix_table) + + return data diff --git a/src/mritk/show.py b/src/mritk/show.py new file mode 100644 index 0000000..89e51fe --- /dev/null +++ b/src/mritk/show.py @@ -0,0 +1,134 @@ +import argparse +from pathlib import Path + +import numpy as np +from rich.console import Console +from rich.panel import Panel +from rich.columns import Columns + +# Assuming relative imports based on your previous file structure +from .data.io import load_mri_data + + +def add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("file", type=Path, help="File to show") + parser.add_argument( + "--cmap", + type=str, + default="gray", + help="Colormap to use for displaying the image (default: gray)", + ) + parser.add_argument( + "--slice-x", + type=float, + default=0.5, + help="Relative position (0-1) of the sagittal slice to display (default: 0.5)", + ) + parser.add_argument( + "--slice-y", + type=float, + default=0.5, + help="Relative position (0-1) of the coronal slice to display (default: 0.5)", + ) + parser.add_argument( + "--slice-z", + type=float, + default=0.5, + help="Relative position (0-1) of the axial slice to display (default: 0.5)", + ) + + +def normalize_to_uint8(data: np.ndarray) -> np.ndarray: + """Normalize array values to 0-255 uint8 range for image display.""" + # Handle NaNs and Infs + data = np.nan_to_num(data) + + d_min, d_max = data.min(), data.max() + if d_max > d_min: + # Linear scaling to 0-255 + normalized = (data - d_min) / (d_max - d_min) * 255 + else: + normalized = np.zeros_like(data) + + return normalized.astype(np.uint8) + + +def dispatch(args): + """ + Displays three orthogonal slices (Sagittal, Coronal, Axial) of an MRI file + in the terminal. + """ + try: + from textual_image.renderable import Image as TermImage + import PIL.Image + except ImportError: + console = Console() + console.print( + "[bold red]Error:[/bold red] The 'textual_image' and 'pillow' " + "packages are required to use the 'show' command. " + "Please install with: 'pip install mri-toolkit[show]'" + ) + return + + # 1. Load Data + # Assuming args is a dict or Namespace. Adapting to your snippet's usage: + + file_path = args.pop("file") + cmap_name = args.pop("cmap", "gray") + slize_x = args.pop("slice_x", 0.5) + slize_y = args.pop("slice_y", 0.5) + slize_z = args.pop("slice_z", 0.5) + + console = Console() + console.print(f"[bold green]Loading MRI data from:[/bold green] {file_path}") + + mri_resource = load_mri_data(file_path) + data = mri_resource.data + + # 2. Define Slice Indices (Middle of the brain) + x_idx = int(data.shape[0] * slize_x) + y_idx = int(data.shape[1] * slize_y) + z_idx = int(data.shape[2] * slize_z) + + # 3. Extract Slices + # orientation in load_mri_data is typically RAS (Right, Anterior, Superior) + # Numpy origin is top-left. We often need to rotate/flip for correct medical view. + + # Sagittal View (Side): Fix X. Axes are Y (Ant) and Z (Sup). + # We rotate 90 deg so Superior is "Up". + slice_sagittal = np.rot90(data[x_idx, :, :]) + + # Coronal View (Front): Fix Y. Axes are X (Right) and Z (Sup). + slice_coronal = np.rot90(data[:, y_idx, :]) + + # Axial View (Top-down): Fix Z. Axes are X (Right) and Y (Ant). + slice_axial = np.rot90(data[:, :, z_idx]) + + # 4. Prepare Images + slices = [("Sagittal", slice_sagittal), ("Coronal", slice_coronal), ("Axial", slice_axial)] + + panels = [] + + try: + from matplotlib import cm + except ImportError: + cmap = lambda x: x / 255 # Identity if matplotlib not available + else: + cmap = cm.get_cmap(cmap_name) + + for title, slice_data in slices: + # Normalize data to 0-255 + img_uint8 = normalize_to_uint8(slice_data) + + # Create PIL Image + # pil_image = PIL.Image.fromarray(img_uint8) + pil_image = PIL.Image.fromarray((cmap(img_uint8) * 255).astype(np.uint8)) + + # Create Terminal Image + term_img = TermImage(pil_image) + + # Add to list as a Panel + panels.append(Panel(term_img, title=title, expand=False)) + + # 5. Display + console.print(Columns(panels, equal=True)) diff --git a/src/mritk/statistics/__init__.py b/src/mritk/statistics/__init__.py index 1ba5c7d..496fa01 100644 --- a/src/mritk/statistics/__init__.py +++ b/src/mritk/statistics/__init__.py @@ -4,9 +4,6 @@ Copyright (C) 2026 Simula Research Laboratory """ -from . import utils, compute_stats +from . import utils, compute_stats, cli -__all__ = [ - "utils", - "compute_stats", -] +__all__ = ["utils", "compute_stats", "cli"] diff --git a/src/mritk/statistics/cli.py b/src/mritk/statistics/cli.py new file mode 100644 index 0000000..b49957b --- /dev/null +++ b/src/mritk/statistics/cli.py @@ -0,0 +1,210 @@ +import argparse +import typing +from pathlib import Path +import pandas as pd + +from ..segmentation.groups import default_segmentation_groups +from .compute_stats import generate_stats_dataframe + + +def compute_mri_stats( + segmentation: Path, + mri: list[Path], + output: Path, + timetable: Path | None = None, + timelabel: str | None = None, + seg_regex: str | None = None, + mri_regex: str | None = None, + lut: Path | None = None, + info: str | None = None, + **kwargs, +): + import sys + import json + from rich.console import Console + from rich.panel import Panel + + # Setup Rich + console = Console() + + # Parse info dict from JSON string if provided + info_dict = None + if info: + try: + info_dict = json.loads(info) + except json.JSONDecodeError: + console.print("[bold red]Error:[/bold red] --info must be a valid JSON string.") + sys.exit(1) + + if not segmentation.exists(): + console.print(f"[bold red]Error:[/bold red] Missing segmentation file: {segmentation}") + sys.exit(1) + + # Validate all MRI paths before starting + for path in mri: + if not path.exists(): + console.print(f"[bold red]Error:[/bold red] Missing MRI file: {path}") + sys.exit(1) + + dataframes = [] + + # Loop through MRI paths + console.print("[bold green]Processing MRIs...[/bold green]") + for i, path in enumerate(mri): + # console.print(f"[blue]Processing MRI {i + 1}/{len(mri)}:[/blue] {path.name}") + + try: + # Call the logic function + df = generate_stats_dataframe( + seg_path=segmentation, + mri_path=path, + timestamp_path=timetable, + timestamp_sequence=timelabel, + seg_pattern=seg_regex, + mri_data_pattern=mri_regex, + lut_path=lut, + info_dict=info_dict, + ) + dataframes.append(df) + except Exception as e: + console.print(f"[bold red]Failed to process {path.name}:[/bold red] {e}") + sys.exit(1) + + if dataframes: + final_df = pd.concat(dataframes) + final_df.to_csv(output, sep=";", index=False) + console.print( + Panel( + f"Stats successfully saved to:\n[bold green]{output}[/bold green]", + title="Success", + expand=False, + ) + ) + else: + console.print("[yellow]No dataframes generated.[/yellow]") + + +def get_stats_value(stats_file: Path, region: str, info: str, **kwargs): + """ + Replaces the @click.command('get') decorated function. + """ + import sys + from rich.console import Console + + # Setup Rich + console = Console() + + # Validate inputs + valid_regions = default_segmentation_groups().keys() + if region not in valid_regions: + console.print( + f"[bold red]Error:[/bold red] Region '{region}' " + "not found in default segmentation groups." + ) + sys.exit(1) + + valid_infos = [ + "sum", + "mean", + "median", + "std", + "min", + "max", + "PC1", + "PC5", + "PC25", + "PC75", + "PC90", + "PC95", + "PC99", + ] + if info not in valid_infos: + console.print( + f"[bold red]Error:[/bold red] Info '{info}' " + f"is invalid. Choose from: {', '.join(valid_infos)}" + ) + sys.exit(1) + + if not stats_file.exists(): + console.print(f"[bold red]Error:[/bold red] Stats file not found: {stats_file}") + sys.exit(1) + + # Process + try: + df = pd.read_csv(stats_file, sep=";") + region_row = df.loc[df["description"] == region] + + if region_row.empty: + console.print(f"[red]Region '{region}' not found in the stats file.[/red]") + sys.exit(1) + + info_value = region_row[info].values[0] + + # Output + console.print( + f"[bold cyan]{info}[/bold cyan] for [bold green]{region}[/bold green] " + f"= [bold white]{info_value}[/bold white]" + ) + return info_value + + except Exception as e: + console.print(f"[bold red]Error reading stats file:[/bold red] {e}") + sys.exit(1) + + +def add_arguments(parser: argparse.ArgumentParser): + subparsers = parser.add_subparsers(dest="stats-command", help="Available commands") + + # --- Compute Command --- + parser_compute = subparsers.add_parser( + "compute", help="Compute MRI statistics", formatter_class=parser.formatter_class + ) + parser_compute.add_argument( + "--segmentation", "-s", type=Path, required=True, help="Path to segmentation file" + ) + parser_compute.add_argument( + "--mri", "-m", type=Path, nargs="+", required=True, help="Path to MRI data file(s)" + ) + parser_compute.add_argument( + "--output", "-o", type=Path, required=True, help="Output CSV file path" + ) + parser_compute.add_argument("--timetable", "-t", type=Path, help="Path to timetable file") + parser_compute.add_argument( + "--timelabel", "-l", dest="timelabel", type=str, help="Time label sequence" + ) + parser_compute.add_argument( + "--seg_regex", + "-sr", + dest="seg_regex", + type=str, + help="Regex pattern for segmentation filename", + ) + parser_compute.add_argument( + "--mri_regex", "-mr", dest="mri_regex", type=str, help="Regex pattern for MRI filename" + ) + parser_compute.add_argument("--lut", "-lt", dest="lut", type=Path, help="Path to Lookup Table") + parser_compute.add_argument("--info", "-i", type=str, help="Info dictionary as JSON string") + parser_compute.set_defaults(func=compute_mri_stats) + + # --- Get Command --- + parser_get = subparsers.add_parser( + "get", help="Get specific stats value", formatter_class=parser.formatter_class + ) + parser_get.add_argument( + "--stats_file", "-f", type=Path, required=True, help="Path to stats CSV file" + ) + parser_get.add_argument("--region", "-r", type=str, required=True, help="Region description") + parser_get.add_argument( + "--info", "-i", type=str, required=True, help="Statistic to retrieve (mean, std, etc.)" + ) + parser_get.set_defaults(func=get_stats_value) + + +def dispatch(args: dict[str, typing.Any]): + command = args.pop("stats-command") + if command == "compute": + compute_mri_stats(**args) + elif command == "get": + get_stats_value(**args) + else: + raise ValueError(f"Unknown command: {command}") diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index b38388f..7234d46 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -10,8 +10,7 @@ import re import numpy as np import pandas as pd -import tqdm -import click +import tqdm.rich from ..data.io import load_mri_data from ..data.orientation import assert_same_space @@ -20,89 +19,6 @@ from .utils import voxel_count_to_ml_scale, find_timestamp, prepend_info -@click.group() -def mristats(): - pass - - -@mristats.command("compute") -@click.option("--segmentation", "-s", "seg_path", type=Path, required=True) -@click.option("--mri", "-m", "mri_paths", multiple=True, type=Path, required=True) -@click.option("--output", "-o", type=Path, required=True) -@click.option("--timetable", "-t", type=Path) -@click.option("--timelabel", "-l", "timetable_sequence", type=str) -@click.option("--seg_regex", "-sr", "seg_pattern", type=str) -@click.option("--mri_regex", "-mr", "mri_data_pattern", type=str) -@click.option("--lut", "-lt", "lut_path", type=Path) -@click.option("--info", "-i", "info_dict", type=dict) -## FIXME : Need to check that all the given mri in mri_paths -## are registered to the same baseline MRI - this is done in create_dataframe -def compute_mri_stats( - seg_path: str | Path, - mri_paths: tuple[str | Path], - output: str | Path, - timetable: Optional[str | Path], - timetable_sequence: Optional[str | Path], - seg_pattern: Optional[str | Path], - mri_data_pattern: Optional[str | Path], - lut_path: Optional[Path] = None, - info_dict: Optional[dict] = None, -): - if not Path(seg_path).exists(): - raise RuntimeError(f"Missing segmentation: {seg_path}") - - for path in mri_paths: - if not Path(path).exists(): - raise RuntimeError(f"Missing: {path}") - - dataframes = [ - generate_stats_dataframe( - Path(seg_path), - Path(path), - timetable, - timetable_sequence, - seg_pattern, - mri_data_pattern, - lut_path, - info_dict, - ) - for path in mri_paths - ] - pd.concat(dataframes).to_csv(output, sep=";", index=False) - - -# FIXME : This function with one mri_path but should be able to handle dataframe with multiple MRIs -@mristats.command("get") -@click.option("--stats_file", "-f", "stats_file", type=Path, required=True) -@click.option("--region", "-r", "region", type=str) -@click.option("--info", "-i", "info", type=str) -def get_stats_value(stats_file: str | Path, region: str, info: str): - assert region in default_segmentation_groups().keys() - assert info in [ - "sum", - "mean", - "median", - "std", - "min", - "max", - "PC1", - "PC5", - "PC25", - "PC75", - "PC90", - "PC95", - "PC99", - ] - - df = pd.read_csv(stats_file, sep=";") - - region_row = df.loc[df["description"] == region] - info_value = region_row[info].values[0] - print(f"{info}[{region}] = {info_value}") - - return info_value - - def generate_stats_dataframe( seg_path: Path, mri_path: Path, @@ -178,7 +94,8 @@ def generate_stats_dataframe( records = [] finite_mask = np.isfinite(mri.data) volscale = voxel_count_to_ml_scale(seg.affine) - for description, labels in tqdm.tqdm(regions.items()): + + for description, labels in tqdm.rich.tqdm(regions.items(), total=len(regions)): region_mask = np.isin(seg.data, labels) voxelcount = region_mask.sum() record = { diff --git a/test/conftest.py b/test/conftest.py index 29b9c5e..dbfe221 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,7 +1,8 @@ +from pathlib import Path import os import pytest @pytest.fixture(scope="session") -def mri_data_dir(): - return os.getenv("MRITK_TEST_DATA_FOLDER", "test_data") +def mri_data_dir() -> Path: + return Path(os.getenv("MRITK_TEST_DATA_FOLDER", "test_data")) diff --git a/test/test_cli.py b/test/test_cli.py new file mode 100644 index 0000000..6c1eb28 --- /dev/null +++ b/test/test_cli.py @@ -0,0 +1,38 @@ +import json +import mritk +import mritk.cli as cli + + +def test_cli_version(capsys): + cli.main(["--version"]) + captured = capsys.readouterr() + assert "MRI Toolkit Environment" in captured.out + assert f"mri-toolkit │ {mritk.__version__} │" in captured.out + + +def test_cli_info(capsys, mri_data_dir): + test_file = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" + ) + args = ["info", str(test_file)] + cli.main(args) + captured = capsys.readouterr() + assert "Voxel Size (mm) (0.50, 0.50, 0.50)" in captured.out + assert "Shape (x, y, z) (368, 512, 512)" in captured.out + + +def test_cli_info_json(capsys, mri_data_dir): + test_file = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" + ) + args = ["info", str(test_file), "--json"] + cli.main(args) + captured = capsys.readouterr() + data = json.loads(captured.out) + + assert "sub-01_ses-01_concentration.nii.gz" in data["filename"] + assert data["shape"] == [368, 512, 512] diff --git a/test/test_mri_io.py b/test/test_mri_io.py index e66e14c..cf68b42 100644 --- a/test/test_mri_io.py +++ b/test/test_mri_io.py @@ -6,16 +6,15 @@ """ import numpy as np -import os - from mritk.data.io import load_mri_data, save_mri_data def test_mri_io_nifti(tmp_path, mri_data_dir): - input_file = os.path.join( - mri_data_dir, - "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz", + input_file = ( + mri_data_dir + / "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz" ) + output_file = tmp_path / "output_nifti.nii.gz" mri = load_mri_data(input_file, dtype=np.single, orient=False) ## TODO : Test orient=True case diff --git a/test/test_mri_stats.py b/test/test_mri_stats.py index 28e2304..638f8b7 100644 --- a/test/test_mri_stats.py +++ b/test/test_mri_stats.py @@ -5,21 +5,23 @@ Copyright (C) 2026 Simula Research Laboratory """ -import os -from click.testing import CliRunner from pathlib import Path -from mritk.statistics.compute_stats import generate_stats_dataframe, compute_mri_stats +from mritk.statistics.compute_stats import generate_stats_dataframe # , compute_mri_stats +import mritk.cli as cli -def test_compute_stats_default(mri_data_dir): - seg_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", + +def test_compute_stats_default(mri_data_dir: Path): + seg_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" ) - mri_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/concentrations/sub-01_ses-01_concentration.nii.gz", + mri_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" ) dataframe = generate_stats_dataframe(seg_path, mri_path) @@ -52,14 +54,16 @@ def test_compute_stats_default(mri_data_dir): } -def test_compute_stats_patterns(mri_data_dir): - seg_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", +def test_compute_stats_patterns(mri_data_dir: Path): + seg_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" ) - mri_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/concentrations/sub-01_ses-01_concentration.nii.gz", + mri_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" ) seg_pattern = "(?Psub-(control|patient)*\\d{2})_seg-(?P[^\\.]+)" mri_data_pattern = ( @@ -80,20 +84,22 @@ def test_compute_stats_patterns(mri_data_dir): assert dataframe["session"].iloc[0] == "ses-01" -def test_compute_stats_timestamp(mri_data_dir): - seg_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", +def test_compute_stats_timestamp(mri_data_dir: Path): + seg_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" ) - mri_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/concentrations/sub-01_ses-01_concentration.nii.gz", + mri_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" ) seg_pattern = "(?Psub-(control|patient)*\\d{2})_seg-(?P[^\\.]+)" mri_data_pattern = ( "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" ) - timetable = os.path.join(mri_data_dir, "timetable/timetable.tsv") + timetable = mri_data_dir / "timetable/timetable.tsv" timetable_sequence = "mixed" dataframe = generate_stats_dataframe( @@ -108,14 +114,16 @@ def test_compute_stats_timestamp(mri_data_dir): assert dataframe["timestamp"].iloc[0] == -6414.9 -def test_compute_stats_info(mri_data_dir): - seg_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", +def test_compute_stats_info(mri_data_dir: Path): + seg_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" ) - mri_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/concentrations/sub-01_ses-01_concentration.nii.gz", + mri_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" ) info = { "mri_data": "concentration", @@ -133,40 +141,44 @@ def test_compute_stats_info(mri_data_dir): assert dataframe["session"].iloc[0] == "ses-01" -def test_compute_mri_stats_cli(tmp_path, mri_data_dir): - runner = CliRunner() - seg_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", +def test_compute_mri_stats_cli(capsys, tmp_path: Path, mri_data_dir: Path): + seg_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" ) - mri_path = os.path.join( - mri_data_dir, - "mri-processed/mri_processed_data/sub-01/concentrations/sub-01_ses-01_concentration.nii.gz", + mri_path = ( + mri_data_dir + / "mri-processed/mri_processed_data/sub-01" + / "concentrations/sub-01_ses-01_concentration.nii.gz" ) seg_pattern = "(?Psub-(control|patient)*\\d{2})_seg-(?P[^\\.]+)" mri_data_pattern = ( "(?Psub-(control|patient)*\\d{2})_(?Pses-\\d{2})_(?P[^\\.]+)" ) - timetable = os.path.join(mri_data_dir, "timetable/timetable.tsv") + timetable = mri_data_dir / "timetable/timetable.tsv" timetable_sequence = "mixed" - result = runner.invoke( - compute_mri_stats, - [ - "--segmentation", - seg_path, - "--mri", - mri_path, - "--output", - Path(str(tmp_path / "mri_stats_output.csv")), - "--timetable", - timetable, - "--timelabel", - timetable_sequence, - "--seg_regex", - seg_pattern, - "--mri_regex", - mri_data_pattern, - ], - ) - assert result.exit_code == 0 + args = [ + "--segmentation", + str(seg_path), + "--mri", + str(mri_path), + "--output", + str(tmp_path / "mri_stats_output.csv"), + "--timetable", + str(timetable), + "--timelabel", + timetable_sequence, + "--seg_regex", + seg_pattern, + "--mri_regex", + mri_data_pattern, + ] + + ret = cli.main(["stats", "compute"] + args) + assert ret == 0 + captured = capsys.readouterr() + assert "Processing MRIs..." in captured.out + assert "Stats successfully saved to" in captured.out + assert (tmp_path / "mri_stats_output.csv").exists()