扩展到其他 RL(HF) 算法

最后更新时间:2025 年 2 月 25 日。

我们已经实现了 PPO 算法完整的训练流程。为了扩展到其他算法,我们分析了 verl 的高级原理,并提供了一个实现 DPO 算法的教程。用户可以遵循类似的模式来扩展到其他 RL 算法。

Note

核心思路:单进程驱动多进程计算和数据通信。

整体方法

步骤 1:考虑每个模型所需的多机多 GPU 计算,例如 actor_rollout 模型中的 generate_sequencecompute_log_probupdate_policy 。实现分布式的单进程多数据 (SPMD) 计算,并将其封装成 API。

步骤 2:根据不同的分布式场景,包括 Megatron-LM 中的 FSDP 和 3D 并行,实现对多进程计算之间数据交互的单进程控制。

步骤 3:利用封装好的 API 来实现控制流。

示例:在线 DPO

我们使用 verl 实现一个简单的在线 DPO 算法。在线 DPO 的算法流程如下:

  1. 存在一个 Prompt(rollout)生成器,其权重与 actor 模型相同。将一批 prompts 输入生成器后,它为每个 prompt 生成 N 个 responses。

  2. 将所有 prompts + responses 发送给一个 verifier 进行评分,它可以是 reward model 或基于规则的函数。然后将它们成对排序,形成一个训练批次。

  3. 使用这个训练批次通过 DPO 来训练 actor 模型。在此过程中,需要一个参考策略 (reference policy)。

步骤 1:所需的多机多 GPU 计算有哪些

Sample Generator

实现细节:

from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray

@ray.remote
class SampleGenerator(Worker):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def generate_sequences(self, data):
        pass

这里,SampleGenerator 可以看作是由 torchrun 启动的多进程,每个进程运行相同的代码 (SPMD)。SampleGenerator 需要实现一个 generate_sequences API 供控制流调用。内部实现细节可以使用任何推理引擎,包括 vllm、sglang 和 huggingface。用户可以很大程度上复用 verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py 中的代码,这里不再赘述。

ReferencePolicy 推理

API:计算参考对数概率 (compute reference log probability)

from verl.single_controller.base import Worker
import ray

@ray.remote
class ReferencePolicy(Worker):
    def __init__(self):
        super().__init__()
        self.model = Model()

    def infer(self, data):
        return self.model(data)

Actor 更新

API:更新 actor 模型参数

from verl.single_controller.base import Worker
import ray

@ray.remote
class DPOActor(Worker):
    def __init__(self):
        super().__init__()
        self.model = Model()
        self.model = FSDP(self.model)  # 或其他分布式策略
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.loss_fn = xxx

    def update(self, data):
        self.optimizer.zero_grad()
        logits = self.model(data)
        loss = self.loss_fn(logits)
        loss.backward()
        self.optimizer.step()

注意:如何区分控制进程和分布式计算进程

  • 控制进程通常是直接用 @ray.remote 装饰的函数。

  • 计算进程都封装在 RayWorkerGroup 中。

用户可以复用 PPO 算法中实现的大部分分布式计算逻辑,包括 verl/verl/trainer/ppo 中的 FSDP 和 Megatron-LM 后端。

步骤 2:根据不同的分布式场景,实现多进程数据交互的单进程控制

这里要解决的核心问题是:单个进程如何将数据发送到多个进程,驱动多进程计算,以及控制进程如何获取多进程计算的结果。 首先,我们在控制进程中初始化多进程 WorkerGroup

@ray.remote(num_cpus=1)
def main_task(config):
    # 构造 SampleGenerator
    resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs
    ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
    # 将 SampleGenerator 放入资源池
    worker_group = RayWorkerGroup(resource_pool, ray_cls)

    # 构造参考策略

从图中可以看到,在控制进程中,多个进程被封装到一个 RayWorkerGroup 中。在这个 WorkerGroup 内部,有一个 self._workers 成员,其中每个 worker 都是一个 SampleGenerator 的 RayActor (https://docs.ray.io/en/latest/ray-core/actors.html)。ray_trainer.md 也提供了一个 MegatronRayWorkerGroup 的实现。

假设模型使用 FSDP 进行分布式部署,并且控制进程上有一批数据,对于数据并行,底层调用过程如下:

data = xxx
data_list = data.chunk(dp_size)

output = []
for d in data_list:
    # worker_group._workers[i] 是一个 SampleGenerator
    output.append(worker_group._workers[i].generate_sequences.remote(d))

output = ray.get(output)
output = torch.cat(output)

单进程调用多进程涉及以下 3 个步骤:

  1. 在控制进程上将数据拆分成 DP 部分。

  2. 将数据发送到远程,通过 RPC 调用远程计算,并利用多进程计算。

  3. 在控制进程上获取每个 worker 的计算结果并合并。

频繁在控制器进程中调用这 3 个步骤会严重影响代码的可读性。在 verl 中,我们抽象并封装了这 3 个步骤,使得 worker 的方法 + 分发 + 收集可以注册到 worker_group 中。

from verl.single_controller.base.decorator import register

def dispatch_data(worker_group, data):
    return data.chunk(worker_group.world_size)

def collect_data(worker_group, data):
    return torch.cat(data)

dispatch_mode = {
    'dispatch_fn': dispatch_data,
    'collect_fn': collect_data
}

@register(dispatch_mode=dispatch_mode)
def generate_sequences(self, data):
    pass

这样,我们就可以在控制(驱动)进程(这是一个单进程)上直接通过 worker_group 调用 worker 中的方法:

output = worker_group.generate_sequences(data)

这一行代码包含了数据拆分、数据分发和计算,以及数据收集。

此外,每个模型的模型并行大小通常是固定的,包括 dp、tp、pp。因此,对于这些常见的分布式场景,我们已经在 decorator.py 中预先实现了特定的分发和收集方法,可以直接用于封装计算。

from verl.single_controller.base.decorator import register, Dispatch

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, data: DataProto) -> DataProto:
    pass

这里要求数据接口为 DataProtoDataProto 的定义在 protocol.py 中。

步骤 3:主训练循环

有了上述训练流程,我们就可以实现算法的控制流。建议 main_task 也作为一个 ray remote 进程。

@ray.remote(num_cpus=1)
def main_task(config):
    # 构造 SampleGenerator
    resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs
    ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
    # 将 SampleGenerator 放入资源池
    sample_gen = RayWorkerGroup(resource_pool, ray_cls)

    # 构造参考策略
    ray_cls = RayClassWithInitArgs(ReferencePolicy)
    ref_policy = RayWorkerGroup(resource_pool, ray_cls)

    # 构造 actor
    ray_cls = RayClassWithInitArgs(DPOActor)
    dpo_policy = RayWorkerGroup(resource_pool, ray_cls)

    dataloader = DataLoader()

    for data in dataloader:
        # 生成数据
        data = sample_gen.generate_sequences(data)
        # 为每个数据生成分数
        data = generate_scores(data)
        # 生成成对数据
        data = generate_pairwise_data(data)
        # 生成 ref_log_prob
        data.batch['ref_log_prob'] = ref_policy.infer(data)
        # 使用 dpo 更新
        dpo_policy.update(data)
        # 记录日志

这里,不同的 WorkerGroups 可以放置在同一个资源池中,或者使用 create_colocated_worker_cls 放置在不同的资源池中,就像在 ray_trainer.py 中一样。