如何在 PyTorch 中乘以矩阵?

有了 numpy,我可以做这样一个简单的矩阵乘法:

a = numpy.ones((3, 2))
b = numpy.ones((2, 1))
result = a.dot(b)

然而,这并不适用于 PyTorch:

a = torch.ones((3, 2))
b = torch.ones((2, 1))
result = torch.dot(a, b)

此代码引发以下错误:

运行时错误: 一维张量预期,但得到2D 和2D 张量

在《火炬》里我该怎么表演矩阵乘法?

140352 次浏览

使用 torch.mm:

torch.mm(a, b)

torch.dot()的行为与 np.dot()不同。有一些关于什么是可取的 给你的讨论。具体来说,torch.dot()ab都视为一维矢量(不论其原始形状如何) ,并计算它们的内积。之所以抛出这个错误,是因为这个行为使您的 a为长度为6的向量,而您的 b为长度为2的向量; 因此它们的内积无法计算。对于 PyTorch 中的矩阵乘法,使用 torch.mm()。相比之下,Numpy 的 np.dot()更灵活,它计算一维数组的内积,并执行二维数组的矩阵乘法。

如果两个参数都是 2D,则 torch.matmul 执行矩阵乘法,如果两个参数都是 1D,则计算它们的点乘。对于这种尺寸的输入,其行为与 np.dot相同。它还允许您分批执行广播或 matrix x matrixmatrix x vectorvector x vector操作。

# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])


# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])

要执行矩阵(秩2张量)乘法,请使用下列任何一种等效方法:

AB = A.mm(B)


AB = torch.mm(A, B)


AB = torch.matmul(A, B)


AB = A @ B  # Python 3.5+ only

有一些微妙之处。来自 PyTorch 文档:

torch.mm不广播。对于广播矩阵产品, 见 torch.matmul()

例如,您不能将两个一维向量与 torch.mm相乘,也不能将批矩阵相乘(秩3)。为此,您应该使用通用性更强的 torch.matmul。有关 torch.matmul广播行为的详细列表,请参阅 文件

对于元素式的乘法,您可以简单地进行(如果 A 和 B 具有相同的形状)

A * B  # element-wise matrix multiplication (Hadamard product)

使用 torch.mm(a, b)torch.matmul(a, b)
两者都是一样的。

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

还有一个选择。 这是 @接线员。

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
[ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
[ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
[ 0.8699, -0.3445,  1.4122, -0.5826]])

这三种方法得出的结果是一样的。

相关链接:
矩阵乘法操作员
PEP 465——专门用于矩阵乘法的 infix 操作符

你可以使用“@”来计算两个张量之间的点积。

a = torch.tensor([[1,2],
[3,4]])
b = torch.tensor([[5,6],
[7,8]])
c = a@b #For dot product
c


d = a*b #For elementwise multiplication
d