Soonchan's picture
Create ft.py
0107eb2 verified
#fine-tuning code
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainerCallback
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
import json
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments
import time
import os
class CustomCallback(TrainerCallback):
def __init__(self):
self.start_time = time.time()
def on_train_begin(self, args, state, control, **kwargs):
print("Training has begun!")
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
elapsed_time = time.time() - self.start_time
if state.log_history:
loss = state.log_history[-1].get('loss', 0)
print(f"Step: {state.global_step}, Loss: {loss:.4f}, Time: {elapsed_time:.2f}s")
else:
print(f"Step: {state.global_step}, Loss: N/A, Time: {elapsed_time:.2f}s")
def on_train_end(self, args, state, control, **kwargs):
print("Training has ended!")
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "Google/gemma-2-9b-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
# (QLORA)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
)
# PEFT
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=6,
lora_alpha=8,
lora_dropout=0.05,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# data
with open('en_ko_data', 'r', encoding='utf-8') as f:
data = json.load(f)
# prompt
def generate_prompt(en_text, ko_text):
return f"""<bos><start_of_turn>user
Please translate the following English colloquial expression into Korean.:
{en_text}<end_of_turn>
<start_of_turn>model
{ko_text}<end_of_turn><eos>"""
key = list(data.keys())[0]
dataset = [{"text": generate_prompt(item['en_original'], item['ko'])} for item in data[key]]
dataset = Dataset.from_list(dataset)
# training set
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
save_steps=100,
logging_steps=1,
learning_rate=2e-4,
weight_decay=0.01,
fp16=True,
optim="paged_adamw_8bit",
)
# SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=512,
)
trainer.add_callback(CustomCallback())
# train
trainer.train()
# save
trainer.save_model("./gemma2_9b_ko_translator")