zetavg commited on
Commit
e652ee3
·
unverified ·
1 Parent(s): 8e2e7b5

finetune: support adding additional_training_arguments and additional_lora_config

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -57,6 +57,9 @@ def train(
57
  save_steps: int = 200,
58
  save_total_limit: int = 3,
59
  logging_steps: int = 10,
 
 
 
60
  # logging
61
  callbacks: List[Any] = [],
62
  # wandb params
@@ -70,6 +73,27 @@ def train(
70
  ):
71
  if lora_modules_to_save is not None and len(lora_modules_to_save) <= 0:
72
  lora_modules_to_save = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # for logging
74
  finetune_args = {
75
  'micro_batch_size': micro_batch_size,
@@ -92,6 +116,8 @@ def train(
92
  'save_steps': save_steps,
93
  'save_total_limit': save_total_limit,
94
  'logging_steps': logging_steps,
 
 
95
  }
96
  if val_set_size and val_set_size > 0:
97
  finetune_args['val_set_size'] = val_set_size
@@ -243,6 +269,7 @@ def train(
243
  lora_dropout=lora_dropout,
244
  bias="none",
245
  task_type="CAUSAL_LM",
 
246
  )
247
  model = get_peft_model(model, config)
248
  if bf16:
@@ -324,6 +351,7 @@ def train(
324
  group_by_length=group_by_length,
325
  report_to="wandb" if use_wandb else None,
326
  run_name=wandb_run_name if use_wandb else None,
 
327
  ),
328
  data_collator=transformers.DataCollatorForSeq2Seq(
329
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
 
57
  save_steps: int = 200,
58
  save_total_limit: int = 3,
59
  logging_steps: int = 10,
60
+ #
61
+ additional_training_arguments: Union[dict, str, None] = None,
62
+ additional_lora_config: Union[dict, str, None] = None,
63
  # logging
64
  callbacks: List[Any] = [],
65
  # wandb params
 
73
  ):
74
  if lora_modules_to_save is not None and len(lora_modules_to_save) <= 0:
75
  lora_modules_to_save = None
76
+
77
+ if isinstance(additional_training_arguments, str):
78
+ additional_training_arguments = additional_training_arguments.strip()
79
+ if not additional_training_arguments:
80
+ additional_training_arguments = None
81
+ if isinstance(additional_training_arguments, str):
82
+ try:
83
+ additional_training_arguments = json.loads(additional_training_arguments)
84
+ except Exception as e:
85
+ raise ValueError(f"Could not parse additional_training_arguments: {e}")
86
+
87
+ if isinstance(additional_lora_config, str):
88
+ additional_lora_config = additional_lora_config.strip()
89
+ if not additional_lora_config:
90
+ additional_lora_config = None
91
+ if isinstance(additional_lora_config, str):
92
+ try:
93
+ additional_lora_config = json.loads(additional_lora_config)
94
+ except Exception as e:
95
+ raise ValueError(f"Could not parse additional_training_arguments: {e}")
96
+
97
  # for logging
98
  finetune_args = {
99
  'micro_batch_size': micro_batch_size,
 
116
  'save_steps': save_steps,
117
  'save_total_limit': save_total_limit,
118
  'logging_steps': logging_steps,
119
+ 'additional_training_arguments': additional_training_arguments,
120
+ 'additional_lora_config': additional_lora_config,
121
  }
122
  if val_set_size and val_set_size > 0:
123
  finetune_args['val_set_size'] = val_set_size
 
269
  lora_dropout=lora_dropout,
270
  bias="none",
271
  task_type="CAUSAL_LM",
272
+ **additional_lora_config,
273
  )
274
  model = get_peft_model(model, config)
275
  if bf16:
 
351
  group_by_length=group_by_length,
352
  report_to="wandb" if use_wandb else None,
353
  run_name=wandb_run_name if use_wandb else None,
354
+ **additional_training_arguments
355
  ),
356
  data_collator=transformers.DataCollatorForSeq2Seq(
357
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
llama_lora/ui/finetune_ui.py CHANGED
@@ -305,6 +305,8 @@ def do_train(
305
  save_steps,
306
  save_total_limit,
307
  logging_steps,
 
 
308
  model_name,
309
  continue_from_model,
310
  continue_from_checkpoint,
@@ -566,6 +568,8 @@ Train data (first 10):
566
  save_steps=save_steps,
567
  save_total_limit=save_total_limit,
568
  logging_steps=logging_steps,
 
 
569
  callbacks=training_callbacks,
570
  wandb_api_key=Global.wandb_api_key,
571
  wandb_project=Global.default_wandb_project if Global.enable_wandb else None,
@@ -632,6 +636,8 @@ def handle_load_params_from_model(
632
  save_steps,
633
  save_total_limit,
634
  logging_steps,
 
 
635
  lora_target_module_choices,
636
  lora_modules_to_save_choices,
637
  ):
@@ -706,6 +712,16 @@ def handle_load_params_from_model(
706
  save_total_limit = value
707
  elif key == "logging_steps":
708
  logging_steps = value
 
 
 
 
 
 
 
 
 
 
709
  elif key == "group_by_length":
710
  pass
711
  elif key == "resume_from_checkpoint":
@@ -748,6 +764,8 @@ def handle_load_params_from_model(
748
  save_steps,
749
  save_total_limit,
750
  logging_steps,
 
 
751
  lora_target_module_choices,
752
  lora_modules_to_save_choices
753
  )
@@ -946,13 +964,14 @@ def finetune_ui():
946
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
947
  )
948
 
949
- with gr.Column():
950
  evaluate_data_count = gr.Slider(
951
  minimum=0, maximum=1, step=1, value=0,
952
  label="Evaluation Data Count",
953
  info="The number of data to be used for evaluation. This specific amount of data will be randomly chosen from the training dataset for evaluating the model's performance during the process, without contributing to the actual training.",
954
  elem_id="finetune_evaluate_data_count"
955
  )
 
956
 
957
  with gr.Box(elem_id="finetune_continue_from_model_box"):
958
  with gr.Row():
@@ -996,6 +1015,18 @@ def finetune_ui():
996
  bf16 = gr.Checkbox(label="BF16", value=False)
997
  gradient_checkpointing = gr.Checkbox(
998
  label="gradient_checkpointing", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
999
 
1000
  with gr.Column():
1001
  lora_r = gr.Slider(
@@ -1077,8 +1108,20 @@ def finetune_ui():
1077
  lora_modules_to_save_add, lora_modules_to_save],
1078
  ))
1079
 
1080
- # with gr.Column():
1081
- # pass
 
 
 
 
 
 
 
 
 
 
 
 
1082
 
1083
  with gr.Column(elem_id="finetune_log_and_save_options_group_container"):
1084
  with gr.Row(elem_id="finetune_log_and_save_options_group"):
@@ -1177,6 +1220,8 @@ def finetune_ui():
1177
  save_steps,
1178
  save_total_limit,
1179
  logging_steps,
 
 
1180
  ]
1181
 
1182
  things_that_might_timeout.append(
 
305
  save_steps,
306
  save_total_limit,
307
  logging_steps,
308
+ additional_training_arguments,
309
+ additional_lora_config,
310
  model_name,
311
  continue_from_model,
312
  continue_from_checkpoint,
 
568
  save_steps=save_steps,
569
  save_total_limit=save_total_limit,
570
  logging_steps=logging_steps,
571
+ additional_training_arguments=additional_training_arguments,
572
+ additional_lora_config=additional_lora_config,
573
  callbacks=training_callbacks,
574
  wandb_api_key=Global.wandb_api_key,
575
  wandb_project=Global.default_wandb_project if Global.enable_wandb else None,
 
636
  save_steps,
637
  save_total_limit,
638
  logging_steps,
639
+ additional_training_arguments,
640
+ additional_lora_config,
641
  lora_target_module_choices,
642
  lora_modules_to_save_choices,
643
  ):
 
712
  save_total_limit = value
713
  elif key == "logging_steps":
714
  logging_steps = value
715
+ elif key == "additional_training_arguments":
716
+ if value:
717
+ additional_training_arguments = json.dumps(value, indent=2)
718
+ else:
719
+ additional_training_arguments = ""
720
+ elif key == "additional_lora_config":
721
+ if value:
722
+ additional_lora_config = json.dumps(value, indent=2)
723
+ else:
724
+ additional_lora_config = ""
725
  elif key == "group_by_length":
726
  pass
727
  elif key == "resume_from_checkpoint":
 
764
  save_steps,
765
  save_total_limit,
766
  logging_steps,
767
+ additional_training_arguments,
768
+ additional_lora_config,
769
  lora_target_module_choices,
770
  lora_modules_to_save_choices
771
  )
 
964
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
965
  )
966
 
967
+ with gr.Column(elem_id="finetune_eval_data_group"):
968
  evaluate_data_count = gr.Slider(
969
  minimum=0, maximum=1, step=1, value=0,
970
  label="Evaluation Data Count",
971
  info="The number of data to be used for evaluation. This specific amount of data will be randomly chosen from the training dataset for evaluating the model's performance during the process, without contributing to the actual training.",
972
  elem_id="finetune_evaluate_data_count"
973
  )
974
+ gr.HTML(elem_classes="flex_vertical_grow_area")
975
 
976
  with gr.Box(elem_id="finetune_continue_from_model_box"):
977
  with gr.Row():
 
1015
  bf16 = gr.Checkbox(label="BF16", value=False)
1016
  gradient_checkpointing = gr.Checkbox(
1017
  label="gradient_checkpointing", value=False)
1018
+ with gr.Column(variant="panel", elem_id="finetune_additional_training_arguments_box"):
1019
+ gr.Textbox(
1020
+ label="Additional Training Arguments",
1021
+ info="Additional training arguments to be passed to the Trainer in JSON format. Note that this can override ALL other arguments set elsewhere. See https://bit.ly/hf20-transformers-training-arguments for more details.",
1022
+ elem_id="finetune_additional_training_arguments_textbox_for_label_display"
1023
+ )
1024
+ additional_training_arguments = gr.Code(
1025
+ show_label=False,
1026
+ language="json",
1027
+ value="",
1028
+ # lines=2,
1029
+ elem_id="finetune_additional_training_arguments")
1030
 
1031
  with gr.Column():
1032
  lora_r = gr.Slider(
 
1108
  lora_modules_to_save_add, lora_modules_to_save],
1109
  ))
1110
 
1111
+ with gr.Column(variant="panel", elem_id="finetune_additional_lora_config_box"):
1112
+ gr.Textbox(
1113
+ label="Additional LoRA Config",
1114
+ info="Additional LoraConfig in JSON format. Note that this can override ALL other arguments set elsewhere.",
1115
+ elem_id="finetune_additional_lora_config_textbox_for_label_display"
1116
+ )
1117
+ additional_lora_config = gr.Code(
1118
+ show_label=False,
1119
+ language="json",
1120
+ value="",
1121
+ # lines=2,
1122
+ elem_id="finetune_additional_lora_config")
1123
+
1124
+ gr.HTML(elem_classes="flex_vertical_grow_area no_limit")
1125
 
1126
  with gr.Column(elem_id="finetune_log_and_save_options_group_container"):
1127
  with gr.Row(elem_id="finetune_log_and_save_options_group"):
 
1220
  save_steps,
1221
  save_total_limit,
1222
  logging_steps,
1223
+ additional_training_arguments,
1224
+ additional_lora_config,
1225
  ]
1226
 
1227
  things_that_might_timeout.append(
llama_lora/ui/main_page.py CHANGED
@@ -250,6 +250,15 @@ def main_page_custom_css():
250
  display: none;
251
  }
252
 
 
 
 
 
 
 
 
 
 
253
  #page_title {
254
  flex-grow: 3;
255
  }
@@ -808,13 +817,32 @@ def main_page_custom_css():
808
  }
809
 
810
  #finetune_log_and_save_options_group_container {
811
- flex-grow: 1 !important;
812
- justify-content: flex-end;
813
  }
814
  #finetune_model_name_group {
815
  flex-grow: 0 !important;
816
  }
817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818
  @media screen and (max-width: 392px) {
819
  #inference_lora_model, #inference_lora_model_group, #finetune_template {
820
  border-bottom-left-radius: 0;
 
250
  display: none;
251
  }
252
 
253
+ .flex_vertical_grow_area {
254
+ margin-top: calc(var(--layout-gap) * -1) !important;
255
+ flex-grow: 1 !important;
256
+ max-height: calc(var(--layout-gap) * 2);
257
+ }
258
+ .flex_vertical_grow_area.no_limit {
259
+ max-height: unset;
260
+ }
261
+
262
  #page_title {
263
  flex-grow: 3;
264
  }
 
817
  }
818
 
819
  #finetune_log_and_save_options_group_container {
820
+ flex-grow: 0 !important;
 
821
  }
822
  #finetune_model_name_group {
823
  flex-grow: 0 !important;
824
  }
825
 
826
+ #finetune_eval_data_group {
827
+ flex-grow: 0 !important;
828
+ }
829
+
830
+ #finetune_additional_training_arguments_box > .form,
831
+ #finetune_additional_lora_config_box > .form {
832
+ border: 0;
833
+ background: transparent;
834
+ }
835
+ #finetune_additional_training_arguments_textbox_for_label_display,
836
+ #finetune_additional_lora_config_textbox_for_label_display {
837
+ padding: 0;
838
+ margin-bottom: -10px;
839
+ background: transparent;
840
+ }
841
+ #finetune_additional_training_arguments_textbox_for_label_display textarea,
842
+ #finetune_additional_lora_config_textbox_for_label_display textarea {
843
+ display: none;
844
+ }
845
+
846
  @media screen and (max-width: 392px) {
847
  #inference_lora_model, #inference_lora_model_group, #finetune_template {
848
  border-bottom-left-radius: 0;