zetavg commited on
Commit
87a0e23
β€’
1 Parent(s): 341c612
llama_lora/globals.py CHANGED
@@ -3,6 +3,8 @@ import subprocess
3
 
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
 
 
 
6
 
7
  class Global:
8
  version = None
@@ -15,11 +17,14 @@ class Global:
15
  loaded_base_model: Any = None
16
 
17
  # Functions
18
- train_fn: Any = None
19
 
20
  # Training Control
21
  should_stop_training = False
22
 
 
 
 
23
  # UI related
24
  ui_title: str = "LLaMA-LoRA"
25
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
 
3
 
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
 
6
+ from .lib.finetune import train
7
+
8
 
9
  class Global:
10
  version = None
 
17
  loaded_base_model: Any = None
18
 
19
  # Functions
20
+ train_fn: Any = train
21
 
22
  # Training Control
23
  should_stop_training = False
24
 
25
+ # Model related
26
+ model_has_been_used = False
27
+
28
  # UI related
29
  ui_title: str = "LLaMA-LoRA"
30
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
llama_lora/lib/finetune.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Any, List
4
+
5
+ import fire
6
+ import torch
7
+ import transformers
8
+ from datasets import Dataset, load_dataset
9
+
10
+
11
+ from peft import (
12
+ LoraConfig,
13
+ get_peft_model,
14
+ get_peft_model_state_dict,
15
+ prepare_model_for_int8_training,
16
+ set_peft_model_state_dict,
17
+ )
18
+ from transformers import LlamaForCausalLM, LlamaTokenizer
19
+
20
+
21
+ def train(
22
+ # model/data params
23
+ base_model: Any,
24
+ tokenizer: Any,
25
+ output_dir: str,
26
+ train_dataset_data: List[Any],
27
+ # training hyperparams
28
+ micro_batch_size: int = 4,
29
+ gradient_accumulation_steps: int = 32,
30
+ num_epochs: int = 3,
31
+ learning_rate: float = 3e-4,
32
+ cutoff_len: int = 256,
33
+ val_set_size: int = 2000,
34
+ # lora hyperparams
35
+ lora_r: int = 8,
36
+ lora_alpha: int = 16,
37
+ lora_dropout: float = 0.05,
38
+ lora_target_modules: List[str] = [
39
+ "q_proj",
40
+ "v_proj",
41
+ ],
42
+ # llm hyperparams
43
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
44
+ group_by_length: bool = False, # faster, but produces an odd training loss curve
45
+ # either training checkpoint or final adapter
46
+ resume_from_checkpoint: str = None,
47
+ # logging
48
+ callbacks: List[Any] = []
49
+ ):
50
+ device_map = "auto"
51
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
52
+ ddp = world_size != 1
53
+ if ddp:
54
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
55
+
56
+ model = base_model
57
+ if isinstance(model, str):
58
+ model = LlamaForCausalLM.from_pretrained(
59
+ base_model,
60
+ load_in_8bit=True,
61
+ torch_dtype=torch.float16,
62
+ device_map=device_map,
63
+ )
64
+
65
+ if isinstance(tokenizer, str):
66
+ tokenizer = LlamaTokenizer.from_pretrained(tokenizer)
67
+
68
+ tokenizer.pad_token_id = (
69
+ 0 # unk. we want this to be different from the eos token
70
+ )
71
+ tokenizer.padding_side = "left" # Allow batched inference
72
+
73
+ def tokenize(prompt, add_eos_token=True):
74
+ # there's probably a way to do this with the tokenizer settings
75
+ # but again, gotta move fast
76
+ result = tokenizer(
77
+ prompt,
78
+ truncation=True,
79
+ max_length=cutoff_len,
80
+ padding=False,
81
+ return_tensors=None,
82
+ )
83
+ if (
84
+ result["input_ids"][-1] != tokenizer.eos_token_id
85
+ and len(result["input_ids"]) < cutoff_len
86
+ and add_eos_token
87
+ ):
88
+ result["input_ids"].append(tokenizer.eos_token_id)
89
+ result["attention_mask"].append(1)
90
+
91
+ result["labels"] = result["input_ids"].copy()
92
+
93
+ return result
94
+
95
+ def generate_and_tokenize_prompt(data_point):
96
+ full_prompt = data_point["prompt"] + data_point["completion"]
97
+ tokenized_full_prompt = tokenize(full_prompt)
98
+ if not train_on_inputs:
99
+ user_prompt = data_point["prompt"]
100
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
101
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
102
+
103
+ tokenized_full_prompt["labels"] = [
104
+ -100
105
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
106
+ user_prompt_len:
107
+ ] # could be sped up, probably
108
+ return tokenized_full_prompt
109
+
110
+ # will fail anyway.
111
+ try:
112
+ model = prepare_model_for_int8_training(model)
113
+ except Exception as e:
114
+ print(
115
+ f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
116
+
117
+ # model = prepare_model_for_int8_training(model)
118
+
119
+ config = LoraConfig(
120
+ r=lora_r,
121
+ lora_alpha=lora_alpha,
122
+ target_modules=lora_target_modules,
123
+ lora_dropout=lora_dropout,
124
+ bias="none",
125
+ task_type="CAUSAL_LM",
126
+ )
127
+ model = get_peft_model(model, config)
128
+
129
+ # If train_dataset_data is a list, convert it to datasets.Dataset
130
+ if isinstance(train_dataset_data, list):
131
+ train_dataset_data = Dataset.from_list(train_dataset_data)
132
+
133
+ if resume_from_checkpoint:
134
+ # Check the available weights and load them
135
+ checkpoint_name = os.path.join(
136
+ resume_from_checkpoint, "pytorch_model.bin"
137
+ ) # Full checkpoint
138
+ if not os.path.exists(checkpoint_name):
139
+ checkpoint_name = os.path.join(
140
+ resume_from_checkpoint, "adapter_model.bin"
141
+ ) # only LoRA model - LoRA config above has to fit
142
+ resume_from_checkpoint = (
143
+ False # So the trainer won't try loading its state
144
+ )
145
+ # The two files above have a different name depending on how they were saved, but are actually the same.
146
+ if os.path.exists(checkpoint_name):
147
+ print(f"Restarting from {checkpoint_name}")
148
+ adapters_weights = torch.load(checkpoint_name)
149
+ model = set_peft_model_state_dict(model, adapters_weights)
150
+ else:
151
+ print(f"Checkpoint {checkpoint_name} not found")
152
+
153
+ # Be more transparent about the % of trainable params.
154
+ model.print_trainable_parameters()
155
+
156
+ if val_set_size > 0:
157
+ train_val = train_dataset_data.train_test_split(
158
+ test_size=val_set_size, shuffle=True, seed=42
159
+ )
160
+ train_data = (
161
+ train_val["train"].shuffle().map(generate_and_tokenize_prompt)
162
+ )
163
+ val_data = (
164
+ train_val["test"].shuffle().map(generate_and_tokenize_prompt)
165
+ )
166
+ else:
167
+ train_data = train_dataset_data.shuffle().map(generate_and_tokenize_prompt)
168
+ val_data = None
169
+
170
+ if not ddp and torch.cuda.device_count() > 1:
171
+ # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
172
+ model.is_parallelizable = True
173
+ model.model_parallel = True
174
+
175
+ trainer = transformers.Trainer(
176
+ model=model,
177
+ train_dataset=train_data,
178
+ eval_dataset=val_data,
179
+ args=transformers.TrainingArguments(
180
+ per_device_train_batch_size=micro_batch_size,
181
+ gradient_accumulation_steps=gradient_accumulation_steps,
182
+ warmup_steps=100,
183
+ num_train_epochs=num_epochs,
184
+ learning_rate=learning_rate,
185
+ fp16=True,
186
+ logging_steps=10,
187
+ optim="adamw_torch",
188
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
189
+ save_strategy="steps",
190
+ eval_steps=200 if val_set_size > 0 else None,
191
+ save_steps=200,
192
+ output_dir=output_dir,
193
+ save_total_limit=3,
194
+ load_best_model_at_end=True if val_set_size > 0 else False,
195
+ ddp_find_unused_parameters=False if ddp else None,
196
+ group_by_length=group_by_length,
197
+ # report_to="wandb" if use_wandb else None,
198
+ # run_name=wandb_run_name if use_wandb else None,
199
+ ),
200
+ data_collator=transformers.DataCollatorForSeq2Seq(
201
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
202
+ ),
203
+ callbacks=callbacks,
204
+ )
205
+ model.config.use_cache = False
206
+
207
+ old_state_dict = model.state_dict
208
+ model.state_dict = (
209
+ lambda self, *_, **__: get_peft_model_state_dict(
210
+ self, old_state_dict()
211
+ )
212
+ ).__get__(model, type(model))
213
+
214
+ if torch.__version__ >= "2" and sys.platform != "win32":
215
+ model = torch.compile(model)
216
+
217
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
218
+
219
+ model.save_pretrained(output_dir)
220
+
221
+ print(
222
+ "\n If there's a warning about missing keys above, please disregard :)"
223
+ )
llama_lora/models.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
 
4
  import torch
5
  import transformers
@@ -31,11 +32,14 @@ def get_base_model():
31
 
32
 
33
  def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
 
 
34
  if device == "cuda":
35
  return PeftModel.from_pretrained(
36
  get_base_model(),
37
  lora_weights,
38
  torch_dtype=torch.float16,
 
39
  )
40
  elif device == "mps":
41
  return PeftModel.from_pretrained(
@@ -58,16 +62,21 @@ def get_tokenizer():
58
 
59
 
60
  def load_base_model():
 
 
 
61
  if Global.loaded_tokenizer is None:
62
  Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
63
- Global.base_model)
 
64
  if Global.loaded_base_model is None:
65
  if device == "cuda":
66
  Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
67
  Global.base_model,
68
  load_in_8bit=Global.load_8bit,
69
  torch_dtype=torch.float16,
70
- device_map="auto",
 
71
  )
72
  elif device == "mps":
73
  Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
@@ -79,3 +88,24 @@ def load_base_model():
79
  model = LlamaForCausalLM.from_pretrained(
80
  base_model, device_map={"": device}, low_cpu_mem_usage=True
81
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
+ import gc
4
 
5
  import torch
6
  import transformers
 
32
 
33
 
34
  def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
+ Global.model_has_been_used = True
36
+
37
  if device == "cuda":
38
  return PeftModel.from_pretrained(
39
  get_base_model(),
40
  lora_weights,
41
  torch_dtype=torch.float16,
42
+ device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
43
  )
44
  elif device == "mps":
45
  return PeftModel.from_pretrained(
 
62
 
63
 
64
  def load_base_model():
65
+ if Global.ui_dev_mode:
66
+ return
67
+
68
  if Global.loaded_tokenizer is None:
69
  Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
70
+ Global.base_model
71
+ )
72
  if Global.loaded_base_model is None:
73
  if device == "cuda":
74
  Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
75
  Global.base_model,
76
  load_in_8bit=Global.load_8bit,
77
  torch_dtype=torch.float16,
78
+ # device_map="auto",
79
+ device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
80
  )
81
  elif device == "mps":
82
  Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
 
88
  model = LlamaForCausalLM.from_pretrained(
89
  base_model, device_map={"": device}, low_cpu_mem_usage=True
90
  )
91
+
92
+
93
+ def unload_models():
94
+ del Global.loaded_base_model
95
+ Global.loaded_base_model = None
96
+
97
+ del Global.loaded_tokenizer
98
+ Global.loaded_tokenizer = None
99
+
100
+ gc.collect()
101
+
102
+ # if not shared.args.cpu: # will not be running on CPUs anyway
103
+ with torch.no_grad():
104
+ torch.cuda.empty_cache()
105
+
106
+ Global.model_has_been_used = False
107
+
108
+
109
+ def unload_models_if_already_used():
110
+ if Global.model_has_been_used:
111
+ unload_models()
llama_lora/ui/finetune_ui.py CHANGED
@@ -9,7 +9,7 @@ from random_word import RandomWords
9
  from transformers import TrainerCallback
10
 
11
  from ..globals import Global
12
- from ..models import get_base_model, get_tokenizer
13
  from ..utils.data import (
14
  get_available_template_names,
15
  get_available_dataset_names,
@@ -353,6 +353,11 @@ Train data (first 10):
353
 
354
  training_callbacks = [UiTrainerCallback]
355
 
 
 
 
 
 
356
  Global.should_stop_training = False
357
 
358
  return Global.train_fn(
 
9
  from transformers import TrainerCallback
10
 
11
  from ..globals import Global
12
+ from ..models import get_base_model, get_tokenizer, unload_models_if_already_used
13
  from ..utils.data import (
14
  get_available_template_names,
15
  get_available_dataset_names,
 
353
 
354
  training_callbacks = [UiTrainerCallback]
355
 
356
+ # If model has been used in inference, we need to unload it first.
357
+ # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
358
+ # gradient at index 1 - expected device meta but got cuda:0' error.
359
+ unload_models_if_already_used()
360
+
361
  Global.should_stop_training = False
362
 
363
  return Global.train_fn(
llama_lora/ui/inference_ui.py CHANGED
@@ -26,6 +26,7 @@ def inference(
26
  repetition_penalty=1.2,
27
  max_new_tokens=128,
28
  stream_output=False,
 
29
  **kwargs,
30
  ):
31
  variables = [variable_0, variable_1, variable_2, variable_3,
 
26
  repetition_penalty=1.2,
27
  max_new_tokens=128,
28
  stream_output=False,
29
+ progress=gr.Progress(track_tqdm=True),
30
  **kwargs,
31
  ):
32
  variables = [variable_0, variable_1, variable_2, variable_3,