Spark ML - Save OneVsRestModel

2019-01-19 04:28发布

I am in the middle of refactoring my code to take advantage of DataFrames, Estimators, and Pipelines. I was originally using MLlib Multiclass LogisticRegressionWithLBFGS on RDD[LabeledPoint]. I am enjoying learning and using the new API, but I am not sure how to save my new model and apply it on new data.

Currently, the ML implementation of LogisticRegression only supports binary classification. I am, instead using OneVsRest like so:

val lr = new LogisticRegression().setFitIntercept(true)
val ovr = new OneVsRest()
ovr.setClassifier(lr)
val ovrModel = ovr.fit(training)

I would now like to save my OneVsRestModel, but this does not seem to be supported by the API. I have tried:

ovrModel.save("my-ovr") // Cannot resolve symbol save
ovrModel.models.foreach(_.save("model-" + _.uid)) // Cannot resolve symbol save

Is there a way to save this, so I can load it in a new application for making new predictions?

1条回答
Lonely孤独者°
2楼-- · 2019-01-19 05:08

Spark 2.0.0

OneVsRestModel implements MLWritable so it should be possible to save it directly. Method shown below can be still useful to save individual models separately.

Spark < 2.0.0

The problem here is that models returns an Array of ClassificationModel[_, _]] not an Array of LogisticRegressionModel (or MLWritable). To make it work you'll have to be specific about the types:

import org.apache.spark.ml.classification.LogisticRegressionModel

ovrModel.models.zipWithIndex.foreach { 
  case (model: LogisticRegressionModel, i: Int) => 
    model.save(s"model-${model.uid}-$i")
}

or to be more generic:

import org.apache.spark.ml.util.MLWritable

ovrModel.models.zipWithIndex.foreach { 
  case (model: MLWritable, i: Int) =>
    model.save(s"model-${model.uid}-$i")
}

Unfortunately as for now (Spark 1.6) OneVsRestModel doesn't implement MLWritable so it cannot be saved alone.

Note:

All models int the OneVsRest seem to use the same uid hence we need an explicit index. It will be also useful to identify the model later.

查看更多
登录 后发表回答