-->

Find the data elements in a data frame that pass t

2019-02-10 20:04发布

问题:

So I have used the rpart package to create a tree model and I found an interesting rule and wondered if there was an easy way to see which observations in that data frame pass that rule.

It seems very tedious to use path.rpart to find the path it took down the tree, and manually enter those filters into the data frame to look for them. Is there a method where I can pass a tree and/or a node, and a data frame and return all the elements in that frame that ended at that node?

回答1:

I modified the code in path.rpart to return the subset of the data that falls within a particular node rather than returning information about that node. It works by either clicking on the plot or by passing nodes just as the path.rpart function does. Here is the code

subset.rpart <- function (tree, df, nodes) {
    if (!inherits(tree, "rpart")) 
        stop("Not a legitimate \"rpart\" object")
    stopifnot(nrow(df)==length(tree$where))
    frame <- tree$frame
    n <- row.names(frame)
    node <- as.numeric(n)

    if (missing(nodes)) {
        xy <- rpart:::rpartco(tree)
        i <- identify(xy, n = 1L, plot = FALSE)
        if(i> 0L) {
             return( df[tree$where==i, ] )
        } else {
            return(df[0,])
        }
    }
    else {
        if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
            return(df[0,])
        return ( df[tree$where %in% as.numeric(nodes), ] )
    }
}

I will use it on some sample data from the package

fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit)
text(fit)

And then to find the observations at a particular node, run

subset.rpart(fit, kyphosis)

and click on a node on the plot. After you do, all the observations at that node will be returned. You must use the same data.frame that was used for modeling for this to work properly. Rather than clicking on a point, you can also pass in a node name that you you discover with path.rpart

# path.rpart(fit)  
#  node number: 10  ---> looks interesting
#    root
#    Start>=8.5
#    Start< 14.5
#    Age< 55

subset.rpart(fit, kyphosis, 10)
#    Kyphosis Age Number Start
# 14   absent   1      4    12
# 20   absent  27      4     9
# 26   absent   9      5    13
# 37   absent   1      3     9
# 39   absent  20      6     9
# 42   absent  35      3    13
# 57   absent   2      3    13
# 59   absent  51      7     9
# 66   absent  17      4    10
# 69   absent  18      4    11
# 78   absent  26      7    13
# 81   absent  36      4    13


回答2:

#' subset of rpart node: return logical index
#' @param tree rpart model
#' @param node which node/leaf?
#' @export
subset_rpart <- function (tree, node) {
  nodes = as.numeric(rownames(tree$frame))
  nodes = log(nodes, 2)
  lower = log(node, 2)
  upper = log(node + 1, 2)
  a = floor(lower)
  lower_ = lower - a
  upper_  = upper - a
  nodes_ = nodes %% 1
  w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
  tree$where %in% w
}



#' subset df by subset_rpart
#' @param tree rpart model
#' @param node node number
#' @param df df
#' @export
subset.rpart = function(tree, node, df){
  df[subset_rpart(tree, node), ]
}


标签: r rpart