nyanko7 commited on
Commit
4b30d84
1 Parent(s): 95a273b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +754 -0
app.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tempfile
3
+ import time
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+
8
+ from gradio import inputs
9
+ from diffusers import (
10
+ AutoencoderKL,
11
+ DDIMScheduler,
12
+ UNet2DConditionModel,
13
+ )
14
+ from modules.model_pww import CrossAttnProcessor, StableDiffusionPipeline, load_lora_attn_procs
15
+ from torchvision import transforms
16
+ from transformers import CLIPTokenizer, CLIPTextModel
17
+ from PIL import Image
18
+ from pathlib import Path
19
+ from safetensors.torch import load_file
20
+ import modules.safe as _
21
+
22
+ models = [
23
+ ("AbyssOrangeMix_Base", "OrangeMix/AbyssOrangeMix2"),
24
+ ]
25
+
26
+ base_name = "AbyssOrangeMix_Base"
27
+ base_model = "OrangeMix/AbyssOrangeMix2"
28
+
29
+ samplers_k_diffusion = [
30
+ ("Euler a", "sample_euler_ancestral", {}),
31
+ ("Euler", "sample_euler", {}),
32
+ ("LMS", "sample_lms", {}),
33
+ ("Heun", "sample_heun", {}),
34
+ ("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
35
+ ("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
36
+ ("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
37
+ ("DPM++ 2M", "sample_dpmpp_2m", {}),
38
+ ("DPM++ SDE", "sample_dpmpp_sde", {}),
39
+ ("DPM fast", "sample_dpm_fast", {}),
40
+ ("DPM adaptive", "sample_dpm_adaptive", {}),
41
+ ("LMS Karras", "sample_lms", {"scheduler": "karras"}),
42
+ (
43
+ "DPM2 Karras",
44
+ "sample_dpm_2",
45
+ {"scheduler": "karras", "discard_next_to_last_sigma": True},
46
+ ),
47
+ (
48
+ "DPM2 a Karras",
49
+ "sample_dpm_2_ancestral",
50
+ {"scheduler": "karras", "discard_next_to_last_sigma": True},
51
+ ),
52
+ ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
53
+ ("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
54
+ ("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
55
+ ]
56
+
57
+ start_time = time.time()
58
+
59
+ scheduler = DDIMScheduler.from_pretrained(
60
+ base_model,
61
+ subfolder="scheduler",
62
+ )
63
+ vae = AutoencoderKL.from_pretrained(
64
+ "stabilityai/sd-vae-ft-ema",
65
+ torch_dtype=torch.float32
66
+ )
67
+ text_encoder = CLIPTextModel.from_pretrained(
68
+ base_model,
69
+ subfolder="text_encoder",
70
+ torch_dtype=torch.float32,
71
+ )
72
+ tokenizer = CLIPTokenizer.from_pretrained(
73
+ base_model,
74
+ subfolder="tokenizer",
75
+ torch_dtype=torch.float32,
76
+ )
77
+ unet = UNet2DConditionModel.from_pretrained(
78
+ base_model,
79
+ subfolder="unet",
80
+ torch_dtype=torch.float32,
81
+ )
82
+ pipe = StableDiffusionPipeline(
83
+ text_encoder=text_encoder,
84
+ tokenizer=tokenizer,
85
+ unet=unet,
86
+ vae=vae,
87
+ scheduler=scheduler,
88
+ )
89
+
90
+ unet.set_attn_processor(CrossAttnProcessor)
91
+ if torch.cuda.is_available():
92
+ pipe = pipe.to("cuda")
93
+
94
+ def get_model_list():
95
+ model_available = []
96
+ for model in models:
97
+ if Path(model[1]).is_dir():
98
+ model_available.append(model)
99
+ return model_available
100
+
101
+
102
+ unet_cache = dict()
103
+
104
+
105
+ def get_model(name):
106
+ keys = [k[0] for k in models]
107
+ if name not in unet_cache:
108
+ if name not in keys:
109
+ raise ValueError(name)
110
+ else:
111
+ unet = UNet2DConditionModel.from_pretrained(
112
+ models[keys.index(name)][1],
113
+ subfolder="unet",
114
+ torch_dtype=torch.float32,
115
+ )
116
+ unet_cache[name] = unet
117
+
118
+ g_unet = unet_cache[name]
119
+ g_unet.set_attn_processor(None)
120
+ return g_unet
121
+
122
+
123
+ def error_str(error, title="Error"):
124
+ return (
125
+ f"""#### {title}
126
+ {error}"""
127
+ if error
128
+ else ""
129
+ )
130
+
131
+
132
+ te_base_weight = text_encoder.get_input_embeddings().weight.data.detach().clone()
133
+
134
+
135
+ def restore_all():
136
+ global te_base_weight, tokenizer
137
+ text_encoder.get_input_embeddings().weight.data = te_base_weight
138
+ tokenizer = CLIPTokenizer.from_pretrained(
139
+ "/root/workspace/storage/models/orangemix",
140
+ subfolder="tokenizer",
141
+ torch_dtype=torch.float16,
142
+ )
143
+
144
+
145
+ def inference(
146
+ prompt,
147
+ guidance,
148
+ steps,
149
+ width=512,
150
+ height=512,
151
+ seed=0,
152
+ neg_prompt="",
153
+ state=None,
154
+ g_strength=0.4,
155
+ img_input=None,
156
+ i2i_scale=0.5,
157
+ hr_enabled=False,
158
+ hr_method="Latent",
159
+ hr_scale=1.5,
160
+ hr_denoise=0.8,
161
+ sampler="DPM++ 2M Karras",
162
+ embs=None,
163
+ model=None,
164
+ lora_state=None,
165
+ lora_scale=None,
166
+ ):
167
+ global pipe, unet, tokenizer, text_encoder
168
+ if seed is None or seed == 0:
169
+ seed = random.randint(0, 2147483647)
170
+ if torch.cuda.is_available():
171
+ generator = torch.Generator("cuda").manual_seed(int(seed))
172
+ else:
173
+ generator = torch.Generator().manual_seed(int(seed))
174
+
175
+ local_unet = get_model(model)
176
+ if lora_state is not None and lora_state != "":
177
+ load_lora_attn_procs(lora_state, local_unet, lora_scale)
178
+ else:
179
+ local_unet.set_attn_processor(CrossAttnProcessor())
180
+
181
+ pipe.setup_unet(local_unet)
182
+ sampler_name, sampler_opt = None, None
183
+ for label, funcname, options in samplers_k_diffusion:
184
+ if label == sampler:
185
+ sampler_name, sampler_opt = funcname, options
186
+
187
+ if embs is not None and len(embs) > 0:
188
+ delta_weight = []
189
+ for name, file in embs.items():
190
+ if str(file).endswith(".pt"):
191
+ loaded_learned_embeds = torch.load(file, map_location="cpu")
192
+ else:
193
+ loaded_learned_embeds = load_file(file, device="cpu")
194
+ loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"]
195
+ added_length = tokenizer.add_tokens(name)
196
+
197
+ assert added_length == loaded_learned_embeds.shape[0]
198
+ delta_weight.append(loaded_learned_embeds)
199
+
200
+ delta_weight = torch.cat(delta_weight, dim=0)
201
+ text_encoder.resize_token_embeddings(len(tokenizer))
202
+ text_encoder.get_input_embeddings().weight.data[-delta_weight.shape[0]:] = delta_weight
203
+
204
+ config = {
205
+ "negative_prompt": neg_prompt,
206
+ "num_inference_steps": int(steps),
207
+ "guidance_scale": guidance,
208
+ "generator": generator,
209
+ "sampler_name": sampler_name,
210
+ "sampler_opt": sampler_opt,
211
+ "pww_state": state,
212
+ "pww_attn_weight": g_strength,
213
+ }
214
+
215
+ if img_input is not None:
216
+ ratio = min(height / img_input.height, width / img_input.width)
217
+ img_input = img_input.resize(
218
+ (int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
219
+ )
220
+ result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
221
+ elif hr_enabled:
222
+ result = pipe.txt2img(
223
+ prompt,
224
+ width=width,
225
+ height=height,
226
+ upscale=True,
227
+ upscale_x=hr_scale,
228
+ upscale_denoising_strength=hr_denoise,
229
+ **config,
230
+ **latent_upscale_modes[hr_method],
231
+ )
232
+ else:
233
+ result = pipe.txt2img(prompt, width=width, height=height, **config)
234
+
235
+ # restore
236
+ if embs is not None and len(embs) > 0:
237
+ restore_all()
238
+ return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
239
+
240
+
241
+ color_list = []
242
+
243
+
244
+ def get_color(n):
245
+ for _ in range(n - len(color_list)):
246
+ color_list.append(tuple(np.random.random(size=3) * 256))
247
+ return color_list
248
+
249
+
250
+ def create_mixed_img(current, state, w=512, h=512):
251
+ w, h = int(w), int(h)
252
+ image_np = np.full([h, w, 4], 255)
253
+ colors = get_color(len(state))
254
+ idx = 0
255
+
256
+ for key, item in state.items():
257
+ if item["map"] is not None:
258
+ m = item["map"] < 255
259
+ alpha = 150
260
+ if current == key:
261
+ alpha = 200
262
+ image_np[m] = colors[idx] + (alpha,)
263
+ idx += 1
264
+
265
+ return image_np
266
+
267
+
268
+ # width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
269
+ def apply_new_res(w, h, state):
270
+ w, h = int(w), int(h)
271
+
272
+ for key, item in state.items():
273
+ if item["map"] is not None:
274
+ item["map"] = resize(item["map"], w, h)
275
+
276
+ update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
277
+ return state, update_img
278
+
279
+
280
+ def detect_text(text, state, width, height):
281
+
282
+ t = text.split(",")
283
+ new_state = {}
284
+
285
+ for item in t:
286
+ item = item.strip()
287
+ if item == "":
288
+ continue
289
+ if item in state:
290
+ new_state[item] = {
291
+ "map": state[item]["map"],
292
+ "weight": state[item]["weight"],
293
+ }
294
+ else:
295
+ new_state[item] = {
296
+ "map": None,
297
+ "weight": 0.5,
298
+ }
299
+ update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
300
+ update_img = gr.update(value=create_mixed_img("", new_state, width, height))
301
+ update_sketch = gr.update(value=None, interactive=False)
302
+ return new_state, update_sketch, update, update_img
303
+
304
+
305
+ def resize(img, w, h):
306
+ trs = transforms.Compose(
307
+ [
308
+ transforms.ToPILImage(),
309
+ transforms.Resize(min(h, w)),
310
+ transforms.CenterCrop((h, w)),
311
+ ]
312
+ )
313
+ result = np.array(trs(img), dtype=np.uint8)
314
+ return result
315
+
316
+
317
+ def switch_canvas(entry, state, width, height):
318
+ if entry == None:
319
+ return None, 0.5, create_mixed_img("", state, width, height)
320
+ return (
321
+ gr.update(value=None, interactive=True),
322
+ gr.update(value=state[entry]["weight"]),
323
+ create_mixed_img(entry, state, width, height),
324
+ )
325
+
326
+
327
+ def apply_canvas(selected, draw, state, w, h):
328
+ w, h = int(w), int(h)
329
+ state[selected]["map"] = resize(draw, w, h)
330
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
331
+
332
+
333
+ def apply_weight(selected, weight, state):
334
+ state[selected]["weight"] = weight
335
+ return state
336
+
337
+
338
+ # sp2, radio, width, height, global_stats
339
+ def apply_image(image, selected, w, h, strgength, state):
340
+ if selected is not None:
341
+ state[selected] = {"map": resize(image, w, h), "weight": strgength}
342
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
343
+
344
+
345
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
346
+ def add_net(files: list[tempfile._TemporaryFileWrapper], ti_state, lora_state):
347
+ if files is None:
348
+ return ti_state, "", lora_state, None
349
+
350
+ for file in files:
351
+ item = Path(file.name)
352
+ stripedname = str(item.stem).strip()
353
+ if item.suffix == ".pt":
354
+ state_dict = torch.load(file.name, map_location="cpu")
355
+ else:
356
+ state_dict = load_file(file.name, device="cpu")
357
+ if any("lora" in k for k in state_dict.keys()):
358
+ lora_state = file.name
359
+ else:
360
+ ti_state[stripedname] = file.name
361
+
362
+ return ti_state, lora_state, gr.Text.update(f"{[key for key in ti_state.keys()]}"), gr.Text.update(f"{lora_state}"), gr.Files.update(value=None)
363
+
364
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
365
+ def clean_states(ti_state, lora_state):
366
+ return dict(), None, gr.Text.update(f""), gr.Text.update(f""), gr.File.update(value=None)
367
+
368
+
369
+ latent_upscale_modes = {
370
+ "Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
371
+ "Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
372
+ "Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
373
+ "Latent (bicubic antialiased)": {
374
+ "upscale_method": "bicubic",
375
+ "upscale_antialias": True,
376
+ },
377
+ "Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
378
+ "Latent (nearest-exact)": {
379
+ "upscale_method": "nearest-exact",
380
+ "upscale_antialias": False,
381
+ },
382
+ }
383
+
384
+ css = """
385
+ .finetuned-diffusion-div div{
386
+ display:inline-flex;
387
+ align-items:center;
388
+ gap:.8rem;
389
+ font-size:1.75rem;
390
+ padding-top:2rem;
391
+ }
392
+ .finetuned-diffusion-div div h1{
393
+ font-weight:900;
394
+ margin-bottom:7px
395
+ }
396
+ .finetuned-diffusion-div p{
397
+ margin-bottom:10px;
398
+ font-size:94%
399
+ }
400
+ .box {
401
+ float: left;
402
+ height: 20px;
403
+ width: 20px;
404
+ margin-bottom: 15px;
405
+ border: 1px solid black;
406
+ clear: both;
407
+ }
408
+ a{
409
+ text-decoration:underline
410
+ }
411
+ .tabs{
412
+ margin-top:0;
413
+ margin-bottom:0
414
+ }
415
+ #gallery{
416
+ min-height:20rem
417
+ }
418
+ .no-border {
419
+ border: none !important;
420
+ }
421
+ """
422
+ with gr.Blocks(css=css) as demo:
423
+ gr.HTML(
424
+ f"""
425
+ <div class="finetuned-diffusion-div">
426
+ <div>
427
+ <h1>Demo for diffusion models</h1>
428
+ </div>
429
+ <p>Hso @ nyanko.sketch2img.gradio</p>
430
+ </div>
431
+ """
432
+ )
433
+ global_stats = gr.State(value={})
434
+
435
+ with gr.Row():
436
+
437
+ with gr.Column(scale=55):
438
+ model = gr.Dropdown(
439
+ choices=[k[0] for k in get_model_list()],
440
+ label="Model",
441
+ value=base_name,
442
+ )
443
+ image_out = gr.Image(height=512)
444
+ # gallery = gr.Gallery(
445
+ # label="Generated images", show_label=False, elem_id="gallery"
446
+ # ).style(grid=[1], height="auto")
447
+
448
+ with gr.Column(scale=45):
449
+
450
+ with gr.Group():
451
+
452
+ with gr.Row():
453
+ with gr.Column(scale=70):
454
+
455
+ prompt = gr.Textbox(
456
+ label="Prompt",
457
+ value="loli cat girl, blue eyes, flat chest, solo, long messy silver hair, blue capelet, garden, cat ears, cat tail, upper body",
458
+ show_label=True,
459
+ max_lines=4,
460
+ placeholder="Enter prompt.",
461
+ )
462
+ neg_prompt = gr.Textbox(
463
+ label="Negative Prompt",
464
+ value="bad quality, low quality, jpeg artifact, cropped",
465
+ show_label=True,
466
+ max_lines=4,
467
+ placeholder="Enter negative prompt.",
468
+ )
469
+
470
+ generate = gr.Button(value="Generate").style(
471
+ rounded=(False, True, True, False)
472
+ )
473
+
474
+ with gr.Tab("Options"):
475
+
476
+ with gr.Group():
477
+
478
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
479
+ with gr.Row():
480
+ guidance = gr.Slider(
481
+ label="Guidance scale", value=7.5, maximum=15
482
+ )
483
+ steps = gr.Slider(
484
+ label="Steps", value=25, minimum=2, maximum=75, step=1
485
+ )
486
+
487
+ with gr.Row():
488
+ width = gr.Slider(
489
+ label="Width", value=512, minimum=64, maximum=2048, step=64
490
+ )
491
+ height = gr.Slider(
492
+ label="Height", value=512, minimum=64, maximum=2048, step=64
493
+ )
494
+
495
+ sampler = gr.Dropdown(
496
+ value="DPM++ 2M Karras",
497
+ label="Sampler",
498
+ choices=[s[0] for s in samplers_k_diffusion],
499
+ )
500
+ seed = gr.Number(label="Seed (0 = random)", value=0)
501
+
502
+ with gr.Tab("Image to image"):
503
+ with gr.Group():
504
+
505
+ inf_image = gr.Image(
506
+ label="Image", height=256, tool="editor", type="pil"
507
+ )
508
+ inf_strength = gr.Slider(
509
+ label="Transformation strength",
510
+ minimum=0,
511
+ maximum=1,
512
+ step=0.01,
513
+ value=0.5,
514
+ )
515
+
516
+ def res_cap(g, w, h, x):
517
+ if g:
518
+ return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
519
+ else:
520
+ return "Enable upscaler"
521
+
522
+ with gr.Tab("Hires fix"):
523
+ with gr.Group():
524
+
525
+ hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
526
+ hr_method = gr.Dropdown(
527
+ [key for key in latent_upscale_modes.keys()],
528
+ value="Latent",
529
+ label="Upscale method",
530
+ )
531
+ hr_scale = gr.Slider(
532
+ label="Upscale factor",
533
+ minimum=1.0,
534
+ maximum=3,
535
+ step=0.1,
536
+ value=1.5,
537
+ )
538
+ hr_denoise = gr.Slider(
539
+ label="Denoising strength",
540
+ minimum=0.0,
541
+ maximum=1.0,
542
+ step=0.1,
543
+ value=0.8,
544
+ )
545
+
546
+ hr_scale.change(
547
+ lambda g, x, w, h: gr.Checkbox.update(
548
+ label=res_cap(g, w, h, x)
549
+ ),
550
+ inputs=[hr_enabled, hr_scale, width, height],
551
+ outputs=hr_enabled,
552
+ )
553
+ hr_enabled.change(
554
+ lambda g, x, w, h: gr.Checkbox.update(
555
+ label=res_cap(g, w, h, x)
556
+ ),
557
+ inputs=[hr_enabled, hr_scale, width, height],
558
+ outputs=hr_enabled,
559
+ )
560
+
561
+ with gr.Tab("Embeddings/Loras"):
562
+
563
+ ti_state = gr.State(dict())
564
+ lora_state = gr.State()
565
+
566
+ with gr.Group():
567
+ with gr.Row():
568
+ with gr.Column(scale=90):
569
+ ti_vals = gr.Text(label="Loaded embeddings")
570
+
571
+ with gr.Row():
572
+ with gr.Column(scale=90):
573
+ lora_vals = gr.Text(label="Loaded loras")
574
+
575
+ with gr.Row():
576
+
577
+ uploads = gr.Files(label="Upload new embeddings/lora")
578
+
579
+ with gr.Column():
580
+ lora_scale = gr.Slider(
581
+ label="Lora scale",
582
+ minimum=0,
583
+ maximum=2,
584
+ step=0.01,
585
+ value=1.0,
586
+ )
587
+ btn = gr.Button(value="Upload")
588
+ btn_del = gr.Button(value="Reset")
589
+
590
+ btn.click(
591
+ add_net, inputs=[uploads, ti_state, lora_state], outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads]
592
+ )
593
+ btn_del.click(
594
+ clean_states, inputs=[ti_state, lora_state], outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads]
595
+ )
596
+
597
+ # error_output = gr.Markdown()
598
+
599
+ gr.HTML(
600
+ f"""
601
+ <div class="finetuned-diffusion-div">
602
+ <div>
603
+ <h1>Paint with words</h1>
604
+ </div>
605
+ <p>
606
+ Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
607
+ </p>
608
+ </div>
609
+ """
610
+ )
611
+
612
+ with gr.Row():
613
+
614
+ with gr.Column(scale=55):
615
+
616
+ rendered = gr.Image(
617
+ invert_colors=True,
618
+ source="canvas",
619
+ interactive=False,
620
+ image_mode="RGBA",
621
+ )
622
+
623
+ with gr.Column(scale=45):
624
+
625
+ with gr.Group():
626
+ with gr.Row():
627
+ with gr.Column(scale=70):
628
+ g_strength = gr.Slider(
629
+ label="Weight scaling",
630
+ minimum=0,
631
+ maximum=0.8,
632
+ step=0.01,
633
+ value=0.4,
634
+ )
635
+
636
+ text = gr.Textbox(
637
+ lines=2,
638
+ interactive=True,
639
+ label="Token to Draw: (Separate by comma)",
640
+ )
641
+
642
+ radio = gr.Radio([], label="Tokens")
643
+
644
+ sk_update = gr.Button(value="Update").style(
645
+ rounded=(False, True, True, False)
646
+ )
647
+
648
+ # g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
649
+
650
+ with gr.Tab("SketchPad"):
651
+
652
+ sp = gr.Image(
653
+ image_mode="L",
654
+ tool="sketch",
655
+ source="canvas",
656
+ interactive=False,
657
+ )
658
+
659
+ strength = gr.Slider(
660
+ label="Token strength",
661
+ minimum=0,
662
+ maximum=0.8,
663
+ step=0.01,
664
+ value=0.5,
665
+ )
666
+
667
+ sk_update.click(
668
+ detect_text,
669
+ inputs=[text, global_stats, width, height],
670
+ outputs=[global_stats, sp, radio, rendered],
671
+ )
672
+ radio.change(
673
+ switch_canvas,
674
+ inputs=[radio, global_stats, width, height],
675
+ outputs=[sp, strength, rendered],
676
+ )
677
+ sp.edit(
678
+ apply_canvas,
679
+ inputs=[radio, sp, global_stats, width, height],
680
+ outputs=[global_stats, rendered],
681
+ )
682
+ strength.change(
683
+ apply_weight,
684
+ inputs=[radio, strength, global_stats],
685
+ outputs=[global_stats],
686
+ )
687
+
688
+ with gr.Tab("UploadFile"):
689
+
690
+ sp2 = gr.Image(
691
+ image_mode="L",
692
+ source="upload",
693
+ shape=(512, 512),
694
+ )
695
+
696
+ strength2 = gr.Slider(
697
+ label="Token strength",
698
+ minimum=0,
699
+ maximum=0.8,
700
+ step=0.01,
701
+ value=0.5,
702
+ )
703
+
704
+ apply_style = gr.Button(value="Apply")
705
+ apply_style.click(
706
+ apply_image,
707
+ inputs=[sp2, radio, width, height, strength2, global_stats],
708
+ outputs=[global_stats, rendered],
709
+ )
710
+
711
+ width.change(
712
+ apply_new_res,
713
+ inputs=[width, height, global_stats],
714
+ outputs=[global_stats, rendered],
715
+ )
716
+ height.change(
717
+ apply_new_res,
718
+ inputs=[width, height, global_stats],
719
+ outputs=[global_stats, rendered],
720
+ )
721
+
722
+ # color_stats = gr.State(value={})
723
+ # text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
724
+ # sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
725
+
726
+ inputs = [
727
+ prompt,
728
+ guidance,
729
+ steps,
730
+ width,
731
+ height,
732
+ seed,
733
+ neg_prompt,
734
+ global_stats,
735
+ g_strength,
736
+ inf_image,
737
+ inf_strength,
738
+ hr_enabled,
739
+ hr_method,
740
+ hr_scale,
741
+ hr_denoise,
742
+ sampler,
743
+ ti_state,
744
+ model,
745
+ lora_state,
746
+ lora_scale
747
+ ]
748
+ outputs = [image_out]
749
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
750
+ generate.click(inference, inputs=inputs, outputs=outputs)
751
+
752
+ print(f"Space built in {time.time() - start_time:.2f} seconds")
753
+ # demo.launch(share=True)
754
+ demo.launch(share=True, enable_queue=True)