PyTorch - parameters not changing

2019-04-14 18:05发布

问题:

In an effort to learn how pytorch works, I am trying to do maximum likelihood estimation of some of the parameters in a multivariate normal distribution. However it does not seem to work for any of the covariance related parameters.

So my question is: why does this code not work?

import torch


def make_covariance_matrix(sigma, rho):
    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],
                         [rho * torch.prod(sigma), sigma[1]**2]])


mu_true = torch.randn(2)
rho_true = torch.rand(1)
sigma_true = torch.exp(torch.rand(2))

cov_true = make_covariance_matrix(sigma_true, rho_true)
dist_true = torch.distributions.MultivariateNormal(mu_true, cov_true)

samples = dist_true.sample((1_000,))

mu = torch.zeros(2, requires_grad=True)
log_sigma = torch.zeros(2, requires_grad=True)
atanh_rho = torch.zeros(1, requires_grad=True)

lbfgs = torch.optim.LBFGS([mu, log_sigma, atanh_rho])


def closure():
    lbfgs.zero_grad()
    sigma = torch.exp(log_sigma)
    rho = torch.tanh(atanh_rho)
    cov = make_covariance_matrix(sigma, rho)
    dist = torch.distributions.MultivariateNormal(mu, cov)
    loss = -torch.mean(dist.log_prob(samples))
    loss.backward()
    return loss


lbfgs.step(closure)

print("mu: {}, mu_hat: {}".format(mu_true, mu))
print("sigma: {}, sigma_hat: {}".format(sigma_true, torch.exp(log_sigma)))
print("rho: {}, rho_hat: {}".format(rho_true, torch.tanh(atanh_rho)))

output:

mu: tensor([0.4168, 0.1580]), mu_hat: tensor([0.4127, 0.1454], requires_grad=True)
sigma: tensor([1.1917, 1.7290]), sigma_hat: tensor([1., 1.], grad_fn=<ExpBackward>)
rho: tensor([0.3589]), rho_hat: tensor([0.], grad_fn=<TanhBackward>)

>>> torch.__version__
'1.0.0.dev20181127'

In other words, why have the estimates of log_sigma and atanh_rho not moved from their initial value?

回答1:

The way you create your covariance matrix is not backprob-able:

def make_covariance_matrix(sigma, rho):
    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],
                         [rho * torch.prod(sigma), sigma[1]**2]])

When creating a new tensor from (multiple) tensors, only the values of your input tensors will be kept. All additional information from the input tensors is stripped away, thus all graph-connection to your parameters is cut from this point, therefore backpropagation cannot get through.

Here is a short example to illustrate this:

import torch

param1 = torch.rand(1, requires_grad=True)
param2 = torch.rand(1, requires_grad=True)
tensor_from_params = torch.tensor([param1, param2])

print('Original parameter 1:')
print(param1, param1.requires_grad)
print('Original parameter 2:')
print(param2, param2.requires_grad)
print('New tensor form params:')
print(tensor_from_params, tensor_from_params.requires_grad)

Output:

Original parameter 1:
tensor([ 0.8913]) True
Original parameter 2:
tensor([ 0.4785]) True
New tensor form params:
tensor([ 0.8913,  0.4785]) False

As you can see the tensor, created from the parameters param1 and param2, does not keep track of the gradients of param1 and param2.


So instead you can use this code that keeps the graph connection and is backprob-able:

def make_covariance_matrix(sigma, rho):
    conv = torch.cat([(sigma[0]**2).view(-1), rho * torch.prod(sigma), rho * torch.prod(sigma), (sigma[1]**2).view(-1)])
    return conv.view(2, 2)

The values are concatenated to a flat tensor using torch.cat. Then they are brought into right shape using view().
This results in the same matrix output as in your function, but it keeps the connection to your parameters log_sigma and atanh_rho.

Here is an output before and after the step with the changed make_covariance_matrix. As you can see, now you can optimize your parameters and the values do change:

Before:
mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.,  0.])
sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.,  1.])
rho: tensor([ 0.2558]), rho_hat: tensor([ 0.])

After:
mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.0712,  0.7781])
sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.4410,  1.0807])
rho: tensor([ 0.2558]), rho_hat: tensor([ 0.2235])

Hope this helps!