From 2ba24c16d5ba1750911d3c632dec2e8c2bd83f38 Mon Sep 17 00:00:00 2001 From: mtorabi59 Date: Thu, 5 May 2022 13:51:46 -0400 Subject: [PATCH] review add --- BIC_codes/FCS_estimate.py | 149 + BIC_codes/dFC_assessment.py | 172 + BIC_codes/functions/__init__.py | 0 BIC_codes/functions/dFC_funcs.py | 3472 +++++++++++++++++ BIC_codes/main.py | 372 ++ BIC_codes/post_analysis.py | 87 + BIC_codes/test_dFC.py | 360 ++ README.md | 4 + demo.ipynb | 727 ++++ .../ROI_data_Gordon_333_surf.mat | Bin 0 -> 1498360 bytes .../ROI_data_Gordon_333_surf.mat | Bin 0 -> 1496005 bytes sampleDATA/Gordon333_Key.txt | 334 ++ sampleDATA/Gordon333_LOCS.mat | Bin 0 -> 2668 bytes 13 files changed, 5677 insertions(+) create mode 100644 BIC_codes/FCS_estimate.py create mode 100644 BIC_codes/dFC_assessment.py create mode 100644 BIC_codes/functions/__init__.py create mode 100755 BIC_codes/functions/dFC_funcs.py create mode 100644 BIC_codes/main.py create mode 100644 BIC_codes/post_analysis.py create mode 100644 BIC_codes/test_dFC.py create mode 100644 README.md create mode 100644 demo.ipynb create mode 100755 sampleDATA/100206_Rest1_LR/ROI_data_Gordon_333_surf.mat create mode 100755 sampleDATA/100307_Rest1_LR/ROI_data_Gordon_333_surf.mat create mode 100755 sampleDATA/Gordon333_Key.txt create mode 100644 sampleDATA/Gordon333_LOCS.mat diff --git a/BIC_codes/FCS_estimate.py b/BIC_codes/FCS_estimate.py new file mode 100644 index 0000000..4876bff --- /dev/null +++ b/BIC_codes/FCS_estimate.py @@ -0,0 +1,149 @@ +from functions.dFC_funcs import * +import numpy as np +import time +import hdf5storage +import scipy.io as sio +import os +os.environ["MKL_NUM_THREADS"] = '64' +os.environ["NUMEXPR_NUM_THREADS"] = '64' +os.environ["OMP_NUM_THREADS"] = '64' + +print('################################# CODE started running ... #################################') + +################################# Parameters ################################# + +###### DATA PARAMETERS ###### + +output_root = './../../../../../RESULTs/methods_implementation/' +# output_root = '/data/origami/dFC/RESULTs/methods_implementation/' +# output_root = '/Users/mte/Documents/McGill/Project/dFC/RESULTs/methods_implementation/' + +# DATA_type is either 'sample' or 'Gordon' or 'simulated' or 'ICA' +params_data_load = { \ + 'DATA_type': 'Gordon', \ + 'SESSIONs':['Rest1_LR' , 'Rest1_RL', 'Rest2_LR', 'Rest2_RL'], \ + + 'data_root_simul': './../../../../DATA/TVB data/', \ + 'data_root_sample': './sampleDATA/', \ + 'data_root_gordon': './../../../../DATA/HCP/HCP_Gordon/', \ + 'data_root_ica': './../../../../DATA/HCP/HCP_PTN1200/node_timeseries/3T_HCP1200_MSMAll_d50_ts2/' +} + +###### MEASUREMENT PARAMETERS ###### + +# W is in sec + +params_methods = { \ + # Sliding Parameters + 'W': 44, 'n_overlap': 0.5, 'sw_method':'pear_corr', 'tapered_window':True, \ + # TIME_FREQ + 'TF_method':'WTC', \ + # CLUSTERING AND DHMM + 'clstr_base_measure':'SlidingWindow', \ + # HMM + 'hmm_iter': 50, 'n_hid_states': 24, \ + # State Parameters + 'n_states': 12, 'n_subj_clstrs': 20, \ + # Parallelization Parameters + 'n_jobs': 2, 'verbose': 0, 'backend': 'loky', \ + # SESSION + 'session': 'Rest1_LR', \ + # Hyper Parameters + 'normalization': True, \ + 'num_subj': 395, \ + 'num_select_nodes': 333, \ + 'num_time_point': 1200, \ + 'Fs_ratio': 1.00, \ + 'noise_ratio': 0.00, \ + 'num_realization': 1 \ +} + +###### HYPER PARAMETERS ALTERNATIVE ###### + +MEASURES_name_lst = [ \ + 'SlidingWindow', \ + 'Time-Freq', \ + 'CAP', \ + 'ContinuousHMM', \ + 'Windowless', \ + 'Clustering', \ + 'DiscreteHMM' \ + ] + +alter_hparams = { \ + # 'session': [], \ + 'n_states': [6], \ + # 'normalization': [], \ + # 'num_subj': [5], \ + # 'num_select_nodes': [50], \ + # 'num_time_point': [500], \ + 'Fs_ratio': [0.40], \ + 'noise_ratio': [2.00], \ + # 'num_realization': [] \ + } + +###### dFC ANALYZER PARAMETERS ###### + +params_dFC_analyzer = { \ + # VISUALIZATION + 'vis_TR_idx': list(range(10, 20, 1)),'save_image': True, 'output_root': output_root, \ + # Parallelization Parameters + 'n_jobs': 8, 'verbose': 0, 'backend': 'loky' \ +} + + +################################# LOAD DATA ################################# + +data_loader = DATA_LOADER(**params_data_load) +BOLD = data_loader.load() + +################################# Visualize BOLD ################################# + +# for session in BOLD: +# BOLD[session].visualize(start_time=0, end_time=50, nodes_lst=list(range(10)), \ +# save_image=params_dFC_analyzer['save_image'], output_root=output_root+'BOLD_signal_'+session) + +################################# Measures of dFC ################################# + +dFC_analyzer = DFC_ANALYZER( \ + analysis_name='reproducibility assessment', \ + **params_dFC_analyzer \ +) + +MEASURES_lst = dFC_analyzer.measures_initializer( \ + MEASURES_name_lst, \ + params_methods, \ + alter_hparams \ + ) + +tic = time.time() +print('Measurement Started ...') + +################################# estimate FCS ################################# + +task_id = int(os.getenv("SGE_TASK_ID")) +MEASURE_id = task_id-1 # SGE_TASK_ID starts from 1 not 0 + + +if MEASURE_id >= len(MEASURES_lst): + print("MEASURE_id out of MEASURES_lst ") +else: + measure = MEASURES_lst[MEASURE_id] + + print("FCS estimation started...") + + time_series = BOLD[measure.params['session']] + if measure.is_state_based: + measure.estimate_FCS(time_series=time_series) + + # dFC_analyzer.estimate_group_FCS(time_series_dict=BOLD) + print("FCS estimation done.") + + print('Measurement required %0.3f seconds.' % (time.time() - tic, )) + + # Save + np.save('./fitted_MEASURES/MEASURE_'+str(MEASURE_id)+'.npy', measure) + np.save('./dFC_analyzer.npy', dFC_analyzer) + np.save('./data_loader.npy', data_loader) + +################################################################################# \ No newline at end of file diff --git a/BIC_codes/dFC_assessment.py b/BIC_codes/dFC_assessment.py new file mode 100644 index 0000000..941f691 --- /dev/null +++ b/BIC_codes/dFC_assessment.py @@ -0,0 +1,172 @@ +from functions.dFC_funcs import * +import numpy as np +import time +import hdf5storage +import scipy.io as sio +import os +os.environ["MKL_NUM_THREADS"] = '64' +os.environ["NUMEXPR_NUM_THREADS"] = '64' +os.environ["OMP_NUM_THREADS"] = '64' + +print('################################# subject-level dFC assessment CODE started running ... #################################') + +################################# Parameters ################################# + +# subj_id = '100206' + +###### DATA PARAMETERS ###### + +output_root = './../../../../../RESULTs/methods_implementation/' +# output_root = '/data/origami/dFC/RESULTs/methods_implementation/' +# output_root = '/Users/mte/Documents/McGill/Project/dFC/RESULTs/methods_implementation/' + +################################# LOAD ################################# + +dFC_analyzer = np.load('./dFC_analyzer.npy',allow_pickle='TRUE').item() +data_loader = np.load('./data_loader.npy',allow_pickle='TRUE').item() + +################################# LOAD FIT MEASURES ################################# + +if dFC_analyzer.MEASURES_fit_lst==[]: + ALL_RECORDS = os.listdir('./fitted_MEASURES/') + ALL_RECORDS = [i for i in ALL_RECORDS if 'MEASURE' in i] + ALL_RECORDS.sort() + MEASURES_fit_lst = list() + for s in ALL_RECORDS: + fit_measure = np.load('./fitted_MEASURES/'+s,allow_pickle='TRUE').item() + MEASURES_fit_lst.append(fit_measure) + dFC_analyzer.set_MEASURES_fit_lst(MEASURES_fit_lst) + print('fitted MEASURES loaded ...') + # np.save('./dFC_analyzer.npy', dFC_analyzer) + +################################# LOAD DATA ################################# + +task_id = int(os.getenv("SGE_TASK_ID")) +subj_id = data_loader.SUBJECTS[task_id-1] # SGE_TASK_ID starts from 1 not 0 + +BOLD = data_loader.load(subj_id2load=subj_id) + +################################# dFC ASSESSMENT ################################# + +tic = time.time() +print('Measurement Started ...') + +print("dFCM estimation started...") +dFCM_dict = dFC_analyzer.subj_lvl_dFC_assess(time_series_dict=BOLD) +# SUBJ_output = dFC_analyzer.group_dFCM_assess(time_series_dict=BOLD) +print("dFCM estimation done.") + +print('Measurement required %0.3f seconds.' % (time.time() - tic, )) + + +################################# POST ANALYSIS ################################# + +SUBJ_output = {} + +dFCM_lst = dFCM_dict['dFCM_lst'] + +# Save dFC samples +common_TRs = TR_intersection(dFCM_lst) +dFCM_sample_dict = {} +dFCM_sample_dict['common_TRs'] = common_TRs +dFCM_sample_dict['dFCM'] = {} +for i, dFCM in enumerate(dFCM_lst): + dFCM_sample = dFCM.get_dFC_mat(TRs=common_TRs) + dFCM_sample_dict['dFCM'][str(i)] = {} + dFCM_sample_dict['dFCM'][str(i)]['mat'] = dFCM_sample + dFCM_sample_dict['dFCM'][str(i)]['info'] = dFCM.measure.info +np.save('./dFC_samples/SUBJ_'+str(subj_id)+'_dFC.npy', dFCM_sample_dict) + +########################## DEFAULT VALUES ####################### + +param_dict = dFC_analyzer.params_methods +analysis_name_lst = [ \ + 'corr_mat', \ + 'dFC_distance', \ + 'dFC_distance_var', \ + 'FO', \ + 'CO', \ + 'TP', \ + 'trans_freq' \ + ] +dFCM_lst2check = filter_dFCM_lst(dFCM_lst, **param_dict) +SUBJ_output['default_values'] = dFC_analyzer.post_analysis( \ + dFCM_lst=dFCM_lst2check, \ + analysis_name_lst=analysis_name_lst \ + ) + +########################## 6_states ####################### + +param_dict = {'n_states': [6], 'is_state_based': [True]} +analysis_name_lst = [ \ + 'corr_mat', \ + 'dFC_distance', \ + 'dFC_distance_var', \ + 'FO', \ + 'CO', \ + 'TP', \ + 'trans_freq' \ + ] +dFCM_lst2check = filter_dFCM_lst(dFCM_lst, **param_dict) +SUBJ_output['6_states'] = dFC_analyzer.post_analysis( \ + dFCM_lst=dFCM_lst2check, \ + analysis_name_lst=analysis_name_lst \ + ) + +########################## SlidingWindow_100_nodes ####################### + +param_dict = {'measure_name': ['SlidingWindow'], 'num_select_nodes': [100]} +analysis_name_lst = [ \ + 'corr_mat', \ + 'dFC_distance', \ + 'dFC_distance_var', \ + 'FO', \ + 'CO', \ + 'TP', \ + 'trans_freq' \ + ] +dFCM_lst2check = filter_dFCM_lst(dFCM_lst, **param_dict) +SUBJ_output['SlidingWindow_100_nodes'] = dFC_analyzer.post_analysis( \ + dFCM_lst=dFCM_lst2check, \ + analysis_name_lst=analysis_name_lst \ + ) + +########################## Fs_ratio_0.5 ####################### + +param_dict = {'Fs_ratio': [0.5]} +analysis_name_lst = [ \ + 'corr_mat', \ + 'dFC_distance', \ + 'dFC_distance_var', \ + 'FO', \ + 'CO', \ + 'TP', \ + 'trans_freq' \ + ] +dFCM_lst2check = filter_dFCM_lst(dFCM_lst, **param_dict) +SUBJ_output['Fs_ratio_0.5'] = dFC_analyzer.post_analysis( \ + dFCM_lst=dFCM_lst2check, \ + analysis_name_lst=analysis_name_lst \ + ) + +########################## noise_ratio_1 ####################### + +param_dict = {'noise_ratio': [1.0]} +analysis_name_lst = [ \ + 'corr_mat', \ + 'dFC_distance', \ + 'dFC_distance_var', \ + 'FO', \ + 'CO', \ + 'TP', \ + 'trans_freq' \ + ] +dFCM_lst2check = filter_dFCM_lst(dFCM_lst, **param_dict) +SUBJ_output['noise_ratio_1'] = dFC_analyzer.post_analysis( \ + dFCM_lst=dFCM_lst2check, \ + analysis_name_lst=analysis_name_lst \ + ) + +# Save +np.save('./dFC_assessed/SUBJ_'+str(subj_id)+'_output.npy', SUBJ_output) +################################################################################# \ No newline at end of file diff --git a/BIC_codes/functions/__init__.py b/BIC_codes/functions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/BIC_codes/functions/dFC_funcs.py b/BIC_codes/functions/dFC_funcs.py new file mode 100755 index 0000000..42e3923 --- /dev/null +++ b/BIC_codes/functions/dFC_funcs.py @@ -0,0 +1,3472 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sun Jun 13 22:34:49 2021 + +@author: mte +""" + +import numpy as np +from scipy import signal +from copy import deepcopy +import matplotlib.pyplot as plt +import networkx as nx +from scipy.spatial import distance +from joblib import Parallel, delayed +import os +import time +import hdf5storage +import scipy.io as sio +from sklearn.preprocessing import power_transform + +# ########## bundled brain graph visualizer ########## + +# import pandas as pd +# import panel as pn +# import datashader as ds +# import datashader.transfer_functions as tf +# from datashader.layout import random_layout, circular_layout, forceatlas2_layout +# from datashader.bundling import connect_edges, hammer_bundle +# from datashader import utils +# import holoviews as hv +# from itertools import chain + +# import warnings + +# warnings.simplefilter('ignore') + +################################# Parameters #################################### + +fig_dpi = 120 +fig_bbox_inches = 'tight' +fig_pad = 0.1 + +################################# Other Functions #################################### + +# test +def get_subj_ts_dict(time_series_dict, subjs_id): + subj_ts_dict = {} + for session in time_series_dict: + subj_ts_dict[session] = time_series_dict[session].get_subj_ts(subjs_id=subjs_id) + return subj_ts_dict + +# test +def filter_dFCM_lst(dFCM_lst, **param_dict): + dFCM_lst2check = list() + for dFCM in dFCM_lst: + if dFCM.measure.param_match(**param_dict): + dFCM_lst2check.append(dFCM) + return dFCM_lst2check + +# test +def normalizeAdjacency(W): + """ + NormalizeAdjacency: Computes the [0, 1]-normalized adjacency matrix + + Input: + + W (np.array): adjacency matrix + + Output: + + W_norm (np.array): [0, 1] normalized adjacency matrix + """ + W_norm = W - np.min(W) + W_norm = np.divide(W_norm, np.max(W_norm)) + return W_norm + +# test +def normalized_euc_dist(x, y): + # https://stats.stackexchange.com/questions/136232/definition-of-normalized-euclidean-distance#:~:text=The%20normalized%20squared%20euclidean%20distance,not%20related%20to%20Mahalanobis%20distance. + + if np.linalg.norm(x-np.mean(x))**2==0 and np.linalg.norm(y-np.mean(y))**2==0: + return 0 + return 0.5*((np.linalg.norm((x-np.mean(x)) - (y-np.mean(y)))**2)/(np.linalg.norm(x-np.mean(x))**2 + np.linalg.norm(y-np.mean(y))**2)) + +def calc_ECM(A): + """ + calc_ECM: Computes Eigenvector Centrality Mapping (ECM) + of adjacency matrix A + + Input: + + A (np.array): adjacency matrix + + Output: + + centrality (np.array): ECM vector + """ + G = nx.from_numpy_matrix(A) + G.remove_edges_from(nx.selfloop_edges(G)) + # G = G.to_undirected() + centrality = nx.eigenvector_centrality(G, weight='weight') + # centrality = nx.pagerank(G, alpha=0.85) + + centrality = [centrality[node] for node in centrality] + + return centrality + +# test +def zip_name(name_lst): + # zip measure names + new_name_lst = list() + for name in name_lst: + if 'Clustering' in name: + new_name = 'SWC' + name[name.rfind('_'):] + if 'CAP' in name: + new_name = 'CAP' + name[name.rfind('_'):] + if 'ContinuousHMM' in name: + new_name = 'CHMM' + name[name.rfind('_'):] + if 'Windowless' in name: + new_name = 'WL' + name[name.rfind('_'):] + if 'DiscreteHMM' in name: + new_name = 'DHMM' + name[name.rfind('_'):] + if 'Time-Freq' in name: + new_name = 'TF' + name[name.rfind('_'):] + if 'SlidingWindow' in name: + new_name = 'SW' + name[name.rfind('_'):] + new_name_lst.append(new_name) + return new_name_lst + +# test +# pear_corr problem +def unzip_name(name): + # zip measure names + flag=False + if not '_' in name: + name = name + '_' + flag=True + if 'CAP' in name: + new_name = 'CAP' + name[name.rfind('_'):] + if 'SWC' in name: + new_name = 'Clustering' + name[name.rfind('_'):] + if 'CHMM' in name: + new_name = 'ContinuousHMM' + name[name.rfind('_'):] + if 'WL' in name: + new_name = 'Windowless' + name[name.rfind('_'):] + if 'DHMM' in name: + new_name = 'DiscreteHMM' + name[name.rfind('_'):] + if 'TF' in name: + new_name = 'Time-Freq' + name[name.rfind('_'):] + if 'SW_' in name: + new_name = 'SlidingWindow' + name[name.rfind('_'):] + if flag: + new_name = new_name[:-1] + return new_name + +#test +def dFC_mat2vec(C_t): + ''' + C_t must be an array of matrices or a single matrix + diagonal values not included. if you want to include + them set k=0 + ''' + if len(C_t.shape)==2: + assert C_t.shape[0]==C_t.shape[1],\ + 'C is not a square matrix' + return C_t[np.triu_indices(C_t.shape[1], k=1)] + + F = list() + for t in range(C_t.shape[0]): + C = C_t[t, : , :] + assert C.shape[0]==C.shape[1],\ + 'C is not a square matrix' + F.append(C[np.triu_indices(C_t.shape[1], k=1)]) + + F = np.array(F) + return F + +#test +def dFC_vec2mat(F, N): + ''' + diagonal values are set to 1.0 + F shape is (observations, features) + ''' + C = list() + iu = np.triu_indices(N, k=1) + for i in range(F.shape[0]): + K = np.zeros((N, N)) + K[iu] = F[i,:] + K = K + K.T + K = K + np.eye(N) + C.append(K) + C = np.array(C) + return C + +# test +def common_subj_lst(time_series_dict): + SUBJECTs = None + for session in time_series_dict: + if SUBJECTs is None: + SUBJECTs = time_series_dict[session].subj_id_lst + else: + SUBJECTs = intersection(SUBJECTs, time_series_dict[session].subj_id_lst) + return SUBJECTs + +def intersection(lst1, lst2): # input is a list + lst3 = [value for value in lst1 if value in lst2] + return lst3 + +def TR_intersection(dFCM_lst): # input is a list of dFCM objs + TRs_lst_old = dFCM_lst[0].TR_array + for dFCM in dFCM_lst: + TRs_lst_new = intersection(TRs_lst_old, dFCM.TR_array) + TRs_lst_old = TRs_lst_new + TRs_lst_old.sort() + if len(TRs_lst_old)==0: + print('No TR intersection.') + return TRs_lst_old + +def dFC_dict_slice(data, idx_lst): + data_sliced = {} + for i, k in enumerate(data): + if i in idx_lst: + data_sliced[k] = data[k] + return data_sliced + +def visualize_state_TC(TC_lst, \ + TRs, \ + state_lst, \ + TC_name_lst, \ + title='', \ + save_image=None, output_root=None\ + ): + + color_lst = ['k', 'b', 'g', 'r'] + + if 'on' in state_lst and 'off' in state_lst: + ticks = range(2) + else: + ticks = range(1, len(state_lst)+1) + + plt.figure(figsize=(25, 5)) + for i, TC in enumerate(TC_lst): + plt.plot(TRs, TC, color_lst[i], linewidth=2) + plt.xlabel('TR') + plt.yticks(ticks=ticks, labels=state_lst) + plt.legend(TC_name_lst) + plt.title(title) + if save_image: + folder = output_root[:output_root.rfind('/')] + if not os.path.exists(folder): + os.makedirs(folder) + plt.savefig(output_root+'.png', \ + dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad \ + ) + plt.close() + else: + plt.show() + + return + +def visualize_conn_mat(data, title='', \ + name_lst_key=None, mat_key=None, \ + cmap='viridis',\ + normalize=False,\ + disp_diag=True,\ + save_image=False, output_root=None, \ + fix_lim=True, lim_val=1.0 \ + ): + + ''' + - name_lst_key can be a list of names or the key to list of names + - data must be a dict of correlation/connectivity matrices + sample: + Suptitle1 + corr_mat + 0.00 0.31 0.76 + 0.31 0.00 0.43 + 0.76 0.43 0.00 + measure_lst + ContinuousHMM + Windowless + Clustering_pear_corr + Suptitle1 + corr_mat + 0.00 0.32 0.76 + 0.32 0.00 0.45 + 0.76 0.45 0.00 + measure_lst + ContinuousHMM + Windowless + Clustering_pear_corr + ''' + + if name_lst_key is None: + fig_width = 25*(len(data)/10) + else: + fig_width = 35*(len(data)/10) + 4 + fig_height = 10 + + fig, axs = plt.subplots(1, len(data), figsize=(fig_width, fig_height), \ + facecolor='w', edgecolor='k') + + if not type(axs) is np.ndarray: + axs = np.array([axs]) + + fig.suptitle(title) #, fontsize=20, size=20 + + axs = axs.ravel() + + for i, key in enumerate(data): + + name_lst = None + if not name_lst_key is None: + if type(name_lst_key) is str: + name_lst = data[key][name_lst_key] + if type(name_lst_key) is list: + name_lst = name_lst_key + + if mat_key is None: + C = data[key] + else: + C = data[key][mat_key] + + # C = np.abs(C) # ?????? should we do this? + + if normalize: + C = dFC_mat_normalize(C[None,:,:], global_normalization=False, threshold=0.0)[0] + + if not disp_diag: + C = np.multiply(C, 1-np.eye(len(C))) + C = C + np.mean(C.flatten()) * np.eye(len(C)) + + if np.any(C<0): # ?????? should we do this? + V_MIN = -1 + V_MAX = 1 + else: # ?????? should we do this? + V_MIN = 0 + V_MAX = lim_val + + if not fix_lim: + V_MAX = np.max(C) + V_MIN = np.min(C) + + if name_lst is None: + axs[i].set_axis_off() + + im = axs[i].imshow(C, interpolation='nearest', aspect='equal', cmap=cmap, # 'viridis' or 'jet' + vmin=V_MIN, vmax=V_MAX) + + if not name_lst is None: + axs[i].set_xticks(np.arange(len(name_lst))) + axs[i].set_yticks(np.arange(len(name_lst))) + axs[i].set_xticklabels(name_lst, rotation=90, fontsize=9) + axs[i].set_yticklabels(name_lst, fontsize=9) + axs[i].set_title(key) + + fig.subplots_adjust( + bottom=0.1, \ + top=1.5, \ + left=0.1, \ + right=0.9, + # wspace=0.02, \ + # hspace=0.02\ + ) + + if not name_lst is None: + fig.subplots_adjust( + wspace=0.85 + ) + + if name_lst is None: + cb_ax = fig.add_axes([0.91, 0.75, 0.007, 0.1]) + else: + cb_ax = fig.add_axes([0.91, 0.75, 0.02, 0.1]) + cbar = fig.colorbar(im, cax=cb_ax, shrink=0.8) # shrink=0.8?? + + # # set the colorbar ticks and tick labels + # cbar.set_ticks(np.arange(0, 1.1, 0.5)) + # cbar.set_ticklabels(['0', '0.5', '1']) + + if save_image: + folder = output_root[:output_root.rfind('/')] + if not os.path.exists(folder): + os.makedirs(folder) + plt.savefig(output_root+'.png', \ + dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad \ + ) + plt.close() + else: + plt.show() + +''' +########## bundled brain graph visualizer ########## + +cvsopts = dict(plot_height=400, plot_width=400) + +def thresh_G(G, threshold=None): + + G_copy = deepcopy(G) + + if threshold==None: + sig_edges = find_sig_edges(G_copy, min_num_edge=0) + threshold = G.edges()[sig_edges[-1]]['weight'] + else: + if threshold > 1: + labels = [d["weight"] for (u, v, d) in G_copy.edges(data=True)] + labels.sort() + threshold = labels[-1*threshold] + # sig_edges = find_sig_edges(G_copy, min_num_edge=threshold) + # threshold = G.edges()[sig_edges[-1]]['weight'] + + ebunch = [(u, v) for u, v, d in G_copy.edges(data=True) if np.abs(d['weight']) < threshold] + G_copy.remove_edges_from(ebunch) + + return G_copy + +def nodesplot(nodes, name=None, canvas=None, cat=None): + canvas = ds.Canvas(**cvsopts) if canvas is None else canvas + # aggregator=None if cat is None else ds.count_cat(cat) + # agg=canvas.points(nodes,'x','y',aggregator) + aggc = canvas.points(nodes, 'x', 'y', ds.count_cat('cat')) #ds.by('cat', ds.count()) + + color_key = dict(cat_normal='#FF3333', cat_sig='#00FF00') + + return tf.spread(tf.shade(aggc, color_key=color_key), px=4, name=name) + + +def edgesplot(edges, name=None, canvas=None): + canvas = ds.Canvas(**cvsopts) if canvas is None else canvas + return tf.shade(canvas.line(edges, 'x','y', agg=ds.count()), name=name) + +def graphplot(nodes, edges, name="", canvas=None, cat=None): + + if canvas is None: + xr = nodes.x.min(), nodes.x.max() + yr = nodes.y.min(), nodes.y.max() + canvas = ds.Canvas(x_range=xr, y_range=yr, **cvsopts) + + np = nodesplot(nodes, name + " nodes", canvas, cat) + ep = edgesplot(edges, name + " edges", canvas) + return tf.stack(ep, np, how="over", name=name) + +def ng(graph,name): + graph.name = name + return graph + +def nx_layout(graph, view_degree=0, threshold=0): + # layout = nx.circular_layout(graph) + + # Get node positions + pos = nx.get_node_attributes(graph, 'pos') + for key in pos: + if view_degree==0: + pos[key] = pos[key][:2] + if view_degree==1: + pos[key] = pos[key][1:3] + if view_degree==2: + pos[key] = pos[key][[0, 2]] + + # layout = pos + + cat = list() + for key in graph.nodes(): + cat.append('cat_normal') + # if key in find_sig_nodes(graph): # + # cat.append( 'cat_sig') + # else: + # cat.append('cat_normal') + + data = [[node]+pos[node].tolist()+[cat[i]] for i, node in enumerate(graph.nodes)] + + nodes = pd.DataFrame(data, columns=['id', 'x', 'y','cat']) + nodes.set_index('id', inplace=True) + nodes["cat"]=nodes["cat"].astype("category") + + graph_copy = thresh_G(graph, threshold=threshold) + + edges = pd.DataFrame(list(graph_copy.edges), columns=['source', 'target']) + return nodes, edges + +def nx_plot(graph, name="", view_degree=0, threshold=0): + # print(graph.name, len(graph.edges)) + nodes, edges = nx_layout(graph, view_degree=view_degree, threshold=threshold) + + direct = connect_edges(nodes, edges) + bundled_bw005 = hammer_bundle(nodes, edges) + bundled_bw030 = hammer_bundle(nodes, edges, initial_bandwidth=0.30) + bundled_bw100 = hammer_bundle(nodes, edges, initial_bandwidth=1) + + return [graphplot(nodes, direct, graph.name, cat=None), + graphplot(nodes, bundled_bw005, "Bundled bw=0.05", cat=None), + graphplot(nodes, bundled_bw030, "Bundled bw=0.30", cat=None), + graphplot(nodes, bundled_bw100, "Bundled bw=1.00", cat=None)] + +def batch_Adj2Net(FCS, nodes_info, is_digraph=False): + + np.fill_diagonal(FCS, 0) + if is_digraph: + G = nx.from_numpy_matrix(FCS, create_using=nx.DiGraph) + else: + G = nx.from_numpy_matrix(FCS) + + mapping = {} + for i, node_info in enumerate(nodes_info): + mapping[i] = node_info[4] + G = nx.relabel_nodes(G, mapping) + + return G + +def set_locs_G(G, locs): + + G_copy = deepcopy(G) + + pos = nx.circular_layout(G_copy) + + for i, key in enumerate(pos): + pos[key] = locs[i][0] + + nx.set_node_attributes(G_copy, pos, "pos") + + + return G_copy + +def visulize_brain_graph(FCS, nodes_info, locs, num_edges2show): + G = batch_Adj2Net(FCS=FCS, nodes_info=nodes_info, is_digraph=False) + G = set_locs_G(G, locs=locs) + plots = [nx_plot(ng(G, name="dFC"), view_degree=0, threshold=num_edges2show) ] + + return plots[0][0] + + +############################## +''' + +def dFC_dict_normalize(D, global_normalization=False, threshold=0.0): + + C = list() + for key in D: + C.append(D[key]) + C = np.array(C) + + C_z = dFC_mat_normalize(C, \ + global_normalization=global_normalization, \ + threshold=threshold \ + ) + + D_z = {} + for i, key in enumerate(D): + D_z[key] = C_z[i,:,:] + + return D_z + +def dFC_mat_normalize(C_t, global_normalization=False, threshold=0.0): + + # threshold is ratio of connections wanted to be zero + C_t_z = deepcopy(C_t) + if len(C_t_z.shape)<3: + C_t_z = np.expand_dims(C_t_z, axis=0) + + if global_normalization: + + # transform the whole abs(dFC mat) to [0, 1] + + signs = np.sign(C_t_z) + C_t_z = np.abs(C_t_z) + + miN = list() + for i in range(C_t_z.shape[0]): + slice = C_t_z[i,:,:] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + miN.append(np.min(slice_non_diag)) + + C_t_z = C_t_z - np.min(miN) + + maX = list() + for i in range(C_t_z.shape[0]): + slice = C_t_z[i,:,:] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + maX.append(np.max(slice_non_diag)) + + if np.max(maX) != 0: + C_t_z = np.divide(C_t_z, np.max(maX)) + + # thresholding + d = deepcopy(np.ravel(C_t_z)) + d.sort() + new_threshold = d[int(threshold*len(d))] + C_t_z = np.multiply(C_t_z, (C_t_z>=new_threshold)) + C_t_z = np.multiply(C_t_z, signs) + + else: + + # transform abs of each time slice to [0, 1] + + signs = np.sign(C_t_z) + C_t_z = np.abs(C_t_z) + + for i in range(C_t_z.shape[0]): + slice = C_t_z[i,:,:] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + slice = slice - np.min(slice_non_diag) + slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + if np.max(slice_non_diag) != 0: + slice = np.divide(slice, np.max(slice_non_diag)) + + # thresholding + d = deepcopy(np.ravel(slice)) + d.sort() + new_threshold = d[int(threshold*len(d))] + slice = np.multiply(slice, (slice>=new_threshold)) + + C_t_z[i,:,:] = slice + + C_t_z = np.multiply(C_t_z, signs) + + # removing self connections + for i in range(C_t_z.shape[1]): + C_t_z[:, i, i] = np.mean(C_t_z) # ????????????????? + + return C_t_z + +def print_mat(mat, s=0): + if len(mat.shape)==1: + mat = np.expand_dims(mat, axis=0) + for i in mat: + print('\t'*s, end=" ") + for j in i: + print("{:.2f}".format(j), end=" ") + print() + +def print_dict(t, s=0): + if not isinstance(t,dict) and not isinstance(t,list): + if isinstance(t,np.ndarray): + print_mat(t, s) + else: + if isinstance(t,float): + print('\t'*s+"{:.2f}".format(t)) + else: + print('\t'*s+str(t)) + else: + for key in t: + print('\t'*s+str(key)) + if not isinstance(t,list): + print_dict(t[key],s+1) + +############################# dFC Analyzer class ################################ + +""" + +todo: +- +""" + +class DFC_ANALYZER: + # if self.n_jobs is None => no parallelization + + def __init__(self, analysis_name='', **params): + + self.analysis_name = analysis_name + + self.params = params + if not 'vis_TR_idx' in self.params: + self.params['vis_TR_idx'] = None + if not 'save_image' in self.params: + self.params['save_image'] = False + if not 'output_root' in self.params: + self.params['output_root'] = None + if not 'n_jobs' in self.params: + self.params['n_jobs'] = -1 + if not 'verbose' in self.params: + self.params['verbose'] = 1 + if not 'backend' in self.params: + self.params['backend'] = 'loky' + + self.MEASURES_lst_ = None + self.MEASURES_fit_lst_ = [] + self.MEASURES_name_lst = [] + self.params_methods = {} + self.alter_hparams = {} + + @property + def MEASURES_lst(self): + assert not self.MEASURES_lst_ is None, \ + 'first set the MEASURES_lst!' + return self.MEASURES_lst_ + + @property + def MEASURES_fit_lst(self): + return self.MEASURES_fit_lst_ + + def set_MEASURES_lst(self, MEASURES_lst): + self.MEASURES_lst_ = MEASURES_lst + + def set_MEASURES_fit_lst(self, MEASURES_fit_lst): + self.MEASURES_fit_lst_ = MEASURES_fit_lst + + def measures_initializer(self, MEASURES_name_lst, params_methods, alter_hparams): + + ''' + - this will test values in hyper_params other than + values already in self.params. values in self.params + will be considered the reference + sample: + hyper_params = { \ + 'n_states': [6, 12, 16], \ + 'normalization': [True], \ + 'num_subj': [50, 100, 395], \ + 'num_select_nodes': [50, 100, 333], \ + 'num_time_point': [500, 800, 1200], \ + 'Fs_ratio': [0.50, 1.00, 1.50], \ + 'noise_ratio': [0.00, 0.50, 1.00], \ + 'num_realization': [1, 2, 3], \ + } + + MEASURES_name_lst = ( \ + 'SlidingWindow', \ + 'Time-Freq', \ + 'CAP', \ + 'ContinuousHMM', \ + 'Windowless', \ + 'Clustering', \ + 'DiscreteHMM' \ + ) + ''' + + self.MEASURES_name_lst = MEASURES_name_lst + self.params_methods = params_methods + self.alter_hparams = alter_hparams + + # a list of MEASURES with default parameter values + MEASURES_lst = self.create_measure_obj(MEASURES_name_lst=MEASURES_name_lst, **params_methods) + + # adding MEASURES with alternative parameter values + for hyper_param in alter_hparams: + params = deepcopy(params_methods) + for value in alter_hparams[hyper_param]: + params[hyper_param] = value + new_MEASURES = self.create_measure_obj(MEASURES_name_lst=MEASURES_name_lst, **params) + for new_measure in new_MEASURES: + flag=0 + for MEASURE in MEASURES_lst: + if new_measure.issame(MEASURE): + flag=1 + if flag==0: + MEASURES_lst.append(new_measure) + + return MEASURES_lst + + def create_measure_obj(self, MEASURES_name_lst, **params): + + MEASURES_lst = list() + for MEASURES_name in MEASURES_name_lst: + + ###### CAP ###### + if MEASURES_name=='CAP': + measure = CAP(**params) + + ###### CONTINUOUS HMM ###### + if MEASURES_name=='ContinuousHMM': + measure = HMM_CONT(**params) + + ###### WINDOW_LESS ###### + if MEASURES_name=='Windowless': + measure = WINDOWLESS(**params) + + ###### SLIDING WINDOW ###### + if MEASURES_name=='SlidingWindow': + measure = SLIDING_WINDOW(**params) + + ###### TIME FREQUENCY ###### + if MEASURES_name=='Time-Freq': + measure = TIME_FREQ(**params) + + ###### SLIDING WINDOW + CLUSTERING ###### + if MEASURES_name=='Clustering': + measure = SLIDING_WINDOW_CLUSTR(**params) + + ###### DISCRETE HMM ###### + if MEASURES_name=='DiscreteHMM': + measure = HMM_DISC(**params) + + MEASURES_lst.append(measure) + + return MEASURES_lst + + def SB_MEASURES_lst(self, MEASURES_lst): # returns state_based measures + SB_MEASURES = list() + for measure in MEASURES_lst: + if measure.is_state_based: + SB_MEASURES.append(measure) + return SB_MEASURES + + def DD_MEASURES_lst(self, MEASURES_lst): # returns data_driven measures + DD_MEASURES = list() + for measure in MEASURES_lst: + if not measure.is_state_based: + DD_MEASURES.append(measure) + return DD_MEASURES + + ##################### MEASURE CHARACTERISTICS ###################### + + def dFCM_var(self, MEASURES_dFCM): + + MEASURES_dFC_var = {} + for measure in MEASURES_dFCM: + dFC_mat = MEASURES_dFCM[measure].get_dFC_mat(TRs = MEASURES_dFCM[measure].TR_array) + V = np.var(dFC_mat, axis=0) + MEASURES_dFC_var[measure] = V + return MEASURES_dFC_var + + ##################### POST ANALYSIS ###################### + + ##################### FCS ESTIMATION ###################### + + def estimate_group_FCS(self, time_series_dict): + + # time_series_dict is a dict of time_series + + for session in time_series_dict: + + time_series = time_series_dict[session] + SB_MEASURES_lst = self.SB_MEASURES_lst(self.MEASURES_lst) + if self.params['n_jobs'] is None: + SB_MEASURES_lst_NEW = list() + for measure in SB_MEASURES_lst: + SB_MEASURES_lst_NEW.append( \ + measure.estimate_FCS(time_series=time_series) \ + ) + else: + SB_MEASURES_lst_NEW = Parallel( \ + n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ + delayed(measure.estimate_FCS)(time_series=time_series) \ + for measure in SB_MEASURES_lst) + self.MEASURES_fit_lst_[session] = self.DD_MEASURES_lst(self.MEASURES_lst) + SB_MEASURES_lst_NEW + + ##################### dFCM ASSESSMENT ###################### + + def group_dFCM_assess(self, time_series_dict): + + # time_series_dict is a dict of time_series + + SUBJ_s_dFCM_dict = {} + + SUBJECTs = common_subj_lst(time_series_dict) + + if self.params['n_jobs'] is None: + OUT = list() + for subject in SUBJECTs: + OUT.append( \ + self.subj_lvl_dFC_assess( \ + time_series_dict=get_subj_ts_dict(time_series_dict, subjs_id=subject), \ + )) + else: + OUT = Parallel( \ + n_jobs=self.params['n_jobs'], \ + verbose=self.params['verbose'], \ + backend=self.params['backend'])( \ + delayed(self.subj_lvl_dFC_assess)( \ + time_series_dict=get_subj_ts_dict(time_series_dict, subjs_id=subject), \ + ) \ + for subject in SUBJECTs) + + return OUT + + def subj_lvl_dFC_assess(self, time_series_dict): + + # time_series_dict is a dict of time_series + + dFCM_dict = {} + # dFC_corr_assess_dict = {} + + if self.params['n_jobs'] is None: + dFCM_lst = list() + for measure in self.MEASURES_fit_lst_: + dFCM_lst.append( \ + measure.estimate_dFCM(time_series=time_series_dict[measure.params['session']]) \ + ) + else: + dFCM_lst = Parallel( \ + n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ + delayed(measure.estimate_dFCM)(time_series=time_series_dict[measure.params['session']]) \ + for measure in self.MEASURES_fit_lst_) + + # for session in time_series_dict: + # dFCM_dict[session] = {} + # time_series = time_series_dict[session] + # if self.params['n_jobs'] is None: + # dFCM_lst = list() + # for measure in self.MEASURES_fit_lst_[session]: + # dFCM_lst.append( \ + # measure.estimate_dFCM(time_series=time_series) \ + # ) + # else: + # dFCM_lst = Parallel( \ + # n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ + # delayed(measure.estimate_dFCM)(time_series=time_series) \ + # for measure in self.MEASURES_fit_lst_[session]) + + # dFC_corr_assess_dict[session] = self.dFC_corr_assess(dFCM_lst=dFCM_lst) + + dFCM_dict['dFCM_lst'] = dFCM_lst + + # SUBJ_output['dFC_corr_assess_dict'] = dFC_corr_assess_dict + + return dFCM_dict + + + ##################### dFC CHARACTERISTICS ###################### + + def dFC_corr(self, dFCM_i, dFCM_j, TRs=None): + + # returns correlation of dFC measures over time + + if TRs is None: + TRs = TR_intersection([dFCM_i, dFCM_j]) + dFC_mat_i = dFCM_i.get_dFC_mat(TRs=TRs) + dFC_mat_j = dFCM_j.get_dFC_mat(TRs=TRs) + corr = list() + for t in range(len(TRs)): + corr.append(np.corrcoef(dFC_mat2vec(dFC_mat_i[t,:,:]), dFC_mat2vec(dFC_mat_j[t,:,:]))[0,1]) + corr= np.array(corr) + return corr + + def dFCM_lst_corr(self, dFCM_lst, common_TRs=None, a=0.1): + # a is portion of the dFCs to ignore from + # the beginning and the end + + if common_TRs is None: + common_TRs = TR_intersection(dFCM_lst) + + corr_mat = np.zeros((len(dFCM_lst), len(dFCM_lst))) + for i in range(len(dFCM_lst)): + for j in range(i+1, len(dFCM_lst)): + + corr_ij = self.dFC_corr( \ + dFCM_lst[i], dFCM_lst[j], \ + TRs=common_TRs \ + ) + corr_mat[i,j] = np.mean(corr_ij[ \ + int(len(corr_ij)*a) : int(len(corr_ij)*(1-a)) \ + ]) + corr_mat[j,i] = corr_mat[i,j] + + return corr_mat + + def FO_calc(self, dFCM_lst, common_TRs=None): + + # returns, for each state the Fractional Occupancy (FO) + # see Visaure et al., 2017 + # it only considers TRs in common_TRs + + if common_TRs is None: + common_TRs = TR_intersection(dFCM_lst) + + FO_list = list() + for dFCM in dFCM_lst: + + FO = {} + + if dFCM.measure.is_state_based: + + state_act_dict = dFCM.state_act_dict(TRs=common_TRs) + + for FCS_key in state_act_dict['state_TC']: + FO[FCS_key] = np.mean(state_act_dict['state_TC'][FCS_key]['act_TC']) + + FO_list.append(FO) + + return FO_list + + def COM_calc(self, dFCM_lst, common_TRs=None, lag=0): + # returns co-occurance (CO) with specified lag, a dict with: + # CO['Obs_seq'] + # CO['FCS_name'] + # CO['COM'] + # automatically ignores DD methods + + # TODO: DOWNSAMPLING problem and ignoring the between state transitions + + if common_TRs is None: + common_TRs = TR_intersection(dFCM_lst) + + TRs_lst = list() + for TR in common_TRs: + TRs_lst.append('TR'+str(TR)) + + # list of FCS names + ############## when combining COMs check if the order of FCS names is the same + FCS_name_lst = list() + for i, dFCM in enumerate(dFCM_lst): + if dFCM.measure.is_state_based: + for FCS in dFCM.FCSs: + FCS_name_lst.append('measure_'+str(i)+'_'+FCS) + # print(FCS_name_lst) + + # building the observation sequence + Obs_seq = list() + for TR in TRs_lst: + Obs_vec = list() + for i, dFCM in enumerate(dFCM_lst): + if dFCM.measure.is_state_based: + Obs_vec.append('measure_' + str(i) + '_' + dFCM.FCS_idx[TR]) + Obs_seq.append(Obs_vec) + # print(Obs_seq) + + # computing COM + flag = 0 + CO = {} + CO['Obs_seq'] = Obs_seq + CO['FCS_name'] = FCS_name_lst + CO['COM'] = np.zeros((len(FCS_name_lst), len(FCS_name_lst))) + for i, TR in enumerate(TRs_lst): + if i>=lag: + for last_FCS in CO['Obs_seq'][i-lag]: + for current_FCS in CO['Obs_seq'][i]: + CO['COM'][CO['FCS_name'].index(last_FCS), CO['FCS_name'].index(current_FCS)] += 1 + + return CO + + + def transition_freq(self, dFCM_lst, common_TRs=None): + # returns the number of total state transition within common_TRs -> trans_freq + # and the number of total state transitions regardless of common_TRs + # but normalized by total number of TRs -> trans_norm + + if common_TRs is None: + common_TRs = TR_intersection(dFCM_lst) + + TRs_lst = list() + for TR in common_TRs: + TRs_lst.append('TR'+str(TR)) + + trans_freq_lst = list() + for dFCM in dFCM_lst: + + trans_freq_dict = {} + + if dFCM.measure.is_state_based: + + trans_freq = 0 + last_TR = None + for TR in dFCM.FCS_idx: + if TR in TRs_lst: + if not last_TR is None: + if dFCM.FCS_idx[TR]!=dFCM.FCS_idx[last_TR]: + trans_freq += 1 + last_TR = TR + + trans_freq_dict['trans_freq'] = trans_freq + + trans_norm = 0 + last_TR = None + for TR in dFCM.FCS_idx: + if not last_TR is None: + if dFCM.FCS_idx[TR]!=dFCM.FCS_idx[last_TR]: + trans_norm += 1 + last_TR = TR + trans_norm = trans_norm / len(dFCM.FCS_idx) + + trans_freq_dict['trans_norm'] = trans_norm + + trans_freq_lst.append(trans_freq_dict) + + return trans_freq_lst + + def dFC_distance(self, FC_t_i, FC_t_j, metric, normalize=True): + ''' + FC_t_i and FC_t_j must be an + array of FC matrices = (n_time, n_regions, n_regions) + metric options: correlation, euclidean, ECM (Eigenvector Centrality Mapping) + normalize option is for ECM and euclidean metrics since correlation is already + normalized. + for ECM, the input can be an array of ECM_vecs and of shape (n_time, n_regions) + or array of FC matrices = (n_time, n_regions, n_regions) + ''' + assert len(FC_t_i)==len(FC_t_j),\ + 'the inputs must of the same number of samples' + + distance_out = list() + for t in range(FC_t_i.shape[0]): + + if metric=='correlation' or metric=='euclidean': + assert FC_t_i[t].shape[0]==FC_t_i[t].shape[1],\ + 'Matrices are not square' + assert FC_t_j[t].shape[0]==FC_t_j[t].shape[1],\ + 'Matrices are not square' + + if metric=='correlation': + FC_vec_i = dFC_mat2vec(FC_t_i[t]) + FC_vec_j = dFC_mat2vec(FC_t_j[t]) + distance_out.append(distance.correlation(FC_vec_i, FC_vec_j)) + + if metric=='euclidean': + FC_vec_i = dFC_mat2vec(FC_t_i[t]) + FC_vec_j = dFC_mat2vec(FC_t_j[t]) + if normalize: + distance_out.append(normalized_euc_dist(FC_vec_i, FC_vec_j)) + else: + distance_out.append(distance.euclidean(FC_vec_i, FC_vec_j)) + + if metric=='ECM': + if len(FC_t_i[t].shape)==2: + assert FC_t_i[t].shape[0]==FC_t_i[t].shape[1],\ + 'Matrices are not square' + assert FC_t_j[t].shape[0]==FC_t_j[t].shape[1],\ + 'Matrices are not square' + ECM_i = calc_ECM(np.abs(FC_t_i[t])) + ECM_j = calc_ECM(np.abs(FC_t_j[t])) + else: + ECM_i = FC_t_i[t] + ECM_j = FC_t_j[t] + if normalize: + distance_out.append(normalized_euc_dist(ECM_i, ECM_j)) + else: + distance_out.append(distance.euclidean(ECM_i, ECM_j)) + + return np.array(distance_out) + + # #regression + # y = dFC_vec_j[t] + # xx = FCS_vecs_new_order + # reg = LinearRegression().fit(xx.T, y.T) + # reg_dist.append(reg.coef_) + + def dFCM_lst_distance(self, dFCM_lst, metric, common_TRs=None, normalize=True): + + if common_TRs is None: + common_TRs = TR_intersection(dFCM_lst) + + distance_mat = np.zeros((len(common_TRs), len(dFCM_lst), len(dFCM_lst))) + for i, dFCM_i in enumerate(dFCM_lst): + for j, dFCM_j in enumerate(dFCM_lst): + dFC_mat_i = dFCM_i.get_dFC_mat(TRs=common_TRs) + dFC_mat_j = dFCM_j.get_dFC_mat(TRs=common_TRs) + distance_mat[:, i, j] = self.dFC_distance(\ + FC_t_i=dFC_mat_i, \ + FC_t_j=dFC_mat_j, \ + metric=metric, \ + normalize=normalize\ + ) + return distance_mat + + def dFCM_lst_var(self, dFCM_lst, metric, common_TRs=None, normalize=True): + + if common_TRs is None: + common_TRs = TR_intersection(dFCM_lst) + + dFC_mat_avg = None + for i, dFCM_i in enumerate(dFCM_lst): + dFC_mat_i = dFCM_i.get_dFC_mat(TRs=common_TRs) + if dFC_mat_avg is None: + dFC_mat_avg = dFC_mat_normalize(C_t=dFC_mat_i, global_normalization=True) + else: + dFC_mat_avg += dFC_mat_normalize(C_t=dFC_mat_i, global_normalization=True) + dFC_mat_avg = np.divide(dFC_mat_avg, len(dFCM_lst)) + + distance_var_mat = np.zeros((len(common_TRs), len(dFCM_lst))) + for i, dFCM_i in enumerate(dFCM_lst): + dFC_mat_i = dFCM_i.get_dFC_mat(TRs=common_TRs) + distance_var_mat[:, i] = self.dFC_distance(\ + FC_t_i=dFC_mat_i, \ + FC_t_j=dFC_mat_avg, \ + metric=metric, \ + normalize=normalize\ + ) + return distance_var_mat + + def post_analysis(self, dFCM_lst, analysis_name_lst): + ''' + analysis_name_lst = [ \ + 'corr_mat', \ + 'dFC_distance', \ + 'dFC_distance_var', \ + 'FO', \ + 'CO', \ + 'TP', \ + 'trans_freq' \ + ] + ''' + + measure_lst = list() + TS_info_lst = list() + for dFCM in dFCM_lst: + measure_lst.append(dFCM.measure) + TS_info_lst.append(dFCM.TS_info) + + common_TRs = TR_intersection(dFCM_lst) + + ########## dFCM corr ########## + # returns averaged correlation of dFC measures + + corr_mat = [] + if 'corr_mat' in analysis_name_lst: + corr_mat = self.dFCM_lst_corr(dFCM_lst, \ + common_TRs=common_TRs, \ + a=0.1 \ + ) + + ########## distance calc ########## + + dFC_distance = {} + if 'dFC_distance' in analysis_name_lst: + dFC_distance['euclidean'] = self.dFCM_lst_distance(\ + dFCM_lst, \ + metric='euclidean', \ + common_TRs=common_TRs, \ + normalize=True \ + ) + dFC_distance['correlation'] = self.dFCM_lst_distance(\ + dFCM_lst, \ + metric='correlation', \ + common_TRs=common_TRs, \ + normalize=True \ + ) + dFC_distance['ECM'] = self.dFCM_lst_distance(\ + dFCM_lst, \ + metric='ECM', \ + common_TRs=common_TRs, \ + normalize=True \ + ) + + ########## distance var calc ########## + + dFC_distance_var = {} + if 'dFC_distance_var' in analysis_name_lst: + dFC_distance_var['euclidean'] = self.dFCM_lst_var(\ + dFCM_lst, \ + metric='euclidean', \ + common_TRs=common_TRs, \ + normalize=True \ + ) + dFC_distance_var['correlation'] = self.dFCM_lst_var(\ + dFCM_lst, \ + metric='correlation', \ + common_TRs=common_TRs, \ + normalize=True \ + ) + dFC_distance_var['ECM'] = self.dFCM_lst_var(\ + dFCM_lst, \ + metric='ECM', \ + common_TRs=common_TRs, \ + normalize=True \ + ) + + ########## Fractional Occupancy ########## + + FO_lst = [] + if 'FO' in analysis_name_lst: + FO_lst = self.FO_calc(dFCM_lst, \ + common_TRs=common_TRs \ + ) + + ########## Co-Occurance Matrix and Transition Probability Matrix ########## + + CO = {} + if 'CO' in analysis_name_lst: + CO = self.COM_calc(dFCM_lst, \ + common_TRs=common_TRs, \ + lag=0 \ + ) + + TP = {} + if 'TP' in analysis_name_lst: + TP = self.COM_calc(dFCM_lst, \ + common_TRs=common_TRs, \ + lag=1 \ + ) + + ########## transition frequency ########## + + trans_freq_lst = [] + if 'trans_freq' in analysis_name_lst: + trans_freq_lst = self.transition_freq(dFCM_lst, \ + common_TRs=common_TRs \ + ) + + ############################################## + + methods_assess = {} + methods_assess['measure_lst'] = measure_lst + methods_assess['TS_info_lst'] = TS_info_lst + methods_assess['common_TRs'] = common_TRs + methods_assess['corr_mat'] = corr_mat + methods_assess['dFC_distance'] = dFC_distance + methods_assess['dFC_distance_var'] = dFC_distance_var + # methods_assess['state_match'] = state_match + methods_assess['FO'] = FO_lst + methods_assess['CO'] = CO + methods_assess['TP'] = TP + methods_assess['trans_freq'] = trans_freq_lst + + return methods_assess + + def visualize_dFCMs(self, dFCM_lst=None, TR_idx=None, normalize=True, threshold=0.0, \ + fix_lim=True, subj_id=''): + + # TR_idx is not TR values, but their indices! + + TRs = TR_intersection(dFCM_lst) + if not TR_idx is None: + assert not np.any(np.array(TR_idx)>=len(TRs)), \ + 'TR_idx out of range.' + TRs = [TRs[i] for i in TR_idx] + + for dFCM in dFCM_lst: + if self.params['save_image']: + output_root = self.params['output_root']+'dFC/' + dFCM.visualize_dFC(TRs=TRs, normalize=normalize, threshold=threshold, \ + fix_lim=fix_lim, \ + save_image=self.params['save_image'], \ + fig_name= output_root+'subject'+subj_id+'_'+dFCM.measure.measure_name+'_dFC') + else: + dFCM.visualize_dFC(TRs=TRs, normalize=normalize, threshold=threshold, fix_lim=fix_lim) + + def visualize_FCS(self, normalize=True, threshold=0.0): + + for session in self.MEASURES_fit_lst: + for measure in self.MEASURES_fit_lst[session]: + if self.params['save_image']: + output_root = self.params['output_root'] + 'FCS/' + measure.visualize_FCS(normalize=normalize, threshold=threshold, save_image=True, \ + fig_name= output_root + measure.measure_name+'_FCS_'+session) + # measure.visualize_TPM(normalize=normalize) + else: + measure.visualize_FCS(normalize=normalize, threshold=threshold) # normalize? + # measure.visualize_TPM(normalize=normalize) + +################################# dFC class #################################### + +""" + +todo: +- separate the matrix visualizing function +- brain or brain graph class +- add an updating behavior -> we can segment subjects and time_series and update the model gradually ? +- type annotation +- remove sliding window type dFC visualization +- normalization: C_t_z[:, i, i] = np.mean(C_t_z) # ????????????????? +""" + +class dFC: + + TF_methods_name_lst = [ \ + 'CWT_mag', \ + 'CWT_phase_r', \ + 'CWT_phase_a', \ + 'WTC' \ + ] + + sw_methods_name_lst = [ \ + 'pear_corr', \ + 'MI', \ + 'GraphLasso', \ + ] + + base_methods_name_lst = ['SlidingWindow', 'Time-Freq'] + + def __init__(self): + self.measure_name = '' + self.is_state_based = bool() + self._stat = [] + self.TPM = [] + self.params = {} + self.TS_info_ = {} + self.FCS_fit_time_ = None + self.dFC_assess_time_ = None + + @property + def FCS_fit_time(self): + return self.FCS_fit_time_ + + @property + def dFC_assess_time(self): + return self.dFC_assess_time_ + + @property + def TS_info(self): + # info of the time series used to train/estimate FCSs + return self.TS_info_ + + @property + def is_state_based(self): + return self.params['is_state_based'] + + @property + def FCS(self): + return self.FCS_ + + # test + @property + def FCS_dict(self): + # returns a dict inclusing each FCS to be fed to similarity assess + + if not self.is_state_based: + return None + + C_A = self.FCS + state_act_dict = {} + state_act_dict['state_TC'] = {} + for k in range(C_A.shape[0]): + state_act_dict['state_TC']['FCS'+str(k+1)] = {} + state_act_dict['state_TC']['FCS'+str(k+1)]['FCS'] = C_A[k,:,:] + + return state_act_dict + + @property + def info(self): + print_dict(self.params) + + def issame(self, dFC): + if type(self)==type(dFC): + for param_name in self.params: + if self.params[param_name] != dFC.params[param_name]: + return False + else: + return False + return True + + #test + def param_match(self, **param_dict): + for param in param_dict: + if param in self.params: + if type(param_dict[param]) is list: + if not self.params[param] in param_dict[param]: + return False + else: + if self.params[param]!=param_dict[param]: + return False + return True + + def set_FCS_fit_time(self, time): + self.FCS_fit_time_ = time + + def set_dFC_assess_time(self, time): + self.dFC_assess_time_ = time + + def estimate_FCS(self, time_series=None): + pass + + def estimate_dFCM(self, time_series=None): + pass + + def manipulate_time_series4FCS(self, time_series): + + new_time_series = deepcopy(time_series) + + # SUBJECTs + new_time_series.select_subjs(num_subj=self.params['num_subj']) + # SPATIAL RESOLUTION + new_time_series.spatial_downsample(num_select_nodes=self.params['num_select_nodes'], rand_node_slct=True) + # TEMPORAL RESOLUTION + new_time_series.Fs_resample(Fs_ratio=self.params['Fs_ratio']) + # NORMALIZE + if self.params['normalization']: + new_time_series.normalize() + # NOISE + new_time_series.add_noise(noise_ratio=self.params['noise_ratio'], mean_noise=0) + # NUMBER OF TIME POINTS + new_time_series.truncate(start_point=0, end_point=self.params['num_time_point']-1) + + self.TS_info_ = new_time_series.info_dict + + return new_time_series + + def manipulate_time_series4dFC(self, time_series): + + new_time_series = deepcopy(time_series) + + # SPATIAL RESOLUTION + new_time_series.spatial_downsample(num_select_nodes=self.params['num_select_nodes'], rand_node_slct=True) + # TEMPORAL RESOLUTION + new_time_series.Fs_resample(Fs_ratio=self.params['Fs_ratio']) + # NORMALIZE + if self.params['normalization']: + new_time_series.normalize() + # NOISE + new_time_series.add_noise(noise_ratio=self.params['noise_ratio'], mean_noise=0) + # NUMBER OF TIME POINTS + new_time_series.truncate(start_point=0, end_point=self.params['num_time_point']-1) + + return new_time_series + + def visualize_states(self): + pass + + # todo : use FCS_dict func in this func + def visualize_FCS(self, normalize=True, threshold=0.0, save_image=False, fig_name=None): + + if self.FCS == []: + return + + if normalize: + C = dFC_mat_normalize(C_t=self.FCS, threshold=threshold) + else: + C = self.FCS + + FCS_dict = {} + for i in range(C.shape[0]): + FCS_dict['FCS '+str(i+1)] = C[i] + + visualize_conn_mat(data=FCS_dict, \ + title=self.measure_name+' FCS', \ + save_image=save_image, \ + output_root=fig_name, \ + fix_lim=True \ + ) + + def visualize_TPM(self, normalize=True, save_image=False, output_root=None): + + if self.TPM == []: + return + if normalize: + C = dFC_mat_normalize(C_t=np.expand_dims(self.TPM, axis=0), threshold=0.0) + else: + C = np.expand_dims(self.TPM, axis=0) + + plt.figure(figsize=(5, 5)) + plt.imshow(np.squeeze(C), interpolation='nearest', aspect='equal', cmap='jet') + cb=plt.colorbar(shrink=0.8) + plt.title(self.measure_name + ' TPM') + + if save_image: + folder = output_root[:output_root.rfind('/')] + if not os.path.exists(folder): + os.makedirs(folder) + plt.savefig(output_root+'.png', \ + dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad \ + ) + plt.close() + else: + plt.show() + + +################################## NEW METHOD ################################## + +''' +by : web link + +Reference: ## + +Parameters + ---------- + y1, y2 : numpy.ndarray, list + Input signals. + dt : float + Sample spacing. + +todo: + +import needed_toolbox + +class method_name(dFC): + + def __init__(self, **params): + self.FCS_ = [] + + self.params_name_lst = ['measure_name', 'is_state_based', 'n_states', \ + 'normalization', 'num_subj', 'num_select_nodes', 'num_time_point', \ + 'Fs_ratio', 'noise_ratio', 'num_realization', 'session'] + self.params = {} + for params_name in self.params_name_lst: + if params_name in params: + self.params[params_name] = params[params_name] + + self.params['specific_param'] = value + self.params['measure_name'] = 'method_name' + self.params['is_state_based'] = True/False + + @property + def measure_name(self): + return self.params['measure_name'] + + def estimate_FCS(self, time_series): + + assert type(time_series) is TIME_SERIES, \ + "time_series must be of TIME_SERIES class." + + time_series = self.manipulate_time_series4FCS(time_series) + + # start timing + tic = time.time() + + # calc FCSs + + # record time + self.set_FCS_fit_time(time.time() - tic) + + return self + + def estimate_dFCM(self, time_series): + + assert type(time_series) is TIME_SERIES, \ + "time_series must be of TIME_SERIES class." + + time_series = self.manipulate_time_series4dFC(time_series) + + # start timing + tic = time.time() + + # calc FCSs and FCS_idx + + # record time + self.set_dFC_assess_time(time.time() - tic) + + dFCM = DFCM(measure=self) + dFCM.set_dFC(FCSs=self.FCS_, FCS_idx=FCS_idx, TS_info=time_series.info_dict) + return dFCM +''' + +################################## CAP ################################## + +''' +by : web link + +Reference: ## + +Parameters + ---------- + y1, y2 : numpy.ndarray, list + Input signals. + dt : float + Sample spacing. + +todo: +''' +from sklearn.cluster import KMeans + +class CAP(dFC): + + def __init__(self, **params): + self.FCS_ = [] + self.FCS_fit_time_ = None + self.dFC_assess_time_ = None + + self.params_name_lst = ['measure_name', 'is_state_based', 'n_states', \ + 'n_subj_clstrs', 'normalization', 'num_subj', 'num_select_nodes', 'num_time_point', \ + 'Fs_ratio', 'noise_ratio', 'num_realization', 'session'] + self.params = {} + for params_name in self.params_name_lst: + if params_name in params: + self.params[params_name] = params[params_name] + + self.params['measure_name'] = 'CAP' + self.params['is_state_based'] = True + + @property + def measure_name(self): + return self.params['measure_name'] + + def act_vec2FCS(self, act_vecs): + FCS_ = list() + for act_vec in act_vecs: + FCS_.append(np.multiply(act_vec[:, np.newaxis], act_vec[np.newaxis, :])) + return np.array(FCS_) + + def cluster_act_vec(self, act_vecs, n_clusters): + + kmeans_ = KMeans(n_clusters=n_clusters, n_init=500).fit(act_vecs) + Z = kmeans_.predict(act_vecs) + act_centroids = kmeans_.cluster_centers_ + + return act_centroids, kmeans_ + + def estimate_FCS(self, time_series): + + assert type(time_series) is TIME_SERIES, \ + "time_series must be of TIME_SERIES class." + + time_series = self.manipulate_time_series4FCS(time_series) + + # start timing + tic = time.time() + + # 2-level clustering + SUBJECTs = time_series.subj_id_lst + act_center_1st_level = None + for subject in SUBJECTs: + + act_vecs = time_series.get_subj_ts(subjs_id=subject).data.T + + # test + if act_vecs.shape[0] no parallelization + +todo: + +- consider COI and edge effect in averaging: + => should we truncate the time points having at less than 20 freqs as done in Savva et al. ? + +""" +import pycwt as wavelet + +class TIME_FREQ(dFC): + + def __init__(self, TF_method='WTC', coi_correction=True, **params): + + assert TF_method in self.TF_methods_name_lst, \ + "Time-frequency method not recognized." + + self.TPM = [] + self.FCS_ = [] + self.FCS_fit_time_ = None + self.dFC_assess_time_ = None + + self.params_name_lst = ['measure_name', 'is_state_based', 'TF_method', 'coi_correction', \ + 'n_jobs', 'verbose', 'backend', \ + 'normalization', 'num_select_nodes', 'num_time_point', \ + 'Fs_ratio', 'noise_ratio', 'num_realization', 'session'] + self.params = {} + for params_name in self.params_name_lst: + if params_name in params: + self.params[params_name] = params[params_name] + + self.params['measure_name'] = 'Time-Freq' + self.params['is_state_based'] = False + self.params['TF_method'] = TF_method + self.params['coi_correction'] = coi_correction + + @property + def measure_name(self): + return self.params['measure_name'] # + '_' + self.params['TF_method'] + + def coi_correct(self, X, coi, freqs): + # correct the edge effect in matrix X = [freqs, time] using coi + # if self.coi_correction=True + + if not self.params['coi_correction']: + return X + periods = 1/freqs + periods = np.repeat(periods[:, None], X.shape[1], axis=1) + coi = np.repeat(coi[None, :], X.shape[0], axis=0) + X_corrected = np.multiply(X, (coi>=periods)) + return X_corrected + + def WT_dFC(self, Y1, Y2, Fs, J, s0, dj): + if self.params['TF_method']=='CWT_mag' or self.params['TF_method']=='CWT_phase_r' or self.params['TF_method']=='CWT_phase_a': + # Cross Wavelet Transform + WT_xy, coi, freqs, _ = wavelet.xwt(Y1, Y2, dt=1/Fs, dj=dj, s0=s0, J=J, + significance_level=0.95, wavelet='morlet', normalize=True) + + if self.params['TF_method']=='CWT_mag': + WT_xy_corrected = self.coi_correct(WT_xy, coi, freqs) + wt = np.abs(np.mean(WT_xy_corrected, axis=0)) + + if self.params['TF_method']=='CWT_phase_r' or self.params['TF_method']=='CWT_phase_a': + cosA = np.cos(np.angle(WT_xy)) + sinA = np.sin(np.angle(WT_xy)) + + cosA_corrected = self.coi_correct(cosA, coi, freqs) + sinA_corrected = self.coi_correct(sinA, coi, freqs) + + A = (cosA_corrected + sinA_corrected * 1j) + + if self.params['TF_method']=='CWT_phase_r': + wt = np.abs(np.mean(A, axis=0)) + else: + wt = np.angle(np.mean(A, axis=0)) + + if self.params['TF_method']=='WTC': + # Wavelet Transform Coherence + WT_xy, _, coi, freqs, _ = wavelet.wct(Y1, Y2, dt=1/Fs, dj=dj, s0=s0, J=J, + sig=False, significance_level=0.95, wavelet='morlet', normalize=True) + WT_xy_corrected = self.coi_correct(WT_xy, coi, freqs) + wt = np.abs(np.mean(WT_xy_corrected, axis=0)) + + return wt + + def estimate_dFCM(self, time_series): + + ''' + we assume calc is applied on subjects separately + ''' + + # params + J = 50 # -1 + s0 = 1 # -1 + dj = 1/8 # 1/12 + + assert type(time_series) is TIME_SERIES, \ + "time_series must be of TIME_SERIES class." + + time_series = self.manipulate_time_series4dFC(time_series) + + # start timing + tic = time.time() + + WT = np.zeros((time_series.n_time, \ + time_series.n_regions, time_series.n_regions)) + + for i in range(time_series.n_regions): + if self.params['n_jobs'] is None: + Q = list() + for j in range(time_series.n_regions): + Q.append(self.WT_dFC( \ + Y1=time_series.data[i, :], \ + Y2=time_series.data[j, :], \ + Fs=time_series.Fs, \ + J=J, s0=s0, dj=dj)) + else: + Q = Parallel( \ + n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ + delayed(self.WT_dFC)( \ + Y1=time_series.data[i, :], \ + Y2=time_series.data[j, :], \ + Fs=time_series.Fs, \ + J=J, s0=s0, dj=dj) \ + for j in range(time_series.n_regions) \ + ) + WT[:, i, :] = np.array(Q).T + + # record time + self.set_dFC_assess_time(time.time() - tic) + + dFCM = DFCM(measure=self) + dFCM.set_dFC(FCSs=WT, TS_info=time_series.info_dict) + return dFCM + +################################# Sliding-Window ################################# + +""" + +Parameters + ---------- + y1, y2 : numpy.ndarray, list + Input signals. + dt : float + Sample spacing. + +todo: +""" + +from sklearn.covariance import GraphicalLassoCV + +class SLIDING_WINDOW(dFC): + + def __init__(self, **params): + + self.TPM = [] + self.FCS_ = [] + self.FCS_fit_time_ = None + self.dFC_assess_time_ = None + + self.params_name_lst = ['measure_name', 'is_state_based', 'sw_method', 'tapered_window', \ + 'W', 'n_overlap', 'normalization', \ + 'num_select_nodes', 'num_time_point', 'Fs_ratio', \ + 'noise_ratio', 'num_realization', 'session'] + self.params = {} + for params_name in self.params_name_lst: + if params_name in params: + self.params[params_name] = params[params_name] + + self.params['measure_name'] = 'SlidingWindow' + self.params['is_state_based'] = False + + assert self.params['sw_method'] in self.sw_methods_name_lst, \ + "sw_method not recognized." + + + @property + def measure_name(self): + return self.params['measure_name'] #+ '_' + self.sw_method + + def shan_entropy(self, c): + c_normalized = c / float(np.sum(c)) + c_normalized = c_normalized[np.nonzero(c_normalized)] + H = -sum(c_normalized* np.log2(c_normalized)) + return H + + def calc_MI(self, X, Y): + + bins = 20 + + c_XY = np.histogram2d(X,Y,bins)[0] + c_X = np.histogram(X,bins)[0] + c_Y = np.histogram(Y,bins)[0] + + H_X = self.shan_entropy(c_X) + H_Y = self.shan_entropy(c_Y) + H_XY = self.shan_entropy(c_XY) + + MI = H_X + H_Y - H_XY + return MI + + def FC(self, time_series): + + if self.params['sw_method']=='GraphLasso': + model = GraphicalLassoCV() + model.fit(time_series.T) + C = model.covariance_ + else: + C = np.zeros((time_series.shape[0], time_series.shape[0])) + for i in range(time_series.shape[0]): + for j in range(i, time_series.shape[0]): + + X = time_series[i, :] + Y = time_series[j, :] + + if self.params['sw_method']=='MI': + ########### Mutual Information ############## + C[j, i] = self.calc_MI(X, Y) + else: + ########### Pearson Correlation ############## + C[j, i] = np.corrcoef(X, Y)[0, 1] + + C[i, j] = C[j, i] + + return C + + def dFC(self, time_series, W=None, n_overlap=None, tapered_window=False): + # W is in time samples + + L = time_series.shape[1] + step = int((1-n_overlap)*W) + if step == 0: + step = 1 + + window_taper = signal.windows.gaussian(W, std=3*W/22) + # C = DFCM(measure=self) + FCSs = list() + TR_array = list() + for l in range(0, L-W+1, step): + + ######### creating a rectangel window ############ + window = np.zeros((L)) + window[l:l+W] = 1 + + ########### tapering the window ############## + if tapered_window: + window = signal.convolve(window, window_taper, mode='same') / sum(window_taper) + + window = np.repeat(np.expand_dims(window, axis=0), time_series.shape[0], axis=0) + + # int(l-W/2):int(l+3*W/2) is the nonzero interval after tapering + FCSs.append(self.FC( \ + np.multiply(time_series, window)[ \ + :,max(int(l-W/2),0):min(int(l+3*W/2),L) \ + ] \ + ) + ) + TR_array.append(int((l + (l+W)) / 2) ) + + return np.array(FCSs), np.array(TR_array) + + def estimate_dFCM(self, time_series): + + ''' + we assume calc is applied on subjects separately + ''' + + assert type(time_series) is TIME_SERIES, \ + "time_series must be of TIME_SERIES class." + + time_series = self.manipulate_time_series4dFC(time_series) + + # start timing + tic = time.time() + + # W is converted from sec to samples + FCSs, TR_array = self.dFC(time_series=time_series.data, \ + W=int(self.params['W'] * time_series.Fs) , \ + n_overlap=self.params['n_overlap'], \ + tapered_window=self.params['tapered_window'] \ + ) + + # record time + self.set_dFC_assess_time(time.time() - tic) + + dFCM = DFCM(measure=self) + dFCM.set_dFC(FCSs=FCSs, TR_array=TR_array, TS_info=time_series.info_dict) + + return dFCM + + +########################### Sliding_Window + Clustering ############################ + +""" +- We used a tapered window as in Allen et al., created by convolving a rectangle (width = 22 TRs = 44s) + with a Gaussian (σ = 3 TRs) and slid in steps of 1 TR, resulting in W= 126 windows (Allen et al., 2014). +- Kmeans Clustering is repeated 500 times to escape local minima (Allen et al., 2014) +- for clustering, we have a 2-level kmeans clustering. First, we cluster FCSs of each subject. Then, we + cluster all clustering centers from all subjects. the final estimate_dFCM is using the second kmeans + model (Allen et al., 2014; Ou et al., 2015). + +Parameters + ---------- + y1, y2 : numpy.ndarray, list + Input signals. + dt : float + Sample spacing. + +todo: +- pyclustering(manhattan) has a problem when suing predict +""" + +from sklearn.cluster import KMeans +from pyclustering.cluster.kmeans import kmeans +from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer +from pyclustering.utils.metric import distance_metric, type_metric + +class SLIDING_WINDOW_CLUSTR(dFC): + + def __init__(self, clstr_distance='euclidean', **params): + + assert clstr_distance=='euclidean' or clstr_distance=='manhattan', \ + "Clustering distance not recognized. It must be either \ + euclidean or manhattan." + + self.TPM = [] + self.FCS_ = [] + self.FCS_fit_time_ = None + self.dFC_assess_time_ = None + + self.params_name_lst = ['measure_name', 'is_state_based', 'clstr_base_measure', 'sw_method', 'tapered_window', \ + 'clstr_distance', 'coi_correction', \ + 'n_subj_clstrs', 'W', 'n_overlap', 'n_states', 'normalization', \ + 'n_jobs', 'verbose', 'backend', \ + 'num_subj', 'num_select_nodes', 'num_time_point', 'Fs_ratio', \ + 'noise_ratio', 'num_realization', 'session'] + self.params = {} + for params_name in self.params_name_lst: + if params_name in params: + self.params[params_name] = params[params_name] + + self.params['measure_name'] = 'Clustering' + self.params['is_state_based'] = True + self.params['clstr_distance'] = clstr_distance + + assert self.params['clstr_base_measure'] in self.base_methods_name_lst, \ + "Base method not recognized." + + @property + def measure_name(self): + return self.params['measure_name'] #+ '_' + self.base_method + + def dFC_mat2vec(self, C_t): + return dFC_mat2vec(C_t) + # if len(C_t.shape)==2: + # assert C_t.shape[0]==C_t.shape[1],\ + # 'C is not a square matrix' + # return C_t[np.triu_indices(C_t.shape[1], k=0)] + + # F = list() + # for t in range(C_t.shape[0]): + # C = C_t[t, : , :] + # assert C.shape[0]==C.shape[1],\ + # 'C is not a square matrix' + # F.append(C[np.triu_indices(C_t.shape[1], k=0)]) + + # F = np.array(F) + # return F + + def dFC_vec2mat(self, F, N): + return dFC_vec2mat(F=F, N=N) + # C = list() + # iu = np.triu_indices(N, k=0) + # for i in range(F.shape[0]): + # K = np.zeros((N, N)) + # K[iu] = F[i,:] + # K = K + np.multiply(K.T, 1-np.eye(N)) + # C.append(K) + # C = np.array(C) + # return C + + def clusters_lst2idx(self, clusters): + Z = np.zeros((self.F.shape[0],)) + for i, cluster in enumerate(clusters): + for sample in cluster: + Z[sample] = i + return Z.astype(int) + + def cluster_FC(self, FCS_raw, n_clusters, n_regions): + + F = self.dFC_mat2vec(FCS_raw) + + if self.params['clstr_distance']=='manhattan': + pass + # ########### Manhattan Clustering ############## + # # Prepare initial centers using K-Means++ method. + # initial_centers = kmeans_plusplus_initializer(F, self.n_states).initialize() + # # create metric that will be used for clustering + # manhattan_metric = distance_metric(type_metric.MANHATTAN) + # # Create instance of K-Means algorithm with prepared centers. + # kmeans_ = kmeans(F, initial_centers, metric=manhattan_metric) + # # Run cluster analysis and obtain results. + # kmeans_.process() + # Z = self.clusters_lst2idx(kmeans_.get_clusters()) + # F_cent = np.array(kmeans_.get_centers()) + else: + ########### Euclidean Clustering ############## + kmeans_ = KMeans(n_clusters=n_clusters, n_init=500).fit(F) + Z = kmeans_.predict(F) + F_cent = kmeans_.cluster_centers_ + + FCS_ = self.dFC_vec2mat(F_cent, N=n_regions) + return FCS_, kmeans_ + + + def estimate_FCS(self, time_series): + + assert type(time_series) is TIME_SERIES, \ + "time_series must be of TIME_SERIES class." + + time_series = self.manipulate_time_series4FCS(time_series) + + # start timing + tic = time.time() + + base_dFC = None + if self.params['clstr_base_measure']=='Time-Freq': + base_dFC = TIME_FREQ(**self.params) + if self.params['clstr_base_measure']=='SlidingWindow': + base_dFC = SLIDING_WINDOW(**self.params) + + # 1-level clustering + # dFCM_raw = base_dFC.estimate_dFCM( \ + # time_series=time_series \ + # ) + # self.FCS_, self.kmeans_ = self.cluster_FC( \ + # dFCM_raw.get_dFC_mat(TRs=self.dFCM_raw.TR_array), \ + # n_regions = dFCM_raw.n_regions \ + # ) + + # 2-level clustering + SUBJECTs = time_series.subj_id_lst + FCS_1st_level = None + for subject in SUBJECTs: + + dFCM_raw = base_dFC.estimate_dFCM( \ + time_series=time_series.get_subj_ts(subjs_id=subject) \ + ) + + # test + if dFCM_raw.n_time