fireballoon
commited on
Commit
•
d3b83cf
1
Parent(s):
fe26a60
Upload train_vicuna.py
Browse files- train_vicuna.py +28 -30
train_vicuna.py
CHANGED
@@ -7,7 +7,7 @@ replace_llama_attn_with_flash_attn()
|
|
7 |
import json
|
8 |
from torch.utils.data import Dataset
|
9 |
from accelerate import Accelerator
|
10 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
|
11 |
import torch
|
12 |
from torch.nn.utils.rnn import pad_sequence
|
13 |
from tqdm import tqdm
|
@@ -79,6 +79,12 @@ def safe_ids(ids, max_value, pad_id):
|
|
79 |
return [i if i < max_value else pad_id for i in ids]
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def tokenize(messages, tokenizer):
|
83 |
roles = {"user": "USER", "assistant": "ASSISTANT"}
|
84 |
input_ids = []
|
@@ -95,23 +101,28 @@ def tokenize(messages, tokenizer):
|
|
95 |
if role == 'ASSISTANT':
|
96 |
content += '</s>'
|
97 |
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
|
98 |
-
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
|
|
|
99 |
input_ids += role_ids + content_ids
|
100 |
if role == 'ASSISTANT':
|
101 |
labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
|
102 |
else:
|
103 |
labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
|
104 |
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
trunc_id = last_index(labels,
|
109 |
input_ids = input_ids[:trunc_id]
|
110 |
labels = labels[:trunc_id]
|
111 |
if len(labels) == 0:
|
112 |
-
|
113 |
-
input_ids = safe_ids(input_ids,
|
114 |
-
labels = safe_ids(labels,
|
115 |
return input_ids, labels
|
116 |
|
117 |
|
@@ -131,7 +142,7 @@ class VicunaData(Dataset):
|
|
131 |
def collate_fn(self, data):
|
132 |
input_ids, labels = zip(*data)
|
133 |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
134 |
-
labels = pad_sequence(labels, batch_first=True, padding_value
|
135 |
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
136 |
features = {
|
137 |
'input_ids': input_ids.long(),
|
@@ -142,11 +153,11 @@ class VicunaData(Dataset):
|
|
142 |
|
143 |
|
144 |
def main():
|
145 |
-
accelerator = Accelerator(gradient_accumulation_steps=
|
146 |
batch_size = 4
|
147 |
|
148 |
save_path = 'out/baichuan-vicuna-7b'
|
149 |
-
model_name = '
|
150 |
|
151 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
|
152 |
tokenizer.pad_token = tokenizer.unk_token
|
@@ -155,12 +166,10 @@ def main():
|
|
155 |
model.config.use_cache = False
|
156 |
model.gradient_checkpointing_enable()
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
dataset = MixData([share_gpt, instruction, code],
|
163 |
-
[len(share_gpt), len(instruction), len(code)], tokenizer)
|
164 |
|
165 |
print(len(dataset))
|
166 |
|
@@ -178,17 +187,8 @@ def main():
|
|
178 |
loss_report = []
|
179 |
for batch in tk0:
|
180 |
with accelerator.accumulate(model):
|
181 |
-
|
182 |
-
|
183 |
-
loss = out.loss
|
184 |
-
|
185 |
-
except:
|
186 |
-
loss = torch.tensor(0., device=model.device, requires_grad=True)
|
187 |
-
|
188 |
-
if loss.isnan():
|
189 |
-
print(loss)
|
190 |
-
print(batch)
|
191 |
-
loss = torch.tensor(0., device=model.device, requires_grad=True)
|
192 |
|
193 |
accelerator.backward(loss)
|
194 |
accelerator.clip_grad_norm_(model.parameters(), 1.)
|
@@ -197,11 +197,9 @@ def main():
|
|
197 |
|
198 |
loss_report.append(accelerator.gather(loss).mean().item())
|
199 |
tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
|
200 |
-
|
201 |
accelerator.wait_for_everyone()
|
202 |
model.save_checkpoint(f'{save_path}/{epoch}')
|
203 |
|
204 |
|
205 |
if __name__ == '__main__':
|
206 |
main()
|
207 |
-
|
|
|
7 |
import json
|
8 |
from torch.utils.data import Dataset
|
9 |
from accelerate import Accelerator
|
10 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AdamW
|
11 |
import torch
|
12 |
from torch.nn.utils.rnn import pad_sequence
|
13 |
from tqdm import tqdm
|
|
|
79 |
return [i if i < max_value else pad_id for i in ids]
|
80 |
|
81 |
|
82 |
+
dummy_message = [{"role": "user", "content": "Who are you?"},
|
83 |
+
{"role": "assistant", "content": "I am vicuna, a language model trained by researchers from open-source community."},
|
84 |
+
{"role": "user", "content": "What can you do?"},
|
85 |
+
{"role": "assistant", "content": "I can chat with you."}]
|
86 |
+
|
87 |
+
|
88 |
def tokenize(messages, tokenizer):
|
89 |
roles = {"user": "USER", "assistant": "ASSISTANT"}
|
90 |
input_ids = []
|
|
|
101 |
if role == 'ASSISTANT':
|
102 |
content += '</s>'
|
103 |
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
|
104 |
+
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
|
105 |
+
max_length=tokenizer.model_max_length)
|
106 |
input_ids += role_ids + content_ids
|
107 |
if role == 'ASSISTANT':
|
108 |
labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
|
109 |
else:
|
110 |
labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
|
111 |
|
112 |
+
if tokenizer.add_bos_token:
|
113 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
114 |
+
labels = [IGNORE_TOKEN_ID] + labels
|
115 |
+
|
116 |
+
input_ids = input_ids[:tokenizer.model_max_length]
|
117 |
+
labels = labels[:tokenizer.model_max_length]
|
118 |
|
119 |
+
trunc_id = last_index(labels, IGNORE_TOKEN_ID) + 1
|
120 |
input_ids = input_ids[:trunc_id]
|
121 |
labels = labels[:trunc_id]
|
122 |
if len(labels) == 0:
|
123 |
+
return tokenize(dummy_message, tokenizer)
|
124 |
+
input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
|
125 |
+
labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
|
126 |
return input_ids, labels
|
127 |
|
128 |
|
|
|
142 |
def collate_fn(self, data):
|
143 |
input_ids, labels = zip(*data)
|
144 |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
145 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_TOKEN_ID)
|
146 |
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
147 |
features = {
|
148 |
'input_ids': input_ids.long(),
|
|
|
153 |
|
154 |
|
155 |
def main():
|
156 |
+
accelerator = Accelerator(gradient_accumulation_steps=4)
|
157 |
batch_size = 4
|
158 |
|
159 |
save_path = 'out/baichuan-vicuna-7b'
|
160 |
+
model_name = 'fireballoon/baichuan-llama-7b'
|
161 |
|
162 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
|
163 |
tokenizer.pad_token = tokenizer.unk_token
|
|
|
166 |
model.config.use_cache = False
|
167 |
model.gradient_checkpointing_enable()
|
168 |
|
169 |
+
dataset = VicunaData(
|
170 |
+
json.load(open('data/new/share_gpt-90k.json')) +
|
171 |
+
json.load(open('data/new/cot-75k.json')) +
|
172 |
+
json.load(open('data/new/leet-9k.json')), tokenizer)
|
|
|
|
|
173 |
|
174 |
print(len(dataset))
|
175 |
|
|
|
187 |
loss_report = []
|
188 |
for batch in tk0:
|
189 |
with accelerator.accumulate(model):
|
190 |
+
out = model(**batch)
|
191 |
+
loss = out.loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
accelerator.backward(loss)
|
194 |
accelerator.clip_grad_norm_(model.parameters(), 1.)
|
|
|
197 |
|
198 |
loss_report.append(accelerator.gather(loss).mean().item())
|
199 |
tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
|
|
|
200 |
accelerator.wait_for_everyone()
|
201 |
model.save_checkpoint(f'{save_path}/{epoch}')
|
202 |
|
203 |
|
204 |
if __name__ == '__main__':
|
205 |
main()
|
|