I have been scratching my head trying to come up with a way to reduce a dataframe in spark to a frame which records gaps in the dataframe, preferably without completely killing parallelism. Here is a much-simplified example (It's a bit lengthy because I wanted it to be able to run):
import org.apache.spark.sql.SparkSession
case class Record(typ: String, start: Int, end: Int);
object Sample {
def main(argv: Array[String]): Unit = {
val sparkSession = SparkSession.builder()
.master("local")
.getOrCreate();
val df = sparkSession.createDataFrame(
Seq(
Record("One", 0, 5),
Record("One", 10, 15),
Record("One", 5, 8),
Record("Two", 10, 25),
Record("Two", 40, 45),
Record("Three", 30, 35)
)
);
df.repartition(df("typ")).sortWithinPartitions(df("start")).show();
}
}
When I get done I would like to be able to output a dataframe like this:
typ start end
--- ----- ---
One 0 8
One 10 15
Two 10 25
Two 40 45
Three 30 35
I guessed that partitioning by the 'typ' value would give me partitions with each distinct data value, 1-1, E.G. in the sample I would end up with three partions, one each for 'One', 'Two' and 'Three'. Furthermore, the sortWithinPartitions call is intended to give me each partition in sorted order on 'start' so that I can iterate from the beginning to the end and record gaps. That last part is where I am stuck. Is this possible? If not, is there another approach that is?
I propose to skip the repartitioning and the sorting steps, and jump directly to a distributed compressed merge sort (I've just invented the name for the algorithm, just like the algorithm itself).
Here is the part of the algorithm that is supposed to be used as reduce
operation:
type Gap = (Int, Int)
def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
require(!as.isEmpty, "as must be non-empty")
require(!bs.isEmpty, "bs must be non-empty")
@annotation.tailrec
def mergeRec(
gaps: List[Gap],
gapStart: Int,
gapEndAccum: Int,
as: List[Gap],
bs: List[Gap]
): List[Gap] = {
as match {
case Nil => {
bs match {
case Nil => (gapStart, gapEndAccum) :: gaps
case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
}
}
case (a0, a1) :: at => {
if (a0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
} else {
bs match {
case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
} else {
if (a0 < b0) {
mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
} else {
mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
}
}
}
}
}
}
}
val (a0, a1) :: at = as
val (b0, b1) :: bt = bs
val reverseRes =
if (a0 < b0)
mergeRec(Nil, a0, a1, at, bs)
else
mergeRec(Nil, b0, b1, as, bt)
reverseRes.reverse
}
It works as follows:
println(mergeIntervals(
List((0, 3), (4, 7), (9, 11), (15, 16), (18, 22)),
List((1, 2), (4, 5), (6, 10), (12, 13), (15, 17))
))
// Outputs:
// List((0,3), (4,11), (12,13), (15,17), (18,22))
Now, if you combine it with the parallel reduce
of Spark,
val mergedIntervals = df.
as[(String, Int, Int)].
rdd.
map{case (t, s, e) => (t, List((s, e)))}. // Convert start end to list with one interval
reduceByKey(mergeIntervals). // perform parallel compressed merge-sort
flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
toDF("typ", "start", "end") // convert back to DF
mergedIntervals.show()
you obtain something like a parallel merge sort that works directly on compressed representations of integer sequences (thus the name).
The result:
+-----+-----+---+
| typ|start|end|
+-----+-----+---+
| Two| 10| 25|
| Two| 40| 45|
| One| 0| 8|
| One| 10| 15|
|Three| 30| 35|
+-----+-----+---+
Discussion
The mergeIntervals
method implements a commutative, associative operation for merging lists of non-overlapping intervals that are already sorted in increasing order. All the overlapping intervals are then merged, and again stored in increasing order. This procedure can be repeated in a reduce
step until all interval sequences are merged.
The interesting property of the algorithm is that it maximally compresses every intermediate result of reduction. Thus, if you have many intervals with a lot of overlap, this algorithm might actually be faster then other algorithms that are based on sorting of input intervals.
However, if you have very many intervals with very seldom overlaps, then this method might run out of memory and not work at all, so that other algorithms must be used that first sort the intervals, and then make some kind of scan and merge adjacent intervals locally. So, whether this will work or not depends on the use-case.
Full code
val df = Seq(
("One", 0, 5),
("One", 10, 15),
("One", 5, 8),
("Two", 10, 25),
("Two", 40, 45),
("Three", 30, 35)
).toDF("typ", "start", "end")
type Gap = (Int, Int)
/** The `merge`-step of a variant of merge-sort
* that works directly on compressed sequences of integers,
* where instead of individual integers, the sequence is
* represented by sorted, non-overlapping ranges of integers.
*/
def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
require(!as.isEmpty, "as must be non-empty")
require(!bs.isEmpty, "bs must be non-empty")
// assuming that `as` and `bs` both are either lists with a single
// interval, or sorted lists that arise as output of
// this method, recursively merges them into a single list of
// gaps, merging all overlapping gaps.
@annotation.tailrec
def mergeRec(
gaps: List[Gap],
gapStart: Int,
gapEndAccum: Int,
as: List[Gap],
bs: List[Gap]
): List[Gap] = {
as match {
case Nil => {
bs match {
case Nil => (gapStart, gapEndAccum) :: gaps
case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
}
}
case (a0, a1) :: at => {
if (a0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
} else {
bs match {
case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
} else {
if (a0 < b0) {
mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
} else {
mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
}
}
}
}
}
}
}
val (a0, a1) :: at = as
val (b0, b1) :: bt = bs
val reverseRes =
if (a0 < b0)
mergeRec(Nil, a0, a1, at, bs)
else
mergeRec(Nil, b0, b1, as, bt)
reverseRes.reverse
}
val mergedIntervals = df.
as[(String, Int, Int)].
rdd.
map{case (t, s, e) => (t, List((s, e)))}. // Convert start end to list with one interval
reduceByKey(mergeIntervals). // perform parallel compressed merge-sort
flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
toDF("typ", "start", "end") // convert back to DF
mergedIntervals.show()
Testing
The implementation of mergeIntervals
is tested a little bit. If you want to actually incorporate it into your codebase, here is at least a sketch of one repeated randomized test for it:
def randomIntervalSequence(): List[Gap] = {
def recHelper(acc: List[Gap], open: Option[Int], currIdx: Int): List[Gap] = {
if (math.random > 0.999) acc.reverse
else {
if (math.random > 0.90) {
if (open.isEmpty) {
recHelper(acc, Some(currIdx), currIdx + 1)
} else {
recHelper((open.get, currIdx) :: acc, None, currIdx + 1)
}
} else {
recHelper(acc, open, currIdx + 1)
}
}
}
recHelper(Nil, None, 0)
}
def intervalsToInts(is: List[Gap]): List[Int] = is.flatMap{ case (a, b) => a to b }
var numNonTrivialTests = 0
while(numNonTrivialTests < 1000) {
val as = randomIntervalSequence()
val bs = randomIntervalSequence()
if (!as.isEmpty && !bs.isEmpty) {
numNonTrivialTests += 1
val merged = mergeIntervals(as, bs)
assert((intervalsToInts(as).toSet ++ intervalsToInts(bs)) == intervalsToInts(merged).toSet)
}
}
You would obviously have to replace the raw assert
by something more civilized, depending on your framework.