Spaces:
Paused
Paused
import os | |
import torch | |
from typing import Dict, Optional, Sequence | |
from transformers import Trainer, DataCollatorWithPadding | |
from transformers.trainer import TRAINING_ARGS_NAME | |
from transformers.modeling_utils import unwrap_model | |
from transformers.tokenization_utils import PreTrainedTokenizer | |
from .config import FinetuningArguments | |
from .other import ( | |
get_logger, | |
save_trainable_params, | |
save_valuehead_params, | |
FINETUNING_ARGS_NAME | |
) | |
logger = get_logger(__name__) | |
class PairwiseDataCollatorForChatGLM(DataCollatorWithPadding): | |
r""" | |
Data collator for ChatGLM. It is capable of dynamically padding for batched data. | |
Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 | |
""" | |
def __init__( | |
self, | |
tokenizer: PreTrainedTokenizer, | |
inference_mode: bool = False | |
): | |
super().__init__(tokenizer, padding=True) | |
self.inference_mode = inference_mode | |
def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: | |
r""" | |
Pads batched data to the longest sequence in the batch. We adopt right-padding for pairwise data. | |
We generate 2 * n examples where the first n examples represents chosen examples and | |
the last n examples represents rejected examples. | |
ChatGLM is able to generate attentions masks and position ids by itself. | |
""" | |
if self.inference_mode: | |
raise NotImplementedError | |
accept_ids, reject_ids = [[torch.tensor(feature[key]) for feature in features] for key in ("accept_ids", "reject_ids")] | |
input_ids = accept_ids + reject_ids | |
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
features = {"input_ids": input_ids} | |
return features | |
class PairwiseTrainerForChatGLM(Trainer): | |
r""" | |
Inherits Trainer to compute pairwise loss. | |
""" | |
def __init__(self, finetuning_args: FinetuningArguments, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.finetuning_args = finetuning_args | |
def compute_loss(self, model, inputs, return_outputs=False): | |
r""" | |
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. | |
We use score on the EOS token to represent reward of the whole sentence. | |
""" | |
batch_size = inputs["input_ids"].size(0) // 2 | |
_, _, values = model(input_ids=inputs["input_ids"]) | |
rewards = values.transpose(0, 1)[(inputs["input_ids"] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)] | |
r_accept, r_reject = rewards.split(batch_size, dim=0) | |
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() | |
if return_outputs: | |
return loss, {"r_accept": r_accept, "r_reject": r_reject} | |
return loss | |
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: | |
r""" | |
Saves trainable parameters as model checkpoints. Use `self.model.pretrained_model` to refer to the backbone model. | |
This function will only be executed at the process zero. | |
Override to inject custom behavior. | |
""" | |
output_dir = output_dir if output_dir is not None else self.args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Saving model checkpoint to {output_dir}") | |
model_to_save = unwrap_model(self.model) | |
if hasattr(model_to_save.pretrained_model, "peft_config"): # peft methods | |
model_to_save.pretrained_model.save_pretrained(output_dir) # save lora weights | |
else: # non-peft methods | |
save_trainable_params(output_dir, model_to_save.pretrained_model) | |
if hasattr(model_to_save, "v_head"): | |
save_valuehead_params(output_dir, model_to_save.v_head) # save valuehead weights | |
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) | |