You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何让Python多进程共享PyTorch导入内存以避免内存浪费?

如何让Python多进程共享PyTorch导入内存以避免内存浪费?

这个问题确实挺棘手的,尤其是要启动大量进程的时候,PyTorch导入时的内存开销会被成倍放大。我来分享几个实用的思路和解决方法:

一、让多进程共享PyTorch导入内存的方法

1. 坚持使用fork启动方式(你已经做对了一半)

Python的multiprocessing中,fork模式会直接复制父进程的整个地址空间,包括已经导入的PyTorch模块。这里的关键是写时复制(Copy-On-Write)机制:只要子进程不修改共享的内存区域(比如PyTorch的全局对象、模块代码等),这些内存就会在所有子进程间共享,不会被重复分配。

你例子里已经设置了mp.set_start_method("fork"),这是正确的选择。从输出也能看出来:父进程导入PyTorch后内存增加了200MB左右,但子进程的内存只有134MB,这就说明大部分内存是和父进程共享的,没有重复占用。

2. 务必在父进程创建进程池前导入PyTorch

一定要把import torch放在创建mp.Pool之前,这样所有子进程启动时都会继承父进程已经加载好的PyTorch模块,不需要自己再重新导入一遍。如果在子进程函数里才导入,那每个子进程都会独立加载PyTorch,内存开销会直接乘以进程数,那才是真的浪费。

你的示例代码已经遵循了这个原则,这很好,继续保持就行。

二、排查PyTorch导入时内存占用的原因

如果想搞清楚到底是PyTorch的哪部分占用了这么多内存,可以试试这些工具:

1. 用tracemalloc跟踪内存分配

Python内置的tracemalloc可以精准记录导入过程中的内存分配情况,帮你找到内存开销最大的部分:

import tracemalloc

# 启动内存跟踪
tracemalloc.start()

# 导入PyTorch
import torch

# 生成内存快照并分析
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')

print("导入PyTorch时内存占用Top 10:")
for stat in top_stats[:10]:
    print(stat)

2. 用pympler分析对象内存占用

pympler库可以统计进程中所有Python对象的内存占用,能帮你看到PyTorch导入后创建了哪些大对象:

from pympler import muppy, summary

# 导入PyTorch
import torch

# 获取所有对象并生成统计
all_objects = muppy.get_objects()
memory_summary = summary.summarize(all_objects)

print("内存占用统计:")
summary.print_(memory_summary)

3. 查看PyTorch初始化逻辑

PyTorch在导入时会做不少初始化工作:加载CPU算子库、注册各类模块、初始化全局上下文等,这些都会占用内存。你可以去查看PyTorch源码里的__init__.py文件,或者相关核心模块的初始化代码,看看有没有可以通过环境变量或配置禁用的非必要项(比如CPU版可以试试设置TORCH_NO_COMPILE=1,不过这个参数不一定能大幅降低内存,仅供参考)。

补充说明

从你的示例输出来看,当前的fork模式已经在帮你共享内存了,子进程的内存开销远低于父进程导入PyTorch的增量。如果还想进一步优化,要尽量避免在子进程中修改PyTorch的全局对象(比如修改默认的dtype、设备等),否则会触发写时复制,导致共享内存被复制一份,增加内存占用。

备注:内容来源于stack exchange,提问作者Bojian Zheng

火山引擎 最新活动