Stratified Train/Validation/Test-split in scikit-l

2019-03-28 14:12发布

There is already a description here of how to do stratified train/test split in scikit via train_test_split (Stratified Train/Test-split in scikit-learn) and a description of how to random train/validation/test split via np.split (How to split data into 3 sets (train, validation and test)?). But what about doing stratified train/validation/test split.

The closest approximation that comes to mind for doing stratified (on class label) train/validation/test split is as follows, but I suspect there's a better way that can perhaps achieve this in one function call or in a more accurate way:

Let's say we want to do a 60/20/20 train/validation/test split, then my current approach is to first do 60/40 stratified split, then do a 50/50 stratifeid split on that first 40 as to ultimately get a 60/20/20 stratified split.

from sklearn.cross_validation import train_test_split
SEED = 2000
x_train, x_validation_and_test, y_train, y_validation_and_test = train_test_split(x, y, test_size=.4, random_state=SEED)
x_validation, x_test, y_validation, y_test = train_test_split(x_validation_and_test, y_validation_and_test, test_size=.5, random_state=SEED)

Please get back if my approach is correct and/or if you have a better approach.

Thank you

2条回答
我想做一个坏孩纸
2楼-- · 2019-03-28 14:21

Yes, this is exactly how I would do it - running train_test_split() twice. Think of the first as splitting off your training set, and then that training set may get divided into different folds or holdouts down the line.

In fact, if you end up testing your model using a scikit model that includes built-in cross-validation, you may not even have to explicitly run train_test_split() again. Same if you use the (very handy!) model_selection.cross_val_score function.

查看更多
男人必须洒脱
3楼-- · 2019-03-28 14:41

The solution is to just use StratifiedShuffleSplit twice, like below:

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=42)
for train_index, test_valid_index in split.split(df, df.target):
    train_set = df.iloc[train_index]
    test_valid_set = df.iloc[test_valid_index]

split2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
for test_index, valid_index in split2.split(test_valid_set, test_valid_set.target):
    test_set = test_valid_set.iloc[test_index]
    valid_set = test_valid_set.iloc[valid_index]
查看更多
登录 后发表回答