pytorch backprop through volatile variable error

2019-05-10 18:17发布

I'm trying to minimize some input relative to some target by running it through several backward pass iterations and updating the input at each step. The first pass runs successfully but I get the following error on the second pass: RuntimeError: element 0 of variables tuple is volatile

This code snippet demonstrates the problem

import torch
from torch.autograd import Variable
import torch.nn as nn

inp = Variable(torch.Tensor([1]), requires_grad=True)
target = Variable(torch.Tensor([3]))

loss_fn = nn.MSELoss()

for i in range(2):
    loss = loss_fn(inp, target)
    loss.backward()
    gradient = inp.grad
    inp = inp - inp.grad * 0.01

When I inspect the value of inp, before it is reassigned on the last line, inp.volatile => False and inp.requires_grad => True but after it is reassigned those switch to True and False, respectively. Why does being a volatile variable prevent the second backprop run?

1条回答
男人必须洒脱
2楼-- · 2019-05-10 19:10

You must zero out the gradient before each update like this:

inp.grad.data.zero_()

But in your code every time you update the gradient you are creating another Variable object, so you must update entire history like this:

import torch
from torch.autograd import Variable
import torch.nn as nn

inp_hist = []
inp = Variable(torch.Tensor([1]), requires_grad=True)
target = Variable(torch.Tensor([3]))

loss_fn = nn.MSELoss()

for i in range(2):
    loss = loss_fn(inp, target)
    loss.backward()
    gradient = inp.grad
    inp_hist.append(inp)
    inp = inp - inp.grad * 0.01
    for inp in inp_hist:
        inp.grad.data.zero_()

But this way you will compute the gradient for all previous inputs you have created in the history(and it's bad, it's a wast of everything), a correct implementation looks like this:

import torch
from torch.autograd import Variable
import torch.nn as nn
inp = Variable(torch.Tensor([1]), requires_grad=True)
target = Variable(torch.Tensor([3]))
loss_fn = nn.MSELoss()
for i in range(2):
    loss = loss_fn(inp, target)
    loss.backward()
    gradient = inp.grad
    inp.data = inp.data - inp.grad.data * 0.01
    inp.grad.data.zero_()
查看更多
登录 后发表回答