HybridFlow 编程指南
最后更新: 2025年6月2日。
作者: Chi Zhang
verl 是论文 HybridFlow [1] 的开源实现。在本节中,我们将介绍 HybridFlow 的基本概念、动机以及如何使用 verl API 进行编程。
动机与设计
我们使用数据流来表示强化学习(RL)系统。[4]。
数据流
数据流是对计算的一种抽象。神经网络训练就是一个典型的数据流。它可以表示为计算图。
该图 [2] 表示了一个多项式函数后接 Sigmoid 函数的计算图。在神经网络计算的数据流中,每个节点代表一个算子,每条边代表前向/后向传播的方向。计算图决定了神经网络的架构。
RL 作为数据流问题
强化学习(RL)训练也可以表示为数据流。以下是代表 RLHF 中使用的 PPO 算法的数据流图 [3]:
然而,RL 的数据流与神经网络训练的数据流存在根本性差异,如下所示:
对于表格型强化学习,每个算子都是一个简单的标量数学运算(例如 bellman 更新)。而在深度强化学习(DRL)中,每个算子都是一个高层神经网络计算,例如模型推理/更新。这使得 RL 成为一个两级数据流问题:
控制流:定义了高层算子如何执行(例如,在 PPO 中,我们首先进行 rollout。然后,进行优势计算。最后,进行训练)。它表达了**RL 算法的核心逻辑**。
计算流:定义了**神经网络计算**的数据流(例如,模型前向/后向传播/优化器)。
设计选择
在 LLM 出现之前的 DRL 中,模型尺寸通常较小。因此,高层神经网络计算可以在单个进程中完成。这使得将计算流嵌入到控制流中作为一个单一进程成为可能。
然而,在 LLM 时代,计算流(例如,训练神经网络)变成了一个多进程程序。这自然导致了两种设计选择:
将控制流也转换为多进程程序。然后与计算流共置(统一的多控制器)。
优点:
在固定的计算流和控制流下,通过最小化训练和数据传输中的通信开销,可以实现**最佳性能**。
缺点:
从软件角度来看,计算流和/或控制流**难以重用**,因为计算代码与特定的控制器代码耦合。例如,PPO 的训练循环是通用的。假设我们已经实现了一个 PPO 训练流程,并使用了特定的计算流,如 FSDP。如果我们想将计算流从 FSDP 切换到 Megatron,由于控制流和计算流的耦合,控制流或计算流都无法重用。
由于程序的 असतात多进程特性,对于灵活和动态的控制流,用户需要付出更多努力。
分离流:一个进程用于控制流,多个进程用于计算流。
优点:
在解耦后,在别处定义的计算流可以**轻松重用**。
控制器运行在单个进程上。实现一个具有**不同控制流的新 RL 算法非常简单容易**。
缺点:
控制器进程和计算进程每次交互时都会产生额外的**数据通信开销**。数据必须来回发送。
在 verl 中,我们采用了后一种策略,分离控制流和计算流。verl 的设计旨在解耦 RL 算法的控制流和计算引擎的实现。
整体执行图
下图是一个简化的图,描述了强化学习作业的执行过程。图中,控制器运行在单个进程上,而生成器/Actor 工作节点、Critic 工作节点则运行在多个进程上,并放置在特定的资源组中。对于 rollou,控制器将数据传递给生成器以执行样本生成。当 rollout 完成后,数据将传回控制器以进行算法的下一步。其他工作节点也执行类似的流程。通过混合控制器设计,数据流和计算被解耦,从而在计算效率和定义算法训练循环的灵活性方面都提供了优势。
代码结构解析(PPO)
入口函数
代码: https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py
在此文件中,我们定义了一个名为 main_task 的远程函数,它作为控制器(驱动程序)进程,如上图所示。我们还定义了一个 RewardManager,用户可以基于数据集中的数据源定制奖励函数。请注意,RewardManager 应返回 RL 算法所优化的最终 token 级奖励。用户可以组合基于模型的奖励和基于规则的奖励。 main_task 构建一个 RayPPOTrainer 实例并启动 fit。请注意,main_task 作为单个进程运行。
我们强烈建议不要将 main_task 调度在 Ray 集群的 head 节点上,因为 main_task 会消耗大量内存,而 head 节点通常资源很少。
Ray trainer
代码: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py
RayPPOTrainer 管理:
Worker 和 WorkerGroup 的构建
运行 PPO 算法的主循环
请注意,RayPPOTrainer 的 fit 函数**作为单个进程运行**。
Worker 和 WorkerGroup 构建
每个 workerGroup 管理一个远程运行的工作节点列表。请注意,worker group 在其构造函数的进程中运行。 WorkerGroup 中的每个 worker 运行在 GPU 上。worker group 作为控制器进程与一组工作节点交互的代理,以执行某些计算。为了做到这一点,我们必须将 worker 的方法绑定到 WorkerGroup 的方法,并定义数据分发和数据收集。这通过简单的装饰器完成,将在 Worker 定义部分介绍。
例如,在 PPO 中,我们定义了 3 个 worker group:
ActorRolloutRef:管理 actor、rollout 和 reference policy。ActorRolloutRefWorker 可以实例化为单个 actor、单个 rollout、单个 reference policy,或者一个组合的 actor/rollout,或者一个组合的 actor/rollout/ref。这种设计旨在最大程度地在各种场景下进行代码重用。将 actor 和 rollout 置于同一位置是为了通过 nccl 进行快速权重传输。将 actor 和 reference 置于同一位置是为了实现高效的 LoRA PPO,因为 reference policy 只是 LoRA 中基模型。这种共置是通过 verl.single_controller.ray.base.create_colocated_worker_cls 完成的,该函数创建一个单一的 Ray 远程类,暴露这些角色的所有类方法。
Critic:管理 critic 模型。
Reward:管理 reward 模型。
worker group 将在指定的资源池上构建。资源池是 Ray 集群中的一组 GPU。
Worker 定义
我们以 ActorRolloutRefWorker 为例。 它应该向控制器进程暴露的 API 包括:
init_model:构建底层模型。
generate_sequences:给定提示,生成响应。
compute_log_prob:使用 actor 计算生成序列的对数概率。
compute_ref_log_prob:使用 reference policy 计算生成序列的对数概率。
save_checkpoint:保存检查点。
请注意,这些方法定义在 worker 中,只能通过远程调用来调用。例如,如果控制器进程想初始化模型,它必须调用:
for worker in actor_rollout_ref_wg:
worker.init_model.remote()
如果控制器进程想生成序列,它必须调用:
data = xxx
# 将数据分割成 dp chunks
data_dp_lst = data.split(dp_size)
output_dp_lst = []
for i, worker in enumerate(actor_rollout_ref_wg):
output_future = worker.generate_sequences.remote(data_dp_lst[i])
output_dp_lst.append(output_future)
output = torch.cat(ray.get(output_dp_lst), dim=0)
我们观察到,控制器进程调用 worker group 方法通常可以分为 3 个部分:
将数据分割成数据并行大小。
将相应的数据分发给每个 worker。
计算完成后收集并连接数据。
在 verl 中,我们设计了一个语法糖来将这 3 个过程封装成控制器进程的一次调用。
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(data):
...
# 在 driver 上
output = actor_rollout_ref_wg.generate_sequences(data)
我们使用 register 装饰 worker 的方法,显式定义输入数据如何分割和分发给每个 worker,以及输出数据如何由控制器收集和连接。例如,Dispatch.DP_COMPUTE_PROTO 将输入数据分割成 dp chunks,将每个数据分发给每个 worker,收集输出并将结果连接起来。请注意,此函数要求输入和输出是此处定义的 DataProto (https://github.com/volcengine/verl/blob/main/verl/protocol.py)。
PPO 主循环
有了上述 API,我们就可以像单进程程序一样实现 PPO 的主循环。
for prompt in dataloader:
output = actor_rollout_ref_wg.generate_sequences(prompt)
old_log_prob = actor_rollout_ref_wg.compute_log_prob(output)
ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output)
values = critic_wg.compute_values(output)
rewards = reward_wg.compute_scores(output)
# compute_advantages 直接在控制进程上运行
advantages = compute_advantages(values, rewards)
output = output.union(old_log_prob)
output = output.union(ref_log_prob)
output = output.union(values)
output = output.union(rewards)
output = output.union(advantages)
# 更新 actor
actor_rollout_ref_wg.update_actor(output)
critic.update_critic(output)
要点
这种编程范式使用户无需修改控制进程即可使用不同的计算后端。
这种编程范式允许通过更改 WorkerGroup 和 ResourcePool 的映射来实现灵活的放置,而无需修改控制进程。
代码库组织
代码库中的重要文件组织如下:
verl # verl 包
trainer
main_ppo.py # RL 训练的入口点
ppo
ray_trainer.py # PPO 等 RL 算法的训练循环
fsdp_sft_trainer.py # 使用 FSDP 后端的 SFT 训练器
config
generation.yaml # rollout 的配置模板
ppo_trainer.yaml # RL 训练器的配置模板
workers
protocol.py # DataProto 的接口
fsdp_workers.py # FSDP worker 接口:ActorRolloutRefWorker, CriticWorker, RewardModelWorker
megatron_workers.py # Megatron worker 接口:ActorRolloutRefWorker, CriticWorker, RewardModelWorker
actor
dp_actor.py # 使用 FSDP 后端的数据并行 actor
megatron_actor.py # 使用 Megatron 后端的 nD 并行 actor
critic
dp_critic.py # 使用 FSDP 后端的数据并行 critic
megatron_critic.py # 使用 Megatron 后端的 nD 并行 critic
reward_model
megatron
reward_model.py # 使用 Megatron 后端的 reward model
rollout
vllm
vllm_rollout.py # 使用 vllm 后端的 rollout
hf_rollout.py # 使用 huggingface TGI 后端的 rollout
sharding_manager
fsdp_ulysses.py # 使用 FSDP + ulysses 时的 DDP 和模型重分片
fsdp_vllm.py # 使用 FSDP + ulysses + vllm 时的 DDP 和模型重分片
megatron_vllm.py # 使用 Megatron + vllm 时的 DDP 和模型重分片
utils
dataset # SFT/RM/RL 的数据集
reward_score # 基于函数的奖励
gsm8k.py # gsm8k 数据集的奖励函数
math.py # math 数据集的奖励函数
seqlen_balancing.py # 序列长度平衡优化
models
llama # Megatron 实现的 llama, deepseek, mistral 等
transformers # ulysses 与 transformer 模型(如 llama, qwen 等)集成
weight_loader_registery.py # 将 hf cpts 加载到 Megatron 的权重加载器注册表
third_party
vllm # adaptor 用于 RL 中的 vllm
vllm_spmd # vllm >= v0.7 adaptor
examples # 示例脚本
tests # 集成和单元测试
.github # 持续集成测试的配置