Pytorch DDP分布式数据合并通信 torch.distributed.all_gather()
import torch
import torch.distributed as dist
def all_gather_ddp(data):
"""
使用PyTorch的torch.distributed.all_gather()函数,
收集分布式训练中每个进程的数据。
这个函数将data的副本发送到所有其他进程,并收集
来自所有进程的这些副本,最后将它们连接起来。
"""
# 确定当前进程的设备
device = data.device
# 所有进程都需要知道收集数据的总大小
world_size = dist.get_world_size()
all_sizes = torch.tensor([data.size(0)], dtype=torch.int64, device=device)
dist.all_gather(all_sizes, all_sizes)
# 计算所有进程发送的数据总大小
max_size = all_sizes.max()
# 对输入tensor进行扩展以容纳从其他进程收集的数据
if data.dim() == 1:
output = data.new_full((max_size,), fill_value=0)
else:
output = data.new_full((max_size, data.size(1)), fill_value=0)
# 收集数据
all_data = [data.new_zeros(size) for size in all_sizes]
dist.all_gather(all_data, data)
# 将所有收集到的数据拼接起来
if data.dim() == 1:
output[:data.size(0)] = data
for i in range(world_size - 1):
offset = all_sizes[:i].sum()
output[offset:offset + all_sizes[i]] = all_data[i]
else:
for i in range(world_size):
offset = i * data.size(0)
output[offset:offset + all_sizes[i]] = all_data[i]
return output
# 示例使用
# 假设已经初始化了进程组并设置了当前进程的 rank
# 以下代码在每个进程上执行
rank = dist.get_rank()
tensor_to_gather = torch.tensor([rank] * 5, dtype=torch.float32)
gathered_tensor = all_gather_ddp(tensor_to_gather)
print(f"进程 {rank} 收集到的数据: {gathered_tensor}")
这个代码示例提供了一个简化版本的all_gather_ddp
函数,它可以在PyTorch的分布式数据并行(DDP)环境中使用。这个函数用于收集每个进程的数据,并将它们合并成一个包含所有进程数据的单一tensor。这对于在训练过程中收集每个模型参数的梯度或是每个batch的输出非常有用。
评论已关闭