You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在PyTorch中按名称提取模型的特定参数张量

在PyTorch中按名称提取特定参数张量的方法

嘿,这事儿其实挺简单的,给你两种常用的解决方案,按需选就行:

方法1:直接通过子模块属性访问(最便捷)

因为你的模型里的fc1是Linear子模块,PyTorch会把这些子模块作为模型对象的属性直接暴露出来,所以你可以直接通过属性链获取目标参数:

# 获取fc1的权重张量
fc1_weight = myModel.fc1.weight
# 获取fc1的偏置张量
fc1_bias = myModel.fc1.bias

如果想确认参数的具体内容,直接打印或者加上.data查看即可,比如print(fc1_weight)就能看到张量的具体数值。

方法2:使用named_parameters()遍历查找(更灵活)

如果你的模型结构复杂,或者需要批量筛选参数,named_parameters()方法会返回所有参数的名称-张量对,你可以遍历它们并筛选出目标参数:

target_params = {}
for name, param in myModel.named_parameters():
    if name in ['fc1.weight', 'fc1.bias']:
        target_params[name] = param

# 之后可通过名称直接访问目标参数
fc1_weight = target_params['fc1.weight']
fc1_bias = target_params['fc1.bias']

要是不确定参数的准确名称,先运行这段代码打印所有参数名称确认一下:

for name, param in myModel.named_parameters():
    print(name)

你的模型会输出fc1.weightfc1.biasfc2.weightfc2.bias这些完整命名,方便你精准筛选。


内容的提问来源于stack exchange,提问作者Amin Kaveh

火山引擎 最新活动