- 确定哪个节点导致了数据碰撞,可以使用以下代码:
import torch.distributed as dist
if dist.get_rank() == 0:
# define tensor to send
tensor_to_send = torch.ones([2, 2])
else:
tensor_to_send = torch.ones([2, 2]) * dist.get_rank()
dist.all_reduce(tensor_to_send)
- 确认哪些节点在 all_gather 中发生了碰撞。可以修改代码将 all_gather 的结果写入文件:
import os
import torch.distributed as dist
tensor_list = [torch.ones(2) * dist.get_rank()]
gathered = [torch.ones(2) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, tensor_list[dist.get_rank()])
# write the gathered results to a file
gathered_list = torch.stack(gathered, dim=0).tolist()
with open('gathered_results.txt', 'w') as f:
f.write(str(gathered_list))
然后可以查看 gathered_results.txt 文件,确定哪些结果发生了碰撞。
3. 在调试所有 all_gather 结果时,始终将结果写入文件以便查看结果。或者,可以使用 PyTorch 中的 DDP(分布式数据并行)来管理数据同步和通信。 DDP 可以在分布式训练中自动管理 Tensor 并确保所有操作的同步。使用 DDP 示例:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def setup(rank, world_size):
# initialize the process group
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist