You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

PyTorch DataLoader多进程未并行运行的原因及解决方法问询

问题描述

我编写了以下测试脚本:

from torch.utils.data import Dataset, DataLoader
import time
import multiprocessing as mp
import torch

class Sleep(Dataset):
    def __len__(self): return 20
    def __getitem__(self, i):
        import time, os
        time.sleep(1)
        return os.getpid()

if __name__ == "__main__":
    mp.set_start_method("fork", force=True)

    loader = DataLoader(Sleep(), batch_size=20, num_workers=10, persistent_workers=True)

    t0 = time.time()
    next(iter(loader))
    print("wall = ", time.time() - t0)   # should be ≈ 3-4 s, not 180 s

    next(iter(loader))
    print("wall = ", time.time() - t0)   # should be ≈ 3-4 s, not 180 s

    next(iter(loader))
    print("wall = ", time.time() - t0)   # should be ≈ 3-4 s, not 180 s

    import pdb; pdb.set_trace()

在均配备至少10核CPU的不同系统上运行时,每个batch的加载耗时均为20秒,说明数据加载处于串行状态而非多进程并行。请问这是什么原因?如何真正实现DataLoader的并行化?


原因分析与解决方案

核心原因

  1. fork启动方式的兼容性问题fork会直接复制父进程的地址空间,容易导致PyTorch内部的资源(如线程锁、状态变量)出现冲突,使得worker进程无法正常并行工作,最终退化为串行加载。
  2. 迭代器使用方式错误:每次调用iter(loader)都会重新创建迭代器,即便开启了persistent_workers,也会导致worker进程频繁重启,无法发挥并行优势。
  3. 数据集长度与batch_size不匹配__len__返回20,batch_size设为20,第一个迭代就取完所有样本,后续next操作本质是重新遍历整个数据集,进一步放大了串行问题。

解决方案

1. 更换进程启动方式为spawn

spawn是PyTorch推荐的跨平台启动方式,会创建全新的Python进程,避免父进程资源继承带来的冲突:

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

2. 正确复用DataLoader迭代器

创建一次迭代器后重复使用,让persistent_workers真正发挥作用:

t0 = time.time()
loader_iter = iter(loader)
# 后续直接复用该迭代器
next(loader_iter)

3. 调整数据集长度与batch_size比例

修改__len__为60(支持3个batch),保证迭代器能正常获取多个批次数据:

class Sleep(Dataset):
    def __len__(self): return 60
    def __getitem__(self, i):
        time.sleep(1)
        return os.getpid()

4. 移除__getitem__内的重复导入

将模块导入移到类外部,避免每次调用__getitem__都重复执行导入操作:

import os, time

class Sleep(Dataset):
    # ... 其余代码不变

修改后的完整验证代码
from torch.utils.data import Dataset, DataLoader
import time
import multiprocessing as mp
import torch
import os

class Sleep(Dataset):
    def __len__(self): return 60
    def __getitem__(self, i):
        time.sleep(1)
        return os.getpid()

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    loader = DataLoader(Sleep(), batch_size=20, num_workers=10, persistent_workers=True)

    t0 = time.time()
    loader_iter = iter(loader)
    
    # 第一个batch
    next(loader_iter)
    print("First batch wall time:", round(time.time() - t0, 2))  # 约2秒
    
    # 第二个batch
    next(loader_iter)
    print("Second batch wall time:", round(time.time() - t0, 2))  # 约3秒
    
    # 第三个batch
    next(loader_iter)
    print("Third batch wall time:", round(time.time() - t0, 2))  # 约4秒

运行后会看到每个batch的加载时间符合并行预期:10个worker同时处理,每个worker负责2个样本,单批次耗时约2秒,后续批次因persistent workers预加载,耗时进一步缩短。

内容的提问来源于stack exchange,提问作者user3180

火山引擎 最新活动