如何构建更高效的DataLoader以加载大型图像数据集?
嘿,我仔细看了你的问题和代码实现,针对你遇到的CPU内存耗尽、初始化缓慢以及多GPU训练下的加载效率问题,给你整理了几个实用的优化方案,应该能帮你大幅提升数据加载的性能:
一、解决CPU内存耗尽问题
你的核心痛点之一是单样本单独存储为.pt文件,导致文件IO次数爆炸,加上多worker并行加载时,每个worker都会缓存大量小文件的内存,很快就把200GB内存吃光了。这里有几个关键优化点:
1. 合并小文件,减少IO开销
把分散的单样本文件打包成批量文件(比如每个文件包含1000个样本对),这样能大幅减少文件打开/关闭的次数,同时降低内存碎片化。你可以先写一个预处理脚本完成这件事:
# 预处理脚本示例(仅参考) import torch from pathlib import Path save_folder = Path(args.data_save_folder) batch_size_per_file = 1000 train_samples = 3000000 for batch_idx in range(0, train_samples, batch_size_per_file): end_idx = min(batch_idx + batch_size_per_file, train_samples) batch_data = [] for idx in range(batch_idx, end_idx): A = torch.load(save_folder/f"train_A_images_{idx}.pt") B = torch.load(save_folder/f"train_B_images_{idx}.pt") label = torch.load(save_folder/f"train_labels_{idx}.pt") batch_data.append((A, B, label)) torch.save(batch_data, save_folder/f"train_batch_{batch_idx//batch_size_per_file}.pt")
然后修改你的ImagePairDataset,让它先加载批量文件,再取对应索引的样本:
class ImagePairDataset(Dataset): def __init__(self, data_save_folder, dataset_name, num_samples, batch_size_per_file=1000, transform=None): self.data_save_folder = Path(data_save_folder) self.dataset_name = dataset_name self.num_samples = num_samples self.batch_size_per_file = batch_size_per_file self.num_batches = (num_samples + batch_size_per_file - 1) // batch_size_per_file self.transform = transform # 预缓存批量文件的路径(不用提前加载,只是存路径) self.batch_paths = [self.data_save_folder/f"{dataset_name}_batch_{i}.pt" for i in range(self.num_batches)] def __len__(self): return self.num_samples def __getitem__(self, idx): # 计算当前样本属于哪个批量文件 batch_idx = idx // self.batch_size_per_file intra_idx = idx % self.batch_size_per_file # 加载对应批量文件(可加简单缓存避免重复加载同一文件) batch_data = torch.load(self.batch_paths[batch_idx]) A_image, B_image, label = batch_data[intra_idx] if self.transform: A_image = self.transform(A_image) B_image = self.transform(B_image) return A_image, B_image, label
这样每个worker只会加载少量的批量文件,而不是成千上万的小文件,内存占用会骤降。
2. 优化DataLoader的内存参数
- 给
DataLoader加上pin_memory=True:它会把加载的数据锁在CPU内存里,加速后续向GPU的传输,同时避免内存页交换导致的额外开销。 - 调整
prefetch_factor:如果你的worker数量是12,试试prefetch_factor=2(每个worker预取2个batch),平衡预取量和内存占用。 - 控制worker数量:12个worker可能太多了,尤其是当你的磁盘IO不是SSD的时候,反而会因为竞争IO导致内存堆积。可以试试降到8或者6,观察内存和加载速度的变化。
二、解决初始化缓慢问题
你提到的“每个epoch前初始化慢”,本质上是每次epoch重启worker时,都要重新处理大量小文件的IO。结合上面的文件合并方案,再加上以下调整:
1. 启用persistent_workers=True并正确配置
你之前注释掉了这个参数,现在要把它加到DataLoader里:
def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=True, # 启用持久化worker,避免epoch间重复初始化 pin_memory=True, prefetch_factor=2 )
这个参数会让worker在epoch之间保持活跃,不用每次都重新初始化,大幅减少epoch间的等待时间。你的dataset是固定的,完全适用这个设置。
2. 避免重复创建Dataset
Lightning的setup函数默认只会在fit/test开始时执行一次,但如果你的代码有特殊逻辑导致重复创建dataset,可以把dataset的初始化移到__init__里,确保只创建一次:
class ImagePairDataModule(pl.LightningDataModule): def __init__(self, data_save_folder, train_samples, val_samples, test_samples, batch_size=32, num_workers=4): super().__init__() self.data_save_folder = data_save_folder self.train_samples = train_samples self.val_samples = val_samples self.test_samples = test_samples self.batch_size = batch_size self.num_workers = num_workers self.train_transform = augmentation self.eval_transform = normalize # 提前创建dataset,避免setup重复创建 self.train_dataset = ImagePairDataset(self.data_save_folder, 'train', self.train_samples, transform=self.train_transform) self.val_dataset = ImagePairDataset(self.data_save_folder, 'val', self.val_samples, transform=self.eval_transform) self.test_dataset = ImagePairDataset(self.data_save_folder, 'test', self.test_samples, transform=self.eval_transform) def setup(self, stage=None): # 这里不用再创建dataset了 pass
三、多GPU(DDP)下的Batch Size优化
你用4个GPU做DDP训练,要注意一个关键细节:Lightning中的batch_size参数是单GPU的batch size,也就是说总batch size是4 * batch_size。如果你的1024是总batch size,那要把batch_size设为256,这样每个GPU处理256个样本,避免单GPU内存溢出。
另外,DDP训练时,shuffle=True会自动保证每个GPU拿到不同的样本,Lightning的DataLoader会处理好这件事,你只要确保Dataset的__getitem__是线程安全的就行(上面的合并文件方案是线程安全的)。
额外的小技巧
- 用
torchvision.transforms.v2替代旧版transform:新版transform支持批量处理,能提升增强操作的效率,同时兼容旧版代码。 - 升级磁盘硬件:如果你的磁盘是HDD,IO速度会成为瓶颈,换成SSD或者NVMe对大数据集的加载速度提升非常明显。
- 内存映射加载:如果你的批量文件很大,可以用
torch.load(..., mmap=True),这样文件会以内存映射的方式加载,不用一次性把整个文件读到内存里,进一步降低内存占用。
备注:内容来源于stack exchange,提问作者Nick Nick Nick




