Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ayon_server/api/files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections.abc import Callable

import aiofiles
from fastapi import Request, Response
Expand All @@ -10,10 +11,17 @@
from ayon_server.helpers.statistics import update_traffic_stats


async def handle_upload(request: Request, target_path: str) -> int:
async def handle_upload(
request: Request,
target_path: str,
*,
content_validator: Callable[[bytes], None] | None = None,
) -> int:
"""Store raw body from the request to a file.

Returns file size in bytes.
Content validator function can be passed to validate the first bytes of the file.
It should raise a ValueError if the content is invalid.
"""

directory, _ = os.path.split(target_path)
Expand All @@ -25,9 +33,19 @@ async def handle_upload(request: Request, target_path: str) -> int:
raise AyonException(f"Failed to create directory: {e}") from e

i = 0
validation_buffer = b""
validated = False
try:
async with aiofiles.open(target_path, "wb") as f:
async for chunk in request.stream():
if content_validator and not validated:
validation_buffer += chunk
if len(validation_buffer) > 1024:
try:
content_validator(validation_buffer)
validated = True
except ValueError as e:
raise BadRequestException(f"Invalid file: {e}") from e
await f.write(chunk)
i += len(chunk)
except Exception as e:
Expand Down
8 changes: 7 additions & 1 deletion ayon_server/files/project_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
from collections.abc import Callable
from typing import Any

import aiocache
Expand Down Expand Up @@ -243,6 +244,7 @@ async def handle_upload(
file_id: str,
file_group: FileGroup = "uploads",
*,
content_validator: Callable[[bytes], None] | None = None,
content_type: str | None = None,
content_disposition: str | None = None,
) -> int:
Expand All @@ -254,7 +256,11 @@ async def handle_upload(
logger.debug(f"Uploading file {file_id} to {self} ({file_group})")
path = await self.get_path(file_id, file_group=file_group)
if self.storage_type == "local":
return await handle_upload(request, path)
return await handle_upload(
request,
path,
content_validator=content_validator,
)
elif self.storage_type == "s3":
return await handle_s3_upload(
self,
Expand Down