PyTorch迁移学习教程中数据集代码及字典推导式疑问
PyTorch迁移学习数据集代码解析与字典推导式详解
一、整体代码功能概述
这段代码是PyTorch迁移学习任务中构建图像训练/验证数据集流水线的标准实现,针对hymenoptera_data(蜂类数据集,包含蜜蜂和胡蜂两类图像),核心作用是完成数据增强、归一化、批量加载等预处理步骤,为后续模型训练做准备。
二、逐模块解析
1. 数据变换配置(data_transforms)
定义了训练集和验证集各自的图像预处理流程:
- 训练集变换:
transforms.RandomResizedCrop(224):随机裁剪到224×224像素(模拟不同视角,增强数据多样性)transforms.RandomHorizontalFlip():随机水平翻转(进一步提升模型泛化能力)transforms.ToTensor():将PIL图像转为PyTorch张量(格式从HWC转为CHW,数值归一到0-1)transforms.Normalize(...):用ImageNet预训练模型的均值和标准差做归一化,让数据符合预训练模型的输入分布
- 验证集变换:
- 仅做固定尺寸的
Resize(256)+CenterCrop(224)(保证验证数据的一致性,避免引入随机因素),后续转张量和归一化步骤和训练集一致
- 仅做固定尺寸的
2. 构建图像数据集(image_datasets)
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
这是字典推导式的应用,等价于下面的普通循环代码:
image_datasets = {} for x in ['train', 'val']: # 拼接数据集路径:data/hymenoptera_data/train 或 data/hymenoptera_data/val dataset_path = os.path.join(data_dir, x) # 用ImageFolder加载数据集:自动按文件夹结构分类,文件夹名即为类别名 image_datasets[x] = datasets.ImageFolder(dataset_path, data_transforms[x])
datasets.ImageFolder是PyTorch内置的图像数据集加载器,要求数据集按类别文件夹/图像文件的结构存放。
3. 构建数据加载器(dataloaders)
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
同样是字典推导式,生成训练/验证集的数据加载器:
DataLoader负责将数据集打包成批量(batch_size=4),训练集开启shuffle=True打乱数据(避免模型学习到顺序规律),num_workers=4用4个进程并行加载数据(提升加载速度)
4. 统计数据集大小(dataset_sizes)
你困惑的{x: len(image_datasets[x]) for x in ['train', 'val']}就是Python字典推导式,是生成字典的简洁写法,完全等价于:
dataset_sizes = {} for x in ['train', 'val']: dataset_sizes[x] = len(image_datasets[x])
字典推导式核心语法:
{键表达式: 值表达式 for 循环变量 in 可迭代对象}
- 循环变量
x依次遍历['train', 'val']中的每个元素 - 每轮循环生成一组键值对:键是
x本身,值是对应数据集的样本数量(len(image_datasets[x])) - 最终生成的
dataset_sizes就是一个包含{'train': 训练集样本数, 'val': 验证集样本数}的字典
5. 获取类别名称(class_names)
class_names = image_datasets['train'].classes直接从训练集获取类别名称列表,ImageFolder会自动按文件夹名排序生成这个列表(比如这里会是['ants', 'bees'])
内容的提问来源于stack exchange,提问作者SayNo2Decaf




