I am curious if there is something similar to sklearn's http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html for apache-spark in the latest 2.0.1 release.
So far I could only find https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling which does not seem to be a great fit for splitting heavily imbalanced dataset into train /test samples.
Let's assume we have a dataset like this:
This dataset is perfectly balanced, but this approach will work for unbalanced data as well.
Now, let's augment this DataFrame with additional information that will be useful in deciding which rows should go to train set. The steps are as follows:
ratio
.label
and then rank each label's observations usingrow_number()
.We end up with the following data frame:
Note: the rows are shuffled (see: random order in
id
column), partitioned by label (see:label
column) and ranked.Let's assume that we would like to make 80% split. In this case, we would like four
1.0
labels and four0.0
labels to go to training dataset and one1.0
label and one0.0
label to go to test dataset. We have this information inrow_number
column, so now we can simply use it in user defined function (ifrow_number
is less or equal four, the example goes to train set).After applying the UDF, the resulting data frame is as follows:
Now, to get the train/test data one has to do:
These sorting and partitioning steps might be prohibitive for some really big datasets, so I suggest first filtering the dataset as much as possible. The physical plan is as follows:
Here's full working example (tested with Spark 2.3.0 and Scala 2.11.12):
Note: the labels are
Double
s in this case. If your labels areString
s you'll have to switch types here and there.Spark supports stratified samples as outlined in https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html
Although this answer is not specific to Spark, in Apache beam I do this to to split train 66% and test 33% (just an illustrative example, you can customize the partition_fn below to be more sophisticated and accept arguments such to specify the number of buckets or bias selection towards something or assure randomization is fair across dimensions, etc):
Perhaps this method wasn't available when the OP posted this question, but I'm leaving this here for future reference: