sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
这里有一个更详细的实现,它可以选择性地过滤掉不可训练的参数:
def numel(m: torch.nn.Module, only_trainable: bool = False):
"""
Returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
def count_number_of_parameters(model: nn.Module, only_trainable: bool = True) -> int:
"""
Counts the number of trainable params. If all params, specify only_trainable = False.
Ref:
- https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=brando_miranda
- https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model/62764464#62764464
:return:
"""
if only_trainable:
num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
else: # counts trainable and none-traibale
num_params: int = sum(p.numel() for p in model.parameters() if p)
assert num_params > 0, f'Err: {num_params=}'
return int(num_params)