Spaces:
Runtime error
Runtime error
File size: 5,075 Bytes
f4fac26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# coding=utf-8
from typing import Dict
import time
import os
import pandas as pd
import numpy as np
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments
from transformers.generation.configuration_utils import GenerationConfig
from model.chat_model import TextToTextModel
from config import SFTconfig, T5ModelConfig
from utils.functions import get_T5_config, MyTrainerCallback
tqdm.pandas()
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
"""
加载数据集
"""
# 加载json数据集,如果要加载parquet,更改为'parquet'即可
dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)
def tokens_to_ids(samples: dict) -> Dict[str, str]:
eos_token_id = tokenizer.eos_token_id
batch_prompt = samples['prompt']
batch_response = samples['response']
encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False)
encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False)
# vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
return {
'input_ids': input_ids,
'labels': labels,
}
dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
return dataset
def sft_train(config: SFTconfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载预训练模型
model = None
if os.path.isdir(config.finetune_from_ckp_file):
# 传入文件夹则 from_pretrained
model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)
else:
# load_state_dict
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
model = TextToTextModel(t5_config)
model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception
# Step 4: Load the dataset
dataset = get_dataset(file=config.sft_train_file, split="train", tokenizer=tokenizer)
# Step 5: Define the training arguments
# T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
# huggingface官网的sft工具适用于language model/LM模型
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
generation_config.decoder_start_token_id = tokenizer.pad_token_id
generation_config.max_new_tokens = 320
generation_config.repetition_penalty = 1.5
generation_config.num_beams = 1 # greedy search
generation_config.do_sample = False # greedy search
training_args = Seq2SeqTrainingArguments(
output_dir=config.output_dir,
per_device_train_batch_size=config.batch_size,
auto_find_batch_size=True, # 防止OOM
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
logging_steps=config.logging_steps,
num_train_epochs=config.num_train_epochs,
optim="adafactor",
report_to='tensorboard',
log_level='info',
save_steps=config.save_steps,
save_total_limit=3,
fp16=config.fp16,
logging_first_step=config.logging_first_step,
warmup_steps=config.warmup_steps,
seed=config.seed,
generation_config=generation_config,
)
# step 6: init a collator
collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
empty_cuda_cahce = MyTrainerCallback()
# Step 7: Define the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
callbacks=[empty_cuda_cahce]
)
# step 8: train
trainer.train(
# resume_from_checkpoint=True
)
loss_log = pd.DataFrame(trainer.state.log_history)
log_dir = './logs'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
loss_log.to_csv(f"{log_dir}/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# Step 9: Save the model
trainer.save_model(config.output_dir)
if __name__ == '__main__':
config = SFTconfig()
sft_train(config) |