|
|
|
import inspect |
|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from transformers import PreTrainedModel |
|
from trl.models.utils import prepare_deepspeed |
|
|
|
|
|
class RLHFTrainerMixin: |
|
|
|
def __init__(self, |
|
model: Optional[Union[PreTrainedModel, nn.Module]] = None, |
|
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, |
|
*_args, |
|
**kwargs): |
|
from trl.trainer import disable_dropout_in_model |
|
from swift.llm import HfConfigFactory |
|
self.ref_model = ref_model |
|
self._stored_metrics = defaultdict(lambda: defaultdict(list)) |
|
args = kwargs['args'] |
|
self.beta = getattr(args, 'beta', 0.0) |
|
if getattr(args, 'disable_dropout', False): |
|
disable_dropout_in_model(model) |
|
if self.ref_model is not None: |
|
disable_dropout_in_model(self.ref_model) |
|
|
|
self.is_encoder_decoder = kwargs['template'].is_encoder_decoder |
|
self.aux_loss_enabled = getattr(model.config, 'output_router_logits', False) |
|
self._peft_has_been_casted_to_bf16 = False |
|
self.generate_during_eval = getattr(args, 'generate_during_eval', False) |
|
if self.is_encoder_decoder: |
|
self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') |
|
self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') |
|
|
|
self.is_vision_model = False |
|
self.label_pad_token_id = -100 |
|
self.use_dpo_data_collator = True |
|
super().__init__(model, *_args, **kwargs) |
|
if ref_model is not None: |
|
if self.is_deepspeed_enabled: |
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
|
else: |
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) |
|
|
|
self.padding_value = self.tokenizer.pad_token_id |
|
|
|
def get_train_dataloader(self, *args, **kwargs): |
|
train_dataloader = super().get_train_dataloader(*args, **kwargs) |
|
base_dataloader = train_dataloader.base_dataloader if hasattr( |
|
train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader, |
|
DataLoader) else train_dataloader |
|
if base_dataloader.worker_init_fn is not None and not isinstance( |
|
base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature( |
|
base_dataloader.worker_init_fn).parameters: |
|
base_dataloader.worker_init_fn = partial( |
|
base_dataloader.worker_init_fn, |
|
num_workers=self.args.dataloader_num_workers, |
|
rank=self.args.process_index) |
|
return train_dataloader |
|
|
|
def concatenated_forward( |
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] |
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
|
model_kwargs = batch.copy() |
|
labels = model_kwargs.pop('labels', None) |
|
if self.is_encoder_decoder: |
|
model_kwargs['labels'] = labels |
|
|
|
if self.aux_loss_enabled: |
|
model_kwargs['output_router_logits'] = True |
|
outputs = model(**model_kwargs, use_cache=False) |
|
model_kwargs['labels'] = labels |
|
model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2) |
|
if outputs.logits.shape[1] != labels.shape[1]: |
|
|
|
|
|
outputs.logits = outputs.logits[:, -labels.shape[1]:] |
|
for key in ['input_ids', 'attention_mask', 'labels']: |
|
model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None) |
|
if self.__class__.__name__ == 'ORPOTrainer': |
|
model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels'] |
|
|
|
@contextmanager |
|
def _patch_concatenated_forward(): |
|
_old_concatenated_inputs = self.concatenated_inputs |
|
_old_model_call = model.__class__.__call__ |
|
self.concatenated_inputs = lambda *args, **kwargs: model_kwargs |
|
model.__class__.__call__ = lambda *args, **kwargs: outputs |
|
try: |
|
yield |
|
finally: |
|
self.concatenated_inputs = _old_concatenated_inputs |
|
model.__class__.__call__ = _old_model_call |
|
|
|
with _patch_concatenated_forward(): |
|
return super().concatenated_forward(model, model_kwargs) |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
res = super().compute_loss(model, inputs, return_outputs=return_outputs) |
|
|
|
if num_items_in_batch is not None and self.model_accepts_loss_kwargs: |
|
loss = res[0] if return_outputs else res |
|
loss = loss / self.args.gradient_accumulation_steps |
|
return (loss, res[1:]) if return_outputs else loss |
|
return res |
|
|
|
def _get_train_sampler(self, train_dataset=None): |
|
get_train_sampler = super()._get_train_sampler |
|
parameters = inspect.signature(get_train_sampler).parameters |
|
kwargs = {'train_dataset': train_dataset} if 'train_dataset' in parameters else {} |
|
return get_train_sampler(**kwargs) |
|
|