diff --git a/README.rst b/README.rst index c7a6f19..5701b6f 100644 --- a/README.rst +++ b/README.rst @@ -164,54 +164,7 @@ Compatible pulse sequences for **fid** data sets: * STEAM.ppg * igFLASH.ppg -ParaVision v6.0.1 -""""""""""""""""" -Compatible data set types: - -* **fid** -* **2dseq** -* **rawdata.job0** -* **rawdata.Navigator** - -Compatible pulse sequences for **fid** data sets: - -* FLASH.ppg, -* FLASHAngio.ppg -* IgFLASH.ppg -* MGE.ppg -* MSME.ppg -* RARE.ppg -* FAIR_RARE.ppg -* RAREVTR.ppg -* RAREst.ppg -* MDEFT.ppg -* FISP.ppg -* FLOWMAP.ppg -* DtiStandard.ppg -* EPI.ppg -* FAIR_EPI.ppg -* CASL_EPI.ppg -* DtiEpi.ppg -* T1_EPI.ppg -* T2_EPI.ppg -* T2S_EPI.ppg -* SPIRAL.ppg -* DtiSpiral.ppg -* UTE.ppg -* UTE3D.ppg -* ZTE.ppg -* CSI.ppg -* FieldMap.ppg -* SINGLEPULSE.ppg -* NSPECT.ppg -* EPSI.ppg -* PRESS.ppg -* STEAM.ppg -* ISIS.ppg -* CPMG.ppg -* RfProfile.ppg - -ParaVision v7.0.0 +ParaVision v6.0.1 and v7.0.0 """"""""""""""""" Compatible data set types: @@ -259,8 +212,9 @@ Compatible pulse sequences for **fid** data sets: * RfProfile.ppg -ParaVision v360 +ParaVision 360 v1.1 v3.0-v3.7 """"""""""""""" +Reading rawdata is supported only in a basic form, no reshaping into k-space is supported at the moment. Compatible data set types: * **2dseq** diff --git a/brukerapi/config/properties_fid_core.json b/brukerapi/config/properties_fid_core.json index bf8d491..015d89a 100644 --- a/brukerapi/config/properties_fid_core.json +++ b/brukerapi/config/properties_fid_core.json @@ -223,9 +223,7 @@ ["#PULPROG[1:-1]", [ "EPI.ppg", - "DtiEpi.ppg", "navigatorEPI_OM.ppg", - "EPSI.ppg", "FAIR_EPI.ppg", "CASL_EPI.ppg", "T1_EPI.ppg", @@ -236,6 +234,18 @@ "#ACQ_sw_version in ['', '', '', '']" ] }, + { + "cmd": "'dEPI'" , + "conditions": [ + ["#PULPROG[1:-1]", + [ + "DtiEpi.ppg", + "EPSI.ppg" + ] + ], + "#ACQ_sw_version in ['', '', '', '']" + ] + }, { "cmd": "'SPECTROSCOPY'" , "conditions": [ @@ -322,7 +332,13 @@ ] }, { - "cmd": "#NR", + "cmd": "#NSegments*#NI*#NR*(#ACQ_size[2] if len(#ACQ_size)>2 else 1)", + "conditions": [ + "@scheme_id=='dEPI'" + ] + }, + { + "cmd": "#NR*#NI", "conditions": [ "@scheme_id=='SPECTROSCOPY'" ] @@ -430,9 +446,23 @@ "@scheme_id=='EPI'" ] }, + { + "cmd": [ + "#PVM_EncMatrix[0] * #PVM_EncMatrix[1] // #NSegments", + "#PVM_EncNReceivers", + "#NSegments", + "#NI", + "#NR", + "#ACQ_size[2] if len(#ACQ_size)>2 else 1" + ], + "conditions": [ + "@scheme_id=='dEPI'" + ] + }, { "cmd": [ "#ACQ_size.tuple[0] // 2", + "#NI", "#NR" ], "conditions": [ @@ -527,11 +557,24 @@ { "cmd": [0,2,3,4,1], "conditions": [ - "@scheme_id in ['EPI', 'SPIRAL']" + "@scheme_id in ['SPIRAL']" + ] + }, + { + "cmd": [0,2,3,4,1], + "conditions": [ + "@scheme_id in ['EPI']" + ] + }, + { + "cmd": [0,2,3,4,1,5], + "conditions": [ + "@scheme_id in ['dEPI']" ] }, + { - "cmd": [0,1], + "cmd": [0,1,2], "conditions": [ "@scheme_id=='SPECTROSCOPY'" ] @@ -604,9 +647,23 @@ "@scheme_id=='EPI'" ] }, + { + "cmd": [ + "#PVM_EncMatrix[0]", + "#PVM_EncMatrix[1]", + "#NI", + "#NR", + "#PVM_EncNReceivers", + "#ACQ_size[2] if len(#ACQ_size)>2 else 1" + ], + "conditions": [ + "@scheme_id=='dEPI'" + ] + }, { "cmd": [ "#ACQ_size.tuple[0] // 2", + "#NI", "#NR" ], "conditions": [ @@ -701,7 +758,7 @@ "'channel'" ], "conditions": [ - ["@scheme_id",["CART_2D","RADIAL","EPI","SPIRAL","ZTE"]] + ["@scheme_id",["CART_2D","RADIAL","EPI","dEPI","SPIRAL","ZTE"]] ] }, { diff --git a/brukerapi/config/properties_rawdata_core.json b/brukerapi/config/properties_rawdata_core.json index dba136b..36199ba 100644 --- a/brukerapi/config/properties_rawdata_core.json +++ b/brukerapi/config/properties_rawdata_core.json @@ -5,7 +5,7 @@ "conditions": [ "#GO_raw_data_format=='GO_32BIT_SGN_INT'", "#BYTORDA=='little'", - ["#ACQ_sw_version",["", "", ""]] + ["#ACQ_sw_version",["", "", "",""]] ] }, { @@ -13,7 +13,7 @@ "conditions": [ "#GO_raw_data_format=='GO_16BIT_SGN_INT'", "#BYTORDA=='little'", - ["#ACQ_sw_version",["", "", ""]] + ["#ACQ_sw_version",["", "", "",""]] ] }, { @@ -21,7 +21,7 @@ "conditions": [ "#GO_raw_data_format=='GO_32BIT_FLOAT'", "#BYTORDA=='little'", - ["#ACQ_sw_version",["", "", ""]] + ["#ACQ_sw_version",["", "", "",""]] ] }, { @@ -29,7 +29,7 @@ "conditions": [ "#GO_raw_data_format=='GO_32BIT_SGN_INT'", "#BYTORDA=='big'", - ["#ACQ_sw_version",["", "", ""]] + ["#ACQ_sw_version",["", "", "",""]] ] }, { @@ -37,7 +37,7 @@ "conditions": [ "#GO_raw_data_format=='GO_16BIT_SGN_INT'", "#BYTORDA=='big'", - ["#ACQ_sw_version",["", "", ""]] + ["#ACQ_sw_version",["", "", "",""]] ] }, { @@ -45,7 +45,7 @@ "conditions": [ "#GO_raw_data_format=='GO_32BIT_FLOAT'", "#BYTORDA=='big'", - ["#ACQ_sw_version",["", "", ""]] + ["#ACQ_sw_version",["", "", "",""]] ] }, { @@ -53,7 +53,7 @@ "conditions": [ "#ACQ_word_size=='_32_BIT'", "#BYTORDA=='little'", - ["#ACQ_sw_version",[""]] + ["#ACQ_sw_version",["","","","","","","","",""]] ] } ], @@ -61,9 +61,16 @@ { "cmd": "#ACQ_jobs.primed_dict(7)['<{}>'.format(@subtype)]", "conditions": [ - "#ACQ_sw_version in ['']" + ["#ACQ_sw_version",[""]] ] }, + { + "cmd": "[v for v in #ACQ_jobs.nested if v[-1] == '<{}>'.format(@subtype)][0]", + "conditions": [ + ["#ACQ_sw_version",["","","","","","","",""]] + ] + }, + { "cmd": "#ACQ_jobs.nested[0]", "conditions": [ @@ -88,7 +95,7 @@ ], "shape_storage": [ { - "cmd": "(@job_desc[0],) + (@job_desc[3],)", + "cmd": "(@job_desc[0],) + (#PVM_EncNReceivers,) + (@job_desc[3],)", "conditions": [] } ] diff --git a/brukerapi/config/properties_rawdata_custom.json b/brukerapi/config/properties_rawdata_custom.json index 9e26dfe..902a02d 100644 --- a/brukerapi/config/properties_rawdata_custom.json +++ b/brukerapi/config/properties_rawdata_custom.json @@ -1 +1,65 @@ -{} \ No newline at end of file +{ + "subj_id": [ + { + "cmd": "#SUBJECT_id[1:-1]", + "conditions": [ + + ] + }, + { + "cmd": "''", + "conditions": [ + + ] + } + ], + "study_id": [ + { + "cmd": "str(#SUBJECT_study_nr)", + "conditions": [ + + ] + }, + { + "cmd": "''", + "conditions": [ + + ] + } + ], + "exp_id": [ + { + "cmd": "@path.parent.name", + "conditions": [ + + ] + }, + { + "cmd": "''", + "conditions": [ + + ] + } + ], + "id": [ + { + "cmd": "f'RawData_{@subtype}_{@exp_id}_{@subj_id}_{@study_id}'", + "conditions": [ + ] + } + ], + "TR": [ + { + "cmd": "#PVM_RepetitionTime", + "conditions": [], + "unit": "ms" + } + ], + "TE": [ + { + "cmd": "#PVM_EchoTime", + "conditions": [], + "unit": "ms" + } + ] +} \ No newline at end of file diff --git a/brukerapi/dataset.py b/brukerapi/dataset.py index 959586f..f52fc2f 100644 --- a/brukerapi/dataset.py +++ b/brukerapi/dataset.py @@ -546,10 +546,10 @@ def _read_binary_file(self, path, dtype, shape): 1D ndarray containing the full data vector """ # TODO debug with this - # try: - # assert os.stat(str(path)).st_size == np.prod(shape) * dtype.itemsize - # except AssertionError: - # raise ValueError('Dimension missmatch') + try: + assert os.stat(str(path)).st_size == np.prod(shape) * dtype.itemsize + except AssertionError: + raise ValueError("Dimension mismatch") from AssertionError return np.array(np.memmap(path, dtype=dtype, shape=shape, order="F")[:]) @@ -594,7 +594,7 @@ def write(self, path, **kwargs): path = Path(path) - if path.name != self.type: + if path.name.split(".")[0] != self.type: raise DatasetTypeMissmatch parent = path.parent diff --git a/brukerapi/folders.py b/brukerapi/folders.py index 209398c..4e20546 100644 --- a/brukerapi/folders.py +++ b/brukerapi/folders.py @@ -217,7 +217,7 @@ def make_tree(self, *, recursive: bool = True) -> list: children.append(Folder(path, parent=self, recursive=recursive, dataset_index=self._dataset_index, dataset_state=self._dataset_state)) continue - if path.name in self._dataset_index: + if path.name in self._dataset_index or (path.name.partition(".")[0] in self._dataset_index and "rawdata" in path.name): try: children.append(Dataset(path, **self._dataset_state)) except (UnsuportedDatasetType, IncompleteDataset, NotADatasetDir): diff --git a/brukerapi/jcampdx.py b/brukerapi/jcampdx.py index 6ffa61f..8ca5f77 100644 --- a/brukerapi/jcampdx.py +++ b/brukerapi/jcampdx.py @@ -1,4 +1,3 @@ -import ast import json import re from collections import OrderedDict @@ -19,14 +18,31 @@ "SIZE_BRACKET": r"^\([^\(\)<>]*\)(?!$)", "LIST_DELIMETER": ", ", "EQUAL_SIGN": "=", - "SINGLE_NUMBER": r"-?[\d.]+(?:e[+-]?\d+)?", + "SINGLE_NUMBER": r"-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?", "PARALLEL_BRACKET": r"\) ", "GEO_OBJ": r"\(\(\([\s\S]*\)[\s\S]*\)[\s\S]*\)", "HEADER": "TITLE|JCAMPDX|JCAMP-DX|DATA TYPE|DATATYPE|ORIGIN|OWNER", "VERSION_TITLE": "JCAMPDX|JCAMP-DX", } + MAX_LINE_LEN = 78 +# Precompile all regexes +_COMPILED_GRAMMAR = {k: re.compile(v) if k not in ["LIST_DELIMETER", "EQUAL_SIGN"] else v for k, v in GRAMMAR.items()} + +# Example usage: +_COMMENT_RE = _COMPILED_GRAMMAR["COMMENT_LINE"] +_USER_DEFINED_RE = _COMPILED_GRAMMAR["USER_DEFINED"] +_TRAILING_EOL_RE = _COMPILED_GRAMMAR["TRAILING_EOL"] +_DATA_LABEL_RE = _COMPILED_GRAMMAR["DATA_LABEL"] +_SIZE_BRACKET_RE = _COMPILED_GRAMMAR["SIZE_BRACKET"] +_SINGLE_NUMBER_RE = _COMPILED_GRAMMAR["SINGLE_NUMBER"] +_PARALLEL_BRACKET_RE = _COMPILED_GRAMMAR["PARALLEL_BRACKET"] +_GEO_OBJ_RE = _COMPILED_GRAMMAR["GEO_OBJ"] +_HEADER_RE = _COMPILED_GRAMMAR["HEADER"] +_VERSION_TITLE_RE = _COMPILED_GRAMMAR["VERSION_TITLE"] +_PARAMETER_RE = _COMPILED_GRAMMAR["PARAMETER"] + class Parameter: """ @@ -106,7 +122,7 @@ def _encode_parameter(self, var): @property def key(self): - return re.sub("##", "", re.sub(r"\$", "", self.key_str)).rstrip() + return self.key_str.replace("##", "").replace("$", "").rstrip() @key.setter def key(self, key): @@ -115,7 +131,7 @@ def key(self, key): @property def user_defined(self): - return bool(re.search(GRAMMAR["USER_DEFINED"], self.key_str)) + return bool(_USER_DEFINED_RE.search(self.key_str)) @property def tuple(self): @@ -186,7 +202,7 @@ def from_values(cls, version, key, size, value, user_defined): @property def value(self): - val_str = re.sub(r"\n", "", self.val_str) + val_str = self.val_str.replace("\n", "") # unwrap wrapped list if re.match(r"@[0-9]*\*", val_str) is not None: @@ -296,7 +312,7 @@ def size(self, size): @classmethod def parse_value(cls, val_str, size_bracket=None): # remove \n - val_str = re.sub(r"\n", "", val_str) + val_str = val_str.replace("\n", "") # sharp string if val_str.startswith("<") and val_str.endswith(">"): @@ -307,9 +323,12 @@ def parse_value(cls, val_str, size_bracket=None): return np.array(val_strs) # int/float - if len(re.findall(GRAMMAR["SINGLE_NUMBER"], val_str)) == 1: + if _SINGLE_NUMBER_RE.fullmatch(val_str): try: - value = ast.literal_eval(val_str) + try: + value = int(val_str) + except ValueError: + value = float(val_str) # if value is int, or float, return, tuple will be parsed as list later on if isinstance(value, (float, int)): @@ -319,7 +338,7 @@ def parse_value(cls, val_str, size_bracket=None): # list if val_str.startswith("(") and val_str.endswith(""): - val_strs = re.split(GRAMMAR["LIST_DELIMETER"], val_str[1:-1]) + val_strs = val_str[1:-1].split(", ") value = [] for val_str in val_strs: @@ -327,7 +346,7 @@ def parse_value(cls, val_str, size_bracket=None): return value - val_strs = re.split(" ", val_str) + val_strs = val_str.split(" ") if len(val_strs) > 1: # try casting into int, or float array, if both of casts fail, it should be string array @@ -406,7 +425,7 @@ def serialize_ndarray(cls, value): @classmethod def split_parallel_lists(cls, val_str): - lst = re.split(GRAMMAR["PARALLEL_BRACKET"], val_str) + lst = _PARALLEL_BRACKET_RE.split(val_str) if len(lst) == 1: return lst[0] @@ -497,7 +516,7 @@ def __init__(self, version, key, size_bracket, value): @property def value(self): - val_list = re.split(GRAMMAR["DATA_DELIMETERS"], self.val_str) + val_list = self.val_str.replace("\n", ",").split(", ") data = [GenericParameter.parse_value(x) for x in val_list] return np.reshape(data, (2, -1)) @@ -641,16 +660,15 @@ def version(self): return self.params["JCAMPDX"] try: - _, version = JCAMPDX.load_parameter(self.path, "JCAMPDX") - return version.value - except (InvalidJcampdxFile, ParameterNotFound): - pass - - try: - _, version = JCAMPDX.load_parameter(self.path, "JCAMP-DX") - return version.value - except (InvalidJcampdxFile, ParameterNotFound): - pass + with self.path.open("r") as f: + for _ in range(10): + line = f.readline() + if line.startswith("##JCAMPDX="): + return line.strip().split("=", 1)[1] + if line.startswith("##JCAMP-DX="): + return line.strip().split("=", 1)[1] + except (UnicodeDecodeError, OSError) as e: + raise InvalidJcampdxFile from e raise InvalidJcampdxFile(self.path) @@ -790,13 +808,13 @@ def read_jcampdx(cls, path): raise JcampdxFileError(f"file {path} is not a text file") from e # remove all comments - content = re.sub(GRAMMAR["COMMENT_LINE"], "", content) + content = _COMMENT_RE.sub("", content) # split into individual entries - content = re.split(GRAMMAR["PARAMETER"], content)[1:-1] + content = _PARAMETER_RE.split(content)[1:-1] # strip trailing EOL - content = [re.sub(GRAMMAR["TRAILING_EOL"], "", x) for x in content] + content = [_TRAILING_EOL_RE.sub("", x) for x in content] # ASSUMPTION the jcampdx version string is in the second row try: @@ -821,11 +839,12 @@ def read_jcampdx(cls, path): @classmethod def handle_jcampdx_line(cls, line, version): key_str, size_str, val_str = cls.divide_jcampdx_line(line) - if re.search(GRAMMAR["GEO_OBJ"], line) is not None: + + if _GEO_OBJ_RE.search(line) is not None: parameter = GeometryParameter(key_str, size_str, val_str, version) - elif re.search(GRAMMAR["DATA_LABEL"], line): + elif _DATA_LABEL_RE.search(line): parameter = DataParameter(key_str, size_str, val_str, version) - elif re.search(GRAMMAR["HEADER"], key_str): + elif _HEADER_RE.search(key_str): parameter = HeaderParameter(key_str, size_str, val_str, version) else: parameter = GenericParameter(key_str, size_str, val_str, version) diff --git a/brukerapi/schemas.py b/brukerapi/schemas.py index 12fde31..b3b3db4 100644 --- a/brukerapi/schemas.py +++ b/brukerapi/schemas.py @@ -426,7 +426,7 @@ class SchemaRawdata(Schema): def layouts(self): layouts = {} layouts["raw"] = (int(self._dataset.job_desc[0] / 2), self._dataset.channels, int(self._dataset.job_desc[3])) - layouts["shape_storage"] = (2, int(self._dataset.job_desc[0] / 2), self._dataset.channels, int(self._dataset.job_desc[3])) + layouts["shape_storage"] = (int(self._dataset.job_desc[0]), self._dataset.channels, int(self._dataset.job_desc[3])) layouts["final"] = layouts["raw"] return layouts @@ -434,9 +434,13 @@ def deserialize(self, data, layouts): return data[0::2, ...] + 1j * data[1::2, ...] def serialize(self, data, layouts): - data_ = np.zeros(layouts["shape_storage"], dtype=self.numpy_dtype, order="F") - data_[0, ...] = data.real - data_[1, ...] = data.imag + # storage array + data_ = np.zeros(layouts["shape_storage"], dtype=self._dataset.numpy_dtype, order="F") + + # interlace real and imag along first axis + data_[0::2, ...] = data.real + data_[1::2, ...] = data.imag + return data_ diff --git a/docs/source/compatibility.rst b/docs/source/compatibility.rst index ace5f7a..3c3c234 100644 --- a/docs/source/compatibility.rst +++ b/docs/source/compatibility.rst @@ -23,7 +23,7 @@ Compatible pulse sequences for **fid** data sets: -ParaVision v6.0.1 +ParaVision v6.0.1 and v7.0.0 """"""""""""""""" Compatible data set types: @@ -70,7 +70,9 @@ Compatible pulse sequences for **fid** data sets: * CPMG.ppg * RfProfile.ppg -ParaVision v360 + + +ParaVision v360 1.1 3.0-3.7 """"""""""""""" Compatible data set types: diff --git a/docs/source/conf.py b/docs/source/conf.py index 0818ec2..3727ce6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,7 +23,7 @@ author = "Tomas Psorn" # The full version, including alpha/beta/rc tags -release = "0.1.2" +release = "0.2.0" # -- General configuration --------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 44a7d2d..66875b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,9 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "brukerapi" -version = "0.1.10" +version = "0.2.0" description = "Bruker API" -authors = [{ name = "Tomas Psorn", email = "tomaspsorn@isibrno.cz" },{ name = "Jiri Vitous", email = "vitous@isibrno.cz" }] +authors = [ + { name = "Tomas Psorn", email = "tomaspsorn@isibrno.cz" }, + { name = "Jiri Vitous", email = "vitous@isibrno.cz" }, +] license = { text = "MIT" } readme = "README.rst" requires-python = ">=3.8" @@ -26,32 +29,32 @@ include-package-data = true zip-safe = false [project.optional-dependencies] -dev = ["pytest", "zenodo_get","ruff","pytest-cov"] - +dev = ["pytest", "zenodo_get", "ruff", "pytest-cov"] [tool.ruff.lint] select = [ - "E","W", # pycodestyle - "F", # Pyflakes - "UP", # pyupgrade - "B", # flake8-bugbear - "SIM", # flake8-simplify - "I", # isort - "PERF",# Perflint - "C4", # Flake8 comprehensions - "RET", # Flake8 return - "FBT", # Flake8 boolean simplifications - "LOG", # Flake8 logging - "PL", # Flake8 pylint + "E", + "W", # pycodestyle + "F", # Pyflakes + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "I", # isort + "PERF", # Perflint + "C4", # Flake8 comprehensions + "RET", # Flake8 return + "FBT", # Flake8 boolean simplifications + "LOG", # Flake8 logging + "PL", # Flake8 pylint "B9", # Bugbear additional checks "TC", "RUF", "PT", "FLY", - "NPY" + "NPY", ] -ignore= [ +ignore = [ "PLR2004", "PLR0915", "PLR0913", @@ -63,7 +66,7 @@ ignore= [ "RET504", "RUF012", "PERF401", - "B905" + "B905", ] @@ -75,4 +78,3 @@ target-version = "py313" quote-style = "double" indent-style = "space" docstring-code-format = true - diff --git a/test/conftest.py b/test/conftest.py index e7e6fab..9b4ca92 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -167,12 +167,11 @@ def pytest_generate_tests(metafunc): data_items = [] for dataset_name in requested: dataset_root = TEST_DATA_ROOT / dataset_name - for subfolder in dataset_root.iterdir(): - if subfolder.is_dir(): - folder_obj = Folder(subfolder, dataset_state={"parameter_files": [], "property_files": [], "load": 2}) - for dataset in folder_obj.get_dataset_list_rec(): - data_ids.append(f"{dataset_name}/{dataset.id}") - data_items.append((dataset.path, ref_state.get(dataset.id, {}))) + folder_obj = Folder(dataset_root, dataset_state={"parameter_files": [], "property_files": [], "load": 2}) + for dataset in folder_obj.get_dataset_list_rec(): + data_ids.append(f"{dataset_name}/{dataset.id}") + data_items.append((dataset.path, ref_state.get(dataset.id, {}))) + metafunc.parametrize("test_data", data_items, indirect=True, ids=data_ids) # ------------------------------- @@ -183,12 +182,12 @@ def pytest_generate_tests(metafunc): ra_items = [] for dataset_name in requested: dataset_root = TEST_DATA_ROOT / dataset_name - for subfolder in dataset_root.iterdir(): - if subfolder.is_dir(): - folder_obj = Folder(subfolder, dataset_state={"parameter_files": [], "property_files": [], "load": 2}) - for dataset in _find_2dseq_datasets(dataset_name): - ra_ids.append(f"{dataset_name}/{dataset.id}") - ra_items.append((dataset.path, ref_state.get(dataset.id, {}))) + + folder_obj = Folder(dataset_root, dataset_state={"parameter_files": [], "property_files": [], "load": 2}) + for dataset in _find_2dseq_datasets(dataset_name): + ra_ids.append(f"{dataset_name}/{dataset.id}") + ra_items.append((dataset.path, ref_state.get(dataset.id, {}))) + metafunc.parametrize("test_ra_data", ra_items, indirect=True, ids=ra_ids) # ------------------------------- diff --git a/test/test_dataset.py b/test/test_dataset.py index 250f420..a642345 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -1,5 +1,6 @@ import contextlib import json +import os from pathlib import Path import numpy as np @@ -34,9 +35,11 @@ def test_data_load(test_data): dataset = Dataset(test_data[0]) return # For now Disable testing array equality + if not os.path.exists(str(dataset.path) + ".npz"): + return with np.load(str(dataset.path) + ".npz") as data: - assert np.array_equal(dataset.data, data["data"]) + assert np.array_equal(np.squeeze(dataset.data), np.squeeze(data["data"])) def test_data_save(test_data, tmp_path, WRITE_TOLERANCE):