Spark SQL: using collect_set over array values?

2019-06-01 06:30发布


I have an aggregated DataFrame with a column created using collect_set. I now need to aggregate over this DataFrame again, and apply collect_set to the values of that column again. The problem is that I need to apply collect_Set ver the values of the sets - and do far the only way I see how to do so is by exploding the aggregated DataFrame. Is there a better way?


Initial DataFrame:

country   | continent   | attributes
Canada    | America     | A
Belgium   | Europe      | Z
USA       | America     | A
Canada    | America     | B
France    | Europe      | Y
France    | Europe      | X

Aggregated DataFrame (the one I receive as input) - aggregation over country:

country   | continent   | attributes
Canada    | America     | A, B
Belgium   | Europe      | Z
USA       | America     | A
France    | Europe      | Y, X

My desired output - aggregation over continent:

continent   | attributes
America     | A, B
Europe      | X, Y, Z


Since you can have only a handful of rows at this point, you just collect attributes as-is and flatten the result (Spark >= 2.4)

import org.apache.spark.sql.functions.{collect_set, flatten, array_distinct}

val byState = Seq(
  ("Canada", "America", Seq("A", "B")),
  ("Belgium", "Europe", Seq("Z")),
  ("USA", "America", Seq("A")),
  ("France", "Europe", Seq("Y", "X"))
).toDF("country", "continent", "attributes")

  .agg(array_distinct(flatten(collect_set($"attributes"))) as "attributes")
|   Europe| [Y, X, Z]|
|  America|    [A, B]|

In general case things are much harder to handle, and in many cases, if you expect large lists, with many duplicates and many values per group, the optimal solution* is to just recompute results from scratch, i.e.

input.groupBy($"continent").agg(collect_set($"attributes") as "attributes")

One possible alternative is to use Aggregator

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{Encoder, Encoders}
import scala.collection.mutable.{Set => MSet}

class MergeSets[T, U](f: T => Seq[U])(implicit enc: Encoder[Seq[U]]) extends 
     Aggregator[T, MSet[U], Seq[U]] with Serializable {

  def zero = MSet.empty[U]

  def reduce(acc: MSet[U], x: T) = {
    for { v <- f(x) } acc.add(v)

  def merge(acc1: MSet[U], acc2: MSet[U]) = {
    acc1 ++= acc2

  def finish(acc: MSet[U]) = acc.toSeq
  def bufferEncoder: Encoder[MSet[U]] = Encoders.kryo[MSet[U]]
  def outputEncoder: Encoder[Seq[U]] = enc


and apply it as follows

case class CountryAggregate(
  country: String, continent: String, attributes: Seq[String])

  .agg(new MergeSets[CountryAggregate, String](_.attributes).toColumn)
  .toDF("continent", "attributes")
|   Europe| [X, Y, Z]|
|  America|    [B, A]|

but that's clearly not a Java-friendly option.

See also How to aggregate values into collection after groupBy? (similar, but without uniqueness constraint).

* That's because explode can be quite expensive, especially in older Spark versions, same as access to external representation of SQL collections.