Spaces:
Sleeping
Sleeping
import torch | |
from collections import defaultdict | |
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union | |
from transformers import BatchEncoding, Trainer | |
from trl import DPOTrainer | |
from trl.trainer.utils import disable_dropout_in_model | |
from llmtuner.extras.constants import IGNORE_INDEX | |
if TYPE_CHECKING: | |
from transformers import PreTrainedModel | |
class CustomDPOTrainer(DPOTrainer): | |
def __init__( | |
self, | |
beta: float, | |
model: Union["PreTrainedModel", torch.nn.Module], | |
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, | |
disable_dropout: Optional[bool] = True, | |
**kwargs | |
): | |
if disable_dropout: | |
disable_dropout_in_model(model) | |
if ref_model is not None: | |
disable_dropout_in_model(ref_model) | |
self.is_encoder_decoder = model.config.is_encoder_decoder | |
self.ref_model = ref_model | |
self.use_dpo_data_collator = True # hack to avoid warning | |
self.label_pad_token_id = IGNORE_INDEX | |
self.padding_value = 0 | |
self.beta = beta | |
self._stored_metrics = defaultdict(lambda: defaultdict(list)) | |
Trainer.__init__(self, model=model, **kwargs) | |
if not hasattr(self, "accelerator"): | |
raise AttributeError("Please update `transformers`.") | |
if ref_model is not None: | |
if self.is_deepspeed_enabled: | |
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model) | |
self.ref_model.eval() | |
else: | |
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) | |
def concatenated_forward( | |
self, | |
model: Optional[torch.nn.Module] = None, | |
batch: Optional[Dict[str, torch.Tensor]] = None | |
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: | |
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error | |
all_logits = model( | |
input_ids=batch_copied["input_ids"], | |
attention_mask=batch_copied["attention_mask"], | |
return_dict=True | |
).logits.to(torch.float32) | |
all_logps = self._get_batch_logps( | |
all_logits, | |
batch["labels"], | |
average_log_prob=False | |
) | |
batch_size = batch["input_ids"].size(0) // 2 | |
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) | |
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) | |
return chosen_logps, rejected_logps, chosen_logits, rejected_logits | |