Llama3-8B-DPO / README.md
CultriX's picture
Update README.md
8501887 verified
metadata
license: apache-2.0
datasets:
  - CultriX/llama70B-dpo-dataset
language:
  - en
base_model:
  - NousResearch/Hermes-3-Llama-3.1-8B
pipeline_tag: text-generation
tags:
  - dpo
  - Llama3
  - general
library_name: transformers

Model Card for Llama3-8B-DPO

License: Apache-2.0
Datasets: CultriX/llama70B-dpo-dataset
Language: English
Base Model: NousResearch/Hermes-3-Llama-3.1-8B
Pipeline Tag: Text-Generation
Tags: DPO, Llama3, General Library: Transformers


Performance

Model Name AGIEval TruthfulQA BigBench
Hermes-3-Llama-3.1-8B 41.51 58.61 43.08
Llama3-8B-DPO 41.87 71.38 44.5

Training Script

# Install required libraries
!pip install --upgrade pip
!pip install git+https://github.com/huggingface/transformers
!pip install git+https://github.com/huggingface/peft.git
!pip install git+https://github.com/huggingface/trl.git
!pip install --upgrade wandb accelerate datasets

import os
import gc
import torch
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig
from huggingface_hub import notebook_login

# Log in to Hugging Face and WandB
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
    notebook_login()
else:
    notebook_login(token=hf_token)

wb_token = os.getenv('WANDB_API_KEY')
if not wb_token:
    wandb.login()
else:
    wandb.login(key=wb_token)

# Set model names
model_name = "NousResearch/Hermes-3-Llama-3.1-8B"
base_model_name = model_name
fine_tuned_model_name = "OrpoLlama-3-8B"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.config.use_cache = False

# Apply LoRA for fine-tuning
peft_config = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.gradient_checkpointing_enable()

# Load and format dataset
dataset = load_dataset("CultriX/llama70B-dpo-dataset")["train"]

def chatml_format(example):
    system = example.get("system", "")
    question = example.get("question", "")
    chosen = example.get("chosen", "")
    rejected = example.get("rejected", "")

    prompt = ""
    if system:
        prompt += f"<|im_start|>system\n{system}<|im_end|>\n"
    prompt += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"

    return {
        "prompt": prompt,
        "chosen": f"{chosen}<|im_end|>\n",
        "rejected": f"{rejected}<|im_end|>\n",
    }

dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)

# Fine-tune the model using DPO Trainer
training_args = DPOConfig(
    output_dir="model-output",
    logging_steps=50,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    num_train_epochs=4,
    save_strategy="no",
    optim="adamw_torch",
    warmup_ratio=0.03,
    bf16=True,
    report_to="wandb",
    beta=0.1,
    max_prompt_length=2048,
    max_length=4096,
    disable_dropout=False,
    force_use_ref_model=True,
)

trainer = DPOTrainer(
    model=model,
    ref_model=AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16),
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=dataset,
)
trainer.train()

# Save fine-tuned model
trainer.model.save_pretrained("final_ckpt")
tokenizer.save_pretrained("final_ckpt")

# Test the fine-tuned model
from transformers import pipeline

fine_tuned_model = AutoModelForCausalLM.from_pretrained("final_ckpt", torch_dtype=torch.bfloat16)
text_gen_pipeline = pipeline(
    "text-generation",
    model=fine_tuned_model,
    tokenizer=tokenizer,
    max_length=4096,
)

messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant chatbot that provides concise answers.",
    },
    {
        "role": "user",
        "content": "What are GPUs and why would I use them for machine learning tasks?",
    },
]
prompt = "".join(f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" for msg in messages)

sequences = text_gen_pipeline(prompt, do_sample=True, temperature=0.7, top_p=0.9)
print(sequences[0]["generated_text"])