# coding=utf-8 from typing import Dict import time import os import pandas as pd import numpy as np import torch from datasets import Dataset, load_dataset from peft import LoraConfig from tqdm import tqdm from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments from transformers.generation.configuration_utils import GenerationConfig from model.chat_model import TextToTextModel from config import SFTconfig, T5ModelConfig from utils.functions import get_T5_config, MyTrainerCallback tqdm.pandas() os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset: """ 加载数据集 """ # 加载json数据集,如果要加载parquet,更改为'parquet'即可 dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir) def tokens_to_ids(samples: dict) -> Dict[str, str]: eos_token_id = tokenizer.eos_token_id batch_prompt = samples['prompt'] batch_response = samples['response'] encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False) encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False) # vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]] labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]] return { 'input_ids': input_ids, 'labels': labels, } dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names) return dataset def sft_train(config: SFTconfig) -> None: # step 1. 加载tokenizer tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir) # step 2. 加载预训练模型 model = None if os.path.isdir(config.finetune_from_ckp_file): # 传入文件夹则 from_pretrained model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file) else: # load_state_dict t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) model = TextToTextModel(t5_config) model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception # Step 4: Load the dataset dataset = get_dataset(file=config.sft_train_file, split="train", tokenizer=tokenizer) # Step 5: Define the training arguments # T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer # huggingface官网的sft工具适用于language model/LM模型 generation_config = GenerationConfig() generation_config.remove_invalid_values = True generation_config.eos_token_id = tokenizer.eos_token_id generation_config.pad_token_id = tokenizer.pad_token_id generation_config.decoder_start_token_id = tokenizer.pad_token_id generation_config.max_new_tokens = 320 generation_config.repetition_penalty = 1.5 generation_config.num_beams = 1 # greedy search generation_config.do_sample = False # greedy search training_args = Seq2SeqTrainingArguments( output_dir=config.output_dir, per_device_train_batch_size=config.batch_size, auto_find_batch_size=True, # 防止OOM gradient_accumulation_steps=config.gradient_accumulation_steps, learning_rate=config.learning_rate, logging_steps=config.logging_steps, num_train_epochs=config.num_train_epochs, optim="adafactor", report_to='tensorboard', log_level='info', save_steps=config.save_steps, save_total_limit=3, fp16=config.fp16, logging_first_step=config.logging_first_step, warmup_steps=config.warmup_steps, seed=config.seed, generation_config=generation_config, ) # step 6: init a collator collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len) empty_cuda_cahce = MyTrainerCallback() # Step 7: Define the Trainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer, data_collator=collator, callbacks=[empty_cuda_cahce] ) # step 8: train trainer.train( # resume_from_checkpoint=True ) loss_log = pd.DataFrame(trainer.state.log_history) log_dir = './logs' if not os.path.exists(log_dir): os.mkdir(log_dir) loss_log.to_csv(f"{log_dir}/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv") # Step 9: Save the model trainer.save_model(config.output_dir) if __name__ == '__main__': config = SFTconfig() sft_train(config)