检查 PyTorch 模型中的参数总数

如何计算 PyTorch 模型中的参数总数? 类似于 Keras 的 model.count_params()

101555 次浏览

PyTorch 没有像 Kera 那样计算参数总数的函数,但是可以对每个参数组的元素数进行求和:

pytorch_total_params = sum(p.numel() for p in model.parameters())

如果只计算 可训练的参数:

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

答案的灵感来自 PyTorch 论坛的 回答

注: 我是 回答我自己的问题。如果有人有更好的解决方案,请与我们分享。

如果你想在不实例化模型的情况下计算每一层的权重和偏差数量,你可以简单地加载原始文件并迭代得到的 collections.OrderedDict,如下所示:

import torch




tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
print('Layer {}: {} elements'.format(layer_tensor_name, torch.numel(tensor)))

你会得到

conv1.weight: 312
conv1.bias: 26
batch_norm1.weight: 26
batch_norm1.bias: 26
batch_norm1.running_mean: 26
batch_norm1.running_var: 26
conv2.weight: 2340
conv2.bias: 10
batch_norm2.weight: 10
batch_norm2.bias: 10
batch_norm2.running_mean: 10
batch_norm2.running_var: 10
fcs.layers.0.weight: 135200
fcs.layers.0.bias: 260
fcs.layers.1.weight: 33800
fcs.layers.1.bias: 130
fcs.batch_norm_layers.0.weight: 260
fcs.batch_norm_layers.0.bias: 260
fcs.batch_norm_layers.0.running_mean: 260
fcs.batch_norm_layers.0.running_var: 260

这是另一个值得尊敬的解决方案

def model_summary(model):
print("model_summary")
print()
print("Layer_name"+"\t"*7+"Number of Parameters")
print("="*100)
model_parameters = [layer for layer in model.parameters() if layer.requires_grad]
layer_name = [child for child in model.children()]
j = 0
total_params = 0
print("\t"*10)
for i in layer_name:
print()
param = 0
try:
bias = (i.bias is not None)
except:
bias = False
if not bias:
param =model_parameters[j].numel()+model_parameters[j+1].numel()
j = j+2
else:
param =model_parameters[j].numel()
j = j+1
print(str(i)+"\t"*3+str(param))
total_params+=param
print("="*100)
print(f"Total Params:{total_params}")


model_summary(net)

这将产生与下面类似的结果

model_summary


Layer_name                          Number of Parameters
====================================================================================================


Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))             60
Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))            880
Linear(in_features=576, out_features=120, bias=True)        69240
Linear(in_features=120, out_features=84, bias=True)         10164
Linear(in_features=84, out_features=10, bias=True)          850
====================================================================================================
Total Params:81194

为了得到像 Keras 这样的每个层的参数计数,PyTorch 使用了返回参数名和参数本身的迭代器的 model.named_paramters()。例如:

from prettytable import PrettyTable


def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
    

count_parameters(net)

输出示例:

+-------------------+------------+
|      Modules      | Parameters |
+-------------------+------------+
| embeddings.weight |   922866   |
|    conv1.weight   |  1048576   |
|     conv1.bias    |    1024    |
|     bn1.weight    |    1024    |
|      bn1.bias     |    1024    |
|    conv2.weight   |  2097152   |
|     conv2.bias    |    1024    |
|     bn2.weight    |    1024    |
|      bn2.bias     |    1024    |
|    conv3.weight   |  2097152   |
|     conv3.bias    |    1024    |
|     bn3.weight    |    1024    |
|      bn3.bias     |    1024    |
|    lin1.weight    |  50331648  |
|     lin1.bias     |    512     |
|    lin2.weight    |   265728   |
|     lin2.bias     |    519     |
+-------------------+------------+
Total Trainable Params: 56773369

为了避免重复计算共享参数,使用 torch.Tensor.data_ptr。例如:

sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

这里有一个更详细的实现,它可以选择性地过滤掉不可训练的参数:

def numel(m: torch.nn.Module, only_trainable: bool = False):
"""
Returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)

你可以使用 torchsummary来做同样的事情,它只是两行代码。

from torchsummary import summary


print(summary(model, (input_shape)))

有一个内置的实用函数可以将一个迭代的张量转换成一个张量: torch.nn.utils.parameters_to_vector,然后与 torch.numel结合:

torch.nn.utils.parameters_to_vector(model.parameters()).numel()

具有命名导入(from torch.nn.utils import parameters_to_vector)的缩写:

parameters_to_vector(model.parameters()).numel()

正如 @ f ábio-perez所提到的,PyTorch 中没有这样的内置函数。

然而,我发现这是达到同样结果的一种简洁方法:

num_of_parameters = sum(map(torch.numel, model.parameters()))

直截了当

print(sum(p.numel() for p in model.parameters()))

你可以插入的最终答案是:

def count_number_of_parameters(model: nn.Module, only_trainable: bool = True) -> int:
"""
Counts the number of trainable params. If all params, specify only_trainable = False.


Ref:
- https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=brando_miranda
- https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model/62764464#62764464
:return:
"""
if only_trainable:
num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
else:  # counts trainable and none-traibale
num_params: int = sum(p.numel() for p in model.parameters() if p)
assert num_params > 0, f'Err: {num_params=}'
return int(num_params)