I currently have code in which I repeatedly apply the same procedure to multiple DataFrame Columns via multiple chains of .withColumn, and am wanting to create a function to streamline the procedure. In my case, I am finding cumulative sums over columns aggregated by keys:
val newDF = oldDF
.withColumn("cumA", sum("A").over(Window.partitionBy("ID").orderBy("time")))
.withColumn("cumB", sum("B").over(Window.partitionBy("ID").orderBy("time")))
.withColumn("cumC", sum("C").over(Window.partitionBy("ID").orderBy("time")))
//.withColumn(...)
What I would like is either something like:
def createCumulativeColums(cols: Array[String], df: DataFrame): DataFrame = {
// Implement the above cumulative sums, partitioning, and ordering
}
or better yet:
def withColumns(cols: Array[String], df: DataFrame, f: function): DataFrame = {
// Implement a udf/arbitrary function on all the specified columns
}
You can use select
with varargs including *
:
import spark.implicits._
df.select($"*" +: Seq("A", "B", "C").map(c =>
sum(c).over(Window.partitionBy("ID").orderBy("time")).alias(s"cum$c")
): _*)
This:
- Maps columns names to window expressions with
Seq("A", ...).map(...)
- Prepends all pre-existing columns with
$"*" +: ...
.
- Unpacks combined sequence with
... : _*
.
and can be generalized as:
import org.apache.spark.sql.{Column, DataFrame}
/**
* @param cols a sequence of columns to transform
* @param df an input DataFrame
* @param f a function to be applied on each col in cols
*/
def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
df.select($"*" +: cols.map(c => f(c)): _*)
If you find withColumn
syntax more readable you can use foldLeft
:
Seq("A", "B", "C").foldLeft(df)((df, c) =>
df.withColumn(s"cum$c", sum(c).over(Window.partitionBy("ID").orderBy("time")))
)
which can be generalized for example to:
/**
* @param cols a sequence of columns to transform
* @param df an input DataFrame
* @param f a function to be applied on each col in cols
* @param name a function mapping from input to output name.
*/
def withColumns(cols: Seq[String], df: DataFrame,
f: String => Column, name: String => String = identity) =
cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))
The question is a bit old, but I thought it would be useful (perhaps for others) to note that folding over the list of columns using the DataFrame
as accumulator and mapping over the DataFrame
have substantially different performance outcomes when the number of columns is not trivial (see here for the full explanation).
Long story short... for few columns foldLeft
is fine, otherwise map
is better.