chatlawv1 / finetune_4bit.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
import os
import sys
import torch
import torch.nn as nn
import bitsandbytes as bnb
from datasets import load_dataset, Dataset
import transformers
import argparse
import warnings
from huggingface_hub import snapshot_download
assert (
"LlamaTokenizer" in transformers._import_structure["models.llama"]
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig
from peft import (
prepare_model_for_kbit_training,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
def generate_prompt(data_point):
# sorry about the formatting disaster gotta move fast
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:
{data_point["output"]}"""
def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
def generate_and_tokenize_prompt(data_point):
# This function masks out the labels for the input,
# so that our loss is computed only on the response.
user_prompt = (
(
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""
)
if data_point["input"]
else (
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:
"""
)
)
len_user_prompt_tokens = (
len(
tokenizer(
user_prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
)["input_ids"]
)
- 1
) # no eos token
full_tokens = tokenizer(
user_prompt + data_point["output"],
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"][:-1]
return {
"input_ids": full_tokens,
"labels": [-100] * len_user_prompt_tokens
+ full_tokens[len_user_prompt_tokens:],
"attention_mask": [1] * (len(full_tokens)),
}
parser = argparse.ArgumentParser()
parser.add_argument("--wandb", action="store_true", default=False)
parser.add_argument("--data_path", type=str, default="merge.json")
parser.add_argument("--output_path", type=str, default="lora-Vicuna")
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
parser.add_argument("--eval_steps", type=int, default=200)
parser.add_argument("--save_steps", type=int, default=200)
parser.add_argument("--test_size", type=int, default=200)
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--lora_remote_checkpoint", type=str, default=None)
parser.add_argument("--ignore_data_skip", type=str, default="False")
args = parser.parse_args()
if not args.wandb:
os.environ["WANDB_MODE"] = "disable"
# optimized for RTX 4090. for larger GPUs, increase some of these?
MICRO_BATCH_SIZE = 8 # this could actually be 5 but i like powers of 2
BATCH_SIZE = 128
MAX_STEPS = None
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3 # we don't always need 3 tbh
LEARNING_RATE = 3e-4 # the Karpathy constant
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
VAL_SET_SIZE = args.test_size #2000
TARGET_MODULES = [
"q_proj",
"v_proj",
]
DATA_PATH = args.data_path
OUTPUT_DIR = args.output_path #"lora-Vicuna"
device_map = {"": 0} #"auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
print(args.model_path)
model = LlamaForCausalLM.from_pretrained(
args.model_path,
load_in_4bit=True,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(
args.model_path, add_eos_token=True
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
#tokenizer.padding_side = "left" # Allow batched inference
data = load_dataset("json", data_files=DATA_PATH)
import random;start = random.randint(1, 100)
examples = Dataset.from_dict(data['train'][start:start+5]).map(generate_and_tokenize_prompt)
for example in examples:
print(f'>>> prompt example:\n { tokenizer.decode(example["input_ids"]) }')
print(f'>>> tokenizer labels: { tokenizer.decode([ 0 if l==-100 else l for l in example["labels"]])}')
print(f'>>> tokenizer example: { example["input_ids"][:250] }...{ example["input_ids"][-10:]}')
now_max_steps = max((len(data["train"]) - VAL_SET_SIZE) // BATCH_SIZE * EPOCHS, EPOCHS)
if args.resume_from_checkpoint:
if args.lora_remote_checkpoint is not None:
snapshot_download(repo_id=args.lora_remote_checkpoint, allow_patterns=["*.pt", "*.bin", "*.json"], local_dir=args.resume_from_checkpoint)
# Check the available weights and load them
checkpoint_name = os.path.join(
args.resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
pytorch_bin_path = checkpoint_name
checkpoint_name = os.path.join(
args.resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
if os.path.exists(checkpoint_name):
os.rename(checkpoint_name, pytorch_bin_path)
warnings.warn("The file name of the lora checkpoint'adapter_model.bin' is replaced with 'pytorch_model.bin'")
else:
args.resume_from_checkpoint = (
None # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
train_args_path = os.path.join(args.resume_from_checkpoint, "trainer_state.json")
if os.path.exists(train_args_path):
import json
base_train_args = json.load(open(train_args_path, 'r'))
base_max_steps = base_train_args["max_steps"]
resume_scale = base_max_steps / now_max_steps
if base_max_steps > now_max_steps:
warnings.warn("epoch {} replace to the base_max_steps {}".format(EPOCHS, base_max_steps))
EPOCHS = None
MAX_STEPS = base_max_steps
else:
MAX_STEPS = now_max_steps
else:
MAX_STEPS = now_max_steps
model.print_trainable_parameters()
num_proc = (os.cpu_count())
if VAL_SET_SIZE > 0:
train_val = data["train"].train_test_split(
test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=num_proc)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt, num_proc=num_proc)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=num_proc)
val_data = None
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=100,
num_train_epochs=EPOCHS,
max_steps=MAX_STEPS,
learning_rate=LEARNING_RATE,
fp16=True,
logging_steps=20,
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
save_strategy="steps",
eval_steps=args.eval_steps if VAL_SET_SIZE > 0 else None,
save_steps=args.save_steps,
output_dir=OUTPUT_DIR,
save_total_limit=30,
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
report_to="wandb" if args.wandb else [],
ignore_data_skip=args.ignore_data_skip,
optim="paged_adamw_8bit",
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
print("\n If there's a warning about missing keys above, please disregard :)")
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
model.save_pretrained(OUTPUT_DIR)