I am reading through the documentation of PyTorch and found an example where they write
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)
where x was an initial variable, from which y was constructed (a 3-vector). The question is, what are the 0.1, 1.0 and 0.0001 arguments of the gradients tensor ? The documentation is not very clear on that.
Explanation
For neural networks, we usually use
loss
to asses how well the network has learned to classify the input image(or other tasks). Theloss
term is usually a scalar value. In order to update the parameters of the network, we need to calculate the gradient ofloss
w.r.t to the parameters, which is actuallyleaf node
in the computation graph (by the way, these parameters are mostly the weight and bias of various layers such Convolution, Linear and so on).According to chain rule, in order to calculate gradient of
loss
w.r.t to a leaf node, we can compute derivative ofloss
w.r.t some intermediate variable, and gradient of intermediate variable w.r.t to the leaf variable, do a dot product and sum all these up.The
gradient
arguments of aVariable
'sbackward()
method is used to calculate a weighted sum of each element of a Variable w.r.t the leaf Variable. These weight is just the derivate of finalloss
w.r.t each element of the intermediate variable.A concrete example
Let's take a concrete and simple example to understand this.
In the above example, the outcome of first
print
iswhich is exactly the derivative of z_1 w.r.t to x.
The outcome of second
print
is :which is the derivative of z_2 w.r.t to x.
Now if use a weight of [1, 1, 1, 1] to calculate the derivative of z w.r.t to x, the outcome is
1*dz_1/dx + 1*dz_2/dx + 1*dz_3/dx + 1*dz_4/dx
. So no surprisingly, the output of 3rdprint
is:It should be noted that weight vector [1, 1, 1, 1] is exactly derivative of
loss
w.r.t to z_1, z_2, z_3 and z_4. The derivative ofloss
w.r.t tox
is calculated as:So the output of 4th
print
is the same as the 3rdprint
:Here, the output of forward(), i.e. y is a a 3-vector.
The three values are the gradients at the output of the network. They are usually set to 1.0 if y is the final output, but can have other values as well, especially if y is part of a bigger network.
For eg. if x is the input, y = [y1, y2, y3] is an intermediate output which is used to compute the final output z,
Then,
So here, the three values to backward are
and then backward() computes dz/dx
Typically, your computational graph has one scalar output says
loss
. Then you can compute the gradient ofloss
w.r.t. the weights (w
) byloss.backward()
. Where the default argument ofbackward()
is1.0
.If your output has multiple values (e.g.
loss=[loss1, loss2, loss3]
), you can compute the gradients of loss w.r.t. the weights byloss.backward(torch.FloatTensor([1.0, 1.0, 1.0]))
.Furthermore, if you want to add weights or importances to different losses, you can use
loss.backward(torch.FloatTensor([-0.1, 1.0, 0.0001]))
.This means to calculate
-0.1*d(loss1)/dw, d(loss2)/dw, 0.0001*d(loss3)/dw
simultaneously.