How to view the number of parameters in PyTorch?

You can retrieve all parameters of a model in the PyTorch library using model.parameters(), and then count the number of parameters using the len() function. Here is an example code:

import torch
import torch.nn as nn

# 创建模型
model = nn.Linear(10, 5)

# 统计参数数量
num_parameters = sum(p.numel() for p in model.parameters())
print(f"模型参数数量: {num_parameters}")

The output will display the number of parameters in the model.

bannerAds