Spark migrate sql window function to RDD for bette

2019-02-18 07:01发布

A function should be executed for multiple columns in a data frame

def handleBias(df: DataFrame, colName: String, target: String = target) = {
    val w1 = Window.partitionBy(colName)
    val w2 = Window.partitionBy(colName, target)

    df.withColumn("cnt_group", count("*").over(w2))
      .withColumn("pre2_" + colName, mean(target).over(w1))
      .withColumn("pre_" + colName, coalesce(min(col("cnt_group") / col("cnt_foo_eq_1")).over(w1), lit(0D)))
      .drop("cnt_group")
  }

This can be written nicely as shown above in spark-SQL and a for loop. However this is causing a lot of shuffles (spark apply function to columns in parallel).

A minimal example:

  val df = Seq(
    (0, "A", "B", "C", "D"),
    (1, "A", "B", "C", "D"),
    (0, "d", "a", "jkl", "d"),
    (0, "d", "g", "C", "D"),
    (1, "A", "d", "t", "k"),
    (1, "d", "c", "C", "D"),
    (1, "c", "B", "C", "D")
  ).toDF("TARGET", "col1", "col2", "col3TooMany", "col4")

  val columnsToDrop = Seq("col3TooMany")
  val columnsToCode = Seq("col1", "col2")
  val target = "TARGET"

  val targetCounts = df.filter(df(target) === 1).groupBy(target)
    .agg(count(target).as("cnt_foo_eq_1"))
  val newDF = df.join(broadcast(targetCounts), Seq(target), "left")

  val result = (columnsToDrop ++ columnsToCode).toSet.foldLeft(newDF) {
    (currentDF, colName) => handleBias(currentDF, colName)
  }

  result.drop(columnsToDrop: _*).show

How can I formulate this more efficient using RDD API? aggregateByKeyshould be a good idea but is still not very clear to me how to apply it here to substitute the window functions.

(provides a bit more context / bigger example https://github.com/geoHeil/sparkContrastCoding)

edit

Initially, I started with Spark dynamic DAG is a lot slower and different from hard coded DAG which is shown below. The good thing is, each column seems to run independent /parallel. The downside is that the joins (even for a small dataset of 300 MB) get "too big" and lead to an unresponsive spark.

handleBiasOriginal("col1", df)
    .join(handleBiasOriginal("col2", df), df.columns)
    .join(handleBiasOriginal("col3TooMany", df), df.columns)
    .drop(columnsToDrop: _*).show

  def handleBiasOriginal(col: String, df: DataFrame, target: String = target): DataFrame = {
    val pre1_1 = df
      .filter(df(target) === 1)
      .groupBy(col, target)
      .agg((count("*") / df.filter(df(target) === 1).count).alias("pre_" + col))
      .drop(target)

    val pre2_1 = df
      .groupBy(col)
      .agg(mean(target).alias("pre2_" + col))

    df
      .join(pre1_1, Seq(col), "left")
      .join(pre2_1, Seq(col), "left")
      .na.fill(0)
  }

This image is with spark 2.1.0, the images from Spark dynamic DAG is a lot slower and different from hard coded DAG are with 2.0.2 toocomplexDAG

The DAG will be a bit simpler when caching is applied df.cache handleBiasOriginal("col1", df). ...

What other possibilities than window functions do you see to optimize the SQL? At best it would be great if the SQL was generated dynamically.

caching

2条回答
一纸荒年 Trace。
2楼-- · 2019-02-18 07:25

The main point here is to avoid unnecessary shuffles. Right now your code shuffles twice for each columns you want to include and resulting data layout cannot be reused between columns.

For simplicity I assume that target is always binary ({0, 1}) and all remaining columns you use are of StringType. Furthermore I assume that cardinality of the columns is low enough for the results to be grouped and handled locally. You can adjust these methods to handle other cases but it requires more work.

RDD API

  • Reshape data from wide to long:

    import org.apache.spark.sql.functions._
    
    val exploded = explode(array(
      (columnsToDrop ++ columnsToCode).map(c => 
        struct(lit(c).alias("k"), col(c).alias("v"))): _*
    )).alias("level")
    
    val long = df.select(exploded, $"TARGET")
    
  • aggregateByKey, reshape and collect:

    import org.apache.spark.util.StatCounter
    
    val lookup = long.as[((String, String), Int)].rdd
      // You can use prefix partitioner (one that depends only on _._1)
      // to avoid reshuffling for groupByKey
      .aggregateByKey(StatCounter())(_ merge _, _ merge _)
      .map { case ((c, v), s) => (c, (v, s)) }
      .groupByKey
      .mapValues(_.toMap)
      .collectAsMap
    
  • You can use lookup to get statistics for individual columns and levels. For example:

    lookup("col1")("A")
    
    org.apache.spark.util.StatCounter = 
      (count: 3, mean: 0.666667, stdev: 0.471405, max: 1.000000, min: 0.000000)
    

    Gives you data for col1, level A. Based on the binary TARGET assumption this information is complete (you get count / fractions for both classes).

    You can use lookup like this to generate SQL expressions or pass it to udf and apply it on individual columns.

DataFrame API

  • Convert data to long as for RDD API.
  • Compute aggregates based on levels:

    val stats = long
      .groupBy($"level.k", $"level.v")
      .agg(mean($"TARGET"), sum($"TARGET"))
    
  • Depending on your preferences you can reshape this to enable efficient joins or convert to a local collection and similarly to the RDD solution.

查看更多
孤傲高冷的网名
3楼-- · 2019-02-18 07:32

Using aggregateByKey A simple explanation on aggregateByKey can be found here. Basically you use two functions: One which works inside a partition and one which works between partitions.

You would need to do something like aggregate by the first column and build a data structure internally with a map for every element of the second column to aggregate and collect data there (of course you could do two aggregateByKey if you want). This will not solve the case of doing multiple runs on the code for each column you want to work with (you can do use aggregate as opposed to aggregateByKey to work on all data and put it in a map but that will probably give you even worse performance). The result would then be one line per key, if you want to move back to the original records (as window function does) you would actually need to either join this value with the original RDD or save all values internally and flatmap

I do not believe this would provide you with any real performance improvement. You would be doing a lot of work to reimplement things that are done for you in SQL and while doing so you would be losing most of the advantages of SQL (catalyst optimization, tungsten memory management, whole stage code generation etc.)

Improving the SQL

What I would do instead is attempt to improve the SQL itself. For example, the result of the column in the window function appears to be the same for all values. Do you really need a window function? You can instead do a groupBy instead of a window function (and if you really need this per record you can try to join the results. This might provide better performance as it would not necessarily mean shuffling everything twice on every step).

查看更多
登录 后发表回答