I want to tunning my model with grid search and cross validation with spark. In the spark, it must put the base model in a pipeline, the office demo of pipeline use the LogistictRegression
as an base model, which can be new as an object. However, the RandomForest
model cannot be new by client code, so it seems not be able to use RandomForest
in the pipeline api. I don't want to recreate an wheel, so can anybody give some advice?
Thanks
问题:
回答1:
However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api.
Well, that is true but you simply trying to use a wrong class. Instead of mllib.tree.RandomForest
you should use ml.classification.RandomForestClassifier
. Here is an example based on the one from MLlib docs.
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._
case class Record(category: String, features: Vector)
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))
val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("label")
val rf = new RandomForestClassifier()
.setNumTrees(3)
.setFeatureSubsetStrategy("auto")
.setImpurity("gini")
.setMaxDepth(4)
.setMaxBins(32)
val pipeline = new Pipeline()
.setStages(Array(indexer, rf))
val model = pipeline.fit(trainDF)
model.transform(testDF)
There is one thing I couldn't figure out here. As far as I can tell it should be possible to use labels extracted from LabeledPoints
directly, but for some reason it doesn't work and pipeline.fit
raises IllegalArgumentExcetion
:
RandomForestClassifier was given input with invalid label column label, without the number of classes specified.
Hence the ugly trick with StringIndexer
. After applying we get required attributes ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}
) but some classes in ml
seem to work just fine without it.