如何让Python多进程共享PyTorch导入内存以避免内存浪费?
这个问题确实挺棘手的,尤其是要启动大量进程的时候,PyTorch导入时的内存开销会被成倍放大。我来分享几个实用的思路和解决方法:
一、让多进程共享PyTorch导入内存的方法
1. 坚持使用fork启动方式(你已经做对了一半)
Python的multiprocessing中,fork模式会直接复制父进程的整个地址空间,包括已经导入的PyTorch模块。这里的关键是写时复制(Copy-On-Write)机制:只要子进程不修改共享的内存区域(比如PyTorch的全局对象、模块代码等),这些内存就会在所有子进程间共享,不会被重复分配。
你例子里已经设置了mp.set_start_method("fork"),这是正确的选择。从输出也能看出来:父进程导入PyTorch后内存增加了200MB左右,但子进程的内存只有134MB,这就说明大部分内存是和父进程共享的,没有重复占用。
2. 务必在父进程创建进程池前导入PyTorch
一定要把import torch放在创建mp.Pool之前,这样所有子进程启动时都会继承父进程已经加载好的PyTorch模块,不需要自己再重新导入一遍。如果在子进程函数里才导入,那每个子进程都会独立加载PyTorch,内存开销会直接乘以进程数,那才是真的浪费。
你的示例代码已经遵循了这个原则,这很好,继续保持就行。
二、排查PyTorch导入时内存占用的原因
如果想搞清楚到底是PyTorch的哪部分占用了这么多内存,可以试试这些工具:
1. 用tracemalloc跟踪内存分配
Python内置的tracemalloc可以精准记录导入过程中的内存分配情况,帮你找到内存开销最大的部分:
import tracemalloc # 启动内存跟踪 tracemalloc.start() # 导入PyTorch import torch # 生成内存快照并分析 snapshot = tracemalloc.take_snapshot() top_stats = snapshot.statistics('lineno') print("导入PyTorch时内存占用Top 10:") for stat in top_stats[:10]: print(stat)
2. 用pympler分析对象内存占用
pympler库可以统计进程中所有Python对象的内存占用,能帮你看到PyTorch导入后创建了哪些大对象:
from pympler import muppy, summary # 导入PyTorch import torch # 获取所有对象并生成统计 all_objects = muppy.get_objects() memory_summary = summary.summarize(all_objects) print("内存占用统计:") summary.print_(memory_summary)
3. 查看PyTorch初始化逻辑
PyTorch在导入时会做不少初始化工作:加载CPU算子库、注册各类模块、初始化全局上下文等,这些都会占用内存。你可以去查看PyTorch源码里的__init__.py文件,或者相关核心模块的初始化代码,看看有没有可以通过环境变量或配置禁用的非必要项(比如CPU版可以试试设置TORCH_NO_COMPILE=1,不过这个参数不一定能大幅降低内存,仅供参考)。
补充说明
从你的示例输出来看,当前的fork模式已经在帮你共享内存了,子进程的内存开销远低于父进程导入PyTorch的增量。如果还想进一步优化,要尽量避免在子进程中修改PyTorch的全局对象(比如修改默认的dtype、设备等),否则会触发写时复制,导致共享内存被复制一份,增加内存占用。
备注:内容来源于stack exchange,提问作者Bojian Zheng




