|
|
|
from contextlib import contextmanager |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from peft import PeftModel |
|
from transformers import PreTrainedModel |
|
from trl import KTOTrainer as HFKTOTrainer |
|
|
|
from swift.utils import get_logger |
|
from ..mixin import SwiftMixin |
|
from .rlhf_mixin import RLHFTrainerMixin |
|
|
|
logger = get_logger() |
|
|
|
del HFKTOTrainer.__init__ |
|
|
|
|
|
class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer): |
|
|
|
def __init__(self, |
|
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
*_args, |
|
**kwargs): |
|
args = kwargs['args'] |
|
args.disable_dropout = True |
|
self.desirable_weight = args.desirable_weight |
|
self.undesirable_weight = args.undesirable_weight |
|
self.precompute_ref_log_probs = args.precompute_ref_log_probs |
|
self.is_peft_model = isinstance(model, PeftModel) |
|
if hasattr(args, 'loss_type'): |
|
self.loss_type = args.loss_type |
|
else: |
|
self.loss_type = 'kto' |
|
|
|
self.ref_adapter_name = None |
|
|
|
self.calculate_KL = True |
|
if self.loss_type in ['apo_zero_unpaired']: |
|
self.calculate_KL = False |
|
super().__init__(model, ref_model, *_args, **kwargs) |
|
|
|
def forward( |
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] |
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
|
is_kl = True |
|
|
|
def _add_data_hook(model, args, kwargs): |
|
nonlocal is_kl |
|
if is_kl: |
|
kwargs = {k[len('KL_completion_'):]: v for k, v in batch.items() if k.startswith('KL_completion_')} |
|
else: |
|
kwargs = {k[len('completion_'):]: v for k, v in batch.items() if k.startswith('completion_')} |
|
is_kl = not is_kl |
|
return (), kwargs |
|
|
|
@contextmanager |
|
def _patch_model_call(): |
|
handle = model.register_forward_pre_hook(_add_data_hook, with_kwargs=True, prepend=True) |
|
|
|
try: |
|
yield |
|
finally: |
|
handle.remove() |
|
|
|
with _patch_model_call(): |
|
return super().forward(model, batch) |
|
|