i have some question about pytorch's backward function i don't think i'm getting the right output
import numpy as np
import torch
from torch.autograd import Variable
a = Variable(torch.FloatTensor([[1,2,3],[4,5,6]]), requires_grad=True)
out = a * a
out.backward(a)
print(a.grad)
the output is
tensor([[ 2., 8., 18.],
[32., 50., 72.]])
maybe it's 2*a*a
but i think the output suppose to be
tensor([[ 2., 4., 6.],
[8., 10., 12.]])
2*a.
cause d(x^2)/dx=2x
Please read carefully the documentation on backward()
to better understand it.
By default, pytorch expects backward()
to be called for the last output of the network - the loss function. The loss function always outputs a scalar and therefore, the gradients of the scalar loss w.r.t all other variables/parameters is well defined (using the chain rule).
Thus, by default, backwards()
is called on a scalar tensor and expects no arguments.
For example:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
yields
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
As expected: d(a^2)/da = 2a
.
However, when you call backwards
on the 2-by-3 out
tensor (no longer a scalar function) - what do you expects a.grad
to be? You'll actually need a 2-by-3-by-2-by-3 output: d out[i,j] / d a[k,l]
(!)
Pytorch does not support this non-scalar function derivatives.
Instead, pytorch assumes out
is only an intermediate tensor and somewhere "upstream" there is a scalar loss function, that through chain rule provides d loss/ d out[i,j]
. This "upstream" gradient is of size 2-by-3 and this is actually the argument you provide backward
in this case: out.backward(g)
where g_ij = d loss/ d out_ij
.
The gradients are then calculated by chain rule d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
Since you provided a
as the "upstream" gradients you got
a.grad[i,j] = 2 * a[i,j] * a[i,j]
If you were to provide the "upstream" gradients to be all ones
out.backward(torch.ones(2,3))
print(a.grad)
yields
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
As expected.
It's all in the chain rule.