PyTorch模型首轮输出正常后续批次出现NaN问题求助
首先,从你的代码和输出来看,第一个批次正常、第二个批次开始输出NaN的核心原因大概率是半精度(float16)训练的数值稳定性问题,再加上输入数据未做归一化导致的数值溢出。下面是具体的问题分析和修复步骤:
一、代码中的明显问题
1. 输入数据未做归一化
你的输入是70x70x3的图像,如果原始像素值是0-255的范围,直接输入模型的话,经过第一个全连接层(9800维度→100维度)的计算后,数值会非常大。而float16的数值范围远小于float32(最大约65504),很容易触发数值溢出变成inf,后续经过ReLU或其他运算后就会变成NaN。
2. 纯半精度训练缺乏数值稳定性保障
你直接将模型和数据转为half(),但没有使用PyTorch专门的混合精度训练工具(torch.cuda.amp)。纯半精度训练时,梯度计算和参数更新过程中很容易出现溢出或精度丢失,进而产生NaN。
3. 冗余的类型转换
在损失计算中,你对已经是float16的model_out再次调用half():
loss = loss_fn(model_out.half(), torch.flatten(values))
这虽然不会直接导致NaN,但属于冗余操作,建议移除。
二、修复步骤
1. 先做数据归一化
将输入图像的像素值归一化到[0,1]或[-1,1]范围,这是计算机视觉模型训练的基础操作:
# 假设X_train是0-255的uint8数组 X_train = X_train / 255.0 # 归一化到0-1 # 或者用均值方差归一化(如果有统计值的话) # X_train = (X_train - mean) / std
2. 改用混合精度训练
使用torch.cuda.amp来处理半精度训练,它会自动在合适的环节使用float32保存梯度和计算,避免数值溢出:
修改你的训练代码如下:
loss_fn = nn.MSELoss() dev = torch.device('cuda') optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4) losses = [] max_batches = 2 # 初始化混合精度的GradScaler scaler = torch.cuda.amp.GradScaler() def process_batch(): inputs = images.float().to(dev) # 先转float32,由amp自动处理半精度 values = scores.float().to(dev) optimizer.zero_grad() # 使用混合精度上下文管理器 with torch.cuda.amp.autocast(): outputs = model(inputs) model_out = torch.flatten(outputs) print(f"Outputs: {model_out}") loss = loss_fn(model_out, torch.flatten(values)) losses.append(loss.item()) # 用scaler反向传播,自动处理梯度缩放 scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # 模型不需要手动转half(),amp会自动处理 model.to(torch.device('cuda')) model.train() i = 0 for images, scores in train_loader: process_batch() i += 1 if i > max_batches: break
3. 可选:调整权重初始化
对于ReLU激活的全连接层,He初始化比默认的Xavier初始化更合适,可以减少数值波动:
修改FullyConnected类的初始化:
layers.append(nn.Linear(in_channels, out_channels, bias=True)) # 添加He初始化 nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')
三、PyTorch模型调试最佳实践
1. 检查输入与目标数据
- 打印输入数据的最大值、最小值,确认是否有异常值(比如
inf、NaN,或数值范围过大):print(f"Input max: {images.max()}, min: {images.min()}, has nan: {torch.isnan(images).any()}") print(f"Target max: {scores.max()}, min: {scores.min()}, has nan: {torch.isnan(scores).any()}")
2. 监控中间层输出
在模型的forward方法中,添加对各层输出的数值检查,看哪一步开始出现NaN或inf:
def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) print(f"fc1 output: max={x.max()}, min={x.min()}, has_nan={torch.isnan(x).any()}") x = self.fc2(x) print(f"fc2 output: max={x.max()}, min={x.min()}, has_nan={torch.isnan(x).any()}") x = self.last(x) x = torch.flatten(x) return x
3. 检查梯度状态
在反向传播后,打印各参数的梯度,确认是否有inf或NaN:
# 在loss.backward()之后添加(如果用混合精度则在scaler.scale(loss).backward()之后) for name, param in model.named_parameters(): if param.grad is not None: print(f"{name} grad: max={param.grad.max()}, min={param.grad.min()}, has_nan={torch.isnan(param.grad).any()}") else: print(f"{name} grad is None")
4. 先禁用半精度测试
暂时将模型和数据都改为float32,如果问题消失,说明确实是半精度的数值稳定性问题,此时再用混合精度工具来解决。
5. 调整学习率与优化器
如果梯度爆炸是原因,尝试降低学习率(比如从1e-3降到1e-4),或者更换优化器(比如SGD代替Adam,Adam在半精度下更容易出现数值问题)。
6. 使用PyTorch的调试工具
torch.autograd.detect_anomaly():自动检测反向传播中的NaN/inf来源,使用方法:with torch.autograd.detect_anomaly(): loss.backward()
内容的提问来源于stack exchange,提问作者GeneC




