如何调整与创建新Dataloader?以CIFAR10 Dataloader移除数据为例
嘿,我来帮你一步步搞定这两个问题~
一、调整与创建新Dataloader的通用方法
其实PyTorch里的Dataloader是基于Dataset对象工作的,所以调整或创建新Dataloader的核心,要么是修改Dataset的内容,要么是调整Dataloader的初始化参数。这里给你两种常见场景:
1. 直接基于现有Dataset创建新Dataloader
如果只是想调整加载参数(比如batch大小、是否打乱、加载线程数),直接用原Dataset重新初始化Dataloader就行:
# 假设你已经有了original_dataset这个Dataset对象 new_loader = torch.utils.data.DataLoader( original_dataset, batch_size=64, # 把原来的128改成64 shuffle=False, # 关闭随机打乱顺序 num_workers=4, # 增加4个后台加载线程,提升速度 drop_last=True # 丢弃最后一个不足batch大小的批次 )
2. 先修改Dataset再创建新Dataloader
如果需要调整数据本身(比如增删样本、修改transform),先处理Dataset,再创建Dataloader。比如给原数据集换个预处理流程:
from torchvision import transforms # 定义新的预处理流程 new_transform = transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 复制原数据集并替换transform modified_dataset = original_dataset modified_dataset.transform = new_transform # 创建新Dataloader new_loader = torch.utils.data.DataLoader(modified_dataset, batch_size=128, shuffle=True)
二、从已有CIFAR10 Dataloader中移除部分数据
注意:Dataloader本身只是加载器,不能直接修改它里面的数据。咱们得先拿到它对应的Dataset,筛选出要保留的样本,再用筛选后的Dataset创建新Dataloader。结合你给出的代码,具体步骤如下:
步骤1:获取原Dataset
先修改你现有的load_data_cifar10函数,让它能返回Dataset(后续处理更方便):
import torchvision import torch def load_data_cifar10(batch_size=128, test=False, return_dataset=False): # 补全代码里缺失的transform定义(用CIFAR10常用的标准化流程) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if not test: dset = torchvision.datasets.CIFAR10( root='/mnt/3CE35B99003D727B/input/pytorch/data', train=True, download=True, transform=transform ) else: dset = torchvision.datasets.CIFAR10( root='/mnt/3CE35B99003D727B/input/pytorch/data', train=False, download=True, transform=transform ) loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=True) print(f"LOAD DATA, {len(loader)} batches") if return_dataset: return loader, dset else: return loader # 调用函数同时拿到Dataloader和Dataset train_loader, train_dset = load_data_cifar10(return_dataset=True)
如果已经有现成的Dataloader,也可以直接从它里面提取Dataset:
original_dataset = train_loader.dataset
步骤2:筛选要保留的样本
PyTorch提供了torch.utils.data.Subset类,可以快速创建数据集的子集。这里给你几个常见的筛选场景:
场景1:按类别移除(比如移除所有“飞机”类别,CIFAR10中类别0是飞机)
# 遍历数据集,收集所有不属于类别0的样本索引 keep_indices = [idx for idx, (_, label) in enumerate(original_dataset) if label != 0] # 创建筛选后的子集数据集 filtered_dataset = torch.utils.data.Subset(original_dataset, keep_indices)
场景2:随机移除10%的样本
import random total_samples = len(original_dataset) # 保留90%的样本 keep_num = int(total_samples * 0.9) # 随机选择要保留的索引 keep_indices = random.sample(range(total_samples), keep_num) filtered_dataset = torch.utils.data.Subset(original_dataset, keep_indices)
场景3:移除前5000个样本
# 保留从第5000个到最后的样本 keep_indices = range(5000, len(original_dataset)) filtered_dataset = torch.utils.data.Subset(original_dataset, keep_indices)
步骤3:创建新的Dataloader
用筛选后的filtered_dataset创建新的Dataloader就行,参数可以按需调整:
new_train_loader = torch.utils.data.DataLoader( filtered_dataset, batch_size=128, shuffle=True, num_workers=2 ) print(f"NEW LOADER, {len(new_train_loader)} batches")
内容的提问来源于stack exchange,提问作者user11173832




