Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 99 additions & 27 deletions cebra/data/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# limitations under the License.
#

import gzip
import hashlib
import re
import warnings
Expand All @@ -32,22 +33,31 @@
_MAX_RETRY_COUNT = 2


def download_file_with_progress_bar(url: str,
expected_checksum: str,
location: str,
file_name: str,
retry_count: int = 0) -> Optional[str]:
def download_file_with_progress_bar(
url: str,
expected_checksum: str,
location: str,
file_name: str,
retry_count: int = 0,
gzipped_checksum: str = None) -> Optional[str]:
"""Download a file from the given URL.

During download, progress is reported using a progress bar. The downloaded
file's checksum is compared to the provided ``expected_checksum``.

If ``gzipped_checksum`` is provided, the file is expected to be gzipped.
The function will verify the gzipped file's checksum, extract it, and then
verify the extracted file's checksum.

Args:
url: The URL to download the file from.
expected_checksum: The expected checksum value of the downloaded file.
expected_checksum: The expected checksum value of the downloaded file
(or extracted file if gzipped_checksum is provided).
location: The directory where the file will be saved.
file_name: The name of the file.
retry_count: The number of retry attempts (default: 0).
gzipped_checksum: Optional MD5 checksum of the gzipped file. If provided,
the file will be extracted after download.

Returns:
The path of the downloaded file if the download is successful, None otherwise.
Expand Down Expand Up @@ -78,30 +88,34 @@ def download_file_with_progress_bar(url: str,
f"Error occurred while downloading the file. Response code: {response.status_code}"
)

# Check if the response headers contain the 'Content-Disposition' header
if 'Content-Disposition' not in response.headers:
raise ValueError(
"Unable to determine the filename. 'Content-Disposition' header not found."
)

# Extract the filename from the 'Content-Disposition' header
filename_match = re.search(r'filename="(.+)"',
response.headers.get("Content-Disposition"))
if not filename_match:
raise ValueError(
"Unable to determine the filename from the 'Content-Disposition' header."
)
# For gzipped files, download to a .gz file first
if gzipped_checksum:
download_path = location_path / f"{file_name}.gz"
else:
# Check if the response headers contain the 'Content-Disposition' header
if 'Content-Disposition' not in response.headers:
raise ValueError(
"Unable to determine the filename. 'Content-Disposition' header not found."
)

# Extract the filename from the 'Content-Disposition' header
filename_match = re.search(r'filename="(.+)"',
response.headers.get("Content-Disposition"))
if not filename_match:
raise ValueError(
"Unable to determine the filename from the 'Content-Disposition' header."
)

filename = filename_match.group(1)
download_path = location_path / filename

# Create the directory and any necessary parent directories
location_path.mkdir(parents=True, exist_ok=True)

filename = filename_match.group(1)
file_path = location_path / filename

total_size = int(response.headers.get("Content-Length", 0))
checksum = hashlib.md5() # create checksum

with open(file_path, "wb") as file:
with open(download_path, "wb") as file:
with tqdm.tqdm(total=total_size, unit="B",
unit_scale=True) as progress_bar:
for data in response.iter_content(chunk_size=1024):
Expand All @@ -111,18 +125,76 @@ def download_file_with_progress_bar(url: str,
progress_bar.update(len(data))

downloaded_checksum = checksum.hexdigest() # Get the checksum value

# If gzipped, verify gzipped checksum, extract, and verify extracted checksum
if gzipped_checksum:
if downloaded_checksum != gzipped_checksum:
warnings.warn(
f"Gzipped file checksum verification failed. Deleting '{download_path}'."
)
download_path.unlink()
warnings.warn("Gzipped file deleted. Retrying download...")
return download_file_with_progress_bar(url, expected_checksum,
location, file_name,
retry_count + 1,
gzipped_checksum)

print("Gzipped file checksum verified. Extracting...")

# Extract the gzipped file
try:
with gzip.open(download_path, 'rb') as f_in:
with open(file_path, 'wb') as f_out:
while True:
chunk = f_in.read(8192)
if not chunk:
break
f_out.write(chunk)
except Exception as e:
warnings.warn(
f"Extraction failed: {e}. Deleting files and retrying...")
if download_path.exists():
download_path.unlink()
if file_path.exists():
file_path.unlink()
return download_file_with_progress_bar(url, expected_checksum,
location, file_name,
retry_count + 1,
gzipped_checksum)

# Verify extracted file checksum
extracted_checksum = calculate_checksum(file_path)
if extracted_checksum != expected_checksum:
warnings.warn(
"Extracted file checksum verification failed. Deleting files.")
download_path.unlink()
file_path.unlink()
warnings.warn("Files deleted. Retrying download...")
return download_file_with_progress_bar(url, expected_checksum,
location, file_name,
retry_count + 1,
gzipped_checksum)

# Clean up the gzipped file after successful extraction
download_path.unlink()
print(f"Extraction complete. Dataset saved in '{file_path}'")
return url

# For non-gzipped files, verify checksum
if downloaded_checksum != expected_checksum:
warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.")
file_path.unlink()
warnings.warn(
f"Checksum verification failed. Deleting '{download_path}'.")
download_path.unlink()
warnings.warn("File deleted. Retrying download...")

# Retry download using a for loop
for _ in range(retry_count + 1, _MAX_RETRY_COUNT + 1):
return download_file_with_progress_bar(url, expected_checksum,
location, file_name,
retry_count + 1)
retry_count + 1,
gzipped_checksum)
else:
print(f"Download complete. Dataset saved in '{file_path}'")
print(f"Download complete. Dataset saved in '{download_path}'")
return url


Expand Down
5 changes: 4 additions & 1 deletion cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self,
download=False,
data_url=None,
data_checksum=None,
gzipped_checksum=None,
location=None,
file_name=None):

Expand All @@ -64,6 +65,7 @@ def __init__(self,
self.download = download
self.data_url = data_url
self.data_checksum = data_checksum
self.gzipped_checksum = gzipped_checksum
self.location = location
self.file_name = file_name

Expand All @@ -82,7 +84,8 @@ def __init__(self,
url=self.data_url,
expected_checksum=self.data_checksum,
location=self.location,
file_name=self.file_name)
file_name=self.file_name,
gzipped_checksum=self.gzipped_checksum)

@property
@abc.abstractmethod
Expand Down
28 changes: 19 additions & 9 deletions cebra/datasets/hippocampus.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,33 @@
rat_dataset_urls = {
"achilles": {
"url":
"https://figshare.com/ndownloader/files/40849463?private_link=9f91576cbbcc8b0d8828",
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz",
"gzipped_checksum":
"5d7b243e07b24c387e5412cd5ff46f0b",
"checksum":
"c52f9b55cbc23c66d57f3842214058b8"
},
"buddy": {
"url":
"https://figshare.com/ndownloader/files/40849460?private_link=9f91576cbbcc8b0d8828",
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz",
"gzipped_checksum":
"339290585be2188f48a176f05aaf5df6",
"checksum":
"36341322907708c466871bf04bc133c2"
},
"cicero": {
"url":
"https://figshare.com/ndownloader/files/40849457?private_link=9f91576cbbcc8b0d8828",
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz",
"gzipped_checksum":
"f262a87d2e59f164cb404cd410015f3a",
"checksum":
"a83b02dbdc884fdd7e53df362499d42f"
},
"gatsby": {
"url":
"https://figshare.com/ndownloader/files/40849454?private_link=9f91576cbbcc8b0d8828",
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz",
"gzipped_checksum":
"564e431c19e55db2286a9d64c86a94c4",
"checksum":
"2b889da48178b3155011c12555342813"
}
Expand All @@ -95,11 +103,13 @@ def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True):
location = pathlib.Path(root) / "rat_hippocampus"
file_path = location / f"{name}.jl"

super().__init__(download=download,
data_url=rat_dataset_urls[name]["url"],
data_checksum=rat_dataset_urls[name]["checksum"],
location=location,
file_name=f"{name}.jl")
super().__init__(
download=download,
data_url=rat_dataset_urls[name]["url"],
data_checksum=rat_dataset_urls[name]["checksum"],
gzipped_checksum=rat_dataset_urls[name].get("gzipped_checksum"),
location=location,
file_name=f"{name}.jl")

data = joblib.load(file_path)
self.neural = torch.from_numpy(data["spikes"]).float()
Expand Down
52 changes: 40 additions & 12 deletions cebra/datasets/monkey_reaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,73 +160,97 @@ def _get_info(trial_info, data):
monkey_reaching_urls = {
"all_all.jl": {
"url":
"https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz",
"gzipped_checksum":
"399abc6e9ef0b23a0d6d057c6f508939",
"checksum":
"dea556301fa4fafa86e28cf8621cab5a"
},
"all_train.jl": {
"url":
"https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz",
"gzipped_checksum":
"eb52c8641fe83ae2a278b372ddec5f69",
"checksum":
"e280e4cd86969e6fd8bfd3a8f402b2fe"
},
"all_test.jl": {
"url":
"https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz",
"gzipped_checksum":
"7688245cf15e0b92503af943ce9f66aa",
"checksum":
"25d3ff2c15014db8b8bf2543482ae881"
},
"all_valid.jl": {
"url":
"https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz",
"gzipped_checksum":
"b169fc008b4d092fe2a1b7e006cd17a7",
"checksum":
"8cd25169d31f83ae01b03f7b1b939723"
},
"active_all.jl": {
"url":
"https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz",
"gzipped_checksum":
"b7b86e2ae00bb71341de8fc352dae097",
"checksum":
"c626acea5062122f5a68ef18d3e45e51"
},
"active_train.jl": {
"url":
"https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz",
"gzipped_checksum":
"56687c633efcbff6c56bbcfa35597565",
"checksum":
"72a48056691078eee22c36c1992b1d37"
},
"active_test.jl": {
"url":
"https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz",
"gzipped_checksum":
"2057ef1846908a69486a61895d1198e8",
"checksum":
"35b7e060008a8722c536584c4748f2ea"
},
"active_valid.jl": {
"url":
"https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz",
"gzipped_checksum":
"60b8e418f234877351fe36f1efc169ad",
"checksum":
"dd58eb1e589361b4132f34b22af56b79"
},
"passive_all.jl": {
"url":
"https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz",
"gzipped_checksum":
"afb257efa0cac3ccd69ec80478d63691",
"checksum":
"bbb1bc9d8eec583a46f6673470fc98ad"
},
"passive_train.jl": {
"url":
"https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz",
"gzipped_checksum":
"24d98d7d41a52591f838c41fe83dc2c6",
"checksum":
"f22e05a69f70e18ba823a0a89162a45c"
},
"passive_test.jl": {
"url":
"https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz",
"gzipped_checksum":
"f1ff4e9b7c4a0d7fa9dcd271893f57ab",
"checksum":
"42453ae3e4fd27d82d297f78c13cd6b7"
},
"passive_valid.jl": {
"url":
"https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914",
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz",
"gzipped_checksum":
"311fcb6a3e86022f12d78828f7bd29d5",
"checksum":
"2dcc10c27631b95a075eaa2d2297bb4a"
}
Expand Down Expand Up @@ -270,6 +294,8 @@ def __init__(self,
data_url=monkey_reaching_urls[f"{self.load_session}_all.jl"]["url"],
data_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"]
["checksum"],
gzipped_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"]
.get("gzipped_checksum"),
location=self.path,
file_name=f"{self.load_session}_all.jl",
)
Expand Down Expand Up @@ -297,6 +323,8 @@ def split(self, split):
["url"],
data_checksum=monkey_reaching_urls[
f"{self.load_session}_{split}.jl"]["checksum"],
gzipped_checksum=monkey_reaching_urls[
f"{self.load_session}_{split}.jl"].get("gzipped_checksum"),
location=self.path,
file_name=f"{self.load_session}_{split}.jl",
)
Expand Down
Loading