Jeffiyyyy commited on
Commit
90ad87a
โ€ข
1 Parent(s): 97817d3
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. README.md +6 -6
  3. app.py +878 -0
  4. modules/lora.py +183 -0
  5. modules/model.py +897 -0
  6. modules/prompt_parser.py +391 -0
  7. modules/safe.py +188 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: LSPDEMO
3
- emoji: ๐Ÿƒ
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
- license: openrail
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LSP LearningandStrivePartner Model
3
+ emoji: ๐Ÿ 
4
+ colorFrom: green
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.17.0
8
  app_file: app.py
9
  pinned: false
10
+ license: afl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tempfile
3
+ import time
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import math
8
+ import re
9
+
10
+ from gradio import inputs
11
+ from diffusers import (
12
+ AutoencoderKL,
13
+ DDIMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from modules.model import (
17
+ CrossAttnProcessor,
18
+ StableDiffusionPipeline,
19
+ )
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+ from PIL import Image
23
+ from pathlib import Path
24
+ from safetensors.torch import load_file
25
+ import modules.safe as _
26
+ from modules.lora import LoRANetwork
27
+
28
+ models = [
29
+ ("LSPV1", "Jeffsun/LSP", 2),
30
+ ("Pastal Mix", "andite/pastel-mix", 2),
31
+ ("Basil Mix", "nuigurumi/basil_mix", 2)
32
+ ]
33
+
34
+ keep_vram = ["Korakoe/AbyssOrangeMix2-HF", "andite/pastel-mix"]
35
+ base_name, base_model, clip_skip = models[0]
36
+
37
+ samplers_k_diffusion = [
38
+ ("Euler a", "sample_euler_ancestral", {}),
39
+ ("Euler", "sample_euler", {}),
40
+ ("LMS", "sample_lms", {}),
41
+ ("Heun", "sample_heun", {}),
42
+ ("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
43
+ ("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
44
+ ("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
45
+ ("DPM++ 2M", "sample_dpmpp_2m", {}),
46
+ ("DPM++ SDE", "sample_dpmpp_sde", {}),
47
+ ("LMS Karras", "sample_lms", {"scheduler": "karras"}),
48
+ ("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
49
+ ("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
50
+ ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
51
+ ("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
52
+ ("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
53
+ ]
54
+
55
+ # samplers_diffusers = [
56
+ # ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {})
57
+ # ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {})
58
+ # ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {})
59
+ # ]
60
+
61
+ start_time = time.time()
62
+ timeout = 90
63
+
64
+ scheduler = DDIMScheduler.from_pretrained(
65
+ base_model,
66
+ subfolder="scheduler",
67
+ )
68
+ vae = AutoencoderKL.from_pretrained(
69
+ "stabilityai/sd-vae-ft-ema",
70
+ torch_dtype=torch.float16
71
+ )
72
+ text_encoder = CLIPTextModel.from_pretrained(
73
+ base_model,
74
+ subfolder="text_encoder",
75
+ torch_dtype=torch.float16,
76
+ )
77
+ tokenizer = CLIPTokenizer.from_pretrained(
78
+ base_model,
79
+ subfolder="tokenizer",
80
+ torch_dtype=torch.float16,
81
+ )
82
+ unet = UNet2DConditionModel.from_pretrained(
83
+ base_model,
84
+ subfolder="unet",
85
+ torch_dtype=torch.float16,
86
+ )
87
+ pipe = StableDiffusionPipeline(
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ unet=unet,
91
+ vae=vae,
92
+ scheduler=scheduler,
93
+ )
94
+
95
+ unet.set_attn_processor(CrossAttnProcessor)
96
+ pipe.setup_text_encoder(clip_skip, text_encoder)
97
+ if torch.cuda.is_available():
98
+ pipe = pipe.to("cuda")
99
+
100
+ def get_model_list():
101
+ return models
102
+
103
+ te_cache = {
104
+ base_model: text_encoder
105
+ }
106
+
107
+ unet_cache = {
108
+ base_model: unet
109
+ }
110
+
111
+ lora_cache = {
112
+ base_model: LoRANetwork(text_encoder, unet)
113
+ }
114
+
115
+ te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
116
+ original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
117
+ current_model = base_model
118
+
119
+ def setup_model(name, lora_state=None, lora_scale=1.0):
120
+ global pipe, current_model
121
+
122
+ keys = [k[0] for k in models]
123
+ model = models[keys.index(name)][1]
124
+ if model not in unet_cache:
125
+ unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
126
+ text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16)
127
+
128
+ unet_cache[model] = unet
129
+ te_cache[model] = text_encoder
130
+ lora_cache[model] = LoRANetwork(text_encoder, unet)
131
+
132
+ if current_model != model:
133
+ if current_model not in keep_vram:
134
+ # offload current model
135
+ unet_cache[current_model].to("cpu")
136
+ te_cache[current_model].to("cpu")
137
+ lora_cache[current_model].to("cpu")
138
+ current_model = model
139
+
140
+ local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model]
141
+ local_unet.set_attn_processor(CrossAttnProcessor())
142
+ local_lora.reset()
143
+ clip_skip = models[keys.index(name)][2]
144
+
145
+ if torch.cuda.is_available():
146
+ local_unet.to("cuda")
147
+ local_te.to("cuda")
148
+
149
+ if lora_state is not None and lora_state != "":
150
+ local_lora.load(lora_state, lora_scale)
151
+ local_lora.to(local_unet.device, dtype=local_unet.dtype)
152
+
153
+ pipe.text_encoder, pipe.unet = local_te, local_unet
154
+ pipe.setup_unet(local_unet)
155
+ pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
156
+ pipe.tokenizer.added_tokens_encoder = {}
157
+ pipe.tokenizer.added_tokens_decoder = {}
158
+ pipe.setup_text_encoder(clip_skip, local_te)
159
+ return pipe
160
+
161
+
162
+ def error_str(error, title="Error"):
163
+ return (
164
+ f"""#### {title}
165
+ {error}"""
166
+ if error
167
+ else ""
168
+ )
169
+
170
+ def make_token_names(embs):
171
+ all_tokens = []
172
+ for name, vec in embs.items():
173
+ tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
174
+ all_tokens.append(tokens)
175
+ return all_tokens
176
+
177
+ def setup_tokenizer(tokenizer, embs):
178
+ reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
179
+ clip_keywords = [' '.join(s) for s in make_token_names(embs)]
180
+
181
+ def parse_prompt(prompt: str):
182
+ for m, v in zip(reg_match, clip_keywords):
183
+ prompt = m.sub(v, prompt)
184
+ return prompt
185
+
186
+ def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs):
187
+ text = parse_prompt(text)
188
+ r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs)
189
+ return r
190
+ tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer)
191
+ return [t for sublist in make_token_names(embs) for t in sublist]
192
+
193
+
194
+ def convert_size(size_bytes):
195
+ if size_bytes == 0:
196
+ return "0B"
197
+ size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
198
+ i = int(math.floor(math.log(size_bytes, 1024)))
199
+ p = math.pow(1024, i)
200
+ s = round(size_bytes / p, 2)
201
+ return "%s %s" % (s, size_name[i])
202
+
203
+ def inference(
204
+ prompt,
205
+ guidance,
206
+ steps,
207
+ width=512,
208
+ height=512,
209
+ seed=0,
210
+ neg_prompt="",
211
+ state=None,
212
+ g_strength=0.4,
213
+ img_input=None,
214
+ i2i_scale=0.5,
215
+ hr_enabled=False,
216
+ hr_method="Latent",
217
+ hr_scale=1.5,
218
+ hr_denoise=0.8,
219
+ sampler="DPM++ 2M Karras",
220
+ embs=None,
221
+ model=None,
222
+ lora_state=None,
223
+ lora_scale=None,
224
+ ):
225
+ if seed is None or seed == 0:
226
+ seed = random.randint(0, 2147483647)
227
+
228
+ pipe = setup_model(model, lora_state, lora_scale)
229
+ generator = torch.Generator("cuda").manual_seed(int(seed))
230
+ start_time = time.time()
231
+
232
+ sampler_name, sampler_opt = None, None
233
+ for label, funcname, options in samplers_k_diffusion:
234
+ if label == sampler:
235
+ sampler_name, sampler_opt = funcname, options
236
+
237
+ tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder
238
+ if embs is not None and len(embs) > 0:
239
+ ti_embs = {}
240
+ for name, file in embs.items():
241
+ if str(file).endswith(".pt"):
242
+ loaded_learned_embeds = torch.load(file, map_location="cpu")
243
+ else:
244
+ loaded_learned_embeds = load_file(file, device="cpu")
245
+ loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] if "string_to_param" in loaded_learned_embed else loaded_learned_embed
246
+ ti_embs[name] = loaded_learned_embeds
247
+
248
+ if len(ti_embs) > 0:
249
+ tokens = setup_tokenizer(tokenizer, ti_embs)
250
+ added_tokens = tokenizer.add_tokens(tokens)
251
+ delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
252
+
253
+ assert added_tokens == delta_weight.shape[0]
254
+ text_encoder.resize_token_embeddings(len(tokenizer))
255
+ token_embeds = text_encoder.get_input_embeddings().weight.data
256
+ token_embeds[-delta_weight.shape[0]:] = delta_weight
257
+
258
+ config = {
259
+ "negative_prompt": neg_prompt,
260
+ "num_inference_steps": int(steps),
261
+ "guidance_scale": guidance,
262
+ "generator": generator,
263
+ "sampler_name": sampler_name,
264
+ "sampler_opt": sampler_opt,
265
+ "pww_state": state,
266
+ "pww_attn_weight": g_strength,
267
+ "start_time": start_time,
268
+ "timeout": timeout,
269
+ }
270
+
271
+ if img_input is not None:
272
+ ratio = min(height / img_input.height, width / img_input.width)
273
+ img_input = img_input.resize(
274
+ (int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
275
+ )
276
+ result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
277
+ elif hr_enabled:
278
+ result = pipe.txt2img(
279
+ prompt,
280
+ width=width,
281
+ height=height,
282
+ upscale=True,
283
+ upscale_x=hr_scale,
284
+ upscale_denoising_strength=hr_denoise,
285
+ **config,
286
+ **latent_upscale_modes[hr_method],
287
+ )
288
+ else:
289
+ result = pipe.txt2img(prompt, width=width, height=height, **config)
290
+
291
+ end_time = time.time()
292
+ vram_free, vram_total = torch.cuda.mem_get_info()
293
+ print(f"done: model={model}, res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
294
+ return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
295
+
296
+
297
+ color_list = []
298
+
299
+
300
+ def get_color(n):
301
+ for _ in range(n - len(color_list)):
302
+ color_list.append(tuple(np.random.random(size=3) * 256))
303
+ return color_list
304
+
305
+
306
+ def create_mixed_img(current, state, w=512, h=512):
307
+ w, h = int(w), int(h)
308
+ image_np = np.full([h, w, 4], 255)
309
+ if state is None:
310
+ state = {}
311
+
312
+ colors = get_color(len(state))
313
+ idx = 0
314
+
315
+ for key, item in state.items():
316
+ if item["map"] is not None:
317
+ m = item["map"] < 255
318
+ alpha = 150
319
+ if current == key:
320
+ alpha = 200
321
+ image_np[m] = colors[idx] + (alpha,)
322
+ idx += 1
323
+
324
+ return image_np
325
+
326
+
327
+ # width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
328
+ def apply_new_res(w, h, state):
329
+ w, h = int(w), int(h)
330
+
331
+ for key, item in state.items():
332
+ if item["map"] is not None:
333
+ item["map"] = resize(item["map"], w, h)
334
+
335
+ update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
336
+ return state, update_img
337
+
338
+
339
+ def detect_text(text, state, width, height):
340
+
341
+ if text is None or text == "":
342
+ return None, None, gr.Radio.update(value=None), None
343
+
344
+ t = text.split(",")
345
+ new_state = {}
346
+
347
+ for item in t:
348
+ item = item.strip()
349
+ if item == "":
350
+ continue
351
+ if state is not None and item in state:
352
+ new_state[item] = {
353
+ "map": state[item]["map"],
354
+ "weight": state[item]["weight"],
355
+ "mask_outsides": state[item]["mask_outsides"],
356
+ }
357
+ else:
358
+ new_state[item] = {
359
+ "map": None,
360
+ "weight": 0.5,
361
+ "mask_outsides": False
362
+ }
363
+ update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
364
+ update_img = gr.update(value=create_mixed_img("", new_state, width, height))
365
+ update_sketch = gr.update(value=None, interactive=False)
366
+ return new_state, update_sketch, update, update_img
367
+
368
+
369
+ def resize(img, w, h):
370
+ trs = transforms.Compose(
371
+ [
372
+ transforms.ToPILImage(),
373
+ transforms.Resize(min(h, w)),
374
+ transforms.CenterCrop((h, w)),
375
+ ]
376
+ )
377
+ result = np.array(trs(img), dtype=np.uint8)
378
+ return result
379
+
380
+
381
+ def switch_canvas(entry, state, width, height):
382
+ if entry == None:
383
+ return None, 0.5, False, create_mixed_img("", state, width, height)
384
+
385
+ return (
386
+ gr.update(value=None, interactive=True),
387
+ gr.update(value=state[entry]["weight"] if entry in state else 0.5),
388
+ gr.update(value=state[entry]["mask_outsides"] if entry in state else False),
389
+ create_mixed_img(entry, state, width, height),
390
+ )
391
+
392
+
393
+ def apply_canvas(selected, draw, state, w, h):
394
+ if selected in state:
395
+ w, h = int(w), int(h)
396
+ state[selected]["map"] = resize(draw, w, h)
397
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
398
+
399
+
400
+ def apply_weight(selected, weight, state):
401
+ if selected in state:
402
+ state[selected]["weight"] = weight
403
+ return state
404
+
405
+
406
+ def apply_option(selected, mask, state):
407
+ if selected in state:
408
+ state[selected]["mask_outsides"] = mask
409
+ return state
410
+
411
+
412
+ # sp2, radio, width, height, global_stats
413
+ def apply_image(image, selected, w, h, strgength, mask, state):
414
+ if selected in state:
415
+ state[selected] = {
416
+ "map": resize(image, w, h),
417
+ "weight": strgength,
418
+ "mask_outsides": mask
419
+ }
420
+
421
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
422
+
423
+
424
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
425
+ def add_net(files, ti_state, lora_state):
426
+ if files is None:
427
+ return ti_state, "", lora_state, None
428
+
429
+ for file in files:
430
+ item = Path(file.name)
431
+ stripedname = str(item.stem).strip()
432
+ if item.suffix == ".pt":
433
+ state_dict = torch.load(file.name, map_location="cpu")
434
+ else:
435
+ state_dict = load_file(file.name, device="cpu")
436
+ if any("lora" in k for k in state_dict.keys()):
437
+ lora_state = file.name
438
+ else:
439
+ ti_state[stripedname] = file.name
440
+
441
+ return (
442
+ ti_state,
443
+ lora_state,
444
+ gr.Text.update(f"{[key for key in ti_state.keys()]}"),
445
+ gr.Text.update(f"{lora_state}"),
446
+ gr.Files.update(value=None),
447
+ )
448
+
449
+
450
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
451
+ def clean_states(ti_state, lora_state):
452
+ return (
453
+ dict(),
454
+ None,
455
+ gr.Text.update(f""),
456
+ gr.Text.update(f""),
457
+ gr.File.update(value=None),
458
+ )
459
+
460
+
461
+ latent_upscale_modes = {
462
+ "Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
463
+ "Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
464
+ "Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
465
+ "Latent (bicubic antialiased)": {
466
+ "upscale_method": "bicubic",
467
+ "upscale_antialias": True,
468
+ },
469
+ "Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
470
+ "Latent (nearest-exact)": {
471
+ "upscale_method": "nearest-exact",
472
+ "upscale_antialias": False,
473
+ },
474
+ }
475
+
476
+ css = """
477
+ .finetuned-diffusion-div div{
478
+ display:inline-flex;
479
+ align-items:center;
480
+ gap:.8rem;
481
+ font-size:1.75rem;
482
+ padding-top:2rem;
483
+ }
484
+ .finetuned-diffusion-div div h1{
485
+ font-weight:900;
486
+ margin-bottom:7px
487
+ }
488
+ .finetuned-diffusion-div p{
489
+ margin-bottom:10px;
490
+ font-size:94%
491
+ }
492
+ .box {
493
+ float: left;
494
+ height: 20px;
495
+ width: 20px;
496
+ margin-bottom: 15px;
497
+ border: 1px solid black;
498
+ clear: both;
499
+ }
500
+ a{
501
+ text-decoration:underline
502
+ }
503
+ .tabs{
504
+ margin-top:0;
505
+ margin-bottom:0
506
+ }
507
+ #gallery{
508
+ min-height:20rem
509
+ }
510
+ .no-border {
511
+ border: none !important;
512
+ }
513
+ """
514
+ with gr.Blocks(css=css) as demo:
515
+ gr.HTML(
516
+ f"""
517
+ <div class="finetuned-diffusion-div">
518
+ <div>
519
+ <h1>Demo for diffusion models</h1>
520
+ </div>
521
+ <p>Hso @ nyanko.sketch2img.gradio</p>
522
+ </div>
523
+ """
524
+ )
525
+ global_stats = gr.State(value={})
526
+
527
+ with gr.Row():
528
+
529
+ with gr.Column(scale=55):
530
+ model = gr.Dropdown(
531
+ choices=[k[0] for k in get_model_list()],
532
+ label="Model",
533
+ value=base_name,
534
+ )
535
+ image_out = gr.Image(height=512)
536
+ # gallery = gr.Gallery(
537
+ # label="Generated images", show_label=False, elem_id="gallery"
538
+ # ).style(grid=[1], height="auto")
539
+
540
+ with gr.Column(scale=45):
541
+
542
+ with gr.Group():
543
+
544
+ with gr.Row():
545
+ with gr.Column(scale=70):
546
+
547
+ prompt = gr.Textbox(
548
+ label="Prompt",
549
+ value="best quality, masterpiece, highres, an extremely delicate and beautiful, original, extremely detailed wallpaper, highres , 1girl",
550
+ show_label=True,
551
+ max_lines=4,
552
+ placeholder="Enter prompt.",
553
+ )
554
+ neg_prompt = gr.Textbox(
555
+ label="Negative Prompt",
556
+ value="simple background,monochrome ,lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits,twisting jawline, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, lowres, bad anatomy, bad hands, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, ugly,pregnant,vore,duplicate,morbid,mut ilated,tran nsexual, hermaphrodite,long neck,mutated hands,poorly drawn hands,poorly drawn face,mutation,deformed,blurry,bad anatomy,bad proportions,malformed limbs,extra limbs,cloned face,disfigured,gross proportions, missing arms, missing legs, extra arms,extra legs,pubic hair, plump,bad legs,error legs,username,blurry,bad feet",
557
+ show_label=True,
558
+ max_lines=4,
559
+ placeholder="Enter negative prompt.",
560
+ )
561
+
562
+ generate = gr.Button(value="Generate").style(
563
+ rounded=(False, True, True, False)
564
+ )
565
+
566
+ with gr.Tab("Options"):
567
+
568
+ with gr.Group():
569
+
570
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
571
+ with gr.Row():
572
+ guidance = gr.Slider(
573
+ label="Guidance scale", value=7.5, maximum=15
574
+ )
575
+ steps = gr.Slider(
576
+ label="Steps", value=25, minimum=2, maximum=50, step=1
577
+ )
578
+
579
+ with gr.Row():
580
+ width = gr.Slider(
581
+ label="Width", value=512, minimum=64, maximum=1024, step=64
582
+ )
583
+ height = gr.Slider(
584
+ label="Height", value=512, minimum=64, maximum=1024, step=64
585
+ )
586
+
587
+ sampler = gr.Dropdown(
588
+ value="DPM++ 2M Karras",
589
+ label="Sampler",
590
+ choices=[s[0] for s in samplers_k_diffusion],
591
+ )
592
+ seed = gr.Number(label="Seed (0 = random)", value=0)
593
+
594
+ with gr.Tab("Image to image"):
595
+ with gr.Group():
596
+
597
+ inf_image = gr.Image(
598
+ label="Image", height=256, tool="editor", type="pil"
599
+ )
600
+ inf_strength = gr.Slider(
601
+ label="Transformation strength",
602
+ minimum=0,
603
+ maximum=1,
604
+ step=0.01,
605
+ value=0.5,
606
+ )
607
+
608
+ def res_cap(g, w, h, x):
609
+ if g:
610
+ return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
611
+ else:
612
+ return "Enable upscaler"
613
+
614
+ with gr.Tab("Hires fix"):
615
+ with gr.Group():
616
+
617
+ hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
618
+ hr_method = gr.Dropdown(
619
+ [key for key in latent_upscale_modes.keys()],
620
+ value="Latent",
621
+ label="Upscale method",
622
+ )
623
+ hr_scale = gr.Slider(
624
+ label="Upscale factor",
625
+ minimum=1.0,
626
+ maximum=2.0,
627
+ step=0.1,
628
+ value=1.5,
629
+ )
630
+ hr_denoise = gr.Slider(
631
+ label="Denoising strength",
632
+ minimum=0.0,
633
+ maximum=1.0,
634
+ step=0.1,
635
+ value=0.8,
636
+ )
637
+
638
+ hr_scale.change(
639
+ lambda g, x, w, h: gr.Checkbox.update(
640
+ label=res_cap(g, w, h, x)
641
+ ),
642
+ inputs=[hr_enabled, hr_scale, width, height],
643
+ outputs=hr_enabled,
644
+ queue=False,
645
+ )
646
+ hr_enabled.change(
647
+ lambda g, x, w, h: gr.Checkbox.update(
648
+ label=res_cap(g, w, h, x)
649
+ ),
650
+ inputs=[hr_enabled, hr_scale, width, height],
651
+ outputs=hr_enabled,
652
+ queue=False,
653
+ )
654
+
655
+ with gr.Tab("Embeddings/Loras"):
656
+
657
+ ti_state = gr.State(dict())
658
+ lora_state = gr.State()
659
+
660
+ with gr.Group():
661
+ with gr.Row():
662
+ with gr.Column(scale=90):
663
+ ti_vals = gr.Text(label="Loaded embeddings")
664
+
665
+ with gr.Row():
666
+ with gr.Column(scale=90):
667
+ lora_vals = gr.Text(label="Loaded loras")
668
+
669
+ with gr.Row():
670
+
671
+ uploads = gr.Files(label="Upload new embeddings/lora")
672
+
673
+ with gr.Column():
674
+ lora_scale = gr.Slider(
675
+ label="Lora scale",
676
+ minimum=0,
677
+ maximum=2,
678
+ step=0.01,
679
+ value=1.0,
680
+ )
681
+ btn = gr.Button(value="Upload")
682
+ btn_del = gr.Button(value="Reset")
683
+
684
+ btn.click(
685
+ add_net,
686
+ inputs=[uploads, ti_state, lora_state],
687
+ outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
688
+ queue=False,
689
+ )
690
+ btn_del.click(
691
+ clean_states,
692
+ inputs=[ti_state, lora_state],
693
+ outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
694
+ queue=False,
695
+ )
696
+
697
+ # error_output = gr.Markdown()
698
+
699
+ gr.HTML(
700
+ f"""
701
+ <div class="finetuned-diffusion-div">
702
+ <div>
703
+ <h1>Paint with words</h1>
704
+ </div>
705
+ <p>
706
+ Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
707
+ </p>
708
+ </div>
709
+ """
710
+ )
711
+
712
+ with gr.Row():
713
+
714
+ with gr.Column(scale=55):
715
+
716
+ rendered = gr.Image(
717
+ invert_colors=True,
718
+ source="canvas",
719
+ interactive=False,
720
+ image_mode="RGBA",
721
+ )
722
+
723
+ with gr.Column(scale=45):
724
+
725
+ with gr.Group():
726
+ with gr.Row():
727
+ with gr.Column(scale=70):
728
+ g_strength = gr.Slider(
729
+ label="Weight scaling",
730
+ minimum=0,
731
+ maximum=0.8,
732
+ step=0.01,
733
+ value=0.4,
734
+ )
735
+
736
+ text = gr.Textbox(
737
+ lines=2,
738
+ interactive=True,
739
+ label="Token to Draw: (Separate by comma)",
740
+ )
741
+
742
+ radio = gr.Radio([], label="Tokens")
743
+
744
+ sk_update = gr.Button(value="Update").style(
745
+ rounded=(False, True, True, False)
746
+ )
747
+
748
+ # g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
749
+
750
+ with gr.Tab("SketchPad"):
751
+
752
+ sp = gr.Image(
753
+ image_mode="L",
754
+ tool="sketch",
755
+ source="canvas",
756
+ interactive=False,
757
+ )
758
+
759
+ mask_outsides = gr.Checkbox(
760
+ label="Mask other areas",
761
+ value=False
762
+ )
763
+
764
+ strength = gr.Slider(
765
+ label="Token strength",
766
+ minimum=0,
767
+ maximum=0.8,
768
+ step=0.01,
769
+ value=0.5,
770
+ )
771
+
772
+
773
+ sk_update.click(
774
+ detect_text,
775
+ inputs=[text, global_stats, width, height],
776
+ outputs=[global_stats, sp, radio, rendered],
777
+ queue=False,
778
+ )
779
+ radio.change(
780
+ switch_canvas,
781
+ inputs=[radio, global_stats, width, height],
782
+ outputs=[sp, strength, mask_outsides, rendered],
783
+ queue=False,
784
+ )
785
+ sp.edit(
786
+ apply_canvas,
787
+ inputs=[radio, sp, global_stats, width, height],
788
+ outputs=[global_stats, rendered],
789
+ queue=False,
790
+ )
791
+ strength.change(
792
+ apply_weight,
793
+ inputs=[radio, strength, global_stats],
794
+ outputs=[global_stats],
795
+ queue=False,
796
+ )
797
+ mask_outsides.change(
798
+ apply_option,
799
+ inputs=[radio, mask_outsides, global_stats],
800
+ outputs=[global_stats],
801
+ queue=False,
802
+ )
803
+
804
+ with gr.Tab("UploadFile"):
805
+
806
+ sp2 = gr.Image(
807
+ image_mode="L",
808
+ source="upload",
809
+ shape=(512, 512),
810
+ )
811
+
812
+ mask_outsides2 = gr.Checkbox(
813
+ label="Mask other areas",
814
+ value=False,
815
+ )
816
+
817
+ strength2 = gr.Slider(
818
+ label="Token strength",
819
+ minimum=0,
820
+ maximum=0.8,
821
+ step=0.01,
822
+ value=0.5,
823
+ )
824
+
825
+ apply_style = gr.Button(value="Apply")
826
+ apply_style.click(
827
+ apply_image,
828
+ inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats],
829
+ outputs=[global_stats, rendered],
830
+ queue=False,
831
+ )
832
+
833
+ width.change(
834
+ apply_new_res,
835
+ inputs=[width, height, global_stats],
836
+ outputs=[global_stats, rendered],
837
+ queue=False,
838
+ )
839
+ height.change(
840
+ apply_new_res,
841
+ inputs=[width, height, global_stats],
842
+ outputs=[global_stats, rendered],
843
+ queue=False,
844
+ )
845
+
846
+ # color_stats = gr.State(value={})
847
+ # text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
848
+ # sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
849
+
850
+ inputs = [
851
+ prompt,
852
+ guidance,
853
+ steps,
854
+ width,
855
+ height,
856
+ seed,
857
+ neg_prompt,
858
+ global_stats,
859
+ g_strength,
860
+ inf_image,
861
+ inf_strength,
862
+ hr_enabled,
863
+ hr_method,
864
+ hr_scale,
865
+ hr_denoise,
866
+ sampler,
867
+ ti_state,
868
+ model,
869
+ lora_state,
870
+ lora_scale,
871
+ ]
872
+ outputs = [image_out]
873
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
874
+ generate.click(inference, inputs=inputs, outputs=outputs)
875
+
876
+ print(f"Space built in {time.time() - start_time:.2f} seconds")
877
+ # demo.launch(share=True)
878
+ demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
modules/lora.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ import modules.safe as _
11
+ from safetensors.torch import load_file
12
+
13
+
14
+ class LoRAModule(torch.nn.Module):
15
+ """
16
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ lora_name,
22
+ org_module: torch.nn.Module,
23
+ multiplier=1.0,
24
+ lora_dim=4,
25
+ alpha=1,
26
+ ):
27
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
28
+ super().__init__()
29
+ self.lora_name = lora_name
30
+ self.lora_dim = lora_dim
31
+
32
+ if org_module.__class__.__name__ == "Conv2d":
33
+ in_dim = org_module.in_channels
34
+ out_dim = org_module.out_channels
35
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
36
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
37
+ else:
38
+ in_dim = org_module.in_features
39
+ out_dim = org_module.out_features
40
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
41
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
42
+
43
+ if type(alpha) == torch.Tensor:
44
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
45
+
46
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
47
+ self.scale = alpha / self.lora_dim
48
+ self.register_buffer("alpha", torch.tensor(alpha)) # ๅฎšๆ•ฐใจใ—ใฆๆ‰ฑใˆใ‚‹
49
+
50
+ # same as microsoft's
51
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
52
+ torch.nn.init.zeros_(self.lora_up.weight)
53
+
54
+ self.multiplier = multiplier
55
+ self.org_module = org_module # remove in applying
56
+ self.enable = False
57
+
58
+ def resize(self, rank, alpha, multiplier):
59
+ self.alpha = torch.tensor(alpha)
60
+ self.multiplier = multiplier
61
+ self.scale = alpha / rank
62
+ if self.lora_down.__class__.__name__ == "Conv2d":
63
+ in_dim = self.lora_down.in_channels
64
+ out_dim = self.lora_up.out_channels
65
+ self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
66
+ self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
67
+ else:
68
+ in_dim = self.lora_down.in_features
69
+ out_dim = self.lora_up.out_features
70
+ self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
71
+ self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
72
+
73
+ def apply(self):
74
+ if hasattr(self, "org_module"):
75
+ self.org_forward = self.org_module.forward
76
+ self.org_module.forward = self.forward
77
+ del self.org_module
78
+
79
+ def forward(self, x):
80
+ if self.enable:
81
+ return (
82
+ self.org_forward(x)
83
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
84
+ )
85
+ return self.org_forward(x)
86
+
87
+
88
+ class LoRANetwork(torch.nn.Module):
89
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
90
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
91
+ LORA_PREFIX_UNET = "lora_unet"
92
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
93
+
94
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
95
+ super().__init__()
96
+ self.multiplier = multiplier
97
+ self.lora_dim = lora_dim
98
+ self.alpha = alpha
99
+
100
+ # create module instances
101
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
102
+ loras = []
103
+ for name, module in root_module.named_modules():
104
+ if module.__class__.__name__ in target_replace_modules:
105
+ for child_name, child_module in module.named_modules():
106
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
107
+ lora_name = prefix + "." + name + "." + child_name
108
+ lora_name = lora_name.replace(".", "_")
109
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
110
+ loras.append(lora)
111
+ return loras
112
+
113
+ if isinstance(text_encoder, list):
114
+ self.text_encoder_loras = text_encoder
115
+ else:
116
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
117
+ print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
118
+
119
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
120
+ print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
121
+
122
+ self.weights_sd = None
123
+
124
+ # assertion
125
+ names = set()
126
+ for lora in self.text_encoder_loras + self.unet_loras:
127
+ assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
128
+ names.add(lora.lora_name)
129
+
130
+ lora.apply()
131
+ self.add_module(lora.lora_name, lora)
132
+
133
+ def reset(self):
134
+ for lora in self.text_encoder_loras + self.unet_loras:
135
+ lora.enable = False
136
+
137
+ def load(self, file, scale):
138
+
139
+ weights = None
140
+ if os.path.splitext(file)[1] == ".safetensors":
141
+ weights = load_file(file)
142
+ else:
143
+ weights = torch.load(file, map_location="cpu")
144
+
145
+ if not weights:
146
+ return
147
+
148
+ network_alpha = None
149
+ network_dim = None
150
+ for key, value in weights.items():
151
+ if network_alpha is None and "alpha" in key:
152
+ network_alpha = value
153
+ if network_dim is None and "lora_down" in key and len(value.size()) == 2:
154
+ network_dim = value.size()[0]
155
+
156
+ if network_alpha is None:
157
+ network_alpha = network_dim
158
+
159
+ weights_has_text_encoder = weights_has_unet = False
160
+ weights_to_modify = []
161
+
162
+ for key in weights.keys():
163
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
164
+ weights_has_text_encoder = True
165
+
166
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
167
+ weights_has_unet = True
168
+
169
+ if weights_has_text_encoder:
170
+ weights_to_modify += self.text_encoder_loras
171
+
172
+ if weights_has_unet:
173
+ weights_to_modify += self.unet_loras
174
+
175
+ for lora in self.text_encoder_loras + self.unet_loras:
176
+ lora.resize(network_dim, network_alpha, scale)
177
+ if lora in weights_to_modify:
178
+ lora.enable = True
179
+
180
+ info = self.load_state_dict(weights, False)
181
+ if len(info.unexpected_keys) > 0:
182
+ print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")
183
+
modules/model.py ADDED
@@ -0,0 +1,897 @@