You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何从PyTorch的ImageFolder数据集中移除指定样本?支持按索引或名称吗?

如何从PyTorch的ImageFolder数据集中移除特定样本?

当然可以!ImageFolder本身没有提供直接的移除API,但它的内部核心结构是可手动修改的,我们可以通过操作它的关键属性来实现按索引或名称移除特定样本,下面给你详细拆解两种实现方式:

按索引移除样本

ImageFolder的samples属性是一个列表,每个元素是(图像完整路径, 对应标签)的元组;而targets属性是和样本一一对应的标签列表。要移除指定索引的样本,只需要同时删除这两个列表中对应位置的元素即可:

import torchvision

train_dir = "your/training/directory"
transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.ImageFolder(train_dir, transform=transform)

# 移除索引为5的单个样本(示例)
del train_data.samples[5]
del train_data.targets[5]

如果要批量移除多个索引,一定要按从大到小的顺序删除——因为从前往后删会导致后续样本的索引错位,比如先删索引3,原来的索引5会变成新的索引4,再删原索引5就会删错样本:

indices_to_remove = [10, 5, 3]
# 先排序为从大到小的顺序
for idx in sorted(indices_to_remove, reverse=True):
    del train_data.samples[idx]
    del train_data.targets[idx]

按图像名称(路径)移除样本

如果只知道要移除的图像文件名或完整路径,我们可以先遍历samples列表找到对应的索引,再执行删除操作:

import os

# 替换成你要移除的图像完整路径
target_image_path = "your/training/directory/class_a/bad_sample.jpg"
idx_to_remove = None

# 遍历样本列表,匹配目标路径
for idx, (img_path, _) in enumerate(train_data.samples):
    if img_path == target_image_path:
        idx_to_remove = idx
        break

# 找到索引后执行删除
if idx_to_remove is not None:
    del train_data.samples[idx_to_remove]
    del train_data.targets[idx_to_remove]
else:
    print("未找到指定路径的图像样本")

如果只想按文件名(而非完整路径)匹配,只需修改判断条件:

target_filename = "bad_sample.jpg"
if os.path.basename(img_path) == target_filename:
    idx_to_remove = idx
    break

额外注意事项

  • 这种修改是直接在原数据集对象上进行的,如果需要保留原始数据集,建议先做深拷贝:import copy; train_data_copy = copy.deepcopy(train_data),再对拷贝后的对象进行修改。
  • 修改完成后,数据集的__len__方法会自动返回更新后的样本数量,后续用DataLoader加载时会正常使用修改后的样本集。

内容的提问来源于stack exchange,提问作者Guy Barash

火山引擎 最新活动