# 策略回滚重要性采样 最后更新:2025年11月10日。 本文档提供了 verl 中策略回滚重要性采样 (IS) 实现的全面概述。 ## 参考资料 - [速度如何扼杀稳定性:从训练-推理不匹配中揭示强化学习崩溃](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - [你的高效强化学习框架在秘密地为你带来离策略训练](https://fengyao.notion.site/off-policy-rl) ## 概述 策略回滚重要性采样用于校正在以下两者之间的分布不匹配: - **回滚策略**:例如,vLLM 配合 BFloat16 - **训练策略**:例如,FSDP 配合 FP32 这种不匹配可能导致梯度估计有偏差和训练不稳定。策略回滚 IS 应用重要性采样权重来纠正这些偏差。 ## 配置 ```yaml # 策略回滚 IS 配置(全部在算法配置中) algorithm: # 主要控制:设置阈值以启用(null = 禁用) rollout_is_threshold: 2.0 # 是否将权重应用于损失(默认为 false = 仅指标) rollout_is: true rollout_is_threshold_lower: null # 自动倒数 rollout_is_level: token rollout_is_mode: truncate rollout_is_veto_threshold: 1e-4 # 必需:启用 log prob 计算 actor_rollout_ref: rollout: calculate_log_probs: true ``` 主要特点: - ✅ 三种聚合级别:token、sequence、geometric - ✅ 两种界限模式:truncate、mask - ✅ 双阈值支持(上限/下限) - ✅ 否决机制应对灾难性异常值 - ✅ 30+ 项全面指标 - ✅ 对数空间计算以提高数值稳定性 - ✅ 内存高效的实现 ## 文件 ### **核心实现** - `verl/trainer/ppo/mismatch_helper.py` - 包含 `compute_rollout_importance_weights()` 和 `compute_is_metrics()` - `verl/trainer/ppo/core_algos.py` - PPO 与策略回滚 IS 的集成 - `verl/workers/actor/dp_actor.py` - 指标收集和日志记录 ### **配置文件** - `verl/trainer/config/algorithm.py` - `AlgoConfig` 中的策略回滚 IS 参数 - `verl/workers/config/actor.py` - `ActorConfig` 中的策略回滚 IS 参数 - `verl/trainer/config/actor/actor.yaml` - 策略回滚 IS 配置部分 - `verl/trainer/config/ppo_trainer.yaml` - 包含策略回滚 IS 的算法配置 ### **文档** - `docs/examples/config.rst` - 配置参数描述 ### **示例脚本** - `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - 带有策略回滚 IS 的 DAPO 示例 - `examples/rollout_importance_sampling/README.md` - 全面的使用指南 - `examples/rollout_importance_sampling/run_with_rollout_is.sh` - 基本示例 ### **测试** - `tests/trainer/ppo/test_rollout_is.py` - 单元测试 - `tests/trainer/ppo/test_rollout_is_integration.py` - 集成测试 ## 配置参数 ### `algorithm.rollout_is_threshold` (float 或 null) **主要开关。** IS 权重的上限阈值。 - `null` = 禁用(不计算,无指标) - `float` 值(例如,2.0) = 启用(计算权重和指标) ### `algorithm.rollout_is` (bool) 是否将 IS 权重应用于策略损失。默认值:`False` - `true` = 将权重应用于损失(完整 IS 校正) - `false` = 仅计算指标(在启用前用于监控很有用) **推荐的阈值范围:** - Token 级别:1.5 - 5.0 - Sequence 级别:2.0 - 10.0 - Geometric 级别:1.0002 - 1.001 ### `algorithm.rollout_is_threshold_lower` (float 或 null) IS 权重的下限阈值。如果为 `null`,则默认为 1/upper(倒数)。 ### `algorithm.rollout_is_level` (str) IS 权重的聚合级别: - `"token"`:每个 token 的比率 - `"sequence"`:比率的乘积 - `"geometric"`:几何平均值(实验性) ### `algorithm.rollout_is_mode` (str) 界限模式: - `"truncate"`:仅将权重限制在上限 - `"mask"`:将权重置零,使其位于 [lower, upper] 之外 ### `algorithm.rollout_is_veto_threshold` (float) 每个 token 的否决阈值。如果任何 token 的比率 < 此值,则整个 sequence 被拒绝。 默认值:`1e-4`(比率误差 10,000 倍) ## 用法 ### 基本设置 ```yaml algorithm: rollout_is_threshold: 2.0 # 主要控制 rollout_is: true # 应用于损失(默认:false) rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true # 必需! ``` ### 指标 所有指标都以 `mismatch/` 为前缀。例如,`rollout_is_mean` 在日志中显示为 `mismatch/rollout_is_mean`。 #### **核心 IS 权重指标** - **`rollout_is_mean`**:所有有效 token 的平均重要性采样权重 - **理想值**:接近 1.0(表示分布不匹配最小) - **警告**:< 0.5 或 > 2.0 表明策略不匹配显著 - **`rollout_is_std`**:IS 权重的标准差 - **理想值**:< 0.5 以实现稳定训练 - **警告**:> 1.0 表明方差较高,可能需要更严格的阈值 - **`rollout_is_min`**:观察到的最小 IS 权重 - 显示了被低估最多的 token/sequence - **`rollout_is_max`**:观察到的最大 IS 权重(在截断/屏蔽之前) - 显示了被高估最多的 token/sequence - 与 `rollout_is_threshold` 比较以了解截断的影响 #### **有效样本量** - **`rollout_is_eff_sample_size`**:IS 加权后的有效样本量 - **公式**:`1 / mean(weights²)`,其中权重已归一化 - **范围**:0.0 到 1.0(占原始 batch 的比例) - **理想值**:> 0.5(保留至少 50% 的有效样本) - **警告**:< 0.3 表示方差很高,丢失了过多的有效样本 #### **否决机制指标** - **`rollout_is_veto_fraction`**:被否决机制拒绝的 sequences 的比例 - **理想值**:< 0.05(否决小于 5%) - **警告**:> 0.1 表明策略差异过大或存在数值问题 - **`rollout_is_catastrophic_token_fraction`**:低于否决阈值的 token 的比例 - 在 sequence 级别否决之前识别有问题的 token - **警告**:> 0.01 表明存在普遍的分布问题 #### **阈值超出指标** - **`rollout_is_ratio_fraction_high`**:超出上限阈值的权重的比例 - 显示了在高位发生截断/屏蔽的频率 - **理想值**:< 0.1(大多数权重在界限内) - **`rollout_is_ratio_fraction_low`**:低于下限阈值的权重的比例 - 显示了在低位发生屏蔽的频率(仅限 mask 模式) - **理想值**:< 0.1 #### **Sequence-Level 指标**(用于 sequence/geometric 模式) - **`rollout_is_seq_mean`**:Sequence 级别的平均 IS 权重 - 对于 sequence 级别的聚合,应与 `rollout_is_mean` 匹配 - **`rollout_is_seq_std`**:Sequence 级别 IS 权重的标准差 - **`rollout_is_seq_min`**:最小的 sequence 级别 IS 权重 - **`rollout_is_seq_max`**:最大的 sequence 级别 IS 权重 - **`rollout_is_seq_max_deviation`**:Sequence 级别上与 1.0 的最大绝对偏差 - **理想值**:< 1.0 - 显示最坏情况的 sequence 不匹配 - **`rollout_is_seq_fraction_high`**:超出上限阈值的 sequences 的比例 - **`rollout_is_seq_fraction_low`**:低于下限阈值的 sequences 的比例 #### **Masking 指标**(仅限 mask 模式) - **`rollout_is_masked_fraction`**:被屏蔽(置零)的 token 的比例 - **理想值**:< 0.1 - **警告**:> 0.3 表示丢失了过多数据 - **`rollout_is_seq_masked_fraction`**:至少有一个被屏蔽 token 的 sequences 的比例 - 显示了屏蔽对 sequence 级别的 ao 响 #### **分布不匹配指标**(训练策略 vs 回滚策略) - **`mismatch_training_ppl`**:训练策略(例如,FSDP FP32)的困惑度 - **公式**:`exp(-mean(log_probs))` - 值越低越好(模型越自信) - **`mismatch_rollout_ppl`**:回滚策略(例如,vLLM BF16)的困惑度 - 如果策略匹配良好,应接近 `mismatch_training_ppl` - **`mismatch_ppl_ratio`**:训练 PPL 与回滚 PPL 的比率 - **公式**:`exp(mean(log(training_ppl / rollout_ppl)))` - **理想值**:接近 1.0 - **含义**:> 1.0 表示训练的置信度低于回滚 - **`mismatch_training_log_ppl`**:训练策略的对数困惑度 - 有助于识别趋势(线性尺度) - **`mismatch_rollout_log_ppl`**:回滚策略的对数困惑度 - **`mismatch_log_ppl_diff`**:对数困惑度均值差 - **公式**:`mean(log_ppl_rollout - log_ppl_training)` - **理想值**:接近 0.0 - 符号指示哪个策略更自信 - **`mismatch_log_ppl_abs_diff`**:对数困惑度绝对差均值 - 不考虑方向的错配幅度 - **`mismatch_log_ppl_diff_max`**:sequences 之间的最大对数困惑度差 - 识别最坏情况的 sequence - **`mismatch_log_ppl_diff_min`**:最小对数困惑度差 - **`mismatch_kl`**:KL 散度 KL(π_rollout || π_training) - **公式**:`mean(log_prob_rollout - log_prob_training)` - **理想值**:接近 0.0(策略匹配) - **警告**:> 0.1 表明错配显著 - **注意**:可能为负(回滚策略的置信度较低) - **`mismatch_k3_kl`**:K3 KL 估计量 - **公式**:`mean(exp(log_ratio) - log_ratio - 1)` - 对于小的 KL 值更稳定 - 始终非负 #### **示例:在代码中访问指标** ```python # 指标由 compute_rollout_importance_weights 返回 from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights weights_proto, metrics = compute_rollout_importance_weights( old_log_prob=training_log_probs, # 来自训练策略 rollout_log_prob=rollout_log_probs, # 来自回滚策略 response_mask=response_mask, rollout_is_level="token", rollout_is_mode="truncate", rollout_is_threshold=2.0, rollout_is_veto_threshold=1e-4, ) # 所有指标都带有 'mismatch/' 前缀 print(f"平均 IS 权重: {metrics['mismatch/rollout_is_mean']:.3f}") print(f"有效样本量: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}") print(f"否决比例: {metrics['mismatch/rollout_is_veto_fraction']:.3f}") print(f"KL 散度: {metrics['mismatch/mismatch_kl']:.3f}") # 检查警告条件 if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0: print("⚠️ 警告:平均 IS 权重远离 1.0,检测到显著的策略不匹配") if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3: print("⚠️ 警告:有效样本量低,IS 权重方差高") if metrics['mismatch/rollout_is_veto_fraction'] > 0.1: print("⚠️ 警告:高否决比例,策略可能差异过大") ``` #### **示例:在训练期间监控指标** ```python # 在你的训练循环中 for epoch in range(num_epochs): for batch_idx, batch in enumerate(dataloader): # ... 回滚阶段 ... # 计算 IS 权重并获取指标 weights_proto, metrics = compute_rollout_importance_weights( old_log_prob=batch.old_log_prob, rollout_log_prob=batch.rollout_log_prob, response_mask=batch.response_mask, rollout_is_level=config.rollout_is_level, rollout_is_mode=config.rollout_is_mode, rollout_is_threshold=config.rollout_is_threshold, rollout_is_veto_threshold=config.rollout_is_veto_threshold, ) # 记录到 tensorboard/wandb for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step=global_step) # 在训练中使用 IS 权重 is_weights = weights_proto.batch["rollout_is_weights"] # ... 将权重应用于策略梯度 ... ``` #### **示例:基于指标的条件警报** ```python def check_rollout_is_health(metrics, config): """检查策略回滚 IS 指标是否表明训练健康。""" warnings = [] # 检查平均 IS 权重 mean_weight = metrics['mismatch/rollout_is_mean'] if mean_weight < 0.5 or mean_weight > 2.0: warnings.append(f"平均 IS 权重 {mean_weight:.3f} 远离 1.0") # 检查有效样本量 ess = metrics['mismatch/rollout_is_eff_sample_size'] if ess < 0.3: warnings.append(f"有效样本量 {ess:.3f} 过低") # 检查否决比例 veto_frac = metrics['mismatch/rollout_is_veto_fraction'] if veto_frac > 0.1: warnings.append(f"否决比例 {veto_frac:.3f} 过高") # 检查方差 std = metrics['mismatch/rollout_is_std'] if std > 1.0: warnings.append(f"IS 权重标准差 {std:.3f} 过高") # 检查 KL 散度 kl = metrics['mismatch/mismatch_kl'] if abs(kl) > 0.1: warnings.append(f"KL 散度 {kl:.3f} 表明错配显著") if warnings: print("⚠️ 策略回滚 IS 健康警告:") for warning in warnings: print(f" - {warning}") return False else: print("✅ 策略回滚 IS 指标看起来健康") return True # 在训练中使用 _, metrics = compute_rollout_importance_weights(...) is_healthy = check_rollout_is_health(metrics, config) if not is_healthy: # 考虑调整配置或调查问题 print("考虑:") print(" - 调整 rollout_is_threshold") print(" - 切换到 geometric 聚合级别") print(" - 检查回滚策略和训练策略是否过于不同") ``` ### 运行示例 从基本的 token 级别截断配置开始: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` 在调整参数之前,先监控 1-2 个 epoch 的指标。 ## 配置示例 ### 示例 1:完整 IS 校正 ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true # 将权重应用于损失 rollout_is_level: token rollout_is_mode: truncate ``` ### 示例 2:仅指标(监控模式) ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # 计算指标,不应用权重 rollout_is_level: token rollout_is_mode: truncate ``` ### 示例 3:带 Mask 的 Geometric Mean ```yaml algorithm: rollout_is_threshold: 1.0002 rollout_is: true rollout_is_threshold_lower: 0.9998 rollout_is_level: geometric rollout_is_mode: mask ``` ### 示例 4:非对称阈值 ```yaml algorithm: rollout_is_threshold: 5.0 rollout_is: true rollout_is_threshold_lower: 0.8 rollout_is_level: token rollout_is_mode: mask ``` ## 故障排除 ### 问题:IS 权重方差过高 **症状**:`rollout_is_std` > 1.0,`rollout_is_eff_sample_size` < 0.3 **解决方案**: 1. 从 `sequence` 级别切换到 `geometric` 级别 2. 调整阈值 3. 验证回滚策略和训练策略差异不大 ### 问题:否决的 sequences 过多 **症状**:`rollout_is_veto_fraction` > 0.1 **解决方案**: 1. 放宽否决阈值:`rollout_is_veto_threshold: 1e-3` 2. 检查对数概率计算是否存在数值问题 3. 验证策略并非完全不同 ### 问题:平均 IS 权重远离 1.0 **症状**:`rollout_is_mean` < 0.5 或 > 2.0 **解决方案**: 1. 验证是否设置了 `calculate_log_probs=True` 2. 检查 `rollout_log_probs` 是否正确传递 3. 检查是否存在系统偏差 ### 调试:可视化指标 **示例:绘制 IS 权重分布** ```python import matplotlib.pyplot as plt import numpy as np def plot_is_metrics(metrics_history): """随训练步数绘制策略回滚 IS 指标。""" fig, axes = plt.subplots(2, 3, figsize=(15, 10)) # 图 1:平均 IS 权重随时间的变化 axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean']) axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='理想') axes[0, 0].set_title('平均 IS 权重') axes[0, 0].set_xlabel('步数') axes[0, 0].legend() # 图 2:有效样本量 axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size']) axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='良好') axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='警告') axes[0, 1].set_title('有效样本量') axes[0, 1].set_xlabel('步数') axes[0, 1].legend() # 图 3:否决比例 axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction']) axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='警告') axes[0, 2].set_title('否决比例') axes[0, 2].set_xlabel('步数') axes[0, 2].legend() # 图 4:KL 散度随时间的变化 axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL') axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL') axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3) axes[1, 0].set_title('KL 散度') axes[1, 0].set_xlabel('步数') axes[1, 0].legend() # 图 5:PPL 比例随时间的变化 axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio']) axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='理想') axes[1, 1].set_title('PPL 比例(训练/回滚)') axes[1, 1].set_xlabel('步数') axes[1, 1].legend() # 隐藏未使用的子图 axes[1, 2].axis('off') plt.tight_layout() plt.savefig('rollout_is_metrics.png', dpi=150) print("已将图保存到 rollout_is_metrics.png") ``` **示例:训练期间的指标收集** ```python # 随时间收集指标 metrics_history = { 'mismatch/rollout_is_mean': [], 'mismatch/rollout_is_eff_sample_size': [], 'mismatch/rollout_is_veto_fraction': [], 'mismatch/mismatch_kl': [], 'mismatch/mismatch_k3_kl': [], 'mismatch/mismatch_ppl_ratio': [], } # 在训练循环中 for step in range(num_steps): # ... 计算 IS 权重 ... _, metrics = compute_rollout_importance_weights(...) # 存储指标 for key in metrics_history.keys(): if key in metrics: metrics_history[key].append(metrics[key]) # 每 100 步绘制一次图 if step % 100 == 0: plot_is_metrics(metrics_history) ``` ## 性能影响 - **内存开销**:约占模型内存的 1% - **计算开销**:1-3%,取决于级别 - **训练稳定性**:在存在不匹配时显著提高 ## 测试 运行测试套件以验证一切正常: ```bash # 基本单元测试 python test_rollout_is.py # 集成测试(如果可用 pytest) pytest tests/trainer/ppo/test_rollout_is_integration.py -v ``` 预期输出:所有测试通过 ✓ ## 附加资源 - **实现**:`verl/trainer/ppo/mismatch_helper.py` - **示例**:`examples/rollout_importance_sampling/` - **DAPO 示例**:`recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` ## 总结 策略回滚重要性采样提供了: - ✅ 对分布不匹配的鲁棒处理 - ✅ 数值稳定性 - ✅ 用于监控的全面指标 - ✅ 适应不同场景的灵活性 - ✅ 内存高效的计算