PyTorch中SWA训练模型时Loss突变为NaN的问题求助
解决PyTorch SWA训练时Loss突变为NaN的问题
从你的描述和代码来看,Loss突然爆炸到NaN的核心问题主要出在SWA的错误使用,同时你的Self Attention模块也存在潜在的数值稳定性bug,下面逐一分析并给出修复方案:
1. 修正SWA的使用逻辑(最关键的问题)
你当前在每个epoch的train和val阶段结束后都调用了optimizer.swap_swa_sgd(),这完全违背了SWA的正确使用流程:
swap_swa_sgd()是用来切换模型权重到SWA平均权重的操作,应该只在训练结束后、需要评估SWA模型性能时调用,而不是训练过程中频繁切换。- SWA的核心是在
swa_startepoch之后,定期调用optimizer.update_swa()来累积权重,而不是用swap。
修正后的训练函数
def train_model_withSWA(net,dataloaders_dict,criterion,optimizer,num_epochs): loss_list=[] acc_list=[] val_loss_list=[] val_acc_list=[] for epoch in tqdm(range(num_epochs)): print("Epoch{}/{}".format(epoch+1,num_epochs)) print("--------------------------") for phase in ["train","val"]: if phase=="train": net.train() # 只有训练阶段才考虑SWA的权重累积 if epoch >= optimizer.swa_start: # 每swa_freq个epoch更新一次SWA权重 if (epoch - optimizer.swa_start) % optimizer.swa_freq == 0: optimizer.update_swa() else: net.eval() epoch_loss=0.0 epoch_corrects=0 for inputs,labels in dataloaders_dict[phase]: optimizer.zero_grad() with torch.set_grad_enabled(phase=="train"): inputs=inputs.to(device) labels=labels.to(device) outputs=net(inputs) loss=criterion(outputs,labels) print("loss:",loss) _,preds=torch.max(outputs,1) if phase == "train": loss.backward() # 添加梯度裁剪,防止梯度爆炸 torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) optimizer.step() epoch_loss += loss.item()*inputs.size(0) epoch_corrects +=torch.sum(preds==labels.data) epoch_loss=epoch_loss/len(dataloaders_dict[phase].dataset) epoch_acc=epoch_corrects.double()/len(dataloaders_dict[phase].dataset) print("{} Loss:{:.4f} Acc:{:.4f}".format(phase,epoch_loss,epoch_acc)) if phase=="train": loss_list.append(epoch_loss) acc_list.append(epoch_acc.item()) else: val_loss_list.append(epoch_loss) val_acc_list.append(epoch_acc.item()) # 训练结束后,更新最后一次SWA权重并切换到SWA模型 optimizer.update_swa() optimizer.swap_swa_sgd() # 此时可以评估SWA模型的性能 print("SWA model evaluation completed.")
2. 修复Self Attention模块的bug与数值稳定性
你的Self Attention代码存在两个严重问题:
- 注释掉了
attention_map_T = self.softmax(S),导致attention_map使用未定义的变量,这会直接引发运行错误(你说不用SWA时正常,可能是粘贴代码时的失误)。 - 没有对注意力分数进行缩放,点积结果可能过大导致softmax饱和,进而引发数值爆炸。
修正后的Self Attention模块
class Self_Attention(nn.Module): """ Self-Attention Layer""" def __init__(self, in_dim): super(Self_Attention, self).__init__() self.query_conv = nn.Conv2d( in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.key_conv = nn.Conv2d( in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.value_conv = nn.Conv2d( in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = nn.Softmax(dim=-2) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): X = x.to(device) B, C, W, H = X.shape N = W * H proj_query = self.query_conv(X).view(B, -1, N).permute(0, 2, 1) # [B, N, C] proj_key = self.key_conv(X).view(B, -1, N) # [B, C, N] # 添加缩放因子,防止点积结果过大 scale = math.sqrt(C) S = torch.bmm(proj_query, proj_key) / scale # [B, N, N] # 恢复softmax操作 attention_map_T = self.softmax(S) attention_map = attention_map_T.permute(0, 2, 1) proj_value = self.value_conv(X).view(B, -1, N) # [B, C, N] o = torch.bmm(proj_value, attention_map.permute(0, 2, 1)) # [B, C, N] o = o.view(B, C, W, H) out = x + self.gamma * o return out, attention_map
3. 其他辅助优化建议
- 降低学习率:你当前base lr设为0.1,对于带注意力的模型来说可能过高,建议先尝试
base_opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9),swa_lr设为0.005。 - 检查输入数据:确保训练数据中没有NaN/inf值,虽然不用SWA时正常,但SWA的权重累积可能放大潜在的数值问题。
- 启用梯度检查:可以在训练初期添加
torch.autograd.detect_anomaly()来定位梯度爆炸的具体位置:with torch.autograd.detect_anomaly(): loss.backward()
内容的提问来源于stack exchange,提问作者yohei




