this is a demo how fine tune phi-2 model on a 24G VRAM A10 or 4090 Card.

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl
from transformers import BitsAndBytesConfig

train_dataset = datasets.load_dataset('HuggingFaceTB/cosmopedia-20k', split='train')

args = TrainingArguments(
    output_dir="./test-sft",
    max_steps=20000,
    per_device_train_batch_size=1,
    optim="adafactor", report_to="none",
)

model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config,device_map="auto")
print(model)

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64, target_modules=["q_proj", "v_proj", "k_proj", "dense", "lm_head", "fc1", "fc2"],
    bias="none", 
    task_type="CAUSAL_LM",
)
model.add_adapter(peft_config)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=1024
)

trainer.train()

trainer.model.save_pretrained("sft", dtype=torch.bfloat16)
trainer.tokenizer.save_pretrained("sft")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .