how to obtain the trained best model from a crossv

2019-04-08 04:55发布

I built a pipeline including a DecisionTreeClassifier(dt) like this

val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

Then I used this pipeline as the estimator in a CrossValidator in order to get a model with the best set of hyperparameters like this

val c_v = new CrossValidator().setEstimator(pipeline).setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).setEstimatorParamMaps(paramGrid).setNumFolds(5)

Finally, I could train a model on a training test with this crossvalidator

val model = c_v.fit(train)

But the question is, I want to view the best trained decision tree model with the parameter .toDebugTree of DecisionTreeClassificationModel. But model is a CrossValidatorModel. Yes, you can use model.bestModel, but it is still of type Model, you cannot apply .toDebugTree to it. And also I assume the bestModel is still a pipline including labelIndexer, featureIndexer, dt, labelConverter.

So does anyone know how I can obtain the decisionTree model from the model fitted by the crossvalidator, which I could view the actual model by toDebugString? Or is there any workaround that I can view the decisionTree model?

2条回答
别忘想泡老子
2楼-- · 2019-04-08 05:20

Well, in cases like this one the answer is always the same - be specific about the types.

First extract the pipeline model, since what you are trying to train is a Pipeline :

import org.apache.spark.ml.PipelineModel

val bestModel: Option[PipelineModel] = model.bestModel match {
  case p: PipelineModel => Some(p)
  case _ => None
}

Then you'll need to extract the model from the underlying stage. In your case it's a decision tree classification model :

import org.apache.spark.ml.classification.DecisionTreeClassificationModel

val treeModel: Option[DecisionTreeClassificationModel] = bestModel
  flatMap {
    _.stages.collect {
      case t: DecisionTreeClassificationModel => t
    }.headOption
  }

To print the tree, for example :

treeModel.foreach(_.toDebugString)
查看更多
兄弟一词,经得起流年.
3楼-- · 2019-04-08 05:24

(DISCLAIMER: There is another aspect, which imho deserves its own answer. I know it is a little OT given the question, however, it questions the question. If somebody down votes because he disagrees with the content please also leave a comment)

Should you extract the "best" tree and the answer is typically no.

Why are we doing CV? We are trying to evaluate our choices, to get. The choices are the classifiers used, hyper parameter used, preprocessing like feature selection. For the last one it is important that this happens on the training data. E.g., do not normalise the features on all data. So the output of CV is the pipeline generated. On a side note: the feature selection should evaluated on a "internal cv"

What we are not doing, we are not generating a "pool of classifiers" where we choose the best classifier. However, i've seen this surprisingly often. The problem is that you have an extremely high chance of a twining-effect. Even in a perfectly Iid dataset there are likely (near)duplicated training examples. There is a pretty good chance that the "best" CV classifier is just an indication in which fold you have the best twining.

Hence, what should you do? Once, you have fixed your parameters you should use the entire training data to build the final model. Hopefully, but nobody does this, you have set aside an additional evaluation set, which you have never touched in the process to get an evaluation of your final model.

查看更多
登录 后发表回答