diff --git a/src/guppy/analysis/io_utils.py b/src/guppy/analysis/io_utils.py index 742ab3b..97f7ecf 100644 --- a/src/guppy/analysis/io_utils.py +++ b/src/guppy/analysis/io_utils.py @@ -6,17 +6,10 @@ import h5py import numpy as np -import pandas as pd - -logger = logging.getLogger(__name__) +from ..utils.utils import takeOnlyDirs -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) +logger = logging.getLogger(__name__) # find files by ignoring the case sensitivity @@ -143,20 +136,6 @@ def get_coords(filepath, name, tsNew, removeArtifacts): # TODO: Make less redun return coords -def get_all_stores_for_combining_data(folderNames): - op = [] - for i in range(100): - temp = [] - match = r"[\s\S]*" + "_output_" + str(i) - for j in folderNames: - temp.append(re.findall(match, j)) - temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold) - if len(temp) > 0: - op.append(temp) - - return op - - # for combining data, reading storeslist file from both data and create a new storeslist array def check_storeslistfile(folderNames): storesList = np.array([[], []]) @@ -197,19 +176,6 @@ def get_control_and_signal_channel_names(storesList): return channels_arr -# function to read h5 file and make a dataframe from it -def read_Df(filepath, event, name): - event = event.replace("\\", "_") - event = event.replace("/", "_") - if name: - op = os.path.join(filepath, event + "_{}.h5".format(name)) - else: - op = os.path.join(filepath, event + ".h5") - df = pd.read_hdf(op, key="df", mode="r") - - return df - - def make_dir_for_cross_correlation(filepath): op = os.path.join(filepath, "cross_correlation_output") if not os.path.exists(op): diff --git a/src/guppy/analysis/psth_average.py b/src/guppy/analysis/psth_average.py index 664cc3d..5df8c87 100644 --- a/src/guppy/analysis/psth_average.py +++ b/src/guppy/analysis/psth_average.py @@ -10,10 +10,10 @@ from .io_utils import ( make_dir_for_cross_correlation, makeAverageDir, - read_Df, write_hdf5, ) from .psth_utils import create_Df_for_psth, getCorrCombinations +from ..utils.utils import read_Df logger = logging.getLogger(__name__) diff --git a/src/guppy/analysis/standard_io.py b/src/guppy/analysis/standard_io.py index d6dd9af..d0ade6d 100644 --- a/src/guppy/analysis/standard_io.py +++ b/src/guppy/analysis/standard_io.py @@ -331,3 +331,18 @@ def read_freq_and_amp_from_hdf5(filepath, name): df = pd.read_hdf(op, key="df", mode="r") return df + + +def write_transients_to_hdf5(filepath, name, z_score, ts, peaksInd): + event = f"transient_outputs_{name}" + write_hdf5(z_score, event, filepath, "z_score") + write_hdf5(ts, event, filepath, "timestamps") + write_hdf5(peaksInd, event, filepath, "peaksInd") + + +def read_transients_from_hdf5(filepath, name): + event = f"transient_outputs_{name}" + z_score = read_hdf5(event, filepath, "z_score") + ts = read_hdf5(event, filepath, "timestamps") + peaksInd = read_hdf5(event, filepath, "peaksInd") + return z_score, ts, peaksInd diff --git a/src/guppy/frontend/__init__.py b/src/guppy/frontend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/guppy/frontend/artifact_removal.py b/src/guppy/frontend/artifact_removal.py new file mode 100644 index 0000000..f626190 --- /dev/null +++ b/src/guppy/frontend/artifact_removal.py @@ -0,0 +1,75 @@ +import logging +import os + +import matplotlib.pyplot as plt +import numpy as np + +from ..visualization.preprocessing import visualize_control_signal_fit + +logger = logging.getLogger(__name__) + +# Only set matplotlib backend if not in CI environment +if not os.getenv("CI"): + plt.switch_backend("TKAgg") + + +class ArtifactRemovalWidget: + + def __init__(self, filepath, x, y1, y2, y3, plot_name, removeArtifacts): + self.coords = [] # List to store selected coordinates + self.filepath = filepath + self.plot_name = plot_name + + if (y1 == 0).all() == True: + y1 = np.zeros(x.shape[0]) + + coords_path = os.path.join(filepath, "coordsForPreProcessing_" + plot_name[0].split("_")[-1] + ".npy") + artifacts_have_been_removed = removeArtifacts and os.path.exists(coords_path) + name = os.path.basename(filepath) + self.fig, self.ax1, self.ax2, self.ax3 = visualize_control_signal_fit( + x, y1, y2, y3, plot_name, name, artifacts_have_been_removed + ) + + self.cid = self.fig.canvas.mpl_connect("key_press_event", self._on_key_press) + self.fig.canvas.mpl_connect("close_event", self._on_close) + + def _on_key_press(self, event): + """Handle key press events for artifact selection. + + Pressing 'space' draws a vertical line at the cursor position to mark artifact boundaries. + Pressing 'd' removes the most recently added line. + """ + if event.key == " ": + ix, iy = event.xdata, event.ydata + logger.info(f"x = {ix}, y = {iy}") + self.ax1.axvline(ix, c="black", ls="--") + self.ax2.axvline(ix, c="black", ls="--") + self.ax3.axvline(ix, c="black", ls="--") + + self.fig.canvas.draw() + + self.coords.append((ix, iy)) + + return self.coords + + elif event.key == "d": + if len(self.coords) > 0: + logger.info(f"x = {self.coords[-1][0]}, y = {self.coords[-1][1]}; deleted") + del self.coords[-1] + self.ax1.lines[-1].remove() + self.ax2.lines[-1].remove() + self.ax3.lines[-1].remove() + self.fig.canvas.draw() + + return self.coords + + def _on_close(self, _event): + """Handle figure close event by saving coordinates and cleaning up.""" + if self.coords and len(self.coords) > 0: + name_1 = self.plot_name[0].split("_")[-1] + np.save(os.path.join(self.filepath, "coordsForPreProcessing_" + name_1 + ".npy"), self.coords) + logger.info( + f"Coordinates file saved at {os.path.join(self.filepath, 'coordsForPreProcessing_'+name_1+'.npy')}" + ) + self.fig.canvas.mpl_disconnect(self.cid) + self.coords = [] diff --git a/src/guppy/frontend/frontend_utils.py b/src/guppy/frontend/frontend_utils.py new file mode 100644 index 0000000..572e798 --- /dev/null +++ b/src/guppy/frontend/frontend_utils.py @@ -0,0 +1,19 @@ +import logging +import socket +from random import randint + +logger = logging.getLogger(__name__) + + +def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"): + while True: + port = randint(start_port, end_port) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(0.001) # Set timeout to avoid long waiting on closed ports + result = sock.connect_ex((host, port)) + if result == 0: # If the connection is successful, the port is open + continue + else: + break + + return port diff --git a/src/guppy/frontend/input_parameters.py b/src/guppy/frontend/input_parameters.py new file mode 100644 index 0000000..1ae8313 --- /dev/null +++ b/src/guppy/frontend/input_parameters.py @@ -0,0 +1,385 @@ +import logging +import os + +import numpy as np +import pandas as pd +import panel as pn + +logger = logging.getLogger(__name__) + + +def checkSameLocation(arr, abspath): + # abspath = [] + for i in range(len(arr)): + abspath.append(os.path.dirname(arr[i])) + abspath = np.asarray(abspath) + abspath = np.unique(abspath) + if len(abspath) > 1: + logger.error("All the folders selected should be at the same location") + raise Exception("All the folders selected should be at the same location") + + return abspath + + +def getAbsPath(files_1, files_2): + arr_1, arr_2 = files_1.value, files_2.value + if len(arr_1) == 0 and len(arr_2) == 0: + logger.error("No folder is selected for analysis") + raise Exception("No folder is selected for analysis") + + abspath = [] + if len(arr_1) > 0: + abspath = checkSameLocation(arr_1, abspath) + else: + abspath = checkSameLocation(arr_2, abspath) + + abspath = np.unique(abspath) + if len(abspath) > 1: + logger.error("All the folders selected should be at the same location") + raise Exception("All the folders selected should be at the same location") + return abspath + + +class ParameterForm: + def __init__(self, *, template, folder_path): + self.template = template + self.folder_path = folder_path + self.styles = dict(background="WhiteSmoke") + self.setup_individual_parameters() + self.setup_group_parameters() + self.setup_visualization_parameters() + self.add_to_template() + + def setup_individual_parameters(self): + # Individual analysis components + self.mark_down_1 = pn.pane.Markdown( + """**Select folders for the analysis from the file selector below**""", width=600 + ) + + self.files_1 = pn.widgets.FileSelector(self.folder_path, name="folderNames", width=950) + + self.explain_modality = pn.pane.Markdown( + """ + **Data Modality:** Select the type of data acquisition system used for your recordings: + - **tdt**: Tucker-Davis Technologies system + - **csv**: Generic CSV format + - **doric**: Doric Photometry system + - **npm**: Neurophotometrics system + """, + width=600, + ) + + self.modality_selector = pn.widgets.Select( + name="Data Modality", value="tdt", options=["tdt", "csv", "doric", "npm"], width=320 + ) + + self.explain_time_artifacts = pn.pane.Markdown( + """ + - ***Number of cores :*** Number of cores used for analysis. Try to + keep it less than the number of cores in your machine. + - ***Combine Data? :*** Make this parameter ``` True ``` if user wants to combine + the data, especially when there is two different + data files for the same recording session.
+ - ***Isosbestic Control Channel? :*** Make this parameter ``` False ``` if user + does not want to use isosbestic control channel in the analysis.
+ - ***Eliminate first few seconds :*** It is the parameter to cut out first x seconds + from the data. Default is 1 seconds.
+ - ***Window for Moving Average filter :*** The filtering of signals + is done using moving average filter. Default window used for moving + average filter is 100 datapoints. Change it based on the requirement.
+ - ***Moving Window (transients detection) :*** Transients in the z-score + and/or \u0394F/F are detected using this moving window. + Default is 15 seconds. Change it based on the requirement.
+ - ***High Amplitude filtering threshold (HAFT) (transients detection) :*** High amplitude + events greater than x times the MAD above the median are filtered out. Here, x is + high amplitude filtering threshold. Default is 2. + - ***Transients detection threshold (TD Thresh):*** Peaks with local maxima greater than x times + the MAD above the median of the trace (after filtering high amplitude events) are detected + as transients. Here, x is transients detection threshold. Default is 3. + - ***Number of channels (Neurophotometrics only) :*** Number of + channels used while recording, when data files has no column names mentioning "Flags" + or "LedState". + - ***removeArtifacts? :*** Make this parameter ``` True``` if there are + artifacts and user wants to remove the artifacts. + - ***removeArtifacts method :*** Selecting ```concatenate``` will remove bad + chunks and concatenate the selected good chunks together. + Selecting ```replace with NaN``` will replace bad chunks with NaN + values. + """, + width=350, + ) + + self.timeForLightsTurnOn = pn.widgets.LiteralInput( + name="Eliminate first few seconds (int)", value=1, type=int, width=320 + ) + + self.isosbestic_control = pn.widgets.Select( + name="Isosbestic Control Channel? (bool)", value=True, options=[True, False], width=320 + ) + + self.numberOfCores = pn.widgets.LiteralInput(name="# of cores (int)", value=2, type=int, width=150) + + self.combine_data = pn.widgets.Select( + name="Combine Data? (bool)", value=False, options=[True, False], width=150 + ) + + self.computePsth = pn.widgets.Select( + name="z_score and/or \u0394F/F? (psth)", options=["z_score", "dff", "Both"], width=320 + ) + + self.transients = pn.widgets.Select( + name="z_score and/or \u0394F/F? (transients)", options=["z_score", "dff", "Both"], width=320 + ) + + self.plot_zScore_dff = pn.widgets.Select( + name="z-score plot and/or \u0394F/F plot?", + options=["z_score", "dff", "Both", "None"], + value="None", + width=320, + ) + + self.moving_wd = pn.widgets.LiteralInput( + name="Moving Window for transients detection (s) (int)", value=15, type=int, width=320 + ) + + self.highAmpFilt = pn.widgets.LiteralInput(name="HAFT (int)", value=2, type=int, width=150) + + self.transientsThresh = pn.widgets.LiteralInput(name="TD Thresh (int)", value=3, type=int, width=150) + + self.moving_avg_filter = pn.widgets.LiteralInput( + name="Window for Moving Average filter (int)", value=100, type=int, width=320 + ) + + self.removeArtifacts = pn.widgets.Select( + name="removeArtifacts? (bool)", value=False, options=[True, False], width=150 + ) + + self.artifactsRemovalMethod = pn.widgets.Select( + name="removeArtifacts method", value="concatenate", options=["concatenate", "replace with NaN"], width=150 + ) + + self.no_channels_np = pn.widgets.LiteralInput( + name="Number of channels (Neurophotometrics only)", value=2, type=int, width=320 + ) + + self.z_score_computation = pn.widgets.Select( + name="z-score computation Method", + options=["standard z-score", "baseline z-score", "modified z-score"], + value="standard z-score", + width=200, + ) + + self.baseline_wd_strt = pn.widgets.LiteralInput( + name="Baseline Window Start Time (s) (int)", value=0, type=int, width=200 + ) + self.baseline_wd_end = pn.widgets.LiteralInput( + name="Baseline Window End Time (s) (int)", value=0, type=int, width=200 + ) + + self.explain_z_score = pn.pane.Markdown( + """ + ***Note :***
+ - Details about z-score computation methods are explained in Github wiki.
+ - The details will make user understand what computation method to use for + their data.
+ - Baseline Window Parameters should be kept 0 unless you are using baseline
+ z-score computation method. The parameters are in seconds. + """, + width=580, + ) + + self.explain_nsec = pn.pane.Markdown( + """ + - ***Time Interval :*** To omit bursts of event timestamps, user defined time interval + is set so that if the time difference between two timestamps is less than this defined time + interval, it will be deleted for the calculation of PSTH. + - ***Compute Cross-correlation :*** Make this parameter ```True```, when user wants + to compute cross-correlation between PSTHs of two different signals or signals + recorded from different brain regions. + """, + width=580, + ) + + self.nSecPrev = pn.widgets.LiteralInput(name="Seconds before 0 (int)", value=-10, type=int, width=120) + + self.nSecPost = pn.widgets.LiteralInput(name="Seconds after 0 (int)", value=20, type=int, width=120) + + self.computeCorr = pn.widgets.Select( + name="Compute Cross-correlation (bool)", options=[True, False], value=False, width=200 + ) + + self.timeInterval = pn.widgets.LiteralInput(name="Time Interval (s)", value=2, type=int, width=120) + + self.use_time_or_trials = pn.widgets.Select( + name="Bin PSTH trials (str)", options=["Time (min)", "# of trials"], value="Time (min)", width=120 + ) + + self.bin_psth_trials = pn.widgets.LiteralInput( + name="Time(min) / # of trials \n for binning? (int)", value=0, type=int, width=200 + ) + + self.explain_baseline = pn.pane.Markdown( + """ + ***Note :***
+ - If user does not want to do baseline correction, + put both parameters 0.
+ - If the first event timestamp is less than the length of baseline + window, it will be rejected in the PSTH computation step.
+ - Baseline parameters must be within the PSTH parameters + set in the PSTH parameters section. + """, + width=580, + ) + + self.baselineCorrectionStart = pn.widgets.LiteralInput( + name="Baseline Correction Start time(int)", value=-5, type=int, width=200 + ) + + self.baselineCorrectionEnd = pn.widgets.LiteralInput( + name="Baseline Correction End time(int)", value=0, type=int, width=200 + ) + + self.zscore_param_wd = pn.WidgetBox( + "### Z-score Parameters", + self.explain_z_score, + self.z_score_computation, + pn.Row(self.baseline_wd_strt, self.baseline_wd_end), + width=600, + ) + + self.psth_param_wd = pn.WidgetBox( + "### PSTH Parameters", + self.explain_nsec, + pn.Row(self.nSecPrev, self.nSecPost, self.computeCorr), + pn.Row(self.timeInterval, self.use_time_or_trials, self.bin_psth_trials), + width=600, + ) + + self.baseline_param_wd = pn.WidgetBox( + "### Baseline Parameters", + self.explain_baseline, + pn.Row(self.baselineCorrectionStart, self.baselineCorrectionEnd), + width=600, + ) + self.peak_explain = pn.pane.Markdown( + """ + ***Note :***
+ - Peak and area are computed between the window set below.
+ - Peak and AUC parameters must be within the PSTH parameters set in the PSTH parameters section.
+ - Please make sure when user changes the parameters in the table below, click on any other cell after + changing a value in a particular cell. + """, + width=580, + ) + + self.start_end_point_df = pd.DataFrame( + { + "Peak Start time": [-5, 0, 5, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "Peak End time": [0, 3, 10, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + + self.df_widget = pn.widgets.Tabulator(self.start_end_point_df, name="DataFrame", show_index=False, widths=280) + + self.peak_param_wd = pn.WidgetBox("### Peak and AUC Parameters", self.peak_explain, self.df_widget, width=600) + + self.individual_analysis_wd_2 = pn.Column( + self.explain_time_artifacts, + pn.Row(self.numberOfCores, self.combine_data), + self.isosbestic_control, + self.timeForLightsTurnOn, + self.moving_avg_filter, + self.computePsth, + self.transients, + self.plot_zScore_dff, + self.moving_wd, + pn.Row(self.highAmpFilt, self.transientsThresh), + self.no_channels_np, + pn.Row(self.removeArtifacts, self.artifactsRemovalMethod), + ) + + self.psth_baseline_param = pn.Column( + self.zscore_param_wd, self.psth_param_wd, self.baseline_param_wd, self.peak_param_wd + ) + + self.widget = pn.Column( + self.mark_down_1, + self.files_1, + self.explain_modality, + self.modality_selector, + pn.Row(self.individual_analysis_wd_2, self.psth_baseline_param), + ) + self.individual = pn.Card(self.widget, title="Individual Analysis", styles=self.styles, width=1000) + + def setup_group_parameters(self): + self.mark_down_2 = pn.pane.Markdown( + """**Select folders for the average analysis from the file selector below**""", width=600 + ) + + self.files_2 = pn.widgets.FileSelector(self.folder_path, name="folderNamesForAvg", width=950) + + self.averageForGroup = pn.widgets.Select( + name="Average Group? (bool)", value=False, options=[True, False], width=435 + ) + + self.group_analysis_wd_1 = pn.Column(self.mark_down_2, self.files_2, self.averageForGroup, width=800) + self.group = pn.Card(self.group_analysis_wd_1, title="Group Analysis", styles=self.styles, width=1000) + + def setup_visualization_parameters(self): + self.visualizeAverageResults = pn.widgets.Select( + name="Visualize Average Results? (bool)", value=False, options=[True, False], width=435 + ) + + self.visualize_zscore_or_dff = pn.widgets.Select( + name="z-score or \u0394F/F? (for visualization)", options=["z_score", "dff"], width=435 + ) + + self.visualization_wd = pn.Row(self.visualize_zscore_or_dff, pn.Spacer(width=60), self.visualizeAverageResults) + self.visualize = pn.Card( + self.visualization_wd, title="Visualization Parameters", styles=self.styles, width=1000 + ) + + def add_to_template(self): + self.template.main.append(self.individual) + self.template.main.append(self.group) + self.template.main.append(self.visualize) + + def getInputParameters(self): + abspath = getAbsPath(self.files_1, self.files_2) + inputParameters = { + "abspath": abspath[0], + "folderNames": self.files_1.value, + "modality": self.modality_selector.value, + "numberOfCores": self.numberOfCores.value, + "combine_data": self.combine_data.value, + "isosbestic_control": self.isosbestic_control.value, + "timeForLightsTurnOn": self.timeForLightsTurnOn.value, + "filter_window": self.moving_avg_filter.value, + "removeArtifacts": self.removeArtifacts.value, + "artifactsRemovalMethod": self.artifactsRemovalMethod.value, + "noChannels": self.no_channels_np.value, + "zscore_method": self.z_score_computation.value, + "baselineWindowStart": self.baseline_wd_strt.value, + "baselineWindowEnd": self.baseline_wd_end.value, + "nSecPrev": self.nSecPrev.value, + "nSecPost": self.nSecPost.value, + "computeCorr": self.computeCorr.value, + "timeInterval": self.timeInterval.value, + "bin_psth_trials": self.bin_psth_trials.value, + "use_time_or_trials": self.use_time_or_trials.value, + "baselineCorrectionStart": self.baselineCorrectionStart.value, + "baselineCorrectionEnd": self.baselineCorrectionEnd.value, + "peak_startPoint": list(self.df_widget.value["Peak Start time"]), # startPoint.value, + "peak_endPoint": list(self.df_widget.value["Peak End time"]), # endPoint.value, + "selectForComputePsth": self.computePsth.value, + "selectForTransientsComputation": self.transients.value, + "moving_window": self.moving_wd.value, + "highAmpFilt": self.highAmpFilt.value, + "transientsThresh": self.transientsThresh.value, + "plot_zScore_dff": self.plot_zScore_dff.value, + "visualize_zscore_or_dff": self.visualize_zscore_or_dff.value, + "folderNamesForAvg": self.files_2.value, + "averageForGroup": self.averageForGroup.value, + "visualizeAverageResults": self.visualizeAverageResults.value, + } + return inputParameters diff --git a/src/guppy/frontend/npm_gui_prompts.py b/src/guppy/frontend/npm_gui_prompts.py new file mode 100644 index 0000000..a6f2077 --- /dev/null +++ b/src/guppy/frontend/npm_gui_prompts.py @@ -0,0 +1,104 @@ +import logging +import tkinter as tk +from tkinter import StringVar, messagebox, ttk + +logger = logging.getLogger(__name__) + + +def get_multi_event_responses(multiple_event_ttls): + responses = [] + for has_multiple in multiple_event_ttls: + if not has_multiple: + responses.append(False) + continue + window = tk.Tk() + response = messagebox.askyesno( + "Multiple event TTLs", + ( + "Based on the TTL file, " + "it looks like TTLs " + "belong to multiple behavior types. " + "Do you want to create multiple files for each " + "behavior type?" + ), + ) + window.destroy() + responses.append(response) + return responses + + +def get_timestamp_configuration(ts_unit_needs, col_names_ts): + ts_units, npm_timestamp_column_names = [], [] + for need in ts_unit_needs: + if not need: + ts_units.append("seconds") + npm_timestamp_column_names.append(None) + continue + window = tk.Tk() + window.title("Select appropriate options for timestamps") + window.geometry("500x200") + holdComboboxValues = dict() + + timestamps_label = ttk.Label(window, text="Select which timestamps to use : ").grid( + row=0, column=1, pady=25, padx=25 + ) + holdComboboxValues["timestamps"] = StringVar() + timestamps_combo = ttk.Combobox(window, values=col_names_ts, textvariable=holdComboboxValues["timestamps"]) + timestamps_combo.grid(row=0, column=2, pady=25, padx=25) + timestamps_combo.current(0) + # timestamps_combo.bind("<>", comboBoxSelected) + + time_unit_label = ttk.Label(window, text="Select timestamps unit : ").grid(row=1, column=1, pady=25, padx=25) + holdComboboxValues["time_unit"] = StringVar() + time_unit_combo = ttk.Combobox( + window, + values=["", "seconds", "milliseconds", "microseconds"], + textvariable=holdComboboxValues["time_unit"], + ) + time_unit_combo.grid(row=1, column=2, pady=25, padx=25) + time_unit_combo.current(0) + # time_unit_combo.bind("<>", comboBoxSelected) + window.lift() + window.after(500, lambda: window.lift()) + window.mainloop() + + if holdComboboxValues["timestamps"].get(): + npm_timestamp_column_name = holdComboboxValues["timestamps"].get() + else: + messagebox.showerror( + "All options not selected", + "All the options for timestamps \ + were not selected. Please select appropriate options", + ) + logger.error( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + raise Exception( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + if holdComboboxValues["time_unit"].get(): + if holdComboboxValues["time_unit"].get() == "seconds": + ts_unit = holdComboboxValues["time_unit"].get() + elif holdComboboxValues["time_unit"].get() == "milliseconds": + ts_unit = holdComboboxValues["time_unit"].get() + else: + ts_unit = holdComboboxValues["time_unit"].get() + else: + messagebox.showerror( + "All options not selected", + "All the options for timestamps \ + were not selected. Please select appropriate options", + ) + logger.error( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + raise Exception( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + ts_units.append(ts_unit) + npm_timestamp_column_names.append(npm_timestamp_column_name) + return ts_units, npm_timestamp_column_names diff --git a/src/guppy/frontend/parameterized_plotter.py b/src/guppy/frontend/parameterized_plotter.py new file mode 100644 index 0000000..2d7d2f1 --- /dev/null +++ b/src/guppy/frontend/parameterized_plotter.py @@ -0,0 +1,582 @@ +import logging +import math +import os +import re + +import datashader as ds +import holoviews as hv +import numpy as np +import pandas as pd +import panel as pn +import param +from bokeh.io import export_png, export_svgs +from holoviews import opts +from holoviews.operation.datashader import datashade +from holoviews.plotting.util import process_cmap + +pn.extension() + +logger = logging.getLogger(__name__) + + +# remove unnecessary column names +def remove_cols(cols): + regex = re.compile("bin_err_*") + remove_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + remove_cols = remove_cols + ["err", "timestamps"] + cols = [i for i in cols if i not in remove_cols] + + return cols + + +# make a new directory for saving plots +def make_dir(filepath): + op = os.path.join(filepath, "saved_plots") + if not os.path.exists(op): + os.mkdir(op) + + return op + + +# create a class to make GUI and plot different graphs +class ParameterizedPlotter(param.Parameterized): + event_selector_objects = param.List(default=None) + event_selector_heatmap_objects = param.List(default=None) + selector_for_multipe_events_plot_objects = param.List(default=None) + color_map_objects = param.List(default=None) + x_objects = param.List(default=None) + y_objects = param.List(default=None) + heatmap_y_objects = param.List(default=None) + psth_y_objects = param.List(default=None) + + filepath = param.Path(default=None) + # create different options and selectors + event_selector = param.ObjectSelector(default=None) + event_selector_heatmap = param.ObjectSelector(default=None) + selector_for_multipe_events_plot = param.ListSelector(default=None) + columns_dict = param.Dict(default=None) + df_new = param.DataFrame(default=None) + x_min = param.Number(default=None) + x_max = param.Number(default=None) + select_trials_checkbox = param.ListSelector(default=["just trials"], objects=["mean", "just trials"]) + Y_Label = param.ObjectSelector(default="y", objects=["y", "z-score", "\u0394F/F"]) + save_options = param.ObjectSelector( + default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"] + ) + save_options_heatmap = param.ObjectSelector( + default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"] + ) + color_map = param.ObjectSelector(default="plasma") + height_heatmap = param.ObjectSelector(default=600, objects=list(np.arange(0, 5100, 100))[1:]) + width_heatmap = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:]) + Height_Plot = param.ObjectSelector(default=300, objects=list(np.arange(0, 5100, 100))[1:]) + Width_Plot = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:]) + save_hm = param.Action(lambda x: x.param.trigger("save_hm"), label="Save") + save_psth = param.Action(lambda x: x.param.trigger("save_psth"), label="Save") + X_Limit = param.Range(default=(-5, 10)) + Y_Limit = param.Range(bounds=(-50, 50.0)) + + x = param.ObjectSelector(default=None) + y = param.ObjectSelector(default=None) + heatmap_y = param.ListSelector(default=None) + psth_y = param.ListSelector(default=None) + results_hm = dict() + results_psth = dict() + + def __init__(self, **params): + super().__init__(**params) + # Bind selector objects from companion params + self.param.event_selector.objects = self.event_selector_objects + self.param.event_selector_heatmap.objects = self.event_selector_heatmap_objects + self.param.selector_for_multipe_events_plot.objects = self.selector_for_multipe_events_plot_objects + self.param.color_map.objects = self.color_map_objects + self.param.x.objects = self.x_objects + self.param.y.objects = self.y_objects + self.param.heatmap_y.objects = self.heatmap_y_objects + self.param.psth_y.objects = self.psth_y_objects + + # Set defaults + self.event_selector = self.event_selector_objects[0] + self.event_selector_heatmap = self.event_selector_heatmap_objects[0] + self.selector_for_multipe_events_plot = [self.selector_for_multipe_events_plot_objects[0]] + self.x = self.x_objects[0] + self.y = self.y_objects[-2] + self.heatmap_y = [self.heatmap_y_objects[-1]] + + self.param.X_Limit.bounds = (self.x_min, self.x_max) + + # function to save heatmaps when save button on heatmap tab is clicked + @param.depends("save_hm", watch=True) + def save_hm_plots(self): + plot = self.results_hm["plot"] + op = self.results_hm["op"] + save_opts = self.save_options_heatmap + logger.info(save_opts) + if save_opts == "save_svg_format": + p = hv.render(plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=op + ".svg") + elif save_opts == "save_png_format": + p = hv.render(plot, backend="bokeh") + export_png(p, filename=op + ".png") + elif save_opts == "save_both_format": + p = hv.render(plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=op + ".svg") + p_png = hv.render(plot, backend="bokeh") + export_png(p_png, filename=op + ".png") + else: + return 0 + + # function to save PSTH plots when save button on PSTH tab is clicked + @param.depends("save_psth", watch=True) + def save_psth_plot(self): + plot, op = [], [] + plot.append(self.results_psth["plot_combine"]) + op.append(self.results_psth["op_combine"]) + plot.append(self.results_psth["plot"]) + op.append(self.results_psth["op"]) + for i in range(len(plot)): + temp_plot, temp_op = plot[i], op[i] + save_opts = self.save_options + if save_opts == "save_svg_format": + p = hv.render(temp_plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=temp_op + ".svg") + elif save_opts == "save_png_format": + p = hv.render(temp_plot, backend="bokeh") + export_png(p, filename=temp_op + ".png") + elif save_opts == "save_both_format": + p = hv.render(temp_plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=temp_op + ".svg") + p_png = hv.render(temp_plot, backend="bokeh") + export_png(p_png, filename=temp_op + ".png") + else: + return 0 + + # function to change Y values based on event selection + @param.depends("event_selector", watch=True) + def _update_x_y(self): + x_value = self.columns_dict[self.event_selector] + y_value = self.columns_dict[self.event_selector] + self.param["x"].objects = [x_value[-4]] + self.param["y"].objects = remove_cols(y_value) + self.x = x_value[-4] + self.y = self.param["y"].objects[-2] + + @param.depends("event_selector_heatmap", watch=True) + def _update_df(self): + cols = self.columns_dict[self.event_selector_heatmap] + trial_no = range(1, len(remove_cols(cols)[:-2]) + 1) + trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + ["All"] + self.param["heatmap_y"].objects = trial_ts + self.heatmap_y = [trial_ts[-1]] + + @param.depends("event_selector", watch=True) + def _update_psth_y(self): + cols = self.columns_dict[self.event_selector] + trial_no = range(1, len(remove_cols(cols)[:-2]) + 1) + trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + self.param["psth_y"].objects = trial_ts + self.psth_y = [trial_ts[0]] + + # function to plot multiple PSTHs into one plot + + @param.depends( + "selector_for_multipe_events_plot", + "Y_Label", + "save_options", + "X_Limit", + "Y_Limit", + "Height_Plot", + "Width_Plot", + ) + def update_selector(self): + data_curve, cols_curve, data_spread, cols_spread = [], [], [], [] + arr = self.selector_for_multipe_events_plot + df1 = self.df_new + for i in range(len(arr)): + if "bin" in arr[i]: + split = arr[i].rsplit("_", 2) + df_name = split[0] #'{}_{}'.format(split[0], split[1]) + col_name_mean = "{}_{}".format(split[-2], split[-1]) + col_name_err = "{}_err_{}".format(split[-2], split[-1]) + data_curve.append(df1[df_name][col_name_mean]) + cols_curve.append(arr[i]) + data_spread.append(df1[df_name][col_name_err]) + cols_spread.append(arr[i]) + else: + data_curve.append(df1[arr[i]]["mean"]) + cols_curve.append(arr[i] + "_" + "mean") + data_spread.append(df1[arr[i]]["err"]) + cols_spread.append(arr[i] + "_" + "mean") + + if len(arr) > 0: + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(np.asarray(data_curve)) - 0.5, np.nanmax(np.asarray(data_curve)) + 0.5) + + if "bin" in arr[i]: + split = arr[i].rsplit("_", 2) + df_name = split[0] + data_curve.append(df1[df_name]["timestamps"]) + cols_curve.append("timestamps") + data_spread.append(df1[df_name]["timestamps"]) + cols_spread.append("timestamps") + else: + data_curve.append(df1[arr[i]]["timestamps"]) + cols_curve.append("timestamps") + data_spread.append(df1[arr[i]]["timestamps"]) + cols_spread.append("timestamps") + df_curve = pd.concat(data_curve, axis=1) + df_spread = pd.concat(data_spread, axis=1) + df_curve.columns = cols_curve + df_spread.columns = cols_spread + + ts = df_curve["timestamps"] + index = np.arange(0, ts.shape[0], 3) + df_curve = df_curve.loc[index, :] + df_spread = df_spread.loc[index, :] + overlay = hv.NdOverlay( + { + c: hv.Curve((df_curve["timestamps"], df_curve[c]), kdims=["Time (s)"]).opts( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + ) + for c in cols_curve[:-1] + } + ) + spread = hv.NdOverlay( + { + d: hv.Spread( + (df_spread["timestamps"], df_curve[d], df_spread[d], df_spread[d]), + vdims=["y", "yerrpos", "yerrneg"], + ).opts(line_width=0, fill_alpha=0.3) + for d in cols_spread[:-1] + } + ) + plot_combine = ((overlay * spread).opts(opts.NdOverlay(xlabel="Time (s)", ylabel=self.Y_Label))).opts( + shared_axes=False + ) + # plot_err = new_df.hvplot.area(x='timestamps', y=[], y2=[]) + save_opts = self.save_options + op = make_dir(self.filepath) + op_filename = os.path.join(op, str(arr) + "_mean") + + self.results_psth["plot_combine"] = plot_combine + self.results_psth["op_combine"] = op_filename + # self.save_plots(plot_combine, save_opts, op_filename) + return plot_combine + + # function to plot mean PSTH, single trial in PSTH and all the trials of PSTH with mean + @param.depends( + "event_selector", "x", "y", "Y_Label", "save_options", "Y_Limit", "X_Limit", "Height_Plot", "Width_Plot" + ) + def contPlot(self): + df1 = self.df_new[self.event_selector] + # height = self.Heigth_Plot + # width = self.Width_Plot + # logger.info(height, width) + if self.y == "All": + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(np.asarray(df1)) - 0.5, np.nanmax(np.asarray(df1)) - 0.5) + + options = self.param["y"].objects + regex = re.compile("bin_[(]") + remove_bin_trials = [options[i] for i in range(len(options)) if not regex.match(options[i])] + + ndoverlay = hv.NdOverlay({c: hv.Curve((df1[self.x], df1[c])) for c in remove_bin_trials[:-2]}) + img1 = datashade(ndoverlay, normalization="linear", aggregator=ds.count()) + x_points = df1[self.x] + y_points = df1["mean"] + img2 = hv.Curve((x_points, y_points)) + img = (img1 * img2).opts( + opts.Curve( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + line_width=4, + color="black", + xlim=self.X_Limit, + ylim=self.Y_Limit, + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ) + + save_opts = self.save_options + + op = make_dir(self.filepath) + op_filename = os.path.join(op, self.event_selector + "_" + self.y) + self.results_psth["plot"] = img + self.results_psth["op"] = op_filename + # self.save_plots(img, save_opts, op_filename) + + return img + + elif self.y == "mean" or "bin" in self.y: + + xpoints = df1[self.x] + ypoints = df1[self.y] + if self.y == "mean": + err = df1["err"] + else: + split = self.y.split("_") + err = df1["{}_err_{}".format(split[0], split[1])] + + index = np.arange(0, xpoints.shape[0], 3) + + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5) + + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="blue", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ropts_spread = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + fill_alpha=0.3, + fill_color="blue", + line_width=0, + ) + + plot_curve = hv.Curve((xpoints[index], ypoints[index])) # .opts(**ropts_curve) + plot_spread = hv.Spread( + (xpoints[index], ypoints[index], err[index], err[index]) + ) # .opts(**ropts_spread) #vdims=['y', 'yerrpos', 'yerrneg'] + plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) + + save_opts = self.save_options + op = make_dir(self.filepath) + op_filename = os.path.join(op, self.event_selector + "_" + self.y) + self.results_psth["plot"] = plot + self.results_psth["op"] = op_filename + # self.save_plots(plot, save_opts, op_filename) + + return plot + + else: + xpoints = df1[self.x] + ypoints = df1[self.y] + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5) + + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="blue", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + plot = hv.Curve((xpoints, ypoints)).opts({"Curve": ropts_curve}) + + save_opts = self.save_options + op = make_dir(self.filepath) + op_filename = os.path.join(op, self.event_selector + "_" + self.y) + self.results_psth["plot"] = plot + self.results_psth["op"] = op_filename + # self.save_plots(plot, save_opts, op_filename) + + return plot + + # function to plot specific PSTH trials + @param.depends( + "event_selector", + "x", + "psth_y", + "select_trials_checkbox", + "Y_Label", + "save_options", + "Y_Limit", + "X_Limit", + "Height_Plot", + "Width_Plot", + ) + def plot_specific_trials(self): + df_psth = self.df_new[self.event_selector] + # if self.Y_Limit==None: + # self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5) + + if self.psth_y == None: + return None + else: + selected_trials = [s.split(" - ")[1] for s in list(self.psth_y)] + + index = np.arange(0, df_psth["timestamps"].shape[0], 3) + + if self.select_trials_checkbox == ["just trials"]: + overlay = hv.NdOverlay( + { + c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"]) + for c in selected_trials + } + ) + ropts = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + return overlay.opts(**ropts) + elif self.select_trials_checkbox == ["mean"]: + arr = np.asarray(df_psth[selected_trials]) + mean = np.nanmean(arr, axis=1) + err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1]) + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="blue", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ropts_spread = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + fill_alpha=0.3, + fill_color="blue", + line_width=0, + ) + plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index])) + plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index])) + plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) + return plot + elif self.select_trials_checkbox == ["mean", "just trials"]: + overlay = hv.NdOverlay( + { + c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"]) + for c in selected_trials + } + ) + ropts_overlay = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + + arr = np.asarray(df_psth[selected_trials]) + mean = np.nanmean(arr, axis=1) + err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1]) + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="black", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ropts_spread = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + fill_alpha=0.3, + fill_color="black", + line_width=0, + ) + plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index])) + plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index])) + + plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) + return overlay.opts(**ropts_overlay) * plot + + # function to show heatmaps for each event + @param.depends("event_selector_heatmap", "color_map", "height_heatmap", "width_heatmap", "heatmap_y") + def heatmap(self): + height = self.height_heatmap + width = self.width_heatmap + df_hm = self.df_new[self.event_selector_heatmap] + cols = list(df_hm.columns) + regex = re.compile("bin_err_*") + drop_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + drop_cols = ["err", "mean"] + drop_cols + df_hm = df_hm.drop(drop_cols, axis=1) + cols = list(df_hm.columns) + bin_cols = [cols[i] for i in range(len(cols)) if re.compile("bin_*").match(cols[i])] + time = np.asarray(df_hm["timestamps"]) + event_ts_for_each_event = np.arange(1, len(df_hm.columns[:-1]) + 1) + yticks = list(event_ts_for_each_event) + z_score = np.asarray(df_hm[df_hm.columns[:-1]]).T + + if self.heatmap_y[0] == "All": + indices = np.arange(z_score.shape[0] - len(bin_cols)) + z_score = z_score[indices, :] + event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1) + yticks = list(event_ts_for_each_event) + else: + remove_all = list(set(self.heatmap_y) - set(["All"])) + indices = sorted([int(s.split("-")[0]) - 1 for s in remove_all]) + z_score = z_score[indices, :] + event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1) + yticks = list(event_ts_for_each_event) + + clim = (np.nanmin(z_score), np.nanmax(z_score)) + font_size = {"labels": 16, "yticks": 6} + + if event_ts_for_each_event.shape[0] == 1: + dummy_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)).opts(colorbar=True, clim=clim) + image = ( + (dummy_image).opts( + opts.QuadMesh( + width=int(width), + height=int(height), + cmap=process_cmap(self.color_map, provider="matplotlib"), + colorbar=True, + ylabel="Trials", + xlabel="Time (s)", + fontsize=font_size, + yticks=yticks, + ) + ) + ).opts(shared_axes=False) + + save_opts = self.save_options_heatmap + op = make_dir(self.filepath) + op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap") + self.results_hm["plot"] = image + self.results_hm["op"] = op_filename + # self.save_plots(image, save_opts, op_filename) + return image + else: + ropts = dict( + width=int(width), + height=int(height), + ylabel="Trials", + xlabel="Time (s)", + fontsize=font_size, + yticks=yticks, + invert_yaxis=True, + ) + dummy_image = hv.QuadMesh((time[0:100], event_ts_for_each_event, z_score[:, 0:100])).opts( + colorbar=True, cmap=process_cmap(self.color_map, provider="matplotlib"), clim=clim + ) + actual_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)) + + dynspread_img = datashade(actual_image, cmap=process_cmap(self.color_map, provider="matplotlib")).opts( + **ropts + ) # clims=self.C_Limit, cnorm='log' + image = ((dummy_image * dynspread_img).opts(opts.QuadMesh(width=int(width), height=int(height)))).opts( + shared_axes=False + ) + + save_opts = self.save_options_heatmap + op = make_dir(self.filepath) + op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap") + self.results_hm["plot"] = image + self.results_hm["op"] = op_filename + + return image diff --git a/src/guppy/frontend/path_selection.py b/src/guppy/frontend/path_selection.py new file mode 100644 index 0000000..3d69efb --- /dev/null +++ b/src/guppy/frontend/path_selection.py @@ -0,0 +1,40 @@ +import logging +import os +import tkinter as tk +from tkinter import filedialog, ttk + +logger = logging.getLogger(__name__) + + +def get_folder_path(): + # Determine base folder path (headless-friendly via env var) + base_dir_env = os.environ.get("GUPPY_BASE_DIR") + is_headless = base_dir_env and os.path.isdir(base_dir_env) + if is_headless: + folder_path = base_dir_env + logger.info(f"Folder path set to {folder_path} (from GUPPY_BASE_DIR)") + return folder_path + + # Create the main window + folder_selection = tk.Tk() + folder_selection.title("Select the folder path where your data is located") + folder_selection.geometry("700x200") + + selected_path = {"value": None} + + def select_folder(): + selected = filedialog.askdirectory(title="Select the folder path where your data is located") + if selected: + logger.info(f"Folder path set to {selected}") + selected_path["value"] = selected + else: + default_path = os.path.expanduser("~") + logger.info(f"Folder path set to {default_path}") + selected_path["value"] = default_path + folder_selection.destroy() + + select_button = ttk.Button(folder_selection, text="Select a Folder", command=select_folder) + select_button.pack(pady=5) + folder_selection.mainloop() + + return selected_path["value"] diff --git a/src/guppy/frontend/progress.py b/src/guppy/frontend/progress.py new file mode 100644 index 0000000..fb5e7c2 --- /dev/null +++ b/src/guppy/frontend/progress.py @@ -0,0 +1,50 @@ +import logging +import os +import time + +logger = logging.getLogger(__name__) + + +def readPBIncrementValues(progressBar): + logger.info("Read progress bar increment values function started...") + file_path = os.path.join(os.path.expanduser("~"), "pbSteps.txt") + if os.path.exists(file_path): + os.remove(file_path) + increment, maximum = 0, 100 + progressBar.value = increment + progressBar.bar_color = "success" + while True: + try: + with open(file_path, "r") as file: + content = file.readlines() + if len(content) == 0: + pass + else: + maximum = int(content[0]) + increment = int(content[-1]) + + if increment == -1: + progressBar.bar_color = "danger" + os.remove(file_path) + break + progressBar.max = maximum + progressBar.value = increment + time.sleep(0.001) + except FileNotFoundError: + time.sleep(0.001) + except PermissionError: + time.sleep(0.001) + except Exception as e: + # Handle other exceptions that may occur + logger.info(f"An error occurred while reading the file: {e}") + break + if increment == maximum: + os.remove(file_path) + break + + logger.info("Read progress bar increment values stopped.") + + +def writeToFile(value: str): + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(value) diff --git a/src/guppy/frontend/sidebar.py b/src/guppy/frontend/sidebar.py new file mode 100644 index 0000000..466a639 --- /dev/null +++ b/src/guppy/frontend/sidebar.py @@ -0,0 +1,75 @@ +import logging + +import panel as pn + +logger = logging.getLogger(__name__) + + +class Sidebar: + def __init__(self, template): + self.template = template + self.setup_markdown() + self.setup_buttons() + self.setup_progress_bars() + + def setup_markdown(self): + self.mark_down_ip = pn.pane.Markdown("""**Step 1 : Save Input Parameters**""", width=300) + self.mark_down_ip_note = pn.pane.Markdown( + """***Note : ***
+ - Save Input Parameters will save input parameters used for the analysis + in all the folders you selected for the analysis (useful for future + reference). All analysis steps will run without saving input parameters. + """, + width=300, + ) + self.mark_down_storenames = pn.pane.Markdown( + """**Step 2 : Open Storenames GUI
and save storenames**""", width=300 + ) + self.mark_down_read = pn.pane.Markdown("""**Step 3 : Read Raw Data**""", width=300) + self.mark_down_preprocess = pn.pane.Markdown("""**Step 4 : Preprocess and Remove Artifacts**""", width=300) + self.mark_down_psth = pn.pane.Markdown("""**Step 5 : PSTH Computation**""", width=300) + self.mark_down_visualization = pn.pane.Markdown("""**Step 6 : Visualization**""", width=300) + + def setup_buttons(self): + self.open_storenames = pn.widgets.Button( + name="Open Storenames GUI", button_type="primary", width=300, align="end" + ) + self.read_rawData = pn.widgets.Button(name="Read Raw Data", button_type="primary", width=300, align="end") + self.preprocess = pn.widgets.Button( + name="Preprocess and Remove Artifacts", button_type="primary", width=300, align="end" + ) + self.psth_computation = pn.widgets.Button( + name="PSTH Computation", button_type="primary", width=300, align="end" + ) + self.open_visualization = pn.widgets.Button( + name="Open Visualization GUI", button_type="primary", width=300, align="end" + ) + self.save_button = pn.widgets.Button(name="Save to file...", button_type="primary", width=300, align="end") + + def attach_callbacks(self, button_name_to_onclick_fn: dict): + for button_name, onclick_fn in button_name_to_onclick_fn.items(): + button = getattr(self, button_name) + button.on_click(onclick_fn) + + def setup_progress_bars(self): + self.read_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) + self.extract_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) + self.psth_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) + + def add_to_template(self): + self.template.sidebar.append(self.mark_down_ip) + self.template.sidebar.append(self.mark_down_ip_note) + self.template.sidebar.append(self.save_button) + self.template.sidebar.append(self.mark_down_storenames) + self.template.sidebar.append(self.open_storenames) + self.template.sidebar.append(self.mark_down_read) + self.template.sidebar.append(self.read_rawData) + self.template.sidebar.append(self.read_progress) + self.template.sidebar.append(self.mark_down_preprocess) + self.template.sidebar.append(self.preprocess) + self.template.sidebar.append(self.extract_progress) + self.template.sidebar.append(self.mark_down_psth) + self.template.sidebar.append(self.psth_computation) + self.template.sidebar.append(self.psth_progress) + self.template.sidebar.append(self.mark_down_visualization) + self.template.sidebar.append(self.open_visualization) diff --git a/src/guppy/frontend/storenames_config.py b/src/guppy/frontend/storenames_config.py new file mode 100644 index 0000000..c7aa419 --- /dev/null +++ b/src/guppy/frontend/storenames_config.py @@ -0,0 +1,104 @@ +import logging + +import panel as pn + +pn.extension() + +logger = logging.getLogger(__name__) + + +class StorenamesConfig: + def __init__( + self, + show_config_button, + storename_dropdowns, + storename_textboxes, + storenames, + storenames_cache, + ): + self.config_widgets = [] + self._dropdown_help_map = {} + storename_dropdowns.clear() + storename_textboxes.clear() + + if len(storenames) == 0: + return + + self.config_widgets.append( + pn.pane.Markdown( + "## Configure Storenames\nSelect appropriate options for each storename and provide names as needed:" + ) + ) + + for i, storename in enumerate(storenames): + self.setup_storename(i, storename, storename_dropdowns, storename_textboxes, storenames_cache) + + # Add show button + self.config_widgets.append(pn.Spacer(height=20)) + self.config_widgets.append(show_config_button) + self.config_widgets.append( + pn.pane.Markdown( + "*Click 'Show Selected Configuration' to apply your selections.*", + styles={"font-size": "12px", "color": "gray"}, + ) + ) + + def _on_dropdown_value_change(self, event): + help_pane = self._dropdown_help_map.get(event.obj) + if help_pane is None: + return + dropdown_value = event.new + help_pane.object = self._get_help_text(dropdown_value=dropdown_value) + + def _get_help_text(self, dropdown_value): + if dropdown_value == "control": + return "*Type appropriate region name*" + elif dropdown_value == "signal": + return "*Type appropriate region name*" + elif dropdown_value == "event TTLs": + return "*Type event name for the TTLs*" + else: + return "" + + def setup_storename(self, i, storename, storename_dropdowns, storename_textboxes, storenames_cache): + # Create a row for each storename + row_widgets = [] + + # Label + label = pn.pane.Markdown(f"**{storename}:**") + row_widgets.append(label) + + # Dropdown options + if storename in storenames_cache: + options = storenames_cache[storename] + default_value = options[0] if options else "" + else: + options = ["", "control", "signal", "event TTLs"] + default_value = "" + + # Create unique key for widget + widget_key = ( + f"{storename}_{i}" + if f"{storename}_{i}" not in storename_dropdowns + else f"{storename}_{i}_{len(storename_dropdowns)}" + ) + + dropdown = pn.widgets.Select(name="Type", value=default_value, options=options, width=150) + storename_dropdowns[widget_key] = dropdown + row_widgets.append(dropdown) + + # Text input (only show if not cached or if control/signal/event TTLs selected) + if storename not in storenames_cache or default_value in ["control", "signal", "event TTLs"]: + textbox = pn.widgets.TextInput(name="Name", value="", placeholder="Enter region/event name", width=200) + storename_textboxes[widget_key] = textbox + row_widgets.append(textbox) + + # Add helper text based on selection + initial_help_text = self._get_help_text(default_value) + help_pane = pn.pane.Markdown(initial_help_text, styles={"color": "gray", "font-size": "12px"}) + self._dropdown_help_map[dropdown] = help_pane + dropdown.param.watch(self._on_dropdown_value_change, "value") + row_widgets.append(help_pane) + + # Add the row to config widgets + self.config_widgets.append(pn.Row(*row_widgets, margin=(5, 0))) diff --git a/src/guppy/frontend/storenames_instructions.py b/src/guppy/frontend/storenames_instructions.py new file mode 100644 index 0000000..ba5fe5d --- /dev/null +++ b/src/guppy/frontend/storenames_instructions.py @@ -0,0 +1,109 @@ +import glob +import logging +import os + +import holoviews as hv +import numpy as np +import pandas as pd +import panel as pn + +# hv.extension() +pn.extension() + +logger = logging.getLogger(__name__) + + +class StorenamesInstructions: + def __init__(self, folder_path): + # instructions about how to save the storeslist file + self.mark_down = pn.pane.Markdown( + """ + + + ### Instructions to follow : + + - Check Storenames to repeat checkbox and see instructions in “Github Wiki” for duplicating storenames. + Otherwise do not check the Storenames to repeat checkbox.
+ - Select storenames from list and click “Select Storenames” to populate area below.
+ - Enter names for storenames, in order, using the following naming convention:
+ Isosbestic = “control_region” (ex: Dv1A= control_DMS)
+ Signal= “signal_region” (ex: Dv2A= signal_DMS)
+ TTLs can be named using any convention (ex: PrtR = RewardedPortEntries) but should be kept consistent for later group analysis + + ``` + {"storenames": ["Dv1A", "Dv2A", + "Dv3B", "Dv4B", + "LNRW", "LNnR", + "PrtN", "PrtR", + "RNPS"], + "names_for_storenames": ["control_DMS", "signal_DMS", + "control_DLS", "signal_DLS", + "RewardedNosepoke", "UnrewardedNosepoke", + "UnrewardedPort", "RewardedPort", + "InactiveNosepoke"]} + ``` + - If user has saved storenames before, clicking "Select Storenames" button will pop up a dialog box + showing previously used names for storenames. Select names for storenames by checking a checkbox and + click on "Show" to populate the text area in the Storenames GUI. Close the dialog box. + + - Select “create new” or “overwrite” to generate a new storenames list or replace a previous one + - Click Save + + """, + width=550, + ) + + self.widget = pn.Column("# " + os.path.basename(folder_path), self.mark_down) + + +class StorenamesInstructionsNPM(StorenamesInstructions): + def __init__(self, folder_path): + super().__init__(folder_path=folder_path) + path_chev = glob.glob(os.path.join(folder_path, "*chev*")) + path_chod = glob.glob(os.path.join(folder_path, "*chod*")) + path_chpr = glob.glob(os.path.join(folder_path, "*chpr*")) + combine_paths = path_chev + path_chod + path_chpr + self.d = dict() + for i in range(len(combine_paths)): + basename = (os.path.basename(combine_paths[i])).split(".")[0] + df = pd.read_csv(combine_paths[i]) + self.d[basename] = {"x": np.array(df["timestamps"]), "y": np.array(df["data"])} + keys = list(self.d.keys()) + self.mark_down_np = pn.pane.Markdown( + """ + ### Extra Instructions to follow when using Neurophotometrics data : + - Guppy will take the NPM data, which has interleaved frames + from the signal and control channels, and divide it out into + separate channels for each site you recordded. + However, since NPM does not automatically annotate which + frames belong to the signal channel and which belong to the + control channel, the user must specify this for GuPPy. + - Each of your recording sites will have a channel + named “chod” and a channel named “chev” + - View the plots below and, for each site, + determine whether the “chev” or “chod” channel is signal or control + - When you give your storenames, name the channels appropriately. + For example, “chev1” might be “signal_A” and + “chod1” might be “control_A” (or vice versa). + + """ + ) + self.plot_select = pn.widgets.Select( + name="Select channel to see correspondings channels", options=keys, value=keys[0] + ) + self.plot_pane = pn.pane.HoloViews(self._make_plot(self.plot_select.value), width=550) + self.plot_select.param.watch(self._on_plot_select_change, "value") + + self.widget = pn.Column( + "# " + os.path.basename(folder_path), + self.mark_down, + self.mark_down_np, + self.plot_select, + self.plot_pane, + ) + + def _make_plot(self, plot_key): + return hv.Curve((self.d[plot_key]["x"], self.d[plot_key]["y"])).opts(width=550) + + def _on_plot_select_change(self, event): + self.plot_pane.object = self._make_plot(event.new) diff --git a/src/guppy/frontend/storenames_selector.py b/src/guppy/frontend/storenames_selector.py new file mode 100644 index 0000000..97919d6 --- /dev/null +++ b/src/guppy/frontend/storenames_selector.py @@ -0,0 +1,152 @@ +import json +import logging + +import panel as pn + +from .storenames_config import StorenamesConfig + +pn.extension() + +logger = logging.getLogger(__name__) + + +class StorenamesSelector: + + def __init__(self, allnames): + self.alert = pn.pane.Alert("#### No alerts !!", alert_type="danger", height=80, width=600) + if len(allnames) == 0: + self.alert.object = ( + "####Alert !! \n No storenames found. There are not any TDT files or csv files to look for storenames." + ) + + # creating different buttons and selectors for the GUI + self.cross_selector = pn.widgets.CrossSelector( + name="Store Names Selection", value=[], options=allnames, width=600 + ) + self.multi_choice = pn.widgets.MultiChoice( + name="Select Storenames which you want more than once (multi-choice: multiple options selection)", + value=[], + options=allnames, + ) + + self.literal_input_1 = pn.widgets.LiteralInput( + name="Number of times you want the above storename (list)", value=[], type=list + ) + # self.literal_input_2 = pn.widgets.LiteralInput(name='Names for Storenames (list)', type=list) + + self.repeat_storenames = pn.widgets.Checkbox(name="Storenames to repeat", value=False) + self.repeat_storename_wd = pn.WidgetBox("", width=600) + + self.repeat_storenames.link(self.repeat_storename_wd, callbacks={"value": self.callback}) + # self.repeat_storename_wd = pn.WidgetBox('Storenames to repeat (leave blank if not needed)', multi_choice, literal_input_1, background="white", width=600) + + self.update_options = pn.widgets.Button(name="Select Storenames", width=600) + self.save = pn.widgets.Button(name="Save", width=600) + + self.text = pn.widgets.LiteralInput(value=[], name="Selected Store Names", type=list, width=600) + + self.path = pn.widgets.TextInput(name="Location to Stores List file", width=600) + + self.mark_down_for_overwrite = pn.pane.Markdown( + """ Select option from below if user wants to over-write a file or create a new file. + **Creating a new file will make a new output folder and will get saved at that location.** + If user selects to over-write a file **Select location of the file to over-write** will provide + the existing options of the output folders where user needs to over-write the file""", + width=600, + ) + + self.select_location = pn.widgets.Select( + name="Select location of the file to over-write", value="None", options=["None"], width=600 + ) + + self.overwrite_button = pn.widgets.MenuButton( + name="over-write storeslist file or create a new one? ", + items=["over_write_file", "create_new_file"], + button_type="default", + split=True, + width=600, + ) + + self.literal_input_2 = pn.widgets.CodeEditor( + value="""{}""", theme="tomorrow", language="json", height=250, width=600 + ) + + self.take_widgets = pn.WidgetBox(self.multi_choice, self.literal_input_1) + + self.change_widgets = pn.WidgetBox(self.text) + + # Panel-based storename configuration (replaces Tkinter dialog) + self.storename_config_widgets = pn.Column(visible=False) + self.show_config_button = pn.widgets.Button(name="Show Selected Configuration", width=600) + + self.widget = pn.Column( + self.repeat_storenames, + self.repeat_storename_wd, + pn.Spacer(height=20), + self.cross_selector, + self.update_options, + self.storename_config_widgets, + pn.Spacer(height=10), + self.text, + self.literal_input_2, + self.alert, + self.mark_down_for_overwrite, + self.overwrite_button, + self.select_location, + self.save, + self.path, + ) + + def callback(self, target, event): + if event.new == True: + target.objects = [self.multi_choice, self.literal_input_1] + elif event.new == False: + target.clear() + + def get_select_location(self): + return self.select_location.value + + def set_select_location_options(self, options): + self.select_location.options = options + + def set_alert_message(self, message): + self.alert.object = message + + def get_literal_input_2(self): # TODO: come up with a better name for this method. + d = json.loads(self.literal_input_2.value) + return d + + def set_literal_input_2(self, d): # TODO: come up with a better name for this method. + self.literal_input_2.value = str(json.dumps(d, indent=2)) + + def get_take_widgets(self): + return [w.value for w in self.take_widgets] + + def set_change_widgets(self, value): + for w in self.change_widgets: + w.value = value + + def get_cross_selector(self): + return self.cross_selector.value + + def set_path(self, value): + self.path.value = value + + def attach_callbacks(self, button_name_to_onclick_fn: dict): + for button_name, onclick_fn in button_name_to_onclick_fn.items(): + button = getattr(self, button_name) + button.on_click(onclick_fn) + + def configure_storenames(self, storename_dropdowns, storename_textboxes, storenames, storenames_cache): + # Create Panel widgets for storename configuration + self.storenames_config = StorenamesConfig( + show_config_button=self.show_config_button, + storename_dropdowns=storename_dropdowns, + storename_textboxes=storename_textboxes, + storenames=storenames, + storenames_cache=storenames_cache, + ) + + # Update the configuration panel + self.storename_config_widgets.objects = self.storenames_config.config_widgets + self.storename_config_widgets.visible = len(storenames) > 0 diff --git a/src/guppy/frontend/visualization_dashboard.py b/src/guppy/frontend/visualization_dashboard.py new file mode 100644 index 0000000..1444b86 --- /dev/null +++ b/src/guppy/frontend/visualization_dashboard.py @@ -0,0 +1,158 @@ +import logging + +import panel as pn + +from .frontend_utils import scanPortsAndFind + +pn.extension() + +logger = logging.getLogger(__name__) + + +class VisualizationDashboard: + """Dashboard for interactive PSTH and heatmap visualization. + + Wraps a ``Viewer`` instance with Panel widgets and a tabbed layout. + Data loading, preparation, and Viewer instantiation are handled + externally; this class is responsible for widget creation, layout + assembly, and serving the application. + + Parameters + ---------- + plotter : ParameterizedPlotter + A fully configured ParameterizedPlotter instance that provides reactive plot + methods and param-based controls. + basename : str + Session name displayed as the tab title. + """ + + def __init__(self, *, plotter, basename): + self.plotter = plotter + self.basename = basename + self._psth_tab = self._build_psth_tab() + self._heatmap_tab = self._build_heatmap_tab() + + def _build_psth_tab(self): + """Build the PSTH tab with controls and plot panels.""" + psth_checkbox = pn.Param( + self.plotter.param.select_trials_checkbox, + widgets={ + "select_trials_checkbox": { + "type": pn.widgets.CheckBoxGroup, + "inline": True, + "name": "Select mean and/or just trials", + } + }, + ) + parameters = pn.Param( + self.plotter.param.selector_for_multipe_events_plot, + widgets={ + "selector_for_multipe_events_plot": {"type": pn.widgets.CrossSelector, "width": 550, "align": "start"} + }, + ) + psth_y_parameters = pn.Param( + self.plotter.param.psth_y, + widgets={ + "psth_y": { + "type": pn.widgets.MultiSelect, + "name": "Trial # - Timestamps", + "width": 200, + "size": 15, + "align": "start", + } + }, + ) + + event_selector = pn.Param( + self.plotter.param.event_selector, widgets={"event_selector": {"type": pn.widgets.Select, "width": 400}} + ) + x_selector = pn.Param(self.plotter.param.x, widgets={"x": {"type": pn.widgets.Select, "width": 180}}) + y_selector = pn.Param(self.plotter.param.y, widgets={"y": {"type": pn.widgets.Select, "width": 180}}) + + width_plot = pn.Param( + self.plotter.param.Width_Plot, widgets={"Width_Plot": {"type": pn.widgets.Select, "width": 70}} + ) + height_plot = pn.Param( + self.plotter.param.Height_Plot, widgets={"Height_Plot": {"type": pn.widgets.Select, "width": 70}} + ) + ylabel = pn.Param(self.plotter.param.Y_Label, widgets={"Y_Label": {"type": pn.widgets.Select, "width": 70}}) + save_opts = pn.Param( + self.plotter.param.save_options, widgets={"save_options": {"type": pn.widgets.Select, "width": 70}} + ) + + xlimit_plot = pn.Param( + self.plotter.param.X_Limit, widgets={"X_Limit": {"type": pn.widgets.RangeSlider, "width": 180}} + ) + ylimit_plot = pn.Param( + self.plotter.param.Y_Limit, widgets={"Y_Limit": {"type": pn.widgets.RangeSlider, "width": 180}} + ) + save_psth = pn.Param( + self.plotter.param.save_psth, widgets={"save_psth": {"type": pn.widgets.Button, "width": 400}} + ) + + options = pn.Column( + event_selector, + pn.Row(x_selector, y_selector), + pn.Row(xlimit_plot, ylimit_plot), + pn.Row(width_plot, height_plot, ylabel, save_opts), + save_psth, + ) + + options_selectors = pn.Row(options, parameters) + + return pn.Column( + "## " + self.basename, + pn.Row(options_selectors, pn.Column(psth_checkbox, psth_y_parameters), width=1200), + self.plotter.contPlot, + self.plotter.update_selector, + self.plotter.plot_specific_trials, + ) + + def _build_heatmap_tab(self): + """Build the heatmap tab with controls and plot panels.""" + heatmap_y_parameters = pn.Param( + self.plotter.param.heatmap_y, + widgets={ + "heatmap_y": {"type": pn.widgets.MultiSelect, "name": "Trial # - Timestamps", "width": 200, "size": 30} + }, + ) + event_selector_heatmap = pn.Param( + self.plotter.param.event_selector_heatmap, + widgets={"event_selector_heatmap": {"type": pn.widgets.Select, "width": 150}}, + ) + color_map = pn.Param( + self.plotter.param.color_map, widgets={"color_map": {"type": pn.widgets.Select, "width": 150}} + ) + width_heatmap = pn.Param( + self.plotter.param.width_heatmap, widgets={"width_heatmap": {"type": pn.widgets.Select, "width": 150}} + ) + height_heatmap = pn.Param( + self.plotter.param.height_heatmap, widgets={"height_heatmap": {"type": pn.widgets.Select, "width": 150}} + ) + save_hm = pn.Param(self.plotter.param.save_hm, widgets={"save_hm": {"type": pn.widgets.Button, "width": 150}}) + save_options_heatmap = pn.Param( + self.plotter.param.save_options_heatmap, + widgets={"save_options_heatmap": {"type": pn.widgets.Select, "width": 150}}, + ) + + return pn.Column( + "## " + self.basename, + pn.Row( + event_selector_heatmap, + color_map, + width_heatmap, + height_heatmap, + save_options_heatmap, + pn.Column(pn.Spacer(height=25), save_hm), + ), + pn.Row(self.plotter.heatmap, heatmap_y_parameters), + ) + + def show(self): + """Serve the dashboard in a browser on an available port.""" + logger.info("app") + template = pn.template.MaterialTemplate(title="Visualization GUI") + number = scanPortsAndFind(start_port=5000, end_port=5200) + app = pn.Tabs(("PSTH", self._psth_tab), ("Heat Map", self._heatmap_tab)) + template.main.append(app) + template.show(port=number) diff --git a/src/guppy/main.py b/src/guppy/main.py index 478e991..cb5ebc8 100644 --- a/src/guppy/main.py +++ b/src/guppy/main.py @@ -11,12 +11,12 @@ import panel as pn -from .savingInputParameters import savingInputParameters +from .orchestration.home import build_homepage def serve_app(): """Serve the GuPPy application using Panel.""" - template = savingInputParameters() + template = build_homepage() pn.serve(template, show=True) diff --git a/src/guppy/orchestration/home.py b/src/guppy/orchestration/home.py new file mode 100644 index 0000000..9b7e6c5 --- /dev/null +++ b/src/guppy/orchestration/home.py @@ -0,0 +1,101 @@ +import json +import logging +import os +import subprocess +import sys +from threading import Thread + +import panel as pn + +from .save_parameters import save_parameters +from .storenames import orchestrate_storenames_page +from .visualize import visualizeResults +from ..frontend.input_parameters import ParameterForm +from ..frontend.path_selection import get_folder_path +from ..frontend.progress import readPBIncrementValues +from ..frontend.sidebar import Sidebar + +logger = logging.getLogger(__name__) + + +def readRawData(parameter_form): + inputParameters = parameter_form.getInputParameters() + subprocess.call([sys.executable, "-m", "guppy.orchestration.read_raw_data", json.dumps(inputParameters)]) + + +def preprocess(parameter_form): + inputParameters = parameter_form.getInputParameters() + subprocess.call([sys.executable, "-m", "guppy.orchestration.preprocess", json.dumps(inputParameters)]) + + +def psthComputation(parameter_form, current_dir): + inputParameters = parameter_form.getInputParameters() + inputParameters["curr_dir"] = current_dir + subprocess.call([sys.executable, "-m", "guppy.orchestration.psth", json.dumps(inputParameters)]) + + +def build_homepage(): + pn.extension() + global folder_path + folder_path = get_folder_path() + current_dir = os.getcwd() + + template = pn.template.BootstrapTemplate(title="Input Parameters GUI") + parameter_form = ParameterForm(folder_path=folder_path, template=template) + sidebar = Sidebar(template=template) + + # ------------------------------------------------------------------------------------------------------------------ + # onclick closure functions for sidebar buttons + def onclickProcess(event=None): + inputParameters = parameter_form.getInputParameters() + save_parameters(inputParameters=inputParameters) + + def onclickStorenames(event=None): + inputParameters = parameter_form.getInputParameters() + orchestrate_storenames_page(inputParameters) + + def onclickVisualization(event=None): + inputParameters = parameter_form.getInputParameters() + visualizeResults(inputParameters) + + def onclickreaddata(event=None): + thread = Thread(target=readRawData, args=(parameter_form,)) + thread.start() + readPBIncrementValues(sidebar.read_progress) + thread.join() + + def onclickpreprocess(event=None): + thread = Thread(target=preprocess, args=(parameter_form,)) + thread.start() + readPBIncrementValues(sidebar.extract_progress) + thread.join() + + def onclickpsth(event=None): + thread = Thread(target=psthComputation, args=(parameter_form, current_dir)) + thread.start() + readPBIncrementValues(sidebar.psth_progress) + thread.join() + + # ------------------------------------------------------------------------------------------------------------------ + + button_name_to_onclick_fn = { + "save_button": onclickProcess, + "open_storenames": onclickStorenames, + "read_rawData": onclickreaddata, + "preprocess": onclickpreprocess, + "psth_computation": onclickpsth, + "open_visualization": onclickVisualization, + } + sidebar.attach_callbacks(button_name_to_onclick_fn=button_name_to_onclick_fn) + sidebar.add_to_template() + + # Expose minimal hooks and widgets to enable programmatic testing + template._hooks = { + "onclickProcess": onclickProcess, + "getInputParameters": parameter_form.getInputParameters, + } + template._widgets = { + "files_1": parameter_form.files_1, + } + + return template diff --git a/src/guppy/preprocess.py b/src/guppy/orchestration/preprocess.py similarity index 76% rename from src/guppy/preprocess.py rename to src/guppy/orchestration/preprocess.py index e4812a2..b5b9d0d 100755 --- a/src/guppy/preprocess.py +++ b/src/guppy/orchestration/preprocess.py @@ -3,23 +3,23 @@ import logging import os import sys +from typing import Literal import matplotlib.pyplot as plt import numpy as np -from .analysis.artifact_removal import remove_artifacts -from .analysis.combine_data import combine_data -from .analysis.control_channel import add_control_channel, create_control_channel -from .analysis.io_utils import ( +from ..analysis.artifact_removal import remove_artifacts +from ..analysis.combine_data import combine_data +from ..analysis.control_channel import add_control_channel, create_control_channel +from ..analysis.io_utils import ( check_storeslistfile, check_TDT, find_files, - get_all_stores_for_combining_data, # noqa: F401 -- Necessary for other modules that depend on preprocess.py get_coords, read_hdf5, takeOnlyDirs, ) -from .analysis.standard_io import ( +from ..analysis.standard_io import ( read_control_and_signal, read_coords_pairwise, read_corrected_data, @@ -37,8 +37,12 @@ write_corrected_ttl_timestamps, write_zscore, ) -from .analysis.timestamp_correction import correct_timestamps -from .analysis.z_score import compute_z_score +from ..analysis.timestamp_correction import correct_timestamps +from ..analysis.z_score import compute_z_score +from ..frontend.artifact_removal import ArtifactRemovalWidget +from ..frontend.progress import writeToFile +from ..utils.utils import get_all_stores_for_combining_data +from ..visualization.preprocessing import visualize_preprocessing logger = logging.getLogger(__name__) @@ -47,17 +51,10 @@ plt.switch_backend("TKAgg") -def writeToFile(value: str): - with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: - file.write(value) - - -# function to plot z_score -def visualize_z_score(filepath): - +def execute_preprocessing_visualization(filepath, visualization_type: Literal["z_score", "dff"]): name = os.path.basename(filepath) - path = glob.glob(os.path.join(filepath, "z_score_*")) + path = glob.glob(os.path.join(filepath, f"{visualization_type}_*")) path = sorted(path) @@ -66,122 +63,7 @@ def visualize_z_score(filepath): name_1 = basename.split("_")[-1] x = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") y = read_hdf5("", path[i], "data") - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot(x, y) - ax.set_title(basename) - fig.suptitle(name) - # plt.show() - - -# function to plot deltaF/F -def visualize_dff(filepath): - name = os.path.basename(filepath) - - path = glob.glob(os.path.join(filepath, "dff_*")) - - path = sorted(path) - - for i in range(len(path)): - basename = (os.path.basename(path[i])).split(".")[0] - name_1 = basename.split("_")[-1] - x = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") - y = read_hdf5("", path[i], "data") - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot(x, y) - ax.set_title(basename) - fig.suptitle(name) - # plt.show() - - -def visualize(filepath, x, y1, y2, y3, plot_name, removeArtifacts): - - # plotting control and signal data - - if (y1 == 0).all() == True: - y1 = np.zeros(x.shape[0]) - - coords_path = os.path.join(filepath, "coordsForPreProcessing_" + plot_name[0].split("_")[-1] + ".npy") - name = os.path.basename(filepath) - fig = plt.figure() - ax1 = fig.add_subplot(311) - (line1,) = ax1.plot(x, y1) - ax1.set_title(plot_name[0]) - ax2 = fig.add_subplot(312) - (line2,) = ax2.plot(x, y2) - ax2.set_title(plot_name[1]) - ax3 = fig.add_subplot(313) - (line3,) = ax3.plot(x, y2) - (line3,) = ax3.plot(x, y3) - ax3.set_title(plot_name[2]) - fig.suptitle(name) - - hfont = {"fontname": "DejaVu Sans"} - - if removeArtifacts == True and os.path.exists(coords_path): - ax3.set_xlabel("Time(s) \n Note : Artifacts have been removed, but are not reflected in this plot.", **hfont) - else: - ax3.set_xlabel("Time(s)", **hfont) - - global coords - coords = [] - - # clicking 'space' key on keyboard will draw a line on the plot so that user can see what chunks are selected - # and clicking 'd' key on keyboard will deselect the selected point - def onclick(event): - # global ix, iy - - if event.key == " ": - ix, iy = event.xdata, event.ydata - logger.info(f"x = {ix}, y = {iy}") - y1_max, y1_min = np.amax(y1), np.amin(y1) - y2_max, y2_min = np.amax(y2), np.amin(y2) - - # ax1.plot([ix,ix], [y1_max, y1_min], 'k--') - # ax2.plot([ix,ix], [y2_max, y2_min], 'k--') - - ax1.axvline(ix, c="black", ls="--") - ax2.axvline(ix, c="black", ls="--") - ax3.axvline(ix, c="black", ls="--") - - fig.canvas.draw() - - global coords - coords.append((ix, iy)) - - # if len(coords) == 2: - # fig.canvas.mpl_disconnect(cid) - - return coords - - elif event.key == "d": - if len(coords) > 0: - logger.info(f"x = {coords[-1][0]}, y = {coords[-1][1]}; deleted") - del coords[-1] - ax1.lines[-1].remove() - ax2.lines[-1].remove() - ax3.lines[-1].remove() - fig.canvas.draw() - - return coords - - # close the plot will save coordinates for all the selected chunks in the data - def plt_close_event(event): - global coords - if coords and len(coords) > 0: - name_1 = plot_name[0].split("_")[-1] - np.save(os.path.join(filepath, "coordsForPreProcessing_" + name_1 + ".npy"), coords) - logger.info(f"Coordinates file saved at {os.path.join(filepath, 'coordsForPreProcessing_'+name_1+'.npy')}") - fig.canvas.mpl_disconnect(cid) - coords = [] - - cid = fig.canvas.mpl_connect("key_press_event", onclick) - cid = fig.canvas.mpl_connect("close_event", plt_close_event) - # multi = MultiCursor(fig.canvas, (ax1, ax2), color='g', lw=1, horizOn=False, vertOn=True) - - # plt.show() - # return fig + fig, ax = visualize_preprocessing(suptitle=name, title=basename, x=x, y=y) # function to plot control and signal, also provide a feature to select chunks for artifacts removal @@ -198,6 +80,7 @@ def visualizeControlAndSignal(filepath, removeArtifacts): path = np.asarray(path).reshape(2, -1) + widgets = [] for i in range(path.shape[1]): name_1 = ((os.path.basename(path[0, i])).split(".")[0]).split("_") @@ -216,7 +99,9 @@ def visualizeControlAndSignal(filepath, removeArtifacts): (os.path.basename(path[1, i])).split(".")[0], (os.path.basename(cntrl_sig_fit_path)).split(".")[0], ] - visualize(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts) + widget = ArtifactRemovalWidget(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts) + widgets.append(widget) + return widgets # function to execute timestamps corrections using functions timestampCorrection and decide_naming_convention_and_applyCorrection @@ -336,23 +221,45 @@ def execute_zscore(folderNames, inputParameters): write_zscore(filepath, name, z_score, dff, control_fit, temp_control_arr) logger.info(f"z-score for the data in {filepath} computed.") + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + + plt.show() + logger.info("Z-score computation completed.") + + +def visualize_z_score(inputParameters, folderNames): + plot_zScore_dff = inputParameters["plot_zScore_dff"] + combine_data = inputParameters["combine_data"] + remove_artifacts = inputParameters["removeArtifacts"] + + storesListPath = [] + for i in range(len(folderNames)): + if combine_data == True: + storesListPath.append([folderNames[i][0]]) + else: + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + + widgets = [] + for j in range(len(storesListPath)): + filepath = storesListPath[j] if not remove_artifacts: - visualizeControlAndSignal(filepath, removeArtifacts=remove_artifacts) + # a reference to widgets has to persist in the same scope as plt.show() is called + widgets.extend(visualizeControlAndSignal(filepath, removeArtifacts=remove_artifacts)) if plot_zScore_dff == "z_score": - visualize_z_score(filepath) + execute_preprocessing_visualization(filepath, visualization_type="z_score") if plot_zScore_dff == "dff": - visualize_dff(filepath) + execute_preprocessing_visualization(filepath, visualization_type="dff") if plot_zScore_dff == "Both": - visualize_z_score(filepath) - visualize_dff(filepath) - - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 + execute_preprocessing_visualization(filepath, visualization_type="z_score") + execute_preprocessing_visualization(filepath, visualization_type="dff") plt.show() - logger.info("Z-score computation completed.") + logger.info("Visualization of z-score and dF/F completed.") # function to remove artifacts from z-score data @@ -394,15 +301,34 @@ def execute_artifact_removal(folderNames, inputParameters): ) write_artifact_removal(filepath, name_to_data, pair_name_to_timestamps, compound_name_to_ttl_timestamps) - visualizeControlAndSignal(filepath, removeArtifacts=True) writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") inputParameters["step"] += 1 - plt.show() + visualize_artifact_removal(folderNames, inputParameters) logger.info("Artifact removal completed.") +def visualize_artifact_removal(folderNames, inputParameters): + combine_data = inputParameters["combine_data"] + + storesListPath = [] + for i in range(len(folderNames)): + if combine_data == True: + storesListPath.append([folderNames[i][0]]) + else: + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + + storesListPath = np.concatenate(storesListPath) + + for j in range(len(storesListPath)): + filepath = storesListPath[j] + visualizeControlAndSignal(filepath, removeArtifacts=True) + plt.show() + logger.info("Visualization of artifact removal completed.") + + # function to combine data when there are two different data files for the same recording session # it will combine the data, do timestamps processing and save the combined data in the first output folder. def execute_combine_data(folderNames, inputParameters, storesList): @@ -490,6 +416,7 @@ def extractTsAndSignal(inputParameters): writeToFile(str((pbMaxValue + 1) * 10) + "\n" + str(10) + "\n") execute_timestamp_correction(folderNames, inputParameters) execute_zscore(folderNames, inputParameters) + visualize_z_score(inputParameters, folderNames) if remove_artifacts == True: execute_artifact_removal(folderNames, inputParameters) else: @@ -499,6 +426,7 @@ def extractTsAndSignal(inputParameters): storesList = check_storeslistfile(folderNames) op_folder = execute_combine_data(folderNames, inputParameters, storesList) execute_zscore(op_folder, inputParameters) + visualize_z_score(inputParameters, op_folder) if remove_artifacts == True: execute_artifact_removal(op_folder, inputParameters) diff --git a/src/guppy/computePsth.py b/src/guppy/orchestration/psth.py similarity index 94% rename from src/guppy/computePsth.py rename to src/guppy/orchestration/psth.py index 32d9be1..e1d78de 100755 --- a/src/guppy/computePsth.py +++ b/src/guppy/orchestration/psth.py @@ -13,44 +13,31 @@ import numpy as np from scipy import signal as ss -from .analysis.compute_psth import compute_psth -from .analysis.cross_correlation import compute_cross_correlation -from .analysis.io_utils import ( - get_all_stores_for_combining_data, +from ..analysis.compute_psth import compute_psth +from ..analysis.cross_correlation import compute_cross_correlation +from ..analysis.io_utils import ( make_dir_for_cross_correlation, makeAverageDir, - read_Df, read_hdf5, write_hdf5, ) -from .analysis.psth_average import averageForGroup -from .analysis.psth_peak_and_area import compute_psth_peak_and_area -from .analysis.psth_utils import ( +from ..analysis.psth_average import averageForGroup +from ..analysis.psth_peak_and_area import compute_psth_peak_and_area +from ..analysis.psth_utils import ( create_Df_for_cross_correlation, create_Df_for_psth, getCorrCombinations, ) -from .analysis.standard_io import ( +from ..analysis.standard_io import ( write_peak_and_area_to_csv, write_peak_and_area_to_hdf5, ) +from ..frontend.progress import writeToFile +from ..utils.utils import get_all_stores_for_combining_data, read_Df, takeOnlyDirs logger = logging.getLogger(__name__) -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) - - -def writeToFile(value: str): - with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: - file.write(value) - - # function to create PSTH for each event using function helper_psth and save the PSTH to h5 file def execute_compute_psth(filepath, event, inputParameters): @@ -358,7 +345,7 @@ def psthForEachStorename(inputParameters): def main(input_parameters): try: inputParameters = psthForEachStorename(input_parameters) - subprocess.call([sys.executable, "-m", "guppy.findTransientsFreqAndAmp", json.dumps(inputParameters)]) + subprocess.call([sys.executable, "-m", "guppy.orchestration.transients", json.dumps(inputParameters)]) logger.info("#" * 400) except Exception as e: with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: diff --git a/src/guppy/readTevTsq.py b/src/guppy/orchestration/read_raw_data.py similarity index 90% rename from src/guppy/readTevTsq.py rename to src/guppy/orchestration/read_raw_data.py index 19a0a4a..cd82634 100755 --- a/src/guppy/readTevTsq.py +++ b/src/guppy/orchestration/read_raw_data.py @@ -14,25 +14,14 @@ TdtRecordingExtractor, read_and_save_all_events, ) +from guppy.frontend.progress import writeToFile +from guppy.utils.utils import takeOnlyDirs logger = logging.getLogger(__name__) -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) - - -def writeToFile(value: str): - with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: - file.write(value) - - # function to read data from 'tsq' and 'tev' files -def readRawData(inputParameters): +def orchestrate_read_raw_data(inputParameters): logger.debug("### Reading raw data... ###") # get input parameters @@ -99,7 +88,7 @@ def readRawData(inputParameters): def main(input_parameters): logger.info("run") try: - readRawData(input_parameters) + orchestrate_read_raw_data(input_parameters) logger.info("#" * 400) except Exception as e: with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: diff --git a/src/guppy/orchestration/save_parameters.py b/src/guppy/orchestration/save_parameters.py new file mode 100644 index 0000000..3af5e2c --- /dev/null +++ b/src/guppy/orchestration/save_parameters.py @@ -0,0 +1,41 @@ +import json +import logging +import os + +logger = logging.getLogger(__name__) + + +def save_parameters(inputParameters: dict): + logger.debug("Saving Input Parameters file.") + analysisParameters = { + "combine_data": inputParameters["combine_data"], + "isosbestic_control": inputParameters["isosbestic_control"], + "timeForLightsTurnOn": inputParameters["timeForLightsTurnOn"], + "filter_window": inputParameters["filter_window"], + "removeArtifacts": inputParameters["removeArtifacts"], + "noChannels": inputParameters["noChannels"], + "zscore_method": inputParameters["zscore_method"], + "baselineWindowStart": inputParameters["baselineWindowStart"], + "baselineWindowEnd": inputParameters["baselineWindowEnd"], + "nSecPrev": inputParameters["nSecPrev"], + "nSecPost": inputParameters["nSecPost"], + "timeInterval": inputParameters["timeInterval"], + "bin_psth_trials": inputParameters["bin_psth_trials"], + "use_time_or_trials": inputParameters["use_time_or_trials"], + "baselineCorrectionStart": inputParameters["baselineCorrectionStart"], + "baselineCorrectionEnd": inputParameters["baselineCorrectionEnd"], + "peak_startPoint": inputParameters["peak_startPoint"], + "peak_endPoint": inputParameters["peak_endPoint"], + "selectForComputePsth": inputParameters["selectForComputePsth"], + "selectForTransientsComputation": inputParameters["selectForTransientsComputation"], + "moving_window": inputParameters["moving_window"], + "highAmpFilt": inputParameters["highAmpFilt"], + "transientsThresh": inputParameters["transientsThresh"], + } + for folder in inputParameters["folderNames"]: + with open(os.path.join(folder, "GuPPyParamtersUsed.json"), "w") as f: + json.dump(analysisParameters, f, indent=4) + logger.info(f"Input Parameters file saved at {folder}") + + logger.info("#" * 400) + logger.info("Input Parameters File Saved.") diff --git a/src/guppy/orchestration/storenames.py b/src/guppy/orchestration/storenames.py new file mode 100755 index 0000000..4a015d9 --- /dev/null +++ b/src/guppy/orchestration/storenames.py @@ -0,0 +1,330 @@ +import glob +import json +import logging +import os +from pathlib import Path + +import holoviews as hv # noqa: F401 +import numpy as np +import panel as pn + +from guppy.extractors import ( + CsvRecordingExtractor, + DoricRecordingExtractor, + NpmRecordingExtractor, + TdtRecordingExtractor, +) +from guppy.frontend.frontend_utils import scanPortsAndFind +from guppy.frontend.npm_gui_prompts import ( + get_multi_event_responses, + get_timestamp_configuration, +) +from guppy.frontend.storenames_instructions import ( + StorenamesInstructions, + StorenamesInstructionsNPM, +) +from guppy.frontend.storenames_selector import StorenamesSelector +from guppy.utils.utils import takeOnlyDirs + +pn.extension() + +logger = logging.getLogger(__name__) + + +# function to show location for over-writing or creating a new stores list file. +def show_dir(filepath): + i = 1 + while True: + basename = os.path.basename(filepath) + op = os.path.join(filepath, basename + "_output_" + str(i)) + if not os.path.exists(op): + break + i += 1 + return op + + +def make_dir(filepath): + i = 1 + while True: + basename = os.path.basename(filepath) + op = os.path.join(filepath, basename + "_output_" + str(i)) + if not os.path.exists(op): + os.mkdir(op) + break + i += 1 + + return op + + +def _fetchValues(text, storenames, storename_dropdowns, storename_textboxes, d): + if not storename_dropdowns or not len(storenames) > 0: + return "####Alert !! \n No storenames selected." + + storenames_cache = dict() + if os.path.exists(os.path.join(Path.home(), ".storesList.json")): + with open(os.path.join(Path.home(), ".storesList.json")) as f: + storenames_cache = json.load(f) + + comboBoxValues, textBoxValues = [], [] + dropdown_keys = list(storename_dropdowns.keys()) + textbox_keys = list(storename_textboxes.keys()) if storename_textboxes else [] + + # Get dropdown values + for key in dropdown_keys: + comboBoxValues.append(storename_dropdowns[key].value) + + # Get textbox values (matching with dropdown keys) + for key in dropdown_keys: + if key in storename_textboxes: + textbox_value = storename_textboxes[key].value or "" + textBoxValues.append(textbox_value) + + # Validation: Check for whitespace + if len(textbox_value.split()) > 1: + return "####Alert !! \n Whitespace is not allowed in the text box entry." + + # Validation: Check for empty required fields + dropdown_value = storename_dropdowns[key].value + if ( + not textbox_value + and dropdown_value not in storenames_cache + and dropdown_value in ["control", "signal", "event TTLs"] + ): + return "####Alert !! \n One of the text box entry is empty." + else: + # For cached values, use the dropdown value directly + textBoxValues.append(storename_dropdowns[key].value) + + if len(comboBoxValues) != len(textBoxValues): + return "####Alert !! \n Number of entries in combo box and text box should be same." + + names_for_storenames = [] + for i in range(len(comboBoxValues)): + if comboBoxValues[i] == "control" or comboBoxValues[i] == "signal": + if "_" in textBoxValues[i]: + return "####Alert !! \n Please do not use underscore in region name." + names_for_storenames.append("{}_{}".format(comboBoxValues[i], textBoxValues[i])) + elif comboBoxValues[i] == "event TTLs": + names_for_storenames.append(textBoxValues[i]) + else: + names_for_storenames.append(comboBoxValues[i]) + + d["storenames"] = text.value + d["names_for_storenames"] = names_for_storenames + return "#### No alerts !!" + + +def _save(d, select_location): + arr1, arr2 = np.asarray(d["storenames"]), np.asarray(d["names_for_storenames"]) + + if np.where(arr2 == "")[0].size > 0: + alert_message = "#### Alert !! \n Empty string in the list names_for_storenames." + logger.error("Empty string in the list names_for_storenames.") + return alert_message + + if arr1.shape[0] != arr2.shape[0]: + alert_message = "#### Alert !! \n Length of list storenames and names_for_storenames is not equal." + logger.error("Length of list storenames and names_for_storenames is not equal.") + return alert_message + + if not os.path.exists(os.path.join(Path.home(), ".storesList.json")): + storenames_cache = dict() + + for i in range(arr1.shape[0]): + if arr1[i] in storenames_cache: + storenames_cache[arr1[i]].append(arr2[i]) + storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]])) + else: + storenames_cache[arr1[i]] = [arr2[i]] + + with open(os.path.join(Path.home(), ".storesList.json"), "w") as f: + json.dump(storenames_cache, f, indent=4) + else: + with open(os.path.join(Path.home(), ".storesList.json")) as f: + storenames_cache = json.load(f) + + for i in range(arr1.shape[0]): + if arr1[i] in storenames_cache: + storenames_cache[arr1[i]].append(arr2[i]) + storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]])) + else: + storenames_cache[arr1[i]] = [arr2[i]] + + with open(os.path.join(Path.home(), ".storesList.json"), "w") as f: + json.dump(storenames_cache, f, indent=4) + + arr = np.asarray([arr1, arr2]) + logger.info(arr) + if not os.path.exists(select_location): + os.mkdir(select_location) + + np.savetxt(os.path.join(select_location, "storesList.csv"), arr, delimiter=",", fmt="%s") + logger.info(f"Storeslist file saved at {select_location}") + logger.info("Storeslist : \n" + str(arr)) + return "#### No alerts !!" + + +# function to show GUI and save +def build_storenames_page(inputParameters, events, flags, folder_path): + + logger.debug("Saving stores list file.") + # getting input parameters + inputParameters = inputParameters + + # Headless path: if storenames_map provided, write storesList.csv without building the Panel UI + storenames_map = inputParameters.get("storenames_map") + if isinstance(storenames_map, dict) and len(storenames_map) > 0: + op = make_dir(folder_path) + arr = np.asarray([list(storenames_map.keys()), list(storenames_map.values())], dtype=str) + np.savetxt(os.path.join(op, "storesList.csv"), arr, delimiter=",", fmt="%s") + logger.info(f"Storeslist file saved at {op}") + logger.info("Storeslist : \n" + str(arr)) + return + + # Get storenames from extractor's events property + allnames = events + + # creating GUI template + template = pn.template.BootstrapTemplate(title="Storenames GUI - {}".format(os.path.basename(folder_path))) + + if "data_np_v2" in flags or "data_np" in flags or "event_np" in flags: + storenames_instructions = StorenamesInstructionsNPM(folder_path=folder_path) + else: + storenames_instructions = StorenamesInstructions(folder_path=folder_path) + storenames_selector = StorenamesSelector(allnames=allnames) + + storenames = [] + storename_dropdowns = {} + storename_textboxes = {} + + # ------------------------------------------------------------------------------------------------------------------ + # onclick closure functions + # on clicking overwrite_button, following function is executed + def overwrite_button_actions(event): + if event.new == "over_write_file": + options = takeOnlyDirs(glob.glob(os.path.join(folder_path, "*_output_*"))) + storenames_selector.set_select_location_options(options=options) + else: + options = [show_dir(folder_path)] + storenames_selector.set_select_location_options(options=options) + + def fetchValues(event): + global storenames + d = dict() + alert_message = _fetchValues( + text=storenames_selector.text, + storenames=storenames, + storename_dropdowns=storename_dropdowns, + storename_textboxes=storename_textboxes, + d=d, + ) + storenames_selector.set_alert_message(alert_message) + storenames_selector.set_literal_input_2(d=d) + + # on clicking 'Select Storenames' button, following function is executed + def update_values(event): + global storenames, vars_list + + arr = storenames_selector.get_take_widgets() + new_arr = [] + for i in range(len(arr[1])): + for j in range(arr[1][i]): + new_arr.append(arr[0][i]) + if len(new_arr) > 0: + storenames = storenames_selector.get_cross_selector() + new_arr + else: + storenames = storenames_selector.get_cross_selector() + storenames_selector.set_change_widgets(storenames) + + storenames_cache = dict() + if os.path.exists(os.path.join(Path.home(), ".storesList.json")): + with open(os.path.join(Path.home(), ".storesList.json")) as f: + storenames_cache = json.load(f) + + storenames_selector.configure_storenames( + storename_dropdowns=storename_dropdowns, + storename_textboxes=storename_textboxes, + storenames=storenames, + storenames_cache=storenames_cache, + ) + + # on clicking save button, following function is executed + def save_button(event=None): + global storenames + d = storenames_selector.get_literal_input_2() + select_location = storenames_selector.get_select_location() + alert_message = _save(d=d, select_location=select_location) + storenames_selector.set_alert_message(alert_message) + storenames_selector.set_path(os.path.join(select_location, "storesList.csv")) + + # ------------------------------------------------------------------------------------------------------------------ + + # Connect button callbacks + button_name_to_onclick_fn = { + "update_options": update_values, + "save": save_button, + "overwrite_button": overwrite_button_actions, + "show_config_button": fetchValues, + } + storenames_selector.attach_callbacks(button_name_to_onclick_fn) + + template.main.append(pn.Row(storenames_instructions.widget, storenames_selector.widget)) + + # creating widgets, adding them to template and showing a GUI on a new browser window + number = scanPortsAndFind(start_port=5000, end_port=5200) + template.show(port=number) + + +def read_header(inputParameters, num_ch, modality, folder_path, headless): + if modality == "tdt": + events, flags = TdtRecordingExtractor.discover_events_and_flags(folder_path=folder_path) + elif modality == "csv": + events, flags = CsvRecordingExtractor.discover_events_and_flags(folder_path=folder_path) + + elif modality == "doric": + events, flags = DoricRecordingExtractor.discover_events_and_flags(folder_path=folder_path) + + elif modality == "npm": + if not headless: + # Resolve multiple event TTLs + multiple_event_ttls = NpmRecordingExtractor.has_multiple_event_ttls(folder_path=folder_path) + responses = get_multi_event_responses(multiple_event_ttls) + inputParameters["npm_split_events"] = responses + + # Resolve timestamp units and columns + ts_unit_needs, col_names_ts = NpmRecordingExtractor.needs_ts_unit(folder_path=folder_path, num_ch=num_ch) + ts_units, npm_timestamp_column_names = get_timestamp_configuration(ts_unit_needs, col_names_ts) + inputParameters["npm_time_units"] = ts_units if ts_units else None + inputParameters["npm_timestamp_column_names"] = ( + npm_timestamp_column_names if npm_timestamp_column_names else None + ) + + events, flags = NpmRecordingExtractor.discover_events_and_flags( + folder_path=folder_path, num_ch=num_ch, inputParameters=inputParameters + ) + else: + raise ValueError("Modality not recognized. Please use 'tdt', 'csv', 'doric', or 'npm'.") + return events, flags + + +# function to read input parameters and run the saveStorenames function +def orchestrate_storenames_page(inputParameters): + + inputParameters = inputParameters + folderNames = inputParameters["folderNames"] + isosbestic_control = inputParameters["isosbestic_control"] + num_ch = inputParameters["noChannels"] + modality = inputParameters.get("modality", "tdt") + headless = bool(os.environ.get("GUPPY_BASE_DIR")) + + logger.info(folderNames) + + try: + for i in folderNames: + folder_path = os.path.join(inputParameters["abspath"], i) + events, flags = read_header(inputParameters, num_ch, modality, folder_path, headless) + build_storenames_page(inputParameters, events, flags, folder_path) + logger.info("#" * 400) + except Exception as e: + logger.error(str(e)) + raise e diff --git a/src/guppy/findTransientsFreqAndAmp.py b/src/guppy/orchestration/transients.py similarity index 64% rename from src/guppy/findTransientsFreqAndAmp.py rename to src/guppy/orchestration/transients.py index f6c3d6e..4d899da 100755 --- a/src/guppy/findTransientsFreqAndAmp.py +++ b/src/guppy/orchestration/transients.py @@ -8,39 +8,24 @@ import matplotlib.pyplot as plt import numpy as np -from .analysis.io_utils import ( - get_all_stores_for_combining_data, +from ..analysis.io_utils import ( read_hdf5, - takeOnlyDirs, ) -from .analysis.standard_io import ( +from ..analysis.standard_io import ( + read_transients_from_hdf5, write_freq_and_amp_to_csv, write_freq_and_amp_to_hdf5, + write_transients_to_hdf5, ) -from .analysis.transients import analyze_transients -from .analysis.transients_average import averageForGroup +from ..analysis.transients import analyze_transients +from ..analysis.transients_average import averageForGroup +from ..frontend.progress import writeToFile +from ..utils.utils import get_all_stores_for_combining_data, takeOnlyDirs +from ..visualization.transients import visualize_peaks logger = logging.getLogger(__name__) -def writeToFile(value: str): - with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: - file.write(value) - - -def visuzlize_peaks(filepath, z_score, timestamps, peaksIndex): - - dirname = os.path.dirname(filepath) - - basename = (os.path.basename(filepath)).split(".")[0] - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot(timestamps, z_score, "-", timestamps[peaksIndex], z_score[peaksIndex], "o") - ax.set_title(basename) - fig.suptitle(os.path.basename(dirname)) - # plt.show() - - def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_count()): logger.debug("Calculating frequency and amplitude of transients in z-score data....") @@ -76,10 +61,68 @@ def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_cou index=np.arange(peaks_occurrences.shape[0]), columns=["timestamps", "amplitude"], ) - visuzlize_peaks(path[i], z_score, ts, peaksInd) + write_transients_to_hdf5(filepath, basename, z_score, ts, peaksInd) logger.info("Frequency and amplitude of transients in z_score data are calculated.") +def execute_visualize_peaks(folderNames, inputParameters): + selectForTransientsComputation = inputParameters["selectForTransientsComputation"] + for i in range(len(folderNames)): + logger.debug(f"Finding transients in z-score data of {folderNames[i]} and calculating frequency and amplitude.") + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + if selectForTransientsComputation == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForTransientsComputation == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + else: + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + + for i in range(len(path)): + basename = (os.path.basename(path[i])).split(".")[0] + z_score, ts, peaksInd = read_transients_from_hdf5(filepath, basename) + + suptitle = os.path.basename(os.path.dirname(path[i])) + title = (os.path.basename(path[i])).split(".")[0] + visualize_peaks(title, suptitle, z_score, ts, peaksInd) + + logger.info("Frequency and amplitude of transients in z_score data are visualized.") + plt.show() + + +def execute_visualize_peaks_combined(folderNames, inputParameters): + selectForTransientsComputation = inputParameters["selectForTransientsComputation"] + + storesListPath = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + for i in range(len(op)): + filepath = op[i][0] + + if selectForTransientsComputation == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForTransientsComputation == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + else: + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + + for i in range(len(path)): + basename = (os.path.basename(path[i])).split(".")[0] + z_score, ts, peaksInd = read_transients_from_hdf5(filepath, basename) + + suptitle = os.path.basename(os.path.dirname(path[i])) + title = (os.path.basename(path[i])).split(".")[0] + visualize_peaks(title, suptitle, z_score, ts, peaksInd) + + logger.info("Frequency and amplitude of transients in z_score data are calculated.") + plt.show() + + def executeFindFreqAndAmp(inputParameters): logger.info("Finding transients in z-score data and calculating frequency and amplitude....") @@ -106,8 +149,10 @@ def executeFindFreqAndAmp(inputParameters): else: if combine_data == True: execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_window, numProcesses) + execute_visualize_peaks_combined(folderNames, inputParameters) else: execute_find_freq_and_amp(inputParameters, folderNames, moving_window, numProcesses) + execute_visualize_peaks(folderNames, inputParameters) logger.info("Transients in z-score data found and frequency and amplitude are calculated.") @@ -126,7 +171,6 @@ def execute_find_freq_and_amp(inputParameters, folderNames, moving_window, numPr writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") inputParameters["step"] += 1 logger.info("Transients in z-score data found and frequency and amplitude are calculated.") - plt.show() def execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_window, numProcesses): @@ -142,7 +186,6 @@ def execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_wind findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") inputParameters["step"] += 1 - plt.show() def execute_average_for_group(inputParameters, folderNamesForAvg): diff --git a/src/guppy/orchestration/visualize.py b/src/guppy/orchestration/visualize.py new file mode 100755 index 0000000..a4149f9 --- /dev/null +++ b/src/guppy/orchestration/visualize.py @@ -0,0 +1,257 @@ +import glob +import logging +import os +import re + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from ..frontend.parameterized_plotter import ParameterizedPlotter, remove_cols +from ..frontend.visualization_dashboard import VisualizationDashboard +from ..utils.utils import get_all_stores_for_combining_data, read_Df, takeOnlyDirs + +logger = logging.getLogger(__name__) + + +# helper function to create plots +def helper_plots(filepath, event, name, inputParameters): + + basename = os.path.basename(filepath) + visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"] + + # note when there are no behavior event TTLs + if len(event) == 0: + logger.warning("\033[1m" + "There are no behavior event TTLs present to visualize.".format(event) + "\033[0m") + return 0 + + if os.path.exists(os.path.join(filepath, "cross_correlation_output")): + event_corr, frames = [], [] + if visualize_zscore_or_dff == "z_score": + corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_z_score_*")) + elif visualize_zscore_or_dff == "dff": + corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_dff_*")) + for i in range(len(corr_fp)): + filename = os.path.basename(corr_fp[i]).split(".")[0] + event_corr.append(filename) + df = pd.read_hdf(corr_fp[i], key="df", mode="r") + frames.append(df) + if len(frames) > 0: + df_corr = pd.concat(frames, keys=event_corr, axis=1) + else: + event_corr = [] + df_corr = [] + else: + event_corr = [] + df_corr = None + + # combine all the event PSTH so that it can be viewed together + if name: + event_name, name = event, name + new_event, frames, bins = [], [], {} + for i in range(len(event_name)): + + for j in range(len(name)): + new_event.append(event_name[i] + "_" + name[j].split("_")[-1]) + new_name = name[j] + temp_df = read_Df(filepath, new_event[-1], new_name) + cols = list(temp_df.columns) + regex = re.compile("bin_[(]") + bins[new_event[-1]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + # bins.append(keep_cols) + frames.append(temp_df) + + df = pd.concat(frames, keys=new_event, axis=1) + else: + new_event = list(np.unique(np.array(event))) + frames, bins = [], {} + for i in range(len(new_event)): + temp_df = read_Df(filepath, new_event[i], "") + cols = list(temp_df.columns) + regex = re.compile("bin_[(]") + bins[new_event[i]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + frames.append(temp_df) + + df = pd.concat(frames, keys=new_event, axis=1) + + if isinstance(df_corr, pd.DataFrame): + new_event.extend(event_corr) + df = pd.concat([df, df_corr], axis=1, sort=False).reset_index() + + columns_dict = dict() + for i in range(len(new_event)): + df_1 = df[new_event[i]] + columns = list(df_1.columns) + columns.append("All") + columns_dict[new_event[i]] = columns + + # make options array for different selectors + multiple_plots_options = [] + heatmap_options = new_event + bins_keys = list(bins.keys()) + if len(bins_keys) > 0: + bins_new = bins + for i in range(len(bins_keys)): + arr = bins[bins_keys[i]] + if len(arr) > 0: + # heatmap_options.append('{}_bin'.format(bins_keys[i])) + for j in arr: + multiple_plots_options.append("{}_{}".format(bins_keys[i], j)) + + multiple_plots_options = new_event + multiple_plots_options + else: + multiple_plots_options = new_event + x_min = float(inputParameters["nSecPrev"]) - 20 + x_max = float(inputParameters["nSecPost"]) + 20 + colormaps = plt.colormaps() + new_colormaps = ["plasma", "plasma_r", "magma", "magma_r", "inferno", "inferno_r", "viridis", "viridis_r"] + set_a = set(colormaps) + set_b = set(new_colormaps) + colormaps = new_colormaps + list(set_a.difference(set_b)) + x = [columns_dict[new_event[0]][-4]] + y = remove_cols(columns_dict[new_event[0]]) + trial_no = range(1, len(remove_cols(columns_dict[heatmap_options[0]])[:-2]) + 1) + trial_ts = [ + "{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(columns_dict[heatmap_options[0]])[:-2]) + ] + ["All"] + + plotter = ParameterizedPlotter( + event_selector_objects=new_event, + event_selector_heatmap_objects=heatmap_options, + selector_for_multipe_events_plot_objects=multiple_plots_options, + columns_dict=columns_dict, + df_new=df, + x_min=x_min, + x_max=x_max, + color_map_objects=colormaps, + filepath=filepath, + x_objects=x, + y_objects=y, + heatmap_y_objects=trial_ts, + psth_y_objects=trial_ts[:-1], + ) + dashboard = VisualizationDashboard(plotter=plotter, basename=basename) + dashboard.show() + + +# function to combine all the output folders together and preprocess them to use them in helper_plots function +def createPlots(filepath, event, inputParameters): + + for i in range(len(event)): + event[i] = event[i].replace("\\", "_") + event[i] = event[i].replace("/", "_") + + average = inputParameters["visualizeAverageResults"] + visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"] + + if average == True: + path = [] + for i in range(len(event)): + if visualize_zscore_or_dff == "z_score": + path.append(glob.glob(os.path.join(filepath, event[i] + "*_z_score_*"))) + elif visualize_zscore_or_dff == "dff": + path.append(glob.glob(os.path.join(filepath, event[i] + "*_dff_*"))) + + path = np.concatenate(path) + else: + if visualize_zscore_or_dff == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif visualize_zscore_or_dff == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + + name_arr = [] + event_arr = [] + + index = [] + for i in range(len(event)): + if "control" in event[i].lower() or "signal" in event[i].lower(): + index.append(i) + + event = np.delete(event, index) + + for i in range(len(path)): + name = (os.path.basename(path[i])).split(".") + name = name[0] + name_arr.append(name) + + if average == True: + logger.info("average") + helper_plots(filepath, name_arr, "", inputParameters) + else: + helper_plots(filepath, event, name_arr, inputParameters) + + +def visualizeResults(inputParameters): + + inputParameters = inputParameters + + average = inputParameters["visualizeAverageResults"] + logger.info(average) + + folderNames = inputParameters["folderNames"] + folderNamesForAvg = inputParameters["folderNamesForAvg"] + combine_data = inputParameters["combine_data"] + + if average == True and len(folderNamesForAvg) > 0: + # folderNames = folderNamesForAvg + filepath_avg = os.path.join(inputParameters["abspath"], "average") + # filepath = os.path.join(inputParameters['abspath'], folderNames[0]) + storesListPath = [] + for i in range(len(folderNamesForAvg)): + filepath = folderNamesForAvg[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + storesList = np.asarray([[], []]) + for i in range(storesListPath.shape[0]): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt( + os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + + createPlots(filepath_avg, np.unique(storesList[1, :]), inputParameters) + + else: + if combine_data == True: + storesListPath = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + for i in range(len(op)): + storesList = np.asarray([[], []]) + for j in range(len(op[i])): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + filepath = op[i][0] + createPlots(filepath, storesList[1, :], inputParameters) + else: + for i in range(len(folderNames)): + + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt( + os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + + createPlots(filepath, storesList[1, :], inputParameters) + + +# logger.info(sys.argv[1:]) +# visualizeResults(sys.argv[1:][0]) diff --git a/src/guppy/saveStoresList.py b/src/guppy/saveStoresList.py deleted file mode 100755 index 20a5c94..0000000 --- a/src/guppy/saveStoresList.py +++ /dev/null @@ -1,712 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -# In[1]: - - -import glob -import json -import logging -import os -import socket -import tkinter as tk -from pathlib import Path -from random import randint -from tkinter import StringVar, messagebox, ttk - -import holoviews as hv -import numpy as np -import pandas as pd -import panel as pn - -from guppy.extractors import ( - CsvRecordingExtractor, - DoricRecordingExtractor, - NpmRecordingExtractor, - TdtRecordingExtractor, -) - -# hv.extension() -pn.extension() - -logger = logging.getLogger(__name__) - - -def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"): - while True: - port = randint(start_port, end_port) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(0.001) # Set timeout to avoid long waiting on closed ports - result = sock.connect_ex((host, port)) - if result == 0: # If the connection is successful, the port is open - continue - else: - break - - return port - - -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) - - -# function to show location for over-writing or creating a new stores list file. -def show_dir(filepath): - i = 1 - while True: - basename = os.path.basename(filepath) - op = os.path.join(filepath, basename + "_output_" + str(i)) - if not os.path.exists(op): - break - i += 1 - return op - - -def make_dir(filepath): - i = 1 - while True: - basename = os.path.basename(filepath) - op = os.path.join(filepath, basename + "_output_" + str(i)) - if not os.path.exists(op): - os.mkdir(op) - break - i += 1 - - return op - - -# function to show GUI and save -def saveStorenames(inputParameters, events, flags, folder_path): - - logger.debug("Saving stores list file.") - # getting input parameters - inputParameters = inputParameters - - # Headless path: if storenames_map provided, write storesList.csv without building the Panel UI - storenames_map = inputParameters.get("storenames_map") - if isinstance(storenames_map, dict) and len(storenames_map) > 0: - op = make_dir(folder_path) - arr = np.asarray([list(storenames_map.keys()), list(storenames_map.values())], dtype=str) - np.savetxt(os.path.join(op, "storesList.csv"), arr, delimiter=",", fmt="%s") - logger.info(f"Storeslist file saved at {op}") - logger.info("Storeslist : \n" + str(arr)) - return - - # Get storenames from extractor's events property - allnames = events - - if "data_np_v2" in flags or "data_np" in flags or "event_np" in flags: - path_chev = glob.glob(os.path.join(folder_path, "*chev*")) - path_chod = glob.glob(os.path.join(folder_path, "*chod*")) - path_chpr = glob.glob(os.path.join(folder_path, "*chpr*")) - combine_paths = path_chev + path_chod + path_chpr - d = dict() - for i in range(len(combine_paths)): - basename = (os.path.basename(combine_paths[i])).split(".")[0] - df = pd.read_csv(combine_paths[i]) - d[basename] = {"x": np.array(df["timestamps"]), "y": np.array(df["data"])} - keys = list(d.keys()) - mark_down_np = pn.pane.Markdown( - """ - ### Extra Instructions to follow when using Neurophotometrics data : - - Guppy will take the NPM data, which has interleaved frames - from the signal and control channels, and divide it out into - separate channels for each site you recordded. - However, since NPM does not automatically annotate which - frames belong to the signal channel and which belong to the - control channel, the user must specify this for GuPPy. - - Each of your recording sites will have a channel - named “chod” and a channel named “chev” - - View the plots below and, for each site, - determine whether the “chev” or “chod” channel is signal or control - - When you give your storenames, name the channels appropriately. - For example, “chev1” might be “signal_A” and - “chod1” might be “control_A” (or vice versa). - - """ - ) - plot_select = pn.widgets.Select( - name="Select channel to see correspondings channels", options=keys, value=keys[0] - ) - - @pn.depends(plot_select=plot_select) - def plot(plot_select): - return hv.Curve((d[plot_select]["x"], d[plot_select]["y"])).opts(width=550) - - else: - pass - - # instructions about how to save the storeslist file - mark_down = pn.pane.Markdown( - """ - - - ### Instructions to follow : - - - Check Storenames to repeat checkbox and see instructions in “Github Wiki” for duplicating storenames. - Otherwise do not check the Storenames to repeat checkbox.
- - Select storenames from list and click “Select Storenames” to populate area below.
- - Enter names for storenames, in order, using the following naming convention:
- Isosbestic = “control_region” (ex: Dv1A= control_DMS)
- Signal= “signal_region” (ex: Dv2A= signal_DMS)
- TTLs can be named using any convention (ex: PrtR = RewardedPortEntries) but should be kept consistent for later group analysis - - ``` - {"storenames": ["Dv1A", "Dv2A", - "Dv3B", "Dv4B", - "LNRW", "LNnR", - "PrtN", "PrtR", - "RNPS"], - "names_for_storenames": ["control_DMS", "signal_DMS", - "control_DLS", "signal_DLS", - "RewardedNosepoke", "UnrewardedNosepoke", - "UnrewardedPort", "RewardedPort", - "InactiveNosepoke"]} - ``` - - If user has saved storenames before, clicking "Select Storenames" button will pop up a dialog box - showing previously used names for storenames. Select names for storenames by checking a checkbox and - click on "Show" to populate the text area in the Storenames GUI. Close the dialog box. - - - Select “create new” or “overwrite” to generate a new storenames list or replace a previous one - - Click Save - - """, - width=550, - ) - - # creating GUI template - template = pn.template.BootstrapTemplate( - title="Storenames GUI - {}".format(os.path.basename(folder_path), mark_down) - ) - - # creating different buttons and selectors for the GUI - cross_selector = pn.widgets.CrossSelector(name="Store Names Selection", value=[], options=allnames, width=600) - multi_choice = pn.widgets.MultiChoice( - name="Select Storenames which you want more than once (multi-choice: multiple options selection)", - value=[], - options=allnames, - ) - - literal_input_1 = pn.widgets.LiteralInput( - name="Number of times you want the above storename (list)", value=[], type=list - ) - # literal_input_2 = pn.widgets.LiteralInput(name='Names for Storenames (list)', type=list) - - repeat_storenames = pn.widgets.Checkbox(name="Storenames to repeat", value=False) - repeat_storename_wd = pn.WidgetBox("", width=600) - - def callback(target, event): - if event.new == True: - target.objects = [multi_choice, literal_input_1] - elif event.new == False: - target.clear() - - repeat_storenames.link(repeat_storename_wd, callbacks={"value": callback}) - # repeat_storename_wd = pn.WidgetBox('Storenames to repeat (leave blank if not needed)', multi_choice, literal_input_1, background="white", width=600) - - update_options = pn.widgets.Button(name="Select Storenames", width=600) - save = pn.widgets.Button(name="Save", width=600) - - text = pn.widgets.LiteralInput(value=[], name="Selected Store Names", type=list, width=600) - - path = pn.widgets.TextInput(name="Location to Stores List file", width=600) - - mark_down_for_overwrite = pn.pane.Markdown( - """ Select option from below if user wants to over-write a file or create a new file. - **Creating a new file will make a new output folder and will get saved at that location.** - If user selects to over-write a file **Select location of the file to over-write** will provide - the existing options of the output folders where user needs to over-write the file""", - width=600, - ) - - select_location = pn.widgets.Select( - name="Select location of the file to over-write", value="None", options=["None"], width=600 - ) - - overwrite_button = pn.widgets.MenuButton( - name="over-write storeslist file or create a new one? ", - items=["over_write_file", "create_new_file"], - button_type="default", - split=True, - width=600, - ) - - literal_input_2 = pn.widgets.CodeEditor(value="""{}""", theme="tomorrow", language="json", height=250, width=600) - - alert = pn.pane.Alert("#### No alerts !!", alert_type="danger", height=80, width=600) - - take_widgets = pn.WidgetBox(multi_choice, literal_input_1) - - change_widgets = pn.WidgetBox(text) - - storenames = [] - storename_dropdowns = {} - storename_textboxes = {} - - if len(allnames) == 0: - alert.object = ( - "####Alert !! \n No storenames found. There are not any TDT files or csv files to look for storenames." - ) - - # on clicking overwrite_button, following function is executed - def overwrite_button_actions(event): - if event.new == "over_write_file": - select_location.options = takeOnlyDirs(glob.glob(os.path.join(folder_path, "*_output_*"))) - # select_location.value = select_location.options[0] - else: - select_location.options = [show_dir(folder_path)] - # select_location.value = select_location.options[0] - - def fetchValues(event): - global storenames - alert.object = "#### No alerts !!" - - if not storename_dropdowns or not len(storenames) > 0: - alert.object = "####Alert !! \n No storenames selected." - return - - storenames_cache = dict() - if os.path.exists(os.path.join(Path.home(), ".storesList.json")): - with open(os.path.join(Path.home(), ".storesList.json")) as f: - storenames_cache = json.load(f) - - comboBoxValues, textBoxValues = [], [] - dropdown_keys = list(storename_dropdowns.keys()) - textbox_keys = list(storename_textboxes.keys()) if storename_textboxes else [] - - # Get dropdown values - for key in dropdown_keys: - comboBoxValues.append(storename_dropdowns[key].value) - - # Get textbox values (matching with dropdown keys) - for key in dropdown_keys: - if key in storename_textboxes: - textbox_value = storename_textboxes[key].value or "" - textBoxValues.append(textbox_value) - - # Validation: Check for whitespace - if len(textbox_value.split()) > 1: - alert.object = "####Alert !! \n Whitespace is not allowed in the text box entry." - return - - # Validation: Check for empty required fields - dropdown_value = storename_dropdowns[key].value - if ( - not textbox_value - and dropdown_value not in storenames_cache - and dropdown_value in ["control", "signal", "event TTLs"] - ): - alert.object = "####Alert !! \n One of the text box entry is empty." - return - else: - # For cached values, use the dropdown value directly - textBoxValues.append(storename_dropdowns[key].value) - - if len(comboBoxValues) != len(textBoxValues): - alert.object = "####Alert !! \n Number of entries in combo box and text box should be same." - return - - names_for_storenames = [] - for i in range(len(comboBoxValues)): - if comboBoxValues[i] == "control" or comboBoxValues[i] == "signal": - if "_" in textBoxValues[i]: - alert.object = "####Alert !! \n Please do not use underscore in region name." - return - names_for_storenames.append("{}_{}".format(comboBoxValues[i], textBoxValues[i])) - elif comboBoxValues[i] == "event TTLs": - names_for_storenames.append(textBoxValues[i]) - else: - names_for_storenames.append(comboBoxValues[i]) - - d = dict() - d["storenames"] = text.value - d["names_for_storenames"] = names_for_storenames - literal_input_2.value = str(json.dumps(d, indent=2)) - - # Panel-based storename configuration (replaces Tkinter dialog) - storename_config_widgets = pn.Column(visible=False) - show_config_button = pn.widgets.Button(name="Show Selected Configuration", width=600) - - # on clicking 'Select Storenames' button, following function is executed - def update_values(event): - global storenames, vars_list - arr = [] - for w in take_widgets: - arr.append(w.value) - - new_arr = [] - - for i in range(len(arr[1])): - for j in range(arr[1][i]): - new_arr.append(arr[0][i]) - - if len(new_arr) > 0: - storenames = cross_selector.value + new_arr - else: - storenames = cross_selector.value - - for w in change_widgets: - w.value = storenames - - storenames_cache = dict() - if os.path.exists(os.path.join(Path.home(), ".storesList.json")): - with open(os.path.join(Path.home(), ".storesList.json")) as f: - storenames_cache = json.load(f) - - # Create Panel widgets for storename configuration - config_widgets = [] - storename_dropdowns.clear() - storename_textboxes.clear() - - if len(storenames) > 0: - config_widgets.append( - pn.pane.Markdown( - "## Configure Storenames\nSelect appropriate options for each storename and provide names as needed:" - ) - ) - - for i, storename in enumerate(storenames): - # Create a row for each storename - row_widgets = [] - - # Label - label = pn.pane.Markdown(f"**{storename}:**") - row_widgets.append(label) - - # Dropdown options - if storename in storenames_cache: - options = storenames_cache[storename] - default_value = options[0] if options else "" - else: - options = ["", "control", "signal", "event TTLs"] - default_value = "" - - # Create unique key for widget - widget_key = ( - f"{storename}_{i}" - if f"{storename}_{i}" not in storename_dropdowns - else f"{storename}_{i}_{len(storename_dropdowns)}" - ) - - dropdown = pn.widgets.Select(name="Type", value=default_value, options=options, width=150) - storename_dropdowns[widget_key] = dropdown - row_widgets.append(dropdown) - - # Text input (only show if not cached or if control/signal/event TTLs selected) - if storename not in storenames_cache or default_value in ["control", "signal", "event TTLs"]: - textbox = pn.widgets.TextInput( - name="Name", value="", placeholder="Enter region/event name", width=200 - ) - storename_textboxes[widget_key] = textbox - row_widgets.append(textbox) - - # Add helper text based on selection - def create_help_function(dropdown_widget, help_pane_container): - @pn.depends(dropdown_widget.param.value, watch=True) - def update_help(dropdown_value): - if dropdown_value == "control": - help_pane_container[0] = pn.pane.Markdown( - "*Type appropriate region name*", styles={"color": "gray", "font-size": "12px"} - ) - elif dropdown_value == "signal": - help_pane_container[0] = pn.pane.Markdown( - "*Type appropriate region name*", styles={"color": "gray", "font-size": "12px"} - ) - elif dropdown_value == "event TTLs": - help_pane_container[0] = pn.pane.Markdown( - "*Type event name for the TTLs*", styles={"color": "gray", "font-size": "12px"} - ) - else: - help_pane_container[0] = pn.pane.Markdown( - "", styles={"color": "gray", "font-size": "12px"} - ) - - return update_help - - help_container = [pn.pane.Markdown("")] - help_function = create_help_function(dropdown, help_container) - help_function(dropdown.value) # Initialize - row_widgets.append(help_container[0]) - - # Add the row to config widgets - config_widgets.append(pn.Row(*row_widgets, margin=(5, 0))) - - # Add show button - config_widgets.append(pn.Spacer(height=20)) - config_widgets.append(show_config_button) - config_widgets.append( - pn.pane.Markdown( - "*Click 'Show Selected Configuration' to apply your selections.*", - styles={"font-size": "12px", "color": "gray"}, - ) - ) - - # Update the configuration panel - storename_config_widgets.objects = config_widgets - storename_config_widgets.visible = len(storenames) > 0 - - # on clicking save button, following function is executed - def save_button(event=None): - global storenames - - d = json.loads(literal_input_2.value) - arr1, arr2 = np.asarray(d["storenames"]), np.asarray(d["names_for_storenames"]) - - if np.where(arr2 == "")[0].size > 0: - alert.object = "#### Alert !! \n Empty string in the list names_for_storenames." - logger.error("Empty string in the list names_for_storenames.") - raise Exception("Empty string in the list names_for_storenames.") - else: - alert.object = "#### No alerts !!" - - if arr1.shape[0] != arr2.shape[0]: - alert.object = "#### Alert !! \n Length of list storenames and names_for_storenames is not equal." - logger.error("Length of list storenames and names_for_storenames is not equal.") - raise Exception("Length of list storenames and names_for_storenames is not equal.") - else: - alert.object = "#### No alerts !!" - - if not os.path.exists(os.path.join(Path.home(), ".storesList.json")): - storenames_cache = dict() - - for i in range(arr1.shape[0]): - if arr1[i] in storenames_cache: - storenames_cache[arr1[i]].append(arr2[i]) - storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]])) - else: - storenames_cache[arr1[i]] = [arr2[i]] - - with open(os.path.join(Path.home(), ".storesList.json"), "w") as f: - json.dump(storenames_cache, f, indent=4) - else: - with open(os.path.join(Path.home(), ".storesList.json")) as f: - storenames_cache = json.load(f) - - for i in range(arr1.shape[0]): - if arr1[i] in storenames_cache: - storenames_cache[arr1[i]].append(arr2[i]) - storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]])) - else: - storenames_cache[arr1[i]] = [arr2[i]] - - with open(os.path.join(Path.home(), ".storesList.json"), "w") as f: - json.dump(storenames_cache, f, indent=4) - - arr = np.asarray([arr1, arr2]) - logger.info(arr) - if not os.path.exists(select_location.value): - os.mkdir(select_location.value) - - np.savetxt(os.path.join(select_location.value, "storesList.csv"), arr, delimiter=",", fmt="%s") - path.value = os.path.join(select_location.value, "storesList.csv") - logger.info(f"Storeslist file saved at {select_location.value}") - logger.info("Storeslist : \n" + str(arr)) - - # Connect button callbacks - update_options.on_click(update_values) - show_config_button.on_click(fetchValues) - save.on_click(save_button) - overwrite_button.on_click(overwrite_button_actions) - - # creating widgets, adding them to template and showing a GUI on a new browser window - number = scanPortsAndFind(start_port=5000, end_port=5200) - - if "data_np_v2" in flags or "data_np" in flags or "event_np" in flags: - widget_1 = pn.Column("# " + os.path.basename(folder_path), mark_down, mark_down_np, plot_select, plot) - widget_2 = pn.Column( - repeat_storenames, - repeat_storename_wd, - pn.Spacer(height=20), - cross_selector, - update_options, - storename_config_widgets, - pn.Spacer(height=10), - text, - literal_input_2, - alert, - mark_down_for_overwrite, - overwrite_button, - select_location, - save, - path, - ) - template.main.append(pn.Row(widget_1, widget_2)) - - else: - widget_1 = pn.Column("# " + os.path.basename(folder_path), mark_down) - widget_2 = pn.Column( - repeat_storenames, - repeat_storename_wd, - pn.Spacer(height=20), - cross_selector, - update_options, - storename_config_widgets, - pn.Spacer(height=10), - text, - literal_input_2, - alert, - mark_down_for_overwrite, - overwrite_button, - select_location, - save, - path, - ) - template.main.append(pn.Row(widget_1, widget_2)) - - template.show(port=number) - - -# function to read input parameters and run the saveStorenames function -def execute(inputParameters): - - inputParameters = inputParameters - folderNames = inputParameters["folderNames"] - isosbestic_control = inputParameters["isosbestic_control"] - num_ch = inputParameters["noChannels"] - modality = inputParameters.get("modality", "tdt") - - logger.info(folderNames) - - try: - for i in folderNames: - folder_path = os.path.join(inputParameters["abspath"], i) - if modality == "tdt": - events, flags = TdtRecordingExtractor.discover_events_and_flags(folder_path=folder_path) - elif modality == "csv": - events, flags = CsvRecordingExtractor.discover_events_and_flags(folder_path=folder_path) - - elif modality == "doric": - events, flags = DoricRecordingExtractor.discover_events_and_flags(folder_path=folder_path) - - elif modality == "npm": - headless = bool(os.environ.get("GUPPY_BASE_DIR")) - if not headless: - # Resolve multiple event TTLs - multiple_event_ttls = NpmRecordingExtractor.has_multiple_event_ttls(folder_path=folder_path) - responses = get_multi_event_responses(multiple_event_ttls) - inputParameters["npm_split_events"] = responses - - # Resolve timestamp units and columns - ts_unit_needs, col_names_ts = NpmRecordingExtractor.needs_ts_unit( - folder_path=folder_path, num_ch=num_ch - ) - ts_units, npm_timestamp_column_names = get_timestamp_configuration(ts_unit_needs, col_names_ts) - inputParameters["npm_time_units"] = ts_units if ts_units else None - inputParameters["npm_timestamp_column_names"] = ( - npm_timestamp_column_names if npm_timestamp_column_names else None - ) - - events, flags = NpmRecordingExtractor.discover_events_and_flags( - folder_path=folder_path, num_ch=num_ch, inputParameters=inputParameters - ) - else: - raise ValueError("Modality not recognized. Please use 'tdt', 'csv', 'doric', or 'npm'.") - - saveStorenames(inputParameters, events, flags, folder_path) - logger.info("#" * 400) - except Exception as e: - logger.error(str(e)) - raise e - - -def get_multi_event_responses(multiple_event_ttls): - responses = [] - for has_multiple in multiple_event_ttls: - if not has_multiple: - responses.append(False) - continue - window = tk.Tk() - response = messagebox.askyesno( - "Multiple event TTLs", - ( - "Based on the TTL file, " - "it looks like TTLs " - "belong to multiple behavior types. " - "Do you want to create multiple files for each " - "behavior type?" - ), - ) - window.destroy() - responses.append(response) - return responses - - -def get_timestamp_configuration(ts_unit_needs, col_names_ts): - ts_units, npm_timestamp_column_names = [], [] - for need in ts_unit_needs: - if not need: - ts_units.append("seconds") - npm_timestamp_column_names.append(None) - continue - window = tk.Tk() - window.title("Select appropriate options for timestamps") - window.geometry("500x200") - holdComboboxValues = dict() - - timestamps_label = ttk.Label(window, text="Select which timestamps to use : ").grid( - row=0, column=1, pady=25, padx=25 - ) - holdComboboxValues["timestamps"] = StringVar() - timestamps_combo = ttk.Combobox(window, values=col_names_ts, textvariable=holdComboboxValues["timestamps"]) - timestamps_combo.grid(row=0, column=2, pady=25, padx=25) - timestamps_combo.current(0) - # timestamps_combo.bind("<>", comboBoxSelected) - - time_unit_label = ttk.Label(window, text="Select timestamps unit : ").grid(row=1, column=1, pady=25, padx=25) - holdComboboxValues["time_unit"] = StringVar() - time_unit_combo = ttk.Combobox( - window, - values=["", "seconds", "milliseconds", "microseconds"], - textvariable=holdComboboxValues["time_unit"], - ) - time_unit_combo.grid(row=1, column=2, pady=25, padx=25) - time_unit_combo.current(0) - # time_unit_combo.bind("<>", comboBoxSelected) - window.lift() - window.after(500, lambda: window.lift()) - window.mainloop() - - if holdComboboxValues["timestamps"].get(): - npm_timestamp_column_name = holdComboboxValues["timestamps"].get() - else: - messagebox.showerror( - "All options not selected", - "All the options for timestamps \ - were not selected. Please select appropriate options", - ) - logger.error( - "All the options for timestamps \ - were not selected. Please select appropriate options" - ) - raise Exception( - "All the options for timestamps \ - were not selected. Please select appropriate options" - ) - if holdComboboxValues["time_unit"].get(): - if holdComboboxValues["time_unit"].get() == "seconds": - ts_unit = holdComboboxValues["time_unit"].get() - elif holdComboboxValues["time_unit"].get() == "milliseconds": - ts_unit = holdComboboxValues["time_unit"].get() - else: - ts_unit = holdComboboxValues["time_unit"].get() - else: - messagebox.showerror( - "All options not selected", - "All the options for timestamps \ - were not selected. Please select appropriate options", - ) - logger.error( - "All the options for timestamps \ - were not selected. Please select appropriate options" - ) - raise Exception( - "All the options for timestamps \ - were not selected. Please select appropriate options" - ) - ts_units.append(ts_unit) - npm_timestamp_column_names.append(npm_timestamp_column_name) - return ts_units, npm_timestamp_column_names diff --git a/src/guppy/savingInputParameters.py b/src/guppy/savingInputParameters.py deleted file mode 100644 index a1bd35e..0000000 --- a/src/guppy/savingInputParameters.py +++ /dev/null @@ -1,581 +0,0 @@ -import json -import logging -import os -import subprocess -import sys -import time -import tkinter as tk -from threading import Thread -from tkinter import filedialog, ttk - -import numpy as np -import pandas as pd -import panel as pn - -from .saveStoresList import execute -from .visualizePlot import visualizeResults - -logger = logging.getLogger(__name__) - - -def savingInputParameters(): - pn.extension() - - # Determine base folder path (headless-friendly via env var) - base_dir_env = os.environ.get("GUPPY_BASE_DIR") - is_headless = base_dir_env and os.path.isdir(base_dir_env) - if is_headless: - global folder_path - folder_path = base_dir_env - logger.info(f"Folder path set to {folder_path} (from GUPPY_BASE_DIR)") - else: - # Create the main window - folder_selection = tk.Tk() - folder_selection.title("Select the folder path where your data is located") - folder_selection.geometry("700x200") - - def select_folder(): - global folder_path - folder_path = filedialog.askdirectory(title="Select the folder path where your data is located") - if folder_path: - logger.info(f"Folder path set to {folder_path}") - folder_selection.destroy() - else: - folder_path = os.path.expanduser("~") - logger.info(f"Folder path set to {folder_path}") - - select_button = ttk.Button(folder_selection, text="Select a Folder", command=select_folder) - select_button.pack(pady=5) - folder_selection.mainloop() - - current_dir = os.getcwd() - - def make_dir(filepath): - op = os.path.join(filepath, "inputParameters") - if not os.path.exists(op): - os.mkdir(op) - return op - - def readRawData(): - inputParameters = getInputParameters() - subprocess.call([sys.executable, "-m", "guppy.readTevTsq", json.dumps(inputParameters)]) - - def extractTs(): - inputParameters = getInputParameters() - subprocess.call([sys.executable, "-m", "guppy.preprocess", json.dumps(inputParameters)]) - - def psthComputation(): - inputParameters = getInputParameters() - inputParameters["curr_dir"] = current_dir - subprocess.call([sys.executable, "-m", "guppy.computePsth", json.dumps(inputParameters)]) - - def readPBIncrementValues(progressBar): - logger.info("Read progress bar increment values function started...") - file_path = os.path.join(os.path.expanduser("~"), "pbSteps.txt") - if os.path.exists(file_path): - os.remove(file_path) - increment, maximum = 0, 100 - progressBar.value = increment - progressBar.bar_color = "success" - while True: - try: - with open(file_path, "r") as file: - content = file.readlines() - if len(content) == 0: - pass - else: - maximum = int(content[0]) - increment = int(content[-1]) - - if increment == -1: - progressBar.bar_color = "danger" - os.remove(file_path) - break - progressBar.max = maximum - progressBar.value = increment - time.sleep(0.001) - except FileNotFoundError: - time.sleep(0.001) - except PermissionError: - time.sleep(0.001) - except Exception as e: - # Handle other exceptions that may occur - logger.info(f"An error occurred while reading the file: {e}") - break - if increment == maximum: - os.remove(file_path) - break - - logger.info("Read progress bar increment values stopped.") - - # progress bars = PB - read_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) - extract_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) - psth_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) - - template = pn.template.BootstrapTemplate(title="Input Parameters GUI") - - mark_down_1 = pn.pane.Markdown("""**Select folders for the analysis from the file selector below**""", width=600) - - files_1 = pn.widgets.FileSelector(folder_path, name="folderNames", width=950) - - explain_modality = pn.pane.Markdown( - """ - **Data Modality:** Select the type of data acquisition system used for your recordings: - - **tdt**: Tucker-Davis Technologies system - - **csv**: Generic CSV format - - **doric**: Doric Photometry system - - **npm**: Neurophotometrics system - """, - width=600, - ) - - modality_selector = pn.widgets.Select( - name="Data Modality", value="tdt", options=["tdt", "csv", "doric", "npm"], width=320 - ) - - explain_time_artifacts = pn.pane.Markdown( - """ - - ***Number of cores :*** Number of cores used for analysis. Try to - keep it less than the number of cores in your machine. - - ***Combine Data? :*** Make this parameter ``` True ``` if user wants to combine - the data, especially when there is two different - data files for the same recording session.
- - ***Isosbestic Control Channel? :*** Make this parameter ``` False ``` if user - does not want to use isosbestic control channel in the analysis.
- - ***Eliminate first few seconds :*** It is the parameter to cut out first x seconds - from the data. Default is 1 seconds.
- - ***Window for Moving Average filter :*** The filtering of signals - is done using moving average filter. Default window used for moving - average filter is 100 datapoints. Change it based on the requirement.
- - ***Moving Window (transients detection) :*** Transients in the z-score - and/or \u0394F/F are detected using this moving window. - Default is 15 seconds. Change it based on the requirement.
- - ***High Amplitude filtering threshold (HAFT) (transients detection) :*** High amplitude - events greater than x times the MAD above the median are filtered out. Here, x is - high amplitude filtering threshold. Default is 2. - - ***Transients detection threshold (TD Thresh):*** Peaks with local maxima greater than x times - the MAD above the median of the trace (after filtering high amplitude events) are detected - as transients. Here, x is transients detection threshold. Default is 3. - - ***Number of channels (Neurophotometrics only) :*** Number of - channels used while recording, when data files has no column names mentioning "Flags" - or "LedState". - - ***removeArtifacts? :*** Make this parameter ``` True``` if there are - artifacts and user wants to remove the artifacts. - - ***removeArtifacts method :*** Selecting ```concatenate``` will remove bad - chunks and concatenate the selected good chunks together. - Selecting ```replace with NaN``` will replace bad chunks with NaN - values. - """, - width=350, - ) - - timeForLightsTurnOn = pn.widgets.LiteralInput( - name="Eliminate first few seconds (int)", value=1, type=int, width=320 - ) - - isosbestic_control = pn.widgets.Select( - name="Isosbestic Control Channel? (bool)", value=True, options=[True, False], width=320 - ) - - numberOfCores = pn.widgets.LiteralInput(name="# of cores (int)", value=2, type=int, width=150) - - combine_data = pn.widgets.Select(name="Combine Data? (bool)", value=False, options=[True, False], width=150) - - computePsth = pn.widgets.Select( - name="z_score and/or \u0394F/F? (psth)", options=["z_score", "dff", "Both"], width=320 - ) - - transients = pn.widgets.Select( - name="z_score and/or \u0394F/F? (transients)", options=["z_score", "dff", "Both"], width=320 - ) - - plot_zScore_dff = pn.widgets.Select( - name="z-score plot and/or \u0394F/F plot?", options=["z_score", "dff", "Both", "None"], value="None", width=320 - ) - - moving_wd = pn.widgets.LiteralInput( - name="Moving Window for transients detection (s) (int)", value=15, type=int, width=320 - ) - - highAmpFilt = pn.widgets.LiteralInput(name="HAFT (int)", value=2, type=int, width=150) - - transientsThresh = pn.widgets.LiteralInput(name="TD Thresh (int)", value=3, type=int, width=150) - - moving_avg_filter = pn.widgets.LiteralInput( - name="Window for Moving Average filter (int)", value=100, type=int, width=320 - ) - - removeArtifacts = pn.widgets.Select(name="removeArtifacts? (bool)", value=False, options=[True, False], width=150) - - artifactsRemovalMethod = pn.widgets.Select( - name="removeArtifacts method", value="concatenate", options=["concatenate", "replace with NaN"], width=150 - ) - - no_channels_np = pn.widgets.LiteralInput( - name="Number of channels (Neurophotometrics only)", value=2, type=int, width=320 - ) - - z_score_computation = pn.widgets.Select( - name="z-score computation Method", - options=["standard z-score", "baseline z-score", "modified z-score"], - value="standard z-score", - width=200, - ) - - baseline_wd_strt = pn.widgets.LiteralInput( - name="Baseline Window Start Time (s) (int)", value=0, type=int, width=200 - ) - baseline_wd_end = pn.widgets.LiteralInput(name="Baseline Window End Time (s) (int)", value=0, type=int, width=200) - - explain_z_score = pn.pane.Markdown( - """ - ***Note :***
- - Details about z-score computation methods are explained in Github wiki.
- - The details will make user understand what computation method to use for - their data.
- - Baseline Window Parameters should be kept 0 unless you are using baseline
- z-score computation method. The parameters are in seconds. - """, - width=580, - ) - - explain_nsec = pn.pane.Markdown( - """ - - ***Time Interval :*** To omit bursts of event timestamps, user defined time interval - is set so that if the time difference between two timestamps is less than this defined time - interval, it will be deleted for the calculation of PSTH. - - ***Compute Cross-correlation :*** Make this parameter ```True```, when user wants - to compute cross-correlation between PSTHs of two different signals or signals - recorded from different brain regions. - """, - width=580, - ) - - nSecPrev = pn.widgets.LiteralInput(name="Seconds before 0 (int)", value=-10, type=int, width=120) - - nSecPost = pn.widgets.LiteralInput(name="Seconds after 0 (int)", value=20, type=int, width=120) - - computeCorr = pn.widgets.Select( - name="Compute Cross-correlation (bool)", options=[True, False], value=False, width=200 - ) - - timeInterval = pn.widgets.LiteralInput(name="Time Interval (s)", value=2, type=int, width=120) - - use_time_or_trials = pn.widgets.Select( - name="Bin PSTH trials (str)", options=["Time (min)", "# of trials"], value="Time (min)", width=120 - ) - - bin_psth_trials = pn.widgets.LiteralInput( - name="Time(min) / # of trials \n for binning? (int)", value=0, type=int, width=200 - ) - - explain_baseline = pn.pane.Markdown( - """ - ***Note :***
- - If user does not want to do baseline correction, - put both parameters 0.
- - If the first event timestamp is less than the length of baseline - window, it will be rejected in the PSTH computation step.
- - Baseline parameters must be within the PSTH parameters - set in the PSTH parameters section. - """, - width=580, - ) - - baselineCorrectionStart = pn.widgets.LiteralInput( - name="Baseline Correction Start time(int)", value=-5, type=int, width=200 - ) - - baselineCorrectionEnd = pn.widgets.LiteralInput( - name="Baseline Correction End time(int)", value=0, type=int, width=200 - ) - - zscore_param_wd = pn.WidgetBox( - "### Z-score Parameters", - explain_z_score, - z_score_computation, - pn.Row(baseline_wd_strt, baseline_wd_end), - width=600, - ) - - psth_param_wd = pn.WidgetBox( - "### PSTH Parameters", - explain_nsec, - pn.Row(nSecPrev, nSecPost, computeCorr), - pn.Row(timeInterval, use_time_or_trials, bin_psth_trials), - width=600, - ) - - baseline_param_wd = pn.WidgetBox( - "### Baseline Parameters", explain_baseline, pn.Row(baselineCorrectionStart, baselineCorrectionEnd), width=600 - ) - peak_explain = pn.pane.Markdown( - """ - ***Note :***
- - Peak and area are computed between the window set below.
- - Peak and AUC parameters must be within the PSTH parameters set in the PSTH parameters section.
- - Please make sure when user changes the parameters in the table below, click on any other cell after - changing a value in a particular cell. - """, - width=580, - ) - - start_end_point_df = pd.DataFrame( - { - "Peak Start time": [-5, 0, 5, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], - "Peak End time": [0, 3, 10, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], - } - ) - - df_widget = pn.widgets.Tabulator(start_end_point_df, name="DataFrame", show_index=False, widths=280) - - peak_param_wd = pn.WidgetBox("### Peak and AUC Parameters", peak_explain, df_widget, width=600) - - mark_down_2 = pn.pane.Markdown( - """**Select folders for the average analysis from the file selector below**""", width=600 - ) - - files_2 = pn.widgets.FileSelector(folder_path, name="folderNamesForAvg", width=950) - - averageForGroup = pn.widgets.Select(name="Average Group? (bool)", value=False, options=[True, False], width=435) - - visualizeAverageResults = pn.widgets.Select( - name="Visualize Average Results? (bool)", value=False, options=[True, False], width=435 - ) - - visualize_zscore_or_dff = pn.widgets.Select( - name="z-score or \u0394F/F? (for visualization)", options=["z_score", "dff"], width=435 - ) - - individual_analysis_wd_2 = pn.Column( - explain_time_artifacts, - pn.Row(numberOfCores, combine_data), - isosbestic_control, - timeForLightsTurnOn, - moving_avg_filter, - computePsth, - transients, - plot_zScore_dff, - moving_wd, - pn.Row(highAmpFilt, transientsThresh), - no_channels_np, - pn.Row(removeArtifacts, artifactsRemovalMethod), - ) - - group_analysis_wd_1 = pn.Column(mark_down_2, files_2, averageForGroup, width=800) - - visualization_wd = pn.Row(visualize_zscore_or_dff, pn.Spacer(width=60), visualizeAverageResults) - - def getInputParameters(): - abspath = getAbsPath() - inputParameters = { - "abspath": abspath[0], - "folderNames": files_1.value, - "modality": modality_selector.value, - "numberOfCores": numberOfCores.value, - "combine_data": combine_data.value, - "isosbestic_control": isosbestic_control.value, - "timeForLightsTurnOn": timeForLightsTurnOn.value, - "filter_window": moving_avg_filter.value, - "removeArtifacts": removeArtifacts.value, - "artifactsRemovalMethod": artifactsRemovalMethod.value, - "noChannels": no_channels_np.value, - "zscore_method": z_score_computation.value, - "baselineWindowStart": baseline_wd_strt.value, - "baselineWindowEnd": baseline_wd_end.value, - "nSecPrev": nSecPrev.value, - "nSecPost": nSecPost.value, - "computeCorr": computeCorr.value, - "timeInterval": timeInterval.value, - "bin_psth_trials": bin_psth_trials.value, - "use_time_or_trials": use_time_or_trials.value, - "baselineCorrectionStart": baselineCorrectionStart.value, - "baselineCorrectionEnd": baselineCorrectionEnd.value, - "peak_startPoint": list(df_widget.value["Peak Start time"]), # startPoint.value, - "peak_endPoint": list(df_widget.value["Peak End time"]), # endPoint.value, - "selectForComputePsth": computePsth.value, - "selectForTransientsComputation": transients.value, - "moving_window": moving_wd.value, - "highAmpFilt": highAmpFilt.value, - "transientsThresh": transientsThresh.value, - "plot_zScore_dff": plot_zScore_dff.value, - "visualize_zscore_or_dff": visualize_zscore_or_dff.value, - "folderNamesForAvg": files_2.value, - "averageForGroup": averageForGroup.value, - "visualizeAverageResults": visualizeAverageResults.value, - } - return inputParameters - - def checkSameLocation(arr, abspath): - # abspath = [] - for i in range(len(arr)): - abspath.append(os.path.dirname(arr[i])) - abspath = np.asarray(abspath) - abspath = np.unique(abspath) - if len(abspath) > 1: - logger.error("All the folders selected should be at the same location") - raise Exception("All the folders selected should be at the same location") - - return abspath - - def getAbsPath(): - arr_1, arr_2 = files_1.value, files_2.value - if len(arr_1) == 0 and len(arr_2) == 0: - logger.error("No folder is selected for analysis") - raise Exception("No folder is selected for analysis") - - abspath = [] - if len(arr_1) > 0: - abspath = checkSameLocation(arr_1, abspath) - else: - abspath = checkSameLocation(arr_2, abspath) - - abspath = np.unique(abspath) - if len(abspath) > 1: - logger.error("All the folders selected should be at the same location") - raise Exception("All the folders selected should be at the same location") - return abspath - - def onclickProcess(event=None): - - logger.debug("Saving Input Parameters file.") - abspath = getAbsPath() - analysisParameters = { - "combine_data": combine_data.value, - "isosbestic_control": isosbestic_control.value, - "timeForLightsTurnOn": timeForLightsTurnOn.value, - "filter_window": moving_avg_filter.value, - "removeArtifacts": removeArtifacts.value, - "noChannels": no_channels_np.value, - "zscore_method": z_score_computation.value, - "baselineWindowStart": baseline_wd_strt.value, - "baselineWindowEnd": baseline_wd_end.value, - "nSecPrev": nSecPrev.value, - "nSecPost": nSecPost.value, - "timeInterval": timeInterval.value, - "bin_psth_trials": bin_psth_trials.value, - "use_time_or_trials": use_time_or_trials.value, - "baselineCorrectionStart": baselineCorrectionStart.value, - "baselineCorrectionEnd": baselineCorrectionEnd.value, - "peak_startPoint": list(df_widget.value["Peak Start time"]), # startPoint.value, - "peak_endPoint": list(df_widget.value["Peak End time"]), # endPoint.value, - "selectForComputePsth": computePsth.value, - "selectForTransientsComputation": transients.value, - "moving_window": moving_wd.value, - "highAmpFilt": highAmpFilt.value, - "transientsThresh": transientsThresh.value, - } - for folder in files_1.value: - with open(os.path.join(folder, "GuPPyParamtersUsed.json"), "w") as f: - json.dump(analysisParameters, f, indent=4) - logger.info(f"Input Parameters file saved at {folder}") - - logger.info("#" * 400) - - # path.value = (os.path.join(op, 'inputParameters.json')).replace('\\', '/') - logger.info("Input Parameters File Saved.") - - def onclickStoresList(event=None): - inputParameters = getInputParameters() - execute(inputParameters) - - def onclickVisualization(event=None): - inputParameters = getInputParameters() - visualizeResults(inputParameters) - - def onclickreaddata(event=None): - thread = Thread(target=readRawData) - thread.start() - readPBIncrementValues(read_progress) - thread.join() - - def onclickextractts(event=None): - thread = Thread(target=extractTs) - thread.start() - readPBIncrementValues(extract_progress) - thread.join() - - def onclickpsth(event=None): - thread = Thread(target=psthComputation) - thread.start() - readPBIncrementValues(psth_progress) - thread.join() - - mark_down_ip = pn.pane.Markdown("""**Step 1 : Save Input Parameters**""", width=300) - mark_down_ip_note = pn.pane.Markdown( - """***Note : ***
- - Save Input Parameters will save input parameters used for the analysis - in all the folders you selected for the analysis (useful for future - reference). All analysis steps will run without saving input parameters. - """, - width=300, - ) - save_button = pn.widgets.Button(name="Save to file...", button_type="primary", width=300, align="end") - mark_down_storenames = pn.pane.Markdown("""**Step 2 : Open Storenames GUI
and save storenames**""", width=300) - open_storesList = pn.widgets.Button(name="Open Storenames GUI", button_type="primary", width=300, align="end") - mark_down_read = pn.pane.Markdown("""**Step 3 : Read Raw Data**""", width=300) - read_rawData = pn.widgets.Button(name="Read Raw Data", button_type="primary", width=300, align="end") - mark_down_extract = pn.pane.Markdown("""**Step 4 : Extract timestamps
and its correction**""", width=300) - extract_ts = pn.widgets.Button( - name="Extract timestamps and it's correction", button_type="primary", width=300, align="end" - ) - mark_down_psth = pn.pane.Markdown("""**Step 5 : PSTH Computation**""", width=300) - psth_computation = pn.widgets.Button(name="PSTH Computation", button_type="primary", width=300, align="end") - mark_down_visualization = pn.pane.Markdown("""**Step 6 : Visualization**""", width=300) - open_visualization = pn.widgets.Button(name="Open Visualization GUI", button_type="primary", width=300, align="end") - open_terminal = pn.widgets.Button(name="Open Terminal", button_type="primary", width=300, align="end") - - save_button.on_click(onclickProcess) - open_storesList.on_click(onclickStoresList) - read_rawData.on_click(onclickreaddata) - extract_ts.on_click(onclickextractts) - psth_computation.on_click(onclickpsth) - open_visualization.on_click(onclickVisualization) - - template.sidebar.append(mark_down_ip) - template.sidebar.append(mark_down_ip_note) - template.sidebar.append(save_button) - # template.sidebar.append(path) - template.sidebar.append(mark_down_storenames) - template.sidebar.append(open_storesList) - template.sidebar.append(mark_down_read) - template.sidebar.append(read_rawData) - template.sidebar.append(read_progress) - template.sidebar.append(mark_down_extract) - template.sidebar.append(extract_ts) - template.sidebar.append(extract_progress) - template.sidebar.append(mark_down_psth) - template.sidebar.append(psth_computation) - template.sidebar.append(psth_progress) - template.sidebar.append(mark_down_visualization) - template.sidebar.append(open_visualization) - # template.sidebar.append(open_terminal) - - psth_baseline_param = pn.Column(zscore_param_wd, psth_param_wd, baseline_param_wd, peak_param_wd) - - widget = pn.Column( - mark_down_1, files_1, explain_modality, modality_selector, pn.Row(individual_analysis_wd_2, psth_baseline_param) - ) - - # file_selector = pn.WidgetBox(files_1) - styles = dict(background="WhiteSmoke") - individual = pn.Card(widget, title="Individual Analysis", styles=styles, width=1000) - group = pn.Card(group_analysis_wd_1, title="Group Analysis", styles=styles, width=1000) - visualize = pn.Card(visualization_wd, title="Visualization Parameters", styles=styles, width=1000) - - # template.main.append(file_selector) - template.main.append(individual) - template.main.append(group) - template.main.append(visualize) - - # Expose minimal hooks and widgets to enable programmatic testing - template._hooks = { - "onclickProcess": onclickProcess, - "getInputParameters": getInputParameters, - } - template._widgets = { - "files_1": files_1, - } - - return template diff --git a/src/guppy/testing/api.py b/src/guppy/testing/api.py index 98939cf..834a9e1 100644 --- a/src/guppy/testing/api.py +++ b/src/guppy/testing/api.py @@ -13,12 +13,12 @@ import os from typing import Iterable -from guppy.computePsth import psthForEachStorename -from guppy.findTransientsFreqAndAmp import executeFindFreqAndAmp -from guppy.preprocess import extractTsAndSignal -from guppy.readTevTsq import readRawData -from guppy.saveStoresList import execute -from guppy.savingInputParameters import savingInputParameters +from guppy.orchestration.home import build_homepage +from guppy.orchestration.preprocess import extractTsAndSignal +from guppy.orchestration.psth import psthForEachStorename +from guppy.orchestration.read_raw_data import orchestrate_read_raw_data +from guppy.orchestration.storenames import orchestrate_storenames_page +from guppy.orchestration.transients import executeFindFreqAndAmp def step1(*, base_dir: str, selected_folders: Iterable[str]) -> None: @@ -50,7 +50,7 @@ def step1(*, base_dir: str, selected_folders: Iterable[str]) -> None: os.environ["GUPPY_BASE_DIR"] = base_dir # Build the template headlessly - template = savingInputParameters() + template = build_homepage() # Sanity checks: ensure hooks/widgets exposed if not hasattr(template, "_hooks") or "onclickProcess" not in template._hooks: @@ -144,7 +144,7 @@ def step2( # Headless build: set base_dir and construct the template os.environ["GUPPY_BASE_DIR"] = base_dir - template = savingInputParameters() + template = build_homepage() # Ensure hooks/widgets exposed if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks: @@ -168,7 +168,7 @@ def step2( input_params["npm_split_events"] = npm_split_events # Call the underlying Step 2 executor (now headless-aware) - execute(input_params) + orchestrate_storenames_page(input_params) def step3( @@ -236,7 +236,7 @@ def step3( # Headless build: set base_dir and construct the template os.environ["GUPPY_BASE_DIR"] = base_dir - template = savingInputParameters() + template = build_homepage() # Ensure hooks/widgets exposed if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks: @@ -257,7 +257,7 @@ def step3( input_params["modality"] = modality # Call the underlying Step 3 worker directly (no subprocess) - readRawData(input_params) + orchestrate_read_raw_data(input_params) def step4( @@ -328,7 +328,7 @@ def step4( # Headless build: set base_dir and construct the template os.environ["GUPPY_BASE_DIR"] = base_dir - template = savingInputParameters() + template = build_homepage() # Ensure hooks/widgets exposed if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks: @@ -420,7 +420,7 @@ def step5( # Headless build: set base_dir and construct the template os.environ["GUPPY_BASE_DIR"] = base_dir - template = savingInputParameters() + template = build_homepage() # Ensure hooks/widgets exposed if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks: diff --git a/src/guppy/utils/__init__.py b/src/guppy/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/guppy/utils/utils.py b/src/guppy/utils/utils.py new file mode 100644 index 0000000..7f7bb17 --- /dev/null +++ b/src/guppy/utils/utils.py @@ -0,0 +1,43 @@ +import logging +import os +import re + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + + +def takeOnlyDirs(paths): + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + + +def get_all_stores_for_combining_data(folderNames): + op = [] + for i in range(100): + temp = [] + match = r"[\s\S]*" + "_output_" + str(i) + for j in folderNames: + temp.append(re.findall(match, j)) + temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold) + if len(temp) > 0: + op.append(temp) + + return op + + +# function to read h5 file and make a dataframe from it +def read_Df(filepath, event, name): + event = event.replace("\\", "_") + event = event.replace("/", "_") + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df diff --git a/src/guppy/visualization/__init__.py b/src/guppy/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/guppy/visualization/preprocessing.py b/src/guppy/visualization/preprocessing.py new file mode 100644 index 0000000..cda8150 --- /dev/null +++ b/src/guppy/visualization/preprocessing.py @@ -0,0 +1,43 @@ +import logging +import os + +import matplotlib.pyplot as plt + +logger = logging.getLogger(__name__) + +# Only set matplotlib backend if not in CI environment +if not os.getenv("CI"): + plt.switch_backend("TKAgg") + + +def visualize_preprocessing(*, suptitle, title, x, y): + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(x, y) + ax.set_title(title) + fig.suptitle(suptitle) + + return fig, ax + + +def visualize_control_signal_fit(x, y1, y2, y3, plot_name, name, artifacts_have_been_removed): + fig = plt.figure() + ax1 = fig.add_subplot(311) + (line1,) = ax1.plot(x, y1) + ax1.set_title(plot_name[0]) + ax2 = fig.add_subplot(312) + (line2,) = ax2.plot(x, y2) + ax2.set_title(plot_name[1]) + ax3 = fig.add_subplot(313) + (line3,) = ax3.plot(x, y2) + (line3,) = ax3.plot(x, y3) + ax3.set_title(plot_name[2]) + fig.suptitle(name) + + hfont = {"fontname": "DejaVu Sans"} + + if artifacts_have_been_removed: + ax3.set_xlabel("Time(s) \n Note : Artifacts have been removed, but are not reflected in this plot.", **hfont) + else: + ax3.set_xlabel("Time(s)", **hfont) + return fig, ax1, ax2, ax3 diff --git a/src/guppy/visualization/transients.py b/src/guppy/visualization/transients.py new file mode 100644 index 0000000..030ac5e --- /dev/null +++ b/src/guppy/visualization/transients.py @@ -0,0 +1,15 @@ +import logging + +import matplotlib.pyplot as plt + +logger = logging.getLogger(__name__) + + +def visualize_peaks(title, suptitle, z_score, timestamps, peaksIndex): + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(timestamps, z_score, "-", timestamps[peaksIndex], z_score[peaksIndex], "o") + ax.set_title(title) + fig.suptitle(suptitle) + + return fig, ax diff --git a/src/guppy/visualizePlot.py b/src/guppy/visualizePlot.py deleted file mode 100755 index 929149e..0000000 --- a/src/guppy/visualizePlot.py +++ /dev/null @@ -1,936 +0,0 @@ -import glob -import logging -import math -import os -import re -import socket -from random import randint - -import datashader as ds -import holoviews as hv -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import panel as pn -import param -from bokeh.io import export_png, export_svgs -from holoviews import opts -from holoviews.operation.datashader import datashade -from holoviews.plotting.util import process_cmap - -from .preprocess import get_all_stores_for_combining_data - -pn.extension() - -logger = logging.getLogger(__name__) - - -def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"): - while True: - port = randint(start_port, end_port) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(0.001) # Set timeout to avoid long waiting on closed ports - result = sock.connect_ex((host, port)) - if result == 0: # If the connection is successful, the port is open - continue - else: - break - - return port - - -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) - - -# read h5 file as a dataframe -def read_Df(filepath, event, name): - event = event.replace("\\", "_") - event = event.replace("/", "_") - if name: - op = os.path.join(filepath, event + "_{}.h5".format(name)) - else: - op = os.path.join(filepath, event + ".h5") - df = pd.read_hdf(op, key="df", mode="r") - - return df - - -# make a new directory for saving plots -def make_dir(filepath): - op = os.path.join(filepath, "saved_plots") - if not os.path.exists(op): - os.mkdir(op) - - return op - - -# remove unnecessary column names -def remove_cols(cols): - regex = re.compile("bin_err_*") - remove_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - remove_cols = remove_cols + ["err", "timestamps"] - cols = [i for i in cols if i not in remove_cols] - - return cols - - -# def look_psth_bins(event, name): - - -# helper function to create plots -def helper_plots(filepath, event, name, inputParameters): - - basename = os.path.basename(filepath) - visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"] - - # note when there are no behavior event TTLs - if len(event) == 0: - logger.warning("\033[1m" + "There are no behavior event TTLs present to visualize.".format(event) + "\033[0m") - return 0 - - if os.path.exists(os.path.join(filepath, "cross_correlation_output")): - event_corr, frames = [], [] - if visualize_zscore_or_dff == "z_score": - corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_z_score_*")) - elif visualize_zscore_or_dff == "dff": - corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_dff_*")) - for i in range(len(corr_fp)): - filename = os.path.basename(corr_fp[i]).split(".")[0] - event_corr.append(filename) - df = pd.read_hdf(corr_fp[i], key="df", mode="r") - frames.append(df) - if len(frames) > 0: - df_corr = pd.concat(frames, keys=event_corr, axis=1) - else: - event_corr = [] - df_corr = [] - else: - event_corr = [] - df_corr = None - - # combine all the event PSTH so that it can be viewed together - if name: - event_name, name = event, name - new_event, frames, bins = [], [], {} - for i in range(len(event_name)): - - for j in range(len(name)): - new_event.append(event_name[i] + "_" + name[j].split("_")[-1]) - new_name = name[j] - temp_df = read_Df(filepath, new_event[-1], new_name) - cols = list(temp_df.columns) - regex = re.compile("bin_[(]") - bins[new_event[-1]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - # bins.append(keep_cols) - frames.append(temp_df) - - df = pd.concat(frames, keys=new_event, axis=1) - else: - new_event = list(np.unique(np.array(event))) - frames, bins = [], {} - for i in range(len(new_event)): - temp_df = read_Df(filepath, new_event[i], "") - cols = list(temp_df.columns) - regex = re.compile("bin_[(]") - bins[new_event[i]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - frames.append(temp_df) - - df = pd.concat(frames, keys=new_event, axis=1) - - if isinstance(df_corr, pd.DataFrame): - new_event.extend(event_corr) - df = pd.concat([df, df_corr], axis=1, sort=False).reset_index() - - columns_dict = dict() - for i in range(len(new_event)): - df_1 = df[new_event[i]] - columns = list(df_1.columns) - columns.append("All") - columns_dict[new_event[i]] = columns - - # create a class to make GUI and plot different graphs - class Viewer(param.Parameterized): - - # class_event = new_event - - # make options array for different selectors - multiple_plots_options = [] - heatmap_options = new_event - bins_keys = list(bins.keys()) - if len(bins_keys) > 0: - bins_new = bins - for i in range(len(bins_keys)): - arr = bins[bins_keys[i]] - if len(arr) > 0: - # heatmap_options.append('{}_bin'.format(bins_keys[i])) - for j in arr: - multiple_plots_options.append("{}_{}".format(bins_keys[i], j)) - - multiple_plots_options = new_event + multiple_plots_options - else: - multiple_plots_options = new_event - - # create different options and selectors - event_selector = param.ObjectSelector(default=new_event[0], objects=new_event) - event_selector_heatmap = param.ObjectSelector(default=heatmap_options[0], objects=heatmap_options) - columns = columns_dict - df_new = df - - colormaps = plt.colormaps() - new_colormaps = ["plasma", "plasma_r", "magma", "magma_r", "inferno", "inferno_r", "viridis", "viridis_r"] - set_a = set(colormaps) - set_b = set(new_colormaps) - colormaps = new_colormaps + list(set_a.difference(set_b)) - - x_min = float(inputParameters["nSecPrev"]) - 20 - x_max = float(inputParameters["nSecPost"]) + 20 - selector_for_multipe_events_plot = param.ListSelector( - default=[multiple_plots_options[0]], objects=multiple_plots_options - ) - x = param.ObjectSelector(default=columns[new_event[0]][-4], objects=[columns[new_event[0]][-4]]) - y = param.ObjectSelector( - default=remove_cols(columns[new_event[0]])[-2], objects=remove_cols(columns[new_event[0]]) - ) - - trial_no = range(1, len(remove_cols(columns[heatmap_options[0]])[:-2]) + 1) - trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(columns[heatmap_options[0]])[:-2])] + [ - "All" - ] - heatmap_y = param.ListSelector(default=[trial_ts[-1]], objects=trial_ts) - psth_y = param.ListSelector(objects=trial_ts[:-1]) - select_trials_checkbox = param.ListSelector(default=["just trials"], objects=["mean", "just trials"]) - Y_Label = param.ObjectSelector(default="y", objects=["y", "z-score", "\u0394F/F"]) - save_options = param.ObjectSelector( - default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"] - ) - save_options_heatmap = param.ObjectSelector( - default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"] - ) - color_map = param.ObjectSelector(default="plasma", objects=colormaps) - height_heatmap = param.ObjectSelector(default=600, objects=list(np.arange(0, 5100, 100))[1:]) - width_heatmap = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:]) - Height_Plot = param.ObjectSelector(default=300, objects=list(np.arange(0, 5100, 100))[1:]) - Width_Plot = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:]) - save_hm = param.Action(lambda x: x.param.trigger("save_hm"), label="Save") - save_psth = param.Action(lambda x: x.param.trigger("save_psth"), label="Save") - X_Limit = param.Range(default=(-5, 10), bounds=(x_min, x_max)) - Y_Limit = param.Range(bounds=(-50, 50.0)) - - # C_Limit = param.Range(bounds=(-20,20.0)) - - results_hm = dict() - results_psth = dict() - - # function to save heatmaps when save button on heatmap tab is clicked - @param.depends("save_hm", watch=True) - def save_hm_plots(self): - plot = self.results_hm["plot"] - op = self.results_hm["op"] - save_opts = self.save_options_heatmap - logger.info(save_opts) - if save_opts == "save_svg_format": - p = hv.render(plot, backend="bokeh") - p.output_backend = "svg" - export_svgs(p, filename=op + ".svg") - elif save_opts == "save_png_format": - p = hv.render(plot, backend="bokeh") - export_png(p, filename=op + ".png") - elif save_opts == "save_both_format": - p = hv.render(plot, backend="bokeh") - p.output_backend = "svg" - export_svgs(p, filename=op + ".svg") - p_png = hv.render(plot, backend="bokeh") - export_png(p_png, filename=op + ".png") - else: - return 0 - - # function to save PSTH plots when save button on PSTH tab is clicked - @param.depends("save_psth", watch=True) - def save_psth_plot(self): - plot, op = [], [] - plot.append(self.results_psth["plot_combine"]) - op.append(self.results_psth["op_combine"]) - plot.append(self.results_psth["plot"]) - op.append(self.results_psth["op"]) - for i in range(len(plot)): - temp_plot, temp_op = plot[i], op[i] - save_opts = self.save_options - if save_opts == "save_svg_format": - p = hv.render(temp_plot, backend="bokeh") - p.output_backend = "svg" - export_svgs(p, filename=temp_op + ".svg") - elif save_opts == "save_png_format": - p = hv.render(temp_plot, backend="bokeh") - export_png(p, filename=temp_op + ".png") - elif save_opts == "save_both_format": - p = hv.render(temp_plot, backend="bokeh") - p.output_backend = "svg" - export_svgs(p, filename=temp_op + ".svg") - p_png = hv.render(temp_plot, backend="bokeh") - export_png(p_png, filename=temp_op + ".png") - else: - return 0 - - # function to change Y values based on event selection - @param.depends("event_selector", watch=True) - def _update_x_y(self): - x_value = self.columns[self.event_selector] - y_value = self.columns[self.event_selector] - self.param["x"].objects = [x_value[-4]] - self.param["y"].objects = remove_cols(y_value) - self.x = x_value[-4] - self.y = self.param["y"].objects[-2] - - @param.depends("event_selector_heatmap", watch=True) - def _update_df(self): - cols = self.columns[self.event_selector_heatmap] - trial_no = range(1, len(remove_cols(cols)[:-2]) + 1) - trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + ["All"] - self.param["heatmap_y"].objects = trial_ts - self.heatmap_y = [trial_ts[-1]] - - @param.depends("event_selector", watch=True) - def _update_psth_y(self): - cols = self.columns[self.event_selector] - trial_no = range(1, len(remove_cols(cols)[:-2]) + 1) - trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] - self.param["psth_y"].objects = trial_ts - self.psth_y = [trial_ts[0]] - - # function to plot multiple PSTHs into one plot - - @param.depends( - "selector_for_multipe_events_plot", - "Y_Label", - "save_options", - "X_Limit", - "Y_Limit", - "Height_Plot", - "Width_Plot", - ) - def update_selector(self): - data_curve, cols_curve, data_spread, cols_spread = [], [], [], [] - arr = self.selector_for_multipe_events_plot - df1 = self.df_new - for i in range(len(arr)): - if "bin" in arr[i]: - split = arr[i].rsplit("_", 2) - df_name = split[0] #'{}_{}'.format(split[0], split[1]) - col_name_mean = "{}_{}".format(split[-2], split[-1]) - col_name_err = "{}_err_{}".format(split[-2], split[-1]) - data_curve.append(df1[df_name][col_name_mean]) - cols_curve.append(arr[i]) - data_spread.append(df1[df_name][col_name_err]) - cols_spread.append(arr[i]) - else: - data_curve.append(df1[arr[i]]["mean"]) - cols_curve.append(arr[i] + "_" + "mean") - data_spread.append(df1[arr[i]]["err"]) - cols_spread.append(arr[i] + "_" + "mean") - - if len(arr) > 0: - if self.Y_Limit == None: - self.Y_Limit = (np.nanmin(np.asarray(data_curve)) - 0.5, np.nanmax(np.asarray(data_curve)) + 0.5) - - if "bin" in arr[i]: - split = arr[i].rsplit("_", 2) - df_name = split[0] - data_curve.append(df1[df_name]["timestamps"]) - cols_curve.append("timestamps") - data_spread.append(df1[df_name]["timestamps"]) - cols_spread.append("timestamps") - else: - data_curve.append(df1[arr[i]]["timestamps"]) - cols_curve.append("timestamps") - data_spread.append(df1[arr[i]]["timestamps"]) - cols_spread.append("timestamps") - df_curve = pd.concat(data_curve, axis=1) - df_spread = pd.concat(data_spread, axis=1) - df_curve.columns = cols_curve - df_spread.columns = cols_spread - - ts = df_curve["timestamps"] - index = np.arange(0, ts.shape[0], 3) - df_curve = df_curve.loc[index, :] - df_spread = df_spread.loc[index, :] - overlay = hv.NdOverlay( - { - c: hv.Curve((df_curve["timestamps"], df_curve[c]), kdims=["Time (s)"]).opts( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - ) - for c in cols_curve[:-1] - } - ) - spread = hv.NdOverlay( - { - d: hv.Spread( - (df_spread["timestamps"], df_curve[d], df_spread[d], df_spread[d]), - vdims=["y", "yerrpos", "yerrneg"], - ).opts(line_width=0, fill_alpha=0.3) - for d in cols_spread[:-1] - } - ) - plot_combine = ((overlay * spread).opts(opts.NdOverlay(xlabel="Time (s)", ylabel=self.Y_Label))).opts( - shared_axes=False - ) - # plot_err = new_df.hvplot.area(x='timestamps', y=[], y2=[]) - save_opts = self.save_options - op = make_dir(filepath) - op_filename = os.path.join(op, str(arr) + "_mean") - - self.results_psth["plot_combine"] = plot_combine - self.results_psth["op_combine"] = op_filename - # self.save_plots(plot_combine, save_opts, op_filename) - return plot_combine - - # function to plot mean PSTH, single trial in PSTH and all the trials of PSTH with mean - @param.depends( - "event_selector", "x", "y", "Y_Label", "save_options", "Y_Limit", "X_Limit", "Height_Plot", "Width_Plot" - ) - def contPlot(self): - df1 = self.df_new[self.event_selector] - # height = self.Heigth_Plot - # width = self.Width_Plot - # logger.info(height, width) - if self.y == "All": - if self.Y_Limit == None: - self.Y_Limit = (np.nanmin(np.asarray(df1)) - 0.5, np.nanmax(np.asarray(df1)) - 0.5) - - options = self.param["y"].objects - regex = re.compile("bin_[(]") - remove_bin_trials = [options[i] for i in range(len(options)) if not regex.match(options[i])] - - ndoverlay = hv.NdOverlay({c: hv.Curve((df1[self.x], df1[c])) for c in remove_bin_trials[:-2]}) - img1 = datashade(ndoverlay, normalization="linear", aggregator=ds.count()) - x_points = df1[self.x] - y_points = df1["mean"] - img2 = hv.Curve((x_points, y_points)) - img = (img1 * img2).opts( - opts.Curve( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - line_width=4, - color="black", - xlim=self.X_Limit, - ylim=self.Y_Limit, - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - ) - - save_opts = self.save_options - - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector + "_" + self.y) - self.results_psth["plot"] = img - self.results_psth["op"] = op_filename - # self.save_plots(img, save_opts, op_filename) - - return img - - elif self.y == "mean" or "bin" in self.y: - - xpoints = df1[self.x] - ypoints = df1[self.y] - if self.y == "mean": - err = df1["err"] - else: - split = self.y.split("_") - err = df1["{}_err_{}".format(split[0], split[1])] - - index = np.arange(0, xpoints.shape[0], 3) - - if self.Y_Limit == None: - self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5) - - ropts_curve = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - color="blue", - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - ropts_spread = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - fill_alpha=0.3, - fill_color="blue", - line_width=0, - ) - - plot_curve = hv.Curve((xpoints[index], ypoints[index])) # .opts(**ropts_curve) - plot_spread = hv.Spread( - (xpoints[index], ypoints[index], err[index], err[index]) - ) # .opts(**ropts_spread) #vdims=['y', 'yerrpos', 'yerrneg'] - plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) - - save_opts = self.save_options - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector + "_" + self.y) - self.results_psth["plot"] = plot - self.results_psth["op"] = op_filename - # self.save_plots(plot, save_opts, op_filename) - - return plot - - else: - xpoints = df1[self.x] - ypoints = df1[self.y] - if self.Y_Limit == None: - self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5) - - ropts_curve = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - color="blue", - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - plot = hv.Curve((xpoints, ypoints)).opts({"Curve": ropts_curve}) - - save_opts = self.save_options - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector + "_" + self.y) - self.results_psth["plot"] = plot - self.results_psth["op"] = op_filename - # self.save_plots(plot, save_opts, op_filename) - - return plot - - # function to plot specific PSTH trials - @param.depends( - "event_selector", - "x", - "psth_y", - "select_trials_checkbox", - "Y_Label", - "save_options", - "Y_Limit", - "X_Limit", - "Height_Plot", - "Width_Plot", - ) - def plot_specific_trials(self): - df_psth = self.df_new[self.event_selector] - # if self.Y_Limit==None: - # self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5) - - if self.psth_y == None: - return None - else: - selected_trials = [s.split(" - ")[1] for s in list(self.psth_y)] - - index = np.arange(0, df_psth["timestamps"].shape[0], 3) - - if self.select_trials_checkbox == ["just trials"]: - overlay = hv.NdOverlay( - { - c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"]) - for c in selected_trials - } - ) - ropts = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - return overlay.opts(**ropts) - elif self.select_trials_checkbox == ["mean"]: - arr = np.asarray(df_psth[selected_trials]) - mean = np.nanmean(arr, axis=1) - err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1]) - ropts_curve = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - color="blue", - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - ropts_spread = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - fill_alpha=0.3, - fill_color="blue", - line_width=0, - ) - plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index])) - plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index])) - plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) - return plot - elif self.select_trials_checkbox == ["mean", "just trials"]: - overlay = hv.NdOverlay( - { - c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"]) - for c in selected_trials - } - ) - ropts_overlay = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - - arr = np.asarray(df_psth[selected_trials]) - mean = np.nanmean(arr, axis=1) - err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1]) - ropts_curve = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - xlim=self.X_Limit, - ylim=self.Y_Limit, - color="black", - xlabel="Time (s)", - ylabel=self.Y_Label, - ) - ropts_spread = dict( - width=int(self.Width_Plot), - height=int(self.Height_Plot), - fill_alpha=0.3, - fill_color="black", - line_width=0, - ) - plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index])) - plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index])) - - plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) - return overlay.opts(**ropts_overlay) * plot - - # function to show heatmaps for each event - @param.depends("event_selector_heatmap", "color_map", "height_heatmap", "width_heatmap", "heatmap_y") - def heatmap(self): - height = self.height_heatmap - width = self.width_heatmap - df_hm = self.df_new[self.event_selector_heatmap] - cols = list(df_hm.columns) - regex = re.compile("bin_err_*") - drop_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - drop_cols = ["err", "mean"] + drop_cols - df_hm = df_hm.drop(drop_cols, axis=1) - cols = list(df_hm.columns) - bin_cols = [cols[i] for i in range(len(cols)) if re.compile("bin_*").match(cols[i])] - time = np.asarray(df_hm["timestamps"]) - event_ts_for_each_event = np.arange(1, len(df_hm.columns[:-1]) + 1) - yticks = list(event_ts_for_each_event) - z_score = np.asarray(df_hm[df_hm.columns[:-1]]).T - - if self.heatmap_y[0] == "All": - indices = np.arange(z_score.shape[0] - len(bin_cols)) - z_score = z_score[indices, :] - event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1) - yticks = list(event_ts_for_each_event) - else: - remove_all = list(set(self.heatmap_y) - set(["All"])) - indices = sorted([int(s.split("-")[0]) - 1 for s in remove_all]) - z_score = z_score[indices, :] - event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1) - yticks = list(event_ts_for_each_event) - - clim = (np.nanmin(z_score), np.nanmax(z_score)) - font_size = {"labels": 16, "yticks": 6} - - if event_ts_for_each_event.shape[0] == 1: - dummy_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)).opts(colorbar=True, clim=clim) - image = ( - (dummy_image).opts( - opts.QuadMesh( - width=int(width), - height=int(height), - cmap=process_cmap(self.color_map, provider="matplotlib"), - colorbar=True, - ylabel="Trials", - xlabel="Time (s)", - fontsize=font_size, - yticks=yticks, - ) - ) - ).opts(shared_axes=False) - - save_opts = self.save_options_heatmap - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap") - self.results_hm["plot"] = image - self.results_hm["op"] = op_filename - # self.save_plots(image, save_opts, op_filename) - return image - else: - ropts = dict( - width=int(width), - height=int(height), - ylabel="Trials", - xlabel="Time (s)", - fontsize=font_size, - yticks=yticks, - invert_yaxis=True, - ) - dummy_image = hv.QuadMesh((time[0:100], event_ts_for_each_event, z_score[:, 0:100])).opts( - colorbar=True, cmap=process_cmap(self.color_map, provider="matplotlib"), clim=clim - ) - actual_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)) - - dynspread_img = datashade(actual_image, cmap=process_cmap(self.color_map, provider="matplotlib")).opts( - **ropts - ) # clims=self.C_Limit, cnorm='log' - image = ((dummy_image * dynspread_img).opts(opts.QuadMesh(width=int(width), height=int(height)))).opts( - shared_axes=False - ) - - save_opts = self.save_options_heatmap - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap") - self.results_hm["plot"] = image - self.results_hm["op"] = op_filename - - return image - - view = Viewer() - - # PSTH plot options - psth_checkbox = pn.Param( - view.param.select_trials_checkbox, - widgets={ - "select_trials_checkbox": { - "type": pn.widgets.CheckBoxGroup, - "inline": True, - "name": "Select mean and/or just trials", - } - }, - ) - parameters = pn.Param( - view.param.selector_for_multipe_events_plot, - widgets={ - "selector_for_multipe_events_plot": {"type": pn.widgets.CrossSelector, "width": 550, "align": "start"} - }, - ) - heatmap_y_parameters = pn.Param( - view.param.heatmap_y, - widgets={ - "heatmap_y": {"type": pn.widgets.MultiSelect, "name": "Trial # - Timestamps", "width": 200, "size": 30} - }, - ) - psth_y_parameters = pn.Param( - view.param.psth_y, - widgets={ - "psth_y": { - "type": pn.widgets.MultiSelect, - "name": "Trial # - Timestamps", - "width": 200, - "size": 15, - "align": "start", - } - }, - ) - - event_selector = pn.Param( - view.param.event_selector, widgets={"event_selector": {"type": pn.widgets.Select, "width": 400}} - ) - x_selector = pn.Param(view.param.x, widgets={"x": {"type": pn.widgets.Select, "width": 180}}) - y_selector = pn.Param(view.param.y, widgets={"y": {"type": pn.widgets.Select, "width": 180}}) - - width_plot = pn.Param(view.param.Width_Plot, widgets={"Width_Plot": {"type": pn.widgets.Select, "width": 70}}) - height_plot = pn.Param(view.param.Height_Plot, widgets={"Height_Plot": {"type": pn.widgets.Select, "width": 70}}) - ylabel = pn.Param(view.param.Y_Label, widgets={"Y_Label": {"type": pn.widgets.Select, "width": 70}}) - save_opts = pn.Param(view.param.save_options, widgets={"save_options": {"type": pn.widgets.Select, "width": 70}}) - - xlimit_plot = pn.Param(view.param.X_Limit, widgets={"X_Limit": {"type": pn.widgets.RangeSlider, "width": 180}}) - ylimit_plot = pn.Param(view.param.Y_Limit, widgets={"Y_Limit": {"type": pn.widgets.RangeSlider, "width": 180}}) - save_psth = pn.Param(view.param.save_psth, widgets={"save_psth": {"type": pn.widgets.Button, "width": 400}}) - - options = pn.Column( - event_selector, - pn.Row(x_selector, y_selector), - pn.Row(xlimit_plot, ylimit_plot), - pn.Row(width_plot, height_plot, ylabel, save_opts), - save_psth, - ) - - options_selectors = pn.Row(options, parameters) - - line_tab = pn.Column( - "## " + basename, - pn.Row(options_selectors, pn.Column(psth_checkbox, psth_y_parameters), width=1200), - view.contPlot, - view.update_selector, - view.plot_specific_trials, - ) - - # Heatmap plot options - event_selector_heatmap = pn.Param( - view.param.event_selector_heatmap, widgets={"event_selector_heatmap": {"type": pn.widgets.Select, "width": 150}} - ) - color_map = pn.Param(view.param.color_map, widgets={"color_map": {"type": pn.widgets.Select, "width": 150}}) - width_heatmap = pn.Param( - view.param.width_heatmap, widgets={"width_heatmap": {"type": pn.widgets.Select, "width": 150}} - ) - height_heatmap = pn.Param( - view.param.height_heatmap, widgets={"height_heatmap": {"type": pn.widgets.Select, "width": 150}} - ) - save_hm = pn.Param(view.param.save_hm, widgets={"save_hm": {"type": pn.widgets.Button, "width": 150}}) - save_options_heatmap = pn.Param( - view.param.save_options_heatmap, widgets={"save_options_heatmap": {"type": pn.widgets.Select, "width": 150}} - ) - - hm_tab = pn.Column( - "## " + basename, - pn.Row( - event_selector_heatmap, - color_map, - width_heatmap, - height_heatmap, - save_options_heatmap, - pn.Column(pn.Spacer(height=25), save_hm), - ), - pn.Row(view.heatmap, heatmap_y_parameters), - ) # - logger.info("app") - - template = pn.template.MaterialTemplate(title="Visualization GUI") - - number = scanPortsAndFind(start_port=5000, end_port=5200) - app = pn.Tabs(("PSTH", line_tab), ("Heat Map", hm_tab)) - - template.main.append(app) - - template.show(port=number) - - -# function to combine all the output folders together and preprocess them to use them in helper_plots function -def createPlots(filepath, event, inputParameters): - - for i in range(len(event)): - event[i] = event[i].replace("\\", "_") - event[i] = event[i].replace("/", "_") - - average = inputParameters["visualizeAverageResults"] - visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"] - - if average == True: - path = [] - for i in range(len(event)): - if visualize_zscore_or_dff == "z_score": - path.append(glob.glob(os.path.join(filepath, event[i] + "*_z_score_*"))) - elif visualize_zscore_or_dff == "dff": - path.append(glob.glob(os.path.join(filepath, event[i] + "*_dff_*"))) - - path = np.concatenate(path) - else: - if visualize_zscore_or_dff == "z_score": - path = glob.glob(os.path.join(filepath, "z_score_*")) - elif visualize_zscore_or_dff == "dff": - path = glob.glob(os.path.join(filepath, "dff_*")) - - name_arr = [] - event_arr = [] - - index = [] - for i in range(len(event)): - if "control" in event[i].lower() or "signal" in event[i].lower(): - index.append(i) - - event = np.delete(event, index) - - for i in range(len(path)): - name = (os.path.basename(path[i])).split(".") - name = name[0] - name_arr.append(name) - - if average == True: - logger.info("average") - helper_plots(filepath, name_arr, "", inputParameters) - else: - helper_plots(filepath, event, name_arr, inputParameters) - - -def visualizeResults(inputParameters): - - inputParameters = inputParameters - - average = inputParameters["visualizeAverageResults"] - logger.info(average) - - folderNames = inputParameters["folderNames"] - folderNamesForAvg = inputParameters["folderNamesForAvg"] - combine_data = inputParameters["combine_data"] - - if average == True and len(folderNamesForAvg) > 0: - # folderNames = folderNamesForAvg - filepath_avg = os.path.join(inputParameters["abspath"], "average") - # filepath = os.path.join(inputParameters['abspath'], folderNames[0]) - storesListPath = [] - for i in range(len(folderNamesForAvg)): - filepath = folderNamesForAvg[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) - storesListPath = np.concatenate(storesListPath) - storesList = np.asarray([[], []]) - for i in range(storesListPath.shape[0]): - storesList = np.concatenate( - ( - storesList, - np.genfromtxt( - os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter="," - ).reshape(2, -1), - ), - axis=1, - ) - storesList = np.unique(storesList, axis=1) - - createPlots(filepath_avg, np.unique(storesList[1, :]), inputParameters) - - else: - if combine_data == True: - storesListPath = [] - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) - storesListPath = list(np.concatenate(storesListPath).flatten()) - op = get_all_stores_for_combining_data(storesListPath) - for i in range(len(op)): - storesList = np.asarray([[], []]) - for j in range(len(op[i])): - storesList = np.concatenate( - ( - storesList, - np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape( - 2, -1 - ), - ), - axis=1, - ) - storesList = np.unique(storesList, axis=1) - filepath = op[i][0] - createPlots(filepath, storesList[1, :], inputParameters) - else: - for i in range(len(folderNames)): - - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt( - os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," - ).reshape(2, -1) - - createPlots(filepath, storesList[1, :], inputParameters) - - -# logger.info(sys.argv[1:]) -# visualizeResults(sys.argv[1:][0])