multimodalart HF staff commited on
Commit
e306774
·
verified ·
1 Parent(s): 89cc8a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -5
app.py CHANGED
@@ -159,9 +159,228 @@ def randomize_loras(selected_indices):
159
  lora_image_2 = lora2['image']
160
  return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
161
 
162
- # ... (rest of your code remains unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- # Update your UI components to include image previews
165
  run_lora.zerogpu = True
166
 
167
  css = '''
@@ -194,13 +413,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
194
  generate_button = gr.Button("Generate", variant="primary")
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
- randomize_button = gr.Button("🎲", variant="secondary", scale=1, min_width=50)
198
- with gr.Column(scale=4):
199
  lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False)
200
  selected_info_1 = gr.Markdown("Select a LoRA 1")
201
  lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
202
  remove_button_1 = gr.Button("Remove LoRA 1")
203
- with gr.Column(scale=4):
204
  lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False)
205
  selected_info_2 = gr.Markdown("Select a LoRA 2")
206
  lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
 
159
  lora_image_2 = lora2['image']
160
  return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
161
 
162
+ @spaces.GPU(duration=70)
163
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
164
+ pipe.to("cuda")
165
+ generator = torch.Generator(device="cuda").manual_seed(seed)
166
+ with calculateDuration("Generating image"):
167
+ # Generate image
168
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
169
+ prompt=prompt_mash,
170
+ num_inference_steps=steps,
171
+ guidance_scale=cfg_scale,
172
+ width=width,
173
+ height=height,
174
+ generator=generator,
175
+ joint_attention_kwargs={"scale": 1.0},
176
+ output_type="pil",
177
+ good_vae=good_vae,
178
+ ):
179
+ yield img
180
+
181
+ @spaces.GPU(duration=70)
182
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
183
+ generator = torch.Generator(device="cuda").manual_seed(seed)
184
+ pipe_i2i.to("cuda")
185
+ image_input = load_image(image_input_path)
186
+ final_image = pipe_i2i(
187
+ prompt=prompt_mash,
188
+ image=image_input,
189
+ strength=image_strength,
190
+ num_inference_steps=steps,
191
+ guidance_scale=cfg_scale,
192
+ width=width,
193
+ height=height,
194
+ generator=generator,
195
+ joint_attention_kwargs={"scale": 1.0},
196
+ output_type="pil",
197
+ ).images[0]
198
+ return final_image
199
+
200
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, progress=gr.Progress(track_tqdm=True)):
201
+ if not selected_indices:
202
+ raise gr.Error("You must select at least one LoRA before proceeding.")
203
+
204
+ selected_loras = [loras[idx] for idx in selected_indices]
205
+
206
+ # Build the prompt with trigger words
207
+ prompt_mash = prompt
208
+ for lora in selected_loras:
209
+ trigger_word = lora.get('trigger_word', '')
210
+ if trigger_word:
211
+ if lora.get("trigger_position") == "prepend":
212
+ prompt_mash = f"{trigger_word} {prompt_mash}"
213
+ else:
214
+ prompt_mash = f"{prompt_mash} {trigger_word}"
215
+
216
+ # Unload previous LoRA weights
217
+ with calculateDuration("Unloading LoRA"):
218
+ pipe.unload_lora_weights()
219
+ pipe_i2i.unload_lora_weights()
220
+
221
+ # Load LoRA weights with respective scales
222
+ with calculateDuration("Loading LoRA weights"):
223
+ for idx, lora in enumerate(selected_loras):
224
+ lora_path = lora['repo']
225
+ scale = lora_scale_1 if idx == 0 else lora_scale_2
226
+ if image_input is not None:
227
+ if "weights" in lora:
228
+ pipe_i2i.load_lora_weights(lora_path, weight_name=lora["weights"], multiplier=scale)
229
+ else:
230
+ pipe_i2i.load_lora_weights(lora_path, multiplier=scale)
231
+ else:
232
+ if "weights" in lora:
233
+ pipe.load_lora_weights(lora_path, weight_name=lora["weights"], multiplier=scale)
234
+ else:
235
+ pipe.load_lora_weights(lora_path, multiplier=scale)
236
+
237
+ # Set random seed for reproducibility
238
+ with calculateDuration("Randomizing seed"):
239
+ if randomize_seed:
240
+ seed = random.randint(0, MAX_SEED)
241
+
242
+ # Generate image
243
+ if image_input is not None:
244
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
245
+ yield final_image, seed, gr.update(visible=False)
246
+ else:
247
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
248
+ # Consume the generator to get the final image
249
+ final_image = None
250
+ step_counter = 0
251
+ for image in image_generator:
252
+ step_counter+=1
253
+ final_image = image
254
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
255
+ yield image, seed, gr.update(value=progress_bar, visible=True)
256
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
257
+
258
+ def get_huggingface_safetensors(link):
259
+ split_link = link.split("/")
260
+ if len(split_link) == 2:
261
+ model_card = ModelCard.load(link)
262
+ base_model = model_card.data.get("base_model")
263
+ print(base_model)
264
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
265
+ raise Exception("Not a FLUX LoRA!")
266
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
267
+ trigger_word = model_card.data.get("instance_prompt", "")
268
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
269
+ fs = HfFileSystem()
270
+ safetensors_name = None
271
+ try:
272
+ list_of_files = fs.ls(link, detail=False)
273
+ for file in list_of_files:
274
+ if file.endswith(".safetensors"):
275
+ safetensors_name = file.split("/")[-1]
276
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
277
+ image_elements = file.split("/")
278
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
279
+ except Exception as e:
280
+ print(e)
281
+ raise Exception("Invalid Hugging Face repository with a *.safetensors LoRA")
282
+ if not safetensors_name:
283
+ raise Exception("No *.safetensors file found in the repository")
284
+ return split_link[1], link, safetensors_name, trigger_word, image_url
285
+
286
+ def check_custom_model(link):
287
+ if link.startswith("https://"):
288
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
289
+ link_split = link.split("huggingface.co/")
290
+ return get_huggingface_safetensors(link_split[1])
291
+ else:
292
+ return get_huggingface_safetensors(link)
293
+
294
+ def add_custom_lora(custom_lora, selected_indices):
295
+ global loras
296
+ if custom_lora:
297
+ try:
298
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
299
+ print(f"Loaded custom LoRA: {repo}")
300
+ card = f'''
301
+ <div class="custom_lora_card">
302
+ <span>Loaded custom LoRA:</span>
303
+ <div class="card_internal">
304
+ <img src="{image}" />
305
+ <div>
306
+ <h3>{title}</h3>
307
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
308
+ </div>
309
+ </div>
310
+ </div>
311
+ '''
312
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
313
+ if existing_item_index is None:
314
+ new_item = {
315
+ "image": image,
316
+ "title": title,
317
+ "repo": repo,
318
+ "weights": path,
319
+ "trigger_word": trigger_word
320
+ }
321
+ print(new_item)
322
+ existing_item_index = len(loras)
323
+ loras.append(new_item)
324
+
325
+ # Update gallery
326
+ gallery_items = [(item["image"], item["title"]) for item in loras]
327
+ # Update selected_indices if there's room
328
+ if len(selected_indices) < 2:
329
+ selected_indices.append(existing_item_index)
330
+ selected_info_1 = ""
331
+ selected_info_2 = ""
332
+ lora_scale_1 = 0.95
333
+ lora_scale_2 = 0.95
334
+ lora_image_1 = None
335
+ lora_image_2 = None
336
+ if len(selected_indices) >= 1:
337
+ lora1 = loras[selected_indices[0]]
338
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
339
+ lora_image_1 = lora1['image']
340
+ if len(selected_indices) >= 2:
341
+ lora2 = loras[selected_indices[1]]
342
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
343
+ lora_image_2 = lora2['image']
344
+ return (gr.update(visible=True, value=card), gr.update(visible=True), gr.update(value=gallery_items),
345
+ selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2)
346
+ else:
347
+ return (gr.update(visible=True, value=card), gr.update(visible=True), gr.update(value=gallery_items),
348
+ gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange())
349
+ except Exception as e:
350
+ print(e)
351
+ return gr.update(visible=True, value=str(e)), gr.update(visible=True), gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange()
352
+ else:
353
+ return gr.update(visible=False), gr.update(visible=False), gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange()
354
+
355
+ def remove_custom_lora(custom_lora_info, custom_lora_button, selected_indices):
356
+ global loras
357
+ if loras:
358
+ custom_lora_repo = loras[-1]['repo']
359
+ # Remove from loras list
360
+ loras = loras[:-1]
361
+ # Remove from selected_indices if selected
362
+ custom_lora_index = len(loras)
363
+ if custom_lora_index in selected_indices:
364
+ selected_indices.remove(custom_lora_index)
365
+ # Update gallery
366
+ gallery_items = [(item["image"], item["title"]) for item in loras]
367
+ # Update selected_info and images
368
+ selected_info_1 = ""
369
+ selected_info_2 = ""
370
+ lora_scale_1 = 0.95
371
+ lora_scale_2 = 0.95
372
+ lora_image_1 = None
373
+ lora_image_2 = None
374
+ if len(selected_indices) >= 1:
375
+ lora1 = loras[selected_indices[0]]
376
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
377
+ lora_image_1 = lora1['image']
378
+ if len(selected_indices) >= 2:
379
+ lora2 = loras[selected_indices[1]]
380
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
381
+ lora_image_2 = lora2['image']
382
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
383
 
 
384
  run_lora.zerogpu = True
385
 
386
  css = '''
 
413
  generate_button = gr.Button("Generate", variant="primary")
414
  with gr.Row():
415
  with gr.Column(scale=1):
416
+ randomize_button = gr.Button("🎲", variant="secondary", scale=1)
417
+ with gr.Column(scale=3):
418
  lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False)
419
  selected_info_1 = gr.Markdown("Select a LoRA 1")
420
  lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
421
  remove_button_1 = gr.Button("Remove LoRA 1")
422
+ with gr.Column(scale=3):
423
  lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False)
424
  selected_info_2 = gr.Markdown("Select a LoRA 2")
425
  lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=0.95)