import os import sys from typing import Dict, List import fire import torch import transformers from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaTokenizerFast from peft import prepare_model_for_kbit_training """ Unused imports: import torch.nn as nn import bitsandbytes as bnb """ from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, set_peft_model_state_dict, ) from transformers import LlamaForCausalLM, LlamaTokenizer from utils.prompter import Prompter from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_callback import TrainerCallback class SavePeftModelCallback(transformers.TrainerCallback): def save_model(self, args, state, kwargs): print('Saving PEFT checkpoint...') if state.best_model_checkpoint is not None: checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model") else: checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") peft_model_path = os.path.join(checkpoint_folder, "adapter_model") kwargs["model"].save_pretrained(peft_model_path) pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") if os.path.exists(pytorch_model_path): os.remove(pytorch_model_path) def on_save(self, args, state, control, **kwargs): self.save_model(args, state, kwargs) return control def on_train_end(self, args, state, control, **kwargs): def touch(fname, times=None): with open(fname, 'a'): os.utime(fname, times) touch(os.path.join(args.output_dir, 'completed')) self.save_model(args, state, kwargs) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) DEFAULT_PAD_TOKEN = "[PAD]" def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def train( # model/data params base_model: str = "", # the only required argument data_path: str = "", output_dir: str = "./lora-alpaca", # training hyperparams batch_size: int = 128, micro_batch_size: int = 4, num_epochs: int = 3, learning_rate: float = 3e-4, cutoff_len: int = 256, val_set_size: int = 2000, # lora hyperparams lora_r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.05, lora_target_modules: List[str] = [ "q_proj", "v_proj", ], # llm hyperparams train_on_inputs: bool = True, # if False, masks out inputs in loss add_eos_token: bool = False, group_by_length: bool = False, # faster, but produces an odd training loss curve resume_from_checkpoint: str = None, # either training checkpoint or final adapter prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( f"Training Alpaca-LoRA model with params:\n" f"base_model: {base_model}\n" f"data_path: {data_path}\n" f"output_dir: {output_dir}\n" f"batch_size: {batch_size}\n" f"micro_batch_size: {micro_batch_size}\n" f"num_epochs: {num_epochs}\n" f"learning_rate: {learning_rate}\n" f"cutoff_len: {cutoff_len}\n" f"val_set_size: {val_set_size}\n" f"lora_r: {lora_r}\n" f"lora_alpha: {lora_alpha}\n" f"lora_dropout: {lora_dropout}\n" f"lora_target_modules: {lora_target_modules}\n" f"train_on_inputs: {train_on_inputs}\n" f"add_eos_token: {add_eos_token}\n" f"group_by_length: {group_by_length}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" f"prompt template: {prompt_template_name}\n" ) assert ( base_model ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" gradient_accumulation_steps = batch_size // micro_batch_size prompter = Prompter(prompt_template_name) device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} gradient_accumulation_steps = gradient_accumulation_steps // world_size model = AutoModelForCausalLM.from_pretrained( base_model, quantization_config=bnb_config, device_map=device_map, ) tokenizer = AutoTokenizer.from_pretrained(base_model) if tokenizer._pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), tokenizer=tokenizer, model=model, ) if isinstance(tokenizer, LlamaTokenizerFast): # LLaMA tokenizer may not have correct special tokens set. # Check and add them if missing to prevent them from being parsed into different tokens. # Note that these are present in the vocabulary. # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. tokenizer.eos_token_id = model.config.eos_token_id tokenizer.pad_token_id = model.config.pad_token_id if hasattr(model.config, 'unk_token_id'): tokenizer.unk_token_id = model.config.unk_token_id else: tokenizer.unk_token_id = tokenizer.pad_token_id #tokenizer.padding_side = "left" # Allow batched inference def tokenize(prompt, add_eos_token=True): # there's probably a way to do this with the tokenizer settings # but again, gotta move fast result = tokenizer( prompt, truncation=True, max_length=cutoff_len, padding=False, return_tensors=None, ) if ( result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < cutoff_len and add_eos_token ): result["input_ids"].append(tokenizer.eos_token_id) result["attention_mask"].append(1) result["labels"] = result["input_ids"].copy() return result def generate_and_tokenize_prompt(data_point): full_prompt = prompter.generate_prompt( data_point["instruction"], data_point["input"], data_point["output"], ) tokenized_full_prompt = tokenize(full_prompt) if not train_on_inputs: user_prompt = prompter.generate_prompt( data_point["instruction"], data_point["input"] ) tokenized_user_prompt = tokenize( user_prompt, add_eos_token=add_eos_token ) user_prompt_len = len(tokenized_user_prompt["input_ids"]) if add_eos_token: user_prompt_len -= 1 tokenized_full_prompt["labels"] = [ -100 ] * user_prompt_len + tokenized_full_prompt["labels"][ user_prompt_len: ] # could be sped up, probably return tokenized_full_prompt model = prepare_model_for_kbit_training(model) config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=lora_target_modules, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, config) if data_path.endswith(".json") or data_path.endswith(".jsonl"): data = load_dataset("json", data_files=data_path) else: data = load_dataset(data_path) if resume_from_checkpoint: # Check the available weights and load them checkpoint_name = os.path.join( resume_from_checkpoint, "pytorch_model.bin" ) # Full checkpoint if not os.path.exists(checkpoint_name): checkpoint_name = os.path.join( resume_from_checkpoint, "adapter_model.bin" ) # only LoRA model - LoRA config above has to fit resume_from_checkpoint = ( False # So the trainer won't try loading its state ) # The two files above have a different name depending on how they were saved, but are actually the same. if os.path.exists(checkpoint_name): print(f"Restarting from {checkpoint_name}") adapters_weights = torch.load(checkpoint_name) set_peft_model_state_dict(model, adapters_weights) else: print(f"Checkpoint {checkpoint_name} not found") print_trainable_parameters(model) # Be more transparent about the % of trainable params. if val_set_size > 0: train_val = data["train"].train_test_split( test_size=val_set_size, shuffle=True, seed=42 ) train_data = ( train_val["train"].shuffle().map(generate_and_tokenize_prompt) ) val_data = ( train_val["test"].shuffle().map(generate_and_tokenize_prompt) ) else: train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) val_data = None trainer = transformers.Trainer( model=model, train_dataset=train_data, eval_dataset=val_data, args=transformers.TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=10, num_train_epochs=num_epochs, learning_rate=learning_rate, # fp16=True, logging_steps=10, optim="paged_adamw_8bit", evaluation_strategy="steps" if val_set_size > 0 else "no", save_strategy="steps", eval_steps=100 if val_set_size > 0 else None, save_steps=100, output_dir=output_dir, save_total_limit=3, #load_best_model_at_end=True if val_set_size > 0 else False, load_best_model_at_end=False, ddp_find_unused_parameters=False if ddp else None, group_by_length=group_by_length, report_to=None, run_name=None, ), data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), callbacks=[SavePeftModelCallback] ) model.config.use_cache = False # if not ddp and torch.cuda.device_count() > 1: # # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available # model.is_parallelizable = True # model.model_parallel = True # old_state_dict = model.state_dict # model.state_dict = ( # lambda self, *_, **__: get_peft_model_state_dict( # self, old_state_dict() # ) # ).__get__(model, type(model)) #if torch.__version__ >= "2" and sys.platform != "win32": # model = torch.compile(model) trainer.train(resume_from_checkpoint=resume_from_checkpoint) model.save_pretrained(output_dir) print( "\n If there's a warning about missing keys above, please disregard :)" ) if __name__ == "__main__": fire.Fire(train)