Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@

package com.indix.utils.spark.parquet

import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.parquet.Log
import org.apache.parquet.hadoop.codec.CodecConfig
import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat}
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat

/**
* An output committer for writing Parquet files. In stead of writing to the `_temporary` folder
* An output committer for writing Parquet files. Instead of writing to the `_temporary` folder
* like what parquet.hadoop.ParquetOutputCommitter does, this output committer writes data directly to the
* destination folder. This can be useful for data stored in S3, where directory operations are
* relatively expensive.
Expand All @@ -37,9 +40,8 @@ import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetO
*
* *NOTE*
*
* NEVER use DirectParquetOutputCommitter when appending data, because currently there's
* no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are
* left empty).
* NEVER use DirectParquetOutputCommitter when appending data, because currently there's
* no safe way undo a failed appending job.
*/

class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
Expand All @@ -48,7 +50,24 @@ class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext

override def getWorkPath: Path = outputPath

override def abortTask(taskContext: TaskAttemptContext): Unit = {}
override def abortTask(taskContext: TaskAttemptContext): Unit = {
val fs = outputPath.getFileSystem(context.getConfiguration)
val split = taskContext.getTaskAttemptID.getTaskID.getId

val lists = fs.listStatus(outputPath, new PathFilter {
override def accept(path: Path): Boolean = path.getName.contains(f"-$split%05d-")
})
try {
lists.foreach {
l =>
LOG.error(s"Abort Task - Deleting ${l.getPath.toUri}")
fs.delete(l.getPath, false)
}
} catch {
case e: Throwable => LOG.warn(s"Cannot clean $outputPath. File does not exist")
}

}

override def commitTask(taskContext: TaskAttemptContext): Unit = {}

Expand Down
29 changes: 29 additions & 0 deletions util-spark/src/test/scala/com/indix/utils/spark/SparkJobSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.indix.utils.spark

import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FlatSpec}

abstract class SparkJobSpec extends FlatSpec with BeforeAndAfterEach with BeforeAndAfterAll {
val appName: String
val taskRetries: Int = 2
val sparkConf: Map[String, String] = Map()

@transient var spark: SparkSession = _
lazy val sqlContext = spark.sqlContext

override protected def beforeAll(): Unit = {
spark = SparkSession.builder()
.master(s"local[2, $taskRetries]").appName(appName)
.getOrCreate()

sparkConf.foreach {
case (k, v) => spark.conf.set(k, v)
}
}

override protected def afterAll() = {
SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ import java.util
import com.backtype.hadoop.pail.{PailFormatFactory, PailSpec, PailStructure}
import com.backtype.support.{Utils => PailUtils}
import com.google.common.io.Files
import com.indix.utils.spark.SparkJobSpec
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FlatSpec}
import org.scalatest.Matchers._

import scala.collection.JavaConverters._
Expand All @@ -27,13 +26,8 @@ class UserPailStructure extends PailStructure[User] {
override def deserialize(serialized: Array[Byte]): User = PailUtils.deserialize(serialized).asInstanceOf[User]
}

class PailDataSourceSpec extends FlatSpec with BeforeAndAfterAll with PailDataSource {
private var spark: SparkSession = _

override protected def beforeAll(): Unit = {
super.beforeAll()
spark = SparkSession.builder().master("local[2]").appName("PailDataSource").getOrCreate()
}
class PailDataSourceSpec extends SparkJobSpec with PailDataSource {
override val appName: String = "PailDataSource"

val userPailSpec = new PailSpec(PailFormatFactory.SEQUENCE_FILE, new UserPailStructure)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.indix.utils.spark.parquet

import java.io.File
import java.nio.file.{Files, Paths}

import com.indix.utils.spark.SparkJobSpec
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.spark.SparkException
import org.apache.spark.sql.SaveMode
import org.scalatest.Matchers

class TestDirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
extends DirectParquetOutputCommitter(outputPath, context) {

override def commitTask(taskContext: TaskAttemptContext): Unit = {
if (taskContext.getTaskAttemptID.getId == 0)
throw new SparkException("Failing first attempt of task")
else
super.commitTask(taskContext)
}

}

class DirectParquetOutputCommitterSpec extends SparkJobSpec with Matchers {
override val appName = "DirectParquetOutputCommitterSpec"
override val sparkConf = Map(("spark.sql.parquet.output.committer.class", "com.indix.utils.spark.parquet.TestDirectParquetOutputCommitter"))
var file: File = _

override def beforeAll() = {
super.beforeAll()
file = File.createTempFile("parquet", "")
}

override def afterAll() = {
super.afterAll()
FileUtils.deleteDirectory(file)
}

it should "not fail with file already exists on subsequent retries" in {
try {
sqlContext
.range(10)
.toDF()
.write
.mode(SaveMode.Overwrite)
.parquet(file.toString)
} catch {
case e: Exception => println(e)
} finally {
val successPath = Paths.get(file.toString + "/_SUCCESS")
Files.exists(successPath) should be(true)
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,16 @@ package com.indix.utils.spark.parquet
import java.io.File

import com.google.common.io.Files
import com.indix.utils.spark.SparkJobSpec
import com.indix.utils.spark.parquet.avro.ParquetAvroDataSource
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FlatSpec}
import org.scalatest.Matchers.{be, convertToAnyShouldWrapper}
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.scalatest.Matchers.{be, convertToAnyShouldWrapper}

case class SampleAvroRecord(a: Int, b: String, c: Seq[String], d: Boolean, e: Double, f: collection.Map[String,String], g: Seq[Byte])

class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with ParquetAvroDataSource {
private var spark: SparkSession = _

override protected def beforeAll(): Unit = {
super.beforeAll()
spark = SparkSession.builder().master("local[2]").appName("ParquetAvroDataSource").getOrCreate()
}

override protected def afterAll(): Unit = {
try {
spark.sparkContext.stop()
} finally {
super.afterAll()
}
}
class ParquetAvroDataSourceSpec extends SparkJobSpec with ParquetAvroDataSource {
override val appName = "ParquetAvroDataSource"

"AvroBasedParquetDataSource" should "read/write avro records as ParquetData" in {

Expand All @@ -44,9 +30,7 @@ class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with Par

sampleDf.rdd.saveAvroInParquet(outputLocation, sampleDf.schema, CompressionCodecName.GZIP)

val sparkVal = spark

import sparkVal.implicits._
import sqlContext.implicits._

val records: Array[SampleAvroRecord] = spark.read.parquet(outputLocation).as[SampleAvroRecord].collect()

Expand Down