diff --git a/ayon_server/api/files.py b/ayon_server/api/files.py index 33099c862..05d6c4a64 100644 --- a/ayon_server/api/files.py +++ b/ayon_server/api/files.py @@ -1,4 +1,5 @@ import os +from collections.abc import Callable import aiofiles from fastapi import Request, Response @@ -10,10 +11,17 @@ from ayon_server.helpers.statistics import update_traffic_stats -async def handle_upload(request: Request, target_path: str) -> int: +async def handle_upload( + request: Request, + target_path: str, + *, + content_validator: Callable[[bytes], None] | None = None, +) -> int: """Store raw body from the request to a file. Returns file size in bytes. + Content validator function can be passed to validate the first bytes of the file. + It should raise a ValueError if the content is invalid. """ directory, _ = os.path.split(target_path) @@ -25,9 +33,19 @@ async def handle_upload(request: Request, target_path: str) -> int: raise AyonException(f"Failed to create directory: {e}") from e i = 0 + validation_buffer = b"" + validated = False try: async with aiofiles.open(target_path, "wb") as f: async for chunk in request.stream(): + if content_validator and not validated: + validation_buffer += chunk + if len(validation_buffer) > 1024: + try: + content_validator(validation_buffer) + validated = True + except ValueError as e: + raise BadRequestException(f"Invalid file: {e}") from e await f.write(chunk) i += len(chunk) except Exception as e: diff --git a/ayon_server/files/project_storage.py b/ayon_server/files/project_storage.py index 6652951fd..7346efcfc 100644 --- a/ayon_server/files/project_storage.py +++ b/ayon_server/files/project_storage.py @@ -1,5 +1,6 @@ import os import time +from collections.abc import Callable from typing import Any import aiocache @@ -243,6 +244,7 @@ async def handle_upload( file_id: str, file_group: FileGroup = "uploads", *, + content_validator: Callable[[bytes], None] | None = None, content_type: str | None = None, content_disposition: str | None = None, ) -> int: @@ -254,7 +256,11 @@ async def handle_upload( logger.debug(f"Uploading file {file_id} to {self} ({file_group})") path = await self.get_path(file_id, file_group=file_group) if self.storage_type == "local": - return await handle_upload(request, path) + return await handle_upload( + request, + path, + content_validator=content_validator, + ) elif self.storage_type == "s3": return await handle_s3_upload( self,