Using R 3.2.0 with caret 6.0-41 and randomForest 4.6-10 on a 64-bit Linux machine.
When trying to use the predict()
method on a randomForest
object trained with the train()
function from the caret
package using a formula, the function returns an error.
When training via randomForest()
and/or using x=
and y=
rather than a formula, it all runs smoothly.
Here is a working example:
library(randomForest)
library(caret)
data(imports85)
imp85 <- imports85[, c("stroke", "price", "fuelType", "numOfDoors")]
imp85 <- imp85[complete.cases(imp85), ]
imp85[] <- lapply(imp85, function(x) if (is.factor(x)) x[,drop=TRUE] else x) ## Drop empty levels for factors.
modRf1 <- randomForest(numOfDoors~., data=imp85)
caretRf <- train( numOfDoors~., data=imp85, method = "rf" )
modRf2 <- caretRf$finalModel
modRf3 <- randomForest(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"])
caretRf <- train(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"], method = "rf")
modRf4 <- caretRf$finalModel
p1 <- predict(modRf1, newdata=imp85)
p2 <- predict(modRf2, newdata=imp85)
p3 <- predict(modRf3, newdata=imp85)
p4 <- predict(modRf4, newdata=imp85)
Among the last 4 lines, only the second one p2 <- predict(modRf2, newdata=imp85)
returns the following error:
Error in predict.randomForest(modRf2, newdata = imp85) :
variables in the training data missing in newdata
It seems that the reason for this error is that the predict.randomForest
method uses rownames(object$importance)
to determine the name of the variables used to train the random forest object
. And when looking at
rownames(modRf1$importance)
rownames(modRf2$importance)
rownames(modRf3$importance)
rownames(modRf4$importance)
We see:
[1] "stroke" "price" "fuelType"
[1] "stroke" "price" "fuelTypegas"
[1] "stroke" "price" "fuelType"
[1] "stroke" "price" "fuelType"
So somehow, when using the caret
train()
function with a formula changes the name of the (factor) variables in the importance
field of the randomForest
object.
Is it really an inconsistency between the formula and and non-formula version of the caret train()
function? Or am I missing something?
Another way is to explicitly code the testing data using
model.matrix
, e.g.This isn't an answer to your question, but I believe it will help others since it helped me. If you have any missing NA's in the columns of your test data that were used in the columns of your training data, then predict will not work. You need to impute these values first.
First, almost never use the
$finalModel
object for prediction. Usepredict.train
. This is one good example of why.There is some inconsistency between how some functions (including
randomForest
andtrain
) handle dummy variables. Most functions in R that use the formula method will convert factor predictors to dummy variables because their models require numerical representations of the data. The exceptions to this are tree- and rule-based models (that can split on categorical predictors), naive Bayes, and a few others.So
randomForest
will not create dummy variables when you userandomForest(y ~ ., data = dat)
buttrain
(and most others) will using a call liketrain(y ~ ., data = dat)
.The error occurs because
fuelType
is a factor. The dummy variables created bytrain
don't have the same names sopredict.randomForest
can't find them.Using the non-formula method with
train
will pass the factor predictors torandomForest
and everything will work.TL;DR
Use the non-formula method with
train
if you want the same levels or usepredict.train
Max
There can be two reasons why you get this error.
1. The categories of the categorical variables in the train and test sets don't match. To check that, you can run something like the following.
Well, first of all, it is good practice to keep the independent variables/features in a list. Say that list is "vars". And say, you separated "Data" into "Train" and "Test". Let's go:
Once you find the non-matching categorical variables, you can go back, and impose the categories of Test data onto Train data, and then re-build your model. In a loop similar to above, for each nonMatchingVar, you can do
2. A silly one. If you accidentally leave the dependent variable in the set of independent variables, you may run into this error message. I have done that mistake. Solution: Just be more careful.