verl.single_controller 的设计

最后更新: 05/21/2025.

作者: 王璋

前言

本文档是为 verl 的开发者准备的,特别是那些对理解或贡献 verl.single_controller 模块感兴趣的开发者。本文档并非面向最终用户,而是面向希望了解架构原理和内部机制的贡献者。


起源

single_controller 模块起源于我收到的一项请求——将一个简易的单进程 RLHF 脚本适配到一个分布式系统中,并且希望改动尽可能小,同时保持调试的便利性。

一种常见的做法,例如使用 PyTorch 的 Distributed Data Parallel (DDP),通常是通过包装 nn.Module 并启动多个进程,这些进程以不同的 rank 执行相同的函数。然而,在分布式 RLHF 的场景下,这种方法存在两个主要局限: - 难以表示 PPO 所需的多个 DAG; - 难以在训练过程中检查中间张量。

为了保持调试的便利性,我们选择了另一种方法——将训练循环分解为定义明确的阶段,如 generate_sequencescompute_advantages 等。

我们选择 Ray 作为 verl 的初始后端,因为它能够将 Python 类方法暴露为 RPC 端点。然而,Ray 的默认模型仅支持“一次方法调用,一次 RPC”,而 LLM 的训练通常需要跨多个进程进行协调。

为了隐藏这种为了单次方法调用而进行的多次 Ray actor 调用对用户的影响,我们引入了以下组件:

  • WorkerGroup – 管理一组远程 worker,并为多进程分布式计算提供统一的接口;

  • ResourcePool – 将计算资源绑定到 worker 进程;

  • ClassWithArgs – 支持使用指定的初始化参数进行延迟的远程实例化。


运行示例:“generate_sequences

为了说明其设计,我们 walkthrough 了 ActorRolloutRefWorker 类中的 generate_sequences 方法是如何被注册和跨分布式 worker 调用的。


步骤 1: 使用装饰器注册

第一步是定义 generate_sequences 方法,并用 @register 装饰它,因为它将在 driver 脚本中被调用。

来源: fsdp_workers.py

class ActorRolloutRefWorker(Worker):
    ...
    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def generate_sequences(self, prompts: DataProto):
        prompts = prompts.to(torch.cuda.current_device())
        ...

@register 装饰器会为 generate_sequences 方法添加元数据。目前,它不改变函数的实际功能,而是通过一个特殊的键(MAGIC_ATTR)附加属性:

来源: decorator.py

def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
    ...
    def decorator(func):
        @wraps(func)
        def inner(*args, **kwargs):
            if materialize_futures:
                args, kwargs = _materialize_futures(*args, **kwargs)
            return func(*args, **kwargs)

        attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
        setattr(inner, MAGIC_ATTR, attrs)
        return inner

    return decorator

如代码所示,dispatch_modeexecute_modeblocking 的值被附加到了 generate_sequences 方法。


步骤 2: 初始化时绑定

这些附加的属性会在 ActorRolloutRefWorker``(包装在 ``RayClassWithArgs 中)被传递到 RayWorkerGroup 时被提取和使用。

来源: main_generation.py

ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)

RayWorkerGroup初始化 过程中,会发生两个关键步骤:

  1. 创建 Worker 实例(Ray actors): RayWorkerGroup._init_with_resource_pool

  2. 将带有 @register 装饰器的方法绑定到 RayWorkerGroupRayWorkerGroup._bind_worker_method

initialization_and_binding_of_worker_group

worker_group 的初始化和绑定

绑定过程是 verl.single_controller 的核心。

关键函数: WorkerGroup._bind_worker_method

def _bind_worker_method(self, user_defined_cls, func_generator):
    ...
    for method_name in dir(user_defined_cls):
        try:
            method = getattr(user_defined_cls, method_name)
            assert callable(method)
        except Exception:
            continue  # Skip properties
        <<<to be continue 1>>>

当方法带有 MAGIC_ATTR 时,@register 设置的属性就会被提取:

<<<continue 1>>>
if hasattr(method, MAGIC_ATTR):
    attribute = getattr(method, MAGIC_ATTR)
    dispatch_mode = attribute["dispatch_mode"]
    execute_mode = attribute["execute_mode"]
    blocking = attribute["blocking"]

    <<<to be continue 2>>>

如上流程图所示,这些属性会被传入 func_generator。但是,func_generator 需要 method_namedispatch_fncollect_fnexecute_fnblocking`。我们需要从 `DISPATCH_MODE_FN_REGISTRY <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L387>`__ 中找到与 ``dispatch_mode``(``DP_COMPUTE_PROTO)相对应的 dispatch_fncollect_fn

DISPATCH_MODE_FN_REGISTRY = {
    Dispatch.ONE_TO_ALL: {
        "dispatch_fn": dispatch_one_to_all,
        "collect_fn": collect_all_to_all,
    },
    ...
    Dispatch.DP_COMPUTE_PROTO: {
        "dispatch_fn": dispatch_dp_compute_data_proto,
        "collect_fn": collect_dp_compute_data_proto,
    },
    ...
}

类似地,execute_fn 是通过 execute_mode 选择的,并由以下代码提取:

<<<continue 2>>>
# get execute_fn_name
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
wg_execute_fn_name = execute_mode["execute_fn_name"]

# get execute_fn from string
try:
    execute_fn = getattr(self, wg_execute_fn_name)
    assert callable(execute_fn), "execute_fn must be callable"
except Exception:
    print(f"execute_fn {wg_execute_fn_name} is invalid")
    raise
<<<to be continue 3>>>

generate_sequences 为例: - dispatch_mode = Dispatch.DP_COMPUTE_PROTO - dispatch_fn = dispatch_dp_compute_data_proto - collect_fn = collect_dp_compute_data_proto - execute_fn = RayWorkerGroup.execute_all

ONE_TO_ALL 对比 DP_COMPUTE_PROTO

dispatch_modedispatch_fncollect_fn 相关联。顾名思义,dispatch_fn 处理 WorkerGroup 中的输入参数,并生成一个输入参数列表(batch),每个参数将传递给 WorkerGroup 中的一个 worker。

ONE_TO_ALL` ``dispatch_fndispatch_one_to_all,它只是将所有输入参数复制 N 份,其中 N 等于附加到 worker_group 的 Worker 数量:

def dispatch_one_to_all(worker_group, *args, **kwargs):
    args = tuple([arg] * worker_group.world_size for arg in args)
    kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
    return args, kwargs

DP_COMPUTE_PROTOdispatch_fndispatch_dp_compute_data_proto,它使用 DataProto.chunk 将一个大的 DataProto 分割成 N 个小的 DataProto,其中 N 等于 worker_group 的 world_size(Worker 数量):

def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
    from verl.single_controller.base.worker_group import WorkerGroup

    assert isinstance(worker_group, WorkerGroup)
    # Note: enable auto padding for dp compute DatapProto
    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(
        worker_group.world_size,
        *args,
        **kwargs,
    )
    return splitted_args, splitted_kwargs

collect_fn 遵循相同的模式,处理来自 WorkerGroup 所有 worker 的返回值列表(batch),并将其合并为一个列表(如 collect_all_to_all 所做)或一个大的 DataProto``(如 ``collect_dp_compute_data_proto 所做)。

最后,使用 func_generator 动态生成一个新方法,并将其添加到 WorkerGroup 实例中:

<<<continue 3>>>
# bind a new method to the RayWorkerGroup
func = func_generator(
    self,
    method_name,
    dispatch_fn=dispatch_fn,
    collect_fn=collect_fn,
    execute_fn=execute_fn,
    blocking=blocking,
)

try:
    setattr(self, method_name, func)
    method_names.append(method_name)
except Exception as e:
    raise ValueError(f"Fail to set method_name {method_name}") from e

这使得该方法可以通过 WorkerGroup 接口进行调用。


步骤 3: 调用链

以上所有机制确保了分布式调用与单进程调用具有相同的体验。在原始的单进程脚本中,代码如下所示:

rollout = Rollout()
rollout.generate_sequences(batch)

使用 verl,多进程程序变为:

rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout))
rollout.generate_sequences(batch)
call_chain_of_generate_sequences

generate_sequences 的调用链

在这个简单的调用背后: - dispatch_fn 将输入分散到各个 worker - execute_fn 执行实际的远程调用 - collect_fn 收集结果

所有这些都被抽象掉了,使得开发者能够以最小的改动来编写分布式代码,而无需改变现有逻辑。


超越 RL 后训练:verl.single_controller 的泛化能力

verl.single_controller 模块的泛化能力远超强化学习。它提供了一个简洁的抽象,用于批量处理远程方法调用,并自动处理输入/输出。

通过最小化单进程脚本和多进程脚本之间的差异,verl.single_controller 为更广泛的领域开启了分布式计算的大门——不仅仅局限于 RL 后训练。

我们希望这个设计能够激发社区提供更多的示例和扩展。