From a168b933bbce3ffe1b665d63575187c5b77531ef Mon Sep 17 00:00:00 2001 From: Claire Shao <127550911+isclaireya@users.noreply.github.com> Date: Fri, 27 Jun 2025 05:28:06 +0000 Subject: [PATCH] Add Griffiths2022 example and update depr --- .../griffiths2022/bifurcation_analysis.py | 151 ++++ whobpyt/depr/griffiths2022/plotting.py | 711 ++++++++++++++++++ .../rww_pytorch_model_updated.py | 676 +++++++++++++++++ whobpyt/depr/griffiths2022/wwd_test.py | 151 ++++ 4 files changed, 1689 insertions(+) create mode 100644 whobpyt/depr/griffiths2022/bifurcation_analysis.py create mode 100644 whobpyt/depr/griffiths2022/plotting.py create mode 100644 whobpyt/depr/griffiths2022/rww_pytorch_model_updated.py create mode 100644 whobpyt/depr/griffiths2022/wwd_test.py diff --git a/whobpyt/depr/griffiths2022/bifurcation_analysis.py b/whobpyt/depr/griffiths2022/bifurcation_analysis.py new file mode 100644 index 00000000..2d312db4 --- /dev/null +++ b/whobpyt/depr/griffiths2022/bifurcation_analysis.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +"""bifurcation_analysis.ipynb + +Automatically generated by Colab. + +Original file is located at + https://colab.research.google.com/drive/1-N-GjmkkJdycabTr4S95Z5unyT-EyvtD +""" + +import numpy as np +import matplotlib.pyplot as plt +from scipy.optimize import fsolve + +def h_tf(a, b, d, x): + return (a*x-b)/(1.0000 -np.exp(-d*(a*x-b))) + +def dh_tf(a, b, d, x): + tmp_e = np.exp(-d*(a*x-b)) + tmp_d = 1. - np.exp(-d*(a*x-b)) + slope_E = (a*tmp_d - (a*x-b)*d*a*tmp_e) / tmp_d**2 + return slope_E + +def smooth_normalize(x): + return max(x, 0.000001) + +def derivative_orig(x, param): + E = x.reshape((2, 50))[0] + I = x.reshape((2, 50))[1] + + IE = param["W_E"]*param["I_0"] + param["g_EE"]*E - param["g_IE"]*I + II = param["W_I"]*param["I_0"] + param["g_EI"]*E - I + + rE = h_tf(param["aE"], param["bE"], param["dE"], IE) + rI = h_tf(param["aI"], param["bI"], param["dI"], II) + + ddE = -E / param["tau_E"] + param["gamma_E"] * (1. - E) * rE + ddI = -I / param["tau_I"] + param["gamma_I"] * rI + + return np.array([ddE, ddI]).ravel() + +def derivative(x, param): + E = x[0] + I = x[1] + + IE = param["W_E"]*param["I_0"] + param["g_EE"]*E - param["g_IE"]*I + II = param["W_I"]*param["I_0"] + param["g_EI"]*E - I + + rE = h_tf(param["aE"], param["bE"], param["dE"], IE) + rI = h_tf(param["aI"], param["bI"], param["dI"], II) + + ddE = -E / param["tau_E"] + param["gamma_E"] * (1. - E) * rE + ddI = -I / param["tau_I"] + param["gamma_I"] * rI + + return 10000.0 * np.array([ddE, ddI]) + +def get_eig_sys(E, I, param): + IE = param["W_E"]*param["I_0"] + param["g_EE"]*E - param["g_IE"]*I + II = param["W_I"]*param["I_0"] + param["g_EI"]*E - I + + rE = h_tf(param["aE"], param["bE"], param["dE"], IE) + rI = h_tf(param["aI"], param["bI"], param["dI"], II) + drEdIE = dh_tf(param["aE"], param["bE"], param["dE"], IE) + drIIdII = dh_tf(param["aI"], param["bI"], param["dI"], II) + + A = np.zeros((2, 2)) + A[0, 0] = -1 / param["tau_E"] - param["gamma_E"] * rE + (1 - E) * param["gamma_E"] * drEdIE * param["g_EE"] + A[0, 1] = -(1 - E) * param["gamma_E"] * drEdIE * param["g_IE"] + A[1, 0] = param["gamma_I"] * drIIdII * param["g_EI"] + A[1, 1] = -param["gamma_I"] * drIIdII + + A = np.nan_to_num(A) + d, _ = np.linalg.eig(A) + return d + +def regime_search_I0(I0_rng, gEE_rng, gIE_rng, gEI_rng, param): + num_param = 3 + num_trials = len(gEE_rng) + + c_I0 = [] + for I0 in I0_rng: + param["I_0"] = I0 + c = [] + for i in range(num_trials ** num_param): + ind_0 = i // (num_trials ** (num_param - 1)) + ind_1 = (i % (num_trials ** (num_param - 1))) // num_trials + ind_2 = i % num_trials + + param["g_EE"] = gEE_rng[ind_0] + param["g_IE"] = gIE_rng[ind_1] + param["g_EI"] = gEI_rng[ind_2] + + initial = np.random.uniform(0., 2, [2, 50]) + solns = [] + for j in range(initial.shape[1]): + x0 = initial[:, j] + x0 = np.round(fsolve(lambda x: derivative(x, param), x0), decimals=4) + if (np.abs(derivative(x0, param)) > 1.0).sum() == 0: + solns.append(tuple(x0)) + + good_sols = [] + for sol in set(solns): + sol_good = True + for g_sol in good_sols: + if np.sqrt(((np.array(g_sol)-np.array(sol))**2).mean()) < 1e-3: + sol_good = False + break + if sol_good: + good_sols.append(sol) + c.append(len(good_sols)) + c_I0.append(max(c)) + return c_I0 + +def regime_search_gEE(I0_rng, gEE_rng, gIE_rng, gEI_rng, param): + n_I0 = len(I0_rng) + n_gEE = len(gEE_rng) + n_gIE = len(gIE_rng) + n_gEI = len(gEI_rng) + + c_gEE = [] + for gEE in gEE_rng: + param["g_EE"] = gEE + c = [] + for i in range(n_I0 * n_gIE * n_gEI): + ind_0 = i // (n_gIE * n_gEI) + ind_1 = (i % (n_gIE * n_gEI)) // n_gEI + ind_2 = i % n_gEI + + param["I_0"] = I0_rng[ind_0] + param["g_IE"] = gIE_rng[ind_1] + param["g_EI"] = gEI_rng[ind_2] + + initial = np.random.uniform(0., 2, [2, 50]) + solns = [] + for j in range(initial.shape[1]): + x0 = initial[:, j] + x0 = np.round(fsolve(lambda x: derivative(x, param), x0), decimals=4) + if (np.abs(derivative(x0, param)) > 1.0).sum() == 0: + solns.append(tuple(x0)) + + good_sols = [] + for sol in set(solns): + sol_good = True + for g_sol in good_sols: + if np.sqrt(((np.array(g_sol)-np.array(sol))**2).mean()) < 1e-3: + sol_good = False + break + if sol_good: + good_sols.append(sol) + c.append(len(good_sols)) + c_gEE.append(max(c)) + return c_gEE diff --git a/whobpyt/depr/griffiths2022/plotting.py b/whobpyt/depr/griffiths2022/plotting.py new file mode 100644 index 00000000..61dcd8cd --- /dev/null +++ b/whobpyt/depr/griffiths2022/plotting.py @@ -0,0 +1,711 @@ +# -*- coding: utf-8 -*- +"""plotting.ipynb + +Automatically generated by Colab. + +Original file is located at + https://colab.research.google.com/drive/12FUXZ35ir8sFAQRysSHUpa2ZKoCkCywc +""" + +import numpy as np +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from scipy.optimize import fsolve +import sys +import json +import pandas as pd # for data manipulation +import seaborn as sns # for plotting +import time # for timer +import os +from numpy import exp,sin,cos,sqrt,pi, r_,floor,zeros +import scipy.io +from matplotlib.tri import Triangulation +#from nilearn import plotting, datasets +from matplotlib import cm +from matplotlib.pyplot import subplot +import warnings # for suppressing warnings and output +import nibabel as nib +import matplotlib.pyplot as plt # for plotting +warnings.filterwarnings('ignore') + +#@title get_rotation_matrix (rotation_axis, deg) and get_combined_rotation_matrix (rotations) +def get_rotation_matrix(rotation_axis, deg): + + '''Return rotation matrix in the x,y,or z plane''' + + + + # (note make deg minus to change from anticlockwise to clockwise rotation) + th = -deg * (pi/180) # convert degrees to radians + + if rotation_axis == 0: + return np.array( [[ 1, 0, 0 ], + [ 0, cos(th), -sin(th)], + [ 0, sin(th), cos(th)]]) + elif rotation_axis ==1: + return np.array( [[ cos(th), 0, sin(th)], + [ 0, 1, 0 ], + [ -sin(th), 0, cos(th)]]) + elif rotation_axis ==2: + return np.array([[ cos(th), -sin(th), 0 ], + [ sin(th), cos(th), 0 ], + [ 0, 0, 1 ]]) + + + +def get_combined_rotation_matrix(rotations): + '''Return a combined rotation matrix from a dictionary of rotations around + the x,y,or z axes''' + rotmat = np.eye(3) + + if type(rotations) is tuple: rotations = [rotations] + for r in rotations: + newrot = get_rotation_matrix(r[0],r[1]) + rotmat = np.dot(rotmat,newrot) + return rotmat + +def plot_surface_mpl_mv(vtx=None,tri=None,data=None,rm=None,hemi=None, # Option 1 + vtx_lh=None,tri_lh=None,data_lh=None,rm_lh=None, # Option 2 + vtx_rh=None,tri_rh=None,data_rh=None,rm_rh=None, + title=None,**kwargs): + + r"""Convenience wrapper on plot_surface_mpl for multiple views + + This function calls plot_surface_mpl five times to give a complete + picture of a surface- or region-based spatial pattern. + + As with plot_surface_mpl, this function is written so as to be + generally usable with neuroimaging surface-based data, and does not + require construction of of interaction with tvb datatype objects. + + In order for the medial surfaces to be displayed properly, it is + necessary to separate the left and right hemispheres. This can be + done in one of two ways: + + 1. Provide single arrays for vertices, faces, data, and + region mappings, and addition provide arrays of indices for + each of these (vtx_inds,tr_inds,rm_inds) with 0/False + indicating left hemisphere vertices/faces/regions, and 1/True + indicating right hemisphere. + + Note: this requires that + + 2. Provide separate vertices,faces,data,and region mappings for + each hemisphere (vtx_lh,tri_lh; vtx_rh,tri_rh,etc...) + + + + Parameters + ---------- + + (see also plot_surface_mpl parameters info for more details) + + (Option 1) + + vtx : surface vertices + + tri : surface faces + + data : spatial pattern to plot + + rm : surface vertex to region mapping + + hemi : hemisphere labels for each vertex + (1/True = right, 0/False = left) - + + + OR + + (Option 2) + + vtx_lh : left hemisphere surface_vertices + vtx_rh : right `` `` `` `` + + tri_lh : left hemisphere surface faces + tri_rh : right `` `` `` `` + + data_lh : left hemisphere surface_vertices + data_rh : right `` `` `` `` + + rm_lh : left hemisphere region_mapping + rm_rh : right `` `` `` `` + + + title : title to show above middle plot + + kwargs : additional tripcolor kwargs; see plot_surface_mpl + + + + Examples + ---------- + + # TVB default data + + # Plot one column of the region-based tract lengths + # connectivity matrix. The corresponding region is + # right auditory cortex ('rA1') + + ctx = cortex.Cortex.from_file(source_file = ctx_file, + region_mapping_file =rm_file) + vtx,tri,rm = ctx.vertices,ctx.triangles,ctx.region_mapping + conn = connectivity.Connectivity.from_file(conn_file); conn.configure() + isrh_reg = conn.is_right_hemisphere(range(conn.number_of_regions)) + isrh_vtx = np.array([isrh_reg[r] for r in rm]) + dat = conn.tract_lengths[:,5] + + plot_surface_mpl_mv(vtx=vtx,tri=tri,rm=rm,data=dat, + hemi=isrh_vtx,title=u'rA1 \ntract length') + + plot_surface_mpl_mv(vtx=vtx,tri=tri,rm=rm,data=dat, + hemi=isrh_vtx,title=u'rA1 \ntract length', + shade_kwargs = {'shading': 'gouraud', + 'cmap': 'rainbow'}) + + + """ + + + + if vtx is not None: # Option 1 + tri_hemi = hemi[tri].any(axis=1) + tri_lh,tri_rh = tri[tri_hemi==0],tri[tri_hemi==1] + elif vtx_lh is not None: # Option 2 + vtx = np.vstack([vtx_lh,vtx_rh]) + tri = np.vstack([tri_lh,tri_rh+tri_lh.max()+1]) + + if data_lh is not None: # Option 2 + data = np.hstack([data_lh,data_rh]) + + if rm_lh is not None: # Option 2 + rm = np.hstack([rm_lh,rm_rh + rm_lh.max() + 1]) + + + + # 2. Now do the plots for each view + + # (Note: for the single hemispheres we only need lh/rh arrays for the + # faces (tri); the full vertices, region mapping, and data arrays + # can be given as arguments, they just won't be shown if they aren't + # connected by the faces in tri ) + fig, ax = plt.subplots(2,3, figsize=(6,4)) + # LH lateral + plot_surface_mpl(vtx,tri_lh,data=data,rm=rm,view='lh_lat', + ax=subplot(2,3,1),**kwargs) + + # LH medial + plot_surface_mpl(vtx,tri_lh, data=data,rm=rm,view='lh_med', + ax=subplot(2,3,4),**kwargs) + + # RH lateral + plot_surface_mpl(vtx,tri_rh, data=data,rm=rm,view='rh_lat', + ax=subplot(2,3,3),**kwargs) + + # RH medial + plot_surface_mpl(vtx,tri_rh, data=data,rm=rm,view='rh_med', + ax=subplot(2,3,6),**kwargs) + + # Both superior + im =plot_surface_mpl(vtx,tri, data=data,rm=rm,view='superior', + ax=subplot(1,3,2),title=title,**kwargs) + + plt.subplots_adjust(left=0.0, right=.8, bottom=0.0, + top=.8, wspace=0, hspace=0) + cbar_ax = fig.add_axes([0.85,0.1,0.05,0.6]) + fig.colorbar(im, cax = cbar_ax) + + #@title function for plot brain in one picture + +def plot_surface_mpl(vtx,tri,data=None,rm=None,reorient='tvb',view='superior', + shaded=False,ax=None,figsize=(6,4), title=None, + lthr=None,uthr=None, nz_thr = 1E-20, + shade_kwargs = {'edgecolors': 'k', 'linewidth': 0.1, + 'alpha': None, 'cmap': 'coolwarm', + 'vmin': None, 'vmax': None}): + + r"""Plot surfaces, surface patterns, and region patterns with matplotlib + + This is a general-use function for neuroimaging surface-based data, and + does not necessarily require construction of or interaction with tvb + datatypes. + + See also: plot_surface_mpl_mv + + + + Parameters + ---------- + + vtx : N vertices x 3 array of surface vertex xyz coordinates + + tri : N faces x 3 array of surface faces + + data : array of numbers to colour surface with. Can be either + a pattern across surface vertices (N vertices x 1 array), + or a pattern across the surface's region mapping + (N regions x 1 array), in which case the region mapping + bust also be given as an argument. + + rm : region mapping - N vertices x 1 array with (up to) N + regions unique values; each element specifies which + region the corresponding surface vertex is mapped to + + reorient : modify the vertex coordinate frame and/or orientation + so that the same default rotations can subsequently be + used for image views. The standard coordinate frame is + xyz; i.e. first,second,third axis = left-right, + front-back, and up-down, respectively. The standard + starting orientation is axial view; i.e. looking down on + the brain in the x-y plane. + + Options: + + tvb (default) : swaps the first 2 axes and applies a rotation + + fs : for the standard freesurfer (RAS) orientation; + e.g. fsaverage lh.orig. + No transformations needed for this; so is + gives same result as reorient=None + + view : specify viewing angle. + + This can be done in one of two ways: by specifying a string + corresponding to a standard viewing angle, or by providing + a tuple or list of tuples detailing exact rotations to apply + around each axis. + + Standard view options are: + + lh_lat / lh_med / rh_lat / rh_med / + superior / inferior / posterior / anterior + + (Note: if the surface contains both hemispheres, then medial + surfaces will not be visible, so e.g. 'rh_med' will look the + same as 'lh_lat') + + Arbitrary rotations can be specied by a tuple or a list of + tuples, each with two elements, the first defining the axis + to rotate around [0,1,2], the second specifying the angle in + degrees. When a list is given the rotations are applied + sequentially in the order given. + + Example: rotations = [(0,45),(1,-45)] applies 45 degrees + rotation around the first axis, followed by 45 degrees rotate + around the second axis. + + lthr/uthr : lower/upper thresholds - set to zero any datapoints below / + above these values + + nz_thr : near-zero threshold - set to zero all datapoints with absolute + values smaller than this number. Default is a very small + number (1E-20), which unless your data has very small numbers, + will only mask out actual zeros. + + shade_kwargs : dictionary specifiying shading options + + Most relevant options (see matplotlib 'tripcolor' for full details): + + - 'shading' (either 'gourand' or omit; + default is 'flat') + - 'edgecolors' 'k' = black is probably best + - 'linewidth' 0.1 works well; note that the visual + effect of this will depend on both the + surface density and the figure size + - 'cmap' colormap + - 'vmin'/'vmax' scale colormap to these values + - 'alpha' surface opacity + + ax : figure axis + + figsize : figure size (ignore if ax provided) + + title : text string to place above figure + + + + + Usage + ----- + + + Basic freesurfer example: + + import nibabel as nib + vtx,tri = nib.freesurfer.read_geometry('subjects/fsaverage/surf/lh.orig') + plot_surface_mpl(vtx,tri,view='lh_lat',reorient='fs') + + + + Basic tvb example: + + ctx = cortex.Cortex.from_file(source_file = ctx_file, + region_mapping_file =rm_file) + vtx,tri,rm = ctx.vertices,ctx.triangles,ctx.region_mapping + conn = connectivity.Connectivity.from_file(conn_file); conn.configure() + isrh_reg = conn.is_right_hemisphere(range(conn.number_of_regions)) + isrh_vtx = np.array([isrh_reg[r] for r in rm]) + dat = conn.tract_lengths[:,5] + + plot_surface_mpl(vtx=vtx,tri=tri,rm=rm,data=dat,view='inferior',title='inferior') + + fig, ax = plt.subplots() + plot_surface_mpl(vtx=vtx,tri=tri,rm=rm,data=dat, view=[(0,-90),(1,55)],ax=ax, + title='lh angle',shade_kwargs={'shading': 'gouraud', 'cmap': 'rainbow'}) + + + """ + + # Copy things to make sure we don't modify things + # in the namespace inadvertently. + + vtx,tri = vtx.copy(),tri.copy() + if data is not None: data = data.copy() + + # 1. Set the viewing angle + + if reorient == 'tvb': + # The tvb default brain has coordinates in the order + # yxz for some reason. So first change that: + vtx = np.array([vtx[:,1],vtx[:,0],vtx[:,2]]).T.copy() + + # Also need to reflect in the x axis + vtx[:,0]*=-1 + + # (reorient == 'fs' is same as reorient=None; so not strictly needed + # but is included for clarity) + + + + # ...get rotations for standard view options + + if view == 'lh_lat' : rots = [(0,-90),(1,90) ] + elif view == 'lh_med' : rots = [(0,-90),(1,-90) ] + elif view == 'rh_lat' : rots = [(0,-90),(1,-90) ] + elif view == 'rh_med' : rots = [(0,-90),(1,90) ] + elif view == 'superior' : rots = None + elif view == 'inferior' : rots = (1,180) + elif view == 'anterior' : rots = (0,-90) + elif view == 'posterior' : rots = [(0, -90),(1,180)] + elif (type(view) == tuple) or (type(view) == list): rots = view + + # (rh_lat is the default 'view' argument because no rotations are + # for that one; so if no view is specified when the function is called, + # the 'rh_lat' option is chose here and the surface is shown 'as is' + + + # ...apply rotations + + if rots is None: rotmat = np.eye(3) + else: rotmat = get_combined_rotation_matrix(rots) + vtx = np.dot(vtx,rotmat) + + + + # 2. Sort out the data + + + # ...if no data is given, plot a vector of 1s. + # if using region data, create corresponding surface vector + if data is None: + data = np.ones(vtx.shape[0]) + elif data.shape[0] != vtx.shape[0]: + data = np.array([data[r] for r in rm]) + + # ...apply thresholds + if uthr: data *= (data < uthr) + if lthr: data *= (data > lthr) + data *= (np.abs(data) > nz_thr) + + + # 3. Create the surface triangulation object + + x,y,z = vtx.T + tx,ty,tz = vtx[tri].mean(axis=1).T + tr = Triangulation(x,y,tri[np.argsort(tz)]) + + # 4. Make the figure + + if ax is None: fig, ax = plt.subplots(figsize=figsize) + + #if shade = 'gouraud': shade_opts['shade'] = + tc = ax.tripcolor(tr, np.squeeze(data), **shade_kwargs) + + ax.set_aspect('equal') + ax.axis('off') + + if title is not None: ax.set_title(title) + return tc +#@title function for plotting surface + +def plot_sim_states_outputs(ts, output): + """ + Plot the simulation states of trained input parameters. + + Parameters + ---------- + ts_sim: tensor with node_size X datapoint + simulated BOLD + ts: tensor with node_size X datapoint + empirical BOLD + E_sim: tensor with node_size X datapoint + simulated E + I_sim: tensor with node_size X datapoint + simulated I + x_sim: tensor with node_size X datapoint + simulated x + f_sim: tensor with node_size X datapoint + simulated f + v_sim: tensor with node_size X datapoint + simulated v + q_sim: tensor with node_size X datapoint + simulated q + """ + ts_sim = output['simBOLD'] + E_sim = output['E'] + I_sim = output['I'] + x_sim = output['x'] + f_sim = output['f'] + v_sim = output['v'] + q_sim = output['q'] + fig, ax = plt.subplots(5, 2, figsize=(12,8)) + im1 = ax[0,0].imshow(np.corrcoef(ts_sim), cmap = 'bwr') + ax[0,0].set_title('simFC') + fig.colorbar(im1, ax=ax[0,0]) + im2 = ax[0,1].imshow(np.corrcoef(ts.T), cmap = 'bwr') + ax[0,1].set_title('expFC') + fig.colorbar(im2, ax=ax[0,1]) + ax[1,0].plot(ts_sim.T) + ax[1,0].set_title('simBOLD') + ax[1,1].plot(ts) + ax[1,1].set_title('expBOLD') + ax[2,0].plot(E_sim.T) + ax[2,0].set_title('sim E') + ax[2,1].plot(I_sim.T) + ax[2,1].set_title('sim I') + ax[3,0].plot(x_sim.T) + ax[3,0].set_title('sim x') + ax[3,1].plot(f_sim.T) + ax[3,1].set_title('sim f') + ax[4,0].plot(v_sim.T) + ax[4,0].set_title('sim v') + ax[4,1].plot(q_sim.T) + ax[4,1].set_title('sim q') + plt.show() + +def plot_fit_parameters(output): + g_par = output['g'] + gEE_par = output['gEE'] + gIE_par = output['gIE'] + gEI_par = output['gEI'] + g_mean_par = output['gmean'] + g_var_par = output['gvar'] + cA_par = output['cA'] + cB_par = output['cB'] + cC_par = output['cC'] + sigma_par = output['sigma_state'] + sigma_out_par = output['sigma_bold'] + """ + Plot the simulation states of fitted input parameters. + + Parameters + ---------- + g_par: list of fitted parameter values + for g + gEE_par: list of fitted parameter values + for gEE + gIE_par: list of fitted parameter values + for gIE + gEI_par: list of fitted parameter values + for gEI + sc_par: list of fitted parameter values + for structural connectivity + sc_par: list of fitted parameter values + for sigma + """ + fig, ax = plt.subplots(6,2, figsize=(12,8)) + im1 = ax[0,0].plot(g_par) + ax[0,0].set_title('g') + + ax[0,1].plot(gEE_par) + ax[0,1].set_title('gEE') + + ax[1,0].plot(gIE_par) + ax[1,0].set_title('gIE') + ax[1,1].plot(gEI_par) + ax[1,1].set_title('gEI') + + ax[2,0].plot(sigma_par) + ax[2,0].set_title('sc') + + ax[2,1].plot(sigma_out_par) + ax[2,1].set_title('σ') + ax[3,0].plot(g_mean_par) + ax[3,0].set_title('post mean: g') + + ax[3,1].plot(g_var_par) + ax[3,1].set_title('post var: g') + + ax[4,0].plot(cA_par) + ax[4,0].set_title('post poly:A') + + ax[4,1].plot(cB_par) + ax[4,1].set_title('post poly:B') + + ax[5,0].plot(cC_par) + ax[5,0].set_title('post poly:C') + # @title plot_sim_states_outputs(ts, output) and plot_fit_parameters(output) + +# @title function for plotting surface +def plot_sim_states_outputs(ts, output): + """ + Plot the simulation states of trained input parameters. + + Parameters + ---------- + ts_sim: tensor with node_size X datapoint + simulated BOLD + ts: tensor with node_size X datapoint + empirical BOLD + E_sim: tensor with node_size X datapoint + simulated E + I_sim: tensor with node_size X datapoint + simulated I + x_sim: tensor with node_size X datapoint + simulated x + f_sim: tensor with node_size X datapoint + simulated f + v_sim: tensor with node_size X datapoint + simulated v + q_sim: tensor with node_size X datapoint + simulated q + """ + ts_sim = output['simBOLD'] + E_sim = output['E'] + I_sim = output['I'] + x_sim = output['x'] + f_sim = output['f'] + v_sim = output['v'] + q_sim = output['q'] + fig, ax = plt.subplots(5, 2, figsize=(12,8)) + im1 = ax[0,0].imshow(np.corrcoef(ts_sim), cmap = 'bwr') + ax[0,0].set_title('simFC') + fig.colorbar(im1, ax=ax[0,0]) + im2 = ax[0,1].imshow(np.corrcoef(ts.T), cmap = 'bwr') + ax[0,1].set_title('expFC') + fig.colorbar(im2, ax=ax[0,1]) + ax[1,0].plot(ts_sim.T) + ax[1,0].set_title('simBOLD') + ax[1,1].plot(ts) + ax[1,1].set_title('expBOLD') + ax[2,0].plot(E_sim.T) + ax[2,0].set_title('sim E') + ax[2,1].plot(I_sim.T) + ax[2,1].set_title('sim I') + ax[3,0].plot(x_sim.T) + ax[3,0].set_title('sim x') + ax[3,1].plot(f_sim.T) + ax[3,1].set_title('sim f') + ax[4,0].plot(v_sim.T) + ax[4,0].set_title('sim v') + ax[4,1].plot(q_sim.T) + ax[4,1].set_title('sim q') + plt.show() + +def plot_fit_parameters(output): + g_par = output['g'] + gEE_par = output['gEE'] + gIE_par = output['gIE'] + gEI_par = output['gEI'] + g_mean_par = output['gmean'] + g_var_par = output['gvar'] + cA_par = output['cA'] + cB_par = output['cB'] + cC_par = output['cC'] + sigma_par = output['sigma_state'] + sigma_out_par = output['sigma_bold'] + """ + Plot the simulation states of fitted input parameters. + + Parameters + ---------- + g_par: list of fitted parameter values + for g + gEE_par: list of fitted parameter values + for gEE + gIE_par: list of fitted parameter values + for gIE + gEI_par: list of fitted parameter values + for gEI + sc_par: list of fitted parameter values + for structural connectivity + sc_par: list of fitted parameter values + for sigma + """ + fig, ax = plt.subplots(6,2, figsize=(12,8)) + im1 = ax[0,0].plot(g_par) + ax[0,0].set_title('g') + + ax[0,1].plot(gEE_par) + ax[0,1].set_title('gEE') + + ax[1,0].plot(gIE_par) + ax[1,0].set_title('gIE') + ax[1,1].plot(gEI_par) + ax[1,1].set_title('gEI') + + ax[2,0].plot(sigma_par) + ax[2,0].set_title('sc') + + ax[2,1].plot(sigma_out_par) + ax[2,1].set_title('$\sigma$') + ax[3,0].plot(g_mean_par) + ax[3,0].set_title('post mean: g') + + ax[3,1].plot(g_var_par) + ax[3,1].set_title('post var: g') + + ax[4,0].plot(cA_par) + ax[4,0].set_title('post poly:A') + + ax[4,1].plot(cB_par) + ax[4,1].set_title('post poly:B') + + ax[5,0].plot(cC_par) + ax[5,0].set_title('post poly:C') + +def R_stat(data): + """ + Calculate the Gelman-Rubin convergence statistic, R-hat. + + Parameters + ---------- + data: input data dictionary for parameter value + """ + ###B step 1 variance of mean of m chains + num_data = data.shape[1] + num_param = data.shape[2] + num_chain = data.shape[0] + var_chs = num_data* data.mean(1).std(0)**2 + + ### step 2. the average of variance of m chains W + m_var_chs = (data.std(1)**2).mean(0) + + ### step 3 target of mean: mean of 4*datapoints + m_target = data.mean(1).mean(0) + + #### step 4 estimate of target variance + v_target = (num_data-1.0)/num_data*m_var_chs + 1.0/num_data*var_chs + + #### step 5 + V_hat = v_target + var_chs/num_chain/num_data + v_var = (data.std(1)**2).std(0)**2 + v_V_hat = (num_data/(num_data- 1.0))**2/num_chain*v_var + ((num_chain+1.0)/num_chain/num_data)**2*2/(num_chain -1.0)*var_chs**2 +\ + 2*(num_chain + 1.0 )*(num_data- 1.0)/num_chain/num_data/num_chain*(np.diag(np.cov((data.std(1)**2).T,\ + (data.mean(1)**2).T)[:num_param,:][:,num_param:])\ + - 2*m_target*np.diag(np.cov((data.std(1)**2).T, data.mean(1).T)[:num_param,:][:,num_param:])) + df = 2.0*V_hat**2/v_V_hat + + ### R + R = V_hat/m_var_chs*df/(df-2.0) + + return R diff --git a/whobpyt/depr/griffiths2022/rww_pytorch_model_updated.py b/whobpyt/depr/griffiths2022/rww_pytorch_model_updated.py new file mode 100644 index 00000000..a8fc1477 --- /dev/null +++ b/whobpyt/depr/griffiths2022/rww_pytorch_model_updated.py @@ -0,0 +1,676 @@ +# -*- coding: utf-8 -*- +""" +Authors: Zheng Wang, John Griffiths, Hussain Ather +WongWangDeco Model fitting +module for forward model (wwd) to simulate a batch of BOLD signals +input: noises, updated model parameters and current state (6) +outputs: updated current state and a batch of BOLD signals +""" +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +from torch.nn.parameter import Parameter +""" +WWD_Model_tf.py +This file performs the modelling and graph-building for the Wong-Wang neuronal mass model +in determining network activity. +This large-scale dynamic mean field model (DMF) approximates the average ensemble +behavior, instead of considering the detailed interactions between individual +neurons. This allow us varying the parameters and, furthermore, to greatly +simplify the system of stochastic differential equations by expressing it +in terms of the first and second-order statistical moments, means and +covariances, of network activity. In the following we present the +details of the model and its simplified versions. +""" +def h_tf(a, b, d, z): + """ + Neuronal input-output functions of excitatory pools and inhibitory pools. + + Take the variables a, x, and b and convert them to a linear equation (a*x - b) while adding a small + amount of noise 0.00001 while dividing that term to an exponential of the linear equation multiplied by the + d constant for the appropriate dimensions. + """ + num = 0.00001 + torch.abs(a * z - b) + den = 0.00001 * d + torch.abs(1.0000 - torch.exp(-d * (a * z - b))) + return torch.divide(num, den) +class RNNWWD(torch.nn.Module): + """ + A module for forward model (WWD) to simulate a batch of BOLD signals + + Attibutes + --------- + state_size : int + the number of states in the WWD model + input_size : int + the number of states with noise as input + tr : float + tr of fMRI image + step_size: float + Integration step for forward model + hidden_size: int + the number of step_size in a tr + batch_size: int + the number of BOLD signals to simulate + node_size: int + the number of ROIs + sc: float node_size x node_size array + structural connectivity + fit_gains: bool + flag for fitting gains 1: fit 0: not fit + g, g_EE, gIE, gEI: tensor with gradient on + model parameters to be fit + w_bb: tensor with node_size x node_size (grad on depends on fit_gains) + connection gains + std_in std_out: tensor with gradient on + std for state noise and output noise + g_m g_v sup_ca sup_cb sup_cc: tensor with gradient on + hyper parameters for prior distribution of g gIE and gEI + Methods + ------- + forward(input, noise_out, hx) + forward model (WWD) for generating a number of BOLD signals with current model parameters + """ + def __init__(self, input_size: int, node_size: int, + batch_size: int, step_size: float, tr: float, sc: float, fit_gains: bool, \ + g_mean_ini=100, g_std_ini = 2.5, gEE_mean_ini=2.5, gEE_std_ini = .5) -> None: + """ + Parameters + ---------- + state_size : int + the number of states in the WWD model + input_size : int + the number of states with noise as input + tr : float + tr of fMRI image + step_size: float + Integration step for forward model + hidden_size: int + the number of step_size in a tr + batch_size: int + the number of BOLD signals to simulate + node_size: int + the number of ROIs + sc: float node_size x node_size array + structural connectivity + fit_gains: bool + flag for fitting gains 1: fit 0: not fit + g_mean_ini: float, optional + prior mean of g (default 100) + g_std_ini: float, optional + prior std of g (default 2.5) + gEE_mean_ini: float, optional + prior mean of gEE (default 2.5) + gEE_std_ini: float, optional + prior std of gEE (default 0.5) + """ + super(RNNWWD, self).__init__() + self.state_size = 6 # 6 states WWD model + self.input_size = input_size # 1 or 2 + self.tr = tr # tr fMRI image + self.step_size = torch.tensor(step_size , dtype=torch.float32) # integration step 0.05 + self.hidden_size = np.int64(tr//step_size) + self.batch_size = batch_size # size of the batch used at each step + self.node_size = node_size # num of ROI + self.sc = sc # matrix node_size x node_size structure connectivity + self.fit_gains = fit_gains # flag for fitting gains + + # model parameters (variables: need to calculate gradient) + # fix initials + self.g_EI = Parameter(torch.tensor(.1, dtype=torch.float32)) # local gain E to I + self.g_IE = Parameter(torch.tensor(.1, dtype=torch.float32)) # local gain I to E # in the model we fix gII = 1.0 (local gain I to I) + self.std_in = Parameter(torch.tensor(0.05, dtype=torch.float32)) # noise for dynamics + self.std_out = Parameter(torch.tensor(0.02, dtype=torch.float32)) # noise for output + # define initals for g and gEE + if g_std_ini == 0: + self.g = Parameter(torch.tensor(g_mean_ini, dtype=torch.float32)) # global gain + else: + g_hi = g_mean_ini+3*g_std_ini + g_lo = g_mean_ini-3*g_std_ini + if g_lo < 0: + g_lo = 0 + self.g = Parameter(torch.tensor(np.random.uniform(g_lo,g_hi,1), dtype=torch.float32)) # global gain + if gEE_std_ini == 0: + self.g_EE = Parameter(torch.tensor(gEE_mean_ini, dtype=torch.float32)) # global gain + else: + gEE_hi = gEE_mean_ini+3*gEE_std_ini + gEE_lo = gEE_mean_ini-3*gEE_std_ini + if gEE_lo < 0: + gEE_lo = 0 + self.g_EE = Parameter(torch.tensor(np.random.uniform(gEE_lo,gEE_hi,1), dtype=torch.float32)) # local gain E to E + # hyper parameters (variables: need to calculate gradient) to fit density + # of gEI and gIE (the shape from the bifurcation analysis on an isloated node) + self.sup_ca = Parameter(torch.tensor(0.5, dtype=torch.float32)) + self.sup_cb = Parameter(torch.tensor(20, dtype=torch.float32)) + self.sup_cc = Parameter(torch.tensor(10, dtype=torch.float32)) + # gEE + self.g_m = Parameter(torch.tensor(g_mean_ini, dtype=torch.float32)) + self.g_v = Parameter(torch.tensor(1/g_std_ini, dtype=torch.float32)) + self.g_EE_m = Parameter(torch.tensor(gEE_mean_ini, dtype=torch.float32)) + self.g_EE_v = Parameter(torch.tensor(1/gEE_std_ini, dtype=torch.float32)) + if self.fit_gains == True: + self.w_bb = Parameter(torch.tensor(np.zeros((node_size,node_size)) + 0.05, dtype=torch.float32)) # connenction gain to modify empirical sc + else: + self.w_bb = torch.tensor(np.ones((node_size,node_size)), dtype=torch.float32) + # Parameters for the ODEs ( constants for pytorch: no need for gradient calculation) + # Excitatory population + self.W_E = torch.tensor(1., dtype=torch.float32) # scale of the external input + self.tau_E = torch.tensor(100., dtype=torch.float32) # decay time + self.gamma_E = torch.tensor(0.641/1000.,dtype=torch.float32) # other dynamic parameter (?) + # Inhibitory population + self.W_I = torch.tensor(0.7, dtype= torch.float32) # scale of the external input + self.tau_I = torch.tensor(10., dtype=torch.float32) # decay time + self.gamma_I = torch.tensor(1./1000., dtype=torch.float32) # other dynamic parameter (?) + # External input + self.I_0 = torch.tensor(0.2, dtype=torch.float32) # external input + #self.sigma = torch.tensor(0.02, dtype=torch.float32) # noise std + # parameters for the nonlinear function of firing rate + self.aE = torch.tensor(310, dtype=torch.float32) + self.bE = torch.tensor(125, dtype=torch.float32) + self.dE = torch.tensor(0.16, dtype=torch.float32) + self.aI = torch.tensor(615, dtype=torch.float32) + self.bI = torch.tensor(177, dtype=torch.float32) + self.dI = torch.tensor(0.087, dtype=torch.float32) + # parameters for Balloon dynamics (BOLD signal) + self.alpha = torch.tensor(0.32, dtype=torch.float32) # stiffness component + self.rho = torch.tensor(0.34, dtype=torch.float32) # resting oxygen extraction fraction + self.k1 = torch.tensor(2.38, dtype=torch.float32) + self.k2 = torch.tensor(2.0, dtype=torch.float32) + self.k3 = torch.tensor(0.48, dtype=torch.float32) + self.V = torch.tensor(.02, dtype=torch.float32) # resting blood volume fraction + self.E0 = torch.tensor(0.34, dtype=torch.float32) # the capillary bed + self.tau_s = torch.tensor(0.65, dtype=torch.float32) # time constant of decay + self.tau_f = torch.tensor(0.41, dtype=torch.float32) # time constant of autoregulation + self.tau_0 = torch.tensor(0.98, dtype=torch.float32) # mean transit time through balloon at rest + + """def check_input(self, input: Tensor) -> None: + expected_input_dim = 2 + if input.dim() != expected_input_dim: + raise RuntimeError( + 'input must have {} dimensions, got {}'.format( + expected_input_dim, input.dim())) + if self.input_size != input.size(-1): + raise RuntimeError( + 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format( + self.input_size, input.size(-1))) + + if self.batch_size != input.size(0): + raise RuntimeError( + 'input.size(0) must be equal to batch_size. Expected {}, got {}'.format( + self.batch_size, input.size(0)))""" + + + def forward(self, input, noise_out, hx): + """ + Forward step in simulating the BOLD signal. + Parameters + ---------- + input: tensor with node_size x hidden_size x batch_size x input_size + noise for states + noise_out: tensor with node_size x batch_size + noise for BOLD + hx: tensor with node_size x state_size + states of WWD model + Outputs + ------- + next_state: dictionary with keys: + 'current_state''bold_batch''E_batch''I_batch''x_batch''f_batch''v_batch''q_batch' + record new states and BOLD + """ + next_state = {} + # hx is current state (6) 0: E 1:I (neural activitiveties) 2:x 3:f 4:v 5:f (BOLD) + + E = hx[:,0:1] + I = hx[:,1:2] + x = hx[:,2:3] + f = hx[:,3:4] + v = hx[:,4:5] + q = hx[:,5:6] + dt = self.step_size + con_1 = torch.ones_like(dt) + # Update the Laplacian based on the updated connection gains w_bb. + w = (1.0+torch.tanh(self.w_bb))*torch.tensor(self.sc, dtype=torch.float32)#(con_1 + torch.tanh(self.w_bb))*torch.tensor(self.sc, dtype=torch.float32) + w_n = 0.5*(w + torch.transpose(w, 0, 1))/torch.linalg.norm(0.5*(w + torch.transpose(w, 0, 1))) + self.sc_m = w_n + l_s = -torch.diag(torch.sum(w_n, axis =1)) + w_n + + # placeholder for the updated corrent state + current_state = torch.zeros_like(hx) + + # Generate the ReLU module for model parameters gEE gEI and gIE + m_gEE = torch.nn.ReLU() + m_gIE = torch.nn.ReLU() + m_gEI = torch.nn.ReLU() + m_std_in = torch.nn.ReLU() + m_std_out = torch.nn.ReLU() + + + # placeholders for output BOLD, history of E I x f v and q + bold_batch = [] + E_batch = [] + I_batch = [] + x_batch = [] + f_batch = [] + v_batch = [] + q_batch = [] + + # Use the forward model to get BOLD signal at ith element in the batch. + for i_batch in range(self.batch_size): + # Get the noise for BOLD output. + noiseBold = noise_out[:,i_batch:i_batch+1] + # Generate the ReLU module for model parameter m_cBOLD (used for calculating the magnitude of the BOLD signal). + # The rectified linear activation function or ReLU for short is a piecewise linear function that + # will output the input directly if it is positive, otherwise, it will output zero. + m_cBOLD = torch.nn.ReLU() + + # Since tr is about second we need to use a small step size like 0.05 to integrate the model states. + for i_hidden in range(self.hidden_size): + + # Generate ReLU module for input recurrent IE and II. + m_IE = torch.nn.ReLU() + m_II = torch.nn.ReLU() + + # Input noise for E and I. + noiseE = input[:,i_hidden, i_batch,0:1] + noiseI = input[:,i_hidden, i_batch,1:2] + + # Calculate the input recurrents. + IE = torch.tanh(m_IE(self.W_E*self.I_0 + (0.001*con_1 + m_gEE(self.g_EE))*E \ + + self.g*torch.matmul(l_s, E) - (0.001*con_1 + m_gIE(self.g_IE))*I)) # input currents for E + II = torch.tanh(m_II(self.W_I*self.I_0 + (0.001*con_1 + m_gIE(self.g_EI))*E -I)) # input currents for I + + # Calculate the firing rates. + rE = h_tf(self.aE, self.bE, self.dE, IE) # firing rate for E + rI = h_tf(self.aI, self.bI, self.dI, II) # firing rate for I + # Update the states by step-size 0.05. + ddE = E + dt*(-E*torch.reciprocal(self.tau_E) +self.gamma_E*(1.-E)*rE) \ + + torch.sqrt(dt)*noiseE*(0.02*con_1 + m_std_in(self.std_in))### equlibrim point at E=(tau_E*gamma_E*rE)/(1+tau_E*gamma_E*rE) + ddI = I + dt*(-I*torch.reciprocal(self.tau_I) +self.gamma_I*rI) \ + + torch.sqrt(dt)*noiseI * (0.02*con_1 + m_std_in(self.std_in)) + + dx = x + dt*(E - torch.reciprocal(self.tau_s) * x - torch.reciprocal(self.tau_f)* (f - con_1)) + df = f + dt*x + dv = v + dt*(f - torch.pow(v, torch.reciprocal(self.alpha))) * torch.reciprocal(self.tau_0) + dq = q + dt*(f * (con_1 - torch.pow(con_1 - self.rho, torch.reciprocal(f)))*torch.reciprocal(self.rho) \ + - q * torch.pow(v, torch.reciprocal(self.alpha)) *torch.reciprocal(v)) \ + * torch.reciprocal(self.tau_0) + + # Calculate the saturation for model states (for stability and gradient calculation). + E = torch.tanh(0.00001+torch.nn.functional.relu(ddE)) + I = torch.tanh(0.00001+torch.nn.functional.relu(ddI)) + x = torch.tanh(dx) + f = (con_1 + torch.tanh(df - con_1)) + v = (con_1 + torch.tanh(dv - con_1)) + q = (con_1 + torch.tanh(dq - con_1)) + + # Put each time step E and I into placeholders (need them to calculate Entropy of E and I + # in the loss (maximize the entropy)). + E_batch.append(E) + I_batch.append(I) + + # Put x f v q from each tr to the placeholders for checking them visually. + x_batch.append(x) + f_batch.append(f) + v_batch.append(v) + q_batch.append(q) + # Put the BOLD signal each tr to the placeholder being used in the cost calculation. + bold_batch.append((0.001*con_1 + m_std_out(self.std_out))*noiseBold + \ + 100.0*self.V*torch.reciprocal(self.E0)*(self.k1 * (con_1 - q) \ + + self.k2 * (con_1 - q *torch.reciprocal(v)) + self.k3 * (con_1 - v)) ) + + # Update the current state. + current_state = torch.cat([E, I, x, f, v, q], axis = 1) + next_state['current_state'] = current_state + next_state['bold_batch'] = torch.cat(bold_batch, axis=1) + next_state['E_batch'] = torch.cat(E_batch, axis=1) + next_state['I_batch'] = torch.cat(I_batch, axis=1) + next_state['x_batch'] = torch.cat(x_batch, axis=1) + next_state['f_batch'] = torch.cat(f_batch, axis=1) + next_state['v_batch'] = torch.cat(v_batch, axis=1) + next_state['q_batch'] = torch.cat(q_batch, axis=1) + return next_state +def cost_r(logits_series_tf, labels_series_tf): + """ + Calculate the Pearson Correlation between the simFC and empFC. + From there, the probability and negative log-likelihood. + Parameters + ---------- + logits_series_tf: tensor with node_size X datapoint + simulated BOLD + labels_series_tf: tensor with node_size X datapoint + empirical BOLD + """ + # get node_size(batch_size) and batch_size() + node_size = logits_series_tf.shape[0] + truncated_backprop_length = logits_series_tf.shape[1] + + # remove mean across time + labels_series_tf_n = labels_series_tf - torch.reshape(torch.mean(labels_series_tf, 1), [node_size, 1])# - torch.matmul( + + logits_series_tf_n = logits_series_tf - torch.reshape(torch.mean(logits_series_tf, 1), [node_size, 1])#- torch.matmul( + + # correlation + cov_sim = torch.matmul(logits_series_tf_n, torch.transpose(logits_series_tf_n, 0, 1)) + cov_def = torch.matmul(labels_series_tf_n, torch.transpose(labels_series_tf_n, 0, 1)) + + # fc for sim and empirical BOLDs + FC_sim_T = torch.matmul(torch.matmul(torch.diag(torch.reciprocal(torch.sqrt(\ + torch.diag(cov_sim)))), cov_sim), + torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_sim))))) + FC_T = torch.matmul(torch.matmul(torch.diag(torch.reciprocal(torch.sqrt(\ + torch.diag(cov_def)))), cov_def), + torch.diag(torch.reciprocal(torch.sqrt(torch.diag(cov_def))))) + + # mask for lower triangle without diagonal + ones_tri = torch.tril(torch.ones_like(FC_T), -1) + zeros = torch.zeros_like(FC_T) # create a tensor all ones + mask = torch.greater(ones_tri, zeros) # boolean tensor, mask[i] = True iff x[i] > 1 + # mask out fc to vector with elements of the lower triangle + FC_tri_v = torch.masked_select(FC_T, mask) + FC_sim_tri_v = torch.masked_select(FC_sim_T, mask) + + # remove the mean across the elements + FC_v = FC_tri_v - torch.mean(FC_tri_v) + FC_sim_v = FC_sim_tri_v - torch.mean(FC_sim_tri_v) + + # corr_coef + corr_FC =torch.sum(torch.multiply(FC_v,FC_sim_v))\ + *torch.reciprocal(torch.sqrt(torch.sum(torch.multiply(FC_v,FC_v))))\ + *torch.reciprocal(torch.sqrt(torch.sum(torch.multiply(FC_sim_v,FC_sim_v)))) + + # use surprise: corr to calculate probability and -log + losses_corr = -torch.log(0.5000 + 0.5*corr_FC) #torch.mean((FC_v -FC_sim_v)**2)# + return losses_corr +class Model_fitting(): + """ + Using ADAM and AutoGrad to fit WWD to empirical BOLD + Attributes + ---------- + model: instance of class RNNWWD + forward model WWD + ts: array with num_tr x node_size + empirical BOLD time-series + num_epoches: int + the times for repeating trainning + Methods: + train() + train model + test() + using the optimal model parater to simulate the BOLD + + """ + def __init__(self, model_wwd, ts, num_epoches): + """ + Parameters + ---------- + model: instance of class RNNWWD + forward model WWD + ts: array with num_tr x node_size + empirical BOLD time-series + num_epoches: int + the times for repeating trainning + """ + self.model = model_wwd + self.num_epoches = num_epoches + if ts.shape[1] != model_wwd.node_size: + print('ts is a matrix with the number of datapoint X the number of node') + else: + self.ts = ts + def train(self): + """ + Parameters + ---------- + None + Outputs: + output_train: dictionary with keys: + 'simBOLD_train''E_train''I_train''x_train''f_train''v_train''q_train' + 'gains''g''gEE''gIE''gEI''sigma_state''sigma_bold' + record states, BOLD, history of model parameters and loss + """ + output_train = {} + # define an optimizor(ADAM) + optimizer = optim.Adam(self.model.parameters(), lr=0.01, eps=1e-7) + # initial state + X = torch.tensor(0.45 * np.random.uniform(0, 1, (self.model.node_size, self.model.state_size)) + np.array( + [0, 0, 0, 1.0, 1.0, 1.0]), dtype=torch.float32) + # placeholders for model parameters + g_par = [] + sigma_par = [] + sigma_out_par = [] + gEE_par = [] + gIE_par = [] + gEI_par = [] + sc_par = [] + g_mean_par = [] + g_var_par = [] + cA_par = [] + cB_par = [] + cC_par = [] + loss_his =[] + + # define mask for geting lower triangle matrix + mask = np.tril_indices(self.model.node_size, -1) + # get initial values of the model parameters + sc_par.append(self.model.sc[mask].copy()) + g_par.append(self.model.g.detach().numpy().copy()) + sigma_par.append(self.model.std_in.detach().numpy().copy()) + sigma_out_par.append(self.model.std_out.detach().numpy().copy()) + gEE_par.append(self.model.g_EE.detach().numpy().copy()) + gIE_par.append(self.model.g_IE.detach().numpy().copy()) + gEI_par.append(self.model.g_EI.detach().numpy().copy()) + g_mean_par.append(self.model.g_m.detach().numpy().copy()) + g_var_par.append(self.model.g_v.detach().numpy().copy()) + cA_par.append(self.model.sup_ca.detach().numpy().copy()) + cB_par.append(self.model.sup_cb.detach().numpy().copy()) + cC_par.append(self.model.sup_cc.detach().numpy().copy()) + + + # define constant 1 tensor + + con_1 = torch.tensor(1.0, dtype=torch.float32) + # define num_batches + num_batches = self.ts.shape[0] // self.model.batch_size + for i_epoch in range(self.num_epoches): + + # Create placeholders for the simulated BOLD E I x f and q of entire time series. + bold_sim_train = [] + E_sim_train = [] + I_sim_train = [] + x_sim_train = [] + f_sim_train = [] + v_sim_train = [] + q_sim_train = [] + + # Perform the training in batches. + + for i_batch in range(num_batches): + + # Generate the ReLU module for hyper parameters: ca cb cc and model parameters gEI and gIE. + m_ca = torch.nn.ReLU() + m_cb = torch.nn.ReLU() + m_cc = torch.nn.ReLU() + m_gIE = torch.nn.ReLU() + m_gEI = torch.nn.ReLU() + m_g = torch.nn.ReLU() + m_g_m = torch.nn.ReLU() + m_g_v = torch.nn.ReLU() + + # Reset the gradient to zeros after update model parameters. + optimizer.zero_grad() + # Initialize the placeholder for the next state. + X_next = torch.zeros_like(X) + # Get the input and output noises for the module. + noise_in = torch.tensor(np.random.randn(self.model.node_size, self.model.hidden_size, \ + self.model.batch_size, self.model.input_size), dtype=torch.float32) + noise_out = torch.tensor(np.random.randn(self.model.node_size, self.model.batch_size), dtype=torch.float32) + + # Use the model.forward() function to update next state and get simulated BOLD in this batch. + next_batch = self.model(noise_in, noise_out, X) + E_batch=next_batch['E_batch'] + I_batch=next_batch['I_batch'] + + # Get the batch of emprical BOLD signal. + ts_batch = torch.tensor((self.ts[i_batch*self.model.batch_size:(i_batch+1)*self.model.batch_size,:]).T, dtype=torch.float32) + # Get the loss by comparing the simulated batch of BOLD against the batch of empircal BOLD, entropy of + # the batch of simulated E I, and prior density of gEI and gIE. + loss = cost_r(next_batch['bold_batch'], ts_batch) + 5.0*((0.001 + m_ca(self.model.sup_ca))\ + *(0.001*con_1 + m_gIE(self.model.g_IE))**2 \ + - (0.001*con_1 + m_cb(self.model.sup_cb))*(0.001*con_1 + m_gIE(self.model.g_IE)) \ + +(0.001*con_1 + m_cc(self.model.sup_cc)) -(0.001*con_1 + m_gEI(self.model.g_EI)))**2\ + + 0.1*torch.mean(torch.mean(E_batch*torch.log(E_batch) + (con_1 - E_batch)*torch.log(con_1 - E_batch)\ + + 0.5*I_batch*torch.log(I_batch) + 0.5*(con_1 - I_batch)*torch.log(con_1 - I_batch), axis=1))\ + + (0.001*con_1 + m_g_v(self.model.g_v))*(m_g(self.model.g) - m_g_m(self.model.g_m))**2\ + +torch.log(0.001*con_1 + m_g_v(self.model.g_v))\ + + (0.001*con_1 + m_g_v(self.model.g_EE_v))*(m_g(self.model.g_EE) - m_g_m(self.model.g_EE_m))**2\ + +torch.log(0.001*con_1 + m_g_v(self.model.g_EE_v)) + + # Put the batch of the simulated BOLD, E I x f v q in to placeholders for entire time-series. + bold_sim_train.append(next_batch['bold_batch'].detach().numpy()) + E_sim_train.append(next_batch['E_batch'].detach().numpy()) + I_sim_train.append(next_batch['I_batch'].detach().numpy()) + x_sim_train.append(next_batch['x_batch'].detach().numpy()) + f_sim_train.append(next_batch['f_batch'].detach().numpy()) + v_sim_train.append(next_batch['v_batch'].detach().numpy()) + q_sim_train.append(next_batch['q_batch'].detach().numpy()) + loss_his.append(loss.detach().numpy()) + + # Calculate gradient using backward (backpropagation) method of the loss function. + loss.backward(retain_graph=True) + + # Optimize the model based on the gradient method in updating the model parameters. + optimizer.step() + + # Put the updated model parameters into the history placeholders. + sc_par.append(self.model.w_bb.detach().numpy()[mask].copy()) + g_par.append(self.model.g.detach().numpy().copy()) + sigma_par.append(self.model.std_in.detach().numpy().copy()) + sigma_out_par.append(self.model.std_out.detach().numpy().copy()) + gEE_par.append(self.model.g_EE.detach().numpy().copy()) + gIE_par.append(self.model.g_IE.detach().numpy().copy()) + gEI_par.append(self.model.g_EI.detach().numpy().copy()) + g_mean_par.append(self.model.g_m.detach().numpy().copy()) + g_var_par.append(self.model.g_v.detach().numpy().copy()) + cA_par.append(self.model.sup_ca.detach().numpy().copy()) + cB_par.append(self.model.sup_cb.detach().numpy().copy()) + cC_par.append(self.model.sup_cc.detach().numpy().copy()) + + # last update current state using next state... (no direct use X = X_next, since gradient calculation only depends on one batch no history) + X = torch.tensor(next_batch['current_state'].detach().numpy(), dtype=torch.float32) + fc = np.corrcoef(self.ts.T) + ts_sim = np.concatenate(bold_sim_train, axis=1) + E_sim = np.concatenate(E_sim_train, axis=1) + I_sim = np.concatenate(I_sim_train, axis=1) + x_sim = np.concatenate(x_sim_train, axis=1) + f_sim = np.concatenate(f_sim_train, axis=1) + v_sim = np.concatenate(v_sim_train, axis=1) + q_sim = np.concatenate(q_sim_train, axis=1) + fc_sim = np.corrcoef(ts_sim[:, 10:]) + print('epoch: ', i_epoch, np.corrcoef(fc_sim[mask], fc[mask])[0, 1]) + output_train['simBOLD'] = ts_sim + output_train['E'] = E_sim + output_train['I'] = I_sim + output_train['x'] = x_sim + output_train['f'] = f_sim + output_train['v'] = v_sim + output_train['q'] = q_sim + output_train['gains'] = np.array(sc_par) + output_train['g'] = np.array(g_par) + output_train['gEE'] = np.array(gEE_par) + output_train['gIE'] = np.array(gIE_par) + output_train['gEI'] = np.array(gEI_par) + output_train['sigma_state'] = np.array(sigma_par) + output_train['sigma_bold'] = np.array(sigma_out_par) + output_train['gmean'] = np.array(g_mean_par) + output_train['gvar'] = np.array(g_var_par) + output_train['cA'] = np.array(cA_par) + output_train['cB'] = np.array(cB_par) + output_train['cC'] = np.array(cC_par) + output_train['loss'] = np.array(loss_his) + + return output_train + def test(self, num_batches, **kwargs): + """ + Parameters + ---------- + num_batches: int + length of simBOLD = batch_size x num_batches + g, gEE, gIE, gEI and gains if have + for values of model parameters + Outputs: + output_test: dictionary with keys: + 'simBOLD_test''E_test''I_test''x_test''f_test''v_test''q_test' + """ + g= kwargs.get('g', None) + gEE= kwargs.get('gEE', None) + gIE= kwargs.get('gIE', None) + gEI= kwargs.get('gEI', None) + gains = kwargs.get('gains', None) + #print(g) + if g is not None: + self.model.g.data = torch.tensor(g, dtype=torch.float32).data + if gIE is not None: + self.model.g_IE.data = torch.tensor(gIE, dtype=torch.float32).data + if gEI is not None: + self.model.g_EI.data = torch.tensor(gEI, dtype=torch.float32).data + if gEE is not None: + self.model.g_EE.data = torch.tensor(gEE, dtype=torch.float32).data + if gains is not None: + self.model.w_bb.data = torch.tensor(gains, dtype=torch.float32).data + output_test = {} + mask = np.tril_indices(self.model.node_size, -1) + # initial state + X = torch.tensor(0.45 * np.random.uniform(0, 1, (self.model.node_size, self.model.state_size)) + np.array( + [0, 0, 0, 1.0, 1.0, 1.0]), dtype=torch.float32) + + # Create placeholders for the simulated BOLD E I x f and q of entire time series. + bold_sim_test = [] + E_sim_test = [] + I_sim_test = [] + x_sim_test = [] + f_sim_test = [] + v_sim_test = [] + q_sim_test = [] + + # Perform the training in batches. + + for i_batch in range(num_batches+2): + + # Get the input and output noises for the module. + noise_in = torch.tensor(np.random.randn(self.model.node_size, self.model.hidden_size, \ + self.model.batch_size, self.model.input_size), dtype=torch.float32) + noise_out = torch.tensor(np.random.randn(self.model.node_size, self.model.batch_size), dtype=torch.float32) + + # Use the model.forward() function to update next state and get simulated BOLD in this batch. + next_batch = self.model(noise_in, noise_out, X) + + if i_batch >= 2: + # Put the batch of the simulated BOLD, E I x f v q in to placeholders for entire time-series. + bold_sim_test.append(next_batch['bold_batch'].detach().numpy()) + E_sim_test.append(next_batch['E_batch'].detach().numpy()) + I_sim_test.append(next_batch['I_batch'].detach().numpy()) + x_sim_test.append(next_batch['x_batch'].detach().numpy()) + f_sim_test.append(next_batch['f_batch'].detach().numpy()) + v_sim_test.append(next_batch['v_batch'].detach().numpy()) + q_sim_test.append(next_batch['q_batch'].detach().numpy()) + + # last update current state using next state... (no direct use X = X_next, since gradient calculation only depends on one batch no history) + X = next_batch['current_state'] + fc = np.corrcoef(self.ts.T) + ts_sim = np.concatenate(bold_sim_test, axis=1) + E_sim = np.concatenate(E_sim_test, axis=1) + I_sim = np.concatenate(I_sim_test, axis=1) + x_sim = np.concatenate(x_sim_test, axis=1) + f_sim = np.concatenate(f_sim_test, axis=1) + v_sim = np.concatenate(v_sim_test, axis=1) + q_sim = np.concatenate(q_sim_test, axis=1) + fc_sim = np.corrcoef(ts_sim) + print(np.corrcoef(fc_sim[mask], fc[mask])[0, 1]) + output_test['simBOLD'] = ts_sim + output_test['E'] = E_sim + output_test['I'] = I_sim + output_test['x'] = x_sim + output_test['f'] = f_sim + output_test['v'] = v_sim + output_test['q'] = q_sim + return output_test diff --git a/whobpyt/depr/griffiths2022/wwd_test.py b/whobpyt/depr/griffiths2022/wwd_test.py new file mode 100644 index 00000000..1da0cb5f --- /dev/null +++ b/whobpyt/depr/griffiths2022/wwd_test.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +"""WWD_test.ipynb + +Automatically generated by Colab. + +Original file is located at + https://colab.research.google.com/drive/1W-k8P5xZaopOKcoItFsYjOjsv5IIvo7- +""" +import matplotlib.pyplot as plt +import numpy as np +import time + +class WWD_test(): + def __init__(self, G, gEE, gIE, gEI, Ws, step_size, node_size, Tr): + self.G = G + self.gEE = gEE + self.gIE = gIE + self.gEI = gEI + L_s = -np.diag(np.sum(Ws, axis= 1)) + Ws + self.L_s = L_s + self.dt = step_size + self.Tr = Tr + self.node_size = node_size + self.hidden_size = np.int64(Tr//step_size) + X0 = np.random.uniform(0,1, (node_size,6)) + X0[:,3:] = 1.0 + X0[:,:3] + self.X = X0 + + def forward(self): + """ + Forward step in generating the BOLD signal. + """ + + def smooth_normalize_ct(x, center): + """ + Normalize the centers when smoothing + """ + return center+ (center-0.001)*np.tanh((x-center)/(center - 0.001)) + + def smooth_normalize(x): + """ + Normalize small values to 0.000001. + """ + x[x< 0.000001] = 0.000001 + return x + + def sigmoid(a, b, d, x): + """ + Sigmoid function for linearizing the current. + """ + return (a*x-b)/(1.0000-np.exp(-d*(a*x-b))) + + I0 = 0.2 + gamma = 0.641/1000. + gammaI = 1.0/1000. + + aE = 310 + bE = 125 + dE_0 = 0.16 + WE = 1.0 #18.4576 #1.0 + + aI = 615 + bI = 177 + dI_0 = 0.087 + WI = 0.7 + + tauE = 100. + tauI = 10. + + W1 = 0.02 + E = self.X[:,0] + I = self.X[:,1] + q = self.X[:,5] + v = self.X[:,4] + f = self.X[:,3] + x = self.X[:,2] + + rho = 0.34 + tau_0 = 0.98 + alpha =0.32 + tau_s = 0.65 + tau_f = 0.41 + k = 1 + + def fout(v, k): + """ + Outflow using the Balloon Model + """ + return (k*v)**(1.0/alpha)/k + + def Ef(f, k): + """ + Energy function of the capillary bed + """ + return 1.0 - (1.0 - rho)**(1.0/f/k) + + IE = np.tanh(smooth_normalize(WE*I0 + self.gEE*E + self.G*np.dot(self.L_s, E) -self.gIE*I)) + II = np.tanh(smooth_normalize(WI*I0 + self.gEI*E - I)) + + # Calculate the differential values. + dE = E - self.dt*(E)/tauE \ + + self.dt*(1.0-E)*gamma*sigmoid(aE, bE, dE_0, IE)\ + + W1*np.sqrt(self.dt)*np.random.randn(self.node_size) + dI = I - self.dt*I/tauI \ + + self.dt*gammaI*sigmoid(aI, bI, dI_0, II)+W1*np.sqrt(self.dt)*np.random.randn(self.node_size) + dx = x + self.dt*(E\ + - 1.0/tau_s*x \ + - 1.0/tau_f *(k*f-1)) + f_tmp = f + self.dt*1/k*x + dv = v + self.dt*(f/tau_0 - fout(v,k)/tau_0) + dq = q + self.dt*(f*Ef(f,k)/rho/tau_0\ + -q/v*fout(v,k)/tau_0) + df = f_tmp + x = np.tanh(dx) # smooth_normalize(dx, 0.5) + v = 1.0 + np.tanh(dv - 1.0) # smooth_normalize_ct(dv, tf.constant(1., dtype=tf.float32), tf.constant(0., dtype=tf.float32), tf.constant(1.0, dtype=tf.float32)) + q = 1.0 + np.tanh(dq - 1.0) # smooth_normalize_ct(dq, tf.constant(1., dtype=tf.float32), tf.constant(0., dtype=tf.float32), tf.constant(1.0, dtype=tf.float32)) + f = 1.0 + np.tanh(f_tmp - 1.0) + E = np.tanh(smooth_normalize(dE)) # 1.0 + np.tanh(dE - 1.0) #smooth_normalize_ct(dE, 1.0)#np.tanh(dE/200.0) + I = np.tanh(smooth_normalize(dI)) # 1.0 + np.tanh(dI - 1.0)#smooth_normalize_ct(dI, 1.0)#np.tanh(dI/200.0) + + self.X = np.array([E, I, x, f, v, q]).T + + def output(self): + """ + Output parameters when generating the BOLD signal. + """ + W2 = 0.02 + E0 = 0.34 + rho = 0.34 + k1 = 7*E0 + k2 = 2. + k3 = 2*E0-0.2 + V = 0.02 + k = 1 + q = self.X[:,4] + v = self.X[:,5] + + y= k1*(1-k*q)+k2*(1-q/v)+k3*(1-k*v) + + return 100.0/E0*V*(y)+W2*np.random.randn(self.node_size) + + def generate_bold(self, num_tr): + """ + Generate the BOLD signal using the hidden states and making steps forward. + """ + bold =[] + for j in range((20+ num_tr)*self.hidden_size): + self.forward() + if (j+1) % self.hidden_size == 0: + bold.append(self.output()) + return np.array(bold[20:])