You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何统计PyTorch模型的总参数数量(类似Keras的model.count_params())

在PyTorch中统计模型总参数数量(类似Keras的model.count_params()

嘿,我来教你怎么在PyTorch里实现和Keras中model.count_params()一样的功能——统计模型的总参数数量。其实方法很直观,主要有两种常用方式,咱们一步步来:

方法一:手动遍历参数统计(最基础可靠)

PyTorch里所有模型的参数都存在model.parameters()迭代器里,我们只需要遍历每个参数张量,计算它的元素个数,再累加起来就行。

示例代码:

首先先定义一个简单的测试模型:

import torch
import torch.nn as nn

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)  # 10输入→20输出,参数是10*20 +20(偏置)=220
        self.fc2 = nn.Linear(20, 5)   # 20*5 +5=105
        self.dropout = nn.Dropout(0.5) # Dropout层没有可学习参数
        
model = SimpleMLP()

然后统计总参数数量(完全对应Keras的count_params()):

total_params = sum(p.numel() for p in model.parameters())
print(f"模型总参数数量: {total_params}")  # 输出结果为220+105=325

这里的p.numel()是PyTorch张量的内置方法,用来获取张量里的元素总数,我们用sum()把所有参数的元素数加起来,就得到了总参数。

如果需要区分可训练参数(比如有些层被冻结,requires_grad=False),可以加个判断条件:

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"可训练参数数量: {trainable_params}")

方法二:用torchsummary库一键查看(更便捷)

如果你想同时看到每层的输出形状、参数占比等额外信息,可以用torchsummary库,它会自动帮你统计总参数和可训练参数。

步骤:

  1. 先安装库:
pip install torchsummary
  1. 查看模型完整信息:
from torchsummary import summary

# 注意input_size要和你的模型输入形状匹配,这里我们的模型输入是10维向量
summary(model, input_size=(10,))

运行后会输出一个清晰的表格,最后一行就是Total paramsTrainable params,结果和手动统计完全一致。


内容的提问来源于stack exchange,提问作者Fábio Perez

火山引擎 最新活动