Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
import torch.nn as nn | |
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: | |
tensor = tensor * mask | |
tensor = tensor.sum(dim=dim) | |
mask_sum = mask.sum(dim=dim) | |
mean = tensor / (mask_sum + 1e-8) | |
return mean | |
class GPTLMLoss(nn.Module): | |
""" | |
GPT Language Model Loss | |
""" | |
def __init__(self): | |
super().__init__() | |
self.loss = nn.CrossEntropyLoss() | |
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
class PolicyLoss(nn.Module): | |
""" | |
Policy Loss for PPO | |
""" | |
def __init__(self, clip_eps: float = 0.2) -> None: | |
super().__init__() | |
self.clip_eps = clip_eps | |
def forward(self, | |
log_probs: torch.Tensor, | |
old_log_probs: torch.Tensor, | |
advantages: torch.Tensor, | |
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
ratio = (log_probs - old_log_probs).exp() | |
surr1 = ratio * advantages | |
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages | |
loss = -torch.min(surr1, surr2) | |
if action_mask is not None: | |
loss = masked_mean(loss, action_mask) | |
loss = loss.mean() | |
return loss | |
class ValueLoss(nn.Module): | |
""" | |
Value Loss for PPO | |
""" | |
def __init__(self, clip_eps: float = 0.4) -> None: | |
super().__init__() | |
self.clip_eps = clip_eps | |
def forward(self, | |
values: torch.Tensor, | |
old_values: torch.Tensor, | |
reward: torch.Tensor, | |
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) | |
surr1 = (values_clipped - reward)**2 | |
surr2 = (values - reward)**2 | |
loss = torch.max(surr1, surr2) | |
loss = loss.mean() | |
return 0.5 * loss | |
class PPOPtxActorLoss(nn.Module): | |
""" | |
To Do: | |
PPO-ptx Actor Loss | |
""" | |
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None: | |
super().__init__() | |
self.pretrain_coef = pretrain_coef | |
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps) | |
self.pretrain_loss_fn = pretrain_loss_fn | |
def forward(self, | |
log_probs: torch.Tensor, | |
old_log_probs: torch.Tensor, | |
advantages: torch.Tensor, | |
lm_logits: torch.Tensor, | |
lm_input_ids: torch.Tensor, | |
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask) | |
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids) | |
return policy_loss + self.pretrain_coef * lm_loss | |
class LogSigLoss(nn.Module): | |
""" | |
Pairwise Loss for Reward Model | |
Details: https://arxiv.org/abs/2203.02155 | |
""" | |
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: | |
probs = torch.sigmoid(chosen_reward - reject_reward) | |
log_probs = torch.log(probs) | |
loss = -log_probs.mean() | |
return loss | |
class LogExpLoss(nn.Module): | |
""" | |
Pairwise Loss for Reward Model | |
Details: https://arxiv.org/abs/2204.05862 | |
""" | |
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: | |
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() | |
return loss | |