Model Card for Llama3-8B-DPO
License: Apache-2.0
Datasets: CultriX/llama70B-dpo-dataset
Language: English
Base Model: NousResearch/Hermes-3-Llama-3.1-8B
Pipeline Tag: Text-Generation
Tags: DPO, Llama3, General
Library: Transformers
Performance
Model Name | AGIEval | TruthfulQA | BigBench |
---|---|---|---|
Hermes-3-Llama-3.1-8B | 41.51 | 58.61 | 43.08 |
Llama3-8B-DPO | 41.87 | 71.38 | 44.5 |
Training Script
# Install required libraries
!pip install --upgrade pip
!pip install git+https://github.com/huggingface/transformers
!pip install git+https://github.com/huggingface/peft.git
!pip install git+https://github.com/huggingface/trl.git
!pip install --upgrade wandb accelerate datasets
import os
import gc
import torch
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig
from huggingface_hub import notebook_login
# Log in to Hugging Face and WandB
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
notebook_login()
else:
notebook_login(token=hf_token)
wb_token = os.getenv('WANDB_API_KEY')
if not wb_token:
wandb.login()
else:
wandb.login(key=wb_token)
# Set model names
model_name = "NousResearch/Hermes-3-Llama-3.1-8B"
base_model_name = model_name
fine_tuned_model_name = "OrpoLlama-3-8B"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False
# Apply LoRA for fine-tuning
peft_config = LoraConfig(
r=8, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.gradient_checkpointing_enable()
# Load and format dataset
dataset = load_dataset("CultriX/llama70B-dpo-dataset")["train"]
def chatml_format(example):
system = example.get("system", "")
question = example.get("question", "")
chosen = example.get("chosen", "")
rejected = example.get("rejected", "")
prompt = ""
if system:
prompt += f"<|im_start|>system\n{system}<|im_end|>\n"
prompt += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
return {
"prompt": prompt,
"chosen": f"{chosen}<|im_end|>\n",
"rejected": f"{rejected}<|im_end|>\n",
}
dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)
# Fine-tune the model using DPO Trainer
training_args = DPOConfig(
output_dir="model-output",
logging_steps=50,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=1e-4,
lr_scheduler_type="cosine",
num_train_epochs=4,
save_strategy="no",
optim="adamw_torch",
warmup_ratio=0.03,
bf16=True,
report_to="wandb",
beta=0.1,
max_prompt_length=2048,
max_length=4096,
disable_dropout=False,
force_use_ref_model=True,
)
trainer = DPOTrainer(
model=model,
ref_model=AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16),
args=training_args,
tokenizer=tokenizer,
train_dataset=dataset,
)
trainer.train()
# Save fine-tuned model
trainer.model.save_pretrained("final_ckpt")
tokenizer.save_pretrained("final_ckpt")
# Test the fine-tuned model
from transformers import pipeline
fine_tuned_model = AutoModelForCausalLM.from_pretrained("final_ckpt", torch_dtype=torch.bfloat16)
text_gen_pipeline = pipeline(
"text-generation",
model=fine_tuned_model,
tokenizer=tokenizer,
max_length=4096,
)
messages = [
{
"role": "system",
"content": "You are a helpful assistant chatbot that provides concise answers.",
},
{
"role": "user",
"content": "What are GPUs and why would I use them for machine learning tasks?",
},
]
prompt = "".join(f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" for msg in messages)
sequences = text_gen_pipeline(prompt, do_sample=True, temperature=0.7, top_p=0.9)
print(sequences[0]["generated_text"])
- Downloads last month
- 11
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for CultriX/Llama3-8B-DPO
Base model
meta-llama/Llama-3.1-8B
Finetuned
NousResearch/Hermes-3-Llama-3.1-8B