Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import statsmodels.stats.multitest
import copy
from collections import namedtuple
from fractions import Fraction
from typing import Union, List, Tuple
import matplotlib.pyplot as plt
from tqdm import tqdm
Expand Down Expand Up @@ -715,7 +716,7 @@ def compute_conn_mvar(complex_signal: np.ndarray, mvar_params: dict,
fit_mvar = mvar.fit(merged_signal[:, 0, 0, :][np.newaxis, ...])
is_stable = fit_mvar.stability()
aug_signal = merged_signal[np.newaxis, ...]
counter += counter
counter += 1

else:

Expand Down Expand Up @@ -929,7 +930,7 @@ def compute_nmPLV(data: np.ndarray, sampling_rate: int,

r = np.mean(freq_range2)/np.mean(freq_range1)
freq_range = [np.min(freq_range1), np.max(freq_range2)]
complex_signal = np.mean(compute_single_freq(data, sampling_rate, freq_range, **filter_options),3).squeeze()
complex_signal = np.mean(compute_single_freq(data, sampling_rate, freq_range),3).squeeze()

n_epoch, n_ch, n_freq, n_samp = complex_signal.shape[1], complex_signal.shape[2], \
complex_signal.shape[3], complex_signal.shape[4]
Expand All @@ -939,13 +940,13 @@ def compute_nmPLV(data: np.ndarray, sampling_rate: int,
transpose_axes = (0, 1, 3, 2)
phase = complex_signal / np.abs(complex_signal)

freqsn = freq_range
freqsm = [f * r for f in freqsn]
n_mult = (freqsn[0] + freqsm[0]) / (2 * freqsn[0])
m_mult = (freqsm[0] + freqsn[0]) / (2 * freqsm[0])
# Compute integer n:m ratio so that n*f1 = m*f2
frac = Fraction(r).limit_denominator(10)
n, m = frac.numerator, frac.denominator

phase[:, :, :, :n_ch] = n_mult * phase[:, :, :, :n_ch]
phase[:, :, :, n_ch:] = m_mult * phase[:, :, :, n_ch:]
# Raise phases to integer powers for n:m coupling
phase[:, :, :, :n_ch] = phase[:, :, :, :n_ch] ** n
phase[:, :, :, n_ch:] = phase[:, :, :, n_ch:] ** m

c = np.real(phase)
s = np.imag(phase)
Expand Down Expand Up @@ -1036,12 +1037,23 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, freqs: Union[int, np.ndarray],
assert n_samples1 == n_samples2, "n_samples1 and n_samples2 should have the same number of samples."

cross_sigs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1), dtype=complex) * np.nan
wtcs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1), dtype=complex) * np.nan
wtcs = np.zeros((n_chans1, n_chans2, n_epochs1, n_freqs, n_samples1)) * np.nan

# Set the mother wavelet
Ws = mne.time_frequency.tfr.morlet(sfreq, freqs,
n_cycles=n_cycles, sigma=None, zero_mean=True)

# Wavelet scales in samples (Morlet scale-frequency relation)
scales = n_cycles * sfreq / (2 * np.pi * freqs)

# Precompute Gaussian smoothing windows per frequency
smooth_wins = []
for s in scales:
win_size = max(3, int(2 * s))
win = scipy.signal.windows.gaussian(win_size, std=s)
win /= win.sum()
smooth_wins.append(win)

# Perform a continuous wavelet transform on all epochs of each signal
for ind1, ch_label1 in enumerate(sig1.ch_names):
for ind2, ch_label2 in enumerate(sig2.ch_names):
Expand All @@ -1052,15 +1064,24 @@ def xwt(sig1: mne.Epochs, sig2: mne.Epochs, freqs: Union[int, np.ndarray],
cur_sig2 = np.squeeze(sig2.get_data(mne.pick_channels(sig2.ch_names, [ch_label2])))
out2 = mne.time_frequency.tfr.cwt(cur_sig2, Ws, use_fft=True,
mode='same', decim=1)

# Compute cross-spectrum
wps1 = out1 * out1.conj()
wps2 = out2 * out2.conj()

# Compute cross-spectrum and auto-spectra
cross_sig = out1 * out2.conj()
cross_sigs[ind1, ind2, :, :, :] = cross_sig
coh = (cross_sig) / (np.sqrt(wps1*wps2))
abs_coh = np.abs(coh)
wtc = (abs_coh - np.min(abs_coh)) / (np.max(abs_coh) - np.min(abs_coh))
wps1 = np.abs(out1) ** 2
wps2 = np.abs(out2) ** 2

# Smooth in time with scale-dependent Gaussian window
# following Grinsted et al. (2004)
for fi, win in enumerate(smooth_wins):
wps1[:, fi, :] = np.apply_along_axis(
lambda x: np.convolve(x, win, mode='same'), -1, wps1[:, fi, :])
wps2[:, fi, :] = np.apply_along_axis(
lambda x: np.convolve(x, win, mode='same'), -1, wps2[:, fi, :])
cross_sig[:, fi, :] = np.apply_along_axis(
lambda x: np.convolve(x, win, mode='same'), -1, cross_sig[:, fi, :])

wtc = np.abs(cross_sig) ** 2 / (wps1 * wps2)
wtcs[ind1, ind2, :, :, :] = wtc

if mode == 'power':
Expand Down
2 changes: 1 addition & 1 deletion hypyp/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union',

for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj
picks = mne.pick_types(
clean_epochs_subj[subject_id].info,
clean_epochs_subj.info,
meg=False,
eeg=True,
stim=False,
Expand Down
25 changes: 13 additions & 12 deletions hypyp/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,7 @@ def statsCond(data: np.ndarray, epochs: mne.Epochs, n_permutations: int, alpha:
tail=0, n_jobs=1)
adj_p = mne.stats.fdr_correction(p_values, alpha=alpha, method='indep')

T_obs_plot = np.nan * np.ones_like(T_obs)
for c in adj_p[1]:
if c <= alpha:
i = np.where(adj_p[1] == c)
T_obs_plot[i] = T_obs[i]
T_obs_plot = np.nan_to_num(T_obs_plot)
T_obs_plot = np.where(adj_p[0], T_obs, 0)

# retrieving sensor position
pos = np.array([[0, 0]])
Expand Down Expand Up @@ -475,7 +470,7 @@ def statscondCluster(data: list, freqs_mean: list, ch_con_freq: scipy.sparse.csr
dfd = np.sum([len(d) for d in data]) - len(data) # Denominator degrees of freedom

if tail == 0:
threshold = f_dist.ppf(1 - alpha / 2, dfn, dfd) # 2-tailed F-test
threshold = f_dist.ppf(1 - alpha, dfn, dfd)
else:
threshold = None # One-tailed test uses MNE's default

Expand Down Expand Up @@ -605,30 +600,36 @@ def statscluster(data: list, test: str, factor_levels: List[int], ch_con_freq: s
if test == 'ind ttest':
def stat_fun(*arg):
return(scipy.stats.ttest_ind(arg[0], arg[1], equal_var=False)[0])
threshold = alpha
df = len(data[0]) + len(data[1]) - 2
p = alpha / 2 if tail == 0 else alpha
threshold = scipy.stats.t.ppf(1 - p, df)
elif test == 'rel ttest':
def stat_fun(*arg):
return(scipy.stats.ttest_rel(arg[0], arg[1])[0])
threshold = alpha
df = len(data[0]) - 1
p = alpha / 2 if tail == 0 else alpha
threshold = scipy.stats.t.ppf(1 - p, df)
elif test == 'f oneway':
def stat_fun(*arg):
return(scipy.stats.f_oneway(arg[0], arg[1])[0])
threshold = alpha
dfn = len(data) - 1
dfd = sum(len(d) for d in data) - len(data)
threshold = f_dist.ppf(1 - alpha, dfn, dfd)
elif test == 'f multipleway':
if max(factor_levels) > 2:
correction = True
else:
correction = False
def stat_fun(*arg):
return(mne.stats.f_mway_rm(np.swapaxes(args, 1, 0),
return(mne.stats.f_mway_rm(np.swapaxes(arg, 1, 0),
factor_levels,
effects='all',
correction=correction,
return_pvals=False)[0])
threshold = mne.stats.f_threshold_mway_rm(n_subjects=data.shape[1],
factor_levels=factor_levels,
effects='all',
pvalue=0.05)
pvalue=alpha)

# computing the cluster permutation t test
Stat_obs, clusters, cluster_p_values, H0 = permutation_cluster_test(data,
Expand Down
2 changes: 1 addition & 1 deletion hypyp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def fp(t, p):
dotp = omega - coupling + noise_phase_level * np.random.randn(n_chan) / n_samp
return dotp

p0 = 2 * np.pi * np.block([np.zeros(n_chan/2), np.zeros(n_chan/2) + np.random.rand(n_chan/2) + 0.5])
p0 = 2 * np.pi * np.block([np.zeros(n_chan//2), np.zeros(n_chan//2) + np.random.rand(n_chan//2) + 0.5])
ans = solve_ivp(fun=fp, t_span=(tv[0], tv[-1]), y0=p0, t_eval=tv)
phi = ans['y'].T % (2*np.pi)

Expand Down