Image segmentation - custom loss function in Keras

2019-07-13 18:16发布

I am using a in Keras implemented U-Net (https://arxiv.org/pdf/1505.04597.pdf) to segment cell organelles in microscopy images. In order for my network to recognize multiple single objects that are separated by only 1 pixel, I want to use weight maps for each label image (formula is given in publication).

As far as I know, I have to create my own custom loss function (in my case crossentropy) to make use of these weight maps. However, the custom loss function only takes in two parameters. How can I add the weight map values in such a function?

Below is the code for my custom loss function:

def pixelwise_crossentropy(self, ytrue, ypred):

    ypred /= tf.reduce_sum(ypred, axis=len(ypred.get_shape()) - 1, keep_dims=True)

    # manual computation of crossentropy
    _epsilon = tf.convert_to_tensor(epsilon, ypred.dtype.base_dtype)
    output = tf.clip_by_value(ypred, _epsilon, 1. - _epsilon)

    return - tf.reduce_sum(ytrue * tf.log(output))

Is there any way to combine the weight map values together with the label values in the ytrue tensor?

I apologize if this questions seem stupid, as I said, I'm relatively new to the game. Any help or suggestions would be greatly appreciated!

1条回答
【Aperson】
2楼-- · 2019-07-13 18:39

if you are trying to implement binary cross entropy weighted loss you can use tensor flow inbuilt loss function

pos_weight = tf.constant([[1.0, 2.0]])
tensorflow.nn.weighted_cross_entropy_with_logits(y_true,
    y_pred,
    pos_weight,
    name=None) 

check out the documentation https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits

implemetation for keras

def pixel_wise_loss(y_true, y_pred):
    pos_weight = tf.constant([[1.0, 2.0]])
    loss = tf.nn.weighted_cross_entropy_with_logits(
        y_true,
        y_pred,
        pos_weight,
        name=None
    )

    return K.mean(loss,axis=-1)

if you are trying to implement softmax_cross_entropy_with_logits then follow the link to previous where it is explained softmax_cross_entropy_with_logits

查看更多
登录 后发表回答