如何统计PyTorch模型的总参数数量(类似Keras的model.count_params())
在PyTorch中统计模型总参数数量(类似Keras的
model.count_params()) 嘿,我来教你怎么在PyTorch里实现和Keras中model.count_params()一样的功能——统计模型的总参数数量。其实方法很直观,主要有两种常用方式,咱们一步步来:
方法一:手动遍历参数统计(最基础可靠)
PyTorch里所有模型的参数都存在model.parameters()迭代器里,我们只需要遍历每个参数张量,计算它的元素个数,再累加起来就行。
示例代码:
首先先定义一个简单的测试模型:
import torch import torch.nn as nn class SimpleMLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 20) # 10输入→20输出,参数是10*20 +20(偏置)=220 self.fc2 = nn.Linear(20, 5) # 20*5 +5=105 self.dropout = nn.Dropout(0.5) # Dropout层没有可学习参数 model = SimpleMLP()
然后统计总参数数量(完全对应Keras的count_params()):
total_params = sum(p.numel() for p in model.parameters()) print(f"模型总参数数量: {total_params}") # 输出结果为220+105=325
这里的p.numel()是PyTorch张量的内置方法,用来获取张量里的元素总数,我们用sum()把所有参数的元素数加起来,就得到了总参数。
如果需要区分可训练参数(比如有些层被冻结,requires_grad=False),可以加个判断条件:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"可训练参数数量: {trainable_params}")
方法二:用torchsummary库一键查看(更便捷)
如果你想同时看到每层的输出形状、参数占比等额外信息,可以用torchsummary库,它会自动帮你统计总参数和可训练参数。
步骤:
- 先安装库:
pip install torchsummary
- 查看模型完整信息:
from torchsummary import summary # 注意input_size要和你的模型输入形状匹配,这里我们的模型输入是10维向量 summary(model, input_size=(10,))
运行后会输出一个清晰的表格,最后一行就是Total params和Trainable params,结果和手动统计完全一致。
内容的提问来源于stack exchange,提问作者Fábio Perez




