How to compute “custom running total” in spark 1.5

2019-08-07 09:31发布

问题:

I have a loan payment history for each LoanId stored in parquet file and trying calculate "Past Due" amount for each period for each loan. This would be simple partition over window task if not the tricky nature of how due amount is computed.

If customer makes payment less than due amount, than past due mount is increased, on the other hand if customer makes advance payments that extra payment is ignored in the subsequent periods (rows 5&6 in the sample below).

LoanID  Period  DueAmt  ActualPmt   PastDue
1       1       100     100             0
1       2       100     60              -40
1       3       100     100             -40
1       4       100     200             0   <== This advance payment is not rolled to next period
1       5       100     110             0   <== This advance payment is not rolled to next period
1       6       100     80              -20
1       7       100     60              -60
1       8       100     100             -60
2       1       150     150             0
2       2       150     150             0
2       3       150     150             0
3       1       200     200             0
3       2       200     120             -80
3       3       200     120             -160

To solve this I effectively need to apply custom function for each partition(LoanID) ordered by period.

What options are available in spark.

Straightforward but complicated seems to use DF-> RDD-> groupby, apply lambda convert back to dataframe.

More elegant would be custom UDAF (in scala?)with window function but can't find a single implementation example of this.


Ok, so I tried first solution with roundtrip from Dataframe to Pair RDD and back

    from pyspark.sql import Row 
    def dueAmt(partition):
        '''
        @type partition:list 
        '''
        #first sort rows
        sp=sorted(partition, key=lambda r: r.Period )
        res=[]
        due=0
        for r in sp:
            due+=r.ActualPmt-r.DueAmt
            if due>0: due=0;
            #row is immutable so we need to create new row with updated value
            d=r.asDict()
            d['CalcDueAmt']=-due
            newRow=Row(**d)
            res.append(newRow)
        return res    

    df = sqlContext.read.format('com.databricks.spark.csv').options(header='true', inferschema='true').load('PmtDueSample2.csv').cache()
    rd1=df.rdd.map(lambda r: (r.LoanID, r ) )
    rd2=rd1.groupByKey()
    rd3=rd2.mapValues(dueAmt)
    rd4=rd3.flatMap(lambda t: t[1] )
    df2=rd4.toDF()

Seems to work.

On this journey I actually discovered couple of bugs in pyspark implementation.

  1. Implementation of ____call____ in class Row is wrong.
  2. Annoying bug in that Row's constructor. For no obvious reason ____new____ sorts columns, so at the end of journey my resulting table had columns ordered alphabetically. This simply makes harder to look at the final result.

回答1:

Neither pretty nor efficient but should give you something to work with. Lets start with creating and registering a table:

val df = sc.parallelize(Seq(
  (1, 1, 100, 100), (1, 2, 100, 60), (1, 3, 100, 100),
  (1, 4, 100, 200), (1, 5, 100, 110), (1, 6, 100, 80),
  (1, 7, 100, 60), (1, 8, 100, 100), (2, 1, 150, 150),
  (2, 2, 150, 150), (2, 3, 150, 150), (3, 1, 200, 200),
  (3, 2, 200, 120), (3, 3, 200, 120)
)).toDF("LoanID", "Period", "DueAmt", "ActualPmt")

df.registerTempTable("df")

Next lets define and register an UDF:

case class Record(period: Int, dueAmt: Int, actualPmt: Int, pastDue: Int)

def runningPastDue(idxs: Seq[Int], dues: Seq[Int], pmts: Seq[Int]) = {
  def f(acc: List[(Int, Int, Int, Int)], x: (Int, (Int, Int))) = 
    (acc.head, x) match {
      case ((_, _, _, pastDue), (idx, (due, pmt))) => 
        (idx, due, pmt, (pmt - due + pastDue).min(0)) :: acc
    }

  idxs.zip(dues.zip(pmts))
    .toList
    .sortBy(_._1)
    .foldLeft(List((0, 0, 0, 0)))(f)
    .reverse
    .tail
    .map{ case (i, due, pmt, past) => Record(i, due, pmt, past) }
}

sqlContext.udf.register("runningPastDue", runningPastDue _)

Aggregate, and compute sums:

val aggregated = sqlContext.sql("""
  SELECT LoanID, explode(pmts) pmts FROM (
    SELECT LoanId, 
           runningPastDue(
             collect_list(Period), 
             collect_list(DueAmt), 
             collect_list(ActualPmt)
           ) pmts
    FROM df GROUP BY LoanID) tmp""")

val flattenExprs = List("Period", "DueAmt", "ActualPmt", "PastDue")
  .zipWithIndex
  .map{case (c, i) => col(s"tmp._${i+1}").alias(c)}

Finally flatten:

val result = aggregated.select($"LoanID" :: flattenExprs: _*)