Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from trl.models.utils import prepare_deepspeed
class RLHFTrainerMixin:
def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
*_args,
**kwargs):
from trl.trainer import disable_dropout_in_model
from swift.llm import HfConfigFactory
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
args = kwargs['args']
self.beta = getattr(args, 'beta', 0.0)
if getattr(args, 'disable_dropout', False):
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.is_encoder_decoder = kwargs['template'].is_encoder_decoder
self.aux_loss_enabled = getattr(model.config, 'output_router_logits', False)
self._peft_has_been_casted_to_bf16 = False
self.generate_during_eval = getattr(args, 'generate_during_eval', False)
if self.is_encoder_decoder:
self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id')
self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id')
# not use
self.is_vision_model = False
self.label_pad_token_id = -100
self.use_dpo_data_collator = True
super().__init__(model, *_args, **kwargs)
if ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.padding_value = self.tokenizer.pad_token_id
def get_train_dataloader(self, *args, **kwargs):
train_dataloader = super().get_train_dataloader(*args, **kwargs)
base_dataloader = train_dataloader.base_dataloader if hasattr(
train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader,
DataLoader) else train_dataloader
if base_dataloader.worker_init_fn is not None and not isinstance(
base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature(
base_dataloader.worker_init_fn).parameters:
base_dataloader.worker_init_fn = partial(
base_dataloader.worker_init_fn,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index)
return train_dataloader
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
model_kwargs = batch.copy()
labels = model_kwargs.pop('labels', None)
if self.is_encoder_decoder:
model_kwargs['labels'] = labels
if self.aux_loss_enabled:
model_kwargs['output_router_logits'] = True
outputs = model(**model_kwargs, use_cache=False)
model_kwargs['labels'] = labels
model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2) # just get shape
if outputs.logits.shape[1] != labels.shape[1]:
# for llava, the model returns logits for the entire sequence, including the image tokens
# (placed before the text tokens)
outputs.logits = outputs.logits[:, -labels.shape[1]:]
for key in ['input_ids', 'attention_mask', 'labels']:
model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None)
if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels']
@contextmanager
def _patch_concatenated_forward():
_old_concatenated_inputs = self.concatenated_inputs
_old_model_call = model.__class__.__call__
self.concatenated_inputs = lambda *args, **kwargs: model_kwargs
model.__class__.__call__ = lambda *args, **kwargs: outputs
try:
yield
finally:
self.concatenated_inputs = _old_concatenated_inputs
model.__class__.__call__ = _old_model_call
with _patch_concatenated_forward():
return super().concatenated_forward(model, model_kwargs)
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
res = super().compute_loss(model, inputs, return_outputs=return_outputs)
# compat transformers>=4.46.*
if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
loss = res[0] if return_outputs else res
loss = loss / self.args.gradient_accumulation_steps
return (loss, res[1:]) if return_outputs else loss
return res
def _get_train_sampler(self, train_dataset=None):
get_train_sampler = super()._get_train_sampler
parameters = inspect.signature(get_train_sampler).parameters
kwargs = {'train_dataset': train_dataset} if 'train_dataset' in parameters else {}
return get_train_sampler(**kwargs)