-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
def plot_channels_on_grid_time_perm_cluster(evoke_data, std_err_data, channels_subset, mat, sample_rate=2048, dec_factor=8, plot_x_dim=6, plot_y_dim=6):
"""
Plots evoked EEG/MEG data for a subset of channels on a grid, overlaying significance markers for specified time windows.
Parameters:
- evoke_data: mne.Evoked object
The evoked data to be plotted. This object contains the averaged EEG/MEG data over epochs.
- std_err_data:
The standard error of the evoked data to be plotted
- channels_subset: list of str
A list of channel names to be plotted. Each channel name must correspond to a channel in `evoke_data`.
- mat: numpy.array
A binary matrix (same shape as evoke_data) indicating significant data points (1 for significant, 0 for non-significant).
- sample_rate: float
The sampling rate of the data, in Hz. Used to convert sample indices in `time_windows` to time in seconds.
- dec_factor: int
the decimation factor by which to downsample the sampling rate.
- plot_x_dim: int, optional (default=6)
The number of columns in the grid layout for plotting the channels.
- plot_y_dim: int, optional (default=6)
The number of rows in the grid layout for plotting the channels.
Returns:
- fig: matplotlib.figure.Figure object
The figure object containing the grid of plots. Each plot shows the evoked data for a channel, with significance
markers overlaid for the specified time windows.
"""
fig, axes = plt.subplots(plot_x_dim, plot_y_dim, figsize=(20, 12))
fig.suptitle("Channels with Significance Overlay")
axes_flat = axes.flatten()
for channel, ax in zip(channels_subset, axes_flat):
stderr = std_err_data.data[channel_to_index[channel], :]
time_in_seconds = np.arange(0, len(mat[channel_to_index[channel]])) / (sample_rate / dec_factor) # Should be 2048 Hz sample rate
sig_data_in_seconds = np.array(mat[channel_to_index[channel]])
ax.plot(evoke_data.times, evoke_data.data[channel_to_index[channel], :])
# Add the standard error shading
ax.fill_between(evoke_data.times, evoke_data.data[channel_to_index[channel], :] - stderr, evoke_data.data[channel_to_index[channel], :] + stderr, alpha=0.2)
# Find the maximum y-value for the current channel
max_y_value = np.max(evoke_data.data[channel_to_index[channel], :])
# Overlay significance as a horizontal line at the max y-value
significant_points = np.where(sig_data_in_seconds == 1)[0]
for point in significant_points:
ax.hlines(y=max_y_value, xmin=time_in_seconds[point]-1, xmax=time_in_seconds[point] + 0.005 - 1, color='red', linewidth=1) # subtract 1 cuz the sig time is from 0 to 2.5, while the high gamma time is from -1 to 1.5
ax.set_title(channel)
plt.tight_layout()
plt.subplots_adjust(top=0.95)
return fig
The above code plots the same thing for raw as z-score.
And, the below code plots completely wrong indexed channels for the z-score:
def plot_channels_across_subjects(electrode_dict, data_dict, std_err_dict, mat_dict, channel_to_index_dict, plot_x_dim=6, plot_y_dim=6, sample_rate=2048, dec_factor=8, y_label="Amplitude"):
"""
Plots evoked EEG/MEG data across multiple subjects for a set of electrodes, organized into subplots.
Parameters:
- electrode_dict: dict
Dictionary where keys are subjects and values are lists of electrodes to plot for each subject.
- data_dict: dict
Dictionary where each key is a subject and each value is the evoked data for that subject.
- std_err_dict: dict
Dictionary where each key is a subject and each value is the standard error data for that subject.
- mat_dict: dict
Dictionary where each key is a subject and each value is the significance matrix for that subject.
- channel_to_index_dict: dict
Dictionary where each key is a subject and each value is a dictionary mapping channel names to their indices for that subject.
- plot_x_dim: int, optional
Number of columns in the grid layout for plotting the channels.
- plot_y_dim: int, optional
Number of rows in the grid layout for plotting the channels.
- sample_rate: float
Sampling rate of the data in Hz.
- dec_factor: int
Decimation factor by which to downsample the sampling rate.
- y_label: str, optional
Label for the y-axis.
Returns:
- fig: matplotlib.figure.Figure object
The figure object containing the grid of plots.
"""
channels_per_fig = plot_x_dim * plot_y_dim
plot_index = 0
fig_num = 1
fig, axes = plt.subplots(plot_y_dim, plot_x_dim, figsize=(20, 12))
fig.suptitle("Channels Across Subjects with Significance Overlay")
axes_flat = axes.flatten()
for subject, electrodes in electrode_dict.items():
for electrode in electrodes:
if electrode in channel_to_index_dict[subject]:
if plot_index >= channels_per_fig:
plt.tight_layout()
plt.subplots_adjust(top=0.95)
yield fig, fig_num
# Start a new figure if the previous one is full
fig, axes = plt.subplots(plot_y_dim, plot_x_dim, figsize=(20, 12))
fig.suptitle("Channels Across Subjects with Significance Overlay")
axes_flat = axes.flatten()
plot_index = 0
fig_num += 1
ax = axes_flat[plot_index]
ch_idx = channel_to_index_dict[subject][electrode]
stderr = std_err_dict[subject].data[ch_idx, :]
time_in_seconds = np.arange(0, len(mat_dict[subject][ch_idx])) / (sample_rate / dec_factor)
sig_data_in_seconds = np.array(mat_dict[subject][ch_idx])
ax.plot(data_dict[subject].times, data_dict[subject].data[ch_idx, :])
# Add the standard error shading
ax.fill_between(data_dict[subject].times, data_dict[subject].data[ch_idx, :] - stderr, data_dict[subject].data[ch_idx, :] + stderr, alpha=0.2)
# Find the maximum y-value for the current channel
max_y_value = np.max(data_dict[subject].data[ch_idx, :])
# Overlay significance as a horizontal line at the max y-value
significant_points = np.where(sig_data_in_seconds == 1)[0]
for point in significant_points:
ax.hlines(y=max_y_value, xmin=time_in_seconds[point]-1, xmax=time_in_seconds[point] + 0.005 - 1, color='red', linewidth=1)
ax.set_title(f"{subject}: {electrode}")
ax.set_ylabel(y_label)
plot_index += 1
plt.tight_layout()
plt.subplots_adjust(top=0.95)
yield fig, fig_num
Metadata
Metadata
Assignees
Labels
No labels