Agentic RL 训练

最后更新时间:2025年7月15日。

概览

Agentic RL 的目标是通过强化学习(RL)来提升后端模型的性能,并将其应用于 Agent。在训练过程中,开发了一系列功能:

  1. 基于服务器的异步 rollouts

  2. 多轮对话和工具调用

  3. 基于 LangGraph 的 Agent

本文档将解释系统原理和使用方法,以帮助用户实现 Agentic RL。

基于服务器的异步 Rollout

由于 Agent 需要通过各种工具调用与环境进行交互,为了避免在等待工具调用返回结果时 GPU 空闲,我们采用了基于 asyncio 的协程机制来异步执行每个 rollout 请求,从而提高训练性能。为了支持异步 rollout,我们将推理引擎(服务器)和 Agent(客户端)在架构上分离,实现了一个基于服务器的系统,其目标如下:

  1. 启用负载均衡机制,在多个 GPU 之间分配负载,并减少长尾请求对性能的影响。为此,我们将流模式(recipestream_mode)下的调度能力实现为一种 recipe。

  2. 防止 Agent 特有的功能(如 tracing)影响推理引擎。

系统架构

https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true

有关内部设计的更多详情,请参阅 Agent Loop

系统组件

“generate” 接口

客户端和服务器之间使用的是基于 ray actor 的 “generate” 函数,而不是标准的 chat completion API。这是因为 token 和文本之间的转换可能是不可逆的。例如,从 “<think>” 转换而来的 token 将与 LLM 生成的 token 不同。在训练阶段,必须严格使用 LLM 推理生成的 token,以避免在计算 advantage 时出现不准确,这可能会影响模型性能。让服务器提供一个基于 token 的 API 有助于客户端维护由工具调用生成的文本与 LLM 返回的 token 之间的关系,从而输出正确的 token 用于训练。

推理引擎适配 AsyncServer 统一向上层提供 generate 函数,并为 SGLang 和 vLLM 分别提供不同的实现,以隐藏底层差异:

  1. SGLang AsyncServer 使用 SGLang 引擎的 async_generate 接口,该接口位于每个 TP 组的第一个 GPU 上。因此,AsyncServer 需要通过 ray actor 远程调用 async_generate

  2. vLLM AsyncServer 使用 vLLM 引擎的 generate 接口,该接口可以通过 ZMQ 与 TP 组中的 GPU 进行通信,并可以直接在 AsyncServer 中调用。

使用示例

请参考 GSM8K 示例 来准备数据集和模型检查点。

使用 agent loop 需要设置两个选项:

  • data.return_raw_chat=True

  • actor_rollout_ref.rollout.mode=async

此示例默认使用 sglang 推理引擎,您也可以修改 rollout_name 来使用 vllm。

bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh

多轮对话和工具调用

请参考 多轮 Rollout 支持 来准备工具和配置文件。

Tool Agent Loop 有一个额外的要求:向数据集中添加 “agent_name” 字段。在 rollout 过程中,它会根据此字段选择使用 tool_agent_loop 还是 `single_turn_agent`(默认)。

使用示例

# 安装 mlflow 以查看 toolcall 和 llm trace
pip install mlflow

# 这将下载并预处理 GSM8K 数据集到 ~/data/gsm8k/ 并添加 "agent_name" 字段。
python examples/data_preprocess/gsm8k_tool_agent_loop.py

# 启动训练,启用工具调用和 mlflow trace 以帮助调试 rollout 细节
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh

# 训练完成后,启动一个 mlflow 服务器以查看 trace
mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db

# 然后您可以通过浏览器访问 http://<您的 IP 地址>:5000 来查看 trace

注意:在训练过程中,由于模型有时可能无法生成正确的 toolcall 标签,控制台会输出错误消息 “Failed to decode tool call”,这并不表示训练异常。

有关 trace 功能的更多信息,请参阅 Rollout trace

Agent 框架

系统架构

https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true

系统组件

有关更多详情,请参考 doc “recipe/langgraph_agent/example/README.md”。