Spark/Scala repeated calls to withColumn() using t

2020-01-29 04:39发布

问题:

I currently have code in which I repeatedly apply the same procedure to multiple DataFrame Columns via multiple chains of .withColumn, and am wanting to create a function to streamline the procedure. In my case, I am finding cumulative sums over columns aggregated by keys:

val newDF = oldDF
  .withColumn("cumA", sum("A").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumB", sum("B").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumC", sum("C").over(Window.partitionBy("ID").orderBy("time")))
  //.withColumn(...)

What I would like is either something like:

def createCumulativeColums(cols: Array[String], df: DataFrame): DataFrame = {
  // Implement the above cumulative sums, partitioning, and ordering
}

or better yet:

def withColumns(cols: Array[String], df: DataFrame, f: function): DataFrame = {
  // Implement a udf/arbitrary function on all the specified columns
}

回答1:

You can use select with varargs including *:

import spark.implicits._

df.select($"*" +: Seq("A", "B", "C").map(c => 
  sum(c).over(Window.partitionBy("ID").orderBy("time")).alias(s"cum$c")
): _*)

This:

  • Maps columns names to window expressions with Seq("A", ...).map(...)
  • Prepends all pre-existing columns with $"*" +: ....
  • Unpacks combined sequence with ... : _*.

and can be generalized as:

import org.apache.spark.sql.{Column, DataFrame}

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 */
def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
  df.select($"*" +: cols.map(c => f(c)): _*)

If you find withColumn syntax more readable you can use foldLeft:

Seq("A", "B", "C").foldLeft(df)((df, c) =>
  df.withColumn(s"cum$c",  sum(c).over(Window.partitionBy("ID").orderBy("time")))
)

which can be generalized for example to:

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 * @param name a function mapping from input to output name.
 */
def withColumns(cols: Seq[String], df: DataFrame, 
    f: String =>  Column, name: String => String = identity) =
  cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))


回答2:

The question is a bit old, but I thought it would be useful (perhaps for others) to note that folding over the list of columns using the DataFrame as accumulator and mapping over the DataFrame have substantially different performance outcomes when the number of columns is not trivial (see here for the full explanation). Long story short... for few columns foldLeft is fine, otherwise map is better.