I'm a beginner with Spark and I am trying to create an RDD that contains the top 3 values for every key, (Not just the top 3 values). My current RDD contains thousands of entries in the following format:
(key, String, value)
So imagine I had an RDD with content like this:
[("K1", "aaa", 6), ("K1", "bbb", 3), ("K1", "ccc", 2), ("K1", "ddd", 9),
("B1", "qwe", 4), ("B1", "rty", 7), ("B1", "iop", 8), ("B1", "zxc", 1)]
I can currently display the top 3 values in the RDD like so:
("K1", "ddd", 9)
("B1", "iop", 8)
("B1", "rty", 7)
Using:
top3RDD = rdd.takeOrdered(3, key = lambda x: x[2])
Instead what I want is to gather the top 3 values for every key in the RDD so I would like to return this instead:
("K1", "ddd", 9)
("K1", "aaa", 6)
("K1", "bbb", 3)
("B1", "iop", 8)
("B1", "rty", 7)
("B1", "qwe", 4)
You need to groupBy the key
and then you can use heapq.nlargest
to take the top 3 values from each group:
from heapq import nlargest
rdd.groupBy(
lambda x: x[0]
).flatMap(
lambda g: nlargest(3, g[1], key=lambda x: x[2])
).collect()
[('B1', 'iop', 8),
('B1', 'rty', 7),
('B1', 'qwe', 4),
('K1', 'ddd', 9),
('K1', 'aaa', 6),
('K1', 'bbb', 3)]
If you're open to converting your rdd
to a DataFrame, you can define a Window to partition by the key
and sort by the value
descending. Use this Window to compute the row number, and pick the rows where the row number is less than or equal to 3.
import pyspark.sql.functions as f
import pyspark.sql.Window
w = Window.partitionBy("key").orderBy(f.col("value").desc())
rdd.toDF(["key", "String", "value"])\
.select("*", f.row_number().over(w).alias("rowNum"))\
.where(f.col("rowNum") <= 3)\
.drop("rowNum")
.show()