diff --git a/.gitignore b/.gitignore index 0628429..f684eec 100755 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ GuPPy/runFiberPhotometryAnalysis.ipynb .clinerules/ testing_data/ + +CLAUDE.md diff --git a/src/guppy/analysis/__init__.py b/src/guppy/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/guppy/analysis/artifact_removal.py b/src/guppy/analysis/artifact_removal.py new file mode 100644 index 0000000..d3da042 --- /dev/null +++ b/src/guppy/analysis/artifact_removal.py @@ -0,0 +1,222 @@ +import logging + +import numpy as np + +logger = logging.getLogger(__name__) + + +def remove_artifacts( + timeForLightsTurnOn, + storesList, + pair_name_to_tsNew, + pair_name_to_sampling_rate, + pair_name_to_coords, + name_to_data, + compound_name_to_ttl_timestamps, + method, +): + if method == "concatenate": + name_to_corrected_data, pair_name_to_corrected_timestamps, compound_name_to_corrected_ttl_timestamps = ( + processTimestampsForArtifacts( + timeForLightsTurnOn, + storesList, + pair_name_to_tsNew, + pair_name_to_sampling_rate, + pair_name_to_coords, + name_to_data, + compound_name_to_ttl_timestamps, + ) + ) + logger.info("Artifacts removed using concatenate method.") + elif method == "replace with NaN": + name_to_corrected_data, compound_name_to_corrected_ttl_timestamps = addingNaNtoChunksWithArtifacts( + storesList, + pair_name_to_tsNew, + pair_name_to_coords, + name_to_data, + compound_name_to_ttl_timestamps, + ) + pair_name_to_corrected_timestamps = None + logger.info("Artifacts removed using NaN replacement method.") + else: + logger.error("Invalid artifact removal method specified.") + raise ValueError("Invalid artifact removal method specified.") + + return name_to_corrected_data, pair_name_to_corrected_timestamps, compound_name_to_corrected_ttl_timestamps + + +def addingNaNtoChunksWithArtifacts( + storesList, pair_name_to_tsNew, pair_name_to_coords, name_to_data, compound_name_to_ttl_timestamps +): + logger.debug("Replacing chunks with artifacts by NaN values.") + names_for_storenames = storesList[1, :] + pair_names = pair_name_to_tsNew.keys() + + name_to_corrected_data = {} + compound_name_to_corrected_ttl_timestamps = {} + for pair_name in pair_names: + tsNew = pair_name_to_tsNew[pair_name] + coords = pair_name_to_coords[pair_name] + for i in range(len(names_for_storenames)): + if ( + "control_" + pair_name.lower() in names_for_storenames[i].lower() + or "signal_" + pair_name.lower() in names_for_storenames[i].lower() + ): # changes done + data = name_to_data[names_for_storenames[i]].reshape(-1) + data = addingNaNValues(data=data, ts=tsNew, coords=coords) + name_to_corrected_data[names_for_storenames[i]] = data + else: + if "control" in names_for_storenames[i].lower() or "signal" in names_for_storenames[i].lower(): + continue + ttl_name = names_for_storenames[i] + compound_name = ttl_name + "_" + pair_name + ts = compound_name_to_ttl_timestamps[compound_name].reshape(-1) + ts = removeTTLs(ts=ts, coords=coords) + compound_name_to_corrected_ttl_timestamps[compound_name] = ts + logger.info("Chunks with artifacts are replaced by NaN values.") + + return name_to_corrected_data, compound_name_to_corrected_ttl_timestamps + + +# main function to align timestamps for control, signal and event timestamps for artifacts removal +def processTimestampsForArtifacts( + timeForLightsTurnOn, + storesList, + pair_name_to_tsNew, + pair_name_to_sampling_rate, + pair_name_to_coords, + name_to_data, + compound_name_to_ttl_timestamps, +): + logger.debug("Processing timestamps to get rid of artifacts using concatenate method...") + names_for_storenames = storesList[1, :] + pair_names = pair_name_to_tsNew.keys() + + name_to_corrected_data = {} + pair_name_to_corrected_timestamps = {} + compound_name_to_corrected_ttl_timestamps = {} + for pair_name in pair_names: + sampling_rate = pair_name_to_sampling_rate[pair_name] + tsNew = pair_name_to_tsNew[pair_name] + coords = pair_name_to_coords[pair_name] + + for i in range(len(names_for_storenames)): + if ( + "control_" + pair_name.lower() in names_for_storenames[i].lower() + or "signal_" + pair_name.lower() in names_for_storenames[i].lower() + ): # changes done + data = name_to_data[names_for_storenames[i]] + data, timestampNew = eliminateData( + data=data, + ts=tsNew, + coords=coords, + timeForLightsTurnOn=timeForLightsTurnOn, + sampling_rate=sampling_rate, + ) + name_to_corrected_data[names_for_storenames[i]] = data + pair_name_to_corrected_timestamps[pair_name] = timestampNew + else: + if "control" in names_for_storenames[i].lower() or "signal" in names_for_storenames[i].lower(): + continue + compound_name = names_for_storenames[i] + "_" + pair_name + ts = compound_name_to_ttl_timestamps[compound_name] + ts = eliminateTs( + ts=ts, + tsNew=tsNew, + coords=coords, + timeForLightsTurnOn=timeForLightsTurnOn, + sampling_rate=sampling_rate, + ) + compound_name_to_corrected_ttl_timestamps[compound_name] = ts + + logger.info("Timestamps processed, artifacts are removed and good chunks are concatenated.") + + return ( + name_to_corrected_data, + pair_name_to_corrected_timestamps, + compound_name_to_corrected_ttl_timestamps, + ) + + +# helper function to process control and signal timestamps +def eliminateData(*, data, ts, coords, timeForLightsTurnOn, sampling_rate): + + if (data == 0).all() == True: + data = np.zeros(ts.shape[0]) + + arr = np.array([]) + ts_arr = np.array([]) + for i in range(coords.shape[0]): + + index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] + + if len(arr) == 0: + arr = np.concatenate((arr, data[index])) + sub = ts[index][0] - timeForLightsTurnOn + new_ts = ts[index] - sub + ts_arr = np.concatenate((ts_arr, new_ts)) + else: + temp = data[index] + # new = temp + (arr[-1]-temp[0]) + temp_ts = ts[index] + new_ts = temp_ts - (temp_ts[0] - ts_arr[-1]) + arr = np.concatenate((arr, temp)) + ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) + + # logger.info(arr.shape, ts_arr.shape) + return arr, ts_arr + + +# helper function to align event timestamps with the control and signal timestamps +def eliminateTs(*, ts, tsNew, coords, timeForLightsTurnOn, sampling_rate): + + ts_arr = np.array([]) + tsNew_arr = np.array([]) + for i in range(coords.shape[0]): + tsNew_index = np.where((tsNew > coords[i, 0]) & (tsNew < coords[i, 1]))[0] + ts_index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] + + if len(tsNew_arr) == 0: + sub = tsNew[tsNew_index][0] - timeForLightsTurnOn + tsNew_arr = np.concatenate((tsNew_arr, tsNew[tsNew_index] - sub)) + ts_arr = np.concatenate((ts_arr, ts[ts_index] - sub)) + else: + temp_tsNew = tsNew[tsNew_index] + temp_ts = ts[ts_index] + new_ts = temp_ts - (temp_tsNew[0] - tsNew_arr[-1]) + new_tsNew = temp_tsNew - (temp_tsNew[0] - tsNew_arr[-1]) + tsNew_arr = np.concatenate((tsNew_arr, new_tsNew + (1 / sampling_rate))) + ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) + + return ts_arr + + +# adding nan values to removed chunks +# when using artifacts removal method - replace with NaN +def addingNaNValues(*, data, ts, coords): + + if (data == 0).all() == True: + data = np.zeros(ts.shape[0]) + + arr = np.array([]) + ts_index = np.arange(ts.shape[0]) + for i in range(coords.shape[0]): + + index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] + arr = np.concatenate((arr, index)) + + nan_indices = list(set(ts_index).symmetric_difference(arr)) + data[nan_indices] = np.nan + + return data + + +# remove event TTLs which falls in the removed chunks +# when using artifacts removal method - replace with NaN +def removeTTLs(*, ts, coords): + ts_arr = np.array([]) + for i in range(coords.shape[0]): + ts_index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] + ts_arr = np.concatenate((ts_arr, ts[ts_index])) + + return ts_arr diff --git a/src/guppy/analysis/combine_data.py b/src/guppy/analysis/combine_data.py new file mode 100644 index 0000000..1eac5b6 --- /dev/null +++ b/src/guppy/analysis/combine_data.py @@ -0,0 +1,119 @@ +import logging +import os + +import numpy as np + +from .io_utils import ( + decide_naming_convention, +) + +logger = logging.getLogger(__name__) + + +def eliminateData(filepath_to_timestamps, filepath_to_data, timeForLightsTurnOn, sampling_rate): + + arr = np.array([]) + ts_arr = np.array([]) + filepaths = list(filepath_to_timestamps.keys()) + for filepath in filepaths: + ts = filepath_to_timestamps[filepath] + data = filepath_to_data[filepath] + + if len(arr) == 0: + arr = np.concatenate((arr, data)) + sub = ts[0] - timeForLightsTurnOn + new_ts = ts - sub + ts_arr = np.concatenate((ts_arr, new_ts)) + else: + temp = data + temp_ts = ts + new_ts = temp_ts - (temp_ts[0] - ts_arr[-1]) + arr = np.concatenate((arr, temp)) + ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) + + return arr, ts_arr + + +def eliminateTs(filepath_to_timestamps, filepath_to_ttl_timestamps, timeForLightsTurnOn, sampling_rate): + + ts_arr = np.array([]) + tsNew_arr = np.array([]) + filepaths = list(filepath_to_timestamps.keys()) + for filepath in filepaths: + ts = filepath_to_timestamps[filepath] + tsNew = filepath_to_ttl_timestamps[filepath] + if len(tsNew_arr) == 0: + sub = tsNew[0] - timeForLightsTurnOn + tsNew_arr = np.concatenate((tsNew_arr, tsNew - sub)) + ts_arr = np.concatenate((ts_arr, ts - sub)) + else: + temp_tsNew = tsNew + temp_ts = ts + new_ts = temp_ts - (temp_tsNew[0] - tsNew_arr[-1]) + new_tsNew = temp_tsNew - (temp_tsNew[0] - tsNew_arr[-1]) + tsNew_arr = np.concatenate((tsNew_arr, new_tsNew + (1 / sampling_rate))) + ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) + + # logger.info(event) + # logger.info(ts_arr) + return ts_arr + + +def combine_data( + filepaths_to_combine: list[str], + pair_name_to_filepath_to_timestamps: dict[str, dict[str, np.ndarray]], + display_name_to_filepath_to_data: dict[str, dict[str, np.ndarray]], + compound_name_to_filepath_to_ttl_timestamps: dict[str, dict[str, np.ndarray]], + timeForLightsTurnOn, + storesList, + sampling_rate, +): + # filepaths_to_combine = [folder1_output_i, folder2_output_i, ...] + logger.debug("Processing timestamps for combining data...") + + names_for_storenames = storesList[1, :] + path = decide_naming_convention(filepaths_to_combine[0]) + + pair_name_to_tsNew = {} + display_name_to_data = {} + compound_name_to_ttl_timestamps = {} + for j in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_")[-1] + name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + pair_name = name_1 + + for i in range(len(names_for_storenames)): + if ( + "control_" + pair_name.lower() in names_for_storenames[i].lower() + or "signal_" + pair_name.lower() in names_for_storenames[i].lower() + ): + display_name = names_for_storenames[i] + filepath_to_timestamps = pair_name_to_filepath_to_timestamps[pair_name] + filepath_to_data = display_name_to_filepath_to_data[display_name] + data, timestampNew = eliminateData( + filepath_to_timestamps, + filepath_to_data, + timeForLightsTurnOn, + sampling_rate, + ) + pair_name_to_tsNew[pair_name] = timestampNew + display_name_to_data[display_name] = data + else: + if "control" in names_for_storenames[i].lower() or "signal" in names_for_storenames[i].lower(): + continue + compound_name = names_for_storenames[i] + "_" + pair_name + filepath_to_timestamps = pair_name_to_filepath_to_timestamps[pair_name] + filepath_to_ttl_timestamps = compound_name_to_filepath_to_ttl_timestamps[compound_name] + + ts = eliminateTs( + filepath_to_timestamps, + filepath_to_ttl_timestamps, + timeForLightsTurnOn, + sampling_rate, + ) + compound_name_to_ttl_timestamps[compound_name] = ts + + return pair_name_to_tsNew, display_name_to_data, compound_name_to_ttl_timestamps diff --git a/src/guppy/analysis/compute_psth.py b/src/guppy/analysis/compute_psth.py new file mode 100644 index 0000000..80aa8a7 --- /dev/null +++ b/src/guppy/analysis/compute_psth.py @@ -0,0 +1,186 @@ +import logging +import math + +import numpy as np + +logger = logging.getLogger(__name__) + + +# helper function to make PSTH for each event +def compute_psth( + z_score, + event, + filepath, + nSecPrev, + nSecPost, + timeInterval, + bin_psth_trials, + use_time_or_trials, + baselineStart, + baselineEnd, + naming, + just_use_signal, + sampling_rate, + ts, + corrected_timestamps, +): + + event = event.replace("\\", "_") + event = event.replace("/", "_") + + # calculate time before event timestamp and time after event timestamp + nTsPrev = int(round(nSecPrev * sampling_rate)) + nTsPost = int(round(nSecPost * sampling_rate)) + + totalTs = (-1 * nTsPrev) + nTsPost + increment = ((-1 * nSecPrev) + nSecPost) / totalTs + timeAxis = np.linspace(nSecPrev, nSecPost + increment, totalTs + 1) + timeAxisNew = np.concatenate((timeAxis, timeAxis[::-1])) + + # reject timestamps for which baseline cannot be calculated because of nan values + new_ts = [] + for i in range(ts.shape[0]): + thisTime = ts[i] # -1 not needed anymore + if thisTime < abs(baselineStart): + continue + else: + new_ts.append(ts[i]) + + # reject burst of timestamps + ts = np.asarray(new_ts) + # skip the event if there are no TTLs + if len(ts) == 0: + new_ts = np.array([]) + logger.info(f"Warning : No TTLs present for {event}. This will cause an error in Visualization step") + else: + new_ts = [ts[0]] + for i in range(1, ts.shape[0]): + thisTime = ts[i] + prevTime = new_ts[-1] + diff = thisTime - prevTime + if diff < timeInterval: + continue + else: + new_ts.append(ts[i]) + + # final timestamps + ts = np.asarray(new_ts) + nTs = ts.shape[0] + + # initialize PSTH vector + psth = np.full((nTs, totalTs + 1), np.nan) + psth_baselineUncorrected = np.full((nTs, totalTs + 1), np.nan) # extra + + # for each timestamp, create trial which will be saved in a PSTH vector + for i in range(nTs): + thisTime = ts[i] # -timeForLightsTurnOn + thisIndex = int(round(thisTime * sampling_rate)) + arr = rowFormation(z_score, thisIndex, -1 * nTsPrev, nTsPost) + if just_use_signal == True: + res = np.subtract(arr, np.nanmean(arr)) + z_score_arr = np.divide(res, np.nanstd(arr)) + arr = z_score_arr + else: + arr = arr + + psth_baselineUncorrected[i, :] = arr # extra + psth[i, :] = baselineCorrection(arr, timeAxis, baselineStart, baselineEnd) + + columns = list(ts) + + if use_time_or_trials == "Time (min)" and bin_psth_trials > 0: + corrected_timestamps = np.divide(corrected_timestamps, 60) + ts_min = np.divide(ts, 60) + bin_steps = np.arange(corrected_timestamps[0], corrected_timestamps[-1] + bin_psth_trials, bin_psth_trials) + indices_each_step = dict() + for i in range(1, bin_steps.shape[0]): + indices_each_step[f"{np.around(bin_steps[i-1],0)}-{np.around(bin_steps[i],0)}"] = np.where( + (ts_min >= bin_steps[i - 1]) & (ts_min <= bin_steps[i]) + )[0] + elif use_time_or_trials == "# of trials" and bin_psth_trials > 0: + bin_steps = np.arange(0, ts.shape[0], bin_psth_trials) + if bin_steps[-1] < ts.shape[0]: + bin_steps = np.concatenate((bin_steps, [ts.shape[0]]), axis=0) + indices_each_step = dict() + for i in range(1, bin_steps.shape[0]): + indices_each_step[f"{bin_steps[i-1]}-{bin_steps[i]}"] = np.arange(bin_steps[i - 1], bin_steps[i]) + else: + indices_each_step = dict() + + psth_bin, psth_bin_baselineUncorrected = [], [] + if indices_each_step: + keys = list(indices_each_step.keys()) + for k in keys: + # no trials in a given bin window, just put all the nan values + if indices_each_step[k].shape[0] == 0: + psth_bin.append(np.full(psth.shape[1], np.nan)) + psth_bin_baselineUncorrected.append(np.full(psth_baselineUncorrected.shape[1], np.nan)) + psth_bin.append(np.full(psth.shape[1], np.nan)) + psth_bin_baselineUncorrected.append(np.full(psth_baselineUncorrected.shape[1], np.nan)) + else: + index = indices_each_step[k] + arr = psth[index, :] + # mean of bins + psth_bin.append(np.nanmean(psth[index, :], axis=0)) + psth_bin_baselineUncorrected.append(np.nanmean(psth_baselineUncorrected[index, :], axis=0)) + psth_bin.append(np.nanstd(psth[index, :], axis=0) / math.sqrt(psth[index, :].shape[0])) + # error of bins + psth_bin_baselineUncorrected.append( + np.nanstd(psth_baselineUncorrected[index, :], axis=0) + / math.sqrt(psth_baselineUncorrected[index, :].shape[0]) + ) + + # adding column names + columns.append(f"bin_({k})") + columns.append(f"bin_err_({k})") + + psth = np.concatenate((psth, psth_bin), axis=0) + psth_baselineUncorrected = np.concatenate((psth_baselineUncorrected, psth_bin_baselineUncorrected), axis=0) + + timeAxis = timeAxis.reshape(1, -1) + psth = np.concatenate((psth, timeAxis), axis=0) + psth_baselineUncorrected = np.concatenate((psth_baselineUncorrected, timeAxis), axis=0) + columns.append("timestamps") + + return psth, psth_baselineUncorrected, columns, ts + + +# function to create PSTH trials corresponding to each event timestamp +def rowFormation(z_score, thisIndex, nTsPrev, nTsPost): + + if nTsPrev < thisIndex and z_score.shape[0] > (thisIndex + nTsPost): + res = z_score[thisIndex - nTsPrev - 1 : thisIndex + nTsPost] + elif nTsPrev >= thisIndex and z_score.shape[0] > (thisIndex + nTsPost): + mismatch = nTsPrev - thisIndex + 1 + res = np.zeros(nTsPrev + nTsPost + 1) + res[:mismatch] = np.nan + res[mismatch:] = z_score[: thisIndex + nTsPost] + elif nTsPrev >= thisIndex and z_score.shape[0] < (thisIndex + nTsPost): + mismatch1 = nTsPrev - thisIndex + 1 + mismatch2 = (thisIndex + nTsPost) - z_score.shape[0] + res1 = np.full(mismatch1, np.nan) + res2 = z_score + res3 = np.full(mismatch2, np.nan) + res = np.concatenate((res1, np.concatenate((res2, res3)))) + else: + mismatch = (thisIndex + nTsPost) - z_score.shape[0] + res1 = np.zeros(mismatch) + res1[:] = np.nan + res2 = z_score[thisIndex - nTsPrev - 1 : z_score.shape[0]] + res = np.concatenate((res2, res1)) + + return res + + +# function to calculate baseline for each PSTH trial and do baseline correction +def baselineCorrection(arr, timeAxis, baselineStart, baselineEnd): + baselineStrtPt = np.where(timeAxis >= baselineStart)[0] + baselineEndPt = np.where(timeAxis >= baselineEnd)[0] + + if baselineStart == 0 and baselineEnd == 0: + return arr + + baseline = np.nanmean(arr[baselineStrtPt[0] : baselineEndPt[0]]) + baselineSub = np.subtract(arr, baseline) + + return baselineSub diff --git a/src/guppy/analysis/control_channel.py b/src/guppy/analysis/control_channel.py new file mode 100644 index 0000000..605bd17 --- /dev/null +++ b/src/guppy/analysis/control_channel.py @@ -0,0 +1,122 @@ +import logging +import os +import shutil + +import numpy as np +import pandas as pd +from scipy import signal as ss +from scipy.optimize import curve_fit + +from .io_utils import ( + read_hdf5, + write_hdf5, +) + +logger = logging.getLogger(__name__) + + +# This function just creates placeholder Control-HDF5 files that are then immediately overwritten later on in the pipeline. +# TODO: Refactor this function to avoid unnecessary file creation. +# function to add control channel when there is no +# isosbestic control channel and update the storeslist file +def add_control_channel(filepath, arr): + + storenames = arr[0, :] + storesList = np.char.lower(arr[1, :]) + + keep_control = np.array([]) + # check a case if there is isosbestic control channel present + for i in range(storesList.shape[0]): + if "control" in storesList[i].lower(): + name = storesList[i].split("_")[-1] + new_str = "signal_" + str(name).lower() + find_signal = [True for i in storesList if i == new_str] + if len(find_signal) > 1: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + if len(find_signal) == 0: + logger.error( + "Isosbectic control channel parameter is set to False and still \ + storeslist file shows there is control channel present" + ) + raise Exception( + "Isosbectic control channel parameter is set to False and still \ + storeslist file shows there is control channel present" + ) + else: + continue + + for i in range(storesList.shape[0]): + if "signal" in storesList[i].lower(): + name = storesList[i].split("_")[-1] + new_str = "control_" + str(name).lower() + find_signal = [True for i in storesList if i == new_str] + if len(find_signal) == 0: + src, dst = os.path.join(filepath, arr[0, i] + ".hdf5"), os.path.join( + filepath, "cntrl" + str(i) + ".hdf5" + ) + shutil.copyfile(src, dst) + arr = np.concatenate((arr, [["cntrl" + str(i)], ["control_" + str(arr[1, i].split("_")[-1])]]), axis=1) + + np.savetxt(os.path.join(filepath, "storesList.csv"), arr, delimiter=",", fmt="%s") + + return arr + + +# main function to create control channel using +# signal channel and save it to a file +def create_control_channel(filepath, arr, window=5001): + + storenames = arr[0, :] + storesList = arr[1, :] + + for i in range(storesList.shape[0]): + event_name, event = storesList[i], storenames[i] + if "control" in event_name.lower() and "cntrl" in event.lower(): + logger.debug("Creating control channel from signal channel using curve-fitting") + name = event_name.split("_")[-1] + signal = read_hdf5("signal_" + name, filepath, "data") + timestampNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") + sampling_rate = np.full(timestampNew.shape, np.nan) + sampling_rate[0] = read_hdf5("timeCorrection_" + name, filepath, "sampling_rate")[0] + + control = helper_create_control_channel(signal, timestampNew, window) + + write_hdf5(control, event_name, filepath, "data") + d = {"timestamps": timestampNew, "data": control, "sampling_rate": sampling_rate} + df = pd.DataFrame(d) + df.to_csv(os.path.join(os.path.dirname(filepath), event.lower() + ".csv"), index=False) + logger.info("Control channel from signal channel created using curve-fitting") + + +# TODO: figure out why a control channel is created for both timestamp correction and z-score steps. +# helper function to create control channel using signal channel +# by curve fitting signal channel to exponential function +# when there is no isosbestic control channel is present +def helper_create_control_channel(signal, timestamps, window): + # check if window is greater than signal shape + if window > signal.shape[0]: + window = ((signal.shape[0] + 1) / 2) + 1 + if window % 2 != 0: + window = window + else: + window = window + 1 + + filtered_signal = ss.savgol_filter(signal, window_length=window, polyorder=3) + + p0 = [5, 50, 60] + + try: + popt, pcov = curve_fit(curveFitFn, timestamps, filtered_signal, p0) + except Exception as e: + logger.error(str(e)) + + # logger.info('Curve Fit Parameters : ', popt) + control = curveFitFn(timestamps, *popt) + + return control + + +# curve fit exponential function +def curveFitFn(x, a, b, c): + return a + (b * np.exp(-(1 / c) * x)) diff --git a/src/guppy/analysis/cross_correlation.py b/src/guppy/analysis/cross_correlation.py new file mode 100644 index 0000000..726943d --- /dev/null +++ b/src/guppy/analysis/cross_correlation.py @@ -0,0 +1,24 @@ +import logging + +import numpy as np +from scipy import signal + +logger = logging.getLogger(__name__) + + +def compute_cross_correlation(arr_A, arr_B, sample_rate): + cross_corr = list() + for a, b in zip(arr_A, arr_B): + if np.isnan(a).any() or np.isnan(b).any(): + corr = signal.correlate(a, b, method="direct") + else: + corr = signal.correlate(a, b) + corr_norm = corr / np.max(np.abs(corr)) + cross_corr.append(corr_norm) + lag = signal.correlation_lags(len(a), len(b)) + lag_msec = np.array(lag / sample_rate, dtype="float32") + + cross_corr_arr = np.array(cross_corr, dtype="float32") + lag_msec = lag_msec.reshape(1, -1) + cross_corr_arr = np.concatenate((cross_corr_arr, lag_msec), axis=0) + return cross_corr_arr diff --git a/src/guppy/analysis/io_utils.py b/src/guppy/analysis/io_utils.py new file mode 100644 index 0000000..742ab3b --- /dev/null +++ b/src/guppy/analysis/io_utils.py @@ -0,0 +1,226 @@ +import fnmatch +import glob +import logging +import os +import re + +import h5py +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + + +def takeOnlyDirs(paths): + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + + +# find files by ignoring the case sensitivity +def find_files(path, glob_path, ignore_case=False): + rule = ( + re.compile(fnmatch.translate(glob_path), re.IGNORECASE) + if ignore_case + else re.compile(fnmatch.translate(glob_path)) + ) + + no_bytes_path = os.listdir(os.path.expanduser(path)) + str_path = [] + + # converting byte object to string + for x in no_bytes_path: + try: + str_path.append(x.decode("utf-8")) + except: + str_path.append(x) + return [os.path.join(path, n) for n in str_path if rule.match(n)] + + +# check if dealing with TDT files or csv files +def check_TDT(filepath): + path = glob.glob(os.path.join(filepath, "*.tsq")) + if len(path) > 0: + return True + else: + return False + + +# function to read hdf5 file +def read_hdf5(event, filepath, key): + if event: + event = event.replace("\\", "_") + event = event.replace("/", "_") + op = os.path.join(filepath, event + ".hdf5") + else: + op = filepath + + if os.path.exists(op): + with h5py.File(op, "r") as f: + arr = np.asarray(f[key]) + else: + logger.error(f"{event}.hdf5 file does not exist") + raise Exception("{}.hdf5 file does not exist".format(event)) + + return arr + + +# function to write hdf5 file +def write_hdf5(data, event, filepath, key): + event = event.replace("\\", "_") + event = event.replace("/", "_") + op = os.path.join(filepath, event + ".hdf5") + + # if file does not exist create a new file + if not os.path.exists(op): + with h5py.File(op, "w") as f: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) + + # if file already exists, append data to it or add a new key to it + else: + with h5py.File(op, "r+") as f: + if key in list(f.keys()): + if type(data) is np.ndarray: + f[key].resize(data.shape) + arr = f[key] + arr[:] = data + else: + arr = f[key] + arr = data + else: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) + + +# function to check if the naming convention for saving storeslist file was followed or not +def decide_naming_convention(filepath): + path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) + + path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) + + path = sorted(path_1 + path_2, key=str.casefold) + if len(path) % 2 != 0: + logger.error("There are not equal number of Control and Signal data") + raise Exception("There are not equal number of Control and Signal data") + + path = np.asarray(path).reshape(2, -1) + + return path + + +# function to read coordinates file which was saved by selecting chunks for artifacts removal +def fetchCoords(filepath, naming, data): + + path = os.path.join(filepath, "coordsForPreProcessing_" + naming + ".npy") + + if not os.path.exists(path): + coords = np.array([0, data[-1]]) + else: + coords = np.load(os.path.join(filepath, "coordsForPreProcessing_" + naming + ".npy"))[:, 0] + + if coords.shape[0] % 2 != 0: + logger.error("Number of values in coordsForPreProcessing file is not even.") + raise Exception("Number of values in coordsForPreProcessing file is not even.") + + coords = coords.reshape(-1, 2) + + return coords + + +def get_coords(filepath, name, tsNew, removeArtifacts): # TODO: Make less redundant with fetchCoords + if removeArtifacts == True: + coords = fetchCoords(filepath, name, tsNew) + else: + dt = tsNew[1] - tsNew[0] + coords = np.array([[tsNew[0] - dt, tsNew[-1] + dt]]) + return coords + + +def get_all_stores_for_combining_data(folderNames): + op = [] + for i in range(100): + temp = [] + match = r"[\s\S]*" + "_output_" + str(i) + for j in folderNames: + temp.append(re.findall(match, j)) + temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold) + if len(temp) > 0: + op.append(temp) + + return op + + +# for combining data, reading storeslist file from both data and create a new storeslist array +def check_storeslistfile(folderNames): + storesList = np.array([[], []]) + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1), + ), + axis=1, + ) + + storesList = np.unique(storesList, axis=1) + + return storesList + + +def get_control_and_signal_channel_names(storesList): + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + + channels_arr = [] + for i in range(names_for_storenames.shape[0]): + if "control" in names_for_storenames[i].lower() or "signal" in names_for_storenames[i].lower(): + channels_arr.append(names_for_storenames[i]) + + channels_arr = sorted(channels_arr, key=str.casefold) + try: + channels_arr = np.asarray(channels_arr).reshape(2, -1) + except: + logger.error("Error in saving stores list file or spelling mistake for control or signal") + raise Exception("Error in saving stores list file or spelling mistake for control or signal") + + return channels_arr + + +# function to read h5 file and make a dataframe from it +def read_Df(filepath, event, name): + event = event.replace("\\", "_") + event = event.replace("/", "_") + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df + + +def make_dir_for_cross_correlation(filepath): + op = os.path.join(filepath, "cross_correlation_output") + if not os.path.exists(op): + os.mkdir(op) + return op + + +def makeAverageDir(filepath): + + op = os.path.join(filepath, "average") + if not os.path.exists(op): + os.mkdir(op) + + return op diff --git a/src/guppy/analysis/psth_average.py b/src/guppy/analysis/psth_average.py new file mode 100644 index 0000000..664cc3d --- /dev/null +++ b/src/guppy/analysis/psth_average.py @@ -0,0 +1,215 @@ +import glob +import logging +import math +import os +import re + +import numpy as np +import pandas as pd + +from .io_utils import ( + make_dir_for_cross_correlation, + makeAverageDir, + read_Df, + write_hdf5, +) +from .psth_utils import create_Df_for_psth, getCorrCombinations + +logger = logging.getLogger(__name__) + + +# function to compute average of group of recordings +def averageForGroup(folderNames, event, inputParameters): + + event = event.replace("\\", "_") + event = event.replace("/", "_") + + logger.debug("Averaging group of data...") + path = [] + abspath = inputParameters["abspath"] + selectForComputePsth = inputParameters["selectForComputePsth"] + path_temp_len = [] + op = makeAverageDir(abspath) + + # combining paths to all the selected folders for doing average + for i in range(len(folderNames)): + if selectForComputePsth == "z_score": + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + elif selectForComputePsth == "dff": + path_temp = glob.glob(os.path.join(folderNames[i], "dff_*")) + else: + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + glob.glob( + os.path.join(folderNames[i], "dff_*") + ) + + path_temp_len.append(len(path_temp)) + # path_temp = glob.glob(os.path.join(folderNames[i], 'z_score_*')) + for j in range(len(path_temp)): + basename = (os.path.basename(path_temp[j])).split(".")[0] + write_hdf5(np.array([]), basename, op, "data") + name_1 = basename.split("_")[-1] + temp = [folderNames[i], event + "_" + name_1, basename] + path.append(temp) + + # processing of all the paths + path_temp_len = np.asarray(path_temp_len) + max_len = np.argmax(path_temp_len) + + naming = [] + for i in range(len(path)): + naming.append(path[i][2]) + naming = np.unique(np.asarray(naming)) + + new_path = [[] for _ in range(path_temp_len[max_len])] + for i in range(len(path)): + idx = np.where(naming == path[i][2])[0][0] + new_path[idx].append(path[i]) + + # read PSTH for each event and make the average of it. Save the final output to an average folder. + for i in range(len(new_path)): + psth, psth_bins = [], [] + columns = [] + bins_cols = [] + temp_path = new_path[i] + for j in range(len(temp_path)): + # logger.info(os.path.join(temp_path[j][0], temp_path[j][1]+'_{}.h5'.format(temp_path[j][2]))) + if not os.path.exists(os.path.join(temp_path[j][0], temp_path[j][1] + "_{}.h5".format(temp_path[j][2]))): + continue + else: + df = read_Df(temp_path[j][0], temp_path[j][1], temp_path[j][2]) # filepath, event, name + cols = list(df.columns) + regex = re.compile("bin_[(]") + bins_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + psth.append(np.asarray(df["mean"])) + columns.append(os.path.basename(temp_path[j][0])) + if len(bins_cols) > 0: + psth_bins.append(df[bins_cols]) + + if len(psth) == 0: + logger.warning("Something is wrong with the file search pattern.") + continue + + if len(bins_cols) > 0: + df_bins = pd.concat(psth_bins, axis=1) + df_bins_mean = df_bins.groupby(by=df_bins.columns, axis=1).mean() + df_bins_err = df_bins.groupby(by=df_bins.columns, axis=1).std() / math.sqrt(df_bins.shape[1]) + cols_err = list(df_bins_err.columns) + dict_err = {} + for i in cols_err: + split = i.split("_") + dict_err[i] = "{}_err_{}".format(split[0], split[1]) + df_bins_err = df_bins_err.rename(columns=dict_err) + columns = columns + list(df_bins_mean.columns) + list(df_bins_err.columns) + df_bins_mean_err = pd.concat([df_bins_mean, df_bins_err], axis=1).T + psth, df_bins_mean_err = np.asarray(psth), np.asarray(df_bins_mean_err) + psth = np.concatenate((psth, df_bins_mean_err), axis=0) + else: + psth = psth_shape_check(psth) + psth = np.asarray(psth) + + timestamps = np.asarray(df["timestamps"]).reshape(1, -1) + psth = np.concatenate((psth, timestamps), axis=0) + columns = columns + ["timestamps"] + create_Df_for_psth(op, temp_path[j][1], temp_path[j][2], psth, columns=columns) + + # read PSTH peak and area for each event and combine them. Save the final output to an average folder + for i in range(len(new_path)): + arr = [] + index = [] + temp_path = new_path[i] + for j in range(len(temp_path)): + if not os.path.exists( + os.path.join(temp_path[j][0], "peak_AUC_" + temp_path[j][1] + "_" + temp_path[j][2] + ".h5") + ): + continue + else: + df = read_Df_area_peak(temp_path[j][0], temp_path[j][1] + "_" + temp_path[j][2]) + arr.append(df) + index.append(list(df.index)) + + if len(arr) == 0: + logger.warning("Something is wrong with the file search pattern.") + continue + index = list(np.concatenate(index)) + new_df = pd.concat(arr, axis=0) # os.path.join(filepath, 'peak_AUC_'+name+'.csv') + new_df.to_csv(os.path.join(op, "peak_AUC_{}_{}.csv".format(temp_path[j][1], temp_path[j][2])), index=index) + new_df.to_hdf( + os.path.join(op, "peak_AUC_{}_{}.h5".format(temp_path[j][1], temp_path[j][2])), + key="df", + mode="w", + index=index, + ) + + # read cross-correlation files and combine them. Save the final output to an average folder + type = [] + for i in range(len(folderNames)): + _, temp_type = getCorrCombinations(folderNames[i], inputParameters) + type.append(temp_type) + + type = np.unique(np.array(type)) + for i in range(len(type)): + corr = [] + columns = [] + df = None + for j in range(len(folderNames)): + corr_info, _ = getCorrCombinations(folderNames[j], inputParameters) + for k in range(1, len(corr_info)): + path = os.path.join( + folderNames[j], + "cross_correlation_output", + "corr_" + event + "_" + type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], + ) + if not os.path.exists(path + ".h5"): + continue + else: + df = read_Df( + os.path.join(folderNames[j], "cross_correlation_output"), + "corr_" + event, + type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], + ) + corr.append(df["mean"]) + columns.append(os.path.basename(folderNames[j])) + + if not isinstance(df, pd.DataFrame): + break + + corr = np.array(corr) + timestamps = np.array(df["timestamps"]).reshape(1, -1) + corr = np.concatenate((corr, timestamps), axis=0) + columns.append("timestamps") + create_Df_for_psth( + make_dir_for_cross_correlation(op), + "corr_" + event, + type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], + corr, + columns=columns, + ) + + logger.info("Group of data averaged.") + + +def psth_shape_check(psth): + + each_ln = [] + for i in range(len(psth)): + each_ln.append(psth[i].shape[0]) + + each_ln = np.asarray(each_ln) + keep_ln = each_ln[-1] + + for i in range(len(psth)): + if psth[i].shape[0] > keep_ln: + psth[i] = psth[i][:keep_ln] + elif psth[i].shape[0] < keep_ln: + psth[i] = np.append(psth[i], np.full(keep_ln - len(psth[i]), np.nan)) + else: + psth[i] = psth[i] + + return psth + + +def read_Df_area_peak(filepath, name): + op = os.path.join(filepath, "peak_AUC_" + name + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df diff --git a/src/guppy/analysis/psth_peak_and_area.py b/src/guppy/analysis/psth_peak_and_area.py new file mode 100644 index 0000000..2c2c421 --- /dev/null +++ b/src/guppy/analysis/psth_peak_and_area.py @@ -0,0 +1,48 @@ +import logging +from collections import OrderedDict + +import numpy as np + +logger = logging.getLogger(__name__) + + +def compute_psth_peak_and_area(psth_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint): + + peak_startPoint = np.asarray(peak_startPoint) + peak_endPoint = np.asarray(peak_endPoint) + + peak_startPoint = peak_startPoint[~np.isnan(peak_startPoint)] + peak_endPoint = peak_endPoint[~np.isnan(peak_endPoint)] + + if peak_startPoint.shape[0] != peak_endPoint.shape[0]: + logger.error("Number of Peak Start Time and Peak End Time are unequal.") + raise Exception("Number of Peak Start Time and Peak End Time are unequal.") + + if np.less_equal(peak_endPoint, peak_startPoint).any() == True: + logger.error( + "Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window." + ) + raise Exception( + "Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window." + ) + + peak_and_area = OrderedDict() + + if peak_startPoint.shape[0] == 0 or peak_endPoint.shape[0] == 0: + peak_and_area["peak"] = np.nan + peak_and_area["area"] = np.nan + + for i in range(peak_startPoint.shape[0]): + startPtForPeak = np.where(timestamps >= peak_startPoint[i])[0] + endPtForPeak = np.where(timestamps >= peak_endPoint[i])[0] + if len(startPtForPeak) >= 1 and len(endPtForPeak) >= 1: + peakPoint_pos = startPtForPeak[0] + np.argmax(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + peakPoint_neg = startPtForPeak[0] + np.argmin(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + peak_and_area["peak_pos_" + str(i + 1)] = np.amax(psth_mean[peakPoint_pos], axis=0) + peak_and_area["peak_neg_" + str(i + 1)] = np.amin(psth_mean[peakPoint_neg], axis=0) + peak_and_area["area_" + str(i + 1)] = np.trapz(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + else: + peak_and_area["peak_" + str(i + 1)] = np.nan + peak_and_area["area_" + str(i + 1)] = np.nan + + return peak_and_area diff --git a/src/guppy/analysis/psth_utils.py b/src/guppy/analysis/psth_utils.py new file mode 100644 index 0000000..c351511 --- /dev/null +++ b/src/guppy/analysis/psth_utils.py @@ -0,0 +1,132 @@ +import glob +import logging +import math +import os +import re + +import numpy as np +import pandas as pd + +from .io_utils import read_hdf5 + +logger = logging.getLogger(__name__) + + +# function to create dataframe for each event PSTH and save it to h5 file +def create_Df_for_psth(filepath, event, name, psth, columns=[]): + event = event.replace("\\", "_") + event = event.replace("/", "_") + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + + # check if file already exists + # if os.path.exists(op): + # return 0 + + # removing psth binned trials + columns = np.array(columns, dtype="str") + regex = re.compile("bin_*") + single_trials = columns[[i for i in range(len(columns)) if not regex.match(columns[i])]] + single_trials_index = [i for i in range(len(single_trials)) if single_trials[i] != "timestamps"] + + psth = psth.T + if psth.ndim > 1: + mean = np.nanmean(psth[:, single_trials_index], axis=1).reshape(-1, 1) + err = np.nanstd(psth[:, single_trials_index], axis=1) / math.sqrt(psth[:, single_trials_index].shape[1]) + err = err.reshape(-1, 1) + psth = np.hstack((psth, mean)) + psth = np.hstack((psth, err)) + # timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) + # psth = np.hstack((psth, timestamps)) + try: + ts = read_hdf5(event, filepath, "ts") + ts = np.append(ts, ["mean", "err"]) + except: + ts = None + + if len(columns) == 0: + df = pd.DataFrame(psth, index=None, columns=ts, dtype="float32") + else: + columns = np.asarray(columns) + columns = np.append(columns, ["mean", "err"]) + df = pd.DataFrame(psth, index=None, columns=list(columns), dtype="float32") + + df.to_hdf(op, key="df", mode="w") + + +# same function used to store PSTH in computePsth file +# Here, cross correlation dataframe is saved instead of PSTH +# cross correlation dataframe has the same structure as PSTH file +def create_Df_for_cross_correlation(filepath, event, name, psth, columns=[]): + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + + # check if file already exists + # if os.path.exists(op): + # return 0 + + # removing psth binned trials + columns = list(np.array(columns, dtype="str")) + regex = re.compile("bin_*") + single_trials_index = [i for i in range(len(columns)) if not regex.match(columns[i])] + single_trials_index = [i for i in range(len(columns)) if columns[i] != "timestamps"] + + psth = psth.T + if psth.ndim > 1: + mean = np.nanmean(psth[:, single_trials_index], axis=1).reshape(-1, 1) + err = np.nanstd(psth[:, single_trials_index], axis=1) / math.sqrt(psth[:, single_trials_index].shape[1]) + err = err.reshape(-1, 1) + psth = np.hstack((psth, mean)) + psth = np.hstack((psth, err)) + # timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) + # psth = np.hstack((psth, timestamps)) + try: + ts = read_hdf5(event, filepath, "ts") + ts = np.append(ts, ["mean", "err"]) + except: + ts = None + + if len(columns) == 0: + df = pd.DataFrame(psth, index=None, columns=ts, dtype="float32") + else: + columns = np.asarray(columns) + columns = np.append(columns, ["mean", "err"]) + df = pd.DataFrame(psth, index=None, columns=columns, dtype="float32") + + df.to_hdf(op, key="df", mode="w") + + +def getCorrCombinations(filepath, inputParameters): + selectForComputePsth = inputParameters["selectForComputePsth"] + if selectForComputePsth == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForComputePsth == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + else: + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + + names = list() + type = list() + for i in range(len(path)): + basename = (os.path.basename(path[i])).split(".")[0] + names.append(basename.split("_")[-1]) + type.append((os.path.basename(path[i])).split(".")[0].split("_" + names[-1], 1)[0]) + + names = list(np.unique(np.array(names))) + type = list(np.unique(np.array(type))) + + corr_info = list() + if len(names) <= 1: + logger.info("Cross-correlation cannot be computed because only one signal is present.") + return corr_info, type + elif len(names) == 2: + corr_info = names + else: + corr_info = names + corr_info.append(names[0]) + + return corr_info, type diff --git a/src/guppy/analysis/standard_io.py b/src/guppy/analysis/standard_io.py new file mode 100644 index 0000000..d6dd9af --- /dev/null +++ b/src/guppy/analysis/standard_io.py @@ -0,0 +1,333 @@ +import logging +import os + +import numpy as np +import pandas as pd + +from .io_utils import ( + decide_naming_convention, + fetchCoords, + get_control_and_signal_channel_names, + read_hdf5, + write_hdf5, +) + +logger = logging.getLogger(__name__) + + +def read_control_and_signal(filepath, storesList): + channels_arr = get_control_and_signal_channel_names(storesList) + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + + name_to_data = {} + name_to_timestamps = {} + name_to_sampling_rate = {} + name_to_npoints = {} + + for i in range(channels_arr.shape[1]): + control_name = channels_arr[0, i] + signal_name = channels_arr[1, i] + idx_c = np.where(names_for_storenames == control_name)[0] + idx_s = np.where(names_for_storenames == signal_name)[0] + control_storename = storenames[idx_c[0]] + signal_storename = storenames[idx_s[0]] + + control_data = read_hdf5(control_storename, filepath, "data") + signal_data = read_hdf5(signal_storename, filepath, "data") + control_timestamps = read_hdf5(control_storename, filepath, "timestamps") + signal_timestamps = read_hdf5(signal_storename, filepath, "timestamps") + control_sampling_rate = read_hdf5(control_storename, filepath, "sampling_rate") + signal_sampling_rate = read_hdf5(signal_storename, filepath, "sampling_rate") + try: # TODO: define npoints for csv datasets + control_npoints = read_hdf5(control_storename, filepath, "npoints") + signal_npoints = read_hdf5(signal_storename, filepath, "npoints") + except KeyError: # npoints is not defined for csv datasets + control_npoints = None + signal_npoints = None + + name_to_data[control_name] = control_data + name_to_data[signal_name] = signal_data + name_to_timestamps[control_name] = control_timestamps + name_to_timestamps[signal_name] = signal_timestamps + name_to_sampling_rate[control_name] = control_sampling_rate + name_to_sampling_rate[signal_name] = signal_sampling_rate + name_to_npoints[control_name] = control_npoints + name_to_npoints[signal_name] = signal_npoints + + return name_to_data, name_to_timestamps, name_to_sampling_rate, name_to_npoints + + +def read_ttl(filepath, storesList): + channels_arr = get_control_and_signal_channel_names(storesList) + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + + name_to_timestamps = {} + for storename, name in zip(storenames, names_for_storenames): + if name in channels_arr: + continue + timestamps = read_hdf5(storename, filepath, "timestamps") + name_to_timestamps[name] = timestamps + + return name_to_timestamps + + +def write_corrected_timestamps( + filepath, corrected_name_to_timestamps, name_to_timestamps, name_to_sampling_rate, name_to_correctionIndex +): + for name, correctionIndex in name_to_correctionIndex.items(): + timestamps = name_to_timestamps[name] + corrected_timestamps = corrected_name_to_timestamps[name] + sampling_rate = name_to_sampling_rate[name] + if sampling_rate.shape == (): # numpy scalar + sampling_rate = np.asarray([sampling_rate]) + name_1 = name.split("_")[-1] + write_hdf5(np.asarray([timestamps[0]]), "timeCorrection_" + name_1, filepath, "timeRecStart") + write_hdf5(corrected_timestamps, "timeCorrection_" + name_1, filepath, "timestampNew") + write_hdf5(correctionIndex, "timeCorrection_" + name_1, filepath, "correctionIndex") + write_hdf5(sampling_rate, "timeCorrection_" + name_1, filepath, "sampling_rate") + + +def write_corrected_data(filepath, name_to_corrected_data): + for name, data in name_to_corrected_data.items(): + write_hdf5(data, name, filepath, "data") + + +def write_corrected_ttl_timestamps( + filepath, + compound_name_to_corrected_ttl_timestamps, +): + logger.debug("Applying correction of timestamps to the data and event timestamps") + for compound_name, corrected_ttl_timestamps in compound_name_to_corrected_ttl_timestamps.items(): + write_hdf5(corrected_ttl_timestamps, compound_name, filepath, "ts") + logger.info("Timestamps corrections applied to the data and event timestamps.") + + +def read_corrected_data(control_path, signal_path, filepath, name): + control = read_hdf5("", control_path, "data").reshape(-1) + signal = read_hdf5("", signal_path, "data").reshape(-1) + tsNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") + + return control, signal, tsNew + + +def write_zscore(filepath, name, z_score, dff, control_fit, temp_control_arr): + write_hdf5(z_score, "z_score_" + name, filepath, "data") + write_hdf5(dff, "dff_" + name, filepath, "data") + write_hdf5(control_fit, "cntrl_sig_fit_" + name, filepath, "data") + if temp_control_arr is not None: + write_hdf5(temp_control_arr, "control_" + name, filepath, "data") + + +def read_corrected_timestamps_pairwise(filepath): + pair_name_to_tsNew = {} + pair_name_to_sampling_rate = {} + path = decide_naming_convention(filepath) + for j in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_") + name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_") + if name_1[-1] != name_2[-1]: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + name = name_1[-1] + + tsNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") + sampling_rate = read_hdf5("timeCorrection_" + name, filepath, "sampling_rate")[0] + pair_name_to_tsNew[name] = tsNew + pair_name_to_sampling_rate[name] = sampling_rate + return pair_name_to_tsNew, pair_name_to_sampling_rate + + +def read_coords_pairwise(filepath, pair_name_to_tsNew): + pair_name_to_coords = {} + path = decide_naming_convention(filepath) + for j in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_") + name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_") + if name_1[-1] != name_2[-1]: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + pair_name = name_1[-1] + + tsNew = pair_name_to_tsNew[pair_name] + coords = fetchCoords(filepath, pair_name, tsNew) + pair_name_to_coords[pair_name] = coords + return pair_name_to_coords + + +def read_corrected_data_dict(filepath, storesList): # TODO: coordinate with read_corrected_data + name_to_corrected_data = {} + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + control_and_signal_names = get_control_and_signal_channel_names(storesList) + + for storename, name in zip(storenames, names_for_storenames): + if name not in control_and_signal_names: + continue + data = read_hdf5(name, filepath, "data").reshape(-1) + name_to_corrected_data[name] = data + + return name_to_corrected_data + + +def read_corrected_ttl_timestamps(filepath, storesList): + compound_name_to_ttl_timestamps = {} + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + arr = get_control_and_signal_channel_names(storesList) + + for storename, name in zip(storenames, names_for_storenames): + if name in arr: + continue + ttl_name = name + for i in range(arr.shape[1]): + name_1 = arr[0, i].split("_")[-1] + name_2 = arr[1, i].split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + compound_name = ttl_name + "_" + name_1 + ts = read_hdf5(compound_name, filepath, "ts") + compound_name_to_ttl_timestamps[compound_name] = ts + + return compound_name_to_ttl_timestamps + + +def write_artifact_corrected_timestamps(filepath, pair_name_to_corrected_timestamps): + for pair_name, timestamps in pair_name_to_corrected_timestamps.items(): + write_hdf5(timestamps, "timeCorrection_" + pair_name, filepath, "timestampNew") + + +def write_artifact_removal( + filepath, + name_to_corrected_data, + pair_name_to_corrected_timestamps, + compound_name_to_corrected_ttl_timestamps=None, +): + write_corrected_data(filepath, name_to_corrected_data) + write_corrected_ttl_timestamps(filepath, compound_name_to_corrected_ttl_timestamps) + if pair_name_to_corrected_timestamps is not None: + write_artifact_corrected_timestamps(filepath, pair_name_to_corrected_timestamps) + + +def read_timestamps_for_combining_data(filepaths_to_combine): + path = decide_naming_convention(filepaths_to_combine[0]) + pair_name_to_filepath_to_timestamps: dict[str, dict[str, np.ndarray]] = {} + for j in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_")[-1] + name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + pair_name = name_1 + pair_name_to_filepath_to_timestamps[pair_name] = {} + for filepath in filepaths_to_combine: + tsNew = read_hdf5("timeCorrection_" + pair_name, filepath, "timestampNew") + pair_name_to_filepath_to_timestamps[pair_name][filepath] = tsNew + + return pair_name_to_filepath_to_timestamps + + +def read_data_for_combining_data(filepaths_to_combine, storesList): + names_for_storenames = storesList[1, :] + path = decide_naming_convention(filepaths_to_combine[0]) + display_name_to_filepath_to_data: dict[str, dict[str, np.ndarray]] = {} + for j in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_")[-1] + name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + pair_name = name_1 + for i in range(len(names_for_storenames)): + if not ( + "control_" + pair_name.lower() in names_for_storenames[i].lower() + or "signal_" + pair_name.lower() in names_for_storenames[i].lower() + ): + continue + display_name = names_for_storenames[i] + display_name_to_filepath_to_data[display_name] = {} + for filepath in filepaths_to_combine: + data = read_hdf5(display_name, filepath, "data").reshape(-1) + display_name_to_filepath_to_data[display_name][filepath] = data + + return display_name_to_filepath_to_data + + +def read_ttl_timestamps_for_combining_data(filepaths_to_combine, storesList): + names_for_storenames = storesList[1, :] + path = decide_naming_convention(filepaths_to_combine[0]) + compound_name_to_filepath_to_ttl_timestamps: dict[str, dict[str, np.ndarray]] = {} + for j in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_")[-1] + name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + pair_name = name_1 + for i in range(len(names_for_storenames)): + if ( + "control_" + pair_name.lower() in names_for_storenames[i].lower() + or "signal_" + pair_name.lower() in names_for_storenames[i].lower() + ): + continue + compound_name = names_for_storenames[i] + "_" + pair_name + compound_name_to_filepath_to_ttl_timestamps[compound_name] = {} + for filepath in filepaths_to_combine: + if os.path.exists(os.path.join(filepath, names_for_storenames[i] + "_" + pair_name + ".hdf5")): + ts = read_hdf5(names_for_storenames[i] + "_" + pair_name, filepath, "ts").reshape(-1) + else: + ts = np.array([]) + compound_name_to_filepath_to_ttl_timestamps[compound_name][filepath] = ts + + return compound_name_to_filepath_to_ttl_timestamps + + +def write_combined_data(output_filepath, pair_name_to_tsNew, display_name_to_data, compound_name_to_ttl_timestamps): + for pair_name, tsNew in pair_name_to_tsNew.items(): + write_hdf5(tsNew, "timeCorrection_" + pair_name, output_filepath, "timestampNew") + for display_name, data in display_name_to_data.items(): + write_hdf5(data, display_name, output_filepath, "data") + for compound_name, ts in compound_name_to_ttl_timestamps.items(): + write_hdf5(ts, compound_name, output_filepath, "ts") + + +def write_peak_and_area_to_hdf5(filepath, arr, name, index=[]): + + op = os.path.join(filepath, "peak_AUC_" + name + ".h5") + dirname = os.path.dirname(filepath) + + df = pd.DataFrame(arr, index=index) + + df.to_hdf(op, key="df", mode="w") + + +def write_peak_and_area_to_csv(filepath, arr, name, index=[]): + op = os.path.join(filepath, "peak_AUC_" + name + ".csv") + df = pd.DataFrame(arr, index=index) + + df.to_csv(op) + + +def write_freq_and_amp_to_hdf5(filepath, arr, name, index=[], columns=[]): + + op = os.path.join(filepath, "freqAndAmp_" + name + ".h5") + dirname = os.path.dirname(filepath) + + df = pd.DataFrame(arr, index=index, columns=columns) + + df.to_hdf(op, key="df", mode="w") + + +def write_freq_and_amp_to_csv(filepath, arr, name, index=[], columns=[]): + op = os.path.join(filepath, name) + df = pd.DataFrame(arr, index=index, columns=columns) + df.to_csv(op) + + +def read_freq_and_amp_from_hdf5(filepath, name): + op = os.path.join(filepath, "freqAndAmp_" + name + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df diff --git a/src/guppy/analysis/timestamp_correction.py b/src/guppy/analysis/timestamp_correction.py new file mode 100644 index 0000000..0806fb8 --- /dev/null +++ b/src/guppy/analysis/timestamp_correction.py @@ -0,0 +1,200 @@ +import logging + +import numpy as np + +from .io_utils import get_control_and_signal_channel_names + +logger = logging.getLogger(__name__) + + +def correct_timestamps( + timeForLightsTurnOn, + storesList, + name_to_timestamps, + name_to_data, + name_to_sampling_rate, + name_to_npoints, + name_to_timestamps_ttl, + mode, +): + name_to_corrected_timestamps, name_to_correctionIndex, name_to_corrected_data = timestampCorrection( + timeForLightsTurnOn, + storesList, + name_to_timestamps, + name_to_data, + name_to_sampling_rate, + name_to_npoints, + mode=mode, + ) + compound_name_to_corrected_ttl_timestamps = decide_naming_and_applyCorrection_ttl( + timeForLightsTurnOn, + storesList, + name_to_timestamps_ttl, + name_to_timestamps, + name_to_data, + mode=mode, + ) + + return ( + name_to_corrected_timestamps, + name_to_correctionIndex, + name_to_corrected_data, + compound_name_to_corrected_ttl_timestamps, + ) + + +# function to correct timestamps after eliminating first few seconds of the data (for csv or TDT data depending on mode) +def timestampCorrection( + timeForLightsTurnOn, + storesList, + name_to_timestamps, + name_to_data, + name_to_sampling_rate, + name_to_npoints, + mode, +): + logger.debug( + f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds" + ) + if mode not in ["tdt", "csv"]: + logger.error("Mode should be either 'tdt' or 'csv'") + raise ValueError("Mode should be either 'tdt' or 'csv'") + name_to_corrected_timestamps = {} + name_to_correctionIndex = {} + name_to_corrected_data = {} + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + channels_arr = get_control_and_signal_channel_names(storesList) + + indices = check_cntrl_sig_length(channels_arr, name_to_data) + + for i in range(channels_arr.shape[1]): + control_name = channels_arr[0, i] + signal_name = channels_arr[1, i] + name_1 = channels_arr[0, i].split("_")[-1] + name_2 = channels_arr[1, i].split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + + # dirname = os.path.dirname(path[i]) + idx = np.where(names_for_storenames == indices[i])[0] + + if idx.shape[0] == 0: + logger.error(f"{channels_arr[0,i]} does not exist in the stores list file.") + raise Exception("{} does not exist in the stores list file.".format(channels_arr[0, i])) + + name = names_for_storenames[idx][0] + timestamp = name_to_timestamps[name] + sampling_rate = name_to_sampling_rate[name] + npoints = name_to_npoints[name] + + if mode == "tdt": + timeRecStart = timestamp[0] + timestamps = np.subtract(timestamp, timeRecStart) + adder = np.arange(npoints) / sampling_rate + lengthAdder = adder.shape[0] + timestampNew = np.zeros((len(timestamps), lengthAdder)) + for i in range(lengthAdder): + timestampNew[:, i] = np.add(timestamps, adder[i]) + timestampNew = (timestampNew.T).reshape(-1, order="F") + correctionIndex = np.where(timestampNew >= timeForLightsTurnOn)[0] + timestampNew = timestampNew[correctionIndex] + elif mode == "csv": + correctionIndex = np.where(timestamp >= timeForLightsTurnOn)[0] + timestampNew = timestamp[correctionIndex] + + for displayName in [control_name, signal_name]: + name_to_corrected_timestamps[displayName] = timestampNew + name_to_correctionIndex[displayName] = correctionIndex + data = name_to_data[displayName] + if (data == 0).all() == True: + name_to_corrected_data[displayName] = data + else: + name_to_corrected_data[displayName] = data[correctionIndex] + + logger.info("Timestamps corrected and converted to seconds.") + return name_to_corrected_timestamps, name_to_correctionIndex, name_to_corrected_data + + +def decide_naming_and_applyCorrection_ttl( + timeForLightsTurnOn, + storesList, + name_to_timestamps_ttl, + name_to_timestamps, + name_to_data, + mode, +): + logger.debug("Applying correction of timestamps to the data and event timestamps") + storenames = storesList[0, :] + names_for_storenames = storesList[1, :] + arr = get_control_and_signal_channel_names(storesList) + indices = check_cntrl_sig_length(arr, name_to_data) + + compound_name_to_corrected_ttl_timestamps = {} + for ttl_name, ttl_timestamps in name_to_timestamps_ttl.items(): + for i in range(arr.shape[1]): + name_1 = arr[0, i].split("_")[-1] + name_2 = arr[1, i].split("_")[-1] + if name_1 != name_2: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + + idx = np.where(names_for_storenames == indices[i])[0] + if idx.shape[0] == 0: + logger.error(f"{arr[0,i]} does not exist in the stores list file.") + raise Exception("{} does not exist in the stores list file.".format(arr[0, i])) + + name = names_for_storenames[idx][0] + timestamps = name_to_timestamps[name] + timeRecStart = timestamps[0] + corrected_ttl_timestamps = applyCorrection_ttl( + timeForLightsTurnOn, + timeRecStart, + ttl_timestamps, + mode, + ) + compound_name = ttl_name + "_" + name_1 + compound_name_to_corrected_ttl_timestamps[compound_name] = corrected_ttl_timestamps + + logger.info("Timestamps corrections applied to the data and event timestamps.") + return compound_name_to_corrected_ttl_timestamps + + +def applyCorrection_ttl( + timeForLightsTurnOn, + timeRecStart, + ttl_timestamps, + mode, +): + corrected_ttl_timestamps = ttl_timestamps + if mode == "tdt": + res = (corrected_ttl_timestamps >= timeRecStart).all() + if res == True: + corrected_ttl_timestamps = np.subtract(corrected_ttl_timestamps, timeRecStart) + corrected_ttl_timestamps = np.subtract(corrected_ttl_timestamps, timeForLightsTurnOn) + else: + corrected_ttl_timestamps = np.subtract(corrected_ttl_timestamps, timeForLightsTurnOn) + elif mode == "csv": + corrected_ttl_timestamps = np.subtract(corrected_ttl_timestamps, timeForLightsTurnOn) + return corrected_ttl_timestamps + + +# function to check control and signal channel has same length +# if not, take a smaller length and do pre-processing +def check_cntrl_sig_length(channels_arr, name_to_data): + + indices = [] + for i in range(channels_arr.shape[1]): + control_name = channels_arr[0, i] + signal_name = channels_arr[1, i] + control = name_to_data[control_name] + signal = name_to_data[signal_name] + if control.shape[0] < signal.shape[0]: + indices.append(control_name) + elif control.shape[0] > signal.shape[0]: + indices.append(signal_name) + else: + indices.append(signal_name) + + return indices diff --git a/src/guppy/analysis/transients.py b/src/guppy/analysis/transients.py new file mode 100644 index 0000000..5fd8645 --- /dev/null +++ b/src/guppy/analysis/transients.py @@ -0,0 +1,112 @@ +import logging +import math +import multiprocessing as mp +from itertools import repeat + +import numpy as np +from scipy.signal import argrelextrema + +logger = logging.getLogger(__name__) + + +def analyze_transients(ts, window, numProcesses, highAmpFilt, transientsThresh, sampling_rate, z_score): + not_nan_indices = ~np.isnan(z_score) + z_score = z_score[not_nan_indices] + z_score_chunks, z_score_chunks_index = createChunks(z_score, sampling_rate, window) + + with mp.Pool(numProcesses) as p: + result = p.starmap( + processChunks, zip(z_score_chunks, z_score_chunks_index, repeat(highAmpFilt), repeat(transientsThresh)) + ) + + result = np.asarray(result, dtype=object) + ts = ts[not_nan_indices] + freq, peaksAmp, peaksInd = calculate_freq_amp(result, z_score, z_score_chunks_index, ts) + peaks_occurrences = np.array([ts[peaksInd], peaksAmp]).T + arr = np.array([[freq, np.mean(peaksAmp)]]) + return z_score, ts, peaksInd, peaks_occurrences, arr + + +def processChunks(arrValues, arrIndexes, highAmpFilt, transientsThresh): + + arrValues = arrValues[~np.isnan(arrValues)] + median = np.median(arrValues) + + mad = np.median(np.abs(arrValues - median)) + + firstThreshold = median + (highAmpFilt * mad) + + greaterThanMad = np.where(arrValues > firstThreshold)[0] + + arr = np.arange(arrValues.shape[0]) + lowerThanMad = np.isin(arr, greaterThanMad, invert=True) + filteredOut = arrValues[np.where(lowerThanMad == True)[0]] + + filteredOutMedian = np.median(filteredOut) + filteredOutMad = np.median(np.abs(filteredOut - np.median(filteredOut))) + secondThreshold = filteredOutMedian + (transientsThresh * filteredOutMad) + + greaterThanThreshIndex = np.where(arrValues > secondThreshold)[0] + greaterThanThreshValues = arrValues[greaterThanThreshIndex] + temp = np.zeros(arrValues.shape[0]) + temp[greaterThanThreshIndex] = greaterThanThreshValues + peaks = argrelextrema(temp, np.greater)[0] + + firstThresholdY = np.full(arrValues.shape[0], firstThreshold) + secondThresholdY = np.full(arrValues.shape[0], secondThreshold) + + newPeaks = np.full(arrValues.shape[0], np.nan) + newPeaks[peaks] = peaks + arrIndexes[0] + + # madY = np.full(arrValues.shape[0], mad) + medianY = np.full(arrValues.shape[0], median) + filteredOutMedianY = np.full(arrValues.shape[0], filteredOutMedian) + + return peaks, mad, filteredOutMad, medianY, filteredOutMedianY, firstThresholdY, secondThresholdY + + +def createChunks(z_score, sampling_rate, window): + + logger.debug("Creating chunks for multiprocessing...") + windowPoints = math.ceil(sampling_rate * window) + remainderPoints = math.ceil((sampling_rate * window) - (z_score.shape[0] % windowPoints)) + + if remainderPoints == windowPoints: + padded_z_score = z_score + z_score_index = np.arange(padded_z_score.shape[0]) + else: + padding = np.full(remainderPoints, np.nan) + padded_z_score = np.concatenate((z_score, padding)) + z_score_index = np.arange(padded_z_score.shape[0]) + + reshape = padded_z_score.shape[0] / windowPoints + + if reshape.is_integer() == True: + z_score_chunks = padded_z_score.reshape(int(reshape), -1) + z_score_chunks_index = z_score_index.reshape(int(reshape), -1) + else: + logger.error("Reshaping values should be integer.") + raise Exception("Reshaping values should be integer.") + logger.info("Chunks are created for multiprocessing.") + return z_score_chunks, z_score_chunks_index + + +def calculate_freq_amp(arr, z_score, z_score_chunks_index, timestamps): + peaks = arr[:, 0] + filteredOutMedian = arr[:, 4] + count = 0 + peaksAmp = np.array([]) + peaksInd = np.array([]) + for i in range(z_score_chunks_index.shape[0]): + count += peaks[i].shape[0] + peaksIndexes = peaks[i] + z_score_chunks_index[i][0] + peaksInd = np.concatenate((peaksInd, peaksIndexes)) + amps = z_score[peaksIndexes] - filteredOutMedian[i][0] + peaksAmp = np.concatenate((peaksAmp, amps)) + + peaksInd = peaksInd.ravel() + peaksInd = peaksInd.astype(int) + # logger.info(timestamps) + freq = peaksAmp.shape[0] / ((timestamps[-1] - timestamps[0]) / 60) + + return freq, peaksAmp, peaksInd diff --git a/src/guppy/analysis/transients_average.py b/src/guppy/analysis/transients_average.py new file mode 100644 index 0000000..9e6d372 --- /dev/null +++ b/src/guppy/analysis/transients_average.py @@ -0,0 +1,81 @@ +import glob +import logging +import os + +import numpy as np + +from .io_utils import ( + makeAverageDir, +) +from .standard_io import ( + read_freq_and_amp_from_hdf5, + write_freq_and_amp_to_csv, + write_freq_and_amp_to_hdf5, +) + +logger = logging.getLogger(__name__) + + +def averageForGroup(folderNames, inputParameters): + + logger.debug("Combining results for frequency and amplitude of transients in z-score data...") + path = [] + abspath = inputParameters["abspath"] + selectForTransientsComputation = inputParameters["selectForTransientsComputation"] + path_temp_len = [] + + for i in range(len(folderNames)): + if selectForTransientsComputation == "z_score": + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + elif selectForTransientsComputation == "dff": + path_temp = glob.glob(os.path.join(folderNames[i], "dff_*")) + else: + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + glob.glob( + os.path.join(folderNames[i], "dff_*") + ) + + path_temp_len.append(len(path_temp)) + + for j in range(len(path_temp)): + basename = (os.path.basename(path_temp[j])).split(".")[0] + # name = name[0] + temp = [folderNames[i], basename] + path.append(temp) + + path_temp_len = np.asarray(path_temp_len) + max_len = np.argmax(path_temp_len) + + naming = [] + for i in range(len(path)): + naming.append(path[i][1]) + naming = np.unique(np.asarray(naming)) + + new_path = [[] for _ in range(path_temp_len[max_len])] + for i in range(len(path)): + idx = np.where(naming == path[i][1])[0][0] + new_path[idx].append(path[i]) + + op = makeAverageDir(abspath) + + for i in range(len(new_path)): + arr = [] # np.zeros((len(new_path[i]), 2)) + fileName = [] + temp_path = new_path[i] + for j in range(len(temp_path)): + if not os.path.exists(os.path.join(temp_path[j][0], "freqAndAmp_" + temp_path[j][1] + ".h5")): + continue + else: + df = read_freq_and_amp_from_hdf5(temp_path[j][0], temp_path[j][1]) + arr.append(np.array([df["freq (events/min)"].iloc[0], df["amplitude"].iloc[0]])) + fileName.append(os.path.basename(temp_path[j][0])) + + arr = np.asarray(arr) + write_freq_and_amp_to_hdf5(op, arr, temp_path[j][1], index=fileName, columns=["freq (events/min)", "amplitude"]) + write_freq_and_amp_to_csv( + op, + arr, + "freqAndAmp_" + temp_path[j][1] + ".csv", + index=fileName, + columns=["freq (events/min)", "amplitude"], + ) + logger.info("Results for frequency and amplitude of transients in z-score data are combined.") diff --git a/src/guppy/analysis/z_score.py b/src/guppy/analysis/z_score.py new file mode 100644 index 0000000..34b29ee --- /dev/null +++ b/src/guppy/analysis/z_score.py @@ -0,0 +1,148 @@ +import logging + +import numpy as np +from scipy import signal as ss + +from .control_channel import helper_create_control_channel + +logger = logging.getLogger(__name__) + + +# high-level function to compute z-score and deltaF/F +def compute_z_score( + control, + signal, + tsNew, + coords, + artifactsRemovalMethod, + filter_window, + isosbestic_control, + zscore_method, + baseline_start, + baseline_end, +): + if (control == 0).all() == True: + control = np.zeros(tsNew.shape[0]) + + z_score_arr = np.array([]) + norm_data_arr = np.full(tsNew.shape[0], np.nan) + control_fit_arr = np.full(tsNew.shape[0], np.nan) + temp_control_arr = np.full(tsNew.shape[0], np.nan) + + # for artifacts removal, each chunk which was selected by user is being processed individually and then + # z-score is calculated + for i in range(coords.shape[0]): + tsNew_index = np.where((tsNew > coords[i, 0]) & (tsNew < coords[i, 1]))[0] + if isosbestic_control == False: + control_arr = helper_create_control_channel(signal[tsNew_index], tsNew[tsNew_index], window=101) + signal_arr = signal[tsNew_index] + norm_data, control_fit = execute_controlFit_dff(control_arr, signal_arr, isosbestic_control, filter_window) + temp_control_arr[tsNew_index] = control_arr + if i < coords.shape[0] - 1: + blank_index = np.where((tsNew > coords[i, 1]) & (tsNew < coords[i + 1, 0]))[0] + temp_control_arr[blank_index] = np.full(blank_index.shape[0], np.nan) + else: + control_arr = control[tsNew_index] + signal_arr = signal[tsNew_index] + norm_data, control_fit = execute_controlFit_dff(control_arr, signal_arr, isosbestic_control, filter_window) + norm_data_arr[tsNew_index] = norm_data + control_fit_arr[tsNew_index] = control_fit + + if artifactsRemovalMethod == "concatenate": + norm_data_arr = norm_data_arr[~np.isnan(norm_data_arr)] + control_fit_arr = control_fit_arr[~np.isnan(control_fit_arr)] + z_score = z_score_computation(norm_data_arr, tsNew, zscore_method, baseline_start, baseline_end) + z_score_arr = np.concatenate((z_score_arr, z_score)) + + # handle the case if there are chunks being cut in the front and the end + if isosbestic_control == False: + coords = coords.flatten() + # front chunk + idx = np.where((tsNew >= tsNew[0]) & (tsNew < coords[0]))[0] + temp_control_arr[idx] = np.full(idx.shape[0], np.nan) + # end chunk + idx = np.where((tsNew > coords[-1]) & (tsNew <= tsNew[-1]))[0] + temp_control_arr[idx] = np.full(idx.shape[0], np.nan) + else: + temp_control_arr = None + + return z_score_arr, norm_data_arr, control_fit_arr, temp_control_arr + + +# function to filter control and signal channel, also execute above two function : controlFit and deltaFF +# function will also take care if there is only signal channel and no control channel +# if there is only signal channel, z-score will be computed using just signal channel +def execute_controlFit_dff(control, signal, isosbestic_control, filter_window): + + if isosbestic_control == False: + signal_smooth = filterSignal(filter_window, signal) # ss.filtfilt(b, a, signal) + control_fit = controlFit(control, signal_smooth) + norm_data = deltaFF(signal_smooth, control_fit) + else: + control_smooth = filterSignal(filter_window, control) # ss.filtfilt(b, a, control) + signal_smooth = filterSignal(filter_window, signal) # ss.filtfilt(b, a, signal) + control_fit = controlFit(control_smooth, signal_smooth) + norm_data = deltaFF(signal_smooth, control_fit) + + return norm_data, control_fit + + +# function to compute deltaF/F using fitted control channel and filtered signal channel +def deltaFF(signal, control): + + res = np.subtract(signal, control) + normData = np.divide(res, control) + # deltaFF = normData + normData = normData * 100 + + return normData + + +# function to fit control channel to signal channel +def controlFit(control, signal): + + p = np.polyfit(control, signal, 1) + arr = (p[0] * control) + p[1] + return arr + + +def filterSignal(filter_window, signal): + if filter_window == 0: + return signal + elif filter_window > 1: + b = np.divide(np.ones((filter_window,)), filter_window) + a = 1 + filtered_signal = ss.filtfilt(b, a, signal) + return filtered_signal + else: + raise Exception("Moving average filter window value is not correct.") + + +# function to compute z-score based on z-score computation method +def z_score_computation(dff, timestamps, zscore_method, baseline_start, baseline_end): + if zscore_method == "standard z-score": + numerator = np.subtract(dff, np.nanmean(dff)) + zscore = np.divide(numerator, np.nanstd(dff)) + elif zscore_method == "baseline z-score": + idx = np.where((timestamps > baseline_start) & (timestamps < baseline_end))[0] + if idx.shape[0] == 0: + logger.error( + "Baseline Window Parameters for baseline z-score computation zscore_method \ + are not correct." + ) + raise Exception( + "Baseline Window Parameters for baseline z-score computation zscore_method \ + are not correct." + ) + else: + baseline_mean = np.nanmean(dff[idx]) + baseline_std = np.nanstd(dff[idx]) + numerator = np.subtract(dff, baseline_mean) + zscore = np.divide(numerator, baseline_std) + else: + median = np.median(dff) + mad = np.median(np.abs(dff - median)) + numerator = 0.6745 * (dff - median) + zscore = np.divide(numerator, mad) + + return zscore diff --git a/src/guppy/combineDataFn.py b/src/guppy/combineDataFn.py deleted file mode 100755 index 51e2bd0..0000000 --- a/src/guppy/combineDataFn.py +++ /dev/null @@ -1,341 +0,0 @@ -import fnmatch -import logging -import os -import re - -logger = logging.getLogger(__name__) - - -def find_files(path, glob_path, ignore_case=False): - rule = ( - re.compile(fnmatch.translate(glob_path), re.IGNORECASE) - if ignore_case - else re.compile(fnmatch.translate(glob_path)) - ) - no_bytes_path = os.listdir(os.path.expanduser(path)) - str_path = [] - - # converting byte object to string - for x in no_bytes_path: - try: - str_path.append(x.decode("utf-8")) - except: - str_path.append(x) - - return [os.path.join(path, n) for n in str_path if rule.match(n)] - - -def read_hdf5(event, filepath, key): - if event: - op = os.path.join(filepath, event + ".hdf5") - else: - op = filepath - - if os.path.exists(op): - with h5py.File(op, "r") as f: - arr = np.asarray(f[key]) - else: - raise Exception("{}.hdf5 file does not exist".format(event)) - - return arr - - -def write_hdf5(data, event, filepath, key): - op = os.path.join(filepath, event + ".hdf5") - - if not os.path.exists(op): - with h5py.File(op, "w") as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - else: - with h5py.File(op, "r+") as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - - -def decide_naming_convention(filepath): - path_1 = find_files(filepath, "control*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) - - path_2 = find_files(filepath, "signal*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) - - path = sorted(path_1 + path_2, key=str.casefold) - - if len(path) % 2 != 0: - raise Exception("There are not equal number of Control and Signal data") - - path = np.asarray(path).reshape(2, -1) - - return path - - -def eliminateData(filepath, timeForLightsTurnOn, event, sampling_rate, naming): - - arr = np.array([]) - ts_arr = np.array([]) - for i in range(len(filepath)): - ts = read_hdf5("timeCorrection_" + naming, filepath[i], "timestampNew") - data = read_hdf5(event, filepath[i], "data").reshape(-1) - - # index = np.where((ts>coords[i,0]) & (tscoords[i,0]) & (ts 1: - mean = np.nanmean(psth[:, single_trials_index], axis=1).reshape(-1, 1) - err = np.nanstd(psth[:, single_trials_index], axis=1) / math.sqrt(psth[:, single_trials_index].shape[1]) - err = err.reshape(-1, 1) - psth = np.hstack((psth, mean)) - psth = np.hstack((psth, err)) - # timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) - # psth = np.hstack((psth, timestamps)) - try: - ts = read_hdf5(event, filepath, "ts") - ts = np.append(ts, ["mean", "err"]) - except: - ts = None - - if len(columns) == 0: - df = pd.DataFrame(psth, index=None, columns=ts, dtype="float32") - else: - columns = np.asarray(columns) - columns = np.append(columns, ["mean", "err"]) - df = pd.DataFrame(psth, index=None, columns=columns, dtype="float32") - - df.to_hdf(op, key="df", mode="w") - - -def getCorrCombinations(filepath, inputParameters): - selectForComputePsth = inputParameters["selectForComputePsth"] - if selectForComputePsth == "z_score": - path = glob.glob(os.path.join(filepath, "z_score_*")) - elif selectForComputePsth == "dff": - path = glob.glob(os.path.join(filepath, "dff_*")) - else: - path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) - - names = list() - type = list() - for i in range(len(path)): - basename = (os.path.basename(path[i])).split(".")[0] - names.append(basename.split("_")[-1]) - type.append((os.path.basename(path[i])).split(".")[0].split("_" + names[-1], 1)[0]) - - names = list(np.unique(np.array(names))) - type = list(np.unique(np.array(type))) - - corr_info = list() - if len(names) <= 1: - logger.info("Cross-correlation cannot be computed because only one signal is present.") - return corr_info, type - elif len(names) == 2: - corr_info = names - else: - corr_info = names - corr_info.append(names[0]) - - return corr_info, type - - -def helperCrossCorrelation(arr_A, arr_B, sample_rate): - cross_corr = list() - for a, b in zip(arr_A, arr_B): - if np.isnan(a).any() or np.isnan(b).any(): - corr = signal.correlate(a, b, method="direct") - else: - corr = signal.correlate(a, b) - corr_norm = corr / np.max(np.abs(corr)) - cross_corr.append(corr_norm) - lag = signal.correlation_lags(len(a), len(b)) - lag_msec = np.array(lag / sample_rate, dtype="float32") - - cross_corr_arr = np.array(cross_corr, dtype="float32") - lag_msec = lag_msec.reshape(1, -1) - cross_corr_arr = np.concatenate((cross_corr_arr, lag_msec), axis=0) - return cross_corr_arr - - -def computeCrossCorrelation(filepath, event, inputParameters): - isCompute = inputParameters["computeCorr"] - removeArtifacts = inputParameters["removeArtifacts"] - artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] - if isCompute == True: - if removeArtifacts == True and artifactsRemovalMethod == "concatenate": - raise Exception( - "For cross-correlation, when removeArtifacts is True, artifacts removal method\ - should be replace with NaNs and not concatenate" - ) - corr_info, type = getCorrCombinations(filepath, inputParameters) - if "control" in event.lower() or "signal" in event.lower(): - return - else: - for i in range(1, len(corr_info)): - logger.debug(f"Computing cross-correlation for event {event}...") - for j in range(len(type)): - psth_a = read_Df(filepath, event + "_" + corr_info[i - 1], type[j] + "_" + corr_info[i - 1]) - psth_b = read_Df(filepath, event + "_" + corr_info[i], type[j] + "_" + corr_info[i]) - sample_rate = 1 / (psth_a["timestamps"][1] - psth_a["timestamps"][0]) - psth_a = psth_a.drop(columns=["timestamps", "err", "mean"]) - psth_b = psth_b.drop(columns=["timestamps", "err", "mean"]) - cols_a, cols_b = np.array(psth_a.columns), np.array(psth_b.columns) - if np.intersect1d(cols_a, cols_b).size > 0: - cols = list(np.intersect1d(cols_a, cols_b)) - else: - cols = list(cols_a) - arr_A, arr_B = np.array(psth_a).T, np.array(psth_b).T - cross_corr = helperCrossCorrelation(arr_A, arr_B, sample_rate) - cols.append("timestamps") - create_Df( - make_dir(filepath), - "corr_" + event, - type[j] + "_" + corr_info[i - 1] + "_" + corr_info[i], - cross_corr, - cols, - ) - logger.info(f"Cross-correlation for event {event} computed.") diff --git a/src/guppy/computePsth.py b/src/guppy/computePsth.py index 671d1d3..32d9be1 100755 --- a/src/guppy/computePsth.py +++ b/src/guppy/computePsth.py @@ -3,22 +3,37 @@ import glob import json import logging -import math import multiprocessing as mp import os import re import subprocess import sys -from collections import OrderedDict from itertools import repeat -import h5py import numpy as np -import pandas as pd from scipy import signal as ss -from .computeCorr import computeCrossCorrelation, getCorrCombinations, make_dir -from .preprocess import get_all_stores_for_combining_data +from .analysis.compute_psth import compute_psth +from .analysis.cross_correlation import compute_cross_correlation +from .analysis.io_utils import ( + get_all_stores_for_combining_data, + make_dir_for_cross_correlation, + makeAverageDir, + read_Df, + read_hdf5, + write_hdf5, +) +from .analysis.psth_average import averageForGroup +from .analysis.psth_peak_and_area import compute_psth_peak_and_area +from .analysis.psth_utils import ( + create_Df_for_cross_correlation, + create_Df_for_psth, + getCorrCombinations, +) +from .analysis.standard_io import ( + write_peak_and_area_to_csv, + write_peak_and_area_to_hdf5, +) logger = logging.getLogger(__name__) @@ -36,335 +51,20 @@ def writeToFile(value: str): file.write(value) -# function to read hdf5 file -def read_hdf5(event, filepath, key): - if event: - event = event.replace("\\", "_") - event = event.replace("/", "_") - op = os.path.join(filepath, event + ".hdf5") - else: - op = filepath - - if os.path.exists(op): - with h5py.File(op, "r") as f: - arr = np.asarray(f[key]) - else: - raise Exception("{}.hdf5 file does not exist".format(event)) - - return arr - - -# function to write hdf5 file -def write_hdf5(data, event, filepath, key): - event = event.replace("\\", "_") - event = event.replace("/", "_") - op = os.path.join(filepath, event + ".hdf5") - - # if file does not exist create a new file - if not os.path.exists(op): - with h5py.File(op, "w") as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - # if file already exists, append data to it or add a new key to it - else: - with h5py.File(op, "r+") as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - - -def create_Df_area_peak(filepath, arr, name, index=[]): - - op = os.path.join(filepath, "peak_AUC_" + name + ".h5") - dirname = os.path.dirname(filepath) - - df = pd.DataFrame(arr, index=index) - - df.to_hdf(op, key="df", mode="w") - - -def read_Df_area_peak(filepath, name): - op = os.path.join(filepath, "peak_AUC_" + name + ".h5") - df = pd.read_hdf(op, key="df", mode="r") - - return df - - -def create_csv_area_peak(filepath, arr, name, index=[]): - op = os.path.join(filepath, "peak_AUC_" + name + ".csv") - df = pd.DataFrame(arr, index=index) - - df.to_csv(op) - - -# function to create dataframe for each event PSTH and save it to h5 file -def create_Df(filepath, event, name, psth, columns=[]): - event = event.replace("\\", "_") - event = event.replace("/", "_") - if name: - op = os.path.join(filepath, event + "_{}.h5".format(name)) - else: - op = os.path.join(filepath, event + ".h5") - - # check if file already exists - # if os.path.exists(op): - # return 0 - - # removing psth binned trials - columns = np.array(columns, dtype="str") - regex = re.compile("bin_*") - single_trials = columns[[i for i in range(len(columns)) if not regex.match(columns[i])]] - single_trials_index = [i for i in range(len(single_trials)) if single_trials[i] != "timestamps"] - - psth = psth.T - if psth.ndim > 1: - mean = np.nanmean(psth[:, single_trials_index], axis=1).reshape(-1, 1) - err = np.nanstd(psth[:, single_trials_index], axis=1) / math.sqrt(psth[:, single_trials_index].shape[1]) - err = err.reshape(-1, 1) - psth = np.hstack((psth, mean)) - psth = np.hstack((psth, err)) - # timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) - # psth = np.hstack((psth, timestamps)) - try: - ts = read_hdf5(event, filepath, "ts") - ts = np.append(ts, ["mean", "err"]) - except: - ts = None - - if len(columns) == 0: - df = pd.DataFrame(psth, index=None, columns=ts, dtype="float32") - else: - columns = np.asarray(columns) - columns = np.append(columns, ["mean", "err"]) - df = pd.DataFrame(psth, index=None, columns=list(columns), dtype="float32") - - df.to_hdf(op, key="df", mode="w") - - -# function to read h5 file and make a dataframe from it -def read_Df(filepath, event, name): - event = event.replace("\\", "_") - event = event.replace("/", "_") - if name: - op = os.path.join(filepath, event + "_{}.h5".format(name)) - else: - op = os.path.join(filepath, event + ".h5") - df = pd.read_hdf(op, key="df", mode="r") - - return df - - -# function to create PSTH trials corresponding to each event timestamp -def rowFormation(z_score, thisIndex, nTsPrev, nTsPost): - - if nTsPrev < thisIndex and z_score.shape[0] > (thisIndex + nTsPost): - res = z_score[thisIndex - nTsPrev - 1 : thisIndex + nTsPost] - elif nTsPrev >= thisIndex and z_score.shape[0] > (thisIndex + nTsPost): - mismatch = nTsPrev - thisIndex + 1 - res = np.zeros(nTsPrev + nTsPost + 1) - res[:mismatch] = np.nan - res[mismatch:] = z_score[: thisIndex + nTsPost] - elif nTsPrev >= thisIndex and z_score.shape[0] < (thisIndex + nTsPost): - mismatch1 = nTsPrev - thisIndex + 1 - mismatch2 = (thisIndex + nTsPost) - z_score.shape[0] - res1 = np.full(mismatch1, np.nan) - res2 = z_score - res3 = np.full(mismatch2, np.nan) - res = np.concatenate((res1, np.concatenate((res2, res3)))) - else: - mismatch = (thisIndex + nTsPost) - z_score.shape[0] - res1 = np.zeros(mismatch) - res1[:] = np.nan - res2 = z_score[thisIndex - nTsPrev - 1 : z_score.shape[0]] - res = np.concatenate((res2, res1)) - - return res - - -# function to calculate baseline for each PSTH trial and do baseline correction -def baselineCorrection(filepath, arr, timeAxis, baselineStart, baselineEnd): - - # timeAxis = read_Df(filepath, 'ts_psth', '') - # timeAxis = np.asarray(timeAxis).reshape(-1) - baselineStrtPt = np.where(timeAxis >= baselineStart)[0] - baselineEndPt = np.where(timeAxis >= baselineEnd)[0] - - # logger.info(baselineStrtPt[0], baselineEndPt[0]) - if baselineStart == 0 and baselineEnd == 0: - return arr - - baseline = np.nanmean(arr[baselineStrtPt[0] : baselineEndPt[0]]) - baselineSub = np.subtract(arr, baseline) - - return baselineSub - - -# helper function to make PSTH for each event -def helper_psth( - z_score, - event, - filepath, - nSecPrev, - nSecPost, - timeInterval, - bin_psth_trials, - use_time_or_trials, - baselineStart, - baselineEnd, - naming, - just_use_signal, -): - - event = event.replace("\\", "_") - event = event.replace("/", "_") - - sampling_rate = read_hdf5("timeCorrection_" + naming, filepath, "sampling_rate")[0] - - # calculate time before event timestamp and time after event timestamp - nTsPrev = int(round(nSecPrev * sampling_rate)) - nTsPost = int(round(nSecPost * sampling_rate)) - - totalTs = (-1 * nTsPrev) + nTsPost - increment = ((-1 * nSecPrev) + nSecPost) / totalTs - timeAxis = np.linspace(nSecPrev, nSecPost + increment, totalTs + 1) - timeAxisNew = np.concatenate((timeAxis, timeAxis[::-1])) - - # avoid writing same data to same file in multi-processing - # if not os.path.exists(os.path.join(filepath, 'ts_psth.h5')): - # logger.info('file not exists') - # create_Df(filepath, 'ts_psth', '', timeAxis) - # time.sleep(2) - - ts = read_hdf5(event + "_" + naming, filepath, "ts") - - # reject timestamps for which baseline cannot be calculated because of nan values - new_ts = [] - for i in range(ts.shape[0]): - thisTime = ts[i] # -1 not needed anymore - if thisTime < abs(baselineStart): - continue - else: - new_ts.append(ts[i]) - - # reject burst of timestamps - ts = np.asarray(new_ts) - # skip the event if there are no TTLs - if len(ts) == 0: - new_ts = np.array([]) - logger.info(f"Warning : No TTLs present for {event}. This will cause an error in Visualization step") - else: - new_ts = [ts[0]] - for i in range(1, ts.shape[0]): - thisTime = ts[i] - prevTime = new_ts[-1] - diff = thisTime - prevTime - if diff < timeInterval: - continue - else: - new_ts.append(ts[i]) - - # final timestamps - ts = np.asarray(new_ts) - nTs = ts.shape[0] - - # initialize PSTH vector - psth = np.full((nTs, totalTs + 1), np.nan) - psth_baselineUncorrected = np.full((nTs, totalTs + 1), np.nan) # extra - - # for each timestamp, create trial which will be saved in a PSTH vector - for i in range(nTs): - thisTime = ts[i] # -timeForLightsTurnOn - thisIndex = int(round(thisTime * sampling_rate)) - arr = rowFormation(z_score, thisIndex, -1 * nTsPrev, nTsPost) - if just_use_signal == True: - res = np.subtract(arr, np.nanmean(arr)) - z_score_arr = np.divide(res, np.nanstd(arr)) - arr = z_score_arr - else: - arr = arr - - psth_baselineUncorrected[i, :] = arr # extra - psth[i, :] = baselineCorrection(filepath, arr, timeAxis, baselineStart, baselineEnd) - - write_hdf5(ts, event + "_" + naming, filepath, "ts") - columns = list(ts) - - if use_time_or_trials == "Time (min)" and bin_psth_trials > 0: - timestamps = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") - timestamps = np.divide(timestamps, 60) - ts_min = np.divide(ts, 60) - bin_steps = np.arange(timestamps[0], timestamps[-1] + bin_psth_trials, bin_psth_trials) - indices_each_step = dict() - for i in range(1, bin_steps.shape[0]): - indices_each_step[f"{np.around(bin_steps[i-1],0)}-{np.around(bin_steps[i],0)}"] = np.where( - (ts_min >= bin_steps[i - 1]) & (ts_min <= bin_steps[i]) - )[0] - elif use_time_or_trials == "# of trials" and bin_psth_trials > 0: - bin_steps = np.arange(0, ts.shape[0], bin_psth_trials) - if bin_steps[-1] < ts.shape[0]: - bin_steps = np.concatenate((bin_steps, [ts.shape[0]]), axis=0) - indices_each_step = dict() - for i in range(1, bin_steps.shape[0]): - indices_each_step[f"{bin_steps[i-1]}-{bin_steps[i]}"] = np.arange(bin_steps[i - 1], bin_steps[i]) - else: - indices_each_step = dict() - - psth_bin, psth_bin_baselineUncorrected = [], [] - if indices_each_step: - keys = list(indices_each_step.keys()) - for k in keys: - # no trials in a given bin window, just put all the nan values - if indices_each_step[k].shape[0] == 0: - psth_bin.append(np.full(psth.shape[1], np.nan)) - psth_bin_baselineUncorrected.append(np.full(psth_baselineUncorrected.shape[1], np.nan)) - psth_bin.append(np.full(psth.shape[1], np.nan)) - psth_bin_baselineUncorrected.append(np.full(psth_baselineUncorrected.shape[1], np.nan)) - else: - index = indices_each_step[k] - arr = psth[index, :] - # mean of bins - psth_bin.append(np.nanmean(psth[index, :], axis=0)) - psth_bin_baselineUncorrected.append(np.nanmean(psth_baselineUncorrected[index, :], axis=0)) - psth_bin.append(np.nanstd(psth[index, :], axis=0) / math.sqrt(psth[index, :].shape[0])) - # error of bins - psth_bin_baselineUncorrected.append( - np.nanstd(psth_baselineUncorrected[index, :], axis=0) - / math.sqrt(psth_baselineUncorrected[index, :].shape[0]) - ) - - # adding column names - columns.append(f"bin_({k})") - columns.append(f"bin_err_({k})") - - psth = np.concatenate((psth, psth_bin), axis=0) - psth_baselineUncorrected = np.concatenate((psth_baselineUncorrected, psth_bin_baselineUncorrected), axis=0) - - timeAxis = timeAxis.reshape(1, -1) - psth = np.concatenate((psth, timeAxis), axis=0) - psth_baselineUncorrected = np.concatenate((psth_baselineUncorrected, timeAxis), axis=0) - columns.append("timestamps") - - return psth, psth_baselineUncorrected, columns - - # function to create PSTH for each event using function helper_psth and save the PSTH to h5 file -def storenamePsth(filepath, event, inputParameters): +def execute_compute_psth(filepath, event, inputParameters): event = event.replace("\\", "_") event = event.replace("/", "_") + if "control" in event.lower() or "signal" in event.lower(): + return 0 selectForComputePsth = inputParameters["selectForComputePsth"] bin_psth_trials = inputParameters["bin_psth_trials"] use_time_or_trials = inputParameters["use_time_or_trials"] + nSecPrev, nSecPost = inputParameters["nSecPrev"], inputParameters["nSecPost"] + baselineStart, baselineEnd = inputParameters["baselineCorrectionStart"], inputParameters["baselineCorrectionEnd"] + timeInterval = inputParameters["timeInterval"] if selectForComputePsth == "z_score": path = glob.glob(os.path.join(filepath, "z_score_*")) @@ -376,100 +76,62 @@ def storenamePsth(filepath, event, inputParameters): b = np.divide(np.ones((100,)), 100) a = 1 - # storesList = storesList - # sampling_rate = read_hdf5(storesList[0,0], filepath, 'sampling_rate') - nSecPrev, nSecPost = inputParameters["nSecPrev"], inputParameters["nSecPost"] - baselineStart, baselineEnd = inputParameters["baselineCorrectionStart"], inputParameters["baselineCorrectionEnd"] - timeInterval = inputParameters["timeInterval"] - - if "control" in event.lower() or "signal" in event.lower(): - return 0 - else: - for i in range(len(path)): - logger.info(f"Computing PSTH for event {event}...") - basename = (os.path.basename(path[i])).split(".")[0] - name_1 = basename.split("_")[-1] - control = read_hdf5("control_" + name_1, os.path.dirname(path[i]), "data") - if (control == 0).all() == True: - signal = read_hdf5("signal_" + name_1, os.path.dirname(path[i]), "data") - z_score = ss.filtfilt(b, a, signal) - just_use_signal = True - else: - z_score = read_hdf5("", path[i], "data") - just_use_signal = False - psth, psth_baselineUncorrected, cols = helper_psth( - z_score, - event, - filepath, - nSecPrev, - nSecPost, - timeInterval, - bin_psth_trials, - use_time_or_trials, - baselineStart, - baselineEnd, - name_1, - just_use_signal, - ) - - create_Df( - filepath, - event + "_" + name_1 + "_baselineUncorrected", - basename, - psth_baselineUncorrected, - columns=cols, - ) # extra - create_Df(filepath, event + "_" + name_1, basename, psth, columns=cols) - logger.info(f"PSTH for event {event} computed.") - - -def helperPSTHPeakAndArea(psth_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint): - - peak_startPoint = np.asarray(peak_startPoint) - peak_endPoint = np.asarray(peak_endPoint) - - peak_startPoint = peak_startPoint[~np.isnan(peak_startPoint)] - peak_endPoint = peak_endPoint[~np.isnan(peak_endPoint)] - - if peak_startPoint.shape[0] != peak_endPoint.shape[0]: - logger.error("Number of Peak Start Time and Peak End Time are unequal.") - raise Exception("Number of Peak Start Time and Peak End Time are unequal.") - - if np.less_equal(peak_endPoint, peak_startPoint).any() == True: - logger.error( - "Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window." - ) - raise Exception( - "Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window." - ) + for i in range(len(path)): + logger.info(f"Computing PSTH for event {event}...") + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + control = read_hdf5("control_" + name_1, os.path.dirname(path[i]), "data") + if (control == 0).all() == True: + signal = read_hdf5("signal_" + name_1, os.path.dirname(path[i]), "data") + z_score = ss.filtfilt(b, a, signal) + just_use_signal = True + else: + z_score = read_hdf5("", path[i], "data") + just_use_signal = False - peak_area = OrderedDict() - - if peak_startPoint.shape[0] == 0 or peak_endPoint.shape[0] == 0: - peak_area["peak"] = np.nan - peak_area["area"] = np.nan - - for i in range(peak_startPoint.shape[0]): - startPtForPeak = np.where(timestamps >= peak_startPoint[i])[0] - endPtForPeak = np.where(timestamps >= peak_endPoint[i])[0] - if len(startPtForPeak) >= 1 and len(endPtForPeak) >= 1: - peakPoint_pos = startPtForPeak[0] + np.argmax(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) - peakPoint_neg = startPtForPeak[0] + np.argmin(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) - peak_area["peak_pos_" + str(i + 1)] = np.amax(psth_mean[peakPoint_pos], axis=0) - peak_area["peak_neg_" + str(i + 1)] = np.amin(psth_mean[peakPoint_neg], axis=0) - peak_area["area_" + str(i + 1)] = np.trapz(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + sampling_rate = read_hdf5("timeCorrection_" + name_1, filepath, "sampling_rate")[0] + ts = read_hdf5(event + "_" + name_1, filepath, "ts") + if use_time_or_trials == "Time (min)" and bin_psth_trials > 0: + corrected_timestamps = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") else: - peak_area["peak_" + str(i + 1)] = np.nan - peak_area["area_" + str(i + 1)] = np.nan + corrected_timestamps = None + psth, psth_baselineUncorrected, cols, ts = compute_psth( + z_score, + event, + filepath, + nSecPrev, + nSecPost, + timeInterval, + bin_psth_trials, + use_time_or_trials, + baselineStart, + baselineEnd, + name_1, + just_use_signal, + sampling_rate, + ts, + corrected_timestamps, + ) + write_hdf5(ts, event + "_" + name_1, filepath, "ts") - return peak_area + create_Df_for_psth( + filepath, + event + "_" + name_1 + "_baselineUncorrected", + basename, + psth_baselineUncorrected, + columns=cols, + ) # extra + create_Df_for_psth(filepath, event + "_" + name_1, basename, psth, columns=cols) + logger.info(f"PSTH for event {event} computed.") # function to compute PSTH peak and area using the function helperPSTHPeakAndArea save the values to h5 and csv files. -def findPSTHPeakAndArea(filepath, event, inputParameters): +def execute_compute_psth_peak_and_area(filepath, event, inputParameters): event = event.replace("\\", "_") event = event.replace("/", "_") + if "control" in event.lower() or "signal" in event.lower(): + return 0 # sampling_rate = read_hdf5(storesList[0,0], filepath, 'sampling_rate') peak_startPoint = inputParameters["peak_startPoint"] @@ -483,229 +145,178 @@ def findPSTHPeakAndArea(filepath, event, inputParameters): else: path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) - if "control" in event.lower() or "signal" in event.lower(): - return 0 - else: - for i in range(len(path)): - logger.info(f"Computing peak and area for PSTH mean signal for event {event}...") - basename = (os.path.basename(path[i])).split(".")[0] - name_1 = basename.split("_")[-1] - sampling_rate = read_hdf5("timeCorrection_" + name_1, filepath, "sampling_rate")[0] - psth = read_Df(filepath, event + "_" + name_1, basename) - cols = list(psth.columns) - regex = re.compile("bin_[(]") - bin_names = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - regex_trials = re.compile("[+-]?([0-9]*[.])?[0-9]+") - trials_names = [cols[i] for i in range(len(cols)) if regex_trials.match(cols[i])] - psth_mean_bin_names = trials_names + bin_names + ["mean"] - psth_mean_bin_mean = np.asarray(psth[psth_mean_bin_names]) - timestamps = np.asarray(psth["timestamps"]).ravel() # np.asarray(read_Df(filepath, 'ts_psth', '')).ravel() - peak_area = helperPSTHPeakAndArea( - psth_mean_bin_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint - ) # peak, area = - # arr = np.array([[peak, area]]) - fileName = [os.path.basename(os.path.dirname(filepath))] - index = [fileName[0] + "_" + s for s in psth_mean_bin_names] - create_Df_area_peak( - filepath, peak_area, event + "_" + name_1 + "_" + basename, index=index - ) # columns=['peak', 'area'] - create_csv_area_peak(filepath, peak_area, event + "_" + name_1 + "_" + basename, index=index) - logger.info(f"Peak and Area for PSTH mean signal for event {event} computed.") - - -def makeAverageDir(filepath): - - op = os.path.join(filepath, "average") - if not os.path.exists(op): - os.mkdir(op) - - return op - - -def psth_shape_check(psth): - - each_ln = [] - for i in range(len(psth)): - each_ln.append(psth[i].shape[0]) - - each_ln = np.asarray(each_ln) - keep_ln = each_ln[-1] - - for i in range(len(psth)): - if psth[i].shape[0] > keep_ln: - psth[i] = psth[i][:keep_ln] - elif psth[i].shape[0] < keep_ln: - psth[i] = np.append(psth[i], np.full(keep_ln - len(psth[i]), np.nan)) + for i in range(len(path)): + logger.info(f"Computing peak and area for PSTH mean signal for event {event}...") + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + sampling_rate = read_hdf5("timeCorrection_" + name_1, filepath, "sampling_rate")[0] + psth = read_Df(filepath, event + "_" + name_1, basename) + cols = list(psth.columns) + regex = re.compile("bin_[(]") + bin_names = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + regex_trials = re.compile("[+-]?([0-9]*[.])?[0-9]+") + trials_names = [cols[i] for i in range(len(cols)) if regex_trials.match(cols[i])] + psth_mean_bin_names = trials_names + bin_names + ["mean"] + psth_mean_bin_mean = np.asarray(psth[psth_mean_bin_names]) + timestamps = np.asarray(psth["timestamps"]).ravel() # np.asarray(read_Df(filepath, 'ts_psth', '')).ravel() + peak_area = compute_psth_peak_and_area( + psth_mean_bin_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint + ) # peak, area = + # arr = np.array([[peak, area]]) + fileName = [os.path.basename(os.path.dirname(filepath))] + index = [fileName[0] + "_" + s for s in psth_mean_bin_names] + write_peak_and_area_to_hdf5( + filepath, peak_area, event + "_" + name_1 + "_" + basename, index=index + ) # columns=['peak', 'area'] + write_peak_and_area_to_csv(filepath, peak_area, event + "_" + name_1 + "_" + basename, index=index) + logger.info(f"Peak and Area for PSTH mean signal for event {event} computed.") + + +def execute_compute_cross_correlation(filepath, event, inputParameters): + isCompute = inputParameters["computeCorr"] + removeArtifacts = inputParameters["removeArtifacts"] + artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] + if isCompute == True: + if removeArtifacts == True and artifactsRemovalMethod == "concatenate": + raise Exception( + "For cross-correlation, when removeArtifacts is True, artifacts removal method\ + should be replace with NaNs and not concatenate" + ) + corr_info, type = getCorrCombinations(filepath, inputParameters) + if "control" in event.lower() or "signal" in event.lower(): + return else: - psth[i] = psth[i] + for i in range(1, len(corr_info)): + logger.debug(f"Computing cross-correlation for event {event}...") + for j in range(len(type)): + psth_a = read_Df(filepath, event + "_" + corr_info[i - 1], type[j] + "_" + corr_info[i - 1]) + psth_b = read_Df(filepath, event + "_" + corr_info[i], type[j] + "_" + corr_info[i]) + sample_rate = 1 / (psth_a["timestamps"][1] - psth_a["timestamps"][0]) + psth_a = psth_a.drop(columns=["timestamps", "err", "mean"]) + psth_b = psth_b.drop(columns=["timestamps", "err", "mean"]) + cols_a, cols_b = np.array(psth_a.columns), np.array(psth_b.columns) + if np.intersect1d(cols_a, cols_b).size > 0: + cols = list(np.intersect1d(cols_a, cols_b)) + else: + cols = list(cols_a) + arr_A, arr_B = np.array(psth_a).T, np.array(psth_b).T + cross_corr = compute_cross_correlation(arr_A, arr_B, sample_rate) + cols.append("timestamps") + create_Df_for_cross_correlation( + make_dir_for_cross_correlation(filepath), + "corr_" + event, + type[j] + "_" + corr_info[i - 1] + "_" + corr_info[i], + cross_corr, + cols, + ) + logger.info(f"Cross-correlation for event {event} computed.") + + +def orchestrate_psth(inputParameters): + folderNames = inputParameters["folderNames"] + numProcesses = inputParameters["numberOfCores"] + storesListPath = [] + for i in range(len(folderNames)): + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + writeToFile(str((storesListPath.shape[0] + storesListPath.shape[0] + 1) * 10) + "\n" + str(10) + "\n") + for i in range(len(folderNames)): + logger.debug(f"Computing PSTH, Peak and Area for each event in {folderNames[i]}") + storesListPath = takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ) - return psth + with mp.Pool(numProcesses) as p: + p.starmap(execute_compute_psth, zip(repeat(filepath), storesList[1, :], repeat(inputParameters))) + with mp.Pool(numProcesses) as pq: + pq.starmap( + execute_compute_psth_peak_and_area, zip(repeat(filepath), storesList[1, :], repeat(inputParameters)) + ) -# function to compute average of group of recordings -def averageForGroup(folderNames, event, inputParameters): + with mp.Pool(numProcesses) as cr: + cr.starmap( + execute_compute_cross_correlation, zip(repeat(filepath), storesList[1, :], repeat(inputParameters)) + ) - event = event.replace("\\", "_") - event = event.replace("/", "_") + # for k in range(storesList.shape[1]): + # storenamePsth(filepath, storesList[1,k], inputParameters) + # findPSTHPeakAndArea(filepath, storesList[1,k], inputParameters) - logger.debug("Averaging group of data...") - path = [] - abspath = inputParameters["abspath"] - selectForComputePsth = inputParameters["selectForComputePsth"] - path_temp_len = [] - op = makeAverageDir(abspath) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + logger.info(f"PSTH, Area and Peak are computed for all events in {folderNames[i]}.") - # combining paths to all the selected folders for doing average + +def execute_psth_combined(inputParameters): + folderNames = inputParameters["folderNames"] + storesListPath = [] for i in range(len(folderNames)): - if selectForComputePsth == "z_score": - path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) - elif selectForComputePsth == "dff": - path_temp = glob.glob(os.path.join(folderNames[i], "dff_*")) - else: - path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + glob.glob( - os.path.join(folderNames[i], "dff_*") + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + writeToFile(str((len(op) + len(op) + 1) * 10) + "\n" + str(10) + "\n") + for i in range(len(op)): + storesList = np.asarray([[], []]) + for j in range(len(op[i])): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1), + ), + axis=1, ) + storesList = np.unique(storesList, axis=1) + for k in range(storesList.shape[1]): + execute_compute_psth(op[i][0], storesList[1, k], inputParameters) + execute_compute_psth_peak_and_area(op[i][0], storesList[1, k], inputParameters) + execute_compute_cross_correlation(op[i][0], storesList[1, k], inputParameters) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 - path_temp_len.append(len(path_temp)) - # path_temp = glob.glob(os.path.join(folderNames[i], 'z_score_*')) - for j in range(len(path_temp)): - basename = (os.path.basename(path_temp[j])).split(".")[0] - write_hdf5(np.array([]), basename, op, "data") - name_1 = basename.split("_")[-1] - temp = [folderNames[i], event + "_" + name_1, basename] - path.append(temp) - - # processing of all the paths - path_temp_len = np.asarray(path_temp_len) - max_len = np.argmax(path_temp_len) - naming = [] - for i in range(len(path)): - naming.append(path[i][2]) - naming = np.unique(np.asarray(naming)) - - new_path = [[] for _ in range(path_temp_len[max_len])] - for i in range(len(path)): - idx = np.where(naming == path[i][2])[0][0] - new_path[idx].append(path[i]) - - # read PSTH for each event and make the average of it. Save the final output to an average folder. - for i in range(len(new_path)): - psth, psth_bins = [], [] - columns = [] - bins_cols = [] - temp_path = new_path[i] - for j in range(len(temp_path)): - # logger.info(os.path.join(temp_path[j][0], temp_path[j][1]+'_{}.h5'.format(temp_path[j][2]))) - if not os.path.exists(os.path.join(temp_path[j][0], temp_path[j][1] + "_{}.h5".format(temp_path[j][2]))): - continue - else: - df = read_Df(temp_path[j][0], temp_path[j][1], temp_path[j][2]) # filepath, event, name - cols = list(df.columns) - regex = re.compile("bin_[(]") - bins_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - psth.append(np.asarray(df["mean"])) - columns.append(os.path.basename(temp_path[j][0])) - if len(bins_cols) > 0: - psth_bins.append(df[bins_cols]) - - if len(psth) == 0: - logger.warning("Something is wrong with the file search pattern.") +def execute_average_for_group(inputParameters): + folderNamesForAvg = inputParameters["folderNamesForAvg"] + if len(folderNamesForAvg) == 0: + logger.error("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + raise Exception("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + + storesListPath = [] + for i in range(len(folderNamesForAvg)): + filepath = folderNamesForAvg[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + storesList = np.asarray([[], []]) + for i in range(storesListPath.shape[0]): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + op = makeAverageDir(inputParameters["abspath"]) + np.savetxt(os.path.join(op, "storesList.csv"), storesList, delimiter=",", fmt="%s") + pbMaxValue = 0 + for j in range(storesList.shape[1]): + if "control" in storesList[1, j].lower() or "signal" in storesList[1, j].lower(): continue - - if len(bins_cols) > 0: - df_bins = pd.concat(psth_bins, axis=1) - df_bins_mean = df_bins.groupby(by=df_bins.columns, axis=1).mean() - df_bins_err = df_bins.groupby(by=df_bins.columns, axis=1).std() / math.sqrt(df_bins.shape[1]) - cols_err = list(df_bins_err.columns) - dict_err = {} - for i in cols_err: - split = i.split("_") - dict_err[i] = "{}_err_{}".format(split[0], split[1]) - df_bins_err = df_bins_err.rename(columns=dict_err) - columns = columns + list(df_bins_mean.columns) + list(df_bins_err.columns) - df_bins_mean_err = pd.concat([df_bins_mean, df_bins_err], axis=1).T - psth, df_bins_mean_err = np.asarray(psth), np.asarray(df_bins_mean_err) - psth = np.concatenate((psth, df_bins_mean_err), axis=0) else: - psth = psth_shape_check(psth) - psth = np.asarray(psth) - - timestamps = np.asarray(df["timestamps"]).reshape(1, -1) - psth = np.concatenate((psth, timestamps), axis=0) - columns = columns + ["timestamps"] - create_Df(op, temp_path[j][1], temp_path[j][2], psth, columns=columns) - - # read PSTH peak and area for each event and combine them. Save the final output to an average folder - for i in range(len(new_path)): - arr = [] - index = [] - temp_path = new_path[i] - for j in range(len(temp_path)): - if not os.path.exists( - os.path.join(temp_path[j][0], "peak_AUC_" + temp_path[j][1] + "_" + temp_path[j][2] + ".h5") - ): - continue - else: - df = read_Df_area_peak(temp_path[j][0], temp_path[j][1] + "_" + temp_path[j][2]) - arr.append(df) - index.append(list(df.index)) - - if len(arr) == 0: - logger.warning("Something is wrong with the file search pattern.") + pbMaxValue += 1 + writeToFile(str((1 + pbMaxValue + 1) * 10) + "\n" + str(10) + "\n") + for k in range(storesList.shape[1]): + if "control" in storesList[1, k].lower() or "signal" in storesList[1, k].lower(): continue - index = list(np.concatenate(index)) - new_df = pd.concat(arr, axis=0) # os.path.join(filepath, 'peak_AUC_'+name+'.csv') - new_df.to_csv(os.path.join(op, "peak_AUC_{}_{}.csv".format(temp_path[j][1], temp_path[j][2])), index=index) - new_df.to_hdf( - os.path.join(op, "peak_AUC_{}_{}.h5".format(temp_path[j][1], temp_path[j][2])), - key="df", - mode="w", - index=index, - ) - - # read cross-correlation files and combine them. Save the final output to an average folder - type = [] - for i in range(len(folderNames)): - _, temp_type = getCorrCombinations(folderNames[i], inputParameters) - type.append(temp_type) - - type = np.unique(np.array(type)) - for i in range(len(type)): - corr = [] - columns = [] - df = None - for j in range(len(folderNames)): - corr_info, _ = getCorrCombinations(folderNames[j], inputParameters) - for k in range(1, len(corr_info)): - path = os.path.join( - folderNames[j], - "cross_correlation_output", - "corr_" + event + "_" + type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], - ) - if not os.path.exists(path + ".h5"): - continue - else: - df = read_Df( - os.path.join(folderNames[j], "cross_correlation_output"), - "corr_" + event, - type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], - ) - corr.append(df["mean"]) - columns.append(os.path.basename(folderNames[j])) - - if not isinstance(df, pd.DataFrame): - break - - corr = np.array(corr) - timestamps = np.array(df["timestamps"]).reshape(1, -1) - corr = np.concatenate((corr, timestamps), axis=0) - columns.append("timestamps") - create_Df( - make_dir(op), "corr_" + event, type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], corr, columns=columns - ) - - logger.info("Group of data averaged.") + else: + averageForGroup(storesListPath, storesList[1, k], inputParameters) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 def psthForEachStorename(inputParameters): @@ -715,8 +326,6 @@ def psthForEachStorename(inputParameters): # storesList = np.genfromtxt(inputParameters['storesListPath'], dtype='str', delimiter=',') - folderNames = inputParameters["folderNames"] - folderNamesForAvg = inputParameters["folderNamesForAvg"] average = inputParameters["averageForGroup"] combine_data = inputParameters["combine_data"] numProcesses = inputParameters["numberOfCores"] @@ -734,108 +343,14 @@ def psthForEachStorename(inputParameters): # for average following if statement will be executed if average == True: - if len(folderNamesForAvg) > 0: - storesListPath = [] - for i in range(len(folderNamesForAvg)): - filepath = folderNamesForAvg[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) - storesListPath = np.concatenate(storesListPath) - storesList = np.asarray([[], []]) - for i in range(storesListPath.shape[0]): - storesList = np.concatenate( - ( - storesList, - np.genfromtxt( - os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter="," - ).reshape(2, -1), - ), - axis=1, - ) - storesList = np.unique(storesList, axis=1) - op = makeAverageDir(inputParameters["abspath"]) - np.savetxt(os.path.join(op, "storesList.csv"), storesList, delimiter=",", fmt="%s") - pbMaxValue = 0 - for j in range(storesList.shape[1]): - if "control" in storesList[1, j].lower() or "signal" in storesList[1, j].lower(): - continue - else: - pbMaxValue += 1 - writeToFile(str((1 + pbMaxValue + 1) * 10) + "\n" + str(10) + "\n") - for k in range(storesList.shape[1]): - if "control" in storesList[1, k].lower() or "signal" in storesList[1, k].lower(): - continue - else: - averageForGroup(storesListPath, storesList[1, k], inputParameters) - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 - - else: - logger.error("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") - raise Exception("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + execute_average_for_group(inputParameters) # for individual analysis following else statement will be executed else: if combine_data == True: - storesListPath = [] - for i in range(len(folderNames)): - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) - storesListPath = list(np.concatenate(storesListPath).flatten()) - op = get_all_stores_for_combining_data(storesListPath) - writeToFile(str((len(op) + len(op) + 1) * 10) + "\n" + str(10) + "\n") - for i in range(len(op)): - storesList = np.asarray([[], []]) - for j in range(len(op[i])): - storesList = np.concatenate( - ( - storesList, - np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape( - 2, -1 - ), - ), - axis=1, - ) - storesList = np.unique(storesList, axis=1) - for k in range(storesList.shape[1]): - storenamePsth(op[i][0], storesList[1, k], inputParameters) - findPSTHPeakAndArea(op[i][0], storesList[1, k], inputParameters) - computeCrossCorrelation(op[i][0], storesList[1, k], inputParameters) - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 + execute_psth_combined(inputParameters) else: - storesListPath = [] - for i in range(len(folderNames)): - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) - storesListPath = np.concatenate(storesListPath) - writeToFile(str((storesListPath.shape[0] + storesListPath.shape[0] + 1) * 10) + "\n" + str(10) + "\n") - for i in range(len(folderNames)): - logger.debug(f"Computing PSTH, Peak and Area for each event in {folderNames[i]}") - storesListPath = takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*"))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt( - os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," - ).reshape(2, -1) - - with mp.Pool(numProcesses) as p: - p.starmap(storenamePsth, zip(repeat(filepath), storesList[1, :], repeat(inputParameters))) - - with mp.Pool(numProcesses) as pq: - pq.starmap( - findPSTHPeakAndArea, zip(repeat(filepath), storesList[1, :], repeat(inputParameters)) - ) - - with mp.Pool(numProcesses) as cr: - cr.starmap( - computeCrossCorrelation, zip(repeat(filepath), storesList[1, :], repeat(inputParameters)) - ) - - # for k in range(storesList.shape[1]): - # storenamePsth(filepath, storesList[1,k], inputParameters) - # findPSTHPeakAndArea(filepath, storesList[1,k], inputParameters) - - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 - logger.info(f"PSTH, Area and Peak are computed for all events in {folderNames[i]}.") + orchestrate_psth(inputParameters) logger.info("PSTH, Area and Peak are computed for all events.") return inputParameters diff --git a/src/guppy/findTransientsFreqAndAmp.py b/src/guppy/findTransientsFreqAndAmp.py index e9b696b..f6c3d6e 100755 --- a/src/guppy/findTransientsFreqAndAmp.py +++ b/src/guppy/findTransientsFreqAndAmp.py @@ -1,162 +1,33 @@ import glob import json import logging -import math import multiprocessing as mp import os import sys -from itertools import repeat -import h5py import matplotlib.pyplot as plt import numpy as np -import pandas as pd -from scipy.signal import argrelextrema -from .preprocess import get_all_stores_for_combining_data +from .analysis.io_utils import ( + get_all_stores_for_combining_data, + read_hdf5, + takeOnlyDirs, +) +from .analysis.standard_io import ( + write_freq_and_amp_to_csv, + write_freq_and_amp_to_hdf5, +) +from .analysis.transients import analyze_transients +from .analysis.transients_average import averageForGroup logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) - - -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) - def writeToFile(value: str): with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: file.write(value) -def read_hdf5(event, filepath, key): - if event: - op = os.path.join(filepath, event + ".hdf5") - else: - op = filepath - - if os.path.exists(op): - with h5py.File(op, "r") as f: - arr = np.asarray(f[key]) - else: - logger.error(f"{event}.hdf5 file does not exist") - raise Exception("{}.hdf5 file does not exist".format(event)) - - return arr - - -def processChunks(arrValues, arrIndexes, highAmpFilt, transientsThresh): - - arrValues = arrValues[~np.isnan(arrValues)] - median = np.median(arrValues) - - mad = np.median(np.abs(arrValues - median)) - - firstThreshold = median + (highAmpFilt * mad) - - greaterThanMad = np.where(arrValues > firstThreshold)[0] - - arr = np.arange(arrValues.shape[0]) - lowerThanMad = np.isin(arr, greaterThanMad, invert=True) - filteredOut = arrValues[np.where(lowerThanMad == True)[0]] - - filteredOutMedian = np.median(filteredOut) - filteredOutMad = np.median(np.abs(filteredOut - np.median(filteredOut))) - secondThreshold = filteredOutMedian + (transientsThresh * filteredOutMad) - - greaterThanThreshIndex = np.where(arrValues > secondThreshold)[0] - greaterThanThreshValues = arrValues[greaterThanThreshIndex] - temp = np.zeros(arrValues.shape[0]) - temp[greaterThanThreshIndex] = greaterThanThreshValues - peaks = argrelextrema(temp, np.greater)[0] - - firstThresholdY = np.full(arrValues.shape[0], firstThreshold) - secondThresholdY = np.full(arrValues.shape[0], secondThreshold) - - newPeaks = np.full(arrValues.shape[0], np.nan) - newPeaks[peaks] = peaks + arrIndexes[0] - - # madY = np.full(arrValues.shape[0], mad) - medianY = np.full(arrValues.shape[0], median) - filteredOutMedianY = np.full(arrValues.shape[0], filteredOutMedian) - - return peaks, mad, filteredOutMad, medianY, filteredOutMedianY, firstThresholdY, secondThresholdY - - -def createChunks(z_score, sampling_rate, window): - - logger.debug("Creating chunks for multiprocessing...") - windowPoints = math.ceil(sampling_rate * window) - remainderPoints = math.ceil((sampling_rate * window) - (z_score.shape[0] % windowPoints)) - - if remainderPoints == windowPoints: - padded_z_score = z_score - z_score_index = np.arange(padded_z_score.shape[0]) - else: - padding = np.full(remainderPoints, np.nan) - padded_z_score = np.concatenate((z_score, padding)) - z_score_index = np.arange(padded_z_score.shape[0]) - - reshape = padded_z_score.shape[0] / windowPoints - - if reshape.is_integer() == True: - z_score_chunks = padded_z_score.reshape(int(reshape), -1) - z_score_chunks_index = z_score_index.reshape(int(reshape), -1) - else: - logger.error("Reshaping values should be integer.") - raise Exception("Reshaping values should be integer.") - logger.info("Chunks are created for multiprocessing.") - return z_score_chunks, z_score_chunks_index - - -def calculate_freq_amp(arr, z_score, z_score_chunks_index, timestamps): - peaks = arr[:, 0] - filteredOutMedian = arr[:, 4] - count = 0 - peaksAmp = np.array([]) - peaksInd = np.array([]) - for i in range(z_score_chunks_index.shape[0]): - count += peaks[i].shape[0] - peaksIndexes = peaks[i] + z_score_chunks_index[i][0] - peaksInd = np.concatenate((peaksInd, peaksIndexes)) - amps = z_score[peaksIndexes] - filteredOutMedian[i][0] - peaksAmp = np.concatenate((peaksAmp, amps)) - - peaksInd = peaksInd.ravel() - peaksInd = peaksInd.astype(int) - # logger.info(timestamps) - freq = peaksAmp.shape[0] / ((timestamps[-1] - timestamps[0]) / 60) - - return freq, peaksAmp, peaksInd - - -def create_Df(filepath, arr, name, index=[], columns=[]): - - op = os.path.join(filepath, "freqAndAmp_" + name + ".h5") - dirname = os.path.dirname(filepath) - - df = pd.DataFrame(arr, index=index, columns=columns) - - df.to_hdf(op, key="df", mode="w") - - -def create_csv(filepath, arr, name, index=[], columns=[]): - op = os.path.join(filepath, name) - df = pd.DataFrame(arr, index=index, columns=columns) - df.to_csv(op) - - -def read_Df(filepath, name): - op = os.path.join(filepath, "freqAndAmp_" + name + ".h5") - df = pd.read_hdf(op, key="df", mode="r") - - return df - - def visuzlize_peaks(filepath, z_score, timestamps, peaksIndex): dirname = os.path.dirname(filepath) @@ -189,27 +60,16 @@ def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_cou name_1 = basename.split("_")[-1] sampling_rate = read_hdf5("timeCorrection_" + name_1, filepath, "sampling_rate")[0] z_score = read_hdf5("", path[i], "data") - not_nan_indices = ~np.isnan(z_score) - z_score = z_score[not_nan_indices] - z_score_chunks, z_score_chunks_index = createChunks(z_score, sampling_rate, window) - - with mp.Pool(numProcesses) as p: - result = p.starmap( - processChunks, zip(z_score_chunks, z_score_chunks_index, repeat(highAmpFilt), repeat(transientsThresh)) - ) - - result = np.asarray(result, dtype=object) ts = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") - ts = ts[not_nan_indices] - freq, peaksAmp, peaksInd = calculate_freq_amp(result, z_score, z_score_chunks_index, ts) - peaks_occurrences = np.array([ts[peaksInd], peaksAmp]).T - arr = np.array([[freq, np.mean(peaksAmp)]]) + z_score, ts, peaksInd, peaks_occurrences, arr = analyze_transients( + ts, window, numProcesses, highAmpFilt, transientsThresh, sampling_rate, z_score + ) fileName = [os.path.basename(os.path.dirname(filepath))] - create_Df(filepath, arr, basename, index=fileName, columns=["freq (events/min)", "amplitude"]) - create_csv( + write_freq_and_amp_to_hdf5(filepath, arr, basename, index=fileName, columns=["freq (events/min)", "amplitude"]) + write_freq_and_amp_to_csv( filepath, arr, "freqAndAmp_" + basename + ".csv", index=fileName, columns=["freq (events/min)", "amplitude"] ) - create_csv( + write_freq_and_amp_to_csv( filepath, peaks_occurrences, "transientsOccurrences_" + basename + ".csv", @@ -220,80 +80,6 @@ def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_cou logger.info("Frequency and amplitude of transients in z_score data are calculated.") -def makeAverageDir(filepath): - - op = os.path.join(filepath, "average") - if not os.path.exists(op): - os.mkdir(op) - - return op - - -def averageForGroup(folderNames, inputParameters): - - logger.debug("Combining results for frequency and amplitude of transients in z-score data...") - path = [] - abspath = inputParameters["abspath"] - selectForTransientsComputation = inputParameters["selectForTransientsComputation"] - path_temp_len = [] - - for i in range(len(folderNames)): - if selectForTransientsComputation == "z_score": - path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) - elif selectForTransientsComputation == "dff": - path_temp = glob.glob(os.path.join(folderNames[i], "dff_*")) - else: - path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + glob.glob( - os.path.join(folderNames[i], "dff_*") - ) - - path_temp_len.append(len(path_temp)) - - for j in range(len(path_temp)): - basename = (os.path.basename(path_temp[j])).split(".")[0] - # name = name[0] - temp = [folderNames[i], basename] - path.append(temp) - - path_temp_len = np.asarray(path_temp_len) - max_len = np.argmax(path_temp_len) - - naming = [] - for i in range(len(path)): - naming.append(path[i][1]) - naming = np.unique(np.asarray(naming)) - - new_path = [[] for _ in range(path_temp_len[max_len])] - for i in range(len(path)): - idx = np.where(naming == path[i][1])[0][0] - new_path[idx].append(path[i]) - - op = makeAverageDir(abspath) - - for i in range(len(new_path)): - arr = [] # np.zeros((len(new_path[i]), 2)) - fileName = [] - temp_path = new_path[i] - for j in range(len(temp_path)): - if not os.path.exists(os.path.join(temp_path[j][0], "freqAndAmp_" + temp_path[j][1] + ".h5")): - continue - else: - df = read_Df(temp_path[j][0], temp_path[j][1]) - arr.append(np.array([df["freq (events/min)"].iloc[0], df["amplitude"].iloc[0]])) - fileName.append(os.path.basename(temp_path[j][0])) - - arr = np.asarray(arr) - create_Df(op, arr, temp_path[j][1], index=fileName, columns=["freq (events/min)", "amplitude"]) - create_csv( - op, - arr, - "freqAndAmp_" + temp_path[j][1] + ".csv", - index=fileName, - columns=["freq (events/min)", "amplitude"], - ) - logger.info("Results for frequency and amplitude of transients in z-score data are combined.") - - def executeFindFreqAndAmp(inputParameters): logger.info("Finding transients in z-score data and calculating frequency and amplitude....") @@ -316,57 +102,63 @@ def executeFindFreqAndAmp(inputParameters): numProcesses = mp.cpu_count() - 1 if average == True: - if len(folderNamesForAvg) > 0: - storesListPath = [] - for i in range(len(folderNamesForAvg)): - filepath = folderNamesForAvg[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) - storesListPath = np.concatenate(storesListPath) - averageForGroup(storesListPath, inputParameters) - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 - else: - logger.error("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") - raise Exception("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") - + execute_average_for_group(inputParameters, folderNamesForAvg) else: if combine_data == True: - storesListPath = [] - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) - storesListPath = list(np.concatenate(storesListPath).flatten()) - op = get_all_stores_for_combining_data(storesListPath) - for i in range(len(op)): - filepath = op[i][0] - storesList = np.genfromtxt( - os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," - ).reshape(2, -1) - findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 - plt.show() + execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_window, numProcesses) else: - for i in range(len(folderNames)): - logger.debug( - f"Finding transients in z-score data of {folderNames[i]} and calculating frequency and amplitude." - ) - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt( - os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," - ).reshape(2, -1) - findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 - logger.info("Transients in z-score data found and frequency and amplitude are calculated.") - plt.show() + execute_find_freq_and_amp(inputParameters, folderNames, moving_window, numProcesses) logger.info("Transients in z-score data found and frequency and amplitude are calculated.") +def execute_find_freq_and_amp(inputParameters, folderNames, moving_window, numProcesses): + for i in range(len(folderNames)): + logger.debug(f"Finding transients in z-score data of {folderNames[i]} and calculating frequency and amplitude.") + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ) + findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + logger.info("Transients in z-score data found and frequency and amplitude are calculated.") + plt.show() + + +def execute_find_freq_and_amp_combined(inputParameters, folderNames, moving_window, numProcesses): + storesListPath = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + for i in range(len(op)): + filepath = op[i][0] + storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1) + findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + plt.show() + + +def execute_average_for_group(inputParameters, folderNamesForAvg): + if len(folderNamesForAvg) == 0: + logger.error("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + raise Exception("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + storesListPath = [] + for i in range(len(folderNamesForAvg)): + filepath = folderNamesForAvg[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + averageForGroup(storesListPath, inputParameters) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + + if __name__ == "__main__": try: executeFindFreqAndAmp(json.loads(sys.argv[1])) diff --git a/src/guppy/preprocess.py b/src/guppy/preprocess.py index 8b79039..e4812a2 100755 --- a/src/guppy/preprocess.py +++ b/src/guppy/preprocess.py @@ -1,22 +1,44 @@ -import fnmatch import glob import json import logging import os -import re -import shutil import sys -import h5py import matplotlib.pyplot as plt import numpy as np -import pandas as pd -from scipy import signal as ss -from scipy.optimize import curve_fit -from .combineDataFn import processTimestampsForCombiningData - -logger = logging.getLogger(__name__) +from .analysis.artifact_removal import remove_artifacts +from .analysis.combine_data import combine_data +from .analysis.control_channel import add_control_channel, create_control_channel +from .analysis.io_utils import ( + check_storeslistfile, + check_TDT, + find_files, + get_all_stores_for_combining_data, # noqa: F401 -- Necessary for other modules that depend on preprocess.py + get_coords, + read_hdf5, + takeOnlyDirs, +) +from .analysis.standard_io import ( + read_control_and_signal, + read_coords_pairwise, + read_corrected_data, + read_corrected_data_dict, + read_corrected_timestamps_pairwise, + read_corrected_ttl_timestamps, + read_data_for_combining_data, + read_timestamps_for_combining_data, + read_ttl, + read_ttl_timestamps_for_combining_data, + write_artifact_removal, + write_combined_data, + write_corrected_data, + write_corrected_timestamps, + write_corrected_ttl_timestamps, + write_zscore, +) +from .analysis.timestamp_correction import correct_timestamps +from .analysis.z_score import compute_z_score logger = logging.getLogger(__name__) @@ -25,404 +47,11 @@ plt.switch_backend("TKAgg") -def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths) - set(removePaths)) - - def writeToFile(value: str): with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: file.write(value) -# find files by ignoring the case sensitivity -def find_files(path, glob_path, ignore_case=False): - rule = ( - re.compile(fnmatch.translate(glob_path), re.IGNORECASE) - if ignore_case - else re.compile(fnmatch.translate(glob_path)) - ) - - no_bytes_path = os.listdir(os.path.expanduser(path)) - str_path = [] - - # converting byte object to string - for x in no_bytes_path: - try: - str_path.append(x.decode("utf-8")) - except: - str_path.append(x) - return [os.path.join(path, n) for n in str_path if rule.match(n)] - - -# curve fit exponential function -def curveFitFn(x, a, b, c): - return a + (b * np.exp(-(1 / c) * x)) - - -# helper function to create control channel using signal channel -# by curve fitting signal channel to exponential function -# when there is no isosbestic control channel is present -def helper_create_control_channel(signal, timestamps, window): - # check if window is greater than signal shape - if window > signal.shape[0]: - window = ((signal.shape[0] + 1) / 2) + 1 - if window % 2 != 0: - window = window - else: - window = window + 1 - - filtered_signal = ss.savgol_filter(signal, window_length=window, polyorder=3) - - p0 = [5, 50, 60] - - try: - popt, pcov = curve_fit(curveFitFn, timestamps, filtered_signal, p0) - except Exception as e: - logger.error(str(e)) - - # logger.info('Curve Fit Parameters : ', popt) - control = curveFitFn(timestamps, *popt) - - return control - - -# main function to create control channel using -# signal channel and save it to a file -def create_control_channel(filepath, arr, window=5001): - - storenames = arr[0, :] - storesList = arr[1, :] - - for i in range(storesList.shape[0]): - event_name, event = storesList[i], storenames[i] - if "control" in event_name.lower() and "cntrl" in event.lower(): - logger.debug("Creating control channel from signal channel using curve-fitting") - name = event_name.split("_")[-1] - signal = read_hdf5("signal_" + name, filepath, "data") - timestampNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") - sampling_rate = np.full(timestampNew.shape, np.nan) - sampling_rate[0] = read_hdf5("timeCorrection_" + name, filepath, "sampling_rate")[0] - - control = helper_create_control_channel(signal, timestampNew, window) - - write_hdf5(control, event_name, filepath, "data") - d = {"timestamps": timestampNew, "data": control, "sampling_rate": sampling_rate} - df = pd.DataFrame(d) - df.to_csv(os.path.join(os.path.dirname(filepath), event.lower() + ".csv"), index=False) - logger.info("Control channel from signal channel created using curve-fitting") - - -# function to add control channel when there is no -# isosbestic control channel and update the storeslist file -def add_control_channel(filepath, arr): - - storenames = arr[0, :] - storesList = np.char.lower(arr[1, :]) - - keep_control = np.array([]) - # check a case if there is isosbestic control channel present - for i in range(storesList.shape[0]): - if "control" in storesList[i].lower(): - name = storesList[i].split("_")[-1] - new_str = "signal_" + str(name).lower() - find_signal = [True for i in storesList if i == new_str] - if len(find_signal) > 1: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - if len(find_signal) == 0: - logger.error( - "Isosbectic control channel parameter is set to False and still \ - storeslist file shows there is control channel present" - ) - raise Exception( - "Isosbectic control channel parameter is set to False and still \ - storeslist file shows there is control channel present" - ) - else: - continue - - for i in range(storesList.shape[0]): - if "signal" in storesList[i].lower(): - name = storesList[i].split("_")[-1] - new_str = "control_" + str(name).lower() - find_signal = [True for i in storesList if i == new_str] - if len(find_signal) == 0: - src, dst = os.path.join(filepath, arr[0, i] + ".hdf5"), os.path.join( - filepath, "cntrl" + str(i) + ".hdf5" - ) - shutil.copyfile(src, dst) - arr = np.concatenate((arr, [["cntrl" + str(i)], ["control_" + str(arr[1, i].split("_")[-1])]]), axis=1) - - np.savetxt(os.path.join(filepath, "storesList.csv"), arr, delimiter=",", fmt="%s") - - return arr - - -# check if dealing with TDT files or csv files -def check_TDT(filepath): - path = glob.glob(os.path.join(filepath, "*.tsq")) - if len(path) > 0: - return True - else: - return False - - -# function to read hdf5 file -def read_hdf5(event, filepath, key): - if event: - event = event.replace("\\", "_") - event = event.replace("/", "_") - op = os.path.join(filepath, event + ".hdf5") - else: - op = filepath - - if os.path.exists(op): - with h5py.File(op, "r") as f: - arr = np.asarray(f[key]) - else: - logger.error(f"{event}.hdf5 file does not exist") - raise Exception("{}.hdf5 file does not exist".format(event)) - - return arr - - -# function to write hdf5 file -def write_hdf5(data, event, filepath, key): - event = event.replace("\\", "_") - event = event.replace("/", "_") - op = os.path.join(filepath, event + ".hdf5") - - # if file does not exist create a new file - if not os.path.exists(op): - with h5py.File(op, "w") as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - - # if file already exists, append data to it or add a new key to it - else: - with h5py.File(op, "r+") as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - - -# function to check control and signal channel has same length -# if not, take a smaller length and do pre-processing -def check_cntrl_sig_length(filepath, channels_arr, storenames, storesList): - - indices = [] - for i in range(channels_arr.shape[1]): - idx_c = np.where(storesList == channels_arr[0, i])[0] - idx_s = np.where(storesList == channels_arr[1, i])[0] - control = read_hdf5(storenames[idx_c[0]], filepath, "data") - signal = read_hdf5(storenames[idx_s[0]], filepath, "data") - if control.shape[0] < signal.shape[0]: - indices.append(storesList[idx_c[0]]) - elif control.shape[0] > signal.shape[0]: - indices.append(storesList[idx_s[0]]) - else: - indices.append(storesList[idx_s[0]]) - - return indices - - -# function to correct timestamps after eliminating first few seconds of the data (for csv data) -def timestampCorrection_csv(filepath, timeForLightsTurnOn, storesList): - - logger.debug( - f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds" - ) - storenames = storesList[0, :] - storesList = storesList[1, :] - - arr = [] - for i in range(storesList.shape[0]): - if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): - arr.append(storesList[i]) - - arr = sorted(arr, key=str.casefold) - try: - arr = np.asarray(arr).reshape(2, -1) - except: - logger.error("Error in saving stores list file or spelling mistake for control or signal") - raise Exception("Error in saving stores list file or spelling mistake for control or signal") - - indices = check_cntrl_sig_length(filepath, arr, storenames, storesList) - - for i in range(arr.shape[1]): - name_1 = arr[0, i].split("_")[-1] - name_2 = arr[1, i].split("_")[-1] - # dirname = os.path.dirname(path[i]) - idx = np.where(storesList == indices[i])[0] - - if idx.shape[0] == 0: - logger.error(f"{arr[0,i]} does not exist in the stores list file.") - raise Exception("{} does not exist in the stores list file.".format(arr[0, i])) - - timestamp = read_hdf5(storenames[idx][0], filepath, "timestamps") - sampling_rate = read_hdf5(storenames[idx][0], filepath, "sampling_rate") - - if name_1 == name_2: - correctionIndex = np.where(timestamp >= timeForLightsTurnOn)[0] - timestampNew = timestamp[correctionIndex] - write_hdf5(timestampNew, "timeCorrection_" + name_1, filepath, "timestampNew") - write_hdf5(correctionIndex, "timeCorrection_" + name_1, filepath, "correctionIndex") - write_hdf5(np.asarray(sampling_rate), "timeCorrection_" + name_1, filepath, "sampling_rate") - - else: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - - logger.info("Timestamps corrected and converted to seconds.") - - -# function to correct timestamps after eliminating first few seconds of the data (for TDT data) -def timestampCorrection_tdt(filepath, timeForLightsTurnOn, storesList): - - logger.debug( - f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds" - ) - storenames = storesList[0, :] - storesList = storesList[1, :] - - arr = [] - for i in range(storesList.shape[0]): - if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): - arr.append(storesList[i]) - - arr = sorted(arr, key=str.casefold) - - try: - arr = np.asarray(arr).reshape(2, -1) - except: - logger.error("Error in saving stores list file or spelling mistake for control or signal") - raise Exception("Error in saving stores list file or spelling mistake for control or signal") - - indices = check_cntrl_sig_length(filepath, arr, storenames, storesList) - - for i in range(arr.shape[1]): - name_1 = arr[0, i].split("_")[-1] - name_2 = arr[1, i].split("_")[-1] - # dirname = os.path.dirname(path[i]) - idx = np.where(storesList == indices[i])[0] - - if idx.shape[0] == 0: - logger.error(f"{arr[0,i]} does not exist in the stores list file.") - raise Exception("{} does not exist in the stores list file.".format(arr[0, i])) - - timestamp = read_hdf5(storenames[idx][0], filepath, "timestamps") - npoints = read_hdf5(storenames[idx][0], filepath, "npoints") - sampling_rate = read_hdf5(storenames[idx][0], filepath, "sampling_rate") - - if name_1 == name_2: - timeRecStart = timestamp[0] - timestamps = np.subtract(timestamp, timeRecStart) - adder = np.arange(npoints) / sampling_rate - lengthAdder = adder.shape[0] - timestampNew = np.zeros((len(timestamps), lengthAdder)) - for i in range(lengthAdder): - timestampNew[:, i] = np.add(timestamps, adder[i]) - timestampNew = (timestampNew.T).reshape(-1, order="F") - correctionIndex = np.where(timestampNew >= timeForLightsTurnOn)[0] - timestampNew = timestampNew[correctionIndex] - - write_hdf5(np.asarray([timeRecStart]), "timeCorrection_" + name_1, filepath, "timeRecStart") - write_hdf5(timestampNew, "timeCorrection_" + name_1, filepath, "timestampNew") - write_hdf5(correctionIndex, "timeCorrection_" + name_1, filepath, "correctionIndex") - write_hdf5(np.asarray([sampling_rate]), "timeCorrection_" + name_1, filepath, "sampling_rate") - else: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - - logger.info("Timestamps corrected and converted to seconds.") - # return timeRecStart, correctionIndex, timestampNew - - -# function to apply correction to control, signal and event timestamps -def applyCorrection(filepath, timeForLightsTurnOn, event, displayName, naming): - - cond = check_TDT(os.path.dirname(filepath)) - - if cond == True: - timeRecStart = read_hdf5("timeCorrection_" + naming, filepath, "timeRecStart")[0] - - timestampNew = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") - correctionIndex = read_hdf5("timeCorrection_" + naming, filepath, "correctionIndex") - - if "control" in displayName.lower() or "signal" in displayName.lower(): - split_name = displayName.split("_")[-1] - if split_name == naming: - pass - else: - correctionIndex = read_hdf5("timeCorrection_" + split_name, filepath, "correctionIndex") - arr = read_hdf5(event, filepath, "data") - if (arr == 0).all() == True: - arr = arr - else: - arr = arr[correctionIndex] - write_hdf5(arr, displayName, filepath, "data") - else: - arr = read_hdf5(event, filepath, "timestamps") - if cond == True: - res = (arr >= timeRecStart).all() - if res == True: - arr = np.subtract(arr, timeRecStart) - arr = np.subtract(arr, timeForLightsTurnOn) - else: - arr = np.subtract(arr, timeForLightsTurnOn) - else: - arr = np.subtract(arr, timeForLightsTurnOn) - write_hdf5(arr, displayName + "_" + naming, filepath, "ts") - - # if isosbestic_control==False and 'control' in displayName.lower(): - # control = create_control_channel(filepath, displayName) - # write_hdf5(control, displayName, filepath, 'data') - - -# function to check if naming convention was followed while saving storeslist file -# and apply timestamps correction using the function applyCorrection -def decide_naming_convention_and_applyCorrection(filepath, timeForLightsTurnOn, event, displayName, storesList): - - logger.debug("Applying correction of timestamps to the data and event timestamps") - storesList = storesList[1, :] - - arr = [] - for i in range(storesList.shape[0]): - if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): - arr.append(storesList[i]) - - arr = sorted(arr, key=str.casefold) - arr = np.asarray(arr).reshape(2, -1) - - for i in range(arr.shape[1]): - name_1 = arr[0, i].split("_")[-1] - name_2 = arr[1, i].split("_")[-1] - # dirname = os.path.dirname(path[i]) - if name_1 == name_2: - applyCorrection(filepath, timeForLightsTurnOn, event, displayName, name_1) - else: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - - logger.info("Timestamps corrections applied to the data and event timestamps.") - - # function to plot z_score def visualize_z_score(filepath): @@ -590,421 +219,6 @@ def visualizeControlAndSignal(filepath, removeArtifacts): visualize(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts) -# function to check if the naming convention for saving storeslist file was followed or not -def decide_naming_convention(filepath): - path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) - - path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) - - path = sorted(path_1 + path_2, key=str.casefold) - if len(path) % 2 != 0: - logger.error("There are not equal number of Control and Signal data") - raise Exception("There are not equal number of Control and Signal data") - - path = np.asarray(path).reshape(2, -1) - - return path - - -# function to read coordinates file which was saved by selecting chunks for artifacts removal -def fetchCoords(filepath, naming, data): - - path = os.path.join(filepath, "coordsForPreProcessing_" + naming + ".npy") - - if not os.path.exists(path): - coords = np.array([0, data[-1]]) - else: - coords = np.load(os.path.join(filepath, "coordsForPreProcessing_" + naming + ".npy"))[:, 0] - - if coords.shape[0] % 2 != 0: - logger.error("Number of values in coordsForPreProcessing file is not even.") - raise Exception("Number of values in coordsForPreProcessing file is not even.") - - coords = coords.reshape(-1, 2) - - return coords - - -# helper function to process control and signal timestamps -def eliminateData(filepath, timeForLightsTurnOn, event, sampling_rate, naming): - - ts = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") - data = read_hdf5(event, filepath, "data").reshape(-1) - coords = fetchCoords(filepath, naming, ts) - - if (data == 0).all() == True: - data = np.zeros(ts.shape[0]) - - arr = np.array([]) - ts_arr = np.array([]) - for i in range(coords.shape[0]): - - index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] - - if len(arr) == 0: - arr = np.concatenate((arr, data[index])) - sub = ts[index][0] - timeForLightsTurnOn - new_ts = ts[index] - sub - ts_arr = np.concatenate((ts_arr, new_ts)) - else: - temp = data[index] - # new = temp + (arr[-1]-temp[0]) - temp_ts = ts[index] - new_ts = temp_ts - (temp_ts[0] - ts_arr[-1]) - arr = np.concatenate((arr, temp)) - ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) - - # logger.info(arr.shape, ts_arr.shape) - return arr, ts_arr - - -# helper function to align event timestamps with the control and signal timestamps -def eliminateTs(filepath, timeForLightsTurnOn, event, sampling_rate, naming): - - tsNew = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") - ts = read_hdf5(event + "_" + naming, filepath, "ts").reshape(-1) - coords = fetchCoords(filepath, naming, tsNew) - - ts_arr = np.array([]) - tsNew_arr = np.array([]) - for i in range(coords.shape[0]): - tsNew_index = np.where((tsNew > coords[i, 0]) & (tsNew < coords[i, 1]))[0] - ts_index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] - - if len(tsNew_arr) == 0: - sub = tsNew[tsNew_index][0] - timeForLightsTurnOn - tsNew_arr = np.concatenate((tsNew_arr, tsNew[tsNew_index] - sub)) - ts_arr = np.concatenate((ts_arr, ts[ts_index] - sub)) - else: - temp_tsNew = tsNew[tsNew_index] - temp_ts = ts[ts_index] - new_ts = temp_ts - (temp_tsNew[0] - tsNew_arr[-1]) - new_tsNew = temp_tsNew - (temp_tsNew[0] - tsNew_arr[-1]) - tsNew_arr = np.concatenate((tsNew_arr, new_tsNew + (1 / sampling_rate))) - ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) - - return ts_arr - - -# adding nan values to removed chunks -# when using artifacts removal method - replace with NaN -def addingNaNValues(filepath, event, naming): - - ts = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") - data = read_hdf5(event, filepath, "data").reshape(-1) - coords = fetchCoords(filepath, naming, ts) - - if (data == 0).all() == True: - data = np.zeros(ts.shape[0]) - - arr = np.array([]) - ts_index = np.arange(ts.shape[0]) - for i in range(coords.shape[0]): - - index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] - arr = np.concatenate((arr, index)) - - nan_indices = list(set(ts_index).symmetric_difference(arr)) - data[nan_indices] = np.nan - - return data - - -# remove event TTLs which falls in the removed chunks -# when using artifacts removal method - replace with NaN -def removeTTLs(filepath, event, naming): - tsNew = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") - ts = read_hdf5(event + "_" + naming, filepath, "ts").reshape(-1) - coords = fetchCoords(filepath, naming, tsNew) - - ts_arr = np.array([]) - for i in range(coords.shape[0]): - ts_index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] - ts_arr = np.concatenate((ts_arr, ts[ts_index])) - - return ts_arr - - -def addingNaNtoChunksWithArtifacts(filepath, events): - - logger.debug("Replacing chunks with artifacts by NaN values.") - storesList = events[1, :] - - path = decide_naming_convention(filepath) - - for j in range(path.shape[1]): - name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_") - name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_") - # dirname = os.path.dirname(path[i]) - if name_1[-1] == name_2[-1]: - name = name_1[-1] - sampling_rate = read_hdf5("timeCorrection_" + name, filepath, "sampling_rate")[0] - for i in range(len(storesList)): - if ( - "control_" + name.lower() in storesList[i].lower() - or "signal_" + name.lower() in storesList[i].lower() - ): # changes done - data = addingNaNValues(filepath, storesList[i], name) - write_hdf5(data, storesList[i], filepath, "data") - else: - if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): - continue - else: - ts = removeTTLs(filepath, storesList[i], name) - write_hdf5(ts, storesList[i] + "_" + name, filepath, "ts") - - else: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - logger.info("Chunks with artifacts are replaced by NaN values.") - - -# main function to align timestamps for control, signal and event timestamps for artifacts removal -def processTimestampsForArtifacts(filepath, timeForLightsTurnOn, events): - - logger.debug("Processing timestamps to get rid of artifacts using concatenate method...") - storesList = events[1, :] - - path = decide_naming_convention(filepath) - - timestamp_dict = dict() - for j in range(path.shape[1]): - name_1 = ((os.path.basename(path[0, j])).split(".")[0]).split("_") - name_2 = ((os.path.basename(path[1, j])).split(".")[0]).split("_") - # dirname = os.path.dirname(path[i]) - if name_1[-1] == name_2[-1]: - name = name_1[-1] - sampling_rate = read_hdf5("timeCorrection_" + name, filepath, "sampling_rate")[0] - - for i in range(len(storesList)): - if ( - "control_" + name.lower() in storesList[i].lower() - or "signal_" + name.lower() in storesList[i].lower() - ): # changes done - data, timestampNew = eliminateData( - filepath, timeForLightsTurnOn, storesList[i], sampling_rate, name - ) - write_hdf5(data, storesList[i], filepath, "data") - else: - if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): - continue - else: - ts = eliminateTs(filepath, timeForLightsTurnOn, storesList[i], sampling_rate, name) - write_hdf5(ts, storesList[i] + "_" + name, filepath, "ts") - - # timestamp_dict[name] = timestampNew - write_hdf5(timestampNew, "timeCorrection_" + name, filepath, "timestampNew") - else: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - logger.info("Timestamps processed, artifacts are removed and good chunks are concatenated.") - - -# function to compute deltaF/F using fitted control channel and filtered signal channel -def deltaFF(signal, control): - - res = np.subtract(signal, control) - normData = np.divide(res, control) - # deltaFF = normData - normData = normData * 100 - - return normData - - -# function to fit control channel to signal channel -def controlFit(control, signal): - - p = np.polyfit(control, signal, 1) - arr = (p[0] * control) + p[1] - return arr - - -def filterSignal(filter_window, signal): - if filter_window == 0: - return signal - elif filter_window > 1: - b = np.divide(np.ones((filter_window,)), filter_window) - a = 1 - filtered_signal = ss.filtfilt(b, a, signal) - return filtered_signal - else: - raise Exception("Moving average filter window value is not correct.") - - -# function to filter control and signal channel, also execute above two function : controlFit and deltaFF -# function will also take care if there is only signal channel and no control channel -# if there is only signal channel, z-score will be computed using just signal channel -def execute_controlFit_dff(control, signal, isosbestic_control, filter_window): - - if isosbestic_control == False: - signal_smooth = filterSignal(filter_window, signal) # ss.filtfilt(b, a, signal) - control_fit = controlFit(control, signal_smooth) - norm_data = deltaFF(signal_smooth, control_fit) - else: - control_smooth = filterSignal(filter_window, control) # ss.filtfilt(b, a, control) - signal_smooth = filterSignal(filter_window, signal) # ss.filtfilt(b, a, signal) - control_fit = controlFit(control_smooth, signal_smooth) - norm_data = deltaFF(signal_smooth, control_fit) - - return norm_data, control_fit - - -# function to compute z-score based on z-score computation method -def z_score_computation(dff, timestamps, inputParameters): - - zscore_method = inputParameters["zscore_method"] - baseline_start, baseline_end = inputParameters["baselineWindowStart"], inputParameters["baselineWindowEnd"] - - if zscore_method == "standard z-score": - numerator = np.subtract(dff, np.nanmean(dff)) - zscore = np.divide(numerator, np.nanstd(dff)) - elif zscore_method == "baseline z-score": - idx = np.where((timestamps > baseline_start) & (timestamps < baseline_end))[0] - if idx.shape[0] == 0: - logger.error( - "Baseline Window Parameters for baseline z-score computation zscore_method \ - are not correct." - ) - raise Exception( - "Baseline Window Parameters for baseline z-score computation zscore_method \ - are not correct." - ) - else: - baseline_mean = np.nanmean(dff[idx]) - baseline_std = np.nanstd(dff[idx]) - numerator = np.subtract(dff, baseline_mean) - zscore = np.divide(numerator, baseline_std) - else: - median = np.median(dff) - mad = np.median(np.abs(dff - median)) - numerator = 0.6745 * (dff - median) - zscore = np.divide(numerator, mad) - - return zscore - - -# helper function to compute z-score and deltaF/F -def helper_z_score(control, signal, filepath, name, inputParameters): # helper_z_score(control_smooth, signal_smooth): - - removeArtifacts = inputParameters["removeArtifacts"] - artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] - filter_window = inputParameters["filter_window"] - - isosbestic_control = inputParameters["isosbestic_control"] - tsNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") - coords_path = os.path.join(filepath, "coordsForPreProcessing_" + name + ".npy") - - logger.info("Remove Artifacts : ", removeArtifacts) - - if (control == 0).all() == True: - control = np.zeros(tsNew.shape[0]) - - z_score_arr = np.array([]) - norm_data_arr = np.full(tsNew.shape[0], np.nan) - control_fit_arr = np.full(tsNew.shape[0], np.nan) - temp_control_arr = np.full(tsNew.shape[0], np.nan) - - if removeArtifacts == True: - coords = fetchCoords(filepath, name, tsNew) - - # for artifacts removal, each chunk which was selected by user is being processed individually and then - # z-score is calculated - for i in range(coords.shape[0]): - tsNew_index = np.where((tsNew > coords[i, 0]) & (tsNew < coords[i, 1]))[0] - if isosbestic_control == False: - control_arr = helper_create_control_channel(signal[tsNew_index], tsNew[tsNew_index], window=101) - signal_arr = signal[tsNew_index] - norm_data, control_fit = execute_controlFit_dff( - control_arr, signal_arr, isosbestic_control, filter_window - ) - temp_control_arr[tsNew_index] = control_arr - if i < coords.shape[0] - 1: - blank_index = np.where((tsNew > coords[i, 1]) & (tsNew < coords[i + 1, 0]))[0] - temp_control_arr[blank_index] = np.full(blank_index.shape[0], np.nan) - else: - control_arr = control[tsNew_index] - signal_arr = signal[tsNew_index] - norm_data, control_fit = execute_controlFit_dff( - control_arr, signal_arr, isosbestic_control, filter_window - ) - norm_data_arr[tsNew_index] = norm_data - control_fit_arr[tsNew_index] = control_fit - - if artifactsRemovalMethod == "concatenate": - norm_data_arr = norm_data_arr[~np.isnan(norm_data_arr)] - control_fit_arr = control_fit_arr[~np.isnan(control_fit_arr)] - z_score = z_score_computation(norm_data_arr, tsNew, inputParameters) - z_score_arr = np.concatenate((z_score_arr, z_score)) - else: - tsNew_index = np.arange(tsNew.shape[0]) - norm_data, control_fit = execute_controlFit_dff(control, signal, isosbestic_control, filter_window) - z_score = z_score_computation(norm_data, tsNew, inputParameters) - z_score_arr = np.concatenate((z_score_arr, z_score)) - norm_data_arr[tsNew_index] = norm_data # np.concatenate((norm_data_arr, norm_data)) - control_fit_arr[tsNew_index] = control_fit # np.concatenate((control_fit_arr, control_fit)) - - # handle the case if there are chunks being cut in the front and the end - if isosbestic_control == False and removeArtifacts == True: - coords = coords.flatten() - # front chunk - idx = np.where((tsNew >= tsNew[0]) & (tsNew < coords[0]))[0] - temp_control_arr[idx] = np.full(idx.shape[0], np.nan) - # end chunk - idx = np.where((tsNew > coords[-1]) & (tsNew <= tsNew[-1]))[0] - temp_control_arr[idx] = np.full(idx.shape[0], np.nan) - write_hdf5(temp_control_arr, "control_" + name, filepath, "data") - - return z_score_arr, norm_data_arr, control_fit_arr - - -# compute z-score and deltaF/F and save it to hdf5 file -def compute_z_score(filepath, inputParameters): - - logger.debug(f"Computing z-score for each of the data in {filepath}") - remove_artifacts = inputParameters["removeArtifacts"] - - path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) - path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) - - path = sorted(path_1 + path_2, key=str.casefold) - - b = np.divide(np.ones((100,)), 100) - a = 1 - - if len(path) % 2 != 0: - logger.error("There are not equal number of Control and Signal data") - raise Exception("There are not equal number of Control and Signal data") - - path = np.asarray(path).reshape(2, -1) - - for i in range(path.shape[1]): - name_1 = ((os.path.basename(path[0, i])).split(".")[0]).split("_") - name_2 = ((os.path.basename(path[1, i])).split(".")[0]).split("_") - # dirname = os.path.dirname(path[i]) - - if name_1[-1] == name_2[-1]: - name = name_1[-1] - control = read_hdf5("", path[0, i], "data").reshape(-1) - signal = read_hdf5("", path[1, i], "data").reshape(-1) - # control_smooth = ss.filtfilt(b, a, control) - # signal_smooth = ss.filtfilt(b, a, signal) - # _score, dff = helper_z_score(control_smooth, signal_smooth) - z_score, dff, control_fit = helper_z_score(control, signal, filepath, name, inputParameters) - if remove_artifacts == True: - write_hdf5(z_score, "z_score_" + name, filepath, "data") - write_hdf5(dff, "dff_" + name, filepath, "data") - write_hdf5(control_fit, "cntrl_sig_fit_" + name, filepath, "data") - else: - write_hdf5(z_score, "z_score_" + name, filepath, "data") - write_hdf5(dff, "dff_" + name, filepath, "data") - write_hdf5(control_fit, "cntrl_sig_fit_" + name, filepath, "data") - else: - logger.error("Error in naming convention of files or Error in storesList file") - raise Exception("Error in naming convention of files or Error in storesList file") - - logger.info(f"z-score for the data in {filepath} computed.") - - # function to execute timestamps corrections using functions timestampCorrection and decide_naming_convention_and_applyCorrection def execute_timestamp_correction(folderNames, inputParameters): @@ -1014,7 +228,7 @@ def execute_timestamp_correction(folderNames, inputParameters): for i in range(len(folderNames)): filepath = folderNames[i] storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) - cond = check_TDT(folderNames[i]) + mode = "tdt" if check_TDT(folderNames[i]) else "csv" logger.debug(f"Timestamps corrections started for {filepath}") for j in range(len(storesListPath)): filepath = storesListPath[j] @@ -1025,15 +239,36 @@ def execute_timestamp_correction(folderNames, inputParameters): if isosbestic_control == False: storesList = add_control_channel(filepath, storesList) - if cond == True: - timestampCorrection_tdt(filepath, timeForLightsTurnOn, storesList) - else: - timestampCorrection_csv(filepath, timeForLightsTurnOn, storesList) - - for k in range(storesList.shape[1]): - decide_naming_convention_and_applyCorrection( - filepath, timeForLightsTurnOn, storesList[0, k], storesList[1, k], storesList - ) + control_and_signal_dicts = read_control_and_signal(filepath, storesList) + name_to_data, name_to_timestamps, name_to_sampling_rate, name_to_npoints = control_and_signal_dicts + name_to_timestamps_ttl = read_ttl(filepath, storesList) + + timestamps_dicts = correct_timestamps( + timeForLightsTurnOn, + storesList, + name_to_timestamps, + name_to_data, + name_to_sampling_rate, + name_to_npoints, + name_to_timestamps_ttl, + mode=mode, + ) + ( + name_to_corrected_timestamps, + name_to_correctionIndex, + name_to_corrected_data, + compound_name_to_corrected_ttl_timestamps, + ) = timestamps_dicts + + write_corrected_timestamps( + filepath, + name_to_corrected_timestamps, + name_to_timestamps, + name_to_sampling_rate, + name_to_correctionIndex, + ) + write_corrected_data(filepath, name_to_corrected_data) + write_corrected_ttl_timestamps(filepath, compound_name_to_corrected_ttl_timestamps) # check if isosbestic control is false and also if new control channel is added if isosbestic_control == False: @@ -1044,45 +279,133 @@ def execute_timestamp_correction(folderNames, inputParameters): logger.info(f"Timestamps corrections finished for {filepath}") -# for combining data, reading storeslist file from both data and create a new storeslist array -def check_storeslistfile(folderNames): - storesList = np.array([[], []]) +# function to compute z-score and deltaF/F +def execute_zscore(folderNames, inputParameters): + + plot_zScore_dff = inputParameters["plot_zScore_dff"] + combine_data = inputParameters["combine_data"] + remove_artifacts = inputParameters["removeArtifacts"] + artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] + filter_window = inputParameters["filter_window"] + isosbestic_control = inputParameters["isosbestic_control"] + zscore_method = inputParameters["zscore_method"] + baseline_start, baseline_end = inputParameters["baselineWindowStart"], inputParameters["baselineWindowEnd"] + + storesListPath = [] for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.concatenate( - ( - storesList, - np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1), - ), - axis=1, + if combine_data == True: + storesListPath.append([folderNames[i][0]]) + else: + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + + for j in range(len(storesListPath)): + filepath = storesListPath[j] + logger.debug(f"Computing z-score for each of the data in {filepath}") + path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) + path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) + path = sorted(path_1 + path_2, key=str.casefold) + if len(path) % 2 != 0: + logger.error("There are not equal number of Control and Signal data") + raise Exception("There are not equal number of Control and Signal data") + path = np.asarray(path).reshape(2, -1) + + for i in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, i])).split(".")[0]).split("_") + name_2 = ((os.path.basename(path[1, i])).split(".")[0]).split("_") + if name_1[-1] != name_2[-1]: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + name = name_1[-1] + + control, signal, tsNew = read_corrected_data(path[0, i], path[1, i], filepath, name) + coords = get_coords(filepath, name, tsNew, remove_artifacts) + z_score, dff, control_fit, temp_control_arr = compute_z_score( + control, + signal, + tsNew, + coords, + artifactsRemovalMethod, + filter_window, + isosbestic_control, + zscore_method, + baseline_start, + baseline_end, ) + write_zscore(filepath, name, z_score, dff, control_fit, temp_control_arr) - storesList = np.unique(storesList, axis=1) + logger.info(f"z-score for the data in {filepath} computed.") - return storesList + if not remove_artifacts: + visualizeControlAndSignal(filepath, removeArtifacts=remove_artifacts) + if plot_zScore_dff == "z_score": + visualize_z_score(filepath) + if plot_zScore_dff == "dff": + visualize_dff(filepath) + if plot_zScore_dff == "Both": + visualize_z_score(filepath) + visualize_dff(filepath) -def get_all_stores_for_combining_data(folderNames): - op = [] - for i in range(100): - temp = [] - match = r"[\s\S]*" + "_output_" + str(i) - for j in folderNames: - temp.append(re.findall(match, j)) - temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold) - if len(temp) > 0: - op.append(temp) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + + plt.show() + logger.info("Z-score computation completed.") - return op + +# function to remove artifacts from z-score data +def execute_artifact_removal(folderNames, inputParameters): + + timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] + artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] + combine_data = inputParameters["combine_data"] + + storesListPath = [] + for i in range(len(folderNames)): + if combine_data == True: + storesListPath.append([folderNames[i][0]]) + else: + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + + storesListPath = np.concatenate(storesListPath) + + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1) + + name_to_data = read_corrected_data_dict(filepath, storesList) + pair_name_to_tsNew, pair_name_to_sampling_rate = read_corrected_timestamps_pairwise(filepath) + pair_name_to_coords = read_coords_pairwise(filepath, pair_name_to_tsNew) + compound_name_to_ttl_timestamps = read_corrected_ttl_timestamps(filepath, storesList) + + logger.debug("Removing artifacts from the data...") + name_to_data, pair_name_to_timestamps, compound_name_to_ttl_timestamps = remove_artifacts( + timeForLightsTurnOn, + storesList, + pair_name_to_tsNew, + pair_name_to_sampling_rate, + pair_name_to_coords, + name_to_data, + compound_name_to_ttl_timestamps, + method=artifactsRemovalMethod, + ) + + write_artifact_removal(filepath, name_to_data, pair_name_to_timestamps, compound_name_to_ttl_timestamps) + visualizeControlAndSignal(filepath, removeArtifacts=True) + + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + + plt.show() + logger.info("Artifact removal completed.") # function to combine data when there are two different data files for the same recording session # it will combine the data, do timestamps processing and save the combined data in the first output folder. -def combineData(folderNames, inputParameters, storesList): - +def execute_combine_data(folderNames, inputParameters, storesList): logger.debug("Combining Data from different data files...") timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] op_folder = [] @@ -1117,64 +440,28 @@ def combineData(folderNames, inputParameters, storesList): op = get_all_stores_for_combining_data(op_folder) # processing timestamps for combining the data - processTimestampsForCombiningData(op, timeForLightsTurnOn, storesList, sampling_rate[0]) + for filepaths_to_combine in op: + pair_name_to_filepath_to_timestamps = read_timestamps_for_combining_data(filepaths_to_combine) + display_name_to_filepath_to_data = read_data_for_combining_data(filepaths_to_combine, storesList) + compound_name_to_filepath_to_ttl_timestamps = read_ttl_timestamps_for_combining_data( + filepaths_to_combine, storesList + ) + pair_name_to_tsNew, display_name_to_data, compound_name_to_ttl_timestamps = combine_data( + filepaths_to_combine, + pair_name_to_filepath_to_timestamps, + display_name_to_filepath_to_data, + compound_name_to_filepath_to_ttl_timestamps, + timeForLightsTurnOn, + storesList, + sampling_rate[0], + ) + output_filepath = filepaths_to_combine[0] + write_combined_data(output_filepath, pair_name_to_tsNew, display_name_to_data, compound_name_to_ttl_timestamps) logger.info("Data is combined from different data files.") return op -# function to compute z-score and deltaF/F using functions : compute_z_score and/or processTimestampsForArtifacts -def execute_zscore(folderNames, inputParameters): - - timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] - remove_artifacts = inputParameters["removeArtifacts"] - artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] - plot_zScore_dff = inputParameters["plot_zScore_dff"] - combine_data = inputParameters["combine_data"] - isosbestic_control = inputParameters["isosbestic_control"] - - storesListPath = [] - for i in range(len(folderNames)): - if combine_data == True: - storesListPath.append([folderNames[i][0]]) - else: - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) - - storesListPath = np.concatenate(storesListPath) - - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1) - - if remove_artifacts == True: - logger.debug("Removing Artifacts from the data and correcting timestamps...") - compute_z_score(filepath, inputParameters) - if artifactsRemovalMethod == "concatenate": - processTimestampsForArtifacts(filepath, timeForLightsTurnOn, storesList) - else: - addingNaNtoChunksWithArtifacts(filepath, storesList) - visualizeControlAndSignal(filepath, remove_artifacts) - logger.info("Artifacts from the data are removed and timestamps are corrected.") - else: - compute_z_score(filepath, inputParameters) - visualizeControlAndSignal(filepath, remove_artifacts) - - if plot_zScore_dff == "z_score": - visualize_z_score(filepath) - if plot_zScore_dff == "dff": - visualize_dff(filepath) - if plot_zScore_dff == "Both": - visualize_z_score(filepath) - visualize_dff(filepath) - - writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") - inputParameters["step"] += 1 - - plt.show() - logger.info("Signal data and event timestamps are extracted.") - - def extractTsAndSignal(inputParameters): logger.debug("Extracting signal data and event timestamps...") @@ -1203,13 +490,17 @@ def extractTsAndSignal(inputParameters): writeToFile(str((pbMaxValue + 1) * 10) + "\n" + str(10) + "\n") execute_timestamp_correction(folderNames, inputParameters) execute_zscore(folderNames, inputParameters) + if remove_artifacts == True: + execute_artifact_removal(folderNames, inputParameters) else: pbMaxValue = 1 + len(folderNames) writeToFile(str((pbMaxValue) * 10) + "\n" + str(10) + "\n") execute_timestamp_correction(folderNames, inputParameters) storesList = check_storeslistfile(folderNames) - op_folder = combineData(folderNames, inputParameters, storesList) + op_folder = execute_combine_data(folderNames, inputParameters, storesList) execute_zscore(op_folder, inputParameters) + if remove_artifacts == True: + execute_artifact_removal(op_folder, inputParameters) def main(input_parameters): diff --git a/src/guppy/testing/api.py b/src/guppy/testing/api.py index c647907..98939cf 100644 --- a/src/guppy/testing/api.py +++ b/src/guppy/testing/api.py @@ -268,6 +268,7 @@ def step4( npm_timestamp_column_names: list[str | None] | None = None, npm_time_units: list[str] | None = None, npm_split_events: list[bool] | None = None, + combine_data: bool = False, ) -> None: """ Run pipeline Step 4 (Extract timestamps and signal) via the Panel-backed logic, headlessly. @@ -293,6 +294,8 @@ def step4( List of time units for NPM files, one per CSV file (e.g., 'seconds', 'milliseconds'). None if not applicable. npm_split_events : list[bool] | None List of booleans indicating whether to split events for NPM files, one per CSV file. None if not applicable. + combine_data : bool + Whether to enable data combining logic in Step 4. Raises ------ @@ -345,6 +348,9 @@ def step4( # Inject modality input_params["modality"] = modality + # Inject combine_data + input_params["combine_data"] = combine_data + # Call the underlying Step 4 worker directly (no subprocess) extractTsAndSignal(input_params) diff --git a/tests/test_combine_data.py b/tests/test_combine_data.py new file mode 100644 index 0000000..f7c0261 --- /dev/null +++ b/tests/test_combine_data.py @@ -0,0 +1,138 @@ +import glob +import os +import shutil +from pathlib import Path + +import h5py +import pytest + +from guppy.testing.api import step2, step3, step4, step5 + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_combine_data(tmp_path, monkeypatch): + session_subdirs = [ + "SampleData_Clean/Photo_63_207-181030-103332", + "SampleData_with_artifacts/Photo_048_392-200728-121222", + ] + storenames_map = { + "Dv1A": "control_dms", + "Dv2A": "signal_dms", + "PrtN": "port_entries_dms", + } + expected_region = "dms" + expected_ttl = "port_entries_dms" + modality = "tdt" + + npm_timestamp_column_names = None + npm_time_units = None + npm_split_events = [True, True] + + # Use the CSV sample session + src_base_dir = str(Path(".") / "testing_data") + src_sessions = [os.path.join(src_base_dir, session_subdir) for session_subdir in session_subdirs] + for src_session in src_sessions: + if not os.path.isdir(src_session): + pytest.skip(f"Sample data not available at expected path: {src_session}") + + # Stub matplotlib.pyplot.show to avoid GUI blocking + import matplotlib.pyplot as plt # noqa: F401 + + monkeypatch.setattr("matplotlib.pyplot.show", lambda *args, **kwargs: None) + + # Stage a clean copy of the session into a temporary workspace + tmp_base = tmp_path / "data_root" + tmp_base.mkdir(parents=True, exist_ok=True) + session_copies = [] + for src_session in src_sessions: + dest_name = os.path.basename(src_session) + session_copy = tmp_base / dest_name + shutil.copytree(src_session, session_copy) + session_copies.append(session_copy) + + for session_copy in session_copies: + # Remove any copied artifacts in the temp session (match only this session's output dirs) + for d in glob.glob(os.path.join(session_copy, f"{dest_name}_output_*")): + assert os.path.isdir(d), f"Expected output directory for cleanup, got non-directory: {d}" + shutil.rmtree(d) + params_fp = session_copy / "GuPPyParamtersUsed.json" + if params_fp.exists(): + params_fp.unlink() + + selected_folders = [str(session_copy) for session_copy in session_copies] + base_dir = str(tmp_base) + + # Step 2: create storesList.csv in the temp copy + step2( + base_dir=base_dir, + selected_folders=selected_folders, + storenames_map=storenames_map, + modality=modality, + npm_timestamp_column_names=npm_timestamp_column_names, + npm_time_units=npm_time_units, + npm_split_events=npm_split_events, + ) + + # Step 3: read raw data in the temp copy + step3( + base_dir=base_dir, + selected_folders=selected_folders, + modality=modality, + npm_timestamp_column_names=npm_timestamp_column_names, + npm_time_units=npm_time_units, + npm_split_events=npm_split_events, + ) + + # Step 4: extract timestamps and signal in the temp copy + step4( + base_dir=base_dir, + selected_folders=selected_folders, + modality=modality, + npm_timestamp_column_names=npm_timestamp_column_names, + npm_time_units=npm_time_units, + npm_split_events=npm_split_events, + combine_data=True, + ) + + # Step 5: compute PSTH in the temp copy (headless) + step5( + base_dir=str(tmp_base), + selected_folders=[str(session_copy)], + modality=modality, + npm_timestamp_column_names=npm_timestamp_column_names, + npm_time_units=npm_time_units, + npm_split_events=npm_split_events, + ) + + # Validate outputs exist in the temp copy + session_copy = selected_folders[0] # Outputs are written to the first session folder + basename = os.path.basename(session_copy) + output_dirs = sorted(glob.glob(os.path.join(session_copy, f"{basename}_output_*"))) + assert output_dirs, f"No output directories found in {session_copy}" + out_dir = None + for d in output_dirs: + if os.path.exists(os.path.join(d, "storesList.csv")): + out_dir = d + break + assert out_dir is not None, f"No storesList.csv found in any output directory under {session_copy}" + stores_fp = os.path.join(out_dir, "storesList.csv") + assert os.path.exists(stores_fp), "Missing storesList.csv after Step 2/3/4" + + # Ensure timeCorrection_.hdf5 exists with 'timestampNew' + timecorr = os.path.join(out_dir, f"timeCorrection_{expected_region}.hdf5") + assert os.path.exists(timecorr), f"Missing {timecorr}" + with h5py.File(timecorr, "r") as f: + assert "timestampNew" in f, f"Expected 'timestampNew' dataset in {timecorr}" + + # If TTLs exist, check their per-region 'ts' outputs + if expected_ttl is None: + expected_ttls = [] + elif isinstance(expected_ttl, str): + expected_ttls = [expected_ttl] + else: + expected_ttls = expected_ttl + for expected_ttl in expected_ttls: + ttl_fp = os.path.join(out_dir, f"{expected_ttl}_{expected_region}.hdf5") + assert os.path.exists(ttl_fp), f"Missing TTL-aligned file {ttl_fp}" + with h5py.File(ttl_fp, "r") as f: + assert "ts" in f, f"Expected 'ts' dataset in {ttl_fp}"