|
import os |
|
import sys |
|
from typing import Dict, List |
|
|
|
import fire |
|
import torch |
|
import transformers |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaTokenizerFast |
|
from peft import prepare_model_for_kbit_training |
|
""" |
|
Unused imports: |
|
import torch.nn as nn |
|
import bitsandbytes as bnb |
|
""" |
|
|
|
from peft import ( |
|
LoraConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
prepare_model_for_int8_training, |
|
set_peft_model_state_dict, |
|
) |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
|
|
from utils.prompter import Prompter |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
from transformers.trainer_callback import TrainerCallback |
|
|
|
class SavePeftModelCallback(transformers.TrainerCallback): |
|
def save_model(self, args, state, kwargs): |
|
print('Saving PEFT checkpoint...') |
|
if state.best_model_checkpoint is not None: |
|
checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model") |
|
else: |
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") |
|
|
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") |
|
kwargs["model"].save_pretrained(peft_model_path) |
|
|
|
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") |
|
if os.path.exists(pytorch_model_path): |
|
os.remove(pytorch_model_path) |
|
|
|
def on_save(self, args, state, control, **kwargs): |
|
self.save_model(args, state, kwargs) |
|
return control |
|
|
|
def on_train_end(self, args, state, control, **kwargs): |
|
def touch(fname, times=None): |
|
with open(fname, 'a'): |
|
os.utime(fname, times) |
|
|
|
touch(os.path.join(args.output_dir, 'completed')) |
|
self.save_model(args, state, kwargs) |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
|
|
def print_trainable_parameters(model): |
|
""" |
|
Prints the number of trainable parameters in the model. |
|
""" |
|
trainable_params = 0 |
|
all_param = 0 |
|
for _, param in model.named_parameters(): |
|
all_param += param.numel() |
|
if param.requires_grad: |
|
trainable_params += param.numel() |
|
print( |
|
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" |
|
) |
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
def train( |
|
|
|
base_model: str = "", |
|
data_path: str = "", |
|
output_dir: str = "./lora-alpaca", |
|
|
|
batch_size: int = 128, |
|
micro_batch_size: int = 4, |
|
num_epochs: int = 3, |
|
learning_rate: float = 3e-4, |
|
cutoff_len: int = 256, |
|
val_set_size: int = 2000, |
|
|
|
lora_r: int = 8, |
|
lora_alpha: int = 16, |
|
lora_dropout: float = 0.05, |
|
lora_target_modules: List[str] = [ |
|
"q_proj", |
|
"v_proj", |
|
], |
|
|
|
train_on_inputs: bool = True, |
|
add_eos_token: bool = False, |
|
group_by_length: bool = False, |
|
resume_from_checkpoint: str = None, |
|
prompt_template_name: str = "alpaca", |
|
): |
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0: |
|
print( |
|
f"Training Alpaca-LoRA model with params:\n" |
|
f"base_model: {base_model}\n" |
|
f"data_path: {data_path}\n" |
|
f"output_dir: {output_dir}\n" |
|
f"batch_size: {batch_size}\n" |
|
f"micro_batch_size: {micro_batch_size}\n" |
|
f"num_epochs: {num_epochs}\n" |
|
f"learning_rate: {learning_rate}\n" |
|
f"cutoff_len: {cutoff_len}\n" |
|
f"val_set_size: {val_set_size}\n" |
|
f"lora_r: {lora_r}\n" |
|
f"lora_alpha: {lora_alpha}\n" |
|
f"lora_dropout: {lora_dropout}\n" |
|
f"lora_target_modules: {lora_target_modules}\n" |
|
f"train_on_inputs: {train_on_inputs}\n" |
|
f"add_eos_token: {add_eos_token}\n" |
|
f"group_by_length: {group_by_length}\n" |
|
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" |
|
f"prompt template: {prompt_template_name}\n" |
|
) |
|
assert ( |
|
base_model |
|
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" |
|
gradient_accumulation_steps = batch_size // micro_batch_size |
|
|
|
prompter = Prompter(prompt_template_name) |
|
|
|
device_map = "auto" |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
ddp = world_size != 1 |
|
if ddp: |
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
quantization_config=bnb_config, |
|
device_map=device_map, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
|
if tokenizer._pad_token is None: |
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
if isinstance(tokenizer, LlamaTokenizerFast): |
|
|
|
|
|
|
|
|
|
tokenizer.eos_token_id = model.config.eos_token_id |
|
tokenizer.pad_token_id = model.config.pad_token_id |
|
if hasattr(model.config, 'unk_token_id'): |
|
tokenizer.unk_token_id = model.config.unk_token_id |
|
else: |
|
tokenizer.unk_token_id = tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
|
def tokenize(prompt, add_eos_token=True): |
|
|
|
|
|
result = tokenizer( |
|
prompt, |
|
truncation=True, |
|
max_length=cutoff_len, |
|
padding=False, |
|
return_tensors=None, |
|
) |
|
if ( |
|
result["input_ids"][-1] != tokenizer.eos_token_id |
|
and len(result["input_ids"]) < cutoff_len |
|
and add_eos_token |
|
): |
|
result["input_ids"].append(tokenizer.eos_token_id) |
|
result["attention_mask"].append(1) |
|
|
|
result["labels"] = result["input_ids"].copy() |
|
|
|
return result |
|
|
|
def generate_and_tokenize_prompt(data_point): |
|
full_prompt = prompter.generate_prompt( |
|
data_point["instruction"], |
|
data_point["input"], |
|
data_point["output"], |
|
) |
|
tokenized_full_prompt = tokenize(full_prompt) |
|
if not train_on_inputs: |
|
user_prompt = prompter.generate_prompt( |
|
data_point["instruction"], data_point["input"] |
|
) |
|
tokenized_user_prompt = tokenize( |
|
user_prompt, add_eos_token=add_eos_token |
|
) |
|
user_prompt_len = len(tokenized_user_prompt["input_ids"]) |
|
|
|
if add_eos_token: |
|
user_prompt_len -= 1 |
|
|
|
tokenized_full_prompt["labels"] = [ |
|
-100 |
|
] * user_prompt_len + tokenized_full_prompt["labels"][ |
|
user_prompt_len: |
|
] |
|
return tokenized_full_prompt |
|
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
config = LoraConfig( |
|
r=lora_r, |
|
lora_alpha=lora_alpha, |
|
target_modules=lora_target_modules, |
|
lora_dropout=lora_dropout, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
model = get_peft_model(model, config) |
|
|
|
if data_path.endswith(".json") or data_path.endswith(".jsonl"): |
|
data = load_dataset("json", data_files=data_path) |
|
else: |
|
data = load_dataset(data_path) |
|
|
|
if resume_from_checkpoint: |
|
|
|
checkpoint_name = os.path.join( |
|
resume_from_checkpoint, "pytorch_model.bin" |
|
) |
|
if not os.path.exists(checkpoint_name): |
|
checkpoint_name = os.path.join( |
|
resume_from_checkpoint, "adapter_model.bin" |
|
) |
|
resume_from_checkpoint = ( |
|
False |
|
) |
|
|
|
if os.path.exists(checkpoint_name): |
|
print(f"Restarting from {checkpoint_name}") |
|
adapters_weights = torch.load(checkpoint_name) |
|
set_peft_model_state_dict(model, adapters_weights) |
|
else: |
|
print(f"Checkpoint {checkpoint_name} not found") |
|
|
|
print_trainable_parameters(model) |
|
if val_set_size > 0: |
|
train_val = data["train"].train_test_split( |
|
test_size=val_set_size, shuffle=True, seed=42 |
|
) |
|
train_data = ( |
|
train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
|
) |
|
val_data = ( |
|
train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
|
) |
|
else: |
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) |
|
val_data = None |
|
|
|
trainer = transformers.Trainer( |
|
model=model, |
|
train_dataset=train_data, |
|
eval_dataset=val_data, |
|
args=transformers.TrainingArguments( |
|
per_device_train_batch_size=micro_batch_size, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
warmup_steps=10, |
|
num_train_epochs=num_epochs, |
|
learning_rate=learning_rate, |
|
|
|
logging_steps=10, |
|
optim="paged_adamw_8bit", |
|
evaluation_strategy="steps" if val_set_size > 0 else "no", |
|
save_strategy="steps", |
|
eval_steps=100 if val_set_size > 0 else None, |
|
save_steps=100, |
|
output_dir=output_dir, |
|
save_total_limit=3, |
|
|
|
load_best_model_at_end=False, |
|
ddp_find_unused_parameters=False if ddp else None, |
|
group_by_length=group_by_length, |
|
report_to=None, |
|
run_name=None, |
|
), |
|
data_collator=transformers.DataCollatorForSeq2Seq( |
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True |
|
), |
|
callbacks=[SavePeftModelCallback] |
|
) |
|
model.config.use_cache = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
|
|
|
model.save_pretrained(output_dir) |
|
|
|
print( |
|
"\n If there's a warning about missing keys above, please disregard :)" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(train) |
|
|
|
|