Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
from typing import Dict, Optional | |
import time | |
import os | |
import pandas as pd | |
import torch | |
from datasets import Dataset, load_dataset | |
from transformers import PreTrainedTokenizerFast, TrainingArguments | |
from trl import DPOTrainer | |
from tokenizers import Tokenizer | |
from peft import LoraConfig, TaskType, PeftModel | |
from config import DpoConfig, T5ModelConfig | |
from model.chat_model import TextToTextModel | |
from utils.functions import get_T5_config | |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
def get_dataset(split: str, file: str, cache_dir: str = '.cache') -> Dataset: | |
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. | |
The dataset is converted to a dictionary with the following structure: | |
{ | |
'prompt': List[str], | |
'chosen': List[str], | |
'rejected': List[str], | |
} | |
""" | |
dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir) | |
def split_prompt_and_responses(sample: dict) -> Dict[str, str]: | |
return { | |
# add an eos token for signal that end of sentence, using in generate. | |
"prompt": f"{sample['prompt']}[EOS]", | |
"chosen": f"{sample['chosen']}[EOS]", | |
"rejected": f"{sample['rejected']}[EOS]", | |
} | |
return dataset.map(split_prompt_and_responses).shuffle(2333) | |
def train_dpo(config: DpoConfig, peft_config: LoraConfig=None) -> None: | |
# step 1. 加载tokenizer | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir) | |
# step 2. 加载预训练模型 | |
model_train, model_ref = None, None | |
if os.path.isdir(config.sft_model_file): | |
# 传入文件夹则 from_pretrained | |
model_train = TextToTextModel.from_pretrained(config.sft_model_file) | |
model_ref = TextToTextModel.from_pretrained(config.sft_model_file) | |
else: | |
# load_state_dict | |
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
model_train = TextToTextModel(t5_config) | |
model_train.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception | |
model_ref = TextToTextModel(t5_config) | |
model_ref.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) | |
# 4. 加载训练数据集 | |
train_dataset = get_dataset("train", file=config.dpo_train_file) | |
# 5. 加载评估数据集 | |
# eval_dataset = get_dataset("train", file=config.dpo_eval_file) | |
eval_dataset = None | |
# 6. 初始化训练参数 | |
training_args = TrainingArguments( | |
per_device_train_batch_size=config.per_device_train_batch_size, | |
num_train_epochs=config.num_train_epochs, | |
auto_find_batch_size=True, | |
remove_unused_columns=False, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
learning_rate=config.learning_rate, | |
logging_first_step=True, | |
logging_steps=config.logging_steps, | |
save_steps=config.save_steps, | |
output_dir=config.output_dir, | |
optim="adafactor", | |
report_to="tensorboard", | |
log_level='info', | |
warmup_steps=config.warmup_steps, | |
bf16=False, | |
fp16=config.fp16, | |
seed=config.seed, | |
logging_dir=config.log_dir, | |
) | |
# 7. 初始化 DPO trainer | |
dpo_trainer = DPOTrainer( | |
model_train, | |
model_ref, | |
peft_config=peft_config, | |
args=training_args, | |
beta=config.beta, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
tokenizer=tokenizer, | |
max_length=config.max_seq_len, | |
max_target_length=config.max_seq_len, | |
max_prompt_length=config.max_seq_len, | |
generate_during_eval=True, | |
is_encoder_decoder=True, | |
) | |
# 8. 训练 | |
dpo_trainer.train( | |
# resume_from_checkpoint=True | |
) | |
# 9. save log | |
loss_log = pd.DataFrame(dpo_trainer.state.log_history) | |
log_dir = './logs' | |
if not os.path.exists(log_dir): | |
os.mkdir(log_dir) | |
loss_log.to_csv(f"{log_dir}/dpo_train_log_{time.strftime('%Y%m%d-%H%M')}.csv") | |
# 10. 保存模型/lora | |
suffixe = '/lora/' if peft_config is not None else '/dpo' | |
model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe | |
dpo_trainer.save_model(model_save_dir) | |
print('save model or lora adapter to: {}'.format(model_save_dir)) | |
def merge_lora_weight_into_model(config: DpoConfig, peft_config: LoraConfig) -> None: | |
# step 1. 加载tokenizer | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir) | |
# step 2. 加载预训练模型 | |
sft_model = None | |
if os.path.isdir(config.sft_model_file): | |
# 传入文件夹则 from_pretrained | |
sft_model = TextToTextModel.from_pretrained(config.sft_model_file) | |
else: | |
# load_state_dict | |
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
sft_model = TextToTextModel(t5_config) | |
sft_model.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception | |
# 注意这个路径要和上面的model_save_dir一致 | |
# train_dpo函数代码 | |
# 9. 保存模型/lora | |
# suffixe = '/lora/' if peft_config is not None else '/dpo' | |
# model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe | |
adapter_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + '/lora' | |
peft_model = PeftModel.from_pretrained( | |
model=sft_model, | |
model_id=adapter_save_dir, | |
config=peft_config, | |
adapter_name='adapter', | |
) | |
# peft_model = PeftModel( | |
# model=sft_model, | |
# peft_config=peft_config, | |
# adapter_name='adapter', | |
# ) | |
# 3. load adapter | |
print('load adapter from dir: {}'.format(adapter_save_dir)) | |
peft_model.load_adapter(model_id=adapter_save_dir, adapter_name='adapter',) | |
# 4. merge | |
peft_model = peft_model.merge_and_unload() | |
# 5. save | |
save_merge_file = config.sft_model_file + '.dpo_lora_merged' | |
sft_model.save_pretrained(save_merge_file) | |
print('save merge model file to: {}'.format(save_merge_file)) | |
if __name__ == "__main__": | |
peft_config = LoraConfig( | |
task_type=TaskType.SEQ_2_SEQ_LM, # text 2 text lora model | |
inference_mode=False, | |
r=16, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
bias="all", | |
) | |
dpo_config = DpoConfig() | |
# 1. train | |
train_dpo(dpo_config, peft_config=None) | |
# 2. merge lora adapter into model | |
# merge_lora_weight_into_model(dpo_config, peft_config) | |