Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 233 additions & 43 deletions pydfc/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

import h5py
import numpy as np
from nilearn import datasets
from nilearn.interfaces.fmriprep import load_confounds, load_confounds_strategy
from nilearn.maskers import NiftiLabelsMasker, NiftiSpheresMasker
from nilearn.plotting import find_parcellation_cut_coords

from .dfc_utils import intersection, label2network
from .time_series import TIME_SERIES
Expand Down Expand Up @@ -150,13 +154,18 @@ def load_from_array(subj_id2load=None, **params):
return BOLD


def nifti2array(nifti_file, confound_strategy="none", standardize=False, n_rois=100):
def extract_region_signals(
nifti_file,
masker_type="NiftiLabelsMasker",
confound_strategy="none",
standardize=False,
labels_img=None,
seeds=None,
radius=None,
):
"""
this function uses nilearn maskers to extract
BOLD signals from nifti files
For now it only works with schaefer atlas,
but you can set the number of rois to extract
{100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}

returns a numpy array of shape (time, roi)
and labels and locs of rois
Expand All @@ -167,37 +176,38 @@ def nifti2array(nifti_file, confound_strategy="none", standardize=False, n_rois=
'no_motion_no_gsr': motion parameters are used
and global signal regression
is applied.
"""
from nilearn import datasets
from nilearn.interfaces.fmriprep import load_confounds
from nilearn.maskers import NiftiLabelsMasker
from nilearn.plotting import find_parcellation_cut_coords

parc = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois)
atlas_filename = parc.maps
labels = parc.labels
# The list of labels does not contain ‘Background’ by default.
# To have proper indexing, you should either manually add ‘Background’ to the list of labels:
# Prepend background label
labels = np.insert(labels, 0, "Background")

# extract locs
# test!
# check if order is the same as labels
locs, labels_ = find_parcellation_cut_coords(
atlas_filename, background_label=0, return_label_names=True
)

# create the masker for extracting time series
masker = NiftiLabelsMasker(
labels_img=atlas_filename,
labels=labels,
resampling_target="data",
standardize=standardize,
)
'simple': nilearn's simple preprocessing with
full motion and basic wm_csf
and high_pass

labels = np.delete(labels, 0) # remove the background label
labels = [label.decode() for label in labels]
For now it only works with NiftiLabelsMasker and NiftiSpheresMasker and not with NiftiMapsMasker
masker_type: "NiftiLabelsMasker" or "NiftiSpheresMasker"
"""
if masker_type == "NiftiSpheresMasker":
# check if seeds and radius are provided
if seeds is None or radius is None:
raise ValueError("For NiftiSpheresMasker, seeds and radius must be provided.")
# create the masker for extracting time series
masker = NiftiSpheresMasker(
seeds=seeds,
radius=radius, # radius in mm
standardize=standardize,
)
elif masker_type == "NiftiLabelsMasker":
# check if labels_img is provided
if labels_img is None:
raise ValueError("For NiftiLabelsMasker, labels_img must be provided.")
# create the masker for extracting time series
masker = NiftiLabelsMasker(
labels_img=labels_img,
resampling_target="data",
standardize=standardize,
)
else:
raise ValueError(
"masker_type must be 'NiftiLabelsMasker' or 'NiftiSpheresMasker', "
f"but got {masker_type}"
)

### extract the timeseries
if confound_strategy == "none":
Expand All @@ -223,16 +233,146 @@ def nifti2array(nifti_file, confound_strategy="none", standardize=False, n_rois=
time_series = masker.fit_transform(
nifti_file, confounds=confounds_simple, sample_mask=sample_mask
)
elif confound_strategy == "simple":
confounds_simple, sample_mask = load_confounds_strategy(
nifti_file, denoise_strategy="simple"
)
time_series = masker.fit_transform(
nifti_file, confounds=confounds_simple, sample_mask=sample_mask
)
else:
raise ValueError(
"confound_strategy must be one of 'none', 'no_motion', 'no_motion_no_gsr', or 'simple', "
f"but got {confound_strategy}"
)

return time_series


def nifti2array(
nifti_file,
masker_type="NiftiLabelsMasker",
confound_strategy="none",
standardize=False,
n_rois=100,
labels_img=None,
seeds=None,
radius=None,
region_names=None,
):
"""
this function uses nilearn maskers to extract
BOLD signals from nifti files

returns a numpy array of shape (time, roi)
and labels and locs of rois

confound_strategy:
'none': no confounds are used
'no_motion': motion parameters are used
'no_motion_no_gsr': motion parameters are used
and global signal regression
is applied.
'simple': nilearn's simple preprocessing with
full motion and basic wm_csf
and high_pass

For now it only works with NiftiLabelsMasker and NiftiSpheresMasker and not with NiftiMapsMasker
masker_type: "NiftiLabelsMasker" or "NiftiSpheresMasker"
if masker_type is "NiftiLabelsMasker",
labels_img must be provided or n_rois must be provided
if masker_type is "NiftiSpheresMasker",
seeds and radius must be provided

Note:
when not using Schaefer atlas, make sure
that the labels_img/seeds and region_names are in the same order.
"""
if masker_type == "NiftiLabelsMasker":
if labels_img is None:
# in this case, we will use the schaefer atlas
# we use n_rois to determine the number of rois
assert n_rois in [
100,
200,
300,
400,
500,
600,
700,
800,
900,
1000,
], "n_rois must be one of {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}"
# fetch the schaefer atlas
parc = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois)
labels_img = parc.maps
labels = parc.labels
labels = [label.decode() for label in labels]
else:
assert (
region_names is not None
), "region_names must be provided if labels_img is provided"
assert type(region_names) is list, "region_names must be a list of strings"

labels = region_names

# extract locs from labels_img
# check if order is the same as labels
locs, labels_ = find_parcellation_cut_coords(
labels_img, background_label=0, return_label_names=True
) # numpy.ndarray of shape (n_labels, 3)

elif masker_type == "NiftiSpheresMasker":

# make sure seeds is a list of tuples (x, y, z)
assert seeds is not None, "seeds must be provided for NiftiSpheresMasker"
assert radius is not None, "radius must be provided for NiftiSpheresMasker"
assert type(seeds) is list, "seeds must be a list of tuples (x, y, z)"
assert all(
isinstance(seed, tuple) and len(seed) == 3 for seed in seeds
), "seeds must be a list of tuples (x, y, z) with 3 elements each"

locs = np.array(seeds) # seeds should be a list of tuples (x, y, z)

assert (
region_names is not None
), "region_names must be provided if seeds are provided"
assert type(region_names) is list, "region_names must be a list of strings"

labels = region_names

else:
raise ValueError(
"masker_type must be 'NiftiLabelsMasker' or 'NiftiSpheresMasker', "
f"but got {masker_type}"
)

# extract the timeseries
time_series = extract_region_signals(
nifti_file=nifti_file,
masker_type=masker_type,
confound_strategy=confound_strategy,
standardize=standardize,
labels_img=labels_img,
seeds=seeds,
radius=radius,
)

return time_series, labels, locs


def nifti2timeseries(
nifti_file,
n_rois,
Fs,
subj_id,
confound_strategy="none",
masker_type="NiftiLabelsMasker",
n_rois=100,
labels_img=None,
seeds=None,
radius=None,
region_names=None,
standardize=False,
TS_name=None,
session=None,
Expand All @@ -242,15 +382,50 @@ def nifti2timeseries(
it uses nilearn maskers to extract ROI signals from nifti files
and returns a TIME_SERIES object

For now it only works with schaefer atlas,
but you can set the number of rois to extract
{100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}
Parameters
----------
nifti_file : str
path to the nifti file
Fs : float
sampling frequency of the data
subj_id : str
subject ID, must start with 'sub-'
confound_strategy : str, optional
strategy for confound regression, by default "none"
masker_type : str, optional
type of masker to use, by default "NiftiLabelsMasker"
n_rois : int, optional
number of regions of interest to extract, by default 100
labels_img : str, optional
path to the labels image, by default None
seeds : list, optional
list of tuples (x, y, z) for NiftiSpheresMasker
by default None
radius : float, optional
radius in mm for NiftiSpheresMasker, by default None
region_names : list, optional
list of region names for NiftiLabelsMasker or NiftiSpheresMasker,
by default None
standardize : bool, optional
whether to standardize the time series, by default False
TS_name : str, optional
name of the time series, by default None
session : str, optional
session name, by default None

For more information on confound_strategy, masker_type, and other parameters,
see the documentation of the nifti2array function.
"""
time_series, labels, locs = nifti2array(
nifti_file=nifti_file,
confound_strategy=confound_strategy,
standardize=standardize,
masker_type=masker_type,
n_rois=n_rois,
labels_img=labels_img,
seeds=seeds,
radius=radius,
region_names=region_names,
)

assert type(locs) is np.ndarray, "locs must be a numpy array"
Expand Down Expand Up @@ -280,8 +455,13 @@ def nifti2timeseries(
def multi_nifti2timeseries(
nifti_files_list,
subj_id_list,
n_rois,
Fs,
masker_type="NiftiLabelsMasker",
n_rois=100,
labels_img=None,
seeds=None,
radius=None,
region_names=None,
confound_strategy="none",
standardize=False,
TS_name=None,
Expand All @@ -295,10 +475,15 @@ def multi_nifti2timeseries(
if BOLD_multi is None:
BOLD_multi = nifti2timeseries(
nifti_file=nifti_file,
n_rois=n_rois,
Fs=Fs,
subj_id=subj_id,
Fs=Fs,
confound_strategy=confound_strategy,
masker_type=masker_type,
n_rois=n_rois,
labels_img=labels_img,
seeds=seeds,
radius=radius,
region_names=region_names,
standardize=standardize,
TS_name=TS_name,
session=session,
Expand All @@ -307,10 +492,15 @@ def multi_nifti2timeseries(
BOLD_multi.concat_ts(
nifti2timeseries(
nifti_file=nifti_file,
n_rois=n_rois,
Fs=Fs,
subj_id=subj_id,
Fs=Fs,
confound_strategy=confound_strategy,
masker_type=masker_type,
n_rois=n_rois,
labels_img=labels_img,
seeds=seeds,
radius=radius,
region_names=region_names,
standardize=standardize,
TS_name=TS_name,
session=session,
Expand Down
Loading
Loading