diff --git a/hypyp/analyses.py b/hypyp/analyses.py index 09efd42..b80d901 100644 --- a/hypyp/analyses.py +++ b/hypyp/analyses.py @@ -19,6 +19,7 @@ import statsmodels.stats.multitest import copy from collections import namedtuple +from fractions import Fraction from typing import Union, List, Tuple import matplotlib.pyplot as plt from tqdm import tqdm @@ -715,7 +716,7 @@ def compute_conn_mvar(complex_signal: np.ndarray, mvar_params: dict, fit_mvar = mvar.fit(merged_signal[:, 0, 0, :][np.newaxis, ...]) is_stable = fit_mvar.stability() aug_signal = merged_signal[np.newaxis, ...] - counter += counter + counter += 1 else: @@ -929,7 +930,7 @@ def compute_nmPLV(data: np.ndarray, sampling_rate: int, r = np.mean(freq_range2)/np.mean(freq_range1) freq_range = [np.min(freq_range1), np.max(freq_range2)] - complex_signal = np.mean(compute_single_freq(data, sampling_rate, freq_range, **filter_options),3).squeeze() + complex_signal = np.mean(compute_single_freq(data, sampling_rate, freq_range),3).squeeze() n_epoch, n_ch, n_freq, n_samp = complex_signal.shape[1], complex_signal.shape[2], \ complex_signal.shape[3], complex_signal.shape[4] @@ -939,13 +940,13 @@ def compute_nmPLV(data: np.ndarray, sampling_rate: int, transpose_axes = (0, 1, 3, 2) phase = complex_signal / np.abs(complex_signal) - freqsn = freq_range - freqsm = [f * r for f in freqsn] - n_mult = (freqsn[0] + freqsm[0]) / (2 * freqsn[0]) - m_mult = (freqsm[0] + freqsn[0]) / (2 * freqsm[0]) + # Compute integer n:m ratio so that n*f1 = m*f2 + frac = Fraction(r).limit_denominator(10) + n, m = frac.numerator, frac.denominator - phase[:, :, :, :n_ch] = n_mult * phase[:, :, :, :n_ch] - phase[:, :, :, n_ch:] = m_mult * phase[:, :, :, n_ch:] + # Raise phases to integer powers for n:m coupling + phase[:, :, :, :n_ch] = phase[:, :, :, :n_ch] ** n + phase[:, :, :, n_ch:] = phase[:, :, :, n_ch:] ** m c = np.real(phase) s = np.imag(phase) @@ -1036,12 +1037,23 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, freqs: Union[int, np.ndarray], assert n_samples1 == n_samples2, "n_samples1 and n_samples2 should have the same number of samples." cross_sigs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1), dtype=complex) * np.nan - wtcs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1), dtype=complex) * np.nan + wtcs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1)) * np.nan # Set the mother wavelet Ws = mne.time_frequency.tfr.morlet(sfreq, freqs, n_cycles=n_cycles, sigma=None, zero_mean=True) + # Wavelet scales in samples (Morlet scale-frequency relation) + scales = n_cycles * sfreq / (2 * np.pi * freqs) + + # Precompute Gaussian smoothing windows per frequency + smooth_wins = [] + for s in scales: + win_size = max(3, int(2 * s)) + win = scipy.signal.windows.gaussian(win_size, std=s) + win /= win.sum() + smooth_wins.append(win) + # Perform a continuous wavelet transform on all epochs of each signal for ind1, ch_label1 in enumerate(sig1.ch_names): for ind2, ch_label2 in enumerate(sig2.ch_names): @@ -1052,15 +1064,24 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, freqs: Union[int, np.ndarray], cur_sig2 = np.squeeze(sig2.get_data(mne.pick_channels(sig2.ch_names, [ch_label2]))) out2 = mne.time_frequency.tfr.cwt(cur_sig2, Ws, use_fft=True, mode='same', decim=1) - - # Compute cross-spectrum - wps1 = out1 * out1.conj() - wps2 = out2 * out2.conj() + + # Compute cross-spectrum and auto-spectra cross_sig = out1 * out2.conj() cross_sigs[ind1, ind2, :, :, :] = cross_sig - coh = (cross_sig) / (np.sqrt(wps1*wps2)) - abs_coh = np.abs(coh) - wtc = (abs_coh - np.min(abs_coh)) / (np.max(abs_coh) - np.min(abs_coh)) + wps1 = np.abs(out1) ** 2 + wps2 = np.abs(out2) ** 2 + + # Smooth in time with scale-dependent Gaussian window + # following Grinsted et al. (2004) + for fi, win in enumerate(smooth_wins): + wps1[:, fi, :] = np.apply_along_axis( + lambda x: np.convolve(x, win, mode='same'), -1, wps1[:, fi, :]) + wps2[:, fi, :] = np.apply_along_axis( + lambda x: np.convolve(x, win, mode='same'), -1, wps2[:, fi, :]) + cross_sig[:, fi, :] = np.apply_along_axis( + lambda x: np.convolve(x, win, mode='same'), -1, cross_sig[:, fi, :]) + + wtc = np.abs(cross_sig) ** 2 / (wps1 * wps2) wtcs[ind1, ind2, :, :, :] = wtc if mode == 'power': diff --git a/hypyp/prep.py b/hypyp/prep.py index 3479a41..3e7f051 100644 --- a/hypyp/prep.py +++ b/hypyp/prep.py @@ -401,7 +401,7 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj picks = mne.pick_types( - clean_epochs_subj[subject_id].info, + clean_epochs_subj.info, meg=False, eeg=True, stim=False, diff --git a/hypyp/stats.py b/hypyp/stats.py index e67186a..4c659d3 100644 --- a/hypyp/stats.py +++ b/hypyp/stats.py @@ -109,12 +109,7 @@ def statsCond(data: np.ndarray, epochs: mne.Epochs, n_permutations: int, alpha: tail=0, n_jobs=1) adj_p = mne.stats.fdr_correction(p_values, alpha=alpha, method='indep') - T_obs_plot = np.nan * np.ones_like(T_obs) - for c in adj_p[1]: - if c <= alpha: - i = np.where(adj_p[1] == c) - T_obs_plot[i] = T_obs[i] - T_obs_plot = np.nan_to_num(T_obs_plot) + T_obs_plot = np.where(adj_p[0], T_obs, 0) # retrieving sensor position pos = np.array([[0, 0]]) @@ -475,7 +470,7 @@ def statscondCluster(data: list, freqs_mean: list, ch_con_freq: scipy.sparse.csr dfd = np.sum([len(d) for d in data]) - len(data) # Denominator degrees of freedom if tail == 0: - threshold = f_dist.ppf(1 - alpha / 2, dfn, dfd) # 2-tailed F-test + threshold = f_dist.ppf(1 - alpha, dfn, dfd) else: threshold = None # One-tailed test uses MNE's default @@ -605,22 +600,28 @@ def statscluster(data: list, test: str, factor_levels: List[int], ch_con_freq: s if test == 'ind ttest': def stat_fun(*arg): return(scipy.stats.ttest_ind(arg[0], arg[1], equal_var=False)[0]) - threshold = alpha + df = len(data[0]) + len(data[1]) - 2 + p = alpha / 2 if tail == 0 else alpha + threshold = scipy.stats.t.ppf(1 - p, df) elif test == 'rel ttest': def stat_fun(*arg): return(scipy.stats.ttest_rel(arg[0], arg[1])[0]) - threshold = alpha + df = len(data[0]) - 1 + p = alpha / 2 if tail == 0 else alpha + threshold = scipy.stats.t.ppf(1 - p, df) elif test == 'f oneway': def stat_fun(*arg): return(scipy.stats.f_oneway(arg[0], arg[1])[0]) - threshold = alpha + dfn = len(data) - 1 + dfd = sum(len(d) for d in data) - len(data) + threshold = f_dist.ppf(1 - alpha, dfn, dfd) elif test == 'f multipleway': if max(factor_levels) > 2: correction = True else: correction = False def stat_fun(*arg): - return(mne.stats.f_mway_rm(np.swapaxes(args, 1, 0), + return(mne.stats.f_mway_rm(np.swapaxes(arg, 1, 0), factor_levels, effects='all', correction=correction, @@ -628,7 +629,7 @@ def stat_fun(*arg): threshold = mne.stats.f_threshold_mway_rm(n_subjects=data.shape[1], factor_levels=factor_levels, effects='all', - pvalue=0.05) + pvalue=alpha) # computing the cluster permutation t test Stat_obs, clusters, cluster_p_values, H0 = permutation_cluster_test(data, diff --git a/hypyp/utils.py b/hypyp/utils.py index 6fa051c..45aeae0 100644 --- a/hypyp/utils.py +++ b/hypyp/utils.py @@ -586,7 +586,7 @@ def fp(t, p): dotp = omega - coupling + noise_phase_level * np.random.randn(n_chan) / n_samp return dotp - p0 = 2 * np.pi * np.block([np.zeros(n_chan/2), np.zeros(n_chan/2) + np.random.rand(n_chan/2) + 0.5]) + p0 = 2 * np.pi * np.block([np.zeros(n_chan//2), np.zeros(n_chan//2) + np.random.rand(n_chan//2) + 0.5]) ans = solve_ivp(fun=fp, t_span=(tv[0], tv[-1]), y0=p0, t_eval=tv) phi = ans['y'].T % (2*np.pi)