You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在PyTorch训练循环中从ImageFolder加载器获取正负类图像批次?

如何在GAN训练循环中从ImageFolder加载的数据里分离正负类批次?

嘿,我来帮你搞定这个问题!你用torchvision.datasets.ImageFolder加载了包含Negatives和Positives两类的数据集,想要在GAN训练循环里分别拿到这两类的图像批次对吧?下面给你几种实用的解决方案,挑适合你的来用:

方法一:在迭代DataLoader时实时过滤批次样本

这是最直接的方法,每次从DataLoader拿到批次后,根据标签筛选出正负类样本:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

# 先完成数据集和DataLoader的初始化
DATA_FOLDER = './tf_data/plasmodium_photos/'
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])
dataset = ImageFolder(root=DATA_FOLDER, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 先确认类别对应的标签值(ImageFolder按文件夹字母顺序分配标签)
print("类别与标签映射:", dataset.class_to_idx)
# 输出类似 {'Negatives': 0, 'Positives': 1},具体看你的文件夹名称顺序

# 训练循环中分离正负类批次
for images, labels in dataloader:
    # 根据标签筛选
    neg_mask = (labels == 0)  # 假设Negatives对应标签0
    pos_mask = (labels == 1)  # 假设Positives对应标签1
    
    neg_batch = images[neg_mask]
    pos_batch = images[pos_mask]
    
    # 注意:如果批次里某类样本数量为0,要做判断避免后续报错
    if len(neg_batch) > 0 and len(pos_batch) > 0:
        # 这里就可以用两个批次进行GAN训练了
        # 比如把正类给判别器,负类或生成样本参与训练...
        pass

这种方法的好处是不用额外修改数据集结构,适合快速验证,但如果你的batch size较小,可能会出现某个批次里没有其中一类样本的情况,记得加判断处理。

方法二:拆分数据集为两个独立的DataLoader

如果你希望每次拿到的都是纯的正负类批次,可以把原数据集拆分成两个子集,再分别创建DataLoader:

from torch.utils.data import Subset

# 获取两类样本的索引
neg_indices = [i for i, (_, label) in enumerate(dataset) if label == 0]
pos_indices = [i for i, (_, label) in enumerate(dataset) if label == 1]

# 创建正负类子集
neg_dataset = Subset(dataset, neg_indices)
pos_dataset = Subset(dataset, pos_indices)

# 创建各自的DataLoader
neg_dataloader = DataLoader(neg_dataset, batch_size=32, shuffle=True)
pos_dataloader = DataLoader(pos_dataset, batch_size=32, shuffle=True)

# 训练时同时迭代两个DataLoader(用cycle处理两类数据量不一致的情况)
from itertools import cycle

for neg_batch, pos_batch in zip(cycle(neg_dataloader), cycle(pos_dataloader)):
    # neg_batch是纯负类图像,pos_batch是纯正类图像
    # 执行你的GAN训练逻辑
    pass

这种方法更稳定,每次拿到的批次都是单一类别,适合两类数据量差异不大的场景;如果其中一类数据量少,用cycle可以循环迭代它,保证训练能持续进行。

方法三:使用自定义Sampler加载指定类别

如果你不想拆分数据集,还可以通过自定义Sampler来让DataLoader只加载指定类别的样本:

from torch.utils.data.sampler import Sampler

class ClassSampler(Sampler):
    def __init__(self, dataset, target_class):
        # 筛选出目标类别的所有样本索引
        self.indices = [i for i, (_, label) in enumerate(dataset) if label == target_class]
    
    def __iter__(self):
        return iter(self.indices)
    
    def __len__(self):
        return len(self.indices)

# 创建正负类对应的采样器
neg_sampler = ClassSampler(dataset, target_class=0)
pos_sampler = ClassSampler(dataset, target_class=1)

# 创建只加载对应类别的DataLoader
neg_dataloader = DataLoader(dataset, batch_size=32, sampler=neg_sampler)
pos_dataloader = DataLoader(dataset, batch_size=32, sampler=pos_sampler)

# 训练时直接迭代两个DataLoader即可
for neg_batch in neg_dataloader:
    # 处理负类批次
    pass

for pos_batch in pos_dataloader:
    # 处理正类批次
    pass

这种方法和拆分数据集的效果类似,但不需要创建子集,直接通过采样器控制加载的样本,灵活性更高。

额外提示

  • 如果你同时在用TensorFlow和PyTorch,记得注意张量的设备转换(比如用.to('cuda').cpu()切换设备),避免跨框架的张量操作报错。
  • 可以通过len(neg_dataset)len(pos_dataset)查看两类数据的数量,方便调整batch size和训练策略。

内容的提问来源于stack exchange,提问作者Vladimir Fomene

火山引擎 最新活动