供应链场景下百万级 SKU 异构分布式强化学习系统:RL-Infra 工程实践全解析
前言:为什么供应链需要强化学习,以及为什么它如此之难
供应链的补货决策表面上是一个预测问题——预测未来需求,然后计算安全库存和补货量。但现实远比这复杂。每一个SKU的库存水平、在途量、供应商交期波动、促销计划、季节因素之间存在高度耦合。一个SKU的缺货可能导致替代品的需求暴增,连锁反应横跨整个品类。传统的运筹优化方法在面对数百万SKU、数十个仓、数百个供应商的组合爆炸时,要么求解时间不可接受,要么不得不做大量简化假设而失去精度。
强化学习(Reinforcement Learning,RL)的引入,是为了让每个SKU拥有自己的决策智能体(agent),通过与仿真环境的大量交互学习最优补货策略。每个agent观察自己的实时特征(库存水位、预测需求、在途量、历史销量等),输出补货量决策。训练目标变为了最大化长期累积收益——在满足服务水平的同时最小化库存持有成本。
这个设想听起来优雅,但工程落地是另一回事。当SKU数量达到100万,每个agent需要与仿真环境交互约百步来完成一个episode(每步代表一天,一个episode覆盖数月的仿真周期),仿真本身是纯CPU密集计算(涉及库存逻辑、约束校验、多级联动),而模型训练又需要GPU的并行算力——一个典型的CPU-GPU异构计算场景就此诞生。这篇文章不会花太多篇幅讨论算法本身(PPO、GAE这些在教材里都能找到),而是聚焦于把这个系统跑起来、跑稳定、跑快、跑得可控的全部工程细节。
一、系统全局架构概览
整个系统在架构上采用了”Go做数据面,Python做计算面”的异构设计。这完全就是现实约束的产物。
仿真器是用Go实现的。供应链仿真需要处理大量的状态机逻辑——库存扣减、在途到货、过期淘汰、多仓调拨——这些逻辑天然适合用强类型、高并发、低GC停顿的语言来写。一个仿真实例在处理一个SKU的一个episode(通常100步左右,每步代表一天)时,涉及的逻辑分支非常密集,但几乎没有浮点矩阵运算。Go的goroutine在这种场景下的表现远优于Python的线程模型,而CGo的开销又使得把仿真逻辑嵌入Python进程不太现实。
训练器使用Python + PyTorch,这没什么好争论的。整个深度学习生态都在Python上,DDP、NCCL、CUDA kernel这些基础设施的可用性决定了learner端必须是Python。
两者之间通过gRPC通信。选择gRPC而非REST或者自定义TCP协议,核心原因有三:Protocol Buffers提供的强类型schema能避免大量运行时数据格式错误;HTTP/2的多路复用在高频小包通信场景下效率显著优于HTTP/1.1;streaming RPC为权重同步提供了天然的推送机制。
系统的角色划分如下:
Learner集群:4张GPU卡组成一个DDP组,运行PPO算法的前向推理和反向传播。Learner是整个系统的”大脑”,它接收collector汇总的trajectory数据,完成梯度计算和参数更新,然后把新权重推送出去。
Collector集群:8台collector机器,每台运行8个仿真worker进程,每个worker独占4个CPU核心(即每台机器需要32核以上),总共64个仿真worker、256核。Collector的职责是拿到最新的模型权重,用它来驱动仿真环境产生trajectory数据,然后把这些数据打包发回给learner。它同时承担了动作推理的职责——在collector端用CPU做模型推理(模型只有几百KB,ONNX更小,CPU推理完全可行),避免了每一步仿真都要和GPU端做一次网络往返。
权重同步服务:基于Redis的pub/sub机制,带版本号的权重分发系统。Learner训练完一轮就把新权重写入Redis,所有collector通过订阅channel得到通知,拉取最新权重。当然用Redis有个前提条件,你所同步的模型权重是一个小到带宽可以接受的。
监控体系:Prometheus采集各组件的metrics,Grafana做可视化。训练指标(loss、reward、entropy)通过TensorBoard或MLflow记录。
整个数据流是一个闭环:Collector采集 → Trajectory传输 → Learner训练 → 权重同步 → Collector更新权重 → 继续采集。在PPO的on-policy约束下,这个闭环的每一步都不能出错,否则数据的分布就会和当前策略不匹配,导致训练不收敛。
二、仿真层——系统瓶颈所在
为什么仿真是瓶颈
100万个SKU,这不是我们计算的上限,而是不同品类尽量合并后的结果。每个episode执行约100步仿真,每步仿真涉及:读取当前库存状态、应用补货决策、模拟当天销量扣减、处理在途到货、计算缺货和过期、更新状态、计算reward。这些操作全部是标量/小向量运算,无法利用GPU的SIMD并行性。我们的实测数据是:10万个SKU在32核上完成一步仿真需要约1.5秒,由此推算单核单SKU单步约480μs。100万SKU × 100步 × 480μs = 48,000秒(约13小时)的纯CPU计算量。我们的集群配置是8台collector机器、每台8个仿真worker、每个worker独占4核,总计256核。每个worker负责约1.5万个SKU,凭借4核并行,单步仿真约1.9秒,一个完整episode约3分钟。仿真虽然已经被大规模并行化,但仍然是整个训练循环中的主要瓶颈。
仿真服务的Go实现
仿真环境被封装为一个gRPC服务,对外暴露Reset、Step、GetInfo、Close四个RPC接口,严格遵循OpenAI Gym的交互范式。每个仿真实例内部维护一个EnvironmentManager,管理多个并行环境的生命周期。
1
2
3
4
5
6
type Environment interface {
Reset() (map[string][]float32, error)
Step(action map[string][]float32) (map[string][]float32, float32, bool, map[string]string, error)
GetInfo() map[string]string
Close() error
}
之所以observation和action使用map[string][]float32而非简单的[]float32,是因为我们计算的商品量太大,希望能一次性推送多个action,让仿真程序能并行计算更多的数据。另一个重要原因就是一些商品并不是独立的,他们会互相影响。
64个Worker的编排
8台collector机器,每台启动8个仿真进程。这些仿真进程不是直接由collector fork出来的子进程,而是独立的gRPC服务。Collector通过配置文件知道自己要连接哪些仿真服务的地址(例如同一台机器上的:50051到:50058)。这种进程隔离的设计有一个重要的好处:仿真进程崩溃不会导致collector崩溃。供应链仿真逻辑经常需要迭代修改(加一种新的约束、调整过期逻辑),修改后的仿真进程可以独立重启,不影响collector的运行。
每台collector机器的内存规划也是一个需要认真对待的问题。每个仿真worker是一个独立的gRPC进程,拥有4个CPU核心,负责约1.5万个SKU。每个SKU的输入数据(历史销量、预测、参数等)约占10KB,一个worker的基础数据量在150MB左右。但仿真过程中会产生大量中间状态(在途订单、库存快照等),实际运行时每个仿真进程的内存消耗在200400MB左右。一台collector机器8个worker进程,峰值内存在23GB左右,CPU需求为32核以上(8 worker × 4核),这在现代服务器上不是问题。但需要注意的是,Go的GC在内存使用量大时可能导致较长的STW停顿。实际部署中我们将GOGC设置为200(默认100),以减少GC频率,用内存换延迟。
仿真数据的分配策略
100万SKU如何分配到64个worker上?最朴素的做法是均分,但这忽略了SKU之间计算复杂度的差异。一个高周转SKU(日销数千件)的仿真步计算量远大于一个长尾SKU(月销几件),因为前者涉及更频繁的库存变动和更复杂的补货触发逻辑。如果均匀分配,某些worker会显著慢于其他worker,整个episode的完成时间由最慢的worker决定(木桶效应)。
解决方案是在训练开始前做一次profiling pass:对所有SKU的仿真做一次试运行,记录每个SKU的平均单步耗时,然后用贪心算法(类似多机调度问题的近似解法)将SKU分配到各个worker,使得各worker的总计算时间尽量均衡。这个profiling只需在SKU集合变化或仿真逻辑修改后重做一次,开销可以接受。
实际操作中还有一种更灵活的方式:将SKU按计算复杂度排序后,用round-robin方式交替分配,这样即使个体差异较大,每个worker分到的”重”和”轻”SKU也大致均衡。我们在实验中发现,round-robin方式在不需要额外profiling的情况下,能将最慢worker与最快worker的时间差控制在15%以内,对于大多数场景已经足够。
三、Collector层——CPU端的推理与数据搬运
Collector的双重角色
在异构架构中,collector承担了两个关键职责:
第一,驱动仿真
Collector从仿真服务获取当前observation,将observation送入本地持有的策略模型做推理,得到action,再将action送回仿真服务执行一步。这个循环在每个episode中重复约100次。
第二,数据整理
一个完整的trajectory包含每一步的(state, action, reward, next_state, done, log_prob, value),collector需要把这些数据按正确的格式组装成RolloutBuffer,用于PPO的训练。
1
2
3
4
5
6
7
8
type Trajectory struct {
Observations []map[string][]float32
Actions []map[string][]float32
Rewards []float32
Dones []bool
Values []float32
LogProbs []float32
}
在Collector端做推理——一个关键的架构决策
一个常见的替代方案是让collector只做数据收集,把observation发给learner端做推理,拿回action再送给仿真。这种”远程推理”模式在很多RL框架中被采用(比如早期的RLlib),但对于我们的场景来说是不可接受的。
原因在于延迟
一个episode有100步,每一步都需要一次推理。如果推理在远端GPU上做,那每步需要一次网络往返,即使在同一个数据中心内(RTT约0.5ms),100步就是50ms。但真正的问题不在单次延迟,而在排队效应。64个worker同时发送推理请求,GPU端需要做batched inference来提高吞吐,这意味着请求需要等待凑批。凑批的窗口时间(通常几十毫秒)乘以100步,延迟会膨胀到不可接受的程度。
更根本的矛盾是:仿真是严格串行的(第N步的action依赖第N-1步的state),无法做流水线化。每一步都必须等推理结果返回后才能执行下一步仿真,这使得网络延迟成为不可掩盖的串行瓶颈。
所以我们选择在collector端做CPU推理。我们的模型是独立的Policy-Value双网络结构(两层隐藏层、每层256个神经元),总参数量约14万,模型文件约600KB。在CPU上对一个batch的SKU做一次前向推理不到1ms,远快于网络往返。64个worker各自独立做推理,天然并行,无需凑批。代价是每台collector机器需要加载一份模型权重到内存,但600KB乘以8个进程也不过5MB,完全可以承受。当机器数量增多或模型增大后,Redis就需要替换为其他存储介质来承接传输(比如S3或通过gRPC直接推送),Redis退化为只广播版本号信号,降低通信协议复杂度。
但这引出了另一个问题:模型权重的同步。
权重同步机制
在PPO这样的on-policy算法中,collector用来采集数据的策略必须和learner正在训练的策略一致。如果collector用旧版本的权重采集了数据,这些数据对当前策略来说就是off-policy的,直接用来更新会引入偏差。
权重同步的设计基于Redis的pub/sub机制。整个流程如下:
Learner完成一轮参数更新后,将新权重序列化为字节流,连同版本号一起写入Redis。权重数据使用
rl:weights:{model_key}作为key,版本号使用rl:version:{model_key}。Learner同时向
rl:updates:{model_key}channel发布一条消息,payload就是新的版本号。每个collector都有一个后台goroutine在订阅这个channel。收到消息后,它比较收到的版本号和自己当前持有的版本号,如果更新,就从Redis拉取新权重并加载。
所有操作通过Redis Pipeline批量执行,减少网络往返。
1
2
3
4
5
6
7
8
9
func (r *RedisSyncer) SetWeights(ctx context.Context, key string, data *WeightData) error {
jsonData, _ := json.Marshal(data)
pipe := r.client.Pipeline()
pipe.Set(ctx, r.weightKey(key), jsonData, 0)
pipe.Set(ctx, r.versionKey(key), data.Version, 0)
pipe.Publish(ctx, r.channelKey(key), data.Version)
_, err = pipe.Exec(ctx)
return err
}
一个值得讨论的工程细节是权重的序列化格式。最初的实现直接使用JSON序列化PyTorch的state_dict(通过torch.save保存到BytesIO再base64编码),600KB的模型权重在JSON化之后膨胀到约900KB(base64编码本身有33%的膨胀,加上JSON的结构开销)。虽然对于当前这个量级的模型来说差异不大,但出于规范性我们改用Protocol Buffers的bytes字段直接传输torch.save的二进制输出,体积保持在650KB左右。当前模型的权重同步通过Redis完成,延迟在毫秒级,不是瓶颈。但如果未来模型规模增长(比如引入更复杂的特征提取器或注意力机制),权重可能增大到数十MB,届时需要将权重存储迁移到S3或通过gRPC streaming直接传输,Redis只作为版本号广播的信号通道。
权重同步的时序问题
一个微妙的问题是:在PPO的强 on-policy要求下,collector采集数据时使用的权重版本必须和learner当前训练使用的权重版本完全一致。但由于网络延迟和处理时间的存在,权重同步不可能是瞬时的。
我们的解决方案是引入”同步屏障”(sync barrier)。训练流程被组织为严格的同步轮次:
Learner通知所有collector开始采集(附带当前权重版本号V)
每个collector检查自己持有的权重版本是否为V,如果不是,等待直到获取到版本V的权重
所有collector使用版本V的权重完成采集
所有collector将trajectory数据上传到learner
Learner汇总数据,执行K个epoch的PPO更新,产生版本V+1的权重
Learner将版本V+1的权重推送到Redis
回到步骤1
这种严格同步的方式虽然牺牲了一些并行度(collector在等待权重更新时是空闲的),但保证了数据的on-policy特性。在实际测量中,权重同步的等待时间约占总训练时间的3~5%,是可以接受的。
这里也调研过能否让on-policy没有那么“强”。一种优化手段是”双缓冲”:collector在使用当前权重V采集数据的同时,后台goroutine预先下载权重V+1(如果learner已经训练完并推送了的话)。这样当采集完成后,新权重可能已经准备好了,等待时间趋近于零。但在严格on-policy场景下,这个优化只在learner训练速度快于collector采集速度时才有效。对于我们的场景(仿真是瓶颈,learner通常先完成),这个优化效果有限。当然如果PPO能切换到近on-policy的模型,就能进一步提高GPU和 CPU利用率。
四、Learner层——GPU训练的核心
DDP + NCCL:4卡并行训练
Learner采用PyTorch的DistributedDataParallel(DDP),后端通信使用NCCL。DDP的工作原理在此不再赘述,但有几个在我们的场景下需要特别注意的地方。
首先是初始化。DDP要求所有参与训练的进程在启动时执行dist.init_process_group。我们使用torchrun(PyTorch的弹性分布式启动工具)来管理进程组的创建:
1
2
3
4
5
6
7
8
class DistributedTrainingContext(TrainingContext):
def _setup(self):
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend)
self.rank = dist.get_rank()
if torch.cuda.is_available():
self._device = torch.device(f"cuda:{self.rank}")
torch.cuda.set_device(self._device)
NCCL之所以是必选项而非GLOO,是因为NCCL专门为NVIDIA GPU间的集合通信优化(AllReduce、Broadcast等),在多GPU场景下的带宽利用率远高于GLOO。对于4张V100 32GB GPU,NCCL能充分利用NVLink或PCIe的带宽做梯度同步,AllReduce的开销几乎可以被计算所掩盖。
PPO训练的具体实现
PPO的核心是裁剪的策略梯度目标函数,加上价值函数损失和熵正则化:
```plain text L = L_clip + c1 * L_value - c2 * H[π]
1
2
3
4
5
6
7
8
9
10
11
12
13
其中`L_clip`使用importance sampling ratio的裁剪版本来避免策略更新过大,`L_value`使用均方误差拟合状态价值函数,`H[π]`是策略的熵用于鼓励探索。
在我们的实现中,一次PPO更新包含K=10个epoch的小批量梯度下降(mini-batch大小为1024,学习率3e-4,clip\_ratio=0.2)。这里有一个容易被忽视的性能细节:在每个epoch开始时,不应重新从CPU内存加载数据到GPU,而应该在第一个epoch时就把所有数据搬到GPU显存并保持在那里。100万SKU × 100步 × 每步约15个float32(状态10维 + action、reward、done、log\_prob、value各1维)= 约6GB的数据。4张GPU意味着每张GPU分到1.5GB数据,对于32GB显存的V100来说游刃有余。
```python
def update_policy(self, policy, optimizer, buffer, context):
states, actions, returns, advantages, old_log_probs = self.prepare_data(buffer, context, policy)
for _ in range(self.epochs):
optimizer.zero_grad(set_to_none=True)
loss, metrics = self.compute_loss(states, actions, returns, advantages, old_log_probs, policy)
loss.backward()
optimizer.step()
zero_grad(set_to_none=True)是一个值得注意的优化:将梯度置为None而非零,可以避免一次不必要的memset操作,在参数量较大时能节省几个百分点的训练时间。
Advantage的归一化问题
在PPO中,advantage的归一化(减均值除标准差)对训练稳定性至关重要。但在DDP场景下,每张GPU只持有全局数据的1/4,局部计算的均值和标准差可能与全局值存在偏差。正确的做法是先通过dist.all_reduce汇总所有GPU上的advantage的总和与平方总和,计算全局均值和标准差,然后在各GPU上分别做归一化。如果跳过这一步,直接在各GPU上做局部归一化,训练仍然能跑,但收敛速度会变慢,最终性能也会打折扣。
这是一个在小规模实验中很难发现的问题——当只有一张GPU或数据量较小时,局部归一化和全局归一化的差异微乎其微,但当数据量达到百万级且分布在多张GPU上时,差异就会显现。我们在最初的上线过程中观察到reward曲线在某个值附近震荡不收敛,经过排查才发现是advantage归一化不一致导致的。
梯度裁剪与数值稳定性
在大batch训练(100万个transition)中,梯度的方差非常低(大数定律),这是好事,但也意味着如果某个SKU的仿真数据出现异常(比如reward突然变得极大或极小),对应的梯度可能在batch中被稀释掉而不会被发现。更危险的情况是当异常数据的比例达到一定程度时,梯度突然变大,一步更新就把模型带飞。
为此我们做了两层防护:
数据层面:在trajectory进入buffer之前做范围检查,reward超过合理范围的transition被标记并记录到异常日志。这种数据清洗必须在collector端完成,因为到了learner端数据已经被混在一起,很难追溯来源。
梯度层面:使用
torch.nn.utils.clip_grad_norm_对全局梯度范数做裁剪,阈值设为0.5。在DDP模式下,梯度裁剪必须在AllReduce之后、optimizer.step之前执行,否则每张GPU裁剪自己的局部梯度会导致全局梯度方向不一致。
五、数据流与通信——gRPC、Protocol Buffers与序列化性能
为什么选择gRPC
系统中的跨进程通信主要发生在三个地方:
Collector ↔ 仿真服务:高频、低延迟、小payload(每步一个observation + action)
Collector ↔ Learner:低频、高吞吐、大payload(一个episode的所有trajectory数据)
Learner ↔ Redis ↔ Collector:低频、中等payload(模型权重)
gRPC的Protocol Buffers序列化在这三种场景下都表现良好。对于第一种场景,protobuf的编码/解码速度远快于JSON;对于第二种场景,protobuf的紧凑二进制格式减少了网络传输量;对于第三种场景,gRPC的streaming RPC使得权重更新可以主动推送而非轮询。
proto定义如下(节选):
1
2
3
4
5
6
7
8
9
10
11
12
13
service CollectorService {
rpc Collect(CollectRequest) returns (CollectResponse);
rpc GetBuffer(GetBufferRequest) returns (GetBufferResponse);
rpc UpdateWeights(UpdateWeightsRequest) returns (UpdateWeightsResponse);
rpc GetStats(GetStatsRequest) returns (GetStatsResponse);
}
service WeightSyncService {
rpc SetWeights(SetWeightsRequest) returns (SetWeightsResponse);
rpc GetWeights(GetWeightsRequest) returns (GetWeightsResponse);
rpc Subscribe(SubscribeRequest) returns (stream WeightUpdate);
rpc GetVersion(GetVersionRequest) returns (GetVersionResponse);
}
Subscribe接口使用server-side streaming,collector打开一个长连接,learner每次更新权重后通过stream推送通知。这比轮询机制(collector定时查询最新版本号)节省了大量无效请求,也降低了Redis的读压力。
Trajectory数据的传输优化
一个episode产生的trajectory数据量的估算:100万SKU × 100步 × 每步15个float32(状态10维 + action、reward、done、log_prob、value各1维)× 4字节 ≈ 6GB。这个数据量不可能一次性通过gRPC传输——gRPC默认的消息大小限制是4MB,即使调大限制,单个消息几GB也会导致内存压力和超时风险。
我们的方案是通过S3作为中间存储来完成RolloutBuffer的持久化与同步。每个collector将自己负责的SKU子集(约15,625个SKU)的trajectory数据序列化后上传到S3,每个collector的数据量约94MB。Learner端从S3拉取所有collector的数据,汇聚后开始训练。S3天然支持大文件的可靠传输、断点续传和并发读写,避免了gRPC分片传输的复杂性。同时S3上的trajectory数据自然形成了训练数据的归档,便于事后审计和问题排查。
这种S3中转的方式引入了一个权衡:延迟 vs 可靠性。S3的写入延迟(通常几秒到十几秒)比直接gRPC传输要高,但对于我们约3分钟一个episode的训练节奏来说,这个开销占比可控。更重要的是,S3的持久化特性解决了一个痛点:如果learner在训练过程中崩溃,不需要所有collector重新采集数据,直接从S3重新拉取即可恢复。
gRPC的连接管理
在长期运行的训练任务中(可能持续数天),gRPC连接的稳定性是一个不容忽视的问题。我们遇到过几类故障:
连接静默断开:TCP连接在中间设备(防火墙、负载均衡器)上被超时关闭,但客户端和服务端都不知道。下次RPC调用时才发现连接已死,导致超时错误。解决方案是启用gRPC的keepalive机制:客户端每30秒发送一次keepalive ping,服务端配置允许接收这些ping。同时设置keepalive_timeout,如果ping在5秒内没有收到pong,就认为连接已死并重建。
负载不均:当使用容器编排工具(如K8s)管理collector时,Service的负载均衡对gRPC的长连接不友好。gRPC默认维持长连接,新创建的连接会被路由到新的Pod,但旧连接仍然指向旧Pod,导致负载不均。我们的解决方案是在客户端侧实现连接池轮换:每个collector客户端维护一个连接池,定期(每5分钟)关闭最旧的连接并创建新连接,让K8s Service有机会重新做负载均衡。
消息体过大:gRPC默认的最大消息大小是4MB。在传输trajectory数据或模型权重时,需要在客户端和服务端都配置max_send_message_length和max_receive_message_length。这个配置必须在两端同时设置,否则会出现单端能发但对端拒收的情况,错误信息还不太明显(通常是RESOURCE_EXHAUSTED状态码),排查起来比较费时间。
六、RolloutBuffer——PPO数据管道的核心组件
RolloutBuffer是连接collector和learner的数据结构,它的设计直接影响训练效率。
在PPO中,RolloutBuffer需要存储每一步的transition,包括(state, action, reward, done, log_prob, value)。在所有步骤完成后,需要计算每一步的回报(return)和优势(advantage)。回报的计算涉及折扣累积奖励(discounted cumulative reward),通常使用GAE(Generalized Advantage Estimation)来平衡偏差和方差。
一个容易犯的错误是在多个episode的trajectory数据中错误地跨episode做了折扣计算。当一个trajectory中包含多个episode(中间有done=True的转折点)时,必须在done处截断折扣累积——因为done意味着环境重置,后续的reward不应影响之前的advantage估计。
1
2
3
4
5
6
7
8
9
def compute_returns(self, gamma: float) -> list[float]:
returns = []
discounted_sum = 0
for t in reversed(self.storage):
if t.done:
discounted_sum = 0
discounted_sum = t.reward + gamma * discounted_sum
returns.insert(0, discounted_sum)
return returns
内存管理的考量
100万SKU × 100步 × 每步15个float32值(state 10维 + action、reward、done、log_prob、value各1维)× 4字节 ≈ 6GB的数据。数据不会在任何单点完整汇聚,而是以分布式的方式存在:每个collector将自己负责的SKU子集的buffer上传到S3,learner从S3拉取后按GPU数量分片加载。
在DDP训练中,数据的分配策略是:将所有collector汇总的数据按SKU均匀分配到4张GPU上。每张GPU在一个PPO epoch中需要处理25万SKU × 100步的数据,约1.5GB显存占用,对于32GB的V100来说非常轻松。单个epoch内部还会进一步分成若干mini-batch。mini-batch的大小是一个需要调优的超参数:太小会导致梯度估计的方差大、GPU利用率低;太大会导致梯度过于平滑而不利于探索。我们的经验值是每个mini-batch包含1024个transition。
Buffer的清理时机
PPO是on-policy算法,每一轮训练完成后,buffer中的数据就不再有用(因为策略已经更新了,旧数据不再反映当前策略)。所以buffer必须在每轮训练后清空。这个”清空”操作看似简单,但如果实现不当,可能导致内存泄漏。Python的引用计数GC在tensor对象形成循环引用时不能及时回收,而PyTorch的CUDA tensor还涉及GPU显存的释放。正确的做法是显式地del掉所有tensor引用,然后调用torch.cuda.empty_cache()释放GPU端的缓存(注意:empty_cache不会释放正在使用的显存,只是释放PyTorch内存分配器缓存的空闲块)。
1
2
3
4
5
def train(self):
for ep in range(1, self.cfg.training.total_episodes + 1):
ep_reward = self._collect()
metrics = self.learner.learn(self.buffer)
self.buffer.clear()
七、模型架构与推理效率
Actor-Critic网络设计
供应链补货决策的模型使用Actor-Critic架构。我们采用独立的Policy网络和Value网络,不共享参数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.action_head = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return torch.softmax(self.action_head(x), dim=-1)
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim=256):
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.value_head(x)
之所以使用独立网络而非共享backbone,是因为在我们的实验中观察到:共享backbone在训练后期会出现actor和critic的学习目标冲突——actor需要特征对动作区分度高,critic需要特征对状态价值预测准确——两者互相拉扯导致收敛变慢。独立网络虽然参数量翻倍,但总量仍然很小,不构成负担。
以实际配置为例(state_dim=10,action_dim=20,hidden_dim=256),两个网络的参数量分别为约7.4万和6.9万,合计约14万参数,模型文件约600KB。状态空间维度约10维(包括当前库存、在途订单、预测销量、leadtime等),动作空间是离散的20个补货倍率档位。模型不需要太大的容量——用更大的模型在我们的实验中并没有显著提升,反而增加了推理延迟和通信开销。
DDP包装的注意事项
当模型被DistributedDataParallel包装后,原始模型被放在model.module属性下。在保存checkpoint或提取推理用的模型时,必须正确地解包装:
1
2
3
4
class ModelFactory:
@staticmethod
def get_model_for_saving(model: nn.Module, is_ddp: bool) -> nn.Module:
return model.module if is_ddp else model
如果直接保存DDP wrapper,checkpoint中会包含module.前缀的参数名(如module.shared.0.weight),在不使用DDP的推理环境中加载时会因为key不匹配而报错。这是一个DDP新手经常踩的坑,但对于有经验的工程师来说更大的风险在于:在训练过程中不小心绕过DDP wrapper直接调用了model.module的forward方法(比如在evaluation时),导致梯度同步失效而自己不知道,表现为训练缓慢但不会报错。
八、监控体系——你无法优化你看不见的东西
多层次监控架构
一个百万SKU规模的RL训练系统,如果没有完善的监控,就像盲人开飞机。我们建立了三个层次的监控:
系统层监控(Prometheus + Grafana):CPU使用率、内存使用率、GPU利用率、GPU显存使用率、网络IO、磁盘IO。这些指标通过node_exporter(系统)和nvidia-smi(GPU)采集到Prometheus,在Grafana上做时序展示。关键告警规则包括:GPU利用率持续低于50%(可能说明数据加载是瓶颈)、某台collector的CPU使用率异常高(可能SKU分配不均)、Redis内存使用率超过80%(可能权重数据没被及时清理)。
应用层监控(自定义Prometheus metrics):每个组件暴露自己的业务指标。
Collector端的关键指标:
collector_episodes_total:已完成的episode总数collector_steps_per_second:每秒完成的仿真步数collector_episode_duration_seconds:单个episode的耗时(histogram,用于发现长尾延迟)collector_weight_sync_lag_seconds:权重同步的延迟(从learner更新到collector收到的时间差)collector_buffer_size:当前buffer中的trajectory数量
Learner端的关键指标:
learner_training_step_duration_seconds:单次训练步骤的耗时learner_gpu_memory_allocated_bytes:GPU显存使用量learner_gradient_norm:梯度范数(用于检测梯度爆炸/消失)learner_weight_version:当前权重版本号
训练质量监控(TensorBoard / MLflow / Wandb):
1
2
3
4
5
6
7
class TensorboardLogger:
def __init__(self, log_dir: str):
self.writer = SummaryWriter(log_dir)
def log(self, metrics: dict, step: int = None):
for k, v in metrics.items():
self.writer.add_scalar(k, v, step)
训练质量指标包括:episode_reward的均值和分布(不仅看均值,还要看分位数——如果中位数在涨但P95在跌,说明策略在大多数SKU上改善了但在某些SKU上恶化了)、policy_loss、value_loss、entropy(entropy持续下降到零说明策略坍缩到确定性策略,可能过早收敛到局部最优)、KL散度(当前策略与旧策略之间的KL divergence,PPO通过裁剪来隐式控制,但监控它有助于判断学习率是否合适)。
训练曲线的解读——一些非显而易见的经验
Reward震荡不收敛:最常见的原因不是超参数不对,而是数据管道有bug——比如observation的某些字段没有正确归一化、advantage计算跨了episode边界、或者collector使用的权重版本和learner不一致。我们的调试经验是:先在小规模(1000个SKU,单GPU,单collector)上验证训练能收敛,确认算法正确后再逐步放大规模。
Value loss先降后升:这通常说明bootstrap value(用于GAE计算的V(s’))的估计偏差在累积。可能的原因是环境的reward scale发生了变化(比如某批新上线的SKU的reward量级和已有SKU差异很大),导致value network需要重新学习。解决方案是对reward做running normalization。
Entropy坍缩:策略的entropy快速降为零,意味着agent对所有状态都给出了确定性的动作。在供应链场景下这可能不是坏事(如果策略确实找到了最优补货量),但更常见的情况是策略过早收敛到了”永远不补货”或”永远补满”这样的退化策略。增大entropy系数或者使用entropy调度(training前半段entropy系数高,后半段逐渐降低)可以缓解。
九、工程化补充:版本管理、推理上线与部署运维
前面的章节聚焦于RL训练循环中核心的几个组件。这里简要补充系统外围的几个工程化话题——这些不是RL特有的问题,但在生产化过程中不可回避。
版本管理与实验追踪
模型版本分三层管理:训练内部用单调递增整数做权重版本号(用于权重同步一致性检查);每隔固定episode保存checkpoint(包含模型权重、优化器状态、归一化参数、episode编号);通过离线仿真回测后的模型进入MLflow Model Registry作为发布版本。
RL的实验追踪比传统ML更复杂,因为同样的配置跑两次可能差异显著(初始权重随机性、仿真中需求的随机波动)。我们在MLflow中使用parent-child run机制:一组配置作为parent,同配置的多次重复作为child run,每组至少3次取平均。
线上推理
线上推理采用batch inference模式,每天定时对所有SKU做一次推理生成次日补货建议。模型导出为ONNX格式后在CPU上推理,约600KB的模型对100万SKU的batch推理在分钟级完成。新模型通过A/B测试(分层抽样对照组,运行2~4周比较业务指标)和逐步放量(10% → 30% → 100%)上线,关键指标异常时配置驱动的快速回滚可在5分钟内完成。
部署与容错
系统使用Docker容器化部署:Go仿真服务通过多阶段构建生成<30MB的轻量镜像,Python trainer则包含PyTorch和CUDA runtime(~2GB)。生产环境通过Kubernetes编排,仿真服务用Deployment+HPA自动扩缩,collector用StatefulSet固定副本数,trainer用Job管理生命周期。
容错方面,核心策略是checkpoint + 优雅降级:训练中途崩溃可从最近checkpoint恢复(保存优化器状态、随机种子等以确保可复现);单个collector宕机如果丢失数据<5%则继续训练,>5%则暂停等待恢复;单个仿真进程崩溃不影响collector,其负责的SKU被临时分配到同机其他进程。Go仿真内部每个SKU的step被recover()包裹,单SKU的panic不会拖垮整个进程。
十、那些踩过的坑——从调试地狱中幸存
NCCL通信超时
问题表现:训练在某个episode莫名卡住,日志停止输出,CPU和GPU都有一定的负载但没有进展。几分钟后NCCL抛出超时错误。
根因:DDP在反向传播时使用AllReduce来同步梯度,这要求所有rank在同一时刻执行同一个AllReduce操作。如果某个rank在forward或backward中走了不同的代码路径(比如某个rank的数据触发了一个条件分支,导致计算图不同),另一些rank的AllReduce会等待这个rank,最终超时。
在我们的场景下,具体的触发条件是:PPO代码中有一个判断if buffer.size() > 0的分支。由于数据分配的不均匀,偶尔某个rank分到的数据恰好为空,跳过了训练步骤。其他rank在等待这个rank的AllReduce,造成死锁。
解决方案:在数据分配阶段确保每个rank至少分到一些数据,即使需要重复分配一些样本来填充。同时设置NCCL_ASYNC_ERROR_HANDLING=1环境变量,让NCCL在超时后能干净地退出而不是挂死整个进程。
Go仿真进程的内存泄漏
问题表现:仿真进程在运行数小时后,内存使用量缓慢但持续增长,最终触发OOM killer。
根因:Trajectory结构体中的Observations字段是[]map[string][]float32类型。每一步仿真都会创建一个新的map和若干新的[]float32切片。在Go中,即使Trajectory被回收,如果其中的切片被其他地方引用(比如通过gRPC传输时被protobuf序列化过程引用),这些切片就不会被GC回收。
调试过程中使用pprof的heap profile发现了大量的小内存分配来自map[string][]float32的创建。根本原因是gRPC的响应消息在发送后没有被及时释放——Go的gRPC库在内部会持有消息的引用直到底层HTTP/2帧完全发送完毕,而在高负载下这个过程可能延迟数秒。
解决方案:引入对象池(sync.Pool)来复用Trajectory结构体和其中的切片。每一步仿真不再创建新的map和切片,而是从池中取出已分配的对象、覆写数据、使用、然后放回池中。这彻底消除了高频小对象分配导致的GC压力和内存泄漏。
Redis权重同步的局限性
当前模型权重只有约600KB,通过Redis直接同步在延迟上不是问题(毫秒级)。但Redis作为权重同步的唯一通道存在几个隐患:
单点依赖:Redis宕机会阻塞所有collector的权重更新。虽然可以用Redis Sentinel或Cluster做高可用,但增加了运维复杂度。
扩展性天花板:如果未来模型复杂度增长(比如引入注意力机制或更大的特征提取器),权重可能增长到数十MB。Redis是单线程模型,大Key的读写会阻塞其他所有操作,届时所有collector的并发读取会成为瓶颈。
带宽争用:Redis同时还承担pub/sub通知职责,权重数据传输和消息分发共享同一连接,在高负载下可能互相干扰。
我们的演进方向是:将权重存储迁移到S3或通过gRPC streaming直接从learner推送到collector,Redis仅保留版本号广播的信号通道角色。S3方案的好处是天然支持大文件、高并发读取、且有版本管理能力;gRPC方案则延迟更低,适合对同步时效要求高的场景。
Collector和Learner的速度不匹配
问题表现:GPU利用率经常在0%和100%之间交替,形成明显的”锯齿”模式——GPU空闲等待数据 → 数据到达后满载训练 → 训练完成再次空闲等待。
这是CPU-GPU异构计算中最经典的瓶颈。仿真(CPU)和训练(GPU)的计算时间比值决定了系统的利用效率。如果仿真时间远大于训练时间(在我们的场景下确实如此),GPU大部分时间都在空闲。
在我们的场景下,这个不匹配仍然很明显:一个episode的仿真耗时约3分钟(64 worker × 4核 = 256核并行),而GPU训练只需要约10秒(14万参数的小模型,1M transitions / 1024 batch × 10 epochs),加上S3数据传输约10~20秒,GPU利用率约5%。增加collector的数量能提高仿真吞吐,但每台机器需要32核以上来支撑8个4核worker,资源成本线性增长。
但更根本的解决方案是在算法层面打破on-policy的严格约束。”近似on-policy”允许collector使用最近N个版本内的权重采集数据(而非严格使用最新版本),然后在PPO更新时用importance sampling ratio来修正off-policy偏差。这在理论上是PPO本身就支持的(PPO的clipping就是为了处理一定程度的off-policy数据),但在实践中N不能太大(通常不超过2~3),否则训练不稳定。这样可以让collector在learner训练的同时继续用旧权重采集下一轮数据,形成流水线,显著提高整体吞吐。
仿真环境的确定性问题
问题表现:同一组SKU、同样的初始条件、同样的策略,两次仿真得到不同的trajectory。
这在调试时是灾难性的——你无法确定某个问题是代码bug还是随机波动。供应链仿真中的随机因素包括:需求的随机波动(通常用泊松分布或正态分布采样)、供应商交期的随机延迟、随机的促销触发等。
解决方案是将所有随机源都通过可控的随机种子管理:
1
2
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
但仅仅设置Python和PyTorch的种子是不够的。Go仿真进程有自己的随机数生成器,需要通过配置传入种子。多个goroutine并发使用math/rand时还需要注意:Go 1.20之后全局math/rand默认使用加锁的源,但如果手动创建了rand.New(rand.NewSource(seed)),这个source不是并发安全的。正确的做法是给每个goroutine创建独立的rand.Rand实例,种子为全局种子+goroutine编号。
此外,CUDA的计算也存在非确定性——某些CUDA kernel的原子操作不保证顺序,导致浮点累加的结果可能因为加法顺序不同而产生微小的差异。对于调试目的,可以设置torch.backends.cudnn.deterministic = True和torch.use_deterministic_algorithms(True),但这会降低训练性能(约10~20%),所以通常只在调试时启用。
训练中途SKU集合变化
问题表现:在训练进行到一半时,业务方通知有一批新SKU上线,需要纳入训练;同时有一批旧SKU下线,需要移除。
这在供应链场景下是常态——SKU的上下架是持续进行的。但对RL训练来说这是一个棘手的问题:
新SKU没有历史trajectory数据,但训练需要它们的数据。
旧SKU的数据已经在buffer中,如果不清理会浪费计算资源。
SKU的总数变了,数据分配方案需要重新调整。
如果模型的输入维度与SKU无关(每个SKU独立推理),模型本身不需要改。但如果存在SKU间的交互特征(如品类维度的共享embedding),模型架构可能需要调整。
我们的解决方案是”定期重平衡”:每隔一定周期(如每天),暂停训练,更新SKU列表,重新计算数据分配方案,然后从当前checkpoint恢复训练。新增SKU的初始策略使用同品类已训练SKU的平均策略作为warm-start,而不是从随机策略开始。
十一、性能调优的方法论
在百万SKU的规模下做性能调优,不能靠猜,必须靠数据。我们的性能调优遵循以下流程:
Profiling先行:使用Go的pprof分析仿真进程的CPU和内存热点;使用PyTorch Profiler分析GPU训练的算子耗时和显存使用;使用Prometheus分析端到端的各阶段耗时。
找到瓶颈:在典型的训练循环中,各阶段的耗时分布大约是:仿真采集
80%、数据传输(S3读写)10%、GPU训练5%、权重同步<1%、其他5%。仿真仍然是最主要的瓶颈——一个episode的仿真约3分钟,而GPU训练只需约10秒。针对瓶颈优化:
仿真层面:优化SKU仿真步骤的计算逻辑,减少不必要的map操作,使用数组代替map来存储高频访问的状态字段。
数据传输层面:使用protobuf的binary wire format代替JSON,压缩大型trajectory数据(使用snappy压缩,压缩比约2:1,解压速度极快)。
GPU训练层面:调优mini-batch大小以充分利用GPU的SM并行度,使用
torch.compile(PyTorch 2.0+)来编译模型获得约10~15%的加速。权重同步层面:当前模型小(600KB),不是瓶颈,但为未来模型增长预留了S3/gRPC的迁移路径。
验证优化效果:每次优化后重新做profiling,确认瓶颈确实被缓解了(而不是仅仅移到了别处),并检查优化是否引入了新的问题。
一个有意思的优化案例
在profiling中发现仿真进程的CPU时间有约30%花在了map[string][]float32的key lookup上——每一步仿真都需要从observation map中按key取值,Go的map lookup虽然是O(1)但常数系数不小(涉及hash计算和bucket遍历)。对于一个每秒执行数百万次的操作,这个常数系数就变得重要了。
优化方案:将observation的字段从map改为固定布局的struct。每个字段对应struct的一个字段,编译器可以直接通过偏移量访问,完全消除了hash计算的开销。这个改动涉及仿真代码和gRPC接口的重构(protobuf消息从map<string, Tensor>改为具名字段),工作量不小,但带来了约25%的仿真速度提升。这种优化在小规模时无关紧要,但在百万SKU的尺度下效果显著。
结语
回顾这个系统的建设过程,最深刻的体会是:RL-Infra的核心挑战不在于任何单个技术组件(gRPC、DDP、Redis这些都有成熟的文档和社区经验),而在于它们的集成——在一个CPU-GPU异构、多进程、多机器的分布式系统中,让数据以正确的顺序、正确的版本、正确的格式流动,并且在某一环节出错时能快速发现和恢复。
供应链场景给这个挑战增加了额外的维度:SKU的规模使得”小规模好用的方案大规模就不行了”成为常态(前面提到的map vs struct就是一个例子);业务逻辑的复杂性使得仿真器成为不可替代的性能瓶颈(不像Atari游戏那样有GPU加速的环境模拟器);决策的商业影响使得可靠性和可追溯性成为硬性要求。
这篇文章试图将我们在这个系统上的工程经验做一个完整的记录——不仅仅是”我们做了什么”,更重要的是”为什么这么做”以及”哪些地方走了弯路”。希望对正在或即将踏入RL-Infra这个领域的工程师们有所帮助。RL不只是算法工程师在Jupyter Notebook里调超参数的事,它的工程化落地是一个完整的系统工程问题,需要在分布式系统、高性能计算、DevOps、数据工程等多个方向上都有扎实的功底。而这些,正是RL-Infra工程师的价值所在。