| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | """
|
| | Core functions to implement PPO algorithms.
|
| | The function implemented in this file should be used by trainer with different distributed strategies to
|
| | implement PPO
|
| | """
|
| |
|
| | from collections import defaultdict
|
| |
|
| | import numpy as np
|
| | import torch
|
| |
|
| | import verl.utils.torch_functional as verl_F
|
| |
|
| |
|
| | class AdaptiveKLController:
|
| | """
|
| | Adaptive KL controller described in the paper:
|
| | https://arxiv.org/pdf/1909.08593.pdf
|
| | """
|
| |
|
| | def __init__(self, init_kl_coef, target_kl, horizon):
|
| | self.value = init_kl_coef
|
| | self.target = target_kl
|
| | self.horizon = horizon
|
| |
|
| | def update(self, current_kl, n_steps):
|
| | target = self.target
|
| | proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
|
| | mult = 1 + proportional_error * n_steps / self.horizon
|
| | self.value *= mult
|
| |
|
| |
|
| | class FixedKLController:
|
| | """Fixed KL controller."""
|
| |
|
| | def __init__(self, kl_coef):
|
| | self.value = kl_coef
|
| |
|
| | def update(self, current_kl, n_steps):
|
| | pass
|
| |
|
| |
|
| | def get_kl_controller(kl_ctrl):
|
| | if kl_ctrl.type == "fixed":
|
| | return FixedKLController(kl_coef=kl_ctrl.kl_coef)
|
| | elif kl_ctrl.type == "adaptive":
|
| | assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
|
| | return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| |
|
| | def compute_gae_advantage_return(
|
| | token_level_rewards: torch.Tensor,
|
| | values: torch.Tensor,
|
| | response_mask: torch.Tensor,
|
| | gamma: torch.Tensor,
|
| | lam: torch.Tensor,
|
| | ):
|
| | """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
|
| |
|
| | Args:
|
| | token_level_rewards: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | values: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
|
| | gamma: `(float)`
|
| | discounted factor used in RL
|
| | lam: `(float)`
|
| | lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
| |
|
| | Returns:
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | Returns: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| |
|
| | """
|
| | with torch.no_grad():
|
| | lastgaelam = 0
|
| | advantages_reversed = []
|
| | gen_len = token_level_rewards.shape[-1]
|
| |
|
| | for t in reversed(range(gen_len)):
|
| | nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
|
| | delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
|
| | lastgaelam = delta + gamma * lam * lastgaelam
|
| | advantages_reversed.append(lastgaelam)
|
| | advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
| |
|
| | returns = advantages + values
|
| | advantages = verl_F.masked_whiten(advantages, response_mask)
|
| | return advantages, returns
|
| |
|
| |
|
| |
|
| | def compute_grpo_outcome_advantage(
|
| | token_level_rewards: torch.Tensor,
|
| | response_mask: torch.Tensor,
|
| | index: np.ndarray,
|
| | epsilon: float = 1e-6,
|
| | norm_adv_by_std_in_grpo: str = True,
|
| | ):
|
| | """
|
| | Compute advantage for GRPO, operating only on Outcome reward
|
| | (with only one scalar reward for each response).
|
| | Args:
|
| | token_level_rewards: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | norm_adv_by_std_in_grpo: (bool)
|
| | whether to scale the GRPO advantage.
|
| | If True, the advantage is scaled by the std, as in the original GRPO.
|
| | If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
|
| |
|
| | Returns:
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | Returns: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | """
|
| | scores = token_level_rewards.sum(dim=-1)
|
| |
|
| | id2score = defaultdict(list)
|
| | id2mean = {}
|
| | id2std = {}
|
| |
|
| | with torch.no_grad():
|
| | bsz = scores.shape[0]
|
| | for i in range(bsz):
|
| | id2score[index[i]].append(scores[i])
|
| | for idx in id2score:
|
| | if len(id2score[idx]) == 1:
|
| | id2mean[idx] = torch.tensor(0.0)
|
| | id2std[idx] = torch.tensor(1.0)
|
| | elif len(id2score[idx]) > 1:
|
| | id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
| | id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
|
| | else:
|
| | raise ValueError(f"no score in prompt index: {idx}")
|
| | for i in range(bsz):
|
| | if norm_adv_by_std_in_grpo:
|
| | scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
|
| | else:
|
| | scores[i] = scores[i] - id2mean[index[i]]
|
| | scores = scores.unsqueeze(-1) * response_mask
|
| |
|
| | return scores, scores
|
| |
|
| |
|
| | def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6):
|
| | """
|
| | Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
|
| | (with only one scalar reward for each response).
|
| | Args:
|
| | token_level_rewards: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| |
|
| | Returns:
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | Returns: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | """
|
| | response_length = token_level_rewards.shape[-1]
|
| | scores = token_level_rewards.sum(dim=-1)
|
| |
|
| | id2score = defaultdict(list)
|
| | id2mean = {}
|
| |
|
| | with torch.no_grad():
|
| | bsz = scores.shape[0]
|
| | for i in range(bsz):
|
| | id2score[index[i]].append(scores[i])
|
| | for idx in id2score:
|
| | if len(id2score[idx]) == 1:
|
| | id2mean[idx] = torch.tensor(0.0)
|
| | elif len(id2score[idx]) > 1:
|
| | id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
| | else:
|
| | raise ValueError(f"no score in prompt index: {idx}")
|
| | for i in range(bsz):
|
| | scores[i] = scores[i] - id2mean[index[i]]
|
| |
|
| | scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
|
| | scores = verl_F.masked_whiten(scores, response_mask)
|
| |
|
| | return scores, scores
|
| |
|
| |
|
| | def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6):
|
| | """
|
| | Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
|
| | Args:
|
| | token_level_rewards: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| |
|
| | Returns:
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | Returns: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | """
|
| | scores = token_level_rewards.sum(dim=-1)
|
| |
|
| | id2score = defaultdict(list)
|
| | id2mean = {}
|
| |
|
| | with torch.no_grad():
|
| | bsz = scores.shape[0]
|
| | for i in range(bsz):
|
| | id2score[index[i]].append(scores[i])
|
| | for idx in id2score:
|
| | if len(id2score[idx]) == 1:
|
| | id2mean[idx] = torch.tensor(0.0)
|
| | elif len(id2score[idx]) > 1:
|
| | id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
| | else:
|
| | raise ValueError(f"no score in prompt index: {idx}")
|
| | for i in range(bsz):
|
| | response_num = len(id2score[index[i]])
|
| | if response_num > 1:
|
| | scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (response_num - 1)
|
| | scores = scores.unsqueeze(-1) * response_mask
|
| |
|
| | return scores, scores
|
| |
|
| |
|
| | def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor):
|
| | """
|
| | Compute advantage for REINFORCE++.
|
| | This implementation is based on the paper: https://arxiv.org/abs/2501.03262
|
| | Args:
|
| | token_level_rewards: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| |
|
| | Returns:
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | Returns: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | """
|
| |
|
| | with torch.no_grad():
|
| | returns = torch.zeros_like(token_level_rewards)
|
| | running_return = 0
|
| |
|
| | for t in reversed(range(token_level_rewards.shape[1])):
|
| | running_return = token_level_rewards[:, t] + gamma * running_return
|
| | returns[:, t] = running_return
|
| |
|
| | running_return = running_return * response_mask[:, t]
|
| |
|
| | advantages = verl_F.masked_whiten(returns, response_mask)
|
| | advantages = advantages * response_mask
|
| |
|
| | return advantages, returns
|
| |
|
| |
|
| | def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor):
|
| | """
|
| | Compute advantage for ReMax, operating only on Outcome reward
|
| | This implementation is based on the paper: https://arxiv.org/abs/2310.10505
|
| |
|
| | (with only one scalar reward for each response).
|
| | Args:
|
| | token_level_rewards: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | reward_baselines: `(torch.Tensor)`
|
| | shape: (bs,)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| |
|
| | Returns:
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | Returns: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | """
|
| |
|
| | with torch.no_grad():
|
| | returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
| | advantages = returns - reward_baselines.unsqueeze(-1) * response_mask
|
| |
|
| | return advantages, returns
|
| |
|
| |
|
| | def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
| | kl = old_log_prob - ref_log_prob
|
| | return token_level_scores - kl * kl_ratio
|
| |
|
| |
|
| | def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
|
| | """
|
| | Aggregate the loss matrix into a scalar.
|
| | Args:
|
| | loss_mat: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | loss_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | loss_agg_mode: (str) choices: "token-mean" /
|
| | "seq-mean-token-sum" /
|
| | "seq-mean-token-mean" /
|
| | "seq-mean-token-sum-norm" /
|
| | "token-mean" is the default behavior
|
| | Returns:
|
| | loss: `a scalar torch.Tensor`
|
| | aggregated loss
|
| | """
|
| | if loss_agg_mode == "token-mean":
|
| | loss = verl_F.masked_mean(loss_mat, loss_mask)
|
| | elif loss_agg_mode == "seq-mean-token-sum":
|
| | seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
|
| | loss = torch.mean(seq_losses)
|
| | elif loss_agg_mode == "seq-mean-token-mean":
|
| | seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)
|
| | loss = torch.mean(seq_losses)
|
| | elif loss_agg_mode == "seq-mean-token-sum-norm":
|
| | seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
|
| | loss = torch.sum(seq_losses) / loss_mask.shape[-1]
|
| |
|
| |
|
| |
|
| |
|
| | else:
|
| | raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
|
| |
|
| | return loss
|
| |
|
| |
|
| | def compute_policy_loss(
|
| | old_log_prob,
|
| | log_prob,
|
| | advantages,
|
| | response_mask,
|
| | cliprange=None,
|
| | cliprange_low=None,
|
| | cliprange_high=None,
|
| | clip_ratio_c=3.0,
|
| | loss_agg_mode="token-mean",
|
| | ):
|
| | """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
|
| | Args:
|
| | old_log_prob: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | log_prob: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | advantages: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| | cliprange: (float)
|
| | The clip range used in PPO. See https://arxiv.org/abs/1707.06347
|
| | cliprange_low: (float)
|
| | The lower clip range used in PPO.
|
| | cliprange_high: (float)
|
| | The higher clip range used in PPO.
|
| | clip_ratio_c: (float) default: 3.0
|
| | The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
|
| | loss_agg_mode: (str) choices: "token-mean" /
|
| | "seq-mean-token-sum" /
|
| | "seq-mean-token-mean" /
|
| | "seq-mean-token-sum-norm" /
|
| | "token-mean" is the default behavior
|
| |
|
| | Returns:
|
| | pg_loss: `a scalar torch.Tensor`
|
| | policy gradient loss computed via PPO
|
| | pg_clipfrac: (float)
|
| | the fraction of policy gradient loss being clipped
|
| | ppo_kl: (float)
|
| | the estimated KL divergence between the latest updating policy and the old sampling policy
|
| | pg_clipfrac_lower: (float)
|
| | the fraction of policy gradient loss being clipped when the advantage is negative
|
| | """
|
| | assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."
|
| |
|
| | negative_approx_kl = log_prob - old_log_prob
|
| | ratio = torch.exp(negative_approx_kl)
|
| | ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
|
| |
|
| | pg_losses1 = -advantages * ratio
|
| | if cliprange_low is None:
|
| | cliprange_low = cliprange
|
| | if cliprange_high is None:
|
| | cliprange_high = cliprange
|
| | pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
|
| | clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
|
| | pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
|
| |
|
| | pg_losses3 = -advantages * clip_ratio_c
|
| | clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
|
| | pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)
|
| |
|
| | pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
|
| | pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
| |
|
| | return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
| |
|
| |
|
| | def compute_entropy_loss(logits, response_mask):
|
| | """Compute Categorical entropy loss
|
| |
|
| | Args:
|
| | logits: `(torch.Tensor)`
|
| | shape: (bs, response_length, vocab_size)
|
| | response_mask: `(torch.Tensor)`
|
| | shape: (bs, response_length)
|
| |
|
| | Returns:
|
| | entropy: a scalar torch.Tensor
|
| |
|
| | """
|
| |
|
| | entropy = verl_F.entropy_from_logits(logits)
|
| | entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
|
| | return entropy_loss
|
| |
|
| |
|
| | def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
|
| | """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
|
| |
|
| | Args:
|
| | vpreds (`torch.FloatTensor`):
|
| | Predicted values of the value head, shape (`batch_size`, `response_length`)
|
| | values (`torch.FloatTensor`):
|
| | Old values of value head, shape (`batch_size`, `response_length`)
|
| | returns: (`torch.FloatTensor`):
|
| | Ground truth returns, shape (`batch_size`, `response_length`)
|
| |
|
| | Returns:
|
| | vf_loss: a scalar (`torch.FloatTensor`):
|
| | value function loss
|
| | vf_clipfrac: a float
|
| | The ratio of vf being clipped
|
| |
|
| | """
|
| | vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
|
| | vf_losses1 = (vpreds - returns) ** 2
|
| | vf_losses2 = (vpredclipped - returns) ** 2
|
| | vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
|
| | vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
|
| | return vf_loss, vf_clipfrac
|
| |
|
| |
|
| | def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
|
| | """Compute KL divergence given logprob and ref_logprob.
|
| | Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
|
| |
|
| | Args:
|
| | logprob:
|
| | ref_logprob:
|
| |
|
| | Returns:
|
| |
|
| | """
|
| | if kl_penalty == "kl":
|
| | return logprob - ref_logprob
|
| |
|
| | if kl_penalty == "abs":
|
| | return (logprob - ref_logprob).abs()
|
| |
|
| | if kl_penalty == "mse":
|
| | return 0.5 * (logprob - ref_logprob).square()
|
| |
|
| |
|
| |
|
| | if kl_penalty == "low_var_kl":
|
| | kl = ref_logprob - logprob
|
| | ratio = torch.exp(kl)
|
| | kld = (ratio - kl - 1).contiguous()
|
| | return torch.clamp(kld, min=-10, max=10)
|
| |
|
| | if kl_penalty == "full":
|
| |
|
| | raise NotImplementedError
|
| |
|
| | raise NotImplementedError
|
| |
|