Train ()在 PyTorch 中做什么?

它在 nn.Module中调用 forward()吗? 我认为当我们调用模型时,使用的是 forward方法。 为什么我们需要指定 train () ?

158526 次浏览

model.train()告诉你的模型你正在训练模型。这有助于向 Dropout 和 BatchNorm 等层提供信息,这些层在培训和评估期间的行为是不同的。例如,在培训模式中,BatchNorm 更新每个新批处理的移动平均值; 而在评估模式中,这些更新是冻结的。

更多细节: model.train()将模式设置为训练 (见 源代码)。您可以调用 model.eval()model.train(mode=False)来告诉您正在测试。 期望 train函数训练模型是有点直观的,但它并没有做到这一点。它只是设置模式。

有两种方法让模型知道你的意图,即你想训练模型还是你想用模型来评估。 在 model.train()的情况下,模型知道它必须学习层,当我们使用 model.eval()时,它表明模型没有什么新的东西要学习,模型用于测试。 model.eval()也是必要的,因为在 pytch 中,如果我们使用 batchscale,并且在测试期间,如果我们只想传递一个图像,如果没有指定 model.eval(),pytch 会抛出一个错误。

下面是 nn.Module.train()的代码:

def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self

下面是 nn.Module.eval()的代码:

def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)

默认情况下,self.training标志被设置为 True,也就是说,默认情况下模块处于 train 模式。当 self.trainingFalse时,模块处于相反的状态,即 eval 模式。

在最常用的图层中,只有 DropoutBatchNorm关心这个标志。

目前的 正式文件声明如下:

这只对某些模块有效。如果它们受到影响,请参阅特定模块的文档,了解它们在培训/评估模式中的行为细节,例如辍学、 BatchNorm 等。

model.train() model.eval()
火车ing 模式下设置模型,即。BatchNorm层使用每批次的统计 < br > • Dropout层被激活 < a href = “ https://stackoverflow. com/questions/66534762/which-pytors- 模块-被-model-eval-and-model-train”> 等 设置模型在 Eval推理(推理)模式,即 < br > < br > • BatchNorm层使用运行统计 < br > • Dropout层停用等
相当于 model.train(False)

注意: 这两个函数调用都不向前/向后运行,它们告诉模型 怎么做执行 什么时候运行。

这一点很重要,因为 一些模块(层)(例如 DropoutBatchNorm)在训练和推理期间的表现是不同的,因此如果运行在错误的模式下,模型将产生意想不到的结果。

考虑以下模型

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv


class GraphNet(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphNet, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)


def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.dropout(x, training=self.training) #Look here
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)

在这里,dropout的功能在不同的操作模式下是不同的。正如您所看到的,它只在 self.training==True。因此,当你输入 model.train()时,模型的转发函数将执行辍学,否则它不会(比如当 model.eval()model.train(mode=False))。