multimodalart HF staff commited on
Commit
09e5977
1 Parent(s): 6deb349

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -160,6 +160,7 @@ def recursive_update(d, u):
160
  def start_training(
161
  lora_name,
162
  concept_sentence,
 
163
  steps,
164
  lr,
165
  rank,
@@ -224,7 +225,12 @@ def start_training(
224
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
225
  else:
226
  config["config"]["process"][0]["train"]["disable_sampling"] = True
227
-
 
 
 
 
 
228
  if(use_more_advanced_options):
229
  more_advanced_options_dict = yaml.safe_load(more_advanced_options)
230
  config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
@@ -291,11 +297,13 @@ def update_pricing(steps, oauth_token: Union[gr.OAuthToken, None]):
291
  else:
292
  return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
293
 
 
 
 
294
  config_yaml = '''
295
  device: cuda:0
296
  model:
297
  is_flux: true
298
- name_or_path: black-forest-labs/FLUX.1-dev
299
  quantize: true
300
  network:
301
  linear: 16 #it will overcome the 'rank' parameter
@@ -342,6 +350,7 @@ h3{margin-top: 0}
342
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
343
  .tabitem{border: 0px}
344
  .group_padding{padding: .55em}
 
345
  """
346
  with gr.Blocks(theme=theme, css=css) as demo:
347
  gr.Markdown(
@@ -352,18 +361,22 @@ with gr.Blocks(theme=theme, css=css) as demo:
352
  gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
353
  with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
354
  with gr.Column() as main_ui:
355
- with gr.Row():
356
- lora_name = gr.Textbox(
357
- label="The name of your LoRA",
358
- info="This has to be a unique name",
359
- placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
360
- )
361
- concept_sentence = gr.Textbox(
362
- label="Trigger word/sentence",
363
- info="Trigger word or sentence to be used",
364
- placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
365
- interactive=True,
366
- )
 
 
 
 
367
  with gr.Group(visible=True) as image_upload:
368
  with gr.Row():
369
  images = gr.File(
@@ -503,12 +516,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
503
  inputs=[steps],
504
  outputs=[cost_preview, cost_preview_info, payment_update, start]
505
  )
506
-
 
 
 
 
 
507
  start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
508
  fn=start_training,
509
  inputs=[
510
  lora_name,
511
  concept_sentence,
 
512
  steps,
513
  lr,
514
  rank,
 
160
  def start_training(
161
  lora_name,
162
  concept_sentence,
163
+ which_model,
164
  steps,
165
  lr,
166
  rank,
 
225
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
226
  else:
227
  config["config"]["process"][0]["train"]["disable_sampling"] = True
228
+
229
+ if(which_model == "[schnell] (4 step fast model)"):
230
+ config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
231
+ config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
232
+ config["config"]["process"][0]["sample"]["sample_steps"] = 4
233
+
234
  if(use_more_advanced_options):
235
  more_advanced_options_dict = yaml.safe_load(more_advanced_options)
236
  config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
 
297
  else:
298
  return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
299
 
300
+ def swap_base_model(model):
301
+ return gr.update(visible=True) if model == "[dev] (high quality model, non-commercial license)" else gr.update(visible=False)
302
+
303
  config_yaml = '''
304
  device: cuda:0
305
  model:
306
  is_flux: true
 
307
  quantize: true
308
  network:
309
  linear: 16 #it will overcome the 'rank' parameter
 
350
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
351
  .tabitem{border: 0px}
352
  .group_padding{padding: .55em}
353
+ #space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
354
  """
355
  with gr.Blocks(theme=theme, css=css) as demo:
356
  gr.Markdown(
 
361
  gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
362
  with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
363
  with gr.Column() as main_ui:
364
+ with gr.Group():
365
+ with gr.Row():
366
+ lora_name = gr.Textbox(
367
+ label="The name of your LoRA",
368
+ info="This has to be a unique name",
369
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
370
+ )
371
+ concept_sentence = gr.Textbox(
372
+ label="Trigger word/sentence",
373
+ info="Trigger word or sentence to be used",
374
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
375
+ interactive=True,
376
+ )
377
+ which_model = gr.Radio(["[schnell] (4 step fast model)", "[dev] (high quality model, non-commercial license - available when training locally)"], label="Which base model to train?", elem_id="space_model" if is_spaces else "local_model", value="[schnell] (4 step fast model)",)
378
+ model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
379
+ """, visible=False)
380
  with gr.Group(visible=True) as image_upload:
381
  with gr.Row():
382
  images = gr.File(
 
516
  inputs=[steps],
517
  outputs=[cost_preview, cost_preview_info, payment_update, start]
518
  )
519
+
520
+ which_model.change(
521
+ fn=swap_base_model,
522
+ inputs=which_model,
523
+ outputs=model_warning
524
+ )
525
  start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
526
  fn=start_training,
527
  inputs=[
528
  lora_name,
529
  concept_sentence,
530
+ which_model,
531
  steps,
532
  lr,
533
  rank,