Background
My original question here was Why using DecisionTreeModel.predict
inside map function raises an exception? and is related to How to generate tuples of (original lable, predicted label) on Spark with MLlib?
When we use Scala API a recommended way of getting predictions for RDD[LabeledPoint]
using DecisionTreeModel
is to simply map over RDD
:
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
Unfortunately similar approach in PySpark doesn\'t work so well:
labelsAndPredictions = testData.map(
lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()
Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transforamtion. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
Instead of that official documentation recommends something like this:
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
So what is going on here? There is no broadcast variable here and Scala API defines predict
as follows:
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predict(features)
}
/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
so at least at the first glance calling from action or transformation is not a problem since prediction seems to be a local operation.
Explanation
After some digging I figured out that the source of the problem is a JavaModelWrapper.call
method invoked from DecisionTreeModel.predict. It access SparkContext
which is required to call Java function:
callJavaFunc(self._sc, getattr(self._java_model, name), *a)
Question
In case of DecisionTreeModel.predict
there is a recommended workaround and all the required code is already a part of the Scala API but is there any elegant way to handle problem like this in general?
Only solutions I can think of right now are rather heavyweight:
- pushing everything down to JVM either by extending Spark classes through Implicit Conversions or adding some kind of wrappers
- using Py4j gateway directly