Text Generation
Transformers
PyTorch
Chinese
English
llama
text-generation-inference
fireballoon commited on
Commit
d3b83cf
1 Parent(s): fe26a60

Upload train_vicuna.py

Browse files
Files changed (1) hide show
  1. train_vicuna.py +28 -30
train_vicuna.py CHANGED
@@ -7,7 +7,7 @@ replace_llama_attn_with_flash_attn()
7
  import json
8
  from torch.utils.data import Dataset
9
  from accelerate import Accelerator
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
11
  import torch
12
  from torch.nn.utils.rnn import pad_sequence
13
  from tqdm import tqdm
@@ -79,6 +79,12 @@ def safe_ids(ids, max_value, pad_id):
79
  return [i if i < max_value else pad_id for i in ids]
80
 
81
 
 
 
 
 
 
 
82
  def tokenize(messages, tokenizer):
83
  roles = {"user": "USER", "assistant": "ASSISTANT"}
84
  input_ids = []
@@ -95,23 +101,28 @@ def tokenize(messages, tokenizer):
95
  if role == 'ASSISTANT':
96
  content += '</s>'
97
  role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
98
- content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, max_length=2048)
 
99
  input_ids += role_ids + content_ids
100
  if role == 'ASSISTANT':
101
  labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
102
  else:
103
  labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
104
 
105
- input_ids = input_ids[:4096]
106
- labels = labels[:4096]
 
 
 
 
107
 
108
- trunc_id = last_index(labels, -100) + 1
109
  input_ids = input_ids[:trunc_id]
110
  labels = labels[:trunc_id]
111
  if len(labels) == 0:
112
- input_ids, labels = [0, 0], [-100, -100]
113
- input_ids = safe_ids(input_ids, 64000, 0)
114
- labels = safe_ids(labels, 64000, -100)
115
  return input_ids, labels
116
 
117
 
@@ -131,7 +142,7 @@ class VicunaData(Dataset):
131
  def collate_fn(self, data):
132
  input_ids, labels = zip(*data)
133
  input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
134
- labels = pad_sequence(labels, batch_first=True, padding_value=-100)
135
  attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
136
  features = {
137
  'input_ids': input_ids.long(),
@@ -142,11 +153,11 @@ class VicunaData(Dataset):
142
 
143
 
144
  def main():
145
- accelerator = Accelerator(gradient_accumulation_steps=8)
146
  batch_size = 4
147
 
148
  save_path = 'out/baichuan-vicuna-7b'
149
- model_name = './models/baichuan-llama-7b'
150
 
151
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
152
  tokenizer.pad_token = tokenizer.unk_token
@@ -155,12 +166,10 @@ def main():
155
  model.config.use_cache = False
156
  model.gradient_checkpointing_enable()
157
 
158
- share_gpt = VicunaData(json.load(open('data/new/share_gpt-90k.json')), tokenizer)
159
- instruction = VicunaData(json.load(open('data/new/cot-75k.json')), tokenizer)
160
- code = VicunaData(json.load(open('data/new/leet-9k.json')), tokenizer)
161
-
162
- dataset = MixData([share_gpt, instruction, code],
163
- [len(share_gpt), len(instruction), len(code)], tokenizer)
164
 
165
  print(len(dataset))
166
 
@@ -178,17 +187,8 @@ def main():
178
  loss_report = []
179
  for batch in tk0:
180
  with accelerator.accumulate(model):
181
- try:
182
- out = model(**batch)
183
- loss = out.loss
184
-
185
- except:
186
- loss = torch.tensor(0., device=model.device, requires_grad=True)
187
-
188
- if loss.isnan():
189
- print(loss)
190
- print(batch)
191
- loss = torch.tensor(0., device=model.device, requires_grad=True)
192
 
193
  accelerator.backward(loss)
194
  accelerator.clip_grad_norm_(model.parameters(), 1.)
@@ -197,11 +197,9 @@ def main():
197
 
198
  loss_report.append(accelerator.gather(loss).mean().item())
199
  tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
200
-
201
  accelerator.wait_for_everyone()
202
  model.save_checkpoint(f'{save_path}/{epoch}')
203
 
204
 
205
  if __name__ == '__main__':
206
  main()
207
-
 
7
  import json
8
  from torch.utils.data import Dataset
9
  from accelerate import Accelerator
10
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AdamW
11
  import torch
12
  from torch.nn.utils.rnn import pad_sequence
13
  from tqdm import tqdm
 
79
  return [i if i < max_value else pad_id for i in ids]
80
 
81
 
82
+ dummy_message = [{"role": "user", "content": "Who are you?"},
83
+ {"role": "assistant", "content": "I am vicuna, a language model trained by researchers from open-source community."},
84
+ {"role": "user", "content": "What can you do?"},
85
+ {"role": "assistant", "content": "I can chat with you."}]
86
+
87
+
88
  def tokenize(messages, tokenizer):
89
  roles = {"user": "USER", "assistant": "ASSISTANT"}
90
  input_ids = []
 
101
  if role == 'ASSISTANT':
102
  content += '</s>'
103
  role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
104
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
105
+ max_length=tokenizer.model_max_length)
106
  input_ids += role_ids + content_ids
107
  if role == 'ASSISTANT':
108
  labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
109
  else:
110
  labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
111
 
112
+ if tokenizer.add_bos_token:
113
+ input_ids = [tokenizer.bos_token_id] + input_ids
114
+ labels = [IGNORE_TOKEN_ID] + labels
115
+
116
+ input_ids = input_ids[:tokenizer.model_max_length]
117
+ labels = labels[:tokenizer.model_max_length]
118
 
119
+ trunc_id = last_index(labels, IGNORE_TOKEN_ID) + 1
120
  input_ids = input_ids[:trunc_id]
121
  labels = labels[:trunc_id]
122
  if len(labels) == 0:
123
+ return tokenize(dummy_message, tokenizer)
124
+ input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
125
+ labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
126
  return input_ids, labels
127
 
128
 
 
142
  def collate_fn(self, data):
143
  input_ids, labels = zip(*data)
144
  input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
145
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_TOKEN_ID)
146
  attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
147
  features = {
148
  'input_ids': input_ids.long(),
 
153
 
154
 
155
  def main():
156
+ accelerator = Accelerator(gradient_accumulation_steps=4)
157
  batch_size = 4
158
 
159
  save_path = 'out/baichuan-vicuna-7b'
160
+ model_name = 'fireballoon/baichuan-llama-7b'
161
 
162
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
163
  tokenizer.pad_token = tokenizer.unk_token
 
166
  model.config.use_cache = False
167
  model.gradient_checkpointing_enable()
168
 
169
+ dataset = VicunaData(
170
+ json.load(open('data/new/share_gpt-90k.json')) +
171
+ json.load(open('data/new/cot-75k.json')) +
172
+ json.load(open('data/new/leet-9k.json')), tokenizer)
 
 
173
 
174
  print(len(dataset))
175
 
 
187
  loss_report = []
188
  for batch in tk0:
189
  with accelerator.accumulate(model):
190
+ out = model(**batch)
191
+ loss = out.loss
 
 
 
 
 
 
 
 
 
192
 
193
  accelerator.backward(loss)
194
  accelerator.clip_grad_norm_(model.parameters(), 1.)
 
197
 
198
  loss_report.append(accelerator.gather(loss).mean().item())
199
  tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
 
200
  accelerator.wait_for_everyone()
201
  model.save_checkpoint(f'{save_path}/{epoch}')
202
 
203
 
204
  if __name__ == '__main__':
205
  main()