You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

优化PyTorch DataLoader:全高清图像小补丁读取训练加速

优化PyTorch大尺寸图像数据加载的实用方案

很棒的问题!数据加载确实是小模型训练时的常见性能瓶颈——毕竟模型计算快,大部分时间都耗在IO和预处理上了。你提到的“一次读取多补丁”思路非常合理,能直接减少IO次数,下面我再分享几个针对性的优化方向,帮你进一步提升数据加载速度:

1. 内存缓存:避免重复读取完整图像

每次__getitem__都重新读取1920x1080的图像/EXR文件,是最大的IO开销来源。可以在Dataset中加入内存缓存,把已经读取过的完整图像存在字典里,后续再用到时直接从缓存取,不用重复读盘:

class Ours(data.Dataset):
    def __init__(self, data_dirpath, split_name, patch_size):
        super().__init__()
        # ... 原有初始化代码 ...
        self.cache = {}  # 缓存已加载的完整图像/深度/掩码

    def get_image(self, path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        # 先检查缓存
        if path not in self.cache:
            # 用更快的读取库替代skimage,比如OpenCV
            import cv2
            image = cv2.imread(path.as_posix())[:, :, ::-1]  # BGR转RGB
            self.cache[path] = image.astype(numpy.float32) / 255 * 2 - 1
        # 直接从缓存裁剪
        image = self.cache[path][h:h + self.patch_size, w:w + self.patch_size, :3]
        return numpy.moveaxis(image, [0,1,2], [1,2,0])

    # 同理修改get_mask和get_depth,加入缓存逻辑

注意:如果数据集太大(比如上万张图),内存不够存所有缓存,可以用functools.lru_cache设置缓存上限,或者用磁盘缓存(提前把完整图像转成二进制格式存在临时文件夹)。

2. 替换更快的图像读取库

skimage.io.imread的读取速度不算快,换成OpenCVPIL能显著提速:

  • OpenCV:cv2.imread是C++底层实现,速度最快,注意转RGB通道(默认BGR);
  • PIL:Image.open(path).convert('RGB'),读取后转numpy数组也很方便。

对于EXR文件,除了OpenEXR,还可以试试pyexr库,它的API更简洁,读取速度也不错:

import pyexr
def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray):
    h, w = patch_start_point
    if path not in self.cache:
        depth = pyexr.read(path.as_posix())[..., 0]  # 取B通道(假设深度存在B通道)
        self.cache[path] = depth
    depth = self.cache[path][h:h+self.patch_size, w:w+self.patch_size][None]
    return depth.astype(numpy.float32) * mask

3. 优化DataLoader的多进程设置

你已经用了num_workers=4pin_memory=True,这两个是基础操作,再加上这两个设置能进一步优化:

  • persistent_workers=True:PyTorch 1.7+支持,worker进程在epoch之间不会销毁,避免重复启动进程的开销;
  • worker_init_fn:给每个worker设置独立的随机种子,避免多进程下随机裁剪的结果重复:
def worker_init_fn(worker_id):
    numpy.random.seed(numpy.random.get_state()[0] + worker_id)

train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    persistent_workers=True,
    worker_init_fn=worker_init_fn
)

4. 提前计算有效补丁位置,避免无效裁剪

你的Dataset注释提到“补丁必须包含至少一个未知像素”,但当前代码没有做这个检查——如果随机生成的补丁不符合条件,相当于白做了一次裁剪和读取。可以在__init__阶段提前对每个图像计算所有有效的补丁位置,后续直接从有效列表中随机选取:

class Ours(data.Dataset):
    def __init__(self, data_dirpath, split_name, patch_size):
        super().__init__()
        # ... 原有初始化代码 ...
        self.valid_patches = []  # 存储每个样本的有效补丁位置列表
        for video_name, view_num in self.video_names:
            mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png'
            mask = skimage.io.imread(mask_path.as_posix())
            # 计算所有满足条件的(h,w):补丁内至少有一个非零像素
            valid_h = []
            valid_w = []
            for h in range(1080 - patch_size + 1):
                for w in range(1920 - patch_size + 1):
                    if numpy.any(mask[h:h+patch_size, w:w+patch_size]):
                        valid_h.append(h)
                        valid_w.append(w)
            self.valid_patches.append(list(zip(valid_h, valid_w)))

    def __getitem__(self, index):
        video_name, view_num = self.video_names[index]
        # 从有效列表中随机选一个位置
        patch_start_pt = numpy.random.choice(self.valid_patches[index])
        # ... 后续原有代码 ...

这样既保证了补丁符合要求,又避免了在__getitem__中重复检查的开销。

5. 你的多补丁方案:进一步放大IO效率

你提到的“一次读取生成4个补丁”方案非常值得尝试,这里给你补充一个具体的实现思路:

  • 修改__getitem__返回包含4个补丁的字典,每个字段的形状是(4, C, H, W)
  • 自定义collate_fn,把每个batch的4个补丁拼接成(BS*4, C, H, W)的大批次;
  • 同时把DataLoader的batch_size缩小为原来的1/4,保证总批次大小不变。

示例代码:

def collate_fn(batch):
    # batch是一个列表,每个元素是包含4个补丁的字典
    new_batch = {}
    for key in batch[0].keys():
        # 把每个样本的4个补丁拼接:(BS,4,C,H,W) -> (BS*4,C,H,W)
        new_batch[key] = torch.cat([item[key] for item in batch], dim=0)
    return new_batch

# 假设原来batch_size=16,现在用4
train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    persistent_workers=True,
    collate_fn=collate_fn
)

这个方案能把IO次数减少到原来的1/4,对于IO瓶颈的场景提升非常明显。

总结优化优先级

按照落地难度和收益排序,推荐你按这个顺序尝试:

  1. 替换更快的图像读取库(OpenCV/PIL + pyexr);
  2. 加入内存缓存;
  3. 开启persistent_workersworker_init_fn
  4. 提前计算有效补丁位置;
  5. 落地多补丁方案。

内容的提问来源于stack exchange,提问作者Nagabhushan S N

火山引擎 最新活动