Spark-SQL Window functions on Dataframe - Finding

2019-08-14 16:50发布

问题:

I have below dataframe (say UserData).

uid region  timestamp
a   1   1
a   1   2
a   1   3
a   1   4
a   2   5
a   2   6
a   2   7
a   3   8
a   4   9
a   4   10
a   4   11
a   4   12
a   1   13
a   1   14
a   3   15
a   3   16
a   5   17
a   5   18
a   5   19
a   5   20

This data is nothing but user (uid) travelling across different regions (region) at different time (timestamp). Presently, timestamp is shown as 'int' for simplicity. Note that above dataframe will not be necessarily in increasing order of timestamp. Also, there may be some rows in between from different users. I have shown dataframe for single user only in monotonically incrementing order of timestamp for simplicity.

My goal is - to find User 'a' spent how much time in each region and in what order? So My final expected output looks like

uid region  regionTimeStart regionTimeEnd
a   1   1   5
a   2   5   8
a   3   8   9
a   4   9   13
a   1   13  15
a   3   15  17
a   5   17  20

Based on my findings, Spark SQL Window functions can be used for this purpose. I have tried below things,

val w = Window
  .partitionBy("region")
  .partitionBy("uid")
  .orderBy("timestamp")

val resultDF = UserData.select(
  UserData("uid"), UserData("timestamp"),
  UserData("region"), rank().over(w).as("Rank"))

But here onwards, I am not sure on how to get regionTimeStart and regionTimeEnd columns. regionTimeEnd column is nothing but 'lead' of regionTimeStart except the last entry in group.

I see Aggregate operations have 'first' and 'last' functions but for that I need to group data based on ('uid','region') which spoils monotonically increasing order of path traversed i.e. at time 13,14 user has come back to region '1' and I want that to be retained instead of clubbing it with initial region '1' at time 1.

It would be very helpful if anyone one can guide me. I am new to Spark and I have better understanding of Scala Spark APIs compared to Python/JAVA Spark APIs.

回答1:

Window functions are indeed useful although your approach can work only if you assume that user visits given region only once. Also window definition you use is incorrect - multiple calls to partitionBy simply return new objects with different window definitions. If you want to partition by multiple columns you should pass them in a single call (.partitionBy("region", "uid")).

Lets start with marking continuous visits in each region:

import org.apache.spark.sql.functions.{lag, sum, not}
import org.apache.spark.sql.expressions.Window 

val w = Window.partitionBy($"uid").orderBy($"timestamp")

val change = (not(lag($"region", 1).over(w) <=> $"region")).cast("int")
val ind = sum(change).over(w)

val dfWithInd = df.withColumn("ind", ind)

Next you we simply aggregate over the groups and find leads:

import org.apache.spark.sql.functions.{lead, coalesce}

val regionTimeEnd = coalesce(lead($"timestamp", 1).over(w), $"max_")

val result = dfWithInd
  .groupBy($"uid", $"region", $"ind")
  .agg(min($"timestamp").alias("timestamp"), max($"timestamp").alias("max_"))
  .drop("ind")
  .withColumn("regionTimeEnd", regionTimeEnd)
  .withColumnRenamed("timestamp", "regionTimeStart")
  .drop("max_")

result.show

// +---+------+---------------+-------------+
// |uid|region|regionTimeStart|regionTimeEnd|
// +---+------+---------------+-------------+
// |  a|     1|              1|            5|
// |  a|     2|              5|            8|
// |  a|     3|              8|            9|
// |  a|     4|              9|           13|
// |  a|     1|             13|           15|
// |  a|     3|             15|           17|
// |  a|     5|             17|           20|
// +---+------+---------------+-------------+