I have tried to run the experimental RandomForestClassifier
from the spark.ml
package (version 1.5.2). The dataset I used is from the LogisticRegression
example in the Spark ML guide.
Here is the code:
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.sql.Row
// Prepare training data from a list of (label, features) tuples.
val training = sqlContext.createDataFrame(Seq(
(1.0, Vectors.dense(0.0, 1.1, 0.1)),
(0.0, Vectors.dense(2.0, 1.0, -1.0)),
(0.0, Vectors.dense(2.0, 1.3, 1.0)),
(1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")
val rf = new RandomForestClassifier()
val model = rf.fit(training)
And here is the error, I obtain:
java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer.
at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:87)
at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:42)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:48)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:53)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:55)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:57)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:59)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:61)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:63)
at $iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:65)
at $iwC$$iwC$$iwC$$iwC.<init>(<console>:67)
at $iwC$$iwC$$iwC.<init>(<console>:69)
at $iwC$$iwC.<init>(<console>:71)
at $iwC.<init>(<console>:73)
at <init>(<console>:75)
at .<init>(<console>:79)
at .<clinit>(<console>)
at .<init>(<console>:7)
at .<clinit>(<console>)
at $print(<console>)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at org.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:1065)
at org.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1340)
at org.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:840)
at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:871)
at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:819)
at org.apache.spark.repl.SparkILoop.reallyInterpret$1(SparkILoop.scala:857)
at org.apache.spark.repl.SparkILoop.interpretStartingWith(SparkILoop.scala:902)
at org.apache.spark.repl.SparkILoop.command(SparkILoop.scala:814)
at org.apache.spark.repl.SparkILoop.processLine$1(SparkILoop.scala:657)
at org.apache.spark.repl.SparkILoop.innerLoop$1(SparkILoop.scala:665)
at org.apache.spark.repl.SparkILoop.org$apache$spark$repl$SparkILoop$$loop(SparkILoop.scala:670)
at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply$mcZ$sp(SparkILoop.scala:997)
at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply(SparkILoop.scala:945)
at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply(SparkILoop.scala:945)
at scala.tools.nsc.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:135)
at org.apache.spark.repl.SparkILoop.org$apache$spark$repl$SparkILoop$$process(SparkILoop.scala:945)
at org.apache.spark.repl.SparkILoop.process(SparkILoop.scala:1059)
at org.apache.spark.repl.Main$.main(Main.scala:31)
at org.apache.spark.repl.Main.main(Main.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:674)
at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180)
at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205)
at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:120)
at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)
The problem appears when the function tries to compute the number of classes in the column "label"
.
As you can see at line 84 in the source code of RandomForestClassifier, the function calls the DataFrame.schema
function with parameter "label"
. This call is OK and returns a org.apache.spark.sql.types.StructField
object.
Then, the function org.apache.spark.ml.util.MetadataUtils.getNumClasses
is called. As it does not return the expected output, an exception is raised at line 87.
After a quick glance at getNumClasses
source code, I suppose that the error is due to the fact that the data in colmun "label"
are neither BinaryAttribute
neither NominalAttribute
. However, I do not know how to fix this problem.
My question:
How can I fix this problem?
Thanks a lot for reading my question and for your help!