diff --git a/src/guppy/analysis/io_utils.py b/src/guppy/analysis/io_utils.py
index 742ab3b..97f7ecf 100644
--- a/src/guppy/analysis/io_utils.py
+++ b/src/guppy/analysis/io_utils.py
@@ -6,17 +6,10 @@
import h5py
import numpy as np
-import pandas as pd
-
-logger = logging.getLogger(__name__)
+from ..utils.utils import takeOnlyDirs
-def takeOnlyDirs(paths):
- removePaths = []
- for p in paths:
- if os.path.isfile(p):
- removePaths.append(p)
- return list(set(paths) - set(removePaths))
+logger = logging.getLogger(__name__)
# find files by ignoring the case sensitivity
@@ -143,20 +136,6 @@ def get_coords(filepath, name, tsNew, removeArtifacts): # TODO: Make less redun
return coords
-def get_all_stores_for_combining_data(folderNames):
- op = []
- for i in range(100):
- temp = []
- match = r"[\s\S]*" + "_output_" + str(i)
- for j in folderNames:
- temp.append(re.findall(match, j))
- temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold)
- if len(temp) > 0:
- op.append(temp)
-
- return op
-
-
# for combining data, reading storeslist file from both data and create a new storeslist array
def check_storeslistfile(folderNames):
storesList = np.array([[], []])
@@ -197,19 +176,6 @@ def get_control_and_signal_channel_names(storesList):
return channels_arr
-# function to read h5 file and make a dataframe from it
-def read_Df(filepath, event, name):
- event = event.replace("\\", "_")
- event = event.replace("/", "_")
- if name:
- op = os.path.join(filepath, event + "_{}.h5".format(name))
- else:
- op = os.path.join(filepath, event + ".h5")
- df = pd.read_hdf(op, key="df", mode="r")
-
- return df
-
-
def make_dir_for_cross_correlation(filepath):
op = os.path.join(filepath, "cross_correlation_output")
if not os.path.exists(op):
diff --git a/src/guppy/analysis/psth_average.py b/src/guppy/analysis/psth_average.py
index 664cc3d..5df8c87 100644
--- a/src/guppy/analysis/psth_average.py
+++ b/src/guppy/analysis/psth_average.py
@@ -10,10 +10,10 @@
from .io_utils import (
make_dir_for_cross_correlation,
makeAverageDir,
- read_Df,
write_hdf5,
)
from .psth_utils import create_Df_for_psth, getCorrCombinations
+from ..utils.utils import read_Df
logger = logging.getLogger(__name__)
diff --git a/src/guppy/analysis/standard_io.py b/src/guppy/analysis/standard_io.py
index d6dd9af..d0ade6d 100644
--- a/src/guppy/analysis/standard_io.py
+++ b/src/guppy/analysis/standard_io.py
@@ -331,3 +331,18 @@ def read_freq_and_amp_from_hdf5(filepath, name):
df = pd.read_hdf(op, key="df", mode="r")
return df
+
+
+def write_transients_to_hdf5(filepath, name, z_score, ts, peaksInd):
+ event = f"transient_outputs_{name}"
+ write_hdf5(z_score, event, filepath, "z_score")
+ write_hdf5(ts, event, filepath, "timestamps")
+ write_hdf5(peaksInd, event, filepath, "peaksInd")
+
+
+def read_transients_from_hdf5(filepath, name):
+ event = f"transient_outputs_{name}"
+ z_score = read_hdf5(event, filepath, "z_score")
+ ts = read_hdf5(event, filepath, "timestamps")
+ peaksInd = read_hdf5(event, filepath, "peaksInd")
+ return z_score, ts, peaksInd
diff --git a/src/guppy/frontend/__init__.py b/src/guppy/frontend/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/guppy/frontend/artifact_removal.py b/src/guppy/frontend/artifact_removal.py
new file mode 100644
index 0000000..f626190
--- /dev/null
+++ b/src/guppy/frontend/artifact_removal.py
@@ -0,0 +1,75 @@
+import logging
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from ..visualization.preprocessing import visualize_control_signal_fit
+
+logger = logging.getLogger(__name__)
+
+# Only set matplotlib backend if not in CI environment
+if not os.getenv("CI"):
+ plt.switch_backend("TKAgg")
+
+
+class ArtifactRemovalWidget:
+
+ def __init__(self, filepath, x, y1, y2, y3, plot_name, removeArtifacts):
+ self.coords = [] # List to store selected coordinates
+ self.filepath = filepath
+ self.plot_name = plot_name
+
+ if (y1 == 0).all() == True:
+ y1 = np.zeros(x.shape[0])
+
+ coords_path = os.path.join(filepath, "coordsForPreProcessing_" + plot_name[0].split("_")[-1] + ".npy")
+ artifacts_have_been_removed = removeArtifacts and os.path.exists(coords_path)
+ name = os.path.basename(filepath)
+ self.fig, self.ax1, self.ax2, self.ax3 = visualize_control_signal_fit(
+ x, y1, y2, y3, plot_name, name, artifacts_have_been_removed
+ )
+
+ self.cid = self.fig.canvas.mpl_connect("key_press_event", self._on_key_press)
+ self.fig.canvas.mpl_connect("close_event", self._on_close)
+
+ def _on_key_press(self, event):
+ """Handle key press events for artifact selection.
+
+ Pressing 'space' draws a vertical line at the cursor position to mark artifact boundaries.
+ Pressing 'd' removes the most recently added line.
+ """
+ if event.key == " ":
+ ix, iy = event.xdata, event.ydata
+ logger.info(f"x = {ix}, y = {iy}")
+ self.ax1.axvline(ix, c="black", ls="--")
+ self.ax2.axvline(ix, c="black", ls="--")
+ self.ax3.axvline(ix, c="black", ls="--")
+
+ self.fig.canvas.draw()
+
+ self.coords.append((ix, iy))
+
+ return self.coords
+
+ elif event.key == "d":
+ if len(self.coords) > 0:
+ logger.info(f"x = {self.coords[-1][0]}, y = {self.coords[-1][1]}; deleted")
+ del self.coords[-1]
+ self.ax1.lines[-1].remove()
+ self.ax2.lines[-1].remove()
+ self.ax3.lines[-1].remove()
+ self.fig.canvas.draw()
+
+ return self.coords
+
+ def _on_close(self, _event):
+ """Handle figure close event by saving coordinates and cleaning up."""
+ if self.coords and len(self.coords) > 0:
+ name_1 = self.plot_name[0].split("_")[-1]
+ np.save(os.path.join(self.filepath, "coordsForPreProcessing_" + name_1 + ".npy"), self.coords)
+ logger.info(
+ f"Coordinates file saved at {os.path.join(self.filepath, 'coordsForPreProcessing_'+name_1+'.npy')}"
+ )
+ self.fig.canvas.mpl_disconnect(self.cid)
+ self.coords = []
diff --git a/src/guppy/frontend/frontend_utils.py b/src/guppy/frontend/frontend_utils.py
new file mode 100644
index 0000000..572e798
--- /dev/null
+++ b/src/guppy/frontend/frontend_utils.py
@@ -0,0 +1,19 @@
+import logging
+import socket
+from random import randint
+
+logger = logging.getLogger(__name__)
+
+
+def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"):
+ while True:
+ port = randint(start_port, end_port)
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(0.001) # Set timeout to avoid long waiting on closed ports
+ result = sock.connect_ex((host, port))
+ if result == 0: # If the connection is successful, the port is open
+ continue
+ else:
+ break
+
+ return port
diff --git a/src/guppy/frontend/input_parameters.py b/src/guppy/frontend/input_parameters.py
new file mode 100644
index 0000000..1ae8313
--- /dev/null
+++ b/src/guppy/frontend/input_parameters.py
@@ -0,0 +1,385 @@
+import logging
+import os
+
+import numpy as np
+import pandas as pd
+import panel as pn
+
+logger = logging.getLogger(__name__)
+
+
+def checkSameLocation(arr, abspath):
+ # abspath = []
+ for i in range(len(arr)):
+ abspath.append(os.path.dirname(arr[i]))
+ abspath = np.asarray(abspath)
+ abspath = np.unique(abspath)
+ if len(abspath) > 1:
+ logger.error("All the folders selected should be at the same location")
+ raise Exception("All the folders selected should be at the same location")
+
+ return abspath
+
+
+def getAbsPath(files_1, files_2):
+ arr_1, arr_2 = files_1.value, files_2.value
+ if len(arr_1) == 0 and len(arr_2) == 0:
+ logger.error("No folder is selected for analysis")
+ raise Exception("No folder is selected for analysis")
+
+ abspath = []
+ if len(arr_1) > 0:
+ abspath = checkSameLocation(arr_1, abspath)
+ else:
+ abspath = checkSameLocation(arr_2, abspath)
+
+ abspath = np.unique(abspath)
+ if len(abspath) > 1:
+ logger.error("All the folders selected should be at the same location")
+ raise Exception("All the folders selected should be at the same location")
+ return abspath
+
+
+class ParameterForm:
+ def __init__(self, *, template, folder_path):
+ self.template = template
+ self.folder_path = folder_path
+ self.styles = dict(background="WhiteSmoke")
+ self.setup_individual_parameters()
+ self.setup_group_parameters()
+ self.setup_visualization_parameters()
+ self.add_to_template()
+
+ def setup_individual_parameters(self):
+ # Individual analysis components
+ self.mark_down_1 = pn.pane.Markdown(
+ """**Select folders for the analysis from the file selector below**""", width=600
+ )
+
+ self.files_1 = pn.widgets.FileSelector(self.folder_path, name="folderNames", width=950)
+
+ self.explain_modality = pn.pane.Markdown(
+ """
+ **Data Modality:** Select the type of data acquisition system used for your recordings:
+ - **tdt**: Tucker-Davis Technologies system
+ - **csv**: Generic CSV format
+ - **doric**: Doric Photometry system
+ - **npm**: Neurophotometrics system
+ """,
+ width=600,
+ )
+
+ self.modality_selector = pn.widgets.Select(
+ name="Data Modality", value="tdt", options=["tdt", "csv", "doric", "npm"], width=320
+ )
+
+ self.explain_time_artifacts = pn.pane.Markdown(
+ """
+ - ***Number of cores :*** Number of cores used for analysis. Try to
+ keep it less than the number of cores in your machine.
+ - ***Combine Data? :*** Make this parameter ``` True ``` if user wants to combine
+ the data, especially when there is two different
+ data files for the same recording session.
+ - ***Isosbestic Control Channel? :*** Make this parameter ``` False ``` if user
+ does not want to use isosbestic control channel in the analysis.
+ - ***Eliminate first few seconds :*** It is the parameter to cut out first x seconds
+ from the data. Default is 1 seconds.
+ - ***Window for Moving Average filter :*** The filtering of signals
+ is done using moving average filter. Default window used for moving
+ average filter is 100 datapoints. Change it based on the requirement.
+ - ***Moving Window (transients detection) :*** Transients in the z-score
+ and/or \u0394F/F are detected using this moving window.
+ Default is 15 seconds. Change it based on the requirement.
+ - ***High Amplitude filtering threshold (HAFT) (transients detection) :*** High amplitude
+ events greater than x times the MAD above the median are filtered out. Here, x is
+ high amplitude filtering threshold. Default is 2.
+ - ***Transients detection threshold (TD Thresh):*** Peaks with local maxima greater than x times
+ the MAD above the median of the trace (after filtering high amplitude events) are detected
+ as transients. Here, x is transients detection threshold. Default is 3.
+ - ***Number of channels (Neurophotometrics only) :*** Number of
+ channels used while recording, when data files has no column names mentioning "Flags"
+ or "LedState".
+ - ***removeArtifacts? :*** Make this parameter ``` True``` if there are
+ artifacts and user wants to remove the artifacts.
+ - ***removeArtifacts method :*** Selecting ```concatenate``` will remove bad
+ chunks and concatenate the selected good chunks together.
+ Selecting ```replace with NaN``` will replace bad chunks with NaN
+ values.
+ """,
+ width=350,
+ )
+
+ self.timeForLightsTurnOn = pn.widgets.LiteralInput(
+ name="Eliminate first few seconds (int)", value=1, type=int, width=320
+ )
+
+ self.isosbestic_control = pn.widgets.Select(
+ name="Isosbestic Control Channel? (bool)", value=True, options=[True, False], width=320
+ )
+
+ self.numberOfCores = pn.widgets.LiteralInput(name="# of cores (int)", value=2, type=int, width=150)
+
+ self.combine_data = pn.widgets.Select(
+ name="Combine Data? (bool)", value=False, options=[True, False], width=150
+ )
+
+ self.computePsth = pn.widgets.Select(
+ name="z_score and/or \u0394F/F? (psth)", options=["z_score", "dff", "Both"], width=320
+ )
+
+ self.transients = pn.widgets.Select(
+ name="z_score and/or \u0394F/F? (transients)", options=["z_score", "dff", "Both"], width=320
+ )
+
+ self.plot_zScore_dff = pn.widgets.Select(
+ name="z-score plot and/or \u0394F/F plot?",
+ options=["z_score", "dff", "Both", "None"],
+ value="None",
+ width=320,
+ )
+
+ self.moving_wd = pn.widgets.LiteralInput(
+ name="Moving Window for transients detection (s) (int)", value=15, type=int, width=320
+ )
+
+ self.highAmpFilt = pn.widgets.LiteralInput(name="HAFT (int)", value=2, type=int, width=150)
+
+ self.transientsThresh = pn.widgets.LiteralInput(name="TD Thresh (int)", value=3, type=int, width=150)
+
+ self.moving_avg_filter = pn.widgets.LiteralInput(
+ name="Window for Moving Average filter (int)", value=100, type=int, width=320
+ )
+
+ self.removeArtifacts = pn.widgets.Select(
+ name="removeArtifacts? (bool)", value=False, options=[True, False], width=150
+ )
+
+ self.artifactsRemovalMethod = pn.widgets.Select(
+ name="removeArtifacts method", value="concatenate", options=["concatenate", "replace with NaN"], width=150
+ )
+
+ self.no_channels_np = pn.widgets.LiteralInput(
+ name="Number of channels (Neurophotometrics only)", value=2, type=int, width=320
+ )
+
+ self.z_score_computation = pn.widgets.Select(
+ name="z-score computation Method",
+ options=["standard z-score", "baseline z-score", "modified z-score"],
+ value="standard z-score",
+ width=200,
+ )
+
+ self.baseline_wd_strt = pn.widgets.LiteralInput(
+ name="Baseline Window Start Time (s) (int)", value=0, type=int, width=200
+ )
+ self.baseline_wd_end = pn.widgets.LiteralInput(
+ name="Baseline Window End Time (s) (int)", value=0, type=int, width=200
+ )
+
+ self.explain_z_score = pn.pane.Markdown(
+ """
+ ***Note :***
+ - Details about z-score computation methods are explained in Github wiki.
+ - The details will make user understand what computation method to use for
+ their data.
+ - Baseline Window Parameters should be kept 0 unless you are using baseline
+ z-score computation method. The parameters are in seconds.
+ """,
+ width=580,
+ )
+
+ self.explain_nsec = pn.pane.Markdown(
+ """
+ - ***Time Interval :*** To omit bursts of event timestamps, user defined time interval
+ is set so that if the time difference between two timestamps is less than this defined time
+ interval, it will be deleted for the calculation of PSTH.
+ - ***Compute Cross-correlation :*** Make this parameter ```True```, when user wants
+ to compute cross-correlation between PSTHs of two different signals or signals
+ recorded from different brain regions.
+ """,
+ width=580,
+ )
+
+ self.nSecPrev = pn.widgets.LiteralInput(name="Seconds before 0 (int)", value=-10, type=int, width=120)
+
+ self.nSecPost = pn.widgets.LiteralInput(name="Seconds after 0 (int)", value=20, type=int, width=120)
+
+ self.computeCorr = pn.widgets.Select(
+ name="Compute Cross-correlation (bool)", options=[True, False], value=False, width=200
+ )
+
+ self.timeInterval = pn.widgets.LiteralInput(name="Time Interval (s)", value=2, type=int, width=120)
+
+ self.use_time_or_trials = pn.widgets.Select(
+ name="Bin PSTH trials (str)", options=["Time (min)", "# of trials"], value="Time (min)", width=120
+ )
+
+ self.bin_psth_trials = pn.widgets.LiteralInput(
+ name="Time(min) / # of trials \n for binning? (int)", value=0, type=int, width=200
+ )
+
+ self.explain_baseline = pn.pane.Markdown(
+ """
+ ***Note :***
+ - If user does not want to do baseline correction,
+ put both parameters 0.
+ - If the first event timestamp is less than the length of baseline
+ window, it will be rejected in the PSTH computation step.
+ - Baseline parameters must be within the PSTH parameters
+ set in the PSTH parameters section.
+ """,
+ width=580,
+ )
+
+ self.baselineCorrectionStart = pn.widgets.LiteralInput(
+ name="Baseline Correction Start time(int)", value=-5, type=int, width=200
+ )
+
+ self.baselineCorrectionEnd = pn.widgets.LiteralInput(
+ name="Baseline Correction End time(int)", value=0, type=int, width=200
+ )
+
+ self.zscore_param_wd = pn.WidgetBox(
+ "### Z-score Parameters",
+ self.explain_z_score,
+ self.z_score_computation,
+ pn.Row(self.baseline_wd_strt, self.baseline_wd_end),
+ width=600,
+ )
+
+ self.psth_param_wd = pn.WidgetBox(
+ "### PSTH Parameters",
+ self.explain_nsec,
+ pn.Row(self.nSecPrev, self.nSecPost, self.computeCorr),
+ pn.Row(self.timeInterval, self.use_time_or_trials, self.bin_psth_trials),
+ width=600,
+ )
+
+ self.baseline_param_wd = pn.WidgetBox(
+ "### Baseline Parameters",
+ self.explain_baseline,
+ pn.Row(self.baselineCorrectionStart, self.baselineCorrectionEnd),
+ width=600,
+ )
+ self.peak_explain = pn.pane.Markdown(
+ """
+ ***Note :***
+ - Peak and area are computed between the window set below.
+ - Peak and AUC parameters must be within the PSTH parameters set in the PSTH parameters section.
+ - Please make sure when user changes the parameters in the table below, click on any other cell after
+ changing a value in a particular cell.
+ """,
+ width=580,
+ )
+
+ self.start_end_point_df = pd.DataFrame(
+ {
+ "Peak Start time": [-5, 0, 5, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "Peak End time": [0, 3, 10, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ }
+ )
+
+ self.df_widget = pn.widgets.Tabulator(self.start_end_point_df, name="DataFrame", show_index=False, widths=280)
+
+ self.peak_param_wd = pn.WidgetBox("### Peak and AUC Parameters", self.peak_explain, self.df_widget, width=600)
+
+ self.individual_analysis_wd_2 = pn.Column(
+ self.explain_time_artifacts,
+ pn.Row(self.numberOfCores, self.combine_data),
+ self.isosbestic_control,
+ self.timeForLightsTurnOn,
+ self.moving_avg_filter,
+ self.computePsth,
+ self.transients,
+ self.plot_zScore_dff,
+ self.moving_wd,
+ pn.Row(self.highAmpFilt, self.transientsThresh),
+ self.no_channels_np,
+ pn.Row(self.removeArtifacts, self.artifactsRemovalMethod),
+ )
+
+ self.psth_baseline_param = pn.Column(
+ self.zscore_param_wd, self.psth_param_wd, self.baseline_param_wd, self.peak_param_wd
+ )
+
+ self.widget = pn.Column(
+ self.mark_down_1,
+ self.files_1,
+ self.explain_modality,
+ self.modality_selector,
+ pn.Row(self.individual_analysis_wd_2, self.psth_baseline_param),
+ )
+ self.individual = pn.Card(self.widget, title="Individual Analysis", styles=self.styles, width=1000)
+
+ def setup_group_parameters(self):
+ self.mark_down_2 = pn.pane.Markdown(
+ """**Select folders for the average analysis from the file selector below**""", width=600
+ )
+
+ self.files_2 = pn.widgets.FileSelector(self.folder_path, name="folderNamesForAvg", width=950)
+
+ self.averageForGroup = pn.widgets.Select(
+ name="Average Group? (bool)", value=False, options=[True, False], width=435
+ )
+
+ self.group_analysis_wd_1 = pn.Column(self.mark_down_2, self.files_2, self.averageForGroup, width=800)
+ self.group = pn.Card(self.group_analysis_wd_1, title="Group Analysis", styles=self.styles, width=1000)
+
+ def setup_visualization_parameters(self):
+ self.visualizeAverageResults = pn.widgets.Select(
+ name="Visualize Average Results? (bool)", value=False, options=[True, False], width=435
+ )
+
+ self.visualize_zscore_or_dff = pn.widgets.Select(
+ name="z-score or \u0394F/F? (for visualization)", options=["z_score", "dff"], width=435
+ )
+
+ self.visualization_wd = pn.Row(self.visualize_zscore_or_dff, pn.Spacer(width=60), self.visualizeAverageResults)
+ self.visualize = pn.Card(
+ self.visualization_wd, title="Visualization Parameters", styles=self.styles, width=1000
+ )
+
+ def add_to_template(self):
+ self.template.main.append(self.individual)
+ self.template.main.append(self.group)
+ self.template.main.append(self.visualize)
+
+ def getInputParameters(self):
+ abspath = getAbsPath(self.files_1, self.files_2)
+ inputParameters = {
+ "abspath": abspath[0],
+ "folderNames": self.files_1.value,
+ "modality": self.modality_selector.value,
+ "numberOfCores": self.numberOfCores.value,
+ "combine_data": self.combine_data.value,
+ "isosbestic_control": self.isosbestic_control.value,
+ "timeForLightsTurnOn": self.timeForLightsTurnOn.value,
+ "filter_window": self.moving_avg_filter.value,
+ "removeArtifacts": self.removeArtifacts.value,
+ "artifactsRemovalMethod": self.artifactsRemovalMethod.value,
+ "noChannels": self.no_channels_np.value,
+ "zscore_method": self.z_score_computation.value,
+ "baselineWindowStart": self.baseline_wd_strt.value,
+ "baselineWindowEnd": self.baseline_wd_end.value,
+ "nSecPrev": self.nSecPrev.value,
+ "nSecPost": self.nSecPost.value,
+ "computeCorr": self.computeCorr.value,
+ "timeInterval": self.timeInterval.value,
+ "bin_psth_trials": self.bin_psth_trials.value,
+ "use_time_or_trials": self.use_time_or_trials.value,
+ "baselineCorrectionStart": self.baselineCorrectionStart.value,
+ "baselineCorrectionEnd": self.baselineCorrectionEnd.value,
+ "peak_startPoint": list(self.df_widget.value["Peak Start time"]), # startPoint.value,
+ "peak_endPoint": list(self.df_widget.value["Peak End time"]), # endPoint.value,
+ "selectForComputePsth": self.computePsth.value,
+ "selectForTransientsComputation": self.transients.value,
+ "moving_window": self.moving_wd.value,
+ "highAmpFilt": self.highAmpFilt.value,
+ "transientsThresh": self.transientsThresh.value,
+ "plot_zScore_dff": self.plot_zScore_dff.value,
+ "visualize_zscore_or_dff": self.visualize_zscore_or_dff.value,
+ "folderNamesForAvg": self.files_2.value,
+ "averageForGroup": self.averageForGroup.value,
+ "visualizeAverageResults": self.visualizeAverageResults.value,
+ }
+ return inputParameters
diff --git a/src/guppy/frontend/npm_gui_prompts.py b/src/guppy/frontend/npm_gui_prompts.py
new file mode 100644
index 0000000..a6f2077
--- /dev/null
+++ b/src/guppy/frontend/npm_gui_prompts.py
@@ -0,0 +1,104 @@
+import logging
+import tkinter as tk
+from tkinter import StringVar, messagebox, ttk
+
+logger = logging.getLogger(__name__)
+
+
+def get_multi_event_responses(multiple_event_ttls):
+ responses = []
+ for has_multiple in multiple_event_ttls:
+ if not has_multiple:
+ responses.append(False)
+ continue
+ window = tk.Tk()
+ response = messagebox.askyesno(
+ "Multiple event TTLs",
+ (
+ "Based on the TTL file, "
+ "it looks like TTLs "
+ "belong to multiple behavior types. "
+ "Do you want to create multiple files for each "
+ "behavior type?"
+ ),
+ )
+ window.destroy()
+ responses.append(response)
+ return responses
+
+
+def get_timestamp_configuration(ts_unit_needs, col_names_ts):
+ ts_units, npm_timestamp_column_names = [], []
+ for need in ts_unit_needs:
+ if not need:
+ ts_units.append("seconds")
+ npm_timestamp_column_names.append(None)
+ continue
+ window = tk.Tk()
+ window.title("Select appropriate options for timestamps")
+ window.geometry("500x200")
+ holdComboboxValues = dict()
+
+ timestamps_label = ttk.Label(window, text="Select which timestamps to use : ").grid(
+ row=0, column=1, pady=25, padx=25
+ )
+ holdComboboxValues["timestamps"] = StringVar()
+ timestamps_combo = ttk.Combobox(window, values=col_names_ts, textvariable=holdComboboxValues["timestamps"])
+ timestamps_combo.grid(row=0, column=2, pady=25, padx=25)
+ timestamps_combo.current(0)
+ # timestamps_combo.bind("<>", comboBoxSelected)
+
+ time_unit_label = ttk.Label(window, text="Select timestamps unit : ").grid(row=1, column=1, pady=25, padx=25)
+ holdComboboxValues["time_unit"] = StringVar()
+ time_unit_combo = ttk.Combobox(
+ window,
+ values=["", "seconds", "milliseconds", "microseconds"],
+ textvariable=holdComboboxValues["time_unit"],
+ )
+ time_unit_combo.grid(row=1, column=2, pady=25, padx=25)
+ time_unit_combo.current(0)
+ # time_unit_combo.bind("<>", comboBoxSelected)
+ window.lift()
+ window.after(500, lambda: window.lift())
+ window.mainloop()
+
+ if holdComboboxValues["timestamps"].get():
+ npm_timestamp_column_name = holdComboboxValues["timestamps"].get()
+ else:
+ messagebox.showerror(
+ "All options not selected",
+ "All the options for timestamps \
+ were not selected. Please select appropriate options",
+ )
+ logger.error(
+ "All the options for timestamps \
+ were not selected. Please select appropriate options"
+ )
+ raise Exception(
+ "All the options for timestamps \
+ were not selected. Please select appropriate options"
+ )
+ if holdComboboxValues["time_unit"].get():
+ if holdComboboxValues["time_unit"].get() == "seconds":
+ ts_unit = holdComboboxValues["time_unit"].get()
+ elif holdComboboxValues["time_unit"].get() == "milliseconds":
+ ts_unit = holdComboboxValues["time_unit"].get()
+ else:
+ ts_unit = holdComboboxValues["time_unit"].get()
+ else:
+ messagebox.showerror(
+ "All options not selected",
+ "All the options for timestamps \
+ were not selected. Please select appropriate options",
+ )
+ logger.error(
+ "All the options for timestamps \
+ were not selected. Please select appropriate options"
+ )
+ raise Exception(
+ "All the options for timestamps \
+ were not selected. Please select appropriate options"
+ )
+ ts_units.append(ts_unit)
+ npm_timestamp_column_names.append(npm_timestamp_column_name)
+ return ts_units, npm_timestamp_column_names
diff --git a/src/guppy/frontend/parameterized_plotter.py b/src/guppy/frontend/parameterized_plotter.py
new file mode 100644
index 0000000..2d7d2f1
--- /dev/null
+++ b/src/guppy/frontend/parameterized_plotter.py
@@ -0,0 +1,582 @@
+import logging
+import math
+import os
+import re
+
+import datashader as ds
+import holoviews as hv
+import numpy as np
+import pandas as pd
+import panel as pn
+import param
+from bokeh.io import export_png, export_svgs
+from holoviews import opts
+from holoviews.operation.datashader import datashade
+from holoviews.plotting.util import process_cmap
+
+pn.extension()
+
+logger = logging.getLogger(__name__)
+
+
+# remove unnecessary column names
+def remove_cols(cols):
+ regex = re.compile("bin_err_*")
+ remove_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
+ remove_cols = remove_cols + ["err", "timestamps"]
+ cols = [i for i in cols if i not in remove_cols]
+
+ return cols
+
+
+# make a new directory for saving plots
+def make_dir(filepath):
+ op = os.path.join(filepath, "saved_plots")
+ if not os.path.exists(op):
+ os.mkdir(op)
+
+ return op
+
+
+# create a class to make GUI and plot different graphs
+class ParameterizedPlotter(param.Parameterized):
+ event_selector_objects = param.List(default=None)
+ event_selector_heatmap_objects = param.List(default=None)
+ selector_for_multipe_events_plot_objects = param.List(default=None)
+ color_map_objects = param.List(default=None)
+ x_objects = param.List(default=None)
+ y_objects = param.List(default=None)
+ heatmap_y_objects = param.List(default=None)
+ psth_y_objects = param.List(default=None)
+
+ filepath = param.Path(default=None)
+ # create different options and selectors
+ event_selector = param.ObjectSelector(default=None)
+ event_selector_heatmap = param.ObjectSelector(default=None)
+ selector_for_multipe_events_plot = param.ListSelector(default=None)
+ columns_dict = param.Dict(default=None)
+ df_new = param.DataFrame(default=None)
+ x_min = param.Number(default=None)
+ x_max = param.Number(default=None)
+ select_trials_checkbox = param.ListSelector(default=["just trials"], objects=["mean", "just trials"])
+ Y_Label = param.ObjectSelector(default="y", objects=["y", "z-score", "\u0394F/F"])
+ save_options = param.ObjectSelector(
+ default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"]
+ )
+ save_options_heatmap = param.ObjectSelector(
+ default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"]
+ )
+ color_map = param.ObjectSelector(default="plasma")
+ height_heatmap = param.ObjectSelector(default=600, objects=list(np.arange(0, 5100, 100))[1:])
+ width_heatmap = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:])
+ Height_Plot = param.ObjectSelector(default=300, objects=list(np.arange(0, 5100, 100))[1:])
+ Width_Plot = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:])
+ save_hm = param.Action(lambda x: x.param.trigger("save_hm"), label="Save")
+ save_psth = param.Action(lambda x: x.param.trigger("save_psth"), label="Save")
+ X_Limit = param.Range(default=(-5, 10))
+ Y_Limit = param.Range(bounds=(-50, 50.0))
+
+ x = param.ObjectSelector(default=None)
+ y = param.ObjectSelector(default=None)
+ heatmap_y = param.ListSelector(default=None)
+ psth_y = param.ListSelector(default=None)
+ results_hm = dict()
+ results_psth = dict()
+
+ def __init__(self, **params):
+ super().__init__(**params)
+ # Bind selector objects from companion params
+ self.param.event_selector.objects = self.event_selector_objects
+ self.param.event_selector_heatmap.objects = self.event_selector_heatmap_objects
+ self.param.selector_for_multipe_events_plot.objects = self.selector_for_multipe_events_plot_objects
+ self.param.color_map.objects = self.color_map_objects
+ self.param.x.objects = self.x_objects
+ self.param.y.objects = self.y_objects
+ self.param.heatmap_y.objects = self.heatmap_y_objects
+ self.param.psth_y.objects = self.psth_y_objects
+
+ # Set defaults
+ self.event_selector = self.event_selector_objects[0]
+ self.event_selector_heatmap = self.event_selector_heatmap_objects[0]
+ self.selector_for_multipe_events_plot = [self.selector_for_multipe_events_plot_objects[0]]
+ self.x = self.x_objects[0]
+ self.y = self.y_objects[-2]
+ self.heatmap_y = [self.heatmap_y_objects[-1]]
+
+ self.param.X_Limit.bounds = (self.x_min, self.x_max)
+
+ # function to save heatmaps when save button on heatmap tab is clicked
+ @param.depends("save_hm", watch=True)
+ def save_hm_plots(self):
+ plot = self.results_hm["plot"]
+ op = self.results_hm["op"]
+ save_opts = self.save_options_heatmap
+ logger.info(save_opts)
+ if save_opts == "save_svg_format":
+ p = hv.render(plot, backend="bokeh")
+ p.output_backend = "svg"
+ export_svgs(p, filename=op + ".svg")
+ elif save_opts == "save_png_format":
+ p = hv.render(plot, backend="bokeh")
+ export_png(p, filename=op + ".png")
+ elif save_opts == "save_both_format":
+ p = hv.render(plot, backend="bokeh")
+ p.output_backend = "svg"
+ export_svgs(p, filename=op + ".svg")
+ p_png = hv.render(plot, backend="bokeh")
+ export_png(p_png, filename=op + ".png")
+ else:
+ return 0
+
+ # function to save PSTH plots when save button on PSTH tab is clicked
+ @param.depends("save_psth", watch=True)
+ def save_psth_plot(self):
+ plot, op = [], []
+ plot.append(self.results_psth["plot_combine"])
+ op.append(self.results_psth["op_combine"])
+ plot.append(self.results_psth["plot"])
+ op.append(self.results_psth["op"])
+ for i in range(len(plot)):
+ temp_plot, temp_op = plot[i], op[i]
+ save_opts = self.save_options
+ if save_opts == "save_svg_format":
+ p = hv.render(temp_plot, backend="bokeh")
+ p.output_backend = "svg"
+ export_svgs(p, filename=temp_op + ".svg")
+ elif save_opts == "save_png_format":
+ p = hv.render(temp_plot, backend="bokeh")
+ export_png(p, filename=temp_op + ".png")
+ elif save_opts == "save_both_format":
+ p = hv.render(temp_plot, backend="bokeh")
+ p.output_backend = "svg"
+ export_svgs(p, filename=temp_op + ".svg")
+ p_png = hv.render(temp_plot, backend="bokeh")
+ export_png(p_png, filename=temp_op + ".png")
+ else:
+ return 0
+
+ # function to change Y values based on event selection
+ @param.depends("event_selector", watch=True)
+ def _update_x_y(self):
+ x_value = self.columns_dict[self.event_selector]
+ y_value = self.columns_dict[self.event_selector]
+ self.param["x"].objects = [x_value[-4]]
+ self.param["y"].objects = remove_cols(y_value)
+ self.x = x_value[-4]
+ self.y = self.param["y"].objects[-2]
+
+ @param.depends("event_selector_heatmap", watch=True)
+ def _update_df(self):
+ cols = self.columns_dict[self.event_selector_heatmap]
+ trial_no = range(1, len(remove_cols(cols)[:-2]) + 1)
+ trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + ["All"]
+ self.param["heatmap_y"].objects = trial_ts
+ self.heatmap_y = [trial_ts[-1]]
+
+ @param.depends("event_selector", watch=True)
+ def _update_psth_y(self):
+ cols = self.columns_dict[self.event_selector]
+ trial_no = range(1, len(remove_cols(cols)[:-2]) + 1)
+ trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])]
+ self.param["psth_y"].objects = trial_ts
+ self.psth_y = [trial_ts[0]]
+
+ # function to plot multiple PSTHs into one plot
+
+ @param.depends(
+ "selector_for_multipe_events_plot",
+ "Y_Label",
+ "save_options",
+ "X_Limit",
+ "Y_Limit",
+ "Height_Plot",
+ "Width_Plot",
+ )
+ def update_selector(self):
+ data_curve, cols_curve, data_spread, cols_spread = [], [], [], []
+ arr = self.selector_for_multipe_events_plot
+ df1 = self.df_new
+ for i in range(len(arr)):
+ if "bin" in arr[i]:
+ split = arr[i].rsplit("_", 2)
+ df_name = split[0] #'{}_{}'.format(split[0], split[1])
+ col_name_mean = "{}_{}".format(split[-2], split[-1])
+ col_name_err = "{}_err_{}".format(split[-2], split[-1])
+ data_curve.append(df1[df_name][col_name_mean])
+ cols_curve.append(arr[i])
+ data_spread.append(df1[df_name][col_name_err])
+ cols_spread.append(arr[i])
+ else:
+ data_curve.append(df1[arr[i]]["mean"])
+ cols_curve.append(arr[i] + "_" + "mean")
+ data_spread.append(df1[arr[i]]["err"])
+ cols_spread.append(arr[i] + "_" + "mean")
+
+ if len(arr) > 0:
+ if self.Y_Limit == None:
+ self.Y_Limit = (np.nanmin(np.asarray(data_curve)) - 0.5, np.nanmax(np.asarray(data_curve)) + 0.5)
+
+ if "bin" in arr[i]:
+ split = arr[i].rsplit("_", 2)
+ df_name = split[0]
+ data_curve.append(df1[df_name]["timestamps"])
+ cols_curve.append("timestamps")
+ data_spread.append(df1[df_name]["timestamps"])
+ cols_spread.append("timestamps")
+ else:
+ data_curve.append(df1[arr[i]]["timestamps"])
+ cols_curve.append("timestamps")
+ data_spread.append(df1[arr[i]]["timestamps"])
+ cols_spread.append("timestamps")
+ df_curve = pd.concat(data_curve, axis=1)
+ df_spread = pd.concat(data_spread, axis=1)
+ df_curve.columns = cols_curve
+ df_spread.columns = cols_spread
+
+ ts = df_curve["timestamps"]
+ index = np.arange(0, ts.shape[0], 3)
+ df_curve = df_curve.loc[index, :]
+ df_spread = df_spread.loc[index, :]
+ overlay = hv.NdOverlay(
+ {
+ c: hv.Curve((df_curve["timestamps"], df_curve[c]), kdims=["Time (s)"]).opts(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ )
+ for c in cols_curve[:-1]
+ }
+ )
+ spread = hv.NdOverlay(
+ {
+ d: hv.Spread(
+ (df_spread["timestamps"], df_curve[d], df_spread[d], df_spread[d]),
+ vdims=["y", "yerrpos", "yerrneg"],
+ ).opts(line_width=0, fill_alpha=0.3)
+ for d in cols_spread[:-1]
+ }
+ )
+ plot_combine = ((overlay * spread).opts(opts.NdOverlay(xlabel="Time (s)", ylabel=self.Y_Label))).opts(
+ shared_axes=False
+ )
+ # plot_err = new_df.hvplot.area(x='timestamps', y=[], y2=[])
+ save_opts = self.save_options
+ op = make_dir(self.filepath)
+ op_filename = os.path.join(op, str(arr) + "_mean")
+
+ self.results_psth["plot_combine"] = plot_combine
+ self.results_psth["op_combine"] = op_filename
+ # self.save_plots(plot_combine, save_opts, op_filename)
+ return plot_combine
+
+ # function to plot mean PSTH, single trial in PSTH and all the trials of PSTH with mean
+ @param.depends(
+ "event_selector", "x", "y", "Y_Label", "save_options", "Y_Limit", "X_Limit", "Height_Plot", "Width_Plot"
+ )
+ def contPlot(self):
+ df1 = self.df_new[self.event_selector]
+ # height = self.Heigth_Plot
+ # width = self.Width_Plot
+ # logger.info(height, width)
+ if self.y == "All":
+ if self.Y_Limit == None:
+ self.Y_Limit = (np.nanmin(np.asarray(df1)) - 0.5, np.nanmax(np.asarray(df1)) - 0.5)
+
+ options = self.param["y"].objects
+ regex = re.compile("bin_[(]")
+ remove_bin_trials = [options[i] for i in range(len(options)) if not regex.match(options[i])]
+
+ ndoverlay = hv.NdOverlay({c: hv.Curve((df1[self.x], df1[c])) for c in remove_bin_trials[:-2]})
+ img1 = datashade(ndoverlay, normalization="linear", aggregator=ds.count())
+ x_points = df1[self.x]
+ y_points = df1["mean"]
+ img2 = hv.Curve((x_points, y_points))
+ img = (img1 * img2).opts(
+ opts.Curve(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ line_width=4,
+ color="black",
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+ )
+
+ save_opts = self.save_options
+
+ op = make_dir(self.filepath)
+ op_filename = os.path.join(op, self.event_selector + "_" + self.y)
+ self.results_psth["plot"] = img
+ self.results_psth["op"] = op_filename
+ # self.save_plots(img, save_opts, op_filename)
+
+ return img
+
+ elif self.y == "mean" or "bin" in self.y:
+
+ xpoints = df1[self.x]
+ ypoints = df1[self.y]
+ if self.y == "mean":
+ err = df1["err"]
+ else:
+ split = self.y.split("_")
+ err = df1["{}_err_{}".format(split[0], split[1])]
+
+ index = np.arange(0, xpoints.shape[0], 3)
+
+ if self.Y_Limit == None:
+ self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5)
+
+ ropts_curve = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ color="blue",
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+ ropts_spread = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ fill_alpha=0.3,
+ fill_color="blue",
+ line_width=0,
+ )
+
+ plot_curve = hv.Curve((xpoints[index], ypoints[index])) # .opts(**ropts_curve)
+ plot_spread = hv.Spread(
+ (xpoints[index], ypoints[index], err[index], err[index])
+ ) # .opts(**ropts_spread) #vdims=['y', 'yerrpos', 'yerrneg']
+ plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread})
+
+ save_opts = self.save_options
+ op = make_dir(self.filepath)
+ op_filename = os.path.join(op, self.event_selector + "_" + self.y)
+ self.results_psth["plot"] = plot
+ self.results_psth["op"] = op_filename
+ # self.save_plots(plot, save_opts, op_filename)
+
+ return plot
+
+ else:
+ xpoints = df1[self.x]
+ ypoints = df1[self.y]
+ if self.Y_Limit == None:
+ self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5)
+
+ ropts_curve = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ color="blue",
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+ plot = hv.Curve((xpoints, ypoints)).opts({"Curve": ropts_curve})
+
+ save_opts = self.save_options
+ op = make_dir(self.filepath)
+ op_filename = os.path.join(op, self.event_selector + "_" + self.y)
+ self.results_psth["plot"] = plot
+ self.results_psth["op"] = op_filename
+ # self.save_plots(plot, save_opts, op_filename)
+
+ return plot
+
+ # function to plot specific PSTH trials
+ @param.depends(
+ "event_selector",
+ "x",
+ "psth_y",
+ "select_trials_checkbox",
+ "Y_Label",
+ "save_options",
+ "Y_Limit",
+ "X_Limit",
+ "Height_Plot",
+ "Width_Plot",
+ )
+ def plot_specific_trials(self):
+ df_psth = self.df_new[self.event_selector]
+ # if self.Y_Limit==None:
+ # self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5)
+
+ if self.psth_y == None:
+ return None
+ else:
+ selected_trials = [s.split(" - ")[1] for s in list(self.psth_y)]
+
+ index = np.arange(0, df_psth["timestamps"].shape[0], 3)
+
+ if self.select_trials_checkbox == ["just trials"]:
+ overlay = hv.NdOverlay(
+ {
+ c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"])
+ for c in selected_trials
+ }
+ )
+ ropts = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+ return overlay.opts(**ropts)
+ elif self.select_trials_checkbox == ["mean"]:
+ arr = np.asarray(df_psth[selected_trials])
+ mean = np.nanmean(arr, axis=1)
+ err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1])
+ ropts_curve = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ color="blue",
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+ ropts_spread = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ fill_alpha=0.3,
+ fill_color="blue",
+ line_width=0,
+ )
+ plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index]))
+ plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index]))
+ plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread})
+ return plot
+ elif self.select_trials_checkbox == ["mean", "just trials"]:
+ overlay = hv.NdOverlay(
+ {
+ c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"])
+ for c in selected_trials
+ }
+ )
+ ropts_overlay = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+
+ arr = np.asarray(df_psth[selected_trials])
+ mean = np.nanmean(arr, axis=1)
+ err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1])
+ ropts_curve = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ xlim=self.X_Limit,
+ ylim=self.Y_Limit,
+ color="black",
+ xlabel="Time (s)",
+ ylabel=self.Y_Label,
+ )
+ ropts_spread = dict(
+ width=int(self.Width_Plot),
+ height=int(self.Height_Plot),
+ fill_alpha=0.3,
+ fill_color="black",
+ line_width=0,
+ )
+ plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index]))
+ plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index]))
+
+ plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread})
+ return overlay.opts(**ropts_overlay) * plot
+
+ # function to show heatmaps for each event
+ @param.depends("event_selector_heatmap", "color_map", "height_heatmap", "width_heatmap", "heatmap_y")
+ def heatmap(self):
+ height = self.height_heatmap
+ width = self.width_heatmap
+ df_hm = self.df_new[self.event_selector_heatmap]
+ cols = list(df_hm.columns)
+ regex = re.compile("bin_err_*")
+ drop_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
+ drop_cols = ["err", "mean"] + drop_cols
+ df_hm = df_hm.drop(drop_cols, axis=1)
+ cols = list(df_hm.columns)
+ bin_cols = [cols[i] for i in range(len(cols)) if re.compile("bin_*").match(cols[i])]
+ time = np.asarray(df_hm["timestamps"])
+ event_ts_for_each_event = np.arange(1, len(df_hm.columns[:-1]) + 1)
+ yticks = list(event_ts_for_each_event)
+ z_score = np.asarray(df_hm[df_hm.columns[:-1]]).T
+
+ if self.heatmap_y[0] == "All":
+ indices = np.arange(z_score.shape[0] - len(bin_cols))
+ z_score = z_score[indices, :]
+ event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1)
+ yticks = list(event_ts_for_each_event)
+ else:
+ remove_all = list(set(self.heatmap_y) - set(["All"]))
+ indices = sorted([int(s.split("-")[0]) - 1 for s in remove_all])
+ z_score = z_score[indices, :]
+ event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1)
+ yticks = list(event_ts_for_each_event)
+
+ clim = (np.nanmin(z_score), np.nanmax(z_score))
+ font_size = {"labels": 16, "yticks": 6}
+
+ if event_ts_for_each_event.shape[0] == 1:
+ dummy_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)).opts(colorbar=True, clim=clim)
+ image = (
+ (dummy_image).opts(
+ opts.QuadMesh(
+ width=int(width),
+ height=int(height),
+ cmap=process_cmap(self.color_map, provider="matplotlib"),
+ colorbar=True,
+ ylabel="Trials",
+ xlabel="Time (s)",
+ fontsize=font_size,
+ yticks=yticks,
+ )
+ )
+ ).opts(shared_axes=False)
+
+ save_opts = self.save_options_heatmap
+ op = make_dir(self.filepath)
+ op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap")
+ self.results_hm["plot"] = image
+ self.results_hm["op"] = op_filename
+ # self.save_plots(image, save_opts, op_filename)
+ return image
+ else:
+ ropts = dict(
+ width=int(width),
+ height=int(height),
+ ylabel="Trials",
+ xlabel="Time (s)",
+ fontsize=font_size,
+ yticks=yticks,
+ invert_yaxis=True,
+ )
+ dummy_image = hv.QuadMesh((time[0:100], event_ts_for_each_event, z_score[:, 0:100])).opts(
+ colorbar=True, cmap=process_cmap(self.color_map, provider="matplotlib"), clim=clim
+ )
+ actual_image = hv.QuadMesh((time, event_ts_for_each_event, z_score))
+
+ dynspread_img = datashade(actual_image, cmap=process_cmap(self.color_map, provider="matplotlib")).opts(
+ **ropts
+ ) # clims=self.C_Limit, cnorm='log'
+ image = ((dummy_image * dynspread_img).opts(opts.QuadMesh(width=int(width), height=int(height)))).opts(
+ shared_axes=False
+ )
+
+ save_opts = self.save_options_heatmap
+ op = make_dir(self.filepath)
+ op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap")
+ self.results_hm["plot"] = image
+ self.results_hm["op"] = op_filename
+
+ return image
diff --git a/src/guppy/frontend/path_selection.py b/src/guppy/frontend/path_selection.py
new file mode 100644
index 0000000..3d69efb
--- /dev/null
+++ b/src/guppy/frontend/path_selection.py
@@ -0,0 +1,40 @@
+import logging
+import os
+import tkinter as tk
+from tkinter import filedialog, ttk
+
+logger = logging.getLogger(__name__)
+
+
+def get_folder_path():
+ # Determine base folder path (headless-friendly via env var)
+ base_dir_env = os.environ.get("GUPPY_BASE_DIR")
+ is_headless = base_dir_env and os.path.isdir(base_dir_env)
+ if is_headless:
+ folder_path = base_dir_env
+ logger.info(f"Folder path set to {folder_path} (from GUPPY_BASE_DIR)")
+ return folder_path
+
+ # Create the main window
+ folder_selection = tk.Tk()
+ folder_selection.title("Select the folder path where your data is located")
+ folder_selection.geometry("700x200")
+
+ selected_path = {"value": None}
+
+ def select_folder():
+ selected = filedialog.askdirectory(title="Select the folder path where your data is located")
+ if selected:
+ logger.info(f"Folder path set to {selected}")
+ selected_path["value"] = selected
+ else:
+ default_path = os.path.expanduser("~")
+ logger.info(f"Folder path set to {default_path}")
+ selected_path["value"] = default_path
+ folder_selection.destroy()
+
+ select_button = ttk.Button(folder_selection, text="Select a Folder", command=select_folder)
+ select_button.pack(pady=5)
+ folder_selection.mainloop()
+
+ return selected_path["value"]
diff --git a/src/guppy/frontend/progress.py b/src/guppy/frontend/progress.py
new file mode 100644
index 0000000..fb5e7c2
--- /dev/null
+++ b/src/guppy/frontend/progress.py
@@ -0,0 +1,50 @@
+import logging
+import os
+import time
+
+logger = logging.getLogger(__name__)
+
+
+def readPBIncrementValues(progressBar):
+ logger.info("Read progress bar increment values function started...")
+ file_path = os.path.join(os.path.expanduser("~"), "pbSteps.txt")
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ increment, maximum = 0, 100
+ progressBar.value = increment
+ progressBar.bar_color = "success"
+ while True:
+ try:
+ with open(file_path, "r") as file:
+ content = file.readlines()
+ if len(content) == 0:
+ pass
+ else:
+ maximum = int(content[0])
+ increment = int(content[-1])
+
+ if increment == -1:
+ progressBar.bar_color = "danger"
+ os.remove(file_path)
+ break
+ progressBar.max = maximum
+ progressBar.value = increment
+ time.sleep(0.001)
+ except FileNotFoundError:
+ time.sleep(0.001)
+ except PermissionError:
+ time.sleep(0.001)
+ except Exception as e:
+ # Handle other exceptions that may occur
+ logger.info(f"An error occurred while reading the file: {e}")
+ break
+ if increment == maximum:
+ os.remove(file_path)
+ break
+
+ logger.info("Read progress bar increment values stopped.")
+
+
+def writeToFile(value: str):
+ with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
+ file.write(value)
diff --git a/src/guppy/frontend/sidebar.py b/src/guppy/frontend/sidebar.py
new file mode 100644
index 0000000..466a639
--- /dev/null
+++ b/src/guppy/frontend/sidebar.py
@@ -0,0 +1,75 @@
+import logging
+
+import panel as pn
+
+logger = logging.getLogger(__name__)
+
+
+class Sidebar:
+ def __init__(self, template):
+ self.template = template
+ self.setup_markdown()
+ self.setup_buttons()
+ self.setup_progress_bars()
+
+ def setup_markdown(self):
+ self.mark_down_ip = pn.pane.Markdown("""**Step 1 : Save Input Parameters**""", width=300)
+ self.mark_down_ip_note = pn.pane.Markdown(
+ """***Note : ***
+ - Save Input Parameters will save input parameters used for the analysis
+ in all the folders you selected for the analysis (useful for future
+ reference). All analysis steps will run without saving input parameters.
+ """,
+ width=300,
+ )
+ self.mark_down_storenames = pn.pane.Markdown(
+ """**Step 2 : Open Storenames GUI
and save storenames**""", width=300
+ )
+ self.mark_down_read = pn.pane.Markdown("""**Step 3 : Read Raw Data**""", width=300)
+ self.mark_down_preprocess = pn.pane.Markdown("""**Step 4 : Preprocess and Remove Artifacts**""", width=300)
+ self.mark_down_psth = pn.pane.Markdown("""**Step 5 : PSTH Computation**""", width=300)
+ self.mark_down_visualization = pn.pane.Markdown("""**Step 6 : Visualization**""", width=300)
+
+ def setup_buttons(self):
+ self.open_storenames = pn.widgets.Button(
+ name="Open Storenames GUI", button_type="primary", width=300, align="end"
+ )
+ self.read_rawData = pn.widgets.Button(name="Read Raw Data", button_type="primary", width=300, align="end")
+ self.preprocess = pn.widgets.Button(
+ name="Preprocess and Remove Artifacts", button_type="primary", width=300, align="end"
+ )
+ self.psth_computation = pn.widgets.Button(
+ name="PSTH Computation", button_type="primary", width=300, align="end"
+ )
+ self.open_visualization = pn.widgets.Button(
+ name="Open Visualization GUI", button_type="primary", width=300, align="end"
+ )
+ self.save_button = pn.widgets.Button(name="Save to file...", button_type="primary", width=300, align="end")
+
+ def attach_callbacks(self, button_name_to_onclick_fn: dict):
+ for button_name, onclick_fn in button_name_to_onclick_fn.items():
+ button = getattr(self, button_name)
+ button.on_click(onclick_fn)
+
+ def setup_progress_bars(self):
+ self.read_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300)
+ self.extract_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300)
+ self.psth_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300)
+
+ def add_to_template(self):
+ self.template.sidebar.append(self.mark_down_ip)
+ self.template.sidebar.append(self.mark_down_ip_note)
+ self.template.sidebar.append(self.save_button)
+ self.template.sidebar.append(self.mark_down_storenames)
+ self.template.sidebar.append(self.open_storenames)
+ self.template.sidebar.append(self.mark_down_read)
+ self.template.sidebar.append(self.read_rawData)
+ self.template.sidebar.append(self.read_progress)
+ self.template.sidebar.append(self.mark_down_preprocess)
+ self.template.sidebar.append(self.preprocess)
+ self.template.sidebar.append(self.extract_progress)
+ self.template.sidebar.append(self.mark_down_psth)
+ self.template.sidebar.append(self.psth_computation)
+ self.template.sidebar.append(self.psth_progress)
+ self.template.sidebar.append(self.mark_down_visualization)
+ self.template.sidebar.append(self.open_visualization)
diff --git a/src/guppy/frontend/storenames_config.py b/src/guppy/frontend/storenames_config.py
new file mode 100644
index 0000000..c7aa419
--- /dev/null
+++ b/src/guppy/frontend/storenames_config.py
@@ -0,0 +1,104 @@
+import logging
+
+import panel as pn
+
+pn.extension()
+
+logger = logging.getLogger(__name__)
+
+
+class StorenamesConfig:
+ def __init__(
+ self,
+ show_config_button,
+ storename_dropdowns,
+ storename_textboxes,
+ storenames,
+ storenames_cache,
+ ):
+ self.config_widgets = []
+ self._dropdown_help_map = {}
+ storename_dropdowns.clear()
+ storename_textboxes.clear()
+
+ if len(storenames) == 0:
+ return
+
+ self.config_widgets.append(
+ pn.pane.Markdown(
+ "## Configure Storenames\nSelect appropriate options for each storename and provide names as needed:"
+ )
+ )
+
+ for i, storename in enumerate(storenames):
+ self.setup_storename(i, storename, storename_dropdowns, storename_textboxes, storenames_cache)
+
+ # Add show button
+ self.config_widgets.append(pn.Spacer(height=20))
+ self.config_widgets.append(show_config_button)
+ self.config_widgets.append(
+ pn.pane.Markdown(
+ "*Click 'Show Selected Configuration' to apply your selections.*",
+ styles={"font-size": "12px", "color": "gray"},
+ )
+ )
+
+ def _on_dropdown_value_change(self, event):
+ help_pane = self._dropdown_help_map.get(event.obj)
+ if help_pane is None:
+ return
+ dropdown_value = event.new
+ help_pane.object = self._get_help_text(dropdown_value=dropdown_value)
+
+ def _get_help_text(self, dropdown_value):
+ if dropdown_value == "control":
+ return "*Type appropriate region name*"
+ elif dropdown_value == "signal":
+ return "*Type appropriate region name*"
+ elif dropdown_value == "event TTLs":
+ return "*Type event name for the TTLs*"
+ else:
+ return ""
+
+ def setup_storename(self, i, storename, storename_dropdowns, storename_textboxes, storenames_cache):
+ # Create a row for each storename
+ row_widgets = []
+
+ # Label
+ label = pn.pane.Markdown(f"**{storename}:**")
+ row_widgets.append(label)
+
+ # Dropdown options
+ if storename in storenames_cache:
+ options = storenames_cache[storename]
+ default_value = options[0] if options else ""
+ else:
+ options = ["", "control", "signal", "event TTLs"]
+ default_value = ""
+
+ # Create unique key for widget
+ widget_key = (
+ f"{storename}_{i}"
+ if f"{storename}_{i}" not in storename_dropdowns
+ else f"{storename}_{i}_{len(storename_dropdowns)}"
+ )
+
+ dropdown = pn.widgets.Select(name="Type", value=default_value, options=options, width=150)
+ storename_dropdowns[widget_key] = dropdown
+ row_widgets.append(dropdown)
+
+ # Text input (only show if not cached or if control/signal/event TTLs selected)
+ if storename not in storenames_cache or default_value in ["control", "signal", "event TTLs"]:
+ textbox = pn.widgets.TextInput(name="Name", value="", placeholder="Enter region/event name", width=200)
+ storename_textboxes[widget_key] = textbox
+ row_widgets.append(textbox)
+
+ # Add helper text based on selection
+ initial_help_text = self._get_help_text(default_value)
+ help_pane = pn.pane.Markdown(initial_help_text, styles={"color": "gray", "font-size": "12px"})
+ self._dropdown_help_map[dropdown] = help_pane
+ dropdown.param.watch(self._on_dropdown_value_change, "value")
+ row_widgets.append(help_pane)
+
+ # Add the row to config widgets
+ self.config_widgets.append(pn.Row(*row_widgets, margin=(5, 0)))
diff --git a/src/guppy/frontend/storenames_instructions.py b/src/guppy/frontend/storenames_instructions.py
new file mode 100644
index 0000000..ba5fe5d
--- /dev/null
+++ b/src/guppy/frontend/storenames_instructions.py
@@ -0,0 +1,109 @@
+import glob
+import logging
+import os
+
+import holoviews as hv
+import numpy as np
+import pandas as pd
+import panel as pn
+
+# hv.extension()
+pn.extension()
+
+logger = logging.getLogger(__name__)
+
+
+class StorenamesInstructions:
+ def __init__(self, folder_path):
+ # instructions about how to save the storeslist file
+ self.mark_down = pn.pane.Markdown(
+ """
+
+
+ ### Instructions to follow :
+
+ - Check Storenames to repeat checkbox and see instructions in “Github Wiki” for duplicating storenames.
+ Otherwise do not check the Storenames to repeat checkbox.
+ - Select storenames from list and click “Select Storenames” to populate area below.
+ - Enter names for storenames, in order, using the following naming convention:
+ Isosbestic = “control_region” (ex: Dv1A= control_DMS)
+ Signal= “signal_region” (ex: Dv2A= signal_DMS)
+ TTLs can be named using any convention (ex: PrtR = RewardedPortEntries) but should be kept consistent for later group analysis
+
+ ```
+ {"storenames": ["Dv1A", "Dv2A",
+ "Dv3B", "Dv4B",
+ "LNRW", "LNnR",
+ "PrtN", "PrtR",
+ "RNPS"],
+ "names_for_storenames": ["control_DMS", "signal_DMS",
+ "control_DLS", "signal_DLS",
+ "RewardedNosepoke", "UnrewardedNosepoke",
+ "UnrewardedPort", "RewardedPort",
+ "InactiveNosepoke"]}
+ ```
+ - If user has saved storenames before, clicking "Select Storenames" button will pop up a dialog box
+ showing previously used names for storenames. Select names for storenames by checking a checkbox and
+ click on "Show" to populate the text area in the Storenames GUI. Close the dialog box.
+
+ - Select “create new” or “overwrite” to generate a new storenames list or replace a previous one
+ - Click Save
+
+ """,
+ width=550,
+ )
+
+ self.widget = pn.Column("# " + os.path.basename(folder_path), self.mark_down)
+
+
+class StorenamesInstructionsNPM(StorenamesInstructions):
+ def __init__(self, folder_path):
+ super().__init__(folder_path=folder_path)
+ path_chev = glob.glob(os.path.join(folder_path, "*chev*"))
+ path_chod = glob.glob(os.path.join(folder_path, "*chod*"))
+ path_chpr = glob.glob(os.path.join(folder_path, "*chpr*"))
+ combine_paths = path_chev + path_chod + path_chpr
+ self.d = dict()
+ for i in range(len(combine_paths)):
+ basename = (os.path.basename(combine_paths[i])).split(".")[0]
+ df = pd.read_csv(combine_paths[i])
+ self.d[basename] = {"x": np.array(df["timestamps"]), "y": np.array(df["data"])}
+ keys = list(self.d.keys())
+ self.mark_down_np = pn.pane.Markdown(
+ """
+ ### Extra Instructions to follow when using Neurophotometrics data :
+ - Guppy will take the NPM data, which has interleaved frames
+ from the signal and control channels, and divide it out into
+ separate channels for each site you recordded.
+ However, since NPM does not automatically annotate which
+ frames belong to the signal channel and which belong to the
+ control channel, the user must specify this for GuPPy.
+ - Each of your recording sites will have a channel
+ named “chod” and a channel named “chev”
+ - View the plots below and, for each site,
+ determine whether the “chev” or “chod” channel is signal or control
+ - When you give your storenames, name the channels appropriately.
+ For example, “chev1” might be “signal_A” and
+ “chod1” might be “control_A” (or vice versa).
+
+ """
+ )
+ self.plot_select = pn.widgets.Select(
+ name="Select channel to see correspondings channels", options=keys, value=keys[0]
+ )
+ self.plot_pane = pn.pane.HoloViews(self._make_plot(self.plot_select.value), width=550)
+ self.plot_select.param.watch(self._on_plot_select_change, "value")
+
+ self.widget = pn.Column(
+ "# " + os.path.basename(folder_path),
+ self.mark_down,
+ self.mark_down_np,
+ self.plot_select,
+ self.plot_pane,
+ )
+
+ def _make_plot(self, plot_key):
+ return hv.Curve((self.d[plot_key]["x"], self.d[plot_key]["y"])).opts(width=550)
+
+ def _on_plot_select_change(self, event):
+ self.plot_pane.object = self._make_plot(event.new)
diff --git a/src/guppy/frontend/storenames_selector.py b/src/guppy/frontend/storenames_selector.py
new file mode 100644
index 0000000..97919d6
--- /dev/null
+++ b/src/guppy/frontend/storenames_selector.py
@@ -0,0 +1,152 @@
+import json
+import logging
+
+import panel as pn
+
+from .storenames_config import StorenamesConfig
+
+pn.extension()
+
+logger = logging.getLogger(__name__)
+
+
+class StorenamesSelector:
+
+ def __init__(self, allnames):
+ self.alert = pn.pane.Alert("#### No alerts !!", alert_type="danger", height=80, width=600)
+ if len(allnames) == 0:
+ self.alert.object = (
+ "####Alert !! \n No storenames found. There are not any TDT files or csv files to look for storenames."
+ )
+
+ # creating different buttons and selectors for the GUI
+ self.cross_selector = pn.widgets.CrossSelector(
+ name="Store Names Selection", value=[], options=allnames, width=600
+ )
+ self.multi_choice = pn.widgets.MultiChoice(
+ name="Select Storenames which you want more than once (multi-choice: multiple options selection)",
+ value=[],
+ options=allnames,
+ )
+
+ self.literal_input_1 = pn.widgets.LiteralInput(
+ name="Number of times you want the above storename (list)", value=[], type=list
+ )
+ # self.literal_input_2 = pn.widgets.LiteralInput(name='Names for Storenames (list)', type=list)
+
+ self.repeat_storenames = pn.widgets.Checkbox(name="Storenames to repeat", value=False)
+ self.repeat_storename_wd = pn.WidgetBox("", width=600)
+
+ self.repeat_storenames.link(self.repeat_storename_wd, callbacks={"value": self.callback})
+ # self.repeat_storename_wd = pn.WidgetBox('Storenames to repeat (leave blank if not needed)', multi_choice, literal_input_1, background="white", width=600)
+
+ self.update_options = pn.widgets.Button(name="Select Storenames", width=600)
+ self.save = pn.widgets.Button(name="Save", width=600)
+
+ self.text = pn.widgets.LiteralInput(value=[], name="Selected Store Names", type=list, width=600)
+
+ self.path = pn.widgets.TextInput(name="Location to Stores List file", width=600)
+
+ self.mark_down_for_overwrite = pn.pane.Markdown(
+ """ Select option from below if user wants to over-write a file or create a new file.
+ **Creating a new file will make a new output folder and will get saved at that location.**
+ If user selects to over-write a file **Select location of the file to over-write** will provide
+ the existing options of the output folders where user needs to over-write the file""",
+ width=600,
+ )
+
+ self.select_location = pn.widgets.Select(
+ name="Select location of the file to over-write", value="None", options=["None"], width=600
+ )
+
+ self.overwrite_button = pn.widgets.MenuButton(
+ name="over-write storeslist file or create a new one? ",
+ items=["over_write_file", "create_new_file"],
+ button_type="default",
+ split=True,
+ width=600,
+ )
+
+ self.literal_input_2 = pn.widgets.CodeEditor(
+ value="""{}""", theme="tomorrow", language="json", height=250, width=600
+ )
+
+ self.take_widgets = pn.WidgetBox(self.multi_choice, self.literal_input_1)
+
+ self.change_widgets = pn.WidgetBox(self.text)
+
+ # Panel-based storename configuration (replaces Tkinter dialog)
+ self.storename_config_widgets = pn.Column(visible=False)
+ self.show_config_button = pn.widgets.Button(name="Show Selected Configuration", width=600)
+
+ self.widget = pn.Column(
+ self.repeat_storenames,
+ self.repeat_storename_wd,
+ pn.Spacer(height=20),
+ self.cross_selector,
+ self.update_options,
+ self.storename_config_widgets,
+ pn.Spacer(height=10),
+ self.text,
+ self.literal_input_2,
+ self.alert,
+ self.mark_down_for_overwrite,
+ self.overwrite_button,
+ self.select_location,
+ self.save,
+ self.path,
+ )
+
+ def callback(self, target, event):
+ if event.new == True:
+ target.objects = [self.multi_choice, self.literal_input_1]
+ elif event.new == False:
+ target.clear()
+
+ def get_select_location(self):
+ return self.select_location.value
+
+ def set_select_location_options(self, options):
+ self.select_location.options = options
+
+ def set_alert_message(self, message):
+ self.alert.object = message
+
+ def get_literal_input_2(self): # TODO: come up with a better name for this method.
+ d = json.loads(self.literal_input_2.value)
+ return d
+
+ def set_literal_input_2(self, d): # TODO: come up with a better name for this method.
+ self.literal_input_2.value = str(json.dumps(d, indent=2))
+
+ def get_take_widgets(self):
+ return [w.value for w in self.take_widgets]
+
+ def set_change_widgets(self, value):
+ for w in self.change_widgets:
+ w.value = value
+
+ def get_cross_selector(self):
+ return self.cross_selector.value
+
+ def set_path(self, value):
+ self.path.value = value
+
+ def attach_callbacks(self, button_name_to_onclick_fn: dict):
+ for button_name, onclick_fn in button_name_to_onclick_fn.items():
+ button = getattr(self, button_name)
+ button.on_click(onclick_fn)
+
+ def configure_storenames(self, storename_dropdowns, storename_textboxes, storenames, storenames_cache):
+ # Create Panel widgets for storename configuration
+ self.storenames_config = StorenamesConfig(
+ show_config_button=self.show_config_button,
+ storename_dropdowns=storename_dropdowns,
+ storename_textboxes=storename_textboxes,
+ storenames=storenames,
+ storenames_cache=storenames_cache,
+ )
+
+ # Update the configuration panel
+ self.storename_config_widgets.objects = self.storenames_config.config_widgets
+ self.storename_config_widgets.visible = len(storenames) > 0
diff --git a/src/guppy/frontend/visualization_dashboard.py b/src/guppy/frontend/visualization_dashboard.py
new file mode 100644
index 0000000..1444b86
--- /dev/null
+++ b/src/guppy/frontend/visualization_dashboard.py
@@ -0,0 +1,158 @@
+import logging
+
+import panel as pn
+
+from .frontend_utils import scanPortsAndFind
+
+pn.extension()
+
+logger = logging.getLogger(__name__)
+
+
+class VisualizationDashboard:
+ """Dashboard for interactive PSTH and heatmap visualization.
+
+ Wraps a ``Viewer`` instance with Panel widgets and a tabbed layout.
+ Data loading, preparation, and Viewer instantiation are handled
+ externally; this class is responsible for widget creation, layout
+ assembly, and serving the application.
+
+ Parameters
+ ----------
+ plotter : ParameterizedPlotter
+ A fully configured ParameterizedPlotter instance that provides reactive plot
+ methods and param-based controls.
+ basename : str
+ Session name displayed as the tab title.
+ """
+
+ def __init__(self, *, plotter, basename):
+ self.plotter = plotter
+ self.basename = basename
+ self._psth_tab = self._build_psth_tab()
+ self._heatmap_tab = self._build_heatmap_tab()
+
+ def _build_psth_tab(self):
+ """Build the PSTH tab with controls and plot panels."""
+ psth_checkbox = pn.Param(
+ self.plotter.param.select_trials_checkbox,
+ widgets={
+ "select_trials_checkbox": {
+ "type": pn.widgets.CheckBoxGroup,
+ "inline": True,
+ "name": "Select mean and/or just trials",
+ }
+ },
+ )
+ parameters = pn.Param(
+ self.plotter.param.selector_for_multipe_events_plot,
+ widgets={
+ "selector_for_multipe_events_plot": {"type": pn.widgets.CrossSelector, "width": 550, "align": "start"}
+ },
+ )
+ psth_y_parameters = pn.Param(
+ self.plotter.param.psth_y,
+ widgets={
+ "psth_y": {
+ "type": pn.widgets.MultiSelect,
+ "name": "Trial # - Timestamps",
+ "width": 200,
+ "size": 15,
+ "align": "start",
+ }
+ },
+ )
+
+ event_selector = pn.Param(
+ self.plotter.param.event_selector, widgets={"event_selector": {"type": pn.widgets.Select, "width": 400}}
+ )
+ x_selector = pn.Param(self.plotter.param.x, widgets={"x": {"type": pn.widgets.Select, "width": 180}})
+ y_selector = pn.Param(self.plotter.param.y, widgets={"y": {"type": pn.widgets.Select, "width": 180}})
+
+ width_plot = pn.Param(
+ self.plotter.param.Width_Plot, widgets={"Width_Plot": {"type": pn.widgets.Select, "width": 70}}
+ )
+ height_plot = pn.Param(
+ self.plotter.param.Height_Plot, widgets={"Height_Plot": {"type": pn.widgets.Select, "width": 70}}
+ )
+ ylabel = pn.Param(self.plotter.param.Y_Label, widgets={"Y_Label": {"type": pn.widgets.Select, "width": 70}})
+ save_opts = pn.Param(
+ self.plotter.param.save_options, widgets={"save_options": {"type": pn.widgets.Select, "width": 70}}
+ )
+
+ xlimit_plot = pn.Param(
+ self.plotter.param.X_Limit, widgets={"X_Limit": {"type": pn.widgets.RangeSlider, "width": 180}}
+ )
+ ylimit_plot = pn.Param(
+ self.plotter.param.Y_Limit, widgets={"Y_Limit": {"type": pn.widgets.RangeSlider, "width": 180}}
+ )
+ save_psth = pn.Param(
+ self.plotter.param.save_psth, widgets={"save_psth": {"type": pn.widgets.Button, "width": 400}}
+ )
+
+ options = pn.Column(
+ event_selector,
+ pn.Row(x_selector, y_selector),
+ pn.Row(xlimit_plot, ylimit_plot),
+ pn.Row(width_plot, height_plot, ylabel, save_opts),
+ save_psth,
+ )
+
+ options_selectors = pn.Row(options, parameters)
+
+ return pn.Column(
+ "## " + self.basename,
+ pn.Row(options_selectors, pn.Column(psth_checkbox, psth_y_parameters), width=1200),
+ self.plotter.contPlot,
+ self.plotter.update_selector,
+ self.plotter.plot_specific_trials,
+ )
+
+ def _build_heatmap_tab(self):
+ """Build the heatmap tab with controls and plot panels."""
+ heatmap_y_parameters = pn.Param(
+ self.plotter.param.heatmap_y,
+ widgets={
+ "heatmap_y": {"type": pn.widgets.MultiSelect, "name": "Trial # - Timestamps", "width": 200, "size": 30}
+ },
+ )
+ event_selector_heatmap = pn.Param(
+ self.plotter.param.event_selector_heatmap,
+ widgets={"event_selector_heatmap": {"type": pn.widgets.Select, "width": 150}},
+ )
+ color_map = pn.Param(
+ self.plotter.param.color_map, widgets={"color_map": {"type": pn.widgets.Select, "width": 150}}
+ )
+ width_heatmap = pn.Param(
+ self.plotter.param.width_heatmap, widgets={"width_heatmap": {"type": pn.widgets.Select, "width": 150}}
+ )
+ height_heatmap = pn.Param(
+ self.plotter.param.height_heatmap, widgets={"height_heatmap": {"type": pn.widgets.Select, "width": 150}}
+ )
+ save_hm = pn.Param(self.plotter.param.save_hm, widgets={"save_hm": {"type": pn.widgets.Button, "width": 150}})
+ save_options_heatmap = pn.Param(
+ self.plotter.param.save_options_heatmap,
+ widgets={"save_options_heatmap": {"type": pn.widgets.Select, "width": 150}},
+ )
+
+ return pn.Column(
+ "## " + self.basename,
+ pn.Row(
+ event_selector_heatmap,
+ color_map,
+ width_heatmap,
+ height_heatmap,
+ save_options_heatmap,
+ pn.Column(pn.Spacer(height=25), save_hm),
+ ),
+ pn.Row(self.plotter.heatmap, heatmap_y_parameters),
+ )
+
+ def show(self):
+ """Serve the dashboard in a browser on an available port."""
+ logger.info("app")
+ template = pn.template.MaterialTemplate(title="Visualization GUI")
+ number = scanPortsAndFind(start_port=5000, end_port=5200)
+ app = pn.Tabs(("PSTH", self._psth_tab), ("Heat Map", self._heatmap_tab))
+ template.main.append(app)
+ template.show(port=number)
diff --git a/src/guppy/main.py b/src/guppy/main.py
index 478e991..cb5ebc8 100644
--- a/src/guppy/main.py
+++ b/src/guppy/main.py
@@ -11,12 +11,12 @@
import panel as pn
-from .savingInputParameters import savingInputParameters
+from .orchestration.home import build_homepage
def serve_app():
"""Serve the GuPPy application using Panel."""
- template = savingInputParameters()
+ template = build_homepage()
pn.serve(template, show=True)
diff --git a/src/guppy/orchestration/home.py b/src/guppy/orchestration/home.py
new file mode 100644
index 0000000..9b7e6c5
--- /dev/null
+++ b/src/guppy/orchestration/home.py
@@ -0,0 +1,101 @@
+import json
+import logging
+import os
+import subprocess
+import sys
+from threading import Thread
+
+import panel as pn
+
+from .save_parameters import save_parameters
+from .storenames import orchestrate_storenames_page
+from .visualize import visualizeResults
+from ..frontend.input_parameters import ParameterForm
+from ..frontend.path_selection import get_folder_path
+from ..frontend.progress import readPBIncrementValues
+from ..frontend.sidebar import Sidebar
+
+logger = logging.getLogger(__name__)
+
+
+def readRawData(parameter_form):
+ inputParameters = parameter_form.getInputParameters()
+ subprocess.call([sys.executable, "-m", "guppy.orchestration.read_raw_data", json.dumps(inputParameters)])
+
+
+def preprocess(parameter_form):
+ inputParameters = parameter_form.getInputParameters()
+ subprocess.call([sys.executable, "-m", "guppy.orchestration.preprocess", json.dumps(inputParameters)])
+
+
+def psthComputation(parameter_form, current_dir):
+ inputParameters = parameter_form.getInputParameters()
+ inputParameters["curr_dir"] = current_dir
+ subprocess.call([sys.executable, "-m", "guppy.orchestration.psth", json.dumps(inputParameters)])
+
+
+def build_homepage():
+ pn.extension()
+ global folder_path
+ folder_path = get_folder_path()
+ current_dir = os.getcwd()
+
+ template = pn.template.BootstrapTemplate(title="Input Parameters GUI")
+ parameter_form = ParameterForm(folder_path=folder_path, template=template)
+ sidebar = Sidebar(template=template)
+
+ # ------------------------------------------------------------------------------------------------------------------
+ # onclick closure functions for sidebar buttons
+ def onclickProcess(event=None):
+ inputParameters = parameter_form.getInputParameters()
+ save_parameters(inputParameters=inputParameters)
+
+ def onclickStorenames(event=None):
+ inputParameters = parameter_form.getInputParameters()
+ orchestrate_storenames_page(inputParameters)
+
+ def onclickVisualization(event=None):
+ inputParameters = parameter_form.getInputParameters()
+ visualizeResults(inputParameters)
+
+ def onclickreaddata(event=None):
+ thread = Thread(target=readRawData, args=(parameter_form,))
+ thread.start()
+ readPBIncrementValues(sidebar.read_progress)
+ thread.join()
+
+ def onclickpreprocess(event=None):
+ thread = Thread(target=preprocess, args=(parameter_form,))
+ thread.start()
+ readPBIncrementValues(sidebar.extract_progress)
+ thread.join()
+
+ def onclickpsth(event=None):
+ thread = Thread(target=psthComputation, args=(parameter_form, current_dir))
+ thread.start()
+ readPBIncrementValues(sidebar.psth_progress)
+ thread.join()
+
+ # ------------------------------------------------------------------------------------------------------------------
+
+ button_name_to_onclick_fn = {
+ "save_button": onclickProcess,
+ "open_storenames": onclickStorenames,
+ "read_rawData": onclickreaddata,
+ "preprocess": onclickpreprocess,
+ "psth_computation": onclickpsth,
+ "open_visualization": onclickVisualization,
+ }
+ sidebar.attach_callbacks(button_name_to_onclick_fn=button_name_to_onclick_fn)
+ sidebar.add_to_template()
+
+ # Expose minimal hooks and widgets to enable programmatic testing
+ template._hooks = {
+ "onclickProcess": onclickProcess,
+ "getInputParameters": parameter_form.getInputParameters,
+ }
+ template._widgets = {
+ "files_1": parameter_form.files_1,
+ }
+
+ return template
diff --git a/src/guppy/preprocess.py b/src/guppy/orchestration/preprocess.py
similarity index 76%
rename from src/guppy/preprocess.py
rename to src/guppy/orchestration/preprocess.py
index e4812a2..b5b9d0d 100755
--- a/src/guppy/preprocess.py
+++ b/src/guppy/orchestration/preprocess.py
@@ -3,23 +3,23 @@
import logging
import os
import sys
+from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
-from .analysis.artifact_removal import remove_artifacts
-from .analysis.combine_data import combine_data
-from .analysis.control_channel import add_control_channel, create_control_channel
-from .analysis.io_utils import (
+from ..analysis.artifact_removal import remove_artifacts
+from ..analysis.combine_data import combine_data
+from ..analysis.control_channel import add_control_channel, create_control_channel
+from ..analysis.io_utils import (
check_storeslistfile,
check_TDT,
find_files,
- get_all_stores_for_combining_data, # noqa: F401 -- Necessary for other modules that depend on preprocess.py
get_coords,
read_hdf5,
takeOnlyDirs,
)
-from .analysis.standard_io import (
+from ..analysis.standard_io import (
read_control_and_signal,
read_coords_pairwise,
read_corrected_data,
@@ -37,8 +37,12 @@
write_corrected_ttl_timestamps,
write_zscore,
)
-from .analysis.timestamp_correction import correct_timestamps
-from .analysis.z_score import compute_z_score
+from ..analysis.timestamp_correction import correct_timestamps
+from ..analysis.z_score import compute_z_score
+from ..frontend.artifact_removal import ArtifactRemovalWidget
+from ..frontend.progress import writeToFile
+from ..utils.utils import get_all_stores_for_combining_data
+from ..visualization.preprocessing import visualize_preprocessing
logger = logging.getLogger(__name__)
@@ -47,17 +51,10 @@
plt.switch_backend("TKAgg")
-def writeToFile(value: str):
- with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
- file.write(value)
-
-
-# function to plot z_score
-def visualize_z_score(filepath):
-
+def execute_preprocessing_visualization(filepath, visualization_type: Literal["z_score", "dff"]):
name = os.path.basename(filepath)
- path = glob.glob(os.path.join(filepath, "z_score_*"))
+ path = glob.glob(os.path.join(filepath, f"{visualization_type}_*"))
path = sorted(path)
@@ -66,122 +63,7 @@ def visualize_z_score(filepath):
name_1 = basename.split("_")[-1]
x = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew")
y = read_hdf5("", path[i], "data")
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(x, y)
- ax.set_title(basename)
- fig.suptitle(name)
- # plt.show()
-
-
-# function to plot deltaF/F
-def visualize_dff(filepath):
- name = os.path.basename(filepath)
-
- path = glob.glob(os.path.join(filepath, "dff_*"))
-
- path = sorted(path)
-
- for i in range(len(path)):
- basename = (os.path.basename(path[i])).split(".")[0]
- name_1 = basename.split("_")[-1]
- x = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew")
- y = read_hdf5("", path[i], "data")
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(x, y)
- ax.set_title(basename)
- fig.suptitle(name)
- # plt.show()
-
-
-def visualize(filepath, x, y1, y2, y3, plot_name, removeArtifacts):
-
- # plotting control and signal data
-
- if (y1 == 0).all() == True:
- y1 = np.zeros(x.shape[0])
-
- coords_path = os.path.join(filepath, "coordsForPreProcessing_" + plot_name[0].split("_")[-1] + ".npy")
- name = os.path.basename(filepath)
- fig = plt.figure()
- ax1 = fig.add_subplot(311)
- (line1,) = ax1.plot(x, y1)
- ax1.set_title(plot_name[0])
- ax2 = fig.add_subplot(312)
- (line2,) = ax2.plot(x, y2)
- ax2.set_title(plot_name[1])
- ax3 = fig.add_subplot(313)
- (line3,) = ax3.plot(x, y2)
- (line3,) = ax3.plot(x, y3)
- ax3.set_title(plot_name[2])
- fig.suptitle(name)
-
- hfont = {"fontname": "DejaVu Sans"}
-
- if removeArtifacts == True and os.path.exists(coords_path):
- ax3.set_xlabel("Time(s) \n Note : Artifacts have been removed, but are not reflected in this plot.", **hfont)
- else:
- ax3.set_xlabel("Time(s)", **hfont)
-
- global coords
- coords = []
-
- # clicking 'space' key on keyboard will draw a line on the plot so that user can see what chunks are selected
- # and clicking 'd' key on keyboard will deselect the selected point
- def onclick(event):
- # global ix, iy
-
- if event.key == " ":
- ix, iy = event.xdata, event.ydata
- logger.info(f"x = {ix}, y = {iy}")
- y1_max, y1_min = np.amax(y1), np.amin(y1)
- y2_max, y2_min = np.amax(y2), np.amin(y2)
-
- # ax1.plot([ix,ix], [y1_max, y1_min], 'k--')
- # ax2.plot([ix,ix], [y2_max, y2_min], 'k--')
-
- ax1.axvline(ix, c="black", ls="--")
- ax2.axvline(ix, c="black", ls="--")
- ax3.axvline(ix, c="black", ls="--")
-
- fig.canvas.draw()
-
- global coords
- coords.append((ix, iy))
-
- # if len(coords) == 2:
- # fig.canvas.mpl_disconnect(cid)
-
- return coords
-
- elif event.key == "d":
- if len(coords) > 0:
- logger.info(f"x = {coords[-1][0]}, y = {coords[-1][1]}; deleted")
- del coords[-1]
- ax1.lines[-1].remove()
- ax2.lines[-1].remove()
- ax3.lines[-1].remove()
- fig.canvas.draw()
-
- return coords
-
- # close the plot will save coordinates for all the selected chunks in the data
- def plt_close_event(event):
- global coords
- if coords and len(coords) > 0:
- name_1 = plot_name[0].split("_")[-1]
- np.save(os.path.join(filepath, "coordsForPreProcessing_" + name_1 + ".npy"), coords)
- logger.info(f"Coordinates file saved at {os.path.join(filepath, 'coordsForPreProcessing_'+name_1+'.npy')}")
- fig.canvas.mpl_disconnect(cid)
- coords = []
-
- cid = fig.canvas.mpl_connect("key_press_event", onclick)
- cid = fig.canvas.mpl_connect("close_event", plt_close_event)
- # multi = MultiCursor(fig.canvas, (ax1, ax2), color='g', lw=1, horizOn=False, vertOn=True)
-
- # plt.show()
- # return fig
+ fig, ax = visualize_preprocessing(suptitle=name, title=basename, x=x, y=y)
# function to plot control and signal, also provide a feature to select chunks for artifacts removal
@@ -198,6 +80,7 @@ def visualizeControlAndSignal(filepath, removeArtifacts):
path = np.asarray(path).reshape(2, -1)
+ widgets = []
for i in range(path.shape[1]):
name_1 = ((os.path.basename(path[0, i])).split(".")[0]).split("_")
@@ -216,7 +99,9 @@ def visualizeControlAndSignal(filepath, removeArtifacts):
(os.path.basename(path[1, i])).split(".")[0],
(os.path.basename(cntrl_sig_fit_path)).split(".")[0],
]
- visualize(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts)
+ widget = ArtifactRemovalWidget(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts)
+ widgets.append(widget)
+ return widgets
# function to execute timestamps corrections using functions timestampCorrection and decide_naming_convention_and_applyCorrection
@@ -336,23 +221,45 @@ def execute_zscore(folderNames, inputParameters):
write_zscore(filepath, name, z_score, dff, control_fit, temp_control_arr)
logger.info(f"z-score for the data in {filepath} computed.")
+ writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n")
+ inputParameters["step"] += 1
+
+ plt.show()
+ logger.info("Z-score computation completed.")
+
+
+def visualize_z_score(inputParameters, folderNames):
+ plot_zScore_dff = inputParameters["plot_zScore_dff"]
+ combine_data = inputParameters["combine_data"]
+ remove_artifacts = inputParameters["removeArtifacts"]
+
+ storesListPath = []
+ for i in range(len(folderNames)):
+ if combine_data == True:
+ storesListPath.append([folderNames[i][0]])
+ else:
+ filepath = folderNames[i]
+ storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
+ storesListPath = np.concatenate(storesListPath)
+
+ widgets = []
+ for j in range(len(storesListPath)):
+ filepath = storesListPath[j]
if not remove_artifacts:
- visualizeControlAndSignal(filepath, removeArtifacts=remove_artifacts)
+ # a reference to widgets has to persist in the same scope as plt.show() is called
+ widgets.extend(visualizeControlAndSignal(filepath, removeArtifacts=remove_artifacts))
if plot_zScore_dff == "z_score":
- visualize_z_score(filepath)
+ execute_preprocessing_visualization(filepath, visualization_type="z_score")
if plot_zScore_dff == "dff":
- visualize_dff(filepath)
+ execute_preprocessing_visualization(filepath, visualization_type="dff")
if plot_zScore_dff == "Both":
- visualize_z_score(filepath)
- visualize_dff(filepath)
-
- writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n")
- inputParameters["step"] += 1
+ execute_preprocessing_visualization(filepath, visualization_type="z_score")
+ execute_preprocessing_visualization(filepath, visualization_type="dff")
plt.show()
- logger.info("Z-score computation completed.")
+ logger.info("Visualization of z-score and dF/F completed.")
# function to remove artifacts from z-score data
@@ -394,15 +301,34 @@ def execute_artifact_removal(folderNames, inputParameters):
)
write_artifact_removal(filepath, name_to_data, pair_name_to_timestamps, compound_name_to_ttl_timestamps)
- visualizeControlAndSignal(filepath, removeArtifacts=True)
writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n")
inputParameters["step"] += 1
- plt.show()
+ visualize_artifact_removal(folderNames, inputParameters)
logger.info("Artifact removal completed.")
+def visualize_artifact_removal(folderNames, inputParameters):
+ combine_data = inputParameters["combine_data"]
+
+ storesListPath = []
+ for i in range(len(folderNames)):
+ if combine_data == True:
+ storesListPath.append([folderNames[i][0]])
+ else:
+ filepath = folderNames[i]
+ storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
+
+ storesListPath = np.concatenate(storesListPath)
+
+ for j in range(len(storesListPath)):
+ filepath = storesListPath[j]
+ visualizeControlAndSignal(filepath, removeArtifacts=True)
+ plt.show()
+ logger.info("Visualization of artifact removal completed.")
+
+
# function to combine data when there are two different data files for the same recording session
# it will combine the data, do timestamps processing and save the combined data in the first output folder.
def execute_combine_data(folderNames, inputParameters, storesList):
@@ -490,6 +416,7 @@ def extractTsAndSignal(inputParameters):
writeToFile(str((pbMaxValue + 1) * 10) + "\n" + str(10) + "\n")
execute_timestamp_correction(folderNames, inputParameters)
execute_zscore(folderNames, inputParameters)
+ visualize_z_score(inputParameters, folderNames)
if remove_artifacts == True:
execute_artifact_removal(folderNames, inputParameters)
else:
@@ -499,6 +426,7 @@ def extractTsAndSignal(inputParameters):
storesList = check_storeslistfile(folderNames)
op_folder = execute_combine_data(folderNames, inputParameters, storesList)
execute_zscore(op_folder, inputParameters)
+ visualize_z_score(inputParameters, op_folder)
if remove_artifacts == True:
execute_artifact_removal(op_folder, inputParameters)
diff --git a/src/guppy/computePsth.py b/src/guppy/orchestration/psth.py
similarity index 94%
rename from src/guppy/computePsth.py
rename to src/guppy/orchestration/psth.py
index 32d9be1..e1d78de 100755
--- a/src/guppy/computePsth.py
+++ b/src/guppy/orchestration/psth.py
@@ -13,44 +13,31 @@
import numpy as np
from scipy import signal as ss
-from .analysis.compute_psth import compute_psth
-from .analysis.cross_correlation import compute_cross_correlation
-from .analysis.io_utils import (
- get_all_stores_for_combining_data,
+from ..analysis.compute_psth import compute_psth
+from ..analysis.cross_correlation import compute_cross_correlation
+from ..analysis.io_utils import (
make_dir_for_cross_correlation,
makeAverageDir,
- read_Df,
read_hdf5,
write_hdf5,
)
-from .analysis.psth_average import averageForGroup
-from .analysis.psth_peak_and_area import compute_psth_peak_and_area
-from .analysis.psth_utils import (
+from ..analysis.psth_average import averageForGroup
+from ..analysis.psth_peak_and_area import compute_psth_peak_and_area
+from ..analysis.psth_utils import (
create_Df_for_cross_correlation,
create_Df_for_psth,
getCorrCombinations,
)
-from .analysis.standard_io import (
+from ..analysis.standard_io import (
write_peak_and_area_to_csv,
write_peak_and_area_to_hdf5,
)
+from ..frontend.progress import writeToFile
+from ..utils.utils import get_all_stores_for_combining_data, read_Df, takeOnlyDirs
logger = logging.getLogger(__name__)
-def takeOnlyDirs(paths):
- removePaths = []
- for p in paths:
- if os.path.isfile(p):
- removePaths.append(p)
- return list(set(paths) - set(removePaths))
-
-
-def writeToFile(value: str):
- with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
- file.write(value)
-
-
# function to create PSTH for each event using function helper_psth and save the PSTH to h5 file
def execute_compute_psth(filepath, event, inputParameters):
@@ -358,7 +345,7 @@ def psthForEachStorename(inputParameters):
def main(input_parameters):
try:
inputParameters = psthForEachStorename(input_parameters)
- subprocess.call([sys.executable, "-m", "guppy.findTransientsFreqAndAmp", json.dumps(inputParameters)])
+ subprocess.call([sys.executable, "-m", "guppy.orchestration.transients", json.dumps(inputParameters)])
logger.info("#" * 400)
except Exception as e:
with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
diff --git a/src/guppy/readTevTsq.py b/src/guppy/orchestration/read_raw_data.py
similarity index 90%
rename from src/guppy/readTevTsq.py
rename to src/guppy/orchestration/read_raw_data.py
index 19a0a4a..cd82634 100755
--- a/src/guppy/readTevTsq.py
+++ b/src/guppy/orchestration/read_raw_data.py
@@ -14,25 +14,14 @@
TdtRecordingExtractor,
read_and_save_all_events,
)
+from guppy.frontend.progress import writeToFile
+from guppy.utils.utils import takeOnlyDirs
logger = logging.getLogger(__name__)
-def takeOnlyDirs(paths):
- removePaths = []
- for p in paths:
- if os.path.isfile(p):
- removePaths.append(p)
- return list(set(paths) - set(removePaths))
-
-
-def writeToFile(value: str):
- with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
- file.write(value)
-
-
# function to read data from 'tsq' and 'tev' files
-def readRawData(inputParameters):
+def orchestrate_read_raw_data(inputParameters):
logger.debug("### Reading raw data... ###")
# get input parameters
@@ -99,7 +88,7 @@ def readRawData(inputParameters):
def main(input_parameters):
logger.info("run")
try:
- readRawData(input_parameters)
+ orchestrate_read_raw_data(input_parameters)
logger.info("#" * 400)
except Exception as e:
with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
diff --git a/src/guppy/orchestration/save_parameters.py b/src/guppy/orchestration/save_parameters.py
new file mode 100644
index 0000000..3af5e2c
--- /dev/null
+++ b/src/guppy/orchestration/save_parameters.py
@@ -0,0 +1,41 @@
+import json
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+def save_parameters(inputParameters: dict):
+ logger.debug("Saving Input Parameters file.")
+ analysisParameters = {
+ "combine_data": inputParameters["combine_data"],
+ "isosbestic_control": inputParameters["isosbestic_control"],
+ "timeForLightsTurnOn": inputParameters["timeForLightsTurnOn"],
+ "filter_window": inputParameters["filter_window"],
+ "removeArtifacts": inputParameters["removeArtifacts"],
+ "noChannels": inputParameters["noChannels"],
+ "zscore_method": inputParameters["zscore_method"],
+ "baselineWindowStart": inputParameters["baselineWindowStart"],
+ "baselineWindowEnd": inputParameters["baselineWindowEnd"],
+ "nSecPrev": inputParameters["nSecPrev"],
+ "nSecPost": inputParameters["nSecPost"],
+ "timeInterval": inputParameters["timeInterval"],
+ "bin_psth_trials": inputParameters["bin_psth_trials"],
+ "use_time_or_trials": inputParameters["use_time_or_trials"],
+ "baselineCorrectionStart": inputParameters["baselineCorrectionStart"],
+ "baselineCorrectionEnd": inputParameters["baselineCorrectionEnd"],
+ "peak_startPoint": inputParameters["peak_startPoint"],
+ "peak_endPoint": inputParameters["peak_endPoint"],
+ "selectForComputePsth": inputParameters["selectForComputePsth"],
+ "selectForTransientsComputation": inputParameters["selectForTransientsComputation"],
+ "moving_window": inputParameters["moving_window"],
+ "highAmpFilt": inputParameters["highAmpFilt"],
+ "transientsThresh": inputParameters["transientsThresh"],
+ }
+ for folder in inputParameters["folderNames"]:
+ with open(os.path.join(folder, "GuPPyParamtersUsed.json"), "w") as f:
+ json.dump(analysisParameters, f, indent=4)
+ logger.info(f"Input Parameters file saved at {folder}")
+
+ logger.info("#" * 400)
+ logger.info("Input Parameters File Saved.")
diff --git a/src/guppy/orchestration/storenames.py b/src/guppy/orchestration/storenames.py
new file mode 100755
index 0000000..4a015d9
--- /dev/null
+++ b/src/guppy/orchestration/storenames.py
@@ -0,0 +1,330 @@
+import glob
+import json
+import logging
+import os
+from pathlib import Path
+
+import holoviews as hv # noqa: F401
+import numpy as np
+import panel as pn
+
+from guppy.extractors import (
+ CsvRecordingExtractor,
+ DoricRecordingExtractor,
+ NpmRecordingExtractor,
+ TdtRecordingExtractor,
+)
+from guppy.frontend.frontend_utils import scanPortsAndFind
+from guppy.frontend.npm_gui_prompts import (
+ get_multi_event_responses,
+ get_timestamp_configuration,
+)
+from guppy.frontend.storenames_instructions import (
+ StorenamesInstructions,
+ StorenamesInstructionsNPM,
+)
+from guppy.frontend.storenames_selector import StorenamesSelector
+from guppy.utils.utils import takeOnlyDirs
+
+pn.extension()
+
+logger = logging.getLogger(__name__)
+
+
+# function to show location for over-writing or creating a new stores list file.
+def show_dir(filepath):
+ i = 1
+ while True:
+ basename = os.path.basename(filepath)
+ op = os.path.join(filepath, basename + "_output_" + str(i))
+ if not os.path.exists(op):
+ break
+ i += 1
+ return op
+
+
+def make_dir(filepath):
+ i = 1
+ while True:
+ basename = os.path.basename(filepath)
+ op = os.path.join(filepath, basename + "_output_" + str(i))
+ if not os.path.exists(op):
+ os.mkdir(op)
+ break
+ i += 1
+
+ return op
+
+
+def _fetchValues(text, storenames, storename_dropdowns, storename_textboxes, d):
+ if not storename_dropdowns or not len(storenames) > 0:
+ return "####Alert !! \n No storenames selected."
+
+ storenames_cache = dict()
+ if os.path.exists(os.path.join(Path.home(), ".storesList.json")):
+ with open(os.path.join(Path.home(), ".storesList.json")) as f:
+ storenames_cache = json.load(f)
+
+ comboBoxValues, textBoxValues = [], []
+ dropdown_keys = list(storename_dropdowns.keys())
+ textbox_keys = list(storename_textboxes.keys()) if storename_textboxes else []
+
+ # Get dropdown values
+ for key in dropdown_keys:
+ comboBoxValues.append(storename_dropdowns[key].value)
+
+ # Get textbox values (matching with dropdown keys)
+ for key in dropdown_keys:
+ if key in storename_textboxes:
+ textbox_value = storename_textboxes[key].value or ""
+ textBoxValues.append(textbox_value)
+
+ # Validation: Check for whitespace
+ if len(textbox_value.split()) > 1:
+ return "####Alert !! \n Whitespace is not allowed in the text box entry."
+
+ # Validation: Check for empty required fields
+ dropdown_value = storename_dropdowns[key].value
+ if (
+ not textbox_value
+ and dropdown_value not in storenames_cache
+ and dropdown_value in ["control", "signal", "event TTLs"]
+ ):
+ return "####Alert !! \n One of the text box entry is empty."
+ else:
+ # For cached values, use the dropdown value directly
+ textBoxValues.append(storename_dropdowns[key].value)
+
+ if len(comboBoxValues) != len(textBoxValues):
+ return "####Alert !! \n Number of entries in combo box and text box should be same."
+
+ names_for_storenames = []
+ for i in range(len(comboBoxValues)):
+ if comboBoxValues[i] == "control" or comboBoxValues[i] == "signal":
+ if "_" in textBoxValues[i]:
+ return "####Alert !! \n Please do not use underscore in region name."
+ names_for_storenames.append("{}_{}".format(comboBoxValues[i], textBoxValues[i]))
+ elif comboBoxValues[i] == "event TTLs":
+ names_for_storenames.append(textBoxValues[i])
+ else:
+ names_for_storenames.append(comboBoxValues[i])
+
+ d["storenames"] = text.value
+ d["names_for_storenames"] = names_for_storenames
+ return "#### No alerts !!"
+
+
+def _save(d, select_location):
+ arr1, arr2 = np.asarray(d["storenames"]), np.asarray(d["names_for_storenames"])
+
+ if np.where(arr2 == "")[0].size > 0:
+ alert_message = "#### Alert !! \n Empty string in the list names_for_storenames."
+ logger.error("Empty string in the list names_for_storenames.")
+ return alert_message
+
+ if arr1.shape[0] != arr2.shape[0]:
+ alert_message = "#### Alert !! \n Length of list storenames and names_for_storenames is not equal."
+ logger.error("Length of list storenames and names_for_storenames is not equal.")
+ return alert_message
+
+ if not os.path.exists(os.path.join(Path.home(), ".storesList.json")):
+ storenames_cache = dict()
+
+ for i in range(arr1.shape[0]):
+ if arr1[i] in storenames_cache:
+ storenames_cache[arr1[i]].append(arr2[i])
+ storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]]))
+ else:
+ storenames_cache[arr1[i]] = [arr2[i]]
+
+ with open(os.path.join(Path.home(), ".storesList.json"), "w") as f:
+ json.dump(storenames_cache, f, indent=4)
+ else:
+ with open(os.path.join(Path.home(), ".storesList.json")) as f:
+ storenames_cache = json.load(f)
+
+ for i in range(arr1.shape[0]):
+ if arr1[i] in storenames_cache:
+ storenames_cache[arr1[i]].append(arr2[i])
+ storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]]))
+ else:
+ storenames_cache[arr1[i]] = [arr2[i]]
+
+ with open(os.path.join(Path.home(), ".storesList.json"), "w") as f:
+ json.dump(storenames_cache, f, indent=4)
+
+ arr = np.asarray([arr1, arr2])
+ logger.info(arr)
+ if not os.path.exists(select_location):
+ os.mkdir(select_location)
+
+ np.savetxt(os.path.join(select_location, "storesList.csv"), arr, delimiter=",", fmt="%s")
+ logger.info(f"Storeslist file saved at {select_location}")
+ logger.info("Storeslist : \n" + str(arr))
+ return "#### No alerts !!"
+
+
+# function to show GUI and save
+def build_storenames_page(inputParameters, events, flags, folder_path):
+
+ logger.debug("Saving stores list file.")
+ # getting input parameters
+ inputParameters = inputParameters
+
+ # Headless path: if storenames_map provided, write storesList.csv without building the Panel UI
+ storenames_map = inputParameters.get("storenames_map")
+ if isinstance(storenames_map, dict) and len(storenames_map) > 0:
+ op = make_dir(folder_path)
+ arr = np.asarray([list(storenames_map.keys()), list(storenames_map.values())], dtype=str)
+ np.savetxt(os.path.join(op, "storesList.csv"), arr, delimiter=",", fmt="%s")
+ logger.info(f"Storeslist file saved at {op}")
+ logger.info("Storeslist : \n" + str(arr))
+ return
+
+ # Get storenames from extractor's events property
+ allnames = events
+
+ # creating GUI template
+ template = pn.template.BootstrapTemplate(title="Storenames GUI - {}".format(os.path.basename(folder_path)))
+
+ if "data_np_v2" in flags or "data_np" in flags or "event_np" in flags:
+ storenames_instructions = StorenamesInstructionsNPM(folder_path=folder_path)
+ else:
+ storenames_instructions = StorenamesInstructions(folder_path=folder_path)
+ storenames_selector = StorenamesSelector(allnames=allnames)
+
+ storenames = []
+ storename_dropdowns = {}
+ storename_textboxes = {}
+
+ # ------------------------------------------------------------------------------------------------------------------
+ # onclick closure functions
+ # on clicking overwrite_button, following function is executed
+ def overwrite_button_actions(event):
+ if event.new == "over_write_file":
+ options = takeOnlyDirs(glob.glob(os.path.join(folder_path, "*_output_*")))
+ storenames_selector.set_select_location_options(options=options)
+ else:
+ options = [show_dir(folder_path)]
+ storenames_selector.set_select_location_options(options=options)
+
+ def fetchValues(event):
+ global storenames
+ d = dict()
+ alert_message = _fetchValues(
+ text=storenames_selector.text,
+ storenames=storenames,
+ storename_dropdowns=storename_dropdowns,
+ storename_textboxes=storename_textboxes,
+ d=d,
+ )
+ storenames_selector.set_alert_message(alert_message)
+ storenames_selector.set_literal_input_2(d=d)
+
+ # on clicking 'Select Storenames' button, following function is executed
+ def update_values(event):
+ global storenames, vars_list
+
+ arr = storenames_selector.get_take_widgets()
+ new_arr = []
+ for i in range(len(arr[1])):
+ for j in range(arr[1][i]):
+ new_arr.append(arr[0][i])
+ if len(new_arr) > 0:
+ storenames = storenames_selector.get_cross_selector() + new_arr
+ else:
+ storenames = storenames_selector.get_cross_selector()
+ storenames_selector.set_change_widgets(storenames)
+
+ storenames_cache = dict()
+ if os.path.exists(os.path.join(Path.home(), ".storesList.json")):
+ with open(os.path.join(Path.home(), ".storesList.json")) as f:
+ storenames_cache = json.load(f)
+
+ storenames_selector.configure_storenames(
+ storename_dropdowns=storename_dropdowns,
+ storename_textboxes=storename_textboxes,
+ storenames=storenames,
+ storenames_cache=storenames_cache,
+ )
+
+ # on clicking save button, following function is executed
+ def save_button(event=None):
+ global storenames
+ d = storenames_selector.get_literal_input_2()
+ select_location = storenames_selector.get_select_location()
+ alert_message = _save(d=d, select_location=select_location)
+ storenames_selector.set_alert_message(alert_message)
+ storenames_selector.set_path(os.path.join(select_location, "storesList.csv"))
+
+ # ------------------------------------------------------------------------------------------------------------------
+
+ # Connect button callbacks
+ button_name_to_onclick_fn = {
+ "update_options": update_values,
+ "save": save_button,
+ "overwrite_button": overwrite_button_actions,
+ "show_config_button": fetchValues,
+ }
+ storenames_selector.attach_callbacks(button_name_to_onclick_fn)
+
+ template.main.append(pn.Row(storenames_instructions.widget, storenames_selector.widget))
+
+ # creating widgets, adding them to template and showing a GUI on a new browser window
+ number = scanPortsAndFind(start_port=5000, end_port=5200)
+ template.show(port=number)
+
+
+def read_header(inputParameters, num_ch, modality, folder_path, headless):
+ if modality == "tdt":
+ events, flags = TdtRecordingExtractor.discover_events_and_flags(folder_path=folder_path)
+ elif modality == "csv":
+ events, flags = CsvRecordingExtractor.discover_events_and_flags(folder_path=folder_path)
+
+ elif modality == "doric":
+ events, flags = DoricRecordingExtractor.discover_events_and_flags(folder_path=folder_path)
+
+ elif modality == "npm":
+ if not headless:
+ # Resolve multiple event TTLs
+ multiple_event_ttls = NpmRecordingExtractor.has_multiple_event_ttls(folder_path=folder_path)
+ responses = get_multi_event_responses(multiple_event_ttls)
+ inputParameters["npm_split_events"] = responses
+
+ # Resolve timestamp units and columns
+ ts_unit_needs, col_names_ts = NpmRecordingExtractor.needs_ts_unit(folder_path=folder_path, num_ch=num_ch)
+ ts_units, npm_timestamp_column_names = get_timestamp_configuration(ts_unit_needs, col_names_ts)
+ inputParameters["npm_time_units"] = ts_units if ts_units else None
+ inputParameters["npm_timestamp_column_names"] = (
+ npm_timestamp_column_names if npm_timestamp_column_names else None
+ )
+
+ events, flags = NpmRecordingExtractor.discover_events_and_flags(
+ folder_path=folder_path, num_ch=num_ch, inputParameters=inputParameters
+ )
+ else:
+ raise ValueError("Modality not recognized. Please use 'tdt', 'csv', 'doric', or 'npm'.")
+ return events, flags
+
+
+# function to read input parameters and run the saveStorenames function
+def orchestrate_storenames_page(inputParameters):
+
+ inputParameters = inputParameters
+ folderNames = inputParameters["folderNames"]
+ isosbestic_control = inputParameters["isosbestic_control"]
+ num_ch = inputParameters["noChannels"]
+ modality = inputParameters.get("modality", "tdt")
+ headless = bool(os.environ.get("GUPPY_BASE_DIR"))
+
+ logger.info(folderNames)
+
+ try:
+ for i in folderNames:
+ folder_path = os.path.join(inputParameters["abspath"], i)
+ events, flags = read_header(inputParameters, num_ch, modality, folder_path, headless)
+ build_storenames_page(inputParameters, events, flags, folder_path)
+ logger.info("#" * 400)
+ except Exception as e:
+ logger.error(str(e))
+ raise e
diff --git a/src/guppy/findTransientsFreqAndAmp.py b/src/guppy/orchestration/transients.py
similarity index 64%
rename from src/guppy/findTransientsFreqAndAmp.py
rename to src/guppy/orchestration/transients.py
index f6c3d6e..4d899da 100755
--- a/src/guppy/findTransientsFreqAndAmp.py
+++ b/src/guppy/orchestration/transients.py
@@ -8,39 +8,24 @@
import matplotlib.pyplot as plt
import numpy as np
-from .analysis.io_utils import (
- get_all_stores_for_combining_data,
+from ..analysis.io_utils import (
read_hdf5,
- takeOnlyDirs,
)
-from .analysis.standard_io import (
+from ..analysis.standard_io import (
+ read_transients_from_hdf5,
write_freq_and_amp_to_csv,
write_freq_and_amp_to_hdf5,
+ write_transients_to_hdf5,
)
-from .analysis.transients import analyze_transients
-from .analysis.transients_average import averageForGroup
+from ..analysis.transients import analyze_transients
+from ..analysis.transients_average import averageForGroup
+from ..frontend.progress import writeToFile
+from ..utils.utils import get_all_stores_for_combining_data, takeOnlyDirs
+from ..visualization.transients import visualize_peaks
logger = logging.getLogger(__name__)
-def writeToFile(value: str):
- with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file:
- file.write(value)
-
-
-def visuzlize_peaks(filepath, z_score, timestamps, peaksIndex):
-
- dirname = os.path.dirname(filepath)
-
- basename = (os.path.basename(filepath)).split(".")[0]
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(timestamps, z_score, "-", timestamps[peaksIndex], z_score[peaksIndex], "o")
- ax.set_title(basename)
- fig.suptitle(os.path.basename(dirname))
- # plt.show()
-
-
def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_count()):
logger.debug("Calculating frequency and amplitude of transients in z-score data....")
@@ -76,10 +61,68 @@ def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_cou
index=np.arange(peaks_occurrences.shape[0]),
columns=["timestamps", "amplitude"],
)
- visuzlize_peaks(path[i], z_score, ts, peaksInd)
+ write_transients_to_hdf5(filepath, basename, z_score, ts, peaksInd)
logger.info("Frequency and amplitude of transients in z_score data are calculated.")
+def execute_visualize_peaks(folderNames, inputParameters):
+ selectForTransientsComputation = inputParameters["selectForTransientsComputation"]
+ for i in range(len(folderNames)):
+ logger.debug(f"Finding transients in z-score data of {folderNames[i]} and calculating frequency and amplitude.")
+ filepath = folderNames[i]
+ storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))
+ for j in range(len(storesListPath)):
+ filepath = storesListPath[j]
+ if selectForTransientsComputation == "z_score":
+ path = glob.glob(os.path.join(filepath, "z_score_*"))
+ elif selectForTransientsComputation == "dff":
+ path = glob.glob(os.path.join(filepath, "dff_*"))
+ else:
+ path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*"))
+
+ for i in range(len(path)):
+ basename = (os.path.basename(path[i])).split(".")[0]
+ z_score, ts, peaksInd = read_transients_from_hdf5(filepath, basename)
+
+ suptitle = os.path.basename(os.path.dirname(path[i]))
+ title = (os.path.basename(path[i])).split(".")[0]
+ visualize_peaks(title, suptitle, z_score, ts, peaksInd)
+
+ logger.info("Frequency and amplitude of transients in z_score data are visualized.")
+ plt.show()
+
+
+def execute_visualize_peaks_combined(folderNames, inputParameters):
+ selectForTransientsComputation = inputParameters["selectForTransientsComputation"]
+
+ storesListPath = []
+ for i in range(len(folderNames)):
+ filepath = folderNames[i]
+ storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
+ storesListPath = list(np.concatenate(storesListPath).flatten())
+ op = get_all_stores_for_combining_data(storesListPath)
+ for i in range(len(op)):
+ filepath = op[i][0]
+
+ if selectForTransientsComputation == "z_score":
+ path = glob.glob(os.path.join(filepath, "z_score_*"))
+ elif selectForTransientsComputation == "dff":
+ path = glob.glob(os.path.join(filepath, "dff_*"))
+ else:
+ path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*"))
+
+ for i in range(len(path)):
+ basename = (os.path.basename(path[i])).split(".")[0]
+ z_score, ts, peaksInd = read_transients_from_hdf5(filepath, basename)
+
+ suptitle = os.path.basename(os.path.dirname(path[i]))
+ title = (os.path.basename(path[i])).split(".")[0]
+ visualize_peaks(title, suptitle, z_score, ts, peaksInd)
+
+ logger.info("Frequency and amplitude of transients in z_score data are calculated.")
+ plt.show()
+
+
def executeFindFreqAndAmp(inputParameters):
logger.info("Finding transients in z-score data and calculating frequency and amplitude....")
@@ -106,8 +149,10 @@ def executeFindFreqAndAmp(inputParameters):
else:
if combine_data == True:
execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_window, numProcesses)
+ execute_visualize_peaks_combined(folderNames, inputParameters)
else:
execute_find_freq_and_amp(inputParameters, folderNames, moving_window, numProcesses)
+ execute_visualize_peaks(folderNames, inputParameters)
logger.info("Transients in z-score data found and frequency and amplitude are calculated.")
@@ -126,7 +171,6 @@ def execute_find_freq_and_amp(inputParameters, folderNames, moving_window, numPr
writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n")
inputParameters["step"] += 1
logger.info("Transients in z-score data found and frequency and amplitude are calculated.")
- plt.show()
def execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_window, numProcesses):
@@ -142,7 +186,6 @@ def execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_wind
findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses)
writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n")
inputParameters["step"] += 1
- plt.show()
def execute_average_for_group(inputParameters, folderNamesForAvg):
diff --git a/src/guppy/orchestration/visualize.py b/src/guppy/orchestration/visualize.py
new file mode 100755
index 0000000..a4149f9
--- /dev/null
+++ b/src/guppy/orchestration/visualize.py
@@ -0,0 +1,257 @@
+import glob
+import logging
+import os
+import re
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+
+from ..frontend.parameterized_plotter import ParameterizedPlotter, remove_cols
+from ..frontend.visualization_dashboard import VisualizationDashboard
+from ..utils.utils import get_all_stores_for_combining_data, read_Df, takeOnlyDirs
+
+logger = logging.getLogger(__name__)
+
+
+# helper function to create plots
+def helper_plots(filepath, event, name, inputParameters):
+
+ basename = os.path.basename(filepath)
+ visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"]
+
+ # note when there are no behavior event TTLs
+ if len(event) == 0:
+ logger.warning("\033[1m" + "There are no behavior event TTLs present to visualize.".format(event) + "\033[0m")
+ return 0
+
+ if os.path.exists(os.path.join(filepath, "cross_correlation_output")):
+ event_corr, frames = [], []
+ if visualize_zscore_or_dff == "z_score":
+ corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_z_score_*"))
+ elif visualize_zscore_or_dff == "dff":
+ corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_dff_*"))
+ for i in range(len(corr_fp)):
+ filename = os.path.basename(corr_fp[i]).split(".")[0]
+ event_corr.append(filename)
+ df = pd.read_hdf(corr_fp[i], key="df", mode="r")
+ frames.append(df)
+ if len(frames) > 0:
+ df_corr = pd.concat(frames, keys=event_corr, axis=1)
+ else:
+ event_corr = []
+ df_corr = []
+ else:
+ event_corr = []
+ df_corr = None
+
+ # combine all the event PSTH so that it can be viewed together
+ if name:
+ event_name, name = event, name
+ new_event, frames, bins = [], [], {}
+ for i in range(len(event_name)):
+
+ for j in range(len(name)):
+ new_event.append(event_name[i] + "_" + name[j].split("_")[-1])
+ new_name = name[j]
+ temp_df = read_Df(filepath, new_event[-1], new_name)
+ cols = list(temp_df.columns)
+ regex = re.compile("bin_[(]")
+ bins[new_event[-1]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
+ # bins.append(keep_cols)
+ frames.append(temp_df)
+
+ df = pd.concat(frames, keys=new_event, axis=1)
+ else:
+ new_event = list(np.unique(np.array(event)))
+ frames, bins = [], {}
+ for i in range(len(new_event)):
+ temp_df = read_Df(filepath, new_event[i], "")
+ cols = list(temp_df.columns)
+ regex = re.compile("bin_[(]")
+ bins[new_event[i]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
+ frames.append(temp_df)
+
+ df = pd.concat(frames, keys=new_event, axis=1)
+
+ if isinstance(df_corr, pd.DataFrame):
+ new_event.extend(event_corr)
+ df = pd.concat([df, df_corr], axis=1, sort=False).reset_index()
+
+ columns_dict = dict()
+ for i in range(len(new_event)):
+ df_1 = df[new_event[i]]
+ columns = list(df_1.columns)
+ columns.append("All")
+ columns_dict[new_event[i]] = columns
+
+ # make options array for different selectors
+ multiple_plots_options = []
+ heatmap_options = new_event
+ bins_keys = list(bins.keys())
+ if len(bins_keys) > 0:
+ bins_new = bins
+ for i in range(len(bins_keys)):
+ arr = bins[bins_keys[i]]
+ if len(arr) > 0:
+ # heatmap_options.append('{}_bin'.format(bins_keys[i]))
+ for j in arr:
+ multiple_plots_options.append("{}_{}".format(bins_keys[i], j))
+
+ multiple_plots_options = new_event + multiple_plots_options
+ else:
+ multiple_plots_options = new_event
+ x_min = float(inputParameters["nSecPrev"]) - 20
+ x_max = float(inputParameters["nSecPost"]) + 20
+ colormaps = plt.colormaps()
+ new_colormaps = ["plasma", "plasma_r", "magma", "magma_r", "inferno", "inferno_r", "viridis", "viridis_r"]
+ set_a = set(colormaps)
+ set_b = set(new_colormaps)
+ colormaps = new_colormaps + list(set_a.difference(set_b))
+ x = [columns_dict[new_event[0]][-4]]
+ y = remove_cols(columns_dict[new_event[0]])
+ trial_no = range(1, len(remove_cols(columns_dict[heatmap_options[0]])[:-2]) + 1)
+ trial_ts = [
+ "{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(columns_dict[heatmap_options[0]])[:-2])
+ ] + ["All"]
+
+ plotter = ParameterizedPlotter(
+ event_selector_objects=new_event,
+ event_selector_heatmap_objects=heatmap_options,
+ selector_for_multipe_events_plot_objects=multiple_plots_options,
+ columns_dict=columns_dict,
+ df_new=df,
+ x_min=x_min,
+ x_max=x_max,
+ color_map_objects=colormaps,
+ filepath=filepath,
+ x_objects=x,
+ y_objects=y,
+ heatmap_y_objects=trial_ts,
+ psth_y_objects=trial_ts[:-1],
+ )
+ dashboard = VisualizationDashboard(plotter=plotter, basename=basename)
+ dashboard.show()
+
+
+# function to combine all the output folders together and preprocess them to use them in helper_plots function
+def createPlots(filepath, event, inputParameters):
+
+ for i in range(len(event)):
+ event[i] = event[i].replace("\\", "_")
+ event[i] = event[i].replace("/", "_")
+
+ average = inputParameters["visualizeAverageResults"]
+ visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"]
+
+ if average == True:
+ path = []
+ for i in range(len(event)):
+ if visualize_zscore_or_dff == "z_score":
+ path.append(glob.glob(os.path.join(filepath, event[i] + "*_z_score_*")))
+ elif visualize_zscore_or_dff == "dff":
+ path.append(glob.glob(os.path.join(filepath, event[i] + "*_dff_*")))
+
+ path = np.concatenate(path)
+ else:
+ if visualize_zscore_or_dff == "z_score":
+ path = glob.glob(os.path.join(filepath, "z_score_*"))
+ elif visualize_zscore_or_dff == "dff":
+ path = glob.glob(os.path.join(filepath, "dff_*"))
+
+ name_arr = []
+ event_arr = []
+
+ index = []
+ for i in range(len(event)):
+ if "control" in event[i].lower() or "signal" in event[i].lower():
+ index.append(i)
+
+ event = np.delete(event, index)
+
+ for i in range(len(path)):
+ name = (os.path.basename(path[i])).split(".")
+ name = name[0]
+ name_arr.append(name)
+
+ if average == True:
+ logger.info("average")
+ helper_plots(filepath, name_arr, "", inputParameters)
+ else:
+ helper_plots(filepath, event, name_arr, inputParameters)
+
+
+def visualizeResults(inputParameters):
+
+ inputParameters = inputParameters
+
+ average = inputParameters["visualizeAverageResults"]
+ logger.info(average)
+
+ folderNames = inputParameters["folderNames"]
+ folderNamesForAvg = inputParameters["folderNamesForAvg"]
+ combine_data = inputParameters["combine_data"]
+
+ if average == True and len(folderNamesForAvg) > 0:
+ # folderNames = folderNamesForAvg
+ filepath_avg = os.path.join(inputParameters["abspath"], "average")
+ # filepath = os.path.join(inputParameters['abspath'], folderNames[0])
+ storesListPath = []
+ for i in range(len(folderNamesForAvg)):
+ filepath = folderNamesForAvg[i]
+ storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
+ storesListPath = np.concatenate(storesListPath)
+ storesList = np.asarray([[], []])
+ for i in range(storesListPath.shape[0]):
+ storesList = np.concatenate(
+ (
+ storesList,
+ np.genfromtxt(
+ os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter=","
+ ).reshape(2, -1),
+ ),
+ axis=1,
+ )
+ storesList = np.unique(storesList, axis=1)
+
+ createPlots(filepath_avg, np.unique(storesList[1, :]), inputParameters)
+
+ else:
+ if combine_data == True:
+ storesListPath = []
+ for i in range(len(folderNames)):
+ filepath = folderNames[i]
+ storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
+ storesListPath = list(np.concatenate(storesListPath).flatten())
+ op = get_all_stores_for_combining_data(storesListPath)
+ for i in range(len(op)):
+ storesList = np.asarray([[], []])
+ for j in range(len(op[i])):
+ storesList = np.concatenate(
+ (
+ storesList,
+ np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape(
+ 2, -1
+ ),
+ ),
+ axis=1,
+ )
+ storesList = np.unique(storesList, axis=1)
+ filepath = op[i][0]
+ createPlots(filepath, storesList[1, :], inputParameters)
+ else:
+ for i in range(len(folderNames)):
+
+ filepath = folderNames[i]
+ storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))
+ for j in range(len(storesListPath)):
+ filepath = storesListPath[j]
+ storesList = np.genfromtxt(
+ os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=","
+ ).reshape(2, -1)
+
+ createPlots(filepath, storesList[1, :], inputParameters)
+
+
+# logger.info(sys.argv[1:])
+# visualizeResults(sys.argv[1:][0])
diff --git a/src/guppy/saveStoresList.py b/src/guppy/saveStoresList.py
deleted file mode 100755
index 20a5c94..0000000
--- a/src/guppy/saveStoresList.py
+++ /dev/null
@@ -1,712 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-# In[1]:
-
-
-import glob
-import json
-import logging
-import os
-import socket
-import tkinter as tk
-from pathlib import Path
-from random import randint
-from tkinter import StringVar, messagebox, ttk
-
-import holoviews as hv
-import numpy as np
-import pandas as pd
-import panel as pn
-
-from guppy.extractors import (
- CsvRecordingExtractor,
- DoricRecordingExtractor,
- NpmRecordingExtractor,
- TdtRecordingExtractor,
-)
-
-# hv.extension()
-pn.extension()
-
-logger = logging.getLogger(__name__)
-
-
-def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"):
- while True:
- port = randint(start_port, end_port)
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.settimeout(0.001) # Set timeout to avoid long waiting on closed ports
- result = sock.connect_ex((host, port))
- if result == 0: # If the connection is successful, the port is open
- continue
- else:
- break
-
- return port
-
-
-def takeOnlyDirs(paths):
- removePaths = []
- for p in paths:
- if os.path.isfile(p):
- removePaths.append(p)
- return list(set(paths) - set(removePaths))
-
-
-# function to show location for over-writing or creating a new stores list file.
-def show_dir(filepath):
- i = 1
- while True:
- basename = os.path.basename(filepath)
- op = os.path.join(filepath, basename + "_output_" + str(i))
- if not os.path.exists(op):
- break
- i += 1
- return op
-
-
-def make_dir(filepath):
- i = 1
- while True:
- basename = os.path.basename(filepath)
- op = os.path.join(filepath, basename + "_output_" + str(i))
- if not os.path.exists(op):
- os.mkdir(op)
- break
- i += 1
-
- return op
-
-
-# function to show GUI and save
-def saveStorenames(inputParameters, events, flags, folder_path):
-
- logger.debug("Saving stores list file.")
- # getting input parameters
- inputParameters = inputParameters
-
- # Headless path: if storenames_map provided, write storesList.csv without building the Panel UI
- storenames_map = inputParameters.get("storenames_map")
- if isinstance(storenames_map, dict) and len(storenames_map) > 0:
- op = make_dir(folder_path)
- arr = np.asarray([list(storenames_map.keys()), list(storenames_map.values())], dtype=str)
- np.savetxt(os.path.join(op, "storesList.csv"), arr, delimiter=",", fmt="%s")
- logger.info(f"Storeslist file saved at {op}")
- logger.info("Storeslist : \n" + str(arr))
- return
-
- # Get storenames from extractor's events property
- allnames = events
-
- if "data_np_v2" in flags or "data_np" in flags or "event_np" in flags:
- path_chev = glob.glob(os.path.join(folder_path, "*chev*"))
- path_chod = glob.glob(os.path.join(folder_path, "*chod*"))
- path_chpr = glob.glob(os.path.join(folder_path, "*chpr*"))
- combine_paths = path_chev + path_chod + path_chpr
- d = dict()
- for i in range(len(combine_paths)):
- basename = (os.path.basename(combine_paths[i])).split(".")[0]
- df = pd.read_csv(combine_paths[i])
- d[basename] = {"x": np.array(df["timestamps"]), "y": np.array(df["data"])}
- keys = list(d.keys())
- mark_down_np = pn.pane.Markdown(
- """
- ### Extra Instructions to follow when using Neurophotometrics data :
- - Guppy will take the NPM data, which has interleaved frames
- from the signal and control channels, and divide it out into
- separate channels for each site you recordded.
- However, since NPM does not automatically annotate which
- frames belong to the signal channel and which belong to the
- control channel, the user must specify this for GuPPy.
- - Each of your recording sites will have a channel
- named “chod” and a channel named “chev”
- - View the plots below and, for each site,
- determine whether the “chev” or “chod” channel is signal or control
- - When you give your storenames, name the channels appropriately.
- For example, “chev1” might be “signal_A” and
- “chod1” might be “control_A” (or vice versa).
-
- """
- )
- plot_select = pn.widgets.Select(
- name="Select channel to see correspondings channels", options=keys, value=keys[0]
- )
-
- @pn.depends(plot_select=plot_select)
- def plot(plot_select):
- return hv.Curve((d[plot_select]["x"], d[plot_select]["y"])).opts(width=550)
-
- else:
- pass
-
- # instructions about how to save the storeslist file
- mark_down = pn.pane.Markdown(
- """
-
-
- ### Instructions to follow :
-
- - Check Storenames to repeat checkbox and see instructions in “Github Wiki” for duplicating storenames.
- Otherwise do not check the Storenames to repeat checkbox.
- - Select storenames from list and click “Select Storenames” to populate area below.
- - Enter names for storenames, in order, using the following naming convention:
- Isosbestic = “control_region” (ex: Dv1A= control_DMS)
- Signal= “signal_region” (ex: Dv2A= signal_DMS)
- TTLs can be named using any convention (ex: PrtR = RewardedPortEntries) but should be kept consistent for later group analysis
-
- ```
- {"storenames": ["Dv1A", "Dv2A",
- "Dv3B", "Dv4B",
- "LNRW", "LNnR",
- "PrtN", "PrtR",
- "RNPS"],
- "names_for_storenames": ["control_DMS", "signal_DMS",
- "control_DLS", "signal_DLS",
- "RewardedNosepoke", "UnrewardedNosepoke",
- "UnrewardedPort", "RewardedPort",
- "InactiveNosepoke"]}
- ```
- - If user has saved storenames before, clicking "Select Storenames" button will pop up a dialog box
- showing previously used names for storenames. Select names for storenames by checking a checkbox and
- click on "Show" to populate the text area in the Storenames GUI. Close the dialog box.
-
- - Select “create new” or “overwrite” to generate a new storenames list or replace a previous one
- - Click Save
-
- """,
- width=550,
- )
-
- # creating GUI template
- template = pn.template.BootstrapTemplate(
- title="Storenames GUI - {}".format(os.path.basename(folder_path), mark_down)
- )
-
- # creating different buttons and selectors for the GUI
- cross_selector = pn.widgets.CrossSelector(name="Store Names Selection", value=[], options=allnames, width=600)
- multi_choice = pn.widgets.MultiChoice(
- name="Select Storenames which you want more than once (multi-choice: multiple options selection)",
- value=[],
- options=allnames,
- )
-
- literal_input_1 = pn.widgets.LiteralInput(
- name="Number of times you want the above storename (list)", value=[], type=list
- )
- # literal_input_2 = pn.widgets.LiteralInput(name='Names for Storenames (list)', type=list)
-
- repeat_storenames = pn.widgets.Checkbox(name="Storenames to repeat", value=False)
- repeat_storename_wd = pn.WidgetBox("", width=600)
-
- def callback(target, event):
- if event.new == True:
- target.objects = [multi_choice, literal_input_1]
- elif event.new == False:
- target.clear()
-
- repeat_storenames.link(repeat_storename_wd, callbacks={"value": callback})
- # repeat_storename_wd = pn.WidgetBox('Storenames to repeat (leave blank if not needed)', multi_choice, literal_input_1, background="white", width=600)
-
- update_options = pn.widgets.Button(name="Select Storenames", width=600)
- save = pn.widgets.Button(name="Save", width=600)
-
- text = pn.widgets.LiteralInput(value=[], name="Selected Store Names", type=list, width=600)
-
- path = pn.widgets.TextInput(name="Location to Stores List file", width=600)
-
- mark_down_for_overwrite = pn.pane.Markdown(
- """ Select option from below if user wants to over-write a file or create a new file.
- **Creating a new file will make a new output folder and will get saved at that location.**
- If user selects to over-write a file **Select location of the file to over-write** will provide
- the existing options of the output folders where user needs to over-write the file""",
- width=600,
- )
-
- select_location = pn.widgets.Select(
- name="Select location of the file to over-write", value="None", options=["None"], width=600
- )
-
- overwrite_button = pn.widgets.MenuButton(
- name="over-write storeslist file or create a new one? ",
- items=["over_write_file", "create_new_file"],
- button_type="default",
- split=True,
- width=600,
- )
-
- literal_input_2 = pn.widgets.CodeEditor(value="""{}""", theme="tomorrow", language="json", height=250, width=600)
-
- alert = pn.pane.Alert("#### No alerts !!", alert_type="danger", height=80, width=600)
-
- take_widgets = pn.WidgetBox(multi_choice, literal_input_1)
-
- change_widgets = pn.WidgetBox(text)
-
- storenames = []
- storename_dropdowns = {}
- storename_textboxes = {}
-
- if len(allnames) == 0:
- alert.object = (
- "####Alert !! \n No storenames found. There are not any TDT files or csv files to look for storenames."
- )
-
- # on clicking overwrite_button, following function is executed
- def overwrite_button_actions(event):
- if event.new == "over_write_file":
- select_location.options = takeOnlyDirs(glob.glob(os.path.join(folder_path, "*_output_*")))
- # select_location.value = select_location.options[0]
- else:
- select_location.options = [show_dir(folder_path)]
- # select_location.value = select_location.options[0]
-
- def fetchValues(event):
- global storenames
- alert.object = "#### No alerts !!"
-
- if not storename_dropdowns or not len(storenames) > 0:
- alert.object = "####Alert !! \n No storenames selected."
- return
-
- storenames_cache = dict()
- if os.path.exists(os.path.join(Path.home(), ".storesList.json")):
- with open(os.path.join(Path.home(), ".storesList.json")) as f:
- storenames_cache = json.load(f)
-
- comboBoxValues, textBoxValues = [], []
- dropdown_keys = list(storename_dropdowns.keys())
- textbox_keys = list(storename_textboxes.keys()) if storename_textboxes else []
-
- # Get dropdown values
- for key in dropdown_keys:
- comboBoxValues.append(storename_dropdowns[key].value)
-
- # Get textbox values (matching with dropdown keys)
- for key in dropdown_keys:
- if key in storename_textboxes:
- textbox_value = storename_textboxes[key].value or ""
- textBoxValues.append(textbox_value)
-
- # Validation: Check for whitespace
- if len(textbox_value.split()) > 1:
- alert.object = "####Alert !! \n Whitespace is not allowed in the text box entry."
- return
-
- # Validation: Check for empty required fields
- dropdown_value = storename_dropdowns[key].value
- if (
- not textbox_value
- and dropdown_value not in storenames_cache
- and dropdown_value in ["control", "signal", "event TTLs"]
- ):
- alert.object = "####Alert !! \n One of the text box entry is empty."
- return
- else:
- # For cached values, use the dropdown value directly
- textBoxValues.append(storename_dropdowns[key].value)
-
- if len(comboBoxValues) != len(textBoxValues):
- alert.object = "####Alert !! \n Number of entries in combo box and text box should be same."
- return
-
- names_for_storenames = []
- for i in range(len(comboBoxValues)):
- if comboBoxValues[i] == "control" or comboBoxValues[i] == "signal":
- if "_" in textBoxValues[i]:
- alert.object = "####Alert !! \n Please do not use underscore in region name."
- return
- names_for_storenames.append("{}_{}".format(comboBoxValues[i], textBoxValues[i]))
- elif comboBoxValues[i] == "event TTLs":
- names_for_storenames.append(textBoxValues[i])
- else:
- names_for_storenames.append(comboBoxValues[i])
-
- d = dict()
- d["storenames"] = text.value
- d["names_for_storenames"] = names_for_storenames
- literal_input_2.value = str(json.dumps(d, indent=2))
-
- # Panel-based storename configuration (replaces Tkinter dialog)
- storename_config_widgets = pn.Column(visible=False)
- show_config_button = pn.widgets.Button(name="Show Selected Configuration", width=600)
-
- # on clicking 'Select Storenames' button, following function is executed
- def update_values(event):
- global storenames, vars_list
- arr = []
- for w in take_widgets:
- arr.append(w.value)
-
- new_arr = []
-
- for i in range(len(arr[1])):
- for j in range(arr[1][i]):
- new_arr.append(arr[0][i])
-
- if len(new_arr) > 0:
- storenames = cross_selector.value + new_arr
- else:
- storenames = cross_selector.value
-
- for w in change_widgets:
- w.value = storenames
-
- storenames_cache = dict()
- if os.path.exists(os.path.join(Path.home(), ".storesList.json")):
- with open(os.path.join(Path.home(), ".storesList.json")) as f:
- storenames_cache = json.load(f)
-
- # Create Panel widgets for storename configuration
- config_widgets = []
- storename_dropdowns.clear()
- storename_textboxes.clear()
-
- if len(storenames) > 0:
- config_widgets.append(
- pn.pane.Markdown(
- "## Configure Storenames\nSelect appropriate options for each storename and provide names as needed:"
- )
- )
-
- for i, storename in enumerate(storenames):
- # Create a row for each storename
- row_widgets = []
-
- # Label
- label = pn.pane.Markdown(f"**{storename}:**")
- row_widgets.append(label)
-
- # Dropdown options
- if storename in storenames_cache:
- options = storenames_cache[storename]
- default_value = options[0] if options else ""
- else:
- options = ["", "control", "signal", "event TTLs"]
- default_value = ""
-
- # Create unique key for widget
- widget_key = (
- f"{storename}_{i}"
- if f"{storename}_{i}" not in storename_dropdowns
- else f"{storename}_{i}_{len(storename_dropdowns)}"
- )
-
- dropdown = pn.widgets.Select(name="Type", value=default_value, options=options, width=150)
- storename_dropdowns[widget_key] = dropdown
- row_widgets.append(dropdown)
-
- # Text input (only show if not cached or if control/signal/event TTLs selected)
- if storename not in storenames_cache or default_value in ["control", "signal", "event TTLs"]:
- textbox = pn.widgets.TextInput(
- name="Name", value="", placeholder="Enter region/event name", width=200
- )
- storename_textboxes[widget_key] = textbox
- row_widgets.append(textbox)
-
- # Add helper text based on selection
- def create_help_function(dropdown_widget, help_pane_container):
- @pn.depends(dropdown_widget.param.value, watch=True)
- def update_help(dropdown_value):
- if dropdown_value == "control":
- help_pane_container[0] = pn.pane.Markdown(
- "*Type appropriate region name*", styles={"color": "gray", "font-size": "12px"}
- )
- elif dropdown_value == "signal":
- help_pane_container[0] = pn.pane.Markdown(
- "*Type appropriate region name*", styles={"color": "gray", "font-size": "12px"}
- )
- elif dropdown_value == "event TTLs":
- help_pane_container[0] = pn.pane.Markdown(
- "*Type event name for the TTLs*", styles={"color": "gray", "font-size": "12px"}
- )
- else:
- help_pane_container[0] = pn.pane.Markdown(
- "", styles={"color": "gray", "font-size": "12px"}
- )
-
- return update_help
-
- help_container = [pn.pane.Markdown("")]
- help_function = create_help_function(dropdown, help_container)
- help_function(dropdown.value) # Initialize
- row_widgets.append(help_container[0])
-
- # Add the row to config widgets
- config_widgets.append(pn.Row(*row_widgets, margin=(5, 0)))
-
- # Add show button
- config_widgets.append(pn.Spacer(height=20))
- config_widgets.append(show_config_button)
- config_widgets.append(
- pn.pane.Markdown(
- "*Click 'Show Selected Configuration' to apply your selections.*",
- styles={"font-size": "12px", "color": "gray"},
- )
- )
-
- # Update the configuration panel
- storename_config_widgets.objects = config_widgets
- storename_config_widgets.visible = len(storenames) > 0
-
- # on clicking save button, following function is executed
- def save_button(event=None):
- global storenames
-
- d = json.loads(literal_input_2.value)
- arr1, arr2 = np.asarray(d["storenames"]), np.asarray(d["names_for_storenames"])
-
- if np.where(arr2 == "")[0].size > 0:
- alert.object = "#### Alert !! \n Empty string in the list names_for_storenames."
- logger.error("Empty string in the list names_for_storenames.")
- raise Exception("Empty string in the list names_for_storenames.")
- else:
- alert.object = "#### No alerts !!"
-
- if arr1.shape[0] != arr2.shape[0]:
- alert.object = "#### Alert !! \n Length of list storenames and names_for_storenames is not equal."
- logger.error("Length of list storenames and names_for_storenames is not equal.")
- raise Exception("Length of list storenames and names_for_storenames is not equal.")
- else:
- alert.object = "#### No alerts !!"
-
- if not os.path.exists(os.path.join(Path.home(), ".storesList.json")):
- storenames_cache = dict()
-
- for i in range(arr1.shape[0]):
- if arr1[i] in storenames_cache:
- storenames_cache[arr1[i]].append(arr2[i])
- storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]]))
- else:
- storenames_cache[arr1[i]] = [arr2[i]]
-
- with open(os.path.join(Path.home(), ".storesList.json"), "w") as f:
- json.dump(storenames_cache, f, indent=4)
- else:
- with open(os.path.join(Path.home(), ".storesList.json")) as f:
- storenames_cache = json.load(f)
-
- for i in range(arr1.shape[0]):
- if arr1[i] in storenames_cache:
- storenames_cache[arr1[i]].append(arr2[i])
- storenames_cache[arr1[i]] = list(set(storenames_cache[arr1[i]]))
- else:
- storenames_cache[arr1[i]] = [arr2[i]]
-
- with open(os.path.join(Path.home(), ".storesList.json"), "w") as f:
- json.dump(storenames_cache, f, indent=4)
-
- arr = np.asarray([arr1, arr2])
- logger.info(arr)
- if not os.path.exists(select_location.value):
- os.mkdir(select_location.value)
-
- np.savetxt(os.path.join(select_location.value, "storesList.csv"), arr, delimiter=",", fmt="%s")
- path.value = os.path.join(select_location.value, "storesList.csv")
- logger.info(f"Storeslist file saved at {select_location.value}")
- logger.info("Storeslist : \n" + str(arr))
-
- # Connect button callbacks
- update_options.on_click(update_values)
- show_config_button.on_click(fetchValues)
- save.on_click(save_button)
- overwrite_button.on_click(overwrite_button_actions)
-
- # creating widgets, adding them to template and showing a GUI on a new browser window
- number = scanPortsAndFind(start_port=5000, end_port=5200)
-
- if "data_np_v2" in flags or "data_np" in flags or "event_np" in flags:
- widget_1 = pn.Column("# " + os.path.basename(folder_path), mark_down, mark_down_np, plot_select, plot)
- widget_2 = pn.Column(
- repeat_storenames,
- repeat_storename_wd,
- pn.Spacer(height=20),
- cross_selector,
- update_options,
- storename_config_widgets,
- pn.Spacer(height=10),
- text,
- literal_input_2,
- alert,
- mark_down_for_overwrite,
- overwrite_button,
- select_location,
- save,
- path,
- )
- template.main.append(pn.Row(widget_1, widget_2))
-
- else:
- widget_1 = pn.Column("# " + os.path.basename(folder_path), mark_down)
- widget_2 = pn.Column(
- repeat_storenames,
- repeat_storename_wd,
- pn.Spacer(height=20),
- cross_selector,
- update_options,
- storename_config_widgets,
- pn.Spacer(height=10),
- text,
- literal_input_2,
- alert,
- mark_down_for_overwrite,
- overwrite_button,
- select_location,
- save,
- path,
- )
- template.main.append(pn.Row(widget_1, widget_2))
-
- template.show(port=number)
-
-
-# function to read input parameters and run the saveStorenames function
-def execute(inputParameters):
-
- inputParameters = inputParameters
- folderNames = inputParameters["folderNames"]
- isosbestic_control = inputParameters["isosbestic_control"]
- num_ch = inputParameters["noChannels"]
- modality = inputParameters.get("modality", "tdt")
-
- logger.info(folderNames)
-
- try:
- for i in folderNames:
- folder_path = os.path.join(inputParameters["abspath"], i)
- if modality == "tdt":
- events, flags = TdtRecordingExtractor.discover_events_and_flags(folder_path=folder_path)
- elif modality == "csv":
- events, flags = CsvRecordingExtractor.discover_events_and_flags(folder_path=folder_path)
-
- elif modality == "doric":
- events, flags = DoricRecordingExtractor.discover_events_and_flags(folder_path=folder_path)
-
- elif modality == "npm":
- headless = bool(os.environ.get("GUPPY_BASE_DIR"))
- if not headless:
- # Resolve multiple event TTLs
- multiple_event_ttls = NpmRecordingExtractor.has_multiple_event_ttls(folder_path=folder_path)
- responses = get_multi_event_responses(multiple_event_ttls)
- inputParameters["npm_split_events"] = responses
-
- # Resolve timestamp units and columns
- ts_unit_needs, col_names_ts = NpmRecordingExtractor.needs_ts_unit(
- folder_path=folder_path, num_ch=num_ch
- )
- ts_units, npm_timestamp_column_names = get_timestamp_configuration(ts_unit_needs, col_names_ts)
- inputParameters["npm_time_units"] = ts_units if ts_units else None
- inputParameters["npm_timestamp_column_names"] = (
- npm_timestamp_column_names if npm_timestamp_column_names else None
- )
-
- events, flags = NpmRecordingExtractor.discover_events_and_flags(
- folder_path=folder_path, num_ch=num_ch, inputParameters=inputParameters
- )
- else:
- raise ValueError("Modality not recognized. Please use 'tdt', 'csv', 'doric', or 'npm'.")
-
- saveStorenames(inputParameters, events, flags, folder_path)
- logger.info("#" * 400)
- except Exception as e:
- logger.error(str(e))
- raise e
-
-
-def get_multi_event_responses(multiple_event_ttls):
- responses = []
- for has_multiple in multiple_event_ttls:
- if not has_multiple:
- responses.append(False)
- continue
- window = tk.Tk()
- response = messagebox.askyesno(
- "Multiple event TTLs",
- (
- "Based on the TTL file, "
- "it looks like TTLs "
- "belong to multiple behavior types. "
- "Do you want to create multiple files for each "
- "behavior type?"
- ),
- )
- window.destroy()
- responses.append(response)
- return responses
-
-
-def get_timestamp_configuration(ts_unit_needs, col_names_ts):
- ts_units, npm_timestamp_column_names = [], []
- for need in ts_unit_needs:
- if not need:
- ts_units.append("seconds")
- npm_timestamp_column_names.append(None)
- continue
- window = tk.Tk()
- window.title("Select appropriate options for timestamps")
- window.geometry("500x200")
- holdComboboxValues = dict()
-
- timestamps_label = ttk.Label(window, text="Select which timestamps to use : ").grid(
- row=0, column=1, pady=25, padx=25
- )
- holdComboboxValues["timestamps"] = StringVar()
- timestamps_combo = ttk.Combobox(window, values=col_names_ts, textvariable=holdComboboxValues["timestamps"])
- timestamps_combo.grid(row=0, column=2, pady=25, padx=25)
- timestamps_combo.current(0)
- # timestamps_combo.bind("<>", comboBoxSelected)
-
- time_unit_label = ttk.Label(window, text="Select timestamps unit : ").grid(row=1, column=1, pady=25, padx=25)
- holdComboboxValues["time_unit"] = StringVar()
- time_unit_combo = ttk.Combobox(
- window,
- values=["", "seconds", "milliseconds", "microseconds"],
- textvariable=holdComboboxValues["time_unit"],
- )
- time_unit_combo.grid(row=1, column=2, pady=25, padx=25)
- time_unit_combo.current(0)
- # time_unit_combo.bind("<>", comboBoxSelected)
- window.lift()
- window.after(500, lambda: window.lift())
- window.mainloop()
-
- if holdComboboxValues["timestamps"].get():
- npm_timestamp_column_name = holdComboboxValues["timestamps"].get()
- else:
- messagebox.showerror(
- "All options not selected",
- "All the options for timestamps \
- were not selected. Please select appropriate options",
- )
- logger.error(
- "All the options for timestamps \
- were not selected. Please select appropriate options"
- )
- raise Exception(
- "All the options for timestamps \
- were not selected. Please select appropriate options"
- )
- if holdComboboxValues["time_unit"].get():
- if holdComboboxValues["time_unit"].get() == "seconds":
- ts_unit = holdComboboxValues["time_unit"].get()
- elif holdComboboxValues["time_unit"].get() == "milliseconds":
- ts_unit = holdComboboxValues["time_unit"].get()
- else:
- ts_unit = holdComboboxValues["time_unit"].get()
- else:
- messagebox.showerror(
- "All options not selected",
- "All the options for timestamps \
- were not selected. Please select appropriate options",
- )
- logger.error(
- "All the options for timestamps \
- were not selected. Please select appropriate options"
- )
- raise Exception(
- "All the options for timestamps \
- were not selected. Please select appropriate options"
- )
- ts_units.append(ts_unit)
- npm_timestamp_column_names.append(npm_timestamp_column_name)
- return ts_units, npm_timestamp_column_names
diff --git a/src/guppy/savingInputParameters.py b/src/guppy/savingInputParameters.py
deleted file mode 100644
index a1bd35e..0000000
--- a/src/guppy/savingInputParameters.py
+++ /dev/null
@@ -1,581 +0,0 @@
-import json
-import logging
-import os
-import subprocess
-import sys
-import time
-import tkinter as tk
-from threading import Thread
-from tkinter import filedialog, ttk
-
-import numpy as np
-import pandas as pd
-import panel as pn
-
-from .saveStoresList import execute
-from .visualizePlot import visualizeResults
-
-logger = logging.getLogger(__name__)
-
-
-def savingInputParameters():
- pn.extension()
-
- # Determine base folder path (headless-friendly via env var)
- base_dir_env = os.environ.get("GUPPY_BASE_DIR")
- is_headless = base_dir_env and os.path.isdir(base_dir_env)
- if is_headless:
- global folder_path
- folder_path = base_dir_env
- logger.info(f"Folder path set to {folder_path} (from GUPPY_BASE_DIR)")
- else:
- # Create the main window
- folder_selection = tk.Tk()
- folder_selection.title("Select the folder path where your data is located")
- folder_selection.geometry("700x200")
-
- def select_folder():
- global folder_path
- folder_path = filedialog.askdirectory(title="Select the folder path where your data is located")
- if folder_path:
- logger.info(f"Folder path set to {folder_path}")
- folder_selection.destroy()
- else:
- folder_path = os.path.expanduser("~")
- logger.info(f"Folder path set to {folder_path}")
-
- select_button = ttk.Button(folder_selection, text="Select a Folder", command=select_folder)
- select_button.pack(pady=5)
- folder_selection.mainloop()
-
- current_dir = os.getcwd()
-
- def make_dir(filepath):
- op = os.path.join(filepath, "inputParameters")
- if not os.path.exists(op):
- os.mkdir(op)
- return op
-
- def readRawData():
- inputParameters = getInputParameters()
- subprocess.call([sys.executable, "-m", "guppy.readTevTsq", json.dumps(inputParameters)])
-
- def extractTs():
- inputParameters = getInputParameters()
- subprocess.call([sys.executable, "-m", "guppy.preprocess", json.dumps(inputParameters)])
-
- def psthComputation():
- inputParameters = getInputParameters()
- inputParameters["curr_dir"] = current_dir
- subprocess.call([sys.executable, "-m", "guppy.computePsth", json.dumps(inputParameters)])
-
- def readPBIncrementValues(progressBar):
- logger.info("Read progress bar increment values function started...")
- file_path = os.path.join(os.path.expanduser("~"), "pbSteps.txt")
- if os.path.exists(file_path):
- os.remove(file_path)
- increment, maximum = 0, 100
- progressBar.value = increment
- progressBar.bar_color = "success"
- while True:
- try:
- with open(file_path, "r") as file:
- content = file.readlines()
- if len(content) == 0:
- pass
- else:
- maximum = int(content[0])
- increment = int(content[-1])
-
- if increment == -1:
- progressBar.bar_color = "danger"
- os.remove(file_path)
- break
- progressBar.max = maximum
- progressBar.value = increment
- time.sleep(0.001)
- except FileNotFoundError:
- time.sleep(0.001)
- except PermissionError:
- time.sleep(0.001)
- except Exception as e:
- # Handle other exceptions that may occur
- logger.info(f"An error occurred while reading the file: {e}")
- break
- if increment == maximum:
- os.remove(file_path)
- break
-
- logger.info("Read progress bar increment values stopped.")
-
- # progress bars = PB
- read_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300)
- extract_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300)
- psth_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300)
-
- template = pn.template.BootstrapTemplate(title="Input Parameters GUI")
-
- mark_down_1 = pn.pane.Markdown("""**Select folders for the analysis from the file selector below**""", width=600)
-
- files_1 = pn.widgets.FileSelector(folder_path, name="folderNames", width=950)
-
- explain_modality = pn.pane.Markdown(
- """
- **Data Modality:** Select the type of data acquisition system used for your recordings:
- - **tdt**: Tucker-Davis Technologies system
- - **csv**: Generic CSV format
- - **doric**: Doric Photometry system
- - **npm**: Neurophotometrics system
- """,
- width=600,
- )
-
- modality_selector = pn.widgets.Select(
- name="Data Modality", value="tdt", options=["tdt", "csv", "doric", "npm"], width=320
- )
-
- explain_time_artifacts = pn.pane.Markdown(
- """
- - ***Number of cores :*** Number of cores used for analysis. Try to
- keep it less than the number of cores in your machine.
- - ***Combine Data? :*** Make this parameter ``` True ``` if user wants to combine
- the data, especially when there is two different
- data files for the same recording session.
- - ***Isosbestic Control Channel? :*** Make this parameter ``` False ``` if user
- does not want to use isosbestic control channel in the analysis.
- - ***Eliminate first few seconds :*** It is the parameter to cut out first x seconds
- from the data. Default is 1 seconds.
- - ***Window for Moving Average filter :*** The filtering of signals
- is done using moving average filter. Default window used for moving
- average filter is 100 datapoints. Change it based on the requirement.
- - ***Moving Window (transients detection) :*** Transients in the z-score
- and/or \u0394F/F are detected using this moving window.
- Default is 15 seconds. Change it based on the requirement.
- - ***High Amplitude filtering threshold (HAFT) (transients detection) :*** High amplitude
- events greater than x times the MAD above the median are filtered out. Here, x is
- high amplitude filtering threshold. Default is 2.
- - ***Transients detection threshold (TD Thresh):*** Peaks with local maxima greater than x times
- the MAD above the median of the trace (after filtering high amplitude events) are detected
- as transients. Here, x is transients detection threshold. Default is 3.
- - ***Number of channels (Neurophotometrics only) :*** Number of
- channels used while recording, when data files has no column names mentioning "Flags"
- or "LedState".
- - ***removeArtifacts? :*** Make this parameter ``` True``` if there are
- artifacts and user wants to remove the artifacts.
- - ***removeArtifacts method :*** Selecting ```concatenate``` will remove bad
- chunks and concatenate the selected good chunks together.
- Selecting ```replace with NaN``` will replace bad chunks with NaN
- values.
- """,
- width=350,
- )
-
- timeForLightsTurnOn = pn.widgets.LiteralInput(
- name="Eliminate first few seconds (int)", value=1, type=int, width=320
- )
-
- isosbestic_control = pn.widgets.Select(
- name="Isosbestic Control Channel? (bool)", value=True, options=[True, False], width=320
- )
-
- numberOfCores = pn.widgets.LiteralInput(name="# of cores (int)", value=2, type=int, width=150)
-
- combine_data = pn.widgets.Select(name="Combine Data? (bool)", value=False, options=[True, False], width=150)
-
- computePsth = pn.widgets.Select(
- name="z_score and/or \u0394F/F? (psth)", options=["z_score", "dff", "Both"], width=320
- )
-
- transients = pn.widgets.Select(
- name="z_score and/or \u0394F/F? (transients)", options=["z_score", "dff", "Both"], width=320
- )
-
- plot_zScore_dff = pn.widgets.Select(
- name="z-score plot and/or \u0394F/F plot?", options=["z_score", "dff", "Both", "None"], value="None", width=320
- )
-
- moving_wd = pn.widgets.LiteralInput(
- name="Moving Window for transients detection (s) (int)", value=15, type=int, width=320
- )
-
- highAmpFilt = pn.widgets.LiteralInput(name="HAFT (int)", value=2, type=int, width=150)
-
- transientsThresh = pn.widgets.LiteralInput(name="TD Thresh (int)", value=3, type=int, width=150)
-
- moving_avg_filter = pn.widgets.LiteralInput(
- name="Window for Moving Average filter (int)", value=100, type=int, width=320
- )
-
- removeArtifacts = pn.widgets.Select(name="removeArtifacts? (bool)", value=False, options=[True, False], width=150)
-
- artifactsRemovalMethod = pn.widgets.Select(
- name="removeArtifacts method", value="concatenate", options=["concatenate", "replace with NaN"], width=150
- )
-
- no_channels_np = pn.widgets.LiteralInput(
- name="Number of channels (Neurophotometrics only)", value=2, type=int, width=320
- )
-
- z_score_computation = pn.widgets.Select(
- name="z-score computation Method",
- options=["standard z-score", "baseline z-score", "modified z-score"],
- value="standard z-score",
- width=200,
- )
-
- baseline_wd_strt = pn.widgets.LiteralInput(
- name="Baseline Window Start Time (s) (int)", value=0, type=int, width=200
- )
- baseline_wd_end = pn.widgets.LiteralInput(name="Baseline Window End Time (s) (int)", value=0, type=int, width=200)
-
- explain_z_score = pn.pane.Markdown(
- """
- ***Note :***
- - Details about z-score computation methods are explained in Github wiki.
- - The details will make user understand what computation method to use for
- their data.
- - Baseline Window Parameters should be kept 0 unless you are using baseline
- z-score computation method. The parameters are in seconds.
- """,
- width=580,
- )
-
- explain_nsec = pn.pane.Markdown(
- """
- - ***Time Interval :*** To omit bursts of event timestamps, user defined time interval
- is set so that if the time difference between two timestamps is less than this defined time
- interval, it will be deleted for the calculation of PSTH.
- - ***Compute Cross-correlation :*** Make this parameter ```True```, when user wants
- to compute cross-correlation between PSTHs of two different signals or signals
- recorded from different brain regions.
- """,
- width=580,
- )
-
- nSecPrev = pn.widgets.LiteralInput(name="Seconds before 0 (int)", value=-10, type=int, width=120)
-
- nSecPost = pn.widgets.LiteralInput(name="Seconds after 0 (int)", value=20, type=int, width=120)
-
- computeCorr = pn.widgets.Select(
- name="Compute Cross-correlation (bool)", options=[True, False], value=False, width=200
- )
-
- timeInterval = pn.widgets.LiteralInput(name="Time Interval (s)", value=2, type=int, width=120)
-
- use_time_or_trials = pn.widgets.Select(
- name="Bin PSTH trials (str)", options=["Time (min)", "# of trials"], value="Time (min)", width=120
- )
-
- bin_psth_trials = pn.widgets.LiteralInput(
- name="Time(min) / # of trials \n for binning? (int)", value=0, type=int, width=200
- )
-
- explain_baseline = pn.pane.Markdown(
- """
- ***Note :***
- - If user does not want to do baseline correction,
- put both parameters 0.
- - If the first event timestamp is less than the length of baseline
- window, it will be rejected in the PSTH computation step.
- - Baseline parameters must be within the PSTH parameters
- set in the PSTH parameters section.
- """,
- width=580,
- )
-
- baselineCorrectionStart = pn.widgets.LiteralInput(
- name="Baseline Correction Start time(int)", value=-5, type=int, width=200
- )
-
- baselineCorrectionEnd = pn.widgets.LiteralInput(
- name="Baseline Correction End time(int)", value=0, type=int, width=200
- )
-
- zscore_param_wd = pn.WidgetBox(
- "### Z-score Parameters",
- explain_z_score,
- z_score_computation,
- pn.Row(baseline_wd_strt, baseline_wd_end),
- width=600,
- )
-
- psth_param_wd = pn.WidgetBox(
- "### PSTH Parameters",
- explain_nsec,
- pn.Row(nSecPrev, nSecPost, computeCorr),
- pn.Row(timeInterval, use_time_or_trials, bin_psth_trials),
- width=600,
- )
-
- baseline_param_wd = pn.WidgetBox(
- "### Baseline Parameters", explain_baseline, pn.Row(baselineCorrectionStart, baselineCorrectionEnd), width=600
- )
- peak_explain = pn.pane.Markdown(
- """
- ***Note :***
- - Peak and area are computed between the window set below.
- - Peak and AUC parameters must be within the PSTH parameters set in the PSTH parameters section.
- - Please make sure when user changes the parameters in the table below, click on any other cell after
- changing a value in a particular cell.
- """,
- width=580,
- )
-
- start_end_point_df = pd.DataFrame(
- {
- "Peak Start time": [-5, 0, 5, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
- "Peak End time": [0, 3, 10, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
- }
- )
-
- df_widget = pn.widgets.Tabulator(start_end_point_df, name="DataFrame", show_index=False, widths=280)
-
- peak_param_wd = pn.WidgetBox("### Peak and AUC Parameters", peak_explain, df_widget, width=600)
-
- mark_down_2 = pn.pane.Markdown(
- """**Select folders for the average analysis from the file selector below**""", width=600
- )
-
- files_2 = pn.widgets.FileSelector(folder_path, name="folderNamesForAvg", width=950)
-
- averageForGroup = pn.widgets.Select(name="Average Group? (bool)", value=False, options=[True, False], width=435)
-
- visualizeAverageResults = pn.widgets.Select(
- name="Visualize Average Results? (bool)", value=False, options=[True, False], width=435
- )
-
- visualize_zscore_or_dff = pn.widgets.Select(
- name="z-score or \u0394F/F? (for visualization)", options=["z_score", "dff"], width=435
- )
-
- individual_analysis_wd_2 = pn.Column(
- explain_time_artifacts,
- pn.Row(numberOfCores, combine_data),
- isosbestic_control,
- timeForLightsTurnOn,
- moving_avg_filter,
- computePsth,
- transients,
- plot_zScore_dff,
- moving_wd,
- pn.Row(highAmpFilt, transientsThresh),
- no_channels_np,
- pn.Row(removeArtifacts, artifactsRemovalMethod),
- )
-
- group_analysis_wd_1 = pn.Column(mark_down_2, files_2, averageForGroup, width=800)
-
- visualization_wd = pn.Row(visualize_zscore_or_dff, pn.Spacer(width=60), visualizeAverageResults)
-
- def getInputParameters():
- abspath = getAbsPath()
- inputParameters = {
- "abspath": abspath[0],
- "folderNames": files_1.value,
- "modality": modality_selector.value,
- "numberOfCores": numberOfCores.value,
- "combine_data": combine_data.value,
- "isosbestic_control": isosbestic_control.value,
- "timeForLightsTurnOn": timeForLightsTurnOn.value,
- "filter_window": moving_avg_filter.value,
- "removeArtifacts": removeArtifacts.value,
- "artifactsRemovalMethod": artifactsRemovalMethod.value,
- "noChannels": no_channels_np.value,
- "zscore_method": z_score_computation.value,
- "baselineWindowStart": baseline_wd_strt.value,
- "baselineWindowEnd": baseline_wd_end.value,
- "nSecPrev": nSecPrev.value,
- "nSecPost": nSecPost.value,
- "computeCorr": computeCorr.value,
- "timeInterval": timeInterval.value,
- "bin_psth_trials": bin_psth_trials.value,
- "use_time_or_trials": use_time_or_trials.value,
- "baselineCorrectionStart": baselineCorrectionStart.value,
- "baselineCorrectionEnd": baselineCorrectionEnd.value,
- "peak_startPoint": list(df_widget.value["Peak Start time"]), # startPoint.value,
- "peak_endPoint": list(df_widget.value["Peak End time"]), # endPoint.value,
- "selectForComputePsth": computePsth.value,
- "selectForTransientsComputation": transients.value,
- "moving_window": moving_wd.value,
- "highAmpFilt": highAmpFilt.value,
- "transientsThresh": transientsThresh.value,
- "plot_zScore_dff": plot_zScore_dff.value,
- "visualize_zscore_or_dff": visualize_zscore_or_dff.value,
- "folderNamesForAvg": files_2.value,
- "averageForGroup": averageForGroup.value,
- "visualizeAverageResults": visualizeAverageResults.value,
- }
- return inputParameters
-
- def checkSameLocation(arr, abspath):
- # abspath = []
- for i in range(len(arr)):
- abspath.append(os.path.dirname(arr[i]))
- abspath = np.asarray(abspath)
- abspath = np.unique(abspath)
- if len(abspath) > 1:
- logger.error("All the folders selected should be at the same location")
- raise Exception("All the folders selected should be at the same location")
-
- return abspath
-
- def getAbsPath():
- arr_1, arr_2 = files_1.value, files_2.value
- if len(arr_1) == 0 and len(arr_2) == 0:
- logger.error("No folder is selected for analysis")
- raise Exception("No folder is selected for analysis")
-
- abspath = []
- if len(arr_1) > 0:
- abspath = checkSameLocation(arr_1, abspath)
- else:
- abspath = checkSameLocation(arr_2, abspath)
-
- abspath = np.unique(abspath)
- if len(abspath) > 1:
- logger.error("All the folders selected should be at the same location")
- raise Exception("All the folders selected should be at the same location")
- return abspath
-
- def onclickProcess(event=None):
-
- logger.debug("Saving Input Parameters file.")
- abspath = getAbsPath()
- analysisParameters = {
- "combine_data": combine_data.value,
- "isosbestic_control": isosbestic_control.value,
- "timeForLightsTurnOn": timeForLightsTurnOn.value,
- "filter_window": moving_avg_filter.value,
- "removeArtifacts": removeArtifacts.value,
- "noChannels": no_channels_np.value,
- "zscore_method": z_score_computation.value,
- "baselineWindowStart": baseline_wd_strt.value,
- "baselineWindowEnd": baseline_wd_end.value,
- "nSecPrev": nSecPrev.value,
- "nSecPost": nSecPost.value,
- "timeInterval": timeInterval.value,
- "bin_psth_trials": bin_psth_trials.value,
- "use_time_or_trials": use_time_or_trials.value,
- "baselineCorrectionStart": baselineCorrectionStart.value,
- "baselineCorrectionEnd": baselineCorrectionEnd.value,
- "peak_startPoint": list(df_widget.value["Peak Start time"]), # startPoint.value,
- "peak_endPoint": list(df_widget.value["Peak End time"]), # endPoint.value,
- "selectForComputePsth": computePsth.value,
- "selectForTransientsComputation": transients.value,
- "moving_window": moving_wd.value,
- "highAmpFilt": highAmpFilt.value,
- "transientsThresh": transientsThresh.value,
- }
- for folder in files_1.value:
- with open(os.path.join(folder, "GuPPyParamtersUsed.json"), "w") as f:
- json.dump(analysisParameters, f, indent=4)
- logger.info(f"Input Parameters file saved at {folder}")
-
- logger.info("#" * 400)
-
- # path.value = (os.path.join(op, 'inputParameters.json')).replace('\\', '/')
- logger.info("Input Parameters File Saved.")
-
- def onclickStoresList(event=None):
- inputParameters = getInputParameters()
- execute(inputParameters)
-
- def onclickVisualization(event=None):
- inputParameters = getInputParameters()
- visualizeResults(inputParameters)
-
- def onclickreaddata(event=None):
- thread = Thread(target=readRawData)
- thread.start()
- readPBIncrementValues(read_progress)
- thread.join()
-
- def onclickextractts(event=None):
- thread = Thread(target=extractTs)
- thread.start()
- readPBIncrementValues(extract_progress)
- thread.join()
-
- def onclickpsth(event=None):
- thread = Thread(target=psthComputation)
- thread.start()
- readPBIncrementValues(psth_progress)
- thread.join()
-
- mark_down_ip = pn.pane.Markdown("""**Step 1 : Save Input Parameters**""", width=300)
- mark_down_ip_note = pn.pane.Markdown(
- """***Note : ***
- - Save Input Parameters will save input parameters used for the analysis
- in all the folders you selected for the analysis (useful for future
- reference). All analysis steps will run without saving input parameters.
- """,
- width=300,
- )
- save_button = pn.widgets.Button(name="Save to file...", button_type="primary", width=300, align="end")
- mark_down_storenames = pn.pane.Markdown("""**Step 2 : Open Storenames GUI
and save storenames**""", width=300)
- open_storesList = pn.widgets.Button(name="Open Storenames GUI", button_type="primary", width=300, align="end")
- mark_down_read = pn.pane.Markdown("""**Step 3 : Read Raw Data**""", width=300)
- read_rawData = pn.widgets.Button(name="Read Raw Data", button_type="primary", width=300, align="end")
- mark_down_extract = pn.pane.Markdown("""**Step 4 : Extract timestamps
and its correction**""", width=300)
- extract_ts = pn.widgets.Button(
- name="Extract timestamps and it's correction", button_type="primary", width=300, align="end"
- )
- mark_down_psth = pn.pane.Markdown("""**Step 5 : PSTH Computation**""", width=300)
- psth_computation = pn.widgets.Button(name="PSTH Computation", button_type="primary", width=300, align="end")
- mark_down_visualization = pn.pane.Markdown("""**Step 6 : Visualization**""", width=300)
- open_visualization = pn.widgets.Button(name="Open Visualization GUI", button_type="primary", width=300, align="end")
- open_terminal = pn.widgets.Button(name="Open Terminal", button_type="primary", width=300, align="end")
-
- save_button.on_click(onclickProcess)
- open_storesList.on_click(onclickStoresList)
- read_rawData.on_click(onclickreaddata)
- extract_ts.on_click(onclickextractts)
- psth_computation.on_click(onclickpsth)
- open_visualization.on_click(onclickVisualization)
-
- template.sidebar.append(mark_down_ip)
- template.sidebar.append(mark_down_ip_note)
- template.sidebar.append(save_button)
- # template.sidebar.append(path)
- template.sidebar.append(mark_down_storenames)
- template.sidebar.append(open_storesList)
- template.sidebar.append(mark_down_read)
- template.sidebar.append(read_rawData)
- template.sidebar.append(read_progress)
- template.sidebar.append(mark_down_extract)
- template.sidebar.append(extract_ts)
- template.sidebar.append(extract_progress)
- template.sidebar.append(mark_down_psth)
- template.sidebar.append(psth_computation)
- template.sidebar.append(psth_progress)
- template.sidebar.append(mark_down_visualization)
- template.sidebar.append(open_visualization)
- # template.sidebar.append(open_terminal)
-
- psth_baseline_param = pn.Column(zscore_param_wd, psth_param_wd, baseline_param_wd, peak_param_wd)
-
- widget = pn.Column(
- mark_down_1, files_1, explain_modality, modality_selector, pn.Row(individual_analysis_wd_2, psth_baseline_param)
- )
-
- # file_selector = pn.WidgetBox(files_1)
- styles = dict(background="WhiteSmoke")
- individual = pn.Card(widget, title="Individual Analysis", styles=styles, width=1000)
- group = pn.Card(group_analysis_wd_1, title="Group Analysis", styles=styles, width=1000)
- visualize = pn.Card(visualization_wd, title="Visualization Parameters", styles=styles, width=1000)
-
- # template.main.append(file_selector)
- template.main.append(individual)
- template.main.append(group)
- template.main.append(visualize)
-
- # Expose minimal hooks and widgets to enable programmatic testing
- template._hooks = {
- "onclickProcess": onclickProcess,
- "getInputParameters": getInputParameters,
- }
- template._widgets = {
- "files_1": files_1,
- }
-
- return template
diff --git a/src/guppy/testing/api.py b/src/guppy/testing/api.py
index 98939cf..834a9e1 100644
--- a/src/guppy/testing/api.py
+++ b/src/guppy/testing/api.py
@@ -13,12 +13,12 @@
import os
from typing import Iterable
-from guppy.computePsth import psthForEachStorename
-from guppy.findTransientsFreqAndAmp import executeFindFreqAndAmp
-from guppy.preprocess import extractTsAndSignal
-from guppy.readTevTsq import readRawData
-from guppy.saveStoresList import execute
-from guppy.savingInputParameters import savingInputParameters
+from guppy.orchestration.home import build_homepage
+from guppy.orchestration.preprocess import extractTsAndSignal
+from guppy.orchestration.psth import psthForEachStorename
+from guppy.orchestration.read_raw_data import orchestrate_read_raw_data
+from guppy.orchestration.storenames import orchestrate_storenames_page
+from guppy.orchestration.transients import executeFindFreqAndAmp
def step1(*, base_dir: str, selected_folders: Iterable[str]) -> None:
@@ -50,7 +50,7 @@ def step1(*, base_dir: str, selected_folders: Iterable[str]) -> None:
os.environ["GUPPY_BASE_DIR"] = base_dir
# Build the template headlessly
- template = savingInputParameters()
+ template = build_homepage()
# Sanity checks: ensure hooks/widgets exposed
if not hasattr(template, "_hooks") or "onclickProcess" not in template._hooks:
@@ -144,7 +144,7 @@ def step2(
# Headless build: set base_dir and construct the template
os.environ["GUPPY_BASE_DIR"] = base_dir
- template = savingInputParameters()
+ template = build_homepage()
# Ensure hooks/widgets exposed
if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks:
@@ -168,7 +168,7 @@ def step2(
input_params["npm_split_events"] = npm_split_events
# Call the underlying Step 2 executor (now headless-aware)
- execute(input_params)
+ orchestrate_storenames_page(input_params)
def step3(
@@ -236,7 +236,7 @@ def step3(
# Headless build: set base_dir and construct the template
os.environ["GUPPY_BASE_DIR"] = base_dir
- template = savingInputParameters()
+ template = build_homepage()
# Ensure hooks/widgets exposed
if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks:
@@ -257,7 +257,7 @@ def step3(
input_params["modality"] = modality
# Call the underlying Step 3 worker directly (no subprocess)
- readRawData(input_params)
+ orchestrate_read_raw_data(input_params)
def step4(
@@ -328,7 +328,7 @@ def step4(
# Headless build: set base_dir and construct the template
os.environ["GUPPY_BASE_DIR"] = base_dir
- template = savingInputParameters()
+ template = build_homepage()
# Ensure hooks/widgets exposed
if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks:
@@ -420,7 +420,7 @@ def step5(
# Headless build: set base_dir and construct the template
os.environ["GUPPY_BASE_DIR"] = base_dir
- template = savingInputParameters()
+ template = build_homepage()
# Ensure hooks/widgets exposed
if not hasattr(template, "_hooks") or "getInputParameters" not in template._hooks:
diff --git a/src/guppy/utils/__init__.py b/src/guppy/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/guppy/utils/utils.py b/src/guppy/utils/utils.py
new file mode 100644
index 0000000..7f7bb17
--- /dev/null
+++ b/src/guppy/utils/utils.py
@@ -0,0 +1,43 @@
+import logging
+import os
+import re
+
+import numpy as np
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+
+def takeOnlyDirs(paths):
+ removePaths = []
+ for p in paths:
+ if os.path.isfile(p):
+ removePaths.append(p)
+ return list(set(paths) - set(removePaths))
+
+
+def get_all_stores_for_combining_data(folderNames):
+ op = []
+ for i in range(100):
+ temp = []
+ match = r"[\s\S]*" + "_output_" + str(i)
+ for j in folderNames:
+ temp.append(re.findall(match, j))
+ temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold)
+ if len(temp) > 0:
+ op.append(temp)
+
+ return op
+
+
+# function to read h5 file and make a dataframe from it
+def read_Df(filepath, event, name):
+ event = event.replace("\\", "_")
+ event = event.replace("/", "_")
+ if name:
+ op = os.path.join(filepath, event + "_{}.h5".format(name))
+ else:
+ op = os.path.join(filepath, event + ".h5")
+ df = pd.read_hdf(op, key="df", mode="r")
+
+ return df
diff --git a/src/guppy/visualization/__init__.py b/src/guppy/visualization/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/guppy/visualization/preprocessing.py b/src/guppy/visualization/preprocessing.py
new file mode 100644
index 0000000..cda8150
--- /dev/null
+++ b/src/guppy/visualization/preprocessing.py
@@ -0,0 +1,43 @@
+import logging
+import os
+
+import matplotlib.pyplot as plt
+
+logger = logging.getLogger(__name__)
+
+# Only set matplotlib backend if not in CI environment
+if not os.getenv("CI"):
+ plt.switch_backend("TKAgg")
+
+
+def visualize_preprocessing(*, suptitle, title, x, y):
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.plot(x, y)
+ ax.set_title(title)
+ fig.suptitle(suptitle)
+
+ return fig, ax
+
+
+def visualize_control_signal_fit(x, y1, y2, y3, plot_name, name, artifacts_have_been_removed):
+ fig = plt.figure()
+ ax1 = fig.add_subplot(311)
+ (line1,) = ax1.plot(x, y1)
+ ax1.set_title(plot_name[0])
+ ax2 = fig.add_subplot(312)
+ (line2,) = ax2.plot(x, y2)
+ ax2.set_title(plot_name[1])
+ ax3 = fig.add_subplot(313)
+ (line3,) = ax3.plot(x, y2)
+ (line3,) = ax3.plot(x, y3)
+ ax3.set_title(plot_name[2])
+ fig.suptitle(name)
+
+ hfont = {"fontname": "DejaVu Sans"}
+
+ if artifacts_have_been_removed:
+ ax3.set_xlabel("Time(s) \n Note : Artifacts have been removed, but are not reflected in this plot.", **hfont)
+ else:
+ ax3.set_xlabel("Time(s)", **hfont)
+ return fig, ax1, ax2, ax3
diff --git a/src/guppy/visualization/transients.py b/src/guppy/visualization/transients.py
new file mode 100644
index 0000000..030ac5e
--- /dev/null
+++ b/src/guppy/visualization/transients.py
@@ -0,0 +1,15 @@
+import logging
+
+import matplotlib.pyplot as plt
+
+logger = logging.getLogger(__name__)
+
+
+def visualize_peaks(title, suptitle, z_score, timestamps, peaksIndex):
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.plot(timestamps, z_score, "-", timestamps[peaksIndex], z_score[peaksIndex], "o")
+ ax.set_title(title)
+ fig.suptitle(suptitle)
+
+ return fig, ax
diff --git a/src/guppy/visualizePlot.py b/src/guppy/visualizePlot.py
deleted file mode 100755
index 929149e..0000000
--- a/src/guppy/visualizePlot.py
+++ /dev/null
@@ -1,936 +0,0 @@
-import glob
-import logging
-import math
-import os
-import re
-import socket
-from random import randint
-
-import datashader as ds
-import holoviews as hv
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-import panel as pn
-import param
-from bokeh.io import export_png, export_svgs
-from holoviews import opts
-from holoviews.operation.datashader import datashade
-from holoviews.plotting.util import process_cmap
-
-from .preprocess import get_all_stores_for_combining_data
-
-pn.extension()
-
-logger = logging.getLogger(__name__)
-
-
-def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"):
- while True:
- port = randint(start_port, end_port)
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.settimeout(0.001) # Set timeout to avoid long waiting on closed ports
- result = sock.connect_ex((host, port))
- if result == 0: # If the connection is successful, the port is open
- continue
- else:
- break
-
- return port
-
-
-def takeOnlyDirs(paths):
- removePaths = []
- for p in paths:
- if os.path.isfile(p):
- removePaths.append(p)
- return list(set(paths) - set(removePaths))
-
-
-# read h5 file as a dataframe
-def read_Df(filepath, event, name):
- event = event.replace("\\", "_")
- event = event.replace("/", "_")
- if name:
- op = os.path.join(filepath, event + "_{}.h5".format(name))
- else:
- op = os.path.join(filepath, event + ".h5")
- df = pd.read_hdf(op, key="df", mode="r")
-
- return df
-
-
-# make a new directory for saving plots
-def make_dir(filepath):
- op = os.path.join(filepath, "saved_plots")
- if not os.path.exists(op):
- os.mkdir(op)
-
- return op
-
-
-# remove unnecessary column names
-def remove_cols(cols):
- regex = re.compile("bin_err_*")
- remove_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
- remove_cols = remove_cols + ["err", "timestamps"]
- cols = [i for i in cols if i not in remove_cols]
-
- return cols
-
-
-# def look_psth_bins(event, name):
-
-
-# helper function to create plots
-def helper_plots(filepath, event, name, inputParameters):
-
- basename = os.path.basename(filepath)
- visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"]
-
- # note when there are no behavior event TTLs
- if len(event) == 0:
- logger.warning("\033[1m" + "There are no behavior event TTLs present to visualize.".format(event) + "\033[0m")
- return 0
-
- if os.path.exists(os.path.join(filepath, "cross_correlation_output")):
- event_corr, frames = [], []
- if visualize_zscore_or_dff == "z_score":
- corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_z_score_*"))
- elif visualize_zscore_or_dff == "dff":
- corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_dff_*"))
- for i in range(len(corr_fp)):
- filename = os.path.basename(corr_fp[i]).split(".")[0]
- event_corr.append(filename)
- df = pd.read_hdf(corr_fp[i], key="df", mode="r")
- frames.append(df)
- if len(frames) > 0:
- df_corr = pd.concat(frames, keys=event_corr, axis=1)
- else:
- event_corr = []
- df_corr = []
- else:
- event_corr = []
- df_corr = None
-
- # combine all the event PSTH so that it can be viewed together
- if name:
- event_name, name = event, name
- new_event, frames, bins = [], [], {}
- for i in range(len(event_name)):
-
- for j in range(len(name)):
- new_event.append(event_name[i] + "_" + name[j].split("_")[-1])
- new_name = name[j]
- temp_df = read_Df(filepath, new_event[-1], new_name)
- cols = list(temp_df.columns)
- regex = re.compile("bin_[(]")
- bins[new_event[-1]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
- # bins.append(keep_cols)
- frames.append(temp_df)
-
- df = pd.concat(frames, keys=new_event, axis=1)
- else:
- new_event = list(np.unique(np.array(event)))
- frames, bins = [], {}
- for i in range(len(new_event)):
- temp_df = read_Df(filepath, new_event[i], "")
- cols = list(temp_df.columns)
- regex = re.compile("bin_[(]")
- bins[new_event[i]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
- frames.append(temp_df)
-
- df = pd.concat(frames, keys=new_event, axis=1)
-
- if isinstance(df_corr, pd.DataFrame):
- new_event.extend(event_corr)
- df = pd.concat([df, df_corr], axis=1, sort=False).reset_index()
-
- columns_dict = dict()
- for i in range(len(new_event)):
- df_1 = df[new_event[i]]
- columns = list(df_1.columns)
- columns.append("All")
- columns_dict[new_event[i]] = columns
-
- # create a class to make GUI and plot different graphs
- class Viewer(param.Parameterized):
-
- # class_event = new_event
-
- # make options array for different selectors
- multiple_plots_options = []
- heatmap_options = new_event
- bins_keys = list(bins.keys())
- if len(bins_keys) > 0:
- bins_new = bins
- for i in range(len(bins_keys)):
- arr = bins[bins_keys[i]]
- if len(arr) > 0:
- # heatmap_options.append('{}_bin'.format(bins_keys[i]))
- for j in arr:
- multiple_plots_options.append("{}_{}".format(bins_keys[i], j))
-
- multiple_plots_options = new_event + multiple_plots_options
- else:
- multiple_plots_options = new_event
-
- # create different options and selectors
- event_selector = param.ObjectSelector(default=new_event[0], objects=new_event)
- event_selector_heatmap = param.ObjectSelector(default=heatmap_options[0], objects=heatmap_options)
- columns = columns_dict
- df_new = df
-
- colormaps = plt.colormaps()
- new_colormaps = ["plasma", "plasma_r", "magma", "magma_r", "inferno", "inferno_r", "viridis", "viridis_r"]
- set_a = set(colormaps)
- set_b = set(new_colormaps)
- colormaps = new_colormaps + list(set_a.difference(set_b))
-
- x_min = float(inputParameters["nSecPrev"]) - 20
- x_max = float(inputParameters["nSecPost"]) + 20
- selector_for_multipe_events_plot = param.ListSelector(
- default=[multiple_plots_options[0]], objects=multiple_plots_options
- )
- x = param.ObjectSelector(default=columns[new_event[0]][-4], objects=[columns[new_event[0]][-4]])
- y = param.ObjectSelector(
- default=remove_cols(columns[new_event[0]])[-2], objects=remove_cols(columns[new_event[0]])
- )
-
- trial_no = range(1, len(remove_cols(columns[heatmap_options[0]])[:-2]) + 1)
- trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(columns[heatmap_options[0]])[:-2])] + [
- "All"
- ]
- heatmap_y = param.ListSelector(default=[trial_ts[-1]], objects=trial_ts)
- psth_y = param.ListSelector(objects=trial_ts[:-1])
- select_trials_checkbox = param.ListSelector(default=["just trials"], objects=["mean", "just trials"])
- Y_Label = param.ObjectSelector(default="y", objects=["y", "z-score", "\u0394F/F"])
- save_options = param.ObjectSelector(
- default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"]
- )
- save_options_heatmap = param.ObjectSelector(
- default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"]
- )
- color_map = param.ObjectSelector(default="plasma", objects=colormaps)
- height_heatmap = param.ObjectSelector(default=600, objects=list(np.arange(0, 5100, 100))[1:])
- width_heatmap = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:])
- Height_Plot = param.ObjectSelector(default=300, objects=list(np.arange(0, 5100, 100))[1:])
- Width_Plot = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:])
- save_hm = param.Action(lambda x: x.param.trigger("save_hm"), label="Save")
- save_psth = param.Action(lambda x: x.param.trigger("save_psth"), label="Save")
- X_Limit = param.Range(default=(-5, 10), bounds=(x_min, x_max))
- Y_Limit = param.Range(bounds=(-50, 50.0))
-
- # C_Limit = param.Range(bounds=(-20,20.0))
-
- results_hm = dict()
- results_psth = dict()
-
- # function to save heatmaps when save button on heatmap tab is clicked
- @param.depends("save_hm", watch=True)
- def save_hm_plots(self):
- plot = self.results_hm["plot"]
- op = self.results_hm["op"]
- save_opts = self.save_options_heatmap
- logger.info(save_opts)
- if save_opts == "save_svg_format":
- p = hv.render(plot, backend="bokeh")
- p.output_backend = "svg"
- export_svgs(p, filename=op + ".svg")
- elif save_opts == "save_png_format":
- p = hv.render(plot, backend="bokeh")
- export_png(p, filename=op + ".png")
- elif save_opts == "save_both_format":
- p = hv.render(plot, backend="bokeh")
- p.output_backend = "svg"
- export_svgs(p, filename=op + ".svg")
- p_png = hv.render(plot, backend="bokeh")
- export_png(p_png, filename=op + ".png")
- else:
- return 0
-
- # function to save PSTH plots when save button on PSTH tab is clicked
- @param.depends("save_psth", watch=True)
- def save_psth_plot(self):
- plot, op = [], []
- plot.append(self.results_psth["plot_combine"])
- op.append(self.results_psth["op_combine"])
- plot.append(self.results_psth["plot"])
- op.append(self.results_psth["op"])
- for i in range(len(plot)):
- temp_plot, temp_op = plot[i], op[i]
- save_opts = self.save_options
- if save_opts == "save_svg_format":
- p = hv.render(temp_plot, backend="bokeh")
- p.output_backend = "svg"
- export_svgs(p, filename=temp_op + ".svg")
- elif save_opts == "save_png_format":
- p = hv.render(temp_plot, backend="bokeh")
- export_png(p, filename=temp_op + ".png")
- elif save_opts == "save_both_format":
- p = hv.render(temp_plot, backend="bokeh")
- p.output_backend = "svg"
- export_svgs(p, filename=temp_op + ".svg")
- p_png = hv.render(temp_plot, backend="bokeh")
- export_png(p_png, filename=temp_op + ".png")
- else:
- return 0
-
- # function to change Y values based on event selection
- @param.depends("event_selector", watch=True)
- def _update_x_y(self):
- x_value = self.columns[self.event_selector]
- y_value = self.columns[self.event_selector]
- self.param["x"].objects = [x_value[-4]]
- self.param["y"].objects = remove_cols(y_value)
- self.x = x_value[-4]
- self.y = self.param["y"].objects[-2]
-
- @param.depends("event_selector_heatmap", watch=True)
- def _update_df(self):
- cols = self.columns[self.event_selector_heatmap]
- trial_no = range(1, len(remove_cols(cols)[:-2]) + 1)
- trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + ["All"]
- self.param["heatmap_y"].objects = trial_ts
- self.heatmap_y = [trial_ts[-1]]
-
- @param.depends("event_selector", watch=True)
- def _update_psth_y(self):
- cols = self.columns[self.event_selector]
- trial_no = range(1, len(remove_cols(cols)[:-2]) + 1)
- trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])]
- self.param["psth_y"].objects = trial_ts
- self.psth_y = [trial_ts[0]]
-
- # function to plot multiple PSTHs into one plot
-
- @param.depends(
- "selector_for_multipe_events_plot",
- "Y_Label",
- "save_options",
- "X_Limit",
- "Y_Limit",
- "Height_Plot",
- "Width_Plot",
- )
- def update_selector(self):
- data_curve, cols_curve, data_spread, cols_spread = [], [], [], []
- arr = self.selector_for_multipe_events_plot
- df1 = self.df_new
- for i in range(len(arr)):
- if "bin" in arr[i]:
- split = arr[i].rsplit("_", 2)
- df_name = split[0] #'{}_{}'.format(split[0], split[1])
- col_name_mean = "{}_{}".format(split[-2], split[-1])
- col_name_err = "{}_err_{}".format(split[-2], split[-1])
- data_curve.append(df1[df_name][col_name_mean])
- cols_curve.append(arr[i])
- data_spread.append(df1[df_name][col_name_err])
- cols_spread.append(arr[i])
- else:
- data_curve.append(df1[arr[i]]["mean"])
- cols_curve.append(arr[i] + "_" + "mean")
- data_spread.append(df1[arr[i]]["err"])
- cols_spread.append(arr[i] + "_" + "mean")
-
- if len(arr) > 0:
- if self.Y_Limit == None:
- self.Y_Limit = (np.nanmin(np.asarray(data_curve)) - 0.5, np.nanmax(np.asarray(data_curve)) + 0.5)
-
- if "bin" in arr[i]:
- split = arr[i].rsplit("_", 2)
- df_name = split[0]
- data_curve.append(df1[df_name]["timestamps"])
- cols_curve.append("timestamps")
- data_spread.append(df1[df_name]["timestamps"])
- cols_spread.append("timestamps")
- else:
- data_curve.append(df1[arr[i]]["timestamps"])
- cols_curve.append("timestamps")
- data_spread.append(df1[arr[i]]["timestamps"])
- cols_spread.append("timestamps")
- df_curve = pd.concat(data_curve, axis=1)
- df_spread = pd.concat(data_spread, axis=1)
- df_curve.columns = cols_curve
- df_spread.columns = cols_spread
-
- ts = df_curve["timestamps"]
- index = np.arange(0, ts.shape[0], 3)
- df_curve = df_curve.loc[index, :]
- df_spread = df_spread.loc[index, :]
- overlay = hv.NdOverlay(
- {
- c: hv.Curve((df_curve["timestamps"], df_curve[c]), kdims=["Time (s)"]).opts(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- )
- for c in cols_curve[:-1]
- }
- )
- spread = hv.NdOverlay(
- {
- d: hv.Spread(
- (df_spread["timestamps"], df_curve[d], df_spread[d], df_spread[d]),
- vdims=["y", "yerrpos", "yerrneg"],
- ).opts(line_width=0, fill_alpha=0.3)
- for d in cols_spread[:-1]
- }
- )
- plot_combine = ((overlay * spread).opts(opts.NdOverlay(xlabel="Time (s)", ylabel=self.Y_Label))).opts(
- shared_axes=False
- )
- # plot_err = new_df.hvplot.area(x='timestamps', y=[], y2=[])
- save_opts = self.save_options
- op = make_dir(filepath)
- op_filename = os.path.join(op, str(arr) + "_mean")
-
- self.results_psth["plot_combine"] = plot_combine
- self.results_psth["op_combine"] = op_filename
- # self.save_plots(plot_combine, save_opts, op_filename)
- return plot_combine
-
- # function to plot mean PSTH, single trial in PSTH and all the trials of PSTH with mean
- @param.depends(
- "event_selector", "x", "y", "Y_Label", "save_options", "Y_Limit", "X_Limit", "Height_Plot", "Width_Plot"
- )
- def contPlot(self):
- df1 = self.df_new[self.event_selector]
- # height = self.Heigth_Plot
- # width = self.Width_Plot
- # logger.info(height, width)
- if self.y == "All":
- if self.Y_Limit == None:
- self.Y_Limit = (np.nanmin(np.asarray(df1)) - 0.5, np.nanmax(np.asarray(df1)) - 0.5)
-
- options = self.param["y"].objects
- regex = re.compile("bin_[(]")
- remove_bin_trials = [options[i] for i in range(len(options)) if not regex.match(options[i])]
-
- ndoverlay = hv.NdOverlay({c: hv.Curve((df1[self.x], df1[c])) for c in remove_bin_trials[:-2]})
- img1 = datashade(ndoverlay, normalization="linear", aggregator=ds.count())
- x_points = df1[self.x]
- y_points = df1["mean"]
- img2 = hv.Curve((x_points, y_points))
- img = (img1 * img2).opts(
- opts.Curve(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- line_width=4,
- color="black",
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
- )
-
- save_opts = self.save_options
-
- op = make_dir(filepath)
- op_filename = os.path.join(op, self.event_selector + "_" + self.y)
- self.results_psth["plot"] = img
- self.results_psth["op"] = op_filename
- # self.save_plots(img, save_opts, op_filename)
-
- return img
-
- elif self.y == "mean" or "bin" in self.y:
-
- xpoints = df1[self.x]
- ypoints = df1[self.y]
- if self.y == "mean":
- err = df1["err"]
- else:
- split = self.y.split("_")
- err = df1["{}_err_{}".format(split[0], split[1])]
-
- index = np.arange(0, xpoints.shape[0], 3)
-
- if self.Y_Limit == None:
- self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5)
-
- ropts_curve = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- color="blue",
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
- ropts_spread = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- fill_alpha=0.3,
- fill_color="blue",
- line_width=0,
- )
-
- plot_curve = hv.Curve((xpoints[index], ypoints[index])) # .opts(**ropts_curve)
- plot_spread = hv.Spread(
- (xpoints[index], ypoints[index], err[index], err[index])
- ) # .opts(**ropts_spread) #vdims=['y', 'yerrpos', 'yerrneg']
- plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread})
-
- save_opts = self.save_options
- op = make_dir(filepath)
- op_filename = os.path.join(op, self.event_selector + "_" + self.y)
- self.results_psth["plot"] = plot
- self.results_psth["op"] = op_filename
- # self.save_plots(plot, save_opts, op_filename)
-
- return plot
-
- else:
- xpoints = df1[self.x]
- ypoints = df1[self.y]
- if self.Y_Limit == None:
- self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5)
-
- ropts_curve = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- color="blue",
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
- plot = hv.Curve((xpoints, ypoints)).opts({"Curve": ropts_curve})
-
- save_opts = self.save_options
- op = make_dir(filepath)
- op_filename = os.path.join(op, self.event_selector + "_" + self.y)
- self.results_psth["plot"] = plot
- self.results_psth["op"] = op_filename
- # self.save_plots(plot, save_opts, op_filename)
-
- return plot
-
- # function to plot specific PSTH trials
- @param.depends(
- "event_selector",
- "x",
- "psth_y",
- "select_trials_checkbox",
- "Y_Label",
- "save_options",
- "Y_Limit",
- "X_Limit",
- "Height_Plot",
- "Width_Plot",
- )
- def plot_specific_trials(self):
- df_psth = self.df_new[self.event_selector]
- # if self.Y_Limit==None:
- # self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5)
-
- if self.psth_y == None:
- return None
- else:
- selected_trials = [s.split(" - ")[1] for s in list(self.psth_y)]
-
- index = np.arange(0, df_psth["timestamps"].shape[0], 3)
-
- if self.select_trials_checkbox == ["just trials"]:
- overlay = hv.NdOverlay(
- {
- c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"])
- for c in selected_trials
- }
- )
- ropts = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
- return overlay.opts(**ropts)
- elif self.select_trials_checkbox == ["mean"]:
- arr = np.asarray(df_psth[selected_trials])
- mean = np.nanmean(arr, axis=1)
- err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1])
- ropts_curve = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- color="blue",
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
- ropts_spread = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- fill_alpha=0.3,
- fill_color="blue",
- line_width=0,
- )
- plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index]))
- plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index]))
- plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread})
- return plot
- elif self.select_trials_checkbox == ["mean", "just trials"]:
- overlay = hv.NdOverlay(
- {
- c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"])
- for c in selected_trials
- }
- )
- ropts_overlay = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
-
- arr = np.asarray(df_psth[selected_trials])
- mean = np.nanmean(arr, axis=1)
- err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1])
- ropts_curve = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- xlim=self.X_Limit,
- ylim=self.Y_Limit,
- color="black",
- xlabel="Time (s)",
- ylabel=self.Y_Label,
- )
- ropts_spread = dict(
- width=int(self.Width_Plot),
- height=int(self.Height_Plot),
- fill_alpha=0.3,
- fill_color="black",
- line_width=0,
- )
- plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index]))
- plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index]))
-
- plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread})
- return overlay.opts(**ropts_overlay) * plot
-
- # function to show heatmaps for each event
- @param.depends("event_selector_heatmap", "color_map", "height_heatmap", "width_heatmap", "heatmap_y")
- def heatmap(self):
- height = self.height_heatmap
- width = self.width_heatmap
- df_hm = self.df_new[self.event_selector_heatmap]
- cols = list(df_hm.columns)
- regex = re.compile("bin_err_*")
- drop_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])]
- drop_cols = ["err", "mean"] + drop_cols
- df_hm = df_hm.drop(drop_cols, axis=1)
- cols = list(df_hm.columns)
- bin_cols = [cols[i] for i in range(len(cols)) if re.compile("bin_*").match(cols[i])]
- time = np.asarray(df_hm["timestamps"])
- event_ts_for_each_event = np.arange(1, len(df_hm.columns[:-1]) + 1)
- yticks = list(event_ts_for_each_event)
- z_score = np.asarray(df_hm[df_hm.columns[:-1]]).T
-
- if self.heatmap_y[0] == "All":
- indices = np.arange(z_score.shape[0] - len(bin_cols))
- z_score = z_score[indices, :]
- event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1)
- yticks = list(event_ts_for_each_event)
- else:
- remove_all = list(set(self.heatmap_y) - set(["All"]))
- indices = sorted([int(s.split("-")[0]) - 1 for s in remove_all])
- z_score = z_score[indices, :]
- event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1)
- yticks = list(event_ts_for_each_event)
-
- clim = (np.nanmin(z_score), np.nanmax(z_score))
- font_size = {"labels": 16, "yticks": 6}
-
- if event_ts_for_each_event.shape[0] == 1:
- dummy_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)).opts(colorbar=True, clim=clim)
- image = (
- (dummy_image).opts(
- opts.QuadMesh(
- width=int(width),
- height=int(height),
- cmap=process_cmap(self.color_map, provider="matplotlib"),
- colorbar=True,
- ylabel="Trials",
- xlabel="Time (s)",
- fontsize=font_size,
- yticks=yticks,
- )
- )
- ).opts(shared_axes=False)
-
- save_opts = self.save_options_heatmap
- op = make_dir(filepath)
- op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap")
- self.results_hm["plot"] = image
- self.results_hm["op"] = op_filename
- # self.save_plots(image, save_opts, op_filename)
- return image
- else:
- ropts = dict(
- width=int(width),
- height=int(height),
- ylabel="Trials",
- xlabel="Time (s)",
- fontsize=font_size,
- yticks=yticks,
- invert_yaxis=True,
- )
- dummy_image = hv.QuadMesh((time[0:100], event_ts_for_each_event, z_score[:, 0:100])).opts(
- colorbar=True, cmap=process_cmap(self.color_map, provider="matplotlib"), clim=clim
- )
- actual_image = hv.QuadMesh((time, event_ts_for_each_event, z_score))
-
- dynspread_img = datashade(actual_image, cmap=process_cmap(self.color_map, provider="matplotlib")).opts(
- **ropts
- ) # clims=self.C_Limit, cnorm='log'
- image = ((dummy_image * dynspread_img).opts(opts.QuadMesh(width=int(width), height=int(height)))).opts(
- shared_axes=False
- )
-
- save_opts = self.save_options_heatmap
- op = make_dir(filepath)
- op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap")
- self.results_hm["plot"] = image
- self.results_hm["op"] = op_filename
-
- return image
-
- view = Viewer()
-
- # PSTH plot options
- psth_checkbox = pn.Param(
- view.param.select_trials_checkbox,
- widgets={
- "select_trials_checkbox": {
- "type": pn.widgets.CheckBoxGroup,
- "inline": True,
- "name": "Select mean and/or just trials",
- }
- },
- )
- parameters = pn.Param(
- view.param.selector_for_multipe_events_plot,
- widgets={
- "selector_for_multipe_events_plot": {"type": pn.widgets.CrossSelector, "width": 550, "align": "start"}
- },
- )
- heatmap_y_parameters = pn.Param(
- view.param.heatmap_y,
- widgets={
- "heatmap_y": {"type": pn.widgets.MultiSelect, "name": "Trial # - Timestamps", "width": 200, "size": 30}
- },
- )
- psth_y_parameters = pn.Param(
- view.param.psth_y,
- widgets={
- "psth_y": {
- "type": pn.widgets.MultiSelect,
- "name": "Trial # - Timestamps",
- "width": 200,
- "size": 15,
- "align": "start",
- }
- },
- )
-
- event_selector = pn.Param(
- view.param.event_selector, widgets={"event_selector": {"type": pn.widgets.Select, "width": 400}}
- )
- x_selector = pn.Param(view.param.x, widgets={"x": {"type": pn.widgets.Select, "width": 180}})
- y_selector = pn.Param(view.param.y, widgets={"y": {"type": pn.widgets.Select, "width": 180}})
-
- width_plot = pn.Param(view.param.Width_Plot, widgets={"Width_Plot": {"type": pn.widgets.Select, "width": 70}})
- height_plot = pn.Param(view.param.Height_Plot, widgets={"Height_Plot": {"type": pn.widgets.Select, "width": 70}})
- ylabel = pn.Param(view.param.Y_Label, widgets={"Y_Label": {"type": pn.widgets.Select, "width": 70}})
- save_opts = pn.Param(view.param.save_options, widgets={"save_options": {"type": pn.widgets.Select, "width": 70}})
-
- xlimit_plot = pn.Param(view.param.X_Limit, widgets={"X_Limit": {"type": pn.widgets.RangeSlider, "width": 180}})
- ylimit_plot = pn.Param(view.param.Y_Limit, widgets={"Y_Limit": {"type": pn.widgets.RangeSlider, "width": 180}})
- save_psth = pn.Param(view.param.save_psth, widgets={"save_psth": {"type": pn.widgets.Button, "width": 400}})
-
- options = pn.Column(
- event_selector,
- pn.Row(x_selector, y_selector),
- pn.Row(xlimit_plot, ylimit_plot),
- pn.Row(width_plot, height_plot, ylabel, save_opts),
- save_psth,
- )
-
- options_selectors = pn.Row(options, parameters)
-
- line_tab = pn.Column(
- "## " + basename,
- pn.Row(options_selectors, pn.Column(psth_checkbox, psth_y_parameters), width=1200),
- view.contPlot,
- view.update_selector,
- view.plot_specific_trials,
- )
-
- # Heatmap plot options
- event_selector_heatmap = pn.Param(
- view.param.event_selector_heatmap, widgets={"event_selector_heatmap": {"type": pn.widgets.Select, "width": 150}}
- )
- color_map = pn.Param(view.param.color_map, widgets={"color_map": {"type": pn.widgets.Select, "width": 150}})
- width_heatmap = pn.Param(
- view.param.width_heatmap, widgets={"width_heatmap": {"type": pn.widgets.Select, "width": 150}}
- )
- height_heatmap = pn.Param(
- view.param.height_heatmap, widgets={"height_heatmap": {"type": pn.widgets.Select, "width": 150}}
- )
- save_hm = pn.Param(view.param.save_hm, widgets={"save_hm": {"type": pn.widgets.Button, "width": 150}})
- save_options_heatmap = pn.Param(
- view.param.save_options_heatmap, widgets={"save_options_heatmap": {"type": pn.widgets.Select, "width": 150}}
- )
-
- hm_tab = pn.Column(
- "## " + basename,
- pn.Row(
- event_selector_heatmap,
- color_map,
- width_heatmap,
- height_heatmap,
- save_options_heatmap,
- pn.Column(pn.Spacer(height=25), save_hm),
- ),
- pn.Row(view.heatmap, heatmap_y_parameters),
- ) #
- logger.info("app")
-
- template = pn.template.MaterialTemplate(title="Visualization GUI")
-
- number = scanPortsAndFind(start_port=5000, end_port=5200)
- app = pn.Tabs(("PSTH", line_tab), ("Heat Map", hm_tab))
-
- template.main.append(app)
-
- template.show(port=number)
-
-
-# function to combine all the output folders together and preprocess them to use them in helper_plots function
-def createPlots(filepath, event, inputParameters):
-
- for i in range(len(event)):
- event[i] = event[i].replace("\\", "_")
- event[i] = event[i].replace("/", "_")
-
- average = inputParameters["visualizeAverageResults"]
- visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"]
-
- if average == True:
- path = []
- for i in range(len(event)):
- if visualize_zscore_or_dff == "z_score":
- path.append(glob.glob(os.path.join(filepath, event[i] + "*_z_score_*")))
- elif visualize_zscore_or_dff == "dff":
- path.append(glob.glob(os.path.join(filepath, event[i] + "*_dff_*")))
-
- path = np.concatenate(path)
- else:
- if visualize_zscore_or_dff == "z_score":
- path = glob.glob(os.path.join(filepath, "z_score_*"))
- elif visualize_zscore_or_dff == "dff":
- path = glob.glob(os.path.join(filepath, "dff_*"))
-
- name_arr = []
- event_arr = []
-
- index = []
- for i in range(len(event)):
- if "control" in event[i].lower() or "signal" in event[i].lower():
- index.append(i)
-
- event = np.delete(event, index)
-
- for i in range(len(path)):
- name = (os.path.basename(path[i])).split(".")
- name = name[0]
- name_arr.append(name)
-
- if average == True:
- logger.info("average")
- helper_plots(filepath, name_arr, "", inputParameters)
- else:
- helper_plots(filepath, event, name_arr, inputParameters)
-
-
-def visualizeResults(inputParameters):
-
- inputParameters = inputParameters
-
- average = inputParameters["visualizeAverageResults"]
- logger.info(average)
-
- folderNames = inputParameters["folderNames"]
- folderNamesForAvg = inputParameters["folderNamesForAvg"]
- combine_data = inputParameters["combine_data"]
-
- if average == True and len(folderNamesForAvg) > 0:
- # folderNames = folderNamesForAvg
- filepath_avg = os.path.join(inputParameters["abspath"], "average")
- # filepath = os.path.join(inputParameters['abspath'], folderNames[0])
- storesListPath = []
- for i in range(len(folderNamesForAvg)):
- filepath = folderNamesForAvg[i]
- storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
- storesListPath = np.concatenate(storesListPath)
- storesList = np.asarray([[], []])
- for i in range(storesListPath.shape[0]):
- storesList = np.concatenate(
- (
- storesList,
- np.genfromtxt(
- os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter=","
- ).reshape(2, -1),
- ),
- axis=1,
- )
- storesList = np.unique(storesList, axis=1)
-
- createPlots(filepath_avg, np.unique(storesList[1, :]), inputParameters)
-
- else:
- if combine_data == True:
- storesListPath = []
- for i in range(len(folderNames)):
- filepath = folderNames[i]
- storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))))
- storesListPath = list(np.concatenate(storesListPath).flatten())
- op = get_all_stores_for_combining_data(storesListPath)
- for i in range(len(op)):
- storesList = np.asarray([[], []])
- for j in range(len(op[i])):
- storesList = np.concatenate(
- (
- storesList,
- np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape(
- 2, -1
- ),
- ),
- axis=1,
- )
- storesList = np.unique(storesList, axis=1)
- filepath = op[i][0]
- createPlots(filepath, storesList[1, :], inputParameters)
- else:
- for i in range(len(folderNames)):
-
- filepath = folderNames[i]
- storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))
- for j in range(len(storesListPath)):
- filepath = storesListPath[j]
- storesList = np.genfromtxt(
- os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=","
- ).reshape(2, -1)
-
- createPlots(filepath, storesList[1, :], inputParameters)
-
-
-# logger.info(sys.argv[1:])
-# visualizeResults(sys.argv[1:][0])