Imbalanced Dataset for Multi Label Classification

2019-06-21 22:24发布

So I trained a deep neural network on a multi label dataset I created (about 20000 samples). I switched softmax for sigmoid and try to minimize (using Adam optimizer) :

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

And I end up with this king of prediction (pretty "constant") :

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646 0.38991812  0.54468459  0.34593087  0.82790571]

Prediction for Im2 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177 0.35400775  0.36479294  0.30002621  0.84438241]

Prediction for Im3 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551 0.43719986  0.54638696  0.20344526  0.88144571]

At first, I thought I just neeeded to find a threshold value for each class.

But I noticed that, for instance, among my 20000 samples, the 1st class appears about 10800 so a 0.54 ratio and it the value around which my prediction is every time. So I think I need to find a way to tackle tuis "imbalanced datset" issue.

I thought about reducing my dataset (Undersampling) to have about the same number of occurence for each class but only 26 samples correspond to one of my classes... That would make me loose a lot of samples...

I read about oversampling or about penalizing even more the classes that are rare but did not really understood how it works.

Can someone share some explainations about these methods please ?

In practice, on Tensorflow, are there functions that help doing that ?

Any other suggestions ?

Thank you :)

PS: Neural Network for Imbalanced Multi-Class Multi-Label Classification This post raises the same problem but had no answer !

2条回答
Deceive 欺骗
2楼-- · 2019-06-21 22:42

Your problem is not the class imbalance, rather just the lack of data. 26 samples are considered to be a very small dataset for practically any real machine learning task. A class imbalance could be easily handled by ensuring that each minibatch will have at least one sample from every class (this leads to situations when some samples will be used much more frequently than another, but who cares).

However, in the case of presence only 26 samples this approach (and any other) will quickly lead to overfitting. This problem could be partly solved with some form of data augmentation, but there still too few samples to construct something reasonable.

So, my suggestion will be to collect more data.

查看更多
ら.Afraid
3楼-- · 2019-06-21 22:51

Well, having 10000 samples in one class and just 26 in a rare class will be indeed a problem.

However, what you experience, to me, seems more like "outputs don't even see the inputs" and thus the net just learns your output distribution.

To debug this I would create a reduced set (just for this debugging purpose) with say 26 samples per class and then try to heavily overfit. If you get correct predictions my thought is wrong. But if the net cannot even detect those undersampled overfit samples then indeed it's an architecture/implementation problem and not due to the schewed distribution (which you will then need to fix. But it'll be not as bad as your current results).

查看更多
登录 后发表回答