Spark 2 iterating over a partition to create a new

2019-09-02 11:46发布

问题:

I have been scratching my head trying to come up with a way to reduce a dataframe in spark to a frame which records gaps in the dataframe, preferably without completely killing parallelism. Here is a much-simplified example (It's a bit lengthy because I wanted it to be able to run):

import org.apache.spark.sql.SparkSession

case class Record(typ: String, start: Int, end: Int);

object Sample {
    def main(argv: Array[String]): Unit = {
        val sparkSession = SparkSession.builder()
            .master("local")
            .getOrCreate();

        val df = sparkSession.createDataFrame(
            Seq(
                Record("One", 0, 5),
                Record("One", 10, 15),
                Record("One", 5, 8),
                Record("Two", 10, 25),
                Record("Two", 40, 45),
                Record("Three", 30, 35)
            )
        );

        df.repartition(df("typ")).sortWithinPartitions(df("start")).show();
    }
}

When I get done I would like to be able to output a dataframe like this:

typ   start    end
---   -----    ---
One   0        8
One   10       15
Two   10       25
Two   40       45
Three 30       35

I guessed that partitioning by the 'typ' value would give me partitions with each distinct data value, 1-1, E.G. in the sample I would end up with three partions, one each for 'One', 'Two' and 'Three'. Furthermore, the sortWithinPartitions call is intended to give me each partition in sorted order on 'start' so that I can iterate from the beginning to the end and record gaps. That last part is where I am stuck. Is this possible? If not, is there another approach that is?

回答1:

I propose to skip the repartitioning and the sorting steps, and jump directly to a distributed compressed merge sort (I've just invented the name for the algorithm, just like the algorithm itself).

Here is the part of the algorithm that is supposed to be used as reduce operation:

  type Gap = (Int, Int)

  def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
    require(!as.isEmpty, "as must be non-empty")
    require(!bs.isEmpty, "bs must be non-empty")

    @annotation.tailrec
    def mergeRec(
      gaps: List[Gap],
      gapStart: Int,
      gapEndAccum: Int,
      as: List[Gap],
      bs: List[Gap]
    ): List[Gap] = {
      as match {
        case Nil => {
          bs match {
            case Nil => (gapStart, gapEndAccum) :: gaps
            case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
          }
        }
        case (a0, a1) :: at => {
          if (a0 <= gapEndAccum) {
            mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
          } else {
            bs match {
              case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
              case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
                mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
              } else {
                if (a0 < b0) {
                  mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
                } else {
                  mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
                }
              }
            }
          }
        }
      }
    }
    val (a0, a1) :: at = as
    val (b0, b1) :: bt = bs

    val reverseRes = 
      if (a0 < b0) 
        mergeRec(Nil, a0, a1, at, bs)
      else
        mergeRec(Nil, b0, b1, as, bt)

    reverseRes.reverse
  }

It works as follows:

  println(mergeIntervals(
    List((0, 3), (4, 7), (9, 11), (15, 16), (18, 22)),
    List((1, 2), (4, 5), (6, 10), (12, 13), (15, 17))
  ))

  // Outputs:
  // List((0,3), (4,11), (12,13), (15,17), (18,22))

Now, if you combine it with the parallel reduce of Spark,

  val mergedIntervals = df.
    as[(String, Int, Int)].
    rdd.
    map{case (t, s, e) => (t, List((s, e)))}.              // Convert start end to list with one interval
    reduceByKey(mergeIntervals).                           // perform parallel compressed merge-sort
    flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
    toDF("typ", "start", "end")                            // convert back to DF

  mergedIntervals.show()

you obtain something like a parallel merge sort that works directly on compressed representations of integer sequences (thus the name).

The result:

+-----+-----+---+
|  typ|start|end|
+-----+-----+---+
|  Two|   10| 25|
|  Two|   40| 45|
|  One|    0|  8|
|  One|   10| 15|
|Three|   30| 35|
+-----+-----+---+

Discussion

The mergeIntervals method implements a commutative, associative operation for merging lists of non-overlapping intervals that are already sorted in increasing order. All the overlapping intervals are then merged, and again stored in increasing order. This procedure can be repeated in a reduce step until all interval sequences are merged.

The interesting property of the algorithm is that it maximally compresses every intermediate result of reduction. Thus, if you have many intervals with a lot of overlap, this algorithm might actually be faster then other algorithms that are based on sorting of input intervals.

However, if you have very many intervals with very seldom overlaps, then this method might run out of memory and not work at all, so that other algorithms must be used that first sort the intervals, and then make some kind of scan and merge adjacent intervals locally. So, whether this will work or not depends on the use-case.


Full code

  val df = Seq(
    ("One", 0, 5),
    ("One", 10, 15),
    ("One", 5, 8),
    ("Two", 10, 25),
    ("Two", 40, 45),
    ("Three", 30, 35)
  ).toDF("typ", "start", "end")

  type Gap = (Int, Int)
  /** The `merge`-step of a variant of merge-sort
    * that works directly on compressed sequences of integers,
    * where instead of individual integers, the sequence is 
    * represented by sorted, non-overlapping ranges of integers.
    */
  def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
    require(!as.isEmpty, "as must be non-empty")
    require(!bs.isEmpty, "bs must be non-empty")
    // assuming that `as` and `bs` both are either lists with a single
    // interval, or sorted lists that arise as output of
    // this method, recursively merges them into a single list of
    // gaps, merging all overlapping gaps.
    @annotation.tailrec
    def mergeRec(
      gaps: List[Gap],
      gapStart: Int,
      gapEndAccum: Int,
      as: List[Gap],
      bs: List[Gap]
    ): List[Gap] = {
      as match {
        case Nil => {
          bs match {
            case Nil => (gapStart, gapEndAccum) :: gaps
            case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
          }
        }
        case (a0, a1) :: at => {
          if (a0 <= gapEndAccum) {
            mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
          } else {
            bs match {
              case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
              case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
                mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
              } else {
                if (a0 < b0) {
                  mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
                } else {
                  mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
                }
              }
            }
          }
        }
      }
    }
    val (a0, a1) :: at = as
    val (b0, b1) :: bt = bs

    val reverseRes = 
      if (a0 < b0) 
        mergeRec(Nil, a0, a1, at, bs)
      else
        mergeRec(Nil, b0, b1, as, bt)

    reverseRes.reverse
  }


  val mergedIntervals = df.
    as[(String, Int, Int)].
    rdd.
    map{case (t, s, e) => (t, List((s, e)))}.              // Convert start end to list with one interval
    reduceByKey(mergeIntervals).                           // perform parallel compressed merge-sort
    flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
    toDF("typ", "start", "end")                            // convert back to DF

  mergedIntervals.show()

Testing

The implementation of mergeIntervals is tested a little bit. If you want to actually incorporate it into your codebase, here is at least a sketch of one repeated randomized test for it:

  def randomIntervalSequence(): List[Gap] = {
    def recHelper(acc: List[Gap], open: Option[Int], currIdx: Int): List[Gap] = {
      if (math.random > 0.999) acc.reverse
      else {
        if (math.random > 0.90) {
          if (open.isEmpty) {
            recHelper(acc, Some(currIdx), currIdx + 1)
          } else {
            recHelper((open.get, currIdx) :: acc, None, currIdx + 1)
          }
        } else {
          recHelper(acc, open, currIdx + 1)
        }
      }
    }
    recHelper(Nil, None, 0)
  }

  def intervalsToInts(is: List[Gap]): List[Int] = is.flatMap{ case (a, b) => a to b }

  var numNonTrivialTests = 0
  while(numNonTrivialTests < 1000) {
    val as = randomIntervalSequence()
    val bs = randomIntervalSequence()
    if (!as.isEmpty && !bs.isEmpty) {
      numNonTrivialTests += 1
      val merged = mergeIntervals(as, bs)
      assert((intervalsToInts(as).toSet ++ intervalsToInts(bs)) == intervalsToInts(merged).toSet)
    }
  }

You would obviously have to replace the raw assert by something more civilized, depending on your framework.