You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何调整与创建新Dataloader?以CIFAR10 Dataloader移除数据为例

嘿,我来帮你一步步搞定这两个问题~

一、调整与创建新Dataloader的通用方法

其实PyTorch里的Dataloader是基于Dataset对象工作的,所以调整或创建新Dataloader的核心,要么是修改Dataset的内容,要么是调整Dataloader的初始化参数。这里给你两种常见场景:

1. 直接基于现有Dataset创建新Dataloader

如果只是想调整加载参数(比如batch大小、是否打乱、加载线程数),直接用原Dataset重新初始化Dataloader就行:

# 假设你已经有了original_dataset这个Dataset对象
new_loader = torch.utils.data.DataLoader(
    original_dataset,
    batch_size=64,  # 把原来的128改成64
    shuffle=False,  # 关闭随机打乱顺序
    num_workers=4,  # 增加4个后台加载线程,提升速度
    drop_last=True  # 丢弃最后一个不足batch大小的批次
)

2. 先修改Dataset再创建新Dataloader

如果需要调整数据本身(比如增删样本、修改transform),先处理Dataset,再创建Dataloader。比如给原数据集换个预处理流程:

from torchvision import transforms

# 定义新的预处理流程
new_transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 复制原数据集并替换transform
modified_dataset = original_dataset
modified_dataset.transform = new_transform
# 创建新Dataloader
new_loader = torch.utils.data.DataLoader(modified_dataset, batch_size=128, shuffle=True)
二、从已有CIFAR10 Dataloader中移除部分数据

注意:Dataloader本身只是加载器,不能直接修改它里面的数据。咱们得先拿到它对应的Dataset,筛选出要保留的样本,再用筛选后的Dataset创建新Dataloader。结合你给出的代码,具体步骤如下:

步骤1:获取原Dataset

先修改你现有的load_data_cifar10函数,让它能返回Dataset(后续处理更方便):

import torchvision
import torch

def load_data_cifar10(batch_size=128, test=False, return_dataset=False):
    # 补全代码里缺失的transform定义(用CIFAR10常用的标准化流程)
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if not test:
        dset = torchvision.datasets.CIFAR10(
            root='/mnt/3CE35B99003D727B/input/pytorch/data', 
            train=True, 
            download=True, 
            transform=transform
        )
    else:
        dset = torchvision.datasets.CIFAR10(
            root='/mnt/3CE35B99003D727B/input/pytorch/data', 
            train=False, 
            download=True, 
            transform=transform
        )
    loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=True)
    print(f"LOAD DATA, {len(loader)} batches")
    if return_dataset:
        return loader, dset
    else:
        return loader

# 调用函数同时拿到Dataloader和Dataset
train_loader, train_dset = load_data_cifar10(return_dataset=True)

如果已经有现成的Dataloader,也可以直接从它里面提取Dataset:

original_dataset = train_loader.dataset

步骤2:筛选要保留的样本

PyTorch提供了torch.utils.data.Subset类,可以快速创建数据集的子集。这里给你几个常见的筛选场景:

场景1:按类别移除(比如移除所有“飞机”类别,CIFAR10中类别0是飞机)

# 遍历数据集,收集所有不属于类别0的样本索引
keep_indices = [idx for idx, (_, label) in enumerate(original_dataset) if label != 0]
# 创建筛选后的子集数据集
filtered_dataset = torch.utils.data.Subset(original_dataset, keep_indices)

场景2:随机移除10%的样本

import random

total_samples = len(original_dataset)
# 保留90%的样本
keep_num = int(total_samples * 0.9)
# 随机选择要保留的索引
keep_indices = random.sample(range(total_samples), keep_num)
filtered_dataset = torch.utils.data.Subset(original_dataset, keep_indices)

场景3:移除前5000个样本

# 保留从第5000个到最后的样本
keep_indices = range(5000, len(original_dataset))
filtered_dataset = torch.utils.data.Subset(original_dataset, keep_indices)

步骤3:创建新的Dataloader

用筛选后的filtered_dataset创建新的Dataloader就行,参数可以按需调整:

new_train_loader = torch.utils.data.DataLoader(
    filtered_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2
)
print(f"NEW LOADER, {len(new_train_loader)} batches")

内容的提问来源于stack exchange,提问作者user11173832

火山引擎 最新活动