fireballoon
commited on
Commit
•
7f162f3
1
Parent(s):
9cbd412
Upload train_vicuna.py
Browse files- train_vicuna.py +207 -0
train_vicuna.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastchat.train.llama_flash_attn_monkey_patch import (
|
2 |
+
replace_llama_attn_with_flash_attn,
|
3 |
+
)
|
4 |
+
|
5 |
+
replace_llama_attn_with_flash_attn()
|
6 |
+
|
7 |
+
import json
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
|
11 |
+
import torch
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
from tqdm import tqdm
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
IGNORE_TOKEN_ID = -100
|
18 |
+
|
19 |
+
|
20 |
+
class MixData(Dataset):
|
21 |
+
def __init__(self, dataset, ratio, tokenizer):
|
22 |
+
super(Dataset, self).__init__()
|
23 |
+
self.dataset = dataset
|
24 |
+
self.data_size = [len(c) for c in self.dataset]
|
25 |
+
ratio = [r if isinstance(r, int) else s for r, s in zip(ratio, self.data_size)]
|
26 |
+
self.ratio = ratio
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
|
29 |
+
print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def rounder(number):
|
33 |
+
rand = np.random.rand()
|
34 |
+
if rand < number - int(number):
|
35 |
+
return int(number) + 1
|
36 |
+
else:
|
37 |
+
return int(number)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def choice_index(number, sample_size):
|
41 |
+
for i in range(len(sample_size)):
|
42 |
+
if number < sum(sample_size[:i + 1]):
|
43 |
+
return i, number - sum(sample_size[:i])
|
44 |
+
|
45 |
+
def __getitem__(self, index):
|
46 |
+
corpus_id, index = self.choice_index(index, self.sample_size)
|
47 |
+
rand = np.random.rand()
|
48 |
+
index = self.rounder((index + rand) / self.sample_size[corpus_id] * self.data_size[corpus_id])
|
49 |
+
index = min(index, len(self.dataset[corpus_id]) - 1)
|
50 |
+
return self.dataset[corpus_id][index]
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return sum(self.sample_size)
|
54 |
+
|
55 |
+
def set_ratio(self, ratio):
|
56 |
+
self.ratio = ratio
|
57 |
+
self.data_size = [len(c) for c in self.dataset]
|
58 |
+
self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
|
59 |
+
print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
|
60 |
+
|
61 |
+
def collate_fn(self, data):
|
62 |
+
input_ids, labels = zip(*data)
|
63 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
64 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
65 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
66 |
+
features = {
|
67 |
+
'input_ids': input_ids.long(),
|
68 |
+
'labels': labels.long(),
|
69 |
+
'attention_mask': attention_mask.long(),
|
70 |
+
}
|
71 |
+
return features
|
72 |
+
|
73 |
+
|
74 |
+
def last_index(lst, value):
|
75 |
+
return next((len(lst) - i - 1 for i, x in enumerate(lst[::-1]) if x != value), -1)
|
76 |
+
|
77 |
+
|
78 |
+
def safe_ids(ids, max_value, pad_id):
|
79 |
+
return [i if i < max_value else pad_id for i in ids]
|
80 |
+
|
81 |
+
|
82 |
+
def tokenize(messages, tokenizer):
|
83 |
+
roles = {"user": "USER", "assistant": "ASSISTANT"}
|
84 |
+
input_ids = []
|
85 |
+
labels = []
|
86 |
+
system = "A chat between a curious user and an artificial intelligence assistant. " \
|
87 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
88 |
+
system_ids = tokenizer.encode(system, add_special_tokens=False)
|
89 |
+
input_ids += system_ids
|
90 |
+
labels += [IGNORE_TOKEN_ID] * len(system_ids)
|
91 |
+
for i, turn in enumerate(messages):
|
92 |
+
role = roles.get(turn['role'], 'USER')
|
93 |
+
content = turn['content']
|
94 |
+
content = content.strip()
|
95 |
+
if role == 'ASSISTANT':
|
96 |
+
content += '</s>'
|
97 |
+
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
|
98 |
+
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, max_length=2048)
|
99 |
+
input_ids += role_ids + content_ids
|
100 |
+
if role == 'ASSISTANT':
|
101 |
+
labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
|
102 |
+
else:
|
103 |
+
labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
|
104 |
+
|
105 |
+
input_ids = input_ids[:4096]
|
106 |
+
labels = labels[:4096]
|
107 |
+
|
108 |
+
trunc_id = last_index(labels, -100) + 1
|
109 |
+
input_ids = input_ids[:trunc_id]
|
110 |
+
labels = labels[:trunc_id]
|
111 |
+
if len(labels) == 0:
|
112 |
+
input_ids, labels = [0, 0], [-100, -100]
|
113 |
+
input_ids = safe_ids(input_ids, 64000, 0)
|
114 |
+
labels = safe_ids(labels, 64000, -100)
|
115 |
+
return input_ids, labels
|
116 |
+
|
117 |
+
|
118 |
+
class VicunaData(Dataset):
|
119 |
+
def __init__(self, data, tokenizer):
|
120 |
+
self.data = data
|
121 |
+
self.tokenizer = tokenizer
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.data)
|
125 |
+
|
126 |
+
def __getitem__(self, item):
|
127 |
+
item = self.data[item]
|
128 |
+
input_ids, labels = tokenize(item, self.tokenizer)
|
129 |
+
return torch.tensor(input_ids), torch.tensor(labels)
|
130 |
+
|
131 |
+
def collate_fn(self, data):
|
132 |
+
input_ids, labels = zip(*data)
|
133 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
134 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
135 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
136 |
+
features = {
|
137 |
+
'input_ids': input_ids.long(),
|
138 |
+
'labels': labels.long(),
|
139 |
+
'attention_mask': attention_mask.long(),
|
140 |
+
}
|
141 |
+
return features
|
142 |
+
|
143 |
+
|
144 |
+
def main():
|
145 |
+
accelerator = Accelerator(gradient_accumulation_steps=8)
|
146 |
+
batch_size = 4
|
147 |
+
|
148 |
+
save_path = 'out/baichuan-vicuna-7b'
|
149 |
+
model_name = './models/baichuan-llama-7b'
|
150 |
+
|
151 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
|
152 |
+
tokenizer.pad_token = tokenizer.unk_token
|
153 |
+
|
154 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
155 |
+
model.config.use_cache = False
|
156 |
+
model.gradient_checkpointing_enable()
|
157 |
+
|
158 |
+
share_gpt = VicunaData(json.load(open('data/new/share_gpt-90k.json')), tokenizer)
|
159 |
+
instruction = VicunaData(json.load(open('data/new/cot-75k.json')), tokenizer)
|
160 |
+
code = VicunaData(json.load(open('data/new/leet-9k.json')), tokenizer)
|
161 |
+
|
162 |
+
dataset = MixData([share_gpt, instruction, code],
|
163 |
+
[len(share_gpt), len(instruction), len(code)], tokenizer)
|
164 |
+
|
165 |
+
print(len(dataset))
|
166 |
+
|
167 |
+
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn,
|
168 |
+
batch_size=batch_size, num_workers=0, shuffle=True)
|
169 |
+
|
170 |
+
optimizer = AdamW(model.parameters(), 2e-5)
|
171 |
+
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
|
172 |
+
|
173 |
+
for epoch in range(10):
|
174 |
+
accelerator.print(f'Training {save_path} {epoch}')
|
175 |
+
accelerator.wait_for_everyone()
|
176 |
+
model.train()
|
177 |
+
tk0 = tqdm(data_loader, total=len(data_loader))
|
178 |
+
loss_report = []
|
179 |
+
for batch in tk0:
|
180 |
+
with accelerator.accumulate(model):
|
181 |
+
try:
|
182 |
+
out = model(**batch)
|
183 |
+
loss = out.loss
|
184 |
+
|
185 |
+
except:
|
186 |
+
loss = torch.tensor(0., device=model.device, requires_grad=True)
|
187 |
+
|
188 |
+
if loss.isnan():
|
189 |
+
print(loss)
|
190 |
+
print(batch)
|
191 |
+
loss = torch.tensor(0., device=model.device, requires_grad=True)
|
192 |
+
|
193 |
+
accelerator.backward(loss)
|
194 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.)
|
195 |
+
optimizer.step()
|
196 |
+
optimizer.zero_grad()
|
197 |
+
|
198 |
+
loss_report.append(accelerator.gather(loss).mean().item())
|
199 |
+
tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
|
200 |
+
|
201 |
+
accelerator.wait_for_everyone()
|
202 |
+
model.save_checkpoint(f'{save_path}/{epoch}')
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
main()
|
207 |
+
|