|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments |
|
from datasets import load_dataset |
|
from peft import LoraConfig |
|
from trl import SFTTrainer |
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train. |
|
""" |
|
per_device_train_batch_size: Optional[int] = field(default=4) |
|
per_device_eval_batch_size: Optional[int] = field(default=1) |
|
gradient_accumulation_steps: Optional[int] = field(default=4) |
|
learning_rate: Optional[float] = field(default=2e-4) |
|
max_grad_norm: Optional[float] = field(default=0.3) |
|
weight_decay: Optional[int] = field(default=0.001) |
|
lora_alpha: Optional[int] = field(default=16) |
|
lora_dropout: Optional[float] = field(default=0.1) |
|
lora_r: Optional[int] = field(default=8) |
|
max_seq_length: Optional[int] = field(default=2048) |
|
model_name: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." |
|
} |
|
) |
|
dataset_name: Optional[str] = field( |
|
default="stingning/ultrachat", |
|
metadata={"help": "The preference dataset to use."}, |
|
) |
|
fp16: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Enables fp16 training."}, |
|
) |
|
bf16: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Enables bf16 training."}, |
|
) |
|
packing: Optional[bool] = field( |
|
default=True, |
|
metadata={"help": "Use packing dataset creating."}, |
|
) |
|
gradient_checkpointing: Optional[bool] = field( |
|
default=True, |
|
metadata={"help": "Enables gradient checkpointing."}, |
|
) |
|
use_flash_attention_2: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Enables Flash Attention 2."}, |
|
) |
|
optim: Optional[str] = field( |
|
default="paged_adamw_32bit", |
|
metadata={"help": "The optimizer to use."}, |
|
) |
|
lr_scheduler_type: str = field( |
|
default="constant", |
|
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"}, |
|
) |
|
max_steps: int = field(default=1000, metadata={"help": "How many optimizer update steps to take"}) |
|
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"}) |
|
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."}) |
|
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."}) |
|
output_dir: str = field( |
|
default="./results", |
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, |
|
) |
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
def formatting_func(example): |
|
text = f"### USER: {example['data'][0]}\n### ASSISTANT: {example['data'][1]}" |
|
return text |
|
|
|
|
|
model_id = "google/gemma-7b" |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4" |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
quantization_config=quantization_config, |
|
torch_dtype=torch.float32, |
|
attn_implementation="sdpa" if not script_args.use_flash_attention_2 else "flash_attention_2" |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
lora_config = LoraConfig( |
|
r=script_args.lora_r, |
|
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
lora_alpha=script_args.lora_alpha, |
|
lora_dropout=script_args.lora_dropout |
|
) |
|
|
|
train_dataset = load_dataset(script_args.dataset_name, split="train[:5%]") |
|
|
|
|
|
YOUR_HF_USERNAME = xxx |
|
output_dir = f"{YOUR_HF_USERNAME}/gemma-qlora-ultrachat" |
|
|
|
training_arguments = TrainingArguments( |
|
output_dir=output_dir, |
|
per_device_train_batch_size=script_args.per_device_train_batch_size, |
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
|
optim=script_args.optim, |
|
save_steps=script_args.save_steps, |
|
logging_steps=script_args.logging_steps, |
|
learning_rate=script_args.learning_rate, |
|
max_grad_norm=script_args.max_grad_norm, |
|
max_steps=script_args.max_steps, |
|
warmup_ratio=script_args.warmup_ratio, |
|
lr_scheduler_type=script_args.lr_scheduler_type, |
|
gradient_checkpointing=script_args.gradient_checkpointing, |
|
fp16=script_args.fp16, |
|
bf16=script_args.bf16, |
|
) |
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
args=training_arguments, |
|
train_dataset=train_dataset, |
|
peft_config=lora_config, |
|
packing=script_args.packing, |
|
dataset_text_field="id", |
|
tokenizer=tokenizer, |
|
max_seq_length=script_args.max_seq_length, |
|
formatting_func=formatting_func, |
|
) |
|
|
|
trainer.train() |