I have a Spark SQL DataFrame:
user1 item1 rating1
user1 item2 rating2
user1 item3 rating3
user2 item1 rating4
...
How to group by user and then return TopN
items from every group using Scala?
Similarity code using Python:
df.groupby("user").apply(the_func_get_TopN)
You can use rank
window function as follows
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}
val n: Int = ???
// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))
// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)
If you don't care about ties then you can replace rank
with row_number