它在 nn.Module中调用 forward()吗? 我认为当我们调用模型时,使用的是 forward方法。 为什么我们需要指定 train () ?
nn.Module
forward()
forward
model.train()告诉你的模型你正在训练模型。这有助于向 Dropout 和 BatchNorm 等层提供信息,这些层在培训和评估期间的行为是不同的。例如,在培训模式中,BatchNorm 更新每个新批处理的移动平均值; 而在评估模式中,这些更新是冻结的。
model.train()
更多细节: model.train()将模式设置为训练 (见 源代码)。您可以调用 model.eval()或 model.train(mode=False)来告诉您正在测试。 期望 train函数训练模型是有点直观的,但它并没有做到这一点。它只是设置模式。
model.eval()
model.train(mode=False)
train
有两种方法让模型知道你的意图,即你想训练模型还是你想用模型来评估。 在 model.train()的情况下,模型知道它必须学习层,当我们使用 model.eval()时,它表明模型没有什么新的东西要学习,模型用于测试。 model.eval()也是必要的,因为在 pytch 中,如果我们使用 batchscale,并且在测试期间,如果我们只想传递一个图像,如果没有指定 model.eval(),pytch 会抛出一个错误。
下面是 nn.Module.train()的代码:
nn.Module.train()
def train(self, mode=True): r"""Sets the module in training mode.""" self.training = mode for module in self.children(): module.train(mode) return self
下面是 nn.Module.eval()的代码:
nn.Module.eval()
def eval(self): r"""Sets the module in evaluation mode.""" return self.train(False)
默认情况下,self.training标志被设置为 True,也就是说,默认情况下模块处于 train 模式。当 self.training为 False时,模块处于相反的状态,即 eval 模式。
self.training
True
False
在最常用的图层中,只有 Dropout和 BatchNorm关心这个标志。
Dropout
BatchNorm
目前的 正式文件声明如下:
这只对某些模块有效。如果它们受到影响,请参阅特定模块的文档,了解它们在培训/评估模式中的行为细节,例如辍学、 BatchNorm 等。
model.train(False)
注意: 这两个函数调用都不向前/向后运行,它们告诉模型 怎么做执行 什么时候运行。
这一点很重要,因为 一些模块(层)(例如 Dropout,BatchNorm)在训练和推理期间的表现是不同的,因此如果运行在错误的模式下,模型将产生意想不到的结果。
考虑以下模型
import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv class GraphNet(torch.nn.Module): def __init__(self, num_node_features, num_classes): super(GraphNet, self).__init__() self.conv1 = GCNConv(num_node_features, 16) self.conv2 = GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.dropout(x, training=self.training) #Look here x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
在这里,dropout的功能在不同的操作模式下是不同的。正如您所看到的,它只在 self.training==True。因此,当你输入 model.train()时,模型的转发函数将执行辍学,否则它不会(比如当 model.eval()或 model.train(mode=False))。
dropout
self.training==True