Post

DDP做强化学习分布式多机多卡训练加速

DDP做强化学习分布式多机多卡训练加速

前话

前文提到了ray.rllib来做分布式训练. 但对整个代码的改造会很复杂, 如果只需要简单的进行分布式多机多卡训练,还可以使用pytorch自带的DDP

DDP分布式同步原理

在反向传播后,通过高效的AllReduce操作同步所有GPU的梯度均值,确保各GPU使用相同的梯度更新本地模型参数,从而实现分布式训练的参数一致性. 以下为DDP训练的数据拆分示意图:

代码修改关键点

初始化分布式进程组

每个进程需要初始化通信后端(如NCCL)并获取全局信息(rankworld_size)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch.distributed as dist

os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

# 初始化进程组
dist.init_process_group(
    backend="nccl",  # 使用NCCL作为高性能通信后端(GPU场景, CPU用gloo)
    init_method="env://",  # 从环境变量获取MASTER_ADDR和MASTER_PORT
    # rank=0, # 可以自己配置rank和world_size
    # world_size=1,
)
rank = dist.get_rank()        # 当前进程的全局ID(0 ~ world_size-1)
world_size = dist.get_world_size()  # 总进程数(通常一个GPU一个进程)

包装模型为DDP模型

将策略网络(Policy Network)或价值网络(Value Network)用DistributedDataParallel封装。

1
2
3
4
5
from torch.nn.parallel import DistributedDataParallel as DDP

# 假设模型已在GPU上(需通过rank指定设备)
model = PolicyNetwork().cuda(rank)
ddp_model = DDP(model, device_ids=[rank])  # 关键修改:包装为DDP模型

确保各进程独立采集数据

强化学习的核心是环境交互,需为每个进程分配独立的环境和随机种子,避免数据冗余。

1
2
3
4
5
6
7
# 每个进程有不同的随机种子(例如用rank作为种子)
def make_env(seed):
    env = gym.make("Pendulum-v1")
    env.seed(seed)
    return env

env = make_env(seed=rank)  # 每个进程的env独立交互

梯度同步与参数更新

反向传播时,DDP自动同步梯度,无需手动操作。

1
2
3
4
5
6
7
8
# 训练循环(以策略梯度为例)
for episode in range(num_episodes):
    states, actions, rewards = collect_trajectory(env, ddp_model)
    loss = compute_loss(states, actions, rewards)

    optimizer.zero_grad()
    loss.backward()  # DDP自动在此处同步梯度
    optimizer.step()  # 各进程独立更新参数(因梯度已同步,参数保持一致)

处理分布式数据采样(如经验回放)

若使用经验回放缓冲区,需确保各进程独立填充和采样数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class DistributedReplayBuffer:
    def __init__(self):
        self.buffer = []

    def add(self, experience):
        # 各进程独立填充自己的缓冲区
        self.buffer.append(experience)

    def sample(self):
        # 各进程独立采样,无需同步
        return random.sample(self.buffer, batch_size)

# 每个进程维护独立的buffer
buffer = DistributedReplayBuffer()

启动多进程训练脚本

使用torch.multiprocessing.spawn启动多进程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch.multiprocessing as mp

def main_worker(rank, world_size):
    # 上述初始化、模型包装、训练代码
    pass

if __name__ == "__main__":
    world_size = 4  # 假设启动4个进程
    mp.spawn( # spawn 方法是在启动多进程的时候启用全新的python解释器环境
        main_worker,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )

关键注意事项

  • 参数一致性:DDP会自动在初始化时广播模型参数,确保所有进程初始一致。

  • 梯度聚合loss.backward()内部通过AllReduce同步梯度,无需手动调用通信操作。

  • 环境独立性:各进程的环境需有不同的随机种子,避免生成重复数据。

  • 优化器状态:优化器(如Adam)的动量参数会因同步梯度而自然保持一致。

This post is licensed under CC BY 4.0 by the author.