优化PyTorch DataLoader:全高清图像小补丁读取训练加速
很棒的问题!数据加载确实是小模型训练时的常见性能瓶颈——毕竟模型计算快,大部分时间都耗在IO和预处理上了。你提到的“一次读取多补丁”思路非常合理,能直接减少IO次数,下面我再分享几个针对性的优化方向,帮你进一步提升数据加载速度:
1. 内存缓存:避免重复读取完整图像
每次__getitem__都重新读取1920x1080的图像/EXR文件,是最大的IO开销来源。可以在Dataset中加入内存缓存,把已经读取过的完整图像存在字典里,后续再用到时直接从缓存取,不用重复读盘:
class Ours(data.Dataset): def __init__(self, data_dirpath, split_name, patch_size): super().__init__() # ... 原有初始化代码 ... self.cache = {} # 缓存已加载的完整图像/深度/掩码 def get_image(self, path: Path, patch_start_point: tuple): h, w = patch_start_point # 先检查缓存 if path not in self.cache: # 用更快的读取库替代skimage,比如OpenCV import cv2 image = cv2.imread(path.as_posix())[:, :, ::-1] # BGR转RGB self.cache[path] = image.astype(numpy.float32) / 255 * 2 - 1 # 直接从缓存裁剪 image = self.cache[path][h:h + self.patch_size, w:w + self.patch_size, :3] return numpy.moveaxis(image, [0,1,2], [1,2,0]) # 同理修改get_mask和get_depth,加入缓存逻辑
注意:如果数据集太大(比如上万张图),内存不够存所有缓存,可以用functools.lru_cache设置缓存上限,或者用磁盘缓存(提前把完整图像转成二进制格式存在临时文件夹)。
2. 替换更快的图像读取库
skimage.io.imread的读取速度不算快,换成OpenCV或PIL能显著提速:
- OpenCV:
cv2.imread是C++底层实现,速度最快,注意转RGB通道(默认BGR); - PIL:
Image.open(path).convert('RGB'),读取后转numpy数组也很方便。
对于EXR文件,除了OpenEXR,还可以试试pyexr库,它的API更简洁,读取速度也不错:
import pyexr def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray): h, w = patch_start_point if path not in self.cache: depth = pyexr.read(path.as_posix())[..., 0] # 取B通道(假设深度存在B通道) self.cache[path] = depth depth = self.cache[path][h:h+self.patch_size, w:w+self.patch_size][None] return depth.astype(numpy.float32) * mask
3. 优化DataLoader的多进程设置
你已经用了num_workers=4和pin_memory=True,这两个是基础操作,再加上这两个设置能进一步优化:
persistent_workers=True:PyTorch 1.7+支持,worker进程在epoch之间不会销毁,避免重复启动进程的开销;worker_init_fn:给每个worker设置独立的随机种子,避免多进程下随机裁剪的结果重复:
def worker_init_fn(worker_id): numpy.random.seed(numpy.random.get_state()[0] + worker_id) train_data_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4, persistent_workers=True, worker_init_fn=worker_init_fn )
4. 提前计算有效补丁位置,避免无效裁剪
你的Dataset注释提到“补丁必须包含至少一个未知像素”,但当前代码没有做这个检查——如果随机生成的补丁不符合条件,相当于白做了一次裁剪和读取。可以在__init__阶段提前对每个图像计算所有有效的补丁位置,后续直接从有效列表中随机选取:
class Ours(data.Dataset): def __init__(self, data_dirpath, split_name, patch_size): super().__init__() # ... 原有初始化代码 ... self.valid_patches = [] # 存储每个样本的有效补丁位置列表 for video_name, view_num in self.video_names: mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png' mask = skimage.io.imread(mask_path.as_posix()) # 计算所有满足条件的(h,w):补丁内至少有一个非零像素 valid_h = [] valid_w = [] for h in range(1080 - patch_size + 1): for w in range(1920 - patch_size + 1): if numpy.any(mask[h:h+patch_size, w:w+patch_size]): valid_h.append(h) valid_w.append(w) self.valid_patches.append(list(zip(valid_h, valid_w))) def __getitem__(self, index): video_name, view_num = self.video_names[index] # 从有效列表中随机选一个位置 patch_start_pt = numpy.random.choice(self.valid_patches[index]) # ... 后续原有代码 ...
这样既保证了补丁符合要求,又避免了在__getitem__中重复检查的开销。
5. 你的多补丁方案:进一步放大IO效率
你提到的“一次读取生成4个补丁”方案非常值得尝试,这里给你补充一个具体的实现思路:
- 修改
__getitem__返回包含4个补丁的字典,每个字段的形状是(4, C, H, W); - 自定义
collate_fn,把每个batch的4个补丁拼接成(BS*4, C, H, W)的大批次; - 同时把DataLoader的
batch_size缩小为原来的1/4,保证总批次大小不变。
示例代码:
def collate_fn(batch): # batch是一个列表,每个元素是包含4个补丁的字典 new_batch = {} for key in batch[0].keys(): # 把每个样本的4个补丁拼接:(BS,4,C,H,W) -> (BS*4,C,H,W) new_batch[key] = torch.cat([item[key] for item in batch], dim=0) return new_batch # 假设原来batch_size=16,现在用4 train_data_loader = DataLoader( dataset=train_dataset, batch_size=4, shuffle=True, pin_memory=True, num_workers=4, persistent_workers=True, collate_fn=collate_fn )
这个方案能把IO次数减少到原来的1/4,对于IO瓶颈的场景提升非常明显。
总结优化优先级
按照落地难度和收益排序,推荐你按这个顺序尝试:
- 替换更快的图像读取库(OpenCV/PIL + pyexr);
- 加入内存缓存;
- 开启
persistent_workers和worker_init_fn; - 提前计算有效补丁位置;
- 落地多补丁方案。
内容的提问来源于stack exchange,提问作者Nagabhushan S N




