Phi-3-mini-4k-instruct / sample_finetune.py
gugarosa's picture
chore(root): Initial files upload.
6bd8b8c
raw
history blame
2.42 kB
import torch
from datasets import load_dataset
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
"""
Please note that A100 or later generation GPUs are required to finetune Phi-3 models
1. Install accelerate:
conda install -c conda-forge accelerate
2. Setup accelerate config:
accelerate config
to simply use all the GPUs available:
python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='bf16')"
check accelerate config:
accelerate env
3. Run the code:
accelerate launch phi3-mini-sample-ft.py
"""
###################
# Hyper-parameters
###################
args = {
"bf16": True,
"do_eval": False,
"evaluation_strategy": "no",
"eval_steps": 100,
"learning_rate": 5.0e-06,
"log_level": "info",
"logging_steps": 20,
"logging_strategy": "steps",
"lr_scheduler_type": "cosine",
"num_train_epochs": 1,
"max_steps": -1,
"output_dir": ".",
"overwrite_output_dir": True,
"per_device_eval_batch_size": 4,
"per_device_train_batch_size": 8,
"remove_unused_columns": True,
"save_steps": 100,
"save_total_limit": 1,
"seed": 0,
"gradient_checkpointing": True,
"gradient_accumulation_steps": 1,
"warmup_ratio": 0.1,
}
training_args = TrainingArguments(**args)
################
# Modle Loading
################
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
# checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
model_kwargs = dict(
trust_remote_code=True,
attn_implementation="flash_attention_2", # load the model with flash-attenstion support
torch_dtype=torch.bfloat16,
device_map="cuda",
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
################
# Data Loading
################
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
max_seq_length=2048,
dataset_text_field="text",
tokenizer=tokenizer,
)
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()