barreloflube commited on
Commit
8493260
1 Parent(s): 1046573

Refactor image_tab.py and flux_tab.py to update LoRA gallery functionality

Browse files
Files changed (1) hide show
  1. tabs/image_tab.py +207 -8
tabs/image_tab.py CHANGED
@@ -1,12 +1,7 @@
1
  # tabs/image_tab.py
2
 
3
  import gradio as gr
4
- from modules.events.flux_events import *
5
- from modules.events.sdxl_events import *
6
  from modules.helpers.common_helpers import *
7
- from modules.helpers.flux_helpers import *
8
- from modules.helpers.sdxl_helpers import *
9
- from config import flux_models, sdxl_models, flux_loras
10
 
11
 
12
  def image_tab():
@@ -18,6 +13,17 @@ def image_tab():
18
 
19
 
20
  def flux_tab():
 
 
 
 
 
 
 
 
 
 
 
21
  loras = flux_loras
22
  with gr.Row():
23
  with gr.Column():
@@ -122,8 +128,8 @@ def flux_tab():
122
  for column in range(2):
123
  with gr.Column():
124
  options = [
125
- ("Height", "image_height", 64, 1024, 64, 1024, True),
126
- ("Width", "image_width", 64, 1024, 64, 1024, True),
127
  ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
128
  ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
129
  ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
@@ -180,4 +186,197 @@ def flux_tab():
180
 
181
 
182
  def sdxl_tab():
183
- gr.Label("To be implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # tabs/image_tab.py
2
 
3
  import gradio as gr
 
 
4
  from modules.helpers.common_helpers import *
 
 
 
5
 
6
 
7
  def image_tab():
 
13
 
14
 
15
  def flux_tab():
16
+ from modules.events.flux_events import (
17
+ update_fast_generation,
18
+ selected_lora_from_gallery,
19
+ update_selected_lora,
20
+ add_to_enabled_loras,
21
+ update_lora_sliders,
22
+ remove_from_enabled_loras,
23
+ generate_image
24
+ )
25
+ from config import flux_models, flux_loras
26
+
27
  loras = flux_loras
28
  with gr.Row():
29
  with gr.Column():
 
128
  for column in range(2):
129
  with gr.Column():
130
  options = [
131
+ ("Height", "image_height", 64, 2048, 64, 1024, True),
132
+ ("Width", "image_width", 64, 2048, 64, 1024, True),
133
  ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
134
  ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
135
  ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
 
186
 
187
 
188
  def sdxl_tab():
189
+ from modules.events.sdxl_events import (
190
+ update_fast_generation,
191
+ selected_lora_from_gallery,
192
+ update_selected_lora,
193
+ add_to_enabled_loras,
194
+ update_lora_sliders,
195
+ remove_from_enabled_loras,
196
+ add_to_embeddings,
197
+ update_custom_embedding,
198
+ remove_from_embeddings,
199
+ generate_image
200
+ )
201
+ from config import sdxl_models, sdxl_loras
202
+
203
+ loras = sdxl_loras
204
+ with gr.Row():
205
+ with gr.Column():
206
+ with gr.Group() as image_options:
207
+ model = gr.Dropdown(label="Models", choices=sdxl_models, value=sdxl_models[0], interactive=True)
208
+ prompt = gr.Textbox(lines=5, label="Prompt")
209
+ fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) 🧪")
210
+
211
+
212
+ with gr.Accordion("Loras", open=True):
213
+ lora_gallery = gr.Gallery(
214
+ label="Gallery",
215
+ value=[(lora['image'], lora['title']) for lora in loras],
216
+ allow_preview=False,
217
+ columns=3,
218
+ rows=3,
219
+ type="pil"
220
+ )
221
+
222
+ with gr.Group():
223
+ with gr.Column():
224
+ with gr.Row():
225
+ custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
226
+ selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
227
+
228
+ custom_lora_info = gr.HTML(visible=False)
229
+ add_lora = gr.Button(value="Add LoRA")
230
+
231
+ enabled_loras = gr.State(value=[])
232
+ with gr.Group():
233
+ with gr.Row():
234
+ for i in range(6):
235
+ with gr.Column():
236
+ with gr.Column(scale=2):
237
+ globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
238
+ with gr.Column():
239
+ globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
240
+
241
+ with gr.Accordion("Embeddings", open=False):
242
+ custom_embedding = gr.Textbox(label="Custom Embedding")
243
+ custom_embedding_info = gr.HTML(visible=False)
244
+ add_embedding = gr.Button(value="Add Embedding")
245
+ embeddings = gr.State(value=[])
246
+ with gr.Group():
247
+ with gr.Row():
248
+ for i in range(6):
249
+ with gr.Column():
250
+ with gr.Column(scale=2):
251
+ globals()[f"embedding_list_{i}"] = gr.Label(label=f"Embedding {i+1}", visible=False)
252
+ with gr.Column():
253
+ globals()[f"embedding_remove_{i}"] = gr.Button(value="Remove Embedding", visible=False)
254
+
255
+ with gr.Accordion("Image Options", open=False):
256
+ with gr.Tabs():
257
+ image_options = {
258
+ "img2img": "Upload Image",
259
+ "inpaint": "Upload Image",
260
+ "canny": "Upload Image",
261
+ "pose": "Upload Image",
262
+ "depth": "Upload Image",
263
+ "scribble": "Upload Image",
264
+ }
265
+
266
+ for image_option, label in image_options.items():
267
+ with gr.Tab(image_option):
268
+ if not image_option in ['inpaint', 'scribble']:
269
+ globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
270
+ elif image_option in ['inpaint', 'scribble']:
271
+ globals()[f"{image_option}_image"] = gr.ImageEditor(
272
+ label=label,
273
+ image_mode='RGB',
274
+ layers=False,
275
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
276
+ interactive=True,
277
+ type="pil",
278
+ )
279
+
280
+ globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
281
+
282
+ resize_mode = gr.Radio(
283
+ label="Resize Mode",
284
+ choices=["crop and resize", "resize only", "resize and fill"],
285
+ value="resize and fill",
286
+ interactive=True
287
+ )
288
+
289
+ with gr.Column():
290
+ with gr.Group():
291
+ output_images = gr.Gallery(
292
+ label="Output Images",
293
+ value=[],
294
+ allow_preview=True,
295
+ type="pil",
296
+ interactive=False,
297
+ )
298
+ generate_images = gr.Button(value="Generate Images", variant="primary")
299
+
300
+ with gr.Accordion("Advance Settings", open=True):
301
+ with gr.Row():
302
+ scheduler = gr.Dropdown(
303
+ label="Scheduler",
304
+ choices = [
305
+ "dpmpp_2m", "dpmpp_2m_k", "dpmpp_2m_sde", "dpmpp_2m_sde_k",
306
+ "dpmpp_sde", "dpmpp_sde_k", "dpm2", "dpm2_k", "dpm2_a",
307
+ "dpm2_a_k", "euler", "euler_a", "heun", "lms", "lms_k",
308
+ "deis", "unipc"
309
+ ]
310
+ value="dpmpp_2m_sde_k",
311
+ interactive=True
312
+ )
313
+
314
+ with gr.Row():
315
+ for column in range(2):
316
+ with gr.Column():
317
+ options = [
318
+ ("Height", "image_height", 64, 2048, 64, 1024, True),
319
+ ("Width", "image_width", 64, 2048, 64, 1024, True),
320
+ ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
321
+ ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
322
+ ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, True),
323
+ ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 7.0, True),
324
+ ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
325
+ ]
326
+ for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
327
+ globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
328
+
329
+ with gr.Row():
330
+ refiner = gr.Checkbox(
331
+ label="Refiner 🧪",
332
+ value=False,
333
+ )
334
+ vae = gr.Checkbox(
335
+ label="VAE",
336
+ value=True,
337
+ )
338
+
339
+ # Events
340
+ # Base Options
341
+ fast_generation.change(update_fast_generation, [fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
342
+
343
+
344
+ # Lora Gallery
345
+ lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
346
+ custom_lora.change(update_selected_lora, custom_lora, [selected_lora, custom_lora_info])
347
+ add_lora.click(add_to_enabled_loras, [selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
348
+ enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
349
+
350
+ for i in range(6):
351
+ globals()[f"lora_remove_{i}"].click(
352
+ lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
353
+ [enabled_loras],
354
+ [enabled_loras]
355
+ )
356
+
357
+
358
+ # Embeddings
359
+ custom_embedding.change(update_custom_embedding, custom_embedding, [custom_embedding_info])
360
+ add_embedding.click(add_to_embeddings, [custom_embedding, embeddings], [custom_embedding, custom_embedding_info, embeddings])
361
+ for i in range(6):
362
+ globals()[f"embedding_remove_{i}"].click(
363
+ lambda embeddings, index=i: remove_from_embeddings(embeddings, index),
364
+ [embeddings],
365
+ [embeddings]
366
+ )
367
+
368
+ # Generate Image
369
+ generate_images.click(
370
+ generate_image, # type: ignore
371
+ [
372
+ model, prompt, fast_generation, enabled_loras,
373
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
374
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
375
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
376
+ resize_mode,
377
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
378
+ image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
379
+ refiner, vae
380
+ ],
381
+ [output_images]
382
+ )