Spaces:
Runtime error
Runtime error
Jeffiyyyy
commited on
Commit
•
90ad87a
1
Parent(s):
97817d3
demo
Browse files- .DS_Store +0 -0
- README.md +6 -6
- app.py +878 -0
- modules/lora.py +183 -0
- modules/model.py +897 -0
- modules/prompt_parser.py +391 -0
- modules/safe.py +188 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: LSP LearningandStrivePartner Model
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: afl-3.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import math
|
8 |
+
import re
|
9 |
+
|
10 |
+
from gradio import inputs
|
11 |
+
from diffusers import (
|
12 |
+
AutoencoderKL,
|
13 |
+
DDIMScheduler,
|
14 |
+
UNet2DConditionModel,
|
15 |
+
)
|
16 |
+
from modules.model import (
|
17 |
+
CrossAttnProcessor,
|
18 |
+
StableDiffusionPipeline,
|
19 |
+
)
|
20 |
+
from torchvision import transforms
|
21 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
22 |
+
from PIL import Image
|
23 |
+
from pathlib import Path
|
24 |
+
from safetensors.torch import load_file
|
25 |
+
import modules.safe as _
|
26 |
+
from modules.lora import LoRANetwork
|
27 |
+
|
28 |
+
models = [
|
29 |
+
("LSPV1", "Jeffsun/LSP", 2),
|
30 |
+
("Pastal Mix", "andite/pastel-mix", 2),
|
31 |
+
("Basil Mix", "nuigurumi/basil_mix", 2)
|
32 |
+
]
|
33 |
+
|
34 |
+
keep_vram = ["Korakoe/AbyssOrangeMix2-HF", "andite/pastel-mix"]
|
35 |
+
base_name, base_model, clip_skip = models[0]
|
36 |
+
|
37 |
+
samplers_k_diffusion = [
|
38 |
+
("Euler a", "sample_euler_ancestral", {}),
|
39 |
+
("Euler", "sample_euler", {}),
|
40 |
+
("LMS", "sample_lms", {}),
|
41 |
+
("Heun", "sample_heun", {}),
|
42 |
+
("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
|
43 |
+
("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
|
44 |
+
("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
|
45 |
+
("DPM++ 2M", "sample_dpmpp_2m", {}),
|
46 |
+
("DPM++ SDE", "sample_dpmpp_sde", {}),
|
47 |
+
("LMS Karras", "sample_lms", {"scheduler": "karras"}),
|
48 |
+
("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
|
49 |
+
("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
|
50 |
+
("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
|
51 |
+
("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
|
52 |
+
("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
|
53 |
+
]
|
54 |
+
|
55 |
+
# samplers_diffusers = [
|
56 |
+
# ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {})
|
57 |
+
# ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {})
|
58 |
+
# ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {})
|
59 |
+
# ]
|
60 |
+
|
61 |
+
start_time = time.time()
|
62 |
+
timeout = 90
|
63 |
+
|
64 |
+
scheduler = DDIMScheduler.from_pretrained(
|
65 |
+
base_model,
|
66 |
+
subfolder="scheduler",
|
67 |
+
)
|
68 |
+
vae = AutoencoderKL.from_pretrained(
|
69 |
+
"stabilityai/sd-vae-ft-ema",
|
70 |
+
torch_dtype=torch.float16
|
71 |
+
)
|
72 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
73 |
+
base_model,
|
74 |
+
subfolder="text_encoder",
|
75 |
+
torch_dtype=torch.float16,
|
76 |
+
)
|
77 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
78 |
+
base_model,
|
79 |
+
subfolder="tokenizer",
|
80 |
+
torch_dtype=torch.float16,
|
81 |
+
)
|
82 |
+
unet = UNet2DConditionModel.from_pretrained(
|
83 |
+
base_model,
|
84 |
+
subfolder="unet",
|
85 |
+
torch_dtype=torch.float16,
|
86 |
+
)
|
87 |
+
pipe = StableDiffusionPipeline(
|
88 |
+
text_encoder=text_encoder,
|
89 |
+
tokenizer=tokenizer,
|
90 |
+
unet=unet,
|
91 |
+
vae=vae,
|
92 |
+
scheduler=scheduler,
|
93 |
+
)
|
94 |
+
|
95 |
+
unet.set_attn_processor(CrossAttnProcessor)
|
96 |
+
pipe.setup_text_encoder(clip_skip, text_encoder)
|
97 |
+
if torch.cuda.is_available():
|
98 |
+
pipe = pipe.to("cuda")
|
99 |
+
|
100 |
+
def get_model_list():
|
101 |
+
return models
|
102 |
+
|
103 |
+
te_cache = {
|
104 |
+
base_model: text_encoder
|
105 |
+
}
|
106 |
+
|
107 |
+
unet_cache = {
|
108 |
+
base_model: unet
|
109 |
+
}
|
110 |
+
|
111 |
+
lora_cache = {
|
112 |
+
base_model: LoRANetwork(text_encoder, unet)
|
113 |
+
}
|
114 |
+
|
115 |
+
te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
|
116 |
+
original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
|
117 |
+
current_model = base_model
|
118 |
+
|
119 |
+
def setup_model(name, lora_state=None, lora_scale=1.0):
|
120 |
+
global pipe, current_model
|
121 |
+
|
122 |
+
keys = [k[0] for k in models]
|
123 |
+
model = models[keys.index(name)][1]
|
124 |
+
if model not in unet_cache:
|
125 |
+
unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
|
126 |
+
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16)
|
127 |
+
|
128 |
+
unet_cache[model] = unet
|
129 |
+
te_cache[model] = text_encoder
|
130 |
+
lora_cache[model] = LoRANetwork(text_encoder, unet)
|
131 |
+
|
132 |
+
if current_model != model:
|
133 |
+
if current_model not in keep_vram:
|
134 |
+
# offload current model
|
135 |
+
unet_cache[current_model].to("cpu")
|
136 |
+
te_cache[current_model].to("cpu")
|
137 |
+
lora_cache[current_model].to("cpu")
|
138 |
+
current_model = model
|
139 |
+
|
140 |
+
local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model]
|
141 |
+
local_unet.set_attn_processor(CrossAttnProcessor())
|
142 |
+
local_lora.reset()
|
143 |
+
clip_skip = models[keys.index(name)][2]
|
144 |
+
|
145 |
+
if torch.cuda.is_available():
|
146 |
+
local_unet.to("cuda")
|
147 |
+
local_te.to("cuda")
|
148 |
+
|
149 |
+
if lora_state is not None and lora_state != "":
|
150 |
+
local_lora.load(lora_state, lora_scale)
|
151 |
+
local_lora.to(local_unet.device, dtype=local_unet.dtype)
|
152 |
+
|
153 |
+
pipe.text_encoder, pipe.unet = local_te, local_unet
|
154 |
+
pipe.setup_unet(local_unet)
|
155 |
+
pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
|
156 |
+
pipe.tokenizer.added_tokens_encoder = {}
|
157 |
+
pipe.tokenizer.added_tokens_decoder = {}
|
158 |
+
pipe.setup_text_encoder(clip_skip, local_te)
|
159 |
+
return pipe
|
160 |
+
|
161 |
+
|
162 |
+
def error_str(error, title="Error"):
|
163 |
+
return (
|
164 |
+
f"""#### {title}
|
165 |
+
{error}"""
|
166 |
+
if error
|
167 |
+
else ""
|
168 |
+
)
|
169 |
+
|
170 |
+
def make_token_names(embs):
|
171 |
+
all_tokens = []
|
172 |
+
for name, vec in embs.items():
|
173 |
+
tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
|
174 |
+
all_tokens.append(tokens)
|
175 |
+
return all_tokens
|
176 |
+
|
177 |
+
def setup_tokenizer(tokenizer, embs):
|
178 |
+
reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
|
179 |
+
clip_keywords = [' '.join(s) for s in make_token_names(embs)]
|
180 |
+
|
181 |
+
def parse_prompt(prompt: str):
|
182 |
+
for m, v in zip(reg_match, clip_keywords):
|
183 |
+
prompt = m.sub(v, prompt)
|
184 |
+
return prompt
|
185 |
+
|
186 |
+
def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs):
|
187 |
+
text = parse_prompt(text)
|
188 |
+
r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs)
|
189 |
+
return r
|
190 |
+
tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer)
|
191 |
+
return [t for sublist in make_token_names(embs) for t in sublist]
|
192 |
+
|
193 |
+
|
194 |
+
def convert_size(size_bytes):
|
195 |
+
if size_bytes == 0:
|
196 |
+
return "0B"
|
197 |
+
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
|
198 |
+
i = int(math.floor(math.log(size_bytes, 1024)))
|
199 |
+
p = math.pow(1024, i)
|
200 |
+
s = round(size_bytes / p, 2)
|
201 |
+
return "%s %s" % (s, size_name[i])
|
202 |
+
|
203 |
+
def inference(
|
204 |
+
prompt,
|
205 |
+
guidance,
|
206 |
+
steps,
|
207 |
+
width=512,
|
208 |
+
height=512,
|
209 |
+
seed=0,
|
210 |
+
neg_prompt="",
|
211 |
+
state=None,
|
212 |
+
g_strength=0.4,
|
213 |
+
img_input=None,
|
214 |
+
i2i_scale=0.5,
|
215 |
+
hr_enabled=False,
|
216 |
+
hr_method="Latent",
|
217 |
+
hr_scale=1.5,
|
218 |
+
hr_denoise=0.8,
|
219 |
+
sampler="DPM++ 2M Karras",
|
220 |
+
embs=None,
|
221 |
+
model=None,
|
222 |
+
lora_state=None,
|
223 |
+
lora_scale=None,
|
224 |
+
):
|
225 |
+
if seed is None or seed == 0:
|
226 |
+
seed = random.randint(0, 2147483647)
|
227 |
+
|
228 |
+
pipe = setup_model(model, lora_state, lora_scale)
|
229 |
+
generator = torch.Generator("cuda").manual_seed(int(seed))
|
230 |
+
start_time = time.time()
|
231 |
+
|
232 |
+
sampler_name, sampler_opt = None, None
|
233 |
+
for label, funcname, options in samplers_k_diffusion:
|
234 |
+
if label == sampler:
|
235 |
+
sampler_name, sampler_opt = funcname, options
|
236 |
+
|
237 |
+
tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder
|
238 |
+
if embs is not None and len(embs) > 0:
|
239 |
+
ti_embs = {}
|
240 |
+
for name, file in embs.items():
|
241 |
+
if str(file).endswith(".pt"):
|
242 |
+
loaded_learned_embeds = torch.load(file, map_location="cpu")
|
243 |
+
else:
|
244 |
+
loaded_learned_embeds = load_file(file, device="cpu")
|
245 |
+
loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] if "string_to_param" in loaded_learned_embed else loaded_learned_embed
|
246 |
+
ti_embs[name] = loaded_learned_embeds
|
247 |
+
|
248 |
+
if len(ti_embs) > 0:
|
249 |
+
tokens = setup_tokenizer(tokenizer, ti_embs)
|
250 |
+
added_tokens = tokenizer.add_tokens(tokens)
|
251 |
+
delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
|
252 |
+
|
253 |
+
assert added_tokens == delta_weight.shape[0]
|
254 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
255 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
256 |
+
token_embeds[-delta_weight.shape[0]:] = delta_weight
|
257 |
+
|
258 |
+
config = {
|
259 |
+
"negative_prompt": neg_prompt,
|
260 |
+
"num_inference_steps": int(steps),
|
261 |
+
"guidance_scale": guidance,
|
262 |
+
"generator": generator,
|
263 |
+
"sampler_name": sampler_name,
|
264 |
+
"sampler_opt": sampler_opt,
|
265 |
+
"pww_state": state,
|
266 |
+
"pww_attn_weight": g_strength,
|
267 |
+
"start_time": start_time,
|
268 |
+
"timeout": timeout,
|
269 |
+
}
|
270 |
+
|
271 |
+
if img_input is not None:
|
272 |
+
ratio = min(height / img_input.height, width / img_input.width)
|
273 |
+
img_input = img_input.resize(
|
274 |
+
(int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
|
275 |
+
)
|
276 |
+
result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
|
277 |
+
elif hr_enabled:
|
278 |
+
result = pipe.txt2img(
|
279 |
+
prompt,
|
280 |
+
width=width,
|
281 |
+
height=height,
|
282 |
+
upscale=True,
|
283 |
+
upscale_x=hr_scale,
|
284 |
+
upscale_denoising_strength=hr_denoise,
|
285 |
+
**config,
|
286 |
+
**latent_upscale_modes[hr_method],
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
result = pipe.txt2img(prompt, width=width, height=height, **config)
|
290 |
+
|
291 |
+
end_time = time.time()
|
292 |
+
vram_free, vram_total = torch.cuda.mem_get_info()
|
293 |
+
print(f"done: model={model}, res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
|
294 |
+
return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
|
295 |
+
|
296 |
+
|
297 |
+
color_list = []
|
298 |
+
|
299 |
+
|
300 |
+
def get_color(n):
|
301 |
+
for _ in range(n - len(color_list)):
|
302 |
+
color_list.append(tuple(np.random.random(size=3) * 256))
|
303 |
+
return color_list
|
304 |
+
|
305 |
+
|
306 |
+
def create_mixed_img(current, state, w=512, h=512):
|
307 |
+
w, h = int(w), int(h)
|
308 |
+
image_np = np.full([h, w, 4], 255)
|
309 |
+
if state is None:
|
310 |
+
state = {}
|
311 |
+
|
312 |
+
colors = get_color(len(state))
|
313 |
+
idx = 0
|
314 |
+
|
315 |
+
for key, item in state.items():
|
316 |
+
if item["map"] is not None:
|
317 |
+
m = item["map"] < 255
|
318 |
+
alpha = 150
|
319 |
+
if current == key:
|
320 |
+
alpha = 200
|
321 |
+
image_np[m] = colors[idx] + (alpha,)
|
322 |
+
idx += 1
|
323 |
+
|
324 |
+
return image_np
|
325 |
+
|
326 |
+
|
327 |
+
# width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
|
328 |
+
def apply_new_res(w, h, state):
|
329 |
+
w, h = int(w), int(h)
|
330 |
+
|
331 |
+
for key, item in state.items():
|
332 |
+
if item["map"] is not None:
|
333 |
+
item["map"] = resize(item["map"], w, h)
|
334 |
+
|
335 |
+
update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
|
336 |
+
return state, update_img
|
337 |
+
|
338 |
+
|
339 |
+
def detect_text(text, state, width, height):
|
340 |
+
|
341 |
+
if text is None or text == "":
|
342 |
+
return None, None, gr.Radio.update(value=None), None
|
343 |
+
|
344 |
+
t = text.split(",")
|
345 |
+
new_state = {}
|
346 |
+
|
347 |
+
for item in t:
|
348 |
+
item = item.strip()
|
349 |
+
if item == "":
|
350 |
+
continue
|
351 |
+
if state is not None and item in state:
|
352 |
+
new_state[item] = {
|
353 |
+
"map": state[item]["map"],
|
354 |
+
"weight": state[item]["weight"],
|
355 |
+
"mask_outsides": state[item]["mask_outsides"],
|
356 |
+
}
|
357 |
+
else:
|
358 |
+
new_state[item] = {
|
359 |
+
"map": None,
|
360 |
+
"weight": 0.5,
|
361 |
+
"mask_outsides": False
|
362 |
+
}
|
363 |
+
update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
|
364 |
+
update_img = gr.update(value=create_mixed_img("", new_state, width, height))
|
365 |
+
update_sketch = gr.update(value=None, interactive=False)
|
366 |
+
return new_state, update_sketch, update, update_img
|
367 |
+
|
368 |
+
|
369 |
+
def resize(img, w, h):
|
370 |
+
trs = transforms.Compose(
|
371 |
+
[
|
372 |
+
transforms.ToPILImage(),
|
373 |
+
transforms.Resize(min(h, w)),
|
374 |
+
transforms.CenterCrop((h, w)),
|
375 |
+
]
|
376 |
+
)
|
377 |
+
result = np.array(trs(img), dtype=np.uint8)
|
378 |
+
return result
|
379 |
+
|
380 |
+
|
381 |
+
def switch_canvas(entry, state, width, height):
|
382 |
+
if entry == None:
|
383 |
+
return None, 0.5, False, create_mixed_img("", state, width, height)
|
384 |
+
|
385 |
+
return (
|
386 |
+
gr.update(value=None, interactive=True),
|
387 |
+
gr.update(value=state[entry]["weight"] if entry in state else 0.5),
|
388 |
+
gr.update(value=state[entry]["mask_outsides"] if entry in state else False),
|
389 |
+
create_mixed_img(entry, state, width, height),
|
390 |
+
)
|
391 |
+
|
392 |
+
|
393 |
+
def apply_canvas(selected, draw, state, w, h):
|
394 |
+
if selected in state:
|
395 |
+
w, h = int(w), int(h)
|
396 |
+
state[selected]["map"] = resize(draw, w, h)
|
397 |
+
return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
|
398 |
+
|
399 |
+
|
400 |
+
def apply_weight(selected, weight, state):
|
401 |
+
if selected in state:
|
402 |
+
state[selected]["weight"] = weight
|
403 |
+
return state
|
404 |
+
|
405 |
+
|
406 |
+
def apply_option(selected, mask, state):
|
407 |
+
if selected in state:
|
408 |
+
state[selected]["mask_outsides"] = mask
|
409 |
+
return state
|
410 |
+
|
411 |
+
|
412 |
+
# sp2, radio, width, height, global_stats
|
413 |
+
def apply_image(image, selected, w, h, strgength, mask, state):
|
414 |
+
if selected in state:
|
415 |
+
state[selected] = {
|
416 |
+
"map": resize(image, w, h),
|
417 |
+
"weight": strgength,
|
418 |
+
"mask_outsides": mask
|
419 |
+
}
|
420 |
+
|
421 |
+
return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
|
422 |
+
|
423 |
+
|
424 |
+
# [ti_state, lora_state, ti_vals, lora_vals, uploads]
|
425 |
+
def add_net(files, ti_state, lora_state):
|
426 |
+
if files is None:
|
427 |
+
return ti_state, "", lora_state, None
|
428 |
+
|
429 |
+
for file in files:
|
430 |
+
item = Path(file.name)
|
431 |
+
stripedname = str(item.stem).strip()
|
432 |
+
if item.suffix == ".pt":
|
433 |
+
state_dict = torch.load(file.name, map_location="cpu")
|
434 |
+
else:
|
435 |
+
state_dict = load_file(file.name, device="cpu")
|
436 |
+
if any("lora" in k for k in state_dict.keys()):
|
437 |
+
lora_state = file.name
|
438 |
+
else:
|
439 |
+
ti_state[stripedname] = file.name
|
440 |
+
|
441 |
+
return (
|
442 |
+
ti_state,
|
443 |
+
lora_state,
|
444 |
+
gr.Text.update(f"{[key for key in ti_state.keys()]}"),
|
445 |
+
gr.Text.update(f"{lora_state}"),
|
446 |
+
gr.Files.update(value=None),
|
447 |
+
)
|
448 |
+
|
449 |
+
|
450 |
+
# [ti_state, lora_state, ti_vals, lora_vals, uploads]
|
451 |
+
def clean_states(ti_state, lora_state):
|
452 |
+
return (
|
453 |
+
dict(),
|
454 |
+
None,
|
455 |
+
gr.Text.update(f""),
|
456 |
+
gr.Text.update(f""),
|
457 |
+
gr.File.update(value=None),
|
458 |
+
)
|
459 |
+
|
460 |
+
|
461 |
+
latent_upscale_modes = {
|
462 |
+
"Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
|
463 |
+
"Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
|
464 |
+
"Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
|
465 |
+
"Latent (bicubic antialiased)": {
|
466 |
+
"upscale_method": "bicubic",
|
467 |
+
"upscale_antialias": True,
|
468 |
+
},
|
469 |
+
"Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
|
470 |
+
"Latent (nearest-exact)": {
|
471 |
+
"upscale_method": "nearest-exact",
|
472 |
+
"upscale_antialias": False,
|
473 |
+
},
|
474 |
+
}
|
475 |
+
|
476 |
+
css = """
|
477 |
+
.finetuned-diffusion-div div{
|
478 |
+
display:inline-flex;
|
479 |
+
align-items:center;
|
480 |
+
gap:.8rem;
|
481 |
+
font-size:1.75rem;
|
482 |
+
padding-top:2rem;
|
483 |
+
}
|
484 |
+
.finetuned-diffusion-div div h1{
|
485 |
+
font-weight:900;
|
486 |
+
margin-bottom:7px
|
487 |
+
}
|
488 |
+
.finetuned-diffusion-div p{
|
489 |
+
margin-bottom:10px;
|
490 |
+
font-size:94%
|
491 |
+
}
|
492 |
+
.box {
|
493 |
+
float: left;
|
494 |
+
height: 20px;
|
495 |
+
width: 20px;
|
496 |
+
margin-bottom: 15px;
|
497 |
+
border: 1px solid black;
|
498 |
+
clear: both;
|
499 |
+
}
|
500 |
+
a{
|
501 |
+
text-decoration:underline
|
502 |
+
}
|
503 |
+
.tabs{
|
504 |
+
margin-top:0;
|
505 |
+
margin-bottom:0
|
506 |
+
}
|
507 |
+
#gallery{
|
508 |
+
min-height:20rem
|
509 |
+
}
|
510 |
+
.no-border {
|
511 |
+
border: none !important;
|
512 |
+
}
|
513 |
+
"""
|
514 |
+
with gr.Blocks(css=css) as demo:
|
515 |
+
gr.HTML(
|
516 |
+
f"""
|
517 |
+
<div class="finetuned-diffusion-div">
|
518 |
+
<div>
|
519 |
+
<h1>Demo for diffusion models</h1>
|
520 |
+
</div>
|
521 |
+
<p>Hso @ nyanko.sketch2img.gradio</p>
|
522 |
+
</div>
|
523 |
+
"""
|
524 |
+
)
|
525 |
+
global_stats = gr.State(value={})
|
526 |
+
|
527 |
+
with gr.Row():
|
528 |
+
|
529 |
+
with gr.Column(scale=55):
|
530 |
+
model = gr.Dropdown(
|
531 |
+
choices=[k[0] for k in get_model_list()],
|
532 |
+
label="Model",
|
533 |
+
value=base_name,
|
534 |
+
)
|
535 |
+
image_out = gr.Image(height=512)
|
536 |
+
# gallery = gr.Gallery(
|
537 |
+
# label="Generated images", show_label=False, elem_id="gallery"
|
538 |
+
# ).style(grid=[1], height="auto")
|
539 |
+
|
540 |
+
with gr.Column(scale=45):
|
541 |
+
|
542 |
+
with gr.Group():
|
543 |
+
|
544 |
+
with gr.Row():
|
545 |
+
with gr.Column(scale=70):
|
546 |
+
|
547 |
+
prompt = gr.Textbox(
|
548 |
+
label="Prompt",
|
549 |
+
value="best quality, masterpiece, highres, an extremely delicate and beautiful, original, extremely detailed wallpaper, highres , 1girl",
|
550 |
+
show_label=True,
|
551 |
+
max_lines=4,
|
552 |
+
placeholder="Enter prompt.",
|
553 |
+
)
|
554 |
+
neg_prompt = gr.Textbox(
|
555 |
+
label="Negative Prompt",
|
556 |
+
value="simple background,monochrome ,lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits,twisting jawline, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, lowres, bad anatomy, bad hands, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, ugly,pregnant,vore,duplicate,morbid,mut ilated,tran nsexual, hermaphrodite,long neck,mutated hands,poorly drawn hands,poorly drawn face,mutation,deformed,blurry,bad anatomy,bad proportions,malformed limbs,extra limbs,cloned face,disfigured,gross proportions, missing arms, missing legs, extra arms,extra legs,pubic hair, plump,bad legs,error legs,username,blurry,bad feet",
|
557 |
+
show_label=True,
|
558 |
+
max_lines=4,
|
559 |
+
placeholder="Enter negative prompt.",
|
560 |
+
)
|
561 |
+
|
562 |
+
generate = gr.Button(value="Generate").style(
|
563 |
+
rounded=(False, True, True, False)
|
564 |
+
)
|
565 |
+
|
566 |
+
with gr.Tab("Options"):
|
567 |
+
|
568 |
+
with gr.Group():
|
569 |
+
|
570 |
+
# n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
|
571 |
+
with gr.Row():
|
572 |
+
guidance = gr.Slider(
|
573 |
+
label="Guidance scale", value=7.5, maximum=15
|
574 |
+
)
|
575 |
+
steps = gr.Slider(
|
576 |
+
label="Steps", value=25, minimum=2, maximum=50, step=1
|
577 |
+
)
|
578 |
+
|
579 |
+
with gr.Row():
|
580 |
+
width = gr.Slider(
|
581 |
+
label="Width", value=512, minimum=64, maximum=1024, step=64
|
582 |
+
)
|
583 |
+
height = gr.Slider(
|
584 |
+
label="Height", value=512, minimum=64, maximum=1024, step=64
|
585 |
+
)
|
586 |
+
|
587 |
+
sampler = gr.Dropdown(
|
588 |
+
value="DPM++ 2M Karras",
|
589 |
+
label="Sampler",
|
590 |
+
choices=[s[0] for s in samplers_k_diffusion],
|
591 |
+
)
|
592 |
+
seed = gr.Number(label="Seed (0 = random)", value=0)
|
593 |
+
|
594 |
+
with gr.Tab("Image to image"):
|
595 |
+
with gr.Group():
|
596 |
+
|
597 |
+
inf_image = gr.Image(
|
598 |
+
label="Image", height=256, tool="editor", type="pil"
|
599 |
+
)
|
600 |
+
inf_strength = gr.Slider(
|
601 |
+
label="Transformation strength",
|
602 |
+
minimum=0,
|
603 |
+
maximum=1,
|
604 |
+
step=0.01,
|
605 |
+
value=0.5,
|
606 |
+
)
|
607 |
+
|
608 |
+
def res_cap(g, w, h, x):
|
609 |
+
if g:
|
610 |
+
return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
|
611 |
+
else:
|
612 |
+
return "Enable upscaler"
|
613 |
+
|
614 |
+
with gr.Tab("Hires fix"):
|
615 |
+
with gr.Group():
|
616 |
+
|
617 |
+
hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
|
618 |
+
hr_method = gr.Dropdown(
|
619 |
+
[key for key in latent_upscale_modes.keys()],
|
620 |
+
value="Latent",
|
621 |
+
label="Upscale method",
|
622 |
+
)
|
623 |
+
hr_scale = gr.Slider(
|
624 |
+
label="Upscale factor",
|
625 |
+
minimum=1.0,
|
626 |
+
maximum=2.0,
|
627 |
+
step=0.1,
|
628 |
+
value=1.5,
|
629 |
+
)
|
630 |
+
hr_denoise = gr.Slider(
|
631 |
+
label="Denoising strength",
|
632 |
+
minimum=0.0,
|
633 |
+
maximum=1.0,
|
634 |
+
step=0.1,
|
635 |
+
value=0.8,
|
636 |
+
)
|
637 |
+
|
638 |
+
hr_scale.change(
|
639 |
+
lambda g, x, w, h: gr.Checkbox.update(
|
640 |
+
label=res_cap(g, w, h, x)
|
641 |
+
),
|
642 |
+
inputs=[hr_enabled, hr_scale, width, height],
|
643 |
+
outputs=hr_enabled,
|
644 |
+
queue=False,
|
645 |
+
)
|
646 |
+
hr_enabled.change(
|
647 |
+
lambda g, x, w, h: gr.Checkbox.update(
|
648 |
+
label=res_cap(g, w, h, x)
|
649 |
+
),
|
650 |
+
inputs=[hr_enabled, hr_scale, width, height],
|
651 |
+
outputs=hr_enabled,
|
652 |
+
queue=False,
|
653 |
+
)
|
654 |
+
|
655 |
+
with gr.Tab("Embeddings/Loras"):
|
656 |
+
|
657 |
+
ti_state = gr.State(dict())
|
658 |
+
lora_state = gr.State()
|
659 |
+
|
660 |
+
with gr.Group():
|
661 |
+
with gr.Row():
|
662 |
+
with gr.Column(scale=90):
|
663 |
+
ti_vals = gr.Text(label="Loaded embeddings")
|
664 |
+
|
665 |
+
with gr.Row():
|
666 |
+
with gr.Column(scale=90):
|
667 |
+
lora_vals = gr.Text(label="Loaded loras")
|
668 |
+
|
669 |
+
with gr.Row():
|
670 |
+
|
671 |
+
uploads = gr.Files(label="Upload new embeddings/lora")
|
672 |
+
|
673 |
+
with gr.Column():
|
674 |
+
lora_scale = gr.Slider(
|
675 |
+
label="Lora scale",
|
676 |
+
minimum=0,
|
677 |
+
maximum=2,
|
678 |
+
step=0.01,
|
679 |
+
value=1.0,
|
680 |
+
)
|
681 |
+
btn = gr.Button(value="Upload")
|
682 |
+
btn_del = gr.Button(value="Reset")
|
683 |
+
|
684 |
+
btn.click(
|
685 |
+
add_net,
|
686 |
+
inputs=[uploads, ti_state, lora_state],
|
687 |
+
outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
|
688 |
+
queue=False,
|
689 |
+
)
|
690 |
+
btn_del.click(
|
691 |
+
clean_states,
|
692 |
+
inputs=[ti_state, lora_state],
|
693 |
+
outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
|
694 |
+
queue=False,
|
695 |
+
)
|
696 |
+
|
697 |
+
# error_output = gr.Markdown()
|
698 |
+
|
699 |
+
gr.HTML(
|
700 |
+
f"""
|
701 |
+
<div class="finetuned-diffusion-div">
|
702 |
+
<div>
|
703 |
+
<h1>Paint with words</h1>
|
704 |
+
</div>
|
705 |
+
<p>
|
706 |
+
Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
|
707 |
+
</p>
|
708 |
+
</div>
|
709 |
+
"""
|
710 |
+
)
|
711 |
+
|
712 |
+
with gr.Row():
|
713 |
+
|
714 |
+
with gr.Column(scale=55):
|
715 |
+
|
716 |
+
rendered = gr.Image(
|
717 |
+
invert_colors=True,
|
718 |
+
source="canvas",
|
719 |
+
interactive=False,
|
720 |
+
image_mode="RGBA",
|
721 |
+
)
|
722 |
+
|
723 |
+
with gr.Column(scale=45):
|
724 |
+
|
725 |
+
with gr.Group():
|
726 |
+
with gr.Row():
|
727 |
+
with gr.Column(scale=70):
|
728 |
+
g_strength = gr.Slider(
|
729 |
+
label="Weight scaling",
|
730 |
+
minimum=0,
|
731 |
+
maximum=0.8,
|
732 |
+
step=0.01,
|
733 |
+
value=0.4,
|
734 |
+
)
|
735 |
+
|
736 |
+
text = gr.Textbox(
|
737 |
+
lines=2,
|
738 |
+
interactive=True,
|
739 |
+
label="Token to Draw: (Separate by comma)",
|
740 |
+
)
|
741 |
+
|
742 |
+
radio = gr.Radio([], label="Tokens")
|
743 |
+
|
744 |
+
sk_update = gr.Button(value="Update").style(
|
745 |
+
rounded=(False, True, True, False)
|
746 |
+
)
|
747 |
+
|
748 |
+
# g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
|
749 |
+
|
750 |
+
with gr.Tab("SketchPad"):
|
751 |
+
|
752 |
+
sp = gr.Image(
|
753 |
+
image_mode="L",
|
754 |
+
tool="sketch",
|
755 |
+
source="canvas",
|
756 |
+
interactive=False,
|
757 |
+
)
|
758 |
+
|
759 |
+
mask_outsides = gr.Checkbox(
|
760 |
+
label="Mask other areas",
|
761 |
+
value=False
|
762 |
+
)
|
763 |
+
|
764 |
+
strength = gr.Slider(
|
765 |
+
label="Token strength",
|
766 |
+
minimum=0,
|
767 |
+
maximum=0.8,
|
768 |
+
step=0.01,
|
769 |
+
value=0.5,
|
770 |
+
)
|
771 |
+
|
772 |
+
|
773 |
+
sk_update.click(
|
774 |
+
detect_text,
|
775 |
+
inputs=[text, global_stats, width, height],
|
776 |
+
outputs=[global_stats, sp, radio, rendered],
|
777 |
+
queue=False,
|
778 |
+
)
|
779 |
+
radio.change(
|
780 |
+
switch_canvas,
|
781 |
+
inputs=[radio, global_stats, width, height],
|
782 |
+
outputs=[sp, strength, mask_outsides, rendered],
|
783 |
+
queue=False,
|
784 |
+
)
|
785 |
+
sp.edit(
|
786 |
+
apply_canvas,
|
787 |
+
inputs=[radio, sp, global_stats, width, height],
|
788 |
+
outputs=[global_stats, rendered],
|
789 |
+
queue=False,
|
790 |
+
)
|
791 |
+
strength.change(
|
792 |
+
apply_weight,
|
793 |
+
inputs=[radio, strength, global_stats],
|
794 |
+
outputs=[global_stats],
|
795 |
+
queue=False,
|
796 |
+
)
|
797 |
+
mask_outsides.change(
|
798 |
+
apply_option,
|
799 |
+
inputs=[radio, mask_outsides, global_stats],
|
800 |
+
outputs=[global_stats],
|
801 |
+
queue=False,
|
802 |
+
)
|
803 |
+
|
804 |
+
with gr.Tab("UploadFile"):
|
805 |
+
|
806 |
+
sp2 = gr.Image(
|
807 |
+
image_mode="L",
|
808 |
+
source="upload",
|
809 |
+
shape=(512, 512),
|
810 |
+
)
|
811 |
+
|
812 |
+
mask_outsides2 = gr.Checkbox(
|
813 |
+
label="Mask other areas",
|
814 |
+
value=False,
|
815 |
+
)
|
816 |
+
|
817 |
+
strength2 = gr.Slider(
|
818 |
+
label="Token strength",
|
819 |
+
minimum=0,
|
820 |
+
maximum=0.8,
|
821 |
+
step=0.01,
|
822 |
+
value=0.5,
|
823 |
+
)
|
824 |
+
|
825 |
+
apply_style = gr.Button(value="Apply")
|
826 |
+
apply_style.click(
|
827 |
+
apply_image,
|
828 |
+
inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats],
|
829 |
+
outputs=[global_stats, rendered],
|
830 |
+
queue=False,
|
831 |
+
)
|
832 |
+
|
833 |
+
width.change(
|
834 |
+
apply_new_res,
|
835 |
+
inputs=[width, height, global_stats],
|
836 |
+
outputs=[global_stats, rendered],
|
837 |
+
queue=False,
|
838 |
+
)
|
839 |
+
height.change(
|
840 |
+
apply_new_res,
|
841 |
+
inputs=[width, height, global_stats],
|
842 |
+
outputs=[global_stats, rendered],
|
843 |
+
queue=False,
|
844 |
+
)
|
845 |
+
|
846 |
+
# color_stats = gr.State(value={})
|
847 |
+
# text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
|
848 |
+
# sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
|
849 |
+
|
850 |
+
inputs = [
|
851 |
+
prompt,
|
852 |
+
guidance,
|
853 |
+
steps,
|
854 |
+
width,
|
855 |
+
height,
|
856 |
+
seed,
|
857 |
+
neg_prompt,
|
858 |
+
global_stats,
|
859 |
+
g_strength,
|
860 |
+
inf_image,
|
861 |
+
inf_strength,
|
862 |
+
hr_enabled,
|
863 |
+
hr_method,
|
864 |
+
hr_scale,
|
865 |
+
hr_denoise,
|
866 |
+
sampler,
|
867 |
+
ti_state,
|
868 |
+
model,
|
869 |
+
lora_state,
|
870 |
+
lora_scale,
|
871 |
+
]
|
872 |
+
outputs = [image_out]
|
873 |
+
prompt.submit(inference, inputs=inputs, outputs=outputs)
|
874 |
+
generate.click(inference, inputs=inputs, outputs=outputs)
|
875 |
+
|
876 |
+
print(f"Space built in {time.time() - start_time:.2f} seconds")
|
877 |
+
# demo.launch(share=True)
|
878 |
+
demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
|
modules/lora.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
# https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import modules.safe as _
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
|
13 |
+
|
14 |
+
class LoRAModule(torch.nn.Module):
|
15 |
+
"""
|
16 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
lora_name,
|
22 |
+
org_module: torch.nn.Module,
|
23 |
+
multiplier=1.0,
|
24 |
+
lora_dim=4,
|
25 |
+
alpha=1,
|
26 |
+
):
|
27 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
28 |
+
super().__init__()
|
29 |
+
self.lora_name = lora_name
|
30 |
+
self.lora_dim = lora_dim
|
31 |
+
|
32 |
+
if org_module.__class__.__name__ == "Conv2d":
|
33 |
+
in_dim = org_module.in_channels
|
34 |
+
out_dim = org_module.out_channels
|
35 |
+
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
36 |
+
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
37 |
+
else:
|
38 |
+
in_dim = org_module.in_features
|
39 |
+
out_dim = org_module.out_features
|
40 |
+
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
41 |
+
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
42 |
+
|
43 |
+
if type(alpha) == torch.Tensor:
|
44 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
45 |
+
|
46 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
47 |
+
self.scale = alpha / self.lora_dim
|
48 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
49 |
+
|
50 |
+
# same as microsoft's
|
51 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
52 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
53 |
+
|
54 |
+
self.multiplier = multiplier
|
55 |
+
self.org_module = org_module # remove in applying
|
56 |
+
self.enable = False
|
57 |
+
|
58 |
+
def resize(self, rank, alpha, multiplier):
|
59 |
+
self.alpha = torch.tensor(alpha)
|
60 |
+
self.multiplier = multiplier
|
61 |
+
self.scale = alpha / rank
|
62 |
+
if self.lora_down.__class__.__name__ == "Conv2d":
|
63 |
+
in_dim = self.lora_down.in_channels
|
64 |
+
out_dim = self.lora_up.out_channels
|
65 |
+
self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
|
66 |
+
self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
|
67 |
+
else:
|
68 |
+
in_dim = self.lora_down.in_features
|
69 |
+
out_dim = self.lora_up.out_features
|
70 |
+
self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
|
71 |
+
self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
|
72 |
+
|
73 |
+
def apply(self):
|
74 |
+
if hasattr(self, "org_module"):
|
75 |
+
self.org_forward = self.org_module.forward
|
76 |
+
self.org_module.forward = self.forward
|
77 |
+
del self.org_module
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
if self.enable:
|
81 |
+
return (
|
82 |
+
self.org_forward(x)
|
83 |
+
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
84 |
+
)
|
85 |
+
return self.org_forward(x)
|
86 |
+
|
87 |
+
|
88 |
+
class LoRANetwork(torch.nn.Module):
|
89 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
90 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
91 |
+
LORA_PREFIX_UNET = "lora_unet"
|
92 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
93 |
+
|
94 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
95 |
+
super().__init__()
|
96 |
+
self.multiplier = multiplier
|
97 |
+
self.lora_dim = lora_dim
|
98 |
+
self.alpha = alpha
|
99 |
+
|
100 |
+
# create module instances
|
101 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
|
102 |
+
loras = []
|
103 |
+
for name, module in root_module.named_modules():
|
104 |
+
if module.__class__.__name__ in target_replace_modules:
|
105 |
+
for child_name, child_module in module.named_modules():
|
106 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
107 |
+
lora_name = prefix + "." + name + "." + child_name
|
108 |
+
lora_name = lora_name.replace(".", "_")
|
109 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
|
110 |
+
loras.append(lora)
|
111 |
+
return loras
|
112 |
+
|
113 |
+
if isinstance(text_encoder, list):
|
114 |
+
self.text_encoder_loras = text_encoder
|
115 |
+
else:
|
116 |
+
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
117 |
+
print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
118 |
+
|
119 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
120 |
+
print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
121 |
+
|
122 |
+
self.weights_sd = None
|
123 |
+
|
124 |
+
# assertion
|
125 |
+
names = set()
|
126 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
127 |
+
assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
|
128 |
+
names.add(lora.lora_name)
|
129 |
+
|
130 |
+
lora.apply()
|
131 |
+
self.add_module(lora.lora_name, lora)
|
132 |
+
|
133 |
+
def reset(self):
|
134 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
135 |
+
lora.enable = False
|
136 |
+
|
137 |
+
def load(self, file, scale):
|
138 |
+
|
139 |
+
weights = None
|
140 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
141 |
+
weights = load_file(file)
|
142 |
+
else:
|
143 |
+
weights = torch.load(file, map_location="cpu")
|
144 |
+
|
145 |
+
if not weights:
|
146 |
+
return
|
147 |
+
|
148 |
+
network_alpha = None
|
149 |
+
network_dim = None
|
150 |
+
for key, value in weights.items():
|
151 |
+
if network_alpha is None and "alpha" in key:
|
152 |
+
network_alpha = value
|
153 |
+
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
|
154 |
+
network_dim = value.size()[0]
|
155 |
+
|
156 |
+
if network_alpha is None:
|
157 |
+
network_alpha = network_dim
|
158 |
+
|
159 |
+
weights_has_text_encoder = weights_has_unet = False
|
160 |
+
weights_to_modify = []
|
161 |
+
|
162 |
+
for key in weights.keys():
|
163 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
164 |
+
weights_has_text_encoder = True
|
165 |
+
|
166 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
167 |
+
weights_has_unet = True
|
168 |
+
|
169 |
+
if weights_has_text_encoder:
|
170 |
+
weights_to_modify += self.text_encoder_loras
|
171 |
+
|
172 |
+
if weights_has_unet:
|
173 |
+
weights_to_modify += self.unet_loras
|
174 |
+
|
175 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
176 |
+
lora.resize(network_dim, network_alpha, scale)
|
177 |
+
if lora in weights_to_modify:
|
178 |
+
lora.enable = True
|
179 |
+
|
180 |
+
info = self.load_state_dict(weights, False)
|
181 |
+
if len(info.unexpected_keys) > 0:
|
182 |
+
print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")
|
183 |
+
|
modules/model.py
ADDED
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import inspect
|
3 |
+
import math
|
4 |
+
from pathlib import Path
|
5 |
+
import re
|
6 |
+
from collections import defaultdict
|
7 |
+
from typing import List, Optional, Union
|
8 |
+
|
9 |
+
import time
|
10 |
+
import k_diffusion
|
11 |
+
import numpy as np
|
12 |
+
import PIL
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from einops import rearrange
|
17 |
+
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
18 |
+
from modules.prompt_parser import FrozenCLIPEmbedderWithCustomWords
|
19 |
+
from torch import einsum
|
20 |
+
from torch.autograd.function import Function
|
21 |
+
|
22 |
+
from diffusers import DiffusionPipeline
|
23 |
+
from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available
|
24 |
+
from diffusers.utils import logging, randn_tensor
|
25 |
+
|
26 |
+
import modules.safe as _
|
27 |
+
from safetensors.torch import load_file
|
28 |
+
|
29 |
+
xformers_available = False
|
30 |
+
try:
|
31 |
+
import xformers
|
32 |
+
|
33 |
+
xformers_available = True
|
34 |
+
except ImportError:
|
35 |
+
pass
|
36 |
+
|
37 |
+
EPSILON = 1e-6
|
38 |
+
exists = lambda val: val is not None
|
39 |
+
default = lambda val, d: val if exists(val) else d
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
def get_attention_scores(attn, query, key, attention_mask=None):
|
44 |
+
|
45 |
+
if attn.upcast_attention:
|
46 |
+
query = query.float()
|
47 |
+
key = key.float()
|
48 |
+
|
49 |
+
attention_scores = torch.baddbmm(
|
50 |
+
torch.empty(
|
51 |
+
query.shape[0],
|
52 |
+
query.shape[1],
|
53 |
+
key.shape[1],
|
54 |
+
dtype=query.dtype,
|
55 |
+
device=query.device,
|
56 |
+
),
|
57 |
+
query,
|
58 |
+
key.transpose(-1, -2),
|
59 |
+
beta=0,
|
60 |
+
alpha=attn.scale,
|
61 |
+
)
|
62 |
+
|
63 |
+
if attention_mask is not None:
|
64 |
+
attention_scores = attention_scores + attention_mask
|
65 |
+
|
66 |
+
if attn.upcast_softmax:
|
67 |
+
attention_scores = attention_scores.float()
|
68 |
+
|
69 |
+
return attention_scores
|
70 |
+
|
71 |
+
|
72 |
+
class CrossAttnProcessor(nn.Module):
|
73 |
+
def __call__(
|
74 |
+
self,
|
75 |
+
attn,
|
76 |
+
hidden_states,
|
77 |
+
encoder_hidden_states=None,
|
78 |
+
attention_mask=None,
|
79 |
+
):
|
80 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
81 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
82 |
+
|
83 |
+
encoder_states = hidden_states
|
84 |
+
is_xattn = False
|
85 |
+
if encoder_hidden_states is not None:
|
86 |
+
is_xattn = True
|
87 |
+
img_state = encoder_hidden_states["img_state"]
|
88 |
+
encoder_states = encoder_hidden_states["states"]
|
89 |
+
weight_func = encoder_hidden_states["weight_func"]
|
90 |
+
sigma = encoder_hidden_states["sigma"]
|
91 |
+
|
92 |
+
query = attn.to_q(hidden_states)
|
93 |
+
key = attn.to_k(encoder_states)
|
94 |
+
value = attn.to_v(encoder_states)
|
95 |
+
|
96 |
+
query = attn.head_to_batch_dim(query)
|
97 |
+
key = attn.head_to_batch_dim(key)
|
98 |
+
value = attn.head_to_batch_dim(value)
|
99 |
+
|
100 |
+
if is_xattn and isinstance(img_state, dict):
|
101 |
+
# use torch.baddbmm method (slow)
|
102 |
+
attention_scores = get_attention_scores(attn, query, key, attention_mask)
|
103 |
+
w = img_state[sequence_length].to(query.device)
|
104 |
+
cross_attention_weight = weight_func(w, sigma, attention_scores)
|
105 |
+
attention_scores += torch.repeat_interleave(
|
106 |
+
cross_attention_weight, repeats=attn.heads, dim=0
|
107 |
+
)
|
108 |
+
|
109 |
+
# calc probs
|
110 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
111 |
+
attention_probs = attention_probs.to(query.dtype)
|
112 |
+
hidden_states = torch.bmm(attention_probs, value)
|
113 |
+
|
114 |
+
elif xformers_available:
|
115 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
116 |
+
query.contiguous(),
|
117 |
+
key.contiguous(),
|
118 |
+
value.contiguous(),
|
119 |
+
attn_bias=attention_mask,
|
120 |
+
)
|
121 |
+
hidden_states = hidden_states.to(query.dtype)
|
122 |
+
|
123 |
+
else:
|
124 |
+
q_bucket_size = 512
|
125 |
+
k_bucket_size = 1024
|
126 |
+
|
127 |
+
# use flash-attention
|
128 |
+
hidden_states = FlashAttentionFunction.apply(
|
129 |
+
query.contiguous(),
|
130 |
+
key.contiguous(),
|
131 |
+
value.contiguous(),
|
132 |
+
attention_mask,
|
133 |
+
False,
|
134 |
+
q_bucket_size,
|
135 |
+
k_bucket_size,
|
136 |
+
)
|
137 |
+
hidden_states = hidden_states.to(query.dtype)
|
138 |
+
|
139 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
140 |
+
|
141 |
+
# linear proj
|
142 |
+
hidden_states = attn.to_out[0](hidden_states)
|
143 |
+
|
144 |
+
# dropout
|
145 |
+
hidden_states = attn.to_out[1](hidden_states)
|
146 |
+
|
147 |
+
return hidden_states
|
148 |
+
|
149 |
+
class ModelWrapper:
|
150 |
+
def __init__(self, model, alphas_cumprod):
|
151 |
+
self.model = model
|
152 |
+
self.alphas_cumprod = alphas_cumprod
|
153 |
+
|
154 |
+
def apply_model(self, *args, **kwargs):
|
155 |
+
if len(args) == 3:
|
156 |
+
encoder_hidden_states = args[-1]
|
157 |
+
args = args[:2]
|
158 |
+
if kwargs.get("cond", None) is not None:
|
159 |
+
encoder_hidden_states = kwargs.pop("cond")
|
160 |
+
return self.model(
|
161 |
+
*args, encoder_hidden_states=encoder_hidden_states, **kwargs
|
162 |
+
).sample
|
163 |
+
|
164 |
+
|
165 |
+
class StableDiffusionPipeline(DiffusionPipeline):
|
166 |
+
|
167 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
vae,
|
172 |
+
text_encoder,
|
173 |
+
tokenizer,
|
174 |
+
unet,
|
175 |
+
scheduler,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
|
179 |
+
# get correct sigmas from LMS
|
180 |
+
self.register_modules(
|
181 |
+
vae=vae,
|
182 |
+
text_encoder=text_encoder,
|
183 |
+
tokenizer=tokenizer,
|
184 |
+
unet=unet,
|
185 |
+
scheduler=scheduler,
|
186 |
+
)
|
187 |
+
self.setup_unet(self.unet)
|
188 |
+
self.setup_text_encoder()
|
189 |
+
|
190 |
+
def setup_text_encoder(self, n=1, new_encoder=None):
|
191 |
+
if new_encoder is not None:
|
192 |
+
self.text_encoder = new_encoder
|
193 |
+
|
194 |
+
self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
|
195 |
+
self.prompt_parser.CLIP_stop_at_last_layers = n
|
196 |
+
|
197 |
+
def setup_unet(self, unet):
|
198 |
+
unet = unet.to(self.device)
|
199 |
+
model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
|
200 |
+
if self.scheduler.prediction_type == "v_prediction":
|
201 |
+
self.k_diffusion_model = CompVisVDenoiser(model)
|
202 |
+
else:
|
203 |
+
self.k_diffusion_model = CompVisDenoiser(model)
|
204 |
+
|
205 |
+
def get_scheduler(self, scheduler_type: str):
|
206 |
+
library = importlib.import_module("k_diffusion")
|
207 |
+
sampling = getattr(library, "sampling")
|
208 |
+
return getattr(sampling, scheduler_type)
|
209 |
+
|
210 |
+
def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
|
211 |
+
uncond, cond = text_ids[0], text_ids[1]
|
212 |
+
|
213 |
+
img_state = []
|
214 |
+
if state is None:
|
215 |
+
return torch.FloatTensor(0)
|
216 |
+
|
217 |
+
for k, v in state.items():
|
218 |
+
if v["map"] is None:
|
219 |
+
continue
|
220 |
+
|
221 |
+
v_input = self.tokenizer(
|
222 |
+
k,
|
223 |
+
max_length=self.tokenizer.model_max_length,
|
224 |
+
truncation=True,
|
225 |
+
add_special_tokens=False,
|
226 |
+
).input_ids
|
227 |
+
|
228 |
+
dotmap = v["map"] < 255
|
229 |
+
out = dotmap.astype(float)
|
230 |
+
if v["mask_outsides"]:
|
231 |
+
out[out==0] = -1
|
232 |
+
|
233 |
+
arr = torch.from_numpy(
|
234 |
+
out * float(v["weight"]) * g_strength
|
235 |
+
)
|
236 |
+
img_state.append((v_input, arr))
|
237 |
+
|
238 |
+
if len(img_state) == 0:
|
239 |
+
return torch.FloatTensor(0)
|
240 |
+
|
241 |
+
w_tensors = dict()
|
242 |
+
cond = cond.tolist()
|
243 |
+
uncond = uncond.tolist()
|
244 |
+
for layer in self.unet.down_blocks:
|
245 |
+
c = int(len(cond))
|
246 |
+
w, h = img_state[0][1].shape
|
247 |
+
w_r, h_r = w // scale_ratio, h // scale_ratio
|
248 |
+
|
249 |
+
ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
|
250 |
+
ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
|
251 |
+
|
252 |
+
for v_as_tokens, img_where_color in img_state:
|
253 |
+
is_in = 0
|
254 |
+
|
255 |
+
ret = (
|
256 |
+
F.interpolate(
|
257 |
+
img_where_color.unsqueeze(0).unsqueeze(1),
|
258 |
+
scale_factor=1 / scale_ratio,
|
259 |
+
mode="bilinear",
|
260 |
+
align_corners=True,
|
261 |
+
)
|
262 |
+
.squeeze()
|
263 |
+
.reshape(-1, 1)
|
264 |
+
.repeat(1, len(v_as_tokens))
|
265 |
+
)
|
266 |
+
|
267 |
+
for idx, tok in enumerate(cond):
|
268 |
+
if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
|
269 |
+
is_in = 1
|
270 |
+
ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
|
271 |
+
|
272 |
+
for idx, tok in enumerate(uncond):
|
273 |
+
if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
|
274 |
+
is_in = 1
|
275 |
+
ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
|
276 |
+
|
277 |
+
if not is_in == 1:
|
278 |
+
print(f"tokens {v_as_tokens} not found in text")
|
279 |
+
|
280 |
+
w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
|
281 |
+
scale_ratio *= 2
|
282 |
+
|
283 |
+
return w_tensors
|
284 |
+
|
285 |
+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
286 |
+
r"""
|
287 |
+
Enable sliced attention computation.
|
288 |
+
|
289 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
290 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
294 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
295 |
+
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
296 |
+
`attention_head_dim` must be a multiple of `slice_size`.
|
297 |
+
"""
|
298 |
+
if slice_size == "auto":
|
299 |
+
# half the attention head size is usually a good trade-off between
|
300 |
+
# speed and memory
|
301 |
+
slice_size = self.unet.config.attention_head_dim // 2
|
302 |
+
self.unet.set_attention_slice(slice_size)
|
303 |
+
|
304 |
+
def disable_attention_slicing(self):
|
305 |
+
r"""
|
306 |
+
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
307 |
+
back to computing attention in one step.
|
308 |
+
"""
|
309 |
+
# set slice_size = `None` to disable `attention slicing`
|
310 |
+
self.enable_attention_slicing(None)
|
311 |
+
|
312 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
313 |
+
r"""
|
314 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
315 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
316 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
317 |
+
"""
|
318 |
+
if is_accelerate_available():
|
319 |
+
from accelerate import cpu_offload
|
320 |
+
else:
|
321 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
322 |
+
|
323 |
+
device = torch.device(f"cuda:{gpu_id}")
|
324 |
+
|
325 |
+
for cpu_offloaded_model in [
|
326 |
+
self.unet,
|
327 |
+
self.text_encoder,
|
328 |
+
self.vae,
|
329 |
+
self.safety_checker,
|
330 |
+
]:
|
331 |
+
if cpu_offloaded_model is not None:
|
332 |
+
cpu_offload(cpu_offloaded_model, device)
|
333 |
+
|
334 |
+
@property
|
335 |
+
def _execution_device(self):
|
336 |
+
r"""
|
337 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
338 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
339 |
+
hooks.
|
340 |
+
"""
|
341 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
342 |
+
return self.device
|
343 |
+
for module in self.unet.modules():
|
344 |
+
if (
|
345 |
+
hasattr(module, "_hf_hook")
|
346 |
+
and hasattr(module._hf_hook, "execution_device")
|
347 |
+
and module._hf_hook.execution_device is not None
|
348 |
+
):
|
349 |
+
return torch.device(module._hf_hook.execution_device)
|
350 |
+
return self.device
|
351 |
+
|
352 |
+
def decode_latents(self, latents):
|
353 |
+
latents = latents.to(self.device, dtype=self.vae.dtype)
|
354 |
+
latents = 1 / 0.18215 * latents
|
355 |
+
image = self.vae.decode(latents).sample
|
356 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
357 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
358 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
359 |
+
return image
|
360 |
+
|
361 |
+
def check_inputs(self, prompt, height, width, callback_steps):
|
362 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
363 |
+
raise ValueError(
|
364 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
365 |
+
)
|
366 |
+
|
367 |
+
if height % 8 != 0 or width % 8 != 0:
|
368 |
+
raise ValueError(
|
369 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
370 |
+
)
|
371 |
+
|
372 |
+
if (callback_steps is None) or (
|
373 |
+
callback_steps is not None
|
374 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
375 |
+
):
|
376 |
+
raise ValueError(
|
377 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
378 |
+
f" {type(callback_steps)}."
|
379 |
+
)
|
380 |
+
|
381 |
+
def prepare_latents(
|
382 |
+
self,
|
383 |
+
batch_size,
|
384 |
+
num_channels_latents,
|
385 |
+
height,
|
386 |
+
width,
|
387 |
+
dtype,
|
388 |
+
device,
|
389 |
+
generator,
|
390 |
+
latents=None,
|
391 |
+
):
|
392 |
+
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
393 |
+
if latents is None:
|
394 |
+
if device.type == "mps":
|
395 |
+
# randn does not work reproducibly on mps
|
396 |
+
latents = torch.randn(
|
397 |
+
shape, generator=generator, device="cpu", dtype=dtype
|
398 |
+
).to(device)
|
399 |
+
else:
|
400 |
+
latents = torch.randn(
|
401 |
+
shape, generator=generator, device=device, dtype=dtype
|
402 |
+
)
|
403 |
+
else:
|
404 |
+
# if latents.shape != shape:
|
405 |
+
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
406 |
+
latents = latents.to(device)
|
407 |
+
|
408 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
409 |
+
return latents
|
410 |
+
|
411 |
+
def preprocess(self, image):
|
412 |
+
if isinstance(image, torch.Tensor):
|
413 |
+
return image
|
414 |
+
elif isinstance(image, PIL.Image.Image):
|
415 |
+
image = [image]
|
416 |
+
|
417 |
+
if isinstance(image[0], PIL.Image.Image):
|
418 |
+
w, h = image[0].size
|
419 |
+
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
420 |
+
|
421 |
+
image = [
|
422 |
+
np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
|
423 |
+
None, :
|
424 |
+
]
|
425 |
+
for i in image
|
426 |
+
]
|
427 |
+
image = np.concatenate(image, axis=0)
|
428 |
+
image = np.array(image).astype(np.float32) / 255.0
|
429 |
+
image = image.transpose(0, 3, 1, 2)
|
430 |
+
image = 2.0 * image - 1.0
|
431 |
+
image = torch.from_numpy(image)
|
432 |
+
elif isinstance(image[0], torch.Tensor):
|
433 |
+
image = torch.cat(image, dim=0)
|
434 |
+
return image
|
435 |
+
|
436 |
+
@torch.no_grad()
|
437 |
+
def img2img(
|
438 |
+
self,
|
439 |
+
prompt: Union[str, List[str]],
|
440 |
+
num_inference_steps: int = 50,
|
441 |
+
guidance_scale: float = 7.5,
|
442 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
443 |
+
generator: Optional[torch.Generator] = None,
|
444 |
+
image: Optional[torch.FloatTensor] = None,
|
445 |
+
output_type: Optional[str] = "pil",
|
446 |
+
latents=None,
|
447 |
+
strength=1.0,
|
448 |
+
pww_state=None,
|
449 |
+
pww_attn_weight=1.0,
|
450 |
+
sampler_name="",
|
451 |
+
sampler_opt={},
|
452 |
+
start_time=-1,
|
453 |
+
timeout=180,
|
454 |
+
scale_ratio=8.0,
|
455 |
+
):
|
456 |
+
sampler = self.get_scheduler(sampler_name)
|
457 |
+
if image is not None:
|
458 |
+
image = self.preprocess(image)
|
459 |
+
image = image.to(self.vae.device, dtype=self.vae.dtype)
|
460 |
+
|
461 |
+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
462 |
+
latents = 0.18215 * init_latents
|
463 |
+
|
464 |
+
# 2. Define call parameters
|
465 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
466 |
+
device = self._execution_device
|
467 |
+
latents = latents.to(device, dtype=self.unet.dtype)
|
468 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
469 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
470 |
+
# corresponds to doing no classifier free guidance.
|
471 |
+
do_classifier_free_guidance = True
|
472 |
+
if guidance_scale <= 1.0:
|
473 |
+
raise ValueError("has to use guidance_scale")
|
474 |
+
|
475 |
+
# 3. Encode input prompt
|
476 |
+
text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
|
477 |
+
text_embeddings = text_embeddings.to(self.unet.dtype)
|
478 |
+
|
479 |
+
init_timestep = (
|
480 |
+
int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
|
481 |
+
)
|
482 |
+
sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
|
483 |
+
text_embeddings.device, dtype=text_embeddings.dtype
|
484 |
+
)
|
485 |
+
|
486 |
+
t_start = max(init_timestep - num_inference_steps, 0)
|
487 |
+
sigma_sched = sigmas[t_start:]
|
488 |
+
|
489 |
+
noise = randn_tensor(
|
490 |
+
latents.shape,
|
491 |
+
generator=generator,
|
492 |
+
device=device,
|
493 |
+
dtype=text_embeddings.dtype,
|
494 |
+
)
|
495 |
+
latents = latents.to(device)
|
496 |
+
latents = latents + noise * sigma_sched[0]
|
497 |
+
|
498 |
+
# 5. Prepare latent variables
|
499 |
+
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
|
500 |
+
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
|
501 |
+
latents.device
|
502 |
+
)
|
503 |
+
|
504 |
+
img_state = self.encode_sketchs(
|
505 |
+
pww_state,
|
506 |
+
g_strength=pww_attn_weight,
|
507 |
+
text_ids=text_ids,
|
508 |
+
)
|
509 |
+
|
510 |
+
def model_fn(x, sigma):
|
511 |
+
|
512 |
+
if start_time > 0 and timeout > 0:
|
513 |
+
assert (time.time() - start_time) < timeout, "inference process timed out"
|
514 |
+
|
515 |
+
latent_model_input = torch.cat([x] * 2)
|
516 |
+
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
517 |
+
encoder_state = {
|
518 |
+
"img_state": img_state,
|
519 |
+
"states": text_embeddings,
|
520 |
+
"sigma": sigma[0],
|
521 |
+
"weight_func": weight_func,
|
522 |
+
}
|
523 |
+
|
524 |
+
noise_pred = self.k_diffusion_model(
|
525 |
+
latent_model_input, sigma, cond=encoder_state
|
526 |
+
)
|
527 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
528 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
529 |
+
noise_pred_text - noise_pred_uncond
|
530 |
+
)
|
531 |
+
return noise_pred
|
532 |
+
|
533 |
+
sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
|
534 |
+
latents = sampler(model_fn, latents, **sampler_args)
|
535 |
+
|
536 |
+
# 8. Post-processing
|
537 |
+
image = self.decode_latents(latents)
|
538 |
+
|
539 |
+
# 10. Convert to PIL
|
540 |
+
if output_type == "pil":
|
541 |
+
image = self.numpy_to_pil(image)
|
542 |
+
|
543 |
+
return (image,)
|
544 |
+
|
545 |
+
def get_sigmas(self, steps, params):
|
546 |
+
discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False)
|
547 |
+
steps += 1 if discard_next_to_last_sigma else 0
|
548 |
+
|
549 |
+
if params.get("scheduler", None) == "karras":
|
550 |
+
sigma_min, sigma_max = (
|
551 |
+
self.k_diffusion_model.sigmas[0].item(),
|
552 |
+
self.k_diffusion_model.sigmas[-1].item(),
|
553 |
+
)
|
554 |
+
sigmas = k_diffusion.sampling.get_sigmas_karras(
|
555 |
+
n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
|
556 |
+
)
|
557 |
+
else:
|
558 |
+
sigmas = self.k_diffusion_model.get_sigmas(steps)
|
559 |
+
|
560 |
+
if discard_next_to_last_sigma:
|
561 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
562 |
+
|
563 |
+
return sigmas
|
564 |
+
|
565 |
+
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
|
566 |
+
def get_sampler_extra_args_t2i(self, sigmas, eta, steps, func):
|
567 |
+
extra_params_kwargs = {}
|
568 |
+
|
569 |
+
if "eta" in inspect.signature(func).parameters:
|
570 |
+
extra_params_kwargs["eta"] = eta
|
571 |
+
|
572 |
+
if "sigma_min" in inspect.signature(func).parameters:
|
573 |
+
extra_params_kwargs["sigma_min"] = sigmas[0].item()
|
574 |
+
extra_params_kwargs["sigma_max"] = sigmas[-1].item()
|
575 |
+
|
576 |
+
if "n" in inspect.signature(func).parameters:
|
577 |
+
extra_params_kwargs["n"] = steps
|
578 |
+
else:
|
579 |
+
extra_params_kwargs["sigmas"] = sigmas
|
580 |
+
|
581 |
+
return extra_params_kwargs
|
582 |
+
|
583 |
+
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
|
584 |
+
def get_sampler_extra_args_i2i(self, sigmas, func):
|
585 |
+
extra_params_kwargs = {}
|
586 |
+
|
587 |
+
if "sigma_min" in inspect.signature(func).parameters:
|
588 |
+
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
589 |
+
extra_params_kwargs["sigma_min"] = sigmas[-2]
|
590 |
+
|
591 |
+
if "sigma_max" in inspect.signature(func).parameters:
|
592 |
+
extra_params_kwargs["sigma_max"] = sigmas[0]
|
593 |
+
|
594 |
+
if "n" in inspect.signature(func).parameters:
|
595 |
+
extra_params_kwargs["n"] = len(sigmas) - 1
|
596 |
+
|
597 |
+
if "sigma_sched" in inspect.signature(func).parameters:
|
598 |
+
extra_params_kwargs["sigma_sched"] = sigmas
|
599 |
+
|
600 |
+
if "sigmas" in inspect.signature(func).parameters:
|
601 |
+
extra_params_kwargs["sigmas"] = sigmas
|
602 |
+
|
603 |
+
return extra_params_kwargs
|
604 |
+
|
605 |
+
@torch.no_grad()
|
606 |
+
def txt2img(
|
607 |
+
self,
|
608 |
+
prompt: Union[str, List[str]],
|
609 |
+
height: int = 512,
|
610 |
+
width: int = 512,
|
611 |
+
num_inference_steps: int = 50,
|
612 |
+
guidance_scale: float = 7.5,
|
613 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
614 |
+
eta: float = 0.0,
|
615 |
+
generator: Optional[torch.Generator] = None,
|
616 |
+
latents: Optional[torch.FloatTensor] = None,
|
617 |
+
output_type: Optional[str] = "pil",
|
618 |
+
callback_steps: Optional[int] = 1,
|
619 |
+
upscale=False,
|
620 |
+
upscale_x: float = 2.0,
|
621 |
+
upscale_method: str = "bicubic",
|
622 |
+
upscale_antialias: bool = False,
|
623 |
+
upscale_denoising_strength: int = 0.7,
|
624 |
+
pww_state=None,
|
625 |
+
pww_attn_weight=1.0,
|
626 |
+
sampler_name="",
|
627 |
+
sampler_opt={},
|
628 |
+
start_time=-1,
|
629 |
+
timeout=180,
|
630 |
+
):
|
631 |
+
sampler = self.get_scheduler(sampler_name)
|
632 |
+
# 1. Check inputs. Raise error if not correct
|
633 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
634 |
+
|
635 |
+
# 2. Define call parameters
|
636 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
637 |
+
device = self._execution_device
|
638 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
639 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
640 |
+
# corresponds to doing no classifier free guidance.
|
641 |
+
do_classifier_free_guidance = True
|
642 |
+
if guidance_scale <= 1.0:
|
643 |
+
raise ValueError("has to use guidance_scale")
|
644 |
+
|
645 |
+
# 3. Encode input prompt
|
646 |
+
text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
|
647 |
+
text_embeddings = text_embeddings.to(self.unet.dtype)
|
648 |
+
|
649 |
+
# 4. Prepare timesteps
|
650 |
+
sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
|
651 |
+
text_embeddings.device, dtype=text_embeddings.dtype
|
652 |
+
)
|
653 |
+
|
654 |
+
# 5. Prepare latent variables
|
655 |
+
num_channels_latents = self.unet.in_channels
|
656 |
+
latents = self.prepare_latents(
|
657 |
+
batch_size,
|
658 |
+
num_channels_latents,
|
659 |
+
height,
|
660 |
+
width,
|
661 |
+
text_embeddings.dtype,
|
662 |
+
device,
|
663 |
+
generator,
|
664 |
+
latents,
|
665 |
+
)
|
666 |
+
latents = latents * sigmas[0]
|
667 |
+
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
|
668 |
+
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
|
669 |
+
latents.device
|
670 |
+
)
|
671 |
+
|
672 |
+
img_state = self.encode_sketchs(
|
673 |
+
pww_state,
|
674 |
+
g_strength=pww_attn_weight,
|
675 |
+
text_ids=text_ids,
|
676 |
+
)
|
677 |
+
|
678 |
+
def model_fn(x, sigma):
|
679 |
+
|
680 |
+
if start_time > 0 and timeout > 0:
|
681 |
+
assert (time.time() - start_time) < timeout, "inference process timed out"
|
682 |
+
|
683 |
+
latent_model_input = torch.cat([x] * 2)
|
684 |
+
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
685 |
+
encoder_state = {
|
686 |
+
"img_state": img_state,
|
687 |
+
"states": text_embeddings,
|
688 |
+
"sigma": sigma[0],
|
689 |
+
"weight_func": weight_func,
|
690 |
+
}
|
691 |
+
|
692 |
+
noise_pred = self.k_diffusion_model(
|
693 |
+
latent_model_input, sigma, cond=encoder_state
|
694 |
+
)
|
695 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
696 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
697 |
+
noise_pred_text - noise_pred_uncond
|
698 |
+
)
|
699 |
+
return noise_pred
|
700 |
+
|
701 |
+
extra_args = self.get_sampler_extra_args_t2i(
|
702 |
+
sigmas, eta, num_inference_steps, sampler
|
703 |
+
)
|
704 |
+
latents = sampler(model_fn, latents, **extra_args)
|
705 |
+
|
706 |
+
if upscale:
|
707 |
+
target_height = height * upscale_x
|
708 |
+
target_width = width * upscale_x
|
709 |
+
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
710 |
+
latents = torch.nn.functional.interpolate(
|
711 |
+
latents,
|
712 |
+
size=(
|
713 |
+
int(target_height // vae_scale_factor),
|
714 |
+
int(target_width // vae_scale_factor),
|
715 |
+
),
|
716 |
+
mode=upscale_method,
|
717 |
+
antialias=upscale_antialias,
|
718 |
+
)
|
719 |
+
return self.img2img(
|
720 |
+
prompt=prompt,
|
721 |
+
num_inference_steps=num_inference_steps,
|
722 |
+
guidance_scale=guidance_scale,
|
723 |
+
negative_prompt=negative_prompt,
|
724 |
+
generator=generator,
|
725 |
+
latents=latents,
|
726 |
+
strength=upscale_denoising_strength,
|
727 |
+
sampler_name=sampler_name,
|
728 |
+
sampler_opt=sampler_opt,
|
729 |
+
pww_state=None,
|
730 |
+
pww_attn_weight=pww_attn_weight / 2,
|
731 |
+
)
|
732 |
+
|
733 |
+
# 8. Post-processing
|
734 |
+
image = self.decode_latents(latents)
|
735 |
+
|
736 |
+
# 10. Convert to PIL
|
737 |
+
if output_type == "pil":
|
738 |
+
image = self.numpy_to_pil(image)
|
739 |
+
|
740 |
+
return (image,)
|
741 |
+
|
742 |
+
|
743 |
+
class FlashAttentionFunction(Function):
|
744 |
+
@staticmethod
|
745 |
+
@torch.no_grad()
|
746 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
747 |
+
"""Algorithm 2 in the paper"""
|
748 |
+
|
749 |
+
device = q.device
|
750 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
751 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
752 |
+
|
753 |
+
o = torch.zeros_like(q)
|
754 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
|
755 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
|
756 |
+
|
757 |
+
scale = q.shape[-1] ** -0.5
|
758 |
+
|
759 |
+
if not exists(mask):
|
760 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
761 |
+
else:
|
762 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
763 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
764 |
+
|
765 |
+
row_splits = zip(
|
766 |
+
q.split(q_bucket_size, dim=-2),
|
767 |
+
o.split(q_bucket_size, dim=-2),
|
768 |
+
mask,
|
769 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
770 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
771 |
+
)
|
772 |
+
|
773 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
774 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
775 |
+
|
776 |
+
col_splits = zip(
|
777 |
+
k.split(k_bucket_size, dim=-2),
|
778 |
+
v.split(k_bucket_size, dim=-2),
|
779 |
+
)
|
780 |
+
|
781 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
782 |
+
k_start_index = k_ind * k_bucket_size
|
783 |
+
|
784 |
+
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
785 |
+
|
786 |
+
if exists(row_mask):
|
787 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
788 |
+
|
789 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
790 |
+
causal_mask = torch.ones(
|
791 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
792 |
+
).triu(q_start_index - k_start_index + 1)
|
793 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
794 |
+
|
795 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
796 |
+
attn_weights -= block_row_maxes
|
797 |
+
exp_weights = torch.exp(attn_weights)
|
798 |
+
|
799 |
+
if exists(row_mask):
|
800 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
801 |
+
|
802 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
803 |
+
min=EPSILON
|
804 |
+
)
|
805 |
+
|
806 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
807 |
+
|
808 |
+
exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
809 |
+
|
810 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
811 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
812 |
+
|
813 |
+
new_row_sums = (
|
814 |
+
exp_row_max_diff * row_sums
|
815 |
+
+ exp_block_row_max_diff * block_row_sums
|
816 |
+
)
|
817 |
+
|
818 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
819 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
820 |
+
)
|
821 |
+
|
822 |
+
row_maxes.copy_(new_row_maxes)
|
823 |
+
row_sums.copy_(new_row_sums)
|
824 |
+
|
825 |
+
lse = all_row_sums.log() + all_row_maxes
|
826 |
+
|
827 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
828 |
+
ctx.save_for_backward(q, k, v, o, lse)
|
829 |
+
|
830 |
+
return o
|
831 |
+
|
832 |
+
@staticmethod
|
833 |
+
@torch.no_grad()
|
834 |
+
def backward(ctx, do):
|
835 |
+
"""Algorithm 4 in the paper"""
|
836 |
+
|
837 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
838 |
+
q, k, v, o, lse = ctx.saved_tensors
|
839 |
+
|
840 |
+
device = q.device
|
841 |
+
|
842 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
843 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
844 |
+
|
845 |
+
dq = torch.zeros_like(q)
|
846 |
+
dk = torch.zeros_like(k)
|
847 |
+
dv = torch.zeros_like(v)
|
848 |
+
|
849 |
+
row_splits = zip(
|
850 |
+
q.split(q_bucket_size, dim=-2),
|
851 |
+
o.split(q_bucket_size, dim=-2),
|
852 |
+
do.split(q_bucket_size, dim=-2),
|
853 |
+
mask,
|
854 |
+
lse.split(q_bucket_size, dim=-2),
|
855 |
+
dq.split(q_bucket_size, dim=-2),
|
856 |
+
)
|
857 |
+
|
858 |
+
for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
|
859 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
860 |
+
|
861 |
+
col_splits = zip(
|
862 |
+
k.split(k_bucket_size, dim=-2),
|
863 |
+
v.split(k_bucket_size, dim=-2),
|
864 |
+
dk.split(k_bucket_size, dim=-2),
|
865 |
+
dv.split(k_bucket_size, dim=-2),
|
866 |
+
)
|
867 |
+
|
868 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
869 |
+
k_start_index = k_ind * k_bucket_size
|
870 |
+
|
871 |
+
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
872 |
+
|
873 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
874 |
+
causal_mask = torch.ones(
|
875 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
876 |
+
).triu(q_start_index - k_start_index + 1)
|
877 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
878 |
+
|
879 |
+
p = torch.exp(attn_weights - lsec)
|
880 |
+
|
881 |
+
if exists(row_mask):
|
882 |
+
p.masked_fill_(~row_mask, 0.0)
|
883 |
+
|
884 |
+
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
|
885 |
+
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
|
886 |
+
|
887 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
888 |
+
ds = p * scale * (dp - D)
|
889 |
+
|
890 |
+
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
|
891 |
+
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
|
892 |
+
|
893 |
+
dqc.add_(dq_chunk)
|
894 |
+
dkc.add_(dk_chunk)
|
895 |
+
dvc.add_(dv_chunk)
|
896 |
+
|
897 |
+
return dq, dk, dv, None, None, None, None
|
modules/prompt_parser.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Code from https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/8e2aeee4a127b295bfc880800e4a312e0f049b85, modified.
|
8 |
+
|
9 |
+
class PromptChunk:
|
10 |
+
"""
|
11 |
+
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
12 |
+
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
13 |
+
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
14 |
+
so just 75 tokens from prompt.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.tokens = []
|
19 |
+
self.multipliers = []
|
20 |
+
self.fixes = []
|
21 |
+
|
22 |
+
|
23 |
+
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
24 |
+
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
25 |
+
have unlimited prompt length and assign weights to tokens in prompt.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, text_encoder, enable_emphasis=True):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.device = lambda: text_encoder.device
|
32 |
+
self.enable_emphasis = enable_emphasis
|
33 |
+
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
34 |
+
depending on model."""
|
35 |
+
|
36 |
+
self.chunk_length = 75
|
37 |
+
|
38 |
+
def empty_chunk(self):
|
39 |
+
"""creates an empty PromptChunk and returns it"""
|
40 |
+
|
41 |
+
chunk = PromptChunk()
|
42 |
+
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
43 |
+
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
44 |
+
return chunk
|
45 |
+
|
46 |
+
def get_target_prompt_token_count(self, token_count):
|
47 |
+
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
48 |
+
|
49 |
+
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
50 |
+
|
51 |
+
def tokenize_line(self, line):
|
52 |
+
"""
|
53 |
+
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
54 |
+
represent the prompt.
|
55 |
+
Returns the list and the total number of tokens in the prompt.
|
56 |
+
"""
|
57 |
+
|
58 |
+
if self.enable_emphasis:
|
59 |
+
parsed = parse_prompt_attention(line)
|
60 |
+
else:
|
61 |
+
parsed = [[line, 1.0]]
|
62 |
+
|
63 |
+
tokenized = self.tokenize([text for text, _ in parsed])
|
64 |
+
|
65 |
+
chunks = []
|
66 |
+
chunk = PromptChunk()
|
67 |
+
token_count = 0
|
68 |
+
last_comma = -1
|
69 |
+
|
70 |
+
def next_chunk(is_last=False):
|
71 |
+
"""puts current chunk into the list of results and produces the next one - empty;
|
72 |
+
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
73 |
+
nonlocal token_count
|
74 |
+
nonlocal last_comma
|
75 |
+
nonlocal chunk
|
76 |
+
|
77 |
+
if is_last:
|
78 |
+
token_count += len(chunk.tokens)
|
79 |
+
else:
|
80 |
+
token_count += self.chunk_length
|
81 |
+
|
82 |
+
to_add = self.chunk_length - len(chunk.tokens)
|
83 |
+
if to_add > 0:
|
84 |
+
chunk.tokens += [self.id_end] * to_add
|
85 |
+
chunk.multipliers += [1.0] * to_add
|
86 |
+
|
87 |
+
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
88 |
+
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
89 |
+
|
90 |
+
last_comma = -1
|
91 |
+
chunks.append(chunk)
|
92 |
+
chunk = PromptChunk()
|
93 |
+
|
94 |
+
comma_padding_backtrack = 20 # default value in https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/shared.py#L410
|
95 |
+
for tokens, (text, weight) in zip(tokenized, parsed):
|
96 |
+
if text == "BREAK" and weight == -1:
|
97 |
+
next_chunk()
|
98 |
+
continue
|
99 |
+
|
100 |
+
position = 0
|
101 |
+
while position < len(tokens):
|
102 |
+
token = tokens[position]
|
103 |
+
|
104 |
+
if token == self.comma_token:
|
105 |
+
last_comma = len(chunk.tokens)
|
106 |
+
|
107 |
+
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
108 |
+
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
109 |
+
elif (
|
110 |
+
comma_padding_backtrack != 0
|
111 |
+
and len(chunk.tokens) == self.chunk_length
|
112 |
+
and last_comma != -1
|
113 |
+
and len(chunk.tokens) - last_comma <= comma_padding_backtrack
|
114 |
+
):
|
115 |
+
break_location = last_comma + 1
|
116 |
+
|
117 |
+
reloc_tokens = chunk.tokens[break_location:]
|
118 |
+
reloc_mults = chunk.multipliers[break_location:]
|
119 |
+
|
120 |
+
chunk.tokens = chunk.tokens[:break_location]
|
121 |
+
chunk.multipliers = chunk.multipliers[:break_location]
|
122 |
+
|
123 |
+
next_chunk()
|
124 |
+
chunk.tokens = reloc_tokens
|
125 |
+
chunk.multipliers = reloc_mults
|
126 |
+
|
127 |
+
if len(chunk.tokens) == self.chunk_length:
|
128 |
+
next_chunk()
|
129 |
+
|
130 |
+
chunk.tokens.append(token)
|
131 |
+
chunk.multipliers.append(weight)
|
132 |
+
position += 1
|
133 |
+
|
134 |
+
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
135 |
+
next_chunk(is_last=True)
|
136 |
+
|
137 |
+
return chunks, token_count
|
138 |
+
|
139 |
+
def process_texts(self, texts):
|
140 |
+
"""
|
141 |
+
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
142 |
+
length, in tokens, of all texts.
|
143 |
+
"""
|
144 |
+
|
145 |
+
token_count = 0
|
146 |
+
|
147 |
+
cache = {}
|
148 |
+
batch_chunks = []
|
149 |
+
for line in texts:
|
150 |
+
if line in cache:
|
151 |
+
chunks = cache[line]
|
152 |
+
else:
|
153 |
+
chunks, current_token_count = self.tokenize_line(line)
|
154 |
+
token_count = max(current_token_count, token_count)
|
155 |
+
|
156 |
+
cache[line] = chunks
|
157 |
+
|
158 |
+
batch_chunks.append(chunks)
|
159 |
+
|
160 |
+
return batch_chunks, token_count
|
161 |
+
|
162 |
+
def forward(self, texts):
|
163 |
+
"""
|
164 |
+
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
165 |
+
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
166 |
+
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
167 |
+
An example shape returned by this function can be: (2, 77, 768).
|
168 |
+
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
169 |
+
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
170 |
+
"""
|
171 |
+
|
172 |
+
batch_chunks, token_count = self.process_texts(texts)
|
173 |
+
chunk_count = max([len(x) for x in batch_chunks])
|
174 |
+
|
175 |
+
zs = []
|
176 |
+
ts = []
|
177 |
+
for i in range(chunk_count):
|
178 |
+
batch_chunk = [
|
179 |
+
chunks[i] if i < len(chunks) else self.empty_chunk()
|
180 |
+
for chunks in batch_chunks
|
181 |
+
]
|
182 |
+
|
183 |
+
tokens = [x.tokens for x in batch_chunk]
|
184 |
+
multipliers = [x.multipliers for x in batch_chunk]
|
185 |
+
# self.embeddings.fixes = [x.fixes for x in batch_chunk]
|
186 |
+
|
187 |
+
# for fixes in self.embeddings.fixes:
|
188 |
+
# for position, embedding in fixes:
|
189 |
+
# used_embeddings[embedding.name] = embedding
|
190 |
+
|
191 |
+
z = self.process_tokens(tokens, multipliers)
|
192 |
+
zs.append(z)
|
193 |
+
ts.append(tokens)
|
194 |
+
|
195 |
+
return np.hstack(ts), torch.hstack(zs)
|
196 |
+
|
197 |
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
198 |
+
"""
|
199 |
+
sends one single prompt chunk to be encoded by transformers neural network.
|
200 |
+
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
201 |
+
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
202 |
+
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
203 |
+
corresponds to one token.
|
204 |
+
"""
|
205 |
+
tokens = torch.asarray(remade_batch_tokens).to(self.device())
|
206 |
+
|
207 |
+
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
208 |
+
if self.id_end != self.id_pad:
|
209 |
+
for batch_pos in range(len(remade_batch_tokens)):
|
210 |
+
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
211 |
+
tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
|
212 |
+
|
213 |
+
z = self.encode_with_transformers(tokens)
|
214 |
+
|
215 |
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
216 |
+
batch_multipliers = torch.asarray(batch_multipliers).to(self.device())
|
217 |
+
original_mean = z.mean()
|
218 |
+
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
219 |
+
new_mean = z.mean()
|
220 |
+
z = z * (original_mean / new_mean)
|
221 |
+
|
222 |
+
return z
|
223 |
+
|
224 |
+
|
225 |
+
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
226 |
+
def __init__(self, tokenizer, text_encoder):
|
227 |
+
super().__init__(text_encoder)
|
228 |
+
self.tokenizer = tokenizer
|
229 |
+
self.text_encoder = text_encoder
|
230 |
+
|
231 |
+
vocab = self.tokenizer.get_vocab()
|
232 |
+
|
233 |
+
self.comma_token = vocab.get(",</w>", None)
|
234 |
+
|
235 |
+
self.token_mults = {}
|
236 |
+
tokens_with_parens = [
|
237 |
+
(k, v)
|
238 |
+
for k, v in vocab.items()
|
239 |
+
if "(" in k or ")" in k or "[" in k or "]" in k
|
240 |
+
]
|
241 |
+
for text, ident in tokens_with_parens:
|
242 |
+
mult = 1.0
|
243 |
+
for c in text:
|
244 |
+
if c == "[":
|
245 |
+
mult /= 1.1
|
246 |
+
if c == "]":
|
247 |
+
mult *= 1.1
|
248 |
+
if c == "(":
|
249 |
+
mult *= 1.1
|
250 |
+
if c == ")":
|
251 |
+
mult /= 1.1
|
252 |
+
|
253 |
+
if mult != 1.0:
|
254 |
+
self.token_mults[ident] = mult
|
255 |
+
|
256 |
+
self.id_start = self.tokenizer.bos_token_id
|
257 |
+
self.id_end = self.tokenizer.eos_token_id
|
258 |
+
self.id_pad = self.id_end
|
259 |
+
|
260 |
+
def tokenize(self, texts):
|
261 |
+
tokenized = self.tokenizer(
|
262 |
+
texts, truncation=False, add_special_tokens=False
|
263 |
+
)["input_ids"]
|
264 |
+
|
265 |
+
return tokenized
|
266 |
+
|
267 |
+
def encode_with_transformers(self, tokens):
|
268 |
+
CLIP_stop_at_last_layers = 1
|
269 |
+
tokens = tokens.to(self.text_encoder.device)
|
270 |
+
outputs = self.text_encoder(tokens, output_hidden_states=True)
|
271 |
+
|
272 |
+
if CLIP_stop_at_last_layers > 1:
|
273 |
+
z = outputs.hidden_states[-CLIP_stop_at_last_layers]
|
274 |
+
z = self.text_encoder.text_model.final_layer_norm(z)
|
275 |
+
else:
|
276 |
+
z = outputs.last_hidden_state
|
277 |
+
|
278 |
+
return z
|
279 |
+
|
280 |
+
|
281 |
+
re_attention = re.compile(
|
282 |
+
r"""
|
283 |
+
\\\(|
|
284 |
+
\\\)|
|
285 |
+
\\\[|
|
286 |
+
\\]|
|
287 |
+
\\\\|
|
288 |
+
\\|
|
289 |
+
\(|
|
290 |
+
\[|
|
291 |
+
:([+-]?[.\d]+)\)|
|
292 |
+
\)|
|
293 |
+
]|
|
294 |
+
[^\\()\[\]:]+|
|
295 |
+
:
|
296 |
+
""",
|
297 |
+
re.X,
|
298 |
+
)
|
299 |
+
|
300 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
301 |
+
|
302 |
+
|
303 |
+
def parse_prompt_attention(text):
|
304 |
+
"""
|
305 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
306 |
+
Accepted tokens are:
|
307 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
308 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
309 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
310 |
+
\( - literal character '('
|
311 |
+
\[ - literal character '['
|
312 |
+
\) - literal character ')'
|
313 |
+
\] - literal character ']'
|
314 |
+
\\ - literal character '\'
|
315 |
+
anything else - just text
|
316 |
+
|
317 |
+
>>> parse_prompt_attention('normal text')
|
318 |
+
[['normal text', 1.0]]
|
319 |
+
>>> parse_prompt_attention('an (important) word')
|
320 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
321 |
+
>>> parse_prompt_attention('(unbalanced')
|
322 |
+
[['unbalanced', 1.1]]
|
323 |
+
>>> parse_prompt_attention('\(literal\]')
|
324 |
+
[['(literal]', 1.0]]
|
325 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
326 |
+
[['unnecessaryparens', 1.1]]
|
327 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
328 |
+
[['a ', 1.0],
|
329 |
+
['house', 1.5730000000000004],
|
330 |
+
[' ', 1.1],
|
331 |
+
['on', 1.0],
|
332 |
+
[' a ', 1.1],
|
333 |
+
['hill', 0.55],
|
334 |
+
[', sun, ', 1.1],
|
335 |
+
['sky', 1.4641000000000006],
|
336 |
+
['.', 1.1]]
|
337 |
+
"""
|
338 |
+
|
339 |
+
res = []
|
340 |
+
round_brackets = []
|
341 |
+
square_brackets = []
|
342 |
+
|
343 |
+
round_bracket_multiplier = 1.1
|
344 |
+
square_bracket_multiplier = 1 / 1.1
|
345 |
+
|
346 |
+
def multiply_range(start_position, multiplier):
|
347 |
+
for p in range(start_position, len(res)):
|
348 |
+
res[p][1] *= multiplier
|
349 |
+
|
350 |
+
for m in re_attention.finditer(text):
|
351 |
+
text = m.group(0)
|
352 |
+
weight = m.group(1)
|
353 |
+
|
354 |
+
if text.startswith("\\"):
|
355 |
+
res.append([text[1:], 1.0])
|
356 |
+
elif text == "(":
|
357 |
+
round_brackets.append(len(res))
|
358 |
+
elif text == "[":
|
359 |
+
square_brackets.append(len(res))
|
360 |
+
elif weight is not None and len(round_brackets) > 0:
|
361 |
+
multiply_range(round_brackets.pop(), float(weight))
|
362 |
+
elif text == ")" and len(round_brackets) > 0:
|
363 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
364 |
+
elif text == "]" and len(square_brackets) > 0:
|
365 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
366 |
+
else:
|
367 |
+
parts = re.split(re_break, text)
|
368 |
+
for i, part in enumerate(parts):
|
369 |
+
if i > 0:
|
370 |
+
res.append(["BREAK", -1])
|
371 |
+
res.append([part, 1.0])
|
372 |
+
|
373 |
+
for pos in round_brackets:
|
374 |
+
multiply_range(pos, round_bracket_multiplier)
|
375 |
+
|
376 |
+
for pos in square_brackets:
|
377 |
+
multiply_range(pos, square_bracket_multiplier)
|
378 |
+
|
379 |
+
if len(res) == 0:
|
380 |
+
res = [["", 1.0]]
|
381 |
+
|
382 |
+
# merge runs of identical weights
|
383 |
+
i = 0
|
384 |
+
while i + 1 < len(res):
|
385 |
+
if res[i][1] == res[i + 1][1]:
|
386 |
+
res[i][0] += res[i + 1][0]
|
387 |
+
res.pop(i + 1)
|
388 |
+
else:
|
389 |
+
i += 1
|
390 |
+
|
391 |
+
return res
|
modules/safe.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this code is adapted from the script contributed by anon from /h/
|
2 |
+
# modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
|
3 |
+
|
4 |
+
import io
|
5 |
+
import pickle
|
6 |
+
import collections
|
7 |
+
import sys
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy
|
12 |
+
import _codecs
|
13 |
+
import zipfile
|
14 |
+
import re
|
15 |
+
|
16 |
+
|
17 |
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
18 |
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
19 |
+
|
20 |
+
|
21 |
+
def encode(*args):
|
22 |
+
out = _codecs.encode(*args)
|
23 |
+
return out
|
24 |
+
|
25 |
+
|
26 |
+
class RestrictedUnpickler(pickle.Unpickler):
|
27 |
+
extra_handler = None
|
28 |
+
|
29 |
+
def persistent_load(self, saved_id):
|
30 |
+
assert saved_id[0] == 'storage'
|
31 |
+
return TypedStorage()
|
32 |
+
|
33 |
+
def find_class(self, module, name):
|
34 |
+
if self.extra_handler is not None:
|
35 |
+
res = self.extra_handler(module, name)
|
36 |
+
if res is not None:
|
37 |
+
return res
|
38 |
+
|
39 |
+
if module == 'collections' and name == 'OrderedDict':
|
40 |
+
return getattr(collections, name)
|
41 |
+
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
42 |
+
return getattr(torch._utils, name)
|
43 |
+
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
44 |
+
return getattr(torch, name)
|
45 |
+
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
46 |
+
return getattr(torch.nn.modules.container, name)
|
47 |
+
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
48 |
+
return getattr(numpy.core.multiarray, name)
|
49 |
+
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
50 |
+
return getattr(numpy, name)
|
51 |
+
if module == '_codecs' and name == 'encode':
|
52 |
+
return encode
|
53 |
+
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
54 |
+
import pytorch_lightning.callbacks
|
55 |
+
return pytorch_lightning.callbacks.model_checkpoint
|
56 |
+
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
57 |
+
import pytorch_lightning.callbacks.model_checkpoint
|
58 |
+
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
59 |
+
if module == "__builtin__" and name == 'set':
|
60 |
+
return set
|
61 |
+
|
62 |
+
# Forbid everything else.
|
63 |
+
raise Exception(f"global '{module}/{name}' is forbidden")
|
64 |
+
|
65 |
+
|
66 |
+
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
67 |
+
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
68 |
+
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
69 |
+
|
70 |
+
def check_zip_filenames(filename, names):
|
71 |
+
for name in names:
|
72 |
+
if allowed_zip_names_re.match(name):
|
73 |
+
continue
|
74 |
+
|
75 |
+
raise Exception(f"bad file inside {filename}: {name}")
|
76 |
+
|
77 |
+
|
78 |
+
def check_pt(filename, extra_handler):
|
79 |
+
try:
|
80 |
+
|
81 |
+
# new pytorch format is a zip file
|
82 |
+
with zipfile.ZipFile(filename) as z:
|
83 |
+
check_zip_filenames(filename, z.namelist())
|
84 |
+
|
85 |
+
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
86 |
+
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
87 |
+
if len(data_pkl_filenames) == 0:
|
88 |
+
raise Exception(f"data.pkl not found in {filename}")
|
89 |
+
if len(data_pkl_filenames) > 1:
|
90 |
+
raise Exception(f"Multiple data.pkl found in {filename}")
|
91 |
+
with z.open(data_pkl_filenames[0]) as file:
|
92 |
+
unpickler = RestrictedUnpickler(file)
|
93 |
+
unpickler.extra_handler = extra_handler
|
94 |
+
unpickler.load()
|
95 |
+
|
96 |
+
except zipfile.BadZipfile:
|
97 |
+
|
98 |
+
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
99 |
+
with open(filename, "rb") as file:
|
100 |
+
unpickler = RestrictedUnpickler(file)
|
101 |
+
unpickler.extra_handler = extra_handler
|
102 |
+
for i in range(5):
|
103 |
+
unpickler.load()
|
104 |
+
|
105 |
+
|
106 |
+
def load(filename, *args, **kwargs):
|
107 |
+
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
108 |
+
|
109 |
+
|
110 |
+
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
111 |
+
"""
|
112 |
+
this function is intended to be used by extensions that want to load models with
|
113 |
+
some extra classes in them that the usual unpickler would find suspicious.
|
114 |
+
|
115 |
+
Use the extra_handler argument to specify a function that takes module and field name as text,
|
116 |
+
and returns that field's value:
|
117 |
+
|
118 |
+
```python
|
119 |
+
def extra(module, name):
|
120 |
+
if module == 'collections' and name == 'OrderedDict':
|
121 |
+
return collections.OrderedDict
|
122 |
+
|
123 |
+
return None
|
124 |
+
|
125 |
+
safe.load_with_extra('model.pt', extra_handler=extra)
|
126 |
+
```
|
127 |
+
|
128 |
+
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
129 |
+
definitely unsafe.
|
130 |
+
"""
|
131 |
+
|
132 |
+
try:
|
133 |
+
check_pt(filename, extra_handler)
|
134 |
+
|
135 |
+
except pickle.UnpicklingError:
|
136 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
137 |
+
print(traceback.format_exc(), file=sys.stderr)
|
138 |
+
print("The file is most likely corrupted.", file=sys.stderr)
|
139 |
+
return None
|
140 |
+
|
141 |
+
except Exception:
|
142 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
143 |
+
print(traceback.format_exc(), file=sys.stderr)
|
144 |
+
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
145 |
+
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
146 |
+
return None
|
147 |
+
|
148 |
+
return unsafe_torch_load(filename, *args, **kwargs)
|
149 |
+
|
150 |
+
|
151 |
+
class Extra:
|
152 |
+
"""
|
153 |
+
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
154 |
+
(because it's not your code making the torch.load call). The intended use is like this:
|
155 |
+
|
156 |
+
```
|
157 |
+
import torch
|
158 |
+
from modules import safe
|
159 |
+
|
160 |
+
def handler(module, name):
|
161 |
+
if module == 'torch' and name in ['float64', 'float16']:
|
162 |
+
return getattr(torch, name)
|
163 |
+
|
164 |
+
return None
|
165 |
+
|
166 |
+
with safe.Extra(handler):
|
167 |
+
x = torch.load('model.pt')
|
168 |
+
```
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, handler):
|
172 |
+
self.handler = handler
|
173 |
+
|
174 |
+
def __enter__(self):
|
175 |
+
global global_extra_handler
|
176 |
+
|
177 |
+
assert global_extra_handler is None, 'already inside an Extra() block'
|
178 |
+
global_extra_handler = self.handler
|
179 |
+
|
180 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
181 |
+
global global_extra_handler
|
182 |
+
|
183 |
+
global_extra_handler = None
|
184 |
+
|
185 |
+
|
186 |
+
unsafe_torch_load = torch.load
|
187 |
+
torch.load = load
|
188 |
+
global_extra_handler = None
|