ruslanmv commited on
Commit
bd77ee2
·
verified ·
1 Parent(s): b3fce51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -55
app.py CHANGED
@@ -31,6 +31,9 @@ from diffusers.utils import load_image
31
 
32
  import spaces
33
 
 
 
 
34
  # Attempt to import loras from lora.py; otherwise use a default placeholder.
35
  try:
36
  from lora import loras
@@ -205,15 +208,16 @@ base_model = "black-forest-labs/FLUX.1-dev"
205
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
206
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
207
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
208
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model,
209
- vae=good_vae,
210
- transformer=pipe.transformer,
211
- text_encoder=pipe.text_encoder,
212
- tokenizer=pipe.tokenizer,
213
- text_encoder_2=pipe.text_encoder_2,
214
- tokenizer_2=pipe.tokenizer_2,
215
- torch_dtype=dtype
216
- )
 
217
  MAX_SEED = 2**32-1
218
 
219
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
@@ -292,23 +296,32 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
292
  return final_image
293
 
294
  @spaces.GPU(duration=100)
295
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
 
296
  if selected_index is None:
297
  raise gr.Error("You must select a LoRA before proceeding.🧨")
298
  selected_lora = loras[selected_index]
299
  lora_path = selected_lora["repo"]
300
  trigger_word = selected_lora["trigger_word"]
301
- if(trigger_word):
302
- if "trigger_position" in selected_lora:
303
- if selected_lora["trigger_position"] == "prepend":
304
- prompt_mash = f"{trigger_word} {prompt}"
305
- else:
306
- prompt_mash = f"{prompt} {trigger_word}"
307
- else:
308
  prompt_mash = f"{trigger_word} {prompt}"
 
 
309
  else:
310
  prompt_mash = prompt
311
 
 
 
 
 
 
 
 
 
 
 
312
  with calculateDuration("Unloading LoRA"):
313
  pipe.unload_lora_weights()
314
  pipe_i2i.unload_lora_weights()
@@ -326,9 +339,9 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
326
  if randomize_seed:
327
  seed = random.randint(0, MAX_SEED)
328
 
329
- if(image_input is not None):
330
  final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
331
- yield final_image, seed, gr.update(visible=False)
332
  else:
333
  image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
334
  final_image = None
@@ -337,16 +350,16 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
337
  step_counter += 1
338
  final_image = image
339
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
340
- yield image, seed, gr.update(value=progress_bar, visible=True)
341
- yield final_image, seed, gr.update(value=progress_bar, visible=False)
342
 
343
  def get_huggingface_safetensors(link):
344
  split_link = link.split("/")
345
- if(len(split_link) == 2):
346
  model_card = ModelCard.load(link)
347
  base_model = model_card.data.get("base_model")
348
  print(base_model)
349
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
350
  raise Exception("Flux LoRA Not Found!")
351
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
352
  trigger_word = model_card.data.get("instance_prompt", "")
@@ -355,30 +368,30 @@ def get_huggingface_safetensors(link):
355
  try:
356
  list_of_files = fs.ls(link, detail=False)
357
  for file in list_of_files:
358
- if(file.endswith(".safetensors")):
359
  safetensors_name = file.split("/")[-1]
360
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
361
  image_elements = file.split("/")
362
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
363
  except Exception as e:
364
  print(e)
365
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
366
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
367
  return split_link[1], link, safetensors_name, trigger_word, image_url
368
  else:
369
  raise Exception("Invalid LoRA link format")
370
 
371
  def check_custom_model(link):
372
- if(link.startswith("https://")):
373
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
374
  link_split = link.split("huggingface.co/")
375
  return get_huggingface_safetensors(link_split[1])
376
- else:
377
  return get_huggingface_safetensors(link)
378
 
379
  def add_custom_lora(custom_lora):
380
  global loras
381
- if(custom_lora):
382
  try:
383
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
384
  print(f"Loaded custom LoRA: {repo}")
@@ -389,13 +402,13 @@ def add_custom_lora(custom_lora):
389
  <img src="{image}" />
390
  <div>
391
  <h3>{title}</h3>
392
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
393
  </div>
394
  </div>
395
  </div>
396
  '''
397
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
398
- if(not existing_item_index):
399
  new_item = {
400
  "image": image,
401
  "title": title,
@@ -409,8 +422,8 @@ def add_custom_lora(custom_lora):
409
 
410
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
411
  except Exception as e:
412
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
413
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
414
  else:
415
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
416
 
@@ -420,25 +433,25 @@ def remove_custom_lora():
420
  run_lora.zerogpu = True
421
 
422
  css = '''
423
- #gen_btn{height: 100%}
424
- #gen_column{align-self: stretch}
425
- #title{text-align: center}
426
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
427
- #title img{width: 100px; margin-right: 0.5em}
428
- #gallery .grid-wrap{height: 10vh}
429
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
430
- .card_internal{display: flex;height: 100px;margin-top: .5em}
431
- .card_internal img{margin-right: 1em}
432
- .styler{--form-gap-width: 0px !important}
433
- #progress{height:30px}
434
- #progress .generating{display:none}
435
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
436
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
437
  '''
438
 
439
- with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
440
  title = gr.HTML(
441
- """<h1>FLUX LoRA DLC🥳</h1>""",
442
  elem_id="title",
443
  )
444
  selected_index = gr.State(None)
@@ -464,7 +477,7 @@ with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
464
  custom_lora_info = gr.HTML(visible=False)
465
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
466
  with gr.Column():
467
- progress_bar = gr.Markdown(elem_id="progress",visible=False)
468
  result = gr.Image(label="Generated Image")
469
  with gr.Row():
470
  with gr.Accordion("Advanced Settings", open=False):
@@ -507,9 +520,11 @@ with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
507
  gr.on(
508
  triggers=[generate_button.click, prompt.submit],
509
  fn=run_lora,
510
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
511
- outputs=[result, seed, progress_bar]
512
  )
513
-
 
 
514
  app.queue()
515
  app.launch(debug=True)
 
31
 
32
  import spaces
33
 
34
+ # Import the prompt enhancer generator from enhance.py
35
+ from enhance import generate as enhance_generate
36
+
37
  # Attempt to import loras from lora.py; otherwise use a default placeholder.
38
  try:
39
  from lora import loras
 
208
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
209
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
210
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
211
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
212
+ base_model,
213
+ vae=good_vae,
214
+ transformer=pipe.transformer,
215
+ text_encoder=pipe.text_encoder,
216
+ tokenizer=pipe.tokenizer,
217
+ text_encoder_2=pipe.text_encoder_2,
218
+ tokenizer_2=pipe.tokenizer_2,
219
+ torch_dtype=dtype,
220
+ ).to(device)
221
  MAX_SEED = 2**32-1
222
 
223
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
296
  return final_image
297
 
298
  @spaces.GPU(duration=100)
299
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)):
300
+ # Check if a LoRA is selected.
301
  if selected_index is None:
302
  raise gr.Error("You must select a LoRA before proceeding.🧨")
303
  selected_lora = loras[selected_index]
304
  lora_path = selected_lora["repo"]
305
  trigger_word = selected_lora["trigger_word"]
306
+ # Prepare prompt by appending/prepending trigger word if available.
307
+ if trigger_word:
308
+ if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
 
 
 
 
309
  prompt_mash = f"{trigger_word} {prompt}"
310
+ else:
311
+ prompt_mash = f"{prompt} {trigger_word}"
312
  else:
313
  prompt_mash = prompt
314
 
315
+ # If prompt enhancer is enabled, stream the enhanced prompt.
316
+ enhanced_text = ""
317
+ if use_enhancer:
318
+ for enhanced_chunk in enhance_generate(prompt_mash):
319
+ enhanced_text = enhanced_chunk
320
+ # Yield intermediate output (no image yet, but update enhanced prompt textbox)
321
+ yield None, seed, gr.update(visible=False), enhanced_text
322
+ prompt_mash = enhanced_text # Use final enhanced prompt for generation
323
+ # Else, leave prompt_mash as is.
324
+
325
  with calculateDuration("Unloading LoRA"):
326
  pipe.unload_lora_weights()
327
  pipe_i2i.unload_lora_weights()
 
339
  if randomize_seed:
340
  seed = random.randint(0, MAX_SEED)
341
 
342
+ if image_input is not None:
343
  final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
344
+ yield final_image, seed, gr.update(visible=False), enhanced_text
345
  else:
346
  image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
347
  final_image = None
 
350
  step_counter += 1
351
  final_image = image
352
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
353
+ yield image, seed, gr.update(value=progress_bar, visible=True), enhanced_text
354
+ yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
355
 
356
  def get_huggingface_safetensors(link):
357
  split_link = link.split("/")
358
+ if len(split_link) == 2:
359
  model_card = ModelCard.load(link)
360
  base_model = model_card.data.get("base_model")
361
  print(base_model)
362
+ if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
363
  raise Exception("Flux LoRA Not Found!")
364
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
365
  trigger_word = model_card.data.get("instance_prompt", "")
 
368
  try:
369
  list_of_files = fs.ls(link, detail=False)
370
  for file in list_of_files:
371
+ if file.endswith(".safetensors"):
372
  safetensors_name = file.split("/")[-1]
373
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
374
  image_elements = file.split("/")
375
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
376
  except Exception as e:
377
  print(e)
378
+ gr.Warning("You didn't include a link nor a valid Hugging Face repository with a *.safetensors LoRA")
379
+ raise Exception("Invalid LoRA repository")
380
  return split_link[1], link, safetensors_name, trigger_word, image_url
381
  else:
382
  raise Exception("Invalid LoRA link format")
383
 
384
  def check_custom_model(link):
385
+ if link.startswith("https://"):
386
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
387
  link_split = link.split("huggingface.co/")
388
  return get_huggingface_safetensors(link_split[1])
389
+ else:
390
  return get_huggingface_safetensors(link)
391
 
392
  def add_custom_lora(custom_lora):
393
  global loras
394
+ if custom_lora:
395
  try:
396
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
397
  print(f"Loaded custom LoRA: {repo}")
 
402
  <img src="{image}" />
403
  <div>
404
  <h3>{title}</h3>
405
+ <small>{"Using: <code><b>" + trigger_word + "</b></code> as the trigger word" if trigger_word else "No trigger word found. Include it in your prompt"}<br></small>
406
  </div>
407
  </div>
408
  </div>
409
  '''
410
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
411
+ if not existing_item_index:
412
  new_item = {
413
  "image": image,
414
  "title": title,
 
422
 
423
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
424
  except Exception as e:
425
+ gr.Warning("Invalid LoRA: either you entered an invalid link or a non-FLUX LoRA")
426
+ return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
427
  else:
428
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
429
 
 
433
  run_lora.zerogpu = True
434
 
435
  css = '''
436
+ #gen_btn { height: 100%; }
437
+ #gen_column { align-self: stretch; }
438
+ #title { text-align: center; }
439
+ #title h1 { font-size: 3em; display:inline-flex; align-items:center; }
440
+ #title img { width: 100px; margin-right: 0.5em; }
441
+ #gallery .grid-wrap { height: 10vh; }
442
+ #lora_list { background: var(--block-background-fill); padding: 0 1em .3em; font-size: 90%; }
443
+ .card_internal { display: flex; height: 100px; margin-top: .5em; }
444
+ .card_internal img { margin-right: 1em; }
445
+ .styler { --form-gap-width: 0px !important; }
446
+ #progress { height:30px; }
447
+ #progress .generating { display:none; }
448
+ .progress-container { width: 100%; height: 30px; background-color: #f0f0f0; border-radius: 15px; overflow: hidden; margin-bottom: 20px; }
449
+ .progress-bar { height: 100%; background-color: #4f46e5; width: calc(var(--current) / var(--total) * 100%); transition: width 0.5s ease-in-out; }
450
  '''
451
 
452
+ with gr.Blocks(theme=gr.themes.Base(), css=css, delete_cache=(60, 60)) as app:
453
  title = gr.HTML(
454
+ """<h1>Flux LoRA Generation</h1>""",
455
  elem_id="title",
456
  )
457
  selected_index = gr.State(None)
 
477
  custom_lora_info = gr.HTML(visible=False)
478
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
479
  with gr.Column():
480
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
481
  result = gr.Image(label="Generated Image")
482
  with gr.Row():
483
  with gr.Accordion("Advanced Settings", open=False):
 
520
  gr.on(
521
  triggers=[generate_button.click, prompt.submit],
522
  fn=run_lora,
523
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer],
524
+ outputs=[result, seed, progress_bar, enhanced_prompt_box]
525
  )
526
+ with gr.Row():
527
+ gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
528
+
529
  app.queue()
530
  app.launch(debug=True)