rm_code / dpo.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
from unsloth import PatchDPOTrainer
from unsloth import FastLanguageModel
import torch
import os
import re
from typing import List, Literal, Optional
import pprint
from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
PatchDPOTrainer()
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "hahayang012/Mistral-Small-3.1-24B-Base-2503-SFT", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
ds1 = load_dataset("parquet", data_files="/home/dataset/data/ds1.parquet")
ds2 = load_dataset("parquet", data_files="/home/dataset/data/ds2.parquet")
ds3 = load_dataset("parquet", data_files="/home/dataset/data/ds3.parquet")
ds4 = load_dataset("parquet", data_files="/home/dataset/data/ds4.parquet")
def prepare_dpo_dataset(dataset):
dataset = dataset.map(lambda x: {
"prompt": x["chosen_prompt"],
"chosen": x["chosen"],
"rejected": x["reject"]
})
return dataset.select_columns(["prompt", "chosen", "rejected"])
ds1 = prepare_dpo_dataset(ds1)
ds2 = prepare_dpo_dataset(ds2)
ds3 = prepare_dpo_dataset(ds3)
ds4 = prepare_dpo_dataset(ds4)
model = FastLanguageModel.get_peft_model(
model,
r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 64,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
dpo_trainer = DPOTrainer(
model = model,
ref_model = None,
args = DPOConfig(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
num_train_epochs = 3,
learning_rate = 5e-6,
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.0,
lr_scheduler_type = "linear",
seed = 42,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
beta = 0.1,
train_dataset = raw_datasets["train"],
# eval_dataset = raw_datasets["test"],
tokenizer = tokenizer,
max_length = 1024,
max_prompt_length = 512,
)
dpo_trainer.train()