You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch训练时RuntimeError:Double与Float数据类型不匹配问题排查

解决PyTorch卷积层Double/Float类型不匹配错误

这个RuntimeError是PyTorch里非常常见的类型不匹配问题——说白了就是你的模型权重和输入数据的张量类型对不上:卷积层期望权重是Double(float64)类型,但你喂进去的输入是Float(float32),或者反过来。下面我一步步帮你排查和解决:

第一步:定位问题根源

先搞清楚到底是输入数据的类型不对,还是模型参数的类型不对,加两行调试代码就能搞定:

  1. 检查输入数据的类型:
# 从训练数据加载器取一个样本
sample_data, _ = next(iter(train_loader))
print("输入图像的 dtype:", sample_data.dtype)
  1. 检查模型参数的类型:
# 取模型第一个可训练参数(比如第一个卷积层的权重)
print("模型参数的 dtype:", next(model.parameters()).dtype)

对比这两个输出,如果一个是torch.float32,另一个是torch.float64,那就是问题所在了。

常见触发场景及修复方案

场景1:输入数据是Double,模型是Float(最常见)

如果你的数据加载代码里,不小心把图像转成了float64(比如用了numpy.array的默认类型,或者手动调用了torch.double()),而PyTorch的nn.Module默认参数是float32

修复方法:在数据加载的transform里统一转成float32:

# 比如你的transform列表里加这一项
transforms.Compose([
    # 其他transform...
    transforms.Lambda(lambda x: x.float())  # 把图像转成float32
])

或者在训练循环里拿到数据后直接转:

data = data.float()

场景2:模型是Double,输入数据是Float

如果之前不小心调用了model.double(),或者模型里的自定义层手动创建了float64的张量,导致整个模型参数变成了Double类型。

修复方法:把模型转成float32:

model = CircleNet().float()  # 初始化后直接转类型

或者在训练前调用:

model.float()

场景3:自定义模型里的手动张量类型不匹配

如果你的CircleNet里有手动创建的张量(比如用torch.tensor()初始化的固定参数,而不是nn.Parameter()),要确保这些张量的类型和模型其他参数一致:

# 不要直接写torch.tensor([1.0, 2.0]),而是用模型的dtype
custom_tensor = torch.tensor([1.0, 2.0], dtype=model.dtype)

额外注意事项

  • 关于CrossEntropyLossAdam优化器:这两个组件会自动适配模型参数的类型,只要输入数据和模型参数类型一致,它们就不会出问题。
  • 标签的类型:CrossEntropyLoss要求标签是torch.long类型,这个不用和输入/模型参数的类型一致,只要你的标签是整数类型就行,别转成float/double。

快速验证

修改完之后,再跑一遍之前的调试代码,确认输入数据和模型参数的dtype完全一致,然后重新训练,这个RuntimeError应该就消失了。

内容的提问来源于stack exchange,提问作者Michael Smith

火山引擎 最新活动