TensorFlow: Are my logits in the right format for

2020-07-10 11:13发布

问题:

Alright, so I'm getting ready to run the tf.nn.softmax_cross_entropy_with_logits() function in Tensorflow.

It's my understanding that the 'logits' should be a Tensor of probabilities, each one corresponding to a certain pixel's probability that it is part of an image that will ultimately be a "dog" or a "truck" or whatever... a finite number of things.

These logits will get plugged into this cross entropy equation:

As I understand it, the logits are plugged into the right side of the equation. That is, they are the q of every x (image). If they were probabilities from 0 to 1... that would make sense to me. But when I'm running my code and ending up with a tensor of logits, I'm not getting probabilities. Instead I'm getting floats that are both positive and negative:

-0.07264724 -0.15262917  0.06612295 ..., -0.03235611  0.08587133 0.01897052 0.04655019 -0.20552202  0.08725972 ..., -0.02107313 -0.00567073 0.03241089 0.06872301 -0.20756687  0.01094618 ...,   etc

So my question is... is that right? Do I have to somehow calculate all my logits and turn them into probabilities from 0 to 1?

回答1:

The crucial thing to note is that tf.nn.softmax_cross_entropy_with_logits(logits, labels) performs an internal softmax on each row of logits so that they are interpretable as probabilities before they are fed to the cross entropy equation.

Therefore, the "logits" need not be probabilities (or even true log probabilities, as the name would suggest), because of the internal normalization that happens within that op.

An alternative way to write:

xent = tf.nn.softmax_cross_entropy_with_logits(logits, labels)

...would be:

softmax = tf.nn.softmax(logits)
xent = -tf.reduce_sum(labels * tf.log(softmax), 1)

However, this alternative would be (i) less numerically stable (since the softmax may compute much larger values) and (ii) less efficient (since some redundant computation would happen in the backprop). For real uses, we recommend that you use tf.nn.softmax_cross_entropy_with_logits().