I would like to extract the terminal nodes of the random forest R implementation. As I have understood random forest, you have a sequence of orthogonal trees. When you predict a new observation (In regression), it enters all these trees and then you average the prediction of each individual tree. If I wanted to not average but maybe do a linear regression with these corresponding observations I would need, say, a list of the observations that are "associated" with this new observation. I have gone through the source code but havent come up with a way to obtain this. Can anyone help me?
可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
回答1:
There must be a better way to do this, but here's a workaround:
library(randomForest)
set.seed(713)
## data
my.df <- data.frame(x = rnorm(100), y = rnorm(100))
## forest
rf <- randomForest(y ~ x, data = my.df, ntree = 10, keep.inbag = TRUE)
keep.inbag = TRUE
saves the inbag observations that are used to fit each of the 10 trees in this example
predList <- lapply(seq_len(rf$ntree), function(z)
predict(rf, newdata = my.df[rf$inbag[, z] == 1, ], nodes = TRUE))
nodes = TRUE
tracks the terminal nodes each observation ends in.
node.list <- lapply(seq_len(rf$ntree), function(z)
split(x = my.df[rf$inbag[, z] == 1, "x"],
f = attr(predList[[z]], "nodes")[, z]))
First three terminal nodes of the first tree:
node.list[[1]][1:3]
$`3`
[1] 2.028358 2.071939
$`7`
[1] 0.8306559
$`9`
[1] 1.660134 1.621299