You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch模型首轮输出正常后续批次出现NaN问题求助

问题诊断与修复建议

首先,从你的代码和输出来看,第一个批次正常、第二个批次开始输出NaN的核心原因大概率是半精度(float16)训练的数值稳定性问题,再加上输入数据未做归一化导致的数值溢出。下面是具体的问题分析和修复步骤:

一、代码中的明显问题

1. 输入数据未做归一化

你的输入是70x70x3的图像,如果原始像素值是0-255的范围,直接输入模型的话,经过第一个全连接层(9800维度→100维度)的计算后,数值会非常大。而float16的数值范围远小于float32(最大约65504),很容易触发数值溢出变成inf,后续经过ReLU或其他运算后就会变成NaN

2. 纯半精度训练缺乏数值稳定性保障

你直接将模型和数据转为half(),但没有使用PyTorch专门的混合精度训练工具(torch.cuda.amp)。纯半精度训练时,梯度计算和参数更新过程中很容易出现溢出或精度丢失,进而产生NaN。

3. 冗余的类型转换

在损失计算中,你对已经是float16的model_out再次调用half()

loss = loss_fn(model_out.half(), torch.flatten(values))

这虽然不会直接导致NaN,但属于冗余操作,建议移除。

二、修复步骤

1. 先做数据归一化

将输入图像的像素值归一化到[0,1]或[-1,1]范围,这是计算机视觉模型训练的基础操作:

# 假设X_train是0-255的uint8数组
X_train = X_train / 255.0  # 归一化到0-1
# 或者用均值方差归一化(如果有统计值的话)
# X_train = (X_train - mean) / std

2. 改用混合精度训练

使用torch.cuda.amp来处理半精度训练,它会自动在合适的环节使用float32保存梯度和计算,避免数值溢出:
修改你的训练代码如下:

loss_fn = nn.MSELoss()
dev = torch.device('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
losses = []
max_batches = 2

# 初始化混合精度的GradScaler
scaler = torch.cuda.amp.GradScaler()

def process_batch():
    inputs = images.float().to(dev)  # 先转float32,由amp自动处理半精度
    values = scores.float().to(dev)
    optimizer.zero_grad()
    
    # 使用混合精度上下文管理器
    with torch.cuda.amp.autocast():
        outputs = model(inputs)
        model_out = torch.flatten(outputs)
        print(f"Outputs: {model_out}")
        loss = loss_fn(model_out, torch.flatten(values))
    
    losses.append(loss.item())
    # 用scaler反向传播,自动处理梯度缩放
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

# 模型不需要手动转half(),amp会自动处理
model.to(torch.device('cuda'))
model.train()
i = 0
for images, scores in train_loader:
    process_batch()
    i += 1
    if i > max_batches:
        break

3. 可选:调整权重初始化

对于ReLU激活的全连接层,He初始化比默认的Xavier初始化更合适,可以减少数值波动:
修改FullyConnected类的初始化:

layers.append(nn.Linear(in_channels, out_channels, bias=True))
# 添加He初始化
nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')

三、PyTorch模型调试最佳实践

1. 检查输入与目标数据

  • 打印输入数据的最大值、最小值,确认是否有异常值(比如infNaN,或数值范围过大):
    print(f"Input max: {images.max()}, min: {images.min()}, has nan: {torch.isnan(images).any()}")
    print(f"Target max: {scores.max()}, min: {scores.min()}, has nan: {torch.isnan(scores).any()}")
    

2. 监控中间层输出

在模型的forward方法中,添加对各层输出的数值检查,看哪一步开始出现NaNinf

def forward(self, x):
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    print(f"fc1 output: max={x.max()}, min={x.min()}, has_nan={torch.isnan(x).any()}")
    x = self.fc2(x)
    print(f"fc2 output: max={x.max()}, min={x.min()}, has_nan={torch.isnan(x).any()}")
    x = self.last(x)
    x = torch.flatten(x)
    return x

3. 检查梯度状态

在反向传播后,打印各参数的梯度,确认是否有infNaN

# 在loss.backward()之后添加(如果用混合精度则在scaler.scale(loss).backward()之后)
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name} grad: max={param.grad.max()}, min={param.grad.min()}, has_nan={torch.isnan(param.grad).any()}")
    else:
        print(f"{name} grad is None")

4. 先禁用半精度测试

暂时将模型和数据都改为float32,如果问题消失,说明确实是半精度的数值稳定性问题,此时再用混合精度工具来解决。

5. 调整学习率与优化器

如果梯度爆炸是原因,尝试降低学习率(比如从1e-3降到1e-4),或者更换优化器(比如SGD代替Adam,Adam在半精度下更容易出现数值问题)。

6. 使用PyTorch的调试工具

  • torch.autograd.detect_anomaly():自动检测反向传播中的NaN/inf来源,使用方法:
    with torch.autograd.detect_anomaly():
        loss.backward()
    

内容的提问来源于stack exchange,提问作者GeneC

火山引擎 最新活动