Get Top 3 values for every key in a RDD in Spark

2020-03-26 17:40发布

I'm a beginner with Spark and I am trying to create an RDD that contains the top 3 values for every key, (Not just the top 3 values). My current RDD contains thousands of entries in the following format:

(key, String, value)

So imagine I had an RDD with content like this:

[("K1", "aaa", 6), ("K1", "bbb", 3), ("K1", "ccc", 2), ("K1", "ddd", 9),
("B1", "qwe", 4), ("B1", "rty", 7), ("B1", "iop", 8), ("B1", "zxc", 1)]

I can currently display the top 3 values in the RDD like so:

("K1", "ddd", 9)
("B1", "iop", 8)
("B1", "rty", 7)

Using:

top3RDD = rdd.takeOrdered(3, key = lambda x: x[2])

Instead what I want is to gather the top 3 values for every key in the RDD so I would like to return this instead:

("K1", "ddd", 9)
("K1", "aaa", 6)
("K1", "bbb", 3)
("B1", "iop", 8)
("B1", "rty", 7)
("B1", "qwe", 4)

2条回答
疯言疯语
2楼-- · 2020-03-26 18:11

You need to groupBy the key and then you can use heapq.nlargest to take the top 3 values from each group:

from heapq import nlargest
rdd.groupBy(
    lambda x: x[0]
).flatMap(
    lambda g: nlargest(3, g[1], key=lambda x: x[2])
).collect()

[('B1', 'iop', 8), 
 ('B1', 'rty', 7), 
 ('B1', 'qwe', 4), 
 ('K1', 'ddd', 9), 
 ('K1', 'aaa', 6), 
 ('K1', 'bbb', 3)]
查看更多
够拽才男人
3楼-- · 2020-03-26 18:28

If you're open to converting your rdd to a DataFrame, you can define a Window to partition by the key and sort by the value descending. Use this Window to compute the row number, and pick the rows where the row number is less than or equal to 3.

import pyspark.sql.functions as f
import pyspark.sql.Window

w = Window.partitionBy("key").orderBy(f.col("value").desc())

rdd.toDF(["key", "String", "value"])\
    .select("*", f.row_number().over(w).alias("rowNum"))\
    .where(f.col("rowNum") <= 3)\
    .drop("rowNum")
    .show()
查看更多
登录 后发表回答