如何在PyTorch训练循环中从ImageFolder加载器获取正负类图像批次?
如何在GAN训练循环中从ImageFolder加载的数据里分离正负类批次?
嘿,我来帮你搞定这个问题!你用torchvision.datasets.ImageFolder加载了包含Negatives和Positives两类的数据集,想要在GAN训练循环里分别拿到这两类的图像批次对吧?下面给你几种实用的解决方案,挑适合你的来用:
方法一:在迭代DataLoader时实时过滤批次样本
这是最直接的方法,每次从DataLoader拿到批次后,根据标签筛选出正负类样本:
import torch from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision import transforms # 先完成数据集和DataLoader的初始化 DATA_FOLDER = './tf_data/plasmodium_photos/' transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), ]) dataset = ImageFolder(root=DATA_FOLDER, transform=transform) dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # 先确认类别对应的标签值(ImageFolder按文件夹字母顺序分配标签) print("类别与标签映射:", dataset.class_to_idx) # 输出类似 {'Negatives': 0, 'Positives': 1},具体看你的文件夹名称顺序 # 训练循环中分离正负类批次 for images, labels in dataloader: # 根据标签筛选 neg_mask = (labels == 0) # 假设Negatives对应标签0 pos_mask = (labels == 1) # 假设Positives对应标签1 neg_batch = images[neg_mask] pos_batch = images[pos_mask] # 注意:如果批次里某类样本数量为0,要做判断避免后续报错 if len(neg_batch) > 0 and len(pos_batch) > 0: # 这里就可以用两个批次进行GAN训练了 # 比如把正类给判别器,负类或生成样本参与训练... pass
这种方法的好处是不用额外修改数据集结构,适合快速验证,但如果你的batch size较小,可能会出现某个批次里没有其中一类样本的情况,记得加判断处理。
方法二:拆分数据集为两个独立的DataLoader
如果你希望每次拿到的都是纯的正负类批次,可以把原数据集拆分成两个子集,再分别创建DataLoader:
from torch.utils.data import Subset # 获取两类样本的索引 neg_indices = [i for i, (_, label) in enumerate(dataset) if label == 0] pos_indices = [i for i, (_, label) in enumerate(dataset) if label == 1] # 创建正负类子集 neg_dataset = Subset(dataset, neg_indices) pos_dataset = Subset(dataset, pos_indices) # 创建各自的DataLoader neg_dataloader = DataLoader(neg_dataset, batch_size=32, shuffle=True) pos_dataloader = DataLoader(pos_dataset, batch_size=32, shuffle=True) # 训练时同时迭代两个DataLoader(用cycle处理两类数据量不一致的情况) from itertools import cycle for neg_batch, pos_batch in zip(cycle(neg_dataloader), cycle(pos_dataloader)): # neg_batch是纯负类图像,pos_batch是纯正类图像 # 执行你的GAN训练逻辑 pass
这种方法更稳定,每次拿到的批次都是单一类别,适合两类数据量差异不大的场景;如果其中一类数据量少,用cycle可以循环迭代它,保证训练能持续进行。
方法三:使用自定义Sampler加载指定类别
如果你不想拆分数据集,还可以通过自定义Sampler来让DataLoader只加载指定类别的样本:
from torch.utils.data.sampler import Sampler class ClassSampler(Sampler): def __init__(self, dataset, target_class): # 筛选出目标类别的所有样本索引 self.indices = [i for i, (_, label) in enumerate(dataset) if label == target_class] def __iter__(self): return iter(self.indices) def __len__(self): return len(self.indices) # 创建正负类对应的采样器 neg_sampler = ClassSampler(dataset, target_class=0) pos_sampler = ClassSampler(dataset, target_class=1) # 创建只加载对应类别的DataLoader neg_dataloader = DataLoader(dataset, batch_size=32, sampler=neg_sampler) pos_dataloader = DataLoader(dataset, batch_size=32, sampler=pos_sampler) # 训练时直接迭代两个DataLoader即可 for neg_batch in neg_dataloader: # 处理负类批次 pass for pos_batch in pos_dataloader: # 处理正类批次 pass
这种方法和拆分数据集的效果类似,但不需要创建子集,直接通过采样器控制加载的样本,灵活性更高。
额外提示
- 如果你同时在用TensorFlow和PyTorch,记得注意张量的设备转换(比如用
.to('cuda')或.cpu()切换设备),避免跨框架的张量操作报错。 - 可以通过
len(neg_dataset)和len(pos_dataset)查看两类数据的数量,方便调整batch size和训练策略。
内容的提问来源于stack exchange,提问作者Vladimir Fomene




