How to iterate grouped data in spark?

2019-08-17 04:35发布

问题:

I have a dataset like this:

uid    group_a    group_b
1      3          unkown
1      unkown     4
2      unkown     3
2      2          unkown

I want to get the result:

uid    group_a    group_b
1      3          4
2      2          3

I try to group the data by "uid" and iterate each group and select the not-unkown value as the final value, but don't know how to do it.

回答1:

After you format the dataset to a PairRDD you can use the reduceByKey operation to find the single known value. The following example assumes that there is only one known value per uid or otherwise returns the first known value

val input = List(
    ("1", "3", "unknown"),
    ("1", "unknown", "4"),
    ("2", "unknown", "3"),
    ("2", "2", "unknown")
)

val pairRdd = sc.parallelize(input).map(l => (l._1, (l._2, l._3)))

val result = pairRdd.reduceByKey { (a, b) => 
    val groupA = if (a._1 != "unknown") a._1 else b._1
    val groupB = if (a._2 != "unknown") a._2 else b._2
    (groupA, groupB)
}

The result will be a pairRdd that looks like this

(uid, (group_a, group_b))
(1,(3,4))                                                                       
(2,(2,3))

You can return to the plain line format with a simple map operation.



回答2:

I would suggest you define a User Defined Aggregation Function (UDAF)

Using inbuilt functions are great ways but they are difficult to be customized. If you own a UDAF then it is customizable and you can edit it according to your needs.

Concerning your problem, following can be your solution. You can edit it according to your needs.

First task is to define a UDAF

class PingJiang extends UserDefinedAggregateFunction {

  def inputSchema = new StructType().add("group_a", StringType).add("group_b", StringType)
  def bufferSchema = new StructType().add("buff0", StringType).add("buff1", StringType)
  def dataType = StringType
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, "")
    buffer.update(1, "")
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0)) {
      val buff = buffer.getString(0)
      val groupa = input.getString(0)
      val groupb = input.getString(1)

      if(!groupa.equalsIgnoreCase("unknown")){
        buffer.update(0, groupa)
      }
      if(!groupb.equalsIgnoreCase("unknown")){
        buffer.update(1, groupb)
      }
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    val buff1 = buffer1.getString(0)+buffer2.getString(0)
    val buff2 = buffer1.getString(1)+buffer2.getString(1)
    buffer1.update(0, buff1+","+buff2)
  }

  def evaluate(buffer: Row) : String = {
    buffer.getString(0)
  }
}

Then you call it from your main class and do some manipulations to get the result you need as

val data = Seq(
  (1, "3", "unknown"),
  (1, "unknown", "4"),
  (2, "unknown", "3"),
  (2, "2", "unknown"))
  .toDF("uid", "group_a", "group_b")

val udaf = new PingJiang()

val result = data.groupBy("uid").agg(udaf($"group_a", $"group_b").as("ping"))
  .withColumn("group_a", split($"ping", ",")(0))
  .withColumn("group_b", split($"ping", ",")(1))
  .drop("ping")
result.show(false)

Visit databricks and augmentiq for better understanding of UDAF

Note : The above solution gets you the latest value for each group if present (You can always edit according to your needs)



回答3:

You could replace all "unknown" values by null, and then use the function first() inside a map (as shown here), to get the first non-null values in each column per group:

import org.apache.spark.sql.functions.{col,first,when}
// We are only gonna apply our function to the last 2 columns
val cols = df.columns.drop(1)
// Create expression
val exprs = cols.map(first(_,true))
// Putting it all together
df.select(df.columns
          .map(c => when(col(c) === "unknown", null)
          .otherwise(col(c)).as(c)): _*)
  .groupBy("uid")
  .agg(exprs.head, exprs.tail: _*).show()
+---+--------------------+--------------------+
|uid|first(group_1, true)|first(group_b, true)|
+---+--------------------+--------------------+
|  1|                   3|                   4|
|  2|                   2|                   3|
+---+--------------------+--------------------+

Data:

val df = sc.parallelize(Array(("1","3","unknown"),("1","unknown","4"),
                              ("2","unknown","3"),("2","2","unknown"))).toDF("uid","group_1","group_b")