What's the difference between torch.stack() an

2019-06-17 05:45发布

问题:

OpenAI's REINFORCE and actor-critic example for reinforcement learning has the following code:

REINFORCE:

policy_loss = torch.cat(policy_loss).sum()

actor-critic:

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

One is using torch.cat, the other uses torch.stack.

As far as my understanding goes, the doc doesn't give any clear distinction between them.

I would be happy to know the differences between the functions.

回答1:

stack

Concatenates sequence of tensors along a new dimension.

cat

Concatenates the given sequence of seq tensors in the given dimension.

So if A and B are of shape (3, 4), torch.cat([A, B], dim=0) will be of shape (6, 4) and torch.stack([A, B], dim=0) will be of shape (2, 3, 4).