I understood the concept of automatic differentiation, but couldn't find any explanation how tensorflow calculates the error gradient for non differentiable functions as for example tf.where
in my loss function or tf.cond
in my graph. It works just fine, but I would like to understand how tensorflow backpropagates the error through such nodes, since there is no formula to calculate the gradient from them.
相关问题
- how to define constructor for Python's new Nam
- streaming md5sum of contents of a large remote tar
- batch_dot with variable batch size in Keras
- How to get the background from multiple images by
- Evil ctypes hack in python
In the case of
tf.where
, you have a function with three inputs, conditionC
, value on trueT
and value on falseF
, and one outputOut
. The gradient receives one value and has to return three values. Currently, no gradient is computed for the condition (that would hardly make sense), so you just need to do the gradients forT
andF
. Assuming the input and the outputs are vectors, imagineC[0]
isTrue
. ThenOut[0]
comes fromT[0]
, and its gradient should propagate back. On the other hand,F[0]
would have been discarded, so its gradient should be made zero. IfOut[1]
wereFalse
, then the gradient forF[1]
should propagate but not forT[1]
. So, in short, forT
you should propagate the given gradient whereC
isTrue
and make it zero where it isFalse
, and the opposite forF
. If you look at the implementation of the gradient oftf.where
(Select
operation), it does exactly that:Note the input values themselves are not used in the computation, that will be done by the gradients of the operation producing those inputs. For
tf.cond
, the code is a bit more complicated, because the same operation (Merge
) is used in different contexts, and alsotf.cond
also usesSwitch
operations inside. However the idea is the same. Essentially,Switch
operations are used for each input, so the input that was activated (the first if the condition wasTrue
and the second otherwise) gets the received gradient and the other input gets a "switched off" gradient (likeNone
), and does not propagate back further.