From ca9a61d7e66941fdc31570c1a54c8b0ce34a2d07 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 31 Jan 2026 18:51:44 +0000 Subject: [PATCH 1/3] Update download urls from figshare to s3 bucket --- cebra/data/assets.py | 126 ++++++++++++++++++++++++++++++ cebra/data/base.py | 22 ++++-- cebra/datasets/hippocampus.py | 36 ++++++--- cebra/datasets/monkey_reaching.py | 76 ++++++++++++------ cebra/datasets/synthetic_data.py | 60 +++++++++----- tests/test_datasets.py | 109 ++++++++++++++++++++++++++ 6 files changed, 366 insertions(+), 63 deletions(-) diff --git a/cebra/data/assets.py b/cebra/data/assets.py index adea8413..683f3b6c 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -20,6 +20,7 @@ # limitations under the License. # +import gzip import hashlib import re import warnings @@ -140,3 +141,128 @@ def calculate_checksum(file_path: str) -> str: for chunk in iter(lambda: file.read(4096), b""): checksum.update(chunk) return checksum.hexdigest() + + +def download_and_extract_gzipped_file(url: str, + expected_checksum: str, + gzipped_checksum: str, + location: str, + file_name: str, + retry_count: int = 0) -> Optional[str]: + """Download a gzipped file from the given URL, verify checksums, and extract. + + This function downloads a gzipped file, verifies the checksum of the gzipped + file, extracts it, and then verifies the checksum of the extracted file. + + Args: + url: The URL to download the gzipped file from. + expected_checksum: The expected MD5 checksum of the extracted file. + gzipped_checksum: The expected MD5 checksum of the gzipped file. + location: The directory where the file will be saved. + file_name: The name of the final extracted file (without .gz extension). + retry_count: The number of retry attempts (default: 0). + + Returns: + The path of the extracted file if successful, None otherwise. + + Raises: + RuntimeError: If the maximum retry count is exceeded. + requests.HTTPError: If the download fails. + """ + + # Check if the final extracted file already exists with correct checksum + location_path = Path(location) + final_file_path = location_path / file_name + + if final_file_path.exists(): + existing_checksum = calculate_checksum(final_file_path) + if existing_checksum == expected_checksum: + return final_file_path + + if retry_count >= _MAX_RETRY_COUNT: + raise RuntimeError( + f"Exceeded maximum retry count ({_MAX_RETRY_COUNT}). " + f"Unable to download the file from {url}") + + # Create the directory and any necessary parent directories + location_path.mkdir(parents=True, exist_ok=True) + + # Download the gzipped file + gz_file_path = location_path / f"{file_name}.gz" + + response = requests.get(url, stream=True) + + # Check if the request was successful + if response.status_code != 200: + raise requests.HTTPError( + f"Error occurred while downloading the file. Response code: {response.status_code}" + ) + + total_size = int(response.headers.get("Content-Length", 0)) + checksum = hashlib.md5() # create checksum for gzipped file + + # Download the gzipped file + with open(gz_file_path, "wb") as file: + with tqdm.tqdm(total=total_size, + unit="B", + unit_scale=True, + desc="Downloading") as progress_bar: + for data in response.iter_content(chunk_size=1024): + file.write(data) + checksum.update(data) + progress_bar.update(len(data)) + + downloaded_gz_checksum = checksum.hexdigest() + + # Verify gzipped file checksum + if downloaded_gz_checksum != gzipped_checksum: + warnings.warn( + f"Gzipped file checksum verification failed. Deleting '{gz_file_path}'." + ) + gz_file_path.unlink() + warnings.warn("Gzipped file deleted. Retrying download...") + return download_and_extract_gzipped_file(url, expected_checksum, + gzipped_checksum, location, + file_name, retry_count + 1) + + print("Gzipped file checksum verified. Extracting...") + + # Extract the gzipped file + try: + with gzip.open(gz_file_path, 'rb') as f_in: + with open(final_file_path, 'wb') as f_out: + # Extract with progress (estimate based on typical compression ratio) + extracted_size = 0 + while True: + chunk = f_in.read(8192) + if not chunk: + break + f_out.write(chunk) + extracted_size += len(chunk) + except Exception as e: + warnings.warn(f"Extraction failed: {e}. Deleting files and retrying...") + if gz_file_path.exists(): + gz_file_path.unlink() + if final_file_path.exists(): + final_file_path.unlink() + return download_and_extract_gzipped_file(url, expected_checksum, + gzipped_checksum, location, + file_name, retry_count + 1) + + # Verify extracted file checksum + extracted_checksum = calculate_checksum(final_file_path) + if extracted_checksum != expected_checksum: + warnings.warn( + "Extracted file checksum verification failed. Deleting files.") + gz_file_path.unlink() + final_file_path.unlink() + warnings.warn("Files deleted. Retrying download...") + return download_and_extract_gzipped_file(url, expected_checksum, + gzipped_checksum, location, + file_name, retry_count + 1) + + # Clean up the gzipped file after successful extraction + gz_file_path.unlink() + + print(f"Extraction complete. Dataset saved in '{final_file_path}'") + return final_file_path diff --git a/cebra/data/base.py b/cebra/data/base.py index 51199cec..acdcff53 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -55,6 +55,7 @@ def __init__(self, download=False, data_url=None, data_checksum=None, + gzipped_checksum=None, location=None, file_name=None): @@ -64,6 +65,7 @@ def __init__(self, self.download = download self.data_url = data_url self.data_checksum = data_checksum + self.gzipped_checksum = gzipped_checksum self.location = location self.file_name = file_name @@ -78,11 +80,21 @@ def __init__(self, "Missing data checksum. Please provide the checksum to verify the data integrity." ) - cebra_data_assets.download_file_with_progress_bar( - url=self.data_url, - expected_checksum=self.data_checksum, - location=self.location, - file_name=self.file_name) + # Use gzipped download if gzipped_checksum is provided + if self.gzipped_checksum is not None: + cebra_data_assets.download_and_extract_gzipped_file( + url=self.data_url, + expected_checksum=self.data_checksum, + gzipped_checksum=self.gzipped_checksum, + location=self.location, + file_name=self.file_name) + else: + # Fall back to legacy download for backward compatibility + cebra_data_assets.download_file_with_progress_bar( + url=self.data_url, + expected_checksum=self.data_checksum, + location=self.location, + file_name=self.file_name) @property @abc.abstractmethod diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index 05c47acb..a8ce12d1 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -50,27 +50,35 @@ rat_dataset_urls = { "achilles": { "url": - "https://figshare.com/ndownloader/files/40849463?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz", "checksum": - "c52f9b55cbc23c66d57f3842214058b8" + "c52f9b55cbc23c66d57f3842214058b8", + "gzipped_checksum": + "5d7b243e07b24c387e5412cd5ff46f0b" }, "buddy": { "url": - "https://figshare.com/ndownloader/files/40849460?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz", "checksum": - "36341322907708c466871bf04bc133c2" + "36341322907708c466871bf04bc133c2", + "gzipped_checksum": + "339290585be2188f48a176f05aaf5df6" }, "cicero": { "url": - "https://figshare.com/ndownloader/files/40849457?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz", "checksum": - "a83b02dbdc884fdd7e53df362499d42f" + "a83b02dbdc884fdd7e53df362499d42f", + "gzipped_checksum": + "f262a87d2e59f164cb404cd410015f3a" }, "gatsby": { "url": - "https://figshare.com/ndownloader/files/40849454?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz", "checksum": - "2b889da48178b3155011c12555342813" + "2b889da48178b3155011c12555342813", + "gzipped_checksum": + "564e431c19e55db2286a9d64c86a94c4" } } @@ -95,11 +103,13 @@ def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True): location = pathlib.Path(root) / "rat_hippocampus" file_path = location / f"{name}.jl" - super().__init__(download=download, - data_url=rat_dataset_urls[name]["url"], - data_checksum=rat_dataset_urls[name]["checksum"], - location=location, - file_name=f"{name}.jl") + super().__init__( + download=download, + data_url=rat_dataset_urls[name]["url"], + data_checksum=rat_dataset_urls[name]["checksum"], + gzipped_checksum=rat_dataset_urls[name].get("gzipped_checksum"), + location=location, + file_name=f"{name}.jl") data = joblib.load(file_path) self.neural = torch.from_numpy(data["spikes"]).float() diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 05071b12..22479455 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -160,75 +160,99 @@ def _get_info(trial_info, data): monkey_reaching_urls = { "all_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz", "checksum": - "dea556301fa4fafa86e28cf8621cab5a" + "dea556301fa4fafa86e28cf8621cab5a", + "gzipped_checksum": + "399abc6e9ef0b23a0d6d057c6f508939" }, "all_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz", "checksum": - "e280e4cd86969e6fd8bfd3a8f402b2fe" + "e280e4cd86969e6fd8bfd3a8f402b2fe", + "gzipped_checksum": + "eb52c8641fe83ae2a278b372ddec5f69" }, "all_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz", "checksum": - "25d3ff2c15014db8b8bf2543482ae881" + "25d3ff2c15014db8b8bf2543482ae881", + "gzipped_checksum": + "7688245cf15e0b92503af943ce9f66aa" }, "all_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz", "checksum": - "8cd25169d31f83ae01b03f7b1b939723" + "8cd25169d31f83ae01b03f7b1b939723", + "gzipped_checksum": + "b169fc008b4d092fe2a1b7e006cd17a7" }, "active_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz", "checksum": - "c626acea5062122f5a68ef18d3e45e51" + "c626acea5062122f5a68ef18d3e45e51", + "gzipped_checksum": + "b7b86e2ae00bb71341de8fc352dae097" }, "active_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz", "checksum": - "72a48056691078eee22c36c1992b1d37" + "72a48056691078eee22c36c1992b1d37", + "gzipped_checksum": + "56687c633efcbff6c56bbcfa35597565" }, "active_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz", "checksum": - "35b7e060008a8722c536584c4748f2ea" + "35b7e060008a8722c536584c4748f2ea", + "gzipped_checksum": + "2057ef1846908a69486a61895d1198e8" }, "active_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz", "checksum": - "dd58eb1e589361b4132f34b22af56b79" + "dd58eb1e589361b4132f34b22af56b79", + "gzipped_checksum": + "60b8e418f234877351fe36f1efc169ad" }, "passive_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz", "checksum": - "bbb1bc9d8eec583a46f6673470fc98ad" + "bbb1bc9d8eec583a46f6673470fc98ad", + "gzipped_checksum": + "afb257efa0cac3ccd69ec80478d63691" }, "passive_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz", "checksum": - "f22e05a69f70e18ba823a0a89162a45c" + "f22e05a69f70e18ba823a0a89162a45c", + "gzipped_checksum": + "24d98d7d41a52591f838c41fe83dc2c6" }, "passive_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz", "checksum": - "42453ae3e4fd27d82d297f78c13cd6b7" + "42453ae3e4fd27d82d297f78c13cd6b7", + "gzipped_checksum": + "f1ff4e9b7c4a0d7fa9dcd271893f57ab" }, "passive_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz", "checksum": - "2dcc10c27631b95a075eaa2d2297bb4a" + "2dcc10c27631b95a075eaa2d2297bb4a", + "gzipped_checksum": + "311fcb6a3e86022f12d78828f7bd29d5" } } @@ -270,6 +294,8 @@ def __init__(self, data_url=monkey_reaching_urls[f"{self.load_session}_all.jl"]["url"], data_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"] ["checksum"], + gzipped_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"] + .get("gzipped_checksum"), location=self.path, file_name=f"{self.load_session}_all.jl", ) @@ -297,6 +323,8 @@ def split(self, split): ["url"], data_checksum=monkey_reaching_urls[ f"{self.load_session}_{split}.jl"]["checksum"], + gzipped_checksum=monkey_reaching_urls[ + f"{self.load_session}_{split}.jl"].get("gzipped_checksum"), location=self.path, file_name=f"{self.load_session}_{split}.jl", ) diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index 9288a93d..eab8d6cf 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -33,51 +33,67 @@ synthetic_data_urls = { "continuous_label_refractory_poisson": { "url": - "https://figshare.com/ndownloader/files/41668815?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_refractory_poisson.jl.gz", "checksum": - "fcd92bd283c528d5294093190f55ceba" + "fcd92bd283c528d5294093190f55ceba", + "gzipped_checksum": + "3641eed973b9cae972493c70b364e981" }, "continuous_label_t": { "url": - "https://figshare.com/ndownloader/files/41668818?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_t.jl.gz", "checksum": - "a6e76f274da571568fd2a4bf4cf48b66" + "a6e76f274da571568fd2a4bf4cf48b66", + "gzipped_checksum": + "1dc8805e8f0836c7c99e864100a65bff" }, "continuous_label_uniform": { "url": - "https://figshare.com/ndownloader/files/41668821?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_uniform.jl.gz", "checksum": - "e67400e77ac009e8c9bc958aa5151973" + "e67400e77ac009e8c9bc958aa5151973", + "gzipped_checksum": + "71d33bc56b89bc227da0990bf16e584b" }, "continuous_label_laplace": { "url": - "https://figshare.com/ndownloader/files/41668824?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_laplace.jl.gz", "checksum": - "41d7ce4ce8901ae7a5136605ac3f5ffb" + "41d7ce4ce8901ae7a5136605ac3f5ffb", + "gzipped_checksum": + "1563e4958031392d2b2e30cc4cd79b3f" }, "continuous_label_poisson": { "url": - "https://figshare.com/ndownloader/files/41668827?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_poisson.jl.gz", "checksum": - "a789828f9cca5f3faf36d62ebc4cc8a1" + "a789828f9cca5f3faf36d62ebc4cc8a1", + "gzipped_checksum": + "7691304ee061e0bf1e9bb5f2bb6b20e7" }, "continuous_label_gaussian": { "url": - "https://figshare.com/ndownloader/files/41668830?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_gaussian.jl.gz", "checksum": - "18d66a2020923e2cd67d2264d20890aa" + "18d66a2020923e2cd67d2264d20890aa", + "gzipped_checksum": + "0cb97a2c1eaa526e57d2248a333ea8e0" }, "continuous_poisson_gaussian_noise": { "url": - "https://figshare.com/ndownloader/files/41668833?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_poisson_gaussian_noise.jl.gz", "checksum": - "1a51461820c24a5bcaddaff3991f0ebe" + "1a51461820c24a5bcaddaff3991f0ebe", + "gzipped_checksum": + "5aa6b6eadf2b733562864d5b67bc6b8d" }, "sim_100d_poisson_cont_label": { "url": - "https://figshare.com/ndownloader/files/41668836?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/sim_100d_poisson_cont_label.npz.gz", "checksum": - "306b9c646e7b76a52cfd828612d700cb" + "306b9c646e7b76a52cfd828612d700cb", + "gzipped_checksum": + "768299435a167dedd57e29b1a6d5af63" } } @@ -98,11 +114,13 @@ def __init__(self, name, root=_DEFAULT_DATADIR, download=True): location = os.path.join(root, "synthetic") file_path = os.path.join(location, f"{name}.jl") - super().__init__(download=download, - data_url=synthetic_data_urls[name]["url"], - data_checksum=synthetic_data_urls[name]["checksum"], - location=location, - file_name=f"{name}.jl") + super().__init__( + download=download, + data_url=synthetic_data_urls[name]["url"], + data_checksum=synthetic_data_urls[name]["checksum"], + gzipped_checksum=synthetic_data_urls[name].get("gzipped_checksum"), + location=location, + file_name=f"{name}.jl") data = joblib.load(file_path) self.data = data #NOTE: making it backwards compatible with synth notebook. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e8e03ff0..36aa77f6 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import os import pathlib import tempfile @@ -384,6 +385,114 @@ def test_download_file_wrong_content_disposition(filename, url, file_name=filename) +def test_download_and_extract_gzipped_file(): + """Test downloading and extracting gzipped files with dual checksum verification.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + # Create a test file + test_content = b"Test dataset content for gzipped download" + test_filename = "test_dataset.jl" + test_gz_filename = f"{test_filename}.gz" + + # Calculate checksums + unzipped_checksum = cebra_data_assets.calculate_checksum.__wrapped__(test_content) \ + if hasattr(cebra_data_assets.calculate_checksum, '__wrapped__') \ + else hashlib.md5(test_content).hexdigest() + + # Create gzipped content + gzipped_content = gzip.compress(test_content) + gzipped_checksum = hashlib.md5(gzipped_content).hexdigest() + + # Mock the HTTP response + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + # Test successful download and extraction + result = cebra_data_assets.download_and_extract_gzipped_file( + url="http://example.com/test.jl.gz", + expected_checksum=unzipped_checksum, + gzipped_checksum=gzipped_checksum, + location=temp_dir, + file_name=test_filename) + + # Verify the file was extracted + assert result is not None + final_path = os.path.join(temp_dir, test_filename) + assert os.path.exists(final_path) + + # Verify the content is correct + with open(final_path, 'rb') as f: + extracted_content = f.read() + assert extracted_content == test_content + + # Verify the .gz file was cleaned up + gz_path = os.path.join(temp_dir, test_gz_filename) + assert not os.path.exists(gz_path) + + +def test_download_and_extract_gzipped_file_wrong_gzipped_checksum(): + """Test that wrong gzipped checksum raises error after retries.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + test_content = b"Test content" + gzipped_content = gzip.compress(test_content) + wrong_gz_checksum = "0" * 32 # Wrong checksum + + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + with pytest.raises(RuntimeError, + match="Exceeded maximum retry count"): + cebra_data_assets.download_and_extract_gzipped_file( + url="http://example.com/test.jl.gz", + expected_checksum=hashlib.md5(test_content).hexdigest(), + gzipped_checksum=wrong_gz_checksum, + location=temp_dir, + file_name="test.jl", + retry_count=2) + + +def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): + """Test that wrong unzipped checksum raises error after retries.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + test_content = b"Test content" + gzipped_content = gzip.compress(test_content) + gzipped_checksum = hashlib.md5(gzipped_content).hexdigest() + wrong_unzipped_checksum = "0" * 32 # Wrong checksum + + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + with pytest.raises(RuntimeError, + match="Exceeded maximum retry count"): + cebra_data_assets.download_and_extract_gzipped_file( + url="http://example.com/test.jl.gz", + expected_checksum=wrong_unzipped_checksum, + gzipped_checksum=gzipped_checksum, + location=temp_dir, + file_name="test.jl", + retry_count=2) + + @pytest.mark.parametrize("neural, continuous, discrete", [ (np.random.randn(100, 30), np.random.randn( 100, 2), np.random.randint(0, 5, (100,))), From 79a25e6d4ef3849592d2bf2235d5548ceb2f732e Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 31 Jan 2026 19:02:48 +0000 Subject: [PATCH 2/3] minimize the diff --- cebra/datasets/hippocampus.py | 24 +++++------ cebra/datasets/monkey_reaching.py | 72 +++++++++++++++---------------- cebra/datasets/synthetic_data.py | 48 ++++++++++----------- 3 files changed, 72 insertions(+), 72 deletions(-) diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index a8ce12d1..aa794d45 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -51,34 +51,34 @@ "achilles": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz", - "checksum": - "c52f9b55cbc23c66d57f3842214058b8", "gzipped_checksum": - "5d7b243e07b24c387e5412cd5ff46f0b" + "5d7b243e07b24c387e5412cd5ff46f0b", + "checksum": + "c52f9b55cbc23c66d57f3842214058b8" }, "buddy": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz", - "checksum": - "36341322907708c466871bf04bc133c2", "gzipped_checksum": - "339290585be2188f48a176f05aaf5df6" + "339290585be2188f48a176f05aaf5df6", + "checksum": + "36341322907708c466871bf04bc133c2" }, "cicero": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz", - "checksum": - "a83b02dbdc884fdd7e53df362499d42f", "gzipped_checksum": - "f262a87d2e59f164cb404cd410015f3a" + "f262a87d2e59f164cb404cd410015f3a", + "checksum": + "a83b02dbdc884fdd7e53df362499d42f" }, "gatsby": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz", - "checksum": - "2b889da48178b3155011c12555342813", "gzipped_checksum": - "564e431c19e55db2286a9d64c86a94c4" + "564e431c19e55db2286a9d64c86a94c4", + "checksum": + "2b889da48178b3155011c12555342813" } } diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 22479455..080e83ae 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -161,98 +161,98 @@ def _get_info(trial_info, data): "all_all.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz", - "checksum": - "dea556301fa4fafa86e28cf8621cab5a", "gzipped_checksum": - "399abc6e9ef0b23a0d6d057c6f508939" + "399abc6e9ef0b23a0d6d057c6f508939", + "checksum": + "dea556301fa4fafa86e28cf8621cab5a" }, "all_train.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz", - "checksum": - "e280e4cd86969e6fd8bfd3a8f402b2fe", "gzipped_checksum": - "eb52c8641fe83ae2a278b372ddec5f69" + "eb52c8641fe83ae2a278b372ddec5f69", + "checksum": + "e280e4cd86969e6fd8bfd3a8f402b2fe" }, "all_test.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz", - "checksum": - "25d3ff2c15014db8b8bf2543482ae881", "gzipped_checksum": - "7688245cf15e0b92503af943ce9f66aa" + "7688245cf15e0b92503af943ce9f66aa", + "checksum": + "25d3ff2c15014db8b8bf2543482ae881" }, "all_valid.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz", - "checksum": - "8cd25169d31f83ae01b03f7b1b939723", "gzipped_checksum": - "b169fc008b4d092fe2a1b7e006cd17a7" + "b169fc008b4d092fe2a1b7e006cd17a7", + "checksum": + "8cd25169d31f83ae01b03f7b1b939723" }, "active_all.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz", - "checksum": - "c626acea5062122f5a68ef18d3e45e51", "gzipped_checksum": - "b7b86e2ae00bb71341de8fc352dae097" + "b7b86e2ae00bb71341de8fc352dae097", + "checksum": + "c626acea5062122f5a68ef18d3e45e51" }, "active_train.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz", - "checksum": - "72a48056691078eee22c36c1992b1d37", "gzipped_checksum": - "56687c633efcbff6c56bbcfa35597565" + "56687c633efcbff6c56bbcfa35597565", + "checksum": + "72a48056691078eee22c36c1992b1d37" }, "active_test.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz", - "checksum": - "35b7e060008a8722c536584c4748f2ea", "gzipped_checksum": - "2057ef1846908a69486a61895d1198e8" + "2057ef1846908a69486a61895d1198e8", + "checksum": + "35b7e060008a8722c536584c4748f2ea" }, "active_valid.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz", - "checksum": - "dd58eb1e589361b4132f34b22af56b79", "gzipped_checksum": - "60b8e418f234877351fe36f1efc169ad" + "60b8e418f234877351fe36f1efc169ad", + "checksum": + "dd58eb1e589361b4132f34b22af56b79" }, "passive_all.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz", - "checksum": - "bbb1bc9d8eec583a46f6673470fc98ad", "gzipped_checksum": - "afb257efa0cac3ccd69ec80478d63691" + "afb257efa0cac3ccd69ec80478d63691", + "checksum": + "bbb1bc9d8eec583a46f6673470fc98ad" }, "passive_train.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz", - "checksum": - "f22e05a69f70e18ba823a0a89162a45c", "gzipped_checksum": - "24d98d7d41a52591f838c41fe83dc2c6" + "24d98d7d41a52591f838c41fe83dc2c6", + "checksum": + "f22e05a69f70e18ba823a0a89162a45c" }, "passive_test.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz", - "checksum": - "42453ae3e4fd27d82d297f78c13cd6b7", "gzipped_checksum": - "f1ff4e9b7c4a0d7fa9dcd271893f57ab" + "f1ff4e9b7c4a0d7fa9dcd271893f57ab", + "checksum": + "42453ae3e4fd27d82d297f78c13cd6b7" }, "passive_valid.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz", - "checksum": - "2dcc10c27631b95a075eaa2d2297bb4a", "gzipped_checksum": - "311fcb6a3e86022f12d78828f7bd29d5" + "311fcb6a3e86022f12d78828f7bd29d5", + "checksum": + "2dcc10c27631b95a075eaa2d2297bb4a" } } diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index eab8d6cf..dc65ff0a 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -34,66 +34,66 @@ "continuous_label_refractory_poisson": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_refractory_poisson.jl.gz", - "checksum": - "fcd92bd283c528d5294093190f55ceba", "gzipped_checksum": - "3641eed973b9cae972493c70b364e981" + "3641eed973b9cae972493c70b364e981", + "checksum": + "fcd92bd283c528d5294093190f55ceba" }, "continuous_label_t": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_t.jl.gz", - "checksum": - "a6e76f274da571568fd2a4bf4cf48b66", "gzipped_checksum": - "1dc8805e8f0836c7c99e864100a65bff" + "1dc8805e8f0836c7c99e864100a65bff", + "checksum": + "a6e76f274da571568fd2a4bf4cf48b66" }, "continuous_label_uniform": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_uniform.jl.gz", - "checksum": - "e67400e77ac009e8c9bc958aa5151973", "gzipped_checksum": - "71d33bc56b89bc227da0990bf16e584b" + "71d33bc56b89bc227da0990bf16e584b", + "checksum": + "e67400e77ac009e8c9bc958aa5151973" }, "continuous_label_laplace": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_laplace.jl.gz", - "checksum": - "41d7ce4ce8901ae7a5136605ac3f5ffb", "gzipped_checksum": - "1563e4958031392d2b2e30cc4cd79b3f" + "1563e4958031392d2b2e30cc4cd79b3f", + "checksum": + "41d7ce4ce8901ae7a5136605ac3f5ffb" }, "continuous_label_poisson": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_poisson.jl.gz", - "checksum": - "a789828f9cca5f3faf36d62ebc4cc8a1", "gzipped_checksum": - "7691304ee061e0bf1e9bb5f2bb6b20e7" + "7691304ee061e0bf1e9bb5f2bb6b20e7", + "checksum": + "a789828f9cca5f3faf36d62ebc4cc8a1" }, "continuous_label_gaussian": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_gaussian.jl.gz", - "checksum": - "18d66a2020923e2cd67d2264d20890aa", "gzipped_checksum": - "0cb97a2c1eaa526e57d2248a333ea8e0" + "0cb97a2c1eaa526e57d2248a333ea8e0", + "checksum": + "18d66a2020923e2cd67d2264d20890aa" }, "continuous_poisson_gaussian_noise": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_poisson_gaussian_noise.jl.gz", - "checksum": - "1a51461820c24a5bcaddaff3991f0ebe", "gzipped_checksum": - "5aa6b6eadf2b733562864d5b67bc6b8d" + "5aa6b6eadf2b733562864d5b67bc6b8d", + "checksum": + "1a51461820c24a5bcaddaff3991f0ebe" }, "sim_100d_poisson_cont_label": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/sim_100d_poisson_cont_label.npz.gz", - "checksum": - "306b9c646e7b76a52cfd828612d700cb", "gzipped_checksum": - "768299435a167dedd57e29b1a6d5af63" + "768299435a167dedd57e29b1a6d5af63", + "checksum": + "306b9c646e7b76a52cfd828612d700cb" } } From 20ab07c44b6e3e3d376bc6ab2922ead31eb5dacc Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 31 Jan 2026 19:07:20 +0000 Subject: [PATCH 3/3] unify the dataset download funcs --- cebra/data/assets.py | 250 ++++++++++++++++------------------------- cebra/data/base.py | 21 +--- tests/test_datasets.py | 18 +-- 3 files changed, 113 insertions(+), 176 deletions(-) diff --git a/cebra/data/assets.py b/cebra/data/assets.py index 683f3b6c..490b798a 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -33,22 +33,31 @@ _MAX_RETRY_COUNT = 2 -def download_file_with_progress_bar(url: str, - expected_checksum: str, - location: str, - file_name: str, - retry_count: int = 0) -> Optional[str]: +def download_file_with_progress_bar( + url: str, + expected_checksum: str, + location: str, + file_name: str, + retry_count: int = 0, + gzipped_checksum: str = None) -> Optional[str]: """Download a file from the given URL. During download, progress is reported using a progress bar. The downloaded file's checksum is compared to the provided ``expected_checksum``. + If ``gzipped_checksum`` is provided, the file is expected to be gzipped. + The function will verify the gzipped file's checksum, extract it, and then + verify the extracted file's checksum. + Args: url: The URL to download the file from. - expected_checksum: The expected checksum value of the downloaded file. + expected_checksum: The expected checksum value of the downloaded file + (or extracted file if gzipped_checksum is provided). location: The directory where the file will be saved. file_name: The name of the file. retry_count: The number of retry attempts (default: 0). + gzipped_checksum: Optional MD5 checksum of the gzipped file. If provided, + the file will be extracted after download. Returns: The path of the downloaded file if the download is successful, None otherwise. @@ -79,30 +88,34 @@ def download_file_with_progress_bar(url: str, f"Error occurred while downloading the file. Response code: {response.status_code}" ) - # Check if the response headers contain the 'Content-Disposition' header - if 'Content-Disposition' not in response.headers: - raise ValueError( - "Unable to determine the filename. 'Content-Disposition' header not found." - ) - - # Extract the filename from the 'Content-Disposition' header - filename_match = re.search(r'filename="(.+)"', - response.headers.get("Content-Disposition")) - if not filename_match: - raise ValueError( - "Unable to determine the filename from the 'Content-Disposition' header." - ) + # For gzipped files, download to a .gz file first + if gzipped_checksum: + download_path = location_path / f"{file_name}.gz" + else: + # Check if the response headers contain the 'Content-Disposition' header + if 'Content-Disposition' not in response.headers: + raise ValueError( + "Unable to determine the filename. 'Content-Disposition' header not found." + ) + + # Extract the filename from the 'Content-Disposition' header + filename_match = re.search(r'filename="(.+)"', + response.headers.get("Content-Disposition")) + if not filename_match: + raise ValueError( + "Unable to determine the filename from the 'Content-Disposition' header." + ) + + filename = filename_match.group(1) + download_path = location_path / filename # Create the directory and any necessary parent directories location_path.mkdir(parents=True, exist_ok=True) - filename = filename_match.group(1) - file_path = location_path / filename - total_size = int(response.headers.get("Content-Length", 0)) checksum = hashlib.md5() # create checksum - with open(file_path, "wb") as file: + with open(download_path, "wb") as file: with tqdm.tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: for data in response.iter_content(chunk_size=1024): @@ -112,18 +125,76 @@ def download_file_with_progress_bar(url: str, progress_bar.update(len(data)) downloaded_checksum = checksum.hexdigest() # Get the checksum value + + # If gzipped, verify gzipped checksum, extract, and verify extracted checksum + if gzipped_checksum: + if downloaded_checksum != gzipped_checksum: + warnings.warn( + f"Gzipped file checksum verification failed. Deleting '{download_path}'." + ) + download_path.unlink() + warnings.warn("Gzipped file deleted. Retrying download...") + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + print("Gzipped file checksum verified. Extracting...") + + # Extract the gzipped file + try: + with gzip.open(download_path, 'rb') as f_in: + with open(file_path, 'wb') as f_out: + while True: + chunk = f_in.read(8192) + if not chunk: + break + f_out.write(chunk) + except Exception as e: + warnings.warn( + f"Extraction failed: {e}. Deleting files and retrying...") + if download_path.exists(): + download_path.unlink() + if file_path.exists(): + file_path.unlink() + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + # Verify extracted file checksum + extracted_checksum = calculate_checksum(file_path) + if extracted_checksum != expected_checksum: + warnings.warn( + "Extracted file checksum verification failed. Deleting files.") + download_path.unlink() + file_path.unlink() + warnings.warn("Files deleted. Retrying download...") + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + # Clean up the gzipped file after successful extraction + download_path.unlink() + print(f"Extraction complete. Dataset saved in '{file_path}'") + return url + + # For non-gzipped files, verify checksum if downloaded_checksum != expected_checksum: - warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.") - file_path.unlink() + warnings.warn( + f"Checksum verification failed. Deleting '{download_path}'.") + download_path.unlink() warnings.warn("File deleted. Retrying download...") # Retry download using a for loop for _ in range(retry_count + 1, _MAX_RETRY_COUNT + 1): return download_file_with_progress_bar(url, expected_checksum, location, file_name, - retry_count + 1) + retry_count + 1, + gzipped_checksum) else: - print(f"Download complete. Dataset saved in '{file_path}'") + print(f"Download complete. Dataset saved in '{download_path}'") return url @@ -141,128 +212,3 @@ def calculate_checksum(file_path: str) -> str: for chunk in iter(lambda: file.read(4096), b""): checksum.update(chunk) return checksum.hexdigest() - - -def download_and_extract_gzipped_file(url: str, - expected_checksum: str, - gzipped_checksum: str, - location: str, - file_name: str, - retry_count: int = 0) -> Optional[str]: - """Download a gzipped file from the given URL, verify checksums, and extract. - - This function downloads a gzipped file, verifies the checksum of the gzipped - file, extracts it, and then verifies the checksum of the extracted file. - - Args: - url: The URL to download the gzipped file from. - expected_checksum: The expected MD5 checksum of the extracted file. - gzipped_checksum: The expected MD5 checksum of the gzipped file. - location: The directory where the file will be saved. - file_name: The name of the final extracted file (without .gz extension). - retry_count: The number of retry attempts (default: 0). - - Returns: - The path of the extracted file if successful, None otherwise. - - Raises: - RuntimeError: If the maximum retry count is exceeded. - requests.HTTPError: If the download fails. - """ - - # Check if the final extracted file already exists with correct checksum - location_path = Path(location) - final_file_path = location_path / file_name - - if final_file_path.exists(): - existing_checksum = calculate_checksum(final_file_path) - if existing_checksum == expected_checksum: - return final_file_path - - if retry_count >= _MAX_RETRY_COUNT: - raise RuntimeError( - f"Exceeded maximum retry count ({_MAX_RETRY_COUNT}). " - f"Unable to download the file from {url}") - - # Create the directory and any necessary parent directories - location_path.mkdir(parents=True, exist_ok=True) - - # Download the gzipped file - gz_file_path = location_path / f"{file_name}.gz" - - response = requests.get(url, stream=True) - - # Check if the request was successful - if response.status_code != 200: - raise requests.HTTPError( - f"Error occurred while downloading the file. Response code: {response.status_code}" - ) - - total_size = int(response.headers.get("Content-Length", 0)) - checksum = hashlib.md5() # create checksum for gzipped file - - # Download the gzipped file - with open(gz_file_path, "wb") as file: - with tqdm.tqdm(total=total_size, - unit="B", - unit_scale=True, - desc="Downloading") as progress_bar: - for data in response.iter_content(chunk_size=1024): - file.write(data) - checksum.update(data) - progress_bar.update(len(data)) - - downloaded_gz_checksum = checksum.hexdigest() - - # Verify gzipped file checksum - if downloaded_gz_checksum != gzipped_checksum: - warnings.warn( - f"Gzipped file checksum verification failed. Deleting '{gz_file_path}'." - ) - gz_file_path.unlink() - warnings.warn("Gzipped file deleted. Retrying download...") - return download_and_extract_gzipped_file(url, expected_checksum, - gzipped_checksum, location, - file_name, retry_count + 1) - - print("Gzipped file checksum verified. Extracting...") - - # Extract the gzipped file - try: - with gzip.open(gz_file_path, 'rb') as f_in: - with open(final_file_path, 'wb') as f_out: - # Extract with progress (estimate based on typical compression ratio) - extracted_size = 0 - while True: - chunk = f_in.read(8192) - if not chunk: - break - f_out.write(chunk) - extracted_size += len(chunk) - except Exception as e: - warnings.warn(f"Extraction failed: {e}. Deleting files and retrying...") - if gz_file_path.exists(): - gz_file_path.unlink() - if final_file_path.exists(): - final_file_path.unlink() - return download_and_extract_gzipped_file(url, expected_checksum, - gzipped_checksum, location, - file_name, retry_count + 1) - - # Verify extracted file checksum - extracted_checksum = calculate_checksum(final_file_path) - if extracted_checksum != expected_checksum: - warnings.warn( - "Extracted file checksum verification failed. Deleting files.") - gz_file_path.unlink() - final_file_path.unlink() - warnings.warn("Files deleted. Retrying download...") - return download_and_extract_gzipped_file(url, expected_checksum, - gzipped_checksum, location, - file_name, retry_count + 1) - - # Clean up the gzipped file after successful extraction - gz_file_path.unlink() - - print(f"Extraction complete. Dataset saved in '{final_file_path}'") - return final_file_path diff --git a/cebra/data/base.py b/cebra/data/base.py index acdcff53..f5491e51 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -80,21 +80,12 @@ def __init__(self, "Missing data checksum. Please provide the checksum to verify the data integrity." ) - # Use gzipped download if gzipped_checksum is provided - if self.gzipped_checksum is not None: - cebra_data_assets.download_and_extract_gzipped_file( - url=self.data_url, - expected_checksum=self.data_checksum, - gzipped_checksum=self.gzipped_checksum, - location=self.location, - file_name=self.file_name) - else: - # Fall back to legacy download for backward compatibility - cebra_data_assets.download_file_with_progress_bar( - url=self.data_url, - expected_checksum=self.data_checksum, - location=self.location, - file_name=self.file_name) + cebra_data_assets.download_file_with_progress_bar( + url=self.data_url, + expected_checksum=self.data_checksum, + location=self.location, + file_name=self.file_name, + gzipped_checksum=self.gzipped_checksum) @property @abc.abstractmethod diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 36aa77f6..88af686c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -414,12 +414,12 @@ def test_download_and_extract_gzipped_file(): mock_response.iter_content = lambda chunk_size: [gzipped_content] # Test successful download and extraction - result = cebra_data_assets.download_and_extract_gzipped_file( + result = cebra_data_assets.download_file_with_progress_bar( url="http://example.com/test.jl.gz", expected_checksum=unzipped_checksum, - gzipped_checksum=gzipped_checksum, location=temp_dir, - file_name=test_filename) + file_name=test_filename, + gzipped_checksum=gzipped_checksum) # Verify the file was extracted assert result is not None @@ -455,13 +455,13 @@ def test_download_and_extract_gzipped_file_wrong_gzipped_checksum(): with pytest.raises(RuntimeError, match="Exceeded maximum retry count"): - cebra_data_assets.download_and_extract_gzipped_file( + cebra_data_assets.download_file_with_progress_bar( url="http://example.com/test.jl.gz", expected_checksum=hashlib.md5(test_content).hexdigest(), - gzipped_checksum=wrong_gz_checksum, location=temp_dir, file_name="test.jl", - retry_count=2) + retry_count=2, + gzipped_checksum=wrong_gz_checksum) def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): @@ -484,13 +484,13 @@ def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): with pytest.raises(RuntimeError, match="Exceeded maximum retry count"): - cebra_data_assets.download_and_extract_gzipped_file( + cebra_data_assets.download_file_with_progress_bar( url="http://example.com/test.jl.gz", expected_checksum=wrong_unzipped_checksum, - gzipped_checksum=gzipped_checksum, location=temp_dir, file_name="test.jl", - retry_count=2) + retry_count=2, + gzipped_checksum=gzipped_checksum) @pytest.mark.parametrize("neural, continuous, discrete", [