How Pytorch Tensor get the index of specific value

2020-02-08 14:49发布

问题:

In python list, we can use list.index(somevalue). How can pytorch do this?
For example:

    a=[1,2,3]
    print(a.index(2))

Then, 1 will be output. How can a pytorch tensor do this without converting it to a python list?

回答1:

I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())

This piece of code returns

1

[torch.LongTensor of size 1x1]



回答2:

Can be done by converting to numpy as follows

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2


回答3:

For floating point tensors, I use this to get the index of the element in the tensor.

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())


回答4:

    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])