PPO Ray Trainer

最后更新时间:2025 年 2 月 12 日。

我们实现了 RayPPOTrainer,这是一个在单个 CPU/GPU 节点(默认是 CPU)的驱动进程上运行的 trainer。

PPORayTrainer 包含 3 个核心函数,用于数据准备、WorkerGroup 初始化和 PPO 训练循环。

数据准备

PPORayTrainer 作为单个进程,负责从数据集中加载完整的样本批次(prompts),然后分发给在不同 GPU 上运行的 worker_groups。

为了通用化数据加载,我们实现了 RLHFDataset 类来加载预处理过的 parquet 文件,对 prompts 应用 chat templates,添加 padding,截断超出最大 prompt 长度的 prompts,然后进行 tokenization。

self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
                                    tokenizer=self.tokenizer,
                                    config=self.config.data)

然后,dataloader 将在 PPO mini batch size 的约束下迭代数据集。

WorkerGroup 初始化

我们首先介绍一个在给定 GPU 集上初始化 actor 模型 WorkerGroup 的基本实现。

# max_colocate_count 表示每个 RayResourcePool 中的 WorkerGroup(即进程)数量
# 对于 FSDP 后端,我们建议使用 max_colocate_count=1,将所有 WorkerGroup 合并为一个。
# 对于 Megatron 后端,我们建议使用 max_colocate_count>1,以便为不同的模型利用不同的 WorkerGroup
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,
                                use_gpu=True,
                                max_colocate_count=1)
# 定义将在远程初始化的 actor rollout 类
actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)
# 定义 actor_rollout worker group
actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,
                                                    ray_cls_with_init=actor_rollout_cls,
                                                    default_megatron_kwargs=config.actor_rollout.megatron)

在上面的实现中,不同的 WorkerGroup,如 actor_rollout_worker_groupcritic_worker_groupref_worker_group 位于单独的进程中。

驱动进程随后可以调用 actor_rollout_worker_group 和其他角色的分布式计算函数来构建 RL 训练循环。

对于驻留在同一 GPU 集中的模型,我们提供了更细粒度的优化,即将同一进程中不同角色的 worker_group 进行合并。此优化可以节省不同进程中的冗余 CUDA/分布式上下文。

# 初始化 WorkerGroup
# 注意:如果您想为每个角色使用不同的资源池,以支持不同的并行大小,
# 则不应使用 `create_colocated_worker_cls`。而是直接将不同的资源池传递给不同的 worker groups。
# 更多信息请参见 TODO(url)。
all_wg = {}
for resource_pool, class_dict in self.resource_pool_to_cls.items():
    worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
    wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
    spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
    all_wg.update(spawn_wg)

if self.use_critic:
    self.critic_wg = all_wg['critic']
    self.critic_wg.init_model()

if self.use_reference_policy:
    self.ref_policy_wg = all_wg['ref']
    self.ref_policy_wg.init_model()

if self.use_rm:
    self.rm_wg = all_wg['rm']
    self.rm_wg.init_model()

# 我们应该最后创建 rollout,以便 vllm 能够更好地估计 kv 缓存内存
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()

Note

对于 megatron 后端,如果我们把 worker_groups 合并到同一个进程中,所有角色将使用相同的 3D 并行大小。为了优化这一点,我们可能需要为同一分布式上下文中的每个角色维护几个 3D 进程组。如果您想为不同角色使用不同的 3D 并行大小,请遵循第一个代码块的类似架构来初始化每个角色的 worker_group

PPO 训练循环

我们通过调用每个角色的 worker_group 中的函数来实现 PPO 训练循环。每个函数的输入和输出数据是 protocol.py 中实现的 DataProto 对象。在训练循环中,trainer 会按照封装在 worker 函数中的传输协议,将数据分发到/从不同的 GPU 收集。PPO micro batches 的计算在 update_actorupdate_critic 函数中进行。

要扩展到其他 RLHF 算法,例如 DPO、GRPO,请参阅 扩展到其他 RL(HF) 算法

def fit(self):
    """
    PPO 的训练循环。
    驱动进程只需要通过 RPC 调用 worker group 的计算函数来构造 PPO 数据流。
    轻量级的 advantage 计算在驱动进程上完成。
    """
    from verl.utils.tracking import Tracking
    from omegaconf import OmegaConf

    logger = Tracking(project_name=self.config.trainer.project_name,
                        experiment_name=self.config.trainer.experiment_name,
                        default_backend=self.config.trainer.logger,
                        config=OmegaConf.to_container(self.config, resolve=True))

    global_steps = 0

    # 在训练前进行验证
    # 目前,我们只支持使用 reward_function 进行验证。
    if self.val_reward_fn is not None:
        val_metrics = self._validate()
        pprint(f'Initial validation metrics: {val_metrics}')

    for epoch in range(self.config.trainer.total_epochs):
        for batch_dict in self.train_dataloader:
            metrics = {}

            batch: DataProto = DataProto.from_single_dict(batch_dict)
            # batch = batch.to('cuda')

            # pop those keys for generation
            gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

            # generate a batch
            with Timer(name='gen', logger=None) as timer:
                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
            metrics['timing/gen'] = timer.last

            batch = batch.union(gen_batch_output)

            if self.use_reference_policy:
                # compute reference log_prob
                with Timer(name='ref', logger=None) as timer:
                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                    batch = batch.union(ref_log_prob)
                metrics['timing/ref'] = timer.last

            # compute values
            with Timer(name='values', logger=None) as timer:
                values = self.critic_wg.compute_values(batch)
                batch = batch.union(values)
            metrics['timing/values'] = timer.last

            with Timer(name='adv', logger=None) as timer:
                # compute scores. Support both model and function-based.
                # We first compute the scores using reward model. Then, we call reward_fn to combine
                # the results from reward model and rule-based results.
                if self.use_rm:
                    # we first compute reward model score
                    reward_tensor = self.rm_wg.compute_rm_score(batch)
                    batch = batch.union(reward_tensor)

                # we combine with rule-based rm
                reward_tensor = self.reward_fn(batch)
                batch.batch['token_level_scores'] = reward_tensor

                # compute rewards. apply_kl_penalty if available
                batch, kl_metrics = apply_kl_penalty(batch,
                                                        kl_ctrl=self.kl_ctrl_in_reward,
                                                        kl_penalty=self.config.algorithm.kl_penalty)
                metrics.update(kl_metrics)

                # compute advantages, executed on the driver process
                batch = compute_advantage(batch,
                                            self.config.algorithm.gamma,
                                            self.config.algorithm.lam,
                                            adv_estimator=self.config.algorithm.adv_estimator)
            metrics['timing/adv'] = timer.last

            # update critic
            if self.use_critic:
                with Timer(name='update_critic', logger=None) as timer:
                    critic_output = self.critic_wg.update_critic(batch)
                metrics['timing/update_critic'] = timer.last
                critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
                metrics.update(critic_output_metrics)

            # implement critic warmup
            if self.config.trainer.critic_warmup <= global_steps:
                # update actor
                with Timer(name='update_actor', logger=None) as timer:
                    actor_output = self.actor_rollout_wg.update_actor(batch)
                metrics['timing/update_actor'] = timer.last
                actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                metrics.update(actor_output_metrics)

            # validate
            if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
                with Timer(name='testing', logger=None) as timer:
                    val_metrics: dict = self._validate()
                    val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
                metrics['timing/testing'] = timer.last
                metrics.update(val_metrics)

            # collect metrics
            data_metrics = compute_data_metrics(batch=batch)
            metrics.update(data_metrics)

            # TODO: make a canonical logger that supports various backend
            logger.log(data=metrics, step=global_steps)

            if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
                actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
                                                f'global_step_{global_steps}')
                actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')
                self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

                if self.use_critic:
                    critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
                                                        f'global_step_{global_steps}')
                    critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
                    self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

            global_steps += 1

    # 在训练后进行验证
    if self.val_reward_fn is not None:
        val_metrics = self._validate()
        pprint(f'Final validation metrics: {val_metrics}')