diff --git a/pyproject.toml b/pyproject.toml index 4527fba..b0973cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "platformdirs", "scipy", "tables", + "pynwb", ] [dependency-groups] diff --git a/src/guppy/preprocess.py b/src/guppy/preprocess.py index 8b79039..7ab8c24 100755 --- a/src/guppy/preprocess.py +++ b/src/guppy/preprocess.py @@ -163,8 +163,9 @@ def add_control_channel(filepath, arr): # check if dealing with TDT files or csv files +# NWB files are treated like TDT files def check_TDT(filepath): - path = glob.glob(os.path.join(filepath, "*.tsq")) + path = glob.glob(os.path.join(filepath, "*.tsq")) + glob.glob(os.path.join(filepath, "*.nwb")) if len(path) > 0: return True else: diff --git a/src/guppy/readTevTsq.py b/src/guppy/readTevTsq.py index 6deb3b1..2525753 100755 --- a/src/guppy/readTevTsq.py +++ b/src/guppy/readTevTsq.py @@ -8,6 +8,7 @@ import time import warnings from itertools import repeat +from pathlib import Path import h5py import numpy as np @@ -471,6 +472,8 @@ def readRawData(inputParameters): logger.debug("### Reading raw data... ###") # get input parameters inputParameters = inputParameters + nwb_response_series_names = inputParameters["nwb_response_series_names"] + nwb_response_series_indices = inputParameters["nwb_response_series_indices"] folderNames = inputParameters["folderNames"] numProcesses = inputParameters["numberOfCores"] storesListPath = [] @@ -490,6 +493,8 @@ def readRawData(inputParameters): step = 0 for i in range(len(folderNames)): filepath = folderNames[i] + nwb_response_series_name = nwb_response_series_names[i] + indices = nwb_response_series_indices[i] logger.debug(f"### Reading raw data for folder {folderNames[i]}") storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) # reading tsq file @@ -499,6 +504,8 @@ def readRawData(inputParameters): pass else: flag = check_doric(filepath) + if flag == 0: # doric file(s) not found + flag = check_nwb(filepath) # read data corresponding to each storename selected by user while saving the storeslist file for j in range(len(storesListPath)): @@ -518,6 +525,9 @@ def readRawData(inputParameters): execute_import_doric(filepath, storesList, flag, op) elif flag == "doric_doric": execute_import_doric(filepath, storesList, flag, op) + elif flag == "nwb": + filepath = Path(filepath) + read_nwb(filepath, op, nwb_response_series_name, indices) else: execute_import_csv(filepath, np.unique(storesList[0, :]), op, numProcesses) @@ -528,6 +538,100 @@ def readRawData(inputParameters): logger.info("#" * 400) +def check_nwb(filepath: str): + """ + Check if an NWB file is present at the given location. + + Parameters + ---------- + filepath : str + Path to the folder containing the NWB file. + + Returns + ------- + flag : str + Flag indicating the presence of an NWB file. If present, the flag is set to 'nwb'. If not present, the flag is set to 0. + + Raises + ------ + Exception + If two NWB files are present at the location. + """ + nwbfile_paths = glob.glob(os.path.join(filepath, "*.nwb")) + if len(nwbfile_paths) > 1: + logging.error("Two nwb files are present at the location.") + raise Exception("Two nwb files are present at the location.") + elif len(nwbfile_paths) == 0: + logging.error("\033[1m" + "NWB file not found." + "\033[0m") + return 0 + else: + flag = "nwb" + return flag + + +def read_nwb(filepath: str, outputPath: str, response_series_name: str, indices: list[int], npoints: int = 128): + """ + Read photometry data from an NWB file and save the output to a hdf5 file. + + Parameters + ---------- + filepath : str + Path to the folder containing the NWB file. + outputPath : str + Path to the folder where the output data will be saved. + response_series_name : str + Name of the response series in the NWB file. + indices : List[int] + List of indices of the response series to be read. + npoints : int, optional + Number of points for each chunk. Timestamps are only saved for the first point in each chunk. Default is 128. + + Raises + ------ + Exception + If two NWB files are present at the location. + """ + print("read_nwb") + from pynwb import ( + NWBHDF5IO, + ) # Dynamic import is necessary since pynwb isn't available in the main environment (python 3.6) + + nwbfilepath = glob.glob(os.path.join(filepath, "*.nwb")) + if len(nwbfilepath) > 1: + raise Exception("Two nwb files are present at the location.") + else: + nwbfilepath = nwbfilepath[0] + logging.info(f"Reading all events {indices} from NWB file {nwbfilepath} to save to {outputPath}") + + with NWBHDF5IO(nwbfilepath, "r") as io: + nwbfile = io.read() + fiber_photometry_response_series = nwbfile.acquisition[response_series_name] + data = fiber_photometry_response_series.data[:] + sampling_rate = getattr(fiber_photometry_response_series, "rate", None) + timestamps = getattr(fiber_photometry_response_series, "timestamps", None) + if sampling_rate is None and timestamps is not None: + sampling_rate = 1 / np.median(np.diff(timestamps)) + elif timestamps is None and sampling_rate is not None: + timestamps = np.arange(0, data.shape[0]) / sampling_rate + else: + raise Exception(f"Fiber photometry response series {response_series_name} must have rate or timestamps.") + + for index in indices: + event = f"event_{index}" + print(f"Reading data for event {event} ...") + S = {} + S["storename"] = str(event) + S["sampling_rate"] = sampling_rate + S["timestamps"] = timestamps[::npoints] + S["data"] = data[:, index] + S["npoints"] = npoints + S["channels"] = np.ones_like(S["timestamps"]) + + save_dict_to_hdf5(S, event, outputPath) + check_data(S, filepath, event, outputPath) + logging.info("Data for event {} fetched and stored.".format(event)) + + def main(input_parameters): logger.info("run") try: