Tensorflow: Sigmoid cross entropy loss does not fo

2019-07-22 15:12发布

I would like to learn image segmentation in TensorFlow with values in {0.0,1.0}. I have two images, ground_truth and prediction and each have shape (120,160). The ground_truth image pixels only contain values that are either 0.0 or 1.0.

The prediction image is the output of a decoder and the last two layers of it are a tf.layers.conv2d_transpose and tf.layers.conv2d like so:

 transforms (?,120,160,30) -> (?,120,160,15)
outputs = tf.layers.conv2d_transpose(outputs, filters=15, kernel_size=1, strides=1, padding='same')
# ReLU
outputs = activation(outputs)

# transforms (?,120,160,15) -> (?,120,160,1)
outputs = tf.layers.conv2d(outputs, filters=1, kernel_size=1, strides=1, padding='same')

The last layer does not carry an activation function and thus it's output is unbounded. I use the following loss function:

logits = tf.reshape(predicted, [-1, predicted.get_shape()[1] * predicted.get_shape()[2]])
labels = tf.reshape(ground_truth, [-1, ground_truth.get_shape()[1] * ground_truth.get_shape()[2]])

loss = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits))

This setup converges nicely. However, I have realized that the outputs of my last NN layer at validation time seem to be in [-inf, inf]. If I visualize the output I can see that the segmented object is not segmented since almost all pixels are "activated". The distributions of values for a single output of the last conv2d layer looks like this:

imgur.com/a/kSPJneU

Question:

Do I have to post-process the outputs (crop negative values or run output trough a sigmoid activation etc.)? What do I need to do to enforce my output values to be {0,1}?

1条回答
孤傲高冷的网名
2楼-- · 2019-07-22 15:36

Solved it. The problem was that the tf.nn.sigmoid_cross_entropy_with_logits runs the logits through a sigmoid which is of course not used at validation time since the loss operation is only called during train time. The solution therefore is:

make sure to run the network outputs through a tf.nn.sigmoid at validation/test time like this:

return output if is_training else tf.nn.sigmoid(output)
查看更多
登录 后发表回答