Cross-validation predictions from caret in assigne

2019-08-22 11:43发布

I am wondering why predictions from 'Fold1' are actually predictions from the second fold in my predefined folds. I attach an example of what I mean.

# load the library
library(caret)
# load the cars dataset
data(cars)
# define folds
cv_folds <- createFolds(cars$Price, k = 5, list = TRUE, returnTrain = TRUE)
# define training control
train_control <- trainControl(method="cv", index = cv_folds, savePredictions = 'final')
# fix the parameters of the algorithm
# train the model
model <- caret::train(Price~., data=cars, trControl=train_control, method="gbm", verbose = F)

model$pred$rowIndex[model$pred$Resample == 'Fold1'] %in% cv_folds[[2]]

1条回答
冷血范
2楼-- · 2019-08-22 11:58

The Resample data of 'Fold1' are the records which are not in cv_folds[[1]]. These records are contained in cv_folds 2-5. This is correct as you are running a 5-fold cross-validation. Resample Fold 1 is tested against training the model on folds 2-5. Resample fold 2 is tested against training on folds 1, 3-5, and so on.

In summary: The predictions in Fold1 are the test predictions from training a model on cv_folds 2-5.

Edit: based on comment

All the needed info is in the model$pred table. I added a bit of code for clarification:

model$pred %>% 
  select(rowIndex, pred, Resample) %>%
  rename(predection = pred, holdout = Resample) %>% 
  mutate(trained_on = case_when(holdout == "Fold1" ~ "Folds 2, 3, 4, 5",
                                holdout == "Fold2" ~ "Folds 1, 3, 4, 5", 
                                holdout == "Fold3" ~ "Folds 1, 2, 4, 5", 
                                holdout == "Fold4" ~ "Folds 1, 2, 3, 5", 
                                holdout == "Fold5" ~ "Folds 1, 2, 3, 4"))

  rowIndex predection holdout       trained_on
1      610   13922.60   Fold2 Folds 1, 3, 4, 5
2      623   38418.83   Fold2 Folds 1, 3, 4, 5
3      604   12383.55   Fold2 Folds 1, 3, 4, 5
4      607   15040.07   Fold2 Folds 1, 3, 4, 5
5       95   33549.40   Fold2 Folds 1, 3, 4, 5
6      624   40357.35   Fold2 Folds 1, 3, 4, 5

Basicly what you need for further stacking with the predictions are the pred and rowIndex columns from the model$pred table.

The rowIndex refers to the row from the original data. So rowIndex 610 refers to record 610 in the cars dataset. You can compare that the data in obs, which is the value of the Price column from the cars dataset.

查看更多
登录 后发表回答