Pytorch DDP分布式数据合并通信 torch.distributed.all_gather()




import torch
import torch.distributed as dist
 
def all_gather_ddp(data):
    """
    使用PyTorch的torch.distributed.all_gather()函数,
    收集分布式训练中每个进程的数据。
    这个函数将data的副本发送到所有其他进程,并收集
    来自所有进程的这些副本,最后将它们连接起来。
    """
    # 确定当前进程的设备
    device = data.device
    
    # 所有进程都需要知道收集数据的总大小
    world_size = dist.get_world_size()
    all_sizes = torch.tensor([data.size(0)], dtype=torch.int64, device=device)
    dist.all_gather(all_sizes, all_sizes)
    
    # 计算所有进程发送的数据总大小
    max_size = all_sizes.max()
    
    # 对输入tensor进行扩展以容纳从其他进程收集的数据
    if data.dim() == 1:
        output = data.new_full((max_size,), fill_value=0)
    else:
        output = data.new_full((max_size, data.size(1)), fill_value=0)
    
    # 收集数据
    all_data = [data.new_zeros(size) for size in all_sizes]
    dist.all_gather(all_data, data)
    
    # 将所有收集到的数据拼接起来
    if data.dim() == 1:
        output[:data.size(0)] = data
        for i in range(world_size - 1):
            offset = all_sizes[:i].sum()
            output[offset:offset + all_sizes[i]] = all_data[i]
    else:
        for i in range(world_size):
            offset = i * data.size(0)
            output[offset:offset + all_sizes[i]] = all_data[i]
    
    return output
 
# 示例使用
# 假设已经初始化了进程组并设置了当前进程的 rank
# 以下代码在每个进程上执行
rank = dist.get_rank()
tensor_to_gather = torch.tensor([rank] * 5, dtype=torch.float32)
gathered_tensor = all_gather_ddp(tensor_to_gather)
print(f"进程 {rank} 收集到的数据: {gathered_tensor}")

这个代码示例提供了一个简化版本的all_gather_ddp函数,它可以在PyTorch的分布式数据并行(DDP)环境中使用。这个函数用于收集每个进程的数据,并将它们合并成一个包含所有进程数据的单一tensor。这对于在训练过程中收集每个模型参数的梯度或是每个batch的输出非常有用。

最后修改于:2024年08月09日 12:33

评论已关闭

推荐阅读

DDPG 模型解析,附Pytorch完整代码
2024年11月24日
DQN 模型解析,附Pytorch完整代码
2024年11月24日
AIGC实战——Transformer模型
2024年12月01日
Socket TCP 和 UDP 编程基础(Python)
2024年11月30日
python , tcp , udp
如何使用 ChatGPT 进行学术润色?你需要这些指令
2024年12月01日
AI
最新 Python 调用 OpenAi 详细教程实现问答、图像合成、图像理解、语音合成、语音识别(详细教程)
2024年11月24日
ChatGPT 和 DALL·E 2 配合生成故事绘本
2024年12月01日
omegaconf,一个超强的 Python 库!
2024年11月24日
【视觉AIGC识别】误差特征、人脸伪造检测、其他类型假图检测
2024年12月01日
[超级详细]如何在深度学习训练模型过程中使用 GPU 加速
2024年11月29日
Python 物理引擎pymunk最完整教程
2024年11月27日
MediaPipe 人体姿态与手指关键点检测教程
2024年11月27日
深入了解 Taipy:Python 打造 Web 应用的全面教程
2024年11月26日
基于Transformer的时间序列预测模型
2024年11月25日
Python在金融大数据分析中的AI应用(股价分析、量化交易)实战
2024年11月25日
AIGC Gradio系列学习教程之Components
2024年12月01日
Python3 `asyncio` — 异步 I/O,事件循环和并发工具
2024年11月30日
llama-factory SFT系列教程:大模型在自定义数据集 LoRA 训练与部署
2024年12月01日
Python 多线程和多进程用法
2024年11月24日
Python socket详解,全网最全教程
2024年11月27日