Aggregate (Sum) over Window for a list of Columns

2019-05-31 03:29发布

问题:

I'm having trouble finding a generic way to calculate the Sum (or any aggregate function) over a given window, for a list of columns available in the DataFrame.

val inputDF = spark
.sparkContext
.parallelize(
    Seq(
        (1,2,1, 30, 100),
        (1,2,2, 30, 100), 
        (1,2,3, 30, 100),
        (11,21,1, 30, 100),
        (11,21,2, 30, 100), 
        (11,21,3, 30, 100)
    ),
    10)
.toDF("c1", "c2", "offset", "v1", "v2")

input.show
+---+---+------+---+---+
| c1| c2|offset| v1| v2|
+---+---+------+---+---+
|  1|  2|     1| 30|100|
|  1|  2|     2| 30|100|
|  1|  2|     3| 30|100|
| 11| 21|     1| 30|100|
| 11| 21|     2| 30|100|
| 11| 21|     3| 30|100|
+---+---+------+---+---+

Given a DataFrame as shown above, it's easy to find Sum for a list of columns, similar to code snippet shown below -

val groupKey = List("c1", "c2").map(x => col(x.trim))
    val orderByKey = List("offset").map(x => col(x.trim))

    val aggKey = List("v1", "v2").map(c => sum(c).alias(c.trim))

    import org.apache.spark.sql.expressions.Window

    val w = Window.partitionBy(groupKey: _*).orderBy(orderByKey: _*)

    val outputDF = inputDF
    .groupBy(groupKey: _*)
    .agg(aggKey.head, aggKey.tail: _*)

    outputDF.show

But I can't seem to find a similar approach for aggregate functions over a window spec. So far I've only been able to solve this by specifying each column individually as shown below -

val outputDF2 = inputDF
    .withColumn("cumulative_v1", sum(when($"offset".between(-1, 1), inputDF("v1")).otherwise(0)).over(w))
    .withColumn("cumulative_v3", sum(when($"offset".between(-2, 2), inputDF("v1")).otherwise(0)).over(w))

I'd appreciate if there is a way to do this aggregation over a dynamic list of columns. Thanks!

回答1:

I think I found an approach that works better than the one stated in the above problem.

/**
    * Utility method takes a DataFrame and a List of columns to return aggregated values for the specified list of columns
    * @param colsToAggregate    Seq[String] of all columns in the input DataFrame to be aggregated
    * @param inputDF            Input DataFrame
    * @param f                  aggregate function 'call by name'
    * @param partitionByColSeq  Seq[] of column names to partition the inputDF before applying the aggregate
    * @param orderByColSeq      Seq[] of column names to order the inputDF before applying the aggregate
    * @param name_prefix        String to prefix the new columns with, to avoid collisions
    * @param name               New column names. Uses Identify function and reuses aggregated column names
    * @return                   output DataFrame
    */
  def withRollingAggregateColumns(colsToAggregate: Seq[String],
                                  inputDF: DataFrame,
                                  f: String => Column,
                                  partitionByColSeq: Seq[String],
                                  orderByColSeq: Seq[String],
                                  name_prefix: String,
                                  name: String => String = identity) = {

    val groupByKey = partitionByColSeq.map(x => col(x.trim))
    val orderByKey = orderByColSeq.map(x => col(x.trim))

    import org.apache.spark.sql.expressions.Window

    val w = Window.partitionBy(groupByKey: _*).orderBy(orderByKey: _*)

    colsToAggregate
      .foldLeft(inputDF)(
        (df, elementInCols) => df
          .withColumn(
            name_prefix + "_" + name(elementInCols),
            f(elementInCols).over(w)
          )
      )
  }

In this case, the Utility method takes a DataFrame as an input and appends new columns based on the provided function f. It uses the "withColumn" and "foldLeft" syntax to iterate over the list of columns which need to be aggregated. To avoid any column name collisions, it appends a user-provided 'prefix' to the new aggregate columns