How to use whole training example to estimate clas

2019-07-21 20:39发布

I want to use scikit-learn RandomForestClassifier to estimate the probabilities of a given example to belong to a set of classes, after prior training of course.

I know I can get the class probabilities using the predict_proba method, that calculates them as

[...] the mean predicted class probabilities of the trees in the forest.

In this question it is mentioned that:

The probabilities returned by a single tree are the normalized class histograms of the leaf a sample lands in.

Now, I've been reading some papers on probability estimation and realized there isn't a trivial solution. According to Estimating Class Probabilities in Random Forests (Böstrom):

using the same examples to both grow the trees and estimate the probabilities, [...] by necessity will lead to pure (and therefore small) estimation sets

And this is bad. The solution appears to be to use all the examples in the training set, instead of only the ones in the bootstrap sample used to grow the tree.

Scikit-learn does use only the bootstrap sample for each tree to calculate the probability estimate of each class, right? Does somebody have any pointers about how to proceed to make the class probabilities come from the whole training set of the RandomForest instead?

I assume this would need some special Tree subclassing that doesn't assign class probabilities to the leaves of the trees and then some procedure to assign them from the RandomForest classifier using the whole training set.

1条回答
相关推荐>>
2楼-- · 2019-07-21 20:47

Scikit-learn does use only the bootstrap sample for each tree to calculate the probability estimate of each class, right?

No, it uses only the in-sample part, and therefore will not give very calibrated probability outputs (which I guess is what the paper suggests).

You could get better probability estimates using the out-of-sample estimates, and maybe that would even be done easily with the current code base. Maybe it would be better to use a calibration method as post-processing (using the out-of-bag samples).

Anyhow, what you want to achieve is the default.

查看更多
登录 后发表回答