You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何解决PyTorch分布式训练中多进程重复打印日志的问题?

解决PyTorch分布式训练中日志重复打印的问题

我在做PyTorch分布式训练时也碰到过这个一模一样的问题——多个进程同时输出日志,导致控制台或者日志文件里全是重复内容,看着特别闹心。这里分享几个我亲测有效的解决方法:

方法一:仅让主进程(rank=0)输出日志

这是最简单直接的方案,PyTorch分布式环境中每个进程都有一个唯一的rank标识,通常rank=0的进程是主进程。我们只需要在日志打印的代码前加一个判断,只有主进程才执行日志输出:

import torch
import logging

# 初始化分布式环境(假设已完成此步骤)
torch.distributed.init_process_group(backend="nccl")

# 配置logging基础设置
logging.basicConfig(level=logging.INFO)

# 仅主进程打印日志
if torch.distributed.get_rank() == 0:
    logging.info(f"Epoch {epoch}, Loss: {loss.item()}")

这种方法简洁高效,不会产生冗余日志,适合只关心全局训练状态的场景。

方法二:为每个进程生成独立的日志文件

如果需要排查单个进程的问题,不想把所有进程的日志混在一起,可以让每个进程将日志输出到单独的文件中,文件名包含进程的rank

import torch
import logging

rank = torch.distributed.get_rank()

# 为每个进程配置独立的日志文件
logging.basicConfig(
    filename=f"training_log_rank_{rank}.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

# 正常打印日志,每个进程的日志会写入专属文件
logging.info(f"Batch {batch_idx}, Accuracy: {acc.item()}")

这样每个进程的日志独立存储,既避免了控制台混乱,也方便后续排查特定进程的问题。

方法三:自定义Logging过滤器实现进程级过滤

如果觉得在每个日志语句前加判断太繁琐,可以自定义一个Logging过滤器,自动过滤掉非主进程的日志:

import torch
import logging

class RankFilter(logging.Filter):
    def filter(self, record):
        # 仅允许rank=0的进程的日志通过
        return torch.distributed.get_rank() == 0

# 初始化分布式环境
torch.distributed.init_process_group(backend="nccl")

# 获取根logger并添加过滤器
logger = logging.getLogger()
logger.addFilter(RankFilter())
logger.setLevel(logging.INFO)

# 后续正常使用logging即可,非主进程的日志会被自动过滤
logging.info(f"Learning rate updated to {lr}")

这种方法更优雅,不需要修改大量日志打印代码,适合大型项目使用。


内容的提问来源于stack exchange,提问作者Kirk Walla

火山引擎 最新活动