Context
I have two tables that I am joining/cogrouping as part of my spark jobs, which incurs a large shuffle each time I run a job. I want to amortise the cost across all jobs by storing cogrouped data once, and use the already cogrouped data as part of my regular Spark runs to avoid the shuffle.
To try and achieve this, I have some data in HDFS stored in parquet format. I am using Parquet repeated fields to achieve the following schema
(date, [aRecords], [bRecords])
Where [aRecords] indicates an array of aRecord. I am also partitioning the data by date on HDFS using the usual write.partitionBy($"date")
.
In this situation, aRecords and bRecords appear to be effectively cogrouped by date. I can perform operations like the following:
case class CogroupedData(date: Date, aRecords: Array[Int], bRecords: Array[Int])
val cogroupedData = spark.read.parquet("path/to/data").as[CogroupedData]
//Dataset[(Date,Int)] where the Int in the two sides multiplied
val results = cogroupedData
.flatMap(el => el.aRecords.zip(el.bRecords).map(pair => (el.date, pair._1 * pair._2)))
and get the results that I get from using the equivalent groupByKey operations on two separate tables for aRecords and bRecords keyed by date.
The difference between the two is that I avoid a shuffle with the already cogrouped data, the cogrouped cost is amortised by persisting on HDFS.
Question
Now for the question. From the cogrouped dataset, I would like to derive the two grouped datasets so I can use standard Spark SQL operators (like cogroup, join etc) without incurring a shuffle. This seems possible since the first code example works, but Spark still insists on hashing/shuffling data when I join/groupByKey/cogroup etc.
Take the below code sample. I expect there is a way that we can run the below without incurring a shuffle when the join is performed.
val cogroupedData = spark.read.parquet("path/to/data").as[CogroupedData]
val aRecords = cogroupedData
.flatMap(cog => cog.aRecords.map(a => (cog.date,a)))
val bRecords = cogroupedData
.flatMap(cog => cog.bRecords.map(b => (cog.date,b)))
val joined = aRecords.join(bRecords,Seq("date"))
Looking at the literature, if cogroupedData has a known partitioner, then the operations that follow should not incur a shuffle since they can use the fact that the RDD is already partitioned and preserve the partitioner.
What I think I need to achieve this is to get a cogroupedData Dataset/rdd with a known partitioner without incurring a shuffle.
Other things I have tried already:
- Hive metadata - Works fine for simple joins, but only optimises the initial join and not subsequent transformations. Hive also does not help with cogroups at all
Anyone have any ideas?