Extracting the terminal nodes of each tree associa

2019-06-14 09:46发布

问题:

I would like to extract the terminal nodes of the random forest R implementation. As I have understood random forest, you have a sequence of orthogonal trees. When you predict a new observation (In regression), it enters all these trees and then you average the prediction of each individual tree. If I wanted to not average but maybe do a linear regression with these corresponding observations I would need, say, a list of the observations that are "associated" with this new observation. I have gone through the source code but havent come up with a way to obtain this. Can anyone help me?

回答1:

There must be a better way to do this, but here's a workaround:

library(randomForest)
set.seed(713)
## data
my.df <- data.frame(x = rnorm(100), y = rnorm(100))
## forest
rf <- randomForest(y ~ x, data = my.df, ntree = 10, keep.inbag = TRUE)

keep.inbag = TRUE saves the inbag observations that are used to fit each of the 10 trees in this example

predList <- lapply(seq_len(rf$ntree), function(z) 
            predict(rf, newdata = my.df[rf$inbag[, z] == 1, ], nodes = TRUE))

nodes = TRUE tracks the terminal nodes each observation ends in.

node.list <- lapply(seq_len(rf$ntree), function(z) 
            split(x = my.df[rf$inbag[, z] == 1, "x"], 
                    f = attr(predList[[z]], "nodes")[, z]))

First three terminal nodes of the first tree:

node.list[[1]][1:3]

$`3`
[1] 2.028358 2.071939

$`7`
[1] 0.8306559

$`9`
[1] 1.660134 1.621299