Spark MultilayerPerceptronClassifier Class Probabi

2019-08-28 19:08发布

问题:

I am an experienced Python programmer trying to transition some Python code to Spark for a classification task. This is my first time working in Spark/Scala.

In Python, both Keras/tensorflow and sci-kit Learn neural networks do a great job on the multi-class classification and I'm able to easily return the top 3 most probable classes along with probabilities which are key to this project.

I have been generally successful in moving the code to Spark (Scala) and I'm able to generate the correct predictions but I have not been able to find a way to return probabilities for the top predicted classes from the MultilayerPerceptronClassifier in MLlib.

The closest solution I found was in this post: How to get classification probabilities from MultilayerPerceptronClassifier? However, I'm not able to get the solution in the post to work either because it's missing a key piece of code or I'm too new to Scala (probably the latter) to make the needed adjustments.

Has anyone solved this problem?

These are the current versions in my environment. Spark version: 2.1.1 Scala version: 2.11.8

Thanks for your help,

RKB

回答1:

If you carefully take a look at the results of MultilayerPerceptronClassificationModel.transform (model and test as defined in the example pipeline in the official documentation)

val result = model.transform(test)

result.printSchema
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

you'll see they contain probability column.

It is stored as o.a.s.ml.linalg.Vector column:

result.select($"probability").show(3, false)
+---------------------------------------------------+
|probability                                        |
+---------------------------------------------------+
|[2.630203838780848E-29,1.7323171642231641E-19,1.0] |
|[1.0,1.448487547623119E-121,4.530084532282489E-44] |
|[1.0,5.157808976162274E-122,2.5702890543589884E-44]|
+---------------------------------------------------+
only showing top 3 rows

and can be accessed using standard methods.

This feature is available since Spark 2.3 (SPARK-12664 Expose probability, rawPrediction in MultilayerPerceptronClassificationModel).