zetavg commited on
Commit
92a8f77
β€’
1 Parent(s): a1c44f4

finetune: support continue from models on HF hub

Browse files
Files changed (1) hide show
  1. llama_lora/ui/finetune_ui.py +36 -11
llama_lora/ui/finetune_ui.py CHANGED
@@ -9,6 +9,7 @@ import math
9
  from random_word import RandomWords
10
 
11
  from transformers import TrainerCallback
 
12
 
13
  from ..globals import Global
14
  from ..models import (
@@ -313,28 +314,49 @@ def do_train(
313
  base_model_name = Global.base_model_name
314
  tokenizer_name = Global.tokenizer_name or Global.base_model_name
315
 
316
- resume_from_checkpoint = None
317
  if continue_from_model == "-" or continue_from_model == "None":
318
  continue_from_model = None
319
  if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
320
  continue_from_checkpoint = None
321
  if continue_from_model:
322
- resume_from_checkpoint = os.path.join(
323
  Global.data_dir, "lora_models", continue_from_model)
 
324
  if continue_from_checkpoint:
325
- resume_from_checkpoint = os.path.join(
326
- resume_from_checkpoint, continue_from_checkpoint)
327
  will_be_resume_from_checkpoint_file = os.path.join(
328
- resume_from_checkpoint, "pytorch_model.bin")
329
  if not os.path.exists(will_be_resume_from_checkpoint_file):
330
  raise ValueError(
331
  f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
332
  else:
333
  will_be_resume_from_checkpoint_file = os.path.join(
334
- resume_from_checkpoint, "adapter_model.bin")
335
  if not os.path.exists(will_be_resume_from_checkpoint_file):
336
- raise ValueError(
337
- f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
340
  if os.path.exists(output_dir):
@@ -400,6 +422,7 @@ Train options: {json.dumps({
400
  'model_name': model_name,
401
  'continue_from_model': continue_from_model,
402
  'continue_from_checkpoint': continue_from_checkpoint,
 
403
  }, indent=2)}
404
 
405
  Train data (first 10):
@@ -539,7 +562,7 @@ Train data (first 10):
539
  bf16=bf16,
540
  gradient_checkpointing=gradient_checkpointing,
541
  group_by_length=False,
542
- resume_from_checkpoint=resume_from_checkpoint,
543
  save_steps=save_steps,
544
  save_total_limit=save_total_limit,
545
  logging_steps=logging_steps,
@@ -937,6 +960,7 @@ def finetune_ui():
937
  value="-",
938
  label="Continue from Model",
939
  choices=["-"],
 
940
  elem_id="finetune_continue_from_model"
941
  )
942
  continue_from_checkpoint = gr.Dropdown(
@@ -970,7 +994,8 @@ def finetune_ui():
970
  load_in_8bit = gr.Checkbox(label="8bit", value=False)
971
  fp16 = gr.Checkbox(label="FP16", value=True)
972
  bf16 = gr.Checkbox(label="BF16", value=False)
973
- gradient_checkpointing = gr.Checkbox(label="gradient_checkpointing", value=False)
 
974
 
975
  with gr.Column():
976
  lora_r = gr.Slider(
@@ -1310,7 +1335,7 @@ def finetune_ui():
1310
  delay: [500, 0],
1311
  animation: 'scale-subtle',
1312
  content:
1313
- 'Select a LoRA model to train a new model on top of that model.<br /><br />πŸ’‘ To use the same training parameters of a previously trained model, select it here and click the <code>Load training parameters from selected model</code> button, then un-select it.',
1314
  allowHTML: true,
1315
  });
1316
 
 
9
  from random_word import RandomWords
10
 
11
  from transformers import TrainerCallback
12
+ from huggingface_hub import try_to_load_from_cache, snapshot_download
13
 
14
  from ..globals import Global
15
  from ..models import (
 
314
  base_model_name = Global.base_model_name
315
  tokenizer_name = Global.tokenizer_name or Global.base_model_name
316
 
317
+ resume_from_checkpoint_param = None
318
  if continue_from_model == "-" or continue_from_model == "None":
319
  continue_from_model = None
320
  if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
321
  continue_from_checkpoint = None
322
  if continue_from_model:
323
+ resume_from_model_path = os.path.join(
324
  Global.data_dir, "lora_models", continue_from_model)
325
+ resume_from_checkpoint_param = resume_from_model_path
326
  if continue_from_checkpoint:
327
+ resume_from_checkpoint_param = os.path.join(
328
+ resume_from_checkpoint_param, continue_from_checkpoint)
329
  will_be_resume_from_checkpoint_file = os.path.join(
330
+ resume_from_checkpoint_param, "pytorch_model.bin")
331
  if not os.path.exists(will_be_resume_from_checkpoint_file):
332
  raise ValueError(
333
  f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
334
  else:
335
  will_be_resume_from_checkpoint_file = os.path.join(
336
+ resume_from_checkpoint_param, "adapter_model.bin")
337
  if not os.path.exists(will_be_resume_from_checkpoint_file):
338
+ # Try to get model in Hugging Face cache
339
+ resume_from_checkpoint_param = None
340
+ possible_hf_model_name = None
341
+ possible_model_info_file = os.path.join(
342
+ resume_from_model_path, "info.json")
343
+ if "/" in continue_from_model:
344
+ possible_hf_model_name = continue_from_model
345
+ elif os.path.exists(possible_model_info_file):
346
+ with open(possible_model_info_file, "r") as file:
347
+ model_info = json.load(file)
348
+ possible_hf_model_name = model_info.get("hf_model_name")
349
+ if possible_hf_model_name:
350
+ possible_hf_model_cached_path = try_to_load_from_cache(possible_hf_model_name, 'adapter_model.bin')
351
+ if not possible_hf_model_cached_path:
352
+ snapshot_download(possible_hf_model_name)
353
+ possible_hf_model_cached_path = try_to_load_from_cache(possible_hf_model_name, 'adapter_model.bin')
354
+ if possible_hf_model_cached_path:
355
+ resume_from_checkpoint_param = os.path.dirname(possible_hf_model_cached_path)
356
+
357
+ if not resume_from_checkpoint_param:
358
+ raise ValueError(
359
+ f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
360
 
361
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
362
  if os.path.exists(output_dir):
 
422
  'model_name': model_name,
423
  'continue_from_model': continue_from_model,
424
  'continue_from_checkpoint': continue_from_checkpoint,
425
+ 'resume_from_checkpoint_param': resume_from_checkpoint_param,
426
  }, indent=2)}
427
 
428
  Train data (first 10):
 
562
  bf16=bf16,
563
  gradient_checkpointing=gradient_checkpointing,
564
  group_by_length=False,
565
+ resume_from_checkpoint=resume_from_checkpoint_param,
566
  save_steps=save_steps,
567
  save_total_limit=save_total_limit,
568
  logging_steps=logging_steps,
 
960
  value="-",
961
  label="Continue from Model",
962
  choices=["-"],
963
+ allow_custom_value=True,
964
  elem_id="finetune_continue_from_model"
965
  )
966
  continue_from_checkpoint = gr.Dropdown(
 
994
  load_in_8bit = gr.Checkbox(label="8bit", value=False)
995
  fp16 = gr.Checkbox(label="FP16", value=True)
996
  bf16 = gr.Checkbox(label="BF16", value=False)
997
+ gradient_checkpointing = gr.Checkbox(
998
+ label="gradient_checkpointing", value=False)
999
 
1000
  with gr.Column():
1001
  lora_r = gr.Slider(
 
1335
  delay: [500, 0],
1336
  animation: 'scale-subtle',
1337
  content:
1338
+ 'Select a LoRA model to train a new model on top of that model. You can also type in a model name on Hugging Face Hub, such as <code>tloen/alpaca-lora-7b</code>.<br /><br />πŸ’‘ To reload the training parameters of one of your previously trained models, select it here and click the <code>Load training parameters from selected model</code> button, then un-select it.',
1339
  allowHTML: true,
1340
  });
1341