PySpark: Calculate grouped-by AUC

2019-07-03 22:25发布

问题:

  • Spark version: 1.6.0

I tried computing AUC (area under ROC) grouped by the field id. Given the following data:

# Within each key-value pair
# key is "id"
# value is a list of (score, label)
data = sc.parallelize(
         [('id1', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0)),
          ('id2', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0))
         ]

The BinaryClassificationMetrics class can calculate the AUC given a list of (score, label).

I want to compute AUC by key (i.e. id1, id2). But how to "map" a class to an RDD by key?

Update

I tried to wrap the BinaryClassificationMetrics in a function:

def auc(scoreAndLabels):
    return BinaryClassificationMetrics(scoreAndLabels).areaUnderROC

And then map the wrapper function to each values:

data.groupByKey()\
    .mapValues(auc)

But the list of (score, label) is in fact of type ResultIterable in mapValues() while the BinaryClassificationMetrics expects RDD.

Is there any approach of converting the ResultIterable to RDD so the the auc function can be applied? Or any other workaround for computing group-by AUC (without importing third-party modules like scikit-learn)?

回答1:

Instead of using BinaryClassificationMetrics you can use sklearn.metrics.auc and map each RDD element value and you'll get your AUC value per key:

from sklearn.metrics import auc

data = sc.parallelize([
         ('id1', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0)]),
         ('id2', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0)])])

result_aucs = data.map(lambda x: (x[0] + '_auc', auc(*zip(*x[1]))))
result_aucs.collect()


Out [1]: [('id1_auc', 0.15000000000000002), ('id2_auc', 0.15000000000000002)]


回答2:

Here's a way to get auc without using sklearn:

keys = data.map(lambda x: x[0]).distinct().collect()
rslt = {}
for k in keys:
    scoreAndLabels = data.filter(lambda x: x[0]==k).flatMap(lambda x: x[1])
    rslt[k] = BinaryClassificationMetrics(scoreAndLabels).areaUnderROC

print(rslt)

Note: this solution requires that the number of key is small enough to fit in the memory.

If you have so many keys that you can't collect() them into memory, don't use this