|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from peft import PeftModel |
|
from transformers import PreTrainedModel |
|
from trl import DPOTrainer as HFDPOTrainer |
|
from trl.trainer.utils import selective_log_softmax |
|
|
|
from swift.utils import get_logger |
|
from ..mixin import DataLoaderMixin, SwiftMixin |
|
from .rlhf_mixin import RLHFTrainerMixin |
|
|
|
del HFDPOTrainer.__init__ |
|
logger = get_logger() |
|
|
|
|
|
class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): |
|
|
|
def __init__(self, |
|
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
*_args, |
|
**kwargs): |
|
from trl.trainer import FDivergenceConstants |
|
args = kwargs['args'] |
|
self.label_smoothing = args.label_smoothing |
|
self.loss_type = args.loss_type |
|
self.precompute_ref_log_probs = args.precompute_ref_log_probs |
|
self.f_divergence_type = args.f_divergence_type |
|
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} |
|
self.is_peft_model = isinstance(model, PeftModel) |
|
|
|
self.ref_adapter_name = args.ref_adapter_name |
|
self.reference_free = args.reference_free |
|
self.use_weighting = False |
|
|
|
super().__init__(model, ref_model, *_args, **kwargs) |
|
|
|
def concatenated_forward( |
|
self, |
|
model: nn.Module, |
|
batch: Dict[str, Union[List, torch.LongTensor]], |
|
is_ref_model: bool = False |
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
|
batch = batch.copy() |
|
labels = batch.pop('labels', None) |
|
|
|
use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1) |
|
if use_logits_to_keep: |
|
labels, logits_to_keep = self.get_logits_to_keep(labels) |
|
if logits_to_keep is not None: |
|
batch['logits_to_keep'] = logits_to_keep |
|
if self.aux_loss_enabled: |
|
batch['output_router_logits'] = True |
|
if self.is_encoder_decoder: |
|
batch['labels'] = labels |
|
position_ids = batch.pop('_position_ids', None) |
|
if position_ids is None: |
|
position_ids = batch.get('position_ids') |
|
outputs = model(**batch, use_cache=False) |
|
all_logits = outputs.logits |
|
|
|
if all_logits.shape[1] != labels.shape[1]: |
|
|
|
|
|
all_logits = all_logits[:, -labels.shape[1]:] |
|
|
|
if not self.is_encoder_decoder and self.template.sequence_parallel_size == 1: |
|
|
|
labels = torch.roll(labels, shifts=-1, dims=1) |
|
per_token_logps, mean_all_logits, loss_mask = self.get_per_token_logps( |
|
all_logits, labels, label_pad_token_id=self.label_pad_token_id) |
|
origin_per_token_logps = per_token_logps |
|
if self.loss_type == 'ipo': |
|
size_completion = loss_mask.sum(dim=-1) |
|
per_token_logps = per_token_logps / size_completion |
|
|
|
output = {} |
|
if self.template.padding_free: |
|
cu_seqlens = self.get_cu_seqlens(position_ids, batch.get('logits_to_keep')) |
|
all_logps = per_token_logps.new_zeros((cu_seqlens.shape[0] - 1, )) |
|
for i in range(cu_seqlens.shape[0] - 1): |
|
start, end = cu_seqlens[i], cu_seqlens[i + 1] |
|
all_logps[i] = per_token_logps[:, start:end].sum() |
|
num_examples = all_logps.shape[0] // 2 |
|
num_tokens = cu_seqlens[num_examples] |
|
if not is_ref_model: |
|
output['nll_loss'] = -origin_per_token_logps[:, :num_tokens][loss_mask[:, :num_tokens]].mean() |
|
output['chosen_logps'] = all_logps[:num_examples] |
|
output['rejected_logps'] = all_logps[num_examples:] |
|
output['mean_chosen_logits'] = mean_all_logits[:, :num_tokens][loss_mask[:, :num_tokens]].mean() |
|
output['mean_rejected_logits'] = mean_all_logits[:, num_tokens:][loss_mask[:, num_tokens:]].mean() |
|
else: |
|
all_logps = per_token_logps.sum(-1) |
|
num_examples = labels.shape[0] // 2 |
|
if not is_ref_model: |
|
output['nll_loss'] = -origin_per_token_logps[:num_examples][loss_mask[:num_examples]].mean() |
|
output['chosen_logps'] = all_logps[:num_examples] |
|
output['rejected_logps'] = all_logps[num_examples:] |
|
output['mean_chosen_logits'] = mean_all_logits[:num_examples][loss_mask[:num_examples]].mean() |
|
output['mean_rejected_logits'] = mean_all_logits[num_examples:][loss_mask[num_examples:]].mean() |
|
if self.aux_loss_enabled: |
|
output['aux_loss'] = outputs.aux_loss |
|
return output |
|
|
|
@staticmethod |
|
def get_per_token_logps( |
|
logits: torch.FloatTensor, |
|
labels: torch.LongTensor, |
|
label_pad_token_id=-100, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
if logits.shape[:-1] != labels.shape: |
|
raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}' |
|
'and labels must have the same shape {labels.shape}') |
|
loss_mask = labels != label_pad_token_id |
|
labels = labels.clone() |
|
labels[~loss_mask] = 0 |
|
|
|
|
|
per_token_logps = selective_log_softmax(logits, labels) |
|
per_token_logps[~loss_mask] = 0 |
|
return per_token_logps, logits.mean(-1), loss_mask |
|
|
|
def training_step(self, model, inputs, *args, **kwargs): |
|
inputs['_position_ids'] = inputs.get('position_ids') |
|
with self.template.forward_context(self.model, inputs): |
|
return super().training_step(model, inputs, *args, **kwargs) |
|
|
|
def prediction_step(self, model, inputs, *args, **kwargs): |
|
inputs['_position_ids'] = inputs.get('position_ids') |
|
with self.template.forward_context(self.model, inputs): |
|
return super().prediction_step(model, inputs, *args, **kwargs) |
|
|