PyTorch DataLoader多进程未并行运行的原因及解决方法问询
问题描述
我编写了以下测试脚本:
from torch.utils.data import Dataset, DataLoader import time import multiprocessing as mp import torch class Sleep(Dataset): def __len__(self): return 20 def __getitem__(self, i): import time, os time.sleep(1) return os.getpid() if __name__ == "__main__": mp.set_start_method("fork", force=True) loader = DataLoader(Sleep(), batch_size=20, num_workers=10, persistent_workers=True) t0 = time.time() next(iter(loader)) print("wall = ", time.time() - t0) # should be ≈ 3-4 s, not 180 s next(iter(loader)) print("wall = ", time.time() - t0) # should be ≈ 3-4 s, not 180 s next(iter(loader)) print("wall = ", time.time() - t0) # should be ≈ 3-4 s, not 180 s import pdb; pdb.set_trace()
在均配备至少10核CPU的不同系统上运行时,每个batch的加载耗时均为20秒,说明数据加载处于串行状态而非多进程并行。请问这是什么原因?如何真正实现DataLoader的并行化?
原因分析与解决方案
核心原因
fork启动方式的兼容性问题:fork会直接复制父进程的地址空间,容易导致PyTorch内部的资源(如线程锁、状态变量)出现冲突,使得worker进程无法正常并行工作,最终退化为串行加载。- 迭代器使用方式错误:每次调用
iter(loader)都会重新创建迭代器,即便开启了persistent_workers,也会导致worker进程频繁重启,无法发挥并行优势。 - 数据集长度与batch_size不匹配:
__len__返回20,batch_size设为20,第一个迭代就取完所有样本,后续next操作本质是重新遍历整个数据集,进一步放大了串行问题。
解决方案
1. 更换进程启动方式为spawn
spawn是PyTorch推荐的跨平台启动方式,会创建全新的Python进程,避免父进程资源继承带来的冲突:
if __name__ == "__main__": mp.set_start_method("spawn", force=True)
2. 正确复用DataLoader迭代器
创建一次迭代器后重复使用,让persistent_workers真正发挥作用:
t0 = time.time() loader_iter = iter(loader) # 后续直接复用该迭代器 next(loader_iter)
3. 调整数据集长度与batch_size比例
修改__len__为60(支持3个batch),保证迭代器能正常获取多个批次数据:
class Sleep(Dataset): def __len__(self): return 60 def __getitem__(self, i): time.sleep(1) return os.getpid()
4. 移除__getitem__内的重复导入
将模块导入移到类外部,避免每次调用__getitem__都重复执行导入操作:
import os, time class Sleep(Dataset): # ... 其余代码不变
修改后的完整验证代码
from torch.utils.data import Dataset, DataLoader import time import multiprocessing as mp import torch import os class Sleep(Dataset): def __len__(self): return 60 def __getitem__(self, i): time.sleep(1) return os.getpid() if __name__ == "__main__": mp.set_start_method("spawn", force=True) loader = DataLoader(Sleep(), batch_size=20, num_workers=10, persistent_workers=True) t0 = time.time() loader_iter = iter(loader) # 第一个batch next(loader_iter) print("First batch wall time:", round(time.time() - t0, 2)) # 约2秒 # 第二个batch next(loader_iter) print("Second batch wall time:", round(time.time() - t0, 2)) # 约3秒 # 第三个batch next(loader_iter) print("Third batch wall time:", round(time.time() - t0, 2)) # 约4秒
运行后会看到每个batch的加载时间符合并行预期:10个worker同时处理,每个worker负责2个样本,单批次耗时约2秒,后续批次因persistent workers预加载,耗时进一步缩短。
内容的提问来源于stack exchange,提问作者user3180




