Caret and dummy variables

2019-08-02 17:02发布

When calling the train function of the caret package, the data is automatically transformed so that all factor variables are turned into a set of dummy variables.

How can I prevent this behaviour? Is it possible to say to caret "don't transform factors into dummy variables"?

For example:

If I run the rpart algorithm on the etitanic data:

library(caret)
library(earth)
data(etitanic)

etitanic$survived[etitanic$survived==1] <- 'YES'
etitanic$survived[etitanic$survived!='YES'] <- 'NO'

model<-train(survived~., data=etitanic, method='rpart')

Then the final model produced looks like so:

> model$finalModel
n= 1046 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 1046 427 NO (0.5917782 0.4082218)  
   2) sexmale>=0.5 658 135 NO (0.7948328 0.2051672)  
     4) age>=9.5 615 110 NO (0.8211382 0.1788618) *
     5) age< 9.5 43  18 YES (0.4186047 0.5813953)  
      10) sibsp>=2.5 16   1 NO (0.9375000 0.0625000) *
      11) sibsp< 2.5 27   3 YES (0.1111111 0.8888889) *
   3) sexmale< 0.5 388  96 YES (0.2474227 0.7525773) *

whereas if I run the rpart algorithm directly and build a tree, I get

> rpart(survived~., data=etitanic)
n= 1046 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 1046 427 NO (0.59177820 0.40822180)  
   2) sex=male 658 135 NO (0.79483283 0.20516717)  
     4) age>=9.5 615 110 NO (0.82113821 0.17886179) *
     5) age< 9.5 43  18 YES (0.41860465 0.58139535)  
      10) sibsp>=2.5 16   1 NO (0.93750000 0.06250000) *
      11) sibsp< 2.5 27   3 YES (0.11111111 0.88888889) *
   3) sex=female 388  96 YES (0.24742268 0.75257732)  
     6) pclass=3rd 152  72 NO (0.52631579 0.47368421)  
      12) age>=1.5 145  66 NO (0.54482759 0.45517241)  
        24) sibsp>=1.5 19   4 NO (0.78947368 0.21052632) *
        25) sibsp< 1.5 126  62 NO (0.50793651 0.49206349)  
          50) age>=27.5 44  15 NO (0.65909091 0.34090909) *
          51) age< 27.5 82  35 YES (0.42682927 0.57317073) *
      13) age< 1.5 7   1 YES (0.14285714 0.85714286) *
     7) pclass=1st,2nd 236  16 YES (0.06779661 0.93220339) *

Now, forget the part that the trees are different. I understand, they are built with different parameters. However, they are also build on different data sets. For example, the caret tree was built on a dataset where one column was "sexmale", and this was the dummy column made from the sex column in the original data.

Is there some way to tell caret not to perform this dummy variable creation before feeding the data to rpart?

标签: r r-caret
1条回答
Root(大扎)
2楼-- · 2019-08-02 17:45

To make caret behave exactly like rpart first I set the trainControl function to "none" and will use a tuneGrid of one record with a cp setting of 0.01. The defaults are then exactly the same as the defaults of rpart.

ctrl <- trainControl(method = "none")
#caret formula model
model<-train(survived ~ ., 
             data=etitanic, 
             method='rpart', 
             trControl = ctrl, 
             tuneGrid = expand.grid(cp = 0.01))

# rpart model
model_rp <- rpart(survived~., data=etitanic)

print(model$finalModel)

 1) root 1046 427 NO (0.59177820 0.40822180)  
   2) sexmale>=0.5 658 135 NO (0.79483283 0.20516717)  
     4) age>=9.5 615 110 NO (0.82113821 0.17886179) *
     5) age< 9.5 43  18 YES (0.41860465 0.58139535)  
      10) sibsp>=2.5 16   1 NO (0.93750000 0.06250000) *
      11) sibsp< 2.5 27   3 YES (0.11111111 0.88888889) *
   3) sexmale< 0.5 388  96 YES (0.24742268 0.75257732)  
     6) pclass3rd>=0.5 152  72 NO (0.52631579 0.47368421)  
      12) age>=1.5 145  66 NO (0.54482759 0.45517241)  
        24) sibsp>=1.5 19   4 NO (0.78947368 0.21052632) *
        25) sibsp< 1.5 126  62 NO (0.50793651 0.49206349)  
          50) age>=27.5 44  15 NO (0.65909091 0.34090909) *
          51) age< 27.5 82  35 YES (0.42682927 0.57317073) *
      13) age< 1.5 7   1 YES (0.14285714 0.85714286) *
     7) pclass3rd< 0.5 236  16 YES (0.06779661 0.93220339) *

print(model_rp)


 1) root 1046 427 NO (0.59177820 0.40822180)  
   2) sex=male 658 135 NO (0.79483283 0.20516717)  
     4) age>=9.5 615 110 NO (0.82113821 0.17886179) *
     5) age< 9.5 43  18 YES (0.41860465 0.58139535)  
      10) sibsp>=2.5 16   1 NO (0.93750000 0.06250000) *
      11) sibsp< 2.5 27   3 YES (0.11111111 0.88888889) *
   3) sex=female 388  96 YES (0.24742268 0.75257732)  
     6) pclass=3rd 152  72 NO (0.52631579 0.47368421)  
      12) age>=1.5 145  66 NO (0.54482759 0.45517241)  
        24) sibsp>=1.5 19   4 NO (0.78947368 0.21052632) *
        25) sibsp< 1.5 126  62 NO (0.50793651 0.49206349)  
          50) age>=27.5 44  15 NO (0.65909091 0.34090909) *
          51) age< 27.5 82  35 YES (0.42682927 0.57317073) *
      13) age< 1.5 7   1 YES (0.14285714 0.85714286) *
     7) pclass=1st,2nd 236  16 YES (0.06779661 0.93220339) *

Looking at both models you can see that even though caret transformed the factors and characters to have a default class as areference class, the tree is exactly the same with the same percentages in the nodes. You could use the partykit package and use as.party() on the models to get a better layout.

But if you want to have the exact same model as rpart without using the factors, you can use the default way of using models.

#caret default model
model_xy <-train(x = etitanic[, -2], 
                 y = etitanic$survived, 
                 method='rpart', 
                 trControl = ctrl, 
                 tuneGrid = expand.grid(cp = 0.01))

print(model_xy$finalModel)

 1) root 1046 427 NO (0.59177820 0.40822180)  
   2) sex=male 658 135 NO (0.79483283 0.20516717)  
     4) age>=9.5 615 110 NO (0.82113821 0.17886179) *
     5) age< 9.5 43  18 YES (0.41860465 0.58139535)  
      10) sibsp>=2.5 16   1 NO (0.93750000 0.06250000) *
      11) sibsp< 2.5 27   3 YES (0.11111111 0.88888889) *
   3) sex=female 388  96 YES (0.24742268 0.75257732)  
     6) pclass=3rd 152  72 NO (0.52631579 0.47368421)  
      12) age>=1.5 145  66 NO (0.54482759 0.45517241)  
        24) sibsp>=1.5 19   4 NO (0.78947368 0.21052632) *
        25) sibsp< 1.5 126  62 NO (0.50793651 0.49206349)  
          50) age>=27.5 44  15 NO (0.65909091 0.34090909) *
          51) age< 27.5 82  35 YES (0.42682927 0.57317073) *
      13) age< 1.5 7   1 YES (0.14285714 0.85714286) *
     7) pclass=1st,2nd 236  16 YES (0.06779661 0.93220339) *
查看更多
登录 后发表回答