Text Generation
Transformers
PyTorch
Chinese
English
llama
text-generation-inference
fireballoon commited on
Commit
7f162f3
1 Parent(s): 9cbd412

Upload train_vicuna.py

Browse files
Files changed (1) hide show
  1. train_vicuna.py +207 -0
train_vicuna.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastchat.train.llama_flash_attn_monkey_patch import (
2
+ replace_llama_attn_with_flash_attn,
3
+ )
4
+
5
+ replace_llama_attn_with_flash_attn()
6
+
7
+ import json
8
+ from torch.utils.data import Dataset
9
+ from accelerate import Accelerator
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
11
+ import torch
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+
16
+
17
+ IGNORE_TOKEN_ID = -100
18
+
19
+
20
+ class MixData(Dataset):
21
+ def __init__(self, dataset, ratio, tokenizer):
22
+ super(Dataset, self).__init__()
23
+ self.dataset = dataset
24
+ self.data_size = [len(c) for c in self.dataset]
25
+ ratio = [r if isinstance(r, int) else s for r, s in zip(ratio, self.data_size)]
26
+ self.ratio = ratio
27
+ self.tokenizer = tokenizer
28
+ self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
29
+ print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
30
+
31
+ @staticmethod
32
+ def rounder(number):
33
+ rand = np.random.rand()
34
+ if rand < number - int(number):
35
+ return int(number) + 1
36
+ else:
37
+ return int(number)
38
+
39
+ @staticmethod
40
+ def choice_index(number, sample_size):
41
+ for i in range(len(sample_size)):
42
+ if number < sum(sample_size[:i + 1]):
43
+ return i, number - sum(sample_size[:i])
44
+
45
+ def __getitem__(self, index):
46
+ corpus_id, index = self.choice_index(index, self.sample_size)
47
+ rand = np.random.rand()
48
+ index = self.rounder((index + rand) / self.sample_size[corpus_id] * self.data_size[corpus_id])
49
+ index = min(index, len(self.dataset[corpus_id]) - 1)
50
+ return self.dataset[corpus_id][index]
51
+
52
+ def __len__(self):
53
+ return sum(self.sample_size)
54
+
55
+ def set_ratio(self, ratio):
56
+ self.ratio = ratio
57
+ self.data_size = [len(c) for c in self.dataset]
58
+ self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
59
+ print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
60
+
61
+ def collate_fn(self, data):
62
+ input_ids, labels = zip(*data)
63
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
64
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
65
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
66
+ features = {
67
+ 'input_ids': input_ids.long(),
68
+ 'labels': labels.long(),
69
+ 'attention_mask': attention_mask.long(),
70
+ }
71
+ return features
72
+
73
+
74
+ def last_index(lst, value):
75
+ return next((len(lst) - i - 1 for i, x in enumerate(lst[::-1]) if x != value), -1)
76
+
77
+
78
+ def safe_ids(ids, max_value, pad_id):
79
+ return [i if i < max_value else pad_id for i in ids]
80
+
81
+
82
+ def tokenize(messages, tokenizer):
83
+ roles = {"user": "USER", "assistant": "ASSISTANT"}
84
+ input_ids = []
85
+ labels = []
86
+ system = "A chat between a curious user and an artificial intelligence assistant. " \
87
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
88
+ system_ids = tokenizer.encode(system, add_special_tokens=False)
89
+ input_ids += system_ids
90
+ labels += [IGNORE_TOKEN_ID] * len(system_ids)
91
+ for i, turn in enumerate(messages):
92
+ role = roles.get(turn['role'], 'USER')
93
+ content = turn['content']
94
+ content = content.strip()
95
+ if role == 'ASSISTANT':
96
+ content += '</s>'
97
+ role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
98
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, max_length=2048)
99
+ input_ids += role_ids + content_ids
100
+ if role == 'ASSISTANT':
101
+ labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
102
+ else:
103
+ labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
104
+
105
+ input_ids = input_ids[:4096]
106
+ labels = labels[:4096]
107
+
108
+ trunc_id = last_index(labels, -100) + 1
109
+ input_ids = input_ids[:trunc_id]
110
+ labels = labels[:trunc_id]
111
+ if len(labels) == 0:
112
+ input_ids, labels = [0, 0], [-100, -100]
113
+ input_ids = safe_ids(input_ids, 64000, 0)
114
+ labels = safe_ids(labels, 64000, -100)
115
+ return input_ids, labels
116
+
117
+
118
+ class VicunaData(Dataset):
119
+ def __init__(self, data, tokenizer):
120
+ self.data = data
121
+ self.tokenizer = tokenizer
122
+
123
+ def __len__(self):
124
+ return len(self.data)
125
+
126
+ def __getitem__(self, item):
127
+ item = self.data[item]
128
+ input_ids, labels = tokenize(item, self.tokenizer)
129
+ return torch.tensor(input_ids), torch.tensor(labels)
130
+
131
+ def collate_fn(self, data):
132
+ input_ids, labels = zip(*data)
133
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
134
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
135
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
136
+ features = {
137
+ 'input_ids': input_ids.long(),
138
+ 'labels': labels.long(),
139
+ 'attention_mask': attention_mask.long(),
140
+ }
141
+ return features
142
+
143
+
144
+ def main():
145
+ accelerator = Accelerator(gradient_accumulation_steps=8)
146
+ batch_size = 4
147
+
148
+ save_path = 'out/baichuan-vicuna-7b'
149
+ model_name = './models/baichuan-llama-7b'
150
+
151
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
152
+ tokenizer.pad_token = tokenizer.unk_token
153
+
154
+ model = AutoModelForCausalLM.from_pretrained(model_name)
155
+ model.config.use_cache = False
156
+ model.gradient_checkpointing_enable()
157
+
158
+ share_gpt = VicunaData(json.load(open('data/new/share_gpt-90k.json')), tokenizer)
159
+ instruction = VicunaData(json.load(open('data/new/cot-75k.json')), tokenizer)
160
+ code = VicunaData(json.load(open('data/new/leet-9k.json')), tokenizer)
161
+
162
+ dataset = MixData([share_gpt, instruction, code],
163
+ [len(share_gpt), len(instruction), len(code)], tokenizer)
164
+
165
+ print(len(dataset))
166
+
167
+ data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn,
168
+ batch_size=batch_size, num_workers=0, shuffle=True)
169
+
170
+ optimizer = AdamW(model.parameters(), 2e-5)
171
+ model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
172
+
173
+ for epoch in range(10):
174
+ accelerator.print(f'Training {save_path} {epoch}')
175
+ accelerator.wait_for_everyone()
176
+ model.train()
177
+ tk0 = tqdm(data_loader, total=len(data_loader))
178
+ loss_report = []
179
+ for batch in tk0:
180
+ with accelerator.accumulate(model):
181
+ try:
182
+ out = model(**batch)
183
+ loss = out.loss
184
+
185
+ except:
186
+ loss = torch.tensor(0., device=model.device, requires_grad=True)
187
+
188
+ if loss.isnan():
189
+ print(loss)
190
+ print(batch)
191
+ loss = torch.tensor(0., device=model.device, requires_grad=True)
192
+
193
+ accelerator.backward(loss)
194
+ accelerator.clip_grad_norm_(model.parameters(), 1.)
195
+ optimizer.step()
196
+ optimizer.zero_grad()
197
+
198
+ loss_report.append(accelerator.gather(loss).mean().item())
199
+ tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
200
+
201
+ accelerator.wait_for_everyone()
202
+ model.save_checkpoint(f'{save_path}/{epoch}')
203
+
204
+
205
+ if __name__ == '__main__':
206
+ main()
207
+