Spark - Window with recursion? - Conditionally pro

2020-02-06 16:40发布

问题:

I have the following dataframe showing the revenue of purchases.

+-------+--------+-------+
|user_id|visit_id|revenue|
+-------+--------+-------+
|      1|       1|      0|
|      1|       2|      0|
|      1|       3|      0|
|      1|       4|    100|
|      1|       5|      0|
|      1|       6|      0|
|      1|       7|    200|
|      1|       8|      0|
|      1|       9|     10|
+-------+--------+-------+

Ultimately I want the new column purch_revenue to show the revenue generated by the purchase in every row. As a workaround, I have also tried to introduce a purchase identifier purch_id which is incremented each time a purchase was made. So this is listed just as a reference.

+-------+--------+-------+-------------+--------+
|user_id|visit_id|revenue|purch_revenue|purch_id|
+-------+--------+-------+-------------+--------+
|      1|       1|      0|          100|       1|
|      1|       2|      0|          100|       1|
|      1|       3|      0|          100|       1|
|      1|       4|    100|          100|       1|
|      1|       5|      0|          100|       2|
|      1|       6|      0|          100|       2|
|      1|       7|    200|          100|       2|
|      1|       8|      0|          100|       3|
|      1|       9|     10|          100|       3|
+-------+--------+-------+-------------+--------+

I've tried to use the lag/lead function like this:

user_timeline = Window.partitionBy("user_id").orderBy("visit_id")
find_rev = fn.when(fn.col("revenue") > 0,fn.col("revenue"))\ 
  .otherwise(fn.lead(fn.col("revenue"), 1).over(user_timeline))
df.withColumn("purch_revenue", find_rev)

This duplicates the revenue column if revenue > 0 and also pulls it up by one row. Clearly, I can chain this for a finite N, but that's not a solution.

  • Is there a way to apply this recursively until revenue > 0?
  • Alternatively, is there a way to increment a value based on a condition? I've tried to figure out a way to do that but struggled to find one.

回答1:

Window functions don't support recursion but it is not required here. This type of sesionization can be easily handled with cumulative sum:

from pyspark.sql.functions import col, sum, when, lag
from pyspark.sql.window import Window

w = Window.partitionBy("user_id").orderBy("visit_id")
purch_id = sum(lag(when(
    col("revenue") > 0, 1).otherwise(0), 
    1, 0
).over(w)).over(w) + 1

df.withColumn("purch_id", purch_id).show()
+-------+--------+-------+--------+
|user_id|visit_id|revenue|purch_id|
+-------+--------+-------+--------+
|      1|       1|      0|       1|
|      1|       2|      0|       1|
|      1|       3|      0|       1|
|      1|       4|    100|       1|
|      1|       5|      0|       2|
|      1|       6|      0|       2|
|      1|       7|    200|       2|
|      1|       8|      0|       3|
|      1|       9|     10|       3|
+-------+--------+-------+--------+