I try to repartition a DataFrame according to a column the the DataFrame has N
(let say N=3
) different values in the partition-column x
, e.g:
val myDF = sc.parallelize(Seq(1,1,2,2,3,3)).toDF("x") // create dummy data
What I like to achieve is to repartiton myDF
by x
without producing empty partitions. Is there a better way than doing this?
val numParts = myDF.select($"x").distinct().count.toInt
myDF.repartition(numParts,$"x")
(If I don't specify numParts
in repartiton
, most of my partitions are empty (as repartition
creates 200 partitions) ...)
I'd think of solution with iterating over df
partition and fetching record count in it to find non-empty partitions.
val nonEmptyPart = sparkContext.longAccumulator("nonEmptyPart")
df.foreachPartition(partition =>
if (partition.length > 0) nonEmptyPart.add(1))
As we got non-empty partitions (nonEmptyPart
), we can clean empty partitions by using coalesce()
(check coalesce() vs repartition()).
val finalDf = df.coalesce(nonEmptyPart.value.toInt) //coalesce() accepts only Int type
It may or may not be the best, but this solution will avoid shuffling as we are not using repartition()
Example to address comment
val df1 = sc.parallelize(Seq(1, 1, 2, 2, 3, 3)).toDF("x").repartition($"x")
val nonEmptyPart = sc.longAccumulator("nonEmptyPart")
df1.foreachPartition(partition =>
if (partition.length > 0) nonEmptyPart.add(1))
val finalDf = df1.coalesce(nonEmptyPart.value.toInt)
println(s"nonEmptyPart => ${nonEmptyPart.value.toInt}")
println(s"df.rdd.partitions.length => ${df1.rdd.partitions.length}")
println(s"finalDf.rdd.partitions.length => ${finalDf.rdd.partitions.length}")
Output
nonEmptyPart => 3
df.rdd.partitions.length => 200
finalDf.rdd.partitions.length => 3