Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pydfc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
time_series --- time series class
dfc --- dfc class
data_loader --- load data
multi_analysis_utils --- utility functions for multi analysis
to implement multiple dFC methods
simultaneously
dfc_utils --- functions used for dFC analysis
comparison --- functions used for dFC results comparison

Expand All @@ -24,6 +27,7 @@
"TIME_SERIES",
"DFC",
"data_loader",
"multi_analysis_utils",
"dfc_methods",
"dfc_utils",
"comparison",
Expand Down
14 changes: 10 additions & 4 deletions pydfc/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Implementation of dFC methods.
Implementation of functions for loading fmri data.

Created on Jun 29 2023
@author: Mohammad Torabi
Expand Down Expand Up @@ -363,9 +363,15 @@ def load_TS(
if "{run}" in file_name:
assert run is not None, "run must be provided"
TS_file = TS_file.replace("{run}", run)
time_series = np.load(
f"{data_root}/{subj_fldr}/{TS_file}", allow_pickle="True"
).item()

try:
time_series = np.load(
f"{data_root}/{subj_fldr}/{TS_file}", allow_pickle="True"
).item()
except FileNotFoundError:
print(f"File {TS_file} not found for {subj}")
continue

if TS[session] is None:
TS[session] = time_series
else:
Expand Down
16 changes: 10 additions & 6 deletions pydfc/dfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def dFC2dict(self, TRs=None):
dFC_mat = self.get_dFC_mat(TRs=TRs)
dFC_dict = {}
for k, TR in enumerate(TRs):
dFC_dict["TR" + str(TR)] = dFC_mat[k, :, :]
dFC_dict[f"TR{TR}"] = dFC_mat[k, :, :]
return dFC_dict

# test this
Expand Down Expand Up @@ -199,7 +199,7 @@ def get_dFC_mat(self, TRs=None, num_samples=None):

dFC_mat = list()
for TR in TRs:
dFC_mat.append(self.FCSs[self.FCS_idx["TR" + str(TR)]])
dFC_mat.append(self.FCSs[self.FCS_idx[f"TR{TR}"]])

dFC_mat = np.array(dFC_mat)

Expand Down Expand Up @@ -231,6 +231,10 @@ def SWed_dFC_mat(self, W=None, n_overlap=None, tapered_window=False):
return dFC_mat_new

def set_dFC(self, FCSs, FCS_idx=None, TS_info=None, TR_array=None):
"""
FCSs: a 3D numpy array of FC matrices with shape (n_time, n_regions, n_regions)
FCS_idx: a list of indices that correspond to each FC matrix in FCSs over time
"""

if len(FCSs.shape) == 2:
FCSs = np.expand_dims(FCSs, axis=0)
Expand Down Expand Up @@ -267,11 +271,11 @@ def set_dFC(self, FCSs, FCS_idx=None, TS_info=None, TR_array=None):
# the input FCS_idx is ranged from 0 to len(FCS)-1 but we shift it to 1 to len(FCS)
self.FCSs_ = {}
for i, FCS in enumerate(FCSs):
self.FCSs_["FCS" + str(i + 1)] = FCS
self.FCSs_[f"FCS{i + 1}"] = FCS

self.FCS_idx_ = {}
for i, idx in enumerate(FCS_idx):
self.FCS_idx_["TR" + str(TR_array[i])] = "FCS" + str(idx + 1)
self.FCS_idx_[f"TR{TR_array[i]}"] = f"FCS{idx + 1}" # "FCS" + str(idx + 1)

self.TS_info_ = TS_info
self.n_regions_ = FCSs.shape[1]
Expand All @@ -287,7 +291,7 @@ def visualize_dFC(
threshold=0.0,
fix_lim=False,
save_image=False,
fig_name=None,
output_root=None,
):

assert not self.measure is None, "Measure is not provided."
Expand Down Expand Up @@ -321,5 +325,5 @@ def visualize_dFC(
cmap=cmap,
center_0=center_0,
save_image=save_image,
output_root=fig_name,
output_root=output_root,
)
5 changes: 3 additions & 2 deletions pydfc/dfc_methods/cap.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def act_vec2FCS(self, act_vecs):
def cluster_act_vec(self, act_vecs, n_clusters):

kmeans_ = KMeans(n_clusters=n_clusters, n_init=500).fit(act_vecs)
kmeans_.cluster_centers_ = kmeans_.cluster_centers_.astype(np.float32)
act_centroids = kmeans_.cluster_centers_

return act_centroids, kmeans_
Expand Down Expand Up @@ -122,7 +123,7 @@ def estimate_FCS(self, time_series):
act_vecs=act_center_1st_level, n_clusters=self.params["n_states"]
)
self.FCS_ = self.act_vec2FCS(group_act_centroids)
self.Z = self.kmeans_.predict(time_series.data.T)
self.Z = self.kmeans_.predict(time_series.data.T.astype(np.float32))

# mean activation of states
self.set_mean_activity(time_series)
Expand All @@ -149,7 +150,7 @@ def estimate_dFC(self, time_series):

act_vecs = time_series.data.T

Z = self.kmeans_.predict(act_vecs)
Z = self.kmeans_.predict(act_vecs.astype(np.float32))

# record time
self.set_dFC_assess_time(time.time() - tic)
Expand Down
22 changes: 19 additions & 3 deletions pydfc/dfc_methods/sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
from scipy import signal
from sklearn.covariance import GraphicalLassoCV, graphical_lasso
from sklearn.covariance import GraphicalLasso, GraphicalLassoCV

from ..dfc import DFC
from ..time_series import TIME_SERIES
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, **params):
self.FCS_ = []
self.FCS_fit_time_ = None
self.dFC_assess_time_ = None
self.graphical_lasso_alpha_ = None

self.params_name_lst = [
"measure_name",
Expand Down Expand Up @@ -96,8 +97,13 @@ def calc_MI(self, X, Y):
def FC(self, time_series):

if self.params["sw_method"] == "GraphLasso":
model = GraphicalLassoCV()
model.fit(time_series.T)
# Standardize the data (zero mean, unit variance for each feature)
mean = np.mean(time_series, axis=1, keepdims=True)
std = np.std(time_series, axis=1, keepdims=True)
time_series_standardized = np.where(std != 0, (time_series - mean) / std, 0)
model = GraphicalLasso(alpha=self.graphical_lasso_alpha_)
model.fit(time_series_standardized.T)
# the covariance matrix will equal the correlation matrix
C = model.covariance_
else:
C = np.zeros((time_series.shape[0], time_series.shape[0]))
Expand Down Expand Up @@ -129,6 +135,12 @@ def dFC(self, time_series, W=None, n_overlap=None, tapered_window=False):
if step == 0:
step = 1

# find the L1 penalty for GraphLasso
if self.params["sw_method"] == "GraphLasso":
model = GraphicalLassoCV()
model.fit(time_series.T)
self.graphical_lasso_alpha_ = model.alpha_

window_taper = signal.windows.gaussian(W, std=3 * W / 22)
# C = DFC(measure=self)
FCSs = list()
Expand Down Expand Up @@ -161,6 +173,10 @@ def dFC(self, time_series, W=None, n_overlap=None, tapered_window=False):

return np.array(FCSs), np.array(TR_array)

def estimate_FCS(self, time_series):

return self

def estimate_dFC(self, time_series):
"""
we assume calc is applied on subjects separately
Expand Down
5 changes: 3 additions & 2 deletions pydfc/dfc_methods/sliding_window_clustr.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def cluster_FC(self, FCS_raw, n_clusters, n_regions):
else:
########### Euclidean Clustering ##############
kmeans_ = KMeans(n_clusters=n_clusters, n_init=500).fit(F)
kmeans_.cluster_centers_ = kmeans_.cluster_centers_.astype(np.float32)
F_cent = kmeans_.cluster_centers_

FCS_ = self.dFC_vec2mat(F_cent, N=n_regions)
Expand Down Expand Up @@ -228,7 +229,7 @@ def estimate_FCS(self, time_series):
n_clusters=self.params["n_states"],
n_regions=dFC_raw.n_regions,
)
self.Z = self.kmeans_.predict(self.dFC_mat2vec(SW_dFC))
self.Z = self.kmeans_.predict(self.dFC_mat2vec(SW_dFC).astype(np.float32))

# mean activation of states
self.set_mean_activity(time_series)
Expand Down Expand Up @@ -270,7 +271,7 @@ def estimate_dFC(self, time_series):
# Z = self.clusters_lst2idx(self.kmeans_.get_clusters())
else:
########### Euclidean Clustering ##############
Z = self.kmeans_.predict(F)
Z = self.kmeans_.predict(F.astype(np.float32))

# record time
self.set_dFC_assess_time(time.time() - tic)
Expand Down
4 changes: 4 additions & 0 deletions pydfc/dfc_methods/time_freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def WT_dFC(self, Y1, Y2, Fs, J, s0, dj):

return wt

def estimate_FCS(self, time_series):

return self

def estimate_dFC(self, time_series):
"""
we assume calc is applied on subjects separately
Expand Down
26 changes: 15 additions & 11 deletions pydfc/multi_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,30 +231,34 @@ def group_dFC_assess(self, time_series_dict):

return OUT

def subj_lvl_dFC_assess(self, time_series_dict):
def subj_lvl_dFC_assess(self, time_series):
"""
time_series can be a dict of time_series or a single time_series
if it is a dict, the time_series with key `measure.params["session"]`
will be used
"""

# time_series_dict is a dict of time_series
if isinstance(time_series, dict):
if not measure.params["session"] in time_series:
raise ValueError(
f"session {measure.params['session']} is not in time_series"
)
else:
time_series = time_series[measure.params["session"]]

dFC_dict = {}
# dFC_corr_assess_dict = {}

if self.params["n_jobs"] is None:
dFC_lst = list()
for measure in self.MEASURES_fit_lst_:
dFC_lst.append(
measure.estimate_dFC(
time_series=time_series_dict[measure.params["session"]]
)
)
dFC_lst.append(measure.estimate_dFC(time_series=time_series))
else:
dFC_lst = Parallel(
n_jobs=self.params["n_jobs"],
verbose=self.params["verbose"],
backend=self.params["backend"],
)(
delayed(measure.estimate_dFC)(
time_series=time_series_dict[measure.params["session"]]
)
delayed(measure.estimate_dFC)(time_series=time_series)
for measure in self.MEASURES_fit_lst_
)

Expand Down
Loading
Loading