PyTorch 1.9.1中nn.Parameter未被保存至model.state_dict的问题咨询
看起来你遇到的这个问题在PyTorch 1.x早期版本里其实挺常见的,我之前也碰到过类似的情况,核心原因是你对nn.Parameter的设备转换方式在1.9.1里会导致参数没有被正确注册到模型的参数列表中。
问题根源
当你执行self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)).to(self.device)的时候,.to(self.device)会返回一个新的张量对象——而原来通过nn.Parameter创建的那个参数其实并没有被赋值给self.class_token,你实际赋值的是这个脱离了nn.Parameter注册流程的新张量。这就导致模型的state_dict里不会包含这个参数,因为它根本没被识别成模型的可训练参数。
解决方法
这里给你几个可行的方案,适配PyTorch 1.9.1的环境:
方案一:先注册参数,再整体移动模型到设备
把参数初始化和设备转换分开操作,优先完成参数注册:# 先初始化参数,此时参数在CPU上,但已被正确注册为模型参数 self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)) # 在__init__方法末尾或外部调用时,整体移动模型到目标设备 self.to(self.device)这样模型的所有参数(包括
class_token)都会被正确移动到目标设备,同时保持注册状态,state_dict自然会包含它。方案二:直接在目标设备上创建参数
如果你需要单独处理这个参数,可以直接在目标设备上生成张量再包装成nn.Parameter:self.class_token = nn.Parameter(torch.rand(1, self.hidden_d, device=self.device))这种写法不需要后续的
.to()操作,参数从一开始就处于正确的设备,并且被模型正确注册。方案三:手动重新赋值转换后的参数
如果你因为某些原因必须先创建参数再转换设备,一定要把转换后的结果重新赋值给self.class_token:self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)) # 重新赋值,确保模型注册的是转换后的参数 self.class_token = self.class_token.to(self.device)这样新的设备上的张量会被重新注册为模型的参数,从而被
state_dict捕获。
验证方式
你可以在保存模型前执行print(list(model.state_dict().keys())),查看输出结果里是否包含class_token这个键,如果有就说明参数已经被正确注册了。
关于版本差异的说明
你同事用新版本PyTorch没问题,是因为PyTorch 2.0及之后对.to()的行为做了优化,调整了参数注册的逻辑,可能直接修改了原参数的设备而不是返回新对象,所以他的写法能正常工作。但在1.9.1里必须遵循上述的正确写法。
备注:内容来源于stack exchange,提问作者gay-victorian-astronomer




