|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import os |
|
import torch |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, HfArgumentParser, pipeline |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from trl import ORPOConfig, ORPOTrainer, set_seed |
|
from trl.core import LengthSampler |
|
|
|
|
|
|
|
tqdm.pandas() |
|
|
|
@dataclass |
|
class ScriptArguments: |
|
model_name: Optional[str] = field(default="microsoft/phi-2", metadata={"help": "the model name"}) |
|
optim: Optional[str] = field(default="adamw_torch", metadata={"help": "the model name"}) |
|
data_name: Optional[str] = field(default="argilla/ultrafeedback-binarized-preferences-cleaned", metadata={"help": "the model name"}) |
|
cache_dir: Optional[str] = field(default="", metadata={"help": "the model name"}) |
|
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"}) |
|
output_dir: Optional[str] = field(default='', metadata={"help": "use 'wandb' to log with wandb"}) |
|
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) |
|
lr_scheduler_type: Optional[str] = field(default='cosine', metadata={"help": "the learning rate scheduler"}) |
|
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the batch size"}) |
|
num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the batch size"}) |
|
beta: Optional[float] = field(default=0.25, metadata={"help": "weighting hyperparameter for L_OR"}) |
|
gradient_accumulation_steps: Optional[int] = field( |
|
default=1, metadata={"help": "the number of gradient accumulation steps"} |
|
) |
|
|
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
config = ORPOConfig( |
|
output_dir=script_args.output_dir, |
|
max_prompt_length=1024, |
|
max_length=2048, |
|
logging_steps=100, |
|
save_strategy='no', |
|
max_completion_length=2048, |
|
per_device_train_batch_size=script_args.per_device_train_batch_size, |
|
remove_unused_columns=False, |
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
|
learning_rate=script_args.learning_rate, |
|
optim=script_args.optim, |
|
lr_scheduler_type=script_args.lr_scheduler_type, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={'use_reentrant':True}, |
|
beta=script_args.beta, |
|
report_to='wandb', |
|
num_train_epochs=script_args.num_train_epochs, |
|
bf16=True, |
|
do_eval=False |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(script_args.model_name, |
|
cache_dir=script_args.cache_dir, |
|
attn_implementation='flash_attention_2', |
|
torch_dtype=torch.bfloat16) |
|
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, |
|
cache_dir=script_args.cache_dir) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" |
|
|
|
def build_dataset(tokenizer): |
|
ds_train = load_dataset(script_args.data_name, split="train", |
|
cache_dir=script_args.cache_dir) |
|
|
|
def chat_template_to_text(sample): |
|
sample["chosen"] = [item_chosen[1]['content'] for item_chosen in sample['chosen']] |
|
sample["rejected"] = [item_rejected[1]['content'] for item_rejected in sample['rejected']] |
|
sample['prompt'] = [tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}], tokenize=False, add_generation_prompt=True) for item_prompt in sample['prompt']] |
|
|
|
return sample |
|
|
|
ds_train = ds_train.map(chat_template_to_text, batched=True, num_proc=8) |
|
|
|
return ds_train |
|
|
|
train = build_dataset(tokenizer=tokenizer) |
|
|
|
trainer = ORPOTrainer( |
|
model=model, |
|
args=config, |
|
tokenizer=tokenizer, |
|
train_dataset=train |
|
) |
|
|
|
trainer.train() |