扩展到其他 RL(HF) 算法
最后更新时间:2025 年 2 月 25 日。
我们已经实现了 PPO 算法完整的训练流程。为了扩展到其他算法,我们分析了 verl 的高级原理,并提供了一个实现 DPO 算法的教程。用户可以遵循类似的模式来扩展到其他 RL 算法。
Note
核心思路:单进程驱动多进程计算和数据通信。
整体方法
步骤 1:考虑每个模型所需的多机多 GPU 计算,例如 actor_rollout 模型中的 generate_sequence 、compute_log_prob 和 update_policy 。实现分布式的单进程多数据 (SPMD) 计算,并将其封装成 API。
步骤 2:根据不同的分布式场景,包括 Megatron-LM 中的 FSDP 和 3D 并行,实现对多进程计算之间数据交互的单进程控制。
步骤 3:利用封装好的 API 来实现控制流。
示例:在线 DPO
我们使用 verl 实现一个简单的在线 DPO 算法。在线 DPO 的算法流程如下:
存在一个 Prompt(rollout)生成器,其权重与 actor 模型相同。将一批 prompts 输入生成器后,它为每个 prompt 生成 N 个 responses。
将所有 prompts + responses 发送给一个 verifier 进行评分,它可以是 reward model 或基于规则的函数。然后将它们成对排序,形成一个训练批次。
使用这个训练批次通过 DPO 来训练 actor 模型。在此过程中,需要一个参考策略 (reference policy)。
步骤 1:所需的多机多 GPU 计算有哪些
Sample Generator
实现细节:
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray
@ray.remote
class SampleGenerator(Worker):
def __init__(self, config):
super().__init__()
self.config = config
def generate_sequences(self, data):
pass
这里,SampleGenerator 可以看作是由 torchrun 启动的多进程,每个进程运行相同的代码 (SPMD)。SampleGenerator 需要实现一个 generate_sequences API 供控制流调用。内部实现细节可以使用任何推理引擎,包括 vllm、sglang 和 huggingface。用户可以很大程度上复用 verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py 中的代码,这里不再赘述。
ReferencePolicy 推理
API:计算参考对数概率 (compute reference log probability)
from verl.single_controller.base import Worker
import ray
@ray.remote
class ReferencePolicy(Worker):
def __init__(self):
super().__init__()
self.model = Model()
def infer(self, data):
return self.model(data)
Actor 更新
API:更新 actor 模型参数
from verl.single_controller.base import Worker
import ray
@ray.remote
class DPOActor(Worker):
def __init__(self):
super().__init__()
self.model = Model()
self.model = FSDP(self.model) # 或其他分布式策略
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
self.loss_fn = xxx
def update(self, data):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(logits)
loss.backward()
self.optimizer.step()
注意:如何区分控制进程和分布式计算进程
控制进程通常是直接用
@ray.remote装饰的函数。计算进程都封装在
RayWorkerGroup中。
用户可以复用 PPO 算法中实现的大部分分布式计算逻辑,包括 verl/verl/trainer/ppo 中的 FSDP 和 Megatron-LM 后端。
步骤 2:根据不同的分布式场景,实现多进程数据交互的单进程控制
这里要解决的核心问题是:单个进程如何将数据发送到多个进程,驱动多进程计算,以及控制进程如何获取多进程计算的结果。 首先,我们在控制进程中初始化多进程 WorkerGroup。
@ray.remote(num_cpus=1)
def main_task(config):
# 构造 SampleGenerator
resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs
ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
# 将 SampleGenerator 放入资源池
worker_group = RayWorkerGroup(resource_pool, ray_cls)
# 构造参考策略
从图中可以看到,在控制进程中,多个进程被封装到一个 RayWorkerGroup 中。在这个 WorkerGroup 内部,有一个 self._workers 成员,其中每个 worker 都是一个 SampleGenerator 的 RayActor (https://docs.ray.io/en/latest/ray-core/actors.html)。ray_trainer.md 也提供了一个 MegatronRayWorkerGroup 的实现。
假设模型使用 FSDP 进行分布式部署,并且控制进程上有一批数据,对于数据并行,底层调用过程如下:
data = xxx
data_list = data.chunk(dp_size)
output = []
for d in data_list:
# worker_group._workers[i] 是一个 SampleGenerator
output.append(worker_group._workers[i].generate_sequences.remote(d))
output = ray.get(output)
output = torch.cat(output)
单进程调用多进程涉及以下 3 个步骤:
在控制进程上将数据拆分成 DP 部分。
将数据发送到远程,通过 RPC 调用远程计算,并利用多进程计算。
在控制进程上获取每个 worker 的计算结果并合并。
频繁在控制器进程中调用这 3 个步骤会严重影响代码的可读性。在 verl 中,我们抽象并封装了这 3 个步骤,使得 worker 的方法 + 分发 + 收集可以注册到 worker_group 中。
from verl.single_controller.base.decorator import register
def dispatch_data(worker_group, data):
return data.chunk(worker_group.world_size)
def collect_data(worker_group, data):
return torch.cat(data)
dispatch_mode = {
'dispatch_fn': dispatch_data,
'collect_fn': collect_data
}
@register(dispatch_mode=dispatch_mode)
def generate_sequences(self, data):
pass
这样,我们就可以在控制(驱动)进程(这是一个单进程)上直接通过 worker_group 调用 worker 中的方法:
output = worker_group.generate_sequences(data)
这一行代码包含了数据拆分、数据分发和计算,以及数据收集。
此外,每个模型的模型并行大小通常是固定的,包括 dp、tp、pp。因此,对于这些常见的分布式场景,我们已经在 decorator.py 中预先实现了特定的分发和收集方法,可以直接用于封装计算。
from verl.single_controller.base.decorator import register, Dispatch
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, data: DataProto) -> DataProto:
pass
这里要求数据接口为 DataProto。 DataProto 的定义在 protocol.py 中。
步骤 3:主训练循环
有了上述训练流程,我们就可以实现算法的控制流。建议 main_task 也作为一个 ray remote 进程。
@ray.remote(num_cpus=1)
def main_task(config):
# 构造 SampleGenerator
resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs
ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
# 将 SampleGenerator 放入资源池
sample_gen = RayWorkerGroup(resource_pool, ray_cls)
# 构造参考策略
ray_cls = RayClassWithInitArgs(ReferencePolicy)
ref_policy = RayWorkerGroup(resource_pool, ray_cls)
# 构造 actor
ray_cls = RayClassWithInitArgs(DPOActor)
dpo_policy = RayWorkerGroup(resource_pool, ray_cls)
dataloader = DataLoader()
for data in dataloader:
# 生成数据
data = sample_gen.generate_sequences(data)
# 为每个数据生成分数
data = generate_scores(data)
# 生成成对数据
data = generate_pairwise_data(data)
# 生成 ref_log_prob
data.batch['ref_log_prob'] = ref_policy.infer(data)
# 使用 dpo 更新
dpo_policy.update(data)
# 记录日志
这里,不同的 WorkerGroups 可以放置在同一个资源池中,或者使用 create_colocated_worker_cls 放置在不同的资源池中,就像在 ray_trainer.py 中一样。