``verl.single_controller`` 的设计 ============================================== 最后更新: 05/21/2025. **作者:** `王璋 `__ 前言 ------- 本文档是为 ``verl`` 的开发者准备的,特别是那些对理解或贡献 ``verl.single_controller`` 模块感兴趣的开发者。本文档并非面向最终用户,而是面向希望了解架构原理和内部机制的贡献者。 -------------- 起源 ------ ``single_controller`` 模块起源于我收到的一项请求——将一个简易的单进程 RLHF 脚本适配到一个分布式系统中,并且希望改动尽可能小,同时保持调试的便利性。 一种常见的做法,例如使用 PyTorch 的 Distributed Data Parallel (DDP),通常是通过包装 ``nn.Module`` 并启动多个进程,这些进程以不同的 rank 执行相同的函数。然而,在分布式 RLHF 的场景下,这种方法存在两个主要局限: - 难以表示 PPO 所需的多个 DAG; - 难以在训练过程中检查中间张量。 为了保持调试的便利性,我们选择了另一种方法——将训练循环分解为定义明确的阶段,如 ``generate_sequences``、``compute_advantages`` 等。 我们选择 `Ray `__ 作为 ``verl`` 的初始后端,因为它能够将 Python 类方法暴露为 RPC 端点。然而,Ray 的默认模型仅支持“一次方法调用,一次 RPC”,而 LLM 的训练通常需要跨多个进程进行协调。 为了隐藏这种为了单次方法调用而进行的多次 Ray actor 调用对用户的影响,我们引入了以下组件: - ``WorkerGroup`` – 管理一组远程 worker,并为多进程分布式计算提供统一的接口; - ``ResourcePool`` – 将计算资源绑定到 worker 进程; - ``ClassWithArgs`` – 支持使用指定的初始化参数进行延迟的远程实例化。 -------------- 运行示例:“``generate_sequences``” ----------------------------------------- 为了说明其设计,我们 walkthrough 了 ``ActorRolloutRefWorker`` 类中的 ``generate_sequences`` 方法是如何被注册和跨分布式 worker 调用的。 -------------- 步骤 1: 使用装饰器注册 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 第一步是定义 ``generate_sequences`` 方法,并用 ``@register`` 装饰它,因为它将在 driver 脚本中被调用。 **来源:** `fsdp_workers.py `__ .. code:: python class ActorRolloutRefWorker(Worker): ... @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): prompts = prompts.to(torch.cuda.current_device()) ... ``@register`` 装饰器会为 ``generate_sequences`` 方法添加元数据。目前,它不改变函数的实际功能,而是通过一个特殊的键(``MAGIC_ATTR``)附加属性: **来源:** `decorator.py `__ .. code:: python def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): ... def decorator(func): @wraps(func) def inner(*args, **kwargs): if materialize_futures: args, kwargs = _materialize_futures(*args, **kwargs) return func(*args, **kwargs) attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} setattr(inner, MAGIC_ATTR, attrs) return inner return decorator 如代码所示,``dispatch_mode``、``execute_mode`` 和 ``blocking`` 的值被附加到了 ``generate_sequences`` 方法。 -------------- 步骤 2: 初始化时绑定 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 这些附加的属性会在 ``ActorRolloutRefWorker``(包装在 ``RayClassWithArgs`` 中)被传递到 ``RayWorkerGroup`` 时被提取和使用。 **来源:** `main_generation.py `__ .. code:: python ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) 在 ``RayWorkerGroup`` 的 `初始化 `__ 过程中,会发生两个关键步骤: 1. 创建 Worker 实例(Ray actors): `RayWorkerGroup._init_with_resource_pool `__ 2. 将带有 ``@register`` 装饰器的方法绑定到 ``RayWorkerGroup``: `RayWorkerGroup._bind_worker_method `__ .. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true :alt: initialization_and_binding_of_worker_group worker_group 的初始化和绑定 绑定过程是 ``verl.single_controller`` 的核心。 **关键函数:** `WorkerGroup._bind_worker_method `__ .. code:: python def _bind_worker_method(self, user_defined_cls, func_generator): ... for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) assert callable(method) except Exception: continue # Skip properties <<>> 当方法带有 ``MAGIC_ATTR`` 时,``@register`` 设置的属性就会被提取: .. code:: python <<>> if hasattr(method, MAGIC_ATTR): attribute = getattr(method, MAGIC_ATTR) dispatch_mode = attribute["dispatch_mode"] execute_mode = attribute["execute_mode"] blocking = attribute["blocking"] <<>> 如上流程图所示,这些属性会被传入 ``func_generator``。但是,``func_generator`` 需要 ``method_name``、``dispatch_fn``、``collect_fn``、``execute_fn`` 和 ``blocking`。我们需要从 `DISPATCH_MODE_FN_REGISTRY `__ 中找到与 ``dispatch_mode``(``DP_COMPUTE_PROTO``)相对应的 ``dispatch_fn`` 和 ``collect_fn``: .. code:: python3 DISPATCH_MODE_FN_REGISTRY = { Dispatch.ONE_TO_ALL: { "dispatch_fn": dispatch_one_to_all, "collect_fn": collect_all_to_all, }, ... Dispatch.DP_COMPUTE_PROTO: { "dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto, }, ... } 类似地,``execute_fn`` 是通过 ``execute_mode`` 选择的,并由以下代码提取: .. code:: python <<>> # get execute_fn_name execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) wg_execute_fn_name = execute_mode["execute_fn_name"] # get execute_fn from string try: execute_fn = getattr(self, wg_execute_fn_name) assert callable(execute_fn), "execute_fn must be callable" except Exception: print(f"execute_fn {wg_execute_fn_name} is invalid") raise <<>> 以 ``generate_sequences`` 为例: - ``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` - ``dispatch_fn = dispatch_dp_compute_data_proto`` - ``collect_fn = collect_dp_compute_data_proto`` - ``execute_fn = RayWorkerGroup.execute_all`` ONE_TO_ALL 对比 DP_COMPUTE_PROTO ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``dispatch_mode`` 与 ``dispatch_fn`` 和 ``collect_fn`` 相关联。顾名思义,``dispatch_fn`` 处理 ``WorkerGroup`` 中的输入参数,并生成一个输入参数列表(batch),每个参数将传递给 ``WorkerGroup`` 中的一个 worker。 ``ONE_TO_ALL` 的 ``dispatch_fn`` 是 `dispatch_one_to_all `__,它只是将所有输入参数复制 N 份,其中 N 等于附加到 ``worker_group`` 的 Worker 数量: .. code:: python def dispatch_one_to_all(worker_group, *args, **kwargs): args = tuple([arg] * worker_group.world_size for arg in args) kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} return args, kwargs ``DP_COMPUTE_PROTO`` 的 ``dispatch_fn`` 是 `dispatch_dp_compute_data_proto `__,它使用 ``DataProto.chunk`` 将一个大的 ``DataProto`` 分割成 N 个小的 ``DataProto``,其中 N 等于 ``worker_group`` 的 world_size(Worker 数量): .. code:: python def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) # Note: enable auto padding for dp compute DatapProto splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( worker_group.world_size, *args, **kwargs, ) return splitted_args, splitted_kwargs ``collect_fn`` 遵循相同的模式,处理来自 ``WorkerGroup`` 所有 worker 的返回值列表(batch),并将其合并为一个列表(如 ``collect_all_to_all`` 所做)或一个大的 ``DataProto``(如 ``collect_dp_compute_data_proto`` 所做)。 最后,使用 ``func_generator`` 动态生成一个新方法,并将其添加到 ``WorkerGroup`` 实例中: .. code:: python <<>> # bind a new method to the RayWorkerGroup func = func_generator( self, method_name, dispatch_fn=dispatch_fn, collect_fn=collect_fn, execute_fn=execute_fn, blocking=blocking, ) try: setattr(self, method_name, func) method_names.append(method_name) except Exception as e: raise ValueError(f"Fail to set method_name {method_name}") from e 这使得该方法可以通过 ``WorkerGroup`` 接口进行调用。 -------------- 步骤 3: 调用链 ~~~~~~~~~~~~~~~~~~ 以上所有机制确保了分布式调用与单进程调用具有相同的体验。在原始的单进程脚本中,代码如下所示: .. code:: python rollout = Rollout() rollout.generate_sequences(batch) 使用 ``verl``,多进程程序变为: .. code:: python rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout)) rollout.generate_sequences(batch) .. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true :alt: call_chain_of_generate_sequences generate_sequences 的调用链 在这个简单的调用背后: - ``dispatch_fn`` 将输入分散到各个 worker - ``execute_fn`` 执行实际的远程调用 - ``collect_fn`` 收集结果 所有这些都被抽象掉了,使得开发者能够以最小的改动来编写分布式代码,而无需改变现有逻辑。 -------------- 超越 RL 后训练:``verl.single_controller`` 的泛化能力 ---------------------------------------------------------------- ``verl.single_controller`` 模块的泛化能力远超强化学习。它提供了一个简洁的抽象,用于批量处理远程方法调用,并自动处理输入/输出。 通过最小化单进程脚本和多进程脚本之间的差异,``verl.single_controller`` 为更广泛的领域开启了分布式计算的大门——不仅仅局限于 RL 后训练。 我们希望这个设计能够激发社区提供更多的示例和扩展。