PyTorch简易文本生成器失效,损失持续发散问题求助
解决你的PyTorch文本生成器损失发散&生成异常问题
我帮你梳理下代码里的几个关键问题,这些正是导致损失异常、模型完全学不到东西的根源:
1. 损失函数与输出层的核心不匹配
你用了nn.NLLLoss(),但模型最后一层是Softmax()——这完全搭错了!
NLLLoss要求输入是对数概率(也就是LogSoftmax的输出),而Softmax输出的是0-1之间的普通概率。用NLLLoss计算时会对这些正数取负对数,结果必然是负数,而且随着模型输出越来越集中到某个类别,损失会变成离谱的负数(比如你看到的-399),完全失去了损失应有的指导意义。- 正确的搭配有两种:
- 把
Softmax()换成LogSoftmax(),保留NLLLoss()不变 - 直接用
nn.CrossEntropyLoss()(它已经整合了LogSoftmax和NLLLoss的功能)
- 把
2. 输入张量的错误构造
你的charTensor函数把字符转成one-hot张量再转成long类型,但nn.Embedding的输入只需要字符对应的索引整数,根本不需要one-hot!
- 比如
all_chars.index(x)得到的就是字符的索引值,直接把这个整数包装成(1,)形状的long tensor就行。你之前的做法相当于每次给Embedding层喂一个全0除了某一位的张量,导致Embedding完全无法学习到有效的字符语义信息。
3. 训练循环的不合理操作
- 你在每个样本后都调用
loss.backward(retain_graph=True),其实完全不需要保留计算图——每个样本的计算图是独立的,retain_graph=True只会白白占用内存。 epoch_loss累加时应该用loss.item(),不然你打印的是带计算图的张量累加值,既占内存又不直观。- 每次反向传播前要记得清零梯度,你之前只在epoch开头清了一次,会导致梯度累积混乱。
修改后的完整代码
import torch import torch.nn as nn # 假设你已经提前定义了以下变量: # num_chars = len(all_chars) # all_chars = 包含所有字符的列表 # train_str = 你的训练文本字符串 class RNN(nn.Module): def __init__(self, embed_size, hidden_size, num_chars): super(RNN, self).__init__() self.embeds = nn.Embedding(num_chars, embed_size) self.l1 = nn.Linear(embed_size, hidden_size) self.l2 = nn.Linear(hidden_size, hidden_size) self.l3 = nn.Linear(hidden_size, num_chars) self.relu = nn.ReLU() # 替换成LogSoftmax,适配NLLLoss self.log_softmax = nn.LogSoftmax(dim=1) def forward(self, inp): out = self.embeds(inp) out = self.relu(self.l1(out)) out = self.relu(self.l2(out)) out = self.l3(out) return self.log_softmax(out) # 初始化模型,传入num_chars参数 rnn = RNN(10, 50, num_chars) optimizer = torch.optim.Adam(rnn.parameters(), lr=0.002) criterion = nn.NLLLoss() # 重构charTensor,直接返回索引张量 def charTensor(x): idx = all_chars.index(x) return torch.tensor([idx], dtype=torch.long) for epoch in range(5): epoch_loss = 0.0 rnn.train() for i in range(len(train_str[:400])-1): inp = charTensor(train_str[i]) output = rnn(inp) target = charTensor(train_str[i+1]) loss = criterion(output, target) epoch_loss += loss.item() # 标准的反向传播流程 optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch+1} Loss:", epoch_loss) # 文本生成部分(推理时关闭梯度计算) rnn.eval() first_char = 'c' inp_t = charTensor(first_char) fin = first_char for i in range(10): with torch.no_grad(): next_t = rnn(inp_t) next_idx = torch.argmax(next_t).item() next_char = all_chars[next_idx] fin += next_char inp_t = charTensor(next_char) print(fin)
额外小建议
- 你现在的模型其实是个普通MLP,不是真正的RNN——如果要做文本生成,建议换成
nn.RNN或nn.LSTM,因为MLP无法捕捉文本的序列依赖关系,即使修复后能训练,生成效果也会很差。 - 如果调整后还是有发散情况,可以尝试把学习率降到0.001试试。
- 训练时尽量用批量数据,比单个样本训练的梯度更新更稳定。
内容的提问来源于stack exchange,提问作者Cracin




