llama_7b_ft_2 / finetune.py
gustavoaq's picture
Update finetune.py
5c9e3fb
raw
history blame
5.24 kB
import os
import sys
import torch
import pickle
import random
import json
import torch.nn as nn
from datasets import load_dataset
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import (
prepare_model_for_int8_training,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
)
HF_TOKEN = os.environ.get("TRL_TOKEN", None)
if HF_TOKEN:
print(HF_TOKEN)
repo = Repository(
local_dir="./checkpoints/", clone_from="gustavoaq/llama_ft", use_auth_token=HF_TOKEN, repo_type="models"
)
repo.git_pull()
# Parameters
MICRO_BATCH_SIZE = 16
BATCH_SIZE = 32
size = "7b"
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 1
LEARNING_RATE = float(0.00015)
CUTOFF_LEN = 512
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
VAL_SET_SIZE = 2000
TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"down_proj",
"gate_proj",
"up_proj",
]
DATA_PATH = "data/data_tmp.json"
OUTPUT_DIR = "checkpoints/{}".format(size)
if not os.path.exists("data"):
os.makedirs("data")
# Load data
data = []
for x in "alpaca,stackoverflow,quora".split(","):
data += json.load(open("data/{}_chat_data.json".format(x)))
random.shuffle(data)
json.dump(data, open(DATA_PATH, "w"))
data = load_dataset("json", data_files=DATA_PATH)
# Load Model
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-{}-hf".format(size),
load_in_8bit=True,
device_map='auto',
)
total_params, params = 0, 0
tokenizer = LlamaTokenizer.from_pretrained(
"decapoda-research/llama-{}-hf".format(size), add_eos_token=True,
load_in_8bit_fp32_cpu_offload=True, device_map={0: [0]},
)
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
config.save_pretrained(OUTPUT_DIR)
model = get_peft_model(model, config)
tokenizer.pad_token_id = 0
for n, p in model.model.named_parameters():
if any([x in n for x in ["lora"]]):
total_params += p.numel()
params += p.numel()
print(
"Total number of parameters: {}M, rate: {}%".format(
total_params // 1000 / 1000, round(total_params / params * 100, 2)
)
)
# Data Preprocess
def generate_prompt(data_point):
return data_point["input"]
def tokenize(prompt):
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt)
if VAL_SET_SIZE > 0:
train_val = data["train"].train_test_split(
test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
# Training
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
per_device_eval_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=100,
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
logging_steps=20,
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
save_strategy="steps",
eval_steps=200 if VAL_SET_SIZE > 0 else None,
save_steps=200,
output_dir=OUTPUT_DIR,
save_total_limit=100,
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
import gradio as gr
def train(input_text):
print(os.listdir(OUTPUT_DIR))
# Call your trainer's train() function here
trainer.train()
print("Training complete.") # optional message to display when training is done
model.save_pretrained(OUTPUT_DIR)
repo.push_to_hub(OUTPUT_DIR, commit_message="Ft model")
iface = gr.Interface(
fn=train,
inputs=gr.inputs.Textbox(label="Input text"),
outputs=gr.outputs.Textbox(label="Output length"),
title="Training Interface",
description="Enter some text and click the button to start training.",
theme="default",
layout="vertical",
allow_flagging=False,
allow_screenshot=False
)
iface.launch(share=True)