Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/mritk/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from rich_argparse import RichHelpFormatter

from . import download_data, info
from . import download_data, info, statistics


def version_info():
Expand Down Expand Up @@ -48,16 +48,25 @@ def setup_parser():
subparsers = parser.add_subparsers(dest="command")

# Download test data parser
download_parser = subparsers.add_parser("download-test-data", help="Download test data")
download_parser = subparsers.add_parser(
"download-test-data", help="Download test data", formatter_class=parser.formatter_class
)
download_parser.add_argument("outdir", type=Path, help="Output directory to download test data")

info_parser = subparsers.add_parser("info", help="Display information about a file")
info_parser = subparsers.add_parser(
"info", help="Display information about a file", formatter_class=parser.formatter_class
)
info_parser.add_argument("file", type=Path, help="File to display information about")

info_parser.add_argument(
"--json", action="store_true", help="Output information in JSON format"
)

stats_parser = subparsers.add_parser(
"stats", help="Compute MRI statistics", formatter_class=parser.formatter_class
)
statistics.cli.add_arguments(stats_parser)

return parser


Expand All @@ -77,6 +86,8 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No
elif command == "info":
file = args.pop("file")
info.nifty_info(file, json_output=args.pop("json"))
elif command == "stats":
statistics.cli.dispatch(args)
else:
logger.error(f"Unknown command {command}")
parser.print_help()
Expand Down
7 changes: 2 additions & 5 deletions src/mritk/statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
Copyright (C) 2026 Simula Research Laboratory
"""

from . import utils, compute_stats
from . import utils, compute_stats, cli

__all__ = [
"utils",
"compute_stats",
]
__all__ = ["utils", "compute_stats", "cli"]
210 changes: 210 additions & 0 deletions src/mritk/statistics/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import argparse
import typing
from pathlib import Path
import pandas as pd

from ..segmentation.groups import default_segmentation_groups
from .compute_stats import generate_stats_dataframe


def compute_mri_stats(
segmentation: Path,
mri: list[Path],
output: Path,
timetable: Path | None = None,
timelabel: str | None = None,
seg_regex: str | None = None,
mri_regex: str | None = None,
lut: Path | None = None,
info: str | None = None,
**kwargs,
):
import sys
import json
from rich.console import Console
from rich.panel import Panel

# Setup Rich
console = Console()

# Parse info dict from JSON string if provided
info_dict = None
if info:
try:
info_dict = json.loads(info)
except json.JSONDecodeError:
console.print("[bold red]Error:[/bold red] --info must be a valid JSON string.")
sys.exit(1)

if not segmentation.exists():
console.print(f"[bold red]Error:[/bold red] Missing segmentation file: {segmentation}")
sys.exit(1)

# Validate all MRI paths before starting
for path in mri:
if not path.exists():
console.print(f"[bold red]Error:[/bold red] Missing MRI file: {path}")
sys.exit(1)

dataframes = []

# Loop through MRI paths
console.print("[bold green]Processing MRIs...[/bold green]")
for i, path in enumerate(mri):
# console.print(f"[blue]Processing MRI {i + 1}/{len(mri)}:[/blue] {path.name}")

try:
# Call the logic function
df = generate_stats_dataframe(
seg_path=segmentation,
mri_path=path,
timestamp_path=timetable,
timestamp_sequence=timelabel,
seg_pattern=seg_regex,
mri_data_pattern=mri_regex,
lut_path=lut,
info_dict=info_dict,
)
dataframes.append(df)
except Exception as e:
console.print(f"[bold red]Failed to process {path.name}:[/bold red] {e}")
sys.exit(1)

if dataframes:
final_df = pd.concat(dataframes)
final_df.to_csv(output, sep=";", index=False)
console.print(
Panel(
f"Stats successfully saved to:\n[bold green]{output}[/bold green]",
title="Success",
expand=False,
)
)
else:
console.print("[yellow]No dataframes generated.[/yellow]")


def get_stats_value(stats_file: Path, region: str, info: str, **kwargs):
"""
Replaces the @click.command('get') decorated function.
"""
import sys
from rich.console import Console

# Setup Rich
console = Console()

# Validate inputs
valid_regions = default_segmentation_groups().keys()
if region not in valid_regions:
console.print(
f"[bold red]Error:[/bold red] Region '{region}' "
"not found in default segmentation groups."
)
sys.exit(1)

valid_infos = [
"sum",
"mean",
"median",
"std",
"min",
"max",
"PC1",
"PC5",
"PC25",
"PC75",
"PC90",
"PC95",
"PC99",
]
if info not in valid_infos:
console.print(
f"[bold red]Error:[/bold red] Info '{info}' "
f"is invalid. Choose from: {', '.join(valid_infos)}"
)
sys.exit(1)

if not stats_file.exists():
console.print(f"[bold red]Error:[/bold red] Stats file not found: {stats_file}")
sys.exit(1)

# Process
try:
df = pd.read_csv(stats_file, sep=";")
region_row = df.loc[df["description"] == region]

if region_row.empty:
console.print(f"[red]Region '{region}' not found in the stats file.[/red]")
sys.exit(1)

info_value = region_row[info].values[0]

# Output
console.print(
f"[bold cyan]{info}[/bold cyan] for [bold green]{region}[/bold green] "
f"= [bold white]{info_value}[/bold white]"
)
return info_value

except Exception as e:
console.print(f"[bold red]Error reading stats file:[/bold red] {e}")
sys.exit(1)


def add_arguments(parser: argparse.ArgumentParser):
subparsers = parser.add_subparsers(dest="stats-command", help="Available commands")

# --- Compute Command ---
parser_compute = subparsers.add_parser(
"compute", help="Compute MRI statistics", formatter_class=parser.formatter_class
)
parser_compute.add_argument(
"--segmentation", "-s", type=Path, required=True, help="Path to segmentation file"
)
parser_compute.add_argument(
"--mri", "-m", type=Path, nargs="+", required=True, help="Path to MRI data file(s)"
)
parser_compute.add_argument(
"--output", "-o", type=Path, required=True, help="Output CSV file path"
)
parser_compute.add_argument("--timetable", "-t", type=Path, help="Path to timetable file")
parser_compute.add_argument(
"--timelabel", "-l", dest="timelabel", type=str, help="Time label sequence"
)
parser_compute.add_argument(
"--seg_regex",
"-sr",
dest="seg_regex",
type=str,
help="Regex pattern for segmentation filename",
)
parser_compute.add_argument(
"--mri_regex", "-mr", dest="mri_regex", type=str, help="Regex pattern for MRI filename"
)
parser_compute.add_argument("--lut", "-lt", dest="lut", type=Path, help="Path to Lookup Table")
parser_compute.add_argument("--info", "-i", type=str, help="Info dictionary as JSON string")
parser_compute.set_defaults(func=compute_mri_stats)

# --- Get Command ---
parser_get = subparsers.add_parser(
"get", help="Get specific stats value", formatter_class=parser.formatter_class
)
parser_get.add_argument(
"--stats_file", "-f", type=Path, required=True, help="Path to stats CSV file"
)
parser_get.add_argument("--region", "-r", type=str, required=True, help="Region description")
parser_get.add_argument(
"--info", "-i", type=str, required=True, help="Statistic to retrieve (mean, std, etc.)"
)
parser_get.set_defaults(func=get_stats_value)


def dispatch(args: dict[str, typing.Any]):
command = args.pop("stats-command")
if command == "compute":
compute_mri_stats(**args)
elif command == "get":
get_stats_value(**args)
else:
raise ValueError(f"Unknown command: {command}")
89 changes: 3 additions & 86 deletions src/mritk/statistics/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import re
import numpy as np
import pandas as pd
import tqdm
import click
import tqdm.rich

from ..data.io import load_mri_data
from ..data.orientation import assert_same_space
Expand All @@ -20,89 +19,6 @@
from .utils import voxel_count_to_ml_scale, find_timestamp, prepend_info


@click.group()
def mristats():
pass


@mristats.command("compute")
@click.option("--segmentation", "-s", "seg_path", type=Path, required=True)
@click.option("--mri", "-m", "mri_paths", multiple=True, type=Path, required=True)
@click.option("--output", "-o", type=Path, required=True)
@click.option("--timetable", "-t", type=Path)
@click.option("--timelabel", "-l", "timetable_sequence", type=str)
@click.option("--seg_regex", "-sr", "seg_pattern", type=str)
@click.option("--mri_regex", "-mr", "mri_data_pattern", type=str)
@click.option("--lut", "-lt", "lut_path", type=Path)
@click.option("--info", "-i", "info_dict", type=dict)
## FIXME : Need to check that all the given mri in mri_paths
## are registered to the same baseline MRI - this is done in create_dataframe
def compute_mri_stats(
seg_path: str | Path,
mri_paths: tuple[str | Path],
output: str | Path,
timetable: Optional[str | Path],
timetable_sequence: Optional[str | Path],
seg_pattern: Optional[str | Path],
mri_data_pattern: Optional[str | Path],
lut_path: Optional[Path] = None,
info_dict: Optional[dict] = None,
):
if not Path(seg_path).exists():
raise RuntimeError(f"Missing segmentation: {seg_path}")

for path in mri_paths:
if not Path(path).exists():
raise RuntimeError(f"Missing: {path}")

dataframes = [
generate_stats_dataframe(
Path(seg_path),
Path(path),
timetable,
timetable_sequence,
seg_pattern,
mri_data_pattern,
lut_path,
info_dict,
)
for path in mri_paths
]
pd.concat(dataframes).to_csv(output, sep=";", index=False)


# FIXME : This function with one mri_path but should be able to handle dataframe with multiple MRIs
@mristats.command("get")
@click.option("--stats_file", "-f", "stats_file", type=Path, required=True)
@click.option("--region", "-r", "region", type=str)
@click.option("--info", "-i", "info", type=str)
def get_stats_value(stats_file: str | Path, region: str, info: str):
assert region in default_segmentation_groups().keys()
assert info in [
"sum",
"mean",
"median",
"std",
"min",
"max",
"PC1",
"PC5",
"PC25",
"PC75",
"PC90",
"PC95",
"PC99",
]

df = pd.read_csv(stats_file, sep=";")

region_row = df.loc[df["description"] == region]
info_value = region_row[info].values[0]
print(f"{info}[{region}] = {info_value}")

return info_value


def generate_stats_dataframe(
seg_path: Path,
mri_path: Path,
Expand Down Expand Up @@ -178,7 +94,8 @@ def generate_stats_dataframe(
records = []
finite_mask = np.isfinite(mri.data)
volscale = voxel_count_to_ml_scale(seg.affine)
for description, labels in tqdm.tqdm(regions.items()):

for description, labels in tqdm.rich.tqdm(regions.items(), total=len(regions)):
region_mask = np.isin(seg.data, labels)
voxelcount = region_mask.sum()
record = {
Expand Down
5 changes: 3 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
import os
import pytest


@pytest.fixture(scope="session")
def mri_data_dir():
return os.getenv("MRITK_TEST_DATA_FOLDER", "test_data")
def mri_data_dir() -> Path:
return Path(os.getenv("MRITK_TEST_DATA_FOLDER", "test_data"))
Loading
Loading