基于PyTorch的Video Vision Transformer(ViViT)自定义数据集训练时CrossEntropyLoss报multi-target不支持错误的求助
问题原因分析与解决方案
你遇到的 RuntimeError: multi-target not supported 错误,核心原因是目标标签的格式不符合CrossEntropyLoss的要求,同时代码中还存在几处其他关键问题,下面逐一拆解并给出修复方案:
1. 核心错误:目标标签格式错误
CrossEntropyLoss要求目标标签是1D整数张量(每个元素对应样本的类别索引,范围0~num_classes-1),但你的代码生成的标签是2D张量,导致PyTorch误判为多标签任务,从而抛出"multi-target not supported"错误。
修复:修正标签生成逻辑
在DatasetProcessing类的__getitem__方法中,替换np.where为直接获取索引的方式,并确保标签是整数类型的标量张量:
def __getitem__(self, index): video_label = self.video_list[index].split('/')[-2] video_frames, len_ = get_frames(self.video_list[index], n_frames=15) video_frames = np.asarray(video_frames) / 255.0 # 归一化到[0,1] class_list = ['Run', 'Walk', 'Wave', 'Sit', 'Turn', 'Stand'] # 直接获取类别索引(若标签不在列表中会抛出ValueError,便于排查脏数据) class_id = class_list.index(video_label) # 转换为符合要求的张量格式 data = torch.tensor(video_frames, dtype=torch.float32) label = torch.tensor(class_id, dtype=torch.long) # 必须是long类型 return (data, label)
2. 模型初始化参数错误
你的ViViT模型初始化时,类别数和帧数参数与实际数据集不匹配:
- 数据集只有6个类别,但你传入了
num_classes=100 - 代码中加载15帧,但模型初始化时传入
num_frames=16
修复:修正模型初始化
# 匹配数据集的6个类别和加载的15帧 model = ViViT(image_size=224, patch_size=16, num_classes=6, num_frames=15).cuda()
3. 损失函数使用错误
CrossEntropyLoss内部已经包含了LogSoftmax操作,你额外调用F.log_softmax会导致计算逻辑重复,进而影响损失计算的正确性。
修复:移除冗余的LogSoftmax
训练函数修正:
def train_epoch(model, optimizer, data_loader, loss_history, loss_func): total_samples = len(data_loader.dataset) model.train() for i, (data, target) in enumerate(data_loader): optimizer.zero_grad() # 调整张量形状并移至GPU data = data.cuda() data = rearrange(data, 'b p h w c -> b p c h w') target = target.cuda() # 获取模型原始logits(无需提前做softmax) pred = model(data.float()) # 直接用原始logits计算损失 loss = loss_func(pred, target) loss.backward() optimizer.step() if i % 100 == 0: progress = f'[{i * len(data):5}/{total_samples:5} ({100 * i / len(data_loader):3.0f}%)] Loss: {loss.item():6.4f}' print(progress) loss_history.append(loss.item())
评估函数修正:
def evaluate(model, data_loader, loss_history, loss_func): model.eval() total_samples = len(data_loader.dataset) correct_samples = 0 total_loss = 0 with torch.no_grad(): for data, target in data_loader: data = data.cuda() data = rearrange(data, 'b p h w c -> b p c h w') target = target.cuda() output = model(data.float()) loss = loss_func(output, target) _, pred = torch.max(output, dim=1) total_loss += loss.item() * data.size(0) # 按批次累加损失 correct_samples += pred.eq(target).sum() avg_loss = total_loss / total_samples loss_history.append(avg_loss) accuracy = 100.0 * correct_samples / total_samples print(f'\nAverage test loss: {avg_loss:.4f} Accuracy: {correct_samples:5}/{total_samples:5} ({accuracy:4.2f}%)\n')
4. 额外优化:帧提取逻辑改进
原get_frames函数可能提取到的帧数不稳定(比如视频长度不足时),且效率较低,建议优化为直接跳转到目标帧:
def get_frames(filename, n_frames=1): frames = [] v_cap = cv2.VideoCapture(filename) v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 确保提取到恰好n_frames帧 if v_len <= n_frames: # 视频帧数不足时,重复最后一帧补全 frame_indices = list(range(v_len)) + [v_len-1]*(n_frames - v_len) else: # 均匀选取n_frames帧 frame_indices = np.linspace(0, v_len-1, n_frames, dtype=np.int16) frame_dims = (224, 224) for idx in frame_indices: v_cap.set(cv2.CAP_PROP_POS_FRAMES, idx) success, frame = v_cap.read() if not success: # 读取失败时用最后一帧或空白帧填充 frame = frames[-1] if frames else np.zeros((frame_dims[0], frame_dims[1], 3), dtype=np.uint8) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.resize(frame, frame_dims) frames.append(frame) v_cap.release() return np.array(frames), v_len
总结
按照上述步骤修复后,你的代码应该可以正常运行:
- 标签格式符合CrossEntropyLoss的要求(1D整数张量)
- 模型参数与数据集匹配
- 损失函数计算逻辑正确
- 帧提取逻辑更稳定高效
内容的提问来源于stack exchange,提问作者Vaibhav Sah




