You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch保存GPU模型Checkpoint时报错RuntimeError: unsupported Storage type的解决方案求助

解决PyTorch从GPU保存模型时的RuntimeError: unsupported Storage type问题

我看到你在使用PyTorch 1.9.0+cu102时遇到了一个棘手的问题:模型在CUDA上训练一切正常,但调用torch.save(model.state_dict())时抛出RuntimeError: unsupported Storage type,而CPU环境下保存完全没问题,甚至用简单模型测试也复现了相同错误。结合你的环境和测试结果,这个问题基本是旧版本PyTorch的CUDA序列化bug导致的,下面给你几个靠谱的解决方案:

方案1:临时修复——将参数移到CPU后保存

这是最快解决问题的办法,不需要改动环境,只需要修改保存代码,把state_dict里的所有张量转到CPU再序列化:

# 替换你原来的模型保存代码
torch.save(
    {key: value.cpu() for key, value in model.state_dict().items()},
    model_path
)
# 优化器和调度器的state_dict也做同样处理
torch.save(
    {key: value.cpu() for key, value in optimizer.state_dict().items()},
    optimizer_path
)
torch.save(
    {key: value.cpu() for key, value in scheduler.state_dict().items()},
    scheduler_path
)

加载的时候,你可以直接指定目标设备,无需额外处理:

# 直接加载到GPU
model.load_state_dict(torch.load(model_path, map_location='cuda'))
# 或者先加载到CPU再移到GPU
model.load_state_dict(torch.load(model_path))
model.to('cuda')

方案2:彻底解决——升级PyTorch版本

PyTorch 1.9.0是2021年发布的旧版本,这个CUDA存储序列化的bug在后续的1.10及以上版本已经被官方修复了。如果你的环境允许升级,这是最彻底的解决方式。

用conda升级(适配CUDA环境)

根据你的CUDA版本选择对应命令,若打算升级到稳定的CUDA 11.8版本,可以执行:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

如果想继续使用CUDA 10.2,可以查找适配的PyTorch版本(不过cu102已停止官方支持,建议同步升级CUDA到11.x以上)。

用pip升级

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

升级完成后,你就能用原来的保存代码直接保存GPU上的模型state_dict,不会再触发该错误。

额外检查项

如果上述方案都无法解决问题,可以排查以下两点:

  • 模型中是否使用了第三方库的特殊张量类型(比如自定义存储格式)
  • 确认CUDA驱动版本和PyTorch的CUDA版本是否兼容(例如1.9+cu102需要驱动版本≥440.33)

内容的提问来源于stack exchange,提问作者System1c

火山引擎 最新活动