weighting true positives vs true negatives

2019-08-17 05:19发布

This loss function in tensorflow is used as a loss function in keras/tensorflow to weight binary decisions

It weights false positives vs false negatives:

targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))

The argument pos_weight is used as a multiplier for the positive targets:

targets * -log(sigmoid(logits)) * pos_weight + (1 - targets) * -log(1 - sigmoid(logits))

Does anybody have any suggestions how in addition true positives could be weighted against true negatives if the loss/reward of them should not have an equal weight?

1条回答
成全新的幸福
2楼-- · 2019-08-17 05:43

First, note that with cross entropy loss, there is some (possibly very very small) penalty to each example (even if correctly classified). For example, if the correct class is 1 and our logit is 10, the penalty will be

-log(sigmoid(10)) = 4*1e-5

This loss (very slightly) pushes the network to produce even higher logit for this case to get its sigmoid even closer to 1. Similarly, for negative class, even if the logit is -10, the loss will push it to be even more negative.

This is usually fine because the loss from such terms is very small. If you would like your network to actually achieve zero loss, you can use label_smoothing. This is probably as close to "rewarding" the network as you can get in the classic setup of minimizing loss (you can obviously "reward" the network by adding some negative number to the loss. That won't change the gradient and training behavior though).

Having said that, you can penalize the network differently for various cases - tp, tn, fp, fn - similarly to what is described in Weight samples if incorrect guessed in binary cross entropy. (It seems like the implementation there is actually incorrect. You want to use corresponding elements of the weight_tensor to weight individual log(sigmoid(...)) terms, not the final output of cross_entropy).

Using this scheme, you might want to penalize very wrong answers much more than almost right answers. However, note that this is already happening to a degree because of the shape of log(sigmoid(...)).

查看更多
登录 后发表回答