You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

PyTorch迁移学习教程中数据集代码及字典推导式疑问

PyTorch迁移学习数据集代码解析与字典推导式详解

一、整体代码功能概述

这段代码是PyTorch迁移学习任务中构建图像训练/验证数据集流水线的标准实现,针对hymenoptera_data(蜂类数据集,包含蜜蜂和胡蜂两类图像),核心作用是完成数据增强、归一化、批量加载等预处理步骤,为后续模型训练做准备。

二、逐模块解析

1. 数据变换配置(data_transforms

定义了训练集和验证集各自的图像预处理流程:

  • 训练集变换
    • transforms.RandomResizedCrop(224):随机裁剪到224×224像素(模拟不同视角,增强数据多样性)
    • transforms.RandomHorizontalFlip():随机水平翻转(进一步提升模型泛化能力)
    • transforms.ToTensor():将PIL图像转为PyTorch张量(格式从HWC转为CHW,数值归一到0-1)
    • transforms.Normalize(...):用ImageNet预训练模型的均值和标准差做归一化,让数据符合预训练模型的输入分布
  • 验证集变换
    • 仅做固定尺寸的Resize(256)+CenterCrop(224)(保证验证数据的一致性,避免引入随机因素),后续转张量和归一化步骤和训练集一致

2. 构建图像数据集(image_datasets

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'val']}

这是字典推导式的应用,等价于下面的普通循环代码:

image_datasets = {}
for x in ['train', 'val']:
    # 拼接数据集路径:data/hymenoptera_data/train 或 data/hymenoptera_data/val
    dataset_path = os.path.join(data_dir, x)
    # 用ImageFolder加载数据集:自动按文件夹结构分类,文件夹名即为类别名
    image_datasets[x] = datasets.ImageFolder(dataset_path, data_transforms[x])

datasets.ImageFolder是PyTorch内置的图像数据集加载器,要求数据集按类别文件夹/图像文件的结构存放。

3. 构建数据加载器(dataloaders

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
              for x in ['train', 'val']}

同样是字典推导式,生成训练/验证集的数据加载器:

  • DataLoader负责将数据集打包成批量(batch_size=4),训练集开启shuffle=True打乱数据(避免模型学习到顺序规律),num_workers=4用4个进程并行加载数据(提升加载速度)

4. 统计数据集大小(dataset_sizes

你困惑的{x: len(image_datasets[x]) for x in ['train', 'val']}就是Python字典推导式,是生成字典的简洁写法,完全等价于:

dataset_sizes = {}
for x in ['train', 'val']:
    dataset_sizes[x] = len(image_datasets[x])

字典推导式核心语法:

{键表达式: 值表达式 for 循环变量 in 可迭代对象}

  • 循环变量x依次遍历['train', 'val']中的每个元素
  • 每轮循环生成一组键值对:键是x本身,值是对应数据集的样本数量(len(image_datasets[x])
  • 最终生成的dataset_sizes就是一个包含{'train': 训练集样本数, 'val': 验证集样本数}的字典

5. 获取类别名称(class_names

class_names = image_datasets['train'].classes直接从训练集获取类别名称列表,ImageFolder会自动按文件夹名排序生成这个列表(比如这里会是['ants', 'bees']


内容的提问来源于stack exchange,提问作者SayNo2Decaf

火山引擎 最新活动