How to check if two Torch tensors or matrices are

2020-06-06 17:14发布

I need a Torch command that checks if two tensors have the same content, and returns TRUE if they have the same content.

For example:

local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

What should I use in this script instead of EQUIVALENCE_COMMAND ?

I tried simply with == but it does not work.

标签: lua torch
5条回答
成全新的幸福
2楼-- · 2020-06-06 17:57

To compare tensors you can do element wise:

torch.eq is element wise:

torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])

Or torch.equal for the whole tensor exactly:

torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True

But then you may be lost because at some point there are small differences you would like to ignore. For instance floats 1.0 and 1.0000000001 are pretty close and you may consider these are equal. For that kind of comparison you have torch.allclose.

torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True

At some point may be important to check element wise how many elements are equal, comparing to the full number of elements. If you have two tensors dt1 and dt2 you get number of elements of dt1 as dt1.nelement()

And with this formula you get the percentage:

print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())
查看更多
我命由我不由天
3楼-- · 2020-06-06 18:00

You can convert the two tensors to numpy arrays:

local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));

a=tens_a.numpy()
b=tens_b.numpy()

and then something like

np.sum(a==b)
4

would give you a fairly good idea of how equals are they.

查看更多
地球回转人心会变
4楼-- · 2020-06-06 18:03

This below solution worked for me:

torch.equal(tensorA, tensorB)

From the documentation:

True if two tensors have the same size and elements, False otherwise.

查看更多
萌系小妹纸
5楼-- · 2020-06-06 18:04

https://github.com/torch/torch7/blob/master/doc/maths.md#torcheqa-b

torch.eq(a, b)

Implements == operator comparing each element in a with b (if b is a number) or each element in a with corresponding element in b.

UPDATE

from @deltheil

torch.all(torch.eq(tens_a, tens_b))

or even simpler

torch.all(tens_a.eq(tens_b))
查看更多
兄弟一词,经得起流年.
6楼-- · 2020-06-06 18:12

Try this if you want to ignore small precision differences which are common for floats

torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))
查看更多
登录 后发表回答