What does the gather function do in pytorch in layman terms?

What does torch.gather do? This answer is hard to understand.

63988 次浏览

torch.gather函数(或 torch.Tensor.gather)是一种多索引选择方法:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

让我们从不同参数的语义开始: 第一个参数 input是我们要从中选择元素的源张量。第二个,dim,是我们想要沿着它收集的维度(或张量流/数字中的轴)。最后,以 index作为 input的索引。 至于操作的语义,官方文件是这样解释的:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

让我们来看看这个例子。

输入张量是 [[1, 2], [3, 4]],dim 参数是 1,也就是说,我们要从第二维收集。第二维的指数分别为 [0, 0][1, 0]

As we "skip" the first dimension (the dimension we want to collect along is 1), the first dimension of the result is implicitly given as the first dimension of the index. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index tensor itself. 例如,这意味着输出将在其第一行中包含 input张量第一行的元素的选择,就像 index张量第一行的第一行所给出的那样。由于列索引是由 [0, 0]给出的,因此我们两次选择输入第一行的第一个元素,结果是 [1, 1]。类似地,结果的第二行的元素是 input张量的第二行的元素对 index张量的第二行进行索引的结果,结果是 [4, 3]

为了进一步说明这一点,让我们交换示例中的维度:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

正如您所看到的,索引现在沿着第一个维度收集。

对于你提到的例子,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather将通过批处理操作列表对 q 值的行(即一批 q 值中的每个样本 q 值)进行索引。结果将与您执行以下操作相同(尽管它比循环快得多) :

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

torch.gather从输入张量中创建一个新的张量,方法是获取沿输入维度 dim的每一行的值。torch.LongTensor中的值作为 index传递,指定从每个“行”获取哪个值。输出张量的尺寸与索引张量的尺寸相同。以下是官方文件中的说明,更加清楚地说明了这一点: Pictoral representation from the docs

(Note: In the illustration, indexing starts from 1 and not 0).

在第一个示例中,给定的维度是沿行(从上到下)的,因此对于 result的(1,1)位置,它从 index获取 src的行值,即 1。在(1,1)的源值是 1,因此,输出 result在(1,1)的 1。 类似地,对于(2,2) ,来自 src索引的行值是 3。在(3,2)处,src的值是 8,因此输出 8等。

类似地,对于第二个例子,索引是沿着列的,因此在 result的(2,2)位置,来自 src的索引的列值是 3,所以在 src的(2,3)处,取 6,并在(2,2)处输出到 result

@ Ritesh 和@cleros 给出了很好的回答(很多的赞成票) ,但是读完之后我还是有点困惑,我知道为什么。这篇文章可能会帮助像我这样的人。

对于这些行和列的练习,我认为使用非正方形对象有助于 真的,所以让我们从使用 source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])的更大的4x3 source(torch.Size([4, 3]))开始。这会给我们

\\ This is the source tensor
tensor([[ 1,  2,  3],
[ 4,  5,  6],
[ 7,  8,  9],
[10, 11, 12]])

现在让我们开始沿着列索引(dim=1)并创建 index = torch.tensor([[0,0],[1,1],[2,2],[0,1]]),它是一个列表列表。下面是 钥匙: 因为我们的维度是列,并且源代码有 4行,所以 index必须包含 4列表!我们需要每一行的列表。运行 source.gather(dim=1, index=index)会给我们

tensor([[ 1,  1],
[ 5,  5],
[ 9,  9],
[10, 11]])

So, each list within index gives us the columns from which to pull the values. The 1st list of the index ([0,0]) is telling us to take to look at the 1st row of the source and take the 1st column of that row (it's zero-indexed) twice, which is [1,1]. The 2nd list of the index ([1,1]) is telling us to take to look at the 2nd row of source and take the 2nd column of that row twice, which is [5,5]. Jumping to the 4th list of the index (index0), which is asking us to look at the 4th and final row of the source, is asking us to take the 1st column (index2) and then the 2nd column (index3) which gives us index4.

这里有一个巧妙的事情: 你的 index的每个列表必须是相同的长度,但他们可能是你喜欢的长度!例如,对于 index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])source.gather(dim=1, index=index)会给我们

tensor([[ 1,  2,  3,  2,  1],
[ 6,  5,  4,  5,  6],
[ 8,  9,  7,  9,  8],
[11, 10, 12, 10, 11]])

输出将始终具有与 source相同的行数,但是列数将等于 index中每个列表的长度。例如,index的第2个列表([2,1,0,1,2])将分别到 source的第2行和第3、第2、第1、第2和第3个项目,即 [6,5,4,5,6]。注意,index中每个元素的值必须小于 source的列数(在本例中为 3) ,否则将得到一个 out of bounds错误。

切换到 dim=0,我们现在将使用行而不是列。使用相同的 source,我们现在需要一个 index,其中每个列表的长度等于 source中的列数。为什么?因为当我们逐列移动时,列表中的每个元素都表示来自 source的行。

因此,index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])会让 source.gather(dim=0, index=index)给我们

tensor([[ 1,  2,  3],
[ 1,  5,  9],
[ 4,  8, 12],
[10,  8,  3]])

查看 index([0,0,0])中的第一个列表,我们可以看到我们正在移动 source的3个列,选择每个列的第一个元素(它的索引为零) ,即 [1,2,3]index([0,1,2])中的第2个列表告诉我们在各列之间移动,分别获取第1、第2和第3个项目,即 [1,5,9]。诸如此类。

对于 dim=1,我们的 index必须有一定数量的列表,这些列表相当于 source中的行数,但是每个列表可以是长的,也可以是短的,随你喜欢。对于 dim=0,我们的 index中的每个列表的长度必须与 source中的列数相同,但是我们现在可以拥有任意多的列表。然而,index中的每个值都需要小于 source中的行数(在本例中为 4)。

例如,index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])会让 source.gather(dim=0, index=index)给我们

tensor([[ 1,  2,  3],
[ 4,  5,  6],
[ 7,  8,  9],
[10, 11, 12],
[ 1,  5,  9],
[ 4,  8, 12],
[10,  8,  3]])

With dim=1 the output always has the same number of rows as the source, although the number of columns will equal the length of the lists in index. The number of lists in index has to equal the number of rows in source. Each value in index, however, needs to be less than the number of columns in source.

对于 dim=0,输出总是具有与 source相同的列数,但是行数将等于 index中的列数。index中每个列表的长度必须等于 source中的列数。但是,index中的每个值都需要小于 source中的行数。

这就是二维世界,超越二维世界会遵循同样的模式。

这是基于@Ritesh 回答(感谢@Ritesh!)和一些真正的代码。

torch.gather API 是

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

例子一

dim = 0,

enter image description here

dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]]


output = torch.gather(input, dim, index))
# tensor([[10, 14, 18],
#         [13, 17, 12]])

例子2

dim = 1,

enter image description here

dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]]


output = torch.gather(input, dim, index))
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])