|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import copy |
|
import os |
|
from dataclasses import dataclass, field |
|
import random |
|
import json |
|
import logging |
|
import pathlib |
|
from typing import Dict, Optional, Sequence |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
import transformers |
|
from torch.utils.data import Dataset |
|
from transformers import Trainer, AddedToken |
|
|
|
from fastchat.model.model_adapter import get_conversation_template |
|
|
|
default_conversation = get_conversation_template("t5") |
|
|
|
|
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
DEFAULT_EOS_TOKEN = "</s>" |
|
DEFAULT_BOS_TOKEN = "</s>" |
|
DEFAULT_UNK_TOKEN = "</s>" |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
data_path: str = field( |
|
default=None, metadata={"help": "Path to the training data."} |
|
) |
|
lazy_preprocess: bool = False |
|
num_data: int = -1 |
|
preprocessed_path: str = field( |
|
default=None, metadata={"help": "Path to the preprocessed training data."} |
|
) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
cache_dir: Optional[str] = field(default=None) |
|
optim: str = field(default="adamw_torch") |
|
model_max_length: int = field( |
|
default=2048, |
|
metadata={ |
|
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." |
|
}, |
|
) |
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): |
|
"""Collects the state dict and dump to disk.""" |
|
state_dict = trainer.model.state_dict() |
|
if trainer.args.should_save: |
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
|
del state_dict |
|
trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
other_tokens, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
for new_token in other_tokens: |
|
num_new_tokens += tokenizer.add_tokens(AddedToken(new_token, normalized=False)) |
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
def _tokenize_fn( |
|
strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer |
|
) -> Dict: |
|
"""Tokenize a list of strings.""" |
|
tokenized_list = [ |
|
tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
) |
|
for text in strings |
|
] |
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
|
input_ids_lens = labels_lens = [ |
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
|
for tokenized in tokenized_list |
|
] |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
input_ids_lens=input_ids_lens, |
|
labels_lens=labels_lens, |
|
) |
|
|
|
|
|
def _form_qa( |
|
q_list, |
|
a_list, |
|
tokenized_conversation, |
|
tokenized_lens, |
|
speakers, |
|
header_len, |
|
max_length, |
|
eos_id, |
|
): |
|
cur_idx = header_len |
|
conv_len = len(tokenized_conversation) |
|
|
|
for tokenized_len, speaker in zip(tokenized_lens, speakers): |
|
if cur_idx >= conv_len: |
|
break |
|
if speaker == "gpt": |
|
|
|
content_a = None |
|
if tokenized_len > max_length: |
|
content_a = tokenized_conversation[cur_idx : cur_idx + max_length] |
|
else: |
|
content_a = tokenized_conversation[cur_idx : cur_idx + tokenized_len] |
|
content_a.append(eos_id) |
|
a_list.append(content_a) |
|
content_q = None |
|
if cur_idx >= max_length: |
|
content_q = tokenized_conversation[cur_idx - max_length : cur_idx] |
|
else: |
|
content_q = tokenized_conversation[:cur_idx] |
|
content_q.append(eos_id) |
|
q_list.append(content_q) |
|
|
|
assert a_list[-1][-1] == eos_id, "Last Token is not EOS!" |
|
cur_idx += tokenized_len |
|
|
|
|
|
def _add_speaker_and_signal(header, source, get_conversation=True): |
|
"""Add speaker and start/end signal on each round.""" |
|
BEGIN_SIGNAL = "### " |
|
END_SIGNAL = "\n" |
|
conversation = header |
|
|
|
unknown_role = "unknown" |
|
roles = { |
|
"human": default_conversation.roles[0], |
|
"gpt": default_conversation.roles[1], |
|
} |
|
|
|
for i in range(len(source)): |
|
sentence = source[i] |
|
sentence_from = sentence["from"].lower() |
|
|
|
|
|
if sentence_from == "human": |
|
|
|
if i != len(source) - 1: |
|
next_sentence = source[i + 1] |
|
sentence["value"] = ( |
|
BEGIN_SIGNAL |
|
+ roles.get(sentence_from, unknown_role) |
|
+ ": " |
|
+ sentence["value"] |
|
+ END_SIGNAL |
|
+ BEGIN_SIGNAL |
|
+ roles.get(next_sentence["from"].lower(), unknown_role) |
|
+ ": " |
|
) |
|
else: |
|
|
|
pass |
|
else: |
|
sentence["value"] = sentence["value"] + END_SIGNAL |
|
if get_conversation: |
|
conversation += sentence["value"] |
|
|
|
return conversation |
|
|
|
|
|
def preprocess( |
|
sources: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
) -> Dict: |
|
""" |
|
Given a list of sources, each is a conversation list. This transform: |
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
|
2. Concatenate conversations together; |
|
3. Tokenize the concatenated conversation; |
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
|
""" |
|
|
|
conversations = [] |
|
header = f"{default_conversation.system_message}\n\n" |
|
for source in sources: |
|
conversation = _add_speaker_and_signal(header, source, tokenizer) |
|
conversations.append(conversation) |
|
|
|
|
|
tokenized_conversations = tokenizer(conversations, max_length=None)["input_ids"] |
|
q_list = [] |
|
a_list = [] |
|
|
|
header_len = _tokenize_fn([header], tokenizer)["input_ids_lens"][0] - 1 |
|
from tqdm import tqdm |
|
|
|
for tokenized_conversation, source in tqdm(zip(tokenized_conversations, sources)): |
|
tokenized_sentence = _tokenize_fn([s["value"] for s in source], tokenizer) |
|
tokenized_lens = tokenized_sentence["input_ids_lens"] |
|
tokenized_lens = [l - 1 for l in tokenized_lens] |
|
speakers = [sentence["from"] for sentence in source] |
|
ids = tokenized_sentence["input_ids"] |
|
_form_qa( |
|
q_list, |
|
a_list, |
|
tokenized_conversation, |
|
tokenized_lens, |
|
speakers, |
|
header_len, |
|
tokenizer.model_max_length, |
|
tokenizer.eos_token_id, |
|
) |
|
return dict(input_ids=q_list, labels=a_list) |
|
|
|
|
|
class SupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__( |
|
self, |
|
data_path: str, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
preprocessed_path, |
|
num_data, |
|
): |
|
super(SupervisedDataset, self).__init__() |
|
|
|
|
|
|
|
if dist.get_rank() != 0: |
|
dist.barrier() |
|
self.preprocessed_path = preprocessed_path |
|
if os.path.exists(self.preprocessed_path): |
|
logging.warning("loading from preprocessed data") |
|
with open(self.preprocessed_path, "r") as f: |
|
data_dict = json.load(f) |
|
if dist.get_rank() == 0: |
|
dist.barrier() |
|
else: |
|
if not os.path.exists("preprocessed_data"): |
|
os.mkdir("preprocessed_data") |
|
assert dist.get_rank() == 0, "Only the first process should process" |
|
logging.warning("Loading data...") |
|
list_data_dict = json.load(open(data_path, "r")) |
|
|
|
logging.warning("Formatting inputs...") |
|
sources = [] |
|
|
|
sources = [example["conversations"] for example in list_data_dict] |
|
|
|
data_dict = preprocess(sources, tokenizer) |
|
json_data_dict = json.dumps(data_dict) |
|
|
|
|
|
with open(self.preprocessed_path, "w") as f: |
|
f.write(json_data_dict) |
|
|
|
|
|
dist.barrier() |
|
|
|
if num_data != -1: |
|
data_dict["input_ids"] = data_dict["input_ids"][:num_data] |
|
data_dict["labels"] = data_dict["labels"][:num_data] |
|
|
|
|
|
temp = list(zip(data_dict["input_ids"], data_dict["labels"])) |
|
random.shuffle(temp) |
|
res1, res2 = zip(*temp) |
|
data_dict["input_ids"], data_dict["labels"] = list(res1), list(res2) |
|
|
|
|
|
self.input_ids = copy.deepcopy(data_dict["input_ids"]) |
|
self.labels = copy.deepcopy(data_dict["labels"]) |
|
length_arr = defaultdict(int) |
|
for idx, (input, label) in enumerate( |
|
zip(data_dict["input_ids"], data_dict["labels"]) |
|
): |
|
length_arr[str(len(label) // 100)] += 1 |
|
if len(input) <= 5: |
|
del_idx = self.input_ids.index(input) |
|
self.input_ids.pop(del_idx) |
|
self.labels.pop(del_idx) |
|
if len(label) <= 5: |
|
del_idx = self.labels.index(label) |
|
self.input_ids.pop(del_idx) |
|
self.labels.pop(del_idx) |
|
|
|
for input, label in zip(self.input_ids, self.labels): |
|
assert len(input) >= 5 |
|
assert len(label) >= 5 |
|
|
|
def __len__(self): |
|
return len(self.input_ids) |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
return dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
|
|
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple( |
|
[ |
|
torch.as_tensor(instance[key], dtype=torch.int64) |
|
for instance in instances |
|
] |
|
for key in ("input_ids", "labels") |
|
) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id |
|
) |
|
labels = torch.nn.utils.rnn.pad_sequence( |
|
labels, batch_first=True, padding_value=IGNORE_INDEX |
|
) |
|
ret = dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
torch.set_printoptions(profile="full") |
|
return ret |
|
|
|
|
|
def make_supervised_data_module( |
|
tokenizer: transformers.PreTrainedTokenizer, data_args |
|
) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
dataset_cls = SupervisedDataset |
|
train_dataset = dataset_cls( |
|
tokenizer=tokenizer, |
|
data_path=data_args.data_path, |
|
preprocessed_path=data_args.preprocessed_path, |
|
num_data=data_args.num_data, |
|
) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict( |
|
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator |
|
) |
|
|
|
|
|
def train(): |
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, DataArguments, TrainingArguments) |
|
) |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
model = transformers.AutoModelForSeq2SeqLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
) |
|
|
|
|
|
tokenizer = transformers.T5Tokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
|
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), |
|
other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
trainer = Trainer( |
|
model=model, tokenizer=tokenizer, args=training_args, **data_module |
|
) |
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
trainer.save_state() |
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|