I have been through the official doc and this but it is hard to understand what is going on.
I am trying to understand a DQN source code and it uses the gather function on line 197.
Could someone explain in simple terms what the gather function does? What is the purpose of that function?
The torch.gather
function (or torch.Tensor.gather
) is a multi-index selection method. Look at the following example from the official docs:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
Let's start with going through the semantics of the different arguments: The first argument, input
, is the source tensor that we want to select elements from. The second, dim
, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, index
are the indices to index input
.
As for the semantics of the operation, this is how the official docs explain it:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
So let's go through the example.
the input tensor is [[1, 2], [3, 4]]
, and the dim argument is 1
, i.e. we want to collect from the second dimension. The indices for the second dimension are given as [0, 0]
and [1, 0]
.
As we "skip" the first dimension (the dimension we want to collect along is 1
), the first dimension of the result is implicitly given as the first dimension of the index
. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index
tensor itself.
For the example, this means that the output will have in its first row a selection of the elements of the input
tensor's first row as well, as given by the first row of the index
tensor's first row. As the column-indices are given by [0, 0]
, we therefore select the first element of the first row of the input twice, resulting in [1, 1]
. Similarly, the elements of the second row of the result are a result of indexing the second row of the input
tensor by the elements of the second row of the index
tensor, resulting in [4, 3]
.
To illustrate this even further, let's swap the dimension in the example:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
As you can see, the indices are now collected along the first dimension.
For the example you referred,
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
gather
will index the rows of the q-values (i.e. the per-sample q-values in a batch of q-values) by the batch-list of actions. The result will be the same as if you had done the following (though it will be much faster than a loop):
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
torch.gather
creates a new tensor from the input tensor by taking the values from each row along the input dimension dim
. The values in torch.LongTensor
, passed as index
, specify which value to take from each 'row'. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly:
(Note: In the illustration, indexing starts from 1 and not 0).
In first example, the dimension given is along rows (top to bottom), so for (1,1) position of result
, it takes row value from the index
for the src
that is 1
. At (1,1) in source value is 1
so, outputs 1
at (1,1) in result
.
Similarly for (2,2) the row value from the index for src
is 3
. At (3,2) the value in src
is 8
and hence outputs 8
and so on.
Similarly for second example, indexing is along columns, and hence at (2,2) position of the result
, the column value from the index for src
is 3
, so at (2,3) from src
,6
is taken and outputs to result
at (2,2)