PyTorch分布式训练随机种子设置方法及训练变慢问题咨询
一、种子设置是否需要移到main_worker中?
答案是必须移到main_worker里,而且要做进程专属的种子调整,原因如下:
主进程的种子设置无法保证子进程的随机性独立
你当前在main()里设置的种子,只会作用于主进程。通过mp.spawn创建的子进程虽然会继承主进程的随机状态,但如果所有进程用完全相同的种子,会导致数据加载、随机增强等操作在每个进程中完全一致——这会让分布式训练失去并行的意义(相当于多个进程在训练完全相同的数据)。正确的种子设置方式
在main_worker中,需要给每个进程的种子加上进程rank的偏移,区分模型初始化的全局随机性和数据处理的进程专属随机性:def main_worker(rank, args): # 1. 模型初始化的全局种子(所有进程保持一致,保证初始参数相同) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # 多GPU场景下更稳妥 # 2. 数据处理/增强的进程专属种子(每个进程不同,避免数据重复) local_seed = args.seed + rank np.random.seed(local_seed) random.seed(local_seed) # 如果用到Python标准库的random模块也要设置 # 后续初始化分布式环境、模型等 torch.distributed.init_process_group(backend='nccl', rank=rank, world_size=args.ngpus) # ... 你的训练代码另外,数据加载器还要配置
worker_init_fn,确保每个数据加载子进程的种子也独立:def worker_init_fn(worker_id): # 数据加载子进程的种子 = 全局种子 + 进程rank + 子进程ID worker_seed = args.seed + rank + worker_id np.random.seed(worker_seed) torch.manual_seed(worker_seed) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, worker_init_fn=worker_init_fn, # ... 其他参数 )
二、训练速度变慢一倍的原因分析
你遇到的速度下降问题,核心是cudnn.benchmark和cudnn.deterministic的冲突:
两个参数不能同时生效
cudnn.deterministic = True会强制CuDNN使用确定性的卷积算法,而cudnn.benchmark = True会让CuDNN自动测试并选择当前硬件下最快的卷积算法。这两个参数是互斥的——当你开启deterministic=True时,benchmark会被自动设置为False,此时CuDNN会放弃最快算法,转而使用确定性但更慢的实现,这直接导致训练速度大幅下降。你的旧代码中的矛盾
你在main()里同时设置了cudnn.benchmark = True和cudnn.deterministic = True,此时deterministic的优先级更高,benchmark其实已经失效了。但可能主进程的设置没有完全传递给子进程,导致之前的训练速度还能维持;当你把设置移到main_worker后,每个进程都严格开启了deterministic=True,速度就明显降下来了。解决速度问题的方案
根据你的需求二选一:- 追求速度(不需要严格复现):关闭
cudnn.deterministic,开启benchmark:cudnn.enabled = True cudnn.benchmark = True cudnn.deterministic = False - 追求复现性(接受速度下降):开启
cudnn.deterministic,关闭benchmark:cudnn.enabled = True cudnn.benchmark = False cudnn.deterministic = True
- 追求速度(不需要严格复现):关闭
总结
- 种子设置必须移到
main_worker中,并且区分全局种子(模型初始化)和进程专属种子(数据处理); - 训练速度变慢是因为
cudnn.deterministic开启导致的,根据需求调整CuDNN配置即可恢复速度。
内容的提问来源于stack exchange,提问作者sunshk1227




