Moses25 commited on
Commit
9bcad37
1 Parent(s): e5d99f3

Create alpaca_lora.py

Browse files
Files changed (1) hide show
  1. alpaca_lora.py +343 -0
alpaca_lora.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List
4
+
5
+ import fire
6
+ import torch
7
+ import transformers
8
+ from datasets import load_dataset
9
+ from typing import List
10
+ import json
11
+
12
+ def jload(data_path:str)-> List:
13
+ with open(data_path,'r') as f:
14
+ data = json.load(f)
15
+ return data
16
+
17
+ """
18
+ Unused imports:
19
+ import torch.nn as nn
20
+ import bitsandbytes as bnb
21
+ """
22
+
23
+ from peft import (
24
+ LoraConfig,
25
+ get_peft_model,
26
+ PeftModel,
27
+ get_peft_model_state_dict,
28
+ prepare_model_for_int8_training,
29
+ set_peft_model_state_dict,
30
+ )
31
+ from transformers import LlamaForCausalLM, LlamaTokenizer,EarlyStoppingCallback
32
+
33
+
34
+
35
+
36
+ import json
37
+ import os.path as osp
38
+ from typing import Union
39
+
40
+ # os.environ["WANDB_DISABLED"] = "true"
41
+ class Prompter(object):
42
+ __slots__ = ("template", "_verbose")
43
+
44
+ def __init__(self, template_name: str = "", verbose: bool = False):
45
+ self._verbose = verbose
46
+ if not template_name:
47
+ # Enforce the default here, so the constructor can be called with '' and will not break.
48
+ template_name = "alpaca"
49
+ file_name = osp.join("data/templates", f"{template_name}.json")
50
+ if not osp.exists(file_name):
51
+ raise ValueError(f"{file_name} 文件不存在")
52
+ with open(file_name) as fp:
53
+ self.template = json.load(fp)
54
+ if self._verbose:
55
+ print(
56
+ f"Using prompt template {file_name}: {self.template['description']}"
57
+ )
58
+
59
+ def generate_prompt(
60
+ self,
61
+ instruction: str,
62
+ input: Union[None, str] = None,
63
+ label: Union[None, str] = None,
64
+ ) -> str:
65
+ # returns the full prompt from instruction and optional input
66
+ # if a label (=response, =output) is provided, it's also appended.
67
+ if input:
68
+ res = self.template["prompt_input"].format(
69
+ instruction=instruction, input=input
70
+ )
71
+ else:
72
+ res = self.template["prompt_no_input"].format(
73
+ instruction=instruction
74
+ )
75
+ if label:
76
+ res = f"{res}{label}"
77
+ if self._verbose:
78
+ print(res)
79
+ return res
80
+
81
+ def get_response(self, output: str) -> str:
82
+ return output.split(self.template["response_split"])[1].strip()
83
+
84
+
85
+ def train(
86
+ # model/data params
87
+ base_model: str = "", # the only required argument
88
+ data_path: str = "data/alapa",
89
+ output_dir: str = "./lora-alpaca",
90
+ # training hyperparams
91
+ batch_size: int = 12,
92
+ micro_batch_size: int = 4,
93
+ num_epochs: int = 3,
94
+ learning_rate: float = 3e-4,
95
+ cutoff_len: int = 512,
96
+ val_set_size: int = 200,
97
+ # lora hyperparams
98
+ lora_r: int = 64,
99
+ lora_alpha: int = 128,
100
+ lora_dropout: float = 0.05,
101
+ lora_target_modules: List[str] = [
102
+ "q_proj",
103
+ "v_proj",
104
+ ],
105
+ cache_dir=None,
106
+ peft_path='',
107
+ report_to='none',
108
+ # llm hyperparams
109
+ train_on_inputs: bool = False, # if False, masks out inputs in loss
110
+ add_eos_token: bool = False,
111
+ group_by_length: bool = False, # faster, but produces an odd training loss curve
112
+ # wandb params
113
+
114
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
115
+ prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
116
+ ):
117
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
118
+ print(
119
+ f"Training Alpaca-LoRA model with params:\n"
120
+ f"base_model: {base_model}\n"
121
+ f"data_path: {data_path}\n"
122
+ f"output_dir: {output_dir}\n"
123
+ f"batch_size: {batch_size}\n"
124
+ f"micro_batch_size: {micro_batch_size}\n"
125
+ f"num_epochs: {num_epochs}\n"
126
+ f"learning_rate: {learning_rate}\n"
127
+ f"cutoff_len: {cutoff_len}\n"
128
+ f"val_set_size: {val_set_size}\n"
129
+ f"lora_r: {lora_r}\n"
130
+ f"cache_dir: {cache_dir}\n"
131
+ f"lora_alpha: {lora_alpha}\n"
132
+ f"lora_dropout: {lora_dropout}\n"
133
+ f"lora_target_modules: {lora_target_modules}\n"
134
+ f"train_on_inputs: {train_on_inputs}\n"
135
+ f"group_by_length: {group_by_length}\n"
136
+ f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
137
+ f"prompt template: {prompt_template_name}\n"
138
+ f"peft_path: {peft_path}\n"
139
+ )
140
+ assert (
141
+ base_model
142
+ ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
143
+ gradient_accumulation_steps = batch_size // micro_batch_size
144
+
145
+ prompter = Prompter(prompt_template_name)
146
+
147
+ device_map = "auto"
148
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
149
+ ddp = world_size != 1
150
+ if ddp:
151
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
152
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
153
+
154
+
155
+
156
+ model = LlamaForCausalLM.from_pretrained(
157
+ base_model,
158
+ load_in_8bit=False,
159
+ torch_dtype=torch.float16,
160
+ device_map=device_map,
161
+ cache_dir=cache_dir,
162
+ )
163
+
164
+ tokenizer = LlamaTokenizer.from_pretrained(base_model)
165
+ if tokenizer.pad_token_id is None:
166
+ tokenizer.pad_token_id = (
167
+ len(tokenizer) + 1 # unk. we want this to be different from the eos token
168
+ )
169
+ tokenizer.padding_side = "left" # Allow batched inference
170
+ if model.get_input_embeddings().weight.size(0) != len(tokenizer):
171
+
172
+ print("Resize model embeddings to fit tokenizer")
173
+ model.resize_token_embeddings(len(tokenizer))
174
+
175
+ def tokenize(prompt, add_eos_token=True):
176
+ # there's probably a way to do this with the tokenizer settings
177
+ # but again, gotta move fast
178
+ result = tokenizer(
179
+ prompt,
180
+ truncation=True,
181
+ max_length=cutoff_len,
182
+ padding=False,
183
+ return_tensors=None,
184
+ )
185
+ if (
186
+ result["input_ids"][-1] != tokenizer.eos_token_id
187
+ and len(result["input_ids"]) < cutoff_len
188
+ and add_eos_token
189
+ ):
190
+ result["input_ids"].append(tokenizer.eos_token_id)
191
+ result["attention_mask"].append(1)
192
+
193
+ result["labels"] = result["input_ids"].copy()
194
+
195
+ return result
196
+
197
+ def generate_and_tokenize_prompt(data_point):
198
+ full_prompt = prompter.generate_prompt(
199
+ data_point["instruction"],
200
+ data_point["input"],
201
+ data_point["output"],
202
+ )
203
+ tokenized_full_prompt = tokenize(full_prompt)
204
+ if not train_on_inputs:
205
+ user_prompt = prompter.generate_prompt(
206
+ data_point["instruction"], data_point["input"]
207
+ )
208
+ tokenized_user_prompt = tokenize(
209
+ user_prompt, add_eos_token=add_eos_token
210
+ )
211
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
212
+
213
+ if add_eos_token:
214
+ user_prompt_len -= 1
215
+
216
+ tokenized_full_prompt["labels"] = [
217
+ -100
218
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
219
+ user_prompt_len:
220
+ ] # could be sped up, probably
221
+ return tokenized_full_prompt
222
+
223
+
224
+ # model = prepare_model_for_int8_training(model)
225
+
226
+ config = LoraConfig(
227
+ r=lora_r,
228
+ lora_alpha=lora_alpha,
229
+ target_modules=lora_target_modules,
230
+ lora_dropout=lora_dropout,
231
+ bias="none",
232
+ task_type="CAUSAL_LM",
233
+ )
234
+ model = get_peft_model(model, config)
235
+
236
+
237
+ if data_path.endswith(".json") or data_path.endswith(".jsonl"):
238
+ data = load_dataset("json", data_files=data_path)
239
+ # data = jload(data_path)
240
+ else:
241
+ data = load_dataset(data_path)
242
+
243
+ if resume_from_checkpoint:
244
+ # Check the available weights and load them
245
+ checkpoint_name = os.path.join(
246
+ resume_from_checkpoint, "pytorch_model.bin"
247
+ ) # Full checkpoint
248
+ if not os.path.exists(checkpoint_name):
249
+ checkpoint_name = os.path.join(
250
+ resume_from_checkpoint, "adapter_model.bin"
251
+ ) # only LoRA model - LoRA config above has to fit
252
+ resume_from_checkpoint = (
253
+ False # So the trainer won't try loading its state
254
+ )
255
+ # The two files above have a different name depending on how they were saved, but are actually the same.
256
+ if os.path.exists(checkpoint_name):
257
+ print(f"Restarting from {checkpoint_name}")
258
+ adapters_weights = torch.load(checkpoint_name)
259
+ set_peft_model_state_dict(model, adapters_weights)
260
+ else:
261
+ print(f"Checkpoint {checkpoint_name} not found")
262
+
263
+ if peft_path:
264
+ adapters_weights = torch.load(f"{peft_path}/adapter_model.bin")
265
+ set_peft_model_state_dict(model, adapters_weights)
266
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
267
+
268
+ if val_set_size > 0:
269
+ train_val = data["train"].train_test_split(
270
+ test_size=val_set_size, shuffle=True, seed=42
271
+ )
272
+ train_data = (
273
+ train_val["train"].shuffle().map(generate_and_tokenize_prompt)
274
+ )
275
+ val_data = (
276
+ train_val["test"].shuffle().map(generate_and_tokenize_prompt)
277
+ )
278
+ else:
279
+ train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
280
+ val_data = None
281
+
282
+ if not ddp and torch.cuda.device_count() > 1:
283
+ # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
284
+ model.is_parallelizable = True
285
+ model.model_parallel = True
286
+
287
+ trainer = transformers.Trainer(
288
+ model=model,
289
+ train_dataset=train_data,
290
+ eval_dataset=val_data,
291
+ args=transformers.TrainingArguments(
292
+ per_device_train_batch_size=micro_batch_size,
293
+ gradient_accumulation_steps=gradient_accumulation_steps,
294
+ warmup_steps=10,
295
+ num_train_epochs=num_epochs,
296
+ learning_rate=learning_rate,
297
+ fp16=True,
298
+ sharded_ddp=True,
299
+ logging_steps=200,
300
+ optim="adamw_torch",
301
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
302
+ save_strategy="steps",
303
+ eval_steps=200 if val_set_size > 0 else None,
304
+ save_steps=200,
305
+ logging_strategy='steps',
306
+ output_dir=output_dir,
307
+ save_total_limit=5,
308
+ load_best_model_at_end=True if val_set_size > 0 else False,
309
+ ddp_find_unused_parameters=False if ddp else None,
310
+ group_by_length=group_by_length,
311
+ report_to=report_to,
312
+
313
+
314
+ ),
315
+ data_collator=transformers.DataCollatorForSeq2Seq(
316
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
317
+ ),
318
+
319
+ )
320
+ model.config.use_cache = False
321
+
322
+ old_state_dict = model.state_dict
323
+ model.state_dict = (
324
+ lambda self, *_, **__: get_peft_model_state_dict(
325
+ self, old_state_dict()
326
+ )
327
+ ).__get__(model, type(model))
328
+
329
+ if torch.__version__ >= "2" and sys.platform != "win32":
330
+ model = torch.compile(model)
331
+
332
+ trainer.train(resume_from_checkpoint=None)
333
+
334
+ model.save_pretrained(output_dir)
335
+ trainer.save_model(output_dir)
336
+
337
+ print(
338
+ "\n If there's a warning about missing keys above, please disregard :)"
339
+ )
340
+
341
+
342
+ if __name__ == "__main__":
343
+ fire.Fire(train)