# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO-like algorithms.
"""
__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, Optional
import numpy as np
import torch
from omegaconf import DictConfig
import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
from verl.utils import as_torch_index, group_mean_std
from verl.utils.import_utils import deprecated
from verl.workers.config import ActorConfig
PolicyLossFn = Callable[
[
torch.Tensor, # old_log_prob
torch.Tensor, # log_prob
torch.Tensor, # advantages
torch.Tensor, # response_mask
str, # loss_agg_mode
Optional[DictConfig | AlgoConfig], # config
torch.Tensor | None, # rollout_log_probs
],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:
"""Register a policy loss function with the given name.
Args:
name (str): The name to register the policy loss function under.
Returns:
function: Decorator function that registers the policy loss function.
"""
def decorator(func: PolicyLossFn) -> PolicyLossFn:
POLICY_LOSS_REGISTRY[name] = func
return func
return decorator
def get_policy_loss_fn(name):
"""Get the policy loss with a given name.
Args:
name: `(str)`
The name of the policy loss.
Returns:
`(callable)`: The policy loss function.
"""
loss_name = name
if loss_name not in POLICY_LOSS_REGISTRY:
raise ValueError(
f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}"
)
return POLICY_LOSS_REGISTRY[loss_name]
class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.
Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
RLOO_VECTORIZED = "rloo_vectorized"
GRPO_VECTORIZED = "grpo_vectorized"
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
"""Decorator to register a advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
"""
def decorator(fn):
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:
raise ValueError(
f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}"
)
ADV_ESTIMATOR_REGISTRY[name] = fn
return fn
return decorator
def get_adv_estimator_fn(name_or_enum):
"""Get the advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
Returns:
`(callable)`: The advantage estimator function.
"""
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name not in ADV_ESTIMATOR_REGISTRY:
raise ValueError(f"Unknown advantage estimator simply: {name}")
return ADV_ESTIMATOR_REGISTRY[name]
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
"""Update the KL coefficient based on current KL divergence.
Args:
current_kl (float): Current KL divergence value.
n_steps (int): Number of steps taken.
"""
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
"""Update method for fixed KL controller (no-op).
Args:
current_kl (float): Current KL divergence value (unused).
n_steps (int): Number of steps taken (unused).
"""
pass
def get_kl_controller(kl_ctrl):
"""Factory function to create appropriate KL controller based on configuration.
Args:
kl_ctrl: Configuration object containing KL controller settings.
Returns:
KL controller instance (FixedKLController or AdaptiveKLController).
Raises:
NotImplementedError: If controller type is not supported.
AssertionError: If adaptive controller horizon is not positive.
"""
if kl_ctrl.type == "fixed":
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
elif kl_ctrl.type == "adaptive":
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
else:
raise NotImplementedError
@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
values: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma is `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
nextvalues = 0
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam_ = delta + gamma * lam * lastgaelam
# skip values and TD-error on observation tokens
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo")
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgoConfig])`
algorithm configuration object
Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED)
def compute_grpo_vectorized_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Vectorized GRPO(outcome-only):
For each group g:
a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g),
then broadcast the scalar across the token dimension (multiplied by response_mask).。
"""
with torch.no_grad():
scores = token_level_rewards.sum(dim=-1)
g = as_torch_index(index, device=scores.device)
mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon)
if norm_adv_by_std_in_grpo:
scalars = (scores - mean_g[g]) / (std_g[g] + epsilon)
else:
scalars = scores - mean_g[g]
advantages = scalars.unsqueeze(-1) * response_mask
return advantages, advantages
@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for Pass@k using a GRPO-style outcome reward formulation.
Only the best response per group gets a non-zero advantage: r_max - r_second_max.
Implemented as described in https://arxiv.org/abs/2503.19595.
Args:
token_level_rewards: (bs, response_length)
response_mask: (bs, response_length)
index: (bs,) → group ID per sample
epsilon: float for numerical stability
config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo"
Returns:
advantages: (bs, response_length)
returns: (bs, response_length)
"""
assert config is not None
# if True, normalize advantage by std within group
norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True)
scores = token_level_rewards.sum(dim=-1) # (bs,)
advantages = torch.zeros_like(scores)
id2scores = defaultdict(list)
id2indices = defaultdict(list)
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
idx = index[i]
id2scores[idx].append(scores[i])
id2indices[idx].append(i)
for idx in id2scores:
rewards = torch.stack(id2scores[idx]) # (k,)
if rewards.numel() < 2:
raise ValueError(
f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}."
)
topk, topk_idx = torch.topk(rewards, 2)
r_max, r_second_max = topk[0], topk[1]
i_max = id2indices[idx][topk_idx[0].item()]
advantage = r_max - r_second_max
if norm_adv_by_std_in_grpo:
std = torch.std(rewards)
advantage = advantage / (std + epsilon)
advantages[i_max] = advantage
advantages = advantages.unsqueeze(-1) * response_mask
return advantages, advantages
@register_adv_est(
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE
) # or simply: @register_adv_est("reinforce_plus_plus_baseline")
def compute_reinforce_plus_plus_baseline_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.stack(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
scores = verl_F.masked_whiten(scores, response_mask) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo")
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.stack(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo")
def compute_opo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = response_mask.sum(dim=-1)
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2len = defaultdict(list)
id2bsl = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
id2len[index[i]].append(response_length[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2bsl[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
score_tensor = torch.stack(id2score[idx])
len_tensor = torch.stack(id2len[idx])
id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2bsl[index[i]]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus")
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
assert config is not None
gamma = config.gamma
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * response_mask[:, t]
advantages = verl_F.masked_whiten(returns, response_mask)
advantages = advantages * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax")
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor,
reward_baselines: torch.Tensor,
response_mask: torch.Tensor,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1) * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg")
def compute_gpg_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
f_norm: float = 1.0,
alpha: float = 1.0,
config=None,
**kwargs,
):
"""
Compute advantage for GPG, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
index: `(np.ndarray)`
shape: (bs,)
epsilon: (float)
f_norm: (float)
alpha: (float)
config: (dict) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
m = torch.count_nonzero(scores)
alpha = bsz / m.clamp(min=1)
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized")
def compute_rloo_vectorized_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device)
c = torch.bincount(inv)[inv].to(scores.dtype)
adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1)
adv = adv.unsqueeze(-1) * response_mask
return adv, adv
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
"""Compute token-level rewards with KL penalty.
Args:
token_level_scores (torch.Tensor): Token-level reward scores.
old_log_prob (torch.Tensor): Log probabilities from current policy.
ref_log_prob (torch.Tensor): Log probabilities from reference policy.
kl_ratio (float): KL penalty coefficient.
Returns:
torch.Tensor: Token-level rewards with KL penalty applied.
"""
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
[docs]
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
"""
Aggregate the loss matrix into a scalar.
Args:
loss_mat: `(torch.Tensor)`:
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`:
shape: (bs, response_length)
loss_agg_mode: (str) choices:
method to aggregate the loss matrix into a scalar.
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
loss = verl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor
# (loss_mask.shape[-1]) should ideally be constant
# throughout training to well-replicate the DrGRPO paper.
# TODO: Perhaps add user-defined normalizer argument to
# agg_loss to ensure divisor stays constant throughout.
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
return loss
[docs]
@deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla")
def compute_policy_loss(
old_log_prob,
log_prob,
advantages,
response_mask,
cliprange=None,
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped policy objective and related metrics for PPO.
Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
clip_ratio_c (float, optional):
Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
Defaults to 3.0.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("vanilla") # type: ignore[arg-type]
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config: `(verl.trainer.config.ActorConfig)`:
config for the actor.
rollout_log_probs: `(torch.Tensor)`:
log probabilities of actions under the rollout policy, shape (batch_size, response_length).
"""
assert config is not None
assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("gspo")
def compute_policy_loss_gspo(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "seq-mean-token-mean",
config: Optional[DictConfig | ActorConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GSPO.
See https://arxiv.org/pdf/2507.18071 for more details.
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean".
"""
assert config is not None
assert isinstance(config, ActorConfig)
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio
negative_approx_kl = log_prob - old_log_prob
# compute sequence-level importance ratio:
# si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =
# exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths
# Combined ratio at token level:
# s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
# In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]
log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0) # clamp for numerical stability
# finaly exp() to remove log
seq_importance_ratio = torch.exp(log_seq_importance_ratio)
pg_losses1 = -advantages * seq_importance_ratio
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
# for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("gpg")
def compute_policy_loss_gpg(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
Args:
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
return:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via GPG
"""
pg_losses = -log_prob * advantages
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
@register_policy_loss("clip_cov")
def compute_policy_loss_clip_cov(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
clip_cvo_ratio (float, optional):
Ratio for clipping the covariance. Defaults to 0.0002.
clip_cov_lb (float, optional):
Lower bound for clipping covariance. Defaults to 1.0.
clip_cov_ub (float, optional):
Upper bound for clipping covariance. Defaults to 5.0.
"""
assert config is not None
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
assert config.policy_loss is not None
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
cliprange = config.clip_ratio
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange
clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0
clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0
assert clip_cov_ratio > 0, "clip_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
corr = torch.ones_like(advantages)
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)
cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * (
log_prob - verl_F.masked_mean(log_prob.detach(), response_mask)
)
cov_all[response_mask == 0] = -torch.inf
cov_all[clip_by_origin] = -torch.inf
clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)
top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
top_k_idx = torch.nonzero(top_k_idx)
if len(top_k_idx) > 0:
perm = torch.randperm(len(top_k_idx))
top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]
else:
top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)
corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
@register_policy_loss("kl_cov")
def compute_policy_loss_kl_cov(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
kl_cov_ratio (float, optional):
Ratio for selecting the top-k covariance values. Defaults to 0.0002.
ppo_kl_coef (float, optional):
Coefficient for the KL penalty term in the loss. Defaults to 1.
"""
assert config is not None
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
assert config.policy_loss is not None
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
abs_kl = negative_approx_kl.abs()
ratio = torch.exp(negative_approx_kl)
ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)
pg_losses1 = -advantages * ratio
pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl
pg_losses = pg_losses1
all_valid = response_mask > 0
all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]
all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()
k = min(kl_cov_ratio, len(all_valid_adv))
if k != 0:
cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))
large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices
if len(large_cov_idxs) != 0:
large_cov_idxs = all_valid_idx[large_cov_idxs]
pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
]
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
@register_policy_loss("geo_mean")
def compute_policy_loss_geo_mean(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GMPO.
Adapted from paper https://arxiv.org/abs/2507.20673
https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
not used
"""
assert config is not None
assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability (uncomment it if you like)
# negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
# Clipping at token-level & Clipping wider
sgn_advantage = torch.sign(advantages)
negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high)
negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp)
negative_approx_kl_min = sgn_advantage * negative_approx_kl_min
# Geometric-Mean Policy Optimization
response_mask_sum = response_mask.sum(dim=-1)
ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8))
# we only support sequence level advantage for now,
# otherwise, below would be not consistent with the paper
advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
pg_losses = -advantage * ratio
# Apply rollout importance sampling weights if provided
# For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level
if rollout_is_weights is not None:
# Aggregate token-level weights to sequence level using geometric mean for consistency
# Note: rollout_is_weights is always 2D regardless of rollout_is_level
seq_is_weights = torch.exp(
(torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
)
pg_losses = pg_losses * seq_is_weights
pg_loss = torch.mean(pg_losses)
# higher: ratio is too large that need clamp to clip_high (when adv > 0)
clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp)
pg_clipfrac = verl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask)
pg_clipfrac_lower = verl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
"""Compute categorical entropy loss (For backward compatibility)
Args:
logits (torch.Tensor): shape is (bs, response_length, vocab_size)
response_mask (torch.Tensor): shape is (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return entropy_loss
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
cliprange_value: float,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped value-function loss for PPO.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (torch.FloatTensor):
Predicted values from the value head, shape (batch_size, response_length).
values (torch.FloatTensor):
Old (baseline) values from the value head, shape (batch_size, response_length).
returns (torch.FloatTensor):
Ground-truth returns, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the value loss calculation.
cliprange_value (float):
Clip range for value prediction updates.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
Returns:
vf_loss (torch.FloatTensor):
A scalar tensor containing the aggregated value-function loss.
vf_clipfrac (float):
Fraction of elements where the clipped loss was used.
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac
[docs]
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other
kl penalty compute method for unbiased KL gradient estimation.
See more description in http://joschu.net/blog/kl-approx.html
Args:
logprob:
ref_logprob:
Returns:
kl_estimate
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
return forward_score
"""
The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3
estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator,
so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+.
"""
backward_score = 0.5 * (logprob - ref_logprob).square()
return backward_score - backward_score.detach() + forward_score.detach()
def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Args:
logprob:
ref_logprob:
Returns:
kl_estimate
"""
if kl_penalty in ("kl", "k1"):
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty in ("mse", "k2"):
return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty in ("low_var_kl", "k3"):
kl = ref_logprob - logprob
# For numerical stability
kl = torch.clamp(kl, min=-20, max=20)
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
raise NotImplementedError
def compute_pf_ppo_reweight_data(
data,
reweight_method: str = "pow",
weight_pow: float = 2.0,
):
"""Reweight the data based on the token_level_scores.
Args:
data: DataProto object, containing batch, non_tensor_batch and meta_info
reweight_method: str, choices: "pow", "max_min", "max_random"
weight_pow: float, the power of the weight
Returns:
"""
@torch.no_grad()
def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor:
"""Compute importance weights for resampling based on scores.
Args:
scores (torch.Tensor): Tensor of scores to compute weights from.
reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random').
weight_pow (float): Power exponent for 'pow' method.
Returns:
torch.Tensor: Computed importance weights.
Raises:
ValueError: If reweight_method is not supported.
"""
if reweight_method == "pow":
weights = torch.pow(torch.abs(scores), weight_pow)
elif reweight_method == "max_min":
max_score = torch.max(scores)
min_score = torch.min(scores)
weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0)
elif reweight_method == "max_random":
max_score = torch.max(scores)
weights = torch.where(scores == max_score, 0.4, 0.1)
else:
raise ValueError(f"Unsupported reweight_method: {reweight_method}")
return weights
scores = data.batch["token_level_scores"].sum(dim=-1)
weights = compute_weights(scores, reweight_method, weight_pow)
weights = torch.clamp(weights + 1e-8, min=1e-8)
batch_size = scores.shape[0]
sample_indices = torch.multinomial(weights, batch_size, replacement=True)
resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()}
sample_indices_np = sample_indices.numpy()
resampled_non_tensor_batch = {}
for key, array in data.non_tensor_batch.items():
if isinstance(array, np.ndarray):
resampled_non_tensor_batch[key] = array[sample_indices_np]
else:
resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np]
resampled_meta_info = {}
for key, value in data.meta_info.items():
if isinstance(value, list) and len(value) == batch_size:
resampled_meta_info[key] = [value[i] for i in sample_indices_np]
else:
resampled_meta_info[key] = value
from copy import deepcopy
resampled_data = deepcopy(data)
resampled_data.batch = type(data.batch)(resampled_batch)
resampled_data.batch.batch_size = data.batch.batch_size
resampled_data.non_tensor_batch = resampled_non_tensor_batch
resampled_data.meta_info = resampled_meta_info
return resampled_data