You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch 1.9.1中nn.Parameter未被保存至model.state_dict的问题咨询

PyTorch 1.9.1中nn.Parameter未被保存至model.state_dict的问题咨询

看起来你遇到的这个问题在PyTorch 1.x早期版本里其实挺常见的,我之前也碰到过类似的情况,核心原因是你对nn.Parameter的设备转换方式在1.9.1里会导致参数没有被正确注册到模型的参数列表中。

问题根源

当你执行self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)).to(self.device)的时候,.to(self.device)会返回一个新的张量对象——而原来通过nn.Parameter创建的那个参数其实并没有被赋值给self.class_token,你实际赋值的是这个脱离了nn.Parameter注册流程的新张量。这就导致模型的state_dict里不会包含这个参数,因为它根本没被识别成模型的可训练参数。

解决方法

这里给你几个可行的方案,适配PyTorch 1.9.1的环境:

  • 方案一:先注册参数,再整体移动模型到设备
    把参数初始化和设备转换分开操作,优先完成参数注册:

    # 先初始化参数,此时参数在CPU上,但已被正确注册为模型参数
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
    # 在__init__方法末尾或外部调用时,整体移动模型到目标设备
    self.to(self.device)
    

    这样模型的所有参数(包括class_token)都会被正确移动到目标设备,同时保持注册状态,state_dict自然会包含它。

  • 方案二:直接在目标设备上创建参数
    如果你需要单独处理这个参数,可以直接在目标设备上生成张量再包装成nn.Parameter

    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d, device=self.device))
    

    这种写法不需要后续的.to()操作,参数从一开始就处于正确的设备,并且被模型正确注册。

  • 方案三:手动重新赋值转换后的参数
    如果你因为某些原因必须先创建参数再转换设备,一定要把转换后的结果重新赋值给self.class_token

    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
    # 重新赋值,确保模型注册的是转换后的参数
    self.class_token = self.class_token.to(self.device)
    

    这样新的设备上的张量会被重新注册为模型的参数,从而被state_dict捕获。

验证方式

你可以在保存模型前执行print(list(model.state_dict().keys())),查看输出结果里是否包含class_token这个键,如果有就说明参数已经被正确注册了。

关于版本差异的说明

你同事用新版本PyTorch没问题,是因为PyTorch 2.0及之后对.to()的行为做了优化,调整了参数注册的逻辑,可能直接修改了原参数的设备而不是返回新对象,所以他的写法能正常工作。但在1.9.1里必须遵循上述的正确写法。

备注:内容来源于stack exchange,提问作者gay-victorian-astronomer

火山引擎 最新活动