diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala index 63c09f4c30b..2ad1c36d21e 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala @@ -358,4 +358,11 @@ object LakeFSStorageClient { branchesApi.resetBranch(repoName, branchName, resetCreation).execute() } + def parsePhysicalAddress(address: String): (String, String) = { + // expected: "://bucket/key..." + val uri = new java.net.URI(address) + val bucket = uri.getHost + val key = uri.getPath.stripPrefix("/") + (bucket, key) + } } diff --git a/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala b/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala index 94007e988e5..2c243b16d3c 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala @@ -259,4 +259,59 @@ object S3StorageClient { DeleteObjectRequest.builder().bucket(bucketName).key(objectKey).build() ) } + + /** + * Uploads a single part for an in-progress S3 multipart upload. + * + * This method wraps the AWS SDK v2 {@code UploadPart} API: + * it builds an {@link software.amazon.awssdk.services.s3.model.UploadPartRequest} + * and streams the part payload via a {@link software.amazon.awssdk.core.sync.RequestBody}. + * + * Payload handling: + * - If {@code contentLength} is provided, the payload is streamed directly from {@code inputStream} + * using {@code RequestBody.fromInputStream(inputStream, len)}. + * - If {@code contentLength} is {@code None}, the entire {@code inputStream} is read into memory + * ({@code readAllBytes}) and uploaded using {@code RequestBody.fromBytes(bytes)}. + * This is convenient but can be memory-expensive for large parts; prefer providing a known length. + * + * Notes: + * - {@code partNumber} must be in the valid S3 range (typically 1..10,000). + * - The caller is responsible for closing {@code inputStream}. + * - This method is synchronous and will block the calling thread until the upload completes. + * + * @param bucket S3 bucket name. + * @param key Object key (path) being uploaded. + * @param uploadId Multipart upload identifier returned by CreateMultipartUpload. + * @param partNumber 1-based part number for this upload. + * @param inputStream Stream containing the bytes for this part. + * @param contentLength Optional size (in bytes) of this part; provide it to avoid buffering in memory. + * @return The {@link software.amazon.awssdk.services.s3.model.UploadPartResponse}, + * including the part ETag used for completing the multipart upload. + */ + def uploadPart( + bucket: String, + key: String, + uploadId: String, + partNumber: Int, + inputStream: InputStream, + contentLength: Option[Long] + ): UploadPartResponse = { + val body: RequestBody = contentLength match { + case Some(len) => RequestBody.fromInputStream(inputStream, len) + case None => + val bytes = inputStream.readAllBytes() + RequestBody.fromBytes(bytes) + } + + val req = UploadPartRequest + .builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(partNumber) + .build() + + s3Client.uploadPart(req, body) + } + } diff --git a/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala b/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala index 2a67440cf0e..cf9a2266ee5 100644 --- a/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala +++ b/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala @@ -54,7 +54,8 @@ import org.apache.texera.service.util.S3StorageClient.{ MINIMUM_NUM_OF_MULTIPART_S3_PART } import org.jooq.{DSLContext, EnumType} - +import org.jooq.impl.DSL +import org.jooq.impl.DSL.{inline => inl} import java.io.{InputStream, OutputStream} import java.net.{HttpURLConnection, URL, URLDecoder} import java.nio.charset.StandardCharsets @@ -65,6 +66,13 @@ import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import scala.jdk.OptionConverters._ +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSession.DATASET_UPLOAD_SESSION +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSessionPart.DATASET_UPLOAD_SESSION_PART +import org.jooq.exception.DataAccessException +import software.amazon.awssdk.services.s3.model.UploadPartResponse + +import java.sql.SQLException +import scala.util.Try object DatasetResource { @@ -89,11 +97,11 @@ object DatasetResource { */ private def put(buf: Array[Byte], len: Int, url: String, partNum: Int): String = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] - conn.setDoOutput(true); + conn.setDoOutput(true) conn.setRequestMethod("PUT") conn.setFixedLengthStreamingMode(len) val out = conn.getOutputStream - out.write(buf, 0, len); + out.write(buf, 0, len) out.close() val code = conn.getResponseCode @@ -401,7 +409,6 @@ class DatasetResource { e ) } - // delete the directory on S3 if ( S3StorageClient.directoryExists(StorageConfig.lakefsBucketName, dataset.getRepositoryName) @@ -639,138 +646,165 @@ class DatasetResource { @QueryParam("type") operationType: String, @QueryParam("ownerEmail") ownerEmail: String, @QueryParam("datasetName") datasetName: String, - @QueryParam("filePath") encodedUrl: String, - @QueryParam("uploadId") uploadId: Optional[String], + @QueryParam("filePath") filePath: String, @QueryParam("numParts") numParts: Optional[Integer], - payload: Map[ - String, - Any - ], // Expecting {"parts": [...], "physicalAddress": "s3://bucket/path"} @Auth user: SessionUser ): Response = { val uid = user.getUid + val dataset: Dataset = getDatasetBy(ownerEmail, datasetName) + + operationType.toLowerCase match { + case "init" => initMultipartUpload(dataset.getDid, filePath, numParts, uid) + case "finish" => finishMultipartUpload(dataset.getDid, filePath, uid) + case "abort" => abortMultipartUpload(dataset.getDid, filePath, uid) + case _ => + throw new BadRequestException("Invalid type parameter. Use 'init', 'finish', or 'abort'.") + } + } + + @POST + @RolesAllowed(Array("REGULAR", "ADMIN")) + @Path("/multipart-upload/part") + @Consumes(Array(MediaType.APPLICATION_OCTET_STREAM)) + def uploadPart( + @QueryParam("ownerEmail") ownerEmail: String, + @QueryParam("datasetName") datasetName: String, + @QueryParam("filePath") encodedFilePath: String, + @QueryParam("partNumber") partNumber: Int, + partStream: InputStream, + @Context headers: HttpHeaders, + @Auth user: SessionUser + ): Response = { + + val uid = user.getUid + val dataset: Dataset = getDatasetBy(ownerEmail, datasetName) + val did = dataset.getDid + + if (encodedFilePath == null || encodedFilePath.isEmpty) + throw new BadRequestException("filePath is required") + if (partNumber < 1) + throw new BadRequestException("partNumber must be >= 1") + + val filePath = validateFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) + + val contentLength = + Option(headers.getHeaderString(HttpHeaders.CONTENT_LENGTH)) + .map(_.trim) + .flatMap(s => Try(s.toLong).toOption) + .filter(_ > 0) + .getOrElse { + throw new BadRequestException("Invalid/Missing Content-Length") + } withTransaction(context) { ctx => - val dataset = context - .select(DATASET.fields: _*) - .from(DATASET) - .leftJoin(USER) - .on(USER.UID.eq(DATASET.OWNER_UID)) - .where(USER.EMAIL.eq(ownerEmail)) - .and(DATASET.NAME.eq(datasetName)) - .fetchOneInto(classOf[Dataset]) - if (dataset == null || !userHasWriteAccess(ctx, dataset.getDid, uid)) { + if (!userHasWriteAccess(ctx, did, uid)) throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) - } - - // Decode the file path - val repositoryName = dataset.getRepositoryName - val filePath = URLDecoder.decode(encodedUrl, StandardCharsets.UTF_8.name()) - operationType.toLowerCase match { - case "init" => - val numPartsValue = numParts.toScala.getOrElse( - throw new BadRequestException("numParts is required for initialization") - ) + val session = ctx + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .fetchOne() - val presignedResponse = LakeFSStorageClient.initiatePresignedMultipartUploads( - repositoryName, - filePath, - numPartsValue - ) - Response - .ok( - Map( - "uploadId" -> presignedResponse.getUploadId, - "presignedUrls" -> presignedResponse.getPresignedUrls, - "physicalAddress" -> presignedResponse.getPhysicalAddress - ) - ) - .build() + if (session == null) + throw new NotFoundException("Upload session not found. Call type=init first.") - case "finish" => - val uploadIdValue = uploadId.toScala.getOrElse( - throw new BadRequestException("uploadId is required for completion") - ) + val expectedParts = session.getNumPartsRequested + if (partNumber > expectedParts) { + throw new BadRequestException( + s"$partNumber exceeds the requested parts on init: $expectedParts" + ) + } - // Extract parts from the payload - val partsList = payload.get("parts") match { - case Some(rawList: List[_]) => - try { - rawList.map { - case part: Map[_, _] => - val partMap = part.asInstanceOf[Map[String, Any]] - val partNumber = partMap.get("PartNumber") match { - case Some(i: Int) => i - case Some(s: String) => s.toInt - case _ => throw new BadRequestException("Invalid or missing PartNumber") - } - val eTag = partMap.get("ETag") match { - case Some(s: String) => s - case _ => throw new BadRequestException("Invalid or missing ETag") - } - (partNumber, eTag) - - case _ => - throw new BadRequestException("Each part must be a Map[String, Any]") - } - } catch { - case e: NumberFormatException => - throw new BadRequestException("PartNumber must be an integer", e) - } - - case _ => - throw new BadRequestException("Missing or invalid 'parts' list in payload") - } + if (partNumber < expectedParts && contentLength < MINIMUM_NUM_OF_MULTIPART_S3_PART) { + throw new BadRequestException( + s"Part $partNumber is too small ($contentLength bytes). " + + s"All non-final parts must be >= $MINIMUM_NUM_OF_MULTIPART_S3_PART bytes." + ) + } - // Extract physical address from payload - val physicalAddress = payload.get("physicalAddress") match { - case Some(address: String) => address - case _ => throw new BadRequestException("Missing physicalAddress in payload") - } + val physicalAddr = Option(session.getPhysicalAddress).map(_.trim).getOrElse("") + if (physicalAddr.isEmpty) { + throw new WebApplicationException( + "Upload session is missing physicalAddress. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } - // Complete the multipart upload with parts and physical address - val objectStats = LakeFSStorageClient.completePresignedMultipartUploads( - repositoryName, - filePath, - uploadIdValue, - partsList, - physicalAddress - ) + val uploadId = session.getUploadId + val (bucket, key) = LakeFSStorageClient.parsePhysicalAddress(physicalAddr) - Response - .ok( - Map( - "message" -> "Multipart upload completed successfully", - "filePath" -> objectStats.getPath - ) + // Per-part lock: if another request is streaming the same part, fail fast. + val partRow = + try { + ctx + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(partNumber)) ) - .build() - - case "abort" => - val uploadIdValue = uploadId.toScala.getOrElse( - throw new BadRequestException("uploadId is required for abortion") - ) + .forUpdate() + .noWait() + .fetchOne() + } catch { + case e: DataAccessException + if Option(e.getCause) + .collect { case s: SQLException => s.getSQLState } + .contains("55P03") => + throw new WebApplicationException( + s"Part $partNumber is already being uploaded", + Response.Status.CONFLICT + ) + } - // Extract physical address from payload - val physicalAddress = payload.get("physicalAddress") match { - case Some(address: String) => address - case _ => throw new BadRequestException("Missing physicalAddress in payload") - } + if (partRow == null) { + // Should not happen if init pre-created rows + throw new WebApplicationException( + s"Part row not initialized for part $partNumber. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } - // Abort the multipart upload - LakeFSStorageClient.abortPresignedMultipartUploads( - repositoryName, - filePath, - uploadIdValue, - physicalAddress + // Idempotency: if ETag already set, accept the retry quickly. + val existing = Option(partRow.getEtag).map(_.trim).getOrElse("") + if (existing.isEmpty) { + // Stream to S3 while holding the part lock (prevents concurrent streams for same part) + val response: UploadPartResponse = + S3StorageClient.uploadPart( + bucket = bucket, + key = key, + uploadId = uploadId, + partNumber = partNumber, + inputStream = partStream, + contentLength = Some(contentLength) ) - Response.ok(Map("message" -> "Multipart upload aborted successfully")).build() + val etagClean = Option(response.eTag()).map(_.replace("\"", "")).map(_.trim).getOrElse("") + if (etagClean.isEmpty) { + throw new WebApplicationException( + s"Missing ETag returned from S3 for part $partNumber", + Response.Status.INTERNAL_SERVER_ERROR + ) + } - case _ => - throw new BadRequestException("Invalid type parameter. Use 'init', 'finish', or 'abort'.") + ctx + .update(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, etagClean) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(partNumber)) + ) + .execute() } + Response.ok().build() } } @@ -1014,9 +1048,8 @@ class DatasetResource { val ownerNode = DatasetFileNode .fromLakeFSRepositoryCommittedObjects( Map( - (user.getEmail, dataset.getName, latestVersion.getName) -> - LakeFSStorageClient - .retrieveObjectsOfVersion(dataset.getRepositoryName, latestVersion.getVersionHash) + (user.getEmail, dataset.getName, latestVersion.getName) -> LakeFSStorageClient + .retrieveObjectsOfVersion(dataset.getRepositoryName, latestVersion.getVersionHash) ) ) .head @@ -1372,4 +1405,378 @@ class DatasetResource { Right(response) } } + + // === Multipart helpers === + + private def getDatasetBy(ownerEmail: String, datasetName: String) = { + val dataset = context + .select(DATASET.fields: _*) + .from(DATASET) + .leftJoin(USER) + .on(USER.UID.eq(DATASET.OWNER_UID)) + .where(USER.EMAIL.eq(ownerEmail)) + .and(DATASET.NAME.eq(datasetName)) + .fetchOneInto(classOf[Dataset]) + if (dataset == null) { + throw new BadRequestException("Dataset not found") + } + dataset + } + + private def validateFilePathOrThrow(filePath: String): String = { + val p = Option(filePath).getOrElse("") + val s = p.replace("\\", "/") + if ( + p.isEmpty || + s.startsWith("/") || + s.split("/").exists(seg => seg == "." || seg == "..") || + s.exists(ch => ch == 0.toChar || ch < 0x20.toChar || ch == 0x7f.toChar) + ) throw new BadRequestException("Invalid filePath") + p + } + + private def initMultipartUpload( + did: Integer, + encodedFilePath: String, + numParts: Optional[Integer], + uid: Integer + ): Response = { + + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val dataset = getDatasetByID(ctx, did) + val repositoryName = dataset.getRepositoryName + + val filePath = + validateFilePathOrThrow(URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name())) + + val numPartsValue = numParts.toScala.getOrElse { + throw new BadRequestException("numParts is required for initialization") + } + if (numPartsValue < 1 || numPartsValue > MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) { + throw new BadRequestException( + "numParts must be between 1 and " + MAXIMUM_NUM_OF_MULTIPART_S3_PARTS + ) + } + + // Reject if a session already exists + val exists = ctx.fetchExists( + ctx + .selectOne() + .from(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + ) + if (exists) { + throw new WebApplicationException( + "Upload already in progress for this filePath", + Response.Status.CONFLICT + ) + } + + val presign = LakeFSStorageClient.initiatePresignedMultipartUploads( + repositoryName, + filePath, + numPartsValue + ) + + val uploadIdStr = presign.getUploadId + val physicalAddr = presign.getPhysicalAddress + + // If anything fails after this point, abort LakeFS multipart + try { + val rowsInserted = ctx + .insertInto(DATASET_UPLOAD_SESSION) + .set(DATASET_UPLOAD_SESSION.FILE_PATH, filePath) + .set(DATASET_UPLOAD_SESSION.DID, did) + .set(DATASET_UPLOAD_SESSION.UID, uid) + .set(DATASET_UPLOAD_SESSION.UPLOAD_ID, uploadIdStr) + .set(DATASET_UPLOAD_SESSION.PHYSICAL_ADDRESS, physicalAddr) + .set(DATASET_UPLOAD_SESSION.NUM_PARTS_REQUESTED, numPartsValue) + .onDuplicateKeyIgnore() + .execute() + + if (rowsInserted != 1) { + LakeFSStorageClient.abortPresignedMultipartUploads( + repositoryName, + filePath, + uploadIdStr, + physicalAddr + ) + throw new WebApplicationException( + "Upload already in progress for this filePath", + Response.Status.CONFLICT + ) + } + + // Pre-create part rows 1..numPartsValue with empty ETag. + // This makes per-part locking cheap and deterministic. + + val gs = DSL.generateSeries(1, numPartsValue).asTable("gs", "pn") + val PN = gs.field("pn", classOf[Integer]) + + ctx + .insertInto( + DATASET_UPLOAD_SESSION_PART, + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID, + DATASET_UPLOAD_SESSION_PART.PART_NUMBER, + DATASET_UPLOAD_SESSION_PART.ETAG + ) + .select( + ctx + .select( + inl(uploadIdStr), + PN, + inl("") // placeholder empty etag + ) + .from(gs) + ) + .execute() + + Response.ok().build() + } catch { + case e: Exception => + // rollback will remove session + parts rows; we still must abort LakeFS + try { + LakeFSStorageClient.abortPresignedMultipartUploads( + repositoryName, + filePath, + uploadIdStr, + physicalAddr + ) + } catch { case _: Throwable => () } + throw e + } + } + } + + private def finishMultipartUpload( + did: Integer, + encodedFilePath: String, + uid: Int + ): Response = { + + val filePath = validateFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) + + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val dataset = getDatasetByID(ctx, did) + + // Lock the session so abort/finish don't race each other + val session = + try { + ctx + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .noWait() + .fetchOne() + } catch { + case e: DataAccessException + if Option(e.getCause) + .collect { case s: SQLException => s.getSQLState } + .contains("55P03") => + throw new WebApplicationException( + "Upload is already being finalized/aborted", + Response.Status.CONFLICT + ) + } + + if (session == null) { + throw new NotFoundException("Upload session not found or already finalized") + } + + val uploadId = session.getUploadId + val expectedParts = session.getNumPartsRequested + + val physicalAddr = Option(session.getPhysicalAddress).map(_.trim).getOrElse("") + if (physicalAddr.isEmpty) { + throw new WebApplicationException( + "Upload session is missing physicalAddress. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + val total = DSL.count() + val done = + DSL + .count() + .filterWhere(DATASET_UPLOAD_SESSION_PART.ETAG.ne("")) + .as("done") + + val agg = ctx + .select(total.as("total"), done) + .from(DATASET_UPLOAD_SESSION_PART) + .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId)) + .fetchOne() + + val totalCnt = agg.get("total", classOf[java.lang.Integer]).intValue() + val doneCnt = agg.get("done", classOf[java.lang.Integer]).intValue() + + if (totalCnt != expectedParts.toLong) { + throw new WebApplicationException( + s"Part table mismatch: expected $expectedParts rows but found $total. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + if (doneCnt != expectedParts.toLong) { + val missing = ctx + .select(DATASET_UPLOAD_SESSION_PART.PART_NUMBER) + .from(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.ETAG.eq("")) + ) + .orderBy(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.asc()) + .limit(50) + .fetch(DATASET_UPLOAD_SESSION_PART.PART_NUMBER) + .asScala + .toList + + throw new WebApplicationException( + s"Upload incomplete. Some missing ETags for parts are: ${missing.mkString(",")}", + Response.Status.CONFLICT + ) + } + + // Build partsList in order + val partsList: List[(Int, String)] = + ctx + .select(DATASET_UPLOAD_SESSION_PART.PART_NUMBER, DATASET_UPLOAD_SESSION_PART.ETAG) + .from(DATASET_UPLOAD_SESSION_PART) + .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId)) + .orderBy(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.asc()) + .fetch() + .asScala + .map(r => + ( + r.get(DATASET_UPLOAD_SESSION_PART.PART_NUMBER).intValue(), + r.get(DATASET_UPLOAD_SESSION_PART.ETAG) + ) + ) + .toList + + val objectStats = LakeFSStorageClient.completePresignedMultipartUploads( + dataset.getRepositoryName, + filePath, + uploadId, + partsList, + physicalAddr + ) + + // Cleanup: delete the session; parts are removed by ON DELETE CASCADE + ctx + .deleteFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .execute() + + Response + .ok( + Map( + "message" -> "Multipart upload completed successfully", + "filePath" -> objectStats.getPath + ) + ) + .build() + } + } + + private def abortMultipartUpload( + did: Integer, + encodedFilePath: String, + uid: Int + ): Response = { + + val filePath = validateFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) + + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val dataset = getDatasetByID(ctx, did) + + val session = + try { + ctx + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .noWait() + .fetchOne() + } catch { + case e: DataAccessException + if Option(e.getCause) + .collect { case s: SQLException => s.getSQLState } + .contains("55P03") => + throw new WebApplicationException( + "Upload is already being finalized/aborted", + Response.Status.CONFLICT + ) + } + + if (session == null) { + throw new NotFoundException("Upload session not found or already finalized") + } + + val physicalAddr = Option(session.getPhysicalAddress).map(_.trim).getOrElse("") + if (physicalAddr.isEmpty) { + throw new WebApplicationException( + "Upload session is missing physicalAddress. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + LakeFSStorageClient.abortPresignedMultipartUploads( + dataset.getRepositoryName, + filePath, + session.getUploadId, + physicalAddr + ) + + // Delete session; parts removed via ON DELETE CASCADE + ctx + .deleteFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .execute() + + Response.ok(Map("message" -> "Multipart upload aborted successfully")).build() + } + } } diff --git a/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala b/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala index fd1f0b8c903..554f7de055b 100644 --- a/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala +++ b/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala @@ -20,11 +20,18 @@ package org.apache.texera.service import com.dimafeng.testcontainers._ +import io.lakefs.clients.sdk.{ApiClient, RepositoriesApi} import org.apache.texera.amber.config.StorageConfig import org.apache.texera.service.util.S3StorageClient import org.scalatest.{BeforeAndAfterAll, Suite} import org.testcontainers.containers.Network import org.testcontainers.utility.DockerImageName +import software.amazon.awssdk.auth.credentials.{AwsBasicCredentials, StaticCredentialsProvider} +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.s3.S3Client +import software.amazon.awssdk.services.s3.S3Configuration + +import java.net.URI /** * Trait to spin up a LakeFS + MinIO + Postgres stack using Testcontainers, @@ -58,9 +65,14 @@ trait MockLakeFS extends ForAllTestContainer with BeforeAndAfterAll { self: Suit s"postgresql://${postgres.username}:${postgres.password}" + s"@${postgres.container.getNetworkAliases.get(0)}:5432/${postgres.databaseName}" + s"?sslmode=disable" + val lakefsUsername = "texera-admin" + + // These are the API credentials created/used during setup. + // In lakeFS, the access key + secret key are used as basic-auth username/password for the API. val lakefsAccessKeyID = "AKIAIOSFOLKFSSAMPLES" val lakefsSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + val lakefs = GenericContainer( dockerImage = "treeverse/lakefs:1.51", exposedPorts = Seq(8000), @@ -87,11 +99,44 @@ trait MockLakeFS extends ForAllTestContainer with BeforeAndAfterAll { self: Suit def lakefsBaseUrl: String = s"http://${lakefs.host}:${lakefs.mappedPort(8000)}" def minioEndpoint: String = s"http://${minio.host}:${minio.mappedPort(9000)}" + def lakefsApiBasePath: String = s"$lakefsBaseUrl/api/v1" + + // ---- Clients (lazy so they initialize after containers are started) ---- + + lazy val lakefsApiClient: ApiClient = { + val c = new ApiClient() + c.setBasePath(lakefsApiBasePath) + // basic-auth for lakeFS API uses accessKey as username, secretKey as password + c.setUsername(lakefsAccessKeyID) + c.setPassword(lakefsSecretAccessKey) + c + } + + lazy val repositoriesApi: RepositoriesApi = new RepositoriesApi(lakefsApiClient) + + /** + * S3 client pointed at MinIO. + * + * Notes: + * - Region can be any value for MinIO, but MUST match what your signing expects. + * so we use that. + * - Path-style is important: http://host:port/bucket/key + */ + lazy val s3Client: S3Client = { + val creds = AwsBasicCredentials.create("texera_minio", "password") + S3Client + .builder() + .endpointOverride(URI.create(StorageConfig.s3Endpoint)) // set in afterStart() + .region(Region.US_WEST_2) + .credentialsProvider(StaticCredentialsProvider.create(creds)) + .serviceConfiguration(S3Configuration.builder().pathStyleAccessEnabled(true).build()) + .build() + } override def afterStart(): Unit = { super.afterStart() - // setup LakeFS + // setup LakeFS (idempotent-ish, but will fail if it truly cannot run) val lakefsSetupResult = lakefs.container.execInContainer( "lakefs", "setup", @@ -103,16 +148,14 @@ trait MockLakeFS extends ForAllTestContainer with BeforeAndAfterAll { self: Suit lakefsSecretAccessKey ) if (lakefsSetupResult.getExitCode != 0) { - throw new RuntimeException( - s"Failed to setup LakeFS: ${lakefsSetupResult.getStderr}" - ) + throw new RuntimeException(s"Failed to setup LakeFS: ${lakefsSetupResult.getStderr}") } // replace storage endpoints in StorageConfig StorageConfig.s3Endpoint = minioEndpoint - StorageConfig.lakefsEndpoint = s"$lakefsBaseUrl/api/v1" + StorageConfig.lakefsEndpoint = lakefsApiBasePath - // create S3 bucket + // create S3 bucket used by lakeFS in tests S3StorageClient.createBucketIfNotExist(StorageConfig.lakefsBucketName) } } diff --git a/file-service/src/test/scala/org/apache/texera/service/resource/DatasetMultipartUploadSpec.scala b/file-service/src/test/scala/org/apache/texera/service/resource/DatasetMultipartUploadSpec.scala new file mode 100644 index 00000000000..c1292991b22 --- /dev/null +++ b/file-service/src/test/scala/org/apache/texera/service/resource/DatasetMultipartUploadSpec.scala @@ -0,0 +1,1052 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.service.resource + +import jakarta.ws.rs._ +import jakarta.ws.rs.core.{Cookie, HttpHeaders, MediaType, MultivaluedHashMap, Response} +import io.lakefs.clients.sdk.ApiException +import org.apache.texera.amber.core.storage.util.LakeFSStorageClient +import org.apache.texera.auth.SessionUser +import org.apache.texera.dao.MockTexeraDB +import org.apache.texera.dao.jooq.generated.enums.UserRoleEnum +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSession.DATASET_UPLOAD_SESSION +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSessionPart.DATASET_UPLOAD_SESSION_PART +import org.apache.texera.dao.jooq.generated.tables.daos.{DatasetDao, UserDao} +import org.apache.texera.dao.jooq.generated.tables.pojos.{Dataset, User} +import org.apache.texera.service.MockLakeFS +import org.jooq.SQLDialect +import org.jooq.impl.DSL +import org.scalatest.tagobjects.Slow +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.jdk.CollectionConverters._ +import java.io.{ByteArrayInputStream, IOException, InputStream} +import java.net.URLEncoder +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} +import java.security.MessageDigest +import java.util.concurrent.CyclicBarrier +import java.util.{Collections, Date, Locale, Optional} +import scala.util.Random + +object StressMultipart extends Tag("org.apache.texera.stress.multipart") + +class DatasetMultipartUploadSpec + extends AnyFlatSpec + with Matchers + with MockTexeraDB + with MockLakeFS + with BeforeAndAfterAll + with BeforeAndAfterEach { + + // ---------- execution context (for race tests) ---------- + private implicit val ec: ExecutionContext = ExecutionContext.global + + // ---------- test fixtures ---------- + private val testUser: User = { + val u = new User + u.setName("multipart_user") + u.setPassword("123") + u.setEmail("multipart_user@test.com") + u.setRole(UserRoleEnum.ADMIN) + u + } + + // REGULAR user, but no WRITE access to someone else's dataset. + private val testUser2: User = { + val u = new User + u.setName("multipart_user2") + u.setPassword("123") + u.setEmail("multipart_user2@test.com") + u.setRole(UserRoleEnum.REGULAR) + u + } + + private val testRepoName: String = + s"multipart-ds-${System.nanoTime()}-${Random.alphanumeric.take(6).mkString.toLowerCase}" + + private val testDataset: Dataset = { + val ds = new Dataset + ds.setName("multipart-ds") + ds.setRepositoryName(testRepoName) + ds.setIsPublic(true) + ds.setIsDownloadable(true) + ds.setDescription("dataset for multipart upload tests") + ds + } + + lazy val datasetDao = new DatasetDao(getDSLContext.configuration()) + lazy val datasetResource = new DatasetResource() + + lazy val sessionUser = new SessionUser(testUser) + lazy val sessionUser2 = new SessionUser(testUser2) + + // ---------- lifecycle ---------- + override protected def beforeAll(): Unit = { + super.beforeAll() + + initializeDBAndReplaceDSLContext() + + val userDao = new UserDao(getDSLContext.configuration()) + userDao.insert(testUser) + userDao.insert(testUser2) + + testDataset.setOwnerUid(testUser.getUid) + datasetDao.insert(testDataset) + } + + override protected def afterAll(): Unit = { + try shutdownDB() + finally super.afterAll() + } + + override protected def beforeEach(): Unit = { + super.beforeEach() + // Repo must exist for presigned multipart init to succeed. + // If it already exists, ignore 409. + try LakeFSStorageClient.initRepo(testDataset.getRepositoryName) + catch { + case e: ApiException if e.getCode == 409 => // ok + } + } + // ---------- SHA-256 Utils ---------- + private def sha256OfChunks(chunks: Seq[Array[Byte]]): Array[Byte] = { + val md = MessageDigest.getInstance("SHA-256") + chunks.foreach(md.update) + md.digest() + } + + private def sha256OfFile(path: java.nio.file.Path): Array[Byte] = { + val md = MessageDigest.getInstance("SHA-256") + val in = Files.newInputStream(path) + try { + val buf = new Array[Byte](8192) + var n = in.read(buf) + while (n != -1) { + md.update(buf, 0, n) + n = in.read(buf) + } + md.digest() + } finally in.close() + } + // ---------- helpers ---------- + private def enc(s: String): String = + URLEncoder.encode(s, StandardCharsets.UTF_8.name()) + + /** Minimum part-size rule (S3-style): every part except the LAST must be >= 5 MiB. */ + private val MinNonFinalPartBytes: Int = 5 * 1024 * 1024 + private def minPartBytes(b: Byte): Array[Byte] = + Array.fill[Byte](MinNonFinalPartBytes)(b) + + private def tinyBytes(b: Byte, n: Int = 1): Array[Byte] = + Array.fill[Byte](n)(b) + + /** Minimal HttpHeaders impl needed by DatasetResource.uploadPart */ + private def mkHeaders(contentLength: Long): HttpHeaders = + new HttpHeaders { + private val headers = new MultivaluedHashMap[String, String]() + headers.putSingle(HttpHeaders.CONTENT_LENGTH, contentLength.toString) + + override def getHeaderString(name: String): String = headers.getFirst(name) + override def getRequestHeaders = headers + override def getRequestHeader(name: String) = + Option(headers.get(name)).getOrElse(Collections.emptyList[String]()) + + override def getAcceptableMediaTypes = Collections.emptyList[MediaType]() + override def getAcceptableLanguages = Collections.emptyList[Locale]() + override def getMediaType: MediaType = null + override def getLanguage: Locale = null + override def getCookies = Collections.emptyMap[String, Cookie]() + override def getDate: Date = null + override def getLength: Int = contentLength.toInt + } + + private def mkHeadersMissingContentLength: HttpHeaders = + new HttpHeaders { + private val headers = new MultivaluedHashMap[String, String]() + override def getHeaderString(name: String): String = null + override def getRequestHeaders = headers + override def getRequestHeader(name: String) = Collections.emptyList[String]() + override def getAcceptableMediaTypes = Collections.emptyList[MediaType]() + override def getAcceptableLanguages = Collections.emptyList[Locale]() + override def getMediaType: MediaType = null + override def getLanguage: Locale = null + override def getCookies = Collections.emptyMap[String, Cookie]() + override def getDate: Date = null + override def getLength: Int = -1 + } + + private def uniqueFilePath(prefix: String): String = + s"$prefix/${System.nanoTime()}-${Random.alphanumeric.take(8).mkString}.bin" + + private def initUpload( + filePath: String, + numParts: Int, + user: SessionUser = sessionUser + ): Response = + datasetResource.multipartUpload( + "init", + testUser.getEmail, + testDataset.getName, + enc(filePath), + Optional.of(numParts), + user + ) + + private def finishUpload(filePath: String, user: SessionUser = sessionUser): Response = + datasetResource.multipartUpload( + "finish", + testUser.getEmail, + testDataset.getName, + enc(filePath), + Optional.empty(), + user + ) + + private def abortUpload(filePath: String, user: SessionUser = sessionUser): Response = + datasetResource.multipartUpload( + "abort", + testUser.getEmail, + testDataset.getName, + enc(filePath), + Optional.empty(), + user + ) + + private def uploadPart( + filePath: String, + partNumber: Int, + bytes: Array[Byte], + user: SessionUser = sessionUser, + contentLengthOverride: Option[Long] = None, + missingContentLength: Boolean = false + ): Response = { + val hdrs = + if (missingContentLength) mkHeadersMissingContentLength + else mkHeaders(contentLengthOverride.getOrElse(bytes.length.toLong)) + + datasetResource.uploadPart( + testUser.getEmail, + testDataset.getName, + enc(filePath), + partNumber, + new ByteArrayInputStream(bytes), + hdrs, + user + ) + } + + private def uploadPartWithStream( + filePath: String, + partNumber: Int, + stream: InputStream, + contentLength: Long, + user: SessionUser = sessionUser + ): Response = + datasetResource.uploadPart( + testUser.getEmail, + testDataset.getName, + enc(filePath), + partNumber, + stream, + mkHeaders(contentLength), + user + ) + + private def fetchSession(filePath: String) = + getDSLContext + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(testUser.getUid) + .and(DATASET_UPLOAD_SESSION.DID.eq(testDataset.getDid)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .fetchOne() + + private def fetchPartRows(uploadId: String) = + getDSLContext + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId)) + .fetch() + .asScala + .toList + + private def fetchUploadIdOrFail(filePath: String): String = { + val s = fetchSession(filePath) + s should not be null + s.getUploadId + } + + private def assertPlaceholdersCreated(uploadId: String, expectedParts: Int): Unit = { + val rows = fetchPartRows(uploadId).sortBy(_.getPartNumber) + rows.size shouldEqual expectedParts + rows.head.getPartNumber shouldEqual 1 + rows.last.getPartNumber shouldEqual expectedParts + rows.foreach { r => + r.getEtag should not be null + r.getEtag shouldEqual "" // placeholder convention + } + } + + private def assertStatus(ex: WebApplicationException, status: Int): Unit = + ex.getResponse.getStatus shouldEqual status + + // --------------------------------------------------------------------------- + // INIT TESTS + // --------------------------------------------------------------------------- + + "multipart-upload?type=init" should "create an upload session row + precreate part placeholders (happy path)" in { + val filePath = uniqueFilePath("init-happy") + val resp = initUpload(filePath, numParts = 3) + + resp.getStatus shouldEqual 200 + + val s = fetchSession(filePath) + s should not be null + s.getNumPartsRequested shouldEqual 3 + s.getUploadId should not be null + s.getPhysicalAddress should not be null + + assertPlaceholdersCreated(s.getUploadId, expectedParts = 3) + } + + it should "reject missing numParts" in { + val filePath = uniqueFilePath("init-missing-numparts") + val ex = intercept[BadRequestException] { + datasetResource.multipartUpload( + "init", + testUser.getEmail, + testDataset.getName, + enc(filePath), + Optional.empty(), + sessionUser + ) + } + assertStatus(ex, 400) + } + + it should "reject invalid numParts (0, negative, too large)" in { + val filePath = uniqueFilePath("init-bad-numparts") + assertStatus(intercept[BadRequestException] { initUpload(filePath, 0) }, 400) + assertStatus(intercept[BadRequestException] { initUpload(filePath, -1) }, 400) + assertStatus(intercept[BadRequestException] { initUpload(filePath, 1000000000) }, 400) + } + + it should "reject invalid filePath (empty, absolute, '.', '..', control chars)" in { + assertStatus(intercept[BadRequestException] { initUpload("./nope.bin", 2) }, 400) + assertStatus(intercept[BadRequestException] { initUpload("/absolute.bin", 2) }, 400) + assertStatus(intercept[BadRequestException] { initUpload("a/./b.bin", 2) }, 400) + + // traversal-like '..' + assertStatus(intercept[BadRequestException] { initUpload("../escape.bin", 2) }, 400) + assertStatus(intercept[BadRequestException] { initUpload("a/../escape.bin", 2) }, 400) + + // control char (0x00) + assertStatus( + intercept[BadRequestException] { + initUpload(s"a/${0.toChar}b.bin", 2) + }, + 400 + ) + } + + it should "reject invalid type parameter" in { + val filePath = uniqueFilePath("init-bad-type") + val ex = intercept[BadRequestException] { + datasetResource.multipartUpload( + "not-a-real-type", + testUser.getEmail, + testDataset.getName, + enc(filePath), + Optional.empty(), + sessionUser + ) + } + assertStatus(ex, 400) + } + + it should "reject init when caller lacks WRITE access" in { + val filePath = uniqueFilePath("init-forbidden") + val ex = intercept[ForbiddenException] { + initUpload(filePath, numParts = 2, user = sessionUser2) + } + assertStatus(ex, 403) + } + + it should "handle init race: exactly one succeeds, one gets 409 CONFLICT" in { + val filePath = uniqueFilePath("init-race") + val barrier = new CyclicBarrier(2) + + def callInit(): Either[Throwable, Response] = + try { + barrier.await() + Right(initUpload(filePath, numParts = 2)) + } catch { + case t: Throwable => Left(t) + } + + val f1 = Future(callInit()) + val f2 = Future(callInit()) + val results = Await.result(Future.sequence(Seq(f1, f2)), 30.seconds) + + val oks = results.collect { case Right(r) if r.getStatus == 200 => r } + val fails = results.collect { case Left(t) => t } + + oks.size shouldEqual 1 + fails.size shouldEqual 1 + + fails.head match { + case e: WebApplicationException => assertStatus(e, 409) + case other => + fail( + s"Expected WebApplicationException(CONFLICT), got: ${other.getClass} / ${other.getMessage}" + ) + } + + val s = fetchSession(filePath) + s should not be null + assertPlaceholdersCreated(s.getUploadId, expectedParts = 2) + } + + it should "reject sequential double init with 409 CONFLICT" in { + val filePath = uniqueFilePath("init-double") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + val ex = intercept[WebApplicationException] { initUpload(filePath, numParts = 2) } + assertStatus(ex, 409) + } + + // --------------------------------------------------------------------------- + // PART UPLOAD TESTS + // --------------------------------------------------------------------------- + + "multipart-upload/part" should "reject uploadPart if init was not called" in { + val filePath = uniqueFilePath("part-no-init") + val ex = intercept[NotFoundException] { + uploadPart(filePath, partNumber = 1, bytes = Array[Byte](1, 2, 3)) + } + assertStatus(ex, 404) + } + + it should "reject missing/invalid Content-Length" in { + val filePath = uniqueFilePath("part-bad-cl") + initUpload(filePath, numParts = 2) + + assertStatus( + intercept[BadRequestException] { + uploadPart( + filePath, + partNumber = 1, + bytes = Array[Byte](1, 2, 3), + missingContentLength = true + ) + }, + 400 + ) + + assertStatus( + intercept[BadRequestException] { + uploadPart( + filePath, + partNumber = 1, + bytes = Array[Byte](1, 2, 3), + contentLengthOverride = Some(0L) + ) + }, + 400 + ) + + assertStatus( + intercept[BadRequestException] { + uploadPart( + filePath, + partNumber = 1, + bytes = Array[Byte](1, 2, 3), + contentLengthOverride = Some(-5L) + ) + }, + 400 + ) + } + + it should "reject null/empty filePath param early without depending on error text" in { + val hdrs = mkHeaders(1L) + + val ex1 = intercept[BadRequestException] { + datasetResource.uploadPart( + testUser.getEmail, + testDataset.getName, + null, // encodedFilePath null + 1, + new ByteArrayInputStream(Array.emptyByteArray), + hdrs, + sessionUser + ) + } + assertStatus(ex1, 400) + + val ex2 = intercept[BadRequestException] { + datasetResource.uploadPart( + testUser.getEmail, + testDataset.getName, + "", // empty + 1, + new ByteArrayInputStream(Array.emptyByteArray), + hdrs, + sessionUser + ) + } + assertStatus(ex2, 400) + } + + it should "reject invalid partNumber (< 1) and partNumber > requested" in { + val filePath = uniqueFilePath("part-bad-pn") + initUpload(filePath, numParts = 2) + + assertStatus( + intercept[BadRequestException] { + uploadPart(filePath, partNumber = 0, bytes = tinyBytes(1.toByte)) + }, + 400 + ) + + // Ensure we don't fail min-size check before we hit range validation. + assertStatus( + intercept[BadRequestException] { + uploadPart(filePath, partNumber = 3, bytes = minPartBytes(2.toByte)) + }, + 400 + ) + } + + it should "reject a non-final part smaller than the minimum size (without checking message)" in { + val filePath = uniqueFilePath("part-too-small-nonfinal") + initUpload(filePath, numParts = 2) // part 1 is NON-FINAL + + val ex = intercept[BadRequestException] { + uploadPart(filePath, partNumber = 1, bytes = tinyBytes(1.toByte)) + } + assertStatus(ex, 400) + + // DB should remain unchanged (etag still empty) + val uploadId = fetchUploadIdOrFail(filePath) + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + } + + it should "upload a part successfully and persist its ETag into DATASET_UPLOAD_SESSION_PART" in { + val filePath = uniqueFilePath("part-happy-db") + initUpload(filePath, numParts = 2) + + val uploadId = fetchUploadIdOrFail(filePath) + + // Before upload: placeholder etag empty + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + + val bytes = minPartBytes(7.toByte) + uploadPart(filePath, partNumber = 1, bytes = bytes).getStatus shouldEqual 200 + + val after = fetchPartRows(uploadId).find(_.getPartNumber == 1).get + after.getEtag should not equal "" + } + + it should "allow retrying the same part sequentially (no duplicates, etag ends non-empty)" in { + val filePath = uniqueFilePath("part-retry") + initUpload(filePath, numParts = 2) + val uploadId = fetchUploadIdOrFail(filePath) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 1, minPartBytes(2.toByte)).getStatus shouldEqual 200 + + val rows = fetchPartRows(uploadId).filter(_.getPartNumber == 1) + rows.size shouldEqual 1 + rows.head.getEtag should not equal "" + } + + it should "apply per-part locking: return 409 if that part row is locked by another uploader" in { + val filePath = uniqueFilePath("part-lock") + initUpload(filePath, numParts = 2) + val uploadId = fetchUploadIdOrFail(filePath) + + val cp = getDSLContext.configuration().connectionProvider() + val conn = cp.acquire() + conn.setAutoCommit(false) + + try { + val locking = DSL.using(conn, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .forUpdate() + .fetchOne() + + val ex = intercept[WebApplicationException] { + uploadPart(filePath, 1, minPartBytes(1.toByte)) + } + assertStatus(ex, 409) + } finally { + conn.rollback() + cp.release(conn) + } + + // After releasing lock, upload should succeed + uploadPart(filePath, 1, minPartBytes(3.toByte)).getStatus shouldEqual 200 + } + + it should "not block other parts: locking part 1 does not prevent uploading part 2" in { + val filePath = uniqueFilePath("part-lock-other-part") + initUpload(filePath, numParts = 2) + val uploadId = fetchUploadIdOrFail(filePath) + + val cp = getDSLContext.configuration().connectionProvider() + val conn = cp.acquire() + conn.setAutoCommit(false) + + try { + val locking = DSL.using(conn, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .forUpdate() + .fetchOne() + + // part 2 is FINAL, can be tiny + uploadPart(filePath, 2, tinyBytes(9.toByte)).getStatus shouldEqual 200 + } finally { + conn.rollback() + cp.release(conn) + } + } + + it should "reject uploadPart when caller lacks WRITE access" in { + val filePath = uniqueFilePath("part-forbidden") + initUpload(filePath, numParts = 2) + + val ex = intercept[ForbiddenException] { + uploadPart(filePath, 1, minPartBytes(1.toByte), user = sessionUser2) + } + assertStatus(ex, 403) + } + + // --------------------------------------------------------------------------- + // FINISH TESTS + // --------------------------------------------------------------------------- + + "multipart-upload?type=finish" should "reject finish if init was not called" in { + val filePath = uniqueFilePath("finish-no-init") + val ex = intercept[NotFoundException] { finishUpload(filePath) } + assertStatus(ex, 404) + } + + it should "reject finish when no parts were uploaded (all placeholders empty) without checking messages" in { + val filePath = uniqueFilePath("finish-no-parts") + initUpload(filePath, numParts = 2) + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 409) + + // session remains + fetchSession(filePath) should not be null + } + + it should "reject finish when some parts are missing (etag empty treated as missing)" in { + val filePath = uniqueFilePath("finish-missing") + initUpload(filePath, numParts = 3) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 409) + + val uploadId = fetchUploadIdOrFail(filePath) + fetchPartRows(uploadId).find(_.getPartNumber == 2).get.getEtag shouldEqual "" + fetchPartRows(uploadId).find(_.getPartNumber == 3).get.getEtag shouldEqual "" + } + + it should "reject finish when extra part rows exist in DB (bypass endpoint) without checking messages" in { + val filePath = uniqueFilePath("finish-extra-db") + initUpload(filePath, numParts = 2) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(2.toByte)).getStatus shouldEqual 200 + + val s = fetchSession(filePath) + val uploadId = s.getUploadId + + // Bypass: insert extra row partNumber=3 + getDSLContext + .insertInto(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID, uploadId) + .set(DATASET_UPLOAD_SESSION_PART.PART_NUMBER, Integer.valueOf(3)) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, "bogus-etag") + .execute() + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 500) + + // Ensure nothing got deleted + fetchSession(filePath) should not be null + fetchPartRows(uploadId).nonEmpty shouldEqual true + } + + it should "finish successfully when all parts have non-empty etags; delete session + part rows" in { + val filePath = uniqueFilePath("finish-happy") + initUpload(filePath, numParts = 3) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, minPartBytes(2.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 3, tinyBytes(3.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + + val resp = finishUpload(filePath) + resp.getStatus shouldEqual 200 + + fetchSession(filePath) shouldBe null + fetchPartRows(uploadId) shouldBe empty + } + + it should "be idempotent-ish: second finish should return NotFound after successful finish" in { + val filePath = uniqueFilePath("finish-twice") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val ex = intercept[NotFoundException] { finishUpload(filePath) } + assertStatus(ex, 404) + } + + it should "reject finish when caller lacks WRITE access" in { + val filePath = uniqueFilePath("finish-forbidden") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + val ex = intercept[ForbiddenException] { finishUpload(filePath, user = sessionUser2) } + assertStatus(ex, 403) + } + + it should "return 409 CONFLICT if the session row is locked by another finalizer/aborter" in { + val filePath = uniqueFilePath("finish-lock-race") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + val cp = getDSLContext.configuration().connectionProvider() + val conn = cp.acquire() + conn.setAutoCommit(false) + + try { + val locking = DSL.using(conn, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(testUser.getUid) + .and(DATASET_UPLOAD_SESSION.DID.eq(testDataset.getDid)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .fetchOne() + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 409) + } finally { + conn.rollback() + cp.release(conn) + } + } + + // --------------------------------------------------------------------------- + // ABORT TESTS + // --------------------------------------------------------------------------- + + "multipart-upload?type=abort" should "reject abort if init was not called" in { + val filePath = uniqueFilePath("abort-no-init") + val ex = intercept[NotFoundException] { abortUpload(filePath) } + assertStatus(ex, 404) + } + + it should "abort successfully; delete session + part rows" in { + val filePath = uniqueFilePath("abort-happy") + initUpload(filePath, numParts = 2) + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + + abortUpload(filePath).getStatus shouldEqual 200 + + fetchSession(filePath) shouldBe null + fetchPartRows(uploadId) shouldBe empty + } + + it should "reject abort when caller lacks WRITE access" in { + val filePath = uniqueFilePath("abort-forbidden") + initUpload(filePath, numParts = 1) + + val ex = intercept[ForbiddenException] { abortUpload(filePath, user = sessionUser2) } + assertStatus(ex, 403) + } + + it should "return 409 CONFLICT if the session row is locked by another finalizer/aborter" in { + val filePath = uniqueFilePath("abort-lock-race") + initUpload(filePath, numParts = 1) + + val cp = getDSLContext.configuration().connectionProvider() + val conn = cp.acquire() + conn.setAutoCommit(false) + + try { + val locking = DSL.using(conn, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(testUser.getUid) + .and(DATASET_UPLOAD_SESSION.DID.eq(testDataset.getDid)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .fetchOne() + + val ex = intercept[WebApplicationException] { abortUpload(filePath) } + assertStatus(ex, 409) + } finally { + conn.rollback() + cp.release(conn) + } + } + + it should "be consistent: abort after finish should return NotFound" in { + val filePath = uniqueFilePath("abort-after-finish") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val ex = intercept[NotFoundException] { abortUpload(filePath) } + assertStatus(ex, 404) + } + + // --------------------------------------------------------------------------- + // FAILURE / RESILIENCE (still unit tests; simulated failures) + // --------------------------------------------------------------------------- + + "multipart upload implementation" should "release locks and keep DB consistent if the incoming stream fails mid-upload (simulated network drop)" in { + val filePath = uniqueFilePath("netfail-upload-stream") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + val uploadId = fetchUploadIdOrFail(filePath) + + val payload = minPartBytes(5.toByte) + + // InputStream that throws after a few reads + val flaky = new InputStream { + private var pos = 0 + override def read(): Int = { + if (pos >= 1024) throw new IOException("simulated network drop") + val b = payload(pos) & 0xff + pos += 1 + b + } + } + + intercept[Throwable] { + uploadPartWithStream( + filePath, + partNumber = 1, + stream = flaky, + contentLength = payload.length.toLong + ) + } + + // ETag should still be empty (no partial DB commit) + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + + // And the lock must be released (retry should succeed) + uploadPart(filePath, 1, payload).getStatus shouldEqual 200 + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag should not equal "" + } + + it should "not delete session/parts if finalize fails downstream (simulate by corrupting an ETag)" in { + val filePath = uniqueFilePath("netfail-finish") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(2.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + + // Corrupt one ETag to force backend finalize failure (S3/LakeFS should reject). + getDSLContext + .update(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, "definitely-not-a-real-etag") + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .execute() + + intercept[Throwable] { finishUpload(filePath) } + + // Nothing should be deleted on failure + fetchSession(filePath) should not be null + fetchPartRows(uploadId).nonEmpty shouldEqual true + } + + // --------------------------------------------------------------------------- + // CORRUPTION CHEKS + // --------------------------------------------------------------------------- + + it should "upload without corruption (sha256 matches final object)" in { + val filePath = uniqueFilePath("sha256-positive") + initUpload(filePath, numParts = 3).getStatus shouldEqual 200 + + val p1 = minPartBytes(1.toByte) + val p2 = minPartBytes(2.toByte) + val p3 = Array.fill[Byte](123)(3.toByte) + + uploadPart(filePath, 1, p1).getStatus shouldEqual 200 + uploadPart(filePath, 2, p2).getStatus shouldEqual 200 + uploadPart(filePath, 3, p3).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val expected = sha256OfChunks(Seq(p1, p2, p3)) + + val repoName = testDataset.getRepositoryName + val ref = "main" + val downloaded = LakeFSStorageClient.getFileFromRepo(repoName, ref, filePath) + + val got = sha256OfFile(Paths.get(downloaded.toURI)) + + got.toSeq shouldEqual expected.toSeq + } + + it should "detect corruption (sha256 mismatch when a part is altered)" in { + val filePath = uniqueFilePath("sha256-negative") + initUpload(filePath, numParts = 3).getStatus shouldEqual 200 + + val p1 = minPartBytes(1.toByte) + val p2 = minPartBytes(2.toByte) + val p3 = Array.fill[Byte](123)(3.toByte) + + // Intended bytes hash + val intendedHash = sha256OfChunks(Seq(p1, p2, p3)) + + // Corrupt one byte in part 2 before uploading + val p2Corrupt = p2.clone() + p2Corrupt(0) = (p2Corrupt(0) ^ 0x01).toByte + + uploadPart(filePath, 1, p1).getStatus shouldEqual 200 + uploadPart(filePath, 2, p2Corrupt).getStatus shouldEqual 200 + uploadPart(filePath, 3, p3).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val repoName = testDataset.getRepositoryName + val ref = "main" + val downloaded = LakeFSStorageClient.getFileFromRepo(repoName, ref, filePath) + + val gotHash = sha256OfFile(Paths.get(downloaded.toURI)) + + // Must NOT equal the intended bytes + gotHash.toSeq should not equal intendedHash.toSeq + + val corruptHash = sha256OfChunks(Seq(p1, p2Corrupt, p3)) + gotHash.toSeq shouldEqual corruptHash.toSeq + } + + // --------------------------------------------------------------------------- + // STRESS / SOAK TESTS (tagged;) + // --------------------------------------------------------------------------- + + it should "survive 2 concurrent multipart uploads (fan-out)" taggedAs (StressMultipart, Slow) in { + val parallelUploads = 2 + val maxParts = 3 + + def oneUpload(i: Int): Future[Unit] = + Future { + val filePath = uniqueFilePath(s"stress-$i") + val numParts = 2 + Random.nextInt(maxParts - 1) + + initUpload(filePath, numParts).getStatus shouldEqual 200 + + // Upload parts concurrently (different parts, so no per-part conflicts expected) + val sharedMin = minPartBytes((i % 127).toByte) + val partFuts = (1 to numParts).map { pn => + Future { + val bytes = + if (pn < numParts) sharedMin + else tinyBytes((pn % 127).toByte, n = 1024) // final tail, 1KiB + uploadPart(filePath, pn, bytes).getStatus shouldEqual 200 + } + } + + Await.result(Future.sequence(partFuts), 60.seconds) + + finishUpload(filePath).getStatus shouldEqual 200 + fetchSession(filePath) shouldBe null + } + + val all = Future.sequence((1 to parallelUploads).map(oneUpload)) + Await.result(all, 180.seconds) + } + + it should "throttle concurrent uploads of the SAME part via per-part locks" taggedAs (StressMultipart, Slow) in { + val filePath = uniqueFilePath("stress-same-part") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + val contenders = 4 + val barrier = new CyclicBarrier(contenders) + + def tryUploadStatus(): Future[Int] = + Future { + barrier.await() + try { + uploadPart(filePath, 1, minPartBytes(7.toByte)).getStatus + } catch { + case e: WebApplicationException => e.getResponse.getStatus + } + } + + val statuses = + Await.result(Future.sequence((1 to contenders).map(_ => tryUploadStatus())), 60.seconds) + + statuses.foreach { s => s should (be(200) or be(409)) } + statuses.count(_ == 200) should be >= 1 + + val uploadId = fetchUploadIdOrFail(filePath) + val part1 = fetchPartRows(uploadId).find(_.getPartNumber == 1).get + part1.getEtag.trim should not be "" + } + +} diff --git a/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts b/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts index b4d12f5a28e..a821879fe1a 100644 --- a/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts +++ b/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts @@ -104,8 +104,8 @@ export class DatasetDetailComponent implements OnInit { // List of upload tasks – each task tracked by its filePath public uploadTasks: Array< MultipartUploadProgress & { - filePath: string; - } + filePath: string; + } > = []; @Output() userMakeChanges = new EventEmitter(); @@ -416,8 +416,6 @@ export class DatasetDetailComponent implements OnInit { filePath: file.name, percentage: 0, status: "initializing", - uploadId: "", - physicalAddress: "", }); // Start multipart upload const subscription = this.datasetService @@ -558,21 +556,19 @@ export class DatasetDetailComponent implements OnInit { this.onUploadComplete(); } + this.datasetService .finalizeMultipartUpload( this.ownerEmail, this.datasetName, task.filePath, - task.uploadId, - [], - task.physicalAddress, true // abort flag ) .pipe(untilDestroyed(this)) .subscribe(() => { this.notificationService.info(`${task.filePath} uploading has been terminated`); }); - // Remove the aborted task immediately + this.uploadTasks = this.uploadTasks.filter(t => t.filePath !== task.filePath); } diff --git a/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts b/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts index c09125d73b1..f7381d9bfed 100644 --- a/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts +++ b/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts @@ -27,6 +27,7 @@ import { DashboardDataset } from "../../../type/dashboard-dataset.interface"; import { DatasetFileNode } from "../../../../common/type/datasetVersionFileTree"; import { DatasetStagedObject } from "../../../../common/type/dataset-staged-object"; import { GuiConfigService } from "../../../../common/service/gui-config.service"; +import { AuthService } from "src/app/common/service/user/auth.service"; export const DATASET_BASE_URL = "dataset"; export const DATASET_CREATE_URL = DATASET_BASE_URL + "/create"; @@ -51,11 +52,9 @@ export interface MultipartUploadProgress { filePath: string; percentage: number; status: "initializing" | "uploading" | "finished" | "aborted"; - uploadId: string; - physicalAddress: string; - uploadSpeed?: number; // bytes per second - estimatedTimeRemaining?: number; // seconds - totalTime?: number; // total seconds taken + uploadSpeed?: number; // bytes per second + estimatedTimeRemaining?: number; // seconds + totalTime?: number; // total seconds taken } @Injectable({ @@ -122,6 +121,7 @@ export class DatasetService { public retrieveAccessibleDatasets(): Observable { return this.http.get(`${AppSettings.getApiEndpoint()}/${DATASET_LIST_URL}`); } + public createDatasetVersion(did: number, newVersion: string): Observable { return this.http .post<{ @@ -141,6 +141,12 @@ export class DatasetService { /** * Handles multipart upload for large files using RxJS, * with a concurrency limit on how many parts we process in parallel. + * + * Backend flow: + * POST /dataset/multipart-upload?type=init&ownerEmail=...&datasetName=...&filePath=...&numParts=N + * POST /dataset/multipart-upload/part?ownerEmail=...&datasetName=...&filePath=...&partNumber= (body: raw chunk) + * POST /dataset/multipart-upload?type=finish&ownerEmail=...&datasetName=...&filePath=... + * POST /dataset/multipart-upload?type=abort&ownerEmail=...&datasetName=...&filePath=... */ public multipartUpload( ownerEmail: string, @@ -152,8 +158,8 @@ export class DatasetService { ): Observable { const partCount = Math.ceil(file.size / partSize); - return new Observable(observer => { - // Track upload progress for each part independently + return new Observable(observer => { + // Track upload progress (bytes) for each part independently const partProgress = new Map(); // Progress tracking state @@ -162,8 +168,15 @@ export class DatasetService { let lastETA = 0; let lastUpdateTime = 0; - // Calculate stats with smoothing + const lastStats = { + uploadSpeed: 0, + estimatedTimeRemaining: 0, + totalTime: 0, + }; + const getTotalTime = () => (startTime ? (Date.now() - startTime) / 1000 : 0); + + // Calculate stats with smoothing and simple throttling (~1s) const calculateStats = (totalUploaded: number) => { if (startTime === null) { startTime = Date.now(); @@ -172,25 +185,28 @@ export class DatasetService { const now = Date.now(); const elapsed = getTotalTime(); - // Throttle updates to every 1s const shouldUpdate = now - lastUpdateTime >= 1000; if (!shouldUpdate) { - return null; + // keep totalTime fresh even when throttled + lastStats.totalTime = elapsed; + return lastStats; } lastUpdateTime = now; - // Calculate speed with moving average const currentSpeed = elapsed > 0 ? totalUploaded / elapsed : 0; speedSamples.push(currentSpeed); - if (speedSamples.length > 5) speedSamples.shift(); - const avgSpeed = speedSamples.reduce((a, b) => a + b, 0) / speedSamples.length; + if (speedSamples.length > 5) { + speedSamples.shift(); + } + const avgSpeed = + speedSamples.length > 0 + ? speedSamples.reduce((a, b) => a + b, 0) / speedSamples.length + : 0; - // Calculate smooth ETA const remaining = file.size - totalUploaded; let eta = avgSpeed > 0 ? remaining / avgSpeed : 0; - eta = Math.min(eta, 24 * 60 * 60); // cap ETA at 24h, 86400 sec + eta = Math.min(eta, 24 * 60 * 60); // cap ETA at 24h - // Smooth ETA changes (limit to 30% change) if (lastETA > 0 && eta > 0) { const maxChange = lastETA * 0.3; const diff = Math.abs(eta - lastETA); @@ -200,229 +216,226 @@ export class DatasetService { } lastETA = eta; - // Near completion optimization const percentComplete = (totalUploaded / file.size) * 100; if (percentComplete > 95) { eta = Math.min(eta, 10); } - return { - uploadSpeed: avgSpeed, - estimatedTimeRemaining: Math.max(0, Math.round(eta)), - totalTime: elapsed, - }; + lastStats.uploadSpeed = avgSpeed; + lastStats.estimatedTimeRemaining = Math.max(0, Math.round(eta)); + lastStats.totalTime = elapsed; + + return lastStats; }; - const subscription = this.initiateMultipartUpload(ownerEmail, datasetName, filePath, partCount) + // 1. INIT: ask backend to create a LakeFS multipart upload session + const initParams = new HttpParams() + .set("type", "init") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)) + .set("numParts", partCount.toString()); + + const init$ = this.http.post<{}>( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: initParams } + ); + + const subscription = init$ .pipe( - switchMap(initiateResponse => { - const { uploadId, presignedUrls, physicalAddress } = initiateResponse; - if (!uploadId) { - observer.error(new Error("Failed to initiate multipart upload")); - return EMPTY; - } + switchMap(initResp => { + // Notify UI that upload is starting observer.next({ - filePath: filePath, + filePath, percentage: 0, status: "initializing", - uploadId: uploadId, - physicalAddress: physicalAddress, uploadSpeed: 0, estimatedTimeRemaining: 0, totalTime: 0, }); - // Keep track of all uploaded parts - const uploadedParts: { PartNumber: number; ETag: string }[] = []; - - // 1) Convert presignedUrls into a stream of URLs - return from(presignedUrls).pipe( - // 2) Use mergeMap with concurrency limit to upload chunk by chunk - mergeMap((url, index) => { - const partNumber = index + 1; - const start = index * partSize; - const end = Math.min(start + partSize, file.size); - const chunk = file.slice(start, end); - - // Upload the chunk - return new Observable(partObserver => { - const xhr = new XMLHttpRequest(); - - xhr.upload.addEventListener("progress", event => { - if (event.lengthComputable) { - // Update this specific part's progress - partProgress.set(partNumber, event.loaded); - - // Calculate total progress across all parts - let totalUploaded = 0; - partProgress.forEach(bytes => (totalUploaded += bytes)); - const percentage = Math.round((totalUploaded / file.size) * 100); - const stats = calculateStats(totalUploaded); - - observer.next({ - filePath, - percentage: Math.min(percentage, 99), // Cap at 99% until finalized - status: "uploading", - uploadId, - physicalAddress, - ...stats, - }); - } - }); - - xhr.addEventListener("load", () => { - if (xhr.status === 200 || xhr.status === 201) { - const etag = xhr.getResponseHeader("ETag")?.replace(/"/g, ""); - if (!etag) { - partObserver.error(new Error(`Missing ETag for part ${partNumber}`)); - return; + // 2. Upload each part to /multipart-upload/part using XMLHttpRequest + return from(Array.from({ length: partCount }, (_, i) => i)).pipe( + mergeMap( + index => { + const partNumber = index + 1; + const start = index * partSize; + const end = Math.min(start + partSize, file.size); + const chunk = file.slice(start, end); + + return new Observable(partObserver => { + const xhr = new XMLHttpRequest(); + + xhr.upload.addEventListener("progress", event => { + if (event.lengthComputable) { + partProgress.set(partNumber, event.loaded); + + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + + const percentage = Math.round((totalUploaded / file.size) * 100); + const stats = calculateStats(totalUploaded); + + observer.next({ + filePath, + percentage: Math.min(percentage, 99), + status: "uploading", + ...stats, + }); } + }); + + xhr.addEventListener("load", () => { + if (xhr.status === 200 || xhr.status === 204) { + // Mark part as fully uploaded + partProgress.set(partNumber, chunk.size); + + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + + // Force stats recompute on completion + lastUpdateTime = 0; + const percentage = Math.round((totalUploaded / file.size) * 100); + const stats = calculateStats(totalUploaded); + + observer.next({ + filePath, + percentage: Math.min(percentage, 99), + status: "uploading", + ...stats, + }); + + partObserver.complete(); + } else { + partObserver.error( + new Error(`Failed to upload part ${partNumber} (HTTP ${xhr.status})`) + ); + } + }); - // Mark this part as fully uploaded - partProgress.set(partNumber, chunk.size); - uploadedParts.push({ PartNumber: partNumber, ETag: etag }); - - // Recalculate progress - let totalUploaded = 0; - partProgress.forEach(bytes => (totalUploaded += bytes)); - const percentage = Math.round((totalUploaded / file.size) * 100); - lastUpdateTime = 0; - const stats = calculateStats(totalUploaded); - - observer.next({ - filePath, - percentage: Math.min(percentage, 99), - status: "uploading", - uploadId, - physicalAddress, - ...stats, - }); - partObserver.complete(); - } else { + xhr.addEventListener("error", () => { + // Remove failed part from progress + partProgress.delete(partNumber); partObserver.error(new Error(`Failed to upload part ${partNumber}`)); + }); + + const partUrl = + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload/part` + + `?ownerEmail=${encodeURIComponent(ownerEmail)}` + + `&datasetName=${encodeURIComponent(datasetName)}` + + `&filePath=${encodeURIComponent(filePath)}` + + `&partNumber=${partNumber}`; + + xhr.open("POST", partUrl); + xhr.setRequestHeader("Content-Type", "application/octet-stream"); + const token = AuthService.getAccessToken(); + if (token) { + xhr.setRequestHeader("Authorization", `Bearer ${token}`); } + xhr.send(chunk); + return () => { + try { + xhr.abort(); + } catch {} + }; }); - - xhr.addEventListener("error", () => { - // Remove failed part from progress - partProgress.delete(partNumber); - partObserver.error(new Error(`Failed to upload part ${partNumber}`)); - }); - - xhr.open("PUT", url); - xhr.send(chunk); - }); - }, concurrencyLimit), - - // 3) Collect results from all uploads (like forkJoin, but respects concurrency) - toArray(), - // 4) Finalize if all parts succeeded - switchMap(() => - this.finalizeMultipartUpload( - ownerEmail, - datasetName, - filePath, - uploadId, - uploadedParts, - physicalAddress, - false - ) + }, + concurrencyLimit ), + toArray(), // wait for all parts + // 3. FINISH: notify backend that all parts are done + switchMap(() => { + const finishParams = new HttpParams() + .set("type", "finish") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)); + + return this.http.post( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: finishParams } + ); + }), tap(() => { + const totalTime = getTotalTime(); observer.next({ filePath, percentage: 100, status: "finished", - uploadId: uploadId, - physicalAddress: physicalAddress, uploadSpeed: 0, estimatedTimeRemaining: 0, - totalTime: getTotalTime(), + totalTime, }); observer.complete(); }), - catchError((error: unknown) => { - // If an error occurred, abort the upload + catchError(error => { + // On error, compute best-effort percentage from bytes we've seen + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + const percentage = + file.size > 0 ? Math.round((totalUploaded / file.size) * 100) : 0; + observer.next({ filePath, - percentage: Math.round((uploadedParts.length / partCount) * 100), + percentage, status: "aborted", - uploadId: uploadId, - physicalAddress: physicalAddress, uploadSpeed: 0, estimatedTimeRemaining: 0, totalTime: getTotalTime(), }); - return this.finalizeMultipartUpload( - ownerEmail, - datasetName, - filePath, - uploadId, - uploadedParts, - physicalAddress, - true - ).pipe(switchMap(() => throwError(() => error))); + // Abort on backend + const abortParams = new HttpParams() + .set("type", "abort") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)); + + return this.http + .post( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: abortParams } + ) + .pipe( + switchMap(() => throwError(() => error)), + catchError(() => throwError(() => error)) + ); }) ); }) ) .subscribe({ - error: (err: unknown) => observer.error(err), + error: err => observer.error(err), }); + return () => subscription.unsubscribe(); }); } - /** - * Initiates a multipart upload and retrieves presigned URLs for each part. - * @param ownerEmail Owner's email - * @param datasetName Dataset Name - * @param filePath File path within the dataset - * @param numParts Number of parts for the multipart upload - */ - private initiateMultipartUpload( - ownerEmail: string, - datasetName: string, - filePath: string, - numParts: number - ): Observable<{ uploadId: string; presignedUrls: string[]; physicalAddress: string }> { - const params = new HttpParams() - .set("type", "init") - .set("ownerEmail", ownerEmail) - .set("datasetName", datasetName) - .set("filePath", encodeURIComponent(filePath)) - .set("numParts", numParts.toString()); - - return this.http.post<{ uploadId: string; presignedUrls: string[]; physicalAddress: string }>( - `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, - {}, - { params } - ); - } - - /** - * Completes or aborts a multipart upload, sending part numbers and ETags to the backend. - */ public finalizeMultipartUpload( ownerEmail: string, datasetName: string, filePath: string, - uploadId: string, - parts: { PartNumber: number; ETag: string }[], - physicalAddress: string, isAbort: boolean ): Observable { const params = new HttpParams() .set("type", isAbort ? "abort" : "finish") .set("ownerEmail", ownerEmail) .set("datasetName", datasetName) - .set("filePath", encodeURIComponent(filePath)) - .set("uploadId", uploadId); + .set("filePath", encodeURIComponent(filePath)); return this.http.post( `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, - { parts, physicalAddress }, + {}, { params } ); } diff --git a/sql/texera_ddl.sql b/sql/texera_ddl.sql index 7b0f9b9063d..8a9f55e1b28 100644 --- a/sql/texera_ddl.sql +++ b/sql/texera_ddl.sql @@ -58,6 +58,9 @@ DROP TABLE IF EXISTS workflow_version CASCADE; DROP TABLE IF EXISTS project CASCADE; DROP TABLE IF EXISTS workflow_of_project CASCADE; DROP TABLE IF EXISTS workflow_executions CASCADE; +DROP TABLE IF EXISTS dataset_upload_session CASCADE; +DROP TABLE IF EXISTS dataset_upload_session_part CASCADE; + DROP TABLE IF EXISTS dataset CASCADE; DROP TABLE IF EXISTS dataset_user_access CASCADE; DROP TABLE IF EXISTS dataset_version CASCADE; @@ -274,6 +277,36 @@ CREATE TABLE IF NOT EXISTS dataset_version FOREIGN KEY (did) REFERENCES dataset(did) ON DELETE CASCADE ); +CREATE TABLE IF NOT EXISTS dataset_upload_session +( + did INT NOT NULL, + uid INT NOT NULL, + file_path TEXT NOT NULL, + upload_id VARCHAR(256) NOT NULL UNIQUE, + physical_address TEXT, + num_parts_requested INT NOT NULL, + + PRIMARY KEY (uid, did, file_path), + + FOREIGN KEY (did) REFERENCES dataset(did) ON DELETE CASCADE, + FOREIGN KEY (uid) REFERENCES "user"(uid) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS dataset_upload_session_part +( + upload_id VARCHAR(256) NOT NULL, + part_number INT NOT NULL, + etag TEXT NOT NULL DEFAULT '', + + PRIMARY KEY (upload_id, part_number), + + CONSTRAINT chk_part_number_positive CHECK (part_number > 0), + + FOREIGN KEY (upload_id) + REFERENCES dataset_upload_session(upload_id) + ON DELETE CASCADE +); + -- operator_executions (modified to match MySQL: no separate primary key; added console_messages_uri) CREATE TABLE IF NOT EXISTS operator_executions ( diff --git a/sql/updates/16.sql b/sql/updates/16.sql new file mode 100644 index 00000000000..9436c405286 --- /dev/null +++ b/sql/updates/16.sql @@ -0,0 +1,66 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ============================================ +-- 1. Connect to the texera_db database +-- ============================================ +\c texera_db + +SET search_path TO texera_db; + +-- ============================================ +-- 2. Update the table schema +-- ============================================ +BEGIN; + +-- 1. Drop old tables (if exist) +DROP TABLE IF EXISTS dataset_upload_session CASCADE; +DROP TABLE IF EXISTS dataset_upload_session_part CASCADE; + +-- 2. Create dataset upload session table +CREATE TABLE IF NOT EXISTS dataset_upload_session +( + did INT NOT NULL, + uid INT NOT NULL, + file_path TEXT NOT NULL, + upload_id VARCHAR(256) NOT NULL UNIQUE, + physical_address TEXT, + num_parts_requested INT NOT NULL, + + PRIMARY KEY (uid, did, file_path), + + FOREIGN KEY (did) REFERENCES dataset(did) ON DELETE CASCADE, + FOREIGN KEY (uid) REFERENCES "user"(uid) ON DELETE CASCADE + ); + +-- 3. Create dataset upload session parts table +CREATE TABLE IF NOT EXISTS dataset_upload_session_part +( + upload_id VARCHAR(256) NOT NULL, + part_number INT NOT NULL, + etag TEXT NOT NULL DEFAULT '', + + PRIMARY KEY (upload_id, part_number), + + CONSTRAINT chk_part_number_positive CHECK (part_number > 0), + + FOREIGN KEY (upload_id) + REFERENCES dataset_upload_session(upload_id) + ON DELETE CASCADE + ); + +COMMIT;