扩展到其他 RL(HF) 算法 ================================= 最后更新时间:2025 年 2 月 25 日。 我们已经实现了 PPO 算法完整的训练流程。为了扩展到其他算法,我们分析了 verl 的高级原理,并提供了一个实现 DPO 算法的教程。用户可以遵循类似的模式来扩展到其他 RL 算法。 .. note:: **核心思路**:单进程驱动多进程计算和数据通信。 整体方法 ---------------- 步骤 1:考虑每个模型所需的多机多 GPU 计算,例如 actor_rollout 模型中的 ``generate_sequence`` 、``compute_log_prob`` 和 ``update_policy`` 。实现分布式的单进程多数据 (SPMD) 计算,并将其封装成 API。 步骤 2:根据不同的分布式场景,包括 Megatron-LM 中的 FSDP 和 3D 并行,实现对多进程计算之间数据交互的单进程控制。 步骤 3:利用封装好的 API 来实现控制流。 示例:在线 DPO ---------------- 我们使用 verl 实现一个简单的在线 DPO 算法。在线 DPO 的算法流程如下: 1. 存在一个 Prompt(rollout)生成器,其权重与 actor 模型相同。将一批 prompts 输入生成器后,它为每个 prompt 生成 N 个 responses。 2. 将所有 prompts + responses 发送给一个 verifier 进行评分,它可以是 reward model 或基于规则的函数。然后将它们成对排序,形成一个训练批次。 3. 使用这个训练批次通过 DPO 来训练 actor 模型。在此过程中,需要一个参考策略 (reference policy)。 步骤 1:所需的多机多 GPU 计算有哪些 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **Sample Generator** 实现细节: .. code:: python from verl.single_controller.base import Worker from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool import ray @ray.remote class SampleGenerator(Worker): def __init__(self, config): super().__init__() self.config = config def generate_sequences(self, data): pass 这里,``SampleGenerator`` 可以看作是由 ``torchrun`` 启动的多进程,每个进程运行相同的代码 (SPMD)。``SampleGenerator`` 需要实现一个 ``generate_sequences`` API 供控制流调用。内部实现细节可以使用任何推理引擎,包括 vllm、sglang 和 huggingface。用户可以很大程度上复用 verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py 中的代码,这里不再赘述。 **ReferencePolicy 推理** API:计算参考对数概率 (compute reference log probability) .. code:: python from verl.single_controller.base import Worker import ray @ray.remote class ReferencePolicy(Worker): def __init__(self): super().__init__() self.model = Model() def infer(self, data): return self.model(data) **Actor 更新** API:更新 actor 模型参数 .. code:: python from verl.single_controller.base import Worker import ray @ray.remote class DPOActor(Worker): def __init__(self): super().__init__() self.model = Model() self.model = FSDP(self.model) # 或其他分布式策略 self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) self.loss_fn = xxx def update(self, data): self.optimizer.zero_grad() logits = self.model(data) loss = self.loss_fn(logits) loss.backward() self.optimizer.step() **注意:如何区分控制进程和分布式计算进程** ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - 控制进程通常是直接用 ``@ray.remote`` 装饰的函数。 - 计算进程都封装在 ``RayWorkerGroup`` 中。 用户可以复用 PPO 算法中实现的大部分分布式计算逻辑,包括 verl/verl/trainer/ppo 中的 FSDP 和 Megatron-LM 后端。 步骤 2:根据不同的分布式场景,实现多进程数据交互的单进程控制 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **这里要解决的核心问题是:单个进程如何将数据发送到多个进程,驱动多进程计算,以及控制进程如何获取多进程计算的结果。** 首先,我们在控制进程中初始化多进程 ``WorkerGroup``。 .. code:: python @ray.remote(num_cpus=1) def main_task(config): # 构造 SampleGenerator resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) # 将 SampleGenerator 放入资源池 worker_group = RayWorkerGroup(resource_pool, ray_cls) # 构造参考策略 从图中可以看到,在控制进程中,多个进程被封装到一个 ``RayWorkerGroup`` 中。在这个 ``WorkerGroup`` 内部,有一个 ``self._workers`` 成员,其中每个 worker 都是一个 SampleGenerator 的 RayActor (https://docs.ray.io/en/latest/ray-core/actors.html)。ray_trainer.md 也提供了一个 ``MegatronRayWorkerGroup`` 的实现。 假设模型使用 FSDP 进行分布式部署,并且控制进程上有一批数据,对于数据并行,底层调用过程如下: .. code:: python data = xxx data_list = data.chunk(dp_size) output = [] for d in data_list: # worker_group._workers[i] 是一个 SampleGenerator output.append(worker_group._workers[i].generate_sequences.remote(d)) output = ray.get(output) output = torch.cat(output) 单进程调用多进程涉及以下 3 个步骤: 1. 在控制进程上将数据拆分成 DP 部分。 2. 将数据发送到远程,通过 RPC 调用远程计算,并利用多进程计算。 3. 在控制进程上获取每个 worker 的计算结果并合并。 频繁在控制器进程中调用这 3 个步骤会严重影响代码的可读性。**在 verl 中,我们抽象并封装了这 3 个步骤,使得 worker 的方法 + 分发 + 收集可以注册到 worker_group 中。** .. code:: python from verl.single_controller.base.decorator import register def dispatch_data(worker_group, data): return data.chunk(worker_group.world_size) def collect_data(worker_group, data): return torch.cat(data) dispatch_mode = { 'dispatch_fn': dispatch_data, 'collect_fn': collect_data } @register(dispatch_mode=dispatch_mode) def generate_sequences(self, data): pass 这样,我们就可以在控制(驱动)进程(这是一个单进程)上直接通过 ``worker_group`` 调用 worker 中的方法: .. code:: python output = worker_group.generate_sequences(data) 这一行代码包含了数据拆分、数据分发和计算,以及数据收集。 此外,每个模型的模型并行大小通常是固定的,包括 dp、tp、pp。因此,对于这些常见的分布式场景,我们已经在 `decorator.py `_ 中预先实现了特定的分发和收集方法,可以直接用于封装计算。 .. code:: python from verl.single_controller.base.decorator import register, Dispatch @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, data: DataProto) -> DataProto: pass 这里要求数据接口为 ``DataProto``。 ``DataProto`` 的定义在 `protocol.py `_ 中。 步骤 3:主训练循环 ~~~~~~~~~~~~~~~~~~~~~~~~~~ 有了上述训练流程,我们就可以实现算法的控制流。建议 ``main_task`` 也作为一个 ray remote 进程。 .. code:: python @ray.remote(num_cpus=1) def main_task(config): # 构造 SampleGenerator resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) # 将 SampleGenerator 放入资源池 sample_gen = RayWorkerGroup(resource_pool, ray_cls) # 构造参考策略 ray_cls = RayClassWithInitArgs(ReferencePolicy) ref_policy = RayWorkerGroup(resource_pool, ray_cls) # 构造 actor ray_cls = RayClassWithInitArgs(DPOActor) dpo_policy = RayWorkerGroup(resource_pool, ray_cls) dataloader = DataLoader() for data in dataloader: # 生成数据 data = sample_gen.generate_sequences(data) # 为每个数据生成分数 data = generate_scores(data) # 生成成对数据 data = generate_pairwise_data(data) # 生成 ref_log_prob data.batch['ref_log_prob'] = ref_policy.infer(data) # 使用 dpo 更新 dpo_policy.update(data) # 记录日志 这里,不同的 ``WorkerGroups`` 可以放置在同一个资源池中,或者使用 ``create_colocated_worker_cls`` 放置在不同的资源池中,就像在 `ray_trainer.py `_ 中一样。