Assuming after performing median frequency balancing for images used for segmentation, we have these class weights:
class_weights = {0: 0.2595,
1: 0.1826,
2: 4.5640,
3: 0.1417,
4: 0.9051,
5: 0.3826,
6: 9.6446,
7: 1.8418,
8: 0.6823,
9: 6.2478,
10: 7.3614,
11: 0.0}
The idea is to create a weight_mask such that it could be multiplied by the cross entropy output of both classes. To create this weight mask, we can broadcast the values based on the ground_truth labels or the predictions. Some mathematics in my implementation:
Both labels and logits are of shape
[batch_size, height, width, num_classes]
The weight mask is of shape
[batch_size, height, width, 1]
The weight mask is broadcasted to the
num_classes
number of channels of the multiplication between the softmax of the logit and the labels to give an output shape of[batch_size, height, width, num_classes]
. In this case,num_classes
is 12.Reduce sum for each example in a batch, then perform reduce mean for all examples in one batch to get a single scalar value of loss.
In this case, should we create the weight mask based on the predictions or the ground truth?
If we build it based on the ground_truth, then it means no matter what the predicted pixel labels are, they get penalized based on the actual labels of the class, which doesn't seem to guide the training in a sensible way.
But if we build it based on the predictions, then for whatever logit predictions that are produced, if the predicted label (from taking the argmax of the logit) is dominant, then the logit values for that pixel will all be reduced by a significant amount.
--> Although this means the maximum logit will still be the maximum since all of the logits in the 12 channels will be scaled by the same value, the final softmax probability of the label predicted (which is still the same before and after scaling), will be lower than before scaling (did some simple math to estimate). --> a lower loss is predicted
But the problem is this: If a lower loss is predicted as a result of this weighting, then wouldn't it contradict the idea that predicting dominant labels should give you a greater loss?
The impression I get in total for this method is that:
- For the dominant labels, they are penalized and rewarded much lesser.
- For the less dominant labels, they are rewarded highly if the predictions are correct, but they're also penalized heavily for a wrong prediction.
So how does this help to tackle the issue of class-balancing? I don't quite get the logic here.
IMPLEMENTATION
Here is my current implementation for calculating the weighted cross entropy loss, although I'm not sure if it is correct.
def weighted_cross_entropy(logits, onehot_labels, class_weights):
if not logits.dtype == tf.float32:
logits = tf.cast(logits, tf.float32)
if not onehot_labels.dtype == tf.float32:
onehot_labels = tf.cast(onehot_labels, tf.float32)
#Obtain the logit label predictions and form a skeleton weight mask with the same shape as it
logit_predictions = tf.argmax(logits, -1)
weight_mask = tf.zeros_like(logit_predictions, dtype=tf.float32)
#Obtain the number of class weights to add to the weight mask
num_classes = logits.get_shape().as_list()[3]
#Form the weight mask mapping for each pixel prediction
for i in xrange(num_classes):
binary_mask = tf.equal(logit_predictions, i) #Get only the positions for class i predicted in the logits prediction
binary_mask = tf.cast(binary_mask, tf.float32) #Convert boolean to ones and zeros
class_mask = tf.multiply(binary_mask, class_weights[i]) #Multiply only the ones in the binary mask with the specific class_weight
weight_mask = tf.add(weight_mask, class_mask) #Add to the weight mask
#Multiply the logits with the scaling based on the weight mask then perform cross entropy
weight_mask = tf.expand_dims(weight_mask, 3) #Expand the fourth dimension to 1 for broadcasting
logits_scaled = tf.multiply(logits, weight_mask)
return tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits_scaled)
Could anyone verify whether my concept of this weighted loss is correct, and whether my implementation is correct? This is my first time getting acquainted with a dataset with imbalanced class, and so I would really appreciate it if anyone could verify this.
TESTING RESULTS: After doing some tests, I found the implementation above results in a greater loss. Is this supposed to be the case? i.e. Would this make the training harder but produce a more accurate model eventually?
SIMILAR THREADS
Note that I have checked a similar thread here: How can I implement a weighted cross entropy loss in tensorflow using sparse_softmax_cross_entropy_with_logits
But it seems that TF only has a sample-wise weighting for loss but not a class-wise one.
Many thanks to all of you.