You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch中SWA训练模型时Loss突变为NaN的问题求助

解决PyTorch SWA训练时Loss突变为NaN的问题

从你的描述和代码来看,Loss突然爆炸到NaN的核心问题主要出在SWA的错误使用,同时你的Self Attention模块也存在潜在的数值稳定性bug,下面逐一分析并给出修复方案:

1. 修正SWA的使用逻辑(最关键的问题)

你当前在每个epoch的trainval阶段结束后都调用了optimizer.swap_swa_sgd(),这完全违背了SWA的正确使用流程:

  • swap_swa_sgd()是用来切换模型权重到SWA平均权重的操作,应该只在训练结束后、需要评估SWA模型性能时调用,而不是训练过程中频繁切换。
  • SWA的核心是在swa_start epoch之后,定期调用optimizer.update_swa()来累积权重,而不是用swap

修正后的训练函数

def train_model_withSWA(net,dataloaders_dict,criterion,optimizer,num_epochs):
    loss_list=[]
    acc_list=[]
    val_loss_list=[]
    val_acc_list=[]
    for epoch in tqdm(range(num_epochs)):
        print("Epoch{}/{}".format(epoch+1,num_epochs))
        print("--------------------------")
        for phase in ["train","val"]:
            if phase=="train":
                net.train()
                # 只有训练阶段才考虑SWA的权重累积
                if epoch >= optimizer.swa_start:
                    # 每swa_freq个epoch更新一次SWA权重
                    if (epoch - optimizer.swa_start) % optimizer.swa_freq == 0:
                        optimizer.update_swa()
            else:
                net.eval()
            epoch_loss=0.0
            epoch_corrects=0
            for inputs,labels in dataloaders_dict[phase]:
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase=="train"):
                    inputs=inputs.to(device)
                    labels=labels.to(device)
                    outputs=net(inputs)
                    loss=criterion(outputs,labels)
                    print("loss:",loss)
                    _,preds=torch.max(outputs,1)
                    if phase == "train":
                        loss.backward()
                        # 添加梯度裁剪,防止梯度爆炸
                        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
                        optimizer.step()
                epoch_loss += loss.item()*inputs.size(0)
                epoch_corrects +=torch.sum(preds==labels.data)
        
        epoch_loss=epoch_loss/len(dataloaders_dict[phase].dataset)
        epoch_acc=epoch_corrects.double()/len(dataloaders_dict[phase].dataset)
        print("{} Loss:{:.4f} Acc:{:.4f}".format(phase,epoch_loss,epoch_acc))
        if phase=="train":
            loss_list.append(epoch_loss)
            acc_list.append(epoch_acc.item())
        else:
            val_loss_list.append(epoch_loss)
            val_acc_list.append(epoch_acc.item())
    
    # 训练结束后,更新最后一次SWA权重并切换到SWA模型
    optimizer.update_swa()
    optimizer.swap_swa_sgd()
    # 此时可以评估SWA模型的性能
    print("SWA model evaluation completed.")

2. 修复Self Attention模块的bug与数值稳定性

你的Self Attention代码存在两个严重问题:

  • 注释掉了attention_map_T = self.softmax(S),导致attention_map使用未定义的变量,这会直接引发运行错误(你说不用SWA时正常,可能是粘贴代码时的失误)。
  • 没有对注意力分数进行缩放,点积结果可能过大导致softmax饱和,进而引发数值爆炸。

修正后的Self Attention模块

class Self_Attention(nn.Module):
    """ Self-Attention Layer"""
    def __init__(self, in_dim):
        super(Self_Attention, self).__init__()
        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = nn.Softmax(dim=-2)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        X = x.to(device)
        B, C, W, H = X.shape
        N = W * H
        
        proj_query = self.query_conv(X).view(B, -1, N).permute(0, 2, 1)  # [B, N, C]
        proj_key = self.key_conv(X).view(B, -1, N)  # [B, C, N]
        
        # 添加缩放因子,防止点积结果过大
        scale = math.sqrt(C)
        S = torch.bmm(proj_query, proj_key) / scale  # [B, N, N]
        
        # 恢复softmax操作
        attention_map_T = self.softmax(S)
        attention_map = attention_map_T.permute(0, 2, 1)
        
        proj_value = self.value_conv(X).view(B, -1, N)  # [B, C, N]
        o = torch.bmm(proj_value, attention_map.permute(0, 2, 1))  # [B, C, N]
        o = o.view(B, C, W, H)
        
        out = x + self.gamma * o
        return out, attention_map

3. 其他辅助优化建议

  • 降低学习率:你当前base lr设为0.1,对于带注意力的模型来说可能过高,建议先尝试base_opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9),swa_lr设为0.005。
  • 检查输入数据:确保训练数据中没有NaN/inf值,虽然不用SWA时正常,但SWA的权重累积可能放大潜在的数值问题。
  • 启用梯度检查:可以在训练初期添加torch.autograd.detect_anomaly()来定位梯度爆炸的具体位置:
    with torch.autograd.detect_anomaly():
        loss.backward()
    

内容的提问来源于stack exchange,提问作者yohei

火山引擎 最新活动