PyTorch 首选的复制张量的方法

似乎有几种方法可以在 PyTorch 中创建张量的副本,包括

y = tensor.new_tensor(x) #a


y = x.clone().detach() #b


y = torch.empty_like(x).copy_(x) #c


y = torch.tensor(x) #d

根据执行 ad时得到的用户警告,显式优先于 ad。为什么偏爱它?表演?我认为它不太好读。

使用 c的原因是什么?

115476 次浏览

Pytch’1.1.0’现在推荐 # b,并显示了 # d 的警告

根据 Pytorch 文档 # a 和 # b 是等价的,它还说

推荐使用 clone ()和 detach ()的等价物。

因此,如果你想复制张量,并从计算图中分离出来,你应该使用

y = x.clone().detach()

因为它是最干净和最可读的方式。在所有其他版本中都存在一些隐藏的逻辑,而且也不是100% 清楚计算图和梯度传播发生了什么。

关于 # c: 对于实际执行的操作似乎有点复杂,还可能引入一些开销,但我不确定这一点。

编辑: 既然在评论中有人问为什么不直接使用 .clone()

来自 Pytorch 医生

与 copy _ ()不同,这个函数记录在计算图中。传播到克隆张量的梯度将传播到原始张量。

因此,当 .clone()返回数据的副本时,它保留计算图并在其中记录克隆操作。如前所述,这将导致梯度传播到克隆张量也传播到原始张量。此行为可能导致错误,且不明显。由于这些可能的副作用,只有在明确需要这种行为时,才应该通过 .clone()克隆张量。为了避免这些副作用,增加了 .detach()以断开计算图与克隆张量之间的连接。

因为一般来说,对于复制操作,一个人需要一个干净的副本,不能导致不可预见的副作用,首选的方式复制张量是 .clone().detach()

DR

使用 .clone().detach()(最好是 .detach().clone())

如果您首先分离张量,然后克隆它,计算路径不会被复制,反之,计算路径会被复制,然后被放弃。因此,.detach().clone()的效率要略高一些。—— 火炬论坛

因为它的动作有点快,而且很明确。


使用 perflot,我绘制了各种方法复制一个热点张量的时间。

y = tensor.new_tensor(x) # method a


y = x.clone().detach() # method b


y = torch.empty_like(x).copy_(x) # method c


y = torch.tensor(x) # method d


y = x.detach().clone() # method e

X 轴是张量的尺寸,y 轴表示时间。这个图是线性比例的。正如您可以清楚地看到的,与其他三种方法相比,tensor()new_tensor()需要更多的时间。

enter image description here

注意: 在多次运行中,我注意到在 b、 c、 e 中,任何方法的时间都可能最短。对于 a 和 d 也是如此,但是方法 b,c,e 一直比 a 和 d 具有更低的时间。

import torch
import perfplot


perfplot.show(
setup=lambda n: torch.randn(n),
kernels=[
lambda a: a.new_tensor(a),
lambda a: a.clone().detach(),
lambda a: torch.empty_like(a).copy_(a),
lambda a: torch.tensor(a),
lambda a: a.detach().clone(),
],
labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
n_range=[2 ** k for k in range(15)],
xlabel="len(a)",
logx=False,
logy=False,
title='Timing comparison for copying a pytorch tensor',
)

一个检查张量是否被复制的例子:

import torch
def samestorage(x,y):
if x.storage().data_ptr()==y.storage().data_ptr():
print("same storage")
else:
print("different storage")
a = torch.ones((1,2), requires_grad=True)
print(a)
b = a
c = a.data
d = a.detach()
e = a.data.clone()
f = a.clone()
g = a.detach().clone()
i = torch.empty_like(a).copy_(a)
j = torch.tensor(a) # UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).




print("a:",end='');samestorage(a,a)
print("b:",end='');samestorage(a,b)
print("c:",end='');samestorage(a,c)
print("d:",end='');samestorage(a,d)
print("e:",end='');samestorage(a,e)
print("f:",end='');samestorage(a,f)
print("g:",end='');samestorage(a,g)
print("i:",end='');samestorage(a,i)


退出:

tensor([[1., 1.]], requires_grad=True)
a:same storage
b:same storage
c:same storage
d:same storage
e:different storage
f:different storage
g:different storage
i:different storage
j:different storage

如果 不同的储藏室出现,张量就是 复制。 PyTorch 有近100种不同的构造函数,因此您可以添加更多的方法。

如果我需要复制一个张量,我会使用 copy(),这也复制了广告相关的信息,所以如果我需要删除广告相关的信息,我会使用:

y = x.clone().detach()