Edit: The answer helps, but I described my solution in: memoryOverhead issue in Spark.
I have an RDD with 202092 partitions, which reads a dataset created by others. I can manually see that the data is not balanced across the partitions, for example some of them have 0 images and other have 4k, while the mean lies at 432. When processing the data, I got this error:
Container killed by YARN for exceeding memory limits. 16.9 GB of 16 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead.
while memoryOverhead is already boosted. I feel that some spikes are happening which make Yarn kill my container, because that spike overflows the specified borders.
So what should I do make sure that my data are (roughly) balanced across partitions?
My idea was that repartition() would work, it invokes shuffling:
dataset = dataset.repartition(202092)
but I just got the very same error, despite the programming-guide's instructions:
repartition(numPartitions)
Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. This always shuffles all data over the network.
Check my toy example though:
data = sc.parallelize([0,1,2], 3).mapPartitions(lambda x: range((x.next() + 1) * 1000))
d = data.glom().collect()
len(d[0]) # 1000
len(d[1]) # 2000
len(d[2]) # 3000
repartitioned_data = data.repartition(3)
re_d = repartitioned_data.glom().collect()
len(re_d[0]) # 1854
len(re_d[1]) # 1754
len(re_d[2]) # 2392
repartitioned_data = data.repartition(6)
re_d = repartitioned_data.glom().collect()
len(re_d[0]) # 422
len(re_d[1]) # 845
len(re_d[2]) # 1643
len(re_d[3]) # 1332
len(re_d[4]) # 1547
len(re_d[5]) # 211
repartitioned_data = data.repartition(12)
re_d = repartitioned_data.glom().collect()
len(re_d[0]) # 132
len(re_d[1]) # 265
len(re_d[2]) # 530
len(re_d[3]) # 1060
len(re_d[4]) # 1025
len(re_d[5]) # 145
len(re_d[6]) # 290
len(re_d[7]) # 580
len(re_d[8]) # 1113
len(re_d[9]) # 272
len(re_d[10]) # 522
len(re_d[11]) # 66