File size: 5,075 Bytes
f4fac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# coding=utf-8
from typing import Dict
import time 
import os 
import pandas as pd 
import numpy as np
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments
from transformers.generation.configuration_utils import GenerationConfig

from model.chat_model import TextToTextModel
from config import SFTconfig, T5ModelConfig
from utils.functions import get_T5_config, MyTrainerCallback

tqdm.pandas()
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
    """
    加载数据集
    """

    # 加载json数据集,如果要加载parquet,更改为'parquet'即可
    dataset = load_dataset('json', data_files=file,  split=split, cache_dir=cache_dir)

    def tokens_to_ids(samples: dict) -> Dict[str, str]:

        eos_token_id = tokenizer.eos_token_id

        batch_prompt = samples['prompt']
        batch_response = samples['response']

        encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False)
        encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False)

        # vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
        input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
        labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]

        return {
            'input_ids': input_ids,
            'labels': labels,
        }

    dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)

    return dataset

def sft_train(config: SFTconfig) -> None:

    # step 1. 加载tokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
    
    # step 2. 加载预训练模型
    model = None
    if os.path.isdir(config.finetune_from_ckp_file):
        # 传入文件夹则 from_pretrained
        model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)
    else:
        # load_state_dict
        t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
        model = TextToTextModel(t5_config)
        model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception

    # Step 4: Load the dataset
    dataset = get_dataset(file=config.sft_train_file, split="train", tokenizer=tokenizer)

    # Step 5: Define the training arguments
    # T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
    # huggingface官网的sft工具适用于language model/LM模型
    generation_config = GenerationConfig()
    generation_config.remove_invalid_values = True
    generation_config.eos_token_id = tokenizer.eos_token_id
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.decoder_start_token_id = tokenizer.pad_token_id
    generation_config.max_new_tokens = 320
    generation_config.repetition_penalty = 1.5
    generation_config.num_beams = 1         # greedy search
    generation_config.do_sample = False     # greedy search

    training_args = Seq2SeqTrainingArguments(
        output_dir=config.output_dir,
        per_device_train_batch_size=config.batch_size,
        auto_find_batch_size=True,  # 防止OOM
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        logging_steps=config.logging_steps,
        num_train_epochs=config.num_train_epochs,
        optim="adafactor",
        report_to='tensorboard',
        log_level='info',
        save_steps=config.save_steps,
        save_total_limit=3,
        fp16=config.fp16,
        logging_first_step=config.logging_first_step,
        warmup_steps=config.warmup_steps,
        seed=config.seed,
        generation_config=generation_config,
    )

    # step 6: init a collator
    collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
    empty_cuda_cahce = MyTrainerCallback()

    # Step 7: Define the Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        data_collator=collator,
        callbacks=[empty_cuda_cahce]
    )

    # step 8: train
    trainer.train(
        # resume_from_checkpoint=True
    )

    loss_log = pd.DataFrame(trainer.state.log_history)
    log_dir = './logs'
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    loss_log.to_csv(f"{log_dir}/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")

    # Step 9: Save the model
    trainer.save_model(config.output_dir)

if __name__ == '__main__':
    config = SFTconfig()
    sft_train(config)