get TopN of all groups after group by using Spark

2020-01-26 08:25发布

问题:

I have a Spark SQL DataFrame:

user1 item1 rating1
user1 item2 rating2
user1 item3 rating3
user2 item1 rating4
...

How to group by user and then return TopN items from every group using Scala?

Similarity code using Python:

df.groupby("user").apply(the_func_get_TopN)

回答1:

You can use rank window function as follows

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}

val n: Int = ???

// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))

// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)

If you don't care about ties then you can replace rank with row_number