zetavg commited on
Commit
8cb0300
·
unverified ·
1 Parent(s): dd931be

finetune: support lora_modules_to_save

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import sys
 
3
  import importlib
4
- from typing import Any, List
5
 
6
  import json
7
 
@@ -18,7 +19,7 @@ from peft import (
18
  prepare_model_for_int8_training,
19
  set_peft_model_state_dict,
20
  )
21
- from transformers import LlamaForCausalLM, LlamaTokenizer
22
 
23
 
24
  def train(
@@ -42,6 +43,7 @@ def train(
42
  "q_proj",
43
  "v_proj",
44
  ],
 
45
  # llm hyperparams
46
  train_on_inputs: bool = True, # if False, masks out inputs in loss
47
  group_by_length: bool = False, # faster, but produces an odd training loss curve
@@ -61,6 +63,8 @@ def train(
61
  wandb_watch: str = "false", # options: false | gradients | all
62
  wandb_log_model: str = "true", # options: false | true
63
  ):
 
 
64
  # for logging
65
  finetune_args = {
66
  'micro_batch_size': micro_batch_size,
@@ -81,6 +85,8 @@ def train(
81
  }
82
  if val_set_size and val_set_size > 0:
83
  finetune_args['val_set_size'] = val_set_size
 
 
84
  if resume_from_checkpoint:
85
  finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
86
 
@@ -131,19 +137,39 @@ def train(
131
 
132
  model = base_model
133
  if isinstance(model, str):
134
- model = LlamaForCausalLM.from_pretrained(
 
135
  base_model,
136
  load_in_8bit=True,
137
  torch_dtype=torch.float16,
 
138
  device_map=device_map,
139
  )
 
 
 
 
140
 
141
  if isinstance(tokenizer, str):
142
- tokenizer = LlamaTokenizer.from_pretrained(tokenizer)
143
-
144
- tokenizer.pad_token_id = (
145
- 0 # unk. we want this to be different from the eos token
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  tokenizer.padding_side = "left" # Allow batched inference
148
 
149
  def tokenize(prompt, add_eos_token=True):
@@ -196,6 +222,7 @@ def train(
196
  r=lora_r,
197
  lora_alpha=lora_alpha,
198
  target_modules=lora_target_modules,
 
199
  lora_dropout=lora_dropout,
200
  bias="none",
201
  task_type="CAUSAL_LM",
 
1
  import os
2
  import sys
3
+ import re
4
  import importlib
5
+ from typing import Any, List, Union
6
 
7
  import json
8
 
 
19
  prepare_model_for_int8_training,
20
  set_peft_model_state_dict,
21
  )
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
23
 
24
 
25
  def train(
 
43
  "q_proj",
44
  "v_proj",
45
  ],
46
+ lora_modules_to_save: Union[List[str], None] = [],
47
  # llm hyperparams
48
  train_on_inputs: bool = True, # if False, masks out inputs in loss
49
  group_by_length: bool = False, # faster, but produces an odd training loss curve
 
63
  wandb_watch: str = "false", # options: false | gradients | all
64
  wandb_log_model: str = "true", # options: false | true
65
  ):
66
+ if lora_modules_to_save is not None and len(lora_modules_to_save) <= 0:
67
+ lora_modules_to_save = None
68
  # for logging
69
  finetune_args = {
70
  'micro_batch_size': micro_batch_size,
 
85
  }
86
  if val_set_size and val_set_size > 0:
87
  finetune_args['val_set_size'] = val_set_size
88
+ if lora_modules_to_save:
89
+ finetune_args['lora_modules_to_save'] = lora_modules_to_save
90
  if resume_from_checkpoint:
91
  finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
92
 
 
137
 
138
  model = base_model
139
  if isinstance(model, str):
140
+ model_name = model
141
+ model = AutoModelForCausalLM.from_pretrained(
142
  base_model,
143
  load_in_8bit=True,
144
  torch_dtype=torch.float16,
145
+ llm_int8_skip_modules=lora_modules_to_save,
146
  device_map=device_map,
147
  )
148
+ if re.match("[^/]+/llama", model_name):
149
+ model.config.pad_token_id = 0
150
+ model.config.bos_token_id = 1
151
+ model.config.eos_token_id = 2
152
 
153
  if isinstance(tokenizer, str):
154
+ tokenizer_name = tokenizer
155
+ try:
156
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
157
+ except Exception as e:
158
+ if 'LLaMATokenizer' in str(e):
159
+ tokenizer = LlamaTokenizer.from_pretrained(
160
+ tokenizer_name,
161
+ )
162
+ else:
163
+ raise e
164
+
165
+ if re.match("[^/]+/llama", tokenizer_name):
166
+ tokenizer.pad_token_id = 0
167
+ tokenizer.bos_token_id = 1
168
+ tokenizer.eos_token_id = 2
169
+
170
+ # tokenizer.pad_token_id = (
171
+ # 0 # unk. we want this to be different from the eos token
172
+ # )
173
  tokenizer.padding_side = "left" # Allow batched inference
174
 
175
  def tokenize(prompt, add_eos_token=True):
 
222
  r=lora_r,
223
  lora_alpha=lora_alpha,
224
  target_modules=lora_target_modules,
225
+ modules_to_save=lora_modules_to_save,
226
  lora_dropout=lora_dropout,
227
  bias="none",
228
  task_type="CAUSAL_LM",
llama_lora/ui/finetune_ui.py CHANGED
@@ -296,6 +296,7 @@ def do_train(
296
  lora_alpha,
297
  lora_dropout,
298
  lora_target_modules,
 
299
  save_steps,
300
  save_total_limit,
301
  logging_steps,
@@ -314,16 +315,22 @@ def do_train(
314
  if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
315
  continue_from_checkpoint = None
316
  if continue_from_model:
317
- resume_from_checkpoint = os.path.join(Global.data_dir, "lora_models", continue_from_model)
 
318
  if continue_from_checkpoint:
319
- resume_from_checkpoint = os.path.join(resume_from_checkpoint, continue_from_checkpoint)
320
- will_be_resume_from_checkpoint_file = os.path.join(resume_from_checkpoint, "pytorch_model.bin")
 
 
321
  if not os.path.exists(will_be_resume_from_checkpoint_file):
322
- raise ValueError(f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
 
323
  else:
324
- will_be_resume_from_checkpoint_file = os.path.join(resume_from_checkpoint, "adapter_model.bin")
 
325
  if not os.path.exists(will_be_resume_from_checkpoint_file):
326
- raise ValueError(f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
 
327
 
328
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
329
  if os.path.exists(output_dir):
@@ -334,7 +341,11 @@ def do_train(
334
  if not should_training_progress_track_tqdm:
335
  progress(0, desc="Preparing train data...")
336
 
337
- unload_models() # Need RAM for training
 
 
 
 
338
 
339
  prompter = Prompter(template)
340
  # variable_names = prompter.get_variable_names()
@@ -363,23 +374,6 @@ def do_train(
363
  if Global.ui_dev_mode:
364
  Global.should_stop_training = False
365
 
366
- for i in range(300):
367
- if (Global.should_stop_training):
368
- return
369
- epochs = 3
370
- epoch = i / 100
371
- last_loss = None
372
- if (i > 20):
373
- last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
374
-
375
- progress(
376
- (i, 300),
377
- desc="(Simulate) " +
378
- get_progress_text(epoch, epochs, last_loss)
379
- )
380
-
381
- time.sleep(0.1)
382
-
383
  message = f"""Currently in UI dev mode, not doing the actual training.
384
 
385
  Train options: {json.dumps({
@@ -394,6 +388,7 @@ Train options: {json.dumps({
394
  'lora_alpha': lora_alpha,
395
  'lora_dropout': lora_dropout,
396
  'lora_target_modules': lora_target_modules,
 
397
  'model_name': model_name,
398
  'continue_from_model': continue_from_model,
399
  'continue_from_checkpoint': continue_from_checkpoint,
@@ -403,11 +398,30 @@ Train data (first 10):
403
  {json.dumps(train_data[:10], indent=2)}
404
  """
405
  print(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  time.sleep(2)
407
  return message
408
 
409
  if not should_training_progress_track_tqdm:
410
- progress(0, desc=f"Preparing model {base_model_name} for training...")
 
411
 
412
  log_history = []
413
 
@@ -445,9 +459,6 @@ Train data (first 10):
445
 
446
  Global.should_stop_training = False
447
 
448
- base_model = get_new_base_model(base_model_name)
449
- tokenizer = get_tokenizer(tokenizer_name)
450
-
451
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
452
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
453
 
@@ -498,33 +509,34 @@ Train data (first 10):
498
  wandb_tags.append(f"dataset:{dataset_from_data_dir}")
499
 
500
  train_output = Global.train_fn(
501
- base_model, # base_model
502
- tokenizer, # tokenizer
503
- output_dir, # output_dir
504
- train_data,
505
  # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
506
- micro_batch_size, # micro_batch_size
507
- gradient_accumulation_steps,
508
- epochs, # num_epochs
509
- learning_rate, # learning_rate
510
- max_seq_length, # cutoff_len
511
- evaluate_data_count, # val_set_size
512
- lora_r, # lora_r
513
- lora_alpha, # lora_alpha
514
- lora_dropout, # lora_dropout
515
- lora_target_modules, # lora_target_modules
516
- train_on_inputs, # train_on_inputs
517
- False, # group_by_length
518
- resume_from_checkpoint, # resume_from_checkpoint
519
- save_steps, # save_steps
520
- save_total_limit, # save_total_limit
521
- logging_steps, # logging_steps
522
- training_callbacks, # callbacks
523
- Global.wandb_api_key, # wandb_api_key
524
- Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
525
- wandb_group, # wandb_group
526
- model_name, # wandb_run_name
527
- wandb_tags # wandb_tags
 
528
  )
529
 
530
  logs_str = "\n".join([json.dumps(log)
@@ -578,10 +590,12 @@ def handle_load_params_from_model(
578
  lora_alpha,
579
  lora_dropout,
580
  lora_target_modules,
 
581
  save_steps,
582
  save_total_limit,
583
  logging_steps,
584
  lora_target_module_choices,
 
585
  ):
586
  error_message = ""
587
  notice_message = ""
@@ -633,6 +647,11 @@ def handle_load_params_from_model(
633
  for element in value:
634
  if element not in lora_target_module_choices:
635
  lora_target_module_choices.append(element)
 
 
 
 
 
636
  elif key == "save_steps":
637
  save_steps = value
638
  elif key == "save_total_limit":
@@ -670,15 +689,20 @@ def handle_load_params_from_model(
670
  lora_r,
671
  lora_alpha,
672
  lora_dropout,
673
- gr.CheckboxGroup.update(value=lora_target_modules, choices=lora_target_module_choices),
 
 
 
674
  save_steps,
675
  save_total_limit,
676
  logging_steps,
677
  lora_target_module_choices,
 
678
  )
679
 
680
 
681
  default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
 
682
 
683
 
684
  def handle_lora_target_modules_add(choices, new_module, selected_modules):
@@ -688,6 +712,13 @@ def handle_lora_target_modules_add(choices, new_module, selected_modules):
688
  return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
689
 
690
 
 
 
 
 
 
 
 
691
  def finetune_ui():
692
  things_that_might_timeout = []
693
 
@@ -863,12 +894,13 @@ def finetune_ui():
863
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
864
  )
865
 
866
- evaluate_data_count = gr.Slider(
867
- minimum=0, maximum=1, step=1, value=0,
868
- label="Evaluation Data Count",
869
- info="The number of data to be used for evaluation. This specific amount of data will be randomly chosen from the training dataset for evaluating the model's performance during the process, without contributing to the actual training.",
870
- elem_id="finetune_evaluate_data_count"
871
- )
 
872
 
873
  with gr.Box(elem_id="finetune_continue_from_model_box"):
874
  with gr.Row():
@@ -923,30 +955,65 @@ def finetune_ui():
923
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
924
  )
925
 
926
- lora_target_modules = gr.CheckboxGroup(
927
- label="LoRA Target Modules",
928
- choices=default_lora_target_module_choices,
929
- value=["q_proj", "v_proj"],
930
- info="Modules to replace with LoRA.",
931
- elem_id="finetune_lora_target_modules"
932
- )
933
- lora_target_module_choices = gr.State(value=default_lora_target_module_choices)
934
- with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
935
- with gr.Row():
936
- lora_target_modules_add = gr.Textbox(
937
- lines=1, max_lines=1, show_label=False,
938
- elem_id="finetune_lora_target_modules_add"
939
- )
940
- lora_target_modules_add_btn = gr.Button(
941
- "Add",
942
- elem_id="finetune_lora_target_modules_add_btn"
943
- )
944
- lora_target_modules_add_btn.style(full_width=False, size="sm")
945
- things_that_might_timeout.append(lora_target_modules_add_btn.click(
946
- handle_lora_target_modules_add,
947
- inputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
948
- outputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
949
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
950
 
951
  with gr.Row():
952
  logging_steps = gr.Number(
@@ -976,20 +1043,25 @@ def finetune_ui():
976
  elem_id="finetune_model_name",
977
  )
978
 
979
- with gr.Row():
980
- train_btn = gr.Button(
981
- "Train", variant="primary", label="Train",
982
- elem_id="finetune_start_btn"
983
- )
984
 
985
- abort_button = gr.Button(
986
- "Abort", label="Abort",
987
- elem_id="finetune_stop_btn"
988
- )
989
- confirm_abort_button = gr.Button(
990
- "Confirm Abort", label="Confirm Abort", variant="stop",
991
- elem_id="finetune_confirm_stop_btn"
992
- )
 
 
 
 
 
 
993
 
994
  things_that_might_timeout.append(reload_selections_button.click(
995
  reload_selections,
@@ -1031,6 +1103,7 @@ def finetune_ui():
1031
  lora_alpha,
1032
  lora_dropout,
1033
  lora_target_modules,
 
1034
  save_steps,
1035
  save_total_limit,
1036
  logging_steps,
@@ -1039,8 +1112,10 @@ def finetune_ui():
1039
  things_that_might_timeout.append(
1040
  load_params_from_model_btn.click(
1041
  fn=handle_load_params_from_model,
1042
- inputs=[continue_from_model] + finetune_args + [lora_target_module_choices],
1043
- outputs=[load_params_from_model_message] + finetune_args + [lora_target_module_choices]
 
 
1044
  )
1045
  )
1046
 
 
296
  lora_alpha,
297
  lora_dropout,
298
  lora_target_modules,
299
+ lora_modules_to_save,
300
  save_steps,
301
  save_total_limit,
302
  logging_steps,
 
315
  if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
316
  continue_from_checkpoint = None
317
  if continue_from_model:
318
+ resume_from_checkpoint = os.path.join(
319
+ Global.data_dir, "lora_models", continue_from_model)
320
  if continue_from_checkpoint:
321
+ resume_from_checkpoint = os.path.join(
322
+ resume_from_checkpoint, continue_from_checkpoint)
323
+ will_be_resume_from_checkpoint_file = os.path.join(
324
+ resume_from_checkpoint, "pytorch_model.bin")
325
  if not os.path.exists(will_be_resume_from_checkpoint_file):
326
+ raise ValueError(
327
+ f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
328
  else:
329
+ will_be_resume_from_checkpoint_file = os.path.join(
330
+ resume_from_checkpoint, "adapter_model.bin")
331
  if not os.path.exists(will_be_resume_from_checkpoint_file):
332
+ raise ValueError(
333
+ f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
334
 
335
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
336
  if os.path.exists(output_dir):
 
341
  if not should_training_progress_track_tqdm:
342
  progress(0, desc="Preparing train data...")
343
 
344
+ # Need RAM for training
345
+ unload_models()
346
+ Global.new_base_model_that_is_ready_to_be_used = None
347
+ Global.name_of_new_base_model_that_is_ready_to_be_used = None
348
+ clear_cache()
349
 
350
  prompter = Prompter(template)
351
  # variable_names = prompter.get_variable_names()
 
374
  if Global.ui_dev_mode:
375
  Global.should_stop_training = False
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  message = f"""Currently in UI dev mode, not doing the actual training.
378
 
379
  Train options: {json.dumps({
 
388
  'lora_alpha': lora_alpha,
389
  'lora_dropout': lora_dropout,
390
  'lora_target_modules': lora_target_modules,
391
+ 'lora_modules_to_save': lora_modules_to_save,
392
  'model_name': model_name,
393
  'continue_from_model': continue_from_model,
394
  'continue_from_checkpoint': continue_from_checkpoint,
 
398
  {json.dumps(train_data[:10], indent=2)}
399
  """
400
  print(message)
401
+
402
+ for i in range(300):
403
+ if (Global.should_stop_training):
404
+ return
405
+ epochs = 3
406
+ epoch = i / 100
407
+ last_loss = None
408
+ if (i > 20):
409
+ last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
410
+
411
+ progress(
412
+ (i, 300),
413
+ desc="(Simulate) " +
414
+ get_progress_text(epoch, epochs, last_loss)
415
+ )
416
+
417
+ time.sleep(0.1)
418
+
419
  time.sleep(2)
420
  return message
421
 
422
  if not should_training_progress_track_tqdm:
423
+ progress(
424
+ 0, desc=f"Preparing model {base_model_name} for training...")
425
 
426
  log_history = []
427
 
 
459
 
460
  Global.should_stop_training = False
461
 
 
 
 
462
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
463
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
464
 
 
509
  wandb_tags.append(f"dataset:{dataset_from_data_dir}")
510
 
511
  train_output = Global.train_fn(
512
+ base_model=base_model_name,
513
+ tokenizer=tokenizer_name,
514
+ output_dir=output_dir,
515
+ train_data=train_data,
516
  # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
517
+ micro_batch_size=micro_batch_size,
518
+ gradient_accumulation_steps=gradient_accumulation_steps,
519
+ num_epochs=epochs,
520
+ learning_rate=learning_rate,
521
+ cutoff_len=max_seq_length,
522
+ val_set_size=evaluate_data_count,
523
+ lora_r=lora_r,
524
+ lora_alpha=lora_alpha,
525
+ lora_dropout=lora_dropout,
526
+ lora_target_modules=lora_target_modules,
527
+ lora_modules_to_save=lora_modules_to_save,
528
+ train_on_inputs=train_on_inputs,
529
+ group_by_length=False,
530
+ resume_from_checkpoint=resume_from_checkpoint,
531
+ save_steps=save_steps,
532
+ save_total_limit=save_total_limit,
533
+ logging_steps=logging_steps,
534
+ callbacks=training_callbacks,
535
+ wandb_api_key=Global.wandb_api_key,
536
+ wandb_project=Global.default_wandb_project if Global.enable_wandb else None,
537
+ wandb_group=wandb_group,
538
+ wandb_run_name=model_name,
539
+ wandb_tags=wandb_tags
540
  )
541
 
542
  logs_str = "\n".join([json.dumps(log)
 
590
  lora_alpha,
591
  lora_dropout,
592
  lora_target_modules,
593
+ lora_modules_to_save,
594
  save_steps,
595
  save_total_limit,
596
  logging_steps,
597
  lora_target_module_choices,
598
+ lora_modules_to_save_choices,
599
  ):
600
  error_message = ""
601
  notice_message = ""
 
647
  for element in value:
648
  if element not in lora_target_module_choices:
649
  lora_target_module_choices.append(element)
650
+ elif key == "lora_modules_to_save":
651
+ lora_modules_to_save = value
652
+ for element in value:
653
+ if element not in lora_modules_to_save_choices:
654
+ lora_modules_to_save_choices.append(element)
655
  elif key == "save_steps":
656
  save_steps = value
657
  elif key == "save_total_limit":
 
689
  lora_r,
690
  lora_alpha,
691
  lora_dropout,
692
+ gr.CheckboxGroup.update(value=lora_target_modules,
693
+ choices=lora_target_module_choices),
694
+ gr.CheckboxGroup.update(
695
+ value=lora_modules_to_save, choices=lora_modules_to_save_choices),
696
  save_steps,
697
  save_total_limit,
698
  logging_steps,
699
  lora_target_module_choices,
700
+ lora_modules_to_save_choices
701
  )
702
 
703
 
704
  default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
705
+ default_lora_modules_to_save_choices = ["model.embed_tokens", "lm_head"]
706
 
707
 
708
  def handle_lora_target_modules_add(choices, new_module, selected_modules):
 
712
  return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
713
 
714
 
715
+ def handle_lora_modules_to_save_add(choices, new_module, selected_modules):
716
+ choices.append(new_module)
717
+ selected_modules.append(new_module)
718
+
719
+ return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
720
+
721
+
722
  def finetune_ui():
723
  things_that_might_timeout = []
724
 
 
894
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
895
  )
896
 
897
+ with gr.Column():
898
+ evaluate_data_count = gr.Slider(
899
+ minimum=0, maximum=1, step=1, value=0,
900
+ label="Evaluation Data Count",
901
+ info="The number of data to be used for evaluation. This specific amount of data will be randomly chosen from the training dataset for evaluating the model's performance during the process, without contributing to the actual training.",
902
+ elem_id="finetune_evaluate_data_count"
903
+ )
904
 
905
  with gr.Box(elem_id="finetune_continue_from_model_box"):
906
  with gr.Row():
 
955
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
956
  )
957
 
958
+ with gr.Column(elem_id="finetune_lora_target_modules_box"):
959
+ lora_target_modules = gr.CheckboxGroup(
960
+ label="LoRA Target Modules",
961
+ choices=default_lora_target_module_choices,
962
+ value=["q_proj", "v_proj"],
963
+ info="Modules to replace with LoRA.",
964
+ elem_id="finetune_lora_target_modules"
965
+ )
966
+ lora_target_module_choices = gr.State(
967
+ value=default_lora_target_module_choices)
968
+ with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
969
+ with gr.Row():
970
+ lora_target_modules_add = gr.Textbox(
971
+ lines=1, max_lines=1, show_label=False,
972
+ elem_id="finetune_lora_target_modules_add"
973
+ )
974
+ lora_target_modules_add_btn = gr.Button(
975
+ "Add",
976
+ elem_id="finetune_lora_target_modules_add_btn"
977
+ )
978
+ lora_target_modules_add_btn.style(
979
+ full_width=False, size="sm")
980
+ things_that_might_timeout.append(lora_target_modules_add_btn.click(
981
+ handle_lora_target_modules_add,
982
+ inputs=[lora_target_module_choices,
983
+ lora_target_modules_add, lora_target_modules],
984
+ outputs=[lora_target_module_choices,
985
+ lora_target_modules_add, lora_target_modules],
986
+ ))
987
+
988
+ with gr.Column(elem_id="finetune_lora_modules_to_save_box"):
989
+ lora_modules_to_save = gr.CheckboxGroup(
990
+ label="LoRA Modules To Save",
991
+ choices=default_lora_modules_to_save_choices,
992
+ value=[],
993
+ # info="",
994
+ elem_id="finetune_lora_modules_to_save"
995
+ )
996
+ lora_modules_to_save_choices = gr.State(
997
+ value=default_lora_modules_to_save_choices)
998
+ with gr.Box(elem_id="finetune_lora_modules_to_save_add_box"):
999
+ with gr.Row():
1000
+ lora_modules_to_save_add = gr.Textbox(
1001
+ lines=1, max_lines=1, show_label=False,
1002
+ elem_id="finetune_lora_modules_to_save_add"
1003
+ )
1004
+ lora_modules_to_save_add_btn = gr.Button(
1005
+ "Add",
1006
+ elem_id="finetune_lora_modules_to_save_add_btn"
1007
+ )
1008
+ lora_modules_to_save_add_btn.style(
1009
+ full_width=False, size="sm")
1010
+ things_that_might_timeout.append(lora_modules_to_save_add_btn.click(
1011
+ handle_lora_modules_to_save_add,
1012
+ inputs=[lora_modules_to_save_choices,
1013
+ lora_modules_to_save_add, lora_modules_to_save],
1014
+ outputs=[lora_modules_to_save_choices,
1015
+ lora_modules_to_save_add, lora_modules_to_save],
1016
+ ))
1017
 
1018
  with gr.Row():
1019
  logging_steps = gr.Number(
 
1043
  elem_id="finetune_model_name",
1044
  )
1045
 
1046
+ with gr.Row():
1047
+ with gr.Column():
1048
+ pass
1049
+ with gr.Column():
 
1050
 
1051
+ with gr.Row():
1052
+ train_btn = gr.Button(
1053
+ "Train", variant="primary", label="Train",
1054
+ elem_id="finetune_start_btn"
1055
+ )
1056
+
1057
+ abort_button = gr.Button(
1058
+ "Abort", label="Abort",
1059
+ elem_id="finetune_stop_btn"
1060
+ )
1061
+ confirm_abort_button = gr.Button(
1062
+ "Confirm Abort", label="Confirm Abort", variant="stop",
1063
+ elem_id="finetune_confirm_stop_btn"
1064
+ )
1065
 
1066
  things_that_might_timeout.append(reload_selections_button.click(
1067
  reload_selections,
 
1103
  lora_alpha,
1104
  lora_dropout,
1105
  lora_target_modules,
1106
+ lora_modules_to_save,
1107
  save_steps,
1108
  save_total_limit,
1109
  logging_steps,
 
1112
  things_that_might_timeout.append(
1113
  load_params_from_model_btn.click(
1114
  fn=handle_load_params_from_model,
1115
+ inputs=[continue_from_model] + finetune_args +
1116
+ [lora_target_module_choices, lora_modules_to_save_choices],
1117
+ outputs=[load_params_from_model_message] + finetune_args +
1118
+ [lora_target_module_choices, lora_modules_to_save_choices]
1119
  )
1120
  )
1121
 
llama_lora/ui/main_page.py CHANGED
@@ -733,24 +733,54 @@ def main_page_custom_css():
733
  flex: 2;
734
  }
735
 
736
- #finetune_lora_target_modules_add_box {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  margin-top: -24px;
738
  padding-top: 8px;
739
  border-top-left-radius: 0;
740
  border-top-right-radius: 0;
741
  border-top: 0;
742
  }
743
- #finetune_lora_target_modules_add_box > * > .form {
 
744
  border: 0;
745
  box-shadow: none;
746
  }
747
- #finetune_lora_target_modules_add {
 
748
  padding: 0;
749
  }
750
- #finetune_lora_target_modules_add input {
 
751
  padding: 4px 8px;
752
  }
753
- #finetune_lora_target_modules_add_btn {
 
754
  min-width: 60px;
755
  }
756
 
 
733
  flex: 2;
734
  }
735
 
736
+ #finetune_lora_target_modules_box,
737
+ #finetune_lora_modules_to_save_box {
738
+ margin-top: -24px;
739
+ }
740
+ #finetune_lora_target_modules_box > .form,
741
+ #finetune_lora_modules_to_save_box > .form {
742
+ padding-top: 8px;
743
+ border-top: 0;
744
+ border-top-left-radius: 0;
745
+ border-top-right-radius: 0;
746
+ background: var(--block-background-fill);
747
+ position: relative;
748
+ }
749
+ #finetune_lora_target_modules_box > .form::before,
750
+ #finetune_lora_modules_to_save_box > .form::before {
751
+ content: "";
752
+ display: block;
753
+ position: absolute;
754
+ top: 8px;
755
+ left: 0;
756
+ right: 0;
757
+ height: 1px;
758
+ z-index: 1;
759
+ background: var(--block-border-color);
760
+ }
761
+ #finetune_lora_target_modules_add_box,
762
+ #finetune_lora_modules_to_save_add_box {
763
  margin-top: -24px;
764
  padding-top: 8px;
765
  border-top-left-radius: 0;
766
  border-top-right-radius: 0;
767
  border-top: 0;
768
  }
769
+ #finetune_lora_target_modules_add_box > * > .form,
770
+ #finetune_lora_modules_to_save_add_box > * > .form {
771
  border: 0;
772
  box-shadow: none;
773
  }
774
+ #finetune_lora_target_modules_add,
775
+ #finetune_lora_modules_to_save_add {
776
  padding: 0;
777
  }
778
+ #finetune_lora_target_modules_add input,
779
+ #finetune_lora_modules_to_save_add input {
780
  padding: 4px 8px;
781
  }
782
+ #finetune_lora_target_modules_add_btn,
783
+ #finetune_lora_modules_to_save_add_btn {
784
  min-width: 60px;
785
  }
786