Spaces:
Runtime error
Runtime error
zetavg
commited on
finetune: support lora_modules_to_save
Browse files- llama_lora/lib/finetune.py +35 -8
- llama_lora/ui/finetune_ui.py +175 -100
- llama_lora/ui/main_page.py +35 -5
llama_lora/lib/finetune.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
import importlib
|
4 |
-
from typing import Any, List
|
5 |
|
6 |
import json
|
7 |
|
@@ -18,7 +19,7 @@ from peft import (
|
|
18 |
prepare_model_for_int8_training,
|
19 |
set_peft_model_state_dict,
|
20 |
)
|
21 |
-
from transformers import
|
22 |
|
23 |
|
24 |
def train(
|
@@ -42,6 +43,7 @@ def train(
|
|
42 |
"q_proj",
|
43 |
"v_proj",
|
44 |
],
|
|
|
45 |
# llm hyperparams
|
46 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
47 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
@@ -61,6 +63,8 @@ def train(
|
|
61 |
wandb_watch: str = "false", # options: false | gradients | all
|
62 |
wandb_log_model: str = "true", # options: false | true
|
63 |
):
|
|
|
|
|
64 |
# for logging
|
65 |
finetune_args = {
|
66 |
'micro_batch_size': micro_batch_size,
|
@@ -81,6 +85,8 @@ def train(
|
|
81 |
}
|
82 |
if val_set_size and val_set_size > 0:
|
83 |
finetune_args['val_set_size'] = val_set_size
|
|
|
|
|
84 |
if resume_from_checkpoint:
|
85 |
finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
|
86 |
|
@@ -131,19 +137,39 @@ def train(
|
|
131 |
|
132 |
model = base_model
|
133 |
if isinstance(model, str):
|
134 |
-
|
|
|
135 |
base_model,
|
136 |
load_in_8bit=True,
|
137 |
torch_dtype=torch.float16,
|
|
|
138 |
device_map=device_map,
|
139 |
)
|
|
|
|
|
|
|
|
|
140 |
|
141 |
if isinstance(tokenizer, str):
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
tokenizer.padding_side = "left" # Allow batched inference
|
148 |
|
149 |
def tokenize(prompt, add_eos_token=True):
|
@@ -196,6 +222,7 @@ def train(
|
|
196 |
r=lora_r,
|
197 |
lora_alpha=lora_alpha,
|
198 |
target_modules=lora_target_modules,
|
|
|
199 |
lora_dropout=lora_dropout,
|
200 |
bias="none",
|
201 |
task_type="CAUSAL_LM",
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import re
|
4 |
import importlib
|
5 |
+
from typing import Any, List, Union
|
6 |
|
7 |
import json
|
8 |
|
|
|
19 |
prepare_model_for_int8_training,
|
20 |
set_peft_model_state_dict,
|
21 |
)
|
22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
|
23 |
|
24 |
|
25 |
def train(
|
|
|
43 |
"q_proj",
|
44 |
"v_proj",
|
45 |
],
|
46 |
+
lora_modules_to_save: Union[List[str], None] = [],
|
47 |
# llm hyperparams
|
48 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
49 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
|
|
63 |
wandb_watch: str = "false", # options: false | gradients | all
|
64 |
wandb_log_model: str = "true", # options: false | true
|
65 |
):
|
66 |
+
if lora_modules_to_save is not None and len(lora_modules_to_save) <= 0:
|
67 |
+
lora_modules_to_save = None
|
68 |
# for logging
|
69 |
finetune_args = {
|
70 |
'micro_batch_size': micro_batch_size,
|
|
|
85 |
}
|
86 |
if val_set_size and val_set_size > 0:
|
87 |
finetune_args['val_set_size'] = val_set_size
|
88 |
+
if lora_modules_to_save:
|
89 |
+
finetune_args['lora_modules_to_save'] = lora_modules_to_save
|
90 |
if resume_from_checkpoint:
|
91 |
finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
|
92 |
|
|
|
137 |
|
138 |
model = base_model
|
139 |
if isinstance(model, str):
|
140 |
+
model_name = model
|
141 |
+
model = AutoModelForCausalLM.from_pretrained(
|
142 |
base_model,
|
143 |
load_in_8bit=True,
|
144 |
torch_dtype=torch.float16,
|
145 |
+
llm_int8_skip_modules=lora_modules_to_save,
|
146 |
device_map=device_map,
|
147 |
)
|
148 |
+
if re.match("[^/]+/llama", model_name):
|
149 |
+
model.config.pad_token_id = 0
|
150 |
+
model.config.bos_token_id = 1
|
151 |
+
model.config.eos_token_id = 2
|
152 |
|
153 |
if isinstance(tokenizer, str):
|
154 |
+
tokenizer_name = tokenizer
|
155 |
+
try:
|
156 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
157 |
+
except Exception as e:
|
158 |
+
if 'LLaMATokenizer' in str(e):
|
159 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
160 |
+
tokenizer_name,
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
raise e
|
164 |
+
|
165 |
+
if re.match("[^/]+/llama", tokenizer_name):
|
166 |
+
tokenizer.pad_token_id = 0
|
167 |
+
tokenizer.bos_token_id = 1
|
168 |
+
tokenizer.eos_token_id = 2
|
169 |
+
|
170 |
+
# tokenizer.pad_token_id = (
|
171 |
+
# 0 # unk. we want this to be different from the eos token
|
172 |
+
# )
|
173 |
tokenizer.padding_side = "left" # Allow batched inference
|
174 |
|
175 |
def tokenize(prompt, add_eos_token=True):
|
|
|
222 |
r=lora_r,
|
223 |
lora_alpha=lora_alpha,
|
224 |
target_modules=lora_target_modules,
|
225 |
+
modules_to_save=lora_modules_to_save,
|
226 |
lora_dropout=lora_dropout,
|
227 |
bias="none",
|
228 |
task_type="CAUSAL_LM",
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -296,6 +296,7 @@ def do_train(
|
|
296 |
lora_alpha,
|
297 |
lora_dropout,
|
298 |
lora_target_modules,
|
|
|
299 |
save_steps,
|
300 |
save_total_limit,
|
301 |
logging_steps,
|
@@ -314,16 +315,22 @@ def do_train(
|
|
314 |
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
|
315 |
continue_from_checkpoint = None
|
316 |
if continue_from_model:
|
317 |
-
resume_from_checkpoint = os.path.join(
|
|
|
318 |
if continue_from_checkpoint:
|
319 |
-
resume_from_checkpoint = os.path.join(
|
320 |
-
|
|
|
|
|
321 |
if not os.path.exists(will_be_resume_from_checkpoint_file):
|
322 |
-
raise ValueError(
|
|
|
323 |
else:
|
324 |
-
will_be_resume_from_checkpoint_file = os.path.join(
|
|
|
325 |
if not os.path.exists(will_be_resume_from_checkpoint_file):
|
326 |
-
raise ValueError(
|
|
|
327 |
|
328 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
329 |
if os.path.exists(output_dir):
|
@@ -334,7 +341,11 @@ def do_train(
|
|
334 |
if not should_training_progress_track_tqdm:
|
335 |
progress(0, desc="Preparing train data...")
|
336 |
|
337 |
-
|
|
|
|
|
|
|
|
|
338 |
|
339 |
prompter = Prompter(template)
|
340 |
# variable_names = prompter.get_variable_names()
|
@@ -363,23 +374,6 @@ def do_train(
|
|
363 |
if Global.ui_dev_mode:
|
364 |
Global.should_stop_training = False
|
365 |
|
366 |
-
for i in range(300):
|
367 |
-
if (Global.should_stop_training):
|
368 |
-
return
|
369 |
-
epochs = 3
|
370 |
-
epoch = i / 100
|
371 |
-
last_loss = None
|
372 |
-
if (i > 20):
|
373 |
-
last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
|
374 |
-
|
375 |
-
progress(
|
376 |
-
(i, 300),
|
377 |
-
desc="(Simulate) " +
|
378 |
-
get_progress_text(epoch, epochs, last_loss)
|
379 |
-
)
|
380 |
-
|
381 |
-
time.sleep(0.1)
|
382 |
-
|
383 |
message = f"""Currently in UI dev mode, not doing the actual training.
|
384 |
|
385 |
Train options: {json.dumps({
|
@@ -394,6 +388,7 @@ Train options: {json.dumps({
|
|
394 |
'lora_alpha': lora_alpha,
|
395 |
'lora_dropout': lora_dropout,
|
396 |
'lora_target_modules': lora_target_modules,
|
|
|
397 |
'model_name': model_name,
|
398 |
'continue_from_model': continue_from_model,
|
399 |
'continue_from_checkpoint': continue_from_checkpoint,
|
@@ -403,11 +398,30 @@ Train data (first 10):
|
|
403 |
{json.dumps(train_data[:10], indent=2)}
|
404 |
"""
|
405 |
print(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
time.sleep(2)
|
407 |
return message
|
408 |
|
409 |
if not should_training_progress_track_tqdm:
|
410 |
-
progress(
|
|
|
411 |
|
412 |
log_history = []
|
413 |
|
@@ -445,9 +459,6 @@ Train data (first 10):
|
|
445 |
|
446 |
Global.should_stop_training = False
|
447 |
|
448 |
-
base_model = get_new_base_model(base_model_name)
|
449 |
-
tokenizer = get_tokenizer(tokenizer_name)
|
450 |
-
|
451 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
452 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
453 |
|
@@ -498,33 +509,34 @@ Train data (first 10):
|
|
498 |
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
499 |
|
500 |
train_output = Global.train_fn(
|
501 |
-
base_model,
|
502 |
-
tokenizer,
|
503 |
-
output_dir,
|
504 |
-
train_data,
|
505 |
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
506 |
-
micro_batch_size,
|
507 |
-
gradient_accumulation_steps,
|
508 |
-
epochs,
|
509 |
-
learning_rate,
|
510 |
-
max_seq_length,
|
511 |
-
evaluate_data_count,
|
512 |
-
lora_r,
|
513 |
-
lora_alpha,
|
514 |
-
lora_dropout,
|
515 |
-
lora_target_modules,
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
Global.
|
525 |
-
|
526 |
-
|
527 |
-
|
|
|
528 |
)
|
529 |
|
530 |
logs_str = "\n".join([json.dumps(log)
|
@@ -578,10 +590,12 @@ def handle_load_params_from_model(
|
|
578 |
lora_alpha,
|
579 |
lora_dropout,
|
580 |
lora_target_modules,
|
|
|
581 |
save_steps,
|
582 |
save_total_limit,
|
583 |
logging_steps,
|
584 |
lora_target_module_choices,
|
|
|
585 |
):
|
586 |
error_message = ""
|
587 |
notice_message = ""
|
@@ -633,6 +647,11 @@ def handle_load_params_from_model(
|
|
633 |
for element in value:
|
634 |
if element not in lora_target_module_choices:
|
635 |
lora_target_module_choices.append(element)
|
|
|
|
|
|
|
|
|
|
|
636 |
elif key == "save_steps":
|
637 |
save_steps = value
|
638 |
elif key == "save_total_limit":
|
@@ -670,15 +689,20 @@ def handle_load_params_from_model(
|
|
670 |
lora_r,
|
671 |
lora_alpha,
|
672 |
lora_dropout,
|
673 |
-
gr.CheckboxGroup.update(value=lora_target_modules,
|
|
|
|
|
|
|
674 |
save_steps,
|
675 |
save_total_limit,
|
676 |
logging_steps,
|
677 |
lora_target_module_choices,
|
|
|
678 |
)
|
679 |
|
680 |
|
681 |
default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
|
682 |
|
683 |
|
684 |
def handle_lora_target_modules_add(choices, new_module, selected_modules):
|
@@ -688,6 +712,13 @@ def handle_lora_target_modules_add(choices, new_module, selected_modules):
|
|
688 |
return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
|
689 |
|
690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
def finetune_ui():
|
692 |
things_that_might_timeout = []
|
693 |
|
@@ -863,12 +894,13 @@ def finetune_ui():
|
|
863 |
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
|
864 |
)
|
865 |
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
|
|
872 |
|
873 |
with gr.Box(elem_id="finetune_continue_from_model_box"):
|
874 |
with gr.Row():
|
@@ -923,30 +955,65 @@ def finetune_ui():
|
|
923 |
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
|
924 |
)
|
925 |
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
950 |
|
951 |
with gr.Row():
|
952 |
logging_steps = gr.Number(
|
@@ -976,20 +1043,25 @@ def finetune_ui():
|
|
976 |
elem_id="finetune_model_name",
|
977 |
)
|
978 |
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
-
)
|
984 |
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
993 |
|
994 |
things_that_might_timeout.append(reload_selections_button.click(
|
995 |
reload_selections,
|
@@ -1031,6 +1103,7 @@ def finetune_ui():
|
|
1031 |
lora_alpha,
|
1032 |
lora_dropout,
|
1033 |
lora_target_modules,
|
|
|
1034 |
save_steps,
|
1035 |
save_total_limit,
|
1036 |
logging_steps,
|
@@ -1039,8 +1112,10 @@ def finetune_ui():
|
|
1039 |
things_that_might_timeout.append(
|
1040 |
load_params_from_model_btn.click(
|
1041 |
fn=handle_load_params_from_model,
|
1042 |
-
inputs=[continue_from_model] + finetune_args +
|
1043 |
-
|
|
|
|
|
1044 |
)
|
1045 |
)
|
1046 |
|
|
|
296 |
lora_alpha,
|
297 |
lora_dropout,
|
298 |
lora_target_modules,
|
299 |
+
lora_modules_to_save,
|
300 |
save_steps,
|
301 |
save_total_limit,
|
302 |
logging_steps,
|
|
|
315 |
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
|
316 |
continue_from_checkpoint = None
|
317 |
if continue_from_model:
|
318 |
+
resume_from_checkpoint = os.path.join(
|
319 |
+
Global.data_dir, "lora_models", continue_from_model)
|
320 |
if continue_from_checkpoint:
|
321 |
+
resume_from_checkpoint = os.path.join(
|
322 |
+
resume_from_checkpoint, continue_from_checkpoint)
|
323 |
+
will_be_resume_from_checkpoint_file = os.path.join(
|
324 |
+
resume_from_checkpoint, "pytorch_model.bin")
|
325 |
if not os.path.exists(will_be_resume_from_checkpoint_file):
|
326 |
+
raise ValueError(
|
327 |
+
f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
|
328 |
else:
|
329 |
+
will_be_resume_from_checkpoint_file = os.path.join(
|
330 |
+
resume_from_checkpoint, "adapter_model.bin")
|
331 |
if not os.path.exists(will_be_resume_from_checkpoint_file):
|
332 |
+
raise ValueError(
|
333 |
+
f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
|
334 |
|
335 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
336 |
if os.path.exists(output_dir):
|
|
|
341 |
if not should_training_progress_track_tqdm:
|
342 |
progress(0, desc="Preparing train data...")
|
343 |
|
344 |
+
# Need RAM for training
|
345 |
+
unload_models()
|
346 |
+
Global.new_base_model_that_is_ready_to_be_used = None
|
347 |
+
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
348 |
+
clear_cache()
|
349 |
|
350 |
prompter = Prompter(template)
|
351 |
# variable_names = prompter.get_variable_names()
|
|
|
374 |
if Global.ui_dev_mode:
|
375 |
Global.should_stop_training = False
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
message = f"""Currently in UI dev mode, not doing the actual training.
|
378 |
|
379 |
Train options: {json.dumps({
|
|
|
388 |
'lora_alpha': lora_alpha,
|
389 |
'lora_dropout': lora_dropout,
|
390 |
'lora_target_modules': lora_target_modules,
|
391 |
+
'lora_modules_to_save': lora_modules_to_save,
|
392 |
'model_name': model_name,
|
393 |
'continue_from_model': continue_from_model,
|
394 |
'continue_from_checkpoint': continue_from_checkpoint,
|
|
|
398 |
{json.dumps(train_data[:10], indent=2)}
|
399 |
"""
|
400 |
print(message)
|
401 |
+
|
402 |
+
for i in range(300):
|
403 |
+
if (Global.should_stop_training):
|
404 |
+
return
|
405 |
+
epochs = 3
|
406 |
+
epoch = i / 100
|
407 |
+
last_loss = None
|
408 |
+
if (i > 20):
|
409 |
+
last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
|
410 |
+
|
411 |
+
progress(
|
412 |
+
(i, 300),
|
413 |
+
desc="(Simulate) " +
|
414 |
+
get_progress_text(epoch, epochs, last_loss)
|
415 |
+
)
|
416 |
+
|
417 |
+
time.sleep(0.1)
|
418 |
+
|
419 |
time.sleep(2)
|
420 |
return message
|
421 |
|
422 |
if not should_training_progress_track_tqdm:
|
423 |
+
progress(
|
424 |
+
0, desc=f"Preparing model {base_model_name} for training...")
|
425 |
|
426 |
log_history = []
|
427 |
|
|
|
459 |
|
460 |
Global.should_stop_training = False
|
461 |
|
|
|
|
|
|
|
462 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
463 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
464 |
|
|
|
509 |
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
510 |
|
511 |
train_output = Global.train_fn(
|
512 |
+
base_model=base_model_name,
|
513 |
+
tokenizer=tokenizer_name,
|
514 |
+
output_dir=output_dir,
|
515 |
+
train_data=train_data,
|
516 |
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
517 |
+
micro_batch_size=micro_batch_size,
|
518 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
519 |
+
num_epochs=epochs,
|
520 |
+
learning_rate=learning_rate,
|
521 |
+
cutoff_len=max_seq_length,
|
522 |
+
val_set_size=evaluate_data_count,
|
523 |
+
lora_r=lora_r,
|
524 |
+
lora_alpha=lora_alpha,
|
525 |
+
lora_dropout=lora_dropout,
|
526 |
+
lora_target_modules=lora_target_modules,
|
527 |
+
lora_modules_to_save=lora_modules_to_save,
|
528 |
+
train_on_inputs=train_on_inputs,
|
529 |
+
group_by_length=False,
|
530 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
531 |
+
save_steps=save_steps,
|
532 |
+
save_total_limit=save_total_limit,
|
533 |
+
logging_steps=logging_steps,
|
534 |
+
callbacks=training_callbacks,
|
535 |
+
wandb_api_key=Global.wandb_api_key,
|
536 |
+
wandb_project=Global.default_wandb_project if Global.enable_wandb else None,
|
537 |
+
wandb_group=wandb_group,
|
538 |
+
wandb_run_name=model_name,
|
539 |
+
wandb_tags=wandb_tags
|
540 |
)
|
541 |
|
542 |
logs_str = "\n".join([json.dumps(log)
|
|
|
590 |
lora_alpha,
|
591 |
lora_dropout,
|
592 |
lora_target_modules,
|
593 |
+
lora_modules_to_save,
|
594 |
save_steps,
|
595 |
save_total_limit,
|
596 |
logging_steps,
|
597 |
lora_target_module_choices,
|
598 |
+
lora_modules_to_save_choices,
|
599 |
):
|
600 |
error_message = ""
|
601 |
notice_message = ""
|
|
|
647 |
for element in value:
|
648 |
if element not in lora_target_module_choices:
|
649 |
lora_target_module_choices.append(element)
|
650 |
+
elif key == "lora_modules_to_save":
|
651 |
+
lora_modules_to_save = value
|
652 |
+
for element in value:
|
653 |
+
if element not in lora_modules_to_save_choices:
|
654 |
+
lora_modules_to_save_choices.append(element)
|
655 |
elif key == "save_steps":
|
656 |
save_steps = value
|
657 |
elif key == "save_total_limit":
|
|
|
689 |
lora_r,
|
690 |
lora_alpha,
|
691 |
lora_dropout,
|
692 |
+
gr.CheckboxGroup.update(value=lora_target_modules,
|
693 |
+
choices=lora_target_module_choices),
|
694 |
+
gr.CheckboxGroup.update(
|
695 |
+
value=lora_modules_to_save, choices=lora_modules_to_save_choices),
|
696 |
save_steps,
|
697 |
save_total_limit,
|
698 |
logging_steps,
|
699 |
lora_target_module_choices,
|
700 |
+
lora_modules_to_save_choices
|
701 |
)
|
702 |
|
703 |
|
704 |
default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
705 |
+
default_lora_modules_to_save_choices = ["model.embed_tokens", "lm_head"]
|
706 |
|
707 |
|
708 |
def handle_lora_target_modules_add(choices, new_module, selected_modules):
|
|
|
712 |
return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
|
713 |
|
714 |
|
715 |
+
def handle_lora_modules_to_save_add(choices, new_module, selected_modules):
|
716 |
+
choices.append(new_module)
|
717 |
+
selected_modules.append(new_module)
|
718 |
+
|
719 |
+
return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
|
720 |
+
|
721 |
+
|
722 |
def finetune_ui():
|
723 |
things_that_might_timeout = []
|
724 |
|
|
|
894 |
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
|
895 |
)
|
896 |
|
897 |
+
with gr.Column():
|
898 |
+
evaluate_data_count = gr.Slider(
|
899 |
+
minimum=0, maximum=1, step=1, value=0,
|
900 |
+
label="Evaluation Data Count",
|
901 |
+
info="The number of data to be used for evaluation. This specific amount of data will be randomly chosen from the training dataset for evaluating the model's performance during the process, without contributing to the actual training.",
|
902 |
+
elem_id="finetune_evaluate_data_count"
|
903 |
+
)
|
904 |
|
905 |
with gr.Box(elem_id="finetune_continue_from_model_box"):
|
906 |
with gr.Row():
|
|
|
955 |
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
|
956 |
)
|
957 |
|
958 |
+
with gr.Column(elem_id="finetune_lora_target_modules_box"):
|
959 |
+
lora_target_modules = gr.CheckboxGroup(
|
960 |
+
label="LoRA Target Modules",
|
961 |
+
choices=default_lora_target_module_choices,
|
962 |
+
value=["q_proj", "v_proj"],
|
963 |
+
info="Modules to replace with LoRA.",
|
964 |
+
elem_id="finetune_lora_target_modules"
|
965 |
+
)
|
966 |
+
lora_target_module_choices = gr.State(
|
967 |
+
value=default_lora_target_module_choices)
|
968 |
+
with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
|
969 |
+
with gr.Row():
|
970 |
+
lora_target_modules_add = gr.Textbox(
|
971 |
+
lines=1, max_lines=1, show_label=False,
|
972 |
+
elem_id="finetune_lora_target_modules_add"
|
973 |
+
)
|
974 |
+
lora_target_modules_add_btn = gr.Button(
|
975 |
+
"Add",
|
976 |
+
elem_id="finetune_lora_target_modules_add_btn"
|
977 |
+
)
|
978 |
+
lora_target_modules_add_btn.style(
|
979 |
+
full_width=False, size="sm")
|
980 |
+
things_that_might_timeout.append(lora_target_modules_add_btn.click(
|
981 |
+
handle_lora_target_modules_add,
|
982 |
+
inputs=[lora_target_module_choices,
|
983 |
+
lora_target_modules_add, lora_target_modules],
|
984 |
+
outputs=[lora_target_module_choices,
|
985 |
+
lora_target_modules_add, lora_target_modules],
|
986 |
+
))
|
987 |
+
|
988 |
+
with gr.Column(elem_id="finetune_lora_modules_to_save_box"):
|
989 |
+
lora_modules_to_save = gr.CheckboxGroup(
|
990 |
+
label="LoRA Modules To Save",
|
991 |
+
choices=default_lora_modules_to_save_choices,
|
992 |
+
value=[],
|
993 |
+
# info="",
|
994 |
+
elem_id="finetune_lora_modules_to_save"
|
995 |
+
)
|
996 |
+
lora_modules_to_save_choices = gr.State(
|
997 |
+
value=default_lora_modules_to_save_choices)
|
998 |
+
with gr.Box(elem_id="finetune_lora_modules_to_save_add_box"):
|
999 |
+
with gr.Row():
|
1000 |
+
lora_modules_to_save_add = gr.Textbox(
|
1001 |
+
lines=1, max_lines=1, show_label=False,
|
1002 |
+
elem_id="finetune_lora_modules_to_save_add"
|
1003 |
+
)
|
1004 |
+
lora_modules_to_save_add_btn = gr.Button(
|
1005 |
+
"Add",
|
1006 |
+
elem_id="finetune_lora_modules_to_save_add_btn"
|
1007 |
+
)
|
1008 |
+
lora_modules_to_save_add_btn.style(
|
1009 |
+
full_width=False, size="sm")
|
1010 |
+
things_that_might_timeout.append(lora_modules_to_save_add_btn.click(
|
1011 |
+
handle_lora_modules_to_save_add,
|
1012 |
+
inputs=[lora_modules_to_save_choices,
|
1013 |
+
lora_modules_to_save_add, lora_modules_to_save],
|
1014 |
+
outputs=[lora_modules_to_save_choices,
|
1015 |
+
lora_modules_to_save_add, lora_modules_to_save],
|
1016 |
+
))
|
1017 |
|
1018 |
with gr.Row():
|
1019 |
logging_steps = gr.Number(
|
|
|
1043 |
elem_id="finetune_model_name",
|
1044 |
)
|
1045 |
|
1046 |
+
with gr.Row():
|
1047 |
+
with gr.Column():
|
1048 |
+
pass
|
1049 |
+
with gr.Column():
|
|
|
1050 |
|
1051 |
+
with gr.Row():
|
1052 |
+
train_btn = gr.Button(
|
1053 |
+
"Train", variant="primary", label="Train",
|
1054 |
+
elem_id="finetune_start_btn"
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
abort_button = gr.Button(
|
1058 |
+
"Abort", label="Abort",
|
1059 |
+
elem_id="finetune_stop_btn"
|
1060 |
+
)
|
1061 |
+
confirm_abort_button = gr.Button(
|
1062 |
+
"Confirm Abort", label="Confirm Abort", variant="stop",
|
1063 |
+
elem_id="finetune_confirm_stop_btn"
|
1064 |
+
)
|
1065 |
|
1066 |
things_that_might_timeout.append(reload_selections_button.click(
|
1067 |
reload_selections,
|
|
|
1103 |
lora_alpha,
|
1104 |
lora_dropout,
|
1105 |
lora_target_modules,
|
1106 |
+
lora_modules_to_save,
|
1107 |
save_steps,
|
1108 |
save_total_limit,
|
1109 |
logging_steps,
|
|
|
1112 |
things_that_might_timeout.append(
|
1113 |
load_params_from_model_btn.click(
|
1114 |
fn=handle_load_params_from_model,
|
1115 |
+
inputs=[continue_from_model] + finetune_args +
|
1116 |
+
[lora_target_module_choices, lora_modules_to_save_choices],
|
1117 |
+
outputs=[load_params_from_model_message] + finetune_args +
|
1118 |
+
[lora_target_module_choices, lora_modules_to_save_choices]
|
1119 |
)
|
1120 |
)
|
1121 |
|
llama_lora/ui/main_page.py
CHANGED
@@ -733,24 +733,54 @@ def main_page_custom_css():
|
|
733 |
flex: 2;
|
734 |
}
|
735 |
|
736 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
737 |
margin-top: -24px;
|
738 |
padding-top: 8px;
|
739 |
border-top-left-radius: 0;
|
740 |
border-top-right-radius: 0;
|
741 |
border-top: 0;
|
742 |
}
|
743 |
-
#finetune_lora_target_modules_add_box > * > .form
|
|
|
744 |
border: 0;
|
745 |
box-shadow: none;
|
746 |
}
|
747 |
-
#finetune_lora_target_modules_add
|
|
|
748 |
padding: 0;
|
749 |
}
|
750 |
-
#finetune_lora_target_modules_add input
|
|
|
751 |
padding: 4px 8px;
|
752 |
}
|
753 |
-
#finetune_lora_target_modules_add_btn
|
|
|
754 |
min-width: 60px;
|
755 |
}
|
756 |
|
|
|
733 |
flex: 2;
|
734 |
}
|
735 |
|
736 |
+
#finetune_lora_target_modules_box,
|
737 |
+
#finetune_lora_modules_to_save_box {
|
738 |
+
margin-top: -24px;
|
739 |
+
}
|
740 |
+
#finetune_lora_target_modules_box > .form,
|
741 |
+
#finetune_lora_modules_to_save_box > .form {
|
742 |
+
padding-top: 8px;
|
743 |
+
border-top: 0;
|
744 |
+
border-top-left-radius: 0;
|
745 |
+
border-top-right-radius: 0;
|
746 |
+
background: var(--block-background-fill);
|
747 |
+
position: relative;
|
748 |
+
}
|
749 |
+
#finetune_lora_target_modules_box > .form::before,
|
750 |
+
#finetune_lora_modules_to_save_box > .form::before {
|
751 |
+
content: "";
|
752 |
+
display: block;
|
753 |
+
position: absolute;
|
754 |
+
top: 8px;
|
755 |
+
left: 0;
|
756 |
+
right: 0;
|
757 |
+
height: 1px;
|
758 |
+
z-index: 1;
|
759 |
+
background: var(--block-border-color);
|
760 |
+
}
|
761 |
+
#finetune_lora_target_modules_add_box,
|
762 |
+
#finetune_lora_modules_to_save_add_box {
|
763 |
margin-top: -24px;
|
764 |
padding-top: 8px;
|
765 |
border-top-left-radius: 0;
|
766 |
border-top-right-radius: 0;
|
767 |
border-top: 0;
|
768 |
}
|
769 |
+
#finetune_lora_target_modules_add_box > * > .form,
|
770 |
+
#finetune_lora_modules_to_save_add_box > * > .form {
|
771 |
border: 0;
|
772 |
box-shadow: none;
|
773 |
}
|
774 |
+
#finetune_lora_target_modules_add,
|
775 |
+
#finetune_lora_modules_to_save_add {
|
776 |
padding: 0;
|
777 |
}
|
778 |
+
#finetune_lora_target_modules_add input,
|
779 |
+
#finetune_lora_modules_to_save_add input {
|
780 |
padding: 4px 8px;
|
781 |
}
|
782 |
+
#finetune_lora_target_modules_add_btn,
|
783 |
+
#finetune_lora_modules_to_save_add_btn {
|
784 |
min-width: 60px;
|
785 |
}
|
786 |
|