You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch带权重衰减的LeNet5模型保存加载报错问题解决

问题原因分析与解决方案

首先要明确一个关键事实:优化器的配置(包括weight_decay)不会影响模型的state_dict——模型的state_dict只存储自身可学习的参数(卷积层权重、偏置,全连接层参数等),和优化器的设置完全无关。你遇到的加载错误,本质是加载的模型权重文件对应的模型结构,和当前使用的LeNet5_mode0结构不一致

为什么会出现这些错误键?

看报错信息里的"conv1.1.weight""conv1.1.running_mean"等参数,这些是BatchNorm2d层的专属参数,但你现在提供的LeNet5_mode0模型中,conv1这个Sequential里的第1个元素(索引1)是ReLU层,而ReLU根本没有这些参数。这说明:你之前保存权重时使用的模型,在conv1conv2Sequential里包含了BatchNorm2d层,而现在的模型移除了这些BN层,导致键不匹配。你可能是在修改优化器配置的同时,不小心调整了模型结构,才误以为是优化器的问题。

修复方案

根据你的需求,有两种常见的解决方式:

方式1:快速兼容加载(忽略不匹配的参数)

如果只需要加载当前模型存在的参数,可以在load_state_dict时添加strict=False参数,它会自动忽略当前模型中不存在的键:

model = LeNet5_mode0()
model.load_state_dict(torch.load(loadpath), strict=False)

这种方式简单快捷,但会直接丢弃旧权重里的BN层参数,当前模型中新增的参数(如果有的话)会保留初始化值。

方式2:精准过滤匹配的参数(更严谨)

如果你想手动控制只加载当前模型存在的参数,可以先过滤旧的state_dict

model = LeNet5_mode0()
old_state_dict = torch.load(loadpath)
# 只保留当前模型拥有的键
filtered_state = {k: v for k, v in old_state_dict.items() if k in model.state_dict()}
# 加载过滤后的权重
model.load_state_dict(filtered_state, strict=False)

后续建议

为了避免类似问题,建议:

  • 保存模型时,同时记录对应的模型结构代码(比如用注释、单独的模型定义文件);
  • 如果需要保存训练的完整状态(包括优化器、调度器),可以把它们一起存入一个字典:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict()
}, savepath)

加载时再分别取出,但前提是模型结构必须完全一致。

内容的提问来源于stack exchange,提问作者Nirit

火山引擎 最新活动