策略回滚重要性采样

最后更新:2025年11月10日。

本文档提供了 verl 中策略回滚重要性采样 (IS) 实现的全面概述。

参考资料

概述

策略回滚重要性采样用于校正在以下两者之间的分布不匹配:

  • 回滚策略:例如,vLLM 配合 BFloat16

  • 训练策略:例如,FSDP 配合 FP32

这种不匹配可能导致梯度估计有偏差和训练不稳定。策略回滚 IS 应用重要性采样权重来纠正这些偏差。

配置

# 策略回滚 IS 配置(全部在算法配置中)
algorithm:
  # 主要控制:设置阈值以启用(null = 禁用)
  rollout_is_threshold: 2.0
  # 是否将权重应用于损失(默认为 false = 仅指标)
  rollout_is: true
  rollout_is_threshold_lower: null  # 自动倒数
  rollout_is_level: token
  rollout_is_mode: truncate
  rollout_is_veto_threshold: 1e-4

# 必需:启用 log prob 计算
actor_rollout_ref:
  rollout:
    calculate_log_probs: true

主要特点:

  • ✅ 三种聚合级别:token、sequence、geometric

  • ✅ 两种界限模式:truncate、mask

  • ✅ 双阈值支持(上限/下限)

  • ✅ 否决机制应对灾难性异常值

  • ✅ 30+ 项全面指标

  • ✅ 对数空间计算以提高数值稳定性

  • ✅ 内存高效的实现

文件

核心实现

  • verl/trainer/ppo/mismatch_helper.py - 包含 compute_rollout_importance_weights()compute_is_metrics()

  • verl/trainer/ppo/core_algos.py - PPO 与策略回滚 IS 的集成

  • verl/workers/actor/dp_actor.py - 指标收集和日志记录

配置文件

  • verl/trainer/config/algorithm.py - AlgoConfig 中的策略回滚 IS 参数

  • verl/workers/config/actor.py - ActorConfig 中的策略回滚 IS 参数

  • verl/trainer/config/actor/actor.yaml - 策略回滚 IS 配置部分

  • verl/trainer/config/ppo_trainer.yaml - 包含策略回滚 IS 的算法配置

文档

  • docs/examples/config.rst - 配置参数描述

示例脚本

  • recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh - 带有策略回滚 IS 的 DAPO 示例

  • examples/rollout_importance_sampling/README.md - 全面的使用指南

  • examples/rollout_importance_sampling/run_with_rollout_is.sh - 基本示例

测试

  • tests/trainer/ppo/test_rollout_is.py - 单元测试

  • tests/trainer/ppo/test_rollout_is_integration.py - 集成测试

配置参数

algorithm.rollout_is_threshold (float 或 null)

主要开关。 IS 权重的上限阈值。

  • null = 禁用(不计算,无指标)

  • float 值(例如,2.0) = 启用(计算权重和指标)

algorithm.rollout_is (bool)

是否将 IS 权重应用于策略损失。默认值:False

  • true = 将权重应用于损失(完整 IS 校正)

  • false = 仅计算指标(在启用前用于监控很有用)

推荐的阈值范围:

  • Token 级别:1.5 - 5.0

  • Sequence 级别:2.0 - 10.0

  • Geometric 级别:1.0002 - 1.001

algorithm.rollout_is_threshold_lower (float 或 null)

IS 权重的下限阈值。如果为 null,则默认为 1/upper(倒数)。

algorithm.rollout_is_level (str)

IS 权重的聚合级别:

  • "token":每个 token 的比率

  • "sequence":比率的乘积

  • "geometric":几何平均值(实验性)

algorithm.rollout_is_mode (str)

界限模式:

  • "truncate":仅将权重限制在上限

  • "mask":将权重置零,使其位于 [lower, upper] 之外

algorithm.rollout_is_veto_threshold (float)

每个 token 的否决阈值。如果任何 token 的比率 < 此值,则整个 sequence 被拒绝。 默认值:1e-4(比率误差 10,000 倍)

用法

基本设置

algorithm:
  rollout_is_threshold: 2.0  # 主要控制
  rollout_is: true           # 应用于损失(默认:false)
  rollout_is_level: token
  rollout_is_mode: truncate

actor_rollout_ref:
  rollout:
    calculate_log_probs: true  # 必需!

指标

所有指标都以 mismatch/ 为前缀。例如,rollout_is_mean 在日志中显示为 mismatch/rollout_is_mean

核心 IS 权重指标

  • rollout_is_mean:所有有效 token 的平均重要性采样权重

    • 理想值:接近 1.0(表示分布不匹配最小)

    • 警告:< 0.5 或 > 2.0 表明策略不匹配显著

  • rollout_is_std:IS 权重的标准差

    • 理想值:< 0.5 以实现稳定训练

    • 警告:> 1.0 表明方差较高,可能需要更严格的阈值

  • rollout_is_min:观察到的最小 IS 权重

    • 显示了被低估最多的 token/sequence

  • rollout_is_max:观察到的最大 IS 权重(在截断/屏蔽之前)

    • 显示了被高估最多的 token/sequence

    • rollout_is_threshold 比较以了解截断的影响

有效样本量

  • rollout_is_eff_sample_size:IS 加权后的有效样本量

    • 公式1 / mean(weights²),其中权重已归一化

    • 范围:0.0 到 1.0(占原始 batch 的比例)

    • 理想值:> 0.5(保留至少 50% 的有效样本)

    • 警告:< 0.3 表示方差很高,丢失了过多的有效样本

否决机制指标

  • rollout_is_veto_fraction:被否决机制拒绝的 sequences 的比例

    • 理想值:< 0.05(否决小于 5%)

    • 警告:> 0.1 表明策略差异过大或存在数值问题

  • rollout_is_catastrophic_token_fraction:低于否决阈值的 token 的比例

    • 在 sequence 级别否决之前识别有问题的 token

    • 警告:> 0.01 表明存在普遍的分布问题

阈值超出指标

  • rollout_is_ratio_fraction_high:超出上限阈值的权重的比例

    • 显示了在高位发生截断/屏蔽的频率

    • 理想值:< 0.1(大多数权重在界限内)

  • rollout_is_ratio_fraction_low:低于下限阈值的权重的比例

    • 显示了在低位发生屏蔽的频率(仅限 mask 模式)

    • 理想值:< 0.1

Sequence-Level 指标(用于 sequence/geometric 模式)

  • rollout_is_seq_mean:Sequence 级别的平均 IS 权重

    • 对于 sequence 级别的聚合,应与 rollout_is_mean 匹配

  • rollout_is_seq_std:Sequence 级别 IS 权重的标准差

  • rollout_is_seq_min:最小的 sequence 级别 IS 权重

  • rollout_is_seq_max:最大的 sequence 级别 IS 权重

  • rollout_is_seq_max_deviation:Sequence 级别上与 1.0 的最大绝对偏差

    • 理想值:< 1.0

    • 显示最坏情况的 sequence 不匹配

  • rollout_is_seq_fraction_high:超出上限阈值的 sequences 的比例

  • rollout_is_seq_fraction_low:低于下限阈值的 sequences 的比例

Masking 指标(仅限 mask 模式)

  • rollout_is_masked_fraction:被屏蔽(置零)的 token 的比例

    • 理想值:< 0.1

    • 警告:> 0.3 表示丢失了过多数据

  • rollout_is_seq_masked_fraction:至少有一个被屏蔽 token 的 sequences 的比例

    • 显示了屏蔽对 sequence 级别的 ao 响

分布不匹配指标(训练策略 vs 回滚策略)

  • mismatch_training_ppl:训练策略(例如,FSDP FP32)的困惑度

    • 公式exp(-mean(log_probs))

    • 值越低越好(模型越自信)

  • mismatch_rollout_ppl:回滚策略(例如,vLLM BF16)的困惑度

    • 如果策略匹配良好,应接近 mismatch_training_ppl

  • mismatch_ppl_ratio:训练 PPL 与回滚 PPL 的比率

    • 公式exp(mean(log(training_ppl / rollout_ppl)))

    • 理想值:接近 1.0

    • 含义:> 1.0 表示训练的置信度低于回滚

  • mismatch_training_log_ppl:训练策略的对数困惑度

    • 有助于识别趋势(线性尺度)

  • mismatch_rollout_log_ppl:回滚策略的对数困惑度

  • mismatch_log_ppl_diff:对数困惑度均值差

    • 公式mean(log_ppl_rollout - log_ppl_training)

    • 理想值:接近 0.0

    • 符号指示哪个策略更自信

  • mismatch_log_ppl_abs_diff:对数困惑度绝对差均值

    • 不考虑方向的错配幅度

  • mismatch_log_ppl_diff_max:sequences 之间的最大对数困惑度差

    • 识别最坏情况的 sequence

  • mismatch_log_ppl_diff_min:最小对数困惑度差

  • mismatch_kl:KL 散度 KL(π_rollout || π_training)

    • 公式mean(log_prob_rollout - log_prob_training)

    • 理想值:接近 0.0(策略匹配)

    • 警告:> 0.1 表明错配显著

    • 注意:可能为负(回滚策略的置信度较低)

  • mismatch_k3_kl:K3 KL 估计量

    • 公式mean(exp(log_ratio) - log_ratio - 1)

    • 对于小的 KL 值更稳定

    • 始终非负

示例:在代码中访问指标

# 指标由 compute_rollout_importance_weights 返回
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights

weights_proto, metrics = compute_rollout_importance_weights(
    old_log_prob=training_log_probs,      # 来自训练策略
    rollout_log_prob=rollout_log_probs,   # 来自回滚策略
    response_mask=response_mask,
    rollout_is_level="token",
    rollout_is_mode="truncate",
    rollout_is_threshold=2.0,
    rollout_is_veto_threshold=1e-4,
)

# 所有指标都带有 'mismatch/' 前缀
print(f"平均 IS 权重: {metrics['mismatch/rollout_is_mean']:.3f}")
print(f"有效样本量: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}")
print(f"否决比例: {metrics['mismatch/rollout_is_veto_fraction']:.3f}")
print(f"KL 散度: {metrics['mismatch/mismatch_kl']:.3f}")

# 检查警告条件
if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:
    print("⚠️  警告:平均 IS 权重远离 1.0,检测到显著的策略不匹配")

if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:
    print("⚠️  警告:有效样本量低,IS 权重方差高")

if metrics['mismatch/rollout_is_veto_fraction'] > 0.1:
    print("⚠️  警告:高否决比例,策略可能差异过大")

示例:在训练期间监控指标

# 在你的训练循环中
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(dataloader):
        # ... 回滚阶段 ...

        # 计算 IS 权重并获取指标
        weights_proto, metrics = compute_rollout_importance_weights(
            old_log_prob=batch.old_log_prob,
            rollout_log_prob=batch.rollout_log_prob,
            response_mask=batch.response_mask,
            rollout_is_level=config.rollout_is_level,
            rollout_is_mode=config.rollout_is_mode,
            rollout_is_threshold=config.rollout_is_threshold,
            rollout_is_veto_threshold=config.rollout_is_veto_threshold,
        )

        # 记录到 tensorboard/wandb
        for metric_name, metric_value in metrics.items():
            logger.log_scalar(metric_name, metric_value, step=global_step)

        # 在训练中使用 IS 权重
        is_weights = weights_proto.batch["rollout_is_weights"]
        # ... 将权重应用于策略梯度 ...

示例:基于指标的条件警报

def check_rollout_is_health(metrics, config):
    """检查策略回滚 IS 指标是否表明训练健康。"""
    warnings = []

    # 检查平均 IS 权重
    mean_weight = metrics['mismatch/rollout_is_mean']
    if mean_weight < 0.5 or mean_weight > 2.0:
        warnings.append(f"平均 IS 权重 {mean_weight:.3f} 远离 1.0")

    # 检查有效样本量
    ess = metrics['mismatch/rollout_is_eff_sample_size']
    if ess < 0.3:
        warnings.append(f"有效样本量 {ess:.3f} 过低")

    # 检查否决比例
    veto_frac = metrics['mismatch/rollout_is_veto_fraction']
    if veto_frac > 0.1:
        warnings.append(f"否决比例 {veto_frac:.3f} 过高")

    # 检查方差
    std = metrics['mismatch/rollout_is_std']
    if std > 1.0:
        warnings.append(f"IS 权重标准差 {std:.3f} 过高")

    # 检查 KL 散度
    kl = metrics['mismatch/mismatch_kl']
    if abs(kl) > 0.1:
        warnings.append(f"KL 散度 {kl:.3f} 表明错配显著")

    if warnings:
        print("⚠️  策略回滚 IS 健康警告:")
        for warning in warnings:
            print(f"  - {warning}")
        return False
    else:
        print("✅ 策略回滚 IS 指标看起来健康")
        return True

# 在训练中使用
_, metrics = compute_rollout_importance_weights(...)
is_healthy = check_rollout_is_health(metrics, config)

if not is_healthy:
    # 考虑调整配置或调查问题
    print("考虑:")
    print("  - 调整 rollout_is_threshold")
    print("  - 切换到 geometric 聚合级别")
    print("  - 检查回滚策略和训练策略是否过于不同")

运行示例

从基本的 token 级别截断配置开始:

bash examples/rollout_importance_sampling/run_with_rollout_is.sh

在调整参数之前,先监控 1-2 个 epoch 的指标。

配置示例

示例 1:完整 IS 校正

algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true  # 将权重应用于损失
  rollout_is_level: token
  rollout_is_mode: truncate

示例 2:仅指标(监控模式)

algorithm:
  rollout_is_threshold: 2.0
  rollout_is: false  # 计算指标,不应用权重
  rollout_is_level: token
  rollout_is_mode: truncate

示例 3:带 Mask 的 Geometric Mean

algorithm:
  rollout_is_threshold: 1.0002
  rollout_is: true
  rollout_is_threshold_lower: 0.9998
  rollout_is_level: geometric
  rollout_is_mode: mask

示例 4:非对称阈值

algorithm:
  rollout_is_threshold: 5.0
  rollout_is: true
  rollout_is_threshold_lower: 0.8
  rollout_is_level: token
  rollout_is_mode: mask

故障排除

问题:IS 权重方差过高

症状rollout_is_std > 1.0,rollout_is_eff_sample_size < 0.3

解决方案

  1. sequence 级别切换到 geometric 级别

  2. 调整阈值

  3. 验证回滚策略和训练策略差异不大

问题:否决的 sequences 过多

症状rollout_is_veto_fraction > 0.1

解决方案

  1. 放宽否决阈值:rollout_is_veto_threshold: 1e-3

  2. 检查对数概率计算是否存在数值问题

  3. 验证策略并非完全不同

问题:平均 IS 权重远离 1.0

症状rollout_is_mean < 0.5 或 > 2.0

解决方案

  1. 验证是否设置了 calculate_log_probs=True

  2. 检查 rollout_log_probs 是否正确传递

  3. 检查是否存在系统偏差

调试:可视化指标

示例:绘制 IS 权重分布

import matplotlib.pyplot as plt
import numpy as np

def plot_is_metrics(metrics_history):
    """随训练步数绘制策略回滚 IS 指标。"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # 图 1:平均 IS 权重随时间的变化
    axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])
    axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='理想')
    axes[0, 0].set_title('平均 IS 权重')
    axes[0, 0].set_xlabel('步数')
    axes[0, 0].legend()

    # 图 2:有效样本量
    axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])
    axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='良好')
    axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='警告')
    axes[0, 1].set_title('有效样本量')
    axes[0, 1].set_xlabel('步数')
    axes[0, 1].legend()

    # 图 3:否决比例
    axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])
    axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='警告')
    axes[0, 2].set_title('否决比例')
    axes[0, 2].set_xlabel('步数')
    axes[0, 2].legend()

    # 图 4:KL 散度随时间的变化
    axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
    axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
    axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3)
    axes[1, 0].set_title('KL 散度')
    axes[1, 0].set_xlabel('步数')
    axes[1, 0].legend()

    # 图 5:PPL 比例随时间的变化
    axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
    axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='理想')
    axes[1, 1].set_title('PPL 比例(训练/回滚)')
    axes[1, 1].set_xlabel('步数')
    axes[1, 1].legend()

    # 隐藏未使用的子图
    axes[1, 2].axis('off')

    plt.tight_layout()
    plt.savefig('rollout_is_metrics.png', dpi=150)
    print("已将图保存到 rollout_is_metrics.png")

示例:训练期间的指标收集

# 随时间收集指标
metrics_history = {
    'mismatch/rollout_is_mean': [],
    'mismatch/rollout_is_eff_sample_size': [],
    'mismatch/rollout_is_veto_fraction': [],
    'mismatch/mismatch_kl': [],
    'mismatch/mismatch_k3_kl': [],
    'mismatch/mismatch_ppl_ratio': [],
}

# 在训练循环中
for step in range(num_steps):
    # ... 计算 IS 权重 ...
    _, metrics = compute_rollout_importance_weights(...)

    # 存储指标
    for key in metrics_history.keys():
        if key in metrics:
            metrics_history[key].append(metrics[key])

    # 每 100 步绘制一次图
    if step % 100 == 0:
        plot_is_metrics(metrics_history)

性能影响

  • 内存开销:约占模型内存的 1%

  • 计算开销:1-3%,取决于级别

  • 训练稳定性:在存在不匹配时显著提高

测试

运行测试套件以验证一切正常:

# 基本单元测试
python test_rollout_is.py

# 集成测试(如果可用 pytest)
pytest tests/trainer/ppo/test_rollout_is_integration.py -v

预期输出:所有测试通过 ✓

附加资源

  • 实现verl/trainer/ppo/mismatch_helper.py

  • 示例examples/rollout_importance_sampling/

  • DAPO 示例recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh

总结

策略回滚重要性采样提供了:

  • ✅ 对分布不匹配的鲁棒处理

  • ✅ 数值稳定性

  • ✅ 用于监控的全面指标

  • ✅ 适应不同场景的灵活性

  • ✅ 内存高效的计算