chatmlTest / pre_train.py
fangshengren's picture
Upload 59 files
f4fac26 verified
raw
history blame
4.81 kB
# coding=utf-8
import time
import os
import pandas as pd
from dataclasses import dataclass
import torch
from typing import Dict
from tqdm import tqdm
import numpy as np
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from transformers.generation.configuration_utils import GenerationConfig
from datasets import Dataset, load_dataset
from model.chat_model import TextToTextModel
from model.dataset import MyDataset
from config import TrainConfig, T5ModelConfig
from utils.functions import json_to_dataclass, get_T5_config, MyTrainerCallback
tqdm.pandas()
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
"""
加载数据集
"""
dataset = load_dataset('parquet', data_files=file, split=split, cache_dir=cache_dir)
def tokens_to_ids(samples: dict) -> Dict[str, str]:
eos_token_id = tokenizer.eos_token_id
batch_prompt = samples['prompt']
batch_response = samples['response']
encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False,)
encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False,)
# vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
return {
'input_ids': input_ids,
'labels': labels,
}
dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
return dataset
def pre_train(config: TrainConfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载模型配置文件
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
# step 3. 初始化模型
model = TextToTextModel(t5_config)
# Step 4: Load my dataset
dataset = get_dataset(file=config.train_file, split='train', tokenizer=tokenizer)
# Step 5: Define the training arguments
# T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
# huggingface官网的sft工具适用于language model/LM模型
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
generation_config.decoder_start_token_id = tokenizer.pad_token_id
generation_config.max_new_tokens = 320
generation_config.num_beams = 1 # greedy search
generation_config.do_sample = False # greedy search
training_args = Seq2SeqTrainingArguments(
output_dir=config.output_dir,
per_device_train_batch_size=config.batch_size_per_gpu,
auto_find_batch_size=True, # 防止OOM
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learn_rate,
logging_steps=config.logging_steps,
num_train_epochs=config.epochs,
optim="adafactor",
report_to='tensorboard',
log_level='info',
save_steps=config.save_steps,
save_total_limit=3,
fp16=True if config.mixed_precision == 'fp16' else False,
bf16=True if config.mixed_precision == 'bf16' else False,
logging_first_step=True,
warmup_steps=config.warmup_steps,
seed=config.seed,
generation_config=generation_config,
)
# step 6: init my collator,
collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
empty_cuda_cahce = MyTrainerCallback()
# Step 7: Define the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
callbacks=[empty_cuda_cahce],
)
# step 8: train
trainer.train(
# resume_from_checkpoint=True
)
#step 9: save log
loss_log = pd.DataFrame(trainer.state.log_history)
log_dir = './logs'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
loss_log.to_csv(f"{log_dir}/pre_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# Step 10: Save the model
trainer.save_model(config.output_dir)
if __name__ == '__main__':
config = TrainConfig()
pre_train(config)