zetavg commited on
Commit
0a36bb6
1 Parent(s): 9713284

add more finetune options

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -28,8 +28,11 @@ def train(
28
  tokenizer: Any,
29
  output_dir: str,
30
  train_data: List[Any],
 
31
  load_in_8bit=True,
32
  fp16=True,
 
 
33
  # training hyperparams
34
  micro_batch_size: int = 4,
35
  gradient_accumulation_steps: int = 32,
@@ -79,18 +82,21 @@ def train(
79
  'lora_alpha': lora_alpha,
80
  'lora_dropout': lora_dropout,
81
  'lora_target_modules': lora_target_modules,
 
82
  'train_on_inputs': train_on_inputs,
83
  'group_by_length': group_by_length,
84
  'load_in_8bit': load_in_8bit,
85
  'fp16': fp16,
 
 
86
  'save_steps': save_steps,
87
  'save_total_limit': save_total_limit,
88
  'logging_steps': logging_steps,
89
  }
90
  if val_set_size and val_set_size > 0:
91
  finetune_args['val_set_size'] = val_set_size
92
- if lora_modules_to_save:
93
- finetune_args['lora_modules_to_save'] = lora_modules_to_save
94
  if resume_from_checkpoint:
95
  finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
96
 
@@ -232,6 +238,8 @@ def train(
232
  task_type="CAUSAL_LM",
233
  )
234
  model = get_peft_model(model, config)
 
 
235
 
236
  # If train_data is a list, convert it to datasets.Dataset
237
  if isinstance(train_data, list):
@@ -289,11 +297,13 @@ def train(
289
  # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
290
  args=transformers.TrainingArguments(
291
  per_device_train_batch_size=micro_batch_size,
 
292
  gradient_accumulation_steps=gradient_accumulation_steps,
293
  warmup_steps=100,
294
  num_train_epochs=num_train_epochs,
295
  learning_rate=learning_rate,
296
  fp16=fp16,
 
297
  logging_steps=logging_steps,
298
  optim="adamw_torch",
299
  evaluation_strategy="steps" if val_set_size > 0 else "no",
 
28
  tokenizer: Any,
29
  output_dir: str,
30
  train_data: List[Any],
31
+ #
32
  load_in_8bit=True,
33
  fp16=True,
34
+ bf16=False,
35
+ gradient_checkpointing=False,
36
  # training hyperparams
37
  micro_batch_size: int = 4,
38
  gradient_accumulation_steps: int = 32,
 
82
  'lora_alpha': lora_alpha,
83
  'lora_dropout': lora_dropout,
84
  'lora_target_modules': lora_target_modules,
85
+ 'lora_modules_to_save': lora_modules_to_save or [],
86
  'train_on_inputs': train_on_inputs,
87
  'group_by_length': group_by_length,
88
  'load_in_8bit': load_in_8bit,
89
  'fp16': fp16,
90
+ 'bf16': bf16,
91
+ 'gradient_checkpointing': gradient_checkpointing,
92
  'save_steps': save_steps,
93
  'save_total_limit': save_total_limit,
94
  'logging_steps': logging_steps,
95
  }
96
  if val_set_size and val_set_size > 0:
97
  finetune_args['val_set_size'] = val_set_size
98
+ # if lora_modules_to_save:
99
+ # finetune_args['lora_modules_to_save'] = lora_modules_to_save
100
  if resume_from_checkpoint:
101
  finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
102
 
 
238
  task_type="CAUSAL_LM",
239
  )
240
  model = get_peft_model(model, config)
241
+ if bf16:
242
+ model = model.to(torch.bfloat16)
243
 
244
  # If train_data is a list, convert it to datasets.Dataset
245
  if isinstance(train_data, list):
 
297
  # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
298
  args=transformers.TrainingArguments(
299
  per_device_train_batch_size=micro_batch_size,
300
+ gradient_checkpointing=gradient_checkpointing,
301
  gradient_accumulation_steps=gradient_accumulation_steps,
302
  warmup_steps=100,
303
  num_train_epochs=num_train_epochs,
304
  learning_rate=learning_rate,
305
  fp16=fp16,
306
+ bf16=bf16,
307
  logging_steps=logging_steps,
308
  optim="adamw_torch",
309
  evaluation_strategy="steps" if val_set_size > 0 else "no",
llama_lora/ui/finetune_ui.py CHANGED
@@ -299,6 +299,8 @@ def do_train(
299
  lora_modules_to_save,
300
  load_in_8bit,
301
  fp16,
 
 
302
  save_steps,
303
  save_total_limit,
304
  logging_steps,
@@ -393,6 +395,8 @@ Train options: {json.dumps({
393
  'lora_modules_to_save': lora_modules_to_save,
394
  'load_in_8bit': load_in_8bit,
395
  'fp16': fp16,
 
 
396
  'model_name': model_name,
397
  'continue_from_model': continue_from_model,
398
  'continue_from_checkpoint': continue_from_checkpoint,
@@ -532,6 +536,8 @@ Train data (first 10):
532
  train_on_inputs=train_on_inputs,
533
  load_in_8bit=load_in_8bit,
534
  fp16=fp16,
 
 
535
  group_by_length=False,
536
  resume_from_checkpoint=resume_from_checkpoint,
537
  save_steps=save_steps,
@@ -548,8 +554,9 @@ Train data (first 10):
548
  logs_str = "\n".join([json.dumps(log)
549
  for log in log_history]) or "None"
550
 
551
- result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
552
  print(result_message)
 
553
 
554
  clear_cache()
555
 
@@ -597,6 +604,8 @@ def handle_load_params_from_model(
597
  lora_modules_to_save,
598
  load_in_8bit,
599
  fp16,
 
 
600
  save_steps,
601
  save_total_limit,
602
  logging_steps,
@@ -650,18 +659,24 @@ def handle_load_params_from_model(
650
  lora_dropout = value
651
  elif key == "lora_target_modules":
652
  lora_target_modules = value
653
- for element in value:
654
- if element not in lora_target_module_choices:
655
- lora_target_module_choices.append(element)
 
656
  elif key == "lora_modules_to_save":
657
  lora_modules_to_save = value
658
- for element in value:
659
- if element not in lora_modules_to_save_choices:
660
- lora_modules_to_save_choices.append(element)
 
661
  elif key == "load_in_8bit":
662
  load_in_8bit = value
663
  elif key == "fp16":
664
  fp16 = value
 
 
 
 
665
  elif key == "save_steps":
666
  save_steps = value
667
  elif key == "save_total_limit":
@@ -705,6 +720,8 @@ def handle_load_params_from_model(
705
  value=lora_modules_to_save, choices=lora_modules_to_save_choices),
706
  load_in_8bit,
707
  fp16,
 
 
708
  save_steps,
709
  save_total_limit,
710
  logging_steps,
@@ -949,9 +966,11 @@ def finetune_ui():
949
  )
950
 
951
  with gr.Accordion("Advanced Options", open=False, elem_id="finetune_advance_options_accordion"):
952
- with gr.Row():
953
- load_in_8bit = gr.Checkbox(label="8bit", value=True)
954
  fp16 = gr.Checkbox(label="FP16", value=True)
 
 
955
 
956
  with gr.Column():
957
  lora_r = gr.Slider(
@@ -1002,57 +1021,62 @@ def finetune_ui():
1002
  lora_target_modules_add, lora_target_modules],
1003
  ))
1004
 
1005
- with gr.Column(elem_id="finetune_lora_modules_to_save_box"):
1006
- lora_modules_to_save = gr.CheckboxGroup(
1007
- label="LoRA Modules To Save",
1008
- choices=default_lora_modules_to_save_choices,
1009
- value=[],
1010
- # info="",
1011
- elem_id="finetune_lora_modules_to_save"
1012
- )
1013
- lora_modules_to_save_choices = gr.State(
1014
- value=default_lora_modules_to_save_choices)
1015
- with gr.Box(elem_id="finetune_lora_modules_to_save_add_box"):
1016
- with gr.Row():
1017
- lora_modules_to_save_add = gr.Textbox(
1018
- lines=1, max_lines=1, show_label=False,
1019
- elem_id="finetune_lora_modules_to_save_add"
1020
- )
1021
- lora_modules_to_save_add_btn = gr.Button(
1022
- "Add",
1023
- elem_id="finetune_lora_modules_to_save_add_btn"
1024
- )
1025
- lora_modules_to_save_add_btn.style(
1026
- full_width=False, size="sm")
1027
- things_that_might_timeout.append(lora_modules_to_save_add_btn.click(
1028
- handle_lora_modules_to_save_add,
1029
- inputs=[lora_modules_to_save_choices,
1030
- lora_modules_to_save_add, lora_modules_to_save],
1031
- outputs=[lora_modules_to_save_choices,
1032
- lora_modules_to_save_add, lora_modules_to_save],
1033
- ))
1034
-
1035
- with gr.Row():
1036
- logging_steps = gr.Number(
1037
- label="Logging Steps",
1038
- precision=0,
1039
- value=10,
1040
- elem_id="finetune_logging_steps"
1041
- )
1042
- save_steps = gr.Number(
1043
- label="Steps Per Save",
1044
- precision=0,
1045
- value=500,
1046
- elem_id="finetune_save_steps"
1047
- )
1048
- save_total_limit = gr.Number(
1049
- label="Saved Checkpoints Limit",
1050
- precision=0,
1051
- value=5,
1052
- elem_id="finetune_save_total_limit"
1053
- )
 
 
 
 
 
1054
 
1055
- with gr.Column():
1056
  model_name = gr.Textbox(
1057
  lines=1, label="LoRA Model Name", value=random_name,
1058
  max_lines=1,
@@ -1123,6 +1147,8 @@ def finetune_ui():
1123
  lora_modules_to_save,
1124
  load_in_8bit,
1125
  fp16,
 
 
1126
  save_steps,
1127
  save_total_limit,
1128
  logging_steps,
 
299
  lora_modules_to_save,
300
  load_in_8bit,
301
  fp16,
302
+ bf16,
303
+ gradient_checkpointing,
304
  save_steps,
305
  save_total_limit,
306
  logging_steps,
 
395
  'lora_modules_to_save': lora_modules_to_save,
396
  'load_in_8bit': load_in_8bit,
397
  'fp16': fp16,
398
+ 'bf16': bf16,
399
+ 'gradient_checkpointing': gradient_checkpointing,
400
  'model_name': model_name,
401
  'continue_from_model': continue_from_model,
402
  'continue_from_checkpoint': continue_from_checkpoint,
 
536
  train_on_inputs=train_on_inputs,
537
  load_in_8bit=load_in_8bit,
538
  fp16=fp16,
539
+ bf16=bf16,
540
+ gradient_checkpointing=gradient_checkpointing,
541
  group_by_length=False,
542
  resume_from_checkpoint=resume_from_checkpoint,
543
  save_steps=save_steps,
 
554
  logs_str = "\n".join([json.dumps(log)
555
  for log in log_history]) or "None"
556
 
557
+ result_message = f"Training ended:\n{str(train_output)}"
558
  print(result_message)
559
+ # result_message += f"\n\nLogs:\n{logs_str}"
560
 
561
  clear_cache()
562
 
 
604
  lora_modules_to_save,
605
  load_in_8bit,
606
  fp16,
607
+ bf16,
608
+ gradient_checkpointing,
609
  save_steps,
610
  save_total_limit,
611
  logging_steps,
 
659
  lora_dropout = value
660
  elif key == "lora_target_modules":
661
  lora_target_modules = value
662
+ if value:
663
+ for element in value:
664
+ if element not in lora_target_module_choices:
665
+ lora_target_module_choices.append(element)
666
  elif key == "lora_modules_to_save":
667
  lora_modules_to_save = value
668
+ if value:
669
+ for element in value:
670
+ if element not in lora_modules_to_save_choices:
671
+ lora_modules_to_save_choices.append(element)
672
  elif key == "load_in_8bit":
673
  load_in_8bit = value
674
  elif key == "fp16":
675
  fp16 = value
676
+ elif key == "bf16":
677
+ bf16 = value
678
+ elif key == "gradient_checkpointing":
679
+ gradient_checkpointing = value
680
  elif key == "save_steps":
681
  save_steps = value
682
  elif key == "save_total_limit":
 
720
  value=lora_modules_to_save, choices=lora_modules_to_save_choices),
721
  load_in_8bit,
722
  fp16,
723
+ bf16,
724
+ gradient_checkpointing,
725
  save_steps,
726
  save_total_limit,
727
  logging_steps,
 
966
  )
967
 
968
  with gr.Accordion("Advanced Options", open=False, elem_id="finetune_advance_options_accordion"):
969
+ with gr.Row(elem_id="finetune_advanced_options_checkboxes"):
970
+ load_in_8bit = gr.Checkbox(label="8bit", value=False)
971
  fp16 = gr.Checkbox(label="FP16", value=True)
972
+ bf16 = gr.Checkbox(label="BF16", value=False)
973
+ gradient_checkpointing = gr.Checkbox(label="gradient_checkpointing", value=False)
974
 
975
  with gr.Column():
976
  lora_r = gr.Slider(
 
1021
  lora_target_modules_add, lora_target_modules],
1022
  ))
1023
 
1024
+ with gr.Accordion("Advanced LoRA Options", open=False, elem_id="finetune_advance_lora_options_accordion"):
1025
+ with gr.Column(elem_id="finetune_lora_modules_to_save_box"):
1026
+ lora_modules_to_save = gr.CheckboxGroup(
1027
+ label="LoRA Modules To Save",
1028
+ choices=default_lora_modules_to_save_choices,
1029
+ value=[],
1030
+ # info="",
1031
+ elem_id="finetune_lora_modules_to_save"
1032
+ )
1033
+ lora_modules_to_save_choices = gr.State(
1034
+ value=default_lora_modules_to_save_choices)
1035
+ with gr.Box(elem_id="finetune_lora_modules_to_save_add_box"):
1036
+ with gr.Row():
1037
+ lora_modules_to_save_add = gr.Textbox(
1038
+ lines=1, max_lines=1, show_label=False,
1039
+ elem_id="finetune_lora_modules_to_save_add"
1040
+ )
1041
+ lora_modules_to_save_add_btn = gr.Button(
1042
+ "Add",
1043
+ elem_id="finetune_lora_modules_to_save_add_btn"
1044
+ )
1045
+ lora_modules_to_save_add_btn.style(
1046
+ full_width=False, size="sm")
1047
+ things_that_might_timeout.append(lora_modules_to_save_add_btn.click(
1048
+ handle_lora_modules_to_save_add,
1049
+ inputs=[lora_modules_to_save_choices,
1050
+ lora_modules_to_save_add, lora_modules_to_save],
1051
+ outputs=[lora_modules_to_save_choices,
1052
+ lora_modules_to_save_add, lora_modules_to_save],
1053
+ ))
1054
+
1055
+ # with gr.Column():
1056
+ # pass
1057
+
1058
+ with gr.Column(elem_id="finetune_log_and_save_options_group_container"):
1059
+ with gr.Row(elem_id="finetune_log_and_save_options_group"):
1060
+ logging_steps = gr.Number(
1061
+ label="Logging Steps",
1062
+ precision=0,
1063
+ value=10,
1064
+ elem_id="finetune_logging_steps"
1065
+ )
1066
+ save_steps = gr.Number(
1067
+ label="Steps Per Save",
1068
+ precision=0,
1069
+ value=500,
1070
+ elem_id="finetune_save_steps"
1071
+ )
1072
+ save_total_limit = gr.Number(
1073
+ label="Saved Checkpoints Limit",
1074
+ precision=0,
1075
+ value=5,
1076
+ elem_id="finetune_save_total_limit"
1077
+ )
1078
 
1079
+ with gr.Column(elem_id="finetune_model_name_group"):
1080
  model_name = gr.Textbox(
1081
  lines=1, label="LoRA Model Name", value=random_name,
1082
  max_lines=1,
 
1147
  lora_modules_to_save,
1148
  load_in_8bit,
1149
  fp16,
1150
+ bf16,
1151
+ gradient_checkpointing,
1152
  save_steps,
1153
  save_total_limit,
1154
  logging_steps,
llama_lora/ui/main_page.py CHANGED
@@ -734,11 +734,12 @@ def main_page_custom_css():
734
  }
735
 
736
  #finetune_lora_target_modules_box,
737
- #finetune_lora_modules_to_save_box {
738
- margin-top: calc((var(--layout-gap) + 8px) * -1)
 
739
  }
740
  #finetune_lora_target_modules_box > .form,
741
- #finetune_lora_modules_to_save_box > .form {
742
  padding-top: calc((var(--layout-gap) + 8px) / 3);
743
  border-top: 0;
744
  border-top-left-radius: 0;
@@ -747,7 +748,7 @@ def main_page_custom_css():
747
  position: relative;
748
  }
749
  #finetune_lora_target_modules_box > .form::before,
750
- #finetune_lora_modules_to_save_box > .form::before {
751
  content: "";
752
  display: block;
753
  position: absolute;
@@ -802,6 +803,18 @@ def main_page_custom_css():
802
  padding: 4px 8px;
803
  }
804
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  @media screen and (max-width: 392px) {
806
  #inference_lora_model, #inference_lora_model_group, #finetune_template {
807
  border-bottom-left-radius: 0;
 
734
  }
735
 
736
  #finetune_lora_target_modules_box,
737
+ #finetune_lora_target_modules_box + #finetune_lora_modules_to_save_box {
738
+ margin-top: calc((var(--layout-gap) + 8px) * -1);
739
+ flex-grow: 0 !important;
740
  }
741
  #finetune_lora_target_modules_box > .form,
742
+ #finetune_lora_target_modules_box + #finetune_lora_modules_to_save_box > .form {
743
  padding-top: calc((var(--layout-gap) + 8px) / 3);
744
  border-top: 0;
745
  border-top-left-radius: 0;
 
748
  position: relative;
749
  }
750
  #finetune_lora_target_modules_box > .form::before,
751
+ #finetune_lora_target_modules_box + #finetune_lora_modules_to_save_box > .form::before {
752
  content: "";
753
  display: block;
754
  position: absolute;
 
803
  padding: 4px 8px;
804
  }
805
 
806
+ #finetune_advanced_options_checkboxes > * > * {
807
+ min-width: auto;
808
+ }
809
+
810
+ #finetune_log_and_save_options_group_container {
811
+ flex-grow: 1 !important;
812
+ justify-content: flex-end;
813
+ }
814
+ #finetune_model_name_group {
815
+ flex-grow: 0 !important;
816
+ }
817
+
818
  @media screen and (max-width: 392px) {
819
  #inference_lora_model, #inference_lora_model_group, #finetune_template {
820
  border-bottom-left-radius: 0;