PyTorch多DataLoader多数据集评估场景下如何避免进程反复创建与销毁?
PyTorch多DataLoader多数据集评估场景下如何避免进程反复创建与销毁?
兄弟,我太懂你这个痛点了!Windows下用PyTorch做多数据集多分辨率评估时,num_workers>0导致的进程反复创建销毁,简直是效率杀手——有时候进程启动销毁的时间比评估本身还长,太闹心了。而且你说的那个进程串行启动的问题,确实是Windows平台的坑,因为Windows的多进程机制和Linux不一样,每次启动子进程都要重新加载整个Python解释器和依赖,还得挨个来,CPU越多反而启动越慢。
针对你的场景,我有两个实用的解决方案,都是围绕复用进程池这个核心思路来的,不用反复折腾进程的创建销毁:
方案一:用全局multiprocessing.Pool统一管理进程,让所有DataLoader复用
这个思路是提前启动一批进程,让所有数据集的加载和transform都交给这个全局进程池处理,DataLoader本身只负责数据的批量整理,不再自己创建子进程。
import multiprocessing as mp import torch from torch.utils.data import IterableDataset, DataLoader # 提前启动全局进程池,进程数根据CPU情况调整(别拉满,留余量) global_pool = mp.Pool(processes=10) def load_and_transform(sample_path, transform): """通用数据加载+transform函数,适配所有分辨率""" # 这里替换成你实际的图片加载逻辑 img = ... # 从sample_path读取图片文件 # 应用对应分辨率的transform return transform(img) class PoolReuseDataset(IterableDataset): def __init__(self, sample_paths, transform): self.sample_paths = sample_paths self.transform = transform def __iter__(self): # 用全局进程池批量处理数据,所有DataLoader复用这些进程 task_list = [(path, self.transform) for path in self.sample_paths] # starmap会按顺序返回结果,imap_unordered是乱序的(速度更快,评估场景如果不要求顺序可以用这个) yield from global_pool.starmap(load_and_transform, task_list) # 准备不同分辨率的transform和样本路径 transform_256 = ... # 256分辨率的transform pipeline transform_512 = ... # 512分辨率的transform pipeline sample_paths_256 = [...] # 256分辨率样本的路径列表 sample_paths_512 = [...] # 512分辨率样本的路径列表 # 创建数据集和DataLoader(注意num_workers要设为0,因为我们用全局进程池了) dataset_256 = PoolReuseDataset(sample_paths_256, transform_256) dataloader_256 = DataLoader(dataset_256, batch_size=32, num_workers=0) dataset_512 = PoolReuseDataset(sample_paths_512, transform_512) dataloader_512 = DataLoader(dataset_512, batch_size=16, num_workers=0) # 开始评估 for model in your_model_list: model.eval() with torch.no_grad(): # 评估256分辨率 for batch in dataloader_256: outputs = model(batch) # 计算评估指标... # 评估512分辨率 for batch in dataloader_512: outputs = model(batch) # 计算评估指标... # 所有评估结束后,记得关闭全局进程池 global_pool.close() global_pool.join()
方案二:复用单个DataLoader的进程池(利用persistent_workers=True)
如果不想折腾全局进程池,也可以利用PyTorch自带的persistent_workers参数,只要不销毁DataLoader实例,进程就会一直存活。我们可以通过动态切换DataLoader的dataset属性,让它处理不同分辨率的数据集,全程复用同一批进程。
这个方案更贴近PyTorch原生用法,代码更简洁:
from torch.utils.data import Dataset, DataLoader class SingleResolutionDataset(Dataset): def __init__(self, sample_paths, transform): self.sample_paths = sample_paths self.transform = transform def __len__(self): return len(self.sample_paths) def __getitem__(self, idx): sample_path = self.sample_paths[idx] img = ... # 加载图片 return self.transform(img) # 准备不同分辨率的数据集 dataset_256 = SingleResolutionDataset(sample_paths_256, transform_256) dataset_512 = SingleResolutionDataset(sample_paths_512, transform_512) # 只创建一个DataLoader,设置persistent_workers=True shared_dataloader = DataLoader( dataset_256, batch_size=32, num_workers=10, persistent_workers=True, shuffle=False # 评估场景不需要shuffle ) # 开始评估 for model in your_model_list: model.eval() with torch.no_grad(): # 先评估256分辨率(当前dataset已经是256的,直接迭代) for batch in shared_dataloader: outputs = model(batch) # 计算指标... # 切换到512分辨率的数据集,进程池会复用,不需要重新启动 shared_dataloader.dataset = dataset_512 # 评估512分辨率 for batch in shared_dataloader: outputs = model(batch) # 计算指标...
额外的Windows平台优化小技巧
- 一定要把主逻辑放到
if __name__ == '__main__':里:Windows下的spawn多进程机制会让子进程重新加载整个脚本,如果主逻辑没加这个判断,会导致子进程无限创建,不仅慢还会崩溃。 - 进程数别拉满:比如12核CPU,设8-10个进程就好,留2-3个给模型推理和系统后台进程,避免CPU过载导致整体变慢。
- 把transform逻辑拆到单独模块:如果你的transform里有大的预训练模型(比如特征提取器),把它放到单独的py文件里,子进程加载时会更快,减少重复初始化的开销。
总结一下,不管用哪种方案,核心都是避免让每个DataLoader单独创建销毁进程,要么提前启动全局进程池统一管理,要么复用单个DataLoader的进程池。这两个方案都能解决你说的进程反复创建销毁的问题,也能避开Windows下进程串行启动的坑,提升评估的整体效率。




