import os import json import math import torch from tqdm import tqdm from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple from transformers import DataCollatorWithPadding, Seq2SeqTrainingArguments from transformers.trainer import TRAINING_ARGS_NAME, TRAINER_STATE_NAME from transformers.tokenization_utils import PreTrainedTokenizer from trl import PPOTrainer, AutoModelForCausalLMWithValueHead from trl.core import LengthSampler from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits from .config import FinetuningArguments from .other import ( AverageMeter, get_logger, save_trainable_params, save_valuehead_params, get_logits_processor, FINETUNING_ARGS_NAME ) logger = get_logger(__name__) def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: if target == "reward": # save original head temporarily valuehead_state_dict = model.v_head.state_dict() setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active model.v_head.load_state_dict({ "summary.weight": getattr(model, "{}_head_weight".format(target)), "summary.bias": getattr(model, "{}_head_bias".format(target)) }) @torch.no_grad() def compute_rewards( input_ids: torch.Tensor, # (batch size x seq len) with format `X [gMASK] [BOS] Y [EOS] [PAD] ... [PAD]` model: AutoModelForCausalLMWithValueHead, tokenizer: PreTrainedTokenizer ) -> torch.Tensor: replace_model(model, target="reward") _, _, values = model(input_ids=input_ids) values = values.transpose(0, 1) rewards = [] for i in range(input_ids.size(0)): eos_idx = (input_ids[i] == tokenizer.eos_token_id).nonzero() # Note: checking with [EOS] token is unsafe if len(eos_idx): eos_idx = eos_idx[0].item() else: eos_idx = input_ids.size(1) - 1 rewards.append(values[i][eos_idx]) rewards = torch.stack(rewards, dim=0) replace_model(model, target="default") return rewards def cast_layernorm_dtype( model: AutoModelForCausalLMWithValueHead, layer_norm_names: List[str] = ["layernorm"], # for chatglm setting layer_norm_params: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: layer_norm_state_dict = {} for name, param in model.named_parameters(): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): if layer_norm_params is not None: param.data = layer_norm_params[name] # restore float32 weights else: layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability param.data = param.data.to(torch.float16) return model, layer_norm_state_dict class PPODataCollatorForChatGLM(DataCollatorWithPadding): r""" Data collator for ChatGLM. It is capable of dynamically padding for batched data. """ def __init__( self, tokenizer: PreTrainedTokenizer, min_input_length: int, max_input_length: int, inference_mode: bool = False, ): super().__init__(tokenizer, padding=True) self.inference_mode = inference_mode if min_input_length < max_input_length: self.input_size = LengthSampler(min_input_length, max_input_length) else: self.input_size = lambda: max_input_length # always use max_input_length def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: r""" Pads batched data to the longest sequence in the batch. We adopt left-padding for ppo data. Equips with a length sampler to generate sequences with variable lengths. ChatGLM is able to generate attentions masks and position ids by itself. """ if self.inference_mode: raise NotImplementedError input_ids = [torch.tensor(feature["input_ids"][:self.input_size()]).flip(0) for feature in features] input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) features = {"input_ids": input_ids.flip(-1)} return features class PPOTrainerForChatGLM(PPOTrainer): r""" Inherits PPOTrainer. """ def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, *args, **kwargs): super().__init__(*args, **kwargs) self.state = {"log_history": []} self.training_args = training_args self.finetuning_args = finetuning_args @torch.no_grad() def generate( self, query_tensor: torch.Tensor, # (batch size x seq len) length_sampler: Callable = None, return_prompt: bool = True, **generation_kwargs, ) -> torch.Tensor: r""" Generate response with the model given the query tensor. Inspired by: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/trl/trainer/ppo_trainer.py#L387 """ self.model, layer_norm_params = cast_layernorm_dtype(self.model) if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() unwrapped_model = self.accelerator.unwrap_model(self.model) response = unwrapped_model.generate( input_ids=query_tensor, **generation_kwargs ) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 if unwrapped_model.pretrained_model.generation_config._from_model_config: unwrapped_model.pretrained_model.generation_config._from_model_config = False self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) if not return_prompt and not self.is_encoder_decoder: return response[:, query_tensor.size(1):] return response def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): input_ids = [] for query, response in zip(queries, responses): # query is left-padded, response is right-padded start = (query != self.tokenizer.pad_token_id).nonzero()[0].item() input_ids.append(torch.cat((query[start:], response, query[:start]))) # change to right-padding model_inputs = {"input_ids": torch.stack(input_ids, dim=0).to(self.current_device)} # already padded to equal length model_inputs["attention_mask"] = torch.ones_like(model_inputs["input_ids"]) # unused indeed, avoid distributed error return model_inputs @PPODecorators.empty_cuda_cache() def batched_forward_pass( self, model: AutoModelForCausalLMWithValueHead, queries: torch.Tensor, responses: torch.Tensor, model_inputs: dict, ): r""" Calculate model outputs in multiple batches. Override to inject custom behavior. """ bs = len(queries) fbs = self.config.mini_batch_size all_logprobs = [] all_logits = [] all_masks = [] all_values = [] for i in range(int(bs / fbs)): input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} input_ids = input_kwargs["input_ids"] logits, _, values = model(input_ids=input_ids) # chatglm only needs input_ids logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) values = values.transpose(0, 1) masks = torch.zeros_like(input_ids) for j in range(fbs): start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item() # always contain a [BOS] token end = (input_ids[j] == self.tokenizer.eos_token_id).nonzero() # Note: checking with [EOS] token is unsafe if len(end): end = end[0].item() else: end = masks.size(1) masks[j][start:end] = 1 if end - start < 2: raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.") all_logits.append(logits) all_values.append(values) all_logprobs.append(logprobs) all_masks.append(masks) return ( torch.cat(all_logprobs), torch.cat(all_logits)[:, :-1], torch.cat(all_values)[:, :-1], torch.cat(all_masks)[:, :-1], ) def ppo_train(self, max_target_length: int) -> None: total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.training_args.world_size len_dataloader = len(self.dataloader) num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1) num_examples = len(self.dataset) num_train_epochs = self.training_args.num_train_epochs max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) if self.is_world_process_zero(): logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {self.config.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") # Keyword arguments for `model.generate` gen_kwargs = { "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, "logits_processor": get_logits_processor() } output_length_sampler = LengthSampler(max_target_length // 2, max_target_length) unwrapped_model = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) steps_trained = 0 loss_meter = AverageMeter() reward_meter = AverageMeter() for step in tqdm(range(max_steps)): for _ in range(self.config.gradient_accumulation_steps): batch = next(dataiter) steps_trained += 1 queries = batch["input_ids"] # left-padded sequences unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True # Get response from ChatGLM responses_with_queries = self.generate(queries, length_sampler=output_length_sampler, **gen_kwargs) responses = responses_with_queries[:, queries.size(1):].clone().detach() # right-padded sequences (remember to clone!!!) # batch["response"] = tokenizer.batch_decode(responses, skip_special_tokens=True) # comment to avoid decode error for i in range(responses_with_queries.size(0)): # change to right-padding start = (responses_with_queries[i] != self.tokenizer.pad_token_id).nonzero()[0].item() responses_with_queries[i] = torch.cat((responses_with_queries[i][start:], responses_with_queries[i][:start])) # Compute rewards rewards = compute_rewards(responses_with_queries, unwrapped_model, self.tokenizer) # Run PPO step unwrapped_model.gradient_checkpointing_enable() unwrapped_model.config.use_cache = False split_into_list = lambda x: [x[i] for i in range(x.size(0))] stats = self.step(*map(split_into_list, [queries, responses, rewards])) loss_meter.update(stats["ppo/loss/total"]) reward_meter.update(rewards.sum().item(), n=rewards.size(0)) if steps_trained == len_dataloader: dataiter = iter(self.dataloader) steps_trained = 0 if self.is_world_process_zero() and (step+1) % self.training_args.logging_steps == 0: logs = { "loss": round(loss_meter.avg, 4), "reward": round(reward_meter.avg, 4), "learning_rate": stats["ppo/learning_rate"], "epoch": round(step / num_steps_per_epoch, 2) } print(logs) logs["step"] = step self.state["log_history"].append(logs) loss_meter.reset() reward_meter.reset() if (step+1) % self.training_args.save_steps == 0: # save checkpoint self.save_model(os.path.join(self.training_args.output_dir, f"checkpoint-{step+1}")) def is_world_process_zero(self) -> bool: r""" Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be `True` for one process). """ return self.training_args.process_index == 0 def save_state(self, output_dir: Optional[str] = None) -> None: r""" Saves trainer state. """ if not self.is_world_process_zero(): return output_dir = output_dir if output_dir is not None else self.training_args.output_dir os.makedirs(output_dir, exist_ok=True) json.dump(self.state, open(os.path.join(output_dir, TRAINER_STATE_NAME), "w", encoding="utf-8", newline="\n"), indent=2) def save_model(self, output_dir: Optional[str] = None) -> None: r""" Saves trainable parameters as model checkpoints. We use `self.model.pretrained_model` to refer to the backbone model. Override to inject custom behavior. """ if not self.is_world_process_zero(): return output_dir = output_dir if output_dir is not None else self.training_args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") unwrapped_model = self.accelerator.unwrap_model(self.model) if hasattr(unwrapped_model.pretrained_model, "peft_config"): # peft methods unwrapped_model.pretrained_model.save_pretrained(output_dir) # save lora weights else: # non-peft methods save_trainable_params(output_dir, unwrapped_model.pretrained_model) if hasattr(unwrapped_model, "v_head"): save_valuehead_params(output_dir, unwrapped_model.v_head) # save valuehead weights torch.save(self.training_args, os.path.join(output_dir, TRAINING_ARGS_NAME)) torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME))