diff --git a/alembic/versions/20260112_232724_ea6989325f62_s3_upload_delegation.py b/alembic/versions/20260112_232724_ea6989325f62_s3_upload_delegation.py new file mode 100644 index 00000000..e0ecd615 --- /dev/null +++ b/alembic/versions/20260112_232724_ea6989325f62_s3_upload_delegation.py @@ -0,0 +1,55 @@ +"""s3_upload_delegation + +Revision ID: ea6989325f62 +Revises: 642ef45b49c6 +Create Date: 2026-01-12 23:27:24.210332 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from alembic_postgresql_enum import TableReference +from sqlalchemy.dialects import postgresql + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "ea6989325f62" +down_revision: Union[str, None] = "642ef45b49c6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "asset", sa.Column("upload_meta", postgresql.JSONB(astext_type=sa.Text()), nullable=True) + ) + op.sync_enum_values( + enum_schema="public", + enum_name="assetstatus", + new_values=["CREATED", "UPLOADING", "DELETED"], + affected_columns=[ + TableReference(table_schema="public", table_name="asset", column_name="status") + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="assetstatus", + new_values=["CREATED", "DELETED"], + affected_columns=[ + TableReference(table_schema="public", table_name="asset", column_name="status") + ], + enum_values_to_rename=[], + ) + op.drop_column("asset", "upload_meta") + # ### end Alembic commands ### diff --git a/alembic/versions/20260112_232729_4d9640ae6ba0_update_triggers.py b/alembic/versions/20260112_232729_4d9640ae6ba0_update_triggers.py new file mode 100644 index 00000000..be9abe92 --- /dev/null +++ b/alembic/versions/20260112_232729_4d9640ae6ba0_update_triggers.py @@ -0,0 +1,54 @@ +"""Update triggers + +Revision ID: 4d9640ae6ba0 +Revises: ea6989325f62 +Create Date: 2026-01-12 23:27:29.559400 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "4d9640ae6ba0" +down_revision: Union[str, None] = "ea6989325f62" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index( + "ix_asset_full_path", + "asset", + ["full_path"], + unique=True, + postgresql_where=sa.text("status != 'DELETED'"), + ) + op.create_index( + "uq_asset_entity_id_path", + "asset", + ["path", "entity_id"], + unique=True, + postgresql_where=sa.text("status != 'DELETED'"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + "uq_asset_entity_id_path", + table_name="asset", + postgresql_where=sa.text("status != 'DELETED'"), + ) + op.drop_index( + "ix_asset_full_path", table_name="asset", postgresql_where=sa.text("status != 'DELETED'") + ) + # ### end Alembic commands ### diff --git a/app/config.py b/app/config.py index 0c694ab2..e437684a 100644 --- a/app/config.py +++ b/app/config.py @@ -48,6 +48,13 @@ class Settings(BaseSettings): S3_MULTIPART_THRESHOLD: int = 5 * 1024**2 # bytes # TODO: decide an appropriate value S3_PRESIGNED_URL_EXPIRATION: int = 600 # seconds # TODO: decide an appropriate value + S3_MULTIPART_UPLOAD_MAX_SIZE: int = 5 * 1024**3 # TODO: Set appropriate upper file size limit + S3_MULTIPART_UPLOAD_MIN_PART_SIZE: int = 5 * 1024**2 + S3_MULTIPART_UPLOAD_MAX_PART_SIZE: int = 5 * 1024**3 + S3_MULTIPART_UPLOAD_MIN_PARTS: int = 1 + S3_MULTIPART_UPLOAD_MAX_PARTS: int = 10_000 + S3_MULTIPART_UPLOAD_DEFAULT_PARTS: int = 100 + API_ASSET_POST_MAX_SIZE: int = 150 * 1024**2 # bytes # TODO: decide an appropriate value DB_ENGINE: str = "postgresql+psycopg2" diff --git a/app/db/events.py b/app/db/events.py index 4802c2a7..c33aad83 100644 --- a/app/db/events.py +++ b/app/db/events.py @@ -3,8 +3,9 @@ from sqlalchemy.orm.session import object_session from app.db.model import Asset +from app.db.types import AssetStatus from app.logger import L -from app.utils.s3 import delete_asset_storage_object, get_s3_client +from app.utils.s3 import delete_asset_storage_object, get_s3_client, multipart_upload_abort ASSETS_TO_DELETE_KEY = "assets_to_delete_from_storage" @@ -22,22 +23,55 @@ def collect_asset_for_storage_deletion(_mapper, _connection, target: Asset): @event.listens_for(Session, "after_commit") def delete_assets_from_storage(session: Session): - """Delete storage objects for assets removed in a committed transaction.""" + """Delete storage objects for assets removed in a committed transaction. + + Note: Due to the nature of the operation that iterates over all assets it is important to not + throw an error even if one of the external side-effect fail. Otherwise, after the rollback + there might be db assets that are not deleted but their s3 files are. + + Instead with capturing the errors it is ensured that db assets are always deleted even if that + may result in orphan files or multipart uploads that have failed to be deleted. + + TODO: Add a cleanup function on a schedule that would remove s3 orphans from time to time. + """ to_delete = session.info.pop(ASSETS_TO_DELETE_KEY, set()) for asset in to_delete: - try: - delete_asset_storage_object( - storage_type=asset.storage_type, - s3_key=asset.full_path, - storage_client_factory=get_s3_client, - ) - except Exception: # noqa: BLE001 - L.exception( - "Failed to delete storage object for Asset id={} full_path={} storage_type={}", - asset.id, - asset.full_path, - asset.storage_type, - ) + match asset.status: + case AssetStatus.UPLOADING: + try: + multipart_upload_abort( + upload_id=asset.upload_meta["upload_id"], + storage_type=asset.storage_type, + s3_key=asset.full_path, + storage_client_factory=get_s3_client, + ) + except Exception: # noqa: BLE001 + L.exception( + ( + "Failed to abort multipart upload for Asset " + "id={} full_path={} storage_type={}" + ), + asset.id, + asset.full_path, + asset.storage_type, + ) + case _: + try: + delete_asset_storage_object( + storage_type=asset.storage_type, + s3_key=asset.full_path, + storage_client_factory=get_s3_client, + ) + except Exception: # noqa: BLE001 + L.exception( + ( + "Failed to delete storage object for Asset " + "id={} full_path={} storage_type={}" + ), + asset.id, + asset.full_path, + asset.storage_type, + ) @event.listens_for(Session, "after_rollback") diff --git a/app/db/model.py b/app/db/model.py index 41cc089d..34175672 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -1438,6 +1438,8 @@ class Asset(Identifiable): ) storage_type: Mapped[StorageType] + upload_meta: Mapped[JSON_DICT | None] + # partial unique index __table_args__ = ( Index( diff --git a/app/db/types.py b/app/db/types.py index 19f00a83..fea26990 100644 --- a/app/db/types.py +++ b/app/db/types.py @@ -229,6 +229,7 @@ class AnnotationBodyType(StrEnum): class AssetStatus(StrEnum): CREATED = auto() + UPLOADING = auto() DELETED = auto() diff --git a/app/errors.py b/app/errors.py index 08595ba7..4f80366e 100644 --- a/app/errors.py +++ b/app/errors.py @@ -46,6 +46,8 @@ class ApiErrorCode(UpperStrEnum): ASSET_NOT_A_DIRECTORY = auto() ASSET_INVALID_SCHEMA = auto() ASSET_INVALID_CONTENT_TYPE = auto() + ASSET_UPLOAD_INCOMPLETE = auto() + ASSET_NOT_UPLOADING = auto() ION_NAME_NOT_FOUND = auto() S3_CANNOT_CREATE_PRESIGNED_URL = auto() OPENAI_API_KEY_MISSING = auto() diff --git a/app/repository/asset.py b/app/repository/asset.py index 73881470..fffa8957 100644 --- a/app/repository/asset.py +++ b/app/repository/asset.py @@ -37,11 +37,13 @@ def get_entity_asset( return self.db.execute(query).scalar_one() - def create_entity_asset(self, entity_id: uuid.UUID, asset: AssetCreate) -> Asset: + def create_entity_asset( + self, entity_id: uuid.UUID, asset: AssetCreate, status: AssetStatus = AssetStatus.CREATED + ) -> Asset: """Create an asset associated with the given entity.""" sha256_digest = bytes.fromhex(asset.sha256_digest) if asset.sha256_digest else None db_asset = Asset( - status=AssetStatus.CREATED, + status=status, entity_id=entity_id, path=asset.path, full_path=asset.full_path, diff --git a/app/routers/asset.py b/app/routers/asset.py index 851c382e..93f6c573 100644 --- a/app/routers/asset.py +++ b/app/routers/asset.py @@ -9,7 +9,7 @@ from starlette.responses import RedirectResponse from app.config import settings, storages -from app.db.types import AssetLabel, ContentType, StorageType +from app.db.types import AssetLabel, AssetStatus, ContentType, StorageType from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep from app.dependencies.common import PaginationQuery from app.dependencies.db import RepoGroupDep @@ -22,7 +22,9 @@ AssetRegister, DetailedFileList, DirectoryUpload, + InitiateUploadRequest, ) +from app.schemas.auth import UserContext from app.schemas.types import ListResponse from app.service import asset as asset_service from app.utils.files import calculate_sha256_digest, get_content_type @@ -34,6 +36,7 @@ upload_to_s3, validate_filename, validate_filesize, + validate_multipart_filesize, ) router = APIRouter( @@ -225,6 +228,12 @@ def download_entity_asset( entity_id=entity_id, asset_id=asset_id, ) + if asset.status == AssetStatus.UPLOADING: + raise ApiError( + message="Cannot download an uploading asset, because it is incomplete.", + error_code=ApiErrorCode.ASSET_UPLOAD_INCOMPLETE, + http_status_code=HTTPStatus.FORBIDDEN, + ) if asset.is_directory: if asset_path is None: msg = "Missing required parameter for downloading a directory file: asset_path" @@ -335,23 +344,105 @@ def entity_asset_directory_list( return files -@router.post("/{entity_route}/{entity_id}/assets/upload/initiate", include_in_schema=False) +@router.post("/{entity_route}/{entity_id}/assets/multipart-upload/initiate") def initiate_entity_asset_upload( repos: RepoGroupDep, + storage_client_factory: StorageClientFactoryDep, user_context: UserContextWithProjectIdDep, entity_route: EntityRoute, - entity_id: int, -): + entity_id: uuid.UUID, + json_model: InitiateUploadRequest, +) -> AssetRead: """Generate a signed URL with expiration that can be used to upload the file directly to S3.""" - raise NotImplementedError + if not validate_multipart_filesize(json_model.filesize): + msg = f"File not allowed because bigger than {settings.S3_MULTIPART_UPLOAD_MAX_SIZE}" + raise ApiError( + message=msg, + error_code=ApiErrorCode.ASSET_INVALID_FILE, + http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + + if not validate_filename(json_model.filename): + msg = f"Invalid file name {json_model.filename!r}" + raise ApiError( + message=msg, + error_code=ApiErrorCode.ASSET_INVALID_PATH, + http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + try: + content_type = get_content_type(json_model) + except ValueError as e: + msg = ( + f"Invalid content type for file {json_model.filename}. " + f"Supported content types: {sorted(c.value for c in ContentType)}.\n" + f"Exception: {e}" + ) + raise ApiError( + message=msg, + error_code=ApiErrorCode.ASSET_INVALID_CONTENT_TYPE, + http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) from None + + storage = storages[StorageType.aws_s3_internal] # hardcoded for now + s3_client = storage_client_factory(storage) -@router.post("/{entity_route}/{entity_id}/assets/upload/complete", include_in_schema=False) + entity_type = entity_route_to_type(entity_route) + + # create asset to fail early if full path already in progress or registered + asset_read = asset_service.create_entity_asset( + repos=repos, + user_context=user_context, + entity_type=entity_type, + entity_id=entity_id, + filename=json_model.filename, + content_type=content_type, + size=json_model.filesize, + sha256_digest=json_model.sha256_digest, + meta=None, + label=json_model.label, + is_directory=False, + storage_type=storage.type, + status=AssetStatus.UPLOADING, + ) + + # create presigned urls using the part count hint and filesize + # asset schemas is updated with the upload metadata + asset_read = asset_service.entity_asset_upload_initiate( + repos=repos, + user_context=UserContext.model_validate(user_context, from_attributes=True), + s3_client=s3_client, + entity_type=entity_type, + entity_id=entity_id, + asset_id=asset_read.id, + bucket=storage.bucket, + s3_key=asset_read.full_path, + filesize=json_model.filesize, + content_type=content_type, + preferred_part_count=json_model.preferred_part_count, + ) + + return asset_read + + +@router.post("/{entity_route}/{entity_id}/assets/{asset_id}/multipart-upload/complete") def complete_entity_asset_upload( repos: RepoGroupDep, + storage_client_factory: StorageClientFactoryDep, user_context: UserContextDep, entity_route: EntityRoute, - entity_id: int, -): + entity_id: uuid.UUID, + asset_id: uuid.UUID, +) -> AssetRead: """Register the uploaded file.""" - raise NotImplementedError + storage = storages[StorageType.aws_s3_internal] # hardcoded for now + + return asset_service.entity_asset_upload_complete( + repos=repos, + user_context=user_context, + entity_type=entity_route_to_type(entity_route), + entity_id=entity_id, + asset_id=asset_id, + storage=storage, + s3_client=storage_client_factory(storage), + ) diff --git a/app/schemas/asset.py b/app/schemas/asset.py index 6f822a7d..f7d7a529 100644 --- a/app/schemas/asset.py +++ b/app/schemas/asset.py @@ -3,10 +3,10 @@ from pathlib import Path from typing import Annotated -from pydantic import AfterValidator, BaseModel, ConfigDict, field_validator, model_validator +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, field_validator, model_validator from pydantic.networks import AnyUrl -from app.config import storages +from app.config import settings, storages from app.db.types import ( ALLOWED_ASSET_LABELS_PER_ENTITY, CONTENT_TYPE_TO_SUFFIX, @@ -38,6 +38,17 @@ def validate_full_path(s: str) -> str: return s +class ToUploadPart(BaseModel): + part_number: int + url: Annotated[str, Field(description="Presigned url to upload file part.")] + + +class UploadMeta(BaseModel): + upload_id: str + part_size: int + parts: list[ToUploadPart] + + class AssetBase(BaseModel): """Asset model with common attributes.""" @@ -49,6 +60,7 @@ class AssetBase(BaseModel): meta: dict = {} label: AssetLabel storage_type: StorageType + upload_meta: UploadMeta | None = None class SizeAndDigestMixin(BaseModel): @@ -175,3 +187,27 @@ class DetailedFileList(BaseModel): class AssetAndPresignedURLS(BaseModel): asset: AssetRead files: dict[Path, AnyUrl] + + +class InitiateUploadRequest(BaseModel): + filename: Annotated[str, Field(description="File name to be uploaded.")] + filesize: Annotated[int, Field(description="File size to be uploaded in bytes.")] + sha256_digest: str + content_type: Annotated[ + str | None, + Field( + description=( + "Content type of file. " + "If not provided it will be deduced from the file's extension." + ) + ), + ] = None + label: AssetLabel + preferred_part_count: Annotated[int, Field(description="Hint of desired part count.")] = ( + settings.S3_MULTIPART_UPLOAD_DEFAULT_PARTS + ) + + +class UploadedPart(BaseModel): + part_number: int + etag: str diff --git a/app/service/asset.py b/app/service/asset.py index f28af62c..3edf2edc 100644 --- a/app/service/asset.py +++ b/app/service/asset.py @@ -22,6 +22,8 @@ AssetRead, DetailedFileList, DirectoryUpload, + ToUploadPart, + UploadMeta, ) from app.schemas.auth import UserContext, UserContextWithProjectId from app.schemas.types import ListResponse @@ -33,6 +35,11 @@ delete_from_s3, generate_presigned_url, list_directory_with_details, + multipart_compute_upload_plan, + multipart_upload_complete, + multipart_upload_create_part_presigned_url, + multipart_upload_initiate, + multipart_upload_list_parts, sanitize_directory_traversal, ) @@ -114,6 +121,7 @@ def create_entity_asset( # noqa: PLR0913 is_directory: bool, storage_type: StorageType, full_path: str | None = None, + status: AssetStatus = AssetStatus.CREATED, ) -> AssetRead: """Create an asset for an entity.""" entity = entity_service.get_writable_entity( @@ -158,6 +166,7 @@ def create_entity_asset( # noqa: PLR0913 asset_db = repos.asset.create_entity_asset( entity_id=entity_id, asset=asset_create, + status=status, ) return AssetRead.model_validate(asset_db) @@ -344,3 +353,139 @@ def list_directory( ) return DetailedFileList.model_validate({"files": ret}) + + +def entity_asset_upload_initiate( + *, + repos: RepositoryGroup, + entity_type: EntityType, + entity_id: uuid.UUID, + asset_id: uuid.UUID, + s3_client: S3Client, + s3_key: str, + bucket: str, + filesize: int, + content_type: str, + preferred_part_count: int, + user_context: UserContext, +) -> AssetRead: + _ = entity_service.get_readable_entity( + repos, + user_context=user_context, + entity_type=entity_type, + entity_id=entity_id, + ) + with ensure_result(f"Asset {asset_id} not found", error_code=ApiErrorCode.ASSET_NOT_FOUND): + asset = repos.asset.get_entity_asset( + entity_type=entity_type, + entity_id=entity_id, + asset_id=asset_id, + ) + + storage = storages[asset.storage_type] + + upload_id = multipart_upload_initiate( + s3_client=s3_client, + s3_key=s3_key, + bucket=bucket, + content_type=content_type, + ) + part_size, part_count = multipart_compute_upload_plan( + filesize=filesize, + preferred_part_count=preferred_part_count, + ) + parts = [ + ToUploadPart( + part_number=part_number, + url=multipart_upload_create_part_presigned_url( + s3_client=s3_client, + bucket=storage.bucket, + s3_key=asset.full_path, + upload_id=upload_id, + part_number=part_number, + ), + ) + for part_number in range(1, part_count + 1) + ] + asset.status = AssetStatus.UPLOADING + asset.upload_meta = UploadMeta( + upload_id=upload_id, + part_size=part_size, + parts=parts, + ).model_dump() + repos.db.flush() + return AssetRead.model_validate(asset) + + +def entity_asset_upload_complete( + *, + repos: RepositoryGroup, + user_context: UserContext, + entity_type: EntityType, + entity_id: uuid.UUID, + asset_id: uuid.UUID, + storage: StorageUnion, + s3_client: S3Client, +) -> AssetRead: + asset = get_entity_asset( + repos, + user_context=user_context, + entity_type=entity_type, + entity_id=entity_id, + asset_id=asset_id, + ) + + if asset.status != AssetStatus.UPLOADING or not asset.upload_meta: + raise ApiError( + message="Asset is not uploading. Operation cannot be performed.", + error_code=ApiErrorCode.ASSET_NOT_UPLOADING, + http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + + # parts that have been already uploaded to s3 + uploaded_parts = multipart_upload_list_parts( + s3_client=s3_client, + bucket=storage.bucket, + s3_key=asset.full_path, + upload_id=asset.upload_meta.upload_id, + ) + + # parts that are expected to be uploaded via the presigned urls + parts = asset.upload_meta.parts + + # verify that expected and uploaded agree + uploaded_part_numbers = {p["PartNumber"] for p in uploaded_parts} + expected_part_numbers = {p.part_number for p in parts} + + if uploaded_part_numbers != expected_part_numbers: + raise ApiError( + message=( + "Expected parts are not uploaded. " + f"Expected: {len(expected_part_numbers)}, Actual: {len(uploaded_part_numbers)}" + ), + error_code=ApiErrorCode.ASSET_UPLOAD_INCOMPLETE, + http_status_code=HTTPStatus.CONFLICT, + ) + + multipart_upload_complete( + s3_client=s3_client, + s3_key=asset.full_path, + upload_id=asset.upload_meta.upload_id, + bucket=storage.bucket, + parts=uploaded_parts, + ) + _ = entity_service.get_readable_entity( + repos, + user_context=user_context, + entity_type=entity_type, + entity_id=entity_id, + ) + with ensure_result(f"Asset {asset_id} not found", error_code=ApiErrorCode.ASSET_NOT_FOUND): + asset_db = repos.asset.get_entity_asset( + entity_type=entity_type, entity_id=entity_id, asset_id=asset_id + ) + + asset_db.status = AssetStatus.CREATED + asset_db.upload_meta = None + repos.db.flush() + return AssetRead.model_validate(asset_db) diff --git a/app/utils/common.py b/app/utils/common.py index 990e76c2..5276e0a8 100644 --- a/app/utils/common.py +++ b/app/utils/common.py @@ -5,3 +5,8 @@ def is_ascii(s: str) -> bool: except UnicodeEncodeError: return False return True + + +def clip(x: int, min_value: int, max_value: int) -> int: + """Clamp x to the inclusive range [min_value, max_value].""" + return max(min_value, min(x, max_value)) diff --git a/app/utils/files.py b/app/utils/files.py index 13033e33..34c49e79 100644 --- a/app/utils/files.py +++ b/app/utils/files.py @@ -5,9 +5,10 @@ from app.db.types import ContentType from app.logger import L +from app.schemas.asset import InitiateUploadRequest -def get_content_type(file: UploadFile) -> ContentType: +def get_content_type(file: UploadFile | InitiateUploadRequest) -> ContentType: """Return the file content-type. In case of discrepancy with the original content-type, the discrepancy is just logged. diff --git a/app/utils/s3.py b/app/utils/s3.py index 72299d9c..4f29ef6a 100644 --- a/app/utils/s3.py +++ b/app/utils/s3.py @@ -1,3 +1,4 @@ +import math import os import uuid from pathlib import Path @@ -17,6 +18,7 @@ from app.db.types import EntityType, StorageType from app.logger import L from app.schemas.asset import validate_path +from app.utils.common import clip class StorageClientFactory(Protocol): @@ -53,6 +55,10 @@ def validate_filesize(filesize: int) -> bool: return filesize <= settings.API_ASSET_POST_MAX_SIZE +def validate_multipart_filesize(filesize: int) -> bool: + return filesize <= settings.S3_MULTIPART_UPLOAD_MAX_SIZE + + def get_s3_client(storage: StorageUnion) -> S3Client: """Return a new S3 client (not thread-safe). @@ -166,6 +172,100 @@ def generate_presigned_url( return url +def multipart_upload_initiate( + s3_client: S3Client, bucket: str, s3_key: str, content_type: str +) -> str: + res = s3_client.create_multipart_upload(Bucket=bucket, Key=s3_key, ContentType=content_type) + return res["UploadId"] + + +def multipart_compute_upload_plan(*, filesize: int, preferred_part_count: int) -> tuple[int, int]: + part_count = clip( + preferred_part_count, + min_value=settings.S3_MULTIPART_UPLOAD_MIN_PARTS, + max_value=settings.S3_MULTIPART_UPLOAD_MAX_PARTS, + ) + part_size = math.ceil(filesize / part_count) + + part_size = clip( + part_size, + min_value=settings.S3_MULTIPART_UPLOAD_MIN_PART_SIZE, + max_value=settings.S3_MULTIPART_UPLOAD_MAX_PART_SIZE, + ) + part_count = math.ceil(filesize / part_size) + + return part_size, part_count + + +def multipart_upload_create_part_presigned_url( + s3_client: S3Client, + bucket: str, + s3_key: str, + upload_id: str, + part_number: int, +): + return s3_client.generate_presigned_url( + "upload_part", + Params={ + "Bucket": bucket, + "Key": s3_key, + "UploadId": upload_id, + "PartNumber": part_number, + }, + ExpiresIn=settings.S3_PRESIGNED_URL_EXPIRATION, + ) + + +def multipart_upload_list_parts( + s3_client: S3Client, + bucket: str, + s3_key: str, + upload_id: str, +) -> list: + paginator = s3_client.get_paginator("list_parts") + page_iterator = paginator.paginate( + Bucket=bucket, + Key=s3_key, + UploadId=upload_id, + ) + return [part for page in page_iterator for part in page.get("Parts", [])] + + +def multipart_upload_complete( + s3_client: S3Client, s3_key: str, upload_id: str, bucket: str, parts: list +): + s3_client.complete_multipart_upload( + Bucket=bucket, + Key=s3_key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + { + "ETag": p["ETag"], + "PartNumber": p["PartNumber"], + } + for p in parts + ], + }, + ) + + +def multipart_upload_abort( + *, + upload_id: str, + storage_type: StorageType, + s3_key: str, + storage_client_factory: StorageClientFactory, +) -> None: + storage = storages[storage_type] + s3_client = storage_client_factory(storage) + s3_client.abort_multipart_upload( + Bucket=storage.bucket, + Key=s3_key, + UploadId=upload_id, + ) + + def list_directory_with_details( s3_client: S3Client, bucket_name: str, diff --git a/scripts/export/build_database_archive.sh b/scripts/export/build_database_archive.sh index 22c6343f..0d2efb7b 100755 --- a/scripts/export/build_database_archive.sh +++ b/scripts/export/build_database_archive.sh @@ -2,7 +2,7 @@ # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="642ef45b49c6" +SCRIPT_DB_VERSION="4d9640ae6ba0" echo "DB dump (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" @@ -263,7 +263,7 @@ install -m 755 /dev/stdin "$WORK_DIR/load.sh" <<'EOF_LOAD_SCRIPT' # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="642ef45b49c6" +SCRIPT_DB_VERSION="4d9640ae6ba0" echo "DB load (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" diff --git a/tests/conftest.py b/tests/conftest.py index 4c093d2f..9c9835bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,9 +108,19 @@ def s3(): @pytest.fixture(scope="session") -def _create_buckets(s3): - s3.create_bucket(Bucket=storages[StorageType.aws_s3_internal].bucket) - s3.create_bucket(Bucket=storages[StorageType.aws_s3_open].bucket, ACL="public-read") +def s3_internal_bucket(): + return storages[StorageType.aws_s3_internal].bucket + + +@pytest.fixture(scope="session") +def s3_open_bucket(): + return storages[StorageType.aws_s3_open].bucket + + +@pytest.fixture(scope="session") +def _create_buckets(s3, s3_internal_bucket, s3_open_bucket): + s3.create_bucket(Bucket=s3_internal_bucket) + s3.create_bucket(Bucket=s3_open_bucket, ACL="public-read") @pytest.fixture diff --git a/tests/routers/test_asset.py b/tests/routers/test_asset.py index a2315a42..200992ec 100644 --- a/tests/routers/test_asset.py +++ b/tests/routers/test_asset.py @@ -22,6 +22,7 @@ create_cell_morphology_id, route, s3_key_exists, + s3_multipart_upload_exists, upload_entity_asset, ) @@ -129,6 +130,24 @@ def asset(client, entity) -> AssetRead: return AssetRead.model_validate(data) +@pytest.fixture +def uploading_asset(client, entity) -> AssetRead: + """Create an asset the file of which is being uploaded with delegation.""" + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json={ + "filename": "foo.swc", + "filesize": 3 * 5 * 1024**2, + "sha256_digest": "e3b7c1f0a9d4b8e6f2c0a5d9e1b4c8f6a0d3e7b2c9f4a6d8e5b1c0f9a2", + "preferred_part_count": 3, + "label": "morphology", + "content_type": "application/swc", + }, + ).json() + return AssetRead.model_validate(data) + + @pytest.fixture def asset_directory(db, root_circuit, person_id) -> Asset: s3_path = _get_expected_full_path(entity=root_circuit, path="my-directory") @@ -176,6 +195,7 @@ def test_upload_entity_asset(client, entity, monkeypatch): "status": "created", "label": "morphology", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } # try to upload again the same file with the same path @@ -269,6 +289,7 @@ def test_upload_entity_asset(client, entity, monkeypatch): "status": "created", "label": "morphology", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } # try to upload a file too big @@ -408,6 +429,7 @@ def test_register_entity_asset_as_file(client, entity): "size": -1, "status": "created", "storage_type": "aws_s3_open", + "upload_meta": None, } @@ -438,6 +460,7 @@ def test_register_entity_asset_as_directory(client, circuit): "size": -1, "status": "created", "storage_type": "aws_s3_open", + "upload_meta": None, } @@ -552,6 +575,7 @@ def test_get_entity_asset(client, entity, asset): "status": "created", "label": "morphology", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } # try to get an asset with non-existent entity id @@ -596,6 +620,7 @@ def test_get_entity_assets(client, entity, asset): "status": "created", "label": "morphology", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } ] @@ -628,6 +653,7 @@ def test_get_deleted_entity_assets__admin(db, client_admin, entity, asset): "status": "created", "label": "morphology", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } assert data == [expected_asset_payload] @@ -692,6 +718,18 @@ def test_download_entity_asset(client, entity, asset): ) +def test_download_entity_asset__uploading(client, entity, uploading_asset): + """Test that downloading an uploading asset is forbidden.""" + + response = client.get( + f"{route(entity.type)}/{entity.id}/assets/{uploading_asset.id}/download", + follow_redirects=False, + ) + assert response.status_code == 403 + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_UPLOAD_INCOMPLETE + + @pytest.mark.parametrize("client_fixture", ["client_user_2", "client_no_project"]) def test_download_entity_asset_non_authorized(request, client_fixture, entity, asset): client = request.getfixturevalue(client_fixture) @@ -1121,3 +1159,369 @@ def test_list_entity_asset_directory_non_authorized( assert response.status_code == expected_status error = ErrorResponse.model_validate(response.json()) assert error.error_code == expected_error + + +def _multipart_json_data( + filesize=3 * 5 * 1024**2, filename="foo.swc", content_type="application/swc" +): + return { + "filename": filename, + "filesize": filesize, + "sha256_digest": "e3b7c1f0a9d4b8e6f2c0a5d9e1b4c8f6a0d3e7b2c9f4a6d8e5b1c0f9a2", + "preferred_part_count": 3, + "label": "morphology", + "content_type": content_type, + } + + +def test_multipart_asset_upload(client, entity, s3, s3_internal_bucket): + """Test iniating, uploading parts, and completing a multipart upload.""" + filesize = 3 * 5 * 1024**2 + + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(filesize), + ).json() + + expected_upload_meta = { + "upload_id": ANY, + "part_size": 5 * 1024**2, + "parts": [ + { + "part_number": 1, + "url": ANY, + }, + { + "part_number": 2, + "url": ANY, + }, + { + "part_number": 3, + "url": ANY, + }, + ], + } + + assert data["status"] == "uploading" + assert data["upload_meta"] == expected_upload_meta + + # check that asset is registered in the db with uploading status + asset_data = assert_request( + client.get, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}", + ).json() + assert asset_data["status"] == "uploading" + assert asset_data["upload_meta"] == expected_upload_meta + + # check that multipart upload has initiated + assert s3_multipart_upload_exists(s3, data["upload_meta"]["upload_id"], s3_internal_bucket) + + # three lines with repeating a, b, c of 5Mb each + file_part_bytes = [(letter * 5_242_879 + "\n").encode("utf-8") for letter in ("a", "b", "c")] + + # upload the file to s3 using parts (no presigned urls) + part_size = data["upload_meta"]["part_size"] + n_full_parts, remainder = divmod(filesize, part_size) + sizes = [part_size] * n_full_parts + if remainder: + sizes += [remainder] + + assert len(sizes) == len(data["upload_meta"]["parts"]) + + for i, part in enumerate(data["upload_meta"]["parts"]): + s3.upload_part( + Bucket=s3_internal_bucket, + Key=data["full_path"], + UploadId=data["upload_meta"]["upload_id"], + PartNumber=part["part_number"], + Body=file_part_bytes[i], + ) + + # sanity check that part data is uploaded + parts = s3.list_parts( + Bucket=s3_internal_bucket, + Key=data["full_path"], + UploadId=data["upload_meta"]["upload_id"], + )["Parts"] + + assert len(parts) == 3 + for i, part in enumerate(parts, start=1): + assert part["PartNumber"] == i + assert part["Size"] == part_size + + completed_data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}/multipart-upload/complete", + ).json() + + assert completed_data["status"] == "created" + assert completed_data["upload_meta"] is None + + # sanity check if db asset changes persisted + completed_data = assert_request( + client.get, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}", + ).json() + assert completed_data["status"] == "created" + assert completed_data["upload_meta"] is None + + # get file from s3 and check concatenation is correct + content = s3.get_object(Bucket=s3_internal_bucket, Key=data["full_path"])["Body"].read() + assert content == b"".join(file_part_bytes) + + # check no uploads left following completion + assert not s3_multipart_upload_exists(s3, data["upload_meta"]["upload_id"], s3_internal_bucket) + + +def test_multipart_asset_upload_abort(client, entity, s3, s3_internal_bucket): + filesize = 3 * 5 * 1024**2 + + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(filesize), + ).json() + + assert data["status"] == "uploading" + assert data["upload_meta"] + + # check that asset is registered in the db with uploading status + asset_data = assert_request( + client.get, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}", + ).json() + assert asset_data["status"] == "uploading" + + # check that multipart upload has initiated + assert s3_multipart_upload_exists(s3, data["upload_meta"]["upload_id"], s3_internal_bucket) + + assert_request( + client.delete, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}", + ) + + # uploading asset must be deleted + assert_request( + client.get, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}", + expected_status_code=404, + ) + + # check that the asset deletion triggered multipart upload abort + assert not s3_multipart_upload_exists(s3, data["upload_meta"]["upload_id"], s3_internal_bucket) + + +def test_initiate_entity_asset_upload_invalid_filesize(client, entity): + """Test initiate_entity_asset_upload with invalid filesize.""" + # Test with filesize exceeding max + max_size = settings.S3_MULTIPART_UPLOAD_MAX_SIZE + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(filesize=max_size + 1), + expected_status_code=422, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_INVALID_FILE + assert f"bigger than {max_size}" in error.message + + +def test_initiate_entity_asset_upload_invalid_filename(client, entity): + """Test initiate_entity_asset_upload with invalid filename.""" + # Test with invalid filename (path traversal) + + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(filename="../../etc/passwd"), + expected_status_code=422, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_INVALID_PATH + assert "Invalid file name" in error.message + + # Test with invalid filename (absolute path) + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(filename="/absolute/path/file.swc"), + expected_status_code=422, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_INVALID_PATH + + +def test_initiate_entity_asset_upload_invalid_content_type(client, entity): + """Test initiate_entity_asset_upload with invalid content type.""" + # Test with invalid content type + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(content_type="application/octet-stream"), + expected_status_code=422, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_INVALID_CONTENT_TYPE + assert "Invalid content type" in error.message + + +def test_initiate_entity_asset_upload_duplicate(client, entity): + """Test initiate_entity_asset_upload with duplicate asset.""" + # Create first upload + assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(), + ) + + # Try to create duplicate + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(), + expected_status_code=409, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_DUPLICATED + + +def test_initiate_entity_asset_upload_entity_not_found(client, entity): + """Test initiate_entity_asset_upload with non-existent entity.""" + response = assert_request( + client.post, + url=f"{route(entity.type)}/{MISSING_ID}/assets/multipart-upload/initiate", + json=_multipart_json_data(), + expected_status_code=404, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND + + +@pytest.mark.parametrize( + ("client_fixture", "expected_status", "expected_error"), + [ + ("client_user_2", 404, ApiErrorCode.ENTITY_NOT_FOUND), + ("client_no_project", 403, ApiErrorCode.NOT_AUTHORIZED), + ], +) +def test_initiate_entity_asset_upload_non_authorized( + request, client_fixture, expected_status, expected_error, entity +): + """Test initiate_entity_asset_upload with unauthorized user.""" + client = request.getfixturevalue(client_fixture) + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(), + expected_status_code=expected_status, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == expected_error + + +def test_complete_entity_asset_upload_asset_not_found(client, entity): + """Test complete_entity_asset_upload with non-existent asset.""" + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/{MISSING_ID}/multipart-upload/complete", + expected_status_code=404, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_NOT_FOUND + + +def test_complete_entity_asset_upload_entity_not_found(client, entity, uploading_asset): + """Test complete_entity_asset_upload with non-existent entity.""" + response = assert_request( + client.post, + url=f"{route(entity.type)}/{MISSING_ID}/assets/{uploading_asset.id}/multipart-upload/complete", + expected_status_code=404, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND + + +def test_complete_entity_asset_upload_not_uploading(client, entity, asset): + """Test complete_entity_asset_upload when asset is not in uploading status.""" + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/{asset.id}/multipart-upload/complete", + expected_status_code=422, + ).json() + error = ErrorResponse.model_validate(data) + assert error.error_code == ApiErrorCode.ASSET_NOT_UPLOADING + assert "not uploading" in error.message.lower() + + +@pytest.mark.parametrize( + ("client_fixture", "expected_status", "expected_error"), + [ + ("client_user_2", 404, ApiErrorCode.ENTITY_NOT_FOUND), + ("client_no_project", 404, ApiErrorCode.ENTITY_NOT_FOUND), + ], +) +def test_complete_entity_asset_upload_non_authorized( + request, client_fixture, expected_status, expected_error, entity, uploading_asset +): + """Test complete_entity_asset_upload with unauthorized user.""" + client = request.getfixturevalue(client_fixture) + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/{uploading_asset.id}/multipart-upload/complete", + expected_status_code=expected_status, + ).json() + error = ErrorResponse.model_validate(data) + assert error.error_code == expected_error + + +def test_complete_entity_asset_upload_incomplete(client, entity, s3, s3_internal_bucket): + """Test complete_entity_asset_upload when upload is incomplete.""" + filesize = 3 * 5 * 1024**2 + + # Initiate upload + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/multipart-upload/initiate", + json=_multipart_json_data(filesize=filesize), + ).json() + + # Upload only 2 out of 3 parts (incomplete) + file_part_bytes = [(letter * 5_242_879 + "\n").encode("utf-8") for letter in ("a", "b")] + + for i, part in enumerate(data["upload_meta"]["parts"][:2]): # Only upload 2 parts + s3.upload_part( + Bucket=s3_internal_bucket, + Key=data["full_path"], + UploadId=data["upload_meta"]["upload_id"], + PartNumber=part["part_number"], + Body=file_part_bytes[i], + ) + + # Try to complete - should fail because not all parts are uploaded + response = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}/multipart-upload/complete", + expected_status_code=409, + ) + error = ErrorResponse.model_validate(response.json()) + assert error.error_code == ApiErrorCode.ASSET_UPLOAD_INCOMPLETE + assert "Expected parts are not uploaded" in error.message + + # upload last part + s3.upload_part( + Bucket=s3_internal_bucket, + Key=data["full_path"], + UploadId=data["upload_meta"]["upload_id"], + PartNumber=data["upload_meta"]["parts"][-1]["part_number"], + Body=file_part_bytes[-1], + ) + + # now should complete + data = assert_request( + client.post, + url=f"{route(entity.type)}/{entity.id}/assets/{data['id']}/multipart-upload/complete", + ).json() + + assert data["status"] == "created" + assert data["upload_meta"] is None diff --git a/tests/test_brain_atlas.py b/tests/test_brain_atlas.py index d4310182..6793e2d9 100644 --- a/tests/test_brain_atlas.py +++ b/tests/test_brain_atlas.py @@ -104,6 +104,7 @@ def test_brain_atlas(db, client, species_id, person_id): "size": 31, "status": "created", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } ], "creation_date": ANY, @@ -177,6 +178,7 @@ def test_brain_atlas(db, client, species_id, person_id): "size": 31, "status": "created", "storage_type": StorageType.aws_s3_internal, + "upload_meta": None, } assert response.json()["data"] == [ { diff --git a/tests/test_utils/test_s3.py b/tests/test_utils/test_s3.py index ef05ee2f..80dcf4a6 100644 --- a/tests/test_utils/test_s3.py +++ b/tests/test_utils/test_s3.py @@ -1,5 +1,9 @@ +import math from pathlib import Path +import pytest + +from app.config import settings from app.db.types import EntityType from app.utils import s3 as test_module @@ -62,3 +66,112 @@ def test_sanitize_directory_traversal(): ] for p0, p1 in attack_paths: assert test_module.sanitize_directory_traversal(p0) == Path(p1), f"`{p0}` failed" + + +@pytest.mark.parametrize( + ("filesize", "preferred_part_count", "expected_part_size", "expected_part_count"), + [ + # preferred count below min + ( + 50 * 1024**2, + 1, + # clipped to min part count + math.ceil(50 * 1024**2 / settings.S3_MULTIPART_UPLOAD_MIN_PARTS), + settings.S3_MULTIPART_UPLOAD_MIN_PARTS, + ), + # preferred count above max + ( + 1024**4, + 300_000, + # part count clipped to max parts + math.ceil((1024**4) / settings.S3_MULTIPART_UPLOAD_MAX_PARTS), + settings.S3_MULTIPART_UPLOAD_MAX_PARTS, + ), + # filesize smaller than min part size + ( + settings.S3_MULTIPART_UPLOAD_MIN_PART_SIZE // 2, + 5, + settings.S3_MULTIPART_UPLOAD_MIN_PART_SIZE, + 1, # only one part needed + ), + # filesize larger than max part size (forces part count up) + ( + settings.S3_MULTIPART_UPLOAD_MAX_PART_SIZE * 10 + 123, + 5, + settings.S3_MULTIPART_UPLOAD_MAX_PART_SIZE, + 11, # ceil(10 + 123 / max_part_size) + ), + # exact multiple of preferred part count + ( + 50 * 1024**2, + 5, + math.ceil(50 * 1024 * 1024 / 5), + 5, + ), + # rounding up part size + ( + 50 * 1024**2 + 1, + 5, + math.ceil((50 * 1024 * 1024 + 1) / 5), + 5, + ), + # small file, preferred part count large + ( + 1024**2, + 100, + settings.S3_MULTIPART_UPLOAD_MIN_PART_SIZE, + 1, + ), + # huge file, preferred part count very small + ( + 500 * 1024**3, + 1, + settings.S3_MULTIPART_UPLOAD_MAX_PART_SIZE, + math.ceil(500 * 1024**3 / settings.S3_MULTIPART_UPLOAD_MAX_PART_SIZE), + ), + ], + ids=[ + "preferred-count-below-min", + "preferred-count-above-max", + "filesize-smaller-min-part-size", + "filesize-larger-max-part-size", + "exact-multiple-of-preferred-part-count", + "rounding-up-part-size", + "small-file-preferred-part-count-large", + "huge-file-preferred-part-count-very-small", + ], +) +def test_multipart_compute_upload_plan( + filesize, preferred_part_count, expected_part_size, expected_part_count +): + part_size, part_count = test_module.multipart_compute_upload_plan( + filesize=filesize, preferred_part_count=preferred_part_count + ) + + # check part count respects limits + assert ( + settings.S3_MULTIPART_UPLOAD_MIN_PARTS + <= part_count + <= settings.S3_MULTIPART_UPLOAD_MAX_PARTS + ) + + # check part size respects limits + assert ( + settings.S3_MULTIPART_UPLOAD_MIN_PART_SIZE + <= part_size + <= settings.S3_MULTIPART_UPLOAD_MAX_PART_SIZE + ) + + # check computed values match expectation + assert part_size == expected_part_size + assert part_count == expected_part_count + + # sanity check: all parts cover the full filesize + assert part_count * part_size >= filesize + + +def test_validate_multipart_filesize(): + """Test validate_multipart_filesize function.""" + max_size = settings.S3_MULTIPART_UPLOAD_MAX_SIZE + assert test_module.validate_multipart_filesize(max_size) is True + assert test_module.validate_multipart_filesize(max_size + 1) is False diff --git a/tests/utils.py b/tests/utils.py index 8af04084..6697f741 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1658,3 +1658,9 @@ def s3_key_exists(s3_client, key: str, storage_type=StorageType.aws_s3_internal) raise RuntimeError(msg) from e return True + + +def s3_multipart_upload_exists(s3_client, upload_id, bucket): + response = s3_client.list_multipart_uploads(Bucket=bucket) + upload_ids = {u["UploadId"] for u in response.get("Uploads", [])} + return upload_id in upload_ids