multimodalart HF staff commited on
Commit
03b43e9
·
1 Parent(s): 02a9af1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -44
app.py CHANGED
@@ -18,16 +18,14 @@ from pathlib import Path
18
  MAX_IMAGES = 50
19
 
20
  training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
21
- subprocess.run(['wget', training_script_url])
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  FACES_DATASET_PATH = snapshot_download(repo_id="multimodalart/faces-prior-preservation", repo_type="dataset")
26
-
27
  #Delete .gitattributes to process things properly
28
  Path(FACES_DATASET_PATH, '.gitattributes').unlink(missing_ok=True)
29
 
30
-
31
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
32
  model = Blip2ForConditionalGeneration.from_pretrained(
33
  "Salesforce/blip2-opt-2.7b", device_map={"": 0}, torch_dtype=torch.float16
@@ -287,11 +285,22 @@ git+https://github.com/huggingface/datasets.git'''
287
  # The subprocess call for autotrain spacerunner
288
  api = HfApi(token=token)
289
  username = api.whoami()["name"]
290
- subprocess_command = ["autotrain", "spacerunner", "--project-name", slugged_lora_name, "--script-path", spacerunner_folder, "--username", username, "--token", token, "--backend", "spaces-a10gl", "--env","HF_TOKEN=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;HF_HUB_ENABLE_HF_TRANSFER=1", "--args", spacerunner_args]
291
  print(subprocess_command)
292
  subprocess.run(subprocess_command)
293
- return f"<h2>Your training has started. Run over to <a href='https://huggingface.co/spaces/{username}/autotrain-{slugged_lora_name}?logs=container'>{username}/autotrain-{slugged_lora_name}</a> to check the status (click the logs tab)</h2>"
 
 
294
 
 
 
 
 
 
 
 
 
 
295
  def start_training_og(
296
  lora_name,
297
  training_option,
@@ -443,23 +452,41 @@ def run_captioning(*inputs):
443
  def check_token(token):
444
  try:
445
  api = HfApi(token=token)
 
446
  except Exception as e:
447
- gr.Warning("Invalid user token. Make sure to get your Hugging Face")
448
  else:
449
- user_data = api.whoami()
450
- if (username['auth']['accessToken']['role'] != "write"):
451
  gr.Warning("Oops, you've uploaded a `Read` token. You need to use a Write token!")
452
  else:
453
  if user_data['canPay']:
454
  return gr.update(visible=False), gr.update(visible=True)
455
  else:
 
456
  return gr.update(visible=True), gr.update(visible=False)
457
 
458
  return gr.update(visible=False), gr.update(visible=False)
459
 
460
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  dataset_folder = gr.State()
462
- gr.Markdown("# SDXL LoRA Dreambooth Training")
 
 
463
  lora_name = gr.Textbox(label="The name of your LoRA", placeholder="e.g.: Persian Miniature Painting style, Cat Toy")
464
  training_option = gr.Radio(
465
  label="What are you training?", choices=["object", "style", "face", "custom"]
@@ -496,7 +523,7 @@ To improve the quality of your outputs, you can add a custom caption for each im
496
  with locals()[f"captioning_row_{i}"]:
497
  locals()[f"image_{i}"] = gr.Image(
498
  width=64,
499
- height=64,
500
  min_width=64,
501
  interactive=False,
502
  scale=1,
@@ -544,7 +571,6 @@ To improve the quality of your outputs, you can add a custom caption for each im
544
  step=0.0000001,
545
  value=1.0, # For prodigy you start high and it will optimize down
546
  )
547
- train_batch_size = gr.Number(label="Train batch size", value=2)
548
  max_train_steps = gr.Number(
549
  label="Max train steps", minimum=1, maximum=50000, value=1000
550
  )
@@ -589,7 +615,7 @@ To improve the quality of your outputs, you can add a custom caption for each im
589
  train_text_encoder_ti = gr.Checkbox(
590
  label="Do textual inversion",
591
  value=True,
592
- info="Will train a textual inversion embedding together with the LoRA. Increases quality significantly.",
593
  )
594
  with gr.Group(visible=True) as pivotal_tuning_params:
595
  train_text_encoder_ti_frac = gr.Number(
@@ -633,27 +659,48 @@ To improve the quality of your outputs, you can add a custom caption for each im
633
  with gr.Accordion(open=False, label="Even more advanced options"):
634
  with gr.Row():
635
  with gr.Column():
636
- num_train_epochs = gr.Number(label="num_train_epochs", value=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
  checkpointing_steps = gr.Number(
638
- label="checkpointing_steps", value=5000
 
 
639
  )
640
- prior_loss_weight = gr.Number(label="prior_loss_weight", value=1)
641
- gradient_accumulation_steps = gr.Number(
642
- label="gradient_accumulation_steps", value=1
643
  )
644
  gradient_checkpointing = gr.Checkbox(
645
  label="gradient_checkpointing",
646
  info="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass",
647
  value=True,
648
  )
649
- enable_xformers_memory_efficient_attention = gr.Checkbox(
650
- label="enable_xformers_memory_efficient_attention"
651
- )
652
  adam_beta1 = gr.Number(
653
- label="adam_beta1", value=0.9, minimum=0, maximum=1, step=0.01
 
 
 
 
654
  )
655
  adam_beta2 = gr.Number(
656
- label="adam_beta2", minimum=0, maximum=1, step=0.01, value=0.99
 
 
 
 
657
  )
658
  prodigy_beta3 = gr.Number(
659
  label="Prodigy Beta 3",
@@ -685,10 +732,12 @@ To improve the quality of your outputs, you can add a custom caption for each im
685
  maximum=1,
686
  )
687
  prodigy_use_bias_correction = gr.Checkbox(
688
- label="Prodigy Use Bias Correction", value=True
 
689
  )
690
  prodigy_safeguard_warmup = gr.Checkbox(
691
- label="Prodigy Safeguard Warmup", value=True
 
692
  )
693
  max_grad_norm = gr.Number(
694
  label="Max Grad Norm",
@@ -697,12 +746,18 @@ To improve the quality of your outputs, you can add a custom caption for each im
697
  maximum=10,
698
  step=0.1,
699
  )
 
 
 
700
  with gr.Column():
701
  scale_lr = gr.Checkbox(
702
  label="Scale learning rate",
703
  info="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size",
704
  )
705
- lr_num_cycles = gr.Number(label="lr_num_cycles", value=1)
 
 
 
706
  lr_scheduler = gr.Dropdown(
707
  label="lr_scheduler",
708
  choices=[
@@ -716,25 +771,32 @@ To improve the quality of your outputs, you can add a custom caption for each im
716
  value="constant",
717
  )
718
  lr_power = gr.Number(
719
- label="lr_power", value=1.0, minimum=0.1, maximum=10
 
 
 
 
 
 
 
720
  )
721
- lr_warmup_steps = gr.Number(label="lr_warmup_steps", value=0)
722
  dataloader_num_workers = gr.Number(
723
  label="Dataloader num workers", value=0, minimum=0, maximum=64
724
  )
725
- local_rank = gr.Number(label="local_rank", value=-1)
726
- with gr.Row(visible=False) as cost_estimation:
727
- with gr.Group():
728
- gr.Markdown('''### This training is estimated to cost <b>< US$ 1,50</b> with your current train settings
729
- Grab a Hugging Face <b>write</b> token [here](https://huggingface.co/settings/tokens)
730
- ''')
731
- token = gr.Textbox(label="Your Hugging Face write token", info="A Hugging Face write token you can obtain on the settings page")
 
732
  with gr.Group(visible=False) as no_payment_method:
733
  with gr.Row():
734
- gr.Markdown("Your Hugging Face account doesn't have a payment method. Set it up [here](https://huggingface.co/settings/billing/payment) to train your LoRA")
735
  payment_setup = gr.Button("I have set up my payment method")
736
- start = gr.Button("Start training", visible=False)
737
- progress_area = gr.HTML("")
738
  output_components.insert(1, advanced)
739
  output_components.insert(1, cost_estimation)
740
 
@@ -745,13 +807,14 @@ Grab a Hugging Face <b>write</b> token [here](https://huggingface.co/settings/to
745
  ],
746
  fn=check_token,
747
  inputs=token,
748
- outputs=[no_payment_method, start]
 
749
  )
750
  use_snr_gamma.change(
751
  lambda x: gr.update(visible=x),
752
  inputs=use_snr_gamma,
753
  outputs=snr_gamma,
754
- queue=False,
755
  )
756
  with_prior_preservation.change(
757
  lambda x: gr.update(visible=x),
@@ -783,26 +846,39 @@ Grab a Hugging Face <b>write</b> token [here](https://huggingface.co/settings/to
783
  queue=False
784
  )
785
  images.upload(
786
- load_captioning, inputs=[images, concept_sentence], outputs=output_components
 
 
 
787
  ).then(
788
  change_defaults,
789
  inputs=[training_option, images],
790
- outputs=[max_train_steps, repeats, lr_scheduler, lora_rank, with_prior_preservation, class_prompt, class_images]
 
791
  )
792
  images.change(
793
  check_removed_and_restart,
794
  inputs=[images],
795
  outputs=[captioning_area, advanced, cost_estimation],
 
796
  )
797
  training_option.change(
798
  make_options_visible,
799
  inputs=training_option,
800
  outputs=[concept_sentence, image_upload],
 
 
 
 
 
 
 
801
  )
802
  start.click(
803
  fn=create_dataset,
804
  inputs=[images] + caption_list,
805
- outputs=dataset_folder
 
806
  ).then(
807
  fn=start_training,
808
  inputs=[
@@ -856,7 +932,8 @@ Grab a Hugging Face <b>write</b> token [here](https://huggingface.co/settings/to
856
  dataset_folder,
857
  token
858
  ],
859
- outputs = progress_area
 
860
  )
861
 
862
  do_captioning.click(
 
18
  MAX_IMAGES = 50
19
 
20
  training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
21
+ subprocess.run(['wget', '-N', training_script_url])
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  FACES_DATASET_PATH = snapshot_download(repo_id="multimodalart/faces-prior-preservation", repo_type="dataset")
 
26
  #Delete .gitattributes to process things properly
27
  Path(FACES_DATASET_PATH, '.gitattributes').unlink(missing_ok=True)
28
 
 
29
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
30
  model = Blip2ForConditionalGeneration.from_pretrained(
31
  "Salesforce/blip2-opt-2.7b", device_map={"": 0}, torch_dtype=torch.float16
 
285
  # The subprocess call for autotrain spacerunner
286
  api = HfApi(token=token)
287
  username = api.whoami()["name"]
288
+ subprocess_command = ["autotrain", "spacerunner", "--project-name", slugged_lora_name, "--script-path", spacerunner_folder, "--username", username, "--token", token, "--backend", "spaces-a10gs", "--env","HF_TOKEN=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;HF_HUB_ENABLE_HF_TRANSFER=1", "--args", spacerunner_args]
289
  print(subprocess_command)
290
  subprocess.run(subprocess_command)
291
+ return f"""# Your training has started.
292
+ ## - Model page: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a> <small>(the model will be available when training finishes)</small>
293
+ ## - Training Status: <a href='https://huggingface.co/spaces/{username}/autotrain-{slugged_lora_name}?logs=container'>{username}/autotrain-{slugged_lora_name}</a> <small>(in the logs tab)</small>"""
294
 
295
+ def calculate_price(iterations):
296
+ seconds_per_iteration = 3.50
297
+ total_seconds = (iterations * seconds_per_iteration) + 210
298
+ cost_per_second = 1.05/60/60
299
+ cost = round(cost_per_second * total_seconds, 2)
300
+ return f'''To train this LoRA, we will duplicate the space and hook an A10G GPU under the hood.
301
+ ## Estimated to cost <b>< US$ {str(cost)}</b> with your current train settings <small>({int(iterations)} iterations at 3.50s/it in Spaces A10G at US$1.05/h)</small>
302
+ #### Grab a <b>write</b> token [here](https://huggingface.co/settings/tokens), enter it below ↓'''
303
+
304
  def start_training_og(
305
  lora_name,
306
  training_option,
 
452
  def check_token(token):
453
  try:
454
  api = HfApi(token=token)
455
+ user_data = api.whoami()
456
  except Exception as e:
457
+ raise gr.Warning("Invalid user token. Make sure to get your Hugging Face token from the settings page")
458
  else:
459
+ if (user_data['auth']['accessToken']['role'] != "write"):
 
460
  gr.Warning("Oops, you've uploaded a `Read` token. You need to use a Write token!")
461
  else:
462
  if user_data['canPay']:
463
  return gr.update(visible=False), gr.update(visible=True)
464
  else:
465
+ gr.Warning("Your payment methods aren't set up. You gotta set them up to start training")
466
  return gr.update(visible=True), gr.update(visible=False)
467
 
468
  return gr.update(visible=False), gr.update(visible=False)
469
 
470
+ css = '''.gr-group{background-color: transparent}
471
+ .gr-group .hide-container{padding: 1em; background: var(--block-background-fill) !important}
472
+ .gr-group img{object-fit: cover}
473
+ #main_title{text-align:center}
474
+ #main_title h1 {font-size: 2.25rem}
475
+ #main_title h3, #main_title p{margin-top: 0;font-size: 1.25em}
476
+ #training_cost h2{margin-top: 10px;padding: 0.5em;border: 1px solid var(--block-border-color);font-size: 1.25em}
477
+ #training_cost h4{margin-top: 1.25em;margin-bottom: 0}
478
+ #training_cost small{font-weight: normal}
479
+
480
+ '''
481
+ theme = gr.themes.Monochrome(
482
+ text_size="lg",
483
+ font=[gr.themes.GoogleFont('Source Sans Pro'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
484
+ )
485
+ with gr.Blocks(css=css, theme=theme) as demo:
486
  dataset_folder = gr.State()
487
+ gr.Markdown('''# Dreambooth Ease 🧞‍♂️
488
+ ### Train a high quality Dreambooth SDXL LoRA in a breeze ༄, using state-of-the-art techniques
489
+ <small>[blog about the training script](#), [Colab Pro](#), [run locally or in a cloud](#)</small>''', elem_id="main_title")
490
  lora_name = gr.Textbox(label="The name of your LoRA", placeholder="e.g.: Persian Miniature Painting style, Cat Toy")
491
  training_option = gr.Radio(
492
  label="What are you training?", choices=["object", "style", "face", "custom"]
 
523
  with locals()[f"captioning_row_{i}"]:
524
  locals()[f"image_{i}"] = gr.Image(
525
  width=64,
526
+ height=111,
527
  min_width=64,
528
  interactive=False,
529
  scale=1,
 
571
  step=0.0000001,
572
  value=1.0, # For prodigy you start high and it will optimize down
573
  )
 
574
  max_train_steps = gr.Number(
575
  label="Max train steps", minimum=1, maximum=50000, value=1000
576
  )
 
615
  train_text_encoder_ti = gr.Checkbox(
616
  label="Do textual inversion",
617
  value=True,
618
+ info="Will train a textual inversion embedding together with the LoRA. Increases quality significantly. If untoggled, you can remove the special TOK token from the prompts.",
619
  )
620
  with gr.Group(visible=True) as pivotal_tuning_params:
621
  train_text_encoder_ti_frac = gr.Number(
 
659
  with gr.Accordion(open=False, label="Even more advanced options"):
660
  with gr.Row():
661
  with gr.Column():
662
+ gradient_accumulation_steps = gr.Number(
663
+ info="If you change this setting, the pricing calculation will be wrong",
664
+ label="gradient_accumulation_steps",
665
+ value=1
666
+ )
667
+ train_batch_size = gr.Number(
668
+ info="If you change this setting, the pricing calculation will be wrong",
669
+ label="Train batch size",
670
+ value=2
671
+ )
672
+ num_train_epochs = gr.Number(
673
+ info="If you change this setting, the pricing calculation will be wrong",
674
+ label="num_train_epochs",
675
+ value=1
676
+ )
677
  checkpointing_steps = gr.Number(
678
+ info="How many steps to save intermediate checkpoints",
679
+ label="checkpointing_steps",
680
+ value=5000
681
  )
682
+ prior_loss_weight = gr.Number(
683
+ label="prior_loss_weight",
684
+ value=1
685
  )
686
  gradient_checkpointing = gr.Checkbox(
687
  label="gradient_checkpointing",
688
  info="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass",
689
  value=True,
690
  )
 
 
 
691
  adam_beta1 = gr.Number(
692
+ label="adam_beta1",
693
+ value=0.9,
694
+ minimum=0,
695
+ maximum=1,
696
+ step=0.01
697
  )
698
  adam_beta2 = gr.Number(
699
+ label="adam_beta2",
700
+ minimum=0,
701
+ maximum=1,
702
+ step=0.01,
703
+ value=0.99
704
  )
705
  prodigy_beta3 = gr.Number(
706
  label="Prodigy Beta 3",
 
732
  maximum=1,
733
  )
734
  prodigy_use_bias_correction = gr.Checkbox(
735
+ label="Prodigy Use Bias Correction",
736
+ value=True
737
  )
738
  prodigy_safeguard_warmup = gr.Checkbox(
739
+ label="Prodigy Safeguard Warmup",
740
+ value=True
741
  )
742
  max_grad_norm = gr.Number(
743
  label="Max Grad Norm",
 
746
  maximum=10,
747
  step=0.1,
748
  )
749
+ enable_xformers_memory_efficient_attention = gr.Checkbox(
750
+ label="enable_xformers_memory_efficient_attention"
751
+ )
752
  with gr.Column():
753
  scale_lr = gr.Checkbox(
754
  label="Scale learning rate",
755
  info="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size",
756
  )
757
+ lr_num_cycles = gr.Number(
758
+ label="lr_num_cycles",
759
+ value=1
760
+ )
761
  lr_scheduler = gr.Dropdown(
762
  label="lr_scheduler",
763
  choices=[
 
771
  value="constant",
772
  )
773
  lr_power = gr.Number(
774
+ label="lr_power",
775
+ value=1.0,
776
+ minimum=0.1,
777
+ maximum=10
778
+ )
779
+ lr_warmup_steps = gr.Number(
780
+ label="lr_warmup_steps",
781
+ value=0
782
  )
 
783
  dataloader_num_workers = gr.Number(
784
  label="Dataloader num workers", value=0, minimum=0, maximum=64
785
  )
786
+ local_rank = gr.Number(
787
+ label="local_rank",
788
+ value=-1
789
+ )
790
+ with gr.Column(visible=False) as cost_estimation:
791
+ with gr.Group(elem_id="cost_box"):
792
+ training_cost_estimate = gr.Markdown(elem_id="training_cost")
793
+ token = gr.Textbox(label="Your Hugging Face write token", info="A Hugging Face write token you can obtain on the settings page", type="password", placeholder="hf_OhHiThIsIsNoTaReALToKeNGOoDTry")
794
  with gr.Group(visible=False) as no_payment_method:
795
  with gr.Row():
796
+ gr.Markdown("## Your Hugging Face account doesn't have a payment method. Set it up [here](https://huggingface.co/settings/billing/payment) to train your LoRA")
797
  payment_setup = gr.Button("I have set up my payment method")
798
+ start = gr.Button("Start training", visible=False, interactive=True)
799
+ progress_area = gr.Markdown("")
800
  output_components.insert(1, advanced)
801
  output_components.insert(1, cost_estimation)
802
 
 
807
  ],
808
  fn=check_token,
809
  inputs=token,
810
+ outputs=[no_payment_method, start],
811
+ queue=False
812
  )
813
  use_snr_gamma.change(
814
  lambda x: gr.update(visible=x),
815
  inputs=use_snr_gamma,
816
  outputs=snr_gamma,
817
+ queue=False
818
  )
819
  with_prior_preservation.change(
820
  lambda x: gr.update(visible=x),
 
846
  queue=False
847
  )
848
  images.upload(
849
+ load_captioning,
850
+ inputs=[images, concept_sentence],
851
+ outputs=output_components,
852
+ queue=False
853
  ).then(
854
  change_defaults,
855
  inputs=[training_option, images],
856
+ outputs=[max_train_steps, repeats, lr_scheduler, lora_rank, with_prior_preservation, class_prompt, class_images],
857
+ queue=False
858
  )
859
  images.change(
860
  check_removed_and_restart,
861
  inputs=[images],
862
  outputs=[captioning_area, advanced, cost_estimation],
863
+ queue=False
864
  )
865
  training_option.change(
866
  make_options_visible,
867
  inputs=training_option,
868
  outputs=[concept_sentence, image_upload],
869
+ queue=False
870
+ )
871
+ max_train_steps.change(
872
+ calculate_price,
873
+ inputs=[max_train_steps],
874
+ outputs=[training_cost_estimate],
875
+ queue=False
876
  )
877
  start.click(
878
  fn=create_dataset,
879
  inputs=[images] + caption_list,
880
+ outputs=dataset_folder,
881
+ queue=False
882
  ).then(
883
  fn=start_training,
884
  inputs=[
 
932
  dataset_folder,
933
  token
934
  ],
935
+ outputs = progress_area,
936
+ queue=False
937
  )
938
 
939
  do_captioning.click(