diff --git a/core/file-service/src/main/scala/edu/uci/ics/texera/service/resource/DatasetResource.scala b/core/file-service/src/main/scala/edu/uci/ics/texera/service/resource/DatasetResource.scala index 914e71ce5f2..b9dd00dbfd6 100644 --- a/core/file-service/src/main/scala/edu/uci/ics/texera/service/resource/DatasetResource.scala +++ b/core/file-service/src/main/scala/edu/uci/ics/texera/service/resource/DatasetResource.scala @@ -466,6 +466,7 @@ class DatasetResource { @PathParam("did") did: Integer, @QueryParam("filePath") encodedFilePath: String, @QueryParam("message") message: String, + @DefaultValue("true") @QueryParam("useMultipartUpload") useMultipartUpload: Boolean, fileStream: InputStream, @Context headers: HttpHeaders, @Auth user: SessionUser @@ -486,74 +487,82 @@ class DatasetResource { repoName = dataset.getRepositoryName filePath = URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name) - // ---------- decide part-size & number-of-parts ---------- - val declaredLen = Option(headers.getHeaderString(HttpHeaders.CONTENT_LENGTH)).map(_.toLong) - var partSize = StorageConfig.s3MultipartUploadPartSize + if (useMultipartUpload) { + // ---------- decide part-size & number-of-parts ---------- + val declaredLen = Option(headers.getHeaderString(HttpHeaders.CONTENT_LENGTH)).map(_.toLong) + var partSize = StorageConfig.s3MultipartUploadPartSize + + declaredLen.foreach { ln => + val needed = ((ln + partSize - 1) / partSize).toInt + if (needed > MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) + partSize = math.max( + MINIMUM_NUM_OF_MULTIPART_S3_PART, + ln / (MAXIMUM_NUM_OF_MULTIPART_S3_PARTS - 1) + ) + } - declaredLen.foreach { ln => - val needed = ((ln + partSize - 1) / partSize).toInt - if (needed > MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) - partSize = math.max( - MINIMUM_NUM_OF_MULTIPART_S3_PART, - ln / (MAXIMUM_NUM_OF_MULTIPART_S3_PARTS - 1) - ) - } + val expectedParts = declaredLen + .map(ln => + ((ln + partSize - 1) / partSize).toInt + 1 + ) // "+1" for the last (possibly small) part + .getOrElse(MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) + + // ---------- ask LakeFS for presigned URLs ---------- + val presign = LakeFSStorageClient + .initiatePresignedMultipartUploads(repoName, filePath, expectedParts) + uploadId = presign.getUploadId + val presignedUrls = presign.getPresignedUrls.asScala.iterator + physicalAddress = presign.getPhysicalAddress + + // ---------- stream & upload parts ---------- + /* + 1. Reads the input stream in chunks of 'partSize' bytes by stacking them in a buffer + 2. Uploads each chunk (part) using a presigned URL + 3. Tracks each part number and ETag returned from S3 + 4. After all parts are uploaded, completes the multipart upload + */ + val buf = new Array[Byte](partSize.toInt) + var buffered = 0 + var partNumber = 1 + val completedParts = ListBuffer[(Int, String)]() + + @inline def flush(): Unit = { + if (buffered == 0) return + if (!presignedUrls.hasNext) + throw new WebApplicationException("Ran out of presigned part URLs – ask for more parts") + + val etag = put(buf, buffered, presignedUrls.next(), partNumber) + completedParts += ((partNumber, etag)) + partNumber += 1 + buffered = 0 + } - val expectedParts = declaredLen - .map(ln => - ((ln + partSize - 1) / partSize).toInt + 1 - ) // “+1” for the last (possibly small) part - .getOrElse(MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) - - // ---------- ask LakeFS for presigned URLs ---------- - val presign = LakeFSStorageClient - .initiatePresignedMultipartUploads(repoName, filePath, expectedParts) - uploadId = presign.getUploadId - val presignedUrls = presign.getPresignedUrls.asScala.iterator - physicalAddress = presign.getPhysicalAddress - - // ---------- stream & upload parts ---------- - /* - 1. Reads the input stream in chunks of 'partSize' bytes by stacking them in a buffer - 2. Uploads each chunk (part) using a presigned URL - 3. Tracks each part number and ETag returned from S3 - 4. After all parts are uploaded, completes the multipart upload - */ - val buf = new Array[Byte](partSize.toInt) - var buffered = 0 - var partNumber = 1 - val completedParts = ListBuffer[(Int, String)]() - - @inline def flush(): Unit = { - if (buffered == 0) return - if (!presignedUrls.hasNext) - throw new WebApplicationException("Ran out of presigned part URLs – ask for more parts") - - val etag = put(buf, buffered, presignedUrls.next(), partNumber) - completedParts += ((partNumber, etag)) - partNumber += 1 - buffered = 0 - } + var read = fileStream.read(buf, buffered, buf.length - buffered) + while (read != -1) { + buffered += read + if (buffered == buf.length) flush() // buffer full + read = fileStream.read(buf, buffered, buf.length - buffered) + } + fileStream.close() + flush() - var read = fileStream.read(buf, buffered, buf.length - buffered) - while (read != -1) { - buffered += read - if (buffered == buf.length) flush() // buffer full - read = fileStream.read(buf, buffered, buf.length - buffered) - } - fileStream.close() - flush() - - // ---------- complete upload ---------- - LakeFSStorageClient.completePresignedMultipartUploads( - repoName, - filePath, - uploadId, - completedParts.toList, - physicalAddress - ) + // ---------- complete upload ---------- + LakeFSStorageClient.completePresignedMultipartUploads( + repoName, + filePath, + uploadId, + completedParts.toList, + physicalAddress + ) - Response.ok(Map("message" -> s"Uploaded $filePath in ${completedParts.size} parts")).build() + Response.ok(Map("message" -> s"Uploaded $filePath in ${completedParts.size} parts")).build() + } else { + // Use single file upload method + LakeFSStorageClient.writeFileToRepo(repoName, filePath, fileStream) + fileStream.close() + + Response.ok(Map("message" -> s"Uploaded $filePath using single upload")).build() + } } } catch { case e: Exception =>