You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

基于PyTorch的Video Vision Transformer(ViViT)自定义数据集训练时CrossEntropyLoss报multi-target不支持错误的求助

问题原因分析与解决方案

你遇到的 RuntimeError: multi-target not supported 错误,核心原因是目标标签的格式不符合CrossEntropyLoss的要求,同时代码中还存在几处其他关键问题,下面逐一拆解并给出修复方案:


1. 核心错误:目标标签格式错误

CrossEntropyLoss要求目标标签是1D整数张量(每个元素对应样本的类别索引,范围0~num_classes-1),但你的代码生成的标签是2D张量,导致PyTorch误判为多标签任务,从而抛出"multi-target not supported"错误。

修复:修正标签生成逻辑

DatasetProcessing类的__getitem__方法中,替换np.where为直接获取索引的方式,并确保标签是整数类型的标量张量:

def __getitem__(self, index):
    video_label = self.video_list[index].split('/')[-2]
    video_frames, len_ = get_frames(self.video_list[index], n_frames=15)
    video_frames = np.asarray(video_frames) / 255.0  # 归一化到[0,1]
    
    class_list = ['Run', 'Walk', 'Wave', 'Sit', 'Turn', 'Stand']
    # 直接获取类别索引(若标签不在列表中会抛出ValueError,便于排查脏数据)
    class_id = class_list.index(video_label)
    
    # 转换为符合要求的张量格式
    data = torch.tensor(video_frames, dtype=torch.float32)
    label = torch.tensor(class_id, dtype=torch.long)  # 必须是long类型
    
    return (data, label)

2. 模型初始化参数错误

你的ViViT模型初始化时,类别数和帧数参数与实际数据集不匹配

  • 数据集只有6个类别,但你传入了num_classes=100
  • 代码中加载15帧,但模型初始化时传入num_frames=16

修复:修正模型初始化

# 匹配数据集的6个类别和加载的15帧
model = ViViT(image_size=224, patch_size=16, num_classes=6, num_frames=15).cuda()

3. 损失函数使用错误

CrossEntropyLoss内部已经包含了LogSoftmax操作,你额外调用F.log_softmax会导致计算逻辑重复,进而影响损失计算的正确性。

修复:移除冗余的LogSoftmax

训练函数修正:

def train_epoch(model, optimizer, data_loader, loss_history, loss_func):
    total_samples = len(data_loader.dataset)
    model.train()
    for i, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        
        # 调整张量形状并移至GPU
        data = data.cuda()
        data = rearrange(data, 'b p h w c -> b p c h w')
        target = target.cuda()
        
        # 获取模型原始logits(无需提前做softmax)
        pred = model(data.float())
        
        # 直接用原始logits计算损失
        loss = loss_func(pred, target)
        
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            progress = f'[{i * len(data):5}/{total_samples:5} ({100 * i / len(data_loader):3.0f}%)] Loss: {loss.item():6.4f}'
            print(progress)
            loss_history.append(loss.item())

评估函数修正:

def evaluate(model, data_loader, loss_history, loss_func):
    model.eval()
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data = data.cuda()
            data = rearrange(data, 'b p h w c -> b p c h w')
            target = target.cuda()
            
            output = model(data.float())
            loss = loss_func(output, target)
            
            _, pred = torch.max(output, dim=1)
            total_loss += loss.item() * data.size(0)  # 按批次累加损失
            correct_samples += pred.eq(target).sum()
    
    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    
    accuracy = 100.0 * correct_samples / total_samples
    print(f'\nAverage test loss: {avg_loss:.4f} Accuracy: {correct_samples:5}/{total_samples:5} ({accuracy:4.2f}%)\n')

4. 额外优化:帧提取逻辑改进

get_frames函数可能提取到的帧数不稳定(比如视频长度不足时),且效率较低,建议优化为直接跳转到目标帧:

def get_frames(filename, n_frames=1):
    frames = []
    v_cap = cv2.VideoCapture(filename)
    v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # 确保提取到恰好n_frames帧
    if v_len <= n_frames:
        # 视频帧数不足时,重复最后一帧补全
        frame_indices = list(range(v_len)) + [v_len-1]*(n_frames - v_len)
    else:
        # 均匀选取n_frames帧
        frame_indices = np.linspace(0, v_len-1, n_frames, dtype=np.int16)
    
    frame_dims = (224, 224)
    for idx in frame_indices:
        v_cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        success, frame = v_cap.read()
        if not success:
            # 读取失败时用最后一帧或空白帧填充
            frame = frames[-1] if frames else np.zeros((frame_dims[0], frame_dims[1], 3), dtype=np.uint8)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, frame_dims)
        frames.append(frame)
    
    v_cap.release()
    return np.array(frames), v_len

总结

按照上述步骤修复后,你的代码应该可以正常运行:

  1. 标签格式符合CrossEntropyLoss的要求(1D整数张量)
  2. 模型参数与数据集匹配
  3. 损失函数计算逻辑正确
  4. 帧提取逻辑更稳定高效

内容的提问来源于stack exchange,提问作者Vaibhav Sah

火山引擎 最新活动