multimodalart HF staff commited on
Commit
2406cac
·
verified ·
1 Parent(s): 09e5977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import subprocess
3
  from typing import Union
4
  from huggingface_hub import whoami
5
  is_spaces = True if os.environ.get("SPACE_ID") else False
 
6
 
7
  if is_spaces:
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -226,7 +227,7 @@ def start_training(
226
  else:
227
  config["config"]["process"][0]["train"]["disable_sampling"] = True
228
 
229
- if(which_model == "[schnell] (4 step fast model)"):
230
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
231
  config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
232
  config["config"]["process"][0]["sample"]["sample_steps"] = 4
@@ -374,7 +375,13 @@ with gr.Blocks(theme=theme, css=css) as demo:
374
  placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
375
  interactive=True,
376
  )
377
- which_model = gr.Radio(["[schnell] (4 step fast model)", "[dev] (high quality model, non-commercial license - available when training locally)"], label="Which base model to train?", elem_id="space_model" if is_spaces else "local_model", value="[schnell] (4 step fast model)",)
 
 
 
 
 
 
378
  model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
379
  """, visible=False)
380
  with gr.Group(visible=True) as image_upload:
 
3
  from typing import Union
4
  from huggingface_hub import whoami
5
  is_spaces = True if os.environ.get("SPACE_ID") else False
6
+ is_canonical = True if os.environ.get("SPACE_ID") == "autotrain-projects/train-flux-lora-ease" else False
7
 
8
  if is_spaces:
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
227
  else:
228
  config["config"]["process"][0]["train"]["disable_sampling"] = True
229
 
230
+ if(which_model == "[schnell]"):
231
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
232
  config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
233
  config["config"]["process"][0]["sample"]["sample_steps"] = 4
 
375
  placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
376
  interactive=True,
377
  )
378
+ which_model = gr.Radio(
379
+ [("[schnell] (4 step fast model)", "[schnell]"),
380
+ ("[dev] (high quality model, non-commercial license - available if you duplicate this space or locally)" if is_canonical else "[dev] (high quality model, non-commercial license)", "[dev]")],
381
+ label="Which base model to train?",
382
+ elem_id="space_model" if is_canonical else "local_model",
383
+ value="[schnell]" if is_canonical else "[dev]"
384
+ )
385
  model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
386
  """, visible=False)
387
  with gr.Group(visible=True) as image_upload: