diff --git a/cebra/data/assets.py b/cebra/data/assets.py index adea8413..490b798a 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -20,6 +20,7 @@ # limitations under the License. # +import gzip import hashlib import re import warnings @@ -32,22 +33,31 @@ _MAX_RETRY_COUNT = 2 -def download_file_with_progress_bar(url: str, - expected_checksum: str, - location: str, - file_name: str, - retry_count: int = 0) -> Optional[str]: +def download_file_with_progress_bar( + url: str, + expected_checksum: str, + location: str, + file_name: str, + retry_count: int = 0, + gzipped_checksum: str = None) -> Optional[str]: """Download a file from the given URL. During download, progress is reported using a progress bar. The downloaded file's checksum is compared to the provided ``expected_checksum``. + If ``gzipped_checksum`` is provided, the file is expected to be gzipped. + The function will verify the gzipped file's checksum, extract it, and then + verify the extracted file's checksum. + Args: url: The URL to download the file from. - expected_checksum: The expected checksum value of the downloaded file. + expected_checksum: The expected checksum value of the downloaded file + (or extracted file if gzipped_checksum is provided). location: The directory where the file will be saved. file_name: The name of the file. retry_count: The number of retry attempts (default: 0). + gzipped_checksum: Optional MD5 checksum of the gzipped file. If provided, + the file will be extracted after download. Returns: The path of the downloaded file if the download is successful, None otherwise. @@ -78,30 +88,34 @@ def download_file_with_progress_bar(url: str, f"Error occurred while downloading the file. Response code: {response.status_code}" ) - # Check if the response headers contain the 'Content-Disposition' header - if 'Content-Disposition' not in response.headers: - raise ValueError( - "Unable to determine the filename. 'Content-Disposition' header not found." - ) - - # Extract the filename from the 'Content-Disposition' header - filename_match = re.search(r'filename="(.+)"', - response.headers.get("Content-Disposition")) - if not filename_match: - raise ValueError( - "Unable to determine the filename from the 'Content-Disposition' header." - ) + # For gzipped files, download to a .gz file first + if gzipped_checksum: + download_path = location_path / f"{file_name}.gz" + else: + # Check if the response headers contain the 'Content-Disposition' header + if 'Content-Disposition' not in response.headers: + raise ValueError( + "Unable to determine the filename. 'Content-Disposition' header not found." + ) + + # Extract the filename from the 'Content-Disposition' header + filename_match = re.search(r'filename="(.+)"', + response.headers.get("Content-Disposition")) + if not filename_match: + raise ValueError( + "Unable to determine the filename from the 'Content-Disposition' header." + ) + + filename = filename_match.group(1) + download_path = location_path / filename # Create the directory and any necessary parent directories location_path.mkdir(parents=True, exist_ok=True) - filename = filename_match.group(1) - file_path = location_path / filename - total_size = int(response.headers.get("Content-Length", 0)) checksum = hashlib.md5() # create checksum - with open(file_path, "wb") as file: + with open(download_path, "wb") as file: with tqdm.tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: for data in response.iter_content(chunk_size=1024): @@ -111,18 +125,76 @@ def download_file_with_progress_bar(url: str, progress_bar.update(len(data)) downloaded_checksum = checksum.hexdigest() # Get the checksum value + + # If gzipped, verify gzipped checksum, extract, and verify extracted checksum + if gzipped_checksum: + if downloaded_checksum != gzipped_checksum: + warnings.warn( + f"Gzipped file checksum verification failed. Deleting '{download_path}'." + ) + download_path.unlink() + warnings.warn("Gzipped file deleted. Retrying download...") + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + print("Gzipped file checksum verified. Extracting...") + + # Extract the gzipped file + try: + with gzip.open(download_path, 'rb') as f_in: + with open(file_path, 'wb') as f_out: + while True: + chunk = f_in.read(8192) + if not chunk: + break + f_out.write(chunk) + except Exception as e: + warnings.warn( + f"Extraction failed: {e}. Deleting files and retrying...") + if download_path.exists(): + download_path.unlink() + if file_path.exists(): + file_path.unlink() + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + # Verify extracted file checksum + extracted_checksum = calculate_checksum(file_path) + if extracted_checksum != expected_checksum: + warnings.warn( + "Extracted file checksum verification failed. Deleting files.") + download_path.unlink() + file_path.unlink() + warnings.warn("Files deleted. Retrying download...") + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + # Clean up the gzipped file after successful extraction + download_path.unlink() + print(f"Extraction complete. Dataset saved in '{file_path}'") + return url + + # For non-gzipped files, verify checksum if downloaded_checksum != expected_checksum: - warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.") - file_path.unlink() + warnings.warn( + f"Checksum verification failed. Deleting '{download_path}'.") + download_path.unlink() warnings.warn("File deleted. Retrying download...") # Retry download using a for loop for _ in range(retry_count + 1, _MAX_RETRY_COUNT + 1): return download_file_with_progress_bar(url, expected_checksum, location, file_name, - retry_count + 1) + retry_count + 1, + gzipped_checksum) else: - print(f"Download complete. Dataset saved in '{file_path}'") + print(f"Download complete. Dataset saved in '{download_path}'") return url diff --git a/cebra/data/base.py b/cebra/data/base.py index 51199cec..f5491e51 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -55,6 +55,7 @@ def __init__(self, download=False, data_url=None, data_checksum=None, + gzipped_checksum=None, location=None, file_name=None): @@ -64,6 +65,7 @@ def __init__(self, self.download = download self.data_url = data_url self.data_checksum = data_checksum + self.gzipped_checksum = gzipped_checksum self.location = location self.file_name = file_name @@ -82,7 +84,8 @@ def __init__(self, url=self.data_url, expected_checksum=self.data_checksum, location=self.location, - file_name=self.file_name) + file_name=self.file_name, + gzipped_checksum=self.gzipped_checksum) @property @abc.abstractmethod diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index 05c47acb..aa794d45 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -50,25 +50,33 @@ rat_dataset_urls = { "achilles": { "url": - "https://figshare.com/ndownloader/files/40849463?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz", + "gzipped_checksum": + "5d7b243e07b24c387e5412cd5ff46f0b", "checksum": "c52f9b55cbc23c66d57f3842214058b8" }, "buddy": { "url": - "https://figshare.com/ndownloader/files/40849460?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz", + "gzipped_checksum": + "339290585be2188f48a176f05aaf5df6", "checksum": "36341322907708c466871bf04bc133c2" }, "cicero": { "url": - "https://figshare.com/ndownloader/files/40849457?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz", + "gzipped_checksum": + "f262a87d2e59f164cb404cd410015f3a", "checksum": "a83b02dbdc884fdd7e53df362499d42f" }, "gatsby": { "url": - "https://figshare.com/ndownloader/files/40849454?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz", + "gzipped_checksum": + "564e431c19e55db2286a9d64c86a94c4", "checksum": "2b889da48178b3155011c12555342813" } @@ -95,11 +103,13 @@ def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True): location = pathlib.Path(root) / "rat_hippocampus" file_path = location / f"{name}.jl" - super().__init__(download=download, - data_url=rat_dataset_urls[name]["url"], - data_checksum=rat_dataset_urls[name]["checksum"], - location=location, - file_name=f"{name}.jl") + super().__init__( + download=download, + data_url=rat_dataset_urls[name]["url"], + data_checksum=rat_dataset_urls[name]["checksum"], + gzipped_checksum=rat_dataset_urls[name].get("gzipped_checksum"), + location=location, + file_name=f"{name}.jl") data = joblib.load(file_path) self.neural = torch.from_numpy(data["spikes"]).float() diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 05071b12..080e83ae 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -160,73 +160,97 @@ def _get_info(trial_info, data): monkey_reaching_urls = { "all_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz", + "gzipped_checksum": + "399abc6e9ef0b23a0d6d057c6f508939", "checksum": "dea556301fa4fafa86e28cf8621cab5a" }, "all_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz", + "gzipped_checksum": + "eb52c8641fe83ae2a278b372ddec5f69", "checksum": "e280e4cd86969e6fd8bfd3a8f402b2fe" }, "all_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz", + "gzipped_checksum": + "7688245cf15e0b92503af943ce9f66aa", "checksum": "25d3ff2c15014db8b8bf2543482ae881" }, "all_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz", + "gzipped_checksum": + "b169fc008b4d092fe2a1b7e006cd17a7", "checksum": "8cd25169d31f83ae01b03f7b1b939723" }, "active_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz", + "gzipped_checksum": + "b7b86e2ae00bb71341de8fc352dae097", "checksum": "c626acea5062122f5a68ef18d3e45e51" }, "active_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz", + "gzipped_checksum": + "56687c633efcbff6c56bbcfa35597565", "checksum": "72a48056691078eee22c36c1992b1d37" }, "active_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz", + "gzipped_checksum": + "2057ef1846908a69486a61895d1198e8", "checksum": "35b7e060008a8722c536584c4748f2ea" }, "active_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz", + "gzipped_checksum": + "60b8e418f234877351fe36f1efc169ad", "checksum": "dd58eb1e589361b4132f34b22af56b79" }, "passive_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz", + "gzipped_checksum": + "afb257efa0cac3ccd69ec80478d63691", "checksum": "bbb1bc9d8eec583a46f6673470fc98ad" }, "passive_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz", + "gzipped_checksum": + "24d98d7d41a52591f838c41fe83dc2c6", "checksum": "f22e05a69f70e18ba823a0a89162a45c" }, "passive_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz", + "gzipped_checksum": + "f1ff4e9b7c4a0d7fa9dcd271893f57ab", "checksum": "42453ae3e4fd27d82d297f78c13cd6b7" }, "passive_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz", + "gzipped_checksum": + "311fcb6a3e86022f12d78828f7bd29d5", "checksum": "2dcc10c27631b95a075eaa2d2297bb4a" } @@ -270,6 +294,8 @@ def __init__(self, data_url=monkey_reaching_urls[f"{self.load_session}_all.jl"]["url"], data_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"] ["checksum"], + gzipped_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"] + .get("gzipped_checksum"), location=self.path, file_name=f"{self.load_session}_all.jl", ) @@ -297,6 +323,8 @@ def split(self, split): ["url"], data_checksum=monkey_reaching_urls[ f"{self.load_session}_{split}.jl"]["checksum"], + gzipped_checksum=monkey_reaching_urls[ + f"{self.load_session}_{split}.jl"].get("gzipped_checksum"), location=self.path, file_name=f"{self.load_session}_{split}.jl", ) diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index 9288a93d..dc65ff0a 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -33,49 +33,65 @@ synthetic_data_urls = { "continuous_label_refractory_poisson": { "url": - "https://figshare.com/ndownloader/files/41668815?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_refractory_poisson.jl.gz", + "gzipped_checksum": + "3641eed973b9cae972493c70b364e981", "checksum": "fcd92bd283c528d5294093190f55ceba" }, "continuous_label_t": { "url": - "https://figshare.com/ndownloader/files/41668818?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_t.jl.gz", + "gzipped_checksum": + "1dc8805e8f0836c7c99e864100a65bff", "checksum": "a6e76f274da571568fd2a4bf4cf48b66" }, "continuous_label_uniform": { "url": - "https://figshare.com/ndownloader/files/41668821?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_uniform.jl.gz", + "gzipped_checksum": + "71d33bc56b89bc227da0990bf16e584b", "checksum": "e67400e77ac009e8c9bc958aa5151973" }, "continuous_label_laplace": { "url": - "https://figshare.com/ndownloader/files/41668824?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_laplace.jl.gz", + "gzipped_checksum": + "1563e4958031392d2b2e30cc4cd79b3f", "checksum": "41d7ce4ce8901ae7a5136605ac3f5ffb" }, "continuous_label_poisson": { "url": - "https://figshare.com/ndownloader/files/41668827?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_poisson.jl.gz", + "gzipped_checksum": + "7691304ee061e0bf1e9bb5f2bb6b20e7", "checksum": "a789828f9cca5f3faf36d62ebc4cc8a1" }, "continuous_label_gaussian": { "url": - "https://figshare.com/ndownloader/files/41668830?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_gaussian.jl.gz", + "gzipped_checksum": + "0cb97a2c1eaa526e57d2248a333ea8e0", "checksum": "18d66a2020923e2cd67d2264d20890aa" }, "continuous_poisson_gaussian_noise": { "url": - "https://figshare.com/ndownloader/files/41668833?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_poisson_gaussian_noise.jl.gz", + "gzipped_checksum": + "5aa6b6eadf2b733562864d5b67bc6b8d", "checksum": "1a51461820c24a5bcaddaff3991f0ebe" }, "sim_100d_poisson_cont_label": { "url": - "https://figshare.com/ndownloader/files/41668836?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/sim_100d_poisson_cont_label.npz.gz", + "gzipped_checksum": + "768299435a167dedd57e29b1a6d5af63", "checksum": "306b9c646e7b76a52cfd828612d700cb" } @@ -98,11 +114,13 @@ def __init__(self, name, root=_DEFAULT_DATADIR, download=True): location = os.path.join(root, "synthetic") file_path = os.path.join(location, f"{name}.jl") - super().__init__(download=download, - data_url=synthetic_data_urls[name]["url"], - data_checksum=synthetic_data_urls[name]["checksum"], - location=location, - file_name=f"{name}.jl") + super().__init__( + download=download, + data_url=synthetic_data_urls[name]["url"], + data_checksum=synthetic_data_urls[name]["checksum"], + gzipped_checksum=synthetic_data_urls[name].get("gzipped_checksum"), + location=location, + file_name=f"{name}.jl") data = joblib.load(file_path) self.data = data #NOTE: making it backwards compatible with synth notebook. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e8e03ff0..88af686c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import os import pathlib import tempfile @@ -384,6 +385,114 @@ def test_download_file_wrong_content_disposition(filename, url, file_name=filename) +def test_download_and_extract_gzipped_file(): + """Test downloading and extracting gzipped files with dual checksum verification.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + # Create a test file + test_content = b"Test dataset content for gzipped download" + test_filename = "test_dataset.jl" + test_gz_filename = f"{test_filename}.gz" + + # Calculate checksums + unzipped_checksum = cebra_data_assets.calculate_checksum.__wrapped__(test_content) \ + if hasattr(cebra_data_assets.calculate_checksum, '__wrapped__') \ + else hashlib.md5(test_content).hexdigest() + + # Create gzipped content + gzipped_content = gzip.compress(test_content) + gzipped_checksum = hashlib.md5(gzipped_content).hexdigest() + + # Mock the HTTP response + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + # Test successful download and extraction + result = cebra_data_assets.download_file_with_progress_bar( + url="http://example.com/test.jl.gz", + expected_checksum=unzipped_checksum, + location=temp_dir, + file_name=test_filename, + gzipped_checksum=gzipped_checksum) + + # Verify the file was extracted + assert result is not None + final_path = os.path.join(temp_dir, test_filename) + assert os.path.exists(final_path) + + # Verify the content is correct + with open(final_path, 'rb') as f: + extracted_content = f.read() + assert extracted_content == test_content + + # Verify the .gz file was cleaned up + gz_path = os.path.join(temp_dir, test_gz_filename) + assert not os.path.exists(gz_path) + + +def test_download_and_extract_gzipped_file_wrong_gzipped_checksum(): + """Test that wrong gzipped checksum raises error after retries.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + test_content = b"Test content" + gzipped_content = gzip.compress(test_content) + wrong_gz_checksum = "0" * 32 # Wrong checksum + + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + with pytest.raises(RuntimeError, + match="Exceeded maximum retry count"): + cebra_data_assets.download_file_with_progress_bar( + url="http://example.com/test.jl.gz", + expected_checksum=hashlib.md5(test_content).hexdigest(), + location=temp_dir, + file_name="test.jl", + retry_count=2, + gzipped_checksum=wrong_gz_checksum) + + +def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): + """Test that wrong unzipped checksum raises error after retries.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + test_content = b"Test content" + gzipped_content = gzip.compress(test_content) + gzipped_checksum = hashlib.md5(gzipped_content).hexdigest() + wrong_unzipped_checksum = "0" * 32 # Wrong checksum + + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + with pytest.raises(RuntimeError, + match="Exceeded maximum retry count"): + cebra_data_assets.download_file_with_progress_bar( + url="http://example.com/test.jl.gz", + expected_checksum=wrong_unzipped_checksum, + location=temp_dir, + file_name="test.jl", + retry_count=2, + gzipped_checksum=gzipped_checksum) + + @pytest.mark.parametrize("neural, continuous, discrete", [ (np.random.randn(100, 30), np.random.randn( 100, 2), np.random.randint(0, 5, (100,))),