multimodalart HF staff commited on
Commit
2caf84c
·
verified ·
1 Parent(s): 8458096

Add custom LoRA loading

Browse files
Files changed (1) hide show
  1. app.py +93 -4
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import json
3
  import logging
@@ -5,6 +6,7 @@ import torch
5
  from PIL import Image
6
  import spaces
7
  from diffusers import DiffusionPipeline
 
8
  import copy
9
  import random
10
  import time
@@ -35,7 +37,6 @@ class calculateDuration:
35
  else:
36
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
37
 
38
-
39
  def update_selection(evt: gr.SelectData, width, height):
40
  selected_lora = loras[evt.index]
41
  new_placeholder = f"Type a prompt for {selected_lora['title']}"
@@ -86,9 +87,10 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
86
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
87
  if "weights" in selected_lora:
88
  pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
 
89
  else:
90
  pipe.load_lora_weights(lora_path)
91
-
92
  # Set random seed for reproducibility
93
  with calculateDuration("Randomizing seed"):
94
  if randomize_seed:
@@ -96,9 +98,80 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
96
 
97
  image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
98
  pipe.to("cpu")
 
99
  pipe.unload_lora_weights()
100
  return image, seed
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  run_lora.zerogpu = True
103
 
104
  css = '''
@@ -107,6 +180,10 @@ css = '''
107
  #title h1{font-size: 3em; display:inline-flex; align-items:center}
108
  #title img{width: 100px; margin-right: 0.5em}
109
  #gallery .grid-wrap{height: 10vh}
 
 
 
 
110
  '''
111
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
112
  title = gr.HTML(
@@ -129,7 +206,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
129
  columns=3,
130
  elem_id="gallery"
131
  )
132
-
 
 
 
 
133
  with gr.Column(scale=4):
134
  result = gr.Image(label="Generated Image")
135
 
@@ -154,7 +235,15 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
154
  inputs=[width, height],
155
  outputs=[prompt, selected_info, selected_index, width, height]
156
  )
157
-
 
 
 
 
 
 
 
 
158
  gr.on(
159
  triggers=[generate_button.click, prompt.submit],
160
  fn=run_lora,
 
1
+ import os
2
  import gradio as gr
3
  import json
4
  import logging
 
6
  from PIL import Image
7
  import spaces
8
  from diffusers import DiffusionPipeline
9
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
10
  import copy
11
  import random
12
  import time
 
37
  else:
38
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
39
 
 
40
  def update_selection(evt: gr.SelectData, width, height):
41
  selected_lora = loras[evt.index]
42
  new_placeholder = f"Type a prompt for {selected_lora['title']}"
 
87
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
88
  if "weights" in selected_lora:
89
  pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
90
+ #pipe.fuse_lora()
91
  else:
92
  pipe.load_lora_weights(lora_path)
93
+ #pipe.fuse_lora()
94
  # Set random seed for reproducibility
95
  with calculateDuration("Randomizing seed"):
96
  if randomize_seed:
 
98
 
99
  image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
100
  pipe.to("cpu")
101
+ #pipe.unfuse_lora()
102
  pipe.unload_lora_weights()
103
  return image, seed
104
 
105
+ def get_huggingface_safetensors(link):
106
+ split_link = link.split("/")
107
+ if(len(split_link) == 2):
108
+ model_card = ModelCard.load(link)
109
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
110
+ trigger_word = model_card.data.get("instance_prompt", "")
111
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
112
+ fs = HfFileSystem()
113
+ try:
114
+ list_of_files = fs.ls(link, detail=False)
115
+ for file in list_of_files:
116
+ if(file.endswith(".safetensors")):
117
+ safetensors_name = file.split("/")[-1]
118
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
119
+ image_elements = file.split("/")
120
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
121
+ except Exception as e:
122
+ print(e)
123
+ gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA {e}")
124
+ raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA {e}")
125
+ return split_link[1], link, safetensors_name, trigger_word, image_url
126
+
127
+ def check_custom_model(link):
128
+ if(link.startswith("https://")):
129
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
130
+ link_split = link.split("huggingface.co/")
131
+ return get_huggingface_safetensors(link_split[1])
132
+ else:
133
+ return get_huggingface_safetensors(link)
134
+
135
+ def add_custom_lora(custom_lora):
136
+ global loras
137
+ if(custom_lora):
138
+ try:
139
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
140
+ card = f'''
141
+ <div class="custom_lora_card">
142
+ <span>Loaded custom LoRA:</span>
143
+ <div class="card_internal">
144
+ <img src="{image}" />
145
+ <div>
146
+ <h3>{title}</h3>
147
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
148
+ </div>
149
+ </div>
150
+ </div>
151
+ '''
152
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
153
+ if(not existing_item_index):
154
+ new_item = {
155
+ "image": image,
156
+ "title": title,
157
+ "repo": repo,
158
+ "weights": path,
159
+ "trigger_word": trigger_word
160
+ }
161
+ print(new_item)
162
+ existing_item_index = len(loras)
163
+ loras.append(new_item)
164
+
165
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index
166
+ except Exception as e:
167
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
168
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=True), gr.update(), "", None
169
+ else:
170
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None
171
+
172
+ def remove_custom_lora():
173
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
174
+
175
  run_lora.zerogpu = True
176
 
177
  css = '''
 
180
  #title h1{font-size: 3em; display:inline-flex; align-items:center}
181
  #title img{width: 100px; margin-right: 0.5em}
182
  #gallery .grid-wrap{height: 10vh}
183
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
184
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
185
+ .card_internal img{margin-right: 1em}
186
+ .styler{--form-gap-width: 0px !important}
187
  '''
188
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
189
  title = gr.HTML(
 
206
  columns=3,
207
  elem_id="gallery"
208
  )
209
+ with gr.Group():
210
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux")
211
+ gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
212
+ custom_lora_info = gr.HTML(visible=False)
213
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
214
  with gr.Column(scale=4):
215
  result = gr.Image(label="Generated Image")
216
 
 
235
  inputs=[width, height],
236
  outputs=[prompt, selected_info, selected_index, width, height]
237
  )
238
+ custom_lora.input(
239
+ add_custom_lora,
240
+ inputs=[custom_lora],
241
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index]
242
+ )
243
+ custom_lora_button.click(
244
+ remove_custom_lora,
245
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
246
+ )
247
  gr.on(
248
  triggers=[generate_button.click, prompt.submit],
249
  fn=run_lora,