Pyspark : Cumulative Sum with reset condition

2019-05-22 06:45发布

问题:

We have dataframe like below :

+------+--------------------+
| Flag |               value|
+------+--------------------+
|1     |5                   |
|1     |4                   |
|1     |3                   |
|1     |5                   |
|1     |6                   |
|1     |4                   |
|1     |7                   |
|1     |5                   |
|1     |2                   |
|1     |3                   |
|1     |2                   |
|1     |6                   |
|1     |9                   |      
+------+--------------------+

After normal cumsum we get this.

+------+--------------------+----------+
| Flag |               value|cumsum    |
+------+--------------------+----------+
|1     |5                   |5         |
|1     |4                   |9         |
|1     |3                   |12        |
|1     |5                   |17        |
|1     |6                   |23        |
|1     |4                   |27        |
|1     |7                   |34        |
|1     |5                   |39        |
|1     |2                   |41        |
|1     |3                   |44        |
|1     |2                   |46        |
|1     |6                   |52        |
|1     |9                   |61        |       
+------+--------------------+----------+

Now what we want is for cumsum to reset when specific condition is set for ex. when it crosses 20.

Below is expected output:

+------+--------------------+----------+---------+
| Flag |               value|cumsum    |expected |
+------+--------------------+----------+---------+
|1     |5                   |5         |5        |
|1     |4                   |9         |9        |
|1     |3                   |12        |12       |
|1     |5                   |17        |17       |
|1     |6                   |23        |23       |
|1     |4                   |27        |4        |  <-----reset 
|1     |7                   |34        |11       |
|1     |5                   |39        |16       |
|1     |2                   |41        |18       |
|1     |3                   |44        |21       |
|1     |2                   |46        |2        |  <-----reset
|1     |6                   |52        |8        |
|1     |9                   |61        |17       |         
+------+--------------------+----------+---------+

This is how we are calculating the cumulative sum.

win_counter = Window.partitionBy("flag")

df_partitioned = df_partitioned.withColumn('cumsum',F.sum(F.col('value')).over(win_counter))