DDP做强化学习分布式多机多卡训练加速
DDP做强化学习分布式多机多卡训练加速
前话
前文提到了ray.rllib来做分布式训练. 但对整个代码的改造会很复杂, 如果只需要简单的进行分布式多机多卡训练,还可以使用pytorch自带的DDP
DDP分布式同步原理
在反向传播后,通过高效的AllReduce操作同步所有GPU的梯度均值,确保各GPU使用相同的梯度更新本地模型参数,从而实现分布式训练的参数一致性. 以下为DDP训练的数据拆分示意图:
代码修改关键点
初始化分布式进程组
每个进程需要初始化通信后端(如NCCL)并获取全局信息(rank
, world_size
)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch.distributed as dist
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
# 初始化进程组
dist.init_process_group(
backend="nccl", # 使用NCCL作为高性能通信后端(GPU场景, CPU用gloo)
init_method="env://", # 从环境变量获取MASTER_ADDR和MASTER_PORT
# rank=0, # 可以自己配置rank和world_size
# world_size=1,
)
rank = dist.get_rank() # 当前进程的全局ID(0 ~ world_size-1)
world_size = dist.get_world_size() # 总进程数(通常一个GPU一个进程)
包装模型为DDP模型
将策略网络(Policy Network)或价值网络(Value Network)用DistributedDataParallel
封装。
1
2
3
4
5
from torch.nn.parallel import DistributedDataParallel as DDP
# 假设模型已在GPU上(需通过rank指定设备)
model = PolicyNetwork().cuda(rank)
ddp_model = DDP(model, device_ids=[rank]) # 关键修改:包装为DDP模型
确保各进程独立采集数据
强化学习的核心是环境交互,需为每个进程分配独立的环境和随机种子,避免数据冗余。
1
2
3
4
5
6
7
# 每个进程有不同的随机种子(例如用rank作为种子)
def make_env(seed):
env = gym.make("Pendulum-v1")
env.seed(seed)
return env
env = make_env(seed=rank) # 每个进程的env独立交互
梯度同步与参数更新
反向传播时,DDP自动同步梯度,无需手动操作。
1
2
3
4
5
6
7
8
# 训练循环(以策略梯度为例)
for episode in range(num_episodes):
states, actions, rewards = collect_trajectory(env, ddp_model)
loss = compute_loss(states, actions, rewards)
optimizer.zero_grad()
loss.backward() # DDP自动在此处同步梯度
optimizer.step() # 各进程独立更新参数(因梯度已同步,参数保持一致)
处理分布式数据采样(如经验回放)
若使用经验回放缓冲区,需确保各进程独立填充和采样数据。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class DistributedReplayBuffer:
def __init__(self):
self.buffer = []
def add(self, experience):
# 各进程独立填充自己的缓冲区
self.buffer.append(experience)
def sample(self):
# 各进程独立采样,无需同步
return random.sample(self.buffer, batch_size)
# 每个进程维护独立的buffer
buffer = DistributedReplayBuffer()
启动多进程训练脚本
使用torch.multiprocessing.spawn
启动多进程。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch.multiprocessing as mp
def main_worker(rank, world_size):
# 上述初始化、模型包装、训练代码
pass
if __name__ == "__main__":
world_size = 4 # 假设启动4个进程
mp.spawn( # spawn 方法是在启动多进程的时候启用全新的python解释器环境
main_worker,
args=(world_size,),
nprocs=world_size,
join=True
)
关键注意事项
参数一致性:DDP会自动在初始化时广播模型参数,确保所有进程初始一致。
梯度聚合:
loss.backward()
内部通过AllReduce
同步梯度,无需手动调用通信操作。环境独立性:各进程的环境需有不同的随机种子,避免生成重复数据。
优化器状态:优化器(如Adam)的动量参数会因同步梯度而自然保持一致。
This post is licensed under CC BY 4.0 by the author.