Text Generation
Transformers
PyTorch
Chinese
English
llama
text-generation-inference
baichuan-vicuna-7b / train_vicuna.py
fireballoon's picture
Upload train_vicuna.py
d3b83cf
from fastchat.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)
replace_llama_attn_with_flash_attn()
import json
from torch.utils.data import Dataset
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AdamW
import torch
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import numpy as np
IGNORE_TOKEN_ID = -100
class MixData(Dataset):
def __init__(self, dataset, ratio, tokenizer):
super(Dataset, self).__init__()
self.dataset = dataset
self.data_size = [len(c) for c in self.dataset]
ratio = [r if isinstance(r, int) else s for r, s in zip(ratio, self.data_size)]
self.ratio = ratio
self.tokenizer = tokenizer
self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
@staticmethod
def rounder(number):
rand = np.random.rand()
if rand < number - int(number):
return int(number) + 1
else:
return int(number)
@staticmethod
def choice_index(number, sample_size):
for i in range(len(sample_size)):
if number < sum(sample_size[:i + 1]):
return i, number - sum(sample_size[:i])
def __getitem__(self, index):
corpus_id, index = self.choice_index(index, self.sample_size)
rand = np.random.rand()
index = self.rounder((index + rand) / self.sample_size[corpus_id] * self.data_size[corpus_id])
index = min(index, len(self.dataset[corpus_id]) - 1)
return self.dataset[corpus_id][index]
def __len__(self):
return sum(self.sample_size)
def set_ratio(self, ratio):
self.ratio = ratio
self.data_size = [len(c) for c in self.dataset]
self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
def collate_fn(self, data):
input_ids, labels = zip(*data)
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
features = {
'input_ids': input_ids.long(),
'labels': labels.long(),
'attention_mask': attention_mask.long(),
}
return features
def last_index(lst, value):
return next((len(lst) - i - 1 for i, x in enumerate(lst[::-1]) if x != value), -1)
def safe_ids(ids, max_value, pad_id):
return [i if i < max_value else pad_id for i in ids]
dummy_message = [{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am vicuna, a language model trained by researchers from open-source community."},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."}]
def tokenize(messages, tokenizer):
roles = {"user": "USER", "assistant": "ASSISTANT"}
input_ids = []
labels = []
system = "A chat between a curious user and an artificial intelligence assistant. " \
"The assistant gives helpful, detailed, and polite answers to the user's questions."
system_ids = tokenizer.encode(system, add_special_tokens=False)
input_ids += system_ids
labels += [IGNORE_TOKEN_ID] * len(system_ids)
for i, turn in enumerate(messages):
role = roles.get(turn['role'], 'USER')
content = turn['content']
content = content.strip()
if role == 'ASSISTANT':
content += '</s>'
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
max_length=tokenizer.model_max_length)
input_ids += role_ids + content_ids
if role == 'ASSISTANT':
labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
else:
labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
if tokenizer.add_bos_token:
input_ids = [tokenizer.bos_token_id] + input_ids
labels = [IGNORE_TOKEN_ID] + labels
input_ids = input_ids[:tokenizer.model_max_length]
labels = labels[:tokenizer.model_max_length]
trunc_id = last_index(labels, IGNORE_TOKEN_ID) + 1
input_ids = input_ids[:trunc_id]
labels = labels[:trunc_id]
if len(labels) == 0:
return tokenize(dummy_message, tokenizer)
input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
return input_ids, labels
class VicunaData(Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, item):
item = self.data[item]
input_ids, labels = tokenize(item, self.tokenizer)
return torch.tensor(input_ids), torch.tensor(labels)
def collate_fn(self, data):
input_ids, labels = zip(*data)
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_TOKEN_ID)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
features = {
'input_ids': input_ids.long(),
'labels': labels.long(),
'attention_mask': attention_mask.long(),
}
return features
def main():
accelerator = Accelerator(gradient_accumulation_steps=4)
batch_size = 4
save_path = 'out/baichuan-vicuna-7b'
model_name = 'fireballoon/baichuan-llama-7b'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
tokenizer.pad_token = tokenizer.unk_token
model = AutoModelForCausalLM.from_pretrained(model_name)
model.config.use_cache = False
model.gradient_checkpointing_enable()
dataset = VicunaData(
json.load(open('data/new/share_gpt-90k.json')) +
json.load(open('data/new/cot-75k.json')) +
json.load(open('data/new/leet-9k.json')), tokenizer)
print(len(dataset))
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn,
batch_size=batch_size, num_workers=0, shuffle=True)
optimizer = AdamW(model.parameters(), 2e-5)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
for epoch in range(10):
accelerator.print(f'Training {save_path} {epoch}')
accelerator.wait_for_everyone()
model.train()
tk0 = tqdm(data_loader, total=len(data_loader))
loss_report = []
for batch in tk0:
with accelerator.accumulate(model):
out = model(**batch)
loss = out.loss
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
optimizer.zero_grad()
loss_report.append(accelerator.gather(loss).mean().item())
tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
accelerator.wait_for_everyone()
model.save_checkpoint(f'{save_path}/{epoch}')
if __name__ == '__main__':
main()