You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

训练PyTorch LSTM多分类模型时触发“CUDA error: device-side assert triggered”错误的求助

训练PyTorch LSTM多分类模型时触发“CUDA error: device-side assert triggered”错误的求助

你遇到的这个CUDA断言错误在PyTorch多分类任务里很常见,结合你的描述和代码,我帮你梳理几个最可能的原因和排查方向:

1. 标签索引不匹配(最可能的核心原因)

你的任务是9分类,原始输出类别是1-9,但PyTorch的CrossEntropyLoss要求标签必须是从0开始的连续索引(也就是0-8),而不是1-9。如果你的验证集标签里还存在9这个值,模型输出的logits维度是9(对应0-8类),此时标签9会超出索引范围,直接触发CUDA设备端的断言错误——这也是"训练正常但验证出错"的典型场景(比如训练集标签已经转成0-8,但验证集没同步处理)。

你可以先快速验证这个点:

# 打印验证集标签的最小值和最大值,确认是否超出0-8范围
print("Val labels min:", torch.min(your_val_labels_tensor))
print("Val labels max:", torch.max(your_val_labels_tensor))

如果最大值是9,你需要把所有标签统一减1,转成0-8的索引后再输入模型。另外你的代码里计算类别权重时用了classes=np.array([0,1,2,3,4,5,6,7,8]),这也侧面说明你应该把原始标签1-9转成0-8的索引格式。

2. 验证集输入的预处理不一致

你提到用collate_fn截断文本到1500长度,但要确认验证集的DataLoader是否和训练集用了同一个collate_fn。如果验证集的文本没有被截断,导致输入序列过长,可能会触发显存相关的断言错误(不过你的错误是在print(outputs)时触发,更偏向标签问题,但这个点也可以排查)。

3. 模型输入的维度/形状不匹配

虽然你检查了输入类型是long,但可以再确认验证集的输入维度是否和训练集完全一致:

  • 比如gene_inputvariance_input的形状:你的模型里对这两个输入用了squeeze(1),如果验证集的这两个输入维度和训练集不一样(比如训练集是(batch_size,1),验证集是(batch_size,)),squeeze后会导致特征维度错误,进而在拼接时出问题,触发断言。

你可以在验证循环里临时添加打印代码:

print("Text input shape:", text_input.shape)
print("Gene input shape:", gene_input.shape)
print("Variance input shape:", variance_input.shape)

对比训练集的输入形状,确保两者完全一致。

4. 启用CUDA DSA查看更详细的错误信息

你已经设置了TORCH_USE_CUDA_DSA=1,但需要重新运行代码才能看到更详细的设备端断言信息——它会明确告诉你是哪个断言失败了(比如"标签索引超出范围"、"特征维度不匹配"等),这能帮你一步定位问题根源。

快速修复尝试

先把验证集的标签统一减1转成0-8的索引格式,然后重新运行验证阶段。如果错误消失,那就是标签索引的问题;如果还有错误,再结合CUDA DSA输出的详细信息进一步排查。

另外,你用的RTX3050 4G显存,截断文本到1500是合理的,也可以尝试把batch_size再调小一点,避免显存波动导致的隐性问题。

备注:内容来源于stack exchange,提问作者atharva mishra

火山引擎 最新活动