Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- external/llite/library/__init__.py +0 -0
- external/llite/library/attention_processors.py +227 -0
- external/llite/library/config_util.py +621 -0
- external/llite/library/custom_train_functions.py +529 -0
- external/llite/library/huggingface_util.py +81 -0
- external/llite/library/hypernetwork.py +223 -0
- external/llite/library/ipex/__init__.py +169 -0
- external/llite/library/ipex/attention.py +151 -0
- external/llite/library/ipex/diffusers.py +120 -0
- external/llite/library/ipex/gradscaler.py +183 -0
- external/llite/library/ipex/hijacks.py +252 -0
- external/llite/library/lpw_stable_diffusion.py +1254 -0
- external/llite/library/model_util.py +1350 -0
- external/llite/library/original_unet.py +1915 -0
- external/llite/library/sai_model_spec.py +305 -0
- external/llite/library/sdxl_lpw_stable_diffusion.py +1342 -0
- external/llite/library/sdxl_model_util.py +578 -0
- external/llite/library/sdxl_original_unet.py +1281 -0
- external/llite/library/sdxl_train_util.py +367 -0
- external/llite/library/slicing_vae.py +679 -0
- external/llite/library/train_util.py +0 -0
- external/llite/library/utils.py +6 -0
- external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py +446 -0
- external/llite/networks/check_lora_weights.py +45 -0
- external/llite/networks/control_net_lllite.py +446 -0
- external/llite/networks/control_net_lllite_for_train.py +502 -0
- external/llite/networks/dylora.py +450 -0
- external/llite/networks/extract_lora_from_dylora.py +125 -0
- external/llite/networks/extract_lora_from_models.py +296 -0
- external/llite/networks/lora.py +1225 -0
- external/llite/networks/lora_diffusers.py +609 -0
- external/llite/networks/lora_fa.py +1241 -0
- external/llite/networks/lora_interrogator.py +139 -0
- external/llite/networks/merge_lora.py +357 -0
- external/llite/networks/merge_lora_old.py +185 -0
- external/llite/networks/oft.py +430 -0
- external/llite/networks/resize_lora.py +362 -0
- external/llite/networks/sdxl_merge_lora.py +348 -0
- external/llite/networks/svd_merge_lora.py +260 -0
- external/llite/tools/cache_latents.py +194 -0
- external/llite/tools/cache_text_encoder_outputs.py +191 -0
- external/llite/tools/canny.py +30 -0
- external/llite/tools/convert_diffusers20_original_sd.py +160 -0
- external/llite/tools/detect_face_rotate.py +246 -0
- external/llite/tools/latent_upscaler.py +348 -0
- external/llite/tools/merge_models.py +168 -0
- external/llite/tools/original_control_net.py +337 -0
- external/llite/tools/resize_images_to_resolution.py +128 -0
- external/llite/tools/show_metadata.py +19 -0
- inference.py +38 -16
external/llite/library/__init__.py
ADDED
File without changes
|
external/llite/library/attention_processors.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
|
7 |
+
|
8 |
+
# flash attention forwards and backwards
|
9 |
+
|
10 |
+
# https://arxiv.org/abs/2205.14135
|
11 |
+
|
12 |
+
EPSILON = 1e-6
|
13 |
+
|
14 |
+
|
15 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
16 |
+
@staticmethod
|
17 |
+
@torch.no_grad()
|
18 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
19 |
+
"""Algorithm 2 in the paper"""
|
20 |
+
|
21 |
+
device = q.device
|
22 |
+
dtype = q.dtype
|
23 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
24 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
25 |
+
|
26 |
+
o = torch.zeros_like(q)
|
27 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
28 |
+
all_row_maxes = torch.full(
|
29 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
30 |
+
)
|
31 |
+
|
32 |
+
scale = q.shape[-1] ** -0.5
|
33 |
+
|
34 |
+
if mask is None:
|
35 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
36 |
+
else:
|
37 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
38 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
39 |
+
|
40 |
+
row_splits = zip(
|
41 |
+
q.split(q_bucket_size, dim=-2),
|
42 |
+
o.split(q_bucket_size, dim=-2),
|
43 |
+
mask,
|
44 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
45 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
46 |
+
)
|
47 |
+
|
48 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
49 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
50 |
+
|
51 |
+
col_splits = zip(
|
52 |
+
k.split(k_bucket_size, dim=-2),
|
53 |
+
v.split(k_bucket_size, dim=-2),
|
54 |
+
)
|
55 |
+
|
56 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
57 |
+
k_start_index = k_ind * k_bucket_size
|
58 |
+
|
59 |
+
attn_weights = (
|
60 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
61 |
+
)
|
62 |
+
|
63 |
+
if row_mask is not None:
|
64 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
65 |
+
|
66 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
67 |
+
causal_mask = torch.ones(
|
68 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
69 |
+
).triu(q_start_index - k_start_index + 1)
|
70 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
71 |
+
|
72 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
73 |
+
attn_weights -= block_row_maxes
|
74 |
+
exp_weights = torch.exp(attn_weights)
|
75 |
+
|
76 |
+
if row_mask is not None:
|
77 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
78 |
+
|
79 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
80 |
+
min=EPSILON
|
81 |
+
)
|
82 |
+
|
83 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
84 |
+
|
85 |
+
exp_values = torch.einsum(
|
86 |
+
"... i j, ... j d -> ... i d", exp_weights, vc
|
87 |
+
)
|
88 |
+
|
89 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
90 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
91 |
+
|
92 |
+
new_row_sums = (
|
93 |
+
exp_row_max_diff * row_sums
|
94 |
+
+ exp_block_row_max_diff * block_row_sums
|
95 |
+
)
|
96 |
+
|
97 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
98 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
99 |
+
)
|
100 |
+
|
101 |
+
row_maxes.copy_(new_row_maxes)
|
102 |
+
row_sums.copy_(new_row_sums)
|
103 |
+
|
104 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
105 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
106 |
+
|
107 |
+
return o
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
@torch.no_grad()
|
111 |
+
def backward(ctx, do):
|
112 |
+
"""Algorithm 4 in the paper"""
|
113 |
+
|
114 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
115 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
116 |
+
|
117 |
+
device = q.device
|
118 |
+
|
119 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
120 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
121 |
+
|
122 |
+
dq = torch.zeros_like(q)
|
123 |
+
dk = torch.zeros_like(k)
|
124 |
+
dv = torch.zeros_like(v)
|
125 |
+
|
126 |
+
row_splits = zip(
|
127 |
+
q.split(q_bucket_size, dim=-2),
|
128 |
+
o.split(q_bucket_size, dim=-2),
|
129 |
+
do.split(q_bucket_size, dim=-2),
|
130 |
+
mask,
|
131 |
+
l.split(q_bucket_size, dim=-2),
|
132 |
+
m.split(q_bucket_size, dim=-2),
|
133 |
+
dq.split(q_bucket_size, dim=-2),
|
134 |
+
)
|
135 |
+
|
136 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
137 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
138 |
+
|
139 |
+
col_splits = zip(
|
140 |
+
k.split(k_bucket_size, dim=-2),
|
141 |
+
v.split(k_bucket_size, dim=-2),
|
142 |
+
dk.split(k_bucket_size, dim=-2),
|
143 |
+
dv.split(k_bucket_size, dim=-2),
|
144 |
+
)
|
145 |
+
|
146 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
147 |
+
k_start_index = k_ind * k_bucket_size
|
148 |
+
|
149 |
+
attn_weights = (
|
150 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
151 |
+
)
|
152 |
+
|
153 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
154 |
+
causal_mask = torch.ones(
|
155 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
156 |
+
).triu(q_start_index - k_start_index + 1)
|
157 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
158 |
+
|
159 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
160 |
+
|
161 |
+
if row_mask is not None:
|
162 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
163 |
+
|
164 |
+
p = exp_attn_weights / lc
|
165 |
+
|
166 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
167 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
168 |
+
|
169 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
170 |
+
ds = p * scale * (dp - D)
|
171 |
+
|
172 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
173 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
174 |
+
|
175 |
+
dqc.add_(dq_chunk)
|
176 |
+
dkc.add_(dk_chunk)
|
177 |
+
dvc.add_(dv_chunk)
|
178 |
+
|
179 |
+
return dq, dk, dv, None, None, None, None
|
180 |
+
|
181 |
+
|
182 |
+
class FlashAttnProcessor:
|
183 |
+
def __call__(
|
184 |
+
self,
|
185 |
+
attn: Attention,
|
186 |
+
hidden_states,
|
187 |
+
encoder_hidden_states=None,
|
188 |
+
attention_mask=None,
|
189 |
+
) -> Any:
|
190 |
+
q_bucket_size = 512
|
191 |
+
k_bucket_size = 1024
|
192 |
+
|
193 |
+
h = attn.heads
|
194 |
+
q = attn.to_q(hidden_states)
|
195 |
+
|
196 |
+
encoder_hidden_states = (
|
197 |
+
encoder_hidden_states
|
198 |
+
if encoder_hidden_states is not None
|
199 |
+
else hidden_states
|
200 |
+
)
|
201 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
202 |
+
|
203 |
+
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
204 |
+
context_k, context_v = attn.hypernetwork.forward(
|
205 |
+
hidden_states, encoder_hidden_states
|
206 |
+
)
|
207 |
+
context_k = context_k.to(hidden_states.dtype)
|
208 |
+
context_v = context_v.to(hidden_states.dtype)
|
209 |
+
else:
|
210 |
+
context_k = encoder_hidden_states
|
211 |
+
context_v = encoder_hidden_states
|
212 |
+
|
213 |
+
k = attn.to_k(context_k)
|
214 |
+
v = attn.to_v(context_v)
|
215 |
+
del encoder_hidden_states, hidden_states
|
216 |
+
|
217 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
218 |
+
|
219 |
+
out = FlashAttentionFunction.apply(
|
220 |
+
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
221 |
+
)
|
222 |
+
|
223 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
224 |
+
|
225 |
+
out = attn.to_out[0](out)
|
226 |
+
out = attn.to_out[1](out)
|
227 |
+
return out
|
external/llite/library/config_util.py
ADDED
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
import random
|
8 |
+
from textwrap import dedent, indent
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
# from toolz import curry
|
12 |
+
from typing import (
|
13 |
+
List,
|
14 |
+
Optional,
|
15 |
+
Sequence,
|
16 |
+
Tuple,
|
17 |
+
Union,
|
18 |
+
)
|
19 |
+
|
20 |
+
import toml
|
21 |
+
import voluptuous
|
22 |
+
from voluptuous import (
|
23 |
+
Any,
|
24 |
+
ExactSequence,
|
25 |
+
MultipleInvalid,
|
26 |
+
Object,
|
27 |
+
Required,
|
28 |
+
Schema,
|
29 |
+
)
|
30 |
+
from transformers import CLIPTokenizer
|
31 |
+
|
32 |
+
from . import train_util
|
33 |
+
from .train_util import (
|
34 |
+
DreamBoothSubset,
|
35 |
+
FineTuningSubset,
|
36 |
+
ControlNetSubset,
|
37 |
+
DreamBoothDataset,
|
38 |
+
FineTuningDataset,
|
39 |
+
ControlNetDataset,
|
40 |
+
DatasetGroup,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
45 |
+
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
46 |
+
|
47 |
+
# TODO: inherit Params class in Subset, Dataset
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class BaseSubsetParams:
|
51 |
+
image_dir: Optional[str] = None
|
52 |
+
num_repeats: int = 1
|
53 |
+
shuffle_caption: bool = False
|
54 |
+
caption_separator: str = ',',
|
55 |
+
keep_tokens: int = 0
|
56 |
+
keep_tokens_separator: str = None,
|
57 |
+
color_aug: bool = False
|
58 |
+
flip_aug: bool = False
|
59 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
60 |
+
random_crop: bool = False
|
61 |
+
caption_prefix: Optional[str] = None
|
62 |
+
caption_suffix: Optional[str] = None
|
63 |
+
caption_dropout_rate: float = 0.0
|
64 |
+
caption_dropout_every_n_epochs: int = 0
|
65 |
+
caption_tag_dropout_rate: float = 0.0
|
66 |
+
token_warmup_min: int = 1
|
67 |
+
token_warmup_step: float = 0
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
71 |
+
is_reg: bool = False
|
72 |
+
class_tokens: Optional[str] = None
|
73 |
+
caption_extension: str = ".caption"
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
77 |
+
metadata_file: Optional[str] = None
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
81 |
+
conditioning_data_dir: str = None
|
82 |
+
caption_extension: str = ".caption"
|
83 |
+
|
84 |
+
@dataclass
|
85 |
+
class BaseDatasetParams:
|
86 |
+
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
87 |
+
max_token_length: int = None
|
88 |
+
resolution: Optional[Tuple[int, int]] = None
|
89 |
+
debug_dataset: bool = False
|
90 |
+
|
91 |
+
@dataclass
|
92 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
93 |
+
batch_size: int = 1
|
94 |
+
enable_bucket: bool = False
|
95 |
+
min_bucket_reso: int = 256
|
96 |
+
max_bucket_reso: int = 1024
|
97 |
+
bucket_reso_steps: int = 64
|
98 |
+
bucket_no_upscale: bool = False
|
99 |
+
prior_loss_weight: float = 1.0
|
100 |
+
|
101 |
+
@dataclass
|
102 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
103 |
+
batch_size: int = 1
|
104 |
+
enable_bucket: bool = False
|
105 |
+
min_bucket_reso: int = 256
|
106 |
+
max_bucket_reso: int = 1024
|
107 |
+
bucket_reso_steps: int = 64
|
108 |
+
bucket_no_upscale: bool = False
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
112 |
+
batch_size: int = 1
|
113 |
+
enable_bucket: bool = False
|
114 |
+
min_bucket_reso: int = 256
|
115 |
+
max_bucket_reso: int = 1024
|
116 |
+
bucket_reso_steps: int = 64
|
117 |
+
bucket_no_upscale: bool = False
|
118 |
+
|
119 |
+
@dataclass
|
120 |
+
class SubsetBlueprint:
|
121 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class DatasetBlueprint:
|
125 |
+
is_dreambooth: bool
|
126 |
+
is_controlnet: bool
|
127 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
128 |
+
subsets: Sequence[SubsetBlueprint]
|
129 |
+
|
130 |
+
@dataclass
|
131 |
+
class DatasetGroupBlueprint:
|
132 |
+
datasets: Sequence[DatasetBlueprint]
|
133 |
+
@dataclass
|
134 |
+
class Blueprint:
|
135 |
+
dataset_group: DatasetGroupBlueprint
|
136 |
+
|
137 |
+
|
138 |
+
class ConfigSanitizer:
|
139 |
+
# @curry
|
140 |
+
@staticmethod
|
141 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
142 |
+
Schema(ExactSequence([klass, klass]))(value)
|
143 |
+
return tuple(value)
|
144 |
+
|
145 |
+
# @curry
|
146 |
+
@staticmethod
|
147 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
148 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
149 |
+
try:
|
150 |
+
Schema(klass)(value)
|
151 |
+
return (value, value)
|
152 |
+
except:
|
153 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
154 |
+
|
155 |
+
# subset schema
|
156 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
157 |
+
"color_aug": bool,
|
158 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
159 |
+
"flip_aug": bool,
|
160 |
+
"num_repeats": int,
|
161 |
+
"random_crop": bool,
|
162 |
+
"shuffle_caption": bool,
|
163 |
+
"keep_tokens": int,
|
164 |
+
"keep_tokens_separator": str,
|
165 |
+
"token_warmup_min": int,
|
166 |
+
"token_warmup_step": Any(float,int),
|
167 |
+
"caption_prefix": str,
|
168 |
+
"caption_suffix": str,
|
169 |
+
}
|
170 |
+
# DO means DropOut
|
171 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
172 |
+
"caption_dropout_every_n_epochs": int,
|
173 |
+
"caption_dropout_rate": Any(float, int),
|
174 |
+
"caption_tag_dropout_rate": Any(float, int),
|
175 |
+
}
|
176 |
+
# DB means DreamBooth
|
177 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
178 |
+
"caption_extension": str,
|
179 |
+
"class_tokens": str,
|
180 |
+
}
|
181 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
182 |
+
Required("image_dir"): str,
|
183 |
+
"is_reg": bool,
|
184 |
+
}
|
185 |
+
# FT means FineTuning
|
186 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
187 |
+
Required("metadata_file"): str,
|
188 |
+
"image_dir": str,
|
189 |
+
}
|
190 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
191 |
+
"caption_extension": str,
|
192 |
+
}
|
193 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
194 |
+
Required("image_dir"): str,
|
195 |
+
Required("conditioning_data_dir"): str,
|
196 |
+
}
|
197 |
+
|
198 |
+
# datasets schema
|
199 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
200 |
+
"batch_size": int,
|
201 |
+
"bucket_no_upscale": bool,
|
202 |
+
"bucket_reso_steps": int,
|
203 |
+
"enable_bucket": bool,
|
204 |
+
"max_bucket_reso": int,
|
205 |
+
"min_bucket_reso": int,
|
206 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
207 |
+
}
|
208 |
+
|
209 |
+
# options handled by argparse but not handled by user config
|
210 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
211 |
+
"debug_dataset": bool,
|
212 |
+
"max_token_length": Any(None, int),
|
213 |
+
"prior_loss_weight": Any(float, int),
|
214 |
+
}
|
215 |
+
# for handling default None value of argparse
|
216 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
217 |
+
"face_crop_aug_range",
|
218 |
+
"resolution",
|
219 |
+
]
|
220 |
+
# prepare map because option name may differ among argparse and user config
|
221 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
222 |
+
"train_batch_size": "batch_size",
|
223 |
+
"dataset_repeats": "num_repeats",
|
224 |
+
}
|
225 |
+
|
226 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
227 |
+
assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
228 |
+
|
229 |
+
self.db_subset_schema = self.__merge_dict(
|
230 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
231 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
232 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
233 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
234 |
+
)
|
235 |
+
|
236 |
+
self.ft_subset_schema = self.__merge_dict(
|
237 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
238 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
239 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
240 |
+
)
|
241 |
+
|
242 |
+
self.cn_subset_schema = self.__merge_dict(
|
243 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
244 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
245 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
246 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
247 |
+
)
|
248 |
+
|
249 |
+
self.db_dataset_schema = self.__merge_dict(
|
250 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
251 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
252 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
253 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
254 |
+
{"subsets": [self.db_subset_schema]},
|
255 |
+
)
|
256 |
+
|
257 |
+
self.ft_dataset_schema = self.__merge_dict(
|
258 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
259 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
260 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
261 |
+
{"subsets": [self.ft_subset_schema]},
|
262 |
+
)
|
263 |
+
|
264 |
+
self.cn_dataset_schema = self.__merge_dict(
|
265 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
266 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
267 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
268 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
269 |
+
{"subsets": [self.cn_subset_schema]},
|
270 |
+
)
|
271 |
+
|
272 |
+
if support_dreambooth and support_finetuning:
|
273 |
+
def validate_flex_dataset(dataset_config: dict):
|
274 |
+
subsets_config = dataset_config.get("subsets", [])
|
275 |
+
|
276 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
277 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
278 |
+
# check dataset meets FT style
|
279 |
+
# NOTE: all FT subsets should have "metadata_file"
|
280 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
281 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
282 |
+
# check dataset meets DB style
|
283 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
284 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
285 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
286 |
+
else:
|
287 |
+
raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
|
288 |
+
|
289 |
+
self.dataset_schema = validate_flex_dataset
|
290 |
+
elif support_dreambooth:
|
291 |
+
self.dataset_schema = self.db_dataset_schema
|
292 |
+
elif support_finetuning:
|
293 |
+
self.dataset_schema = self.ft_dataset_schema
|
294 |
+
elif support_controlnet:
|
295 |
+
self.dataset_schema = self.cn_dataset_schema
|
296 |
+
|
297 |
+
self.general_schema = self.__merge_dict(
|
298 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
299 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
300 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
301 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
302 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
303 |
+
)
|
304 |
+
|
305 |
+
self.user_config_validator = Schema({
|
306 |
+
"general": self.general_schema,
|
307 |
+
"datasets": [self.dataset_schema],
|
308 |
+
})
|
309 |
+
|
310 |
+
self.argparse_schema = self.__merge_dict(
|
311 |
+
self.general_schema,
|
312 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
313 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
314 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
315 |
+
)
|
316 |
+
|
317 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
318 |
+
|
319 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
320 |
+
try:
|
321 |
+
return self.user_config_validator(user_config)
|
322 |
+
except MultipleInvalid:
|
323 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
324 |
+
print("Invalid user config / ユーザ設定の形式が正しくないようです")
|
325 |
+
raise
|
326 |
+
|
327 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
328 |
+
# However this will help us to detect program bug
|
329 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
330 |
+
try:
|
331 |
+
return self.argparse_config_validator(argparse_namespace)
|
332 |
+
except MultipleInvalid:
|
333 |
+
# XXX: this should be a bug
|
334 |
+
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
335 |
+
raise
|
336 |
+
|
337 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
338 |
+
@staticmethod
|
339 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
340 |
+
merged = {}
|
341 |
+
for schema in dict_list:
|
342 |
+
# merged |= schema
|
343 |
+
for k, v in schema.items():
|
344 |
+
merged[k] = v
|
345 |
+
return merged
|
346 |
+
|
347 |
+
|
348 |
+
class BlueprintGenerator:
|
349 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
|
350 |
+
}
|
351 |
+
|
352 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
353 |
+
self.sanitizer = sanitizer
|
354 |
+
|
355 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
356 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
357 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
358 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
359 |
+
|
360 |
+
# convert argparse namespace to dict like config
|
361 |
+
# NOTE: it is ok to have extra entries in dict
|
362 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
363 |
+
argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
|
364 |
+
|
365 |
+
general_config = sanitized_user_config.get("general", {})
|
366 |
+
|
367 |
+
dataset_blueprints = []
|
368 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
369 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
370 |
+
subsets = dataset_config.get("subsets", [])
|
371 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
372 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
373 |
+
if is_controlnet:
|
374 |
+
subset_params_klass = ControlNetSubsetParams
|
375 |
+
dataset_params_klass = ControlNetDatasetParams
|
376 |
+
elif is_dreambooth:
|
377 |
+
subset_params_klass = DreamBoothSubsetParams
|
378 |
+
dataset_params_klass = DreamBoothDatasetParams
|
379 |
+
else:
|
380 |
+
subset_params_klass = FineTuningSubsetParams
|
381 |
+
dataset_params_klass = FineTuningDatasetParams
|
382 |
+
|
383 |
+
subset_blueprints = []
|
384 |
+
for subset_config in subsets:
|
385 |
+
params = self.generate_params_by_fallbacks(subset_params_klass,
|
386 |
+
[subset_config, dataset_config, general_config, argparse_config, runtime_params])
|
387 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
388 |
+
|
389 |
+
params = self.generate_params_by_fallbacks(dataset_params_klass,
|
390 |
+
[dataset_config, general_config, argparse_config, runtime_params])
|
391 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
392 |
+
|
393 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
394 |
+
|
395 |
+
return Blueprint(dataset_group_blueprint)
|
396 |
+
|
397 |
+
@staticmethod
|
398 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
399 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
400 |
+
search_value = BlueprintGenerator.search_value
|
401 |
+
default_params = asdict(param_klass())
|
402 |
+
param_names = default_params.keys()
|
403 |
+
|
404 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
405 |
+
|
406 |
+
return param_klass(**params)
|
407 |
+
|
408 |
+
@staticmethod
|
409 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
|
410 |
+
for cand in fallbacks:
|
411 |
+
value = cand.get(key)
|
412 |
+
if value is not None:
|
413 |
+
return value
|
414 |
+
|
415 |
+
return default_value
|
416 |
+
|
417 |
+
|
418 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
419 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
420 |
+
|
421 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
422 |
+
if dataset_blueprint.is_controlnet:
|
423 |
+
subset_klass = ControlNetSubset
|
424 |
+
dataset_klass = ControlNetDataset
|
425 |
+
elif dataset_blueprint.is_dreambooth:
|
426 |
+
subset_klass = DreamBoothSubset
|
427 |
+
dataset_klass = DreamBoothDataset
|
428 |
+
else:
|
429 |
+
subset_klass = FineTuningSubset
|
430 |
+
dataset_klass = FineTuningDataset
|
431 |
+
|
432 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
433 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
434 |
+
datasets.append(dataset)
|
435 |
+
|
436 |
+
# print info
|
437 |
+
info = ""
|
438 |
+
for i, dataset in enumerate(datasets):
|
439 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
440 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
441 |
+
info += dedent(f"""\
|
442 |
+
[Dataset {i}]
|
443 |
+
batch_size: {dataset.batch_size}
|
444 |
+
resolution: {(dataset.width, dataset.height)}
|
445 |
+
enable_bucket: {dataset.enable_bucket}
|
446 |
+
""")
|
447 |
+
|
448 |
+
if dataset.enable_bucket:
|
449 |
+
info += indent(dedent(f"""\
|
450 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
451 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
452 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
453 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
454 |
+
\n"""), " ")
|
455 |
+
else:
|
456 |
+
info += "\n"
|
457 |
+
|
458 |
+
for j, subset in enumerate(dataset.subsets):
|
459 |
+
info += indent(dedent(f"""\
|
460 |
+
[Subset {j} of Dataset {i}]
|
461 |
+
image_dir: "{subset.image_dir}"
|
462 |
+
image_count: {subset.img_count}
|
463 |
+
num_repeats: {subset.num_repeats}
|
464 |
+
shuffle_caption: {subset.shuffle_caption}
|
465 |
+
keep_tokens: {subset.keep_tokens}
|
466 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
467 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
468 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
469 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
470 |
+
caption_prefix: {subset.caption_prefix}
|
471 |
+
caption_suffix: {subset.caption_suffix}
|
472 |
+
color_aug: {subset.color_aug}
|
473 |
+
flip_aug: {subset.flip_aug}
|
474 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
475 |
+
random_crop: {subset.random_crop}
|
476 |
+
token_warmup_min: {subset.token_warmup_min},
|
477 |
+
token_warmup_step: {subset.token_warmup_step},
|
478 |
+
"""), " ")
|
479 |
+
|
480 |
+
if is_dreambooth:
|
481 |
+
info += indent(dedent(f"""\
|
482 |
+
is_reg: {subset.is_reg}
|
483 |
+
class_tokens: {subset.class_tokens}
|
484 |
+
caption_extension: {subset.caption_extension}
|
485 |
+
\n"""), " ")
|
486 |
+
elif not is_controlnet:
|
487 |
+
info += indent(dedent(f"""\
|
488 |
+
metadata_file: {subset.metadata_file}
|
489 |
+
\n"""), " ")
|
490 |
+
|
491 |
+
print(info)
|
492 |
+
|
493 |
+
# make buckets first because it determines the length of dataset
|
494 |
+
# and set the same seed for all datasets
|
495 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
496 |
+
for i, dataset in enumerate(datasets):
|
497 |
+
print(f"[Dataset {i}]")
|
498 |
+
dataset.make_buckets()
|
499 |
+
dataset.set_seed(seed)
|
500 |
+
|
501 |
+
return DatasetGroup(datasets)
|
502 |
+
|
503 |
+
|
504 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
505 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
506 |
+
tokens = name.split('_')
|
507 |
+
try:
|
508 |
+
n_repeats = int(tokens[0])
|
509 |
+
except ValueError as e:
|
510 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
511 |
+
return 0, ""
|
512 |
+
caption_by_folder = '_'.join(tokens[1:])
|
513 |
+
return n_repeats, caption_by_folder
|
514 |
+
|
515 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
516 |
+
if base_dir is None:
|
517 |
+
return []
|
518 |
+
|
519 |
+
base_dir: Path = Path(base_dir)
|
520 |
+
if not base_dir.is_dir():
|
521 |
+
return []
|
522 |
+
|
523 |
+
subsets_config = []
|
524 |
+
for subdir in base_dir.iterdir():
|
525 |
+
if not subdir.is_dir():
|
526 |
+
continue
|
527 |
+
|
528 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
529 |
+
if num_repeats < 1:
|
530 |
+
continue
|
531 |
+
|
532 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
533 |
+
subsets_config.append(subset_config)
|
534 |
+
|
535 |
+
return subsets_config
|
536 |
+
|
537 |
+
subsets_config = []
|
538 |
+
subsets_config += generate(train_data_dir, False)
|
539 |
+
subsets_config += generate(reg_data_dir, True)
|
540 |
+
|
541 |
+
return subsets_config
|
542 |
+
|
543 |
+
|
544 |
+
def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
|
545 |
+
def generate(base_dir: Optional[str]):
|
546 |
+
if base_dir is None:
|
547 |
+
return []
|
548 |
+
|
549 |
+
base_dir: Path = Path(base_dir)
|
550 |
+
if not base_dir.is_dir():
|
551 |
+
return []
|
552 |
+
|
553 |
+
subsets_config = []
|
554 |
+
subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
|
555 |
+
subsets_config.append(subset_config)
|
556 |
+
|
557 |
+
return subsets_config
|
558 |
+
|
559 |
+
subsets_config = []
|
560 |
+
subsets_config += generate(train_data_dir)
|
561 |
+
|
562 |
+
return subsets_config
|
563 |
+
|
564 |
+
|
565 |
+
def load_user_config(file: str) -> dict:
|
566 |
+
file: Path = Path(file)
|
567 |
+
if not file.is_file():
|
568 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
569 |
+
|
570 |
+
if file.name.lower().endswith('.json'):
|
571 |
+
try:
|
572 |
+
with open(file, 'r') as f:
|
573 |
+
config = json.load(f)
|
574 |
+
except Exception:
|
575 |
+
print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
576 |
+
raise
|
577 |
+
elif file.name.lower().endswith('.toml'):
|
578 |
+
try:
|
579 |
+
config = toml.load(file)
|
580 |
+
except Exception:
|
581 |
+
print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
582 |
+
raise
|
583 |
+
else:
|
584 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
585 |
+
|
586 |
+
return config
|
587 |
+
|
588 |
+
# for config test
|
589 |
+
if __name__ == "__main__":
|
590 |
+
parser = argparse.ArgumentParser()
|
591 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
592 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
593 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
594 |
+
parser.add_argument("--support_dropout", action="store_true")
|
595 |
+
parser.add_argument("dataset_config")
|
596 |
+
config_args, remain = parser.parse_known_args()
|
597 |
+
|
598 |
+
parser = argparse.ArgumentParser()
|
599 |
+
train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
|
600 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
601 |
+
argparse_namespace = parser.parse_args(remain)
|
602 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
603 |
+
|
604 |
+
print("[argparse_namespace]")
|
605 |
+
print(vars(argparse_namespace))
|
606 |
+
|
607 |
+
user_config = load_user_config(config_args.dataset_config)
|
608 |
+
|
609 |
+
print("\n[user_config]")
|
610 |
+
print(user_config)
|
611 |
+
|
612 |
+
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout)
|
613 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
614 |
+
|
615 |
+
print("\n[sanitized_user_config]")
|
616 |
+
print(sanitized_user_config)
|
617 |
+
|
618 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
619 |
+
|
620 |
+
print("\n[blueprint]")
|
621 |
+
print(blueprint)
|
external/llite/library/custom_train_functions.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
9 |
+
if hasattr(noise_scheduler, "all_snr"):
|
10 |
+
return
|
11 |
+
|
12 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
13 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
14 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
15 |
+
alpha = sqrt_alphas_cumprod
|
16 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
17 |
+
all_snr = (alpha / sigma) ** 2
|
18 |
+
|
19 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
20 |
+
|
21 |
+
|
22 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
23 |
+
# fix beta: zero terminal SNR
|
24 |
+
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
25 |
+
|
26 |
+
def enforce_zero_terminal_snr(betas):
|
27 |
+
# Convert betas to alphas_bar_sqrt
|
28 |
+
alphas = 1 - betas
|
29 |
+
alphas_bar = alphas.cumprod(0)
|
30 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
31 |
+
|
32 |
+
# Store old values.
|
33 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
34 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
35 |
+
# Shift so last timestep is zero.
|
36 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
37 |
+
# Scale so first timestep is back to old value.
|
38 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
39 |
+
|
40 |
+
# Convert alphas_bar_sqrt to betas
|
41 |
+
alphas_bar = alphas_bar_sqrt**2
|
42 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
43 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
44 |
+
betas = 1 - alphas
|
45 |
+
return betas
|
46 |
+
|
47 |
+
betas = noise_scheduler.betas
|
48 |
+
betas = enforce_zero_terminal_snr(betas)
|
49 |
+
alphas = 1.0 - betas
|
50 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
51 |
+
|
52 |
+
# print("original:", noise_scheduler.betas)
|
53 |
+
# print("fixed:", betas)
|
54 |
+
|
55 |
+
noise_scheduler.betas = betas
|
56 |
+
noise_scheduler.alphas = alphas
|
57 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
58 |
+
|
59 |
+
|
60 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
61 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
62 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
63 |
+
if v_prediction:
|
64 |
+
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
65 |
+
else:
|
66 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
67 |
+
loss = loss * snr_weight
|
68 |
+
return loss
|
69 |
+
|
70 |
+
|
71 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
72 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
73 |
+
loss = loss * scale
|
74 |
+
return loss
|
75 |
+
|
76 |
+
|
77 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
78 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
79 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
80 |
+
scale = snr_t / (snr_t + 1)
|
81 |
+
# # show debug info
|
82 |
+
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
83 |
+
return scale
|
84 |
+
|
85 |
+
|
86 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
87 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
88 |
+
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
89 |
+
loss = loss + loss / scale * v_pred_like_loss
|
90 |
+
return loss
|
91 |
+
|
92 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
93 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
94 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
95 |
+
weight = 1/torch.sqrt(snr_t)
|
96 |
+
loss = weight * loss
|
97 |
+
return loss
|
98 |
+
|
99 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
100 |
+
|
101 |
+
|
102 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
103 |
+
parser.add_argument(
|
104 |
+
"--min_snr_gamma",
|
105 |
+
type=float,
|
106 |
+
default=None,
|
107 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--scale_v_pred_loss_like_noise_pred",
|
111 |
+
action="store_true",
|
112 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--v_pred_like_loss",
|
116 |
+
type=float,
|
117 |
+
default=None,
|
118 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--debiased_estimation_loss",
|
122 |
+
action="store_true",
|
123 |
+
help="debiased estimation loss / debiased estimation loss",
|
124 |
+
)
|
125 |
+
if support_weighted_captions:
|
126 |
+
parser.add_argument(
|
127 |
+
"--weighted_captions",
|
128 |
+
action="store_true",
|
129 |
+
default=False,
|
130 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
re_attention = re.compile(
|
135 |
+
r"""
|
136 |
+
\\\(|
|
137 |
+
\\\)|
|
138 |
+
\\\[|
|
139 |
+
\\]|
|
140 |
+
\\\\|
|
141 |
+
\\|
|
142 |
+
\(|
|
143 |
+
\[|
|
144 |
+
:([+-]?[.\d]+)\)|
|
145 |
+
\)|
|
146 |
+
]|
|
147 |
+
[^\\()\[\]:]+|
|
148 |
+
:
|
149 |
+
""",
|
150 |
+
re.X,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def parse_prompt_attention(text):
|
155 |
+
"""
|
156 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
157 |
+
Accepted tokens are:
|
158 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
159 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
160 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
161 |
+
\( - literal character '('
|
162 |
+
\[ - literal character '['
|
163 |
+
\) - literal character ')'
|
164 |
+
\] - literal character ']'
|
165 |
+
\\ - literal character '\'
|
166 |
+
anything else - just text
|
167 |
+
>>> parse_prompt_attention('normal text')
|
168 |
+
[['normal text', 1.0]]
|
169 |
+
>>> parse_prompt_attention('an (important) word')
|
170 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
171 |
+
>>> parse_prompt_attention('(unbalanced')
|
172 |
+
[['unbalanced', 1.1]]
|
173 |
+
>>> parse_prompt_attention('\(literal\]')
|
174 |
+
[['(literal]', 1.0]]
|
175 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
176 |
+
[['unnecessaryparens', 1.1]]
|
177 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
178 |
+
[['a ', 1.0],
|
179 |
+
['house', 1.5730000000000004],
|
180 |
+
[' ', 1.1],
|
181 |
+
['on', 1.0],
|
182 |
+
[' a ', 1.1],
|
183 |
+
['hill', 0.55],
|
184 |
+
[', sun, ', 1.1],
|
185 |
+
['sky', 1.4641000000000006],
|
186 |
+
['.', 1.1]]
|
187 |
+
"""
|
188 |
+
|
189 |
+
res = []
|
190 |
+
round_brackets = []
|
191 |
+
square_brackets = []
|
192 |
+
|
193 |
+
round_bracket_multiplier = 1.1
|
194 |
+
square_bracket_multiplier = 1 / 1.1
|
195 |
+
|
196 |
+
def multiply_range(start_position, multiplier):
|
197 |
+
for p in range(start_position, len(res)):
|
198 |
+
res[p][1] *= multiplier
|
199 |
+
|
200 |
+
for m in re_attention.finditer(text):
|
201 |
+
text = m.group(0)
|
202 |
+
weight = m.group(1)
|
203 |
+
|
204 |
+
if text.startswith("\\"):
|
205 |
+
res.append([text[1:], 1.0])
|
206 |
+
elif text == "(":
|
207 |
+
round_brackets.append(len(res))
|
208 |
+
elif text == "[":
|
209 |
+
square_brackets.append(len(res))
|
210 |
+
elif weight is not None and len(round_brackets) > 0:
|
211 |
+
multiply_range(round_brackets.pop(), float(weight))
|
212 |
+
elif text == ")" and len(round_brackets) > 0:
|
213 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
214 |
+
elif text == "]" and len(square_brackets) > 0:
|
215 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
216 |
+
else:
|
217 |
+
res.append([text, 1.0])
|
218 |
+
|
219 |
+
for pos in round_brackets:
|
220 |
+
multiply_range(pos, round_bracket_multiplier)
|
221 |
+
|
222 |
+
for pos in square_brackets:
|
223 |
+
multiply_range(pos, square_bracket_multiplier)
|
224 |
+
|
225 |
+
if len(res) == 0:
|
226 |
+
res = [["", 1.0]]
|
227 |
+
|
228 |
+
# merge runs of identical weights
|
229 |
+
i = 0
|
230 |
+
while i + 1 < len(res):
|
231 |
+
if res[i][1] == res[i + 1][1]:
|
232 |
+
res[i][0] += res[i + 1][0]
|
233 |
+
res.pop(i + 1)
|
234 |
+
else:
|
235 |
+
i += 1
|
236 |
+
|
237 |
+
return res
|
238 |
+
|
239 |
+
|
240 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
241 |
+
r"""
|
242 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
243 |
+
|
244 |
+
No padding, starting or ending token is included.
|
245 |
+
"""
|
246 |
+
tokens = []
|
247 |
+
weights = []
|
248 |
+
truncated = False
|
249 |
+
for text in prompt:
|
250 |
+
texts_and_weights = parse_prompt_attention(text)
|
251 |
+
text_token = []
|
252 |
+
text_weight = []
|
253 |
+
for word, weight in texts_and_weights:
|
254 |
+
# tokenize and discard the starting and the ending token
|
255 |
+
token = tokenizer(word).input_ids[1:-1]
|
256 |
+
text_token += token
|
257 |
+
# copy the weight by length of token
|
258 |
+
text_weight += [weight] * len(token)
|
259 |
+
# stop if the text is too long (longer than truncation limit)
|
260 |
+
if len(text_token) > max_length:
|
261 |
+
truncated = True
|
262 |
+
break
|
263 |
+
# truncate
|
264 |
+
if len(text_token) > max_length:
|
265 |
+
truncated = True
|
266 |
+
text_token = text_token[:max_length]
|
267 |
+
text_weight = text_weight[:max_length]
|
268 |
+
tokens.append(text_token)
|
269 |
+
weights.append(text_weight)
|
270 |
+
if truncated:
|
271 |
+
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
272 |
+
return tokens, weights
|
273 |
+
|
274 |
+
|
275 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
276 |
+
r"""
|
277 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
278 |
+
"""
|
279 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
280 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
281 |
+
for i in range(len(tokens)):
|
282 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
283 |
+
if no_boseos_middle:
|
284 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
285 |
+
else:
|
286 |
+
w = []
|
287 |
+
if len(weights[i]) == 0:
|
288 |
+
w = [1.0] * weights_length
|
289 |
+
else:
|
290 |
+
for j in range(max_embeddings_multiples):
|
291 |
+
w.append(1.0) # weight for starting token in this chunk
|
292 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
293 |
+
w.append(1.0) # weight for ending token in this chunk
|
294 |
+
w += [1.0] * (weights_length - len(w))
|
295 |
+
weights[i] = w[:]
|
296 |
+
|
297 |
+
return tokens, weights
|
298 |
+
|
299 |
+
|
300 |
+
def get_unweighted_text_embeddings(
|
301 |
+
tokenizer,
|
302 |
+
text_encoder,
|
303 |
+
text_input: torch.Tensor,
|
304 |
+
chunk_length: int,
|
305 |
+
clip_skip: int,
|
306 |
+
eos: int,
|
307 |
+
pad: int,
|
308 |
+
no_boseos_middle: Optional[bool] = True,
|
309 |
+
):
|
310 |
+
"""
|
311 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
312 |
+
it should be split into chunks and sent to the text encoder individually.
|
313 |
+
"""
|
314 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
315 |
+
if max_embeddings_multiples > 1:
|
316 |
+
text_embeddings = []
|
317 |
+
for i in range(max_embeddings_multiples):
|
318 |
+
# extract the i-th chunk
|
319 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
320 |
+
|
321 |
+
# cover the head and the tail by the starting and the ending tokens
|
322 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
323 |
+
if pad == eos: # v1
|
324 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
325 |
+
else: # v2
|
326 |
+
for j in range(len(text_input_chunk)):
|
327 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
328 |
+
text_input_chunk[j, -1] = eos
|
329 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
330 |
+
text_input_chunk[j, 1] = eos
|
331 |
+
|
332 |
+
if clip_skip is None or clip_skip == 1:
|
333 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
334 |
+
else:
|
335 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
336 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
337 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
338 |
+
|
339 |
+
if no_boseos_middle:
|
340 |
+
if i == 0:
|
341 |
+
# discard the ending token
|
342 |
+
text_embedding = text_embedding[:, :-1]
|
343 |
+
elif i == max_embeddings_multiples - 1:
|
344 |
+
# discard the starting token
|
345 |
+
text_embedding = text_embedding[:, 1:]
|
346 |
+
else:
|
347 |
+
# discard both starting and ending tokens
|
348 |
+
text_embedding = text_embedding[:, 1:-1]
|
349 |
+
|
350 |
+
text_embeddings.append(text_embedding)
|
351 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
352 |
+
else:
|
353 |
+
if clip_skip is None or clip_skip == 1:
|
354 |
+
text_embeddings = text_encoder(text_input)[0]
|
355 |
+
else:
|
356 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
357 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
358 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
359 |
+
return text_embeddings
|
360 |
+
|
361 |
+
|
362 |
+
def get_weighted_text_embeddings(
|
363 |
+
tokenizer,
|
364 |
+
text_encoder,
|
365 |
+
prompt: Union[str, List[str]],
|
366 |
+
device,
|
367 |
+
max_embeddings_multiples: Optional[int] = 3,
|
368 |
+
no_boseos_middle: Optional[bool] = False,
|
369 |
+
clip_skip=None,
|
370 |
+
):
|
371 |
+
r"""
|
372 |
+
Prompts can be assigned with local weights using brackets. For example,
|
373 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
374 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
375 |
+
|
376 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
prompt (`str` or `List[str]`):
|
380 |
+
The prompt or prompts to guide the image generation.
|
381 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
382 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
383 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
384 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
385 |
+
ending token in each of the chunk in the middle.
|
386 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
387 |
+
Skip the parsing of brackets.
|
388 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
389 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
390 |
+
"""
|
391 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
392 |
+
if isinstance(prompt, str):
|
393 |
+
prompt = [prompt]
|
394 |
+
|
395 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
396 |
+
|
397 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
398 |
+
max_length = max([len(token) for token in prompt_tokens])
|
399 |
+
|
400 |
+
max_embeddings_multiples = min(
|
401 |
+
max_embeddings_multiples,
|
402 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
403 |
+
)
|
404 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
405 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
406 |
+
|
407 |
+
# pad the length of tokens and weights
|
408 |
+
bos = tokenizer.bos_token_id
|
409 |
+
eos = tokenizer.eos_token_id
|
410 |
+
pad = tokenizer.pad_token_id
|
411 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
412 |
+
prompt_tokens,
|
413 |
+
prompt_weights,
|
414 |
+
max_length,
|
415 |
+
bos,
|
416 |
+
eos,
|
417 |
+
no_boseos_middle=no_boseos_middle,
|
418 |
+
chunk_length=tokenizer.model_max_length,
|
419 |
+
)
|
420 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
421 |
+
|
422 |
+
# get the embeddings
|
423 |
+
text_embeddings = get_unweighted_text_embeddings(
|
424 |
+
tokenizer,
|
425 |
+
text_encoder,
|
426 |
+
prompt_tokens,
|
427 |
+
tokenizer.model_max_length,
|
428 |
+
clip_skip,
|
429 |
+
eos,
|
430 |
+
pad,
|
431 |
+
no_boseos_middle=no_boseos_middle,
|
432 |
+
)
|
433 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
434 |
+
|
435 |
+
# assign weights to the prompts and normalize in the sense of mean
|
436 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
437 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
438 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
439 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
440 |
+
|
441 |
+
return text_embeddings
|
442 |
+
|
443 |
+
|
444 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
445 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
446 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
447 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
448 |
+
for i in range(iterations):
|
449 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
450 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
451 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
452 |
+
if wn == 1 or hn == 1:
|
453 |
+
break # Lowest resolution is 1x1
|
454 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
455 |
+
|
456 |
+
|
457 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
458 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
459 |
+
if noise_offset is None:
|
460 |
+
return noise
|
461 |
+
if adaptive_noise_scale is not None:
|
462 |
+
# latent shape: (batch_size, channels, height, width)
|
463 |
+
# abs mean value for each channel
|
464 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
465 |
+
|
466 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
467 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
468 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
469 |
+
|
470 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
471 |
+
return noise
|
472 |
+
|
473 |
+
|
474 |
+
"""
|
475 |
+
##########################################
|
476 |
+
# Perlin Noise
|
477 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
478 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
479 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
480 |
+
|
481 |
+
grid = (
|
482 |
+
torch.stack(
|
483 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
484 |
+
dim=-1,
|
485 |
+
)
|
486 |
+
% 1
|
487 |
+
)
|
488 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
489 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
490 |
+
|
491 |
+
tile_grads = (
|
492 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
493 |
+
.repeat_interleave(d[0], 0)
|
494 |
+
.repeat_interleave(d[1], 1)
|
495 |
+
)
|
496 |
+
dot = lambda grad, shift: (
|
497 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
498 |
+
* grad[: shape[0], : shape[1]]
|
499 |
+
).sum(dim=-1)
|
500 |
+
|
501 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
502 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
503 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
504 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
505 |
+
t = fade(grid[: shape[0], : shape[1]])
|
506 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
507 |
+
|
508 |
+
|
509 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
510 |
+
noise = torch.zeros(shape, device=device)
|
511 |
+
frequency = 1
|
512 |
+
amplitude = 1
|
513 |
+
for _ in range(octaves):
|
514 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
515 |
+
frequency *= 2
|
516 |
+
amplitude *= persistence
|
517 |
+
return noise
|
518 |
+
|
519 |
+
|
520 |
+
def perlin_noise(noise, device, octaves):
|
521 |
+
_, c, w, h = noise.shape
|
522 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
523 |
+
noise_perlin = []
|
524 |
+
for _ in range(c):
|
525 |
+
noise_perlin.append(perlin())
|
526 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
527 |
+
noise += noise_perlin # broadcast for each batch
|
528 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
529 |
+
"""
|
external/llite/library/huggingface_util.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, BinaryIO
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
from external.llite.library.utils import fire_in_thread
|
7 |
+
|
8 |
+
|
9 |
+
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
10 |
+
api = HfApi(
|
11 |
+
token=token,
|
12 |
+
)
|
13 |
+
try:
|
14 |
+
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
15 |
+
return True
|
16 |
+
except:
|
17 |
+
return False
|
18 |
+
|
19 |
+
|
20 |
+
def upload(
|
21 |
+
args: argparse.Namespace,
|
22 |
+
src: Union[str, Path, bytes, BinaryIO],
|
23 |
+
dest_suffix: str = "",
|
24 |
+
force_sync_upload: bool = False,
|
25 |
+
):
|
26 |
+
repo_id = args.huggingface_repo_id
|
27 |
+
repo_type = args.huggingface_repo_type
|
28 |
+
token = args.huggingface_token
|
29 |
+
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
30 |
+
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
31 |
+
api = HfApi(token=token)
|
32 |
+
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
33 |
+
try:
|
34 |
+
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
35 |
+
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
36 |
+
print("===========================================")
|
37 |
+
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
38 |
+
print("===========================================")
|
39 |
+
|
40 |
+
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
41 |
+
|
42 |
+
def uploader():
|
43 |
+
try:
|
44 |
+
if is_folder:
|
45 |
+
api.upload_folder(
|
46 |
+
repo_id=repo_id,
|
47 |
+
repo_type=repo_type,
|
48 |
+
folder_path=src,
|
49 |
+
path_in_repo=path_in_repo,
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
api.upload_file(
|
53 |
+
repo_id=repo_id,
|
54 |
+
repo_type=repo_type,
|
55 |
+
path_or_fileobj=src,
|
56 |
+
path_in_repo=path_in_repo,
|
57 |
+
)
|
58 |
+
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
59 |
+
print("===========================================")
|
60 |
+
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
61 |
+
print("===========================================")
|
62 |
+
|
63 |
+
if args.async_upload and not force_sync_upload:
|
64 |
+
fire_in_thread(uploader)
|
65 |
+
else:
|
66 |
+
uploader()
|
67 |
+
|
68 |
+
|
69 |
+
def list_dir(
|
70 |
+
repo_id: str,
|
71 |
+
subfolder: str,
|
72 |
+
repo_type: str,
|
73 |
+
revision: str = "main",
|
74 |
+
token: str = None,
|
75 |
+
):
|
76 |
+
api = HfApi(
|
77 |
+
token=token,
|
78 |
+
)
|
79 |
+
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
80 |
+
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
81 |
+
return file_list
|
external/llite/library/hypernetwork.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from diffusers.models.attention_processor import (
|
4 |
+
Attention,
|
5 |
+
AttnProcessor2_0,
|
6 |
+
SlicedAttnProcessor,
|
7 |
+
XFormersAttnProcessor
|
8 |
+
)
|
9 |
+
|
10 |
+
try:
|
11 |
+
import xformers.ops
|
12 |
+
except:
|
13 |
+
xformers = None
|
14 |
+
|
15 |
+
|
16 |
+
loaded_networks = []
|
17 |
+
|
18 |
+
|
19 |
+
def apply_single_hypernetwork(
|
20 |
+
hypernetwork, hidden_states, encoder_hidden_states
|
21 |
+
):
|
22 |
+
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
23 |
+
return context_k, context_v
|
24 |
+
|
25 |
+
|
26 |
+
def apply_hypernetworks(context_k, context_v, layer=None):
|
27 |
+
if len(loaded_networks) == 0:
|
28 |
+
return context_v, context_v
|
29 |
+
for hypernetwork in loaded_networks:
|
30 |
+
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
31 |
+
|
32 |
+
context_k = context_k.to(dtype=context_k.dtype)
|
33 |
+
context_v = context_v.to(dtype=context_k.dtype)
|
34 |
+
|
35 |
+
return context_k, context_v
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def xformers_forward(
|
40 |
+
self: XFormersAttnProcessor,
|
41 |
+
attn: Attention,
|
42 |
+
hidden_states: torch.Tensor,
|
43 |
+
encoder_hidden_states: torch.Tensor = None,
|
44 |
+
attention_mask: torch.Tensor = None,
|
45 |
+
):
|
46 |
+
batch_size, sequence_length, _ = (
|
47 |
+
hidden_states.shape
|
48 |
+
if encoder_hidden_states is None
|
49 |
+
else encoder_hidden_states.shape
|
50 |
+
)
|
51 |
+
|
52 |
+
attention_mask = attn.prepare_attention_mask(
|
53 |
+
attention_mask, sequence_length, batch_size
|
54 |
+
)
|
55 |
+
|
56 |
+
query = attn.to_q(hidden_states)
|
57 |
+
|
58 |
+
if encoder_hidden_states is None:
|
59 |
+
encoder_hidden_states = hidden_states
|
60 |
+
elif attn.norm_cross:
|
61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
62 |
+
|
63 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
64 |
+
|
65 |
+
key = attn.to_k(context_k)
|
66 |
+
value = attn.to_v(context_v)
|
67 |
+
|
68 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
69 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
70 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
71 |
+
|
72 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
73 |
+
query,
|
74 |
+
key,
|
75 |
+
value,
|
76 |
+
attn_bias=attention_mask,
|
77 |
+
op=self.attention_op,
|
78 |
+
scale=attn.scale,
|
79 |
+
)
|
80 |
+
hidden_states = hidden_states.to(query.dtype)
|
81 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
82 |
+
|
83 |
+
# linear proj
|
84 |
+
hidden_states = attn.to_out[0](hidden_states)
|
85 |
+
# dropout
|
86 |
+
hidden_states = attn.to_out[1](hidden_states)
|
87 |
+
return hidden_states
|
88 |
+
|
89 |
+
|
90 |
+
def sliced_attn_forward(
|
91 |
+
self: SlicedAttnProcessor,
|
92 |
+
attn: Attention,
|
93 |
+
hidden_states: torch.Tensor,
|
94 |
+
encoder_hidden_states: torch.Tensor = None,
|
95 |
+
attention_mask: torch.Tensor = None,
|
96 |
+
):
|
97 |
+
batch_size, sequence_length, _ = (
|
98 |
+
hidden_states.shape
|
99 |
+
if encoder_hidden_states is None
|
100 |
+
else encoder_hidden_states.shape
|
101 |
+
)
|
102 |
+
attention_mask = attn.prepare_attention_mask(
|
103 |
+
attention_mask, sequence_length, batch_size
|
104 |
+
)
|
105 |
+
|
106 |
+
query = attn.to_q(hidden_states)
|
107 |
+
dim = query.shape[-1]
|
108 |
+
query = attn.head_to_batch_dim(query)
|
109 |
+
|
110 |
+
if encoder_hidden_states is None:
|
111 |
+
encoder_hidden_states = hidden_states
|
112 |
+
elif attn.norm_cross:
|
113 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
114 |
+
|
115 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
116 |
+
|
117 |
+
key = attn.to_k(context_k)
|
118 |
+
value = attn.to_v(context_v)
|
119 |
+
key = attn.head_to_batch_dim(key)
|
120 |
+
value = attn.head_to_batch_dim(value)
|
121 |
+
|
122 |
+
batch_size_attention, query_tokens, _ = query.shape
|
123 |
+
hidden_states = torch.zeros(
|
124 |
+
(batch_size_attention, query_tokens, dim // attn.heads),
|
125 |
+
device=query.device,
|
126 |
+
dtype=query.dtype,
|
127 |
+
)
|
128 |
+
|
129 |
+
for i in range(batch_size_attention // self.slice_size):
|
130 |
+
start_idx = i * self.slice_size
|
131 |
+
end_idx = (i + 1) * self.slice_size
|
132 |
+
|
133 |
+
query_slice = query[start_idx:end_idx]
|
134 |
+
key_slice = key[start_idx:end_idx]
|
135 |
+
attn_mask_slice = (
|
136 |
+
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
137 |
+
)
|
138 |
+
|
139 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
140 |
+
|
141 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
142 |
+
|
143 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
144 |
+
|
145 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
146 |
+
|
147 |
+
# linear proj
|
148 |
+
hidden_states = attn.to_out[0](hidden_states)
|
149 |
+
# dropout
|
150 |
+
hidden_states = attn.to_out[1](hidden_states)
|
151 |
+
|
152 |
+
return hidden_states
|
153 |
+
|
154 |
+
|
155 |
+
def v2_0_forward(
|
156 |
+
self: AttnProcessor2_0,
|
157 |
+
attn: Attention,
|
158 |
+
hidden_states,
|
159 |
+
encoder_hidden_states=None,
|
160 |
+
attention_mask=None,
|
161 |
+
):
|
162 |
+
batch_size, sequence_length, _ = (
|
163 |
+
hidden_states.shape
|
164 |
+
if encoder_hidden_states is None
|
165 |
+
else encoder_hidden_states.shape
|
166 |
+
)
|
167 |
+
inner_dim = hidden_states.shape[-1]
|
168 |
+
|
169 |
+
if attention_mask is not None:
|
170 |
+
attention_mask = attn.prepare_attention_mask(
|
171 |
+
attention_mask, sequence_length, batch_size
|
172 |
+
)
|
173 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
174 |
+
# (batch, heads, source_length, target_length)
|
175 |
+
attention_mask = attention_mask.view(
|
176 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
177 |
+
)
|
178 |
+
|
179 |
+
query = attn.to_q(hidden_states)
|
180 |
+
|
181 |
+
if encoder_hidden_states is None:
|
182 |
+
encoder_hidden_states = hidden_states
|
183 |
+
elif attn.norm_cross:
|
184 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
185 |
+
|
186 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
187 |
+
|
188 |
+
key = attn.to_k(context_k)
|
189 |
+
value = attn.to_v(context_v)
|
190 |
+
|
191 |
+
head_dim = inner_dim // attn.heads
|
192 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
193 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
194 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
195 |
+
|
196 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
197 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
198 |
+
hidden_states = F.scaled_dot_product_attention(
|
199 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
200 |
+
)
|
201 |
+
|
202 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
203 |
+
batch_size, -1, attn.heads * head_dim
|
204 |
+
)
|
205 |
+
hidden_states = hidden_states.to(query.dtype)
|
206 |
+
|
207 |
+
# linear proj
|
208 |
+
hidden_states = attn.to_out[0](hidden_states)
|
209 |
+
# dropout
|
210 |
+
hidden_states = attn.to_out[1](hidden_states)
|
211 |
+
return hidden_states
|
212 |
+
|
213 |
+
|
214 |
+
def replace_attentions_for_hypernetwork():
|
215 |
+
import diffusers.models.attention_processor
|
216 |
+
|
217 |
+
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
218 |
+
xformers_forward
|
219 |
+
)
|
220 |
+
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
221 |
+
sliced_attn_forward
|
222 |
+
)
|
223 |
+
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
external/llite/library/ipex/__init__.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import contextlib
|
4 |
+
import torch
|
5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
6 |
+
from .hijacks import ipex_hijacks
|
7 |
+
|
8 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
9 |
+
|
10 |
+
def ipex_init(): # pylint: disable=too-many-statements
|
11 |
+
try:
|
12 |
+
# Replace cuda with xpu:
|
13 |
+
torch.cuda.current_device = torch.xpu.current_device
|
14 |
+
torch.cuda.current_stream = torch.xpu.current_stream
|
15 |
+
torch.cuda.device = torch.xpu.device
|
16 |
+
torch.cuda.device_count = torch.xpu.device_count
|
17 |
+
torch.cuda.device_of = torch.xpu.device_of
|
18 |
+
torch.cuda.get_device_name = torch.xpu.get_device_name
|
19 |
+
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
20 |
+
torch.cuda.init = torch.xpu.init
|
21 |
+
torch.cuda.is_available = torch.xpu.is_available
|
22 |
+
torch.cuda.is_initialized = torch.xpu.is_initialized
|
23 |
+
torch.cuda.is_current_stream_capturing = lambda: False
|
24 |
+
torch.cuda.set_device = torch.xpu.set_device
|
25 |
+
torch.cuda.stream = torch.xpu.stream
|
26 |
+
torch.cuda.synchronize = torch.xpu.synchronize
|
27 |
+
torch.cuda.Event = torch.xpu.Event
|
28 |
+
torch.cuda.Stream = torch.xpu.Stream
|
29 |
+
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
30 |
+
torch.Tensor.cuda = torch.Tensor.xpu
|
31 |
+
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
32 |
+
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
33 |
+
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
34 |
+
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
35 |
+
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
36 |
+
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
37 |
+
torch.cuda._tls = torch.xpu.lazy_init._tls
|
38 |
+
torch.cuda.threading = torch.xpu.lazy_init.threading
|
39 |
+
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
40 |
+
torch.cuda.Optional = torch.xpu.Optional
|
41 |
+
torch.cuda.__cached__ = torch.xpu.__cached__
|
42 |
+
torch.cuda.__loader__ = torch.xpu.__loader__
|
43 |
+
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
44 |
+
torch.cuda.Tuple = torch.xpu.Tuple
|
45 |
+
torch.cuda.streams = torch.xpu.streams
|
46 |
+
torch.cuda._lazy_new = torch.xpu._lazy_new
|
47 |
+
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
48 |
+
torch.cuda.Any = torch.xpu.Any
|
49 |
+
torch.cuda.__doc__ = torch.xpu.__doc__
|
50 |
+
torch.cuda.default_generators = torch.xpu.default_generators
|
51 |
+
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
52 |
+
torch.cuda._get_device_index = torch.xpu._get_device_index
|
53 |
+
torch.cuda.__path__ = torch.xpu.__path__
|
54 |
+
torch.cuda.Device = torch.xpu.Device
|
55 |
+
torch.cuda.IntTensor = torch.xpu.IntTensor
|
56 |
+
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
57 |
+
torch.cuda.set_stream = torch.xpu.set_stream
|
58 |
+
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
59 |
+
torch.cuda.os = torch.xpu.os
|
60 |
+
torch.cuda.torch = torch.xpu.torch
|
61 |
+
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
62 |
+
torch.cuda.Union = torch.xpu.Union
|
63 |
+
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
64 |
+
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
65 |
+
torch.cuda.LongTensor = torch.xpu.LongTensor
|
66 |
+
torch.cuda.IntStorage = torch.xpu.IntStorage
|
67 |
+
torch.cuda.LongStorage = torch.xpu.LongStorage
|
68 |
+
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
69 |
+
torch.cuda.__package__ = torch.xpu.__package__
|
70 |
+
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
71 |
+
torch.cuda.CharTensor = torch.xpu.CharTensor
|
72 |
+
torch.cuda.List = torch.xpu.List
|
73 |
+
torch.cuda._lazy_init = torch.xpu._lazy_init
|
74 |
+
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
75 |
+
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
76 |
+
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
77 |
+
torch.cuda.StreamContext = torch.xpu.StreamContext
|
78 |
+
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
79 |
+
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
80 |
+
torch.cuda._lazy_call = torch.xpu._lazy_call
|
81 |
+
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
82 |
+
torch.cuda.random = torch.xpu.random
|
83 |
+
torch.cuda._device = torch.xpu._device
|
84 |
+
torch.cuda.classproperty = torch.xpu.classproperty
|
85 |
+
torch.cuda.__name__ = torch.xpu.__name__
|
86 |
+
torch.cuda._device_t = torch.xpu._device_t
|
87 |
+
torch.cuda.warnings = torch.xpu.warnings
|
88 |
+
torch.cuda.__spec__ = torch.xpu.__spec__
|
89 |
+
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
90 |
+
torch.cuda.CharStorage = torch.xpu.CharStorage
|
91 |
+
torch.cuda.__file__ = torch.xpu.__file__
|
92 |
+
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
93 |
+
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
94 |
+
|
95 |
+
# Memory:
|
96 |
+
torch.cuda.memory = torch.xpu.memory
|
97 |
+
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
98 |
+
torch.xpu.empty_cache = lambda: None
|
99 |
+
torch.cuda.empty_cache = torch.xpu.empty_cache
|
100 |
+
torch.cuda.memory_stats = torch.xpu.memory_stats
|
101 |
+
torch.cuda.memory_summary = torch.xpu.memory_summary
|
102 |
+
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
103 |
+
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
104 |
+
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
105 |
+
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
106 |
+
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
107 |
+
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
108 |
+
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
109 |
+
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
110 |
+
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
111 |
+
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
112 |
+
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
113 |
+
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
114 |
+
|
115 |
+
# RNG:
|
116 |
+
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
117 |
+
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
118 |
+
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
119 |
+
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
120 |
+
torch.cuda.manual_seed = torch.xpu.manual_seed
|
121 |
+
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
122 |
+
torch.cuda.seed = torch.xpu.seed
|
123 |
+
torch.cuda.seed_all = torch.xpu.seed_all
|
124 |
+
torch.cuda.initial_seed = torch.xpu.initial_seed
|
125 |
+
|
126 |
+
# AMP:
|
127 |
+
torch.cuda.amp = torch.xpu.amp
|
128 |
+
if not hasattr(torch.cuda.amp, "common"):
|
129 |
+
torch.cuda.amp.common = contextlib.nullcontext()
|
130 |
+
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
131 |
+
try:
|
132 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
133 |
+
except Exception: # pylint: disable=broad-exception-caught
|
134 |
+
try:
|
135 |
+
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
136 |
+
gradscaler_init()
|
137 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
138 |
+
except Exception: # pylint: disable=broad-exception-caught
|
139 |
+
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
140 |
+
|
141 |
+
# C
|
142 |
+
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
143 |
+
ipex._C._DeviceProperties.major = 2023
|
144 |
+
ipex._C._DeviceProperties.minor = 2
|
145 |
+
|
146 |
+
# Fix functions with ipex:
|
147 |
+
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
148 |
+
torch._utils._get_available_device_type = lambda: "xpu"
|
149 |
+
torch.has_cuda = True
|
150 |
+
torch.cuda.has_half = True
|
151 |
+
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
152 |
+
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
153 |
+
torch.version.cuda = "11.7"
|
154 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
155 |
+
torch.cuda.get_device_properties.major = 11
|
156 |
+
torch.cuda.get_device_properties.minor = 7
|
157 |
+
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
158 |
+
torch.cuda.utilization = lambda *args, **kwargs: 0
|
159 |
+
|
160 |
+
ipex_hijacks()
|
161 |
+
if not torch.xpu.has_fp64_dtype():
|
162 |
+
try:
|
163 |
+
from .diffusers import ipex_diffusers
|
164 |
+
ipex_diffusers()
|
165 |
+
except Exception: # pylint: disable=broad-exception-caught
|
166 |
+
pass
|
167 |
+
except Exception as e:
|
168 |
+
return False, e
|
169 |
+
return True, None
|
external/llite/library/ipex/attention.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
3 |
+
|
4 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
5 |
+
|
6 |
+
original_torch_bmm = torch.bmm
|
7 |
+
def torch_bmm_32_bit(input, mat2, *, out=None):
|
8 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
9 |
+
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
10 |
+
block_multiply = input.element_size()
|
11 |
+
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
|
12 |
+
block_size = batch_size_attention * slice_block_size
|
13 |
+
|
14 |
+
split_slice_size = batch_size_attention
|
15 |
+
if block_size > 4:
|
16 |
+
do_split = True
|
17 |
+
# Find something divisible with the input_tokens
|
18 |
+
while (split_slice_size * slice_block_size) > 4:
|
19 |
+
split_slice_size = split_slice_size // 2
|
20 |
+
if split_slice_size <= 1:
|
21 |
+
split_slice_size = 1
|
22 |
+
break
|
23 |
+
split_2_slice_size = input_tokens
|
24 |
+
if split_slice_size * slice_block_size > 4:
|
25 |
+
slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
26 |
+
do_split_2 = True
|
27 |
+
# Find something divisible with the input_tokens
|
28 |
+
while (split_2_slice_size * slice_block_size_2) > 4:
|
29 |
+
split_2_slice_size = split_2_slice_size // 2
|
30 |
+
if split_2_slice_size <= 1:
|
31 |
+
split_2_slice_size = 1
|
32 |
+
break
|
33 |
+
else:
|
34 |
+
do_split_2 = False
|
35 |
+
else:
|
36 |
+
do_split = False
|
37 |
+
|
38 |
+
if do_split:
|
39 |
+
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
40 |
+
for i in range(batch_size_attention // split_slice_size):
|
41 |
+
start_idx = i * split_slice_size
|
42 |
+
end_idx = (i + 1) * split_slice_size
|
43 |
+
if do_split_2:
|
44 |
+
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
45 |
+
start_idx_2 = i2 * split_2_slice_size
|
46 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
47 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
48 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
49 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
50 |
+
out=out
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
54 |
+
input[start_idx:end_idx],
|
55 |
+
mat2[start_idx:end_idx],
|
56 |
+
out=out
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
return original_torch_bmm(input, mat2, out=out)
|
60 |
+
return hidden_states
|
61 |
+
|
62 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
63 |
+
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
64 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
65 |
+
if len(query.shape) == 3:
|
66 |
+
batch_size_attention, query_tokens, shape_three = query.shape
|
67 |
+
shape_four = 1
|
68 |
+
else:
|
69 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query.shape
|
70 |
+
|
71 |
+
block_multiply = query.element_size()
|
72 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
|
73 |
+
block_size = batch_size_attention * slice_block_size
|
74 |
+
|
75 |
+
split_slice_size = batch_size_attention
|
76 |
+
if block_size > 4:
|
77 |
+
do_split = True
|
78 |
+
# Find something divisible with the batch_size_attention
|
79 |
+
while (split_slice_size * slice_block_size) > 4:
|
80 |
+
split_slice_size = split_slice_size // 2
|
81 |
+
if split_slice_size <= 1:
|
82 |
+
split_slice_size = 1
|
83 |
+
break
|
84 |
+
split_2_slice_size = query_tokens
|
85 |
+
if split_slice_size * slice_block_size > 4:
|
86 |
+
slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
|
87 |
+
do_split_2 = True
|
88 |
+
# Find something divisible with the query_tokens
|
89 |
+
while (split_2_slice_size * slice_block_size_2) > 4:
|
90 |
+
split_2_slice_size = split_2_slice_size // 2
|
91 |
+
if split_2_slice_size <= 1:
|
92 |
+
split_2_slice_size = 1
|
93 |
+
break
|
94 |
+
split_3_slice_size = shape_three
|
95 |
+
if split_2_slice_size * slice_block_size_2 > 4:
|
96 |
+
slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
|
97 |
+
do_split_3 = True
|
98 |
+
# Find something divisible with the shape_three
|
99 |
+
while (split_3_slice_size * slice_block_size_3) > 4:
|
100 |
+
split_3_slice_size = split_3_slice_size // 2
|
101 |
+
if split_3_slice_size <= 1:
|
102 |
+
split_3_slice_size = 1
|
103 |
+
break
|
104 |
+
else:
|
105 |
+
do_split_3 = False
|
106 |
+
else:
|
107 |
+
do_split_2 = False
|
108 |
+
else:
|
109 |
+
do_split = False
|
110 |
+
|
111 |
+
if do_split:
|
112 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
113 |
+
for i in range(batch_size_attention // split_slice_size):
|
114 |
+
start_idx = i * split_slice_size
|
115 |
+
end_idx = (i + 1) * split_slice_size
|
116 |
+
if do_split_2:
|
117 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
118 |
+
start_idx_2 = i2 * split_2_slice_size
|
119 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
120 |
+
if do_split_3:
|
121 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
122 |
+
start_idx_3 = i3 * split_3_slice_size
|
123 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
124 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
125 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
126 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
127 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
128 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
129 |
+
dropout_p=dropout_p, is_causal=is_causal
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
133 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
134 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
135 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
136 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
137 |
+
dropout_p=dropout_p, is_causal=is_causal
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
141 |
+
query[start_idx:end_idx],
|
142 |
+
key[start_idx:end_idx],
|
143 |
+
value[start_idx:end_idx],
|
144 |
+
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
145 |
+
dropout_p=dropout_p, is_causal=is_causal
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
return original_scaled_dot_product_attention(
|
149 |
+
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
150 |
+
)
|
151 |
+
return hidden_states
|
external/llite/library/ipex/diffusers.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
3 |
+
import diffusers #0.24.0 # pylint: disable=import-error
|
4 |
+
from diffusers.models.attention_processor import Attention
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
7 |
+
|
8 |
+
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
9 |
+
r"""
|
10 |
+
Processor for implementing sliced attention.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
slice_size (`int`, *optional*):
|
14 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
15 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, slice_size):
|
19 |
+
self.slice_size = slice_size
|
20 |
+
|
21 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
22 |
+
residual = hidden_states
|
23 |
+
|
24 |
+
input_ndim = hidden_states.ndim
|
25 |
+
|
26 |
+
if input_ndim == 4:
|
27 |
+
batch_size, channel, height, width = hidden_states.shape
|
28 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
29 |
+
|
30 |
+
batch_size, sequence_length, _ = (
|
31 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
32 |
+
)
|
33 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
34 |
+
|
35 |
+
if attn.group_norm is not None:
|
36 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
37 |
+
|
38 |
+
query = attn.to_q(hidden_states)
|
39 |
+
dim = query.shape[-1]
|
40 |
+
query = attn.head_to_batch_dim(query)
|
41 |
+
|
42 |
+
if encoder_hidden_states is None:
|
43 |
+
encoder_hidden_states = hidden_states
|
44 |
+
elif attn.norm_cross:
|
45 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
46 |
+
|
47 |
+
key = attn.to_k(encoder_hidden_states)
|
48 |
+
value = attn.to_v(encoder_hidden_states)
|
49 |
+
key = attn.head_to_batch_dim(key)
|
50 |
+
value = attn.head_to_batch_dim(value)
|
51 |
+
|
52 |
+
batch_size_attention, query_tokens, shape_three = query.shape
|
53 |
+
hidden_states = torch.zeros(
|
54 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
55 |
+
)
|
56 |
+
|
57 |
+
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
58 |
+
block_multiply = query.element_size()
|
59 |
+
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
|
60 |
+
block_size = query_tokens * slice_block_size
|
61 |
+
split_2_slice_size = query_tokens
|
62 |
+
if block_size > 4:
|
63 |
+
do_split_2 = True
|
64 |
+
#Find something divisible with the query_tokens
|
65 |
+
while (split_2_slice_size * slice_block_size) > 4:
|
66 |
+
split_2_slice_size = split_2_slice_size // 2
|
67 |
+
if split_2_slice_size <= 1:
|
68 |
+
split_2_slice_size = 1
|
69 |
+
break
|
70 |
+
else:
|
71 |
+
do_split_2 = False
|
72 |
+
|
73 |
+
for i in range(batch_size_attention // self.slice_size):
|
74 |
+
start_idx = i * self.slice_size
|
75 |
+
end_idx = (i + 1) * self.slice_size
|
76 |
+
|
77 |
+
if do_split_2:
|
78 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
79 |
+
start_idx_2 = i2 * split_2_slice_size
|
80 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
81 |
+
|
82 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
83 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
84 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
85 |
+
|
86 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
87 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
88 |
+
|
89 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
90 |
+
else:
|
91 |
+
query_slice = query[start_idx:end_idx]
|
92 |
+
key_slice = key[start_idx:end_idx]
|
93 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
94 |
+
|
95 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
96 |
+
|
97 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
98 |
+
|
99 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
100 |
+
|
101 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
102 |
+
|
103 |
+
# linear proj
|
104 |
+
hidden_states = attn.to_out[0](hidden_states)
|
105 |
+
# dropout
|
106 |
+
hidden_states = attn.to_out[1](hidden_states)
|
107 |
+
|
108 |
+
if input_ndim == 4:
|
109 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
110 |
+
|
111 |
+
if attn.residual_connection:
|
112 |
+
hidden_states = hidden_states + residual
|
113 |
+
|
114 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
115 |
+
|
116 |
+
return hidden_states
|
117 |
+
|
118 |
+
def ipex_diffusers():
|
119 |
+
#ARC GPUs can't allocate more than 4GB to a single block:
|
120 |
+
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
external/llite/library/ipex/gradscaler.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
7 |
+
|
8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
9 |
+
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
10 |
+
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
11 |
+
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
12 |
+
|
13 |
+
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
14 |
+
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
15 |
+
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
16 |
+
|
17 |
+
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
18 |
+
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
19 |
+
# However, we don't know their devices or dtypes in advance.
|
20 |
+
|
21 |
+
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
22 |
+
# Google says mypy struggles with defaultdicts type annotations.
|
23 |
+
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
24 |
+
# sync grad to master weight
|
25 |
+
if hasattr(optimizer, "sync_grad"):
|
26 |
+
optimizer.sync_grad()
|
27 |
+
with torch.no_grad():
|
28 |
+
for group in optimizer.param_groups:
|
29 |
+
for param in group["params"]:
|
30 |
+
if param.grad is None:
|
31 |
+
continue
|
32 |
+
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
33 |
+
raise ValueError("Attempting to unscale FP16 gradients.")
|
34 |
+
if param.grad.is_sparse:
|
35 |
+
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
36 |
+
# coalesce() deduplicates indices and adds all values that have the same index.
|
37 |
+
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
38 |
+
# so we should check the coalesced _values().
|
39 |
+
if param.grad.dtype is torch.float16:
|
40 |
+
param.grad = param.grad.coalesce()
|
41 |
+
to_unscale = param.grad._values()
|
42 |
+
else:
|
43 |
+
to_unscale = param.grad
|
44 |
+
|
45 |
+
# -: is there a way to split by device and dtype without appending in the inner loop?
|
46 |
+
to_unscale = to_unscale.to("cpu")
|
47 |
+
per_device_and_dtype_grads[to_unscale.device][
|
48 |
+
to_unscale.dtype
|
49 |
+
].append(to_unscale)
|
50 |
+
|
51 |
+
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
52 |
+
for grads in per_dtype_grads.values():
|
53 |
+
core._amp_foreach_non_finite_check_and_unscale_(
|
54 |
+
grads,
|
55 |
+
per_device_found_inf.get("cpu"),
|
56 |
+
per_device_inv_scale.get("cpu"),
|
57 |
+
)
|
58 |
+
|
59 |
+
return per_device_found_inf._per_device_tensors
|
60 |
+
|
61 |
+
def unscale_(self, optimizer):
|
62 |
+
"""
|
63 |
+
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
64 |
+
:meth:`unscale_` is optional, serving cases where you need to
|
65 |
+
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
66 |
+
between the backward pass(es) and :meth:`step`.
|
67 |
+
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
68 |
+
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
69 |
+
...
|
70 |
+
scaler.scale(loss).backward()
|
71 |
+
scaler.unscale_(optimizer)
|
72 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
73 |
+
scaler.step(optimizer)
|
74 |
+
scaler.update()
|
75 |
+
Args:
|
76 |
+
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
77 |
+
.. warning::
|
78 |
+
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
79 |
+
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
80 |
+
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
81 |
+
.. warning::
|
82 |
+
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
83 |
+
"""
|
84 |
+
if not self._enabled:
|
85 |
+
return
|
86 |
+
|
87 |
+
self._check_scale_growth_tracker("unscale_")
|
88 |
+
|
89 |
+
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
90 |
+
|
91 |
+
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
92 |
+
raise RuntimeError(
|
93 |
+
"unscale_() has already been called on this optimizer since the last update()."
|
94 |
+
)
|
95 |
+
elif optimizer_state["stage"] is OptState.STEPPED:
|
96 |
+
raise RuntimeError("unscale_() is being called after step().")
|
97 |
+
|
98 |
+
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
99 |
+
assert self._scale is not None
|
100 |
+
if device_supports_fp64:
|
101 |
+
inv_scale = self._scale.double().reciprocal().float()
|
102 |
+
else:
|
103 |
+
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
104 |
+
found_inf = torch.full(
|
105 |
+
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
106 |
+
)
|
107 |
+
|
108 |
+
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
109 |
+
optimizer, inv_scale, found_inf, False
|
110 |
+
)
|
111 |
+
optimizer_state["stage"] = OptState.UNSCALED
|
112 |
+
|
113 |
+
def update(self, new_scale=None):
|
114 |
+
"""
|
115 |
+
Updates the scale factor.
|
116 |
+
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
117 |
+
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
118 |
+
the scale is multiplied by ``growth_factor`` to increase it.
|
119 |
+
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
120 |
+
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
121 |
+
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
122 |
+
affect the scale GradScaler uses internally.)
|
123 |
+
Args:
|
124 |
+
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
125 |
+
.. warning::
|
126 |
+
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
127 |
+
been invoked for all optimizers used this iteration.
|
128 |
+
"""
|
129 |
+
if not self._enabled:
|
130 |
+
return
|
131 |
+
|
132 |
+
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
133 |
+
|
134 |
+
if new_scale is not None:
|
135 |
+
# Accept a new user-defined scale.
|
136 |
+
if isinstance(new_scale, float):
|
137 |
+
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
138 |
+
else:
|
139 |
+
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
140 |
+
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
141 |
+
assert new_scale.numel() == 1, reason
|
142 |
+
assert new_scale.requires_grad is False, reason
|
143 |
+
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
144 |
+
else:
|
145 |
+
# Consume shared inf/nan data collected from optimizers to update the scale.
|
146 |
+
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
147 |
+
found_infs = [
|
148 |
+
found_inf.to(device="cpu", non_blocking=True)
|
149 |
+
for state in self._per_optimizer_states.values()
|
150 |
+
for found_inf in state["found_inf_per_device"].values()
|
151 |
+
]
|
152 |
+
|
153 |
+
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
154 |
+
|
155 |
+
found_inf_combined = found_infs[0]
|
156 |
+
if len(found_infs) > 1:
|
157 |
+
for i in range(1, len(found_infs)):
|
158 |
+
found_inf_combined += found_infs[i]
|
159 |
+
|
160 |
+
to_device = _scale.device
|
161 |
+
_scale = _scale.to("cpu")
|
162 |
+
_growth_tracker = _growth_tracker.to("cpu")
|
163 |
+
|
164 |
+
core._amp_update_scale_(
|
165 |
+
_scale,
|
166 |
+
_growth_tracker,
|
167 |
+
found_inf_combined,
|
168 |
+
self._growth_factor,
|
169 |
+
self._backoff_factor,
|
170 |
+
self._growth_interval,
|
171 |
+
)
|
172 |
+
|
173 |
+
_scale = _scale.to(to_device)
|
174 |
+
_growth_tracker = _growth_tracker.to(to_device)
|
175 |
+
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
176 |
+
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
177 |
+
|
178 |
+
def gradscaler_init():
|
179 |
+
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
180 |
+
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
181 |
+
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
182 |
+
torch.xpu.amp.GradScaler.update = update
|
183 |
+
return torch.xpu.amp.GradScaler
|
external/llite/library/ipex/hijacks.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import importlib
|
3 |
+
import torch
|
4 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
7 |
+
|
8 |
+
class CondFunc: # pylint: disable=missing-class-docstring
|
9 |
+
def __new__(cls, orig_func, sub_func, cond_func):
|
10 |
+
self = super(CondFunc, cls).__new__(cls)
|
11 |
+
if isinstance(orig_func, str):
|
12 |
+
func_path = orig_func.split('.')
|
13 |
+
for i in range(len(func_path)-1, -1, -1):
|
14 |
+
try:
|
15 |
+
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
16 |
+
break
|
17 |
+
except ImportError:
|
18 |
+
pass
|
19 |
+
for attr_name in func_path[i:-1]:
|
20 |
+
resolved_obj = getattr(resolved_obj, attr_name)
|
21 |
+
orig_func = getattr(resolved_obj, func_path[-1])
|
22 |
+
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
23 |
+
self.__init__(orig_func, sub_func, cond_func)
|
24 |
+
return lambda *args, **kwargs: self(*args, **kwargs)
|
25 |
+
def __init__(self, orig_func, sub_func, cond_func):
|
26 |
+
self.__orig_func = orig_func
|
27 |
+
self.__sub_func = sub_func
|
28 |
+
self.__cond_func = cond_func
|
29 |
+
def __call__(self, *args, **kwargs):
|
30 |
+
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
31 |
+
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
32 |
+
else:
|
33 |
+
return self.__orig_func(*args, **kwargs)
|
34 |
+
|
35 |
+
_utils = torch.utils.data._utils
|
36 |
+
def _shutdown_workers(self):
|
37 |
+
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
38 |
+
return
|
39 |
+
if hasattr(self, "_shutdown") and not self._shutdown:
|
40 |
+
self._shutdown = True
|
41 |
+
try:
|
42 |
+
if hasattr(self, '_pin_memory_thread'):
|
43 |
+
self._pin_memory_thread_done_event.set()
|
44 |
+
self._worker_result_queue.put((None, None))
|
45 |
+
self._pin_memory_thread.join()
|
46 |
+
self._worker_result_queue.cancel_join_thread()
|
47 |
+
self._worker_result_queue.close()
|
48 |
+
self._workers_done_event.set()
|
49 |
+
for worker_id in range(len(self._workers)):
|
50 |
+
if self._persistent_workers or self._workers_status[worker_id]:
|
51 |
+
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
52 |
+
for w in self._workers: # pylint: disable=invalid-name
|
53 |
+
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
54 |
+
for q in self._index_queues: # pylint: disable=invalid-name
|
55 |
+
q.cancel_join_thread()
|
56 |
+
q.close()
|
57 |
+
finally:
|
58 |
+
if self._worker_pids_set:
|
59 |
+
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
60 |
+
self._worker_pids_set = False
|
61 |
+
for w in self._workers: # pylint: disable=invalid-name
|
62 |
+
if w.is_alive():
|
63 |
+
w.terminate()
|
64 |
+
|
65 |
+
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
66 |
+
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
67 |
+
if isinstance(device_ids, list) and len(device_ids) > 1:
|
68 |
+
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
69 |
+
return module.to("xpu")
|
70 |
+
|
71 |
+
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
72 |
+
return contextlib.nullcontext()
|
73 |
+
|
74 |
+
def check_device(device):
|
75 |
+
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
76 |
+
|
77 |
+
def return_xpu(device):
|
78 |
+
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
79 |
+
|
80 |
+
def ipex_no_cuda(orig_func, *args, **kwargs):
|
81 |
+
torch.cuda.is_available = lambda: False
|
82 |
+
orig_func(*args, **kwargs)
|
83 |
+
torch.cuda.is_available = torch.xpu.is_available
|
84 |
+
|
85 |
+
original_autocast = torch.autocast
|
86 |
+
def ipex_autocast(*args, **kwargs):
|
87 |
+
if len(args) > 0 and args[0] == "cuda":
|
88 |
+
return original_autocast("xpu", *args[1:], **kwargs)
|
89 |
+
else:
|
90 |
+
return original_autocast(*args, **kwargs)
|
91 |
+
|
92 |
+
# Embedding BF16
|
93 |
+
original_torch_cat = torch.cat
|
94 |
+
def torch_cat(tensor, *args, **kwargs):
|
95 |
+
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
96 |
+
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
97 |
+
else:
|
98 |
+
return original_torch_cat(tensor, *args, **kwargs)
|
99 |
+
|
100 |
+
# Latent antialias:
|
101 |
+
original_interpolate = torch.nn.functional.interpolate
|
102 |
+
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
103 |
+
if antialias or align_corners is not None:
|
104 |
+
return_device = tensor.device
|
105 |
+
return_dtype = tensor.dtype
|
106 |
+
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
107 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
108 |
+
else:
|
109 |
+
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
110 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
111 |
+
|
112 |
+
original_linalg_solve = torch.linalg.solve
|
113 |
+
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
114 |
+
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
115 |
+
return_device = A.device
|
116 |
+
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
117 |
+
else:
|
118 |
+
return original_linalg_solve(A, B, *args, **kwargs)
|
119 |
+
|
120 |
+
if torch.xpu.has_fp64_dtype():
|
121 |
+
original_torch_bmm = torch.bmm
|
122 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
123 |
+
else:
|
124 |
+
# 64 bit attention workarounds for Alchemist:
|
125 |
+
try:
|
126 |
+
from .attention import torch_bmm_32_bit as original_torch_bmm
|
127 |
+
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
128 |
+
except Exception: # pylint: disable=broad-exception-caught
|
129 |
+
original_torch_bmm = torch.bmm
|
130 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
131 |
+
|
132 |
+
# dtype errors:
|
133 |
+
def torch_bmm(input, mat2, *, out=None):
|
134 |
+
if input.dtype != mat2.dtype:
|
135 |
+
mat2 = mat2.to(input.dtype)
|
136 |
+
return original_torch_bmm(input, mat2, out=out)
|
137 |
+
|
138 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
139 |
+
if query.dtype != key.dtype:
|
140 |
+
key = key.to(dtype=query.dtype)
|
141 |
+
if query.dtype != value.dtype:
|
142 |
+
value = value.to(dtype=query.dtype)
|
143 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
144 |
+
|
145 |
+
@property
|
146 |
+
def is_cuda(self):
|
147 |
+
return self.device.type == 'xpu'
|
148 |
+
|
149 |
+
def ipex_hijacks():
|
150 |
+
CondFunc('torch.tensor',
|
151 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
152 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
153 |
+
CondFunc('torch.Tensor.to',
|
154 |
+
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
155 |
+
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
156 |
+
CondFunc('torch.Tensor.cuda',
|
157 |
+
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
158 |
+
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
159 |
+
CondFunc('torch.UntypedStorage.__init__',
|
160 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
161 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
162 |
+
CondFunc('torch.UntypedStorage.cuda',
|
163 |
+
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
164 |
+
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
165 |
+
CondFunc('torch.empty',
|
166 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
167 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
168 |
+
CondFunc('torch.randn',
|
169 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
170 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
171 |
+
CondFunc('torch.ones',
|
172 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
173 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
174 |
+
CondFunc('torch.zeros',
|
175 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
176 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
177 |
+
CondFunc('torch.linspace',
|
178 |
+
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
179 |
+
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
180 |
+
CondFunc('torch.load',
|
181 |
+
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
182 |
+
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
183 |
+
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
184 |
+
if hasattr(torch.xpu, "Generator"):
|
185 |
+
CondFunc('torch.Generator',
|
186 |
+
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
187 |
+
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
188 |
+
else:
|
189 |
+
CondFunc('torch.Generator',
|
190 |
+
lambda orig_func, device=None: orig_func(return_xpu(device)),
|
191 |
+
lambda orig_func, device=None: check_device(device))
|
192 |
+
|
193 |
+
# TiledVAE and ControlNet:
|
194 |
+
CondFunc('torch.batch_norm',
|
195 |
+
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
196 |
+
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
197 |
+
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
198 |
+
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
199 |
+
CondFunc('torch.instance_norm',
|
200 |
+
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
201 |
+
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
202 |
+
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
203 |
+
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
204 |
+
|
205 |
+
# Functions with dtype errors:
|
206 |
+
CondFunc('torch.nn.modules.GroupNorm.forward',
|
207 |
+
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
208 |
+
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
209 |
+
# Training:
|
210 |
+
CondFunc('torch.nn.modules.linear.Linear.forward',
|
211 |
+
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
212 |
+
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
213 |
+
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
214 |
+
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
215 |
+
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
216 |
+
# BF16:
|
217 |
+
CondFunc('torch.nn.functional.layer_norm',
|
218 |
+
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
219 |
+
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
220 |
+
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
221 |
+
weight is not None and input.dtype != weight.data.dtype)
|
222 |
+
# SwinIR BF16:
|
223 |
+
CondFunc('torch.nn.functional.pad',
|
224 |
+
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
225 |
+
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
226 |
+
|
227 |
+
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
228 |
+
if not torch.xpu.has_fp64_dtype():
|
229 |
+
CondFunc('torch.from_numpy',
|
230 |
+
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
231 |
+
lambda orig_func, ndarray: ndarray.dtype == float)
|
232 |
+
|
233 |
+
# Broken functions when torch.cuda.is_available is True:
|
234 |
+
# Pin Memory:
|
235 |
+
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
236 |
+
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
237 |
+
lambda orig_func, *args, **kwargs: True)
|
238 |
+
|
239 |
+
# Functions that make compile mad with CondFunc:
|
240 |
+
torch.nn.DataParallel = DummyDataParallel
|
241 |
+
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
242 |
+
|
243 |
+
torch.autocast = ipex_autocast
|
244 |
+
torch.backends.cuda.sdp_kernel = return_null_context
|
245 |
+
torch.UntypedStorage.is_cuda = is_cuda
|
246 |
+
|
247 |
+
torch.nn.functional.interpolate = interpolate
|
248 |
+
torch.linalg.solve = linalg_solve
|
249 |
+
|
250 |
+
torch.bmm = torch_bmm
|
251 |
+
torch.cat = torch_cat
|
252 |
+
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
external/llite/library/lpw_stable_diffusion.py
ADDED
@@ -0,0 +1,1254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
2 |
+
# and modify to support SD2.x
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import re
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from packaging import version
|
12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
13 |
+
|
14 |
+
import diffusers
|
15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
+
from diffusers.utils import logging
|
19 |
+
|
20 |
+
|
21 |
+
try:
|
22 |
+
from diffusers.utils import PIL_INTERPOLATION
|
23 |
+
except ImportError:
|
24 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
25 |
+
PIL_INTERPOLATION = {
|
26 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
27 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
28 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
29 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
30 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
31 |
+
}
|
32 |
+
else:
|
33 |
+
PIL_INTERPOLATION = {
|
34 |
+
"linear": PIL.Image.LINEAR,
|
35 |
+
"bilinear": PIL.Image.BILINEAR,
|
36 |
+
"bicubic": PIL.Image.BICUBIC,
|
37 |
+
"lanczos": PIL.Image.LANCZOS,
|
38 |
+
"nearest": PIL.Image.NEAREST,
|
39 |
+
}
|
40 |
+
# ------------------------------------------------------------------------------
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
re_attention = re.compile(
|
45 |
+
r"""
|
46 |
+
\\\(|
|
47 |
+
\\\)|
|
48 |
+
\\\[|
|
49 |
+
\\]|
|
50 |
+
\\\\|
|
51 |
+
\\|
|
52 |
+
\(|
|
53 |
+
\[|
|
54 |
+
:([+-]?[.\d]+)\)|
|
55 |
+
\)|
|
56 |
+
]|
|
57 |
+
[^\\()\[\]:]+|
|
58 |
+
:
|
59 |
+
""",
|
60 |
+
re.X,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def parse_prompt_attention(text):
|
65 |
+
"""
|
66 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
67 |
+
Accepted tokens are:
|
68 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
69 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
70 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
71 |
+
\( - literal character '('
|
72 |
+
\[ - literal character '['
|
73 |
+
\) - literal character ')'
|
74 |
+
\] - literal character ']'
|
75 |
+
\\ - literal character '\'
|
76 |
+
anything else - just text
|
77 |
+
>>> parse_prompt_attention('normal text')
|
78 |
+
[['normal text', 1.0]]
|
79 |
+
>>> parse_prompt_attention('an (important) word')
|
80 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
81 |
+
>>> parse_prompt_attention('(unbalanced')
|
82 |
+
[['unbalanced', 1.1]]
|
83 |
+
>>> parse_prompt_attention('\(literal\]')
|
84 |
+
[['(literal]', 1.0]]
|
85 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
86 |
+
[['unnecessaryparens', 1.1]]
|
87 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
88 |
+
[['a ', 1.0],
|
89 |
+
['house', 1.5730000000000004],
|
90 |
+
[' ', 1.1],
|
91 |
+
['on', 1.0],
|
92 |
+
[' a ', 1.1],
|
93 |
+
['hill', 0.55],
|
94 |
+
[', sun, ', 1.1],
|
95 |
+
['sky', 1.4641000000000006],
|
96 |
+
['.', 1.1]]
|
97 |
+
"""
|
98 |
+
|
99 |
+
res = []
|
100 |
+
round_brackets = []
|
101 |
+
square_brackets = []
|
102 |
+
|
103 |
+
round_bracket_multiplier = 1.1
|
104 |
+
square_bracket_multiplier = 1 / 1.1
|
105 |
+
|
106 |
+
def multiply_range(start_position, multiplier):
|
107 |
+
for p in range(start_position, len(res)):
|
108 |
+
res[p][1] *= multiplier
|
109 |
+
|
110 |
+
for m in re_attention.finditer(text):
|
111 |
+
text = m.group(0)
|
112 |
+
weight = m.group(1)
|
113 |
+
|
114 |
+
if text.startswith("\\"):
|
115 |
+
res.append([text[1:], 1.0])
|
116 |
+
elif text == "(":
|
117 |
+
round_brackets.append(len(res))
|
118 |
+
elif text == "[":
|
119 |
+
square_brackets.append(len(res))
|
120 |
+
elif weight is not None and len(round_brackets) > 0:
|
121 |
+
multiply_range(round_brackets.pop(), float(weight))
|
122 |
+
elif text == ")" and len(round_brackets) > 0:
|
123 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
124 |
+
elif text == "]" and len(square_brackets) > 0:
|
125 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
126 |
+
else:
|
127 |
+
res.append([text, 1.0])
|
128 |
+
|
129 |
+
for pos in round_brackets:
|
130 |
+
multiply_range(pos, round_bracket_multiplier)
|
131 |
+
|
132 |
+
for pos in square_brackets:
|
133 |
+
multiply_range(pos, square_bracket_multiplier)
|
134 |
+
|
135 |
+
if len(res) == 0:
|
136 |
+
res = [["", 1.0]]
|
137 |
+
|
138 |
+
# merge runs of identical weights
|
139 |
+
i = 0
|
140 |
+
while i + 1 < len(res):
|
141 |
+
if res[i][1] == res[i + 1][1]:
|
142 |
+
res[i][0] += res[i + 1][0]
|
143 |
+
res.pop(i + 1)
|
144 |
+
else:
|
145 |
+
i += 1
|
146 |
+
|
147 |
+
return res
|
148 |
+
|
149 |
+
|
150 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
151 |
+
r"""
|
152 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
153 |
+
|
154 |
+
No padding, starting or ending token is included.
|
155 |
+
"""
|
156 |
+
tokens = []
|
157 |
+
weights = []
|
158 |
+
truncated = False
|
159 |
+
for text in prompt:
|
160 |
+
texts_and_weights = parse_prompt_attention(text)
|
161 |
+
text_token = []
|
162 |
+
text_weight = []
|
163 |
+
for word, weight in texts_and_weights:
|
164 |
+
# tokenize and discard the starting and the ending token
|
165 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
166 |
+
text_token += token
|
167 |
+
# copy the weight by length of token
|
168 |
+
text_weight += [weight] * len(token)
|
169 |
+
# stop if the text is too long (longer than truncation limit)
|
170 |
+
if len(text_token) > max_length:
|
171 |
+
truncated = True
|
172 |
+
break
|
173 |
+
# truncate
|
174 |
+
if len(text_token) > max_length:
|
175 |
+
truncated = True
|
176 |
+
text_token = text_token[:max_length]
|
177 |
+
text_weight = text_weight[:max_length]
|
178 |
+
tokens.append(text_token)
|
179 |
+
weights.append(text_weight)
|
180 |
+
if truncated:
|
181 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
182 |
+
return tokens, weights
|
183 |
+
|
184 |
+
|
185 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
186 |
+
r"""
|
187 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
188 |
+
"""
|
189 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
190 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
191 |
+
for i in range(len(tokens)):
|
192 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
193 |
+
if no_boseos_middle:
|
194 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
195 |
+
else:
|
196 |
+
w = []
|
197 |
+
if len(weights[i]) == 0:
|
198 |
+
w = [1.0] * weights_length
|
199 |
+
else:
|
200 |
+
for j in range(max_embeddings_multiples):
|
201 |
+
w.append(1.0) # weight for starting token in this chunk
|
202 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
203 |
+
w.append(1.0) # weight for ending token in this chunk
|
204 |
+
w += [1.0] * (weights_length - len(w))
|
205 |
+
weights[i] = w[:]
|
206 |
+
|
207 |
+
return tokens, weights
|
208 |
+
|
209 |
+
|
210 |
+
def get_unweighted_text_embeddings(
|
211 |
+
pipe: StableDiffusionPipeline,
|
212 |
+
text_input: torch.Tensor,
|
213 |
+
chunk_length: int,
|
214 |
+
clip_skip: int,
|
215 |
+
eos: int,
|
216 |
+
pad: int,
|
217 |
+
no_boseos_middle: Optional[bool] = True,
|
218 |
+
):
|
219 |
+
"""
|
220 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
221 |
+
it should be split into chunks and sent to the text encoder individually.
|
222 |
+
"""
|
223 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
224 |
+
if max_embeddings_multiples > 1:
|
225 |
+
text_embeddings = []
|
226 |
+
for i in range(max_embeddings_multiples):
|
227 |
+
# extract the i-th chunk
|
228 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
229 |
+
|
230 |
+
# cover the head and the tail by the starting and the ending tokens
|
231 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
232 |
+
if pad == eos: # v1
|
233 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
234 |
+
else: # v2
|
235 |
+
for j in range(len(text_input_chunk)):
|
236 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
237 |
+
text_input_chunk[j, -1] = eos
|
238 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
239 |
+
text_input_chunk[j, 1] = eos
|
240 |
+
|
241 |
+
if clip_skip is None or clip_skip == 1:
|
242 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
243 |
+
else:
|
244 |
+
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
245 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
246 |
+
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
247 |
+
|
248 |
+
if no_boseos_middle:
|
249 |
+
if i == 0:
|
250 |
+
# discard the ending token
|
251 |
+
text_embedding = text_embedding[:, :-1]
|
252 |
+
elif i == max_embeddings_multiples - 1:
|
253 |
+
# discard the starting token
|
254 |
+
text_embedding = text_embedding[:, 1:]
|
255 |
+
else:
|
256 |
+
# discard both starting and ending tokens
|
257 |
+
text_embedding = text_embedding[:, 1:-1]
|
258 |
+
|
259 |
+
text_embeddings.append(text_embedding)
|
260 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
261 |
+
else:
|
262 |
+
if clip_skip is None or clip_skip == 1:
|
263 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
264 |
+
else:
|
265 |
+
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
266 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
267 |
+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
268 |
+
return text_embeddings
|
269 |
+
|
270 |
+
|
271 |
+
def get_weighted_text_embeddings(
|
272 |
+
pipe: StableDiffusionPipeline,
|
273 |
+
prompt: Union[str, List[str]],
|
274 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
275 |
+
max_embeddings_multiples: Optional[int] = 3,
|
276 |
+
no_boseos_middle: Optional[bool] = False,
|
277 |
+
skip_parsing: Optional[bool] = False,
|
278 |
+
skip_weighting: Optional[bool] = False,
|
279 |
+
clip_skip=None,
|
280 |
+
):
|
281 |
+
r"""
|
282 |
+
Prompts can be assigned with local weights using brackets. For example,
|
283 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
284 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
285 |
+
|
286 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
pipe (`StableDiffusionPipeline`):
|
290 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
291 |
+
prompt (`str` or `List[str]`):
|
292 |
+
The prompt or prompts to guide the image generation.
|
293 |
+
uncond_prompt (`str` or `List[str]`):
|
294 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
295 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
296 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
297 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
298 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
299 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
300 |
+
ending token in each of the chunk in the middle.
|
301 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
302 |
+
Skip the parsing of brackets.
|
303 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
304 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
305 |
+
"""
|
306 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
307 |
+
if isinstance(prompt, str):
|
308 |
+
prompt = [prompt]
|
309 |
+
|
310 |
+
if not skip_parsing:
|
311 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
312 |
+
if uncond_prompt is not None:
|
313 |
+
if isinstance(uncond_prompt, str):
|
314 |
+
uncond_prompt = [uncond_prompt]
|
315 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
316 |
+
else:
|
317 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
318 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
319 |
+
if uncond_prompt is not None:
|
320 |
+
if isinstance(uncond_prompt, str):
|
321 |
+
uncond_prompt = [uncond_prompt]
|
322 |
+
uncond_tokens = [
|
323 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
324 |
+
]
|
325 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
326 |
+
|
327 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
328 |
+
max_length = max([len(token) for token in prompt_tokens])
|
329 |
+
if uncond_prompt is not None:
|
330 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
331 |
+
|
332 |
+
max_embeddings_multiples = min(
|
333 |
+
max_embeddings_multiples,
|
334 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
335 |
+
)
|
336 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
337 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
338 |
+
|
339 |
+
# pad the length of tokens and weights
|
340 |
+
bos = pipe.tokenizer.bos_token_id
|
341 |
+
eos = pipe.tokenizer.eos_token_id
|
342 |
+
pad = pipe.tokenizer.pad_token_id
|
343 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
344 |
+
prompt_tokens,
|
345 |
+
prompt_weights,
|
346 |
+
max_length,
|
347 |
+
bos,
|
348 |
+
eos,
|
349 |
+
no_boseos_middle=no_boseos_middle,
|
350 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
351 |
+
)
|
352 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
353 |
+
if uncond_prompt is not None:
|
354 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
355 |
+
uncond_tokens,
|
356 |
+
uncond_weights,
|
357 |
+
max_length,
|
358 |
+
bos,
|
359 |
+
eos,
|
360 |
+
no_boseos_middle=no_boseos_middle,
|
361 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
362 |
+
)
|
363 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
364 |
+
|
365 |
+
# get the embeddings
|
366 |
+
text_embeddings = get_unweighted_text_embeddings(
|
367 |
+
pipe,
|
368 |
+
prompt_tokens,
|
369 |
+
pipe.tokenizer.model_max_length,
|
370 |
+
clip_skip,
|
371 |
+
eos,
|
372 |
+
pad,
|
373 |
+
no_boseos_middle=no_boseos_middle,
|
374 |
+
)
|
375 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
376 |
+
if uncond_prompt is not None:
|
377 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
378 |
+
pipe,
|
379 |
+
uncond_tokens,
|
380 |
+
pipe.tokenizer.model_max_length,
|
381 |
+
clip_skip,
|
382 |
+
eos,
|
383 |
+
pad,
|
384 |
+
no_boseos_middle=no_boseos_middle,
|
385 |
+
)
|
386 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
387 |
+
|
388 |
+
# assign weights to the prompts and normalize in the sense of mean
|
389 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
390 |
+
if (not skip_parsing) and (not skip_weighting):
|
391 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
392 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
393 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
394 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
395 |
+
if uncond_prompt is not None:
|
396 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
397 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
398 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
399 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
400 |
+
|
401 |
+
if uncond_prompt is not None:
|
402 |
+
return text_embeddings, uncond_embeddings
|
403 |
+
return text_embeddings, None
|
404 |
+
|
405 |
+
|
406 |
+
def preprocess_image(image):
|
407 |
+
w, h = image.size
|
408 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
409 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
410 |
+
image = np.array(image).astype(np.float32) / 255.0
|
411 |
+
image = image[None].transpose(0, 3, 1, 2)
|
412 |
+
image = torch.from_numpy(image)
|
413 |
+
return 2.0 * image - 1.0
|
414 |
+
|
415 |
+
|
416 |
+
def preprocess_mask(mask, scale_factor=8):
|
417 |
+
mask = mask.convert("L")
|
418 |
+
w, h = mask.size
|
419 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
420 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
421 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
422 |
+
mask = np.tile(mask, (4, 1, 1))
|
423 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
424 |
+
mask = 1 - mask # repaint white, keep black
|
425 |
+
mask = torch.from_numpy(mask)
|
426 |
+
return mask
|
427 |
+
|
428 |
+
|
429 |
+
def prepare_controlnet_image(
|
430 |
+
image: PIL.Image.Image,
|
431 |
+
width: int,
|
432 |
+
height: int,
|
433 |
+
batch_size: int,
|
434 |
+
num_images_per_prompt: int,
|
435 |
+
device: torch.device,
|
436 |
+
dtype: torch.dtype,
|
437 |
+
do_classifier_free_guidance: bool = False,
|
438 |
+
guess_mode: bool = False,
|
439 |
+
):
|
440 |
+
if not isinstance(image, torch.Tensor):
|
441 |
+
if isinstance(image, PIL.Image.Image):
|
442 |
+
image = [image]
|
443 |
+
|
444 |
+
if isinstance(image[0], PIL.Image.Image):
|
445 |
+
images = []
|
446 |
+
|
447 |
+
for image_ in image:
|
448 |
+
image_ = image_.convert("RGB")
|
449 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
450 |
+
image_ = np.array(image_)
|
451 |
+
image_ = image_[None, :]
|
452 |
+
images.append(image_)
|
453 |
+
|
454 |
+
image = images
|
455 |
+
|
456 |
+
image = np.concatenate(image, axis=0)
|
457 |
+
image = np.array(image).astype(np.float32) / 255.0
|
458 |
+
image = image.transpose(0, 3, 1, 2)
|
459 |
+
image = torch.from_numpy(image)
|
460 |
+
elif isinstance(image[0], torch.Tensor):
|
461 |
+
image = torch.cat(image, dim=0)
|
462 |
+
|
463 |
+
image_batch_size = image.shape[0]
|
464 |
+
|
465 |
+
if image_batch_size == 1:
|
466 |
+
repeat_by = batch_size
|
467 |
+
else:
|
468 |
+
# image batch size is the same as prompt batch size
|
469 |
+
repeat_by = num_images_per_prompt
|
470 |
+
|
471 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
472 |
+
|
473 |
+
image = image.to(device=device, dtype=dtype)
|
474 |
+
|
475 |
+
if do_classifier_free_guidance and not guess_mode:
|
476 |
+
image = torch.cat([image] * 2)
|
477 |
+
|
478 |
+
return image
|
479 |
+
|
480 |
+
|
481 |
+
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
482 |
+
r"""
|
483 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
484 |
+
weighting in prompt.
|
485 |
+
|
486 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
487 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
488 |
+
|
489 |
+
Args:
|
490 |
+
vae ([`AutoencoderKL`]):
|
491 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
492 |
+
text_encoder ([`CLIPTextModel`]):
|
493 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
494 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
495 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
496 |
+
tokenizer (`CLIPTokenizer`):
|
497 |
+
Tokenizer of class
|
498 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
499 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
500 |
+
scheduler ([`SchedulerMixin`]):
|
501 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
502 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
503 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
504 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
505 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
506 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
507 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
508 |
+
"""
|
509 |
+
|
510 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
511 |
+
|
512 |
+
def __init__(
|
513 |
+
self,
|
514 |
+
vae: AutoencoderKL,
|
515 |
+
text_encoder: CLIPTextModel,
|
516 |
+
tokenizer: CLIPTokenizer,
|
517 |
+
unet: UNet2DConditionModel,
|
518 |
+
scheduler: SchedulerMixin,
|
519 |
+
# clip_skip: int,
|
520 |
+
safety_checker: StableDiffusionSafetyChecker,
|
521 |
+
feature_extractor: CLIPFeatureExtractor,
|
522 |
+
requires_safety_checker: bool = True,
|
523 |
+
clip_skip: int = 1,
|
524 |
+
):
|
525 |
+
super().__init__(
|
526 |
+
vae=vae,
|
527 |
+
text_encoder=text_encoder,
|
528 |
+
tokenizer=tokenizer,
|
529 |
+
unet=unet,
|
530 |
+
scheduler=scheduler,
|
531 |
+
safety_checker=safety_checker,
|
532 |
+
feature_extractor=feature_extractor,
|
533 |
+
requires_safety_checker=requires_safety_checker,
|
534 |
+
)
|
535 |
+
self.clip_skip = clip_skip
|
536 |
+
self.__init__additional__()
|
537 |
+
|
538 |
+
# else:
|
539 |
+
# def __init__(
|
540 |
+
# self,
|
541 |
+
# vae: AutoencoderKL,
|
542 |
+
# text_encoder: CLIPTextModel,
|
543 |
+
# tokenizer: CLIPTokenizer,
|
544 |
+
# unet: UNet2DConditionModel,
|
545 |
+
# scheduler: SchedulerMixin,
|
546 |
+
# safety_checker: StableDiffusionSafetyChecker,
|
547 |
+
# feature_extractor: CLIPFeatureExtractor,
|
548 |
+
# ):
|
549 |
+
# super().__init__(
|
550 |
+
# vae=vae,
|
551 |
+
# text_encoder=text_encoder,
|
552 |
+
# tokenizer=tokenizer,
|
553 |
+
# unet=unet,
|
554 |
+
# scheduler=scheduler,
|
555 |
+
# safety_checker=safety_checker,
|
556 |
+
# feature_extractor=feature_extractor,
|
557 |
+
# )
|
558 |
+
# self.__init__additional__()
|
559 |
+
|
560 |
+
def __init__additional__(self):
|
561 |
+
if not hasattr(self, "vae_scale_factor"):
|
562 |
+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
563 |
+
|
564 |
+
@property
|
565 |
+
def _execution_device(self):
|
566 |
+
r"""
|
567 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
568 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
569 |
+
hooks.
|
570 |
+
"""
|
571 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
572 |
+
return self.device
|
573 |
+
for module in self.unet.modules():
|
574 |
+
if (
|
575 |
+
hasattr(module, "_hf_hook")
|
576 |
+
and hasattr(module._hf_hook, "execution_device")
|
577 |
+
and module._hf_hook.execution_device is not None
|
578 |
+
):
|
579 |
+
return torch.device(module._hf_hook.execution_device)
|
580 |
+
return self.device
|
581 |
+
|
582 |
+
def _encode_prompt(
|
583 |
+
self,
|
584 |
+
prompt,
|
585 |
+
device,
|
586 |
+
num_images_per_prompt,
|
587 |
+
do_classifier_free_guidance,
|
588 |
+
negative_prompt,
|
589 |
+
max_embeddings_multiples,
|
590 |
+
):
|
591 |
+
r"""
|
592 |
+
Encodes the prompt into text encoder hidden states.
|
593 |
+
|
594 |
+
Args:
|
595 |
+
prompt (`str` or `list(int)`):
|
596 |
+
prompt to be encoded
|
597 |
+
device: (`torch.device`):
|
598 |
+
torch device
|
599 |
+
num_images_per_prompt (`int`):
|
600 |
+
number of images that should be generated per prompt
|
601 |
+
do_classifier_free_guidance (`bool`):
|
602 |
+
whether to use classifier free guidance or not
|
603 |
+
negative_prompt (`str` or `List[str]`):
|
604 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
605 |
+
if `guidance_scale` is less than `1`).
|
606 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
607 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
608 |
+
"""
|
609 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
610 |
+
|
611 |
+
if negative_prompt is None:
|
612 |
+
negative_prompt = [""] * batch_size
|
613 |
+
elif isinstance(negative_prompt, str):
|
614 |
+
negative_prompt = [negative_prompt] * batch_size
|
615 |
+
if batch_size != len(negative_prompt):
|
616 |
+
raise ValueError(
|
617 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
618 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
619 |
+
" the batch size of `prompt`."
|
620 |
+
)
|
621 |
+
|
622 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
623 |
+
pipe=self,
|
624 |
+
prompt=prompt,
|
625 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
626 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
627 |
+
clip_skip=self.clip_skip,
|
628 |
+
)
|
629 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
630 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
631 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
632 |
+
|
633 |
+
if do_classifier_free_guidance:
|
634 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
635 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
636 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
637 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
638 |
+
|
639 |
+
return text_embeddings
|
640 |
+
|
641 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
642 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
643 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
644 |
+
|
645 |
+
if strength < 0 or strength > 1:
|
646 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
647 |
+
|
648 |
+
if height % 8 != 0 or width % 8 != 0:
|
649 |
+
print(height, width)
|
650 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
651 |
+
|
652 |
+
if (callback_steps is None) or (
|
653 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
654 |
+
):
|
655 |
+
raise ValueError(
|
656 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
657 |
+
)
|
658 |
+
|
659 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
660 |
+
if is_text2img:
|
661 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
662 |
+
else:
|
663 |
+
# get the original timestep using init_timestep
|
664 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
665 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
666 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
667 |
+
|
668 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
669 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
670 |
+
return timesteps, num_inference_steps - t_start
|
671 |
+
|
672 |
+
def run_safety_checker(self, image, device, dtype):
|
673 |
+
if self.safety_checker is not None:
|
674 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
675 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
676 |
+
else:
|
677 |
+
has_nsfw_concept = None
|
678 |
+
return image, has_nsfw_concept
|
679 |
+
|
680 |
+
def decode_latents(self, latents):
|
681 |
+
latents = 1 / 0.18215 * latents
|
682 |
+
image = self.vae.decode(latents).sample
|
683 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
684 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
685 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
686 |
+
return image
|
687 |
+
|
688 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
689 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
690 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
691 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
692 |
+
# and should be between [0, 1]
|
693 |
+
|
694 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
695 |
+
extra_step_kwargs = {}
|
696 |
+
if accepts_eta:
|
697 |
+
extra_step_kwargs["eta"] = eta
|
698 |
+
|
699 |
+
# check if the scheduler accepts generator
|
700 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
701 |
+
if accepts_generator:
|
702 |
+
extra_step_kwargs["generator"] = generator
|
703 |
+
return extra_step_kwargs
|
704 |
+
|
705 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
706 |
+
if image is None:
|
707 |
+
shape = (
|
708 |
+
batch_size,
|
709 |
+
self.unet.in_channels,
|
710 |
+
height // self.vae_scale_factor,
|
711 |
+
width // self.vae_scale_factor,
|
712 |
+
)
|
713 |
+
|
714 |
+
if latents is None:
|
715 |
+
if device.type == "mps":
|
716 |
+
# randn does not work reproducibly on mps
|
717 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
718 |
+
else:
|
719 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
720 |
+
else:
|
721 |
+
if latents.shape != shape:
|
722 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
723 |
+
latents = latents.to(device)
|
724 |
+
|
725 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
726 |
+
latents = latents * self.scheduler.init_noise_sigma
|
727 |
+
return latents, None, None
|
728 |
+
else:
|
729 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
730 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
731 |
+
init_latents = 0.18215 * init_latents
|
732 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
733 |
+
init_latents_orig = init_latents
|
734 |
+
shape = init_latents.shape
|
735 |
+
|
736 |
+
# add noise to latents using the timesteps
|
737 |
+
if device.type == "mps":
|
738 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
739 |
+
else:
|
740 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
741 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
742 |
+
return latents, init_latents_orig, noise
|
743 |
+
|
744 |
+
@torch.no_grad()
|
745 |
+
def __call__(
|
746 |
+
self,
|
747 |
+
prompt: Union[str, List[str]],
|
748 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
749 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
750 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
751 |
+
height: int = 512,
|
752 |
+
width: int = 512,
|
753 |
+
num_inference_steps: int = 50,
|
754 |
+
guidance_scale: float = 7.5,
|
755 |
+
strength: float = 0.8,
|
756 |
+
num_images_per_prompt: Optional[int] = 1,
|
757 |
+
eta: float = 0.0,
|
758 |
+
generator: Optional[torch.Generator] = None,
|
759 |
+
latents: Optional[torch.FloatTensor] = None,
|
760 |
+
max_embeddings_multiples: Optional[int] = 3,
|
761 |
+
output_type: Optional[str] = "pil",
|
762 |
+
return_dict: bool = True,
|
763 |
+
controlnet=None,
|
764 |
+
controlnet_image=None,
|
765 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
766 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
767 |
+
callback_steps: int = 1,
|
768 |
+
):
|
769 |
+
r"""
|
770 |
+
Function invoked when calling the pipeline for generation.
|
771 |
+
|
772 |
+
Args:
|
773 |
+
prompt (`str` or `List[str]`):
|
774 |
+
The prompt or prompts to guide the image generation.
|
775 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
776 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
777 |
+
if `guidance_scale` is less than `1`).
|
778 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
779 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
780 |
+
process.
|
781 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
782 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
783 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
784 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
785 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
786 |
+
height (`int`, *optional*, defaults to 512):
|
787 |
+
The height in pixels of the generated image.
|
788 |
+
width (`int`, *optional*, defaults to 512):
|
789 |
+
The width in pixels of the generated image.
|
790 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
791 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
792 |
+
expense of slower inference.
|
793 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
794 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
795 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
796 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
797 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
798 |
+
usually at the expense of lower image quality.
|
799 |
+
strength (`float`, *optional*, defaults to 0.8):
|
800 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
801 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
802 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
803 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
804 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
805 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
806 |
+
The number of images to generate per prompt.
|
807 |
+
eta (`float`, *optional*, defaults to 0.0):
|
808 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
809 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
810 |
+
generator (`torch.Generator`, *optional*):
|
811 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
812 |
+
deterministic.
|
813 |
+
latents (`torch.FloatTensor`, *optional*):
|
814 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
815 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
816 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
817 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
818 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
819 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
820 |
+
The output format of the generate image. Choose between
|
821 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
822 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
823 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
824 |
+
plain tuple.
|
825 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
826 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
827 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
828 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
829 |
+
inference.
|
830 |
+
callback (`Callable`, *optional*):
|
831 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
832 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
833 |
+
is_cancelled_callback (`Callable`, *optional*):
|
834 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
835 |
+
`True`, the inference will be cancelled.
|
836 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
837 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
838 |
+
called at every step.
|
839 |
+
|
840 |
+
Returns:
|
841 |
+
`None` if cancelled by `is_cancelled_callback`,
|
842 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
843 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
844 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
845 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
846 |
+
(nsfw) content, according to the `safety_checker`.
|
847 |
+
"""
|
848 |
+
if controlnet is not None and controlnet_image is None:
|
849 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
850 |
+
|
851 |
+
# 0. Default height and width to unet
|
852 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
853 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
854 |
+
|
855 |
+
# 1. Check inputs. Raise error if not correct
|
856 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
857 |
+
|
858 |
+
# 2. Define call parameters
|
859 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
860 |
+
device = self._execution_device
|
861 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
862 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
863 |
+
# corresponds to doing no classifier free guidance.
|
864 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
865 |
+
|
866 |
+
# 3. Encode input prompt
|
867 |
+
text_embeddings = self._encode_prompt(
|
868 |
+
prompt,
|
869 |
+
device,
|
870 |
+
num_images_per_prompt,
|
871 |
+
do_classifier_free_guidance,
|
872 |
+
negative_prompt,
|
873 |
+
max_embeddings_multiples,
|
874 |
+
)
|
875 |
+
dtype = text_embeddings.dtype
|
876 |
+
|
877 |
+
# 4. Preprocess image and mask
|
878 |
+
if isinstance(image, PIL.Image.Image):
|
879 |
+
image = preprocess_image(image)
|
880 |
+
if image is not None:
|
881 |
+
image = image.to(device=self.device, dtype=dtype)
|
882 |
+
if isinstance(mask_image, PIL.Image.Image):
|
883 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
884 |
+
if mask_image is not None:
|
885 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
886 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
887 |
+
else:
|
888 |
+
mask = None
|
889 |
+
|
890 |
+
if controlnet_image is not None:
|
891 |
+
controlnet_image = prepare_controlnet_image(
|
892 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
893 |
+
)
|
894 |
+
|
895 |
+
# 5. set timesteps
|
896 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
897 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
898 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
899 |
+
|
900 |
+
# 6. Prepare latent variables
|
901 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
902 |
+
image,
|
903 |
+
latent_timestep,
|
904 |
+
batch_size * num_images_per_prompt,
|
905 |
+
height,
|
906 |
+
width,
|
907 |
+
dtype,
|
908 |
+
device,
|
909 |
+
generator,
|
910 |
+
latents,
|
911 |
+
)
|
912 |
+
|
913 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
914 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
915 |
+
|
916 |
+
# 8. Denoising loop
|
917 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
918 |
+
# expand the latents if we are doing classifier free guidance
|
919 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
920 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
921 |
+
|
922 |
+
unet_additional_args = {}
|
923 |
+
if controlnet is not None:
|
924 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
925 |
+
latent_model_input,
|
926 |
+
t,
|
927 |
+
encoder_hidden_states=text_embeddings,
|
928 |
+
controlnet_cond=controlnet_image,
|
929 |
+
conditioning_scale=1.0,
|
930 |
+
guess_mode=False,
|
931 |
+
return_dict=False,
|
932 |
+
)
|
933 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
934 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
935 |
+
|
936 |
+
# predict the noise residual
|
937 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
938 |
+
|
939 |
+
# perform guidance
|
940 |
+
if do_classifier_free_guidance:
|
941 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
942 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
943 |
+
|
944 |
+
# compute the previous noisy sample x_t -> x_t-1
|
945 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
946 |
+
|
947 |
+
if mask is not None:
|
948 |
+
# masking
|
949 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
950 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
951 |
+
|
952 |
+
# call the callback, if provided
|
953 |
+
if i % callback_steps == 0:
|
954 |
+
if callback is not None:
|
955 |
+
callback(i, t, latents)
|
956 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
957 |
+
return None
|
958 |
+
|
959 |
+
return latents
|
960 |
+
|
961 |
+
def latents_to_image(self, latents):
|
962 |
+
# 9. Post-processing
|
963 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
964 |
+
image = self.numpy_to_pil(image)
|
965 |
+
return image
|
966 |
+
|
967 |
+
def text2img(
|
968 |
+
self,
|
969 |
+
prompt: Union[str, List[str]],
|
970 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
971 |
+
height: int = 512,
|
972 |
+
width: int = 512,
|
973 |
+
num_inference_steps: int = 50,
|
974 |
+
guidance_scale: float = 7.5,
|
975 |
+
num_images_per_prompt: Optional[int] = 1,
|
976 |
+
eta: float = 0.0,
|
977 |
+
generator: Optional[torch.Generator] = None,
|
978 |
+
latents: Optional[torch.FloatTensor] = None,
|
979 |
+
max_embeddings_multiples: Optional[int] = 3,
|
980 |
+
output_type: Optional[str] = "pil",
|
981 |
+
return_dict: bool = True,
|
982 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
983 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
984 |
+
callback_steps: int = 1,
|
985 |
+
):
|
986 |
+
r"""
|
987 |
+
Function for text-to-image generation.
|
988 |
+
Args:
|
989 |
+
prompt (`str` or `List[str]`):
|
990 |
+
The prompt or prompts to guide the image generation.
|
991 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
992 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
993 |
+
if `guidance_scale` is less than `1`).
|
994 |
+
height (`int`, *optional*, defaults to 512):
|
995 |
+
The height in pixels of the generated image.
|
996 |
+
width (`int`, *optional*, defaults to 512):
|
997 |
+
The width in pixels of the generated image.
|
998 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
999 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1000 |
+
expense of slower inference.
|
1001 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1002 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1003 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1004 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1005 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1006 |
+
usually at the expense of lower image quality.
|
1007 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1008 |
+
The number of images to generate per prompt.
|
1009 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1010 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1011 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1012 |
+
generator (`torch.Generator`, *optional*):
|
1013 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1014 |
+
deterministic.
|
1015 |
+
latents (`torch.FloatTensor`, *optional*):
|
1016 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1017 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1018 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1019 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1020 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1021 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1022 |
+
The output format of the generate image. Choose between
|
1023 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1024 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1025 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1026 |
+
plain tuple.
|
1027 |
+
callback (`Callable`, *optional*):
|
1028 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1029 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1030 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1031 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1032 |
+
`True`, the inference will be cancelled.
|
1033 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1034 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1035 |
+
called at every step.
|
1036 |
+
Returns:
|
1037 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1038 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1039 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1040 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1041 |
+
(nsfw) content, according to the `safety_checker`.
|
1042 |
+
"""
|
1043 |
+
return self.__call__(
|
1044 |
+
prompt=prompt,
|
1045 |
+
negative_prompt=negative_prompt,
|
1046 |
+
height=height,
|
1047 |
+
width=width,
|
1048 |
+
num_inference_steps=num_inference_steps,
|
1049 |
+
guidance_scale=guidance_scale,
|
1050 |
+
num_images_per_prompt=num_images_per_prompt,
|
1051 |
+
eta=eta,
|
1052 |
+
generator=generator,
|
1053 |
+
latents=latents,
|
1054 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1055 |
+
output_type=output_type,
|
1056 |
+
return_dict=return_dict,
|
1057 |
+
callback=callback,
|
1058 |
+
is_cancelled_callback=is_cancelled_callback,
|
1059 |
+
callback_steps=callback_steps,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
def img2img(
|
1063 |
+
self,
|
1064 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1065 |
+
prompt: Union[str, List[str]],
|
1066 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1067 |
+
strength: float = 0.8,
|
1068 |
+
num_inference_steps: Optional[int] = 50,
|
1069 |
+
guidance_scale: Optional[float] = 7.5,
|
1070 |
+
num_images_per_prompt: Optional[int] = 1,
|
1071 |
+
eta: Optional[float] = 0.0,
|
1072 |
+
generator: Optional[torch.Generator] = None,
|
1073 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1074 |
+
output_type: Optional[str] = "pil",
|
1075 |
+
return_dict: bool = True,
|
1076 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1077 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1078 |
+
callback_steps: int = 1,
|
1079 |
+
):
|
1080 |
+
r"""
|
1081 |
+
Function for image-to-image generation.
|
1082 |
+
Args:
|
1083 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1084 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1085 |
+
process.
|
1086 |
+
prompt (`str` or `List[str]`):
|
1087 |
+
The prompt or prompts to guide the image generation.
|
1088 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1089 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1090 |
+
if `guidance_scale` is less than `1`).
|
1091 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1092 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
1093 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
1094 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
1095 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
1096 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1097 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1098 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1099 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1100 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1101 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1102 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1103 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1104 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1105 |
+
usually at the expense of lower image quality.
|
1106 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1107 |
+
The number of images to generate per prompt.
|
1108 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1109 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1110 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1111 |
+
generator (`torch.Generator`, *optional*):
|
1112 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1113 |
+
deterministic.
|
1114 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1115 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1116 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1117 |
+
The output format of the generate image. Choose between
|
1118 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1119 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1120 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1121 |
+
plain tuple.
|
1122 |
+
callback (`Callable`, *optional*):
|
1123 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1124 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1125 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1126 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1127 |
+
`True`, the inference will be cancelled.
|
1128 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1129 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1130 |
+
called at every step.
|
1131 |
+
Returns:
|
1132 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1133 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1134 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1135 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1136 |
+
(nsfw) content, according to the `safety_checker`.
|
1137 |
+
"""
|
1138 |
+
return self.__call__(
|
1139 |
+
prompt=prompt,
|
1140 |
+
negative_prompt=negative_prompt,
|
1141 |
+
image=image,
|
1142 |
+
num_inference_steps=num_inference_steps,
|
1143 |
+
guidance_scale=guidance_scale,
|
1144 |
+
strength=strength,
|
1145 |
+
num_images_per_prompt=num_images_per_prompt,
|
1146 |
+
eta=eta,
|
1147 |
+
generator=generator,
|
1148 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1149 |
+
output_type=output_type,
|
1150 |
+
return_dict=return_dict,
|
1151 |
+
callback=callback,
|
1152 |
+
is_cancelled_callback=is_cancelled_callback,
|
1153 |
+
callback_steps=callback_steps,
|
1154 |
+
)
|
1155 |
+
|
1156 |
+
def inpaint(
|
1157 |
+
self,
|
1158 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1159 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
1160 |
+
prompt: Union[str, List[str]],
|
1161 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1162 |
+
strength: float = 0.8,
|
1163 |
+
num_inference_steps: Optional[int] = 50,
|
1164 |
+
guidance_scale: Optional[float] = 7.5,
|
1165 |
+
num_images_per_prompt: Optional[int] = 1,
|
1166 |
+
eta: Optional[float] = 0.0,
|
1167 |
+
generator: Optional[torch.Generator] = None,
|
1168 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1169 |
+
output_type: Optional[str] = "pil",
|
1170 |
+
return_dict: bool = True,
|
1171 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1172 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1173 |
+
callback_steps: int = 1,
|
1174 |
+
):
|
1175 |
+
r"""
|
1176 |
+
Function for inpaint.
|
1177 |
+
Args:
|
1178 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1179 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1180 |
+
process. This is the image whose masked region will be inpainted.
|
1181 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1182 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1183 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1184 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1185 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1186 |
+
prompt (`str` or `List[str]`):
|
1187 |
+
The prompt or prompts to guide the image generation.
|
1188 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1189 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1190 |
+
if `guidance_scale` is less than `1`).
|
1191 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1192 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1193 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1194 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1195 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1196 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1197 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1198 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1199 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1200 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1201 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1202 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1203 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1204 |
+
usually at the expense of lower image quality.
|
1205 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1206 |
+
The number of images to generate per prompt.
|
1207 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1208 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1209 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1210 |
+
generator (`torch.Generator`, *optional*):
|
1211 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1212 |
+
deterministic.
|
1213 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1214 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1215 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1216 |
+
The output format of the generate image. Choose between
|
1217 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1218 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1219 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1220 |
+
plain tuple.
|
1221 |
+
callback (`Callable`, *optional*):
|
1222 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1223 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1224 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1225 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1226 |
+
`True`, the inference will be cancelled.
|
1227 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1228 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1229 |
+
called at every step.
|
1230 |
+
Returns:
|
1231 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1232 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1233 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1234 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1235 |
+
(nsfw) content, according to the `safety_checker`.
|
1236 |
+
"""
|
1237 |
+
return self.__call__(
|
1238 |
+
prompt=prompt,
|
1239 |
+
negative_prompt=negative_prompt,
|
1240 |
+
image=image,
|
1241 |
+
mask_image=mask_image,
|
1242 |
+
num_inference_steps=num_inference_steps,
|
1243 |
+
guidance_scale=guidance_scale,
|
1244 |
+
strength=strength,
|
1245 |
+
num_images_per_prompt=num_images_per_prompt,
|
1246 |
+
eta=eta,
|
1247 |
+
generator=generator,
|
1248 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1249 |
+
output_type=output_type,
|
1250 |
+
return_dict=return_dict,
|
1251 |
+
callback=callback,
|
1252 |
+
is_cancelled_callback=is_cancelled_callback,
|
1253 |
+
callback_steps=callback_steps,
|
1254 |
+
)
|
external/llite/library/model_util.py
ADDED
@@ -0,0 +1,1350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
try:
|
8 |
+
import intel_extension_for_pytorch as ipex
|
9 |
+
if torch.xpu.is_available():
|
10 |
+
from library.ipex import ipex_init
|
11 |
+
ipex_init()
|
12 |
+
except Exception:
|
13 |
+
pass
|
14 |
+
import diffusers
|
15 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
16 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
17 |
+
from safetensors.torch import load_file, save_file
|
18 |
+
from external.llite.library.original_unet import UNet2DConditionModel
|
19 |
+
|
20 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
21 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
22 |
+
BETA_START = 0.00085
|
23 |
+
BETA_END = 0.0120
|
24 |
+
|
25 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
26 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
27 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
28 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
29 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
30 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
31 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
32 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
33 |
+
UNET_PARAMS_NUM_HEADS = 8
|
34 |
+
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
35 |
+
|
36 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
37 |
+
VAE_PARAMS_RESOLUTION = 256
|
38 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
39 |
+
VAE_PARAMS_OUT_CH = 3
|
40 |
+
VAE_PARAMS_CH = 128
|
41 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
42 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
43 |
+
|
44 |
+
# V2
|
45 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
46 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
47 |
+
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
48 |
+
|
49 |
+
# Diffusersの設定を読み込むための参照モデル
|
50 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
51 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
52 |
+
|
53 |
+
|
54 |
+
# region StableDiffusion->Diffusersの変換コード
|
55 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
56 |
+
|
57 |
+
|
58 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
59 |
+
"""
|
60 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
61 |
+
"""
|
62 |
+
if n_shave_prefix_segments >= 0:
|
63 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
64 |
+
else:
|
65 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
66 |
+
|
67 |
+
|
68 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
69 |
+
"""
|
70 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
71 |
+
"""
|
72 |
+
mapping = []
|
73 |
+
for old_item in old_list:
|
74 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
75 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
76 |
+
|
77 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
78 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
79 |
+
|
80 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
81 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
82 |
+
|
83 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
84 |
+
|
85 |
+
mapping.append({"old": old_item, "new": new_item})
|
86 |
+
|
87 |
+
return mapping
|
88 |
+
|
89 |
+
|
90 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
91 |
+
"""
|
92 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
93 |
+
"""
|
94 |
+
mapping = []
|
95 |
+
for old_item in old_list:
|
96 |
+
new_item = old_item
|
97 |
+
|
98 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
99 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
100 |
+
|
101 |
+
mapping.append({"old": old_item, "new": new_item})
|
102 |
+
|
103 |
+
return mapping
|
104 |
+
|
105 |
+
|
106 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
107 |
+
"""
|
108 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
109 |
+
"""
|
110 |
+
mapping = []
|
111 |
+
for old_item in old_list:
|
112 |
+
new_item = old_item
|
113 |
+
|
114 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
115 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
116 |
+
|
117 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
118 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
119 |
+
|
120 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
121 |
+
|
122 |
+
mapping.append({"old": old_item, "new": new_item})
|
123 |
+
|
124 |
+
return mapping
|
125 |
+
|
126 |
+
|
127 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
128 |
+
"""
|
129 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
130 |
+
"""
|
131 |
+
mapping = []
|
132 |
+
for old_item in old_list:
|
133 |
+
new_item = old_item
|
134 |
+
|
135 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
136 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
137 |
+
|
138 |
+
if diffusers.__version__ < "0.17.0":
|
139 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
140 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
141 |
+
|
142 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
143 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
144 |
+
|
145 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
146 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
147 |
+
|
148 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
149 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
150 |
+
else:
|
151 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
152 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
153 |
+
|
154 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
155 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
156 |
+
|
157 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
158 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
159 |
+
|
160 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
161 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
162 |
+
|
163 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
164 |
+
|
165 |
+
mapping.append({"old": old_item, "new": new_item})
|
166 |
+
|
167 |
+
return mapping
|
168 |
+
|
169 |
+
|
170 |
+
def assign_to_checkpoint(
|
171 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
172 |
+
):
|
173 |
+
"""
|
174 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
175 |
+
to them. It splits attention layers, and takes into account additional replacements
|
176 |
+
that may arise.
|
177 |
+
|
178 |
+
Assigns the weights to the new checkpoint.
|
179 |
+
"""
|
180 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
181 |
+
|
182 |
+
# Splits the attention layers into three variables.
|
183 |
+
if attention_paths_to_split is not None:
|
184 |
+
for path, path_map in attention_paths_to_split.items():
|
185 |
+
old_tensor = old_checkpoint[path]
|
186 |
+
channels = old_tensor.shape[0] // 3
|
187 |
+
|
188 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
189 |
+
|
190 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
191 |
+
|
192 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
193 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
194 |
+
|
195 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
196 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
197 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
198 |
+
|
199 |
+
for path in paths:
|
200 |
+
new_path = path["new"]
|
201 |
+
|
202 |
+
# These have already been assigned
|
203 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
204 |
+
continue
|
205 |
+
|
206 |
+
# Global renaming happens here
|
207 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
208 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
209 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
210 |
+
|
211 |
+
if additional_replacements is not None:
|
212 |
+
for replacement in additional_replacements:
|
213 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
214 |
+
|
215 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
216 |
+
reshaping = False
|
217 |
+
if diffusers.__version__ < "0.17.0":
|
218 |
+
if "proj_attn.weight" in new_path:
|
219 |
+
reshaping = True
|
220 |
+
else:
|
221 |
+
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
222 |
+
reshaping = True
|
223 |
+
|
224 |
+
if reshaping:
|
225 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
226 |
+
else:
|
227 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
228 |
+
|
229 |
+
|
230 |
+
def conv_attn_to_linear(checkpoint):
|
231 |
+
keys = list(checkpoint.keys())
|
232 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
233 |
+
for key in keys:
|
234 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
235 |
+
if checkpoint[key].ndim > 2:
|
236 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
237 |
+
elif "proj_attn.weight" in key:
|
238 |
+
if checkpoint[key].ndim > 2:
|
239 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
240 |
+
|
241 |
+
|
242 |
+
def linear_transformer_to_conv(checkpoint):
|
243 |
+
keys = list(checkpoint.keys())
|
244 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
245 |
+
for key in keys:
|
246 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
247 |
+
if checkpoint[key].ndim == 2:
|
248 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
249 |
+
|
250 |
+
|
251 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
252 |
+
"""
|
253 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
254 |
+
"""
|
255 |
+
|
256 |
+
# extract state_dict for UNet
|
257 |
+
unet_state_dict = {}
|
258 |
+
unet_key = "model.diffusion_model."
|
259 |
+
keys = list(checkpoint.keys())
|
260 |
+
for key in keys:
|
261 |
+
if key.startswith(unet_key):
|
262 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
263 |
+
|
264 |
+
new_checkpoint = {}
|
265 |
+
|
266 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
267 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
268 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
269 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
270 |
+
|
271 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
272 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
273 |
+
|
274 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
275 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
276 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
277 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
278 |
+
|
279 |
+
# Retrieves the keys for the input blocks only
|
280 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
281 |
+
input_blocks = {
|
282 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
283 |
+
}
|
284 |
+
|
285 |
+
# Retrieves the keys for the middle blocks only
|
286 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
287 |
+
middle_blocks = {
|
288 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
289 |
+
}
|
290 |
+
|
291 |
+
# Retrieves the keys for the output blocks only
|
292 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
293 |
+
output_blocks = {
|
294 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
295 |
+
}
|
296 |
+
|
297 |
+
for i in range(1, num_input_blocks):
|
298 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
299 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
300 |
+
|
301 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
302 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
303 |
+
|
304 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
305 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
306 |
+
f"input_blocks.{i}.0.op.weight"
|
307 |
+
)
|
308 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
309 |
+
|
310 |
+
paths = renew_resnet_paths(resnets)
|
311 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
312 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
313 |
+
|
314 |
+
if len(attentions):
|
315 |
+
paths = renew_attention_paths(attentions)
|
316 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
317 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
318 |
+
|
319 |
+
resnet_0 = middle_blocks[0]
|
320 |
+
attentions = middle_blocks[1]
|
321 |
+
resnet_1 = middle_blocks[2]
|
322 |
+
|
323 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
324 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
325 |
+
|
326 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
327 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
328 |
+
|
329 |
+
attentions_paths = renew_attention_paths(attentions)
|
330 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
331 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
332 |
+
|
333 |
+
for i in range(num_output_blocks):
|
334 |
+
block_id = i // (config["layers_per_block"] + 1)
|
335 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
336 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
337 |
+
output_block_list = {}
|
338 |
+
|
339 |
+
for layer in output_block_layers:
|
340 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
341 |
+
if layer_id in output_block_list:
|
342 |
+
output_block_list[layer_id].append(layer_name)
|
343 |
+
else:
|
344 |
+
output_block_list[layer_id] = [layer_name]
|
345 |
+
|
346 |
+
if len(output_block_list) > 1:
|
347 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
348 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
349 |
+
|
350 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
351 |
+
paths = renew_resnet_paths(resnets)
|
352 |
+
|
353 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
354 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
355 |
+
|
356 |
+
# オリジナル:
|
357 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
358 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
359 |
+
|
360 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
361 |
+
for l in output_block_list.values():
|
362 |
+
l.sort()
|
363 |
+
|
364 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
365 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
366 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
367 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
368 |
+
]
|
369 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
370 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
371 |
+
]
|
372 |
+
|
373 |
+
# Clear attentions as they have been attributed above.
|
374 |
+
if len(attentions) == 2:
|
375 |
+
attentions = []
|
376 |
+
|
377 |
+
if len(attentions):
|
378 |
+
paths = renew_attention_paths(attentions)
|
379 |
+
meta_path = {
|
380 |
+
"old": f"output_blocks.{i}.1",
|
381 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
382 |
+
}
|
383 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
384 |
+
else:
|
385 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
386 |
+
for path in resnet_0_paths:
|
387 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
388 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
389 |
+
|
390 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
391 |
+
|
392 |
+
# SDのv2では1*1のconv2dがlinearに変わっている
|
393 |
+
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
394 |
+
if v2 and not config.get("use_linear_projection", False):
|
395 |
+
linear_transformer_to_conv(new_checkpoint)
|
396 |
+
|
397 |
+
return new_checkpoint
|
398 |
+
|
399 |
+
|
400 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
401 |
+
# extract state dict for VAE
|
402 |
+
vae_state_dict = {}
|
403 |
+
vae_key = "first_stage_model."
|
404 |
+
keys = list(checkpoint.keys())
|
405 |
+
for key in keys:
|
406 |
+
if key.startswith(vae_key):
|
407 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
408 |
+
# if len(vae_state_dict) == 0:
|
409 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
410 |
+
# vae_state_dict = checkpoint
|
411 |
+
|
412 |
+
new_checkpoint = {}
|
413 |
+
|
414 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
415 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
416 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
417 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
418 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
419 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
420 |
+
|
421 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
422 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
423 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
424 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
425 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
426 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
427 |
+
|
428 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
429 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
430 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
431 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
432 |
+
|
433 |
+
# Retrieves the keys for the encoder down blocks only
|
434 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
435 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
436 |
+
|
437 |
+
# Retrieves the keys for the decoder up blocks only
|
438 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
439 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
440 |
+
|
441 |
+
for i in range(num_down_blocks):
|
442 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
443 |
+
|
444 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
445 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
446 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
447 |
+
)
|
448 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
449 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
450 |
+
)
|
451 |
+
|
452 |
+
paths = renew_vae_resnet_paths(resnets)
|
453 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
454 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
455 |
+
|
456 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
457 |
+
num_mid_res_blocks = 2
|
458 |
+
for i in range(1, num_mid_res_blocks + 1):
|
459 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
460 |
+
|
461 |
+
paths = renew_vae_resnet_paths(resnets)
|
462 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
463 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
464 |
+
|
465 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
466 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
467 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
468 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
469 |
+
conv_attn_to_linear(new_checkpoint)
|
470 |
+
|
471 |
+
for i in range(num_up_blocks):
|
472 |
+
block_id = num_up_blocks - 1 - i
|
473 |
+
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
474 |
+
|
475 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
476 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
477 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
478 |
+
]
|
479 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
480 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
481 |
+
]
|
482 |
+
|
483 |
+
paths = renew_vae_resnet_paths(resnets)
|
484 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
485 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
486 |
+
|
487 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
488 |
+
num_mid_res_blocks = 2
|
489 |
+
for i in range(1, num_mid_res_blocks + 1):
|
490 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
491 |
+
|
492 |
+
paths = renew_vae_resnet_paths(resnets)
|
493 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
494 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
495 |
+
|
496 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
497 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
498 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
499 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
500 |
+
conv_attn_to_linear(new_checkpoint)
|
501 |
+
return new_checkpoint
|
502 |
+
|
503 |
+
|
504 |
+
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
505 |
+
"""
|
506 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
507 |
+
"""
|
508 |
+
# unet_params = original_config.model.params.unet_config.params
|
509 |
+
|
510 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
511 |
+
|
512 |
+
down_block_types = []
|
513 |
+
resolution = 1
|
514 |
+
for i in range(len(block_out_channels)):
|
515 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
516 |
+
down_block_types.append(block_type)
|
517 |
+
if i != len(block_out_channels) - 1:
|
518 |
+
resolution *= 2
|
519 |
+
|
520 |
+
up_block_types = []
|
521 |
+
for i in range(len(block_out_channels)):
|
522 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
523 |
+
up_block_types.append(block_type)
|
524 |
+
resolution //= 2
|
525 |
+
|
526 |
+
config = dict(
|
527 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
528 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
529 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
530 |
+
down_block_types=tuple(down_block_types),
|
531 |
+
up_block_types=tuple(up_block_types),
|
532 |
+
block_out_channels=tuple(block_out_channels),
|
533 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
534 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
535 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
536 |
+
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
537 |
+
)
|
538 |
+
if v2 and use_linear_projection_in_v2:
|
539 |
+
config["use_linear_projection"] = True
|
540 |
+
|
541 |
+
return config
|
542 |
+
|
543 |
+
|
544 |
+
def create_vae_diffusers_config():
|
545 |
+
"""
|
546 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
547 |
+
"""
|
548 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
549 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
550 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
551 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
552 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
553 |
+
|
554 |
+
config = dict(
|
555 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
556 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
557 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
558 |
+
down_block_types=tuple(down_block_types),
|
559 |
+
up_block_types=tuple(up_block_types),
|
560 |
+
block_out_channels=tuple(block_out_channels),
|
561 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
562 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
563 |
+
)
|
564 |
+
return config
|
565 |
+
|
566 |
+
|
567 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
568 |
+
keys = list(checkpoint.keys())
|
569 |
+
text_model_dict = {}
|
570 |
+
for key in keys:
|
571 |
+
if key.startswith("cond_stage_model.transformer"):
|
572 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
573 |
+
|
574 |
+
# support checkpoint without position_ids (invalid checkpoint)
|
575 |
+
if "text_model.embeddings.position_ids" not in text_model_dict:
|
576 |
+
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
577 |
+
|
578 |
+
return text_model_dict
|
579 |
+
|
580 |
+
|
581 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
582 |
+
# 嫌になるくらい違うぞ!
|
583 |
+
def convert_key(key):
|
584 |
+
if not key.startswith("cond_stage_model"):
|
585 |
+
return None
|
586 |
+
|
587 |
+
# common conversion
|
588 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
589 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
590 |
+
|
591 |
+
if "resblocks" in key:
|
592 |
+
# resblocks conversion
|
593 |
+
key = key.replace(".resblocks.", ".layers.")
|
594 |
+
if ".ln_" in key:
|
595 |
+
key = key.replace(".ln_", ".layer_norm")
|
596 |
+
elif ".mlp." in key:
|
597 |
+
key = key.replace(".c_fc.", ".fc1.")
|
598 |
+
key = key.replace(".c_proj.", ".fc2.")
|
599 |
+
elif ".attn.out_proj" in key:
|
600 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
601 |
+
elif ".attn.in_proj" in key:
|
602 |
+
key = None # 特殊なので後で処理する
|
603 |
+
else:
|
604 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
605 |
+
elif ".positional_embedding" in key:
|
606 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
607 |
+
elif ".text_projection" in key:
|
608 |
+
key = None # 使われない???
|
609 |
+
elif ".logit_scale" in key:
|
610 |
+
key = None # 使われない???
|
611 |
+
elif ".token_embedding" in key:
|
612 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
613 |
+
elif ".ln_final" in key:
|
614 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
615 |
+
return key
|
616 |
+
|
617 |
+
keys = list(checkpoint.keys())
|
618 |
+
new_sd = {}
|
619 |
+
for key in keys:
|
620 |
+
# remove resblocks 23
|
621 |
+
if ".resblocks.23." in key:
|
622 |
+
continue
|
623 |
+
new_key = convert_key(key)
|
624 |
+
if new_key is None:
|
625 |
+
continue
|
626 |
+
new_sd[new_key] = checkpoint[key]
|
627 |
+
|
628 |
+
# attnの変換
|
629 |
+
for key in keys:
|
630 |
+
if ".resblocks.23." in key:
|
631 |
+
continue
|
632 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
633 |
+
# 三つに分割
|
634 |
+
values = torch.chunk(checkpoint[key], 3)
|
635 |
+
|
636 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
637 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
638 |
+
key_pfx = key_pfx.replace("_weight", "")
|
639 |
+
key_pfx = key_pfx.replace("_bias", "")
|
640 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
641 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
642 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
643 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
644 |
+
|
645 |
+
# rename or add position_ids
|
646 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
647 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
648 |
+
# waifu diffusion v1.4
|
649 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
650 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
651 |
+
else:
|
652 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
653 |
+
|
654 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
655 |
+
return new_sd
|
656 |
+
|
657 |
+
|
658 |
+
# endregion
|
659 |
+
|
660 |
+
|
661 |
+
# region Diffusers->StableDiffusion の変換コード
|
662 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
663 |
+
|
664 |
+
|
665 |
+
def conv_transformer_to_linear(checkpoint):
|
666 |
+
keys = list(checkpoint.keys())
|
667 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
668 |
+
for key in keys:
|
669 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
670 |
+
if checkpoint[key].ndim > 2:
|
671 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
672 |
+
|
673 |
+
|
674 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
675 |
+
unet_conversion_map = [
|
676 |
+
# (stable-diffusion, HF Diffusers)
|
677 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
678 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
679 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
680 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
681 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
682 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
683 |
+
("out.0.weight", "conv_norm_out.weight"),
|
684 |
+
("out.0.bias", "conv_norm_out.bias"),
|
685 |
+
("out.2.weight", "conv_out.weight"),
|
686 |
+
("out.2.bias", "conv_out.bias"),
|
687 |
+
]
|
688 |
+
|
689 |
+
unet_conversion_map_resnet = [
|
690 |
+
# (stable-diffusion, HF Diffusers)
|
691 |
+
("in_layers.0", "norm1"),
|
692 |
+
("in_layers.2", "conv1"),
|
693 |
+
("out_layers.0", "norm2"),
|
694 |
+
("out_layers.3", "conv2"),
|
695 |
+
("emb_layers.1", "time_emb_proj"),
|
696 |
+
("skip_connection", "conv_shortcut"),
|
697 |
+
]
|
698 |
+
|
699 |
+
unet_conversion_map_layer = []
|
700 |
+
for i in range(4):
|
701 |
+
# loop over downblocks/upblocks
|
702 |
+
|
703 |
+
for j in range(2):
|
704 |
+
# loop over resnets/attentions for downblocks
|
705 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
706 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
707 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
708 |
+
|
709 |
+
if i < 3:
|
710 |
+
# no attention layers in down_blocks.3
|
711 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
712 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
713 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
714 |
+
|
715 |
+
for j in range(3):
|
716 |
+
# loop over resnets/attentions for upblocks
|
717 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
718 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
719 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
720 |
+
|
721 |
+
if i > 0:
|
722 |
+
# no attention layers in up_blocks.0
|
723 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
724 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
725 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
726 |
+
|
727 |
+
if i < 3:
|
728 |
+
# no downsample in down_blocks.3
|
729 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
730 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
731 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
732 |
+
|
733 |
+
# no upsample in up_blocks.3
|
734 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
735 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
736 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
737 |
+
|
738 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
739 |
+
sd_mid_atn_prefix = "middle_block.1."
|
740 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
741 |
+
|
742 |
+
for j in range(2):
|
743 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
744 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
745 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
746 |
+
|
747 |
+
# buyer beware: this is a *brittle* function,
|
748 |
+
# and correct output requires that all of these pieces interact in
|
749 |
+
# the exact order in which I have arranged them.
|
750 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
751 |
+
for sd_name, hf_name in unet_conversion_map:
|
752 |
+
mapping[hf_name] = sd_name
|
753 |
+
for k, v in mapping.items():
|
754 |
+
if "resnets" in k:
|
755 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
756 |
+
v = v.replace(hf_part, sd_part)
|
757 |
+
mapping[k] = v
|
758 |
+
for k, v in mapping.items():
|
759 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
760 |
+
v = v.replace(hf_part, sd_part)
|
761 |
+
mapping[k] = v
|
762 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
763 |
+
|
764 |
+
if v2:
|
765 |
+
conv_transformer_to_linear(new_state_dict)
|
766 |
+
|
767 |
+
return new_state_dict
|
768 |
+
|
769 |
+
|
770 |
+
def controlnet_conversion_map():
|
771 |
+
unet_conversion_map = [
|
772 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
773 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
774 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
775 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
776 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
777 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
778 |
+
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
779 |
+
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
780 |
+
]
|
781 |
+
|
782 |
+
unet_conversion_map_resnet = [
|
783 |
+
("in_layers.0", "norm1"),
|
784 |
+
("in_layers.2", "conv1"),
|
785 |
+
("out_layers.0", "norm2"),
|
786 |
+
("out_layers.3", "conv2"),
|
787 |
+
("emb_layers.1", "time_emb_proj"),
|
788 |
+
("skip_connection", "conv_shortcut"),
|
789 |
+
]
|
790 |
+
|
791 |
+
unet_conversion_map_layer = []
|
792 |
+
for i in range(4):
|
793 |
+
for j in range(2):
|
794 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
795 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
796 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
797 |
+
|
798 |
+
if i < 3:
|
799 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
800 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
801 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
802 |
+
|
803 |
+
if i < 3:
|
804 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
805 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
806 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
807 |
+
|
808 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
809 |
+
sd_mid_atn_prefix = "middle_block.1."
|
810 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
811 |
+
|
812 |
+
for j in range(2):
|
813 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
814 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
815 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
816 |
+
|
817 |
+
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
818 |
+
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
819 |
+
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
820 |
+
sd_prefix = f"input_hint_block.{i*2}."
|
821 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
822 |
+
|
823 |
+
for i in range(12):
|
824 |
+
hf_prefix = f"controlnet_down_blocks.{i}."
|
825 |
+
sd_prefix = f"zero_convs.{i}.0."
|
826 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
827 |
+
|
828 |
+
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
829 |
+
|
830 |
+
|
831 |
+
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
832 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
833 |
+
|
834 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
835 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
836 |
+
mapping[diffusers_name] = sd_name
|
837 |
+
for k, v in mapping.items():
|
838 |
+
if "resnets" in k:
|
839 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
840 |
+
v = v.replace(diffusers_part, sd_part)
|
841 |
+
mapping[k] = v
|
842 |
+
for k, v in mapping.items():
|
843 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
844 |
+
v = v.replace(diffusers_part, sd_part)
|
845 |
+
mapping[k] = v
|
846 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
847 |
+
return new_state_dict
|
848 |
+
|
849 |
+
|
850 |
+
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
851 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
852 |
+
|
853 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
854 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
855 |
+
mapping[sd_name] = diffusers_name
|
856 |
+
for k, v in mapping.items():
|
857 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
858 |
+
v = v.replace(sd_part, diffusers_part)
|
859 |
+
mapping[k] = v
|
860 |
+
for k, v in mapping.items():
|
861 |
+
if "resnets" in v:
|
862 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
863 |
+
v = v.replace(sd_part, diffusers_part)
|
864 |
+
mapping[k] = v
|
865 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
866 |
+
return new_state_dict
|
867 |
+
|
868 |
+
|
869 |
+
# ================#
|
870 |
+
# VAE Conversion #
|
871 |
+
# ================#
|
872 |
+
|
873 |
+
|
874 |
+
def reshape_weight_for_sd(w):
|
875 |
+
# convert HF linear weights to SD conv2d weights
|
876 |
+
return w.reshape(*w.shape, 1, 1)
|
877 |
+
|
878 |
+
|
879 |
+
def convert_vae_state_dict(vae_state_dict):
|
880 |
+
vae_conversion_map = [
|
881 |
+
# (stable-diffusion, HF Diffusers)
|
882 |
+
("nin_shortcut", "conv_shortcut"),
|
883 |
+
("norm_out", "conv_norm_out"),
|
884 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
885 |
+
]
|
886 |
+
|
887 |
+
for i in range(4):
|
888 |
+
# down_blocks have two resnets
|
889 |
+
for j in range(2):
|
890 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
891 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
892 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
893 |
+
|
894 |
+
if i < 3:
|
895 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
896 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
897 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
898 |
+
|
899 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
900 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
901 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
902 |
+
|
903 |
+
# up_blocks have three resnets
|
904 |
+
# also, up blocks in hf are numbered in reverse from sd
|
905 |
+
for j in range(3):
|
906 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
907 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
908 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
909 |
+
|
910 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
911 |
+
for i in range(2):
|
912 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
913 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
914 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
915 |
+
|
916 |
+
if diffusers.__version__ < "0.17.0":
|
917 |
+
vae_conversion_map_attn = [
|
918 |
+
# (stable-diffusion, HF Diffusers)
|
919 |
+
("norm.", "group_norm."),
|
920 |
+
("q.", "query."),
|
921 |
+
("k.", "key."),
|
922 |
+
("v.", "value."),
|
923 |
+
("proj_out.", "proj_attn."),
|
924 |
+
]
|
925 |
+
else:
|
926 |
+
vae_conversion_map_attn = [
|
927 |
+
# (stable-diffusion, HF Diffusers)
|
928 |
+
("norm.", "group_norm."),
|
929 |
+
("q.", "to_q."),
|
930 |
+
("k.", "to_k."),
|
931 |
+
("v.", "to_v."),
|
932 |
+
("proj_out.", "to_out.0."),
|
933 |
+
]
|
934 |
+
|
935 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
936 |
+
for k, v in mapping.items():
|
937 |
+
for sd_part, hf_part in vae_conversion_map:
|
938 |
+
v = v.replace(hf_part, sd_part)
|
939 |
+
mapping[k] = v
|
940 |
+
for k, v in mapping.items():
|
941 |
+
if "attentions" in k:
|
942 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
943 |
+
v = v.replace(hf_part, sd_part)
|
944 |
+
mapping[k] = v
|
945 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
946 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
947 |
+
for k, v in new_state_dict.items():
|
948 |
+
for weight_name in weights_to_convert:
|
949 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
950 |
+
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
951 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
952 |
+
|
953 |
+
return new_state_dict
|
954 |
+
|
955 |
+
|
956 |
+
# endregion
|
957 |
+
|
958 |
+
# region 自作のモデル読み書きなど
|
959 |
+
|
960 |
+
|
961 |
+
def is_safetensors(path):
|
962 |
+
return os.path.splitext(path)[1].lower() == ".safetensors"
|
963 |
+
|
964 |
+
|
965 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
966 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
967 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
968 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
969 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
970 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
971 |
+
]
|
972 |
+
|
973 |
+
if is_safetensors(ckpt_path):
|
974 |
+
checkpoint = None
|
975 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
976 |
+
else:
|
977 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
978 |
+
if "state_dict" in checkpoint:
|
979 |
+
state_dict = checkpoint["state_dict"]
|
980 |
+
else:
|
981 |
+
state_dict = checkpoint
|
982 |
+
checkpoint = None
|
983 |
+
|
984 |
+
key_reps = []
|
985 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
986 |
+
for key in state_dict.keys():
|
987 |
+
if key.startswith(rep_from):
|
988 |
+
new_key = rep_to + key[len(rep_from) :]
|
989 |
+
key_reps.append((key, new_key))
|
990 |
+
|
991 |
+
for key, new_key in key_reps:
|
992 |
+
state_dict[new_key] = state_dict[key]
|
993 |
+
del state_dict[key]
|
994 |
+
|
995 |
+
return checkpoint, state_dict
|
996 |
+
|
997 |
+
|
998 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
999 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
1000 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
1001 |
+
|
1002 |
+
# Convert the UNet2DConditionModel model.
|
1003 |
+
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
1004 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
1005 |
+
|
1006 |
+
unet = UNet2DConditionModel(**unet_config).to(device)
|
1007 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
1008 |
+
print("loading u-net:", info)
|
1009 |
+
|
1010 |
+
# Convert the VAE model.
|
1011 |
+
vae_config = create_vae_diffusers_config()
|
1012 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
1013 |
+
|
1014 |
+
vae = AutoencoderKL(**vae_config).to(device)
|
1015 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
1016 |
+
print("loading vae:", info)
|
1017 |
+
|
1018 |
+
# convert text_model
|
1019 |
+
if v2:
|
1020 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
1021 |
+
cfg = CLIPTextConfig(
|
1022 |
+
vocab_size=49408,
|
1023 |
+
hidden_size=1024,
|
1024 |
+
intermediate_size=4096,
|
1025 |
+
num_hidden_layers=23,
|
1026 |
+
num_attention_heads=16,
|
1027 |
+
max_position_embeddings=77,
|
1028 |
+
hidden_act="gelu",
|
1029 |
+
layer_norm_eps=1e-05,
|
1030 |
+
dropout=0.0,
|
1031 |
+
attention_dropout=0.0,
|
1032 |
+
initializer_range=0.02,
|
1033 |
+
initializer_factor=1.0,
|
1034 |
+
pad_token_id=1,
|
1035 |
+
bos_token_id=0,
|
1036 |
+
eos_token_id=2,
|
1037 |
+
model_type="clip_text_model",
|
1038 |
+
projection_dim=512,
|
1039 |
+
torch_dtype="float32",
|
1040 |
+
transformers_version="4.25.0.dev0",
|
1041 |
+
)
|
1042 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1043 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1044 |
+
else:
|
1045 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
1046 |
+
|
1047 |
+
# logging.set_verbosity_error() # don't show annoying warning
|
1048 |
+
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
1049 |
+
# logging.set_verbosity_warning()
|
1050 |
+
# print(f"config: {text_model.config}")
|
1051 |
+
cfg = CLIPTextConfig(
|
1052 |
+
vocab_size=49408,
|
1053 |
+
hidden_size=768,
|
1054 |
+
intermediate_size=3072,
|
1055 |
+
num_hidden_layers=12,
|
1056 |
+
num_attention_heads=12,
|
1057 |
+
max_position_embeddings=77,
|
1058 |
+
hidden_act="quick_gelu",
|
1059 |
+
layer_norm_eps=1e-05,
|
1060 |
+
dropout=0.0,
|
1061 |
+
attention_dropout=0.0,
|
1062 |
+
initializer_range=0.02,
|
1063 |
+
initializer_factor=1.0,
|
1064 |
+
pad_token_id=1,
|
1065 |
+
bos_token_id=0,
|
1066 |
+
eos_token_id=2,
|
1067 |
+
model_type="clip_text_model",
|
1068 |
+
projection_dim=768,
|
1069 |
+
torch_dtype="float32",
|
1070 |
+
)
|
1071 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1072 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1073 |
+
print("loading text encoder:", info)
|
1074 |
+
|
1075 |
+
return text_model, vae, unet
|
1076 |
+
|
1077 |
+
|
1078 |
+
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
1079 |
+
# only for reference
|
1080 |
+
version_str = "sd"
|
1081 |
+
if v2:
|
1082 |
+
version_str += "_v2"
|
1083 |
+
else:
|
1084 |
+
version_str += "_v1"
|
1085 |
+
if v_parameterization:
|
1086 |
+
version_str += "_v"
|
1087 |
+
return version_str
|
1088 |
+
|
1089 |
+
|
1090 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
1091 |
+
def convert_key(key):
|
1092 |
+
# position_idsの除去
|
1093 |
+
if ".position_ids" in key:
|
1094 |
+
return None
|
1095 |
+
|
1096 |
+
# common
|
1097 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
1098 |
+
key = key.replace("text_model.", "")
|
1099 |
+
if "layers" in key:
|
1100 |
+
# resblocks conversion
|
1101 |
+
key = key.replace(".layers.", ".resblocks.")
|
1102 |
+
if ".layer_norm" in key:
|
1103 |
+
key = key.replace(".layer_norm", ".ln_")
|
1104 |
+
elif ".mlp." in key:
|
1105 |
+
key = key.replace(".fc1.", ".c_fc.")
|
1106 |
+
key = key.replace(".fc2.", ".c_proj.")
|
1107 |
+
elif ".self_attn.out_proj" in key:
|
1108 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
1109 |
+
elif ".self_attn." in key:
|
1110 |
+
key = None # 特殊なので後で処理する
|
1111 |
+
else:
|
1112 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
1113 |
+
elif ".position_embedding" in key:
|
1114 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
1115 |
+
elif ".token_embedding" in key:
|
1116 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
1117 |
+
elif "final_layer_norm" in key:
|
1118 |
+
key = key.replace("final_layer_norm", "ln_final")
|
1119 |
+
return key
|
1120 |
+
|
1121 |
+
keys = list(checkpoint.keys())
|
1122 |
+
new_sd = {}
|
1123 |
+
for key in keys:
|
1124 |
+
new_key = convert_key(key)
|
1125 |
+
if new_key is None:
|
1126 |
+
continue
|
1127 |
+
new_sd[new_key] = checkpoint[key]
|
1128 |
+
|
1129 |
+
# attnの変換
|
1130 |
+
for key in keys:
|
1131 |
+
if "layers" in key and "q_proj" in key:
|
1132 |
+
# 三つを結合
|
1133 |
+
key_q = key
|
1134 |
+
key_k = key.replace("q_proj", "k_proj")
|
1135 |
+
key_v = key.replace("q_proj", "v_proj")
|
1136 |
+
|
1137 |
+
value_q = checkpoint[key_q]
|
1138 |
+
value_k = checkpoint[key_k]
|
1139 |
+
value_v = checkpoint[key_v]
|
1140 |
+
value = torch.cat([value_q, value_k, value_v])
|
1141 |
+
|
1142 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
1143 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
1144 |
+
new_sd[new_key] = value
|
1145 |
+
|
1146 |
+
# 最後の層などを捏造するか
|
1147 |
+
if make_dummy_weights:
|
1148 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
1149 |
+
keys = list(new_sd.keys())
|
1150 |
+
for key in keys:
|
1151 |
+
if key.startswith("transformer.resblocks.22."):
|
1152 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
1153 |
+
|
1154 |
+
# Diffusersに含まれない重みを作っておく
|
1155 |
+
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
1156 |
+
new_sd["logit_scale"] = torch.tensor(1)
|
1157 |
+
|
1158 |
+
return new_sd
|
1159 |
+
|
1160 |
+
|
1161 |
+
def save_stable_diffusion_checkpoint(
|
1162 |
+
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
1163 |
+
):
|
1164 |
+
if ckpt_path is not None:
|
1165 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1166 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1167 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1168 |
+
checkpoint = {}
|
1169 |
+
strict = False
|
1170 |
+
else:
|
1171 |
+
strict = True
|
1172 |
+
if "state_dict" in state_dict:
|
1173 |
+
del state_dict["state_dict"]
|
1174 |
+
else:
|
1175 |
+
# 新しく作る
|
1176 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1177 |
+
checkpoint = {}
|
1178 |
+
state_dict = {}
|
1179 |
+
strict = False
|
1180 |
+
|
1181 |
+
def update_sd(prefix, sd):
|
1182 |
+
for k, v in sd.items():
|
1183 |
+
key = prefix + k
|
1184 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1185 |
+
if save_dtype is not None:
|
1186 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1187 |
+
state_dict[key] = v
|
1188 |
+
|
1189 |
+
# Convert the UNet model
|
1190 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1191 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1192 |
+
|
1193 |
+
# Convert the text encoder model
|
1194 |
+
if v2:
|
1195 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1196 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1197 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1198 |
+
else:
|
1199 |
+
text_enc_dict = text_encoder.state_dict()
|
1200 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1201 |
+
|
1202 |
+
# Convert the VAE
|
1203 |
+
if vae is not None:
|
1204 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1205 |
+
update_sd("first_stage_model.", vae_dict)
|
1206 |
+
|
1207 |
+
# Put together new checkpoint
|
1208 |
+
key_count = len(state_dict.keys())
|
1209 |
+
new_ckpt = {"state_dict": state_dict}
|
1210 |
+
|
1211 |
+
# epoch and global_step are sometimes not int
|
1212 |
+
try:
|
1213 |
+
if "epoch" in checkpoint:
|
1214 |
+
epochs += checkpoint["epoch"]
|
1215 |
+
if "global_step" in checkpoint:
|
1216 |
+
steps += checkpoint["global_step"]
|
1217 |
+
except:
|
1218 |
+
pass
|
1219 |
+
|
1220 |
+
new_ckpt["epoch"] = epochs
|
1221 |
+
new_ckpt["global_step"] = steps
|
1222 |
+
|
1223 |
+
if is_safetensors(output_file):
|
1224 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1225 |
+
save_file(state_dict, output_file, metadata)
|
1226 |
+
else:
|
1227 |
+
torch.save(new_ckpt, output_file)
|
1228 |
+
|
1229 |
+
return key_count
|
1230 |
+
|
1231 |
+
|
1232 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1233 |
+
if pretrained_model_name_or_path is None:
|
1234 |
+
# load default settings for v1/v2
|
1235 |
+
if v2:
|
1236 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1237 |
+
else:
|
1238 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1239 |
+
|
1240 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1241 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1242 |
+
if vae is None:
|
1243 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1244 |
+
|
1245 |
+
pipeline = StableDiffusionPipeline(
|
1246 |
+
unet=unet,
|
1247 |
+
text_encoder=text_encoder,
|
1248 |
+
vae=vae,
|
1249 |
+
scheduler=scheduler,
|
1250 |
+
tokenizer=tokenizer,
|
1251 |
+
safety_checker=None,
|
1252 |
+
feature_extractor=None,
|
1253 |
+
requires_safety_checker=None,
|
1254 |
+
)
|
1255 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1256 |
+
|
1257 |
+
|
1258 |
+
VAE_PREFIX = "first_stage_model."
|
1259 |
+
|
1260 |
+
|
1261 |
+
def load_vae(vae_id, dtype):
|
1262 |
+
print(f"load VAE: {vae_id}")
|
1263 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1264 |
+
# Diffusers local/remote
|
1265 |
+
try:
|
1266 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1267 |
+
except EnvironmentError as e:
|
1268 |
+
print(f"exception occurs in loading vae: {e}")
|
1269 |
+
print("retry with subfolder='vae'")
|
1270 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1271 |
+
return vae
|
1272 |
+
|
1273 |
+
# local
|
1274 |
+
vae_config = create_vae_diffusers_config()
|
1275 |
+
|
1276 |
+
if vae_id.endswith(".bin"):
|
1277 |
+
# SD 1.5 VAE on Huggingface
|
1278 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1279 |
+
else:
|
1280 |
+
# StableDiffusion
|
1281 |
+
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
1282 |
+
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
1283 |
+
|
1284 |
+
# vae only or full model
|
1285 |
+
full_model = False
|
1286 |
+
for vae_key in vae_sd:
|
1287 |
+
if vae_key.startswith(VAE_PREFIX):
|
1288 |
+
full_model = True
|
1289 |
+
break
|
1290 |
+
if not full_model:
|
1291 |
+
sd = {}
|
1292 |
+
for key, value in vae_sd.items():
|
1293 |
+
sd[VAE_PREFIX + key] = value
|
1294 |
+
vae_sd = sd
|
1295 |
+
del sd
|
1296 |
+
|
1297 |
+
# Convert the VAE model.
|
1298 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1299 |
+
|
1300 |
+
vae = AutoencoderKL(**vae_config)
|
1301 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1302 |
+
return vae
|
1303 |
+
|
1304 |
+
|
1305 |
+
# endregion
|
1306 |
+
|
1307 |
+
|
1308 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1309 |
+
max_width, max_height = max_reso
|
1310 |
+
max_area = max_width * max_height
|
1311 |
+
|
1312 |
+
resos = set()
|
1313 |
+
|
1314 |
+
width = int(math.sqrt(max_area) // divisible) * divisible
|
1315 |
+
resos.add((width, width))
|
1316 |
+
|
1317 |
+
width = min_size
|
1318 |
+
while width <= max_size:
|
1319 |
+
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
1320 |
+
if height >= min_size:
|
1321 |
+
resos.add((width, height))
|
1322 |
+
resos.add((height, width))
|
1323 |
+
|
1324 |
+
# # make additional resos
|
1325 |
+
# if width >= height and width - divisible >= min_size:
|
1326 |
+
# resos.add((width - divisible, height))
|
1327 |
+
# resos.add((height, width - divisible))
|
1328 |
+
# if height >= width and height - divisible >= min_size:
|
1329 |
+
# resos.add((width, height - divisible))
|
1330 |
+
# resos.add((height - divisible, width))
|
1331 |
+
|
1332 |
+
width += divisible
|
1333 |
+
|
1334 |
+
resos = list(resos)
|
1335 |
+
resos.sort()
|
1336 |
+
return resos
|
1337 |
+
|
1338 |
+
|
1339 |
+
if __name__ == "__main__":
|
1340 |
+
resos = make_bucket_resolutions((512, 768))
|
1341 |
+
print(len(resos))
|
1342 |
+
print(resos)
|
1343 |
+
aspect_ratios = [w / h for w, h in resos]
|
1344 |
+
print(aspect_ratios)
|
1345 |
+
|
1346 |
+
ars = set()
|
1347 |
+
for ar in aspect_ratios:
|
1348 |
+
if ar in ars:
|
1349 |
+
print("error! duplicate ar:", ar)
|
1350 |
+
ars.add(ar)
|
external/llite/library/original_unet.py
ADDED
@@ -0,0 +1,1915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
|
2 |
+
# 条件分岐等で不要な部分は削除している
|
3 |
+
# コードの多くはDiffusersからコピーしている
|
4 |
+
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
|
5 |
+
|
6 |
+
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
|
7 |
+
# Unnecessary parts are deleted by condition branching.
|
8 |
+
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
|
9 |
+
|
10 |
+
"""
|
11 |
+
v1.5とv2.1の相違点は
|
12 |
+
- attention_head_dimがintかlist[int]か
|
13 |
+
- cross_attention_dimが768か1024か
|
14 |
+
- use_linear_projection: trueがない(=False, 1.5)かあるか
|
15 |
+
- upcast_attentionがFalse(1.5)かTrue(2.1)か
|
16 |
+
- (以下は多分無視していい)
|
17 |
+
- sample_sizeが64か96か
|
18 |
+
- dual_cross_attentionがあるかないか
|
19 |
+
- num_class_embedsがあるかないか
|
20 |
+
- only_cross_attentionがあるかないか
|
21 |
+
|
22 |
+
v1.5
|
23 |
+
{
|
24 |
+
"_class_name": "UNet2DConditionModel",
|
25 |
+
"_diffusers_version": "0.6.0",
|
26 |
+
"act_fn": "silu",
|
27 |
+
"attention_head_dim": 8,
|
28 |
+
"block_out_channels": [
|
29 |
+
320,
|
30 |
+
640,
|
31 |
+
1280,
|
32 |
+
1280
|
33 |
+
],
|
34 |
+
"center_input_sample": false,
|
35 |
+
"cross_attention_dim": 768,
|
36 |
+
"down_block_types": [
|
37 |
+
"CrossAttnDownBlock2D",
|
38 |
+
"CrossAttnDownBlock2D",
|
39 |
+
"CrossAttnDownBlock2D",
|
40 |
+
"DownBlock2D"
|
41 |
+
],
|
42 |
+
"downsample_padding": 1,
|
43 |
+
"flip_sin_to_cos": true,
|
44 |
+
"freq_shift": 0,
|
45 |
+
"in_channels": 4,
|
46 |
+
"layers_per_block": 2,
|
47 |
+
"mid_block_scale_factor": 1,
|
48 |
+
"norm_eps": 1e-05,
|
49 |
+
"norm_num_groups": 32,
|
50 |
+
"out_channels": 4,
|
51 |
+
"sample_size": 64,
|
52 |
+
"up_block_types": [
|
53 |
+
"UpBlock2D",
|
54 |
+
"CrossAttnUpBlock2D",
|
55 |
+
"CrossAttnUpBlock2D",
|
56 |
+
"CrossAttnUpBlock2D"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
|
60 |
+
v2.1
|
61 |
+
{
|
62 |
+
"_class_name": "UNet2DConditionModel",
|
63 |
+
"_diffusers_version": "0.10.0.dev0",
|
64 |
+
"act_fn": "silu",
|
65 |
+
"attention_head_dim": [
|
66 |
+
5,
|
67 |
+
10,
|
68 |
+
20,
|
69 |
+
20
|
70 |
+
],
|
71 |
+
"block_out_channels": [
|
72 |
+
320,
|
73 |
+
640,
|
74 |
+
1280,
|
75 |
+
1280
|
76 |
+
],
|
77 |
+
"center_input_sample": false,
|
78 |
+
"cross_attention_dim": 1024,
|
79 |
+
"down_block_types": [
|
80 |
+
"CrossAttnDownBlock2D",
|
81 |
+
"CrossAttnDownBlock2D",
|
82 |
+
"CrossAttnDownBlock2D",
|
83 |
+
"DownBlock2D"
|
84 |
+
],
|
85 |
+
"downsample_padding": 1,
|
86 |
+
"dual_cross_attention": false,
|
87 |
+
"flip_sin_to_cos": true,
|
88 |
+
"freq_shift": 0,
|
89 |
+
"in_channels": 4,
|
90 |
+
"layers_per_block": 2,
|
91 |
+
"mid_block_scale_factor": 1,
|
92 |
+
"norm_eps": 1e-05,
|
93 |
+
"norm_num_groups": 32,
|
94 |
+
"num_class_embeds": null,
|
95 |
+
"only_cross_attention": false,
|
96 |
+
"out_channels": 4,
|
97 |
+
"sample_size": 96,
|
98 |
+
"up_block_types": [
|
99 |
+
"UpBlock2D",
|
100 |
+
"CrossAttnUpBlock2D",
|
101 |
+
"CrossAttnUpBlock2D",
|
102 |
+
"CrossAttnUpBlock2D"
|
103 |
+
],
|
104 |
+
"use_linear_projection": true,
|
105 |
+
"upcast_attention": true
|
106 |
+
}
|
107 |
+
"""
|
108 |
+
|
109 |
+
import math
|
110 |
+
from types import SimpleNamespace
|
111 |
+
from typing import Dict, Optional, Tuple, Union
|
112 |
+
import torch
|
113 |
+
from torch import nn
|
114 |
+
from torch.nn import functional as F
|
115 |
+
from einops import rearrange
|
116 |
+
|
117 |
+
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
118 |
+
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
119 |
+
TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
|
120 |
+
IN_CHANNELS: int = 4
|
121 |
+
OUT_CHANNELS: int = 4
|
122 |
+
LAYERS_PER_BLOCK: int = 2
|
123 |
+
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
|
124 |
+
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
|
125 |
+
TIME_EMBED_FREQ_SHIFT: int = 0
|
126 |
+
NORM_GROUPS: int = 32
|
127 |
+
NORM_EPS: float = 1e-5
|
128 |
+
TRANSFORMER_NORM_NUM_GROUPS = 32
|
129 |
+
|
130 |
+
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
|
131 |
+
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
|
132 |
+
|
133 |
+
|
134 |
+
# region memory efficient attention
|
135 |
+
|
136 |
+
# FlashAttentionを使うCrossAttention
|
137 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
138 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
139 |
+
|
140 |
+
# constants
|
141 |
+
|
142 |
+
EPSILON = 1e-6
|
143 |
+
|
144 |
+
# helper functions
|
145 |
+
|
146 |
+
|
147 |
+
def exists(val):
|
148 |
+
return val is not None
|
149 |
+
|
150 |
+
|
151 |
+
def default(val, d):
|
152 |
+
return val if exists(val) else d
|
153 |
+
|
154 |
+
|
155 |
+
# flash attention forwards and backwards
|
156 |
+
|
157 |
+
# https://arxiv.org/abs/2205.14135
|
158 |
+
|
159 |
+
|
160 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
161 |
+
@staticmethod
|
162 |
+
@torch.no_grad()
|
163 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
164 |
+
"""Algorithm 2 in the paper"""
|
165 |
+
|
166 |
+
device = q.device
|
167 |
+
dtype = q.dtype
|
168 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
169 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
170 |
+
|
171 |
+
o = torch.zeros_like(q)
|
172 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
173 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
174 |
+
|
175 |
+
scale = q.shape[-1] ** -0.5
|
176 |
+
|
177 |
+
if not exists(mask):
|
178 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
179 |
+
else:
|
180 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
181 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
182 |
+
|
183 |
+
row_splits = zip(
|
184 |
+
q.split(q_bucket_size, dim=-2),
|
185 |
+
o.split(q_bucket_size, dim=-2),
|
186 |
+
mask,
|
187 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
188 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
189 |
+
)
|
190 |
+
|
191 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
192 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
193 |
+
|
194 |
+
col_splits = zip(
|
195 |
+
k.split(k_bucket_size, dim=-2),
|
196 |
+
v.split(k_bucket_size, dim=-2),
|
197 |
+
)
|
198 |
+
|
199 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
200 |
+
k_start_index = k_ind * k_bucket_size
|
201 |
+
|
202 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
203 |
+
|
204 |
+
if exists(row_mask):
|
205 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
206 |
+
|
207 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
208 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
209 |
+
q_start_index - k_start_index + 1
|
210 |
+
)
|
211 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
212 |
+
|
213 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
214 |
+
attn_weights -= block_row_maxes
|
215 |
+
exp_weights = torch.exp(attn_weights)
|
216 |
+
|
217 |
+
if exists(row_mask):
|
218 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
219 |
+
|
220 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
221 |
+
|
222 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
223 |
+
|
224 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
225 |
+
|
226 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
227 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
228 |
+
|
229 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
230 |
+
|
231 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
232 |
+
|
233 |
+
row_maxes.copy_(new_row_maxes)
|
234 |
+
row_sums.copy_(new_row_sums)
|
235 |
+
|
236 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
237 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
238 |
+
|
239 |
+
return o
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
@torch.no_grad()
|
243 |
+
def backward(ctx, do):
|
244 |
+
"""Algorithm 4 in the paper"""
|
245 |
+
|
246 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
247 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
248 |
+
|
249 |
+
device = q.device
|
250 |
+
|
251 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
252 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
253 |
+
|
254 |
+
dq = torch.zeros_like(q)
|
255 |
+
dk = torch.zeros_like(k)
|
256 |
+
dv = torch.zeros_like(v)
|
257 |
+
|
258 |
+
row_splits = zip(
|
259 |
+
q.split(q_bucket_size, dim=-2),
|
260 |
+
o.split(q_bucket_size, dim=-2),
|
261 |
+
do.split(q_bucket_size, dim=-2),
|
262 |
+
mask,
|
263 |
+
l.split(q_bucket_size, dim=-2),
|
264 |
+
m.split(q_bucket_size, dim=-2),
|
265 |
+
dq.split(q_bucket_size, dim=-2),
|
266 |
+
)
|
267 |
+
|
268 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
269 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
270 |
+
|
271 |
+
col_splits = zip(
|
272 |
+
k.split(k_bucket_size, dim=-2),
|
273 |
+
v.split(k_bucket_size, dim=-2),
|
274 |
+
dk.split(k_bucket_size, dim=-2),
|
275 |
+
dv.split(k_bucket_size, dim=-2),
|
276 |
+
)
|
277 |
+
|
278 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
279 |
+
k_start_index = k_ind * k_bucket_size
|
280 |
+
|
281 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
282 |
+
|
283 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
284 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
285 |
+
q_start_index - k_start_index + 1
|
286 |
+
)
|
287 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
288 |
+
|
289 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
290 |
+
|
291 |
+
if exists(row_mask):
|
292 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
293 |
+
|
294 |
+
p = exp_attn_weights / lc
|
295 |
+
|
296 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
297 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
298 |
+
|
299 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
300 |
+
ds = p * scale * (dp - D)
|
301 |
+
|
302 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
303 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
304 |
+
|
305 |
+
dqc.add_(dq_chunk)
|
306 |
+
dkc.add_(dk_chunk)
|
307 |
+
dvc.add_(dv_chunk)
|
308 |
+
|
309 |
+
return dq, dk, dv, None, None, None, None
|
310 |
+
|
311 |
+
|
312 |
+
# endregion
|
313 |
+
|
314 |
+
|
315 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
316 |
+
return next(parameter.parameters()).dtype
|
317 |
+
|
318 |
+
|
319 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
320 |
+
return next(parameter.parameters()).device
|
321 |
+
|
322 |
+
|
323 |
+
def get_timestep_embedding(
|
324 |
+
timesteps: torch.Tensor,
|
325 |
+
embedding_dim: int,
|
326 |
+
flip_sin_to_cos: bool = False,
|
327 |
+
downscale_freq_shift: float = 1,
|
328 |
+
scale: float = 1,
|
329 |
+
max_period: int = 10000,
|
330 |
+
):
|
331 |
+
"""
|
332 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
333 |
+
|
334 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
335 |
+
These may be fractional.
|
336 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
337 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
338 |
+
"""
|
339 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
340 |
+
|
341 |
+
half_dim = embedding_dim // 2
|
342 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
343 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
344 |
+
|
345 |
+
emb = torch.exp(exponent)
|
346 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
347 |
+
|
348 |
+
# scale embeddings
|
349 |
+
emb = scale * emb
|
350 |
+
|
351 |
+
# concat sine and cosine embeddings
|
352 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
353 |
+
|
354 |
+
# flip sine and cosine embeddings
|
355 |
+
if flip_sin_to_cos:
|
356 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
357 |
+
|
358 |
+
# zero pad
|
359 |
+
if embedding_dim % 2 == 1:
|
360 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
361 |
+
return emb
|
362 |
+
|
363 |
+
|
364 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
365 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
366 |
+
org_dtype = x.dtype
|
367 |
+
if org_dtype == torch.bfloat16:
|
368 |
+
x = x.to(torch.float32)
|
369 |
+
|
370 |
+
if x.shape[-2:] != target.shape[-2:]:
|
371 |
+
if mode == "nearest":
|
372 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
373 |
+
else:
|
374 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
375 |
+
|
376 |
+
if org_dtype == torch.bfloat16:
|
377 |
+
x = x.to(org_dtype)
|
378 |
+
return x
|
379 |
+
|
380 |
+
|
381 |
+
class SampleOutput:
|
382 |
+
def __init__(self, sample):
|
383 |
+
self.sample = sample
|
384 |
+
|
385 |
+
|
386 |
+
class TimestepEmbedding(nn.Module):
|
387 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
388 |
+
super().__init__()
|
389 |
+
|
390 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
391 |
+
self.act = None
|
392 |
+
if act_fn == "silu":
|
393 |
+
self.act = nn.SiLU()
|
394 |
+
elif act_fn == "mish":
|
395 |
+
self.act = nn.Mish()
|
396 |
+
|
397 |
+
if out_dim is not None:
|
398 |
+
time_embed_dim_out = out_dim
|
399 |
+
else:
|
400 |
+
time_embed_dim_out = time_embed_dim
|
401 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
402 |
+
|
403 |
+
def forward(self, sample):
|
404 |
+
sample = self.linear_1(sample)
|
405 |
+
|
406 |
+
if self.act is not None:
|
407 |
+
sample = self.act(sample)
|
408 |
+
|
409 |
+
sample = self.linear_2(sample)
|
410 |
+
return sample
|
411 |
+
|
412 |
+
|
413 |
+
class Timesteps(nn.Module):
|
414 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
415 |
+
super().__init__()
|
416 |
+
self.num_channels = num_channels
|
417 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
418 |
+
self.downscale_freq_shift = downscale_freq_shift
|
419 |
+
|
420 |
+
def forward(self, timesteps):
|
421 |
+
t_emb = get_timestep_embedding(
|
422 |
+
timesteps,
|
423 |
+
self.num_channels,
|
424 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
425 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
426 |
+
)
|
427 |
+
return t_emb
|
428 |
+
|
429 |
+
|
430 |
+
class ResnetBlock2D(nn.Module):
|
431 |
+
def __init__(
|
432 |
+
self,
|
433 |
+
in_channels,
|
434 |
+
out_channels,
|
435 |
+
):
|
436 |
+
super().__init__()
|
437 |
+
self.in_channels = in_channels
|
438 |
+
self.out_channels = out_channels
|
439 |
+
|
440 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
|
441 |
+
|
442 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
443 |
+
|
444 |
+
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
|
445 |
+
|
446 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
|
447 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
448 |
+
|
449 |
+
# if non_linearity == "swish":
|
450 |
+
self.nonlinearity = lambda x: F.silu(x)
|
451 |
+
|
452 |
+
self.use_in_shortcut = self.in_channels != self.out_channels
|
453 |
+
|
454 |
+
self.conv_shortcut = None
|
455 |
+
if self.use_in_shortcut:
|
456 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
457 |
+
|
458 |
+
def forward(self, input_tensor, temb):
|
459 |
+
hidden_states = input_tensor
|
460 |
+
|
461 |
+
hidden_states = self.norm1(hidden_states)
|
462 |
+
hidden_states = self.nonlinearity(hidden_states)
|
463 |
+
|
464 |
+
hidden_states = self.conv1(hidden_states)
|
465 |
+
|
466 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
467 |
+
hidden_states = hidden_states + temb
|
468 |
+
|
469 |
+
hidden_states = self.norm2(hidden_states)
|
470 |
+
hidden_states = self.nonlinearity(hidden_states)
|
471 |
+
|
472 |
+
hidden_states = self.conv2(hidden_states)
|
473 |
+
|
474 |
+
if self.conv_shortcut is not None:
|
475 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
476 |
+
|
477 |
+
output_tensor = input_tensor + hidden_states
|
478 |
+
|
479 |
+
return output_tensor
|
480 |
+
|
481 |
+
|
482 |
+
class DownBlock2D(nn.Module):
|
483 |
+
def __init__(
|
484 |
+
self,
|
485 |
+
in_channels: int,
|
486 |
+
out_channels: int,
|
487 |
+
add_downsample=True,
|
488 |
+
):
|
489 |
+
super().__init__()
|
490 |
+
|
491 |
+
self.has_cross_attention = False
|
492 |
+
resnets = []
|
493 |
+
|
494 |
+
for i in range(LAYERS_PER_BLOCK):
|
495 |
+
in_channels = in_channels if i == 0 else out_channels
|
496 |
+
resnets.append(
|
497 |
+
ResnetBlock2D(
|
498 |
+
in_channels=in_channels,
|
499 |
+
out_channels=out_channels,
|
500 |
+
)
|
501 |
+
)
|
502 |
+
self.resnets = nn.ModuleList(resnets)
|
503 |
+
|
504 |
+
if add_downsample:
|
505 |
+
self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
|
506 |
+
else:
|
507 |
+
self.downsamplers = None
|
508 |
+
|
509 |
+
self.gradient_checkpointing = False
|
510 |
+
|
511 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
512 |
+
pass
|
513 |
+
|
514 |
+
def set_use_sdpa(self, sdpa):
|
515 |
+
pass
|
516 |
+
|
517 |
+
def forward(self, hidden_states, temb=None):
|
518 |
+
output_states = ()
|
519 |
+
|
520 |
+
for resnet in self.resnets:
|
521 |
+
if self.training and self.gradient_checkpointing:
|
522 |
+
|
523 |
+
def create_custom_forward(module):
|
524 |
+
def custom_forward(*inputs):
|
525 |
+
return module(*inputs)
|
526 |
+
|
527 |
+
return custom_forward
|
528 |
+
|
529 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
530 |
+
else:
|
531 |
+
hidden_states = resnet(hidden_states, temb)
|
532 |
+
|
533 |
+
output_states += (hidden_states,)
|
534 |
+
|
535 |
+
if self.downsamplers is not None:
|
536 |
+
for downsampler in self.downsamplers:
|
537 |
+
hidden_states = downsampler(hidden_states)
|
538 |
+
|
539 |
+
output_states += (hidden_states,)
|
540 |
+
|
541 |
+
return hidden_states, output_states
|
542 |
+
|
543 |
+
|
544 |
+
class Downsample2D(nn.Module):
|
545 |
+
def __init__(self, channels, out_channels):
|
546 |
+
super().__init__()
|
547 |
+
|
548 |
+
self.channels = channels
|
549 |
+
self.out_channels = out_channels
|
550 |
+
|
551 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
552 |
+
|
553 |
+
def forward(self, hidden_states):
|
554 |
+
assert hidden_states.shape[1] == self.channels
|
555 |
+
hidden_states = self.conv(hidden_states)
|
556 |
+
|
557 |
+
return hidden_states
|
558 |
+
|
559 |
+
|
560 |
+
class CrossAttention(nn.Module):
|
561 |
+
def __init__(
|
562 |
+
self,
|
563 |
+
query_dim: int,
|
564 |
+
cross_attention_dim: Optional[int] = None,
|
565 |
+
heads: int = 8,
|
566 |
+
dim_head: int = 64,
|
567 |
+
upcast_attention: bool = False,
|
568 |
+
):
|
569 |
+
super().__init__()
|
570 |
+
inner_dim = dim_head * heads
|
571 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
572 |
+
self.upcast_attention = upcast_attention
|
573 |
+
|
574 |
+
self.scale = dim_head**-0.5
|
575 |
+
self.heads = heads
|
576 |
+
|
577 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
578 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
579 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
580 |
+
|
581 |
+
self.to_out = nn.ModuleList([])
|
582 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
583 |
+
# no dropout here
|
584 |
+
|
585 |
+
self.use_memory_efficient_attention_xformers = False
|
586 |
+
self.use_memory_efficient_attention_mem_eff = False
|
587 |
+
self.use_sdpa = False
|
588 |
+
|
589 |
+
# Attention processor
|
590 |
+
self.processor = None
|
591 |
+
|
592 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
593 |
+
self.use_memory_efficient_attention_xformers = xformers
|
594 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
595 |
+
|
596 |
+
def set_use_sdpa(self, sdpa):
|
597 |
+
self.use_sdpa = sdpa
|
598 |
+
|
599 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
600 |
+
batch_size, seq_len, dim = tensor.shape
|
601 |
+
head_size = self.heads
|
602 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
603 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
604 |
+
return tensor
|
605 |
+
|
606 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
607 |
+
batch_size, seq_len, dim = tensor.shape
|
608 |
+
head_size = self.heads
|
609 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
610 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
611 |
+
return tensor
|
612 |
+
|
613 |
+
def set_processor(self):
|
614 |
+
return self.processor
|
615 |
+
|
616 |
+
def get_processor(self):
|
617 |
+
return self.processor
|
618 |
+
|
619 |
+
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
620 |
+
if self.processor is not None:
|
621 |
+
(
|
622 |
+
hidden_states,
|
623 |
+
encoder_hidden_states,
|
624 |
+
attention_mask,
|
625 |
+
) = translate_attention_names_from_diffusers(
|
626 |
+
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
627 |
+
)
|
628 |
+
return self.processor(
|
629 |
+
attn=self,
|
630 |
+
hidden_states=hidden_states,
|
631 |
+
encoder_hidden_states=context,
|
632 |
+
attention_mask=mask,
|
633 |
+
**kwargs
|
634 |
+
)
|
635 |
+
if self.use_memory_efficient_attention_xformers:
|
636 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
637 |
+
if self.use_memory_efficient_attention_mem_eff:
|
638 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
639 |
+
if self.use_sdpa:
|
640 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
641 |
+
|
642 |
+
query = self.to_q(hidden_states)
|
643 |
+
context = context if context is not None else hidden_states
|
644 |
+
key = self.to_k(context)
|
645 |
+
value = self.to_v(context)
|
646 |
+
|
647 |
+
query = self.reshape_heads_to_batch_dim(query)
|
648 |
+
key = self.reshape_heads_to_batch_dim(key)
|
649 |
+
value = self.reshape_heads_to_batch_dim(value)
|
650 |
+
|
651 |
+
hidden_states = self._attention(query, key, value)
|
652 |
+
|
653 |
+
# linear proj
|
654 |
+
hidden_states = self.to_out[0](hidden_states)
|
655 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
656 |
+
return hidden_states
|
657 |
+
|
658 |
+
def _attention(self, query, key, value):
|
659 |
+
if self.upcast_attention:
|
660 |
+
query = query.float()
|
661 |
+
key = key.float()
|
662 |
+
|
663 |
+
attention_scores = torch.baddbmm(
|
664 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
665 |
+
query,
|
666 |
+
key.transpose(-1, -2),
|
667 |
+
beta=0,
|
668 |
+
alpha=self.scale,
|
669 |
+
)
|
670 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
671 |
+
|
672 |
+
# cast back to the original dtype
|
673 |
+
attention_probs = attention_probs.to(value.dtype)
|
674 |
+
|
675 |
+
# compute attention output
|
676 |
+
hidden_states = torch.bmm(attention_probs, value)
|
677 |
+
|
678 |
+
# reshape hidden_states
|
679 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
680 |
+
return hidden_states
|
681 |
+
|
682 |
+
# TODO support Hypernetworks
|
683 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
684 |
+
import xformers.ops
|
685 |
+
|
686 |
+
h = self.heads
|
687 |
+
q_in = self.to_q(x)
|
688 |
+
context = context if context is not None else x
|
689 |
+
context = context.to(x.dtype)
|
690 |
+
k_in = self.to_k(context)
|
691 |
+
v_in = self.to_v(context)
|
692 |
+
|
693 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
694 |
+
del q_in, k_in, v_in
|
695 |
+
|
696 |
+
q = q.contiguous()
|
697 |
+
k = k.contiguous()
|
698 |
+
v = v.contiguous()
|
699 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
700 |
+
|
701 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
702 |
+
|
703 |
+
out = self.to_out[0](out)
|
704 |
+
return out
|
705 |
+
|
706 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
707 |
+
flash_func = FlashAttentionFunction
|
708 |
+
|
709 |
+
q_bucket_size = 512
|
710 |
+
k_bucket_size = 1024
|
711 |
+
|
712 |
+
h = self.heads
|
713 |
+
q = self.to_q(x)
|
714 |
+
context = context if context is not None else x
|
715 |
+
context = context.to(x.dtype)
|
716 |
+
k = self.to_k(context)
|
717 |
+
v = self.to_v(context)
|
718 |
+
del context, x
|
719 |
+
|
720 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
721 |
+
|
722 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
723 |
+
|
724 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
725 |
+
|
726 |
+
out = self.to_out[0](out)
|
727 |
+
return out
|
728 |
+
|
729 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
730 |
+
h = self.heads
|
731 |
+
q_in = self.to_q(x)
|
732 |
+
context = context if context is not None else x
|
733 |
+
context = context.to(x.dtype)
|
734 |
+
k_in = self.to_k(context)
|
735 |
+
v_in = self.to_v(context)
|
736 |
+
|
737 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
738 |
+
del q_in, k_in, v_in
|
739 |
+
|
740 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
741 |
+
|
742 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
743 |
+
|
744 |
+
out = self.to_out[0](out)
|
745 |
+
return out
|
746 |
+
|
747 |
+
def translate_attention_names_from_diffusers(
|
748 |
+
hidden_states: torch.FloatTensor,
|
749 |
+
context: Optional[torch.FloatTensor] = None,
|
750 |
+
mask: Optional[torch.FloatTensor] = None,
|
751 |
+
# HF naming
|
752 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
753 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
754 |
+
):
|
755 |
+
# translate from hugging face diffusers
|
756 |
+
context = context if context is not None else encoder_hidden_states
|
757 |
+
|
758 |
+
# translate from hugging face diffusers
|
759 |
+
mask = mask if mask is not None else attention_mask
|
760 |
+
|
761 |
+
return hidden_states, context, mask
|
762 |
+
|
763 |
+
# feedforward
|
764 |
+
class GEGLU(nn.Module):
|
765 |
+
r"""
|
766 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
767 |
+
|
768 |
+
Parameters:
|
769 |
+
dim_in (`int`): The number of channels in the input.
|
770 |
+
dim_out (`int`): The number of channels in the output.
|
771 |
+
"""
|
772 |
+
|
773 |
+
def __init__(self, dim_in: int, dim_out: int):
|
774 |
+
super().__init__()
|
775 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
776 |
+
|
777 |
+
def gelu(self, gate):
|
778 |
+
if gate.device.type != "mps":
|
779 |
+
return F.gelu(gate)
|
780 |
+
# mps: gelu is not implemented for float16
|
781 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
782 |
+
|
783 |
+
def forward(self, hidden_states):
|
784 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
785 |
+
return hidden_states * self.gelu(gate)
|
786 |
+
|
787 |
+
|
788 |
+
class FeedForward(nn.Module):
|
789 |
+
def __init__(
|
790 |
+
self,
|
791 |
+
dim: int,
|
792 |
+
):
|
793 |
+
super().__init__()
|
794 |
+
inner_dim = int(dim * 4) # mult is always 4
|
795 |
+
|
796 |
+
self.net = nn.ModuleList([])
|
797 |
+
# project in
|
798 |
+
self.net.append(GEGLU(dim, inner_dim))
|
799 |
+
# project dropout
|
800 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
801 |
+
# project out
|
802 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
803 |
+
|
804 |
+
def forward(self, hidden_states):
|
805 |
+
for module in self.net:
|
806 |
+
hidden_states = module(hidden_states)
|
807 |
+
return hidden_states
|
808 |
+
|
809 |
+
|
810 |
+
class BasicTransformerBlock(nn.Module):
|
811 |
+
def __init__(
|
812 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
813 |
+
):
|
814 |
+
super().__init__()
|
815 |
+
|
816 |
+
# 1. Self-Attn
|
817 |
+
self.attn1 = CrossAttention(
|
818 |
+
query_dim=dim,
|
819 |
+
cross_attention_dim=None,
|
820 |
+
heads=num_attention_heads,
|
821 |
+
dim_head=attention_head_dim,
|
822 |
+
upcast_attention=upcast_attention,
|
823 |
+
)
|
824 |
+
self.ff = FeedForward(dim)
|
825 |
+
|
826 |
+
# 2. Cross-Attn
|
827 |
+
self.attn2 = CrossAttention(
|
828 |
+
query_dim=dim,
|
829 |
+
cross_attention_dim=cross_attention_dim,
|
830 |
+
heads=num_attention_heads,
|
831 |
+
dim_head=attention_head_dim,
|
832 |
+
upcast_attention=upcast_attention,
|
833 |
+
)
|
834 |
+
|
835 |
+
self.norm1 = nn.LayerNorm(dim)
|
836 |
+
self.norm2 = nn.LayerNorm(dim)
|
837 |
+
|
838 |
+
# 3. Feed-forward
|
839 |
+
self.norm3 = nn.LayerNorm(dim)
|
840 |
+
|
841 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
842 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
843 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
844 |
+
|
845 |
+
def set_use_sdpa(self, sdpa: bool):
|
846 |
+
self.attn1.set_use_sdpa(sdpa)
|
847 |
+
self.attn2.set_use_sdpa(sdpa)
|
848 |
+
|
849 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
850 |
+
# 1. Self-Attention
|
851 |
+
norm_hidden_states = self.norm1(hidden_states)
|
852 |
+
|
853 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
854 |
+
|
855 |
+
# 2. Cross-Attention
|
856 |
+
norm_hidden_states = self.norm2(hidden_states)
|
857 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
858 |
+
|
859 |
+
# 3. Feed-forward
|
860 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
861 |
+
|
862 |
+
return hidden_states
|
863 |
+
|
864 |
+
|
865 |
+
class Transformer2DModel(nn.Module):
|
866 |
+
def __init__(
|
867 |
+
self,
|
868 |
+
num_attention_heads: int = 16,
|
869 |
+
attention_head_dim: int = 88,
|
870 |
+
in_channels: Optional[int] = None,
|
871 |
+
cross_attention_dim: Optional[int] = None,
|
872 |
+
use_linear_projection: bool = False,
|
873 |
+
upcast_attention: bool = False,
|
874 |
+
):
|
875 |
+
super().__init__()
|
876 |
+
self.in_channels = in_channels
|
877 |
+
self.num_attention_heads = num_attention_heads
|
878 |
+
self.attention_head_dim = attention_head_dim
|
879 |
+
inner_dim = num_attention_heads * attention_head_dim
|
880 |
+
self.use_linear_projection = use_linear_projection
|
881 |
+
|
882 |
+
self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
|
883 |
+
|
884 |
+
if use_linear_projection:
|
885 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
886 |
+
else:
|
887 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
888 |
+
|
889 |
+
self.transformer_blocks = nn.ModuleList(
|
890 |
+
[
|
891 |
+
BasicTransformerBlock(
|
892 |
+
inner_dim,
|
893 |
+
num_attention_heads,
|
894 |
+
attention_head_dim,
|
895 |
+
cross_attention_dim=cross_attention_dim,
|
896 |
+
upcast_attention=upcast_attention,
|
897 |
+
)
|
898 |
+
]
|
899 |
+
)
|
900 |
+
|
901 |
+
if use_linear_projection:
|
902 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
903 |
+
else:
|
904 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
905 |
+
|
906 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
907 |
+
for transformer in self.transformer_blocks:
|
908 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
909 |
+
|
910 |
+
def set_use_sdpa(self, sdpa):
|
911 |
+
for transformer in self.transformer_blocks:
|
912 |
+
transformer.set_use_sdpa(sdpa)
|
913 |
+
|
914 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
915 |
+
# 1. Input
|
916 |
+
batch, _, height, weight = hidden_states.shape
|
917 |
+
residual = hidden_states
|
918 |
+
|
919 |
+
hidden_states = self.norm(hidden_states)
|
920 |
+
if not self.use_linear_projection:
|
921 |
+
hidden_states = self.proj_in(hidden_states)
|
922 |
+
inner_dim = hidden_states.shape[1]
|
923 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
924 |
+
else:
|
925 |
+
inner_dim = hidden_states.shape[1]
|
926 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
927 |
+
hidden_states = self.proj_in(hidden_states)
|
928 |
+
|
929 |
+
# 2. Blocks
|
930 |
+
for block in self.transformer_blocks:
|
931 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
932 |
+
|
933 |
+
# 3. Output
|
934 |
+
if not self.use_linear_projection:
|
935 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
936 |
+
hidden_states = self.proj_out(hidden_states)
|
937 |
+
else:
|
938 |
+
hidden_states = self.proj_out(hidden_states)
|
939 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
940 |
+
|
941 |
+
output = hidden_states + residual
|
942 |
+
|
943 |
+
if not return_dict:
|
944 |
+
return (output,)
|
945 |
+
|
946 |
+
return SampleOutput(sample=output)
|
947 |
+
|
948 |
+
|
949 |
+
class CrossAttnDownBlock2D(nn.Module):
|
950 |
+
def __init__(
|
951 |
+
self,
|
952 |
+
in_channels: int,
|
953 |
+
out_channels: int,
|
954 |
+
add_downsample=True,
|
955 |
+
cross_attention_dim=1280,
|
956 |
+
attn_num_head_channels=1,
|
957 |
+
use_linear_projection=False,
|
958 |
+
upcast_attention=False,
|
959 |
+
):
|
960 |
+
super().__init__()
|
961 |
+
self.has_cross_attention = True
|
962 |
+
resnets = []
|
963 |
+
attentions = []
|
964 |
+
|
965 |
+
self.attn_num_head_channels = attn_num_head_channels
|
966 |
+
|
967 |
+
for i in range(LAYERS_PER_BLOCK):
|
968 |
+
in_channels = in_channels if i == 0 else out_channels
|
969 |
+
|
970 |
+
resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
|
971 |
+
attentions.append(
|
972 |
+
Transformer2DModel(
|
973 |
+
attn_num_head_channels,
|
974 |
+
out_channels // attn_num_head_channels,
|
975 |
+
in_channels=out_channels,
|
976 |
+
cross_attention_dim=cross_attention_dim,
|
977 |
+
use_linear_projection=use_linear_projection,
|
978 |
+
upcast_attention=upcast_attention,
|
979 |
+
)
|
980 |
+
)
|
981 |
+
self.attentions = nn.ModuleList(attentions)
|
982 |
+
self.resnets = nn.ModuleList(resnets)
|
983 |
+
|
984 |
+
if add_downsample:
|
985 |
+
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
|
986 |
+
else:
|
987 |
+
self.downsamplers = None
|
988 |
+
|
989 |
+
self.gradient_checkpointing = False
|
990 |
+
|
991 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
992 |
+
for attn in self.attentions:
|
993 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
994 |
+
|
995 |
+
def set_use_sdpa(self, sdpa):
|
996 |
+
for attn in self.attentions:
|
997 |
+
attn.set_use_sdpa(sdpa)
|
998 |
+
|
999 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
1000 |
+
output_states = ()
|
1001 |
+
|
1002 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1003 |
+
if self.training and self.gradient_checkpointing:
|
1004 |
+
|
1005 |
+
def create_custom_forward(module, return_dict=None):
|
1006 |
+
def custom_forward(*inputs):
|
1007 |
+
if return_dict is not None:
|
1008 |
+
return module(*inputs, return_dict=return_dict)
|
1009 |
+
else:
|
1010 |
+
return module(*inputs)
|
1011 |
+
|
1012 |
+
return custom_forward
|
1013 |
+
|
1014 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1015 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1016 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1017 |
+
)[0]
|
1018 |
+
else:
|
1019 |
+
hidden_states = resnet(hidden_states, temb)
|
1020 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1021 |
+
|
1022 |
+
output_states += (hidden_states,)
|
1023 |
+
|
1024 |
+
if self.downsamplers is not None:
|
1025 |
+
for downsampler in self.downsamplers:
|
1026 |
+
hidden_states = downsampler(hidden_states)
|
1027 |
+
|
1028 |
+
output_states += (hidden_states,)
|
1029 |
+
|
1030 |
+
return hidden_states, output_states
|
1031 |
+
|
1032 |
+
|
1033 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
1034 |
+
def __init__(
|
1035 |
+
self,
|
1036 |
+
in_channels: int,
|
1037 |
+
attn_num_head_channels=1,
|
1038 |
+
cross_attention_dim=1280,
|
1039 |
+
use_linear_projection=False,
|
1040 |
+
):
|
1041 |
+
super().__init__()
|
1042 |
+
|
1043 |
+
self.has_cross_attention = True
|
1044 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1045 |
+
|
1046 |
+
# Middle block has two resnets and one attention
|
1047 |
+
resnets = [
|
1048 |
+
ResnetBlock2D(
|
1049 |
+
in_channels=in_channels,
|
1050 |
+
out_channels=in_channels,
|
1051 |
+
),
|
1052 |
+
ResnetBlock2D(
|
1053 |
+
in_channels=in_channels,
|
1054 |
+
out_channels=in_channels,
|
1055 |
+
),
|
1056 |
+
]
|
1057 |
+
attentions = [
|
1058 |
+
Transformer2DModel(
|
1059 |
+
attn_num_head_channels,
|
1060 |
+
in_channels // attn_num_head_channels,
|
1061 |
+
in_channels=in_channels,
|
1062 |
+
cross_attention_dim=cross_attention_dim,
|
1063 |
+
use_linear_projection=use_linear_projection,
|
1064 |
+
)
|
1065 |
+
]
|
1066 |
+
|
1067 |
+
self.attentions = nn.ModuleList(attentions)
|
1068 |
+
self.resnets = nn.ModuleList(resnets)
|
1069 |
+
|
1070 |
+
self.gradient_checkpointing = False
|
1071 |
+
|
1072 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1073 |
+
for attn in self.attentions:
|
1074 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
1075 |
+
|
1076 |
+
def set_use_sdpa(self, sdpa):
|
1077 |
+
for attn in self.attentions:
|
1078 |
+
attn.set_use_sdpa(sdpa)
|
1079 |
+
|
1080 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
1081 |
+
for i, resnet in enumerate(self.resnets):
|
1082 |
+
attn = None if i == 0 else self.attentions[i - 1]
|
1083 |
+
|
1084 |
+
if self.training and self.gradient_checkpointing:
|
1085 |
+
|
1086 |
+
def create_custom_forward(module, return_dict=None):
|
1087 |
+
def custom_forward(*inputs):
|
1088 |
+
if return_dict is not None:
|
1089 |
+
return module(*inputs, return_dict=return_dict)
|
1090 |
+
else:
|
1091 |
+
return module(*inputs)
|
1092 |
+
|
1093 |
+
return custom_forward
|
1094 |
+
|
1095 |
+
if attn is not None:
|
1096 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1097 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1098 |
+
)[0]
|
1099 |
+
|
1100 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1101 |
+
else:
|
1102 |
+
if attn is not None:
|
1103 |
+
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
1104 |
+
hidden_states = resnet(hidden_states, temb)
|
1105 |
+
|
1106 |
+
return hidden_states
|
1107 |
+
|
1108 |
+
|
1109 |
+
class Upsample2D(nn.Module):
|
1110 |
+
def __init__(self, channels, out_channels):
|
1111 |
+
super().__init__()
|
1112 |
+
self.channels = channels
|
1113 |
+
self.out_channels = out_channels
|
1114 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
1115 |
+
|
1116 |
+
def forward(self, hidden_states, output_size):
|
1117 |
+
assert hidden_states.shape[1] == self.channels
|
1118 |
+
|
1119 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
1120 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
1121 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
1122 |
+
dtype = hidden_states.dtype
|
1123 |
+
if dtype == torch.bfloat16:
|
1124 |
+
hidden_states = hidden_states.to(torch.float32)
|
1125 |
+
|
1126 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
1127 |
+
if hidden_states.shape[0] >= 64:
|
1128 |
+
hidden_states = hidden_states.contiguous()
|
1129 |
+
|
1130 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
1131 |
+
if output_size is None:
|
1132 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
1133 |
+
else:
|
1134 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
1135 |
+
|
1136 |
+
# If the input is bfloat16, we cast back to bfloat16
|
1137 |
+
if dtype == torch.bfloat16:
|
1138 |
+
hidden_states = hidden_states.to(dtype)
|
1139 |
+
|
1140 |
+
hidden_states = self.conv(hidden_states)
|
1141 |
+
|
1142 |
+
return hidden_states
|
1143 |
+
|
1144 |
+
|
1145 |
+
class UpBlock2D(nn.Module):
|
1146 |
+
def __init__(
|
1147 |
+
self,
|
1148 |
+
in_channels: int,
|
1149 |
+
prev_output_channel: int,
|
1150 |
+
out_channels: int,
|
1151 |
+
add_upsample=True,
|
1152 |
+
):
|
1153 |
+
super().__init__()
|
1154 |
+
|
1155 |
+
self.has_cross_attention = False
|
1156 |
+
resnets = []
|
1157 |
+
|
1158 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
1159 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
1160 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1161 |
+
|
1162 |
+
resnets.append(
|
1163 |
+
ResnetBlock2D(
|
1164 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1165 |
+
out_channels=out_channels,
|
1166 |
+
)
|
1167 |
+
)
|
1168 |
+
|
1169 |
+
self.resnets = nn.ModuleList(resnets)
|
1170 |
+
|
1171 |
+
if add_upsample:
|
1172 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
1173 |
+
else:
|
1174 |
+
self.upsamplers = None
|
1175 |
+
|
1176 |
+
self.gradient_checkpointing = False
|
1177 |
+
|
1178 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1179 |
+
pass
|
1180 |
+
|
1181 |
+
def set_use_sdpa(self, sdpa):
|
1182 |
+
pass
|
1183 |
+
|
1184 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1185 |
+
for resnet in self.resnets:
|
1186 |
+
# pop res hidden states
|
1187 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1188 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1189 |
+
|
1190 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1191 |
+
|
1192 |
+
if self.training and self.gradient_checkpointing:
|
1193 |
+
|
1194 |
+
def create_custom_forward(module):
|
1195 |
+
def custom_forward(*inputs):
|
1196 |
+
return module(*inputs)
|
1197 |
+
|
1198 |
+
return custom_forward
|
1199 |
+
|
1200 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1201 |
+
else:
|
1202 |
+
hidden_states = resnet(hidden_states, temb)
|
1203 |
+
|
1204 |
+
if self.upsamplers is not None:
|
1205 |
+
for upsampler in self.upsamplers:
|
1206 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1207 |
+
|
1208 |
+
return hidden_states
|
1209 |
+
|
1210 |
+
|
1211 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1212 |
+
def __init__(
|
1213 |
+
self,
|
1214 |
+
in_channels: int,
|
1215 |
+
out_channels: int,
|
1216 |
+
prev_output_channel: int,
|
1217 |
+
attn_num_head_channels=1,
|
1218 |
+
cross_attention_dim=1280,
|
1219 |
+
add_upsample=True,
|
1220 |
+
use_linear_projection=False,
|
1221 |
+
upcast_attention=False,
|
1222 |
+
):
|
1223 |
+
super().__init__()
|
1224 |
+
resnets = []
|
1225 |
+
attentions = []
|
1226 |
+
|
1227 |
+
self.has_cross_attention = True
|
1228 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1229 |
+
|
1230 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
1231 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
1232 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1233 |
+
|
1234 |
+
resnets.append(
|
1235 |
+
ResnetBlock2D(
|
1236 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1237 |
+
out_channels=out_channels,
|
1238 |
+
)
|
1239 |
+
)
|
1240 |
+
attentions.append(
|
1241 |
+
Transformer2DModel(
|
1242 |
+
attn_num_head_channels,
|
1243 |
+
out_channels // attn_num_head_channels,
|
1244 |
+
in_channels=out_channels,
|
1245 |
+
cross_attention_dim=cross_attention_dim,
|
1246 |
+
use_linear_projection=use_linear_projection,
|
1247 |
+
upcast_attention=upcast_attention,
|
1248 |
+
)
|
1249 |
+
)
|
1250 |
+
|
1251 |
+
self.attentions = nn.ModuleList(attentions)
|
1252 |
+
self.resnets = nn.ModuleList(resnets)
|
1253 |
+
|
1254 |
+
if add_upsample:
|
1255 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
1256 |
+
else:
|
1257 |
+
self.upsamplers = None
|
1258 |
+
|
1259 |
+
self.gradient_checkpointing = False
|
1260 |
+
|
1261 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1262 |
+
for attn in self.attentions:
|
1263 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
1264 |
+
|
1265 |
+
def set_use_sdpa(self, spda):
|
1266 |
+
for attn in self.attentions:
|
1267 |
+
attn.set_use_sdpa(spda)
|
1268 |
+
|
1269 |
+
def forward(
|
1270 |
+
self,
|
1271 |
+
hidden_states,
|
1272 |
+
res_hidden_states_tuple,
|
1273 |
+
temb=None,
|
1274 |
+
encoder_hidden_states=None,
|
1275 |
+
upsample_size=None,
|
1276 |
+
):
|
1277 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1278 |
+
# pop res hidden states
|
1279 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1280 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1281 |
+
|
1282 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1283 |
+
|
1284 |
+
if self.training and self.gradient_checkpointing:
|
1285 |
+
|
1286 |
+
def create_custom_forward(module, return_dict=None):
|
1287 |
+
def custom_forward(*inputs):
|
1288 |
+
if return_dict is not None:
|
1289 |
+
return module(*inputs, return_dict=return_dict)
|
1290 |
+
else:
|
1291 |
+
return module(*inputs)
|
1292 |
+
|
1293 |
+
return custom_forward
|
1294 |
+
|
1295 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1296 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1297 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1298 |
+
)[0]
|
1299 |
+
else:
|
1300 |
+
hidden_states = resnet(hidden_states, temb)
|
1301 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1302 |
+
|
1303 |
+
if self.upsamplers is not None:
|
1304 |
+
for upsampler in self.upsamplers:
|
1305 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1306 |
+
|
1307 |
+
return hidden_states
|
1308 |
+
|
1309 |
+
|
1310 |
+
def get_down_block(
|
1311 |
+
down_block_type,
|
1312 |
+
in_channels,
|
1313 |
+
out_channels,
|
1314 |
+
add_downsample,
|
1315 |
+
attn_num_head_channels,
|
1316 |
+
cross_attention_dim,
|
1317 |
+
use_linear_projection,
|
1318 |
+
upcast_attention,
|
1319 |
+
):
|
1320 |
+
if down_block_type == "DownBlock2D":
|
1321 |
+
return DownBlock2D(
|
1322 |
+
in_channels=in_channels,
|
1323 |
+
out_channels=out_channels,
|
1324 |
+
add_downsample=add_downsample,
|
1325 |
+
)
|
1326 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
1327 |
+
return CrossAttnDownBlock2D(
|
1328 |
+
in_channels=in_channels,
|
1329 |
+
out_channels=out_channels,
|
1330 |
+
add_downsample=add_downsample,
|
1331 |
+
cross_attention_dim=cross_attention_dim,
|
1332 |
+
attn_num_head_channels=attn_num_head_channels,
|
1333 |
+
use_linear_projection=use_linear_projection,
|
1334 |
+
upcast_attention=upcast_attention,
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
|
1338 |
+
def get_up_block(
|
1339 |
+
up_block_type,
|
1340 |
+
in_channels,
|
1341 |
+
out_channels,
|
1342 |
+
prev_output_channel,
|
1343 |
+
add_upsample,
|
1344 |
+
attn_num_head_channels,
|
1345 |
+
cross_attention_dim=None,
|
1346 |
+
use_linear_projection=False,
|
1347 |
+
upcast_attention=False,
|
1348 |
+
):
|
1349 |
+
if up_block_type == "UpBlock2D":
|
1350 |
+
return UpBlock2D(
|
1351 |
+
in_channels=in_channels,
|
1352 |
+
prev_output_channel=prev_output_channel,
|
1353 |
+
out_channels=out_channels,
|
1354 |
+
add_upsample=add_upsample,
|
1355 |
+
)
|
1356 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
1357 |
+
return CrossAttnUpBlock2D(
|
1358 |
+
in_channels=in_channels,
|
1359 |
+
out_channels=out_channels,
|
1360 |
+
prev_output_channel=prev_output_channel,
|
1361 |
+
attn_num_head_channels=attn_num_head_channels,
|
1362 |
+
cross_attention_dim=cross_attention_dim,
|
1363 |
+
add_upsample=add_upsample,
|
1364 |
+
use_linear_projection=use_linear_projection,
|
1365 |
+
upcast_attention=upcast_attention,
|
1366 |
+
)
|
1367 |
+
|
1368 |
+
|
1369 |
+
class UNet2DConditionModel(nn.Module):
|
1370 |
+
_supports_gradient_checkpointing = True
|
1371 |
+
|
1372 |
+
def __init__(
|
1373 |
+
self,
|
1374 |
+
sample_size: Optional[int] = None,
|
1375 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
1376 |
+
cross_attention_dim: int = 1280,
|
1377 |
+
use_linear_projection: bool = False,
|
1378 |
+
upcast_attention: bool = False,
|
1379 |
+
**kwargs,
|
1380 |
+
):
|
1381 |
+
super().__init__()
|
1382 |
+
assert sample_size is not None, "sample_size must be specified"
|
1383 |
+
print(
|
1384 |
+
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
1385 |
+
)
|
1386 |
+
|
1387 |
+
# 外部からの参照用に定義しておく
|
1388 |
+
self.in_channels = IN_CHANNELS
|
1389 |
+
self.out_channels = OUT_CHANNELS
|
1390 |
+
|
1391 |
+
self.sample_size = sample_size
|
1392 |
+
self.prepare_config(sample_size=sample_size)
|
1393 |
+
|
1394 |
+
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
1395 |
+
|
1396 |
+
# input
|
1397 |
+
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
|
1398 |
+
|
1399 |
+
# time
|
1400 |
+
self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
|
1401 |
+
|
1402 |
+
self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
|
1403 |
+
|
1404 |
+
self.down_blocks = nn.ModuleList([])
|
1405 |
+
self.mid_block = None
|
1406 |
+
self.up_blocks = nn.ModuleList([])
|
1407 |
+
|
1408 |
+
if isinstance(attention_head_dim, int):
|
1409 |
+
attention_head_dim = (attention_head_dim,) * 4
|
1410 |
+
|
1411 |
+
# down
|
1412 |
+
output_channel = BLOCK_OUT_CHANNELS[0]
|
1413 |
+
for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
|
1414 |
+
input_channel = output_channel
|
1415 |
+
output_channel = BLOCK_OUT_CHANNELS[i]
|
1416 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
1417 |
+
|
1418 |
+
down_block = get_down_block(
|
1419 |
+
down_block_type,
|
1420 |
+
in_channels=input_channel,
|
1421 |
+
out_channels=output_channel,
|
1422 |
+
add_downsample=not is_final_block,
|
1423 |
+
attn_num_head_channels=attention_head_dim[i],
|
1424 |
+
cross_attention_dim=cross_attention_dim,
|
1425 |
+
use_linear_projection=use_linear_projection,
|
1426 |
+
upcast_attention=upcast_attention,
|
1427 |
+
)
|
1428 |
+
self.down_blocks.append(down_block)
|
1429 |
+
|
1430 |
+
# mid
|
1431 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
1432 |
+
in_channels=BLOCK_OUT_CHANNELS[-1],
|
1433 |
+
attn_num_head_channels=attention_head_dim[-1],
|
1434 |
+
cross_attention_dim=cross_attention_dim,
|
1435 |
+
use_linear_projection=use_linear_projection,
|
1436 |
+
)
|
1437 |
+
|
1438 |
+
# count how many layers upsample the images
|
1439 |
+
self.num_upsamplers = 0
|
1440 |
+
|
1441 |
+
# up
|
1442 |
+
reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
|
1443 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
1444 |
+
output_channel = reversed_block_out_channels[0]
|
1445 |
+
for i, up_block_type in enumerate(UP_BLOCK_TYPES):
|
1446 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
1447 |
+
|
1448 |
+
prev_output_channel = output_channel
|
1449 |
+
output_channel = reversed_block_out_channels[i]
|
1450 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
|
1451 |
+
|
1452 |
+
# add upsample block for all BUT final layer
|
1453 |
+
if not is_final_block:
|
1454 |
+
add_upsample = True
|
1455 |
+
self.num_upsamplers += 1
|
1456 |
+
else:
|
1457 |
+
add_upsample = False
|
1458 |
+
|
1459 |
+
up_block = get_up_block(
|
1460 |
+
up_block_type,
|
1461 |
+
in_channels=input_channel,
|
1462 |
+
out_channels=output_channel,
|
1463 |
+
prev_output_channel=prev_output_channel,
|
1464 |
+
add_upsample=add_upsample,
|
1465 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
1466 |
+
cross_attention_dim=cross_attention_dim,
|
1467 |
+
use_linear_projection=use_linear_projection,
|
1468 |
+
upcast_attention=upcast_attention,
|
1469 |
+
)
|
1470 |
+
self.up_blocks.append(up_block)
|
1471 |
+
prev_output_channel = output_channel
|
1472 |
+
|
1473 |
+
# out
|
1474 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
|
1475 |
+
self.conv_act = nn.SiLU()
|
1476 |
+
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
1477 |
+
|
1478 |
+
# region diffusers compatibility
|
1479 |
+
def prepare_config(self, *args, **kwargs):
|
1480 |
+
self.config = SimpleNamespace(**kwargs)
|
1481 |
+
|
1482 |
+
@property
|
1483 |
+
def dtype(self) -> torch.dtype:
|
1484 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1485 |
+
return get_parameter_dtype(self)
|
1486 |
+
|
1487 |
+
@property
|
1488 |
+
def device(self) -> torch.device:
|
1489 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
1490 |
+
return get_parameter_device(self)
|
1491 |
+
|
1492 |
+
def set_attention_slice(self, slice_size):
|
1493 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
1494 |
+
|
1495 |
+
def is_gradient_checkpointing(self) -> bool:
|
1496 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
1497 |
+
|
1498 |
+
def enable_gradient_checkpointing(self):
|
1499 |
+
self.set_gradient_checkpointing(value=True)
|
1500 |
+
|
1501 |
+
def disable_gradient_checkpointing(self):
|
1502 |
+
self.set_gradient_checkpointing(value=False)
|
1503 |
+
|
1504 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
1505 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1506 |
+
for module in modules:
|
1507 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
1508 |
+
|
1509 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
1510 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1511 |
+
for module in modules:
|
1512 |
+
module.set_use_sdpa(sdpa)
|
1513 |
+
|
1514 |
+
def set_gradient_checkpointing(self, value=False):
|
1515 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1516 |
+
for module in modules:
|
1517 |
+
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
1518 |
+
module.gradient_checkpointing = value
|
1519 |
+
|
1520 |
+
# endregion
|
1521 |
+
|
1522 |
+
def forward(
|
1523 |
+
self,
|
1524 |
+
sample: torch.FloatTensor,
|
1525 |
+
timestep: Union[torch.Tensor, float, int],
|
1526 |
+
encoder_hidden_states: torch.Tensor,
|
1527 |
+
class_labels: Optional[torch.Tensor] = None,
|
1528 |
+
return_dict: bool = True,
|
1529 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1530 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1531 |
+
) -> Union[Dict, Tuple]:
|
1532 |
+
r"""
|
1533 |
+
Args:
|
1534 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
1535 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
1536 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
1537 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1538 |
+
Whether or not to return a dict instead of a plain tuple.
|
1539 |
+
|
1540 |
+
Returns:
|
1541 |
+
`SampleOutput` or `tuple`:
|
1542 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
1543 |
+
"""
|
1544 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1545 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
1546 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1547 |
+
# on the fly if necessary.
|
1548 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
1549 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
1550 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
1551 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
1552 |
+
|
1553 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1554 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
1555 |
+
forward_upsample_size = False
|
1556 |
+
upsample_size = None
|
1557 |
+
|
1558 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
1559 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
1560 |
+
forward_upsample_size = True
|
1561 |
+
|
1562 |
+
# 1. time
|
1563 |
+
timesteps = timestep
|
1564 |
+
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
1565 |
+
|
1566 |
+
t_emb = self.time_proj(timesteps)
|
1567 |
+
|
1568 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
1569 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1570 |
+
# there might be better ways to encapsulate this.
|
1571 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
1572 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
1573 |
+
# time_projでキャストしておけばいいんじゃね?
|
1574 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
1575 |
+
emb = self.time_embedding(t_emb)
|
1576 |
+
|
1577 |
+
# 2. pre-process
|
1578 |
+
sample = self.conv_in(sample)
|
1579 |
+
|
1580 |
+
down_block_res_samples = (sample,)
|
1581 |
+
for downsample_block in self.down_blocks:
|
1582 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
1583 |
+
# まあこちらのほうがわかりやすいかもしれない
|
1584 |
+
if downsample_block.has_cross_attention:
|
1585 |
+
sample, res_samples = downsample_block(
|
1586 |
+
hidden_states=sample,
|
1587 |
+
temb=emb,
|
1588 |
+
encoder_hidden_states=encoder_hidden_states,
|
1589 |
+
)
|
1590 |
+
else:
|
1591 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1592 |
+
|
1593 |
+
down_block_res_samples += res_samples
|
1594 |
+
|
1595 |
+
# skip connectionにControlNetの出力を追加する
|
1596 |
+
if down_block_additional_residuals is not None:
|
1597 |
+
down_block_res_samples = list(down_block_res_samples)
|
1598 |
+
for i in range(len(down_block_res_samples)):
|
1599 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
1600 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
1601 |
+
|
1602 |
+
# 4. mid
|
1603 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
1604 |
+
|
1605 |
+
# ControlNetの出力を追加する
|
1606 |
+
if mid_block_additional_residual is not None:
|
1607 |
+
sample += mid_block_additional_residual
|
1608 |
+
|
1609 |
+
# 5. up
|
1610 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1611 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1612 |
+
|
1613 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1614 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
1615 |
+
|
1616 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
1617 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
1618 |
+
if not is_final_block and forward_upsample_size:
|
1619 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1620 |
+
|
1621 |
+
if upsample_block.has_cross_attention:
|
1622 |
+
sample = upsample_block(
|
1623 |
+
hidden_states=sample,
|
1624 |
+
temb=emb,
|
1625 |
+
res_hidden_states_tuple=res_samples,
|
1626 |
+
encoder_hidden_states=encoder_hidden_states,
|
1627 |
+
upsample_size=upsample_size,
|
1628 |
+
)
|
1629 |
+
else:
|
1630 |
+
sample = upsample_block(
|
1631 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1632 |
+
)
|
1633 |
+
|
1634 |
+
# 6. post-process
|
1635 |
+
sample = self.conv_norm_out(sample)
|
1636 |
+
sample = self.conv_act(sample)
|
1637 |
+
sample = self.conv_out(sample)
|
1638 |
+
|
1639 |
+
if not return_dict:
|
1640 |
+
return (sample,)
|
1641 |
+
|
1642 |
+
return SampleOutput(sample=sample)
|
1643 |
+
|
1644 |
+
def handle_unusual_timesteps(self, sample, timesteps):
|
1645 |
+
r"""
|
1646 |
+
timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
|
1647 |
+
"""
|
1648 |
+
if not torch.is_tensor(timesteps):
|
1649 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
1650 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
1651 |
+
is_mps = sample.device.type == "mps"
|
1652 |
+
if isinstance(timesteps, float):
|
1653 |
+
dtype = torch.float32 if is_mps else torch.float64
|
1654 |
+
else:
|
1655 |
+
dtype = torch.int32 if is_mps else torch.int64
|
1656 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
1657 |
+
elif len(timesteps.shape) == 0:
|
1658 |
+
timesteps = timesteps[None].to(sample.device)
|
1659 |
+
|
1660 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1661 |
+
timesteps = timesteps.expand(sample.shape[0])
|
1662 |
+
|
1663 |
+
return timesteps
|
1664 |
+
|
1665 |
+
|
1666 |
+
class InferUNet2DConditionModel:
|
1667 |
+
def __init__(self, original_unet: UNet2DConditionModel):
|
1668 |
+
self.delegate = original_unet
|
1669 |
+
|
1670 |
+
# override original model's forward method: because forward is not called by `__call__`
|
1671 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
1672 |
+
self.delegate.forward = self.forward
|
1673 |
+
|
1674 |
+
# override original model's up blocks' forward method
|
1675 |
+
for up_block in self.delegate.up_blocks:
|
1676 |
+
if up_block.__class__.__name__ == "UpBlock2D":
|
1677 |
+
|
1678 |
+
def resnet_wrapper(func, block):
|
1679 |
+
def forward(*args, **kwargs):
|
1680 |
+
return func(block, *args, **kwargs)
|
1681 |
+
|
1682 |
+
return forward
|
1683 |
+
|
1684 |
+
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
1685 |
+
|
1686 |
+
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
1687 |
+
|
1688 |
+
def cross_attn_up_wrapper(func, block):
|
1689 |
+
def forward(*args, **kwargs):
|
1690 |
+
return func(block, *args, **kwargs)
|
1691 |
+
|
1692 |
+
return forward
|
1693 |
+
|
1694 |
+
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
1695 |
+
|
1696 |
+
# Deep Shrink
|
1697 |
+
self.ds_depth_1 = None
|
1698 |
+
self.ds_depth_2 = None
|
1699 |
+
self.ds_timesteps_1 = None
|
1700 |
+
self.ds_timesteps_2 = None
|
1701 |
+
self.ds_ratio = None
|
1702 |
+
|
1703 |
+
# call original model's methods
|
1704 |
+
def __getattr__(self, name):
|
1705 |
+
return getattr(self.delegate, name)
|
1706 |
+
|
1707 |
+
def __call__(self, *args, **kwargs):
|
1708 |
+
return self.delegate(*args, **kwargs)
|
1709 |
+
|
1710 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
1711 |
+
if ds_depth_1 is None:
|
1712 |
+
print("Deep Shrink is disabled.")
|
1713 |
+
self.ds_depth_1 = None
|
1714 |
+
self.ds_timesteps_1 = None
|
1715 |
+
self.ds_depth_2 = None
|
1716 |
+
self.ds_timesteps_2 = None
|
1717 |
+
self.ds_ratio = None
|
1718 |
+
else:
|
1719 |
+
print(
|
1720 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
1721 |
+
)
|
1722 |
+
self.ds_depth_1 = ds_depth_1
|
1723 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
1724 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
1725 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
1726 |
+
self.ds_ratio = ds_ratio
|
1727 |
+
|
1728 |
+
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1729 |
+
for resnet in _self.resnets:
|
1730 |
+
# pop res hidden states
|
1731 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1732 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1733 |
+
|
1734 |
+
# Deep Shrink
|
1735 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
1736 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
1737 |
+
|
1738 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1739 |
+
hidden_states = resnet(hidden_states, temb)
|
1740 |
+
|
1741 |
+
if _self.upsamplers is not None:
|
1742 |
+
for upsampler in _self.upsamplers:
|
1743 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1744 |
+
|
1745 |
+
return hidden_states
|
1746 |
+
|
1747 |
+
def cross_attn_up_block_forward(
|
1748 |
+
self,
|
1749 |
+
_self,
|
1750 |
+
hidden_states,
|
1751 |
+
res_hidden_states_tuple,
|
1752 |
+
temb=None,
|
1753 |
+
encoder_hidden_states=None,
|
1754 |
+
upsample_size=None,
|
1755 |
+
):
|
1756 |
+
for resnet, attn in zip(_self.resnets, _self.attentions):
|
1757 |
+
# pop res hidden states
|
1758 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1759 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1760 |
+
|
1761 |
+
# Deep Shrink
|
1762 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
1763 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
1764 |
+
|
1765 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1766 |
+
hidden_states = resnet(hidden_states, temb)
|
1767 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1768 |
+
|
1769 |
+
if _self.upsamplers is not None:
|
1770 |
+
for upsampler in _self.upsamplers:
|
1771 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1772 |
+
|
1773 |
+
return hidden_states
|
1774 |
+
|
1775 |
+
def forward(
|
1776 |
+
self,
|
1777 |
+
sample: torch.FloatTensor,
|
1778 |
+
timestep: Union[torch.Tensor, float, int],
|
1779 |
+
encoder_hidden_states: torch.Tensor,
|
1780 |
+
class_labels: Optional[torch.Tensor] = None,
|
1781 |
+
return_dict: bool = True,
|
1782 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1783 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1784 |
+
) -> Union[Dict, Tuple]:
|
1785 |
+
r"""
|
1786 |
+
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
1787 |
+
"""
|
1788 |
+
|
1789 |
+
r"""
|
1790 |
+
Args:
|
1791 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
1792 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
1793 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
1794 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1795 |
+
Whether or not to return a dict instead of a plain tuple.
|
1796 |
+
|
1797 |
+
Returns:
|
1798 |
+
`SampleOutput` or `tuple`:
|
1799 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
1800 |
+
"""
|
1801 |
+
|
1802 |
+
_self = self.delegate
|
1803 |
+
|
1804 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1805 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
1806 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1807 |
+
# on the fly if necessary.
|
1808 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
1809 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
1810 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
1811 |
+
default_overall_up_factor = 2**_self.num_upsamplers
|
1812 |
+
|
1813 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1814 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
1815 |
+
forward_upsample_size = False
|
1816 |
+
upsample_size = None
|
1817 |
+
|
1818 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
1819 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
1820 |
+
forward_upsample_size = True
|
1821 |
+
|
1822 |
+
# 1. time
|
1823 |
+
timesteps = timestep
|
1824 |
+
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
1825 |
+
|
1826 |
+
t_emb = _self.time_proj(timesteps)
|
1827 |
+
|
1828 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
1829 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1830 |
+
# there might be better ways to encapsulate this.
|
1831 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
1832 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
1833 |
+
# time_projでキャストしておけばいいんじゃね?
|
1834 |
+
t_emb = t_emb.to(dtype=_self.dtype)
|
1835 |
+
emb = _self.time_embedding(t_emb)
|
1836 |
+
|
1837 |
+
# 2. pre-process
|
1838 |
+
sample = _self.conv_in(sample)
|
1839 |
+
|
1840 |
+
down_block_res_samples = (sample,)
|
1841 |
+
for depth, downsample_block in enumerate(_self.down_blocks):
|
1842 |
+
# Deep Shrink
|
1843 |
+
if self.ds_depth_1 is not None:
|
1844 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
1845 |
+
self.ds_depth_2 is not None
|
1846 |
+
and depth == self.ds_depth_2
|
1847 |
+
and timesteps[0] < self.ds_timesteps_1
|
1848 |
+
and timesteps[0] >= self.ds_timesteps_2
|
1849 |
+
):
|
1850 |
+
org_dtype = sample.dtype
|
1851 |
+
if org_dtype == torch.bfloat16:
|
1852 |
+
sample = sample.to(torch.float32)
|
1853 |
+
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
1854 |
+
|
1855 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
1856 |
+
# まあこちらのほうがわかりやすいかもしれない
|
1857 |
+
if downsample_block.has_cross_attention:
|
1858 |
+
sample, res_samples = downsample_block(
|
1859 |
+
hidden_states=sample,
|
1860 |
+
temb=emb,
|
1861 |
+
encoder_hidden_states=encoder_hidden_states,
|
1862 |
+
)
|
1863 |
+
else:
|
1864 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1865 |
+
|
1866 |
+
down_block_res_samples += res_samples
|
1867 |
+
|
1868 |
+
# skip connectionにControlNetの出力を追加する
|
1869 |
+
if down_block_additional_residuals is not None:
|
1870 |
+
down_block_res_samples = list(down_block_res_samples)
|
1871 |
+
for i in range(len(down_block_res_samples)):
|
1872 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
1873 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
1874 |
+
|
1875 |
+
# 4. mid
|
1876 |
+
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
1877 |
+
|
1878 |
+
# ControlNetの出力を追加する
|
1879 |
+
if mid_block_additional_residual is not None:
|
1880 |
+
sample += mid_block_additional_residual
|
1881 |
+
|
1882 |
+
# 5. up
|
1883 |
+
for i, upsample_block in enumerate(_self.up_blocks):
|
1884 |
+
is_final_block = i == len(_self.up_blocks) - 1
|
1885 |
+
|
1886 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1887 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
1888 |
+
|
1889 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
1890 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
1891 |
+
if not is_final_block and forward_upsample_size:
|
1892 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1893 |
+
|
1894 |
+
if upsample_block.has_cross_attention:
|
1895 |
+
sample = upsample_block(
|
1896 |
+
hidden_states=sample,
|
1897 |
+
temb=emb,
|
1898 |
+
res_hidden_states_tuple=res_samples,
|
1899 |
+
encoder_hidden_states=encoder_hidden_states,
|
1900 |
+
upsample_size=upsample_size,
|
1901 |
+
)
|
1902 |
+
else:
|
1903 |
+
sample = upsample_block(
|
1904 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1905 |
+
)
|
1906 |
+
|
1907 |
+
# 6. post-process
|
1908 |
+
sample = _self.conv_norm_out(sample)
|
1909 |
+
sample = _self.conv_act(sample)
|
1910 |
+
sample = _self.conv_out(sample)
|
1911 |
+
|
1912 |
+
if not return_dict:
|
1913 |
+
return (sample,)
|
1914 |
+
|
1915 |
+
return SampleOutput(sample=sample)
|
external/llite/library/sai_model_spec.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/Stability-AI/ModelSpec
|
2 |
+
import datetime
|
3 |
+
import hashlib
|
4 |
+
from io import BytesIO
|
5 |
+
import os
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
import safetensors
|
8 |
+
|
9 |
+
r"""
|
10 |
+
# Metadata Example
|
11 |
+
metadata = {
|
12 |
+
# === Must ===
|
13 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
14 |
+
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
15 |
+
"modelspec.implementation": "sgm",
|
16 |
+
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
17 |
+
# === Should ===
|
18 |
+
"modelspec.author": "Example Corp", # Your name or company name
|
19 |
+
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
20 |
+
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
21 |
+
# === Can ===
|
22 |
+
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
23 |
+
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
24 |
+
}
|
25 |
+
"""
|
26 |
+
|
27 |
+
BASE_METADATA = {
|
28 |
+
# === Must ===
|
29 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
30 |
+
"modelspec.architecture": None,
|
31 |
+
"modelspec.implementation": None,
|
32 |
+
"modelspec.title": None,
|
33 |
+
"modelspec.resolution": None,
|
34 |
+
# === Should ===
|
35 |
+
"modelspec.description": None,
|
36 |
+
"modelspec.author": None,
|
37 |
+
"modelspec.date": None,
|
38 |
+
# === Can ===
|
39 |
+
"modelspec.license": None,
|
40 |
+
"modelspec.tags": None,
|
41 |
+
"modelspec.merged_from": None,
|
42 |
+
"modelspec.prediction_type": None,
|
43 |
+
"modelspec.timestep_range": None,
|
44 |
+
"modelspec.encoder_layer": None,
|
45 |
+
}
|
46 |
+
|
47 |
+
# 別に使うやつだけ定義
|
48 |
+
MODELSPEC_TITLE = "modelspec.title"
|
49 |
+
|
50 |
+
ARCH_SD_V1 = "stable-diffusion-v1"
|
51 |
+
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
52 |
+
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
53 |
+
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
54 |
+
|
55 |
+
ADAPTER_LORA = "lora"
|
56 |
+
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
57 |
+
|
58 |
+
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
59 |
+
IMPL_DIFFUSERS = "diffusers"
|
60 |
+
|
61 |
+
PRED_TYPE_EPSILON = "epsilon"
|
62 |
+
PRED_TYPE_V = "v"
|
63 |
+
|
64 |
+
|
65 |
+
def load_bytes_in_safetensors(tensors):
|
66 |
+
bytes = safetensors.torch.save(tensors)
|
67 |
+
b = BytesIO(bytes)
|
68 |
+
|
69 |
+
b.seek(0)
|
70 |
+
header = b.read(8)
|
71 |
+
n = int.from_bytes(header, "little")
|
72 |
+
|
73 |
+
offset = n + 8
|
74 |
+
b.seek(offset)
|
75 |
+
|
76 |
+
return b.read()
|
77 |
+
|
78 |
+
|
79 |
+
def precalculate_safetensors_hashes(state_dict):
|
80 |
+
# calculate each tensor one by one to reduce memory usage
|
81 |
+
hash_sha256 = hashlib.sha256()
|
82 |
+
for tensor in state_dict.values():
|
83 |
+
single_tensor_sd = {"tensor": tensor}
|
84 |
+
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
85 |
+
hash_sha256.update(bytes_for_tensor)
|
86 |
+
|
87 |
+
return f"0x{hash_sha256.hexdigest()}"
|
88 |
+
|
89 |
+
|
90 |
+
def update_hash_sha256(metadata: dict, state_dict: dict):
|
91 |
+
raise NotImplementedError
|
92 |
+
|
93 |
+
|
94 |
+
def build_metadata(
|
95 |
+
state_dict: Optional[dict],
|
96 |
+
v2: bool,
|
97 |
+
v_parameterization: bool,
|
98 |
+
sdxl: bool,
|
99 |
+
lora: bool,
|
100 |
+
textual_inversion: bool,
|
101 |
+
timestamp: float,
|
102 |
+
title: Optional[str] = None,
|
103 |
+
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
104 |
+
is_stable_diffusion_ckpt: Optional[bool] = None,
|
105 |
+
author: Optional[str] = None,
|
106 |
+
description: Optional[str] = None,
|
107 |
+
license: Optional[str] = None,
|
108 |
+
tags: Optional[str] = None,
|
109 |
+
merged_from: Optional[str] = None,
|
110 |
+
timesteps: Optional[Tuple[int, int]] = None,
|
111 |
+
clip_skip: Optional[int] = None,
|
112 |
+
):
|
113 |
+
# if state_dict is None, hash is not calculated
|
114 |
+
|
115 |
+
metadata = {}
|
116 |
+
metadata.update(BASE_METADATA)
|
117 |
+
|
118 |
+
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
119 |
+
# if state_dict is not None:
|
120 |
+
# hash = precalculate_safetensors_hashes(state_dict)
|
121 |
+
# metadata["modelspec.hash_sha256"] = hash
|
122 |
+
|
123 |
+
if sdxl:
|
124 |
+
arch = ARCH_SD_XL_V1_BASE
|
125 |
+
elif v2:
|
126 |
+
if v_parameterization:
|
127 |
+
arch = ARCH_SD_V2_768_V
|
128 |
+
else:
|
129 |
+
arch = ARCH_SD_V2_512
|
130 |
+
else:
|
131 |
+
arch = ARCH_SD_V1
|
132 |
+
|
133 |
+
if lora:
|
134 |
+
arch += f"/{ADAPTER_LORA}"
|
135 |
+
elif textual_inversion:
|
136 |
+
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
137 |
+
|
138 |
+
metadata["modelspec.architecture"] = arch
|
139 |
+
|
140 |
+
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
141 |
+
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
142 |
+
|
143 |
+
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
144 |
+
# Stable Diffusion ckpt, TI, SDXL LoRA
|
145 |
+
impl = IMPL_STABILITY_AI
|
146 |
+
else:
|
147 |
+
# v1/v2 LoRA or Diffusers
|
148 |
+
impl = IMPL_DIFFUSERS
|
149 |
+
metadata["modelspec.implementation"] = impl
|
150 |
+
|
151 |
+
if title is None:
|
152 |
+
if lora:
|
153 |
+
title = "LoRA"
|
154 |
+
elif textual_inversion:
|
155 |
+
title = "TextualInversion"
|
156 |
+
else:
|
157 |
+
title = "Checkpoint"
|
158 |
+
title += f"@{timestamp}"
|
159 |
+
metadata[MODELSPEC_TITLE] = title
|
160 |
+
|
161 |
+
if author is not None:
|
162 |
+
metadata["modelspec.author"] = author
|
163 |
+
else:
|
164 |
+
del metadata["modelspec.author"]
|
165 |
+
|
166 |
+
if description is not None:
|
167 |
+
metadata["modelspec.description"] = description
|
168 |
+
else:
|
169 |
+
del metadata["modelspec.description"]
|
170 |
+
|
171 |
+
if merged_from is not None:
|
172 |
+
metadata["modelspec.merged_from"] = merged_from
|
173 |
+
else:
|
174 |
+
del metadata["modelspec.merged_from"]
|
175 |
+
|
176 |
+
if license is not None:
|
177 |
+
metadata["modelspec.license"] = license
|
178 |
+
else:
|
179 |
+
del metadata["modelspec.license"]
|
180 |
+
|
181 |
+
if tags is not None:
|
182 |
+
metadata["modelspec.tags"] = tags
|
183 |
+
else:
|
184 |
+
del metadata["modelspec.tags"]
|
185 |
+
|
186 |
+
# remove microsecond from time
|
187 |
+
int_ts = int(timestamp)
|
188 |
+
|
189 |
+
# time to iso-8601 compliant date
|
190 |
+
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
191 |
+
metadata["modelspec.date"] = date
|
192 |
+
|
193 |
+
if reso is not None:
|
194 |
+
# comma separated to tuple
|
195 |
+
if isinstance(reso, str):
|
196 |
+
reso = tuple(map(int, reso.split(",")))
|
197 |
+
if len(reso) == 1:
|
198 |
+
reso = (reso[0], reso[0])
|
199 |
+
else:
|
200 |
+
# resolution is defined in dataset, so use default
|
201 |
+
if sdxl:
|
202 |
+
reso = 1024
|
203 |
+
elif v2 and v_parameterization:
|
204 |
+
reso = 768
|
205 |
+
else:
|
206 |
+
reso = 512
|
207 |
+
if isinstance(reso, int):
|
208 |
+
reso = (reso, reso)
|
209 |
+
|
210 |
+
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
211 |
+
|
212 |
+
if v_parameterization:
|
213 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
214 |
+
else:
|
215 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
216 |
+
|
217 |
+
if timesteps is not None:
|
218 |
+
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
219 |
+
timesteps = (timesteps, timesteps)
|
220 |
+
if len(timesteps) == 1:
|
221 |
+
timesteps = (timesteps[0], timesteps[0])
|
222 |
+
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
223 |
+
else:
|
224 |
+
del metadata["modelspec.timestep_range"]
|
225 |
+
|
226 |
+
if clip_skip is not None:
|
227 |
+
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
228 |
+
else:
|
229 |
+
del metadata["modelspec.encoder_layer"]
|
230 |
+
|
231 |
+
# # assert all values are filled
|
232 |
+
# assert all([v is not None for v in metadata.values()]), metadata
|
233 |
+
if not all([v is not None for v in metadata.values()]):
|
234 |
+
print(f"Internal error: some metadata values are None: {metadata}")
|
235 |
+
|
236 |
+
return metadata
|
237 |
+
|
238 |
+
|
239 |
+
# region utils
|
240 |
+
|
241 |
+
|
242 |
+
def get_title(metadata: dict) -> Optional[str]:
|
243 |
+
return metadata.get(MODELSPEC_TITLE, None)
|
244 |
+
|
245 |
+
|
246 |
+
def load_metadata_from_safetensors(model: str) -> dict:
|
247 |
+
if not model.endswith(".safetensors"):
|
248 |
+
return {}
|
249 |
+
|
250 |
+
with safetensors.safe_open(model, framework="pt") as f:
|
251 |
+
metadata = f.metadata()
|
252 |
+
if metadata is None:
|
253 |
+
metadata = {}
|
254 |
+
return metadata
|
255 |
+
|
256 |
+
|
257 |
+
def build_merged_from(models: List[str]) -> str:
|
258 |
+
def get_title(model: str):
|
259 |
+
metadata = load_metadata_from_safetensors(model)
|
260 |
+
title = metadata.get(MODELSPEC_TITLE, None)
|
261 |
+
if title is None:
|
262 |
+
title = os.path.splitext(os.path.basename(model))[0] # use filename
|
263 |
+
return title
|
264 |
+
|
265 |
+
titles = [get_title(model) for model in models]
|
266 |
+
return ", ".join(titles)
|
267 |
+
|
268 |
+
|
269 |
+
# endregion
|
270 |
+
|
271 |
+
|
272 |
+
r"""
|
273 |
+
if __name__ == "__main__":
|
274 |
+
import argparse
|
275 |
+
import torch
|
276 |
+
from safetensors.torch import load_file
|
277 |
+
from library import train_util
|
278 |
+
|
279 |
+
parser = argparse.ArgumentParser()
|
280 |
+
parser.add_argument("--ckpt", type=str, required=True)
|
281 |
+
args = parser.parse_args()
|
282 |
+
|
283 |
+
print(f"Loading {args.ckpt}")
|
284 |
+
state_dict = load_file(args.ckpt)
|
285 |
+
|
286 |
+
print(f"Calculating metadata")
|
287 |
+
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
288 |
+
print(metadata)
|
289 |
+
del state_dict
|
290 |
+
|
291 |
+
# by reference implementation
|
292 |
+
with open(args.ckpt, mode="rb") as file_data:
|
293 |
+
file_hash = hashlib.sha256()
|
294 |
+
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
295 |
+
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
296 |
+
content = (
|
297 |
+
file_data.read()
|
298 |
+
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
299 |
+
file_hash.update(content)
|
300 |
+
# ===== Update the hash for modelspec =====
|
301 |
+
by_ref = f"0x{file_hash.hexdigest()}"
|
302 |
+
print(by_ref)
|
303 |
+
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
304 |
+
|
305 |
+
"""
|
external/llite/library/sdxl_lpw_stable_diffusion.py
ADDED
@@ -0,0 +1,1342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
2 |
+
# and modify to support SD2.x
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import re
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from packaging import version
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
+
from diffusers.utils import logging
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
from external.llite.library import sdxl_model_util, sdxl_train_util, train_util
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
from diffusers.utils import PIL_INTERPOLATION
|
26 |
+
except ImportError:
|
27 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
28 |
+
PIL_INTERPOLATION = {
|
29 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
30 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
31 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
32 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
33 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
34 |
+
}
|
35 |
+
else:
|
36 |
+
PIL_INTERPOLATION = {
|
37 |
+
"linear": PIL.Image.LINEAR,
|
38 |
+
"bilinear": PIL.Image.BILINEAR,
|
39 |
+
"bicubic": PIL.Image.BICUBIC,
|
40 |
+
"lanczos": PIL.Image.LANCZOS,
|
41 |
+
"nearest": PIL.Image.NEAREST,
|
42 |
+
}
|
43 |
+
# ------------------------------------------------------------------------------
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
46 |
+
|
47 |
+
re_attention = re.compile(
|
48 |
+
r"""
|
49 |
+
\\\(|
|
50 |
+
\\\)|
|
51 |
+
\\\[|
|
52 |
+
\\]|
|
53 |
+
\\\\|
|
54 |
+
\\|
|
55 |
+
\(|
|
56 |
+
\[|
|
57 |
+
:([+-]?[.\d]+)\)|
|
58 |
+
\)|
|
59 |
+
]|
|
60 |
+
[^\\()\[\]:]+|
|
61 |
+
:
|
62 |
+
""",
|
63 |
+
re.X,
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def parse_prompt_attention(text):
|
68 |
+
"""
|
69 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
70 |
+
Accepted tokens are:
|
71 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
72 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
73 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
74 |
+
\( - literal character '('
|
75 |
+
\[ - literal character '['
|
76 |
+
\) - literal character ')'
|
77 |
+
\] - literal character ']'
|
78 |
+
\\ - literal character '\'
|
79 |
+
anything else - just text
|
80 |
+
>>> parse_prompt_attention('normal text')
|
81 |
+
[['normal text', 1.0]]
|
82 |
+
>>> parse_prompt_attention('an (important) word')
|
83 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
84 |
+
>>> parse_prompt_attention('(unbalanced')
|
85 |
+
[['unbalanced', 1.1]]
|
86 |
+
>>> parse_prompt_attention('\(literal\]')
|
87 |
+
[['(literal]', 1.0]]
|
88 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
89 |
+
[['unnecessaryparens', 1.1]]
|
90 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
91 |
+
[['a ', 1.0],
|
92 |
+
['house', 1.5730000000000004],
|
93 |
+
[' ', 1.1],
|
94 |
+
['on', 1.0],
|
95 |
+
[' a ', 1.1],
|
96 |
+
['hill', 0.55],
|
97 |
+
[', sun, ', 1.1],
|
98 |
+
['sky', 1.4641000000000006],
|
99 |
+
['.', 1.1]]
|
100 |
+
"""
|
101 |
+
|
102 |
+
res = []
|
103 |
+
round_brackets = []
|
104 |
+
square_brackets = []
|
105 |
+
|
106 |
+
round_bracket_multiplier = 1.1
|
107 |
+
square_bracket_multiplier = 1 / 1.1
|
108 |
+
|
109 |
+
def multiply_range(start_position, multiplier):
|
110 |
+
for p in range(start_position, len(res)):
|
111 |
+
res[p][1] *= multiplier
|
112 |
+
|
113 |
+
for m in re_attention.finditer(text):
|
114 |
+
text = m.group(0)
|
115 |
+
weight = m.group(1)
|
116 |
+
|
117 |
+
if text.startswith("\\"):
|
118 |
+
res.append([text[1:], 1.0])
|
119 |
+
elif text == "(":
|
120 |
+
round_brackets.append(len(res))
|
121 |
+
elif text == "[":
|
122 |
+
square_brackets.append(len(res))
|
123 |
+
elif weight is not None and len(round_brackets) > 0:
|
124 |
+
multiply_range(round_brackets.pop(), float(weight))
|
125 |
+
elif text == ")" and len(round_brackets) > 0:
|
126 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
127 |
+
elif text == "]" and len(square_brackets) > 0:
|
128 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
129 |
+
else:
|
130 |
+
res.append([text, 1.0])
|
131 |
+
|
132 |
+
for pos in round_brackets:
|
133 |
+
multiply_range(pos, round_bracket_multiplier)
|
134 |
+
|
135 |
+
for pos in square_brackets:
|
136 |
+
multiply_range(pos, square_bracket_multiplier)
|
137 |
+
|
138 |
+
if len(res) == 0:
|
139 |
+
res = [["", 1.0]]
|
140 |
+
|
141 |
+
# merge runs of identical weights
|
142 |
+
i = 0
|
143 |
+
while i + 1 < len(res):
|
144 |
+
if res[i][1] == res[i + 1][1]:
|
145 |
+
res[i][0] += res[i + 1][0]
|
146 |
+
res.pop(i + 1)
|
147 |
+
else:
|
148 |
+
i += 1
|
149 |
+
|
150 |
+
return res
|
151 |
+
|
152 |
+
|
153 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
154 |
+
r"""
|
155 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
156 |
+
|
157 |
+
No padding, starting or ending token is included.
|
158 |
+
"""
|
159 |
+
tokens = []
|
160 |
+
weights = []
|
161 |
+
truncated = False
|
162 |
+
for text in prompt:
|
163 |
+
texts_and_weights = parse_prompt_attention(text)
|
164 |
+
text_token = []
|
165 |
+
text_weight = []
|
166 |
+
for word, weight in texts_and_weights:
|
167 |
+
# tokenize and discard the starting and the ending token
|
168 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
169 |
+
text_token += token
|
170 |
+
# copy the weight by length of token
|
171 |
+
text_weight += [weight] * len(token)
|
172 |
+
# stop if the text is too long (longer than truncation limit)
|
173 |
+
if len(text_token) > max_length:
|
174 |
+
truncated = True
|
175 |
+
break
|
176 |
+
# truncate
|
177 |
+
if len(text_token) > max_length:
|
178 |
+
truncated = True
|
179 |
+
text_token = text_token[:max_length]
|
180 |
+
text_weight = text_weight[:max_length]
|
181 |
+
tokens.append(text_token)
|
182 |
+
weights.append(text_weight)
|
183 |
+
if truncated:
|
184 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
185 |
+
return tokens, weights
|
186 |
+
|
187 |
+
|
188 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
189 |
+
r"""
|
190 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
191 |
+
"""
|
192 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
193 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
194 |
+
for i in range(len(tokens)):
|
195 |
+
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
|
196 |
+
if no_boseos_middle:
|
197 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
198 |
+
else:
|
199 |
+
w = []
|
200 |
+
if len(weights[i]) == 0:
|
201 |
+
w = [1.0] * weights_length
|
202 |
+
else:
|
203 |
+
for j in range(max_embeddings_multiples):
|
204 |
+
w.append(1.0) # weight for starting token in this chunk
|
205 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
206 |
+
w.append(1.0) # weight for ending token in this chunk
|
207 |
+
w += [1.0] * (weights_length - len(w))
|
208 |
+
weights[i] = w[:]
|
209 |
+
|
210 |
+
return tokens, weights
|
211 |
+
|
212 |
+
|
213 |
+
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
|
214 |
+
if not is_sdxl_text_encoder2:
|
215 |
+
# text_encoder1: same as SD1/2
|
216 |
+
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
217 |
+
hidden_states = enc_out["hidden_states"][11]
|
218 |
+
pool = None
|
219 |
+
else:
|
220 |
+
# text_encoder2
|
221 |
+
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
222 |
+
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
|
223 |
+
# pool = enc_out["text_embeds"]
|
224 |
+
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
|
225 |
+
hidden_states = hidden_states.to(device)
|
226 |
+
if pool is not None:
|
227 |
+
pool = pool.to(device)
|
228 |
+
return hidden_states, pool
|
229 |
+
|
230 |
+
|
231 |
+
def get_unweighted_text_embeddings(
|
232 |
+
pipe: StableDiffusionPipeline,
|
233 |
+
text_input: torch.Tensor,
|
234 |
+
chunk_length: int,
|
235 |
+
clip_skip: int,
|
236 |
+
eos: int,
|
237 |
+
pad: int,
|
238 |
+
is_sdxl_text_encoder2: bool,
|
239 |
+
no_boseos_middle: Optional[bool] = True,
|
240 |
+
):
|
241 |
+
"""
|
242 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
243 |
+
it should be split into chunks and sent to the text encoder individually.
|
244 |
+
"""
|
245 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
246 |
+
text_pool = None
|
247 |
+
if max_embeddings_multiples > 1:
|
248 |
+
text_embeddings = []
|
249 |
+
for i in range(max_embeddings_multiples):
|
250 |
+
# extract the i-th chunk
|
251 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
252 |
+
|
253 |
+
# cover the head and the tail by the starting and the ending tokens
|
254 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
255 |
+
if pad == eos: # v1
|
256 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
257 |
+
else: # v2
|
258 |
+
for j in range(len(text_input_chunk)):
|
259 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
260 |
+
text_input_chunk[j, -1] = eos
|
261 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
262 |
+
text_input_chunk[j, 1] = eos
|
263 |
+
|
264 |
+
text_embedding, current_text_pool = get_hidden_states(
|
265 |
+
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
|
266 |
+
)
|
267 |
+
if text_pool is None:
|
268 |
+
text_pool = current_text_pool
|
269 |
+
|
270 |
+
if no_boseos_middle:
|
271 |
+
if i == 0:
|
272 |
+
# discard the ending token
|
273 |
+
text_embedding = text_embedding[:, :-1]
|
274 |
+
elif i == max_embeddings_multiples - 1:
|
275 |
+
# discard the starting token
|
276 |
+
text_embedding = text_embedding[:, 1:]
|
277 |
+
else:
|
278 |
+
# discard both starting and ending tokens
|
279 |
+
text_embedding = text_embedding[:, 1:-1]
|
280 |
+
|
281 |
+
text_embeddings.append(text_embedding)
|
282 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
283 |
+
else:
|
284 |
+
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
|
285 |
+
return text_embeddings, text_pool
|
286 |
+
|
287 |
+
|
288 |
+
def get_weighted_text_embeddings(
|
289 |
+
pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
|
290 |
+
prompt: Union[str, List[str]],
|
291 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
292 |
+
max_embeddings_multiples: Optional[int] = 3,
|
293 |
+
no_boseos_middle: Optional[bool] = False,
|
294 |
+
skip_parsing: Optional[bool] = False,
|
295 |
+
skip_weighting: Optional[bool] = False,
|
296 |
+
clip_skip=None,
|
297 |
+
is_sdxl_text_encoder2=False,
|
298 |
+
):
|
299 |
+
r"""
|
300 |
+
Prompts can be assigned with local weights using brackets. For example,
|
301 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
302 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
303 |
+
|
304 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
pipe (`StableDiffusionPipeline`):
|
308 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
309 |
+
prompt (`str` or `List[str]`):
|
310 |
+
The prompt or prompts to guide the image generation.
|
311 |
+
uncond_prompt (`str` or `List[str]`):
|
312 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
313 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
314 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
315 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
316 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
317 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
318 |
+
ending token in each of the chunk in the middle.
|
319 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
320 |
+
Skip the parsing of brackets.
|
321 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
322 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
323 |
+
"""
|
324 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
325 |
+
if isinstance(prompt, str):
|
326 |
+
prompt = [prompt]
|
327 |
+
|
328 |
+
if not skip_parsing:
|
329 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
330 |
+
if uncond_prompt is not None:
|
331 |
+
if isinstance(uncond_prompt, str):
|
332 |
+
uncond_prompt = [uncond_prompt]
|
333 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
334 |
+
else:
|
335 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
336 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
337 |
+
if uncond_prompt is not None:
|
338 |
+
if isinstance(uncond_prompt, str):
|
339 |
+
uncond_prompt = [uncond_prompt]
|
340 |
+
uncond_tokens = [
|
341 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
342 |
+
]
|
343 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
344 |
+
|
345 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
346 |
+
max_length = max([len(token) for token in prompt_tokens])
|
347 |
+
if uncond_prompt is not None:
|
348 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
349 |
+
|
350 |
+
max_embeddings_multiples = min(
|
351 |
+
max_embeddings_multiples,
|
352 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
353 |
+
)
|
354 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
355 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
356 |
+
|
357 |
+
# pad the length of tokens and weights
|
358 |
+
bos = pipe.tokenizer.bos_token_id
|
359 |
+
eos = pipe.tokenizer.eos_token_id
|
360 |
+
pad = pipe.tokenizer.pad_token_id
|
361 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
362 |
+
prompt_tokens,
|
363 |
+
prompt_weights,
|
364 |
+
max_length,
|
365 |
+
bos,
|
366 |
+
eos,
|
367 |
+
pad,
|
368 |
+
no_boseos_middle=no_boseos_middle,
|
369 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
370 |
+
)
|
371 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
372 |
+
if uncond_prompt is not None:
|
373 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
374 |
+
uncond_tokens,
|
375 |
+
uncond_weights,
|
376 |
+
max_length,
|
377 |
+
bos,
|
378 |
+
eos,
|
379 |
+
pad,
|
380 |
+
no_boseos_middle=no_boseos_middle,
|
381 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
382 |
+
)
|
383 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
384 |
+
|
385 |
+
# get the embeddings
|
386 |
+
text_embeddings, text_pool = get_unweighted_text_embeddings(
|
387 |
+
pipe,
|
388 |
+
prompt_tokens,
|
389 |
+
pipe.tokenizer.model_max_length,
|
390 |
+
clip_skip,
|
391 |
+
eos,
|
392 |
+
pad,
|
393 |
+
is_sdxl_text_encoder2,
|
394 |
+
no_boseos_middle=no_boseos_middle,
|
395 |
+
)
|
396 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
397 |
+
|
398 |
+
if uncond_prompt is not None:
|
399 |
+
uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
|
400 |
+
pipe,
|
401 |
+
uncond_tokens,
|
402 |
+
pipe.tokenizer.model_max_length,
|
403 |
+
clip_skip,
|
404 |
+
eos,
|
405 |
+
pad,
|
406 |
+
is_sdxl_text_encoder2,
|
407 |
+
no_boseos_middle=no_boseos_middle,
|
408 |
+
)
|
409 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
410 |
+
|
411 |
+
# assign weights to the prompts and normalize in the sense of mean
|
412 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
413 |
+
if (not skip_parsing) and (not skip_weighting):
|
414 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
415 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
416 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
417 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
418 |
+
if uncond_prompt is not None:
|
419 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
420 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
421 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
422 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
423 |
+
|
424 |
+
if uncond_prompt is not None:
|
425 |
+
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
426 |
+
return text_embeddings, text_pool, None, None
|
427 |
+
|
428 |
+
|
429 |
+
def preprocess_image(image):
|
430 |
+
w, h = image.size
|
431 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
432 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
433 |
+
image = np.array(image).astype(np.float32) / 255.0
|
434 |
+
image = image[None].transpose(0, 3, 1, 2)
|
435 |
+
image = torch.from_numpy(image)
|
436 |
+
return 2.0 * image - 1.0
|
437 |
+
|
438 |
+
|
439 |
+
def preprocess_mask(mask, scale_factor=8):
|
440 |
+
mask = mask.convert("L")
|
441 |
+
w, h = mask.size
|
442 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
443 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
444 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
445 |
+
mask = np.tile(mask, (4, 1, 1))
|
446 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
447 |
+
mask = 1 - mask # repaint white, keep black
|
448 |
+
mask = torch.from_numpy(mask)
|
449 |
+
return mask
|
450 |
+
|
451 |
+
|
452 |
+
def prepare_controlnet_image(
|
453 |
+
image: PIL.Image.Image,
|
454 |
+
width: int,
|
455 |
+
height: int,
|
456 |
+
batch_size: int,
|
457 |
+
num_images_per_prompt: int,
|
458 |
+
device: torch.device,
|
459 |
+
dtype: torch.dtype,
|
460 |
+
do_classifier_free_guidance: bool = False,
|
461 |
+
guess_mode: bool = False,
|
462 |
+
):
|
463 |
+
if not isinstance(image, torch.Tensor):
|
464 |
+
if isinstance(image, PIL.Image.Image):
|
465 |
+
image = [image]
|
466 |
+
|
467 |
+
if isinstance(image[0], PIL.Image.Image):
|
468 |
+
images = []
|
469 |
+
|
470 |
+
for image_ in image:
|
471 |
+
image_ = image_.convert("RGB")
|
472 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
473 |
+
image_ = np.array(image_)
|
474 |
+
image_ = image_[None, :]
|
475 |
+
images.append(image_)
|
476 |
+
|
477 |
+
image = images
|
478 |
+
|
479 |
+
image = np.concatenate(image, axis=0)
|
480 |
+
image = np.array(image).astype(np.float32) / 255.0
|
481 |
+
image = image.transpose(0, 3, 1, 2)
|
482 |
+
image = torch.from_numpy(image)
|
483 |
+
elif isinstance(image[0], torch.Tensor):
|
484 |
+
image = torch.cat(image, dim=0)
|
485 |
+
|
486 |
+
image_batch_size = image.shape[0]
|
487 |
+
|
488 |
+
if image_batch_size == 1:
|
489 |
+
repeat_by = batch_size
|
490 |
+
else:
|
491 |
+
# image batch size is the same as prompt batch size
|
492 |
+
repeat_by = num_images_per_prompt
|
493 |
+
|
494 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
495 |
+
|
496 |
+
image = image.to(device=device, dtype=dtype)
|
497 |
+
|
498 |
+
if do_classifier_free_guidance and not guess_mode:
|
499 |
+
image = torch.cat([image] * 2)
|
500 |
+
|
501 |
+
return image
|
502 |
+
|
503 |
+
|
504 |
+
class SdxlStableDiffusionLongPromptWeightingPipeline:
|
505 |
+
r"""
|
506 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
507 |
+
weighting in prompt.
|
508 |
+
|
509 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
510 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
511 |
+
|
512 |
+
Args:
|
513 |
+
vae ([`AutoencoderKL`]):
|
514 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
515 |
+
text_encoder ([`CLIPTextModel`]):
|
516 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
517 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
518 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
519 |
+
tokenizer (`CLIPTokenizer`):
|
520 |
+
Tokenizer of class
|
521 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
522 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
523 |
+
scheduler ([`SchedulerMixin`]):
|
524 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
525 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
526 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
527 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
528 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
529 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
530 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
531 |
+
"""
|
532 |
+
|
533 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
534 |
+
|
535 |
+
def __init__(
|
536 |
+
self,
|
537 |
+
vae: AutoencoderKL,
|
538 |
+
text_encoder: List[CLIPTextModel],
|
539 |
+
tokenizer: List[CLIPTokenizer],
|
540 |
+
unet: UNet2DConditionModel,
|
541 |
+
scheduler: SchedulerMixin,
|
542 |
+
# clip_skip: int,
|
543 |
+
safety_checker: StableDiffusionSafetyChecker,
|
544 |
+
feature_extractor: CLIPFeatureExtractor,
|
545 |
+
requires_safety_checker: bool = True,
|
546 |
+
clip_skip: int = 1,
|
547 |
+
):
|
548 |
+
# clip skip is ignored currently
|
549 |
+
self.tokenizer = tokenizer[0]
|
550 |
+
self.text_encoder = text_encoder[0]
|
551 |
+
self.unet = unet
|
552 |
+
self.scheduler = scheduler
|
553 |
+
self.safety_checker = safety_checker
|
554 |
+
self.feature_extractor = feature_extractor
|
555 |
+
self.requires_safety_checker = requires_safety_checker
|
556 |
+
self.vae = vae
|
557 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
558 |
+
self.progress_bar = lambda x: tqdm(x, leave=False)
|
559 |
+
|
560 |
+
self.clip_skip = clip_skip
|
561 |
+
self.tokenizers = tokenizer
|
562 |
+
self.text_encoders = text_encoder
|
563 |
+
|
564 |
+
# self.__init__additional__()
|
565 |
+
|
566 |
+
# def __init__additional__(self):
|
567 |
+
# if not hasattr(self, "vae_scale_factor"):
|
568 |
+
# setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
569 |
+
|
570 |
+
def to(self, device=None, dtype=None):
|
571 |
+
if device is not None:
|
572 |
+
self.device = device
|
573 |
+
# self.vae.to(device=self.device)
|
574 |
+
if dtype is not None:
|
575 |
+
self.dtype = dtype
|
576 |
+
|
577 |
+
# do not move Text Encoders to device, because Text Encoder should be on CPU
|
578 |
+
|
579 |
+
@property
|
580 |
+
def _execution_device(self):
|
581 |
+
r"""
|
582 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
583 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
584 |
+
hooks.
|
585 |
+
"""
|
586 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
587 |
+
return self.device
|
588 |
+
for module in self.unet.modules():
|
589 |
+
if (
|
590 |
+
hasattr(module, "_hf_hook")
|
591 |
+
and hasattr(module._hf_hook, "execution_device")
|
592 |
+
and module._hf_hook.execution_device is not None
|
593 |
+
):
|
594 |
+
return torch.device(module._hf_hook.execution_device)
|
595 |
+
return self.device
|
596 |
+
|
597 |
+
def _encode_prompt(
|
598 |
+
self,
|
599 |
+
prompt,
|
600 |
+
device,
|
601 |
+
num_images_per_prompt,
|
602 |
+
do_classifier_free_guidance,
|
603 |
+
negative_prompt,
|
604 |
+
max_embeddings_multiples,
|
605 |
+
is_sdxl_text_encoder2,
|
606 |
+
):
|
607 |
+
r"""
|
608 |
+
Encodes the prompt into text encoder hidden states.
|
609 |
+
|
610 |
+
Args:
|
611 |
+
prompt (`str` or `list(int)`):
|
612 |
+
prompt to be encoded
|
613 |
+
device: (`torch.device`):
|
614 |
+
torch device
|
615 |
+
num_images_per_prompt (`int`):
|
616 |
+
number of images that should be generated per prompt
|
617 |
+
do_classifier_free_guidance (`bool`):
|
618 |
+
whether to use classifier free guidance or not
|
619 |
+
negative_prompt (`str` or `List[str]`):
|
620 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
621 |
+
if `guidance_scale` is less than `1`).
|
622 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
623 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
624 |
+
"""
|
625 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
626 |
+
|
627 |
+
if negative_prompt is None:
|
628 |
+
negative_prompt = [""] * batch_size
|
629 |
+
elif isinstance(negative_prompt, str):
|
630 |
+
negative_prompt = [negative_prompt] * batch_size
|
631 |
+
if batch_size != len(negative_prompt):
|
632 |
+
raise ValueError(
|
633 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
634 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
635 |
+
" the batch size of `prompt`."
|
636 |
+
)
|
637 |
+
|
638 |
+
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
|
639 |
+
pipe=self,
|
640 |
+
prompt=prompt,
|
641 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
642 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
643 |
+
clip_skip=self.clip_skip,
|
644 |
+
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
|
645 |
+
)
|
646 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
647 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
|
648 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
649 |
+
if text_pool is not None:
|
650 |
+
text_pool = text_pool.repeat(1, num_images_per_prompt)
|
651 |
+
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
|
652 |
+
|
653 |
+
if do_classifier_free_guidance:
|
654 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
655 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
656 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
657 |
+
if uncond_pool is not None:
|
658 |
+
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
|
659 |
+
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
|
660 |
+
|
661 |
+
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
662 |
+
|
663 |
+
return text_embeddings, text_pool, None, None
|
664 |
+
|
665 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
666 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
667 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
668 |
+
|
669 |
+
if strength < 0 or strength > 1:
|
670 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
671 |
+
|
672 |
+
if height % 8 != 0 or width % 8 != 0:
|
673 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
674 |
+
|
675 |
+
if (callback_steps is None) or (
|
676 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
677 |
+
):
|
678 |
+
raise ValueError(
|
679 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
680 |
+
)
|
681 |
+
|
682 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
683 |
+
if is_text2img:
|
684 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
685 |
+
else:
|
686 |
+
# get the original timestep using init_timestep
|
687 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
688 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
689 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
690 |
+
|
691 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
692 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
693 |
+
return timesteps, num_inference_steps - t_start
|
694 |
+
|
695 |
+
def run_safety_checker(self, image, device, dtype):
|
696 |
+
if self.safety_checker is not None:
|
697 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
698 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
699 |
+
else:
|
700 |
+
has_nsfw_concept = None
|
701 |
+
return image, has_nsfw_concept
|
702 |
+
|
703 |
+
def decode_latents(self, latents):
|
704 |
+
with torch.no_grad():
|
705 |
+
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
706 |
+
|
707 |
+
# print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32
|
708 |
+
# x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
|
709 |
+
# print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16
|
710 |
+
# self.vae.to("cpu")
|
711 |
+
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
712 |
+
# image = self.vae.decode(latents.to("cpu")).sample
|
713 |
+
|
714 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
715 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
716 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
717 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
718 |
+
return image
|
719 |
+
|
720 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
721 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
722 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
723 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
724 |
+
# and should be between [0, 1]
|
725 |
+
|
726 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
727 |
+
extra_step_kwargs = {}
|
728 |
+
if accepts_eta:
|
729 |
+
extra_step_kwargs["eta"] = eta
|
730 |
+
|
731 |
+
# check if the scheduler accepts generator
|
732 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
733 |
+
if accepts_generator:
|
734 |
+
extra_step_kwargs["generator"] = generator
|
735 |
+
return extra_step_kwargs
|
736 |
+
|
737 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
738 |
+
if image is None:
|
739 |
+
shape = (
|
740 |
+
batch_size,
|
741 |
+
self.unet.in_channels,
|
742 |
+
height // self.vae_scale_factor,
|
743 |
+
width // self.vae_scale_factor,
|
744 |
+
)
|
745 |
+
|
746 |
+
if latents is None:
|
747 |
+
if device.type == "mps":
|
748 |
+
# randn does not work reproducibly on mps
|
749 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
750 |
+
else:
|
751 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
752 |
+
else:
|
753 |
+
if latents.shape != shape:
|
754 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
755 |
+
latents = latents.to(device)
|
756 |
+
|
757 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
758 |
+
latents = latents * self.scheduler.init_noise_sigma
|
759 |
+
return latents, None, None
|
760 |
+
else:
|
761 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
762 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
763 |
+
init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
|
764 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
765 |
+
init_latents_orig = init_latents
|
766 |
+
shape = init_latents.shape
|
767 |
+
|
768 |
+
# add noise to latents using the timesteps
|
769 |
+
if device.type == "mps":
|
770 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
771 |
+
else:
|
772 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
773 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
774 |
+
return latents, init_latents_orig, noise
|
775 |
+
|
776 |
+
@torch.no_grad()
|
777 |
+
def __call__(
|
778 |
+
self,
|
779 |
+
prompt: Union[str, List[str]],
|
780 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
781 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
782 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
783 |
+
height: int = 512,
|
784 |
+
width: int = 512,
|
785 |
+
num_inference_steps: int = 50,
|
786 |
+
guidance_scale: float = 7.5,
|
787 |
+
strength: float = 0.8,
|
788 |
+
num_images_per_prompt: Optional[int] = 1,
|
789 |
+
eta: float = 0.0,
|
790 |
+
generator: Optional[torch.Generator] = None,
|
791 |
+
latents: Optional[torch.FloatTensor] = None,
|
792 |
+
max_embeddings_multiples: Optional[int] = 3,
|
793 |
+
output_type: Optional[str] = "pil",
|
794 |
+
return_dict: bool = True,
|
795 |
+
controlnet=None,
|
796 |
+
controlnet_image=None,
|
797 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
798 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
799 |
+
callback_steps: int = 1,
|
800 |
+
):
|
801 |
+
r"""
|
802 |
+
Function invoked when calling the pipeline for generation.
|
803 |
+
|
804 |
+
Args:
|
805 |
+
prompt (`str` or `List[str]`):
|
806 |
+
The prompt or prompts to guide the image generation.
|
807 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
808 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
809 |
+
if `guidance_scale` is less than `1`).
|
810 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
811 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
812 |
+
process.
|
813 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
814 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
815 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
816 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
817 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
818 |
+
height (`int`, *optional*, defaults to 512):
|
819 |
+
The height in pixels of the generated image.
|
820 |
+
width (`int`, *optional*, defaults to 512):
|
821 |
+
The width in pixels of the generated image.
|
822 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
823 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
824 |
+
expense of slower inference.
|
825 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
826 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
827 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
828 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
829 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
830 |
+
usually at the expense of lower image quality.
|
831 |
+
strength (`float`, *optional*, defaults to 0.8):
|
832 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
833 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
834 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
835 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
836 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
837 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
838 |
+
The number of images to generate per prompt.
|
839 |
+
eta (`float`, *optional*, defaults to 0.0):
|
840 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
841 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
842 |
+
generator (`torch.Generator`, *optional*):
|
843 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
844 |
+
deterministic.
|
845 |
+
latents (`torch.FloatTensor`, *optional*):
|
846 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
847 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
848 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
849 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
850 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
851 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
852 |
+
The output format of the generate image. Choose between
|
853 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
854 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
855 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
856 |
+
plain tuple.
|
857 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
858 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
859 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
860 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
861 |
+
inference.
|
862 |
+
callback (`Callable`, *optional*):
|
863 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
864 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
865 |
+
is_cancelled_callback (`Callable`, *optional*):
|
866 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
867 |
+
`True`, the inference will be cancelled.
|
868 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
869 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
870 |
+
called at every step.
|
871 |
+
|
872 |
+
Returns:
|
873 |
+
`None` if cancelled by `is_cancelled_callback`,
|
874 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
875 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
876 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
877 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
878 |
+
(nsfw) content, according to the `safety_checker`.
|
879 |
+
"""
|
880 |
+
if controlnet is not None and controlnet_image is None:
|
881 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
882 |
+
|
883 |
+
# 0. Default height and width to unet
|
884 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
885 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
886 |
+
|
887 |
+
# 1. Check inputs. Raise error if not correct
|
888 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
889 |
+
|
890 |
+
# 2. Define call parameters
|
891 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
892 |
+
device = self._execution_device
|
893 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
894 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
895 |
+
# corresponds to doing no classifier free guidance.
|
896 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
897 |
+
|
898 |
+
# 3. Encode input prompt
|
899 |
+
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
|
900 |
+
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
|
901 |
+
text_embeddings_list = []
|
902 |
+
text_pool = None
|
903 |
+
uncond_embeddings_list = []
|
904 |
+
uncond_pool = None
|
905 |
+
for i in range(len(self.tokenizers)):
|
906 |
+
self.tokenizer = self.tokenizers[i]
|
907 |
+
self.text_encoder = self.text_encoders[i]
|
908 |
+
|
909 |
+
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
|
910 |
+
prompt,
|
911 |
+
device,
|
912 |
+
num_images_per_prompt,
|
913 |
+
do_classifier_free_guidance,
|
914 |
+
negative_prompt,
|
915 |
+
max_embeddings_multiples,
|
916 |
+
is_sdxl_text_encoder2=i == 1,
|
917 |
+
)
|
918 |
+
text_embeddings_list.append(text_embeddings)
|
919 |
+
uncond_embeddings_list.append(uncond_embeddings)
|
920 |
+
|
921 |
+
if tp1 is not None:
|
922 |
+
text_pool = tp1
|
923 |
+
if up1 is not None:
|
924 |
+
uncond_pool = up1
|
925 |
+
|
926 |
+
dtype = self.unet.dtype
|
927 |
+
|
928 |
+
# 4. Preprocess image and mask
|
929 |
+
if isinstance(image, PIL.Image.Image):
|
930 |
+
image = preprocess_image(image)
|
931 |
+
if image is not None:
|
932 |
+
image = image.to(device=self.device, dtype=dtype)
|
933 |
+
if isinstance(mask_image, PIL.Image.Image):
|
934 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
935 |
+
if mask_image is not None:
|
936 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
937 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
938 |
+
else:
|
939 |
+
mask = None
|
940 |
+
|
941 |
+
# ControlNet is not working yet in SDXL, but keep the code here for future use
|
942 |
+
if controlnet_image is not None:
|
943 |
+
controlnet_image = prepare_controlnet_image(
|
944 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
945 |
+
)
|
946 |
+
|
947 |
+
# 5. set timesteps
|
948 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
949 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
950 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
951 |
+
|
952 |
+
# 6. Prepare latent variables
|
953 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
954 |
+
image,
|
955 |
+
latent_timestep,
|
956 |
+
batch_size * num_images_per_prompt,
|
957 |
+
height,
|
958 |
+
width,
|
959 |
+
dtype,
|
960 |
+
device,
|
961 |
+
generator,
|
962 |
+
latents,
|
963 |
+
)
|
964 |
+
|
965 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
966 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
967 |
+
|
968 |
+
# create size embs and concat embeddings for SDXL
|
969 |
+
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
|
970 |
+
crop_size = torch.zeros_like(orig_size)
|
971 |
+
target_size = orig_size
|
972 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
|
973 |
+
|
974 |
+
# make conditionings
|
975 |
+
if do_classifier_free_guidance:
|
976 |
+
text_embeddings = torch.cat(text_embeddings_list, dim=2)
|
977 |
+
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
|
978 |
+
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
|
979 |
+
|
980 |
+
cond_vector = torch.cat([text_pool, embs], dim=1)
|
981 |
+
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
|
982 |
+
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
|
983 |
+
else:
|
984 |
+
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
|
985 |
+
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
|
986 |
+
|
987 |
+
# 8. Denoising loop
|
988 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
989 |
+
# expand the latents if we are doing classifier free guidance
|
990 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
991 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
992 |
+
|
993 |
+
unet_additional_args = {}
|
994 |
+
if controlnet is not None:
|
995 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
996 |
+
latent_model_input,
|
997 |
+
t,
|
998 |
+
encoder_hidden_states=text_embeddings,
|
999 |
+
controlnet_cond=controlnet_image,
|
1000 |
+
conditioning_scale=1.0,
|
1001 |
+
guess_mode=False,
|
1002 |
+
return_dict=False,
|
1003 |
+
)
|
1004 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
1005 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
1006 |
+
|
1007 |
+
# predict the noise residual
|
1008 |
+
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
1009 |
+
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
1010 |
+
|
1011 |
+
# perform guidance
|
1012 |
+
if do_classifier_free_guidance:
|
1013 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1014 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1015 |
+
|
1016 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1017 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1018 |
+
|
1019 |
+
if mask is not None:
|
1020 |
+
# masking
|
1021 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
1022 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
1023 |
+
|
1024 |
+
# call the callback, if provided
|
1025 |
+
if i % callback_steps == 0:
|
1026 |
+
if callback is not None:
|
1027 |
+
callback(i, t, latents)
|
1028 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
1029 |
+
return None
|
1030 |
+
|
1031 |
+
return latents
|
1032 |
+
|
1033 |
+
def latents_to_image(self, latents):
|
1034 |
+
# 9. Post-processing
|
1035 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
1036 |
+
image = self.numpy_to_pil(image)
|
1037 |
+
return image
|
1038 |
+
|
1039 |
+
# copy from pil_utils.py
|
1040 |
+
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
|
1041 |
+
"""
|
1042 |
+
Convert a numpy image or a batch of images to a PIL image.
|
1043 |
+
"""
|
1044 |
+
if images.ndim == 3:
|
1045 |
+
images = images[None, ...]
|
1046 |
+
images = (images * 255).round().astype("uint8")
|
1047 |
+
if images.shape[-1] == 1:
|
1048 |
+
# special case for grayscale (single channel) images
|
1049 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
1050 |
+
else:
|
1051 |
+
pil_images = [Image.fromarray(image) for image in images]
|
1052 |
+
|
1053 |
+
return pil_images
|
1054 |
+
|
1055 |
+
def text2img(
|
1056 |
+
self,
|
1057 |
+
prompt: Union[str, List[str]],
|
1058 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1059 |
+
height: int = 512,
|
1060 |
+
width: int = 512,
|
1061 |
+
num_inference_steps: int = 50,
|
1062 |
+
guidance_scale: float = 7.5,
|
1063 |
+
num_images_per_prompt: Optional[int] = 1,
|
1064 |
+
eta: float = 0.0,
|
1065 |
+
generator: Optional[torch.Generator] = None,
|
1066 |
+
latents: Optional[torch.FloatTensor] = None,
|
1067 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1068 |
+
output_type: Optional[str] = "pil",
|
1069 |
+
return_dict: bool = True,
|
1070 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1071 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1072 |
+
callback_steps: int = 1,
|
1073 |
+
):
|
1074 |
+
r"""
|
1075 |
+
Function for text-to-image generation.
|
1076 |
+
Args:
|
1077 |
+
prompt (`str` or `List[str]`):
|
1078 |
+
The prompt or prompts to guide the image generation.
|
1079 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1080 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1081 |
+
if `guidance_scale` is less than `1`).
|
1082 |
+
height (`int`, *optional*, defaults to 512):
|
1083 |
+
The height in pixels of the generated image.
|
1084 |
+
width (`int`, *optional*, defaults to 512):
|
1085 |
+
The width in pixels of the generated image.
|
1086 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1087 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1088 |
+
expense of slower inference.
|
1089 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1090 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1091 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1092 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1093 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1094 |
+
usually at the expense of lower image quality.
|
1095 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1096 |
+
The number of images to generate per prompt.
|
1097 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1098 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1099 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1100 |
+
generator (`torch.Generator`, *optional*):
|
1101 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1102 |
+
deterministic.
|
1103 |
+
latents (`torch.FloatTensor`, *optional*):
|
1104 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1105 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1106 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1107 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1108 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1109 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1110 |
+
The output format of the generate image. Choose between
|
1111 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1112 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1113 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1114 |
+
plain tuple.
|
1115 |
+
callback (`Callable`, *optional*):
|
1116 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1117 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1118 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1119 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1120 |
+
`True`, the inference will be cancelled.
|
1121 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1122 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1123 |
+
called at every step.
|
1124 |
+
Returns:
|
1125 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1126 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1127 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1128 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1129 |
+
(nsfw) content, according to the `safety_checker`.
|
1130 |
+
"""
|
1131 |
+
return self.__call__(
|
1132 |
+
prompt=prompt,
|
1133 |
+
negative_prompt=negative_prompt,
|
1134 |
+
height=height,
|
1135 |
+
width=width,
|
1136 |
+
num_inference_steps=num_inference_steps,
|
1137 |
+
guidance_scale=guidance_scale,
|
1138 |
+
num_images_per_prompt=num_images_per_prompt,
|
1139 |
+
eta=eta,
|
1140 |
+
generator=generator,
|
1141 |
+
latents=latents,
|
1142 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1143 |
+
output_type=output_type,
|
1144 |
+
return_dict=return_dict,
|
1145 |
+
callback=callback,
|
1146 |
+
is_cancelled_callback=is_cancelled_callback,
|
1147 |
+
callback_steps=callback_steps,
|
1148 |
+
)
|
1149 |
+
|
1150 |
+
def img2img(
|
1151 |
+
self,
|
1152 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1153 |
+
prompt: Union[str, List[str]],
|
1154 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1155 |
+
strength: float = 0.8,
|
1156 |
+
num_inference_steps: Optional[int] = 50,
|
1157 |
+
guidance_scale: Optional[float] = 7.5,
|
1158 |
+
num_images_per_prompt: Optional[int] = 1,
|
1159 |
+
eta: Optional[float] = 0.0,
|
1160 |
+
generator: Optional[torch.Generator] = None,
|
1161 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1162 |
+
output_type: Optional[str] = "pil",
|
1163 |
+
return_dict: bool = True,
|
1164 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1165 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1166 |
+
callback_steps: int = 1,
|
1167 |
+
):
|
1168 |
+
r"""
|
1169 |
+
Function for image-to-image generation.
|
1170 |
+
Args:
|
1171 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1172 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1173 |
+
process.
|
1174 |
+
prompt (`str` or `List[str]`):
|
1175 |
+
The prompt or prompts to guide the image generation.
|
1176 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1177 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1178 |
+
if `guidance_scale` is less than `1`).
|
1179 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1180 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
1181 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
1182 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
1183 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
1184 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1185 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1186 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1187 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1188 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1189 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1190 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1191 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1192 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1193 |
+
usually at the expense of lower image quality.
|
1194 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1195 |
+
The number of images to generate per prompt.
|
1196 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1197 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1198 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1199 |
+
generator (`torch.Generator`, *optional*):
|
1200 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1201 |
+
deterministic.
|
1202 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1203 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1204 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1205 |
+
The output format of the generate image. Choose between
|
1206 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1207 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1208 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1209 |
+
plain tuple.
|
1210 |
+
callback (`Callable`, *optional*):
|
1211 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1212 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1213 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1214 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1215 |
+
`True`, the inference will be cancelled.
|
1216 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1217 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1218 |
+
called at every step.
|
1219 |
+
Returns:
|
1220 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1221 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1222 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1223 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1224 |
+
(nsfw) content, according to the `safety_checker`.
|
1225 |
+
"""
|
1226 |
+
return self.__call__(
|
1227 |
+
prompt=prompt,
|
1228 |
+
negative_prompt=negative_prompt,
|
1229 |
+
image=image,
|
1230 |
+
num_inference_steps=num_inference_steps,
|
1231 |
+
guidance_scale=guidance_scale,
|
1232 |
+
strength=strength,
|
1233 |
+
num_images_per_prompt=num_images_per_prompt,
|
1234 |
+
eta=eta,
|
1235 |
+
generator=generator,
|
1236 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1237 |
+
output_type=output_type,
|
1238 |
+
return_dict=return_dict,
|
1239 |
+
callback=callback,
|
1240 |
+
is_cancelled_callback=is_cancelled_callback,
|
1241 |
+
callback_steps=callback_steps,
|
1242 |
+
)
|
1243 |
+
|
1244 |
+
def inpaint(
|
1245 |
+
self,
|
1246 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1247 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
1248 |
+
prompt: Union[str, List[str]],
|
1249 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1250 |
+
strength: float = 0.8,
|
1251 |
+
num_inference_steps: Optional[int] = 50,
|
1252 |
+
guidance_scale: Optional[float] = 7.5,
|
1253 |
+
num_images_per_prompt: Optional[int] = 1,
|
1254 |
+
eta: Optional[float] = 0.0,
|
1255 |
+
generator: Optional[torch.Generator] = None,
|
1256 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1257 |
+
output_type: Optional[str] = "pil",
|
1258 |
+
return_dict: bool = True,
|
1259 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1260 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1261 |
+
callback_steps: int = 1,
|
1262 |
+
):
|
1263 |
+
r"""
|
1264 |
+
Function for inpaint.
|
1265 |
+
Args:
|
1266 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1267 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1268 |
+
process. This is the image whose masked region will be inpainted.
|
1269 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1270 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1271 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1272 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1273 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1274 |
+
prompt (`str` or `List[str]`):
|
1275 |
+
The prompt or prompts to guide the image generation.
|
1276 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1277 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1278 |
+
if `guidance_scale` is less than `1`).
|
1279 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1280 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1281 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1282 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1283 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1284 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1285 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1286 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1287 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1288 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1289 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1290 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1291 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1292 |
+
usually at the expense of lower image quality.
|
1293 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1294 |
+
The number of images to generate per prompt.
|
1295 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1296 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1297 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1298 |
+
generator (`torch.Generator`, *optional*):
|
1299 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1300 |
+
deterministic.
|
1301 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1302 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1303 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1304 |
+
The output format of the generate image. Choose between
|
1305 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1306 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1307 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1308 |
+
plain tuple.
|
1309 |
+
callback (`Callable`, *optional*):
|
1310 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1311 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1312 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1313 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1314 |
+
`True`, the inference will be cancelled.
|
1315 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1316 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1317 |
+
called at every step.
|
1318 |
+
Returns:
|
1319 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1320 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1321 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1322 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1323 |
+
(nsfw) content, according to the `safety_checker`.
|
1324 |
+
"""
|
1325 |
+
return self.__call__(
|
1326 |
+
prompt=prompt,
|
1327 |
+
negative_prompt=negative_prompt,
|
1328 |
+
image=image,
|
1329 |
+
mask_image=mask_image,
|
1330 |
+
num_inference_steps=num_inference_steps,
|
1331 |
+
guidance_scale=guidance_scale,
|
1332 |
+
strength=strength,
|
1333 |
+
num_images_per_prompt=num_images_per_prompt,
|
1334 |
+
eta=eta,
|
1335 |
+
generator=generator,
|
1336 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1337 |
+
output_type=output_type,
|
1338 |
+
return_dict=return_dict,
|
1339 |
+
callback=callback,
|
1340 |
+
is_cancelled_callback=is_cancelled_callback,
|
1341 |
+
callback_steps=callback_steps,
|
1342 |
+
)
|
external/llite/library/sdxl_model_util.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from accelerate import init_empty_weights
|
3 |
+
from accelerate.utils.modeling import set_module_tensor_to_device
|
4 |
+
from safetensors.torch import load_file, save_file
|
5 |
+
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
6 |
+
from typing import List
|
7 |
+
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
8 |
+
from external.llite.library import model_util
|
9 |
+
from external.llite.library import sdxl_original_unet
|
10 |
+
|
11 |
+
|
12 |
+
VAE_SCALE_FACTOR = 0.13025
|
13 |
+
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
14 |
+
|
15 |
+
# Diffusersの設定を読み込むための参照モデル
|
16 |
+
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
17 |
+
|
18 |
+
DIFFUSERS_SDXL_UNET_CONFIG = {
|
19 |
+
"act_fn": "silu",
|
20 |
+
"addition_embed_type": "text_time",
|
21 |
+
"addition_embed_type_num_heads": 64,
|
22 |
+
"addition_time_embed_dim": 256,
|
23 |
+
"attention_head_dim": [5, 10, 20],
|
24 |
+
"block_out_channels": [320, 640, 1280],
|
25 |
+
"center_input_sample": False,
|
26 |
+
"class_embed_type": None,
|
27 |
+
"class_embeddings_concat": False,
|
28 |
+
"conv_in_kernel": 3,
|
29 |
+
"conv_out_kernel": 3,
|
30 |
+
"cross_attention_dim": 2048,
|
31 |
+
"cross_attention_norm": None,
|
32 |
+
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
|
33 |
+
"downsample_padding": 1,
|
34 |
+
"dual_cross_attention": False,
|
35 |
+
"encoder_hid_dim": None,
|
36 |
+
"encoder_hid_dim_type": None,
|
37 |
+
"flip_sin_to_cos": True,
|
38 |
+
"freq_shift": 0,
|
39 |
+
"in_channels": 4,
|
40 |
+
"layers_per_block": 2,
|
41 |
+
"mid_block_only_cross_attention": None,
|
42 |
+
"mid_block_scale_factor": 1,
|
43 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
44 |
+
"norm_eps": 1e-05,
|
45 |
+
"norm_num_groups": 32,
|
46 |
+
"num_attention_heads": None,
|
47 |
+
"num_class_embeds": None,
|
48 |
+
"only_cross_attention": False,
|
49 |
+
"out_channels": 4,
|
50 |
+
"projection_class_embeddings_input_dim": 2816,
|
51 |
+
"resnet_out_scale_factor": 1.0,
|
52 |
+
"resnet_skip_time_act": False,
|
53 |
+
"resnet_time_scale_shift": "default",
|
54 |
+
"sample_size": 128,
|
55 |
+
"time_cond_proj_dim": None,
|
56 |
+
"time_embedding_act_fn": None,
|
57 |
+
"time_embedding_dim": None,
|
58 |
+
"time_embedding_type": "positional",
|
59 |
+
"timestep_post_act": None,
|
60 |
+
"transformer_layers_per_block": [1, 2, 10],
|
61 |
+
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
|
62 |
+
"upcast_attention": False,
|
63 |
+
"use_linear_projection": True,
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
68 |
+
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
69 |
+
|
70 |
+
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
|
71 |
+
# logit_scaleはcheckpointの保存時に使用する
|
72 |
+
def convert_key(key):
|
73 |
+
# common conversion
|
74 |
+
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
75 |
+
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
76 |
+
|
77 |
+
if "resblocks" in key:
|
78 |
+
# resblocks conversion
|
79 |
+
key = key.replace(".resblocks.", ".layers.")
|
80 |
+
if ".ln_" in key:
|
81 |
+
key = key.replace(".ln_", ".layer_norm")
|
82 |
+
elif ".mlp." in key:
|
83 |
+
key = key.replace(".c_fc.", ".fc1.")
|
84 |
+
key = key.replace(".c_proj.", ".fc2.")
|
85 |
+
elif ".attn.out_proj" in key:
|
86 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
87 |
+
elif ".attn.in_proj" in key:
|
88 |
+
key = None # 特殊なので後で処理する
|
89 |
+
else:
|
90 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
91 |
+
elif ".positional_embedding" in key:
|
92 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
93 |
+
elif ".text_projection" in key:
|
94 |
+
key = key.replace("text_model.text_projection", "text_projection.weight")
|
95 |
+
elif ".logit_scale" in key:
|
96 |
+
key = None # 後で処理する
|
97 |
+
elif ".token_embedding" in key:
|
98 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
99 |
+
elif ".ln_final" in key:
|
100 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
101 |
+
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
102 |
+
elif ".embeddings.position_ids" in key:
|
103 |
+
key = None # remove this key: make position_ids by ourselves
|
104 |
+
return key
|
105 |
+
|
106 |
+
keys = list(checkpoint.keys())
|
107 |
+
new_sd = {}
|
108 |
+
for key in keys:
|
109 |
+
new_key = convert_key(key)
|
110 |
+
if new_key is None:
|
111 |
+
continue
|
112 |
+
new_sd[new_key] = checkpoint[key]
|
113 |
+
|
114 |
+
# attnの変換
|
115 |
+
for key in keys:
|
116 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
117 |
+
# 三つに分割
|
118 |
+
values = torch.chunk(checkpoint[key], 3)
|
119 |
+
|
120 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
121 |
+
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
122 |
+
key_pfx = key_pfx.replace("_weight", "")
|
123 |
+
key_pfx = key_pfx.replace("_bias", "")
|
124 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
125 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
126 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
127 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
128 |
+
|
129 |
+
# original SD にはないので、position_idsを追加
|
130 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
131 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
132 |
+
|
133 |
+
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
134 |
+
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
135 |
+
|
136 |
+
# temporary workaround for text_projection.weight.weight for Playground-v2
|
137 |
+
if "text_projection.weight.weight" in new_sd:
|
138 |
+
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
139 |
+
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
140 |
+
del new_sd["text_projection.weight.weight"]
|
141 |
+
|
142 |
+
return new_sd, logit_scale
|
143 |
+
|
144 |
+
|
145 |
+
# load state_dict without allocating new tensors
|
146 |
+
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
147 |
+
# dtype will use fp32 as default
|
148 |
+
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
149 |
+
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
150 |
+
|
151 |
+
# similar to model.load_state_dict()
|
152 |
+
if not missing_keys and not unexpected_keys:
|
153 |
+
for k in list(state_dict.keys()):
|
154 |
+
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
155 |
+
return "<All keys matched successfully>"
|
156 |
+
|
157 |
+
# error_msgs
|
158 |
+
error_msgs: List[str] = []
|
159 |
+
if missing_keys:
|
160 |
+
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
|
161 |
+
if unexpected_keys:
|
162 |
+
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
|
163 |
+
|
164 |
+
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
165 |
+
|
166 |
+
|
167 |
+
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
168 |
+
# model_version is reserved for future use
|
169 |
+
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
170 |
+
|
171 |
+
# Load the state dict
|
172 |
+
if model_util.is_safetensors(ckpt_path):
|
173 |
+
checkpoint = None
|
174 |
+
try:
|
175 |
+
state_dict = load_file(ckpt_path, device=map_location)
|
176 |
+
except:
|
177 |
+
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
178 |
+
epoch = None
|
179 |
+
global_step = None
|
180 |
+
else:
|
181 |
+
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
182 |
+
if "state_dict" in checkpoint:
|
183 |
+
state_dict = checkpoint["state_dict"]
|
184 |
+
epoch = checkpoint.get("epoch", 0)
|
185 |
+
global_step = checkpoint.get("global_step", 0)
|
186 |
+
else:
|
187 |
+
state_dict = checkpoint
|
188 |
+
epoch = 0
|
189 |
+
global_step = 0
|
190 |
+
checkpoint = None
|
191 |
+
|
192 |
+
# U-Net
|
193 |
+
print("building U-Net")
|
194 |
+
with init_empty_weights():
|
195 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
196 |
+
|
197 |
+
print("loading U-Net from checkpoint")
|
198 |
+
unet_sd = {}
|
199 |
+
for k in list(state_dict.keys()):
|
200 |
+
if k.startswith("model.diffusion_model."):
|
201 |
+
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
202 |
+
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
203 |
+
print("U-Net: ", info)
|
204 |
+
|
205 |
+
# Text Encoders
|
206 |
+
print("building text encoders")
|
207 |
+
|
208 |
+
# Text Encoder 1 is same to Stability AI's SDXL
|
209 |
+
text_model1_cfg = CLIPTextConfig(
|
210 |
+
vocab_size=49408,
|
211 |
+
hidden_size=768,
|
212 |
+
intermediate_size=3072,
|
213 |
+
num_hidden_layers=12,
|
214 |
+
num_attention_heads=12,
|
215 |
+
max_position_embeddings=77,
|
216 |
+
hidden_act="quick_gelu",
|
217 |
+
layer_norm_eps=1e-05,
|
218 |
+
dropout=0.0,
|
219 |
+
attention_dropout=0.0,
|
220 |
+
initializer_range=0.02,
|
221 |
+
initializer_factor=1.0,
|
222 |
+
pad_token_id=1,
|
223 |
+
bos_token_id=0,
|
224 |
+
eos_token_id=2,
|
225 |
+
model_type="clip_text_model",
|
226 |
+
projection_dim=768,
|
227 |
+
# torch_dtype="float32",
|
228 |
+
# transformers_version="4.25.0.dev0",
|
229 |
+
)
|
230 |
+
with init_empty_weights():
|
231 |
+
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
232 |
+
|
233 |
+
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
234 |
+
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
235 |
+
text_model2_cfg = CLIPTextConfig(
|
236 |
+
vocab_size=49408,
|
237 |
+
hidden_size=1280,
|
238 |
+
intermediate_size=5120,
|
239 |
+
num_hidden_layers=32,
|
240 |
+
num_attention_heads=20,
|
241 |
+
max_position_embeddings=77,
|
242 |
+
hidden_act="gelu",
|
243 |
+
layer_norm_eps=1e-05,
|
244 |
+
dropout=0.0,
|
245 |
+
attention_dropout=0.0,
|
246 |
+
initializer_range=0.02,
|
247 |
+
initializer_factor=1.0,
|
248 |
+
pad_token_id=1,
|
249 |
+
bos_token_id=0,
|
250 |
+
eos_token_id=2,
|
251 |
+
model_type="clip_text_model",
|
252 |
+
projection_dim=1280,
|
253 |
+
# torch_dtype="float32",
|
254 |
+
# transformers_version="4.25.0.dev0",
|
255 |
+
)
|
256 |
+
with init_empty_weights():
|
257 |
+
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
258 |
+
|
259 |
+
print("loading text encoders from checkpoint")
|
260 |
+
te1_sd = {}
|
261 |
+
te2_sd = {}
|
262 |
+
for k in list(state_dict.keys()):
|
263 |
+
if k.startswith("conditioner.embedders.0.transformer."):
|
264 |
+
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
265 |
+
elif k.startswith("conditioner.embedders.1.model."):
|
266 |
+
te2_sd[k] = state_dict.pop(k)
|
267 |
+
|
268 |
+
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
269 |
+
if "text_model.embeddings.position_ids" not in te1_sd:
|
270 |
+
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
271 |
+
|
272 |
+
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
273 |
+
print("text encoder 1:", info1)
|
274 |
+
|
275 |
+
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
276 |
+
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
277 |
+
print("text encoder 2:", info2)
|
278 |
+
|
279 |
+
# prepare vae
|
280 |
+
print("building VAE")
|
281 |
+
vae_config = model_util.create_vae_diffusers_config()
|
282 |
+
with init_empty_weights():
|
283 |
+
vae = AutoencoderKL(**vae_config)
|
284 |
+
|
285 |
+
print("loading VAE from checkpoint")
|
286 |
+
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
287 |
+
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
288 |
+
print("VAE:", info)
|
289 |
+
|
290 |
+
ckpt_info = (epoch, global_step) if epoch is not None else None
|
291 |
+
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
292 |
+
|
293 |
+
|
294 |
+
def make_unet_conversion_map():
|
295 |
+
unet_conversion_map_layer = []
|
296 |
+
|
297 |
+
for i in range(3): # num_blocks is 3 in sdxl
|
298 |
+
# loop over downblocks/upblocks
|
299 |
+
for j in range(2):
|
300 |
+
# loop over resnets/attentions for downblocks
|
301 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
302 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
303 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
304 |
+
|
305 |
+
if i < 3:
|
306 |
+
# no attention layers in down_blocks.3
|
307 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
308 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
309 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
310 |
+
|
311 |
+
for j in range(3):
|
312 |
+
# loop over resnets/attentions for upblocks
|
313 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
314 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
315 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
316 |
+
|
317 |
+
# if i > 0: commentout for sdxl
|
318 |
+
# no attention layers in up_blocks.0
|
319 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
320 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
321 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
322 |
+
|
323 |
+
if i < 3:
|
324 |
+
# no downsample in down_blocks.3
|
325 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
326 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
327 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
328 |
+
|
329 |
+
# no upsample in up_blocks.3
|
330 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
331 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
332 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
333 |
+
|
334 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
335 |
+
sd_mid_atn_prefix = "middle_block.1."
|
336 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
337 |
+
|
338 |
+
for j in range(2):
|
339 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
340 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
341 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
342 |
+
|
343 |
+
unet_conversion_map_resnet = [
|
344 |
+
# (stable-diffusion, HF Diffusers)
|
345 |
+
("in_layers.0.", "norm1."),
|
346 |
+
("in_layers.2.", "conv1."),
|
347 |
+
("out_layers.0.", "norm2."),
|
348 |
+
("out_layers.3.", "conv2."),
|
349 |
+
("emb_layers.1.", "time_emb_proj."),
|
350 |
+
("skip_connection.", "conv_shortcut."),
|
351 |
+
]
|
352 |
+
|
353 |
+
unet_conversion_map = []
|
354 |
+
for sd, hf in unet_conversion_map_layer:
|
355 |
+
if "resnets" in hf:
|
356 |
+
for sd_res, hf_res in unet_conversion_map_resnet:
|
357 |
+
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
358 |
+
else:
|
359 |
+
unet_conversion_map.append((sd, hf))
|
360 |
+
|
361 |
+
for j in range(2):
|
362 |
+
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
363 |
+
sd_time_embed_prefix = f"time_embed.{j*2}."
|
364 |
+
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
365 |
+
|
366 |
+
for j in range(2):
|
367 |
+
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
368 |
+
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
369 |
+
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
370 |
+
|
371 |
+
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
372 |
+
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
373 |
+
unet_conversion_map.append(("out.2.", "conv_out."))
|
374 |
+
|
375 |
+
return unet_conversion_map
|
376 |
+
|
377 |
+
|
378 |
+
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
|
379 |
+
unet_conversion_map = make_unet_conversion_map()
|
380 |
+
|
381 |
+
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
|
382 |
+
return convert_unet_state_dict(du_sd, conversion_map)
|
383 |
+
|
384 |
+
|
385 |
+
def convert_unet_state_dict(src_sd, conversion_map):
|
386 |
+
converted_sd = {}
|
387 |
+
for src_key, value in src_sd.items():
|
388 |
+
# さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
|
389 |
+
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
|
390 |
+
while len(src_key_fragments) > 0:
|
391 |
+
src_key_prefix = ".".join(src_key_fragments) + "."
|
392 |
+
if src_key_prefix in conversion_map:
|
393 |
+
converted_prefix = conversion_map[src_key_prefix]
|
394 |
+
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
|
395 |
+
converted_sd[converted_key] = value
|
396 |
+
break
|
397 |
+
src_key_fragments.pop(-1)
|
398 |
+
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
|
399 |
+
|
400 |
+
return converted_sd
|
401 |
+
|
402 |
+
|
403 |
+
def convert_sdxl_unet_state_dict_to_diffusers(sd):
|
404 |
+
unet_conversion_map = make_unet_conversion_map()
|
405 |
+
|
406 |
+
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
|
407 |
+
return convert_unet_state_dict(sd, conversion_dict)
|
408 |
+
|
409 |
+
|
410 |
+
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
411 |
+
def convert_key(key):
|
412 |
+
# position_idsの除去
|
413 |
+
if ".position_ids" in key:
|
414 |
+
return None
|
415 |
+
|
416 |
+
# common
|
417 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
418 |
+
key = key.replace("text_model.", "")
|
419 |
+
if "layers" in key:
|
420 |
+
# resblocks conversion
|
421 |
+
key = key.replace(".layers.", ".resblocks.")
|
422 |
+
if ".layer_norm" in key:
|
423 |
+
key = key.replace(".layer_norm", ".ln_")
|
424 |
+
elif ".mlp." in key:
|
425 |
+
key = key.replace(".fc1.", ".c_fc.")
|
426 |
+
key = key.replace(".fc2.", ".c_proj.")
|
427 |
+
elif ".self_attn.out_proj" in key:
|
428 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
429 |
+
elif ".self_attn." in key:
|
430 |
+
key = None # 特殊なので後で処理する
|
431 |
+
else:
|
432 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
433 |
+
elif ".position_embedding" in key:
|
434 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
435 |
+
elif ".token_embedding" in key:
|
436 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
437 |
+
elif "text_projection" in key: # no dot in key
|
438 |
+
key = key.replace("text_projection.weight", "text_projection")
|
439 |
+
elif "final_layer_norm" in key:
|
440 |
+
key = key.replace("final_layer_norm", "ln_final")
|
441 |
+
return key
|
442 |
+
|
443 |
+
keys = list(checkpoint.keys())
|
444 |
+
new_sd = {}
|
445 |
+
for key in keys:
|
446 |
+
new_key = convert_key(key)
|
447 |
+
if new_key is None:
|
448 |
+
continue
|
449 |
+
new_sd[new_key] = checkpoint[key]
|
450 |
+
|
451 |
+
# attnの変換
|
452 |
+
for key in keys:
|
453 |
+
if "layers" in key and "q_proj" in key:
|
454 |
+
# 三つを結合
|
455 |
+
key_q = key
|
456 |
+
key_k = key.replace("q_proj", "k_proj")
|
457 |
+
key_v = key.replace("q_proj", "v_proj")
|
458 |
+
|
459 |
+
value_q = checkpoint[key_q]
|
460 |
+
value_k = checkpoint[key_k]
|
461 |
+
value_v = checkpoint[key_v]
|
462 |
+
value = torch.cat([value_q, value_k, value_v])
|
463 |
+
|
464 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
465 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
466 |
+
new_sd[new_key] = value
|
467 |
+
|
468 |
+
if logit_scale is not None:
|
469 |
+
new_sd["logit_scale"] = logit_scale
|
470 |
+
|
471 |
+
return new_sd
|
472 |
+
|
473 |
+
|
474 |
+
def save_stable_diffusion_checkpoint(
|
475 |
+
output_file,
|
476 |
+
text_encoder1,
|
477 |
+
text_encoder2,
|
478 |
+
unet,
|
479 |
+
epochs,
|
480 |
+
steps,
|
481 |
+
ckpt_info,
|
482 |
+
vae,
|
483 |
+
logit_scale,
|
484 |
+
metadata,
|
485 |
+
save_dtype=None,
|
486 |
+
):
|
487 |
+
state_dict = {}
|
488 |
+
|
489 |
+
def update_sd(prefix, sd):
|
490 |
+
for k, v in sd.items():
|
491 |
+
key = prefix + k
|
492 |
+
if save_dtype is not None:
|
493 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
494 |
+
state_dict[key] = v
|
495 |
+
|
496 |
+
# Convert the UNet model
|
497 |
+
update_sd("model.diffusion_model.", unet.state_dict())
|
498 |
+
|
499 |
+
# Convert the text encoders
|
500 |
+
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
501 |
+
|
502 |
+
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
503 |
+
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
504 |
+
|
505 |
+
# Convert the VAE
|
506 |
+
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
507 |
+
update_sd("first_stage_model.", vae_dict)
|
508 |
+
|
509 |
+
# Put together new checkpoint
|
510 |
+
key_count = len(state_dict.keys())
|
511 |
+
new_ckpt = {"state_dict": state_dict}
|
512 |
+
|
513 |
+
# epoch and global_step are sometimes not int
|
514 |
+
if ckpt_info is not None:
|
515 |
+
epochs += ckpt_info[0]
|
516 |
+
steps += ckpt_info[1]
|
517 |
+
|
518 |
+
new_ckpt["epoch"] = epochs
|
519 |
+
new_ckpt["global_step"] = steps
|
520 |
+
|
521 |
+
if model_util.is_safetensors(output_file):
|
522 |
+
save_file(state_dict, output_file, metadata)
|
523 |
+
else:
|
524 |
+
torch.save(new_ckpt, output_file)
|
525 |
+
|
526 |
+
return key_count
|
527 |
+
|
528 |
+
|
529 |
+
def save_diffusers_checkpoint(
|
530 |
+
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
531 |
+
):
|
532 |
+
from diffusers import StableDiffusionXLPipeline
|
533 |
+
|
534 |
+
# convert U-Net
|
535 |
+
unet_sd = unet.state_dict()
|
536 |
+
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
537 |
+
|
538 |
+
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
|
539 |
+
if save_dtype is not None:
|
540 |
+
diffusers_unet.to(save_dtype)
|
541 |
+
diffusers_unet.load_state_dict(du_unet_sd)
|
542 |
+
|
543 |
+
# create pipeline to save
|
544 |
+
if pretrained_model_name_or_path is None:
|
545 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
|
546 |
+
|
547 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
548 |
+
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
549 |
+
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
|
550 |
+
if vae is None:
|
551 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
552 |
+
|
553 |
+
# prevent local path from being saved
|
554 |
+
def remove_name_or_path(model):
|
555 |
+
if hasattr(model, "config"):
|
556 |
+
model.config._name_or_path = None
|
557 |
+
model.config._name_or_path = None
|
558 |
+
|
559 |
+
remove_name_or_path(diffusers_unet)
|
560 |
+
remove_name_or_path(text_encoder1)
|
561 |
+
remove_name_or_path(text_encoder2)
|
562 |
+
remove_name_or_path(scheduler)
|
563 |
+
remove_name_or_path(tokenizer1)
|
564 |
+
remove_name_or_path(tokenizer2)
|
565 |
+
remove_name_or_path(vae)
|
566 |
+
|
567 |
+
pipeline = StableDiffusionXLPipeline(
|
568 |
+
unet=diffusers_unet,
|
569 |
+
text_encoder=text_encoder1,
|
570 |
+
text_encoder_2=text_encoder2,
|
571 |
+
vae=vae,
|
572 |
+
scheduler=scheduler,
|
573 |
+
tokenizer=tokenizer1,
|
574 |
+
tokenizer_2=tokenizer2,
|
575 |
+
)
|
576 |
+
if save_dtype is not None:
|
577 |
+
pipeline.to(None, save_dtype)
|
578 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
external/llite/library/sdxl_original_unet.py
ADDED
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusersのコードをベースとした sd_xl_baseのU-Net
|
2 |
+
# state dictの形式をSDXLに合わせてある
|
3 |
+
|
4 |
+
"""
|
5 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
6 |
+
params:
|
7 |
+
adm_in_channels: 2816
|
8 |
+
num_classes: sequential
|
9 |
+
use_checkpoint: True
|
10 |
+
in_channels: 4
|
11 |
+
out_channels: 4
|
12 |
+
model_channels: 320
|
13 |
+
attention_resolutions: [4, 2]
|
14 |
+
num_res_blocks: 2
|
15 |
+
channel_mult: [1, 2, 4]
|
16 |
+
num_head_channels: 64
|
17 |
+
use_spatial_transformer: True
|
18 |
+
use_linear_in_transformer: True
|
19 |
+
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
20 |
+
context_dim: 2048
|
21 |
+
spatial_transformer_attn_type: softmax-xformers
|
22 |
+
legacy: False
|
23 |
+
"""
|
24 |
+
|
25 |
+
import math
|
26 |
+
from types import SimpleNamespace
|
27 |
+
from typing import Any, Optional
|
28 |
+
import torch
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
from torch import nn
|
31 |
+
from torch.nn import functional as F
|
32 |
+
from einops import rearrange
|
33 |
+
|
34 |
+
|
35 |
+
IN_CHANNELS: int = 4
|
36 |
+
OUT_CHANNELS: int = 4
|
37 |
+
ADM_IN_CHANNELS: int = 2816
|
38 |
+
CONTEXT_DIM: int = 2048
|
39 |
+
MODEL_CHANNELS: int = 320
|
40 |
+
TIME_EMBED_DIM = 320 * 4
|
41 |
+
|
42 |
+
USE_REENTRANT = True
|
43 |
+
|
44 |
+
# region memory efficient attention
|
45 |
+
|
46 |
+
# FlashAttentionを使うCrossAttention
|
47 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
48 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
49 |
+
|
50 |
+
# constants
|
51 |
+
|
52 |
+
EPSILON = 1e-6
|
53 |
+
|
54 |
+
# helper functions
|
55 |
+
|
56 |
+
|
57 |
+
def exists(val):
|
58 |
+
return val is not None
|
59 |
+
|
60 |
+
|
61 |
+
def default(val, d):
|
62 |
+
return val if exists(val) else d
|
63 |
+
|
64 |
+
|
65 |
+
# flash attention forwards and backwards
|
66 |
+
|
67 |
+
# https://arxiv.org/abs/2205.14135
|
68 |
+
|
69 |
+
|
70 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
71 |
+
@staticmethod
|
72 |
+
@torch.no_grad()
|
73 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
74 |
+
"""Algorithm 2 in the paper"""
|
75 |
+
|
76 |
+
device = q.device
|
77 |
+
dtype = q.dtype
|
78 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
79 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
80 |
+
|
81 |
+
o = torch.zeros_like(q)
|
82 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
83 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
84 |
+
|
85 |
+
scale = q.shape[-1] ** -0.5
|
86 |
+
|
87 |
+
if not exists(mask):
|
88 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
89 |
+
else:
|
90 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
91 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
92 |
+
|
93 |
+
row_splits = zip(
|
94 |
+
q.split(q_bucket_size, dim=-2),
|
95 |
+
o.split(q_bucket_size, dim=-2),
|
96 |
+
mask,
|
97 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
98 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
99 |
+
)
|
100 |
+
|
101 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
102 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
103 |
+
|
104 |
+
col_splits = zip(
|
105 |
+
k.split(k_bucket_size, dim=-2),
|
106 |
+
v.split(k_bucket_size, dim=-2),
|
107 |
+
)
|
108 |
+
|
109 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
110 |
+
k_start_index = k_ind * k_bucket_size
|
111 |
+
|
112 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
113 |
+
|
114 |
+
if exists(row_mask):
|
115 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
116 |
+
|
117 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
118 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
119 |
+
q_start_index - k_start_index + 1
|
120 |
+
)
|
121 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
122 |
+
|
123 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
124 |
+
attn_weights -= block_row_maxes
|
125 |
+
exp_weights = torch.exp(attn_weights)
|
126 |
+
|
127 |
+
if exists(row_mask):
|
128 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
129 |
+
|
130 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
131 |
+
|
132 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
133 |
+
|
134 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
135 |
+
|
136 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
137 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
138 |
+
|
139 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
140 |
+
|
141 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
142 |
+
|
143 |
+
row_maxes.copy_(new_row_maxes)
|
144 |
+
row_sums.copy_(new_row_sums)
|
145 |
+
|
146 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
147 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
148 |
+
|
149 |
+
return o
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
@torch.no_grad()
|
153 |
+
def backward(ctx, do):
|
154 |
+
"""Algorithm 4 in the paper"""
|
155 |
+
|
156 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
157 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
158 |
+
|
159 |
+
device = q.device
|
160 |
+
|
161 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
162 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
163 |
+
|
164 |
+
dq = torch.zeros_like(q)
|
165 |
+
dk = torch.zeros_like(k)
|
166 |
+
dv = torch.zeros_like(v)
|
167 |
+
|
168 |
+
row_splits = zip(
|
169 |
+
q.split(q_bucket_size, dim=-2),
|
170 |
+
o.split(q_bucket_size, dim=-2),
|
171 |
+
do.split(q_bucket_size, dim=-2),
|
172 |
+
mask,
|
173 |
+
l.split(q_bucket_size, dim=-2),
|
174 |
+
m.split(q_bucket_size, dim=-2),
|
175 |
+
dq.split(q_bucket_size, dim=-2),
|
176 |
+
)
|
177 |
+
|
178 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
179 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
180 |
+
|
181 |
+
col_splits = zip(
|
182 |
+
k.split(k_bucket_size, dim=-2),
|
183 |
+
v.split(k_bucket_size, dim=-2),
|
184 |
+
dk.split(k_bucket_size, dim=-2),
|
185 |
+
dv.split(k_bucket_size, dim=-2),
|
186 |
+
)
|
187 |
+
|
188 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
189 |
+
k_start_index = k_ind * k_bucket_size
|
190 |
+
|
191 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
192 |
+
|
193 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
194 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
195 |
+
q_start_index - k_start_index + 1
|
196 |
+
)
|
197 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
198 |
+
|
199 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
200 |
+
|
201 |
+
if exists(row_mask):
|
202 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
203 |
+
|
204 |
+
p = exp_attn_weights / lc
|
205 |
+
|
206 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
207 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
208 |
+
|
209 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
210 |
+
ds = p * scale * (dp - D)
|
211 |
+
|
212 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
213 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
214 |
+
|
215 |
+
dqc.add_(dq_chunk)
|
216 |
+
dkc.add_(dk_chunk)
|
217 |
+
dvc.add_(dv_chunk)
|
218 |
+
|
219 |
+
return dq, dk, dv, None, None, None, None
|
220 |
+
|
221 |
+
|
222 |
+
# endregion
|
223 |
+
|
224 |
+
|
225 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
226 |
+
return next(parameter.parameters()).dtype
|
227 |
+
|
228 |
+
|
229 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
230 |
+
return next(parameter.parameters()).device
|
231 |
+
|
232 |
+
|
233 |
+
def get_timestep_embedding(
|
234 |
+
timesteps: torch.Tensor,
|
235 |
+
embedding_dim: int,
|
236 |
+
downscale_freq_shift: float = 1,
|
237 |
+
scale: float = 1,
|
238 |
+
max_period: int = 10000,
|
239 |
+
):
|
240 |
+
"""
|
241 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
242 |
+
|
243 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
244 |
+
These may be fractional.
|
245 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
246 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
247 |
+
"""
|
248 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
249 |
+
|
250 |
+
half_dim = embedding_dim // 2
|
251 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
252 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
253 |
+
|
254 |
+
emb = torch.exp(exponent)
|
255 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
256 |
+
|
257 |
+
# scale embeddings
|
258 |
+
emb = scale * emb
|
259 |
+
|
260 |
+
# concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
|
261 |
+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
262 |
+
|
263 |
+
# zero pad
|
264 |
+
if embedding_dim % 2 == 1:
|
265 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
266 |
+
return emb
|
267 |
+
|
268 |
+
|
269 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
270 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
271 |
+
org_dtype = x.dtype
|
272 |
+
if org_dtype == torch.bfloat16:
|
273 |
+
x = x.to(torch.float32)
|
274 |
+
|
275 |
+
if x.shape[-2:] != target.shape[-2:]:
|
276 |
+
if mode == "nearest":
|
277 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
278 |
+
else:
|
279 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
280 |
+
|
281 |
+
if org_dtype == torch.bfloat16:
|
282 |
+
x = x.to(org_dtype)
|
283 |
+
return x
|
284 |
+
|
285 |
+
|
286 |
+
class GroupNorm32(nn.GroupNorm):
|
287 |
+
def forward(self, x):
|
288 |
+
if self.weight.dtype != torch.float32:
|
289 |
+
return super().forward(x)
|
290 |
+
return super().forward(x.float()).type(x.dtype)
|
291 |
+
|
292 |
+
|
293 |
+
class ResnetBlock2D(nn.Module):
|
294 |
+
def __init__(
|
295 |
+
self,
|
296 |
+
in_channels,
|
297 |
+
out_channels,
|
298 |
+
):
|
299 |
+
super().__init__()
|
300 |
+
self.in_channels = in_channels
|
301 |
+
self.out_channels = out_channels
|
302 |
+
|
303 |
+
self.in_layers = nn.Sequential(
|
304 |
+
GroupNorm32(32, in_channels),
|
305 |
+
nn.SiLU(),
|
306 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
307 |
+
)
|
308 |
+
|
309 |
+
self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
|
310 |
+
|
311 |
+
self.out_layers = nn.Sequential(
|
312 |
+
GroupNorm32(32, out_channels),
|
313 |
+
nn.SiLU(),
|
314 |
+
nn.Identity(), # to make state_dict compatible with original model
|
315 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
316 |
+
)
|
317 |
+
|
318 |
+
if in_channels != out_channels:
|
319 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
320 |
+
else:
|
321 |
+
self.skip_connection = nn.Identity()
|
322 |
+
|
323 |
+
self.gradient_checkpointing = False
|
324 |
+
|
325 |
+
def forward_body(self, x, emb):
|
326 |
+
h = self.in_layers(x)
|
327 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
328 |
+
h = h + emb_out[:, :, None, None]
|
329 |
+
h = self.out_layers(h)
|
330 |
+
x = self.skip_connection(x)
|
331 |
+
return x + h
|
332 |
+
|
333 |
+
def forward(self, x, emb):
|
334 |
+
if self.training and self.gradient_checkpointing:
|
335 |
+
# print("ResnetBlock2D: gradient_checkpointing")
|
336 |
+
|
337 |
+
def create_custom_forward(func):
|
338 |
+
def custom_forward(*inputs):
|
339 |
+
return func(*inputs)
|
340 |
+
|
341 |
+
return custom_forward
|
342 |
+
|
343 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
|
344 |
+
else:
|
345 |
+
x = self.forward_body(x, emb)
|
346 |
+
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
class Downsample2D(nn.Module):
|
351 |
+
def __init__(self, channels, out_channels):
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.channels = channels
|
355 |
+
self.out_channels = out_channels
|
356 |
+
|
357 |
+
self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
358 |
+
|
359 |
+
self.gradient_checkpointing = False
|
360 |
+
|
361 |
+
def forward_body(self, hidden_states):
|
362 |
+
assert hidden_states.shape[1] == self.channels
|
363 |
+
hidden_states = self.op(hidden_states)
|
364 |
+
|
365 |
+
return hidden_states
|
366 |
+
|
367 |
+
def forward(self, hidden_states):
|
368 |
+
if self.training and self.gradient_checkpointing:
|
369 |
+
# print("Downsample2D: gradient_checkpointing")
|
370 |
+
|
371 |
+
def create_custom_forward(func):
|
372 |
+
def custom_forward(*inputs):
|
373 |
+
return func(*inputs)
|
374 |
+
|
375 |
+
return custom_forward
|
376 |
+
|
377 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
378 |
+
create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
|
379 |
+
)
|
380 |
+
else:
|
381 |
+
hidden_states = self.forward_body(hidden_states)
|
382 |
+
|
383 |
+
return hidden_states
|
384 |
+
|
385 |
+
|
386 |
+
class CrossAttention(nn.Module):
|
387 |
+
def __init__(
|
388 |
+
self,
|
389 |
+
query_dim: int,
|
390 |
+
cross_attention_dim: Optional[int] = None,
|
391 |
+
heads: int = 8,
|
392 |
+
dim_head: int = 64,
|
393 |
+
upcast_attention: bool = False,
|
394 |
+
):
|
395 |
+
super().__init__()
|
396 |
+
inner_dim = dim_head * heads
|
397 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
398 |
+
self.upcast_attention = upcast_attention
|
399 |
+
|
400 |
+
self.scale = dim_head**-0.5
|
401 |
+
self.heads = heads
|
402 |
+
|
403 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
404 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
405 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
406 |
+
|
407 |
+
self.to_out = nn.ModuleList([])
|
408 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
409 |
+
# no dropout here
|
410 |
+
|
411 |
+
self.use_memory_efficient_attention_xformers = False
|
412 |
+
self.use_memory_efficient_attention_mem_eff = False
|
413 |
+
self.use_sdpa = False
|
414 |
+
|
415 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
416 |
+
self.use_memory_efficient_attention_xformers = xformers
|
417 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
418 |
+
|
419 |
+
def set_use_sdpa(self, sdpa):
|
420 |
+
self.use_sdpa = sdpa
|
421 |
+
|
422 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
423 |
+
batch_size, seq_len, dim = tensor.shape
|
424 |
+
head_size = self.heads
|
425 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
426 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
427 |
+
return tensor
|
428 |
+
|
429 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
430 |
+
batch_size, seq_len, dim = tensor.shape
|
431 |
+
head_size = self.heads
|
432 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
433 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
434 |
+
return tensor
|
435 |
+
|
436 |
+
def forward(self, hidden_states, context=None, mask=None):
|
437 |
+
if self.use_memory_efficient_attention_xformers:
|
438 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
439 |
+
if self.use_memory_efficient_attention_mem_eff:
|
440 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
441 |
+
if self.use_sdpa:
|
442 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
443 |
+
|
444 |
+
query = self.to_q(hidden_states)
|
445 |
+
context = context if context is not None else hidden_states
|
446 |
+
key = self.to_k(context)
|
447 |
+
value = self.to_v(context)
|
448 |
+
|
449 |
+
query = self.reshape_heads_to_batch_dim(query)
|
450 |
+
key = self.reshape_heads_to_batch_dim(key)
|
451 |
+
value = self.reshape_heads_to_batch_dim(value)
|
452 |
+
|
453 |
+
hidden_states = self._attention(query, key, value)
|
454 |
+
|
455 |
+
# linear proj
|
456 |
+
hidden_states = self.to_out[0](hidden_states)
|
457 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
458 |
+
return hidden_states
|
459 |
+
|
460 |
+
def _attention(self, query, key, value):
|
461 |
+
if self.upcast_attention:
|
462 |
+
query = query.float()
|
463 |
+
key = key.float()
|
464 |
+
|
465 |
+
attention_scores = torch.baddbmm(
|
466 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
467 |
+
query,
|
468 |
+
key.transpose(-1, -2),
|
469 |
+
beta=0,
|
470 |
+
alpha=self.scale,
|
471 |
+
)
|
472 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
473 |
+
|
474 |
+
# cast back to the original dtype
|
475 |
+
attention_probs = attention_probs.to(value.dtype)
|
476 |
+
|
477 |
+
# compute attention output
|
478 |
+
hidden_states = torch.bmm(attention_probs, value)
|
479 |
+
|
480 |
+
# reshape hidden_states
|
481 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
482 |
+
return hidden_states
|
483 |
+
|
484 |
+
# TODO support Hypernetworks
|
485 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
486 |
+
import xformers.ops
|
487 |
+
|
488 |
+
h = self.heads
|
489 |
+
q_in = self.to_q(x)
|
490 |
+
context = context if context is not None else x
|
491 |
+
context = context.to(x.dtype)
|
492 |
+
k_in = self.to_k(context)
|
493 |
+
v_in = self.to_v(context)
|
494 |
+
|
495 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
496 |
+
del q_in, k_in, v_in
|
497 |
+
|
498 |
+
q = q.contiguous()
|
499 |
+
k = k.contiguous()
|
500 |
+
v = v.contiguous()
|
501 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
502 |
+
del q, k, v
|
503 |
+
|
504 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
505 |
+
|
506 |
+
out = self.to_out[0](out)
|
507 |
+
return out
|
508 |
+
|
509 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
510 |
+
flash_func = FlashAttentionFunction
|
511 |
+
|
512 |
+
q_bucket_size = 512
|
513 |
+
k_bucket_size = 1024
|
514 |
+
|
515 |
+
h = self.heads
|
516 |
+
q = self.to_q(x)
|
517 |
+
context = context if context is not None else x
|
518 |
+
context = context.to(x.dtype)
|
519 |
+
k = self.to_k(context)
|
520 |
+
v = self.to_v(context)
|
521 |
+
del context, x
|
522 |
+
|
523 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
524 |
+
|
525 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
526 |
+
|
527 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
528 |
+
|
529 |
+
out = self.to_out[0](out)
|
530 |
+
return out
|
531 |
+
|
532 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
533 |
+
h = self.heads
|
534 |
+
q_in = self.to_q(x)
|
535 |
+
context = context if context is not None else x
|
536 |
+
context = context.to(x.dtype)
|
537 |
+
k_in = self.to_k(context)
|
538 |
+
v_in = self.to_v(context)
|
539 |
+
|
540 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
541 |
+
del q_in, k_in, v_in
|
542 |
+
|
543 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
544 |
+
|
545 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
546 |
+
|
547 |
+
out = self.to_out[0](out)
|
548 |
+
return out
|
549 |
+
|
550 |
+
|
551 |
+
# feedforward
|
552 |
+
class GEGLU(nn.Module):
|
553 |
+
r"""
|
554 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
555 |
+
|
556 |
+
Parameters:
|
557 |
+
dim_in (`int`): The number of channels in the input.
|
558 |
+
dim_out (`int`): The number of channels in the output.
|
559 |
+
"""
|
560 |
+
|
561 |
+
def __init__(self, dim_in: int, dim_out: int):
|
562 |
+
super().__init__()
|
563 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
564 |
+
|
565 |
+
def gelu(self, gate):
|
566 |
+
if gate.device.type != "mps":
|
567 |
+
return F.gelu(gate)
|
568 |
+
# mps: gelu is not implemented for float16
|
569 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
570 |
+
|
571 |
+
def forward(self, hidden_states):
|
572 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
573 |
+
return hidden_states * self.gelu(gate)
|
574 |
+
|
575 |
+
|
576 |
+
class FeedForward(nn.Module):
|
577 |
+
def __init__(
|
578 |
+
self,
|
579 |
+
dim: int,
|
580 |
+
):
|
581 |
+
super().__init__()
|
582 |
+
inner_dim = int(dim * 4) # mult is always 4
|
583 |
+
|
584 |
+
self.net = nn.ModuleList([])
|
585 |
+
# project in
|
586 |
+
self.net.append(GEGLU(dim, inner_dim))
|
587 |
+
# project dropout
|
588 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
589 |
+
# project out
|
590 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
591 |
+
|
592 |
+
def forward(self, hidden_states):
|
593 |
+
for module in self.net:
|
594 |
+
hidden_states = module(hidden_states)
|
595 |
+
return hidden_states
|
596 |
+
|
597 |
+
|
598 |
+
class BasicTransformerBlock(nn.Module):
|
599 |
+
def __init__(
|
600 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
601 |
+
):
|
602 |
+
super().__init__()
|
603 |
+
|
604 |
+
self.gradient_checkpointing = False
|
605 |
+
|
606 |
+
# 1. Self-Attn
|
607 |
+
self.attn1 = CrossAttention(
|
608 |
+
query_dim=dim,
|
609 |
+
cross_attention_dim=None,
|
610 |
+
heads=num_attention_heads,
|
611 |
+
dim_head=attention_head_dim,
|
612 |
+
upcast_attention=upcast_attention,
|
613 |
+
)
|
614 |
+
self.ff = FeedForward(dim)
|
615 |
+
|
616 |
+
# 2. Cross-Attn
|
617 |
+
self.attn2 = CrossAttention(
|
618 |
+
query_dim=dim,
|
619 |
+
cross_attention_dim=cross_attention_dim,
|
620 |
+
heads=num_attention_heads,
|
621 |
+
dim_head=attention_head_dim,
|
622 |
+
upcast_attention=upcast_attention,
|
623 |
+
)
|
624 |
+
|
625 |
+
self.norm1 = nn.LayerNorm(dim)
|
626 |
+
self.norm2 = nn.LayerNorm(dim)
|
627 |
+
|
628 |
+
# 3. Feed-forward
|
629 |
+
self.norm3 = nn.LayerNorm(dim)
|
630 |
+
|
631 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
632 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
633 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
634 |
+
|
635 |
+
def set_use_sdpa(self, sdpa: bool):
|
636 |
+
self.attn1.set_use_sdpa(sdpa)
|
637 |
+
self.attn2.set_use_sdpa(sdpa)
|
638 |
+
|
639 |
+
def forward_body(self, hidden_states, context=None, timestep=None):
|
640 |
+
# 1. Self-Attention
|
641 |
+
norm_hidden_states = self.norm1(hidden_states)
|
642 |
+
|
643 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
644 |
+
|
645 |
+
# 2. Cross-Attention
|
646 |
+
norm_hidden_states = self.norm2(hidden_states)
|
647 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
648 |
+
|
649 |
+
# 3. Feed-forward
|
650 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
651 |
+
|
652 |
+
return hidden_states
|
653 |
+
|
654 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
655 |
+
if self.training and self.gradient_checkpointing:
|
656 |
+
# print("BasicTransformerBlock: checkpointing")
|
657 |
+
|
658 |
+
def create_custom_forward(func):
|
659 |
+
def custom_forward(*inputs):
|
660 |
+
return func(*inputs)
|
661 |
+
|
662 |
+
return custom_forward
|
663 |
+
|
664 |
+
output = torch.utils.checkpoint.checkpoint(
|
665 |
+
create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
|
666 |
+
)
|
667 |
+
else:
|
668 |
+
output = self.forward_body(hidden_states, context, timestep)
|
669 |
+
|
670 |
+
return output
|
671 |
+
|
672 |
+
|
673 |
+
class Transformer2DModel(nn.Module):
|
674 |
+
def __init__(
|
675 |
+
self,
|
676 |
+
num_attention_heads: int = 16,
|
677 |
+
attention_head_dim: int = 88,
|
678 |
+
in_channels: Optional[int] = None,
|
679 |
+
cross_attention_dim: Optional[int] = None,
|
680 |
+
use_linear_projection: bool = False,
|
681 |
+
upcast_attention: bool = False,
|
682 |
+
num_transformer_layers: int = 1,
|
683 |
+
):
|
684 |
+
super().__init__()
|
685 |
+
self.in_channels = in_channels
|
686 |
+
self.num_attention_heads = num_attention_heads
|
687 |
+
self.attention_head_dim = attention_head_dim
|
688 |
+
inner_dim = num_attention_heads * attention_head_dim
|
689 |
+
self.use_linear_projection = use_linear_projection
|
690 |
+
|
691 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
692 |
+
# self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
|
693 |
+
|
694 |
+
if use_linear_projection:
|
695 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
696 |
+
else:
|
697 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
698 |
+
|
699 |
+
blocks = []
|
700 |
+
for _ in range(num_transformer_layers):
|
701 |
+
blocks.append(
|
702 |
+
BasicTransformerBlock(
|
703 |
+
inner_dim,
|
704 |
+
num_attention_heads,
|
705 |
+
attention_head_dim,
|
706 |
+
cross_attention_dim=cross_attention_dim,
|
707 |
+
upcast_attention=upcast_attention,
|
708 |
+
)
|
709 |
+
)
|
710 |
+
|
711 |
+
self.transformer_blocks = nn.ModuleList(blocks)
|
712 |
+
|
713 |
+
if use_linear_projection:
|
714 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
715 |
+
else:
|
716 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
717 |
+
|
718 |
+
self.gradient_checkpointing = False
|
719 |
+
|
720 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
721 |
+
for transformer in self.transformer_blocks:
|
722 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
723 |
+
|
724 |
+
def set_use_sdpa(self, sdpa):
|
725 |
+
for transformer in self.transformer_blocks:
|
726 |
+
transformer.set_use_sdpa(sdpa)
|
727 |
+
|
728 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
|
729 |
+
# 1. Input
|
730 |
+
batch, _, height, weight = hidden_states.shape
|
731 |
+
residual = hidden_states
|
732 |
+
|
733 |
+
hidden_states = self.norm(hidden_states)
|
734 |
+
if not self.use_linear_projection:
|
735 |
+
hidden_states = self.proj_in(hidden_states)
|
736 |
+
inner_dim = hidden_states.shape[1]
|
737 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
738 |
+
else:
|
739 |
+
inner_dim = hidden_states.shape[1]
|
740 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
741 |
+
hidden_states = self.proj_in(hidden_states)
|
742 |
+
|
743 |
+
# 2. Blocks
|
744 |
+
for block in self.transformer_blocks:
|
745 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
746 |
+
|
747 |
+
# 3. Output
|
748 |
+
if not self.use_linear_projection:
|
749 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
750 |
+
hidden_states = self.proj_out(hidden_states)
|
751 |
+
else:
|
752 |
+
hidden_states = self.proj_out(hidden_states)
|
753 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
754 |
+
|
755 |
+
output = hidden_states + residual
|
756 |
+
|
757 |
+
return output
|
758 |
+
|
759 |
+
|
760 |
+
class Upsample2D(nn.Module):
|
761 |
+
def __init__(self, channels, out_channels):
|
762 |
+
super().__init__()
|
763 |
+
self.channels = channels
|
764 |
+
self.out_channels = out_channels
|
765 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
766 |
+
|
767 |
+
self.gradient_checkpointing = False
|
768 |
+
|
769 |
+
def forward_body(self, hidden_states, output_size=None):
|
770 |
+
assert hidden_states.shape[1] == self.channels
|
771 |
+
|
772 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
773 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
774 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
775 |
+
dtype = hidden_states.dtype
|
776 |
+
if dtype == torch.bfloat16:
|
777 |
+
hidden_states = hidden_states.to(torch.float32)
|
778 |
+
|
779 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
780 |
+
if hidden_states.shape[0] >= 64:
|
781 |
+
hidden_states = hidden_states.contiguous()
|
782 |
+
|
783 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
784 |
+
if output_size is None:
|
785 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
786 |
+
else:
|
787 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
788 |
+
|
789 |
+
# If the input is bfloat16, we cast back to bfloat16
|
790 |
+
if dtype == torch.bfloat16:
|
791 |
+
hidden_states = hidden_states.to(dtype)
|
792 |
+
|
793 |
+
hidden_states = self.conv(hidden_states)
|
794 |
+
|
795 |
+
return hidden_states
|
796 |
+
|
797 |
+
def forward(self, hidden_states, output_size=None):
|
798 |
+
if self.training and self.gradient_checkpointing:
|
799 |
+
# print("Upsample2D: gradient_checkpointing")
|
800 |
+
|
801 |
+
def create_custom_forward(func):
|
802 |
+
def custom_forward(*inputs):
|
803 |
+
return func(*inputs)
|
804 |
+
|
805 |
+
return custom_forward
|
806 |
+
|
807 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
808 |
+
create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
hidden_states = self.forward_body(hidden_states, output_size)
|
812 |
+
|
813 |
+
return hidden_states
|
814 |
+
|
815 |
+
|
816 |
+
class SdxlUNet2DConditionModel(nn.Module):
|
817 |
+
_supports_gradient_checkpointing = True
|
818 |
+
|
819 |
+
def __init__(
|
820 |
+
self,
|
821 |
+
**kwargs,
|
822 |
+
):
|
823 |
+
super().__init__()
|
824 |
+
|
825 |
+
self.in_channels = IN_CHANNELS
|
826 |
+
self.out_channels = OUT_CHANNELS
|
827 |
+
self.model_channels = MODEL_CHANNELS
|
828 |
+
self.time_embed_dim = TIME_EMBED_DIM
|
829 |
+
self.adm_in_channels = ADM_IN_CHANNELS
|
830 |
+
|
831 |
+
self.gradient_checkpointing = False
|
832 |
+
# self.sample_size = sample_size
|
833 |
+
|
834 |
+
# time embedding
|
835 |
+
self.time_embed = nn.Sequential(
|
836 |
+
nn.Linear(self.model_channels, self.time_embed_dim),
|
837 |
+
nn.SiLU(),
|
838 |
+
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
839 |
+
)
|
840 |
+
|
841 |
+
# label embedding
|
842 |
+
self.label_emb = nn.Sequential(
|
843 |
+
nn.Sequential(
|
844 |
+
nn.Linear(self.adm_in_channels, self.time_embed_dim),
|
845 |
+
nn.SiLU(),
|
846 |
+
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
847 |
+
)
|
848 |
+
)
|
849 |
+
|
850 |
+
# input
|
851 |
+
self.input_blocks = nn.ModuleList(
|
852 |
+
[
|
853 |
+
nn.Sequential(
|
854 |
+
nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
|
855 |
+
)
|
856 |
+
]
|
857 |
+
)
|
858 |
+
|
859 |
+
# level 0
|
860 |
+
for i in range(2):
|
861 |
+
layers = [
|
862 |
+
ResnetBlock2D(
|
863 |
+
in_channels=1 * self.model_channels,
|
864 |
+
out_channels=1 * self.model_channels,
|
865 |
+
),
|
866 |
+
]
|
867 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
868 |
+
|
869 |
+
self.input_blocks.append(
|
870 |
+
nn.Sequential(
|
871 |
+
Downsample2D(
|
872 |
+
channels=1 * self.model_channels,
|
873 |
+
out_channels=1 * self.model_channels,
|
874 |
+
),
|
875 |
+
)
|
876 |
+
)
|
877 |
+
|
878 |
+
# level 1
|
879 |
+
for i in range(2):
|
880 |
+
layers = [
|
881 |
+
ResnetBlock2D(
|
882 |
+
in_channels=(1 if i == 0 else 2) * self.model_channels,
|
883 |
+
out_channels=2 * self.model_channels,
|
884 |
+
),
|
885 |
+
Transformer2DModel(
|
886 |
+
num_attention_heads=2 * self.model_channels // 64,
|
887 |
+
attention_head_dim=64,
|
888 |
+
in_channels=2 * self.model_channels,
|
889 |
+
num_transformer_layers=2,
|
890 |
+
use_linear_projection=True,
|
891 |
+
cross_attention_dim=2048,
|
892 |
+
),
|
893 |
+
]
|
894 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
895 |
+
|
896 |
+
self.input_blocks.append(
|
897 |
+
nn.Sequential(
|
898 |
+
Downsample2D(
|
899 |
+
channels=2 * self.model_channels,
|
900 |
+
out_channels=2 * self.model_channels,
|
901 |
+
),
|
902 |
+
)
|
903 |
+
)
|
904 |
+
|
905 |
+
# level 2
|
906 |
+
for i in range(2):
|
907 |
+
layers = [
|
908 |
+
ResnetBlock2D(
|
909 |
+
in_channels=(2 if i == 0 else 4) * self.model_channels,
|
910 |
+
out_channels=4 * self.model_channels,
|
911 |
+
),
|
912 |
+
Transformer2DModel(
|
913 |
+
num_attention_heads=4 * self.model_channels // 64,
|
914 |
+
attention_head_dim=64,
|
915 |
+
in_channels=4 * self.model_channels,
|
916 |
+
num_transformer_layers=10,
|
917 |
+
use_linear_projection=True,
|
918 |
+
cross_attention_dim=2048,
|
919 |
+
),
|
920 |
+
]
|
921 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
922 |
+
|
923 |
+
# mid
|
924 |
+
self.middle_block = nn.ModuleList(
|
925 |
+
[
|
926 |
+
ResnetBlock2D(
|
927 |
+
in_channels=4 * self.model_channels,
|
928 |
+
out_channels=4 * self.model_channels,
|
929 |
+
),
|
930 |
+
Transformer2DModel(
|
931 |
+
num_attention_heads=4 * self.model_channels // 64,
|
932 |
+
attention_head_dim=64,
|
933 |
+
in_channels=4 * self.model_channels,
|
934 |
+
num_transformer_layers=10,
|
935 |
+
use_linear_projection=True,
|
936 |
+
cross_attention_dim=2048,
|
937 |
+
),
|
938 |
+
ResnetBlock2D(
|
939 |
+
in_channels=4 * self.model_channels,
|
940 |
+
out_channels=4 * self.model_channels,
|
941 |
+
),
|
942 |
+
]
|
943 |
+
)
|
944 |
+
|
945 |
+
# output
|
946 |
+
self.output_blocks = nn.ModuleList([])
|
947 |
+
|
948 |
+
# level 2
|
949 |
+
for i in range(3):
|
950 |
+
layers = [
|
951 |
+
ResnetBlock2D(
|
952 |
+
in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
|
953 |
+
out_channels=4 * self.model_channels,
|
954 |
+
),
|
955 |
+
Transformer2DModel(
|
956 |
+
num_attention_heads=4 * self.model_channels // 64,
|
957 |
+
attention_head_dim=64,
|
958 |
+
in_channels=4 * self.model_channels,
|
959 |
+
num_transformer_layers=10,
|
960 |
+
use_linear_projection=True,
|
961 |
+
cross_attention_dim=2048,
|
962 |
+
),
|
963 |
+
]
|
964 |
+
if i == 2:
|
965 |
+
layers.append(
|
966 |
+
Upsample2D(
|
967 |
+
channels=4 * self.model_channels,
|
968 |
+
out_channels=4 * self.model_channels,
|
969 |
+
)
|
970 |
+
)
|
971 |
+
|
972 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
973 |
+
|
974 |
+
# level 1
|
975 |
+
for i in range(3):
|
976 |
+
layers = [
|
977 |
+
ResnetBlock2D(
|
978 |
+
in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
|
979 |
+
out_channels=2 * self.model_channels,
|
980 |
+
),
|
981 |
+
Transformer2DModel(
|
982 |
+
num_attention_heads=2 * self.model_channels // 64,
|
983 |
+
attention_head_dim=64,
|
984 |
+
in_channels=2 * self.model_channels,
|
985 |
+
num_transformer_layers=2,
|
986 |
+
use_linear_projection=True,
|
987 |
+
cross_attention_dim=2048,
|
988 |
+
),
|
989 |
+
]
|
990 |
+
if i == 2:
|
991 |
+
layers.append(
|
992 |
+
Upsample2D(
|
993 |
+
channels=2 * self.model_channels,
|
994 |
+
out_channels=2 * self.model_channels,
|
995 |
+
)
|
996 |
+
)
|
997 |
+
|
998 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
999 |
+
|
1000 |
+
# level 0
|
1001 |
+
for i in range(3):
|
1002 |
+
layers = [
|
1003 |
+
ResnetBlock2D(
|
1004 |
+
in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
|
1005 |
+
out_channels=1 * self.model_channels,
|
1006 |
+
),
|
1007 |
+
]
|
1008 |
+
|
1009 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
1010 |
+
|
1011 |
+
# output
|
1012 |
+
self.out = nn.ModuleList(
|
1013 |
+
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
# region diffusers compatibility
|
1017 |
+
def prepare_config(self):
|
1018 |
+
self.config = SimpleNamespace()
|
1019 |
+
|
1020 |
+
@property
|
1021 |
+
def dtype(self) -> torch.dtype:
|
1022 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1023 |
+
return get_parameter_dtype(self)
|
1024 |
+
|
1025 |
+
@property
|
1026 |
+
def device(self) -> torch.device:
|
1027 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
1028 |
+
return get_parameter_device(self)
|
1029 |
+
|
1030 |
+
def set_attention_slice(self, slice_size):
|
1031 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
1032 |
+
|
1033 |
+
def is_gradient_checkpointing(self) -> bool:
|
1034 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
1035 |
+
|
1036 |
+
def enable_gradient_checkpointing(self):
|
1037 |
+
self.gradient_checkpointing = True
|
1038 |
+
self.set_gradient_checkpointing(value=True)
|
1039 |
+
|
1040 |
+
def disable_gradient_checkpointing(self):
|
1041 |
+
self.gradient_checkpointing = False
|
1042 |
+
self.set_gradient_checkpointing(value=False)
|
1043 |
+
|
1044 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
1045 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
1046 |
+
for block in blocks:
|
1047 |
+
for module in block:
|
1048 |
+
if hasattr(module, "set_use_memory_efficient_attention"):
|
1049 |
+
# print(module.__class__.__name__)
|
1050 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
1051 |
+
|
1052 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
1053 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
1054 |
+
for block in blocks:
|
1055 |
+
for module in block:
|
1056 |
+
if hasattr(module, "set_use_sdpa"):
|
1057 |
+
module.set_use_sdpa(sdpa)
|
1058 |
+
|
1059 |
+
def set_gradient_checkpointing(self, value=False):
|
1060 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
1061 |
+
for block in blocks:
|
1062 |
+
for module in block.modules():
|
1063 |
+
if hasattr(module, "gradient_checkpointing"):
|
1064 |
+
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
1065 |
+
module.gradient_checkpointing = value
|
1066 |
+
|
1067 |
+
# endregion
|
1068 |
+
|
1069 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
1070 |
+
# broadcast timesteps to batch dimension
|
1071 |
+
timesteps = timesteps.expand(x.shape[0])
|
1072 |
+
|
1073 |
+
hs = []
|
1074 |
+
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
|
1075 |
+
t_emb = t_emb.to(x.dtype)
|
1076 |
+
emb = self.time_embed(t_emb)
|
1077 |
+
|
1078 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
1079 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
1080 |
+
# assert x.dtype == self.dtype
|
1081 |
+
emb = emb + self.label_emb(y)
|
1082 |
+
|
1083 |
+
def call_module(module, h, emb, context):
|
1084 |
+
x = h
|
1085 |
+
for layer in module:
|
1086 |
+
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
1087 |
+
if isinstance(layer, ResnetBlock2D):
|
1088 |
+
x = layer(x, emb)
|
1089 |
+
elif isinstance(layer, Transformer2DModel):
|
1090 |
+
x = layer(x, context)
|
1091 |
+
else:
|
1092 |
+
x = layer(x)
|
1093 |
+
return x
|
1094 |
+
|
1095 |
+
# h = x.type(self.dtype)
|
1096 |
+
h = x
|
1097 |
+
|
1098 |
+
for module in self.input_blocks:
|
1099 |
+
h = call_module(module, h, emb, context)
|
1100 |
+
hs.append(h)
|
1101 |
+
|
1102 |
+
h = call_module(self.middle_block, h, emb, context)
|
1103 |
+
|
1104 |
+
for module in self.output_blocks:
|
1105 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
1106 |
+
h = call_module(module, h, emb, context)
|
1107 |
+
|
1108 |
+
h = h.type(x.dtype)
|
1109 |
+
h = call_module(self.out, h, emb, context)
|
1110 |
+
|
1111 |
+
return h
|
1112 |
+
|
1113 |
+
|
1114 |
+
class InferSdxlUNet2DConditionModel:
|
1115 |
+
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
1116 |
+
self.delegate = original_unet
|
1117 |
+
|
1118 |
+
# override original model's forward method: because forward is not called by `__call__`
|
1119 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
1120 |
+
self.delegate.forward = self.forward
|
1121 |
+
|
1122 |
+
# Deep Shrink
|
1123 |
+
self.ds_depth_1 = None
|
1124 |
+
self.ds_depth_2 = None
|
1125 |
+
self.ds_timesteps_1 = None
|
1126 |
+
self.ds_timesteps_2 = None
|
1127 |
+
self.ds_ratio = None
|
1128 |
+
|
1129 |
+
# call original model's methods
|
1130 |
+
def __getattr__(self, name):
|
1131 |
+
return getattr(self.delegate, name)
|
1132 |
+
|
1133 |
+
def __call__(self, *args, **kwargs):
|
1134 |
+
return self.delegate(*args, **kwargs)
|
1135 |
+
|
1136 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
1137 |
+
if ds_depth_1 is None:
|
1138 |
+
print("Deep Shrink is disabled.")
|
1139 |
+
self.ds_depth_1 = None
|
1140 |
+
self.ds_timesteps_1 = None
|
1141 |
+
self.ds_depth_2 = None
|
1142 |
+
self.ds_timesteps_2 = None
|
1143 |
+
self.ds_ratio = None
|
1144 |
+
else:
|
1145 |
+
print(
|
1146 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
1147 |
+
)
|
1148 |
+
self.ds_depth_1 = ds_depth_1
|
1149 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
1150 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
1151 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
1152 |
+
self.ds_ratio = ds_ratio
|
1153 |
+
|
1154 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
1155 |
+
r"""
|
1156 |
+
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
1157 |
+
"""
|
1158 |
+
_self = self.delegate
|
1159 |
+
|
1160 |
+
# broadcast timesteps to batch dimension
|
1161 |
+
timesteps = timesteps.expand(x.shape[0])
|
1162 |
+
|
1163 |
+
hs = []
|
1164 |
+
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
1165 |
+
t_emb = t_emb.to(x.dtype)
|
1166 |
+
emb = _self.time_embed(t_emb)
|
1167 |
+
|
1168 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
1169 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
1170 |
+
# assert x.dtype == _self.dtype
|
1171 |
+
emb = emb + _self.label_emb(y)
|
1172 |
+
|
1173 |
+
def call_module(module, h, emb, context):
|
1174 |
+
x = h
|
1175 |
+
for layer in module:
|
1176 |
+
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
1177 |
+
if isinstance(layer, ResnetBlock2D):
|
1178 |
+
x = layer(x, emb)
|
1179 |
+
elif isinstance(layer, Transformer2DModel):
|
1180 |
+
x = layer(x, context)
|
1181 |
+
else:
|
1182 |
+
x = layer(x)
|
1183 |
+
return x
|
1184 |
+
|
1185 |
+
# h = x.type(self.dtype)
|
1186 |
+
h = x
|
1187 |
+
|
1188 |
+
for depth, module in enumerate(_self.input_blocks):
|
1189 |
+
# Deep Shrink
|
1190 |
+
if self.ds_depth_1 is not None:
|
1191 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
1192 |
+
self.ds_depth_2 is not None
|
1193 |
+
and depth == self.ds_depth_2
|
1194 |
+
and timesteps[0] < self.ds_timesteps_1
|
1195 |
+
and timesteps[0] >= self.ds_timesteps_2
|
1196 |
+
):
|
1197 |
+
# print("downsample", h.shape, self.ds_ratio)
|
1198 |
+
org_dtype = h.dtype
|
1199 |
+
if org_dtype == torch.bfloat16:
|
1200 |
+
h = h.to(torch.float32)
|
1201 |
+
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
1202 |
+
|
1203 |
+
h = call_module(module, h, emb, context)
|
1204 |
+
hs.append(h)
|
1205 |
+
|
1206 |
+
h = call_module(_self.middle_block, h, emb, context)
|
1207 |
+
|
1208 |
+
for module in _self.output_blocks:
|
1209 |
+
# Deep Shrink
|
1210 |
+
if self.ds_depth_1 is not None:
|
1211 |
+
if hs[-1].shape[-2:] != h.shape[-2:]:
|
1212 |
+
# print("upsample", h.shape, hs[-1].shape)
|
1213 |
+
h = resize_like(h, hs[-1])
|
1214 |
+
|
1215 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
1216 |
+
h = call_module(module, h, emb, context)
|
1217 |
+
|
1218 |
+
# Deep Shrink: in case of depth 0
|
1219 |
+
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
|
1220 |
+
# print("upsample", h.shape, x.shape)
|
1221 |
+
h = resize_like(h, x)
|
1222 |
+
|
1223 |
+
h = h.type(x.dtype)
|
1224 |
+
h = call_module(_self.out, h, emb, context)
|
1225 |
+
|
1226 |
+
return h
|
1227 |
+
|
1228 |
+
|
1229 |
+
if __name__ == "__main__":
|
1230 |
+
import time
|
1231 |
+
|
1232 |
+
print("create unet")
|
1233 |
+
unet = SdxlUNet2DConditionModel()
|
1234 |
+
|
1235 |
+
unet.to("cuda")
|
1236 |
+
unet.set_use_memory_efficient_attention(True, False)
|
1237 |
+
unet.set_gradient_checkpointing(True)
|
1238 |
+
unet.train()
|
1239 |
+
|
1240 |
+
# 使用メモリ量確認用の疑似学習ループ
|
1241 |
+
print("preparing optimizer")
|
1242 |
+
|
1243 |
+
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
1244 |
+
|
1245 |
+
# import bitsandbytes
|
1246 |
+
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
|
1247 |
+
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
1248 |
+
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
1249 |
+
|
1250 |
+
import transformers
|
1251 |
+
|
1252 |
+
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
1253 |
+
|
1254 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
1255 |
+
|
1256 |
+
print("start training")
|
1257 |
+
steps = 10
|
1258 |
+
batch_size = 1
|
1259 |
+
|
1260 |
+
for step in range(steps):
|
1261 |
+
print(f"step {step}")
|
1262 |
+
if step == 1:
|
1263 |
+
time_start = time.perf_counter()
|
1264 |
+
|
1265 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
1266 |
+
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
1267 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
1268 |
+
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
|
1269 |
+
|
1270 |
+
with torch.cuda.amp.autocast(enabled=True):
|
1271 |
+
output = unet(x, t, ctx, y)
|
1272 |
+
target = torch.randn_like(output)
|
1273 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
1274 |
+
|
1275 |
+
scaler.scale(loss).backward()
|
1276 |
+
scaler.step(optimizer)
|
1277 |
+
scaler.update()
|
1278 |
+
optimizer.zero_grad(set_to_none=True)
|
1279 |
+
|
1280 |
+
time_end = time.perf_counter()
|
1281 |
+
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
external/llite/library/sdxl_train_util.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from typing import Optional
|
6 |
+
import torch
|
7 |
+
from accelerate import init_empty_weights
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import CLIPTokenizer
|
10 |
+
from external.llite.library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
11 |
+
from external.llite.library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
12 |
+
|
13 |
+
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
14 |
+
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
15 |
+
|
16 |
+
# DEFAULT_NOISE_OFFSET = 0.0357
|
17 |
+
|
18 |
+
|
19 |
+
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
20 |
+
# load models for each process
|
21 |
+
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
22 |
+
for pi in range(accelerator.state.num_processes):
|
23 |
+
if pi == accelerator.state.local_process_index:
|
24 |
+
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
25 |
+
|
26 |
+
(
|
27 |
+
load_stable_diffusion_format,
|
28 |
+
text_encoder1,
|
29 |
+
text_encoder2,
|
30 |
+
vae,
|
31 |
+
unet,
|
32 |
+
logit_scale,
|
33 |
+
ckpt_info,
|
34 |
+
) = _load_target_model(
|
35 |
+
args.pretrained_model_name_or_path,
|
36 |
+
args.vae,
|
37 |
+
model_version,
|
38 |
+
weight_dtype,
|
39 |
+
accelerator.device if args.lowram else "cpu",
|
40 |
+
model_dtype,
|
41 |
+
)
|
42 |
+
|
43 |
+
# work on low-ram device
|
44 |
+
if args.lowram:
|
45 |
+
text_encoder1.to(accelerator.device)
|
46 |
+
text_encoder2.to(accelerator.device)
|
47 |
+
unet.to(accelerator.device)
|
48 |
+
vae.to(accelerator.device)
|
49 |
+
|
50 |
+
gc.collect()
|
51 |
+
torch.cuda.empty_cache()
|
52 |
+
accelerator.wait_for_everyone()
|
53 |
+
|
54 |
+
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
55 |
+
|
56 |
+
|
57 |
+
def _load_target_model(
|
58 |
+
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
|
59 |
+
):
|
60 |
+
# model_dtype only work with full fp16/bf16
|
61 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
62 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
63 |
+
|
64 |
+
if load_stable_diffusion_format:
|
65 |
+
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
66 |
+
(
|
67 |
+
text_encoder1,
|
68 |
+
text_encoder2,
|
69 |
+
vae,
|
70 |
+
unet,
|
71 |
+
logit_scale,
|
72 |
+
ckpt_info,
|
73 |
+
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
|
74 |
+
else:
|
75 |
+
# Diffusers model is loaded to CPU
|
76 |
+
from diffusers import StableDiffusionXLPipeline
|
77 |
+
|
78 |
+
variant = "fp16" if weight_dtype == torch.float16 else None
|
79 |
+
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
80 |
+
try:
|
81 |
+
try:
|
82 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
83 |
+
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
|
84 |
+
)
|
85 |
+
except EnvironmentError as ex:
|
86 |
+
if variant is not None:
|
87 |
+
print("try to load fp32 model")
|
88 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
89 |
+
else:
|
90 |
+
raise ex
|
91 |
+
except EnvironmentError as ex:
|
92 |
+
print(
|
93 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
94 |
+
)
|
95 |
+
raise ex
|
96 |
+
|
97 |
+
text_encoder1 = pipe.text_encoder
|
98 |
+
text_encoder2 = pipe.text_encoder_2
|
99 |
+
|
100 |
+
# convert to fp32 for cache text_encoders outputs
|
101 |
+
if text_encoder1.dtype != torch.float32:
|
102 |
+
text_encoder1 = text_encoder1.to(dtype=torch.float32)
|
103 |
+
if text_encoder2.dtype != torch.float32:
|
104 |
+
text_encoder2 = text_encoder2.to(dtype=torch.float32)
|
105 |
+
|
106 |
+
vae = pipe.vae
|
107 |
+
unet = pipe.unet
|
108 |
+
del pipe
|
109 |
+
|
110 |
+
# Diffusers U-Net to original U-Net
|
111 |
+
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
112 |
+
with init_empty_weights():
|
113 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
114 |
+
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
115 |
+
print("U-Net converted to original U-Net")
|
116 |
+
|
117 |
+
logit_scale = None
|
118 |
+
ckpt_info = None
|
119 |
+
|
120 |
+
# VAEを読み込む
|
121 |
+
if vae_path is not None:
|
122 |
+
vae = model_util.load_vae(vae_path, weight_dtype)
|
123 |
+
print("additional VAE loaded")
|
124 |
+
|
125 |
+
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
126 |
+
|
127 |
+
|
128 |
+
def load_tokenizers(args: argparse.Namespace):
|
129 |
+
print("prepare tokenizers")
|
130 |
+
|
131 |
+
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
132 |
+
tokeniers = []
|
133 |
+
for i, original_path in enumerate(original_paths):
|
134 |
+
tokenizer: CLIPTokenizer = None
|
135 |
+
if args.tokenizer_cache_dir:
|
136 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
137 |
+
if os.path.exists(local_tokenizer_path):
|
138 |
+
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
139 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
140 |
+
|
141 |
+
if tokenizer is None:
|
142 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
143 |
+
|
144 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
145 |
+
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
146 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
147 |
+
|
148 |
+
if i == 1:
|
149 |
+
tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
|
150 |
+
|
151 |
+
tokeniers.append(tokenizer)
|
152 |
+
|
153 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
154 |
+
print(f"update token length: {args.max_token_length}")
|
155 |
+
|
156 |
+
return tokeniers
|
157 |
+
|
158 |
+
|
159 |
+
def match_mixed_precision(args, weight_dtype):
|
160 |
+
if args.full_fp16:
|
161 |
+
assert (
|
162 |
+
weight_dtype == torch.float16
|
163 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
164 |
+
return weight_dtype
|
165 |
+
elif args.full_bf16:
|
166 |
+
assert (
|
167 |
+
weight_dtype == torch.bfloat16
|
168 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
169 |
+
return weight_dtype
|
170 |
+
else:
|
171 |
+
return None
|
172 |
+
|
173 |
+
|
174 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
175 |
+
"""
|
176 |
+
Create sinusoidal timestep embeddings.
|
177 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
178 |
+
These may be fractional.
|
179 |
+
:param dim: the dimension of the output.
|
180 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
181 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
182 |
+
"""
|
183 |
+
half = dim // 2
|
184 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
185 |
+
device=timesteps.device
|
186 |
+
)
|
187 |
+
args = timesteps[:, None].float() * freqs[None]
|
188 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
189 |
+
if dim % 2:
|
190 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
191 |
+
return embedding
|
192 |
+
|
193 |
+
|
194 |
+
def get_timestep_embedding(x, outdim):
|
195 |
+
assert len(x.shape) == 2
|
196 |
+
b, dims = x.shape[0], x.shape[1]
|
197 |
+
x = torch.flatten(x)
|
198 |
+
emb = timestep_embedding(x, outdim)
|
199 |
+
emb = torch.reshape(emb, (b, dims * outdim))
|
200 |
+
return emb
|
201 |
+
|
202 |
+
|
203 |
+
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
204 |
+
emb1 = get_timestep_embedding(orig_size, 256)
|
205 |
+
emb2 = get_timestep_embedding(crop_size, 256)
|
206 |
+
emb3 = get_timestep_embedding(target_size, 256)
|
207 |
+
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
208 |
+
return vector
|
209 |
+
|
210 |
+
|
211 |
+
def save_sd_model_on_train_end(
|
212 |
+
args: argparse.Namespace,
|
213 |
+
src_path: str,
|
214 |
+
save_stable_diffusion_format: bool,
|
215 |
+
use_safetensors: bool,
|
216 |
+
save_dtype: torch.dtype,
|
217 |
+
epoch: int,
|
218 |
+
global_step: int,
|
219 |
+
text_encoder1,
|
220 |
+
text_encoder2,
|
221 |
+
unet,
|
222 |
+
vae,
|
223 |
+
logit_scale,
|
224 |
+
ckpt_info,
|
225 |
+
):
|
226 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
227 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
228 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
229 |
+
ckpt_file,
|
230 |
+
text_encoder1,
|
231 |
+
text_encoder2,
|
232 |
+
unet,
|
233 |
+
epoch_no,
|
234 |
+
global_step,
|
235 |
+
ckpt_info,
|
236 |
+
vae,
|
237 |
+
logit_scale,
|
238 |
+
sai_metadata,
|
239 |
+
save_dtype,
|
240 |
+
)
|
241 |
+
|
242 |
+
def diffusers_saver(out_dir):
|
243 |
+
sdxl_model_util.save_diffusers_checkpoint(
|
244 |
+
out_dir,
|
245 |
+
text_encoder1,
|
246 |
+
text_encoder2,
|
247 |
+
unet,
|
248 |
+
src_path,
|
249 |
+
vae,
|
250 |
+
use_safetensors=use_safetensors,
|
251 |
+
save_dtype=save_dtype,
|
252 |
+
)
|
253 |
+
|
254 |
+
train_util.save_sd_model_on_train_end_common(
|
255 |
+
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
256 |
+
)
|
257 |
+
|
258 |
+
|
259 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
260 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
261 |
+
def save_sd_model_on_epoch_end_or_stepwise(
|
262 |
+
args: argparse.Namespace,
|
263 |
+
on_epoch_end: bool,
|
264 |
+
accelerator,
|
265 |
+
src_path,
|
266 |
+
save_stable_diffusion_format: bool,
|
267 |
+
use_safetensors: bool,
|
268 |
+
save_dtype: torch.dtype,
|
269 |
+
epoch: int,
|
270 |
+
num_train_epochs: int,
|
271 |
+
global_step: int,
|
272 |
+
text_encoder1,
|
273 |
+
text_encoder2,
|
274 |
+
unet,
|
275 |
+
vae,
|
276 |
+
logit_scale,
|
277 |
+
ckpt_info,
|
278 |
+
):
|
279 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
280 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
281 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
282 |
+
ckpt_file,
|
283 |
+
text_encoder1,
|
284 |
+
text_encoder2,
|
285 |
+
unet,
|
286 |
+
epoch_no,
|
287 |
+
global_step,
|
288 |
+
ckpt_info,
|
289 |
+
vae,
|
290 |
+
logit_scale,
|
291 |
+
sai_metadata,
|
292 |
+
save_dtype,
|
293 |
+
)
|
294 |
+
|
295 |
+
def diffusers_saver(out_dir):
|
296 |
+
sdxl_model_util.save_diffusers_checkpoint(
|
297 |
+
out_dir,
|
298 |
+
text_encoder1,
|
299 |
+
text_encoder2,
|
300 |
+
unet,
|
301 |
+
src_path,
|
302 |
+
vae,
|
303 |
+
use_safetensors=use_safetensors,
|
304 |
+
save_dtype=save_dtype,
|
305 |
+
)
|
306 |
+
|
307 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
308 |
+
args,
|
309 |
+
on_epoch_end,
|
310 |
+
accelerator,
|
311 |
+
save_stable_diffusion_format,
|
312 |
+
use_safetensors,
|
313 |
+
epoch,
|
314 |
+
num_train_epochs,
|
315 |
+
global_step,
|
316 |
+
sd_saver,
|
317 |
+
diffusers_saver,
|
318 |
+
)
|
319 |
+
|
320 |
+
|
321 |
+
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
322 |
+
parser.add_argument(
|
323 |
+
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
324 |
+
)
|
325 |
+
parser.add_argument(
|
326 |
+
"--cache_text_encoder_outputs_to_disk",
|
327 |
+
action="store_true",
|
328 |
+
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
329 |
+
)
|
330 |
+
|
331 |
+
|
332 |
+
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
333 |
+
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
334 |
+
if args.v_parameterization:
|
335 |
+
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
336 |
+
|
337 |
+
if args.clip_skip is not None:
|
338 |
+
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
339 |
+
|
340 |
+
# if args.multires_noise_iterations:
|
341 |
+
# print(
|
342 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
343 |
+
# )
|
344 |
+
# else:
|
345 |
+
# if args.noise_offset is None:
|
346 |
+
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
347 |
+
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
348 |
+
# print(
|
349 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
350 |
+
# )
|
351 |
+
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
352 |
+
|
353 |
+
assert (
|
354 |
+
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
355 |
+
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
356 |
+
|
357 |
+
if supportTextEncoderCaching:
|
358 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
359 |
+
args.cache_text_encoder_outputs = True
|
360 |
+
print(
|
361 |
+
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
362 |
+
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
363 |
+
)
|
364 |
+
|
365 |
+
|
366 |
+
def sample_images(*args, **kwargs):
|
367 |
+
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
external/llite/library/slicing_vae.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Diffusers to reduce VRAM usage
|
2 |
+
|
3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
|
24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
27 |
+
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
28 |
+
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
29 |
+
|
30 |
+
|
31 |
+
def slice_h(x, num_slices):
|
32 |
+
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
33 |
+
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
|
34 |
+
# NCHWでもNHWCでもどちらでも動く
|
35 |
+
size = (x.shape[2] + num_slices - 1) // num_slices
|
36 |
+
sliced = []
|
37 |
+
for i in range(num_slices):
|
38 |
+
if i == 0:
|
39 |
+
sliced.append(x[:, :, : size + 1, :])
|
40 |
+
else:
|
41 |
+
end = size * (i + 1) + 1
|
42 |
+
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
|
43 |
+
end = x.shape[2]
|
44 |
+
sliced.append(x[:, :, size * i - 1 : end, :])
|
45 |
+
if end >= x.shape[2]:
|
46 |
+
break
|
47 |
+
return sliced
|
48 |
+
|
49 |
+
|
50 |
+
def cat_h(sliced):
|
51 |
+
# padding分を除いて結合する
|
52 |
+
cat = []
|
53 |
+
for i, x in enumerate(sliced):
|
54 |
+
if i == 0:
|
55 |
+
cat.append(x[:, :, :-1, :])
|
56 |
+
elif i == len(sliced) - 1:
|
57 |
+
cat.append(x[:, :, 1:, :])
|
58 |
+
else:
|
59 |
+
cat.append(x[:, :, 1:-1, :])
|
60 |
+
del x
|
61 |
+
x = torch.cat(cat, dim=2)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
66 |
+
assert _self.upsample is None and _self.downsample is None
|
67 |
+
assert _self.norm1.num_groups == _self.norm2.num_groups
|
68 |
+
assert temb is None
|
69 |
+
|
70 |
+
# make sure norms are on cpu
|
71 |
+
org_device = input_tensor.device
|
72 |
+
cpu_device = torch.device("cpu")
|
73 |
+
_self.norm1.to(cpu_device)
|
74 |
+
_self.norm2.to(cpu_device)
|
75 |
+
|
76 |
+
# GroupNormがCPUでfp16で動かない対策
|
77 |
+
org_dtype = input_tensor.dtype
|
78 |
+
if org_dtype == torch.float16:
|
79 |
+
_self.norm1.to(torch.float32)
|
80 |
+
_self.norm2.to(torch.float32)
|
81 |
+
|
82 |
+
# すべてのテンソルをCPUに移動する
|
83 |
+
input_tensor = input_tensor.to(cpu_device)
|
84 |
+
hidden_states = input_tensor
|
85 |
+
|
86 |
+
# どうもこれは結果が異なるようだ……
|
87 |
+
# def sliced_norm1(norm, x):
|
88 |
+
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
|
89 |
+
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
90 |
+
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
91 |
+
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
92 |
+
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
93 |
+
# normed_tensor = []
|
94 |
+
# for i in range(num_div):
|
95 |
+
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
96 |
+
# normed_tensor.append(n)
|
97 |
+
# del n
|
98 |
+
# x = torch.cat(normed_tensor, dim=1)
|
99 |
+
# return num_div, x
|
100 |
+
|
101 |
+
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
102 |
+
if org_dtype == torch.float16:
|
103 |
+
hidden_states = hidden_states.to(torch.float32)
|
104 |
+
hidden_states = _self.norm1(hidden_states) # run on cpu
|
105 |
+
if org_dtype == torch.float16:
|
106 |
+
hidden_states = hidden_states.to(torch.float16)
|
107 |
+
|
108 |
+
sliced = slice_h(hidden_states, num_slices)
|
109 |
+
del hidden_states
|
110 |
+
|
111 |
+
for i in range(len(sliced)):
|
112 |
+
x = sliced[i]
|
113 |
+
sliced[i] = None
|
114 |
+
|
115 |
+
# 計算する部分だけGPUに移動する、以下同様
|
116 |
+
x = x.to(org_device)
|
117 |
+
x = _self.nonlinearity(x)
|
118 |
+
x = _self.conv1(x)
|
119 |
+
x = x.to(cpu_device)
|
120 |
+
sliced[i] = x
|
121 |
+
del x
|
122 |
+
|
123 |
+
hidden_states = cat_h(sliced)
|
124 |
+
del sliced
|
125 |
+
|
126 |
+
if org_dtype == torch.float16:
|
127 |
+
hidden_states = hidden_states.to(torch.float32)
|
128 |
+
hidden_states = _self.norm2(hidden_states) # run on cpu
|
129 |
+
if org_dtype == torch.float16:
|
130 |
+
hidden_states = hidden_states.to(torch.float16)
|
131 |
+
|
132 |
+
sliced = slice_h(hidden_states, num_slices)
|
133 |
+
del hidden_states
|
134 |
+
|
135 |
+
for i in range(len(sliced)):
|
136 |
+
x = sliced[i]
|
137 |
+
sliced[i] = None
|
138 |
+
|
139 |
+
x = x.to(org_device)
|
140 |
+
x = _self.nonlinearity(x)
|
141 |
+
x = _self.dropout(x)
|
142 |
+
x = _self.conv2(x)
|
143 |
+
x = x.to(cpu_device)
|
144 |
+
sliced[i] = x
|
145 |
+
del x
|
146 |
+
|
147 |
+
hidden_states = cat_h(sliced)
|
148 |
+
del sliced
|
149 |
+
|
150 |
+
# make shortcut
|
151 |
+
if _self.conv_shortcut is not None:
|
152 |
+
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
|
153 |
+
del input_tensor
|
154 |
+
|
155 |
+
for i in range(len(sliced)):
|
156 |
+
x = sliced[i]
|
157 |
+
sliced[i] = None
|
158 |
+
|
159 |
+
x = x.to(org_device)
|
160 |
+
x = _self.conv_shortcut(x)
|
161 |
+
x = x.to(cpu_device)
|
162 |
+
sliced[i] = x
|
163 |
+
del x
|
164 |
+
|
165 |
+
input_tensor = torch.cat(sliced, dim=2)
|
166 |
+
del sliced
|
167 |
+
|
168 |
+
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
|
169 |
+
|
170 |
+
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
|
171 |
+
return output_tensor
|
172 |
+
|
173 |
+
|
174 |
+
class SlicingEncoder(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
in_channels=3,
|
178 |
+
out_channels=3,
|
179 |
+
down_block_types=("DownEncoderBlock2D",),
|
180 |
+
block_out_channels=(64,),
|
181 |
+
layers_per_block=2,
|
182 |
+
norm_num_groups=32,
|
183 |
+
act_fn="silu",
|
184 |
+
double_z=True,
|
185 |
+
num_slices=2,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.layers_per_block = layers_per_block
|
189 |
+
|
190 |
+
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
191 |
+
|
192 |
+
self.mid_block = None
|
193 |
+
self.down_blocks = nn.ModuleList([])
|
194 |
+
|
195 |
+
# down
|
196 |
+
output_channel = block_out_channels[0]
|
197 |
+
for i, down_block_type in enumerate(down_block_types):
|
198 |
+
input_channel = output_channel
|
199 |
+
output_channel = block_out_channels[i]
|
200 |
+
is_final_block = i == len(block_out_channels) - 1
|
201 |
+
|
202 |
+
down_block = get_down_block(
|
203 |
+
down_block_type,
|
204 |
+
num_layers=self.layers_per_block,
|
205 |
+
in_channels=input_channel,
|
206 |
+
out_channels=output_channel,
|
207 |
+
add_downsample=not is_final_block,
|
208 |
+
resnet_eps=1e-6,
|
209 |
+
downsample_padding=0,
|
210 |
+
resnet_act_fn=act_fn,
|
211 |
+
resnet_groups=norm_num_groups,
|
212 |
+
attention_head_dim=output_channel,
|
213 |
+
temb_channels=None,
|
214 |
+
)
|
215 |
+
self.down_blocks.append(down_block)
|
216 |
+
|
217 |
+
# mid
|
218 |
+
self.mid_block = UNetMidBlock2D(
|
219 |
+
in_channels=block_out_channels[-1],
|
220 |
+
resnet_eps=1e-6,
|
221 |
+
resnet_act_fn=act_fn,
|
222 |
+
output_scale_factor=1,
|
223 |
+
resnet_time_scale_shift="default",
|
224 |
+
attention_head_dim=block_out_channels[-1],
|
225 |
+
resnet_groups=norm_num_groups,
|
226 |
+
temb_channels=None,
|
227 |
+
)
|
228 |
+
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
229 |
+
|
230 |
+
# out
|
231 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
232 |
+
self.conv_act = nn.SiLU()
|
233 |
+
|
234 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
235 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
236 |
+
|
237 |
+
# replace forward of ResBlocks
|
238 |
+
def wrapper(func, module, num_slices):
|
239 |
+
def forward(*args, **kwargs):
|
240 |
+
return func(module, num_slices, *args, **kwargs)
|
241 |
+
|
242 |
+
return forward
|
243 |
+
|
244 |
+
self.num_slices = num_slices
|
245 |
+
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
246 |
+
# print(f"initial divisor: {div}")
|
247 |
+
if div >= 2:
|
248 |
+
div = int(div)
|
249 |
+
for resnet in self.mid_block.resnets:
|
250 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
251 |
+
# midblock doesn't have downsample
|
252 |
+
|
253 |
+
for i, down_block in enumerate(self.down_blocks[::-1]):
|
254 |
+
if div >= 2:
|
255 |
+
div = int(div)
|
256 |
+
# print(f"down block: {i} divisor: {div}")
|
257 |
+
for resnet in down_block.resnets:
|
258 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
259 |
+
if down_block.downsamplers is not None:
|
260 |
+
# print("has downsample")
|
261 |
+
for downsample in down_block.downsamplers:
|
262 |
+
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
263 |
+
div *= 2
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
sample = x
|
267 |
+
del x
|
268 |
+
|
269 |
+
org_device = sample.device
|
270 |
+
cpu_device = torch.device("cpu")
|
271 |
+
|
272 |
+
# sample = self.conv_in(sample)
|
273 |
+
sample = sample.to(cpu_device)
|
274 |
+
sliced = slice_h(sample, self.num_slices)
|
275 |
+
del sample
|
276 |
+
|
277 |
+
for i in range(len(sliced)):
|
278 |
+
x = sliced[i]
|
279 |
+
sliced[i] = None
|
280 |
+
|
281 |
+
x = x.to(org_device)
|
282 |
+
x = self.conv_in(x)
|
283 |
+
x = x.to(cpu_device)
|
284 |
+
sliced[i] = x
|
285 |
+
del x
|
286 |
+
|
287 |
+
sample = cat_h(sliced)
|
288 |
+
del sliced
|
289 |
+
|
290 |
+
sample = sample.to(org_device)
|
291 |
+
|
292 |
+
# down
|
293 |
+
for down_block in self.down_blocks:
|
294 |
+
sample = down_block(sample)
|
295 |
+
|
296 |
+
# middle
|
297 |
+
sample = self.mid_block(sample)
|
298 |
+
|
299 |
+
# post-process
|
300 |
+
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
|
301 |
+
sample = self.conv_norm_out(sample)
|
302 |
+
sample = self.conv_act(sample)
|
303 |
+
sample = self.conv_out(sample)
|
304 |
+
|
305 |
+
return sample
|
306 |
+
|
307 |
+
def downsample_forward(self, _self, num_slices, hidden_states):
|
308 |
+
assert hidden_states.shape[1] == _self.channels
|
309 |
+
assert _self.use_conv and _self.padding == 0
|
310 |
+
print("downsample forward", num_slices, hidden_states.shape)
|
311 |
+
|
312 |
+
org_device = hidden_states.device
|
313 |
+
cpu_device = torch.device("cpu")
|
314 |
+
|
315 |
+
hidden_states = hidden_states.to(cpu_device)
|
316 |
+
pad = (0, 1, 0, 1)
|
317 |
+
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
318 |
+
|
319 |
+
# slice with even number because of stride 2
|
320 |
+
# strideが2なので偶数でスライスする
|
321 |
+
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
322 |
+
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
|
323 |
+
size = size + 1 if size % 2 == 1 else size
|
324 |
+
|
325 |
+
sliced = []
|
326 |
+
for i in range(num_slices):
|
327 |
+
if i == 0:
|
328 |
+
sliced.append(hidden_states[:, :, : size + 1, :])
|
329 |
+
else:
|
330 |
+
end = size * (i + 1) + 1
|
331 |
+
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
|
332 |
+
end = hidden_states.shape[2]
|
333 |
+
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
|
334 |
+
if end >= hidden_states.shape[2]:
|
335 |
+
break
|
336 |
+
del hidden_states
|
337 |
+
|
338 |
+
for i in range(len(sliced)):
|
339 |
+
x = sliced[i]
|
340 |
+
sliced[i] = None
|
341 |
+
|
342 |
+
x = x.to(org_device)
|
343 |
+
x = _self.conv(x)
|
344 |
+
x = x.to(cpu_device)
|
345 |
+
|
346 |
+
# ここだけ雰囲気が違うのはCopilotのせい
|
347 |
+
if i == 0:
|
348 |
+
hidden_states = x
|
349 |
+
else:
|
350 |
+
hidden_states = torch.cat([hidden_states, x], dim=2)
|
351 |
+
|
352 |
+
hidden_states = hidden_states.to(org_device)
|
353 |
+
# print("downsample forward done", hidden_states.shape)
|
354 |
+
return hidden_states
|
355 |
+
|
356 |
+
|
357 |
+
class SlicingDecoder(nn.Module):
|
358 |
+
def __init__(
|
359 |
+
self,
|
360 |
+
in_channels=3,
|
361 |
+
out_channels=3,
|
362 |
+
up_block_types=("UpDecoderBlock2D",),
|
363 |
+
block_out_channels=(64,),
|
364 |
+
layers_per_block=2,
|
365 |
+
norm_num_groups=32,
|
366 |
+
act_fn="silu",
|
367 |
+
num_slices=2,
|
368 |
+
):
|
369 |
+
super().__init__()
|
370 |
+
self.layers_per_block = layers_per_block
|
371 |
+
|
372 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
373 |
+
|
374 |
+
self.mid_block = None
|
375 |
+
self.up_blocks = nn.ModuleList([])
|
376 |
+
|
377 |
+
# mid
|
378 |
+
self.mid_block = UNetMidBlock2D(
|
379 |
+
in_channels=block_out_channels[-1],
|
380 |
+
resnet_eps=1e-6,
|
381 |
+
resnet_act_fn=act_fn,
|
382 |
+
output_scale_factor=1,
|
383 |
+
resnet_time_scale_shift="default",
|
384 |
+
attention_head_dim=block_out_channels[-1],
|
385 |
+
resnet_groups=norm_num_groups,
|
386 |
+
temb_channels=None,
|
387 |
+
)
|
388 |
+
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
389 |
+
|
390 |
+
# up
|
391 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
392 |
+
output_channel = reversed_block_out_channels[0]
|
393 |
+
for i, up_block_type in enumerate(up_block_types):
|
394 |
+
prev_output_channel = output_channel
|
395 |
+
output_channel = reversed_block_out_channels[i]
|
396 |
+
|
397 |
+
is_final_block = i == len(block_out_channels) - 1
|
398 |
+
|
399 |
+
up_block = get_up_block(
|
400 |
+
up_block_type,
|
401 |
+
num_layers=self.layers_per_block + 1,
|
402 |
+
in_channels=prev_output_channel,
|
403 |
+
out_channels=output_channel,
|
404 |
+
prev_output_channel=None,
|
405 |
+
add_upsample=not is_final_block,
|
406 |
+
resnet_eps=1e-6,
|
407 |
+
resnet_act_fn=act_fn,
|
408 |
+
resnet_groups=norm_num_groups,
|
409 |
+
attention_head_dim=output_channel,
|
410 |
+
temb_channels=None,
|
411 |
+
)
|
412 |
+
self.up_blocks.append(up_block)
|
413 |
+
prev_output_channel = output_channel
|
414 |
+
|
415 |
+
# out
|
416 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
417 |
+
self.conv_act = nn.SiLU()
|
418 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
419 |
+
|
420 |
+
# replace forward of ResBlocks
|
421 |
+
def wrapper(func, module, num_slices):
|
422 |
+
def forward(*args, **kwargs):
|
423 |
+
return func(module, num_slices, *args, **kwargs)
|
424 |
+
|
425 |
+
return forward
|
426 |
+
|
427 |
+
self.num_slices = num_slices
|
428 |
+
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
429 |
+
print(f"initial divisor: {div}")
|
430 |
+
if div >= 2:
|
431 |
+
div = int(div)
|
432 |
+
for resnet in self.mid_block.resnets:
|
433 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
434 |
+
# midblock doesn't have upsample
|
435 |
+
|
436 |
+
for i, up_block in enumerate(self.up_blocks):
|
437 |
+
if div >= 2:
|
438 |
+
div = int(div)
|
439 |
+
# print(f"up block: {i} divisor: {div}")
|
440 |
+
for resnet in up_block.resnets:
|
441 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
442 |
+
if up_block.upsamplers is not None:
|
443 |
+
# print("has upsample")
|
444 |
+
for upsample in up_block.upsamplers:
|
445 |
+
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
446 |
+
div *= 2
|
447 |
+
|
448 |
+
def forward(self, z):
|
449 |
+
sample = z
|
450 |
+
del z
|
451 |
+
sample = self.conv_in(sample)
|
452 |
+
|
453 |
+
# middle
|
454 |
+
sample = self.mid_block(sample)
|
455 |
+
|
456 |
+
# up
|
457 |
+
for i, up_block in enumerate(self.up_blocks):
|
458 |
+
sample = up_block(sample)
|
459 |
+
|
460 |
+
# post-process
|
461 |
+
sample = self.conv_norm_out(sample)
|
462 |
+
sample = self.conv_act(sample)
|
463 |
+
|
464 |
+
# conv_out with slicing because of VRAM usage
|
465 |
+
# conv_outはとてもVRAM使うのでスライスして対応
|
466 |
+
org_device = sample.device
|
467 |
+
cpu_device = torch.device("cpu")
|
468 |
+
sample = sample.to(cpu_device)
|
469 |
+
|
470 |
+
sliced = slice_h(sample, self.num_slices)
|
471 |
+
del sample
|
472 |
+
for i in range(len(sliced)):
|
473 |
+
x = sliced[i]
|
474 |
+
sliced[i] = None
|
475 |
+
|
476 |
+
x = x.to(org_device)
|
477 |
+
x = self.conv_out(x)
|
478 |
+
x = x.to(cpu_device)
|
479 |
+
sliced[i] = x
|
480 |
+
sample = cat_h(sliced)
|
481 |
+
del sliced
|
482 |
+
|
483 |
+
sample = sample.to(org_device)
|
484 |
+
return sample
|
485 |
+
|
486 |
+
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
|
487 |
+
assert hidden_states.shape[1] == _self.channels
|
488 |
+
assert _self.use_conv_transpose == False and _self.use_conv
|
489 |
+
|
490 |
+
org_dtype = hidden_states.dtype
|
491 |
+
org_device = hidden_states.device
|
492 |
+
cpu_device = torch.device("cpu")
|
493 |
+
|
494 |
+
hidden_states = hidden_states.to(cpu_device)
|
495 |
+
sliced = slice_h(hidden_states, num_slices)
|
496 |
+
del hidden_states
|
497 |
+
|
498 |
+
for i in range(len(sliced)):
|
499 |
+
x = sliced[i]
|
500 |
+
sliced[i] = None
|
501 |
+
|
502 |
+
x = x.to(org_device)
|
503 |
+
|
504 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
505 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
506 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
507 |
+
# PyTorch 2で直らないかね……
|
508 |
+
if org_dtype == torch.bfloat16:
|
509 |
+
x = x.to(torch.float32)
|
510 |
+
|
511 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
512 |
+
|
513 |
+
if org_dtype == torch.bfloat16:
|
514 |
+
x = x.to(org_dtype)
|
515 |
+
|
516 |
+
x = _self.conv(x)
|
517 |
+
|
518 |
+
# upsampleされてるのでpadは2になる
|
519 |
+
if i == 0:
|
520 |
+
x = x[:, :, :-2, :]
|
521 |
+
elif i == num_slices - 1:
|
522 |
+
x = x[:, :, 2:, :]
|
523 |
+
else:
|
524 |
+
x = x[:, :, 2:-2, :]
|
525 |
+
|
526 |
+
x = x.to(cpu_device)
|
527 |
+
sliced[i] = x
|
528 |
+
del x
|
529 |
+
|
530 |
+
hidden_states = torch.cat(sliced, dim=2)
|
531 |
+
# print("us hidden_states", hidden_states.shape)
|
532 |
+
del sliced
|
533 |
+
|
534 |
+
hidden_states = hidden_states.to(org_device)
|
535 |
+
return hidden_states
|
536 |
+
|
537 |
+
|
538 |
+
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
|
539 |
+
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
540 |
+
and Max Welling.
|
541 |
+
|
542 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
543 |
+
implements for all the model (such as downloading or saving, etc.)
|
544 |
+
|
545 |
+
Parameters:
|
546 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
547 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
548 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
549 |
+
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
550 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
551 |
+
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
552 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
553 |
+
obj:`(64,)`): Tuple of block output channels.
|
554 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
555 |
+
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
556 |
+
sample_size (`int`, *optional*, defaults to `32`): TODO
|
557 |
+
"""
|
558 |
+
|
559 |
+
@register_to_config
|
560 |
+
def __init__(
|
561 |
+
self,
|
562 |
+
in_channels: int = 3,
|
563 |
+
out_channels: int = 3,
|
564 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
565 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
566 |
+
block_out_channels: Tuple[int] = (64,),
|
567 |
+
layers_per_block: int = 1,
|
568 |
+
act_fn: str = "silu",
|
569 |
+
latent_channels: int = 4,
|
570 |
+
norm_num_groups: int = 32,
|
571 |
+
sample_size: int = 32,
|
572 |
+
num_slices: int = 16,
|
573 |
+
):
|
574 |
+
super().__init__()
|
575 |
+
|
576 |
+
# pass init params to Encoder
|
577 |
+
self.encoder = SlicingEncoder(
|
578 |
+
in_channels=in_channels,
|
579 |
+
out_channels=latent_channels,
|
580 |
+
down_block_types=down_block_types,
|
581 |
+
block_out_channels=block_out_channels,
|
582 |
+
layers_per_block=layers_per_block,
|
583 |
+
act_fn=act_fn,
|
584 |
+
norm_num_groups=norm_num_groups,
|
585 |
+
double_z=True,
|
586 |
+
num_slices=num_slices,
|
587 |
+
)
|
588 |
+
|
589 |
+
# pass init params to Decoder
|
590 |
+
self.decoder = SlicingDecoder(
|
591 |
+
in_channels=latent_channels,
|
592 |
+
out_channels=out_channels,
|
593 |
+
up_block_types=up_block_types,
|
594 |
+
block_out_channels=block_out_channels,
|
595 |
+
layers_per_block=layers_per_block,
|
596 |
+
norm_num_groups=norm_num_groups,
|
597 |
+
act_fn=act_fn,
|
598 |
+
num_slices=num_slices,
|
599 |
+
)
|
600 |
+
|
601 |
+
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
602 |
+
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
603 |
+
self.use_slicing = False
|
604 |
+
|
605 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
606 |
+
h = self.encoder(x)
|
607 |
+
moments = self.quant_conv(h)
|
608 |
+
posterior = DiagonalGaussianDistribution(moments)
|
609 |
+
|
610 |
+
if not return_dict:
|
611 |
+
return (posterior,)
|
612 |
+
|
613 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
614 |
+
|
615 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
616 |
+
z = self.post_quant_conv(z)
|
617 |
+
dec = self.decoder(z)
|
618 |
+
|
619 |
+
if not return_dict:
|
620 |
+
return (dec,)
|
621 |
+
|
622 |
+
return DecoderOutput(sample=dec)
|
623 |
+
|
624 |
+
# これはバッチ方向のスライシング 紛らわしい
|
625 |
+
def enable_slicing(self):
|
626 |
+
r"""
|
627 |
+
Enable sliced VAE decoding.
|
628 |
+
|
629 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
630 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
631 |
+
"""
|
632 |
+
self.use_slicing = True
|
633 |
+
|
634 |
+
def disable_slicing(self):
|
635 |
+
r"""
|
636 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
637 |
+
decoding in one step.
|
638 |
+
"""
|
639 |
+
self.use_slicing = False
|
640 |
+
|
641 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
642 |
+
if self.use_slicing and z.shape[0] > 1:
|
643 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
644 |
+
decoded = torch.cat(decoded_slices)
|
645 |
+
else:
|
646 |
+
decoded = self._decode(z).sample
|
647 |
+
|
648 |
+
if not return_dict:
|
649 |
+
return (decoded,)
|
650 |
+
|
651 |
+
return DecoderOutput(sample=decoded)
|
652 |
+
|
653 |
+
def forward(
|
654 |
+
self,
|
655 |
+
sample: torch.FloatTensor,
|
656 |
+
sample_posterior: bool = False,
|
657 |
+
return_dict: bool = True,
|
658 |
+
generator: Optional[torch.Generator] = None,
|
659 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
660 |
+
r"""
|
661 |
+
Args:
|
662 |
+
sample (`torch.FloatTensor`): Input sample.
|
663 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
664 |
+
Whether to sample from the posterior.
|
665 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
666 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
667 |
+
"""
|
668 |
+
x = sample
|
669 |
+
posterior = self.encode(x).latent_dist
|
670 |
+
if sample_posterior:
|
671 |
+
z = posterior.sample(generator=generator)
|
672 |
+
else:
|
673 |
+
z = posterior.mode()
|
674 |
+
dec = self.decode(z).sample
|
675 |
+
|
676 |
+
if not return_dict:
|
677 |
+
return (dec,)
|
678 |
+
|
679 |
+
return DecoderOutput(sample=dec)
|
external/llite/library/train_util.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
external/llite/library/utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
|
5 |
+
def fire_in_thread(f, *args, **kwargs):
|
6 |
+
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, List, Type
|
3 |
+
import torch
|
4 |
+
from external.llite.library import sdxl_original_unet
|
5 |
+
|
6 |
+
|
7 |
+
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
8 |
+
SKIP_INPUT_BLOCKS = False
|
9 |
+
|
10 |
+
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
11 |
+
SKIP_OUTPUT_BLOCKS = True
|
12 |
+
|
13 |
+
# conv2dに適用するかどうか / if True, conv2d are not applied
|
14 |
+
SKIP_CONV2D = False
|
15 |
+
|
16 |
+
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
17 |
+
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
18 |
+
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
19 |
+
|
20 |
+
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
21 |
+
ATTN1_2_ONLY = True
|
22 |
+
|
23 |
+
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
24 |
+
ATTN_QKV_ONLY = True
|
25 |
+
|
26 |
+
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
27 |
+
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
28 |
+
ATTN1_ETC_ONLY = False # True
|
29 |
+
|
30 |
+
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
31 |
+
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
32 |
+
TRANSFORMER_MAX_BLOCK_INDEX = None
|
33 |
+
|
34 |
+
|
35 |
+
class LLLiteModule(torch.nn.Module):
|
36 |
+
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
40 |
+
self.lllite_name = name
|
41 |
+
self.cond_emb_dim = cond_emb_dim
|
42 |
+
self.org_module = [org_module]
|
43 |
+
self.dropout = dropout
|
44 |
+
self.multiplier = multiplier
|
45 |
+
|
46 |
+
if self.is_conv2d:
|
47 |
+
in_dim = org_module.in_channels
|
48 |
+
else:
|
49 |
+
in_dim = org_module.in_features
|
50 |
+
|
51 |
+
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
52 |
+
# conditioning1 embeds conditioning image. it is not called for each timestep
|
53 |
+
modules = []
|
54 |
+
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
55 |
+
if depth == 1:
|
56 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
57 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
58 |
+
elif depth == 2:
|
59 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
60 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
61 |
+
elif depth == 3:
|
62 |
+
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
63 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
64 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
65 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
66 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
67 |
+
|
68 |
+
self.conditioning1 = torch.nn.Sequential(*modules)
|
69 |
+
|
70 |
+
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
71 |
+
# midでconditioning image embeddingと入力を結合する
|
72 |
+
# upで元の次元数に戻す
|
73 |
+
# これらはtimestepごとに呼ばれる
|
74 |
+
# reduce the number of input dimensions with down. inspired by LoRA
|
75 |
+
# combine conditioning image embedding and input with mid
|
76 |
+
# restore to the original dimension with up
|
77 |
+
# these are called for each timestep
|
78 |
+
|
79 |
+
if self.is_conv2d:
|
80 |
+
self.down = torch.nn.Sequential(
|
81 |
+
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
82 |
+
torch.nn.ReLU(inplace=True),
|
83 |
+
)
|
84 |
+
self.mid = torch.nn.Sequential(
|
85 |
+
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
86 |
+
torch.nn.ReLU(inplace=True),
|
87 |
+
)
|
88 |
+
self.up = torch.nn.Sequential(
|
89 |
+
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
# midの前にconditioningをreshapeすること / reshape conditioning before mid
|
93 |
+
self.down = torch.nn.Sequential(
|
94 |
+
torch.nn.Linear(in_dim, mlp_dim),
|
95 |
+
torch.nn.ReLU(inplace=True),
|
96 |
+
)
|
97 |
+
self.mid = torch.nn.Sequential(
|
98 |
+
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
99 |
+
torch.nn.ReLU(inplace=True),
|
100 |
+
)
|
101 |
+
self.up = torch.nn.Sequential(
|
102 |
+
torch.nn.Linear(mlp_dim, in_dim),
|
103 |
+
)
|
104 |
+
|
105 |
+
# Zero-Convにする / set to Zero-Conv
|
106 |
+
torch.nn.init.zeros_(self.up[0].weight) # zero conv
|
107 |
+
|
108 |
+
self.depth = depth # 1~3
|
109 |
+
self.cond_emb = None
|
110 |
+
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
|
111 |
+
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
|
112 |
+
|
113 |
+
# batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
|
114 |
+
# Controlの種類によっては使えるかも
|
115 |
+
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
|
116 |
+
# it may be available depending on the type of Control
|
117 |
+
|
118 |
+
def set_cond_image(self, cond_image):
|
119 |
+
r"""
|
120 |
+
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
121 |
+
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
122 |
+
"""
|
123 |
+
if cond_image is None:
|
124 |
+
self.cond_emb = None
|
125 |
+
return
|
126 |
+
|
127 |
+
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
128 |
+
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
129 |
+
cx = self.conditioning1(cond_image)
|
130 |
+
if not self.is_conv2d:
|
131 |
+
# reshape / b,c,h,w -> b,h*w,c
|
132 |
+
n, c, h, w = cx.shape
|
133 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
134 |
+
self.cond_emb = cx
|
135 |
+
|
136 |
+
def set_batch_cond_only(self, cond_only, zeros):
|
137 |
+
self.batch_cond_only = cond_only
|
138 |
+
self.use_zeros_for_batch_uncond = zeros
|
139 |
+
|
140 |
+
def apply_to(self):
|
141 |
+
self.org_forward = self.org_module[0].forward
|
142 |
+
self.org_module[0].forward = self.forward
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
r"""
|
146 |
+
学習用の便利forward。元のモジュールのforwardを呼び出す
|
147 |
+
/ convenient forward for training. call the forward of the original module
|
148 |
+
"""
|
149 |
+
if self.multiplier == 0.0 or self.cond_emb is None:
|
150 |
+
return self.org_forward(x)
|
151 |
+
|
152 |
+
cx = self.cond_emb
|
153 |
+
|
154 |
+
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
155 |
+
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
156 |
+
if self.use_zeros_for_batch_uncond:
|
157 |
+
cx[0::2] = 0.0 # uncond is zero
|
158 |
+
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
159 |
+
|
160 |
+
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
161 |
+
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
162 |
+
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
163 |
+
# we expect that it will mix well by combining in the channel direction instead of adding
|
164 |
+
|
165 |
+
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
|
166 |
+
cx = self.mid(cx)
|
167 |
+
|
168 |
+
if self.dropout is not None and self.training:
|
169 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
170 |
+
|
171 |
+
cx = self.up(cx) * self.multiplier
|
172 |
+
|
173 |
+
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
174 |
+
if self.batch_cond_only:
|
175 |
+
zx = torch.zeros_like(x)
|
176 |
+
zx[1::2] += cx
|
177 |
+
cx = zx
|
178 |
+
|
179 |
+
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class ControlNetLLLite(torch.nn.Module):
|
184 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
185 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
186 |
+
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
|
190 |
+
cond_emb_dim: int = 16,
|
191 |
+
mlp_dim: int = 16,
|
192 |
+
dropout: Optional[float] = None,
|
193 |
+
varbose: Optional[bool] = False,
|
194 |
+
multiplier: Optional[float] = 1.0,
|
195 |
+
) -> None:
|
196 |
+
super().__init__()
|
197 |
+
# self.unets = [unet]
|
198 |
+
|
199 |
+
def create_modules(
|
200 |
+
root_module: torch.nn.Module,
|
201 |
+
target_replace_modules: List[torch.nn.Module],
|
202 |
+
module_class: Type[object],
|
203 |
+
) -> List[torch.nn.Module]:
|
204 |
+
prefix = "lllite_unet"
|
205 |
+
|
206 |
+
modules = []
|
207 |
+
for name, module in root_module.named_modules():
|
208 |
+
if module.__class__.__name__ in target_replace_modules:
|
209 |
+
for child_name, child_module in module.named_modules():
|
210 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
211 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
212 |
+
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
213 |
+
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
214 |
+
# block index to depth: depth is using to calculate conditioning size and channels
|
215 |
+
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
216 |
+
index1 = int(index1)
|
217 |
+
if block_name == "input_blocks":
|
218 |
+
if SKIP_INPUT_BLOCKS:
|
219 |
+
continue
|
220 |
+
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
221 |
+
elif block_name == "middle_block":
|
222 |
+
depth = 3
|
223 |
+
elif block_name == "output_blocks":
|
224 |
+
if SKIP_OUTPUT_BLOCKS:
|
225 |
+
continue
|
226 |
+
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
227 |
+
if int(index2) >= 2:
|
228 |
+
depth -= 1
|
229 |
+
else:
|
230 |
+
raise NotImplementedError()
|
231 |
+
|
232 |
+
lllite_name = prefix + "." + name + "." + child_name
|
233 |
+
lllite_name = lllite_name.replace(".", "_")
|
234 |
+
|
235 |
+
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
236 |
+
p = lllite_name.find("transformer_blocks")
|
237 |
+
if p >= 0:
|
238 |
+
tf_index = int(lllite_name[p:].split("_")[2])
|
239 |
+
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
240 |
+
continue
|
241 |
+
|
242 |
+
# time embは適用外とする
|
243 |
+
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
244 |
+
# time emb is not applied
|
245 |
+
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
246 |
+
if "emb_layers" in lllite_name or (
|
247 |
+
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
248 |
+
):
|
249 |
+
continue
|
250 |
+
|
251 |
+
if ATTN1_2_ONLY:
|
252 |
+
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
253 |
+
continue
|
254 |
+
if ATTN_QKV_ONLY:
|
255 |
+
if "to_out" in lllite_name:
|
256 |
+
continue
|
257 |
+
|
258 |
+
if ATTN1_ETC_ONLY:
|
259 |
+
if "proj_out" in lllite_name:
|
260 |
+
pass
|
261 |
+
elif "attn1" in lllite_name and (
|
262 |
+
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
263 |
+
):
|
264 |
+
pass
|
265 |
+
elif "ff_net_2" in lllite_name:
|
266 |
+
pass
|
267 |
+
else:
|
268 |
+
continue
|
269 |
+
|
270 |
+
module = module_class(
|
271 |
+
depth,
|
272 |
+
cond_emb_dim,
|
273 |
+
lllite_name,
|
274 |
+
child_module,
|
275 |
+
mlp_dim,
|
276 |
+
dropout=dropout,
|
277 |
+
multiplier=multiplier,
|
278 |
+
)
|
279 |
+
modules.append(module)
|
280 |
+
print(f"Returning {len(modules)} modules for llite net")
|
281 |
+
return modules
|
282 |
+
|
283 |
+
target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
284 |
+
if not TRANSFORMER_ONLY:
|
285 |
+
target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
286 |
+
|
287 |
+
# create module instances
|
288 |
+
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
289 |
+
print(f"created ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
return x # dummy
|
293 |
+
|
294 |
+
def set_cond_image(self, cond_image):
|
295 |
+
r"""
|
296 |
+
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
297 |
+
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
298 |
+
"""
|
299 |
+
for module in self.unet_modules:
|
300 |
+
module.set_cond_image(cond_image)
|
301 |
+
|
302 |
+
def set_batch_cond_only(self, cond_only, zeros):
|
303 |
+
for module in self.unet_modules:
|
304 |
+
module.set_batch_cond_only(cond_only, zeros)
|
305 |
+
|
306 |
+
def set_multiplier(self, multiplier):
|
307 |
+
for module in self.unet_modules:
|
308 |
+
module.multiplier = multiplier
|
309 |
+
|
310 |
+
def load_weights(self, file):
|
311 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
312 |
+
from safetensors.torch import load_file
|
313 |
+
|
314 |
+
weights_sd = load_file(file)
|
315 |
+
else:
|
316 |
+
weights_sd = torch.load(file, map_location="cpu")
|
317 |
+
|
318 |
+
info = self.load_state_dict(weights_sd, False)
|
319 |
+
return info
|
320 |
+
|
321 |
+
def apply_to(self):
|
322 |
+
print("applying LLLite for U-Net...")
|
323 |
+
for module in self.unet_modules:
|
324 |
+
module.apply_to()
|
325 |
+
self.add_module(module.lllite_name, module)
|
326 |
+
|
327 |
+
# マージできるかどうかを返す
|
328 |
+
def is_mergeable(self):
|
329 |
+
return False
|
330 |
+
|
331 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
332 |
+
raise NotImplementedError()
|
333 |
+
|
334 |
+
def enable_gradient_checkpointing(self):
|
335 |
+
# not supported
|
336 |
+
pass
|
337 |
+
|
338 |
+
def prepare_optimizer_params(self):
|
339 |
+
self.requires_grad_(True)
|
340 |
+
return self.parameters()
|
341 |
+
|
342 |
+
def prepare_grad_etc(self):
|
343 |
+
self.requires_grad_(True)
|
344 |
+
|
345 |
+
def on_epoch_start(self):
|
346 |
+
self.train()
|
347 |
+
|
348 |
+
def get_trainable_params(self):
|
349 |
+
return self.parameters()
|
350 |
+
|
351 |
+
def save_weights(self, file, dtype, metadata):
|
352 |
+
if metadata is not None and len(metadata) == 0:
|
353 |
+
metadata = None
|
354 |
+
|
355 |
+
state_dict = self.state_dict()
|
356 |
+
|
357 |
+
if dtype is not None:
|
358 |
+
for key in list(state_dict.keys()):
|
359 |
+
v = state_dict[key]
|
360 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
361 |
+
state_dict[key] = v
|
362 |
+
|
363 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
364 |
+
from safetensors.torch import save_file
|
365 |
+
|
366 |
+
save_file(state_dict, file, metadata)
|
367 |
+
else:
|
368 |
+
torch.save(state_dict, file)
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
# デバッグ用 / for debug
|
373 |
+
|
374 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
375 |
+
|
376 |
+
# test shape etc
|
377 |
+
print("create unet")
|
378 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
379 |
+
unet.to("cuda").to(torch.float16)
|
380 |
+
|
381 |
+
print("create ControlNet-LLLite")
|
382 |
+
control_net = ControlNetLLLite(unet, 32, 64)
|
383 |
+
control_net.apply_to()
|
384 |
+
control_net.to("cuda")
|
385 |
+
|
386 |
+
print(control_net)
|
387 |
+
|
388 |
+
# print number of parameters
|
389 |
+
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
390 |
+
|
391 |
+
input()
|
392 |
+
|
393 |
+
unet.set_use_memory_efficient_attention(True, False)
|
394 |
+
unet.set_gradient_checkpointing(True)
|
395 |
+
unet.train() # for gradient checkpointing
|
396 |
+
|
397 |
+
control_net.train()
|
398 |
+
|
399 |
+
# # visualize
|
400 |
+
# import torchviz
|
401 |
+
# print("run visualize")
|
402 |
+
# controlnet.set_control(conditioning_image)
|
403 |
+
# output = unet(x, t, ctx, y)
|
404 |
+
# print("make_dot")
|
405 |
+
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
406 |
+
# print("render")
|
407 |
+
# image.format = "svg" # "png"
|
408 |
+
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
409 |
+
# input()
|
410 |
+
|
411 |
+
import bitsandbytes
|
412 |
+
|
413 |
+
optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
|
414 |
+
|
415 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
416 |
+
|
417 |
+
print("start training")
|
418 |
+
steps = 10
|
419 |
+
|
420 |
+
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
421 |
+
for step in range(steps):
|
422 |
+
print(f"step {step}")
|
423 |
+
|
424 |
+
batch_size = 1
|
425 |
+
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
426 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
427 |
+
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
428 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
429 |
+
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
430 |
+
|
431 |
+
with torch.cuda.amp.autocast(enabled=True):
|
432 |
+
control_net.set_cond_image(conditioning_image)
|
433 |
+
|
434 |
+
output = unet(x, t, ctx, y)
|
435 |
+
target = torch.randn_like(output)
|
436 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
437 |
+
|
438 |
+
scaler.scale(loss).backward()
|
439 |
+
scaler.step(optimizer)
|
440 |
+
scaler.update()
|
441 |
+
optimizer.zero_grad(set_to_none=True)
|
442 |
+
print(sample_param)
|
443 |
+
|
444 |
+
# from safetensors.torch import save_file
|
445 |
+
|
446 |
+
# save_file(control_net.state_dict(), "logs/control_net.safetensors")
|
external/llite/networks/check_lora_weights.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from safetensors.torch import load_file
|
5 |
+
|
6 |
+
|
7 |
+
def main(file):
|
8 |
+
print(f"loading: {file}")
|
9 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
10 |
+
sd = load_file(file)
|
11 |
+
else:
|
12 |
+
sd = torch.load(file, map_location="cpu")
|
13 |
+
|
14 |
+
values = []
|
15 |
+
|
16 |
+
keys = list(sd.keys())
|
17 |
+
for key in keys:
|
18 |
+
if "lora_up" in key or "lora_down" in key:
|
19 |
+
values.append((key, sd[key]))
|
20 |
+
print(f"number of LoRA modules: {len(values)}")
|
21 |
+
|
22 |
+
if args.show_all_keys:
|
23 |
+
for key in [k for k in keys if k not in values]:
|
24 |
+
values.append((key, sd[key]))
|
25 |
+
print(f"number of all modules: {len(values)}")
|
26 |
+
|
27 |
+
for key, value in values:
|
28 |
+
value = value.to(torch.float32)
|
29 |
+
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
30 |
+
|
31 |
+
|
32 |
+
def setup_parser() -> argparse.ArgumentParser:
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
35 |
+
parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
|
36 |
+
|
37 |
+
return parser
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = setup_parser()
|
42 |
+
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
main(args.file)
|
external/llite/networks/control_net_lllite.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, List, Type
|
3 |
+
import torch
|
4 |
+
from external.llite.library import sdxl_original_unet
|
5 |
+
|
6 |
+
|
7 |
+
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
8 |
+
SKIP_INPUT_BLOCKS = False
|
9 |
+
|
10 |
+
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
11 |
+
SKIP_OUTPUT_BLOCKS = True
|
12 |
+
|
13 |
+
# conv2dに適用するかどうか / if True, conv2d are not applied
|
14 |
+
SKIP_CONV2D = False
|
15 |
+
|
16 |
+
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
17 |
+
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
18 |
+
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
19 |
+
|
20 |
+
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
21 |
+
ATTN1_2_ONLY = True
|
22 |
+
|
23 |
+
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
24 |
+
ATTN_QKV_ONLY = True
|
25 |
+
|
26 |
+
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
27 |
+
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
28 |
+
ATTN1_ETC_ONLY = False # True
|
29 |
+
|
30 |
+
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
31 |
+
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
32 |
+
TRANSFORMER_MAX_BLOCK_INDEX = None
|
33 |
+
|
34 |
+
|
35 |
+
class LLLiteModule(torch.nn.Module):
|
36 |
+
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
40 |
+
self.lllite_name = name
|
41 |
+
self.cond_emb_dim = cond_emb_dim
|
42 |
+
self.org_module = [org_module]
|
43 |
+
self.dropout = dropout
|
44 |
+
self.multiplier = multiplier
|
45 |
+
|
46 |
+
if self.is_conv2d:
|
47 |
+
in_dim = org_module.in_channels
|
48 |
+
else:
|
49 |
+
in_dim = org_module.in_features
|
50 |
+
|
51 |
+
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
52 |
+
# conditioning1 embeds conditioning image. it is not called for each timestep
|
53 |
+
modules = []
|
54 |
+
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
55 |
+
if depth == 1:
|
56 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
57 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
58 |
+
elif depth == 2:
|
59 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
60 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
61 |
+
elif depth == 3:
|
62 |
+
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
63 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
64 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
65 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
66 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
67 |
+
|
68 |
+
self.conditioning1 = torch.nn.Sequential(*modules)
|
69 |
+
|
70 |
+
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
71 |
+
# midでconditioning image embeddingと入力を結合する
|
72 |
+
# upで元の次元数に戻す
|
73 |
+
# これらはtimestepごとに呼ばれる
|
74 |
+
# reduce the number of input dimensions with down. inspired by LoRA
|
75 |
+
# combine conditioning image embedding and input with mid
|
76 |
+
# restore to the original dimension with up
|
77 |
+
# these are called for each timestep
|
78 |
+
|
79 |
+
if self.is_conv2d:
|
80 |
+
self.down = torch.nn.Sequential(
|
81 |
+
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
82 |
+
torch.nn.ReLU(inplace=True),
|
83 |
+
)
|
84 |
+
self.mid = torch.nn.Sequential(
|
85 |
+
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
86 |
+
torch.nn.ReLU(inplace=True),
|
87 |
+
)
|
88 |
+
self.up = torch.nn.Sequential(
|
89 |
+
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
# midの前にconditioningをreshapeすること / reshape conditioning before mid
|
93 |
+
self.down = torch.nn.Sequential(
|
94 |
+
torch.nn.Linear(in_dim, mlp_dim),
|
95 |
+
torch.nn.ReLU(inplace=True),
|
96 |
+
)
|
97 |
+
self.mid = torch.nn.Sequential(
|
98 |
+
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
99 |
+
torch.nn.ReLU(inplace=True),
|
100 |
+
)
|
101 |
+
self.up = torch.nn.Sequential(
|
102 |
+
torch.nn.Linear(mlp_dim, in_dim),
|
103 |
+
)
|
104 |
+
|
105 |
+
# Zero-Convにする / set to Zero-Conv
|
106 |
+
torch.nn.init.zeros_(self.up[0].weight) # zero conv
|
107 |
+
|
108 |
+
self.depth = depth # 1~3
|
109 |
+
self.cond_emb = None
|
110 |
+
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
|
111 |
+
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
|
112 |
+
|
113 |
+
# batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
|
114 |
+
# Controlの種類によっては使えるかも
|
115 |
+
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
|
116 |
+
# it may be available depending on the type of Control
|
117 |
+
|
118 |
+
def set_cond_image(self, cond_image):
|
119 |
+
r"""
|
120 |
+
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
121 |
+
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
122 |
+
"""
|
123 |
+
if cond_image is None:
|
124 |
+
self.cond_emb = None
|
125 |
+
return
|
126 |
+
|
127 |
+
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
128 |
+
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
129 |
+
cx = self.conditioning1(cond_image)
|
130 |
+
if not self.is_conv2d:
|
131 |
+
# reshape / b,c,h,w -> b,h*w,c
|
132 |
+
n, c, h, w = cx.shape
|
133 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
134 |
+
self.cond_emb = cx
|
135 |
+
|
136 |
+
def set_batch_cond_only(self, cond_only, zeros):
|
137 |
+
self.batch_cond_only = cond_only
|
138 |
+
self.use_zeros_for_batch_uncond = zeros
|
139 |
+
|
140 |
+
def apply_to(self):
|
141 |
+
self.org_forward = self.org_module[0].forward
|
142 |
+
self.org_module[0].forward = self.forward
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
r"""
|
146 |
+
学習用の便利forward。元のモジュールのforwardを呼び出す
|
147 |
+
/ convenient forward for training. call the forward of the original module
|
148 |
+
"""
|
149 |
+
if self.multiplier == 0.0 or self.cond_emb is None:
|
150 |
+
return self.org_forward(x)
|
151 |
+
|
152 |
+
cx = self.cond_emb
|
153 |
+
|
154 |
+
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
155 |
+
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
156 |
+
if self.use_zeros_for_batch_uncond:
|
157 |
+
cx[0::2] = 0.0 # uncond is zero
|
158 |
+
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
159 |
+
|
160 |
+
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
161 |
+
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
162 |
+
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
163 |
+
# we expect that it will mix well by combining in the channel direction instead of adding
|
164 |
+
|
165 |
+
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
|
166 |
+
cx = self.mid(cx)
|
167 |
+
|
168 |
+
if self.dropout is not None and self.training:
|
169 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
170 |
+
|
171 |
+
cx = self.up(cx) * self.multiplier
|
172 |
+
|
173 |
+
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
174 |
+
if self.batch_cond_only:
|
175 |
+
zx = torch.zeros_like(x)
|
176 |
+
zx[1::2] += cx
|
177 |
+
cx = zx
|
178 |
+
|
179 |
+
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class ControlNetLLLite(torch.nn.Module):
|
184 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
185 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
186 |
+
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
|
190 |
+
cond_emb_dim: int = 16,
|
191 |
+
mlp_dim: int = 16,
|
192 |
+
dropout: Optional[float] = None,
|
193 |
+
varbose: Optional[bool] = False,
|
194 |
+
multiplier: Optional[float] = 1.0,
|
195 |
+
) -> None:
|
196 |
+
super().__init__()
|
197 |
+
# self.unets = [unet]
|
198 |
+
|
199 |
+
def create_modules(
|
200 |
+
root_module: torch.nn.Module,
|
201 |
+
target_replace_modules: List[torch.nn.Module],
|
202 |
+
module_class: Type[object],
|
203 |
+
) -> List[torch.nn.Module]:
|
204 |
+
prefix = "lllite_unet"
|
205 |
+
|
206 |
+
modules = []
|
207 |
+
for name, module in root_module.named_modules():
|
208 |
+
if module.__class__.__name__ in target_replace_modules:
|
209 |
+
for child_name, child_module in module.named_modules():
|
210 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
211 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
212 |
+
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
213 |
+
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
214 |
+
# block index to depth: depth is using to calculate conditioning size and channels
|
215 |
+
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
216 |
+
index1 = int(index1)
|
217 |
+
if block_name == "input_blocks":
|
218 |
+
if SKIP_INPUT_BLOCKS:
|
219 |
+
continue
|
220 |
+
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
221 |
+
elif block_name == "middle_block":
|
222 |
+
depth = 3
|
223 |
+
elif block_name == "output_blocks":
|
224 |
+
if SKIP_OUTPUT_BLOCKS:
|
225 |
+
continue
|
226 |
+
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
227 |
+
if int(index2) >= 2:
|
228 |
+
depth -= 1
|
229 |
+
else:
|
230 |
+
raise NotImplementedError()
|
231 |
+
|
232 |
+
lllite_name = prefix + "." + name + "." + child_name
|
233 |
+
lllite_name = lllite_name.replace(".", "_")
|
234 |
+
|
235 |
+
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
236 |
+
p = lllite_name.find("transformer_blocks")
|
237 |
+
if p >= 0:
|
238 |
+
tf_index = int(lllite_name[p:].split("_")[2])
|
239 |
+
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
240 |
+
continue
|
241 |
+
|
242 |
+
# time embは適用外とする
|
243 |
+
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
244 |
+
# time emb is not applied
|
245 |
+
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
246 |
+
if "emb_layers" in lllite_name or (
|
247 |
+
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
248 |
+
):
|
249 |
+
continue
|
250 |
+
|
251 |
+
if ATTN1_2_ONLY:
|
252 |
+
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
253 |
+
continue
|
254 |
+
if ATTN_QKV_ONLY:
|
255 |
+
if "to_out" in lllite_name:
|
256 |
+
continue
|
257 |
+
|
258 |
+
if ATTN1_ETC_ONLY:
|
259 |
+
if "proj_out" in lllite_name:
|
260 |
+
pass
|
261 |
+
elif "attn1" in lllite_name and (
|
262 |
+
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
263 |
+
):
|
264 |
+
pass
|
265 |
+
elif "ff_net_2" in lllite_name:
|
266 |
+
pass
|
267 |
+
else:
|
268 |
+
continue
|
269 |
+
|
270 |
+
module = module_class(
|
271 |
+
depth,
|
272 |
+
cond_emb_dim,
|
273 |
+
lllite_name,
|
274 |
+
child_module,
|
275 |
+
mlp_dim,
|
276 |
+
dropout=dropout,
|
277 |
+
multiplier=multiplier,
|
278 |
+
)
|
279 |
+
modules.append(module)
|
280 |
+
print(f"Returning {len(modules)} modules for llite net")
|
281 |
+
return modules
|
282 |
+
|
283 |
+
target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
284 |
+
if not TRANSFORMER_ONLY:
|
285 |
+
target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
286 |
+
|
287 |
+
# create module instances
|
288 |
+
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
289 |
+
print(f"created ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
return x # dummy
|
293 |
+
|
294 |
+
def set_cond_image(self, cond_image):
|
295 |
+
r"""
|
296 |
+
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
297 |
+
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
298 |
+
"""
|
299 |
+
for module in self.unet_modules:
|
300 |
+
module.set_cond_image(cond_image)
|
301 |
+
|
302 |
+
def set_batch_cond_only(self, cond_only, zeros):
|
303 |
+
for module in self.unet_modules:
|
304 |
+
module.set_batch_cond_only(cond_only, zeros)
|
305 |
+
|
306 |
+
def set_multiplier(self, multiplier):
|
307 |
+
for module in self.unet_modules:
|
308 |
+
module.multiplier = multiplier
|
309 |
+
|
310 |
+
def load_weights(self, file):
|
311 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
312 |
+
from safetensors.torch import load_file
|
313 |
+
|
314 |
+
weights_sd = load_file(file)
|
315 |
+
else:
|
316 |
+
weights_sd = torch.load(file, map_location="cpu")
|
317 |
+
|
318 |
+
info = self.load_state_dict(weights_sd, False)
|
319 |
+
return info
|
320 |
+
|
321 |
+
def apply_to(self):
|
322 |
+
print("applying LLLite for U-Net...")
|
323 |
+
for module in self.unet_modules:
|
324 |
+
module.apply_to()
|
325 |
+
self.add_module(module.lllite_name, module)
|
326 |
+
|
327 |
+
# マージできるかどうかを返す
|
328 |
+
def is_mergeable(self):
|
329 |
+
return False
|
330 |
+
|
331 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
332 |
+
raise NotImplementedError()
|
333 |
+
|
334 |
+
def enable_gradient_checkpointing(self):
|
335 |
+
# not supported
|
336 |
+
pass
|
337 |
+
|
338 |
+
def prepare_optimizer_params(self):
|
339 |
+
self.requires_grad_(True)
|
340 |
+
return self.parameters()
|
341 |
+
|
342 |
+
def prepare_grad_etc(self):
|
343 |
+
self.requires_grad_(True)
|
344 |
+
|
345 |
+
def on_epoch_start(self):
|
346 |
+
self.train()
|
347 |
+
|
348 |
+
def get_trainable_params(self):
|
349 |
+
return self.parameters()
|
350 |
+
|
351 |
+
def save_weights(self, file, dtype, metadata):
|
352 |
+
if metadata is not None and len(metadata) == 0:
|
353 |
+
metadata = None
|
354 |
+
|
355 |
+
state_dict = self.state_dict()
|
356 |
+
|
357 |
+
if dtype is not None:
|
358 |
+
for key in list(state_dict.keys()):
|
359 |
+
v = state_dict[key]
|
360 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
361 |
+
state_dict[key] = v
|
362 |
+
|
363 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
364 |
+
from safetensors.torch import save_file
|
365 |
+
|
366 |
+
save_file(state_dict, file, metadata)
|
367 |
+
else:
|
368 |
+
torch.save(state_dict, file)
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
# デバッグ用 / for debug
|
373 |
+
|
374 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
375 |
+
|
376 |
+
# test shape etc
|
377 |
+
print("create unet")
|
378 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
379 |
+
unet.to("cuda").to(torch.float16)
|
380 |
+
|
381 |
+
print("create ControlNet-LLLite")
|
382 |
+
control_net = ControlNetLLLite(unet, 32, 64)
|
383 |
+
control_net.apply_to()
|
384 |
+
control_net.to("cuda")
|
385 |
+
|
386 |
+
print(control_net)
|
387 |
+
|
388 |
+
# print number of parameters
|
389 |
+
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
390 |
+
|
391 |
+
input()
|
392 |
+
|
393 |
+
unet.set_use_memory_efficient_attention(True, False)
|
394 |
+
unet.set_gradient_checkpointing(True)
|
395 |
+
unet.train() # for gradient checkpointing
|
396 |
+
|
397 |
+
control_net.train()
|
398 |
+
|
399 |
+
# # visualize
|
400 |
+
# import torchviz
|
401 |
+
# print("run visualize")
|
402 |
+
# controlnet.set_control(conditioning_image)
|
403 |
+
# output = unet(x, t, ctx, y)
|
404 |
+
# print("make_dot")
|
405 |
+
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
406 |
+
# print("render")
|
407 |
+
# image.format = "svg" # "png"
|
408 |
+
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
409 |
+
# input()
|
410 |
+
|
411 |
+
import bitsandbytes
|
412 |
+
|
413 |
+
optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
|
414 |
+
|
415 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
416 |
+
|
417 |
+
print("start training")
|
418 |
+
steps = 10
|
419 |
+
|
420 |
+
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
421 |
+
for step in range(steps):
|
422 |
+
print(f"step {step}")
|
423 |
+
|
424 |
+
batch_size = 1
|
425 |
+
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
426 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
427 |
+
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
428 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
429 |
+
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
430 |
+
|
431 |
+
with torch.cuda.amp.autocast(enabled=True):
|
432 |
+
control_net.set_cond_image(conditioning_image)
|
433 |
+
|
434 |
+
output = unet(x, t, ctx, y)
|
435 |
+
target = torch.randn_like(output)
|
436 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
437 |
+
|
438 |
+
scaler.scale(loss).backward()
|
439 |
+
scaler.step(optimizer)
|
440 |
+
scaler.update()
|
441 |
+
optimizer.zero_grad(set_to_none=True)
|
442 |
+
print(sample_param)
|
443 |
+
|
444 |
+
# from safetensors.torch import save_file
|
445 |
+
|
446 |
+
# save_file(control_net.state_dict(), "logs/control_net.safetensors")
|
external/llite/networks/control_net_lllite_for_train.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
|
2 |
+
# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from typing import Optional, List, Type
|
7 |
+
import torch
|
8 |
+
from library import sdxl_original_unet
|
9 |
+
|
10 |
+
|
11 |
+
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
12 |
+
SKIP_INPUT_BLOCKS = False
|
13 |
+
|
14 |
+
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
15 |
+
SKIP_OUTPUT_BLOCKS = True
|
16 |
+
|
17 |
+
# conv2dに適用するかどうか / if True, conv2d are not applied
|
18 |
+
SKIP_CONV2D = False
|
19 |
+
|
20 |
+
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
21 |
+
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
22 |
+
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
23 |
+
|
24 |
+
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
25 |
+
ATTN1_2_ONLY = True
|
26 |
+
|
27 |
+
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
28 |
+
ATTN_QKV_ONLY = True
|
29 |
+
|
30 |
+
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
31 |
+
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
32 |
+
ATTN1_ETC_ONLY = False # True
|
33 |
+
|
34 |
+
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
35 |
+
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
36 |
+
TRANSFORMER_MAX_BLOCK_INDEX = None
|
37 |
+
|
38 |
+
ORIGINAL_LINEAR = torch.nn.Linear
|
39 |
+
ORIGINAL_CONV2D = torch.nn.Conv2d
|
40 |
+
|
41 |
+
|
42 |
+
def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
|
43 |
+
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
44 |
+
# conditioning1 embeds conditioning image. it is not called for each timestep
|
45 |
+
modules = []
|
46 |
+
modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
47 |
+
if depth == 1:
|
48 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
49 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
50 |
+
elif depth == 2:
|
51 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
52 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
53 |
+
elif depth == 3:
|
54 |
+
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
55 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
56 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
57 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
58 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
59 |
+
|
60 |
+
module.lllite_conditioning1 = torch.nn.Sequential(*modules)
|
61 |
+
|
62 |
+
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
63 |
+
# midでconditioning image embeddingと入力を結合する
|
64 |
+
# upで元の次元数に戻す
|
65 |
+
# これらはtimestepごとに呼ばれる
|
66 |
+
# reduce the number of input dimensions with down. inspired by LoRA
|
67 |
+
# combine conditioning image embedding and input with mid
|
68 |
+
# restore to the original dimension with up
|
69 |
+
# these are called for each timestep
|
70 |
+
|
71 |
+
module.lllite_down = torch.nn.Sequential(
|
72 |
+
ORIGINAL_LINEAR(in_dim, mlp_dim),
|
73 |
+
torch.nn.ReLU(inplace=True),
|
74 |
+
)
|
75 |
+
module.lllite_mid = torch.nn.Sequential(
|
76 |
+
ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
|
77 |
+
torch.nn.ReLU(inplace=True),
|
78 |
+
)
|
79 |
+
module.lllite_up = torch.nn.Sequential(
|
80 |
+
ORIGINAL_LINEAR(mlp_dim, in_dim),
|
81 |
+
)
|
82 |
+
|
83 |
+
# Zero-Convにする / set to Zero-Conv
|
84 |
+
torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
|
85 |
+
|
86 |
+
|
87 |
+
class LLLiteLinear(ORIGINAL_LINEAR):
|
88 |
+
def __init__(self, in_features: int, out_features: int, **kwargs):
|
89 |
+
super().__init__(in_features, out_features, **kwargs)
|
90 |
+
self.enabled = False
|
91 |
+
|
92 |
+
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
|
93 |
+
self.enabled = True
|
94 |
+
self.lllite_name = name
|
95 |
+
self.cond_emb_dim = cond_emb_dim
|
96 |
+
self.dropout = dropout
|
97 |
+
self.multiplier = multiplier # ignored
|
98 |
+
|
99 |
+
in_dim = self.in_features
|
100 |
+
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
101 |
+
|
102 |
+
self.cond_image = None
|
103 |
+
self.cond_emb = None
|
104 |
+
|
105 |
+
def set_cond_image(self, cond_image):
|
106 |
+
self.cond_image = cond_image
|
107 |
+
self.cond_emb = None
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
if not self.enabled:
|
111 |
+
return super().forward(x)
|
112 |
+
|
113 |
+
if self.cond_emb is None:
|
114 |
+
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
115 |
+
cx = self.cond_emb
|
116 |
+
|
117 |
+
# reshape / b,c,h,w -> b,h*w,c
|
118 |
+
n, c, h, w = cx.shape
|
119 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
120 |
+
|
121 |
+
cx = torch.cat([cx, self.lllite_down(x)], dim=2)
|
122 |
+
cx = self.lllite_mid(cx)
|
123 |
+
|
124 |
+
if self.dropout is not None and self.training:
|
125 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
126 |
+
|
127 |
+
cx = self.lllite_up(cx) * self.multiplier
|
128 |
+
|
129 |
+
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class LLLiteConv2d(ORIGINAL_CONV2D):
|
134 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
|
135 |
+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
136 |
+
self.enabled = False
|
137 |
+
|
138 |
+
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
|
139 |
+
self.enabled = True
|
140 |
+
self.lllite_name = name
|
141 |
+
self.cond_emb_dim = cond_emb_dim
|
142 |
+
self.dropout = dropout
|
143 |
+
self.multiplier = multiplier # ignored
|
144 |
+
|
145 |
+
in_dim = self.in_channels
|
146 |
+
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
147 |
+
|
148 |
+
self.cond_image = None
|
149 |
+
self.cond_emb = None
|
150 |
+
|
151 |
+
def set_cond_image(self, cond_image):
|
152 |
+
self.cond_image = cond_image
|
153 |
+
self.cond_emb = None
|
154 |
+
|
155 |
+
def forward(self, x): # , cond_image=None):
|
156 |
+
if not self.enabled:
|
157 |
+
return super().forward(x)
|
158 |
+
|
159 |
+
if self.cond_emb is None:
|
160 |
+
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
161 |
+
cx = self.cond_emb
|
162 |
+
|
163 |
+
cx = torch.cat([cx, self.down(x)], dim=1)
|
164 |
+
cx = self.mid(cx)
|
165 |
+
|
166 |
+
if self.dropout is not None and self.training:
|
167 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
168 |
+
|
169 |
+
cx = self.up(cx) * self.multiplier
|
170 |
+
|
171 |
+
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
|
176 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
177 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
178 |
+
LLLITE_PREFIX = "lllite_unet"
|
179 |
+
|
180 |
+
def __init__(self, **kwargs):
|
181 |
+
super().__init__(**kwargs)
|
182 |
+
|
183 |
+
def apply_lllite(
|
184 |
+
self,
|
185 |
+
cond_emb_dim: int = 16,
|
186 |
+
mlp_dim: int = 16,
|
187 |
+
dropout: Optional[float] = None,
|
188 |
+
varbose: Optional[bool] = False,
|
189 |
+
multiplier: Optional[float] = 1.0,
|
190 |
+
) -> None:
|
191 |
+
def apply_to_modules(
|
192 |
+
root_module: torch.nn.Module,
|
193 |
+
target_replace_modules: List[torch.nn.Module],
|
194 |
+
) -> List[torch.nn.Module]:
|
195 |
+
prefix = "lllite_unet"
|
196 |
+
|
197 |
+
modules = []
|
198 |
+
for name, module in root_module.named_modules():
|
199 |
+
if module.__class__.__name__ in target_replace_modules:
|
200 |
+
for child_name, child_module in module.named_modules():
|
201 |
+
is_linear = child_module.__class__.__name__ == "LLLiteLinear"
|
202 |
+
is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
|
203 |
+
|
204 |
+
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
205 |
+
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
206 |
+
# block index to depth: depth is using to calculate conditioning size and channels
|
207 |
+
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
208 |
+
index1 = int(index1)
|
209 |
+
if block_name == "input_blocks":
|
210 |
+
if SKIP_INPUT_BLOCKS:
|
211 |
+
continue
|
212 |
+
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
213 |
+
elif block_name == "middle_block":
|
214 |
+
depth = 3
|
215 |
+
elif block_name == "output_blocks":
|
216 |
+
if SKIP_OUTPUT_BLOCKS:
|
217 |
+
continue
|
218 |
+
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
219 |
+
if int(index2) >= 2:
|
220 |
+
depth -= 1
|
221 |
+
else:
|
222 |
+
raise NotImplementedError()
|
223 |
+
|
224 |
+
lllite_name = prefix + "." + name + "." + child_name
|
225 |
+
lllite_name = lllite_name.replace(".", "_")
|
226 |
+
|
227 |
+
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
228 |
+
p = lllite_name.find("transformer_blocks")
|
229 |
+
if p >= 0:
|
230 |
+
tf_index = int(lllite_name[p:].split("_")[2])
|
231 |
+
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
232 |
+
continue
|
233 |
+
|
234 |
+
# time embは適用外とする
|
235 |
+
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
236 |
+
# time emb is not applied
|
237 |
+
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
238 |
+
if "emb_layers" in lllite_name or (
|
239 |
+
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
240 |
+
):
|
241 |
+
continue
|
242 |
+
|
243 |
+
if ATTN1_2_ONLY:
|
244 |
+
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
245 |
+
continue
|
246 |
+
if ATTN_QKV_ONLY:
|
247 |
+
if "to_out" in lllite_name:
|
248 |
+
continue
|
249 |
+
|
250 |
+
if ATTN1_ETC_ONLY:
|
251 |
+
if "proj_out" in lllite_name:
|
252 |
+
pass
|
253 |
+
elif "attn1" in lllite_name and (
|
254 |
+
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
255 |
+
):
|
256 |
+
pass
|
257 |
+
elif "ff_net_2" in lllite_name:
|
258 |
+
pass
|
259 |
+
else:
|
260 |
+
continue
|
261 |
+
|
262 |
+
child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
|
263 |
+
modules.append(child_module)
|
264 |
+
|
265 |
+
return modules
|
266 |
+
|
267 |
+
target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
268 |
+
if not TRANSFORMER_ONLY:
|
269 |
+
target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
270 |
+
|
271 |
+
# create module instances
|
272 |
+
self.lllite_modules = apply_to_modules(self, target_modules)
|
273 |
+
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
274 |
+
|
275 |
+
# def prepare_optimizer_params(self):
|
276 |
+
def prepare_params(self):
|
277 |
+
train_params = []
|
278 |
+
non_train_params = []
|
279 |
+
for name, p in self.named_parameters():
|
280 |
+
if "lllite" in name:
|
281 |
+
train_params.append(p)
|
282 |
+
else:
|
283 |
+
non_train_params.append(p)
|
284 |
+
print(f"count of trainable parameters: {len(train_params)}")
|
285 |
+
print(f"count of non-trainable parameters: {len(non_train_params)}")
|
286 |
+
|
287 |
+
for p in non_train_params:
|
288 |
+
p.requires_grad_(False)
|
289 |
+
|
290 |
+
# without this, an error occurs in the optimizer
|
291 |
+
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
|
292 |
+
non_train_params[0].requires_grad_(True)
|
293 |
+
|
294 |
+
for p in train_params:
|
295 |
+
p.requires_grad_(True)
|
296 |
+
|
297 |
+
return train_params
|
298 |
+
|
299 |
+
# def prepare_grad_etc(self):
|
300 |
+
# self.requires_grad_(True)
|
301 |
+
|
302 |
+
# def on_epoch_start(self):
|
303 |
+
# self.train()
|
304 |
+
|
305 |
+
def get_trainable_params(self):
|
306 |
+
return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
|
307 |
+
|
308 |
+
def save_lllite_weights(self, file, dtype, metadata):
|
309 |
+
if metadata is not None and len(metadata) == 0:
|
310 |
+
metadata = None
|
311 |
+
|
312 |
+
org_state_dict = self.state_dict()
|
313 |
+
|
314 |
+
# copy LLLite keys from org_state_dict to state_dict with key conversion
|
315 |
+
state_dict = {}
|
316 |
+
for key in org_state_dict.keys():
|
317 |
+
# split with ".lllite"
|
318 |
+
pos = key.find(".lllite")
|
319 |
+
if pos < 0:
|
320 |
+
continue
|
321 |
+
lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
|
322 |
+
lllite_key = lllite_key.replace(".", "_") + key[pos:]
|
323 |
+
lllite_key = lllite_key.replace(".lllite_", ".")
|
324 |
+
state_dict[lllite_key] = org_state_dict[key]
|
325 |
+
|
326 |
+
if dtype is not None:
|
327 |
+
for key in list(state_dict.keys()):
|
328 |
+
v = state_dict[key]
|
329 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
330 |
+
state_dict[key] = v
|
331 |
+
|
332 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
333 |
+
from safetensors.torch import save_file
|
334 |
+
|
335 |
+
save_file(state_dict, file, metadata)
|
336 |
+
else:
|
337 |
+
torch.save(state_dict, file)
|
338 |
+
|
339 |
+
def load_lllite_weights(self, file, non_lllite_unet_sd=None):
|
340 |
+
r"""
|
341 |
+
LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。
|
342 |
+
この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
|
343 |
+
|
344 |
+
If you do not want to load LLLite weights (use initialized values), specify None for file.
|
345 |
+
In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
|
346 |
+
"""
|
347 |
+
if not file:
|
348 |
+
state_dict = self.state_dict()
|
349 |
+
for key in non_lllite_unet_sd:
|
350 |
+
if key in state_dict:
|
351 |
+
state_dict[key] = non_lllite_unet_sd[key]
|
352 |
+
info = self.load_state_dict(state_dict, False)
|
353 |
+
return info
|
354 |
+
|
355 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
356 |
+
from safetensors.torch import load_file
|
357 |
+
|
358 |
+
weights_sd = load_file(file)
|
359 |
+
else:
|
360 |
+
weights_sd = torch.load(file, map_location="cpu")
|
361 |
+
|
362 |
+
# module_name = module_name.replace("_block", "@blocks")
|
363 |
+
# module_name = module_name.replace("_layer", "@layer")
|
364 |
+
# module_name = module_name.replace("to_", "to@")
|
365 |
+
# module_name = module_name.replace("time_embed", "time@embed")
|
366 |
+
# module_name = module_name.replace("label_emb", "label@emb")
|
367 |
+
# module_name = module_name.replace("skip_connection", "skip@connection")
|
368 |
+
# module_name = module_name.replace("proj_in", "proj@in")
|
369 |
+
# module_name = module_name.replace("proj_out", "proj@out")
|
370 |
+
pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
|
371 |
+
|
372 |
+
# convert to lllite with U-Net state dict
|
373 |
+
state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
|
374 |
+
for key in weights_sd.keys():
|
375 |
+
# split with "."
|
376 |
+
pos = key.find(".")
|
377 |
+
if pos < 0:
|
378 |
+
continue
|
379 |
+
|
380 |
+
module_name = key[:pos]
|
381 |
+
weight_name = key[pos + 1 :] # exclude "."
|
382 |
+
module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
|
383 |
+
|
384 |
+
# これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
|
385 |
+
# module_name = module_name.replace("_", ".")
|
386 |
+
|
387 |
+
# ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
|
388 |
+
matches = pattern.findall(module_name)
|
389 |
+
if matches is not None:
|
390 |
+
for m in matches:
|
391 |
+
print(module_name, m)
|
392 |
+
module_name = module_name.replace(m, m.replace("_", "@"))
|
393 |
+
module_name = module_name.replace("_", ".")
|
394 |
+
module_name = module_name.replace("@", "_")
|
395 |
+
|
396 |
+
lllite_key = module_name + ".lllite_" + weight_name
|
397 |
+
|
398 |
+
state_dict[lllite_key] = weights_sd[key]
|
399 |
+
|
400 |
+
info = self.load_state_dict(state_dict, False)
|
401 |
+
return info
|
402 |
+
|
403 |
+
def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
|
404 |
+
for m in self.lllite_modules:
|
405 |
+
m.set_cond_image(cond_image)
|
406 |
+
return super().forward(x, timesteps, context, y, **kwargs)
|
407 |
+
|
408 |
+
|
409 |
+
def replace_unet_linear_and_conv2d():
|
410 |
+
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
411 |
+
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
412 |
+
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
413 |
+
|
414 |
+
|
415 |
+
if __name__ == "__main__":
|
416 |
+
# デバッグ用 / for debug
|
417 |
+
|
418 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
419 |
+
replace_unet_linear_and_conv2d()
|
420 |
+
|
421 |
+
# test shape etc
|
422 |
+
print("create unet")
|
423 |
+
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
424 |
+
|
425 |
+
print("enable ControlNet-LLLite")
|
426 |
+
unet.apply_lllite(32, 64, None, False, 1.0)
|
427 |
+
unet.to("cuda") # .to(torch.float16)
|
428 |
+
|
429 |
+
# from safetensors.torch import load_file
|
430 |
+
|
431 |
+
# model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
|
432 |
+
# unet_sd = {}
|
433 |
+
|
434 |
+
# # copy U-Net keys from unet_state_dict to state_dict
|
435 |
+
# prefix = "model.diffusion_model."
|
436 |
+
# for key in model_sd.keys():
|
437 |
+
# if key.startswith(prefix):
|
438 |
+
# converted_key = key[len(prefix) :]
|
439 |
+
# unet_sd[converted_key] = model_sd[key]
|
440 |
+
|
441 |
+
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
442 |
+
# print(info)
|
443 |
+
|
444 |
+
# print(unet)
|
445 |
+
|
446 |
+
# print number of parameters
|
447 |
+
params = unet.prepare_params()
|
448 |
+
print("number of parameters", sum(p.numel() for p in params))
|
449 |
+
# print("type any key to continue")
|
450 |
+
# input()
|
451 |
+
|
452 |
+
unet.set_use_memory_efficient_attention(True, False)
|
453 |
+
unet.set_gradient_checkpointing(True)
|
454 |
+
unet.train() # for gradient checkpointing
|
455 |
+
|
456 |
+
# # visualize
|
457 |
+
# import torchviz
|
458 |
+
# print("run visualize")
|
459 |
+
# controlnet.set_control(conditioning_image)
|
460 |
+
# output = unet(x, t, ctx, y)
|
461 |
+
# print("make_dot")
|
462 |
+
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
463 |
+
# print("render")
|
464 |
+
# image.format = "svg" # "png"
|
465 |
+
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
466 |
+
# input()
|
467 |
+
|
468 |
+
import bitsandbytes
|
469 |
+
|
470 |
+
optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
|
471 |
+
|
472 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
473 |
+
|
474 |
+
print("start training")
|
475 |
+
steps = 10
|
476 |
+
batch_size = 1
|
477 |
+
|
478 |
+
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
479 |
+
for step in range(steps):
|
480 |
+
print(f"step {step}")
|
481 |
+
|
482 |
+
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
483 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
484 |
+
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
485 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
486 |
+
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
487 |
+
|
488 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
489 |
+
output = unet(x, t, ctx, y, conditioning_image)
|
490 |
+
target = torch.randn_like(output)
|
491 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
492 |
+
|
493 |
+
scaler.scale(loss).backward()
|
494 |
+
scaler.step(optimizer)
|
495 |
+
scaler.update()
|
496 |
+
optimizer.zero_grad(set_to_none=True)
|
497 |
+
print(sample_param)
|
498 |
+
|
499 |
+
# from safetensors.torch import save_file
|
500 |
+
|
501 |
+
# print("save weights")
|
502 |
+
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
external/llite/networks/dylora.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# some codes are copied from:
|
2 |
+
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
3 |
+
|
4 |
+
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
5 |
+
# Changes made to the original code:
|
6 |
+
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
7 |
+
# ------------------------------------------------------------------------------------------
|
8 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
9 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
10 |
+
# ------------------------------------------------------------------------------------------
|
11 |
+
|
12 |
+
import math
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
from typing import List, Tuple, Union
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
class DyLoRAModule(torch.nn.Module):
|
21 |
+
"""
|
22 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
23 |
+
"""
|
24 |
+
|
25 |
+
# NOTE: support dropout in future
|
26 |
+
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
27 |
+
super().__init__()
|
28 |
+
self.lora_name = lora_name
|
29 |
+
self.lora_dim = lora_dim
|
30 |
+
self.unit = unit
|
31 |
+
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
32 |
+
|
33 |
+
if org_module.__class__.__name__ == "Conv2d":
|
34 |
+
in_dim = org_module.in_channels
|
35 |
+
out_dim = org_module.out_channels
|
36 |
+
else:
|
37 |
+
in_dim = org_module.in_features
|
38 |
+
out_dim = org_module.out_features
|
39 |
+
|
40 |
+
if type(alpha) == torch.Tensor:
|
41 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
42 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
43 |
+
self.scale = alpha / self.lora_dim
|
44 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
45 |
+
|
46 |
+
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
47 |
+
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
48 |
+
|
49 |
+
if self.is_conv2d and self.is_conv2d_3x3:
|
50 |
+
kernel_size = org_module.kernel_size
|
51 |
+
self.stride = org_module.stride
|
52 |
+
self.padding = org_module.padding
|
53 |
+
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
54 |
+
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
55 |
+
else:
|
56 |
+
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
57 |
+
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
58 |
+
|
59 |
+
# same as microsoft's
|
60 |
+
for lora in self.lora_A:
|
61 |
+
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
62 |
+
for lora in self.lora_B:
|
63 |
+
torch.nn.init.zeros_(lora)
|
64 |
+
|
65 |
+
self.multiplier = multiplier
|
66 |
+
self.org_module = org_module # remove in applying
|
67 |
+
|
68 |
+
def apply_to(self):
|
69 |
+
self.org_forward = self.org_module.forward
|
70 |
+
self.org_module.forward = self.forward
|
71 |
+
del self.org_module
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
result = self.org_forward(x)
|
75 |
+
|
76 |
+
# specify the dynamic rank
|
77 |
+
trainable_rank = random.randint(0, self.lora_dim - 1)
|
78 |
+
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
79 |
+
|
80 |
+
# 一部のパラメータを固定して、残りのパラメータを学習する
|
81 |
+
for i in range(0, trainable_rank):
|
82 |
+
self.lora_A[i].requires_grad = False
|
83 |
+
self.lora_B[i].requires_grad = False
|
84 |
+
for i in range(trainable_rank, trainable_rank + self.unit):
|
85 |
+
self.lora_A[i].requires_grad = True
|
86 |
+
self.lora_B[i].requires_grad = True
|
87 |
+
for i in range(trainable_rank + self.unit, self.lora_dim):
|
88 |
+
self.lora_A[i].requires_grad = False
|
89 |
+
self.lora_B[i].requires_grad = False
|
90 |
+
|
91 |
+
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
92 |
+
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
93 |
+
|
94 |
+
# calculate with lora_A and lora_B
|
95 |
+
if self.is_conv2d_3x3:
|
96 |
+
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
97 |
+
ab = torch.nn.functional.conv2d(ab, lora_B)
|
98 |
+
else:
|
99 |
+
ab = x
|
100 |
+
if self.is_conv2d:
|
101 |
+
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
102 |
+
|
103 |
+
ab = torch.nn.functional.linear(ab, lora_A)
|
104 |
+
ab = torch.nn.functional.linear(ab, lora_B)
|
105 |
+
|
106 |
+
if self.is_conv2d:
|
107 |
+
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
108 |
+
|
109 |
+
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
110 |
+
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
111 |
+
|
112 |
+
# NOTE weightに加算してからlinear/conv2dを呼んだほうが��いかも
|
113 |
+
return result
|
114 |
+
|
115 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
116 |
+
# state dictを通常のLoRAと同じにする:
|
117 |
+
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
118 |
+
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
119 |
+
|
120 |
+
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
121 |
+
if self.is_conv2d and not self.is_conv2d_3x3:
|
122 |
+
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
123 |
+
|
124 |
+
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
125 |
+
if self.is_conv2d and not self.is_conv2d_3x3:
|
126 |
+
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
127 |
+
|
128 |
+
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
129 |
+
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
130 |
+
|
131 |
+
i = 0
|
132 |
+
while True:
|
133 |
+
key_a = f"{self.lora_name}.lora_A.{i}"
|
134 |
+
key_b = f"{self.lora_name}.lora_B.{i}"
|
135 |
+
if key_a in sd:
|
136 |
+
sd.pop(key_a)
|
137 |
+
sd.pop(key_b)
|
138 |
+
else:
|
139 |
+
break
|
140 |
+
i += 1
|
141 |
+
return sd
|
142 |
+
|
143 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
144 |
+
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
145 |
+
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
146 |
+
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
147 |
+
|
148 |
+
if lora_A_weight is None or lora_B_weight is None:
|
149 |
+
if strict:
|
150 |
+
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
151 |
+
else:
|
152 |
+
return
|
153 |
+
|
154 |
+
if self.is_conv2d and not self.is_conv2d_3x3:
|
155 |
+
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
156 |
+
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
157 |
+
|
158 |
+
state_dict.update(
|
159 |
+
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
160 |
+
)
|
161 |
+
state_dict.update(
|
162 |
+
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
163 |
+
)
|
164 |
+
|
165 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
166 |
+
|
167 |
+
|
168 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
169 |
+
if network_dim is None:
|
170 |
+
network_dim = 4 # default
|
171 |
+
if network_alpha is None:
|
172 |
+
network_alpha = 1.0
|
173 |
+
|
174 |
+
# extract dim/alpha for conv2d, and block dim
|
175 |
+
conv_dim = kwargs.get("conv_dim", None)
|
176 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
177 |
+
unit = kwargs.get("unit", None)
|
178 |
+
if conv_dim is not None:
|
179 |
+
conv_dim = int(conv_dim)
|
180 |
+
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
181 |
+
if conv_alpha is None:
|
182 |
+
conv_alpha = 1.0
|
183 |
+
else:
|
184 |
+
conv_alpha = float(conv_alpha)
|
185 |
+
if unit is not None:
|
186 |
+
unit = int(unit)
|
187 |
+
else:
|
188 |
+
unit = 1
|
189 |
+
|
190 |
+
network = DyLoRANetwork(
|
191 |
+
text_encoder,
|
192 |
+
unet,
|
193 |
+
multiplier=multiplier,
|
194 |
+
lora_dim=network_dim,
|
195 |
+
alpha=network_alpha,
|
196 |
+
apply_to_conv=conv_dim is not None,
|
197 |
+
unit=unit,
|
198 |
+
varbose=True,
|
199 |
+
)
|
200 |
+
return network
|
201 |
+
|
202 |
+
|
203 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
204 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
205 |
+
if weights_sd is None:
|
206 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
207 |
+
from safetensors.torch import load_file, safe_open
|
208 |
+
|
209 |
+
weights_sd = load_file(file)
|
210 |
+
else:
|
211 |
+
weights_sd = torch.load(file, map_location="cpu")
|
212 |
+
|
213 |
+
# get dim/alpha mapping
|
214 |
+
modules_dim = {}
|
215 |
+
modules_alpha = {}
|
216 |
+
for key, value in weights_sd.items():
|
217 |
+
if "." not in key:
|
218 |
+
continue
|
219 |
+
|
220 |
+
lora_name = key.split(".")[0]
|
221 |
+
if "alpha" in key:
|
222 |
+
modules_alpha[lora_name] = value
|
223 |
+
elif "lora_down" in key:
|
224 |
+
dim = value.size()[0]
|
225 |
+
modules_dim[lora_name] = dim
|
226 |
+
# print(lora_name, value.size(), dim)
|
227 |
+
|
228 |
+
# support old LoRA without alpha
|
229 |
+
for key in modules_dim.keys():
|
230 |
+
if key not in modules_alpha:
|
231 |
+
modules_alpha = modules_dim[key]
|
232 |
+
|
233 |
+
module_class = DyLoRAModule
|
234 |
+
|
235 |
+
network = DyLoRANetwork(
|
236 |
+
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
237 |
+
)
|
238 |
+
return network, weights_sd
|
239 |
+
|
240 |
+
|
241 |
+
class DyLoRANetwork(torch.nn.Module):
|
242 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
243 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
244 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
245 |
+
LORA_PREFIX_UNET = "lora_unet"
|
246 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
247 |
+
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
text_encoder,
|
251 |
+
unet,
|
252 |
+
multiplier=1.0,
|
253 |
+
lora_dim=4,
|
254 |
+
alpha=1,
|
255 |
+
apply_to_conv=False,
|
256 |
+
modules_dim=None,
|
257 |
+
modules_alpha=None,
|
258 |
+
unit=1,
|
259 |
+
module_class=DyLoRAModule,
|
260 |
+
varbose=False,
|
261 |
+
) -> None:
|
262 |
+
super().__init__()
|
263 |
+
self.multiplier = multiplier
|
264 |
+
|
265 |
+
self.lora_dim = lora_dim
|
266 |
+
self.alpha = alpha
|
267 |
+
self.apply_to_conv = apply_to_conv
|
268 |
+
|
269 |
+
if modules_dim is not None:
|
270 |
+
print(f"create LoRA network from weights")
|
271 |
+
else:
|
272 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
273 |
+
if self.apply_to_conv:
|
274 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
275 |
+
|
276 |
+
# create module instances
|
277 |
+
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
278 |
+
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
279 |
+
loras = []
|
280 |
+
for name, module in root_module.named_modules():
|
281 |
+
if module.__class__.__name__ in target_replace_modules:
|
282 |
+
for child_name, child_module in module.named_modules():
|
283 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
284 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
285 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
286 |
+
|
287 |
+
if is_linear or is_conv2d:
|
288 |
+
lora_name = prefix + "." + name + "." + child_name
|
289 |
+
lora_name = lora_name.replace(".", "_")
|
290 |
+
|
291 |
+
dim = None
|
292 |
+
alpha = None
|
293 |
+
if modules_dim is not None:
|
294 |
+
if lora_name in modules_dim:
|
295 |
+
dim = modules_dim[lora_name]
|
296 |
+
alpha = modules_alpha[lora_name]
|
297 |
+
else:
|
298 |
+
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
299 |
+
dim = self.lora_dim
|
300 |
+
alpha = self.alpha
|
301 |
+
|
302 |
+
if dim is None or dim == 0:
|
303 |
+
continue
|
304 |
+
|
305 |
+
# dropout and fan_in_fan_out is default
|
306 |
+
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
307 |
+
loras.append(lora)
|
308 |
+
return loras
|
309 |
+
|
310 |
+
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
311 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
312 |
+
|
313 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
314 |
+
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
315 |
+
if modules_dim is not None or self.apply_to_conv:
|
316 |
+
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
317 |
+
|
318 |
+
self.unet_loras = create_modules(True, unet, target_modules)
|
319 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
320 |
+
|
321 |
+
def set_multiplier(self, multiplier):
|
322 |
+
self.multiplier = multiplier
|
323 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
324 |
+
lora.multiplier = self.multiplier
|
325 |
+
|
326 |
+
def load_weights(self, file):
|
327 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
328 |
+
from safetensors.torch import load_file
|
329 |
+
|
330 |
+
weights_sd = load_file(file)
|
331 |
+
else:
|
332 |
+
weights_sd = torch.load(file, map_location="cpu")
|
333 |
+
|
334 |
+
info = self.load_state_dict(weights_sd, False)
|
335 |
+
return info
|
336 |
+
|
337 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
338 |
+
if apply_text_encoder:
|
339 |
+
print("enable LoRA for text encoder")
|
340 |
+
else:
|
341 |
+
self.text_encoder_loras = []
|
342 |
+
|
343 |
+
if apply_unet:
|
344 |
+
print("enable LoRA for U-Net")
|
345 |
+
else:
|
346 |
+
self.unet_loras = []
|
347 |
+
|
348 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
349 |
+
lora.apply_to()
|
350 |
+
self.add_module(lora.lora_name, lora)
|
351 |
+
|
352 |
+
"""
|
353 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
354 |
+
apply_text_encoder = apply_unet = False
|
355 |
+
for key in weights_sd.keys():
|
356 |
+
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
357 |
+
apply_text_encoder = True
|
358 |
+
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
359 |
+
apply_unet = True
|
360 |
+
|
361 |
+
if apply_text_encoder:
|
362 |
+
print("enable LoRA for text encoder")
|
363 |
+
else:
|
364 |
+
self.text_encoder_loras = []
|
365 |
+
|
366 |
+
if apply_unet:
|
367 |
+
print("enable LoRA for U-Net")
|
368 |
+
else:
|
369 |
+
self.unet_loras = []
|
370 |
+
|
371 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
372 |
+
sd_for_lora = {}
|
373 |
+
for key in weights_sd.keys():
|
374 |
+
if key.startswith(lora.lora_name):
|
375 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
376 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
377 |
+
|
378 |
+
print(f"weights are merged")
|
379 |
+
"""
|
380 |
+
|
381 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
382 |
+
self.requires_grad_(True)
|
383 |
+
all_params = []
|
384 |
+
|
385 |
+
def enumerate_params(loras):
|
386 |
+
params = []
|
387 |
+
for lora in loras:
|
388 |
+
params.extend(lora.parameters())
|
389 |
+
return params
|
390 |
+
|
391 |
+
if self.text_encoder_loras:
|
392 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
393 |
+
if text_encoder_lr is not None:
|
394 |
+
param_data["lr"] = text_encoder_lr
|
395 |
+
all_params.append(param_data)
|
396 |
+
|
397 |
+
if self.unet_loras:
|
398 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
399 |
+
if unet_lr is not None:
|
400 |
+
param_data["lr"] = unet_lr
|
401 |
+
all_params.append(param_data)
|
402 |
+
|
403 |
+
return all_params
|
404 |
+
|
405 |
+
def enable_gradient_checkpointing(self):
|
406 |
+
# not supported
|
407 |
+
pass
|
408 |
+
|
409 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
410 |
+
self.requires_grad_(True)
|
411 |
+
|
412 |
+
def on_epoch_start(self, text_encoder, unet):
|
413 |
+
self.train()
|
414 |
+
|
415 |
+
def get_trainable_params(self):
|
416 |
+
return self.parameters()
|
417 |
+
|
418 |
+
def save_weights(self, file, dtype, metadata):
|
419 |
+
if metadata is not None and len(metadata) == 0:
|
420 |
+
metadata = None
|
421 |
+
|
422 |
+
state_dict = self.state_dict()
|
423 |
+
|
424 |
+
if dtype is not None:
|
425 |
+
for key in list(state_dict.keys()):
|
426 |
+
v = state_dict[key]
|
427 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
428 |
+
state_dict[key] = v
|
429 |
+
|
430 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
431 |
+
from safetensors.torch import save_file
|
432 |
+
from library import train_util
|
433 |
+
|
434 |
+
# Precalculate model hashes to save time on indexing
|
435 |
+
if metadata is None:
|
436 |
+
metadata = {}
|
437 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
438 |
+
metadata["sshs_model_hash"] = model_hash
|
439 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
440 |
+
|
441 |
+
save_file(state_dict, file, metadata)
|
442 |
+
else:
|
443 |
+
torch.save(state_dict, file)
|
444 |
+
|
445 |
+
# mask is a tensor with values from 0 to 1
|
446 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
447 |
+
pass
|
448 |
+
|
449 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
450 |
+
pass
|
external/llite/networks/extract_lora_from_dylora.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
+
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
from safetensors.torch import load_file, save_file, safe_open
|
10 |
+
from tqdm import tqdm
|
11 |
+
from library import train_util, model_util
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
def load_state_dict(file_name):
|
16 |
+
if model_util.is_safetensors(file_name):
|
17 |
+
sd = load_file(file_name)
|
18 |
+
with safe_open(file_name, framework="pt") as f:
|
19 |
+
metadata = f.metadata()
|
20 |
+
else:
|
21 |
+
sd = torch.load(file_name, map_location="cpu")
|
22 |
+
metadata = None
|
23 |
+
|
24 |
+
return sd, metadata
|
25 |
+
|
26 |
+
|
27 |
+
def save_to_file(file_name, model, metadata):
|
28 |
+
if model_util.is_safetensors(file_name):
|
29 |
+
save_file(model, file_name, metadata)
|
30 |
+
else:
|
31 |
+
torch.save(model, file_name)
|
32 |
+
|
33 |
+
|
34 |
+
def split_lora_model(lora_sd, unit):
|
35 |
+
max_rank = 0
|
36 |
+
|
37 |
+
# Extract loaded lora dim and alpha
|
38 |
+
for key, value in lora_sd.items():
|
39 |
+
if "lora_down" in key:
|
40 |
+
rank = value.size()[0]
|
41 |
+
if rank > max_rank:
|
42 |
+
max_rank = rank
|
43 |
+
print(f"Max rank: {max_rank}")
|
44 |
+
|
45 |
+
rank = unit
|
46 |
+
split_models = []
|
47 |
+
new_alpha = None
|
48 |
+
while rank < max_rank:
|
49 |
+
print(f"Splitting rank {rank}")
|
50 |
+
new_sd = {}
|
51 |
+
for key, value in lora_sd.items():
|
52 |
+
if "lora_down" in key:
|
53 |
+
new_sd[key] = value[:rank].contiguous()
|
54 |
+
elif "lora_up" in key:
|
55 |
+
new_sd[key] = value[:, :rank].contiguous()
|
56 |
+
else:
|
57 |
+
# なぜかscaleするとおかしくなる……
|
58 |
+
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
59 |
+
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
60 |
+
# print(key, value.size(), this_rank, rank, value, scale)
|
61 |
+
# new_alpha = value * scale # always same
|
62 |
+
# new_sd[key] = new_alpha
|
63 |
+
new_sd[key] = value
|
64 |
+
|
65 |
+
split_models.append((new_sd, rank, new_alpha))
|
66 |
+
rank += unit
|
67 |
+
|
68 |
+
return max_rank, split_models
|
69 |
+
|
70 |
+
|
71 |
+
def split(args):
|
72 |
+
print("loading Model...")
|
73 |
+
lora_sd, metadata = load_state_dict(args.model)
|
74 |
+
|
75 |
+
print("Splitting Model...")
|
76 |
+
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
77 |
+
|
78 |
+
comment = metadata.get("ss_training_comment", "")
|
79 |
+
for state_dict, new_rank, new_alpha in split_models:
|
80 |
+
# update metadata
|
81 |
+
if metadata is None:
|
82 |
+
new_metadata = {}
|
83 |
+
else:
|
84 |
+
new_metadata = metadata.copy()
|
85 |
+
|
86 |
+
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
87 |
+
new_metadata["ss_network_dim"] = str(new_rank)
|
88 |
+
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
89 |
+
|
90 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
91 |
+
metadata["sshs_model_hash"] = model_hash
|
92 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
93 |
+
|
94 |
+
filename, ext = os.path.splitext(args.save_to)
|
95 |
+
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
96 |
+
|
97 |
+
print(f"saving model to: {model_file_name}")
|
98 |
+
save_to_file(model_file_name, state_dict, new_metadata)
|
99 |
+
|
100 |
+
|
101 |
+
def setup_parser() -> argparse.ArgumentParser:
|
102 |
+
parser = argparse.ArgumentParser()
|
103 |
+
|
104 |
+
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
105 |
+
parser.add_argument(
|
106 |
+
"--save_to",
|
107 |
+
type=str,
|
108 |
+
default=None,
|
109 |
+
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--model",
|
113 |
+
type=str,
|
114 |
+
default=None,
|
115 |
+
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
116 |
+
)
|
117 |
+
|
118 |
+
return parser
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
parser = setup_parser()
|
123 |
+
|
124 |
+
args = parser.parse_args()
|
125 |
+
split(args)
|
external/llite/networks/extract_lora_from_models.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# extract approximating LoRA by svd from two SD models
|
2 |
+
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo!
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import torch
|
10 |
+
from safetensors.torch import load_file, save_file
|
11 |
+
from tqdm import tqdm
|
12 |
+
from library import sai_model_spec, model_util, sdxl_model_util
|
13 |
+
import lora
|
14 |
+
|
15 |
+
|
16 |
+
# CLAMP_QUANTILE = 0.99
|
17 |
+
# MIN_DIFF = 1e-1
|
18 |
+
|
19 |
+
|
20 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
21 |
+
if dtype is not None:
|
22 |
+
for key in list(state_dict.keys()):
|
23 |
+
if type(state_dict[key]) == torch.Tensor:
|
24 |
+
state_dict[key] = state_dict[key].to(dtype)
|
25 |
+
|
26 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
27 |
+
save_file(model, file_name)
|
28 |
+
else:
|
29 |
+
torch.save(model, file_name)
|
30 |
+
|
31 |
+
|
32 |
+
def svd(
|
33 |
+
model_org=None,
|
34 |
+
model_tuned=None,
|
35 |
+
save_to=None,
|
36 |
+
dim=4,
|
37 |
+
v2=None,
|
38 |
+
sdxl=None,
|
39 |
+
conv_dim=None,
|
40 |
+
v_parameterization=None,
|
41 |
+
device=None,
|
42 |
+
save_precision=None,
|
43 |
+
clamp_quantile=0.99,
|
44 |
+
min_diff=0.01,
|
45 |
+
no_metadata=False,
|
46 |
+
):
|
47 |
+
def str_to_dtype(p):
|
48 |
+
if p == "float":
|
49 |
+
return torch.float
|
50 |
+
if p == "fp16":
|
51 |
+
return torch.float16
|
52 |
+
if p == "bf16":
|
53 |
+
return torch.bfloat16
|
54 |
+
return None
|
55 |
+
|
56 |
+
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
57 |
+
if v_parameterization is None:
|
58 |
+
v_parameterization = v2
|
59 |
+
|
60 |
+
save_dtype = str_to_dtype(save_precision)
|
61 |
+
|
62 |
+
# load models
|
63 |
+
if not sdxl:
|
64 |
+
print(f"loading original SD model : {model_org}")
|
65 |
+
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
66 |
+
text_encoders_o = [text_encoder_o]
|
67 |
+
print(f"loading tuned SD model : {model_tuned}")
|
68 |
+
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
69 |
+
text_encoders_t = [text_encoder_t]
|
70 |
+
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
71 |
+
else:
|
72 |
+
print(f"loading original SDXL model : {model_org}")
|
73 |
+
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
74 |
+
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
75 |
+
)
|
76 |
+
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
77 |
+
print(f"loading original SDXL model : {model_tuned}")
|
78 |
+
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
79 |
+
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
80 |
+
)
|
81 |
+
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
82 |
+
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
83 |
+
|
84 |
+
# create LoRA network to extract weights: Use dim (rank) as alpha
|
85 |
+
if conv_dim is None:
|
86 |
+
kwargs = {}
|
87 |
+
else:
|
88 |
+
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
89 |
+
|
90 |
+
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
91 |
+
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
92 |
+
assert len(lora_network_o.text_encoder_loras) == len(
|
93 |
+
lora_network_t.text_encoder_loras
|
94 |
+
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
95 |
+
|
96 |
+
# get diffs
|
97 |
+
diffs = {}
|
98 |
+
text_encoder_different = False
|
99 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
100 |
+
lora_name = lora_o.lora_name
|
101 |
+
module_o = lora_o.org_module
|
102 |
+
module_t = lora_t.org_module
|
103 |
+
diff = module_t.weight - module_o.weight
|
104 |
+
|
105 |
+
# Text Encoder might be same
|
106 |
+
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
107 |
+
text_encoder_different = True
|
108 |
+
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
109 |
+
|
110 |
+
diff = diff.float()
|
111 |
+
diffs[lora_name] = diff
|
112 |
+
|
113 |
+
if not text_encoder_different:
|
114 |
+
print("Text encoder is same. Extract U-Net only.")
|
115 |
+
lora_network_o.text_encoder_loras = []
|
116 |
+
diffs = {}
|
117 |
+
|
118 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
119 |
+
lora_name = lora_o.lora_name
|
120 |
+
module_o = lora_o.org_module
|
121 |
+
module_t = lora_t.org_module
|
122 |
+
diff = module_t.weight - module_o.weight
|
123 |
+
diff = diff.float()
|
124 |
+
|
125 |
+
if args.device:
|
126 |
+
diff = diff.to(args.device)
|
127 |
+
|
128 |
+
diffs[lora_name] = diff
|
129 |
+
|
130 |
+
# make LoRA with svd
|
131 |
+
print("calculating by svd")
|
132 |
+
lora_weights = {}
|
133 |
+
with torch.no_grad():
|
134 |
+
for lora_name, mat in tqdm(list(diffs.items())):
|
135 |
+
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
136 |
+
conv2d = len(mat.size()) == 4
|
137 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
138 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
139 |
+
|
140 |
+
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
141 |
+
out_dim, in_dim = mat.size()[0:2]
|
142 |
+
|
143 |
+
if device:
|
144 |
+
mat = mat.to(device)
|
145 |
+
|
146 |
+
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
147 |
+
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
148 |
+
|
149 |
+
if conv2d:
|
150 |
+
if conv2d_3x3:
|
151 |
+
mat = mat.flatten(start_dim=1)
|
152 |
+
else:
|
153 |
+
mat = mat.squeeze()
|
154 |
+
|
155 |
+
U, S, Vh = torch.linalg.svd(mat)
|
156 |
+
|
157 |
+
U = U[:, :rank]
|
158 |
+
S = S[:rank]
|
159 |
+
U = U @ torch.diag(S)
|
160 |
+
|
161 |
+
Vh = Vh[:rank, :]
|
162 |
+
|
163 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
164 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
165 |
+
low_val = -hi_val
|
166 |
+
|
167 |
+
U = U.clamp(low_val, hi_val)
|
168 |
+
Vh = Vh.clamp(low_val, hi_val)
|
169 |
+
|
170 |
+
if conv2d:
|
171 |
+
U = U.reshape(out_dim, rank, 1, 1)
|
172 |
+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
173 |
+
|
174 |
+
U = U.to("cpu").contiguous()
|
175 |
+
Vh = Vh.to("cpu").contiguous()
|
176 |
+
|
177 |
+
lora_weights[lora_name] = (U, Vh)
|
178 |
+
|
179 |
+
# make state dict for LoRA
|
180 |
+
lora_sd = {}
|
181 |
+
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
182 |
+
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
183 |
+
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
184 |
+
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
|
185 |
+
|
186 |
+
# load state dict to LoRA and save it
|
187 |
+
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
188 |
+
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
189 |
+
|
190 |
+
info = lora_network_save.load_state_dict(lora_sd)
|
191 |
+
print(f"Loading extracted LoRA weights: {info}")
|
192 |
+
|
193 |
+
dir_name = os.path.dirname(save_to)
|
194 |
+
if dir_name and not os.path.exists(dir_name):
|
195 |
+
os.makedirs(dir_name, exist_ok=True)
|
196 |
+
|
197 |
+
# minimum metadata
|
198 |
+
net_kwargs = {}
|
199 |
+
if conv_dim is not None:
|
200 |
+
net_kwargs["conv_dim"] = str(conv_dim)
|
201 |
+
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
202 |
+
|
203 |
+
metadata = {
|
204 |
+
"ss_v2": str(v2),
|
205 |
+
"ss_base_model_version": model_version,
|
206 |
+
"ss_network_module": "networks.lora",
|
207 |
+
"ss_network_dim": str(dim),
|
208 |
+
"ss_network_alpha": str(float(dim)),
|
209 |
+
"ss_network_args": json.dumps(net_kwargs),
|
210 |
+
}
|
211 |
+
|
212 |
+
if not no_metadata:
|
213 |
+
title = os.path.splitext(os.path.basename(save_to))[0]
|
214 |
+
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
215 |
+
metadata.update(sai_metadata)
|
216 |
+
|
217 |
+
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
218 |
+
print(f"LoRA weights are saved to: {save_to}")
|
219 |
+
|
220 |
+
|
221 |
+
def setup_parser() -> argparse.ArgumentParser:
|
222 |
+
parser = argparse.ArgumentParser()
|
223 |
+
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
224 |
+
parser.add_argument(
|
225 |
+
"--v_parameterization",
|
226 |
+
action="store_true",
|
227 |
+
default=None,
|
228 |
+
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--save_precision",
|
235 |
+
type=str,
|
236 |
+
default=None,
|
237 |
+
choices=[None, "float", "fp16", "bf16"],
|
238 |
+
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--model_org",
|
242 |
+
type=str,
|
243 |
+
default=None,
|
244 |
+
required=True,
|
245 |
+
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--model_tuned",
|
249 |
+
type=str,
|
250 |
+
default=None,
|
251 |
+
required=True,
|
252 |
+
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--save_to",
|
256 |
+
type=str,
|
257 |
+
default=None,
|
258 |
+
required=True,
|
259 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
260 |
+
)
|
261 |
+
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
262 |
+
parser.add_argument(
|
263 |
+
"--conv_dim",
|
264 |
+
type=int,
|
265 |
+
default=None,
|
266 |
+
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
267 |
+
)
|
268 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
269 |
+
parser.add_argument(
|
270 |
+
"--clamp_quantile",
|
271 |
+
type=float,
|
272 |
+
default=0.99,
|
273 |
+
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--min_diff",
|
277 |
+
type=float,
|
278 |
+
default=0.01,
|
279 |
+
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
280 |
+
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
281 |
+
)
|
282 |
+
parser.add_argument(
|
283 |
+
"--no_metadata",
|
284 |
+
action="store_true",
|
285 |
+
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
286 |
+
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
287 |
+
)
|
288 |
+
|
289 |
+
return parser
|
290 |
+
|
291 |
+
|
292 |
+
if __name__ == "__main__":
|
293 |
+
parser = setup_parser()
|
294 |
+
|
295 |
+
args = parser.parse_args()
|
296 |
+
svd(**vars(args))
|
external/llite/networks/lora.py
ADDED
@@ -0,0 +1,1225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
9 |
+
from diffusers import AutoencoderKL
|
10 |
+
from transformers import CLIPTextModel
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import re
|
14 |
+
|
15 |
+
|
16 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
17 |
+
|
18 |
+
|
19 |
+
class LoRAModule(torch.nn.Module):
|
20 |
+
"""
|
21 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
lora_name,
|
27 |
+
org_module: torch.nn.Module,
|
28 |
+
multiplier=1.0,
|
29 |
+
lora_dim=4,
|
30 |
+
alpha=1,
|
31 |
+
dropout=None,
|
32 |
+
rank_dropout=None,
|
33 |
+
module_dropout=None,
|
34 |
+
):
|
35 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
36 |
+
super().__init__()
|
37 |
+
self.lora_name = lora_name
|
38 |
+
|
39 |
+
if org_module.__class__.__name__ == "Conv2d":
|
40 |
+
in_dim = org_module.in_channels
|
41 |
+
out_dim = org_module.out_channels
|
42 |
+
else:
|
43 |
+
in_dim = org_module.in_features
|
44 |
+
out_dim = org_module.out_features
|
45 |
+
|
46 |
+
# if limit_rank:
|
47 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
48 |
+
# if self.lora_dim != lora_dim:
|
49 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
50 |
+
# else:
|
51 |
+
self.lora_dim = lora_dim
|
52 |
+
|
53 |
+
if org_module.__class__.__name__ == "Conv2d":
|
54 |
+
kernel_size = org_module.kernel_size
|
55 |
+
stride = org_module.stride
|
56 |
+
padding = org_module.padding
|
57 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
58 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
59 |
+
else:
|
60 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
61 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
62 |
+
|
63 |
+
if type(alpha) == torch.Tensor:
|
64 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
65 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
66 |
+
self.scale = alpha / self.lora_dim
|
67 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
68 |
+
|
69 |
+
# same as microsoft's
|
70 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
71 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
72 |
+
|
73 |
+
self.multiplier = multiplier
|
74 |
+
self.org_module = org_module # remove in applying
|
75 |
+
self.dropout = dropout
|
76 |
+
self.rank_dropout = rank_dropout
|
77 |
+
self.module_dropout = module_dropout
|
78 |
+
|
79 |
+
def apply_to(self):
|
80 |
+
self.org_forward = self.org_module.forward
|
81 |
+
self.org_module.forward = self.forward
|
82 |
+
del self.org_module
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
org_forwarded = self.org_forward(x)
|
86 |
+
|
87 |
+
# module dropout
|
88 |
+
if self.module_dropout is not None and self.training:
|
89 |
+
if torch.rand(1) < self.module_dropout:
|
90 |
+
return org_forwarded
|
91 |
+
|
92 |
+
lx = self.lora_down(x)
|
93 |
+
|
94 |
+
# normal dropout
|
95 |
+
if self.dropout is not None and self.training:
|
96 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
97 |
+
|
98 |
+
# rank dropout
|
99 |
+
if self.rank_dropout is not None and self.training:
|
100 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
101 |
+
if len(lx.size()) == 3:
|
102 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
103 |
+
elif len(lx.size()) == 4:
|
104 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
105 |
+
lx = lx * mask
|
106 |
+
|
107 |
+
# scaling for rank dropout: treat as if the rank is changed
|
108 |
+
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
109 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
110 |
+
else:
|
111 |
+
scale = self.scale
|
112 |
+
|
113 |
+
lx = self.lora_up(lx)
|
114 |
+
|
115 |
+
return org_forwarded + lx * self.multiplier * scale
|
116 |
+
|
117 |
+
|
118 |
+
class LoRAInfModule(LoRAModule):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
lora_name,
|
122 |
+
org_module: torch.nn.Module,
|
123 |
+
multiplier=1.0,
|
124 |
+
lora_dim=4,
|
125 |
+
alpha=1,
|
126 |
+
**kwargs,
|
127 |
+
):
|
128 |
+
# no dropout for inference
|
129 |
+
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
130 |
+
|
131 |
+
self.org_module_ref = [org_module] # 後から参照できるように
|
132 |
+
self.enabled = True
|
133 |
+
|
134 |
+
# check regional or not by lora_name
|
135 |
+
self.text_encoder = False
|
136 |
+
if lora_name.startswith("lora_te_"):
|
137 |
+
self.regional = False
|
138 |
+
self.use_sub_prompt = True
|
139 |
+
self.text_encoder = True
|
140 |
+
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
141 |
+
self.regional = False
|
142 |
+
self.use_sub_prompt = True
|
143 |
+
elif "time_emb" in lora_name:
|
144 |
+
self.regional = False
|
145 |
+
self.use_sub_prompt = False
|
146 |
+
else:
|
147 |
+
self.regional = True
|
148 |
+
self.use_sub_prompt = False
|
149 |
+
|
150 |
+
self.network: LoRANetwork = None
|
151 |
+
|
152 |
+
def set_network(self, network):
|
153 |
+
self.network = network
|
154 |
+
|
155 |
+
# freezeしてマージする
|
156 |
+
def merge_to(self, sd, dtype, device):
|
157 |
+
# get up/down weight
|
158 |
+
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
159 |
+
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
160 |
+
|
161 |
+
# extract weight from org_module
|
162 |
+
org_sd = self.org_module.state_dict()
|
163 |
+
weight = org_sd["weight"].to(torch.float)
|
164 |
+
|
165 |
+
# merge weight
|
166 |
+
if len(weight.size()) == 2:
|
167 |
+
# linear
|
168 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
169 |
+
elif down_weight.size()[2:4] == (1, 1):
|
170 |
+
# conv2d 1x1
|
171 |
+
weight = (
|
172 |
+
weight
|
173 |
+
+ self.multiplier
|
174 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
175 |
+
* self.scale
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# conv2d 3x3
|
179 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
180 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
181 |
+
weight = weight + self.multiplier * conved * self.scale
|
182 |
+
|
183 |
+
# set weight to org_module
|
184 |
+
org_sd["weight"] = weight.to(dtype)
|
185 |
+
self.org_module.load_state_dict(org_sd)
|
186 |
+
|
187 |
+
# 復元できるマージのため、このモジュールのweightを返す
|
188 |
+
def get_weight(self, multiplier=None):
|
189 |
+
if multiplier is None:
|
190 |
+
multiplier = self.multiplier
|
191 |
+
|
192 |
+
# get up/down weight from module
|
193 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
194 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
195 |
+
|
196 |
+
# pre-calculated weight
|
197 |
+
if len(down_weight.size()) == 2:
|
198 |
+
# linear
|
199 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
200 |
+
elif down_weight.size()[2:4] == (1, 1):
|
201 |
+
# conv2d 1x1
|
202 |
+
weight = (
|
203 |
+
self.multiplier
|
204 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
205 |
+
* self.scale
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
# conv2d 3x3
|
209 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
210 |
+
weight = self.multiplier * conved * self.scale
|
211 |
+
|
212 |
+
return weight
|
213 |
+
|
214 |
+
def set_region(self, region):
|
215 |
+
self.region = region
|
216 |
+
self.region_mask = None
|
217 |
+
|
218 |
+
def default_forward(self, x):
|
219 |
+
# print("default_forward", self.lora_name, x.size())
|
220 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
if not self.enabled:
|
224 |
+
return self.org_forward(x)
|
225 |
+
|
226 |
+
if self.network is None or self.network.sub_prompt_index is None:
|
227 |
+
return self.default_forward(x)
|
228 |
+
if not self.regional and not self.use_sub_prompt:
|
229 |
+
return self.default_forward(x)
|
230 |
+
|
231 |
+
if self.regional:
|
232 |
+
return self.regional_forward(x)
|
233 |
+
else:
|
234 |
+
return self.sub_prompt_forward(x)
|
235 |
+
|
236 |
+
def get_mask_for_x(self, x):
|
237 |
+
# calculate size from shape of x
|
238 |
+
if len(x.size()) == 4:
|
239 |
+
h, w = x.size()[2:4]
|
240 |
+
area = h * w
|
241 |
+
else:
|
242 |
+
area = x.size()[1]
|
243 |
+
|
244 |
+
mask = self.network.mask_dic.get(area, None)
|
245 |
+
if mask is None:
|
246 |
+
# raise ValueError(f"mask is None for resolution {area}")
|
247 |
+
# emb_layers in SDXL doesn't have mask
|
248 |
+
# print(f"mask is None for resolution {area}, {x.size()}")
|
249 |
+
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
250 |
+
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
251 |
+
if len(x.size()) != 4:
|
252 |
+
mask = torch.reshape(mask, (1, -1, 1))
|
253 |
+
return mask
|
254 |
+
|
255 |
+
def regional_forward(self, x):
|
256 |
+
if "attn2_to_out" in self.lora_name:
|
257 |
+
return self.to_out_forward(x)
|
258 |
+
|
259 |
+
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
260 |
+
return self.default_forward(x)
|
261 |
+
|
262 |
+
# apply mask for LoRA result
|
263 |
+
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
264 |
+
mask = self.get_mask_for_x(lx)
|
265 |
+
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
266 |
+
lx = lx * mask
|
267 |
+
|
268 |
+
x = self.org_forward(x)
|
269 |
+
x = x + lx
|
270 |
+
|
271 |
+
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
272 |
+
x = self.postp_to_q(x)
|
273 |
+
|
274 |
+
return x
|
275 |
+
|
276 |
+
def postp_to_q(self, x):
|
277 |
+
# repeat x to num_sub_prompts
|
278 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
279 |
+
qc = self.network.batch_size # uncond
|
280 |
+
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
281 |
+
if has_real_uncond:
|
282 |
+
qc += self.network.batch_size # real_uncond
|
283 |
+
|
284 |
+
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
285 |
+
query[: self.network.batch_size] = x[: self.network.batch_size]
|
286 |
+
|
287 |
+
for i in range(self.network.batch_size):
|
288 |
+
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
289 |
+
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
290 |
+
|
291 |
+
if has_real_uncond:
|
292 |
+
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
293 |
+
|
294 |
+
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
295 |
+
return query
|
296 |
+
|
297 |
+
def sub_prompt_forward(self, x):
|
298 |
+
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
299 |
+
return self.org_forward(x)
|
300 |
+
|
301 |
+
emb_idx = self.network.sub_prompt_index
|
302 |
+
if not self.text_encoder:
|
303 |
+
emb_idx += self.network.batch_size
|
304 |
+
|
305 |
+
# apply sub prompt of X
|
306 |
+
lx = x[emb_idx :: self.network.num_sub_prompts]
|
307 |
+
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
308 |
+
|
309 |
+
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
310 |
+
|
311 |
+
x = self.org_forward(x)
|
312 |
+
x[emb_idx :: self.network.num_sub_prompts] += lx
|
313 |
+
|
314 |
+
return x
|
315 |
+
|
316 |
+
def to_out_forward(self, x):
|
317 |
+
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
318 |
+
|
319 |
+
if self.network.is_last_network:
|
320 |
+
masks = [None] * self.network.num_sub_prompts
|
321 |
+
self.network.shared[self.lora_name] = (None, masks)
|
322 |
+
else:
|
323 |
+
lx, masks = self.network.shared[self.lora_name]
|
324 |
+
|
325 |
+
# call own LoRA
|
326 |
+
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
327 |
+
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
328 |
+
|
329 |
+
if self.network.is_last_network:
|
330 |
+
lx = torch.zeros(
|
331 |
+
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
332 |
+
)
|
333 |
+
self.network.shared[self.lora_name] = (lx, masks)
|
334 |
+
|
335 |
+
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
336 |
+
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
337 |
+
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
338 |
+
|
339 |
+
# if not last network, return x and masks
|
340 |
+
x = self.org_forward(x)
|
341 |
+
if not self.network.is_last_network:
|
342 |
+
return x
|
343 |
+
|
344 |
+
lx, masks = self.network.shared.pop(self.lora_name)
|
345 |
+
|
346 |
+
# if last network, combine separated x with mask weighted sum
|
347 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
348 |
+
|
349 |
+
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
350 |
+
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
351 |
+
if has_real_uncond:
|
352 |
+
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
353 |
+
|
354 |
+
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
355 |
+
# if num_sub_prompts > num of LoRAs, fill with zero
|
356 |
+
for i in range(len(masks)):
|
357 |
+
if masks[i] is None:
|
358 |
+
masks[i] = torch.zeros_like(masks[0])
|
359 |
+
|
360 |
+
mask = torch.cat(masks)
|
361 |
+
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
362 |
+
for i in range(self.network.batch_size):
|
363 |
+
# 1枚の画像ごとに処理する
|
364 |
+
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
365 |
+
lx1 = lx1 * mask
|
366 |
+
lx1 = torch.sum(lx1, dim=0)
|
367 |
+
|
368 |
+
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
369 |
+
x1 = x[xi : xi + self.network.num_sub_prompts]
|
370 |
+
x1 = x1 * mask
|
371 |
+
x1 = torch.sum(x1, dim=0)
|
372 |
+
x1 = x1 / mask_sum
|
373 |
+
|
374 |
+
x1 = x1 + lx1
|
375 |
+
out[self.network.batch_size + i] = x1
|
376 |
+
|
377 |
+
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
378 |
+
return out
|
379 |
+
|
380 |
+
|
381 |
+
def parse_block_lr_kwargs(nw_kwargs):
|
382 |
+
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
383 |
+
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
384 |
+
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
385 |
+
|
386 |
+
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
387 |
+
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
388 |
+
return None, None, None
|
389 |
+
|
390 |
+
# extract learning rate weight for each block
|
391 |
+
if down_lr_weight is not None:
|
392 |
+
# if some parameters are not set, use zero
|
393 |
+
if "," in down_lr_weight:
|
394 |
+
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
395 |
+
|
396 |
+
if mid_lr_weight is not None:
|
397 |
+
mid_lr_weight = float(mid_lr_weight)
|
398 |
+
|
399 |
+
if up_lr_weight is not None:
|
400 |
+
if "," in up_lr_weight:
|
401 |
+
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
402 |
+
|
403 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
404 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
405 |
+
)
|
406 |
+
|
407 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
408 |
+
|
409 |
+
|
410 |
+
def create_network(
|
411 |
+
multiplier: float,
|
412 |
+
network_dim: Optional[int],
|
413 |
+
network_alpha: Optional[float],
|
414 |
+
vae: AutoencoderKL,
|
415 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
416 |
+
unet,
|
417 |
+
neuron_dropout: Optional[float] = None,
|
418 |
+
**kwargs,
|
419 |
+
):
|
420 |
+
if network_dim is None:
|
421 |
+
network_dim = 4 # default
|
422 |
+
if network_alpha is None:
|
423 |
+
network_alpha = 1.0
|
424 |
+
|
425 |
+
# extract dim/alpha for conv2d, and block dim
|
426 |
+
conv_dim = kwargs.get("conv_dim", None)
|
427 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
428 |
+
if conv_dim is not None:
|
429 |
+
conv_dim = int(conv_dim)
|
430 |
+
if conv_alpha is None:
|
431 |
+
conv_alpha = 1.0
|
432 |
+
else:
|
433 |
+
conv_alpha = float(conv_alpha)
|
434 |
+
|
435 |
+
# block dim/alpha/lr
|
436 |
+
block_dims = kwargs.get("block_dims", None)
|
437 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
438 |
+
|
439 |
+
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
440 |
+
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
441 |
+
block_alphas = kwargs.get("block_alphas", None)
|
442 |
+
conv_block_dims = kwargs.get("conv_block_dims", None)
|
443 |
+
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
444 |
+
|
445 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
446 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
447 |
+
)
|
448 |
+
|
449 |
+
# remove block dim/alpha without learning rate
|
450 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
451 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
452 |
+
)
|
453 |
+
|
454 |
+
else:
|
455 |
+
block_alphas = None
|
456 |
+
conv_block_dims = None
|
457 |
+
conv_block_alphas = None
|
458 |
+
|
459 |
+
# rank/module dropout
|
460 |
+
rank_dropout = kwargs.get("rank_dropout", None)
|
461 |
+
if rank_dropout is not None:
|
462 |
+
rank_dropout = float(rank_dropout)
|
463 |
+
module_dropout = kwargs.get("module_dropout", None)
|
464 |
+
if module_dropout is not None:
|
465 |
+
module_dropout = float(module_dropout)
|
466 |
+
|
467 |
+
# すごく引数が多いな ( ^ω^)・・・
|
468 |
+
network = LoRANetwork(
|
469 |
+
text_encoder,
|
470 |
+
unet,
|
471 |
+
multiplier=multiplier,
|
472 |
+
lora_dim=network_dim,
|
473 |
+
alpha=network_alpha,
|
474 |
+
dropout=neuron_dropout,
|
475 |
+
rank_dropout=rank_dropout,
|
476 |
+
module_dropout=module_dropout,
|
477 |
+
conv_lora_dim=conv_dim,
|
478 |
+
conv_alpha=conv_alpha,
|
479 |
+
block_dims=block_dims,
|
480 |
+
block_alphas=block_alphas,
|
481 |
+
conv_block_dims=conv_block_dims,
|
482 |
+
conv_block_alphas=conv_block_alphas,
|
483 |
+
varbose=True,
|
484 |
+
)
|
485 |
+
|
486 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
487 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
488 |
+
|
489 |
+
return network
|
490 |
+
|
491 |
+
|
492 |
+
# このメソッドは外部から呼び出される可能性を考慮しておく
|
493 |
+
# network_dim, network_alpha にはデフォルト値が入っている。
|
494 |
+
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
495 |
+
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
496 |
+
def get_block_dims_and_alphas(
|
497 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
498 |
+
):
|
499 |
+
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
500 |
+
|
501 |
+
def parse_ints(s):
|
502 |
+
return [int(i) for i in s.split(",")]
|
503 |
+
|
504 |
+
def parse_floats(s):
|
505 |
+
return [float(i) for i in s.split(",")]
|
506 |
+
|
507 |
+
# block_dimsとblock_alphasをパースする。必ず値が入る
|
508 |
+
if block_dims is not None:
|
509 |
+
block_dims = parse_ints(block_dims)
|
510 |
+
assert (
|
511 |
+
len(block_dims) == num_total_blocks
|
512 |
+
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
513 |
+
else:
|
514 |
+
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
515 |
+
block_dims = [network_dim] * num_total_blocks
|
516 |
+
|
517 |
+
if block_alphas is not None:
|
518 |
+
block_alphas = parse_floats(block_alphas)
|
519 |
+
assert (
|
520 |
+
len(block_alphas) == num_total_blocks
|
521 |
+
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してくださ��"
|
522 |
+
else:
|
523 |
+
print(
|
524 |
+
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
525 |
+
)
|
526 |
+
block_alphas = [network_alpha] * num_total_blocks
|
527 |
+
|
528 |
+
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
529 |
+
if conv_block_dims is not None:
|
530 |
+
conv_block_dims = parse_ints(conv_block_dims)
|
531 |
+
assert (
|
532 |
+
len(conv_block_dims) == num_total_blocks
|
533 |
+
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
534 |
+
|
535 |
+
if conv_block_alphas is not None:
|
536 |
+
conv_block_alphas = parse_floats(conv_block_alphas)
|
537 |
+
assert (
|
538 |
+
len(conv_block_alphas) == num_total_blocks
|
539 |
+
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
540 |
+
else:
|
541 |
+
if conv_alpha is None:
|
542 |
+
conv_alpha = 1.0
|
543 |
+
print(
|
544 |
+
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
545 |
+
)
|
546 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
547 |
+
else:
|
548 |
+
if conv_dim is not None:
|
549 |
+
print(
|
550 |
+
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
551 |
+
)
|
552 |
+
conv_block_dims = [conv_dim] * num_total_blocks
|
553 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
554 |
+
else:
|
555 |
+
conv_block_dims = None
|
556 |
+
conv_block_alphas = None
|
557 |
+
|
558 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
559 |
+
|
560 |
+
|
561 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
562 |
+
def get_block_lr_weight(
|
563 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
564 |
+
) -> Tuple[List[float], List[float], List[float]]:
|
565 |
+
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
566 |
+
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
567 |
+
return None, None, None
|
568 |
+
|
569 |
+
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
570 |
+
|
571 |
+
def get_list(name_with_suffix) -> List[float]:
|
572 |
+
import math
|
573 |
+
|
574 |
+
tokens = name_with_suffix.split("+")
|
575 |
+
name = tokens[0]
|
576 |
+
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
577 |
+
|
578 |
+
if name == "cosine":
|
579 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
580 |
+
elif name == "sine":
|
581 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
582 |
+
elif name == "linear":
|
583 |
+
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
584 |
+
elif name == "reverse_linear":
|
585 |
+
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
586 |
+
elif name == "zeros":
|
587 |
+
return [0.0 + base_lr] * max_len
|
588 |
+
else:
|
589 |
+
print(
|
590 |
+
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
591 |
+
% (name)
|
592 |
+
)
|
593 |
+
return None
|
594 |
+
|
595 |
+
if type(down_lr_weight) == str:
|
596 |
+
down_lr_weight = get_list(down_lr_weight)
|
597 |
+
if type(up_lr_weight) == str:
|
598 |
+
up_lr_weight = get_list(up_lr_weight)
|
599 |
+
|
600 |
+
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
601 |
+
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
602 |
+
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
603 |
+
up_lr_weight = up_lr_weight[:max_len]
|
604 |
+
down_lr_weight = down_lr_weight[:max_len]
|
605 |
+
|
606 |
+
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
607 |
+
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
608 |
+
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
609 |
+
|
610 |
+
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
611 |
+
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
612 |
+
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
613 |
+
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
614 |
+
|
615 |
+
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
616 |
+
print("apply block learning rate / 階層別学習率を適用します。")
|
617 |
+
if down_lr_weight != None:
|
618 |
+
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
619 |
+
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
620 |
+
else:
|
621 |
+
print("down_lr_weight: all 1.0, すべて1.0")
|
622 |
+
|
623 |
+
if mid_lr_weight != None:
|
624 |
+
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
625 |
+
print("mid_lr_weight:", mid_lr_weight)
|
626 |
+
else:
|
627 |
+
print("mid_lr_weight: 1.0")
|
628 |
+
|
629 |
+
if up_lr_weight != None:
|
630 |
+
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
631 |
+
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
632 |
+
else:
|
633 |
+
print("up_lr_weight: all 1.0, すべて1.0")
|
634 |
+
|
635 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
636 |
+
|
637 |
+
|
638 |
+
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
639 |
+
def remove_block_dims_and_alphas(
|
640 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
641 |
+
):
|
642 |
+
# set 0 to block dim without learning rate to remove the block
|
643 |
+
if down_lr_weight != None:
|
644 |
+
for i, lr in enumerate(down_lr_weight):
|
645 |
+
if lr == 0:
|
646 |
+
block_dims[i] = 0
|
647 |
+
if conv_block_dims is not None:
|
648 |
+
conv_block_dims[i] = 0
|
649 |
+
if mid_lr_weight != None:
|
650 |
+
if mid_lr_weight == 0:
|
651 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
652 |
+
if conv_block_dims is not None:
|
653 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
654 |
+
if up_lr_weight != None:
|
655 |
+
for i, lr in enumerate(up_lr_weight):
|
656 |
+
if lr == 0:
|
657 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
658 |
+
if conv_block_dims is not None:
|
659 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
660 |
+
|
661 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
662 |
+
|
663 |
+
|
664 |
+
# 外部から呼び出す可能性を考慮しておく
|
665 |
+
def get_block_index(lora_name: str) -> int:
|
666 |
+
block_idx = -1 # invalid lora name
|
667 |
+
|
668 |
+
m = RE_UPDOWN.search(lora_name)
|
669 |
+
if m:
|
670 |
+
g = m.groups()
|
671 |
+
i = int(g[1])
|
672 |
+
j = int(g[3])
|
673 |
+
if g[2] == "resnets":
|
674 |
+
idx = 3 * i + j
|
675 |
+
elif g[2] == "attentions":
|
676 |
+
idx = 3 * i + j
|
677 |
+
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
678 |
+
idx = 3 * i + 2
|
679 |
+
|
680 |
+
if g[0] == "down":
|
681 |
+
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
682 |
+
elif g[0] == "up":
|
683 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
684 |
+
|
685 |
+
elif "mid_block_" in lora_name:
|
686 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
687 |
+
|
688 |
+
return block_idx
|
689 |
+
|
690 |
+
|
691 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
692 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
693 |
+
if weights_sd is None:
|
694 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
695 |
+
from safetensors.torch import load_file, safe_open
|
696 |
+
|
697 |
+
weights_sd = load_file(file)
|
698 |
+
else:
|
699 |
+
weights_sd = torch.load(file, map_location="cpu")
|
700 |
+
|
701 |
+
# get dim/alpha mapping
|
702 |
+
modules_dim = {}
|
703 |
+
modules_alpha = {}
|
704 |
+
for key, value in weights_sd.items():
|
705 |
+
if "." not in key:
|
706 |
+
continue
|
707 |
+
|
708 |
+
lora_name = key.split(".")[0]
|
709 |
+
if "alpha" in key:
|
710 |
+
modules_alpha[lora_name] = value
|
711 |
+
elif "lora_down" in key:
|
712 |
+
dim = value.size()[0]
|
713 |
+
modules_dim[lora_name] = dim
|
714 |
+
# print(lora_name, value.size(), dim)
|
715 |
+
|
716 |
+
# support old LoRA without alpha
|
717 |
+
for key in modules_dim.keys():
|
718 |
+
if key not in modules_alpha:
|
719 |
+
modules_alpha[key] = modules_dim[key]
|
720 |
+
|
721 |
+
module_class = LoRAInfModule if for_inference else LoRAModule
|
722 |
+
|
723 |
+
network = LoRANetwork(
|
724 |
+
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
725 |
+
)
|
726 |
+
|
727 |
+
# block lr
|
728 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
729 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
730 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
731 |
+
|
732 |
+
return network, weights_sd
|
733 |
+
|
734 |
+
|
735 |
+
class LoRANetwork(torch.nn.Module):
|
736 |
+
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
737 |
+
|
738 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
739 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
740 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
741 |
+
LORA_PREFIX_UNET = "lora_unet"
|
742 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
743 |
+
|
744 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
745 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
746 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
747 |
+
|
748 |
+
def __init__(
|
749 |
+
self,
|
750 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
751 |
+
unet,
|
752 |
+
multiplier: float = 1.0,
|
753 |
+
lora_dim: int = 4,
|
754 |
+
alpha: float = 1,
|
755 |
+
dropout: Optional[float] = None,
|
756 |
+
rank_dropout: Optional[float] = None,
|
757 |
+
module_dropout: Optional[float] = None,
|
758 |
+
conv_lora_dim: Optional[int] = None,
|
759 |
+
conv_alpha: Optional[float] = None,
|
760 |
+
block_dims: Optional[List[int]] = None,
|
761 |
+
block_alphas: Optional[List[float]] = None,
|
762 |
+
conv_block_dims: Optional[List[int]] = None,
|
763 |
+
conv_block_alphas: Optional[List[float]] = None,
|
764 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
765 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
766 |
+
module_class: Type[object] = LoRAModule,
|
767 |
+
varbose: Optional[bool] = False,
|
768 |
+
) -> None:
|
769 |
+
"""
|
770 |
+
LoRA network: すごく引数が多いが、パターンは以下の通り
|
771 |
+
1. lora_dimとalphaを指定
|
772 |
+
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
773 |
+
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
774 |
+
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
775 |
+
5. modules_dimとmodules_alphaを指定 (推論用)
|
776 |
+
"""
|
777 |
+
super().__init__()
|
778 |
+
self.multiplier = multiplier
|
779 |
+
|
780 |
+
self.lora_dim = lora_dim
|
781 |
+
self.alpha = alpha
|
782 |
+
self.conv_lora_dim = conv_lora_dim
|
783 |
+
self.conv_alpha = conv_alpha
|
784 |
+
self.dropout = dropout
|
785 |
+
self.rank_dropout = rank_dropout
|
786 |
+
self.module_dropout = module_dropout
|
787 |
+
|
788 |
+
if modules_dim is not None:
|
789 |
+
print(f"create LoRA network from weights")
|
790 |
+
elif block_dims is not None:
|
791 |
+
print(f"create LoRA network from block_dims")
|
792 |
+
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
793 |
+
print(f"block_dims: {block_dims}")
|
794 |
+
print(f"block_alphas: {block_alphas}")
|
795 |
+
if conv_block_dims is not None:
|
796 |
+
print(f"conv_block_dims: {conv_block_dims}")
|
797 |
+
print(f"conv_block_alphas: {conv_block_alphas}")
|
798 |
+
else:
|
799 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
800 |
+
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
801 |
+
if self.conv_lora_dim is not None:
|
802 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
803 |
+
|
804 |
+
# create module instances
|
805 |
+
def create_modules(
|
806 |
+
is_unet: bool,
|
807 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
808 |
+
root_module: torch.nn.Module,
|
809 |
+
target_replace_modules: List[torch.nn.Module],
|
810 |
+
) -> List[LoRAModule]:
|
811 |
+
prefix = (
|
812 |
+
self.LORA_PREFIX_UNET
|
813 |
+
if is_unet
|
814 |
+
else (
|
815 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
816 |
+
if text_encoder_idx is None
|
817 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
818 |
+
)
|
819 |
+
)
|
820 |
+
loras = []
|
821 |
+
skipped = []
|
822 |
+
for name, module in root_module.named_modules():
|
823 |
+
if module.__class__.__name__ in target_replace_modules:
|
824 |
+
for child_name, child_module in module.named_modules():
|
825 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
826 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
827 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
828 |
+
|
829 |
+
if is_linear or is_conv2d:
|
830 |
+
lora_name = prefix + "." + name + "." + child_name
|
831 |
+
lora_name = lora_name.replace(".", "_")
|
832 |
+
|
833 |
+
dim = None
|
834 |
+
alpha = None
|
835 |
+
|
836 |
+
if modules_dim is not None:
|
837 |
+
# モジュール指定あり
|
838 |
+
if lora_name in modules_dim:
|
839 |
+
dim = modules_dim[lora_name]
|
840 |
+
alpha = modules_alpha[lora_name]
|
841 |
+
elif is_unet and block_dims is not None:
|
842 |
+
# U-Netでblock_dims指定あり
|
843 |
+
block_idx = get_block_index(lora_name)
|
844 |
+
if is_linear or is_conv2d_1x1:
|
845 |
+
dim = block_dims[block_idx]
|
846 |
+
alpha = block_alphas[block_idx]
|
847 |
+
elif conv_block_dims is not None:
|
848 |
+
dim = conv_block_dims[block_idx]
|
849 |
+
alpha = conv_block_alphas[block_idx]
|
850 |
+
else:
|
851 |
+
# 通常、すべて対象とする
|
852 |
+
if is_linear or is_conv2d_1x1:
|
853 |
+
dim = self.lora_dim
|
854 |
+
alpha = self.alpha
|
855 |
+
elif self.conv_lora_dim is not None:
|
856 |
+
dim = self.conv_lora_dim
|
857 |
+
alpha = self.conv_alpha
|
858 |
+
|
859 |
+
if dim is None or dim == 0:
|
860 |
+
# skipした情報を出力
|
861 |
+
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
862 |
+
skipped.append(lora_name)
|
863 |
+
continue
|
864 |
+
|
865 |
+
lora = module_class(
|
866 |
+
lora_name,
|
867 |
+
child_module,
|
868 |
+
self.multiplier,
|
869 |
+
dim,
|
870 |
+
alpha,
|
871 |
+
dropout=dropout,
|
872 |
+
rank_dropout=rank_dropout,
|
873 |
+
module_dropout=module_dropout,
|
874 |
+
)
|
875 |
+
loras.append(lora)
|
876 |
+
return loras, skipped
|
877 |
+
|
878 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
879 |
+
|
880 |
+
# create LoRA for text encoder
|
881 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
882 |
+
self.text_encoder_loras = []
|
883 |
+
skipped_te = []
|
884 |
+
for i, text_encoder in enumerate(text_encoders):
|
885 |
+
if len(text_encoders) > 1:
|
886 |
+
index = i + 1
|
887 |
+
print(f"create LoRA for Text Encoder {index}:")
|
888 |
+
else:
|
889 |
+
index = None
|
890 |
+
print(f"create LoRA for Text Encoder:")
|
891 |
+
|
892 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
893 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
894 |
+
skipped_te += skipped
|
895 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
896 |
+
|
897 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
898 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
899 |
+
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
900 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
901 |
+
|
902 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
903 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
904 |
+
|
905 |
+
skipped = skipped_te + skipped_un
|
906 |
+
if varbose and len(skipped) > 0:
|
907 |
+
print(
|
908 |
+
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
909 |
+
)
|
910 |
+
for name in skipped:
|
911 |
+
print(f"\t{name}")
|
912 |
+
|
913 |
+
self.up_lr_weight: List[float] = None
|
914 |
+
self.down_lr_weight: List[float] = None
|
915 |
+
self.mid_lr_weight: float = None
|
916 |
+
self.block_lr = False
|
917 |
+
|
918 |
+
# assertion
|
919 |
+
names = set()
|
920 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
921 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
922 |
+
names.add(lora.lora_name)
|
923 |
+
|
924 |
+
def set_multiplier(self, multiplier):
|
925 |
+
self.multiplier = multiplier
|
926 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
927 |
+
lora.multiplier = self.multiplier
|
928 |
+
|
929 |
+
def load_weights(self, file):
|
930 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
931 |
+
from safetensors.torch import load_file
|
932 |
+
|
933 |
+
weights_sd = load_file(file)
|
934 |
+
else:
|
935 |
+
weights_sd = torch.load(file, map_location="cpu")
|
936 |
+
|
937 |
+
info = self.load_state_dict(weights_sd, False)
|
938 |
+
return info
|
939 |
+
|
940 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
941 |
+
if apply_text_encoder:
|
942 |
+
print("enable LoRA for text encoder")
|
943 |
+
else:
|
944 |
+
self.text_encoder_loras = []
|
945 |
+
|
946 |
+
if apply_unet:
|
947 |
+
print("enable LoRA for U-Net")
|
948 |
+
else:
|
949 |
+
self.unet_loras = []
|
950 |
+
|
951 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
952 |
+
lora.apply_to()
|
953 |
+
self.add_module(lora.lora_name, lora)
|
954 |
+
|
955 |
+
# マージできるかどうかを返す
|
956 |
+
def is_mergeable(self):
|
957 |
+
return True
|
958 |
+
|
959 |
+
# TODO refactor to common function with apply_to
|
960 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
961 |
+
apply_text_encoder = apply_unet = False
|
962 |
+
for key in weights_sd.keys():
|
963 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
964 |
+
apply_text_encoder = True
|
965 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
966 |
+
apply_unet = True
|
967 |
+
|
968 |
+
if apply_text_encoder:
|
969 |
+
print("enable LoRA for text encoder")
|
970 |
+
else:
|
971 |
+
self.text_encoder_loras = []
|
972 |
+
|
973 |
+
if apply_unet:
|
974 |
+
print("enable LoRA for U-Net")
|
975 |
+
else:
|
976 |
+
self.unet_loras = []
|
977 |
+
|
978 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
979 |
+
sd_for_lora = {}
|
980 |
+
for key in weights_sd.keys():
|
981 |
+
if key.startswith(lora.lora_name):
|
982 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
983 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
984 |
+
|
985 |
+
print(f"weights are merged")
|
986 |
+
|
987 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
988 |
+
def set_block_lr_weight(
|
989 |
+
self,
|
990 |
+
up_lr_weight: List[float] = None,
|
991 |
+
mid_lr_weight: float = None,
|
992 |
+
down_lr_weight: List[float] = None,
|
993 |
+
):
|
994 |
+
self.block_lr = True
|
995 |
+
self.down_lr_weight = down_lr_weight
|
996 |
+
self.mid_lr_weight = mid_lr_weight
|
997 |
+
self.up_lr_weight = up_lr_weight
|
998 |
+
|
999 |
+
def get_lr_weight(self, lora: LoRAModule) -> float:
|
1000 |
+
lr_weight = 1.0
|
1001 |
+
block_idx = get_block_index(lora.lora_name)
|
1002 |
+
if block_idx < 0:
|
1003 |
+
return lr_weight
|
1004 |
+
|
1005 |
+
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
1006 |
+
if self.down_lr_weight != None:
|
1007 |
+
lr_weight = self.down_lr_weight[block_idx]
|
1008 |
+
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
1009 |
+
if self.mid_lr_weight != None:
|
1010 |
+
lr_weight = self.mid_lr_weight
|
1011 |
+
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
1012 |
+
if self.up_lr_weight != None:
|
1013 |
+
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
1014 |
+
|
1015 |
+
return lr_weight
|
1016 |
+
|
1017 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
1018 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
1019 |
+
self.requires_grad_(True)
|
1020 |
+
all_params = []
|
1021 |
+
|
1022 |
+
def enumerate_params(loras):
|
1023 |
+
params = []
|
1024 |
+
for lora in loras:
|
1025 |
+
params.extend(lora.parameters())
|
1026 |
+
return params
|
1027 |
+
|
1028 |
+
if self.text_encoder_loras:
|
1029 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
1030 |
+
if text_encoder_lr is not None:
|
1031 |
+
param_data["lr"] = text_encoder_lr
|
1032 |
+
all_params.append(param_data)
|
1033 |
+
|
1034 |
+
if self.unet_loras:
|
1035 |
+
if self.block_lr:
|
1036 |
+
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
1037 |
+
block_idx_to_lora = {}
|
1038 |
+
for lora in self.unet_loras:
|
1039 |
+
idx = get_block_index(lora.lora_name)
|
1040 |
+
if idx not in block_idx_to_lora:
|
1041 |
+
block_idx_to_lora[idx] = []
|
1042 |
+
block_idx_to_lora[idx].append(lora)
|
1043 |
+
|
1044 |
+
# blockごとにパラメータを設定する
|
1045 |
+
for idx, block_loras in block_idx_to_lora.items():
|
1046 |
+
param_data = {"params": enumerate_params(block_loras)}
|
1047 |
+
|
1048 |
+
if unet_lr is not None:
|
1049 |
+
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
1050 |
+
elif default_lr is not None:
|
1051 |
+
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
1052 |
+
if ("lr" in param_data) and (param_data["lr"] == 0):
|
1053 |
+
continue
|
1054 |
+
all_params.append(param_data)
|
1055 |
+
|
1056 |
+
else:
|
1057 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
1058 |
+
if unet_lr is not None:
|
1059 |
+
param_data["lr"] = unet_lr
|
1060 |
+
all_params.append(param_data)
|
1061 |
+
|
1062 |
+
return all_params
|
1063 |
+
|
1064 |
+
def enable_gradient_checkpointing(self):
|
1065 |
+
# not supported
|
1066 |
+
pass
|
1067 |
+
|
1068 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
1069 |
+
self.requires_grad_(True)
|
1070 |
+
|
1071 |
+
def on_epoch_start(self, text_encoder, unet):
|
1072 |
+
self.train()
|
1073 |
+
|
1074 |
+
def get_trainable_params(self):
|
1075 |
+
return self.parameters()
|
1076 |
+
|
1077 |
+
def save_weights(self, file, dtype, metadata):
|
1078 |
+
if metadata is not None and len(metadata) == 0:
|
1079 |
+
metadata = None
|
1080 |
+
|
1081 |
+
state_dict = self.state_dict()
|
1082 |
+
|
1083 |
+
if dtype is not None:
|
1084 |
+
for key in list(state_dict.keys()):
|
1085 |
+
v = state_dict[key]
|
1086 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
1087 |
+
state_dict[key] = v
|
1088 |
+
|
1089 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
1090 |
+
from safetensors.torch import save_file
|
1091 |
+
from library import train_util
|
1092 |
+
|
1093 |
+
# Precalculate model hashes to save time on indexing
|
1094 |
+
if metadata is None:
|
1095 |
+
metadata = {}
|
1096 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
1097 |
+
metadata["sshs_model_hash"] = model_hash
|
1098 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
1099 |
+
|
1100 |
+
save_file(state_dict, file, metadata)
|
1101 |
+
else:
|
1102 |
+
torch.save(state_dict, file)
|
1103 |
+
|
1104 |
+
# mask is a tensor with values from 0 to 1
|
1105 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
1106 |
+
if mask.max() == 0:
|
1107 |
+
mask = torch.ones_like(mask)
|
1108 |
+
|
1109 |
+
self.mask = mask
|
1110 |
+
self.sub_prompt_index = sub_prompt_index
|
1111 |
+
self.is_last_network = is_last_network
|
1112 |
+
|
1113 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
1114 |
+
lora.set_network(self)
|
1115 |
+
|
1116 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
1117 |
+
self.batch_size = batch_size
|
1118 |
+
self.num_sub_prompts = num_sub_prompts
|
1119 |
+
self.current_size = (height, width)
|
1120 |
+
self.shared = shared
|
1121 |
+
|
1122 |
+
# create masks
|
1123 |
+
mask = self.mask
|
1124 |
+
mask_dic = {}
|
1125 |
+
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
1126 |
+
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
1127 |
+
dtype = ref_weight.dtype
|
1128 |
+
device = ref_weight.device
|
1129 |
+
|
1130 |
+
def resize_add(mh, mw):
|
1131 |
+
# print(mh, mw, mh * mw)
|
1132 |
+
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
1133 |
+
m = m.to(device, dtype=dtype)
|
1134 |
+
mask_dic[mh * mw] = m
|
1135 |
+
|
1136 |
+
h = height // 8
|
1137 |
+
w = width // 8
|
1138 |
+
for _ in range(4):
|
1139 |
+
resize_add(h, w)
|
1140 |
+
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
1141 |
+
resize_add(h + h % 2, w + w % 2)
|
1142 |
+
h = (h + 1) // 2
|
1143 |
+
w = (w + 1) // 2
|
1144 |
+
|
1145 |
+
self.mask_dic = mask_dic
|
1146 |
+
|
1147 |
+
def backup_weights(self):
|
1148 |
+
# 重みのバックアップを行う
|
1149 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
1150 |
+
for lora in loras:
|
1151 |
+
org_module = lora.org_module_ref[0]
|
1152 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
1153 |
+
sd = org_module.state_dict()
|
1154 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
1155 |
+
org_module._lora_restored = True
|
1156 |
+
|
1157 |
+
def restore_weights(self):
|
1158 |
+
# 重みのリストアを行う
|
1159 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
1160 |
+
for lora in loras:
|
1161 |
+
org_module = lora.org_module_ref[0]
|
1162 |
+
if not org_module._lora_restored:
|
1163 |
+
sd = org_module.state_dict()
|
1164 |
+
sd["weight"] = org_module._lora_org_weight
|
1165 |
+
org_module.load_state_dict(sd)
|
1166 |
+
org_module._lora_restored = True
|
1167 |
+
|
1168 |
+
def pre_calculation(self):
|
1169 |
+
# 事前計算を行う
|
1170 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
1171 |
+
for lora in loras:
|
1172 |
+
org_module = lora.org_module_ref[0]
|
1173 |
+
sd = org_module.state_dict()
|
1174 |
+
|
1175 |
+
org_weight = sd["weight"]
|
1176 |
+
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
1177 |
+
sd["weight"] = org_weight + lora_weight
|
1178 |
+
assert sd["weight"].shape == org_weight.shape
|
1179 |
+
org_module.load_state_dict(sd)
|
1180 |
+
|
1181 |
+
org_module._lora_restored = False
|
1182 |
+
lora.enabled = False
|
1183 |
+
|
1184 |
+
def apply_max_norm_regularization(self, max_norm_value, device):
|
1185 |
+
downkeys = []
|
1186 |
+
upkeys = []
|
1187 |
+
alphakeys = []
|
1188 |
+
norms = []
|
1189 |
+
keys_scaled = 0
|
1190 |
+
|
1191 |
+
state_dict = self.state_dict()
|
1192 |
+
for key in state_dict.keys():
|
1193 |
+
if "lora_down" in key and "weight" in key:
|
1194 |
+
downkeys.append(key)
|
1195 |
+
upkeys.append(key.replace("lora_down", "lora_up"))
|
1196 |
+
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
1197 |
+
|
1198 |
+
for i in range(len(downkeys)):
|
1199 |
+
down = state_dict[downkeys[i]].to(device)
|
1200 |
+
up = state_dict[upkeys[i]].to(device)
|
1201 |
+
alpha = state_dict[alphakeys[i]].to(device)
|
1202 |
+
dim = down.shape[0]
|
1203 |
+
scale = alpha / dim
|
1204 |
+
|
1205 |
+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
1206 |
+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
1207 |
+
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
1208 |
+
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
1209 |
+
else:
|
1210 |
+
updown = up @ down
|
1211 |
+
|
1212 |
+
updown *= scale
|
1213 |
+
|
1214 |
+
norm = updown.norm().clamp(min=max_norm_value / 2)
|
1215 |
+
desired = torch.clamp(norm, max=max_norm_value)
|
1216 |
+
ratio = desired.cpu() / norm.cpu()
|
1217 |
+
sqrt_ratio = ratio**0.5
|
1218 |
+
if ratio != 1:
|
1219 |
+
keys_scaled += 1
|
1220 |
+
state_dict[upkeys[i]] *= sqrt_ratio
|
1221 |
+
state_dict[downkeys[i]] *= sqrt_ratio
|
1222 |
+
scalednorm = updown.norm() * ratio
|
1223 |
+
norms.append(scalednorm.item())
|
1224 |
+
|
1225 |
+
return keys_scaled, sum(norms) / len(norms), max(norms)
|
external/llite/networks/lora_diffusers.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusersで動くLoRA。このファイル単独で完結する。
|
2 |
+
# LoRA module for Diffusers. This file works independently.
|
3 |
+
|
4 |
+
import bisect
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
from typing import Any, Dict, List, Mapping, Optional, Union
|
8 |
+
from diffusers import UNet2DConditionModel
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import CLIPTextModel
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
def make_unet_conversion_map() -> Dict[str, str]:
|
16 |
+
unet_conversion_map_layer = []
|
17 |
+
|
18 |
+
for i in range(3): # num_blocks is 3 in sdxl
|
19 |
+
# loop over downblocks/upblocks
|
20 |
+
for j in range(2):
|
21 |
+
# loop over resnets/attentions for downblocks
|
22 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
23 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
24 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
25 |
+
|
26 |
+
if i < 3:
|
27 |
+
# no attention layers in down_blocks.3
|
28 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
29 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
30 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
31 |
+
|
32 |
+
for j in range(3):
|
33 |
+
# loop over resnets/attentions for upblocks
|
34 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
35 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
36 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
37 |
+
|
38 |
+
# if i > 0: commentout for sdxl
|
39 |
+
# no attention layers in up_blocks.0
|
40 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
41 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
42 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
43 |
+
|
44 |
+
if i < 3:
|
45 |
+
# no downsample in down_blocks.3
|
46 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
47 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
48 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
49 |
+
|
50 |
+
# no upsample in up_blocks.3
|
51 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
52 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
53 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
54 |
+
|
55 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
56 |
+
sd_mid_atn_prefix = "middle_block.1."
|
57 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
58 |
+
|
59 |
+
for j in range(2):
|
60 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
61 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
62 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
63 |
+
|
64 |
+
unet_conversion_map_resnet = [
|
65 |
+
# (stable-diffusion, HF Diffusers)
|
66 |
+
("in_layers.0.", "norm1."),
|
67 |
+
("in_layers.2.", "conv1."),
|
68 |
+
("out_layers.0.", "norm2."),
|
69 |
+
("out_layers.3.", "conv2."),
|
70 |
+
("emb_layers.1.", "time_emb_proj."),
|
71 |
+
("skip_connection.", "conv_shortcut."),
|
72 |
+
]
|
73 |
+
|
74 |
+
unet_conversion_map = []
|
75 |
+
for sd, hf in unet_conversion_map_layer:
|
76 |
+
if "resnets" in hf:
|
77 |
+
for sd_res, hf_res in unet_conversion_map_resnet:
|
78 |
+
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
79 |
+
else:
|
80 |
+
unet_conversion_map.append((sd, hf))
|
81 |
+
|
82 |
+
for j in range(2):
|
83 |
+
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
84 |
+
sd_time_embed_prefix = f"time_embed.{j*2}."
|
85 |
+
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
86 |
+
|
87 |
+
for j in range(2):
|
88 |
+
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
89 |
+
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
90 |
+
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
91 |
+
|
92 |
+
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
93 |
+
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
94 |
+
unet_conversion_map.append(("out.2.", "conv_out."))
|
95 |
+
|
96 |
+
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
97 |
+
return sd_hf_conversion_map
|
98 |
+
|
99 |
+
|
100 |
+
UNET_CONVERSION_MAP = make_unet_conversion_map()
|
101 |
+
|
102 |
+
|
103 |
+
class LoRAModule(torch.nn.Module):
|
104 |
+
"""
|
105 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
lora_name,
|
111 |
+
org_module: torch.nn.Module,
|
112 |
+
multiplier=1.0,
|
113 |
+
lora_dim=4,
|
114 |
+
alpha=1,
|
115 |
+
):
|
116 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
117 |
+
super().__init__()
|
118 |
+
self.lora_name = lora_name
|
119 |
+
|
120 |
+
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
121 |
+
in_dim = org_module.in_channels
|
122 |
+
out_dim = org_module.out_channels
|
123 |
+
else:
|
124 |
+
in_dim = org_module.in_features
|
125 |
+
out_dim = org_module.out_features
|
126 |
+
|
127 |
+
self.lora_dim = lora_dim
|
128 |
+
|
129 |
+
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
130 |
+
kernel_size = org_module.kernel_size
|
131 |
+
stride = org_module.stride
|
132 |
+
padding = org_module.padding
|
133 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
134 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
135 |
+
else:
|
136 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
137 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
138 |
+
|
139 |
+
if type(alpha) == torch.Tensor:
|
140 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
141 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
142 |
+
self.scale = alpha / self.lora_dim
|
143 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
|
144 |
+
|
145 |
+
# same as microsoft's
|
146 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
147 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
148 |
+
|
149 |
+
self.multiplier = multiplier
|
150 |
+
self.org_module = [org_module]
|
151 |
+
self.enabled = True
|
152 |
+
self.network: LoRANetwork = None
|
153 |
+
self.org_forward = None
|
154 |
+
|
155 |
+
# override org_module's forward method
|
156 |
+
def apply_to(self, multiplier=None):
|
157 |
+
if multiplier is not None:
|
158 |
+
self.multiplier = multiplier
|
159 |
+
if self.org_forward is None:
|
160 |
+
self.org_forward = self.org_module[0].forward
|
161 |
+
self.org_module[0].forward = self.forward
|
162 |
+
|
163 |
+
# restore org_module's forward method
|
164 |
+
def unapply_to(self):
|
165 |
+
if self.org_forward is not None:
|
166 |
+
self.org_module[0].forward = self.org_forward
|
167 |
+
|
168 |
+
# forward with lora
|
169 |
+
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
|
170 |
+
def forward(self, x, scale=1.0):
|
171 |
+
if not self.enabled:
|
172 |
+
return self.org_forward(x)
|
173 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
174 |
+
|
175 |
+
def set_network(self, network):
|
176 |
+
self.network = network
|
177 |
+
|
178 |
+
# merge lora weight to org weight
|
179 |
+
def merge_to(self, multiplier=1.0):
|
180 |
+
# get lora weight
|
181 |
+
lora_weight = self.get_weight(multiplier)
|
182 |
+
|
183 |
+
# get org weight
|
184 |
+
org_sd = self.org_module[0].state_dict()
|
185 |
+
org_weight = org_sd["weight"]
|
186 |
+
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
187 |
+
|
188 |
+
# set weight to org_module
|
189 |
+
org_sd["weight"] = weight
|
190 |
+
self.org_module[0].load_state_dict(org_sd)
|
191 |
+
|
192 |
+
# restore org weight from lora weight
|
193 |
+
def restore_from(self, multiplier=1.0):
|
194 |
+
# get lora weight
|
195 |
+
lora_weight = self.get_weight(multiplier)
|
196 |
+
|
197 |
+
# get org weight
|
198 |
+
org_sd = self.org_module[0].state_dict()
|
199 |
+
org_weight = org_sd["weight"]
|
200 |
+
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
201 |
+
|
202 |
+
# set weight to org_module
|
203 |
+
org_sd["weight"] = weight
|
204 |
+
self.org_module[0].load_state_dict(org_sd)
|
205 |
+
|
206 |
+
# return lora weight
|
207 |
+
def get_weight(self, multiplier=None):
|
208 |
+
if multiplier is None:
|
209 |
+
multiplier = self.multiplier
|
210 |
+
|
211 |
+
# get up/down weight from module
|
212 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
213 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
214 |
+
|
215 |
+
# pre-calculated weight
|
216 |
+
if len(down_weight.size()) == 2:
|
217 |
+
# linear
|
218 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
219 |
+
elif down_weight.size()[2:4] == (1, 1):
|
220 |
+
# conv2d 1x1
|
221 |
+
weight = (
|
222 |
+
self.multiplier
|
223 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
224 |
+
* self.scale
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
# conv2d 3x3
|
228 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
229 |
+
weight = self.multiplier * conved * self.scale
|
230 |
+
|
231 |
+
return weight
|
232 |
+
|
233 |
+
|
234 |
+
# Create network from weights for inference, weights are not loaded here
|
235 |
+
def create_network_from_weights(
|
236 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
|
237 |
+
):
|
238 |
+
# get dim/alpha mapping
|
239 |
+
modules_dim = {}
|
240 |
+
modules_alpha = {}
|
241 |
+
for key, value in weights_sd.items():
|
242 |
+
if "." not in key:
|
243 |
+
continue
|
244 |
+
|
245 |
+
lora_name = key.split(".")[0]
|
246 |
+
if "alpha" in key:
|
247 |
+
modules_alpha[lora_name] = value
|
248 |
+
elif "lora_down" in key:
|
249 |
+
dim = value.size()[0]
|
250 |
+
modules_dim[lora_name] = dim
|
251 |
+
# print(lora_name, value.size(), dim)
|
252 |
+
|
253 |
+
# support old LoRA without alpha
|
254 |
+
for key in modules_dim.keys():
|
255 |
+
if key not in modules_alpha:
|
256 |
+
modules_alpha[key] = modules_dim[key]
|
257 |
+
|
258 |
+
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
259 |
+
|
260 |
+
|
261 |
+
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
262 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
263 |
+
unet = pipe.unet
|
264 |
+
|
265 |
+
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
266 |
+
lora_network.load_state_dict(weights_sd)
|
267 |
+
lora_network.merge_to(multiplier=multiplier)
|
268 |
+
|
269 |
+
|
270 |
+
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
271 |
+
class LoRANetwork(torch.nn.Module):
|
272 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
273 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
274 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
275 |
+
LORA_PREFIX_UNET = "lora_unet"
|
276 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
277 |
+
|
278 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
279 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
280 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
281 |
+
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
285 |
+
unet: UNet2DConditionModel,
|
286 |
+
multiplier: float = 1.0,
|
287 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
288 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
289 |
+
varbose: Optional[bool] = False,
|
290 |
+
) -> None:
|
291 |
+
super().__init__()
|
292 |
+
self.multiplier = multiplier
|
293 |
+
|
294 |
+
print(f"create LoRA network from weights")
|
295 |
+
|
296 |
+
# convert SDXL Stability AI's U-Net modules to Diffusers
|
297 |
+
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
298 |
+
if converted:
|
299 |
+
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
300 |
+
|
301 |
+
# create module instances
|
302 |
+
def create_modules(
|
303 |
+
is_unet: bool,
|
304 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
305 |
+
root_module: torch.nn.Module,
|
306 |
+
target_replace_modules: List[torch.nn.Module],
|
307 |
+
) -> List[LoRAModule]:
|
308 |
+
prefix = (
|
309 |
+
self.LORA_PREFIX_UNET
|
310 |
+
if is_unet
|
311 |
+
else (
|
312 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
313 |
+
if text_encoder_idx is None
|
314 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
315 |
+
)
|
316 |
+
)
|
317 |
+
loras = []
|
318 |
+
skipped = []
|
319 |
+
for name, module in root_module.named_modules():
|
320 |
+
if module.__class__.__name__ in target_replace_modules:
|
321 |
+
for child_name, child_module in module.named_modules():
|
322 |
+
is_linear = (
|
323 |
+
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
324 |
+
)
|
325 |
+
is_conv2d = (
|
326 |
+
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
327 |
+
)
|
328 |
+
|
329 |
+
if is_linear or is_conv2d:
|
330 |
+
lora_name = prefix + "." + name + "." + child_name
|
331 |
+
lora_name = lora_name.replace(".", "_")
|
332 |
+
|
333 |
+
if lora_name not in modules_dim:
|
334 |
+
# print(f"skipped {lora_name} (not found in modules_dim)")
|
335 |
+
skipped.append(lora_name)
|
336 |
+
continue
|
337 |
+
|
338 |
+
dim = modules_dim[lora_name]
|
339 |
+
alpha = modules_alpha[lora_name]
|
340 |
+
lora = LoRAModule(
|
341 |
+
lora_name,
|
342 |
+
child_module,
|
343 |
+
self.multiplier,
|
344 |
+
dim,
|
345 |
+
alpha,
|
346 |
+
)
|
347 |
+
loras.append(lora)
|
348 |
+
return loras, skipped
|
349 |
+
|
350 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
351 |
+
|
352 |
+
# create LoRA for text encoder
|
353 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
|
354 |
+
self.text_encoder_loras: List[LoRAModule] = []
|
355 |
+
skipped_te = []
|
356 |
+
for i, text_encoder in enumerate(text_encoders):
|
357 |
+
if len(text_encoders) > 1:
|
358 |
+
index = i + 1
|
359 |
+
else:
|
360 |
+
index = None
|
361 |
+
|
362 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
363 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
364 |
+
skipped_te += skipped
|
365 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
366 |
+
if len(skipped_te) > 0:
|
367 |
+
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
368 |
+
|
369 |
+
# extend U-Net target modules to include Conv2d 3x3
|
370 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
371 |
+
|
372 |
+
self.unet_loras: List[LoRAModule]
|
373 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
374 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
375 |
+
if len(skipped_un) > 0:
|
376 |
+
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
377 |
+
|
378 |
+
# assertion
|
379 |
+
names = set()
|
380 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
381 |
+
names.add(lora.lora_name)
|
382 |
+
for lora_name in modules_dim.keys():
|
383 |
+
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
|
384 |
+
|
385 |
+
# make to work load_state_dict
|
386 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
387 |
+
self.add_module(lora.lora_name, lora)
|
388 |
+
|
389 |
+
# SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
|
390 |
+
def convert_unet_modules(self, modules_dim, modules_alpha):
|
391 |
+
converted_count = 0
|
392 |
+
not_converted_count = 0
|
393 |
+
|
394 |
+
map_keys = list(UNET_CONVERSION_MAP.keys())
|
395 |
+
map_keys.sort()
|
396 |
+
|
397 |
+
for key in list(modules_dim.keys()):
|
398 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
399 |
+
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
400 |
+
position = bisect.bisect_right(map_keys, search_key)
|
401 |
+
map_key = map_keys[position - 1]
|
402 |
+
if search_key.startswith(map_key):
|
403 |
+
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
404 |
+
modules_dim[new_key] = modules_dim[key]
|
405 |
+
modules_alpha[new_key] = modules_alpha[key]
|
406 |
+
del modules_dim[key]
|
407 |
+
del modules_alpha[key]
|
408 |
+
converted_count += 1
|
409 |
+
else:
|
410 |
+
not_converted_count += 1
|
411 |
+
assert (
|
412 |
+
converted_count == 0 or not_converted_count == 0
|
413 |
+
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
|
414 |
+
return converted_count
|
415 |
+
|
416 |
+
def set_multiplier(self, multiplier):
|
417 |
+
self.multiplier = multiplier
|
418 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
419 |
+
lora.multiplier = self.multiplier
|
420 |
+
|
421 |
+
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
422 |
+
if apply_text_encoder:
|
423 |
+
print("enable LoRA for text encoder")
|
424 |
+
for lora in self.text_encoder_loras:
|
425 |
+
lora.apply_to(multiplier)
|
426 |
+
if apply_unet:
|
427 |
+
print("enable LoRA for U-Net")
|
428 |
+
for lora in self.unet_loras:
|
429 |
+
lora.apply_to(multiplier)
|
430 |
+
|
431 |
+
def unapply_to(self):
|
432 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
433 |
+
lora.unapply_to()
|
434 |
+
|
435 |
+
def merge_to(self, multiplier=1.0):
|
436 |
+
print("merge LoRA weights to original weights")
|
437 |
+
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
438 |
+
lora.merge_to(multiplier)
|
439 |
+
print(f"weights are merged")
|
440 |
+
|
441 |
+
def restore_from(self, multiplier=1.0):
|
442 |
+
print("restore LoRA weights from original weights")
|
443 |
+
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
444 |
+
lora.restore_from(multiplier)
|
445 |
+
print(f"weights are restored")
|
446 |
+
|
447 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
448 |
+
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
449 |
+
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
450 |
+
map_keys.sort()
|
451 |
+
for key in list(state_dict.keys()):
|
452 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
453 |
+
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
454 |
+
position = bisect.bisect_right(map_keys, search_key)
|
455 |
+
map_key = map_keys[position - 1]
|
456 |
+
if search_key.startswith(map_key):
|
457 |
+
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
458 |
+
state_dict[new_key] = state_dict[key]
|
459 |
+
del state_dict[key]
|
460 |
+
|
461 |
+
# in case of V2, some weights have different shape, so we need to convert them
|
462 |
+
# because V2 LoRA is based on U-Net created by use_linear_projection=False
|
463 |
+
my_state_dict = self.state_dict()
|
464 |
+
for key in state_dict.keys():
|
465 |
+
if state_dict[key].size() != my_state_dict[key].size():
|
466 |
+
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
467 |
+
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
468 |
+
|
469 |
+
return super().load_state_dict(state_dict, strict)
|
470 |
+
|
471 |
+
|
472 |
+
if __name__ == "__main__":
|
473 |
+
# sample code to use LoRANetwork
|
474 |
+
import os
|
475 |
+
import argparse
|
476 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
477 |
+
import torch
|
478 |
+
|
479 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
480 |
+
|
481 |
+
parser = argparse.ArgumentParser()
|
482 |
+
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
483 |
+
parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
|
484 |
+
parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
|
485 |
+
parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
|
486 |
+
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
|
487 |
+
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
488 |
+
args = parser.parse_args()
|
489 |
+
|
490 |
+
image_prefix = args.model_id.replace("/", "_") + "_"
|
491 |
+
|
492 |
+
# load Diffusers model
|
493 |
+
print(f"load model from {args.model_id}")
|
494 |
+
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
495 |
+
if args.sdxl:
|
496 |
+
# use_safetensors=True does not work with 0.18.2
|
497 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
498 |
+
else:
|
499 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
500 |
+
pipe.to(device)
|
501 |
+
pipe.set_use_memory_efficient_attention_xformers(True)
|
502 |
+
|
503 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
504 |
+
|
505 |
+
# load LoRA weights
|
506 |
+
print(f"load LoRA weights from {args.lora_weights}")
|
507 |
+
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
508 |
+
from safetensors.torch import load_file
|
509 |
+
|
510 |
+
lora_sd = load_file(args.lora_weights)
|
511 |
+
else:
|
512 |
+
lora_sd = torch.load(args.lora_weights)
|
513 |
+
|
514 |
+
# create by LoRA weights and load weights
|
515 |
+
print(f"create LoRA network")
|
516 |
+
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
517 |
+
|
518 |
+
print(f"load LoRA network weights")
|
519 |
+
lora_network.load_state_dict(lora_sd)
|
520 |
+
|
521 |
+
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
522 |
+
|
523 |
+
# 必要があれば、元のモデルの重みをバックアップしておく
|
524 |
+
# back-up unet/text encoder weights if necessary
|
525 |
+
def detach_and_move_to_cpu(state_dict):
|
526 |
+
for k, v in state_dict.items():
|
527 |
+
state_dict[k] = v.detach().cpu()
|
528 |
+
return state_dict
|
529 |
+
|
530 |
+
org_unet_sd = pipe.unet.state_dict()
|
531 |
+
detach_and_move_to_cpu(org_unet_sd)
|
532 |
+
|
533 |
+
org_text_encoder_sd = pipe.text_encoder.state_dict()
|
534 |
+
detach_and_move_to_cpu(org_text_encoder_sd)
|
535 |
+
|
536 |
+
if args.sdxl:
|
537 |
+
org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
|
538 |
+
detach_and_move_to_cpu(org_text_encoder_2_sd)
|
539 |
+
|
540 |
+
def seed_everything(seed):
|
541 |
+
torch.manual_seed(seed)
|
542 |
+
torch.cuda.manual_seed_all(seed)
|
543 |
+
np.random.seed(seed)
|
544 |
+
random.seed(seed)
|
545 |
+
|
546 |
+
# create image with original weights
|
547 |
+
print(f"create image with original weights")
|
548 |
+
seed_everything(args.seed)
|
549 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
550 |
+
image.save(image_prefix + "original.png")
|
551 |
+
|
552 |
+
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
553 |
+
print(f"apply LoRA network to the model")
|
554 |
+
lora_network.apply_to(multiplier=1.0)
|
555 |
+
|
556 |
+
print(f"create image with applied LoRA")
|
557 |
+
seed_everything(args.seed)
|
558 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
559 |
+
image.save(image_prefix + "applied_lora.png")
|
560 |
+
|
561 |
+
# unapply LoRA network to the model
|
562 |
+
print(f"unapply LoRA network to the model")
|
563 |
+
lora_network.unapply_to()
|
564 |
+
|
565 |
+
print(f"create image with unapplied LoRA")
|
566 |
+
seed_everything(args.seed)
|
567 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
568 |
+
image.save(image_prefix + "unapplied_lora.png")
|
569 |
+
|
570 |
+
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
571 |
+
print(f"merge LoRA network to the model")
|
572 |
+
lora_network.merge_to(multiplier=1.0)
|
573 |
+
|
574 |
+
print(f"create image with LoRA")
|
575 |
+
seed_everything(args.seed)
|
576 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
577 |
+
image.save(image_prefix + "merged_lora.png")
|
578 |
+
|
579 |
+
# restore (unmerge) LoRA weights: numerically unstable
|
580 |
+
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
581 |
+
# 保存したstate_dictから元の重みを復元するのが確実
|
582 |
+
print(f"restore (unmerge) LoRA weights")
|
583 |
+
lora_network.restore_from(multiplier=1.0)
|
584 |
+
|
585 |
+
print(f"create image without LoRA")
|
586 |
+
seed_everything(args.seed)
|
587 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
588 |
+
image.save(image_prefix + "unmerged_lora.png")
|
589 |
+
|
590 |
+
# restore original weights
|
591 |
+
print(f"restore original weights")
|
592 |
+
pipe.unet.load_state_dict(org_unet_sd)
|
593 |
+
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
594 |
+
if args.sdxl:
|
595 |
+
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
596 |
+
|
597 |
+
print(f"create image with restored original weights")
|
598 |
+
seed_everything(args.seed)
|
599 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
600 |
+
image.save(image_prefix + "restore_original.png")
|
601 |
+
|
602 |
+
# use convenience function to merge LoRA weights
|
603 |
+
print(f"merge LoRA weights with convenience function")
|
604 |
+
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
605 |
+
|
606 |
+
print(f"create image with merged LoRA weights")
|
607 |
+
seed_everything(args.seed)
|
608 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
609 |
+
image.save(image_prefix + "convenience_merged_lora.png")
|
external/llite/networks/lora_fa.py
ADDED
@@ -0,0 +1,1241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
|
6 |
+
# temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303
|
7 |
+
# need to be refactored and merged to lora.py
|
8 |
+
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
12 |
+
from diffusers import AutoencoderKL
|
13 |
+
from transformers import CLIPTextModel
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import re
|
17 |
+
|
18 |
+
|
19 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
20 |
+
|
21 |
+
|
22 |
+
class LoRAModule(torch.nn.Module):
|
23 |
+
"""
|
24 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
lora_name,
|
30 |
+
org_module: torch.nn.Module,
|
31 |
+
multiplier=1.0,
|
32 |
+
lora_dim=4,
|
33 |
+
alpha=1,
|
34 |
+
dropout=None,
|
35 |
+
rank_dropout=None,
|
36 |
+
module_dropout=None,
|
37 |
+
):
|
38 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
39 |
+
super().__init__()
|
40 |
+
self.lora_name = lora_name
|
41 |
+
|
42 |
+
if org_module.__class__.__name__ == "Conv2d":
|
43 |
+
in_dim = org_module.in_channels
|
44 |
+
out_dim = org_module.out_channels
|
45 |
+
else:
|
46 |
+
in_dim = org_module.in_features
|
47 |
+
out_dim = org_module.out_features
|
48 |
+
|
49 |
+
# if limit_rank:
|
50 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
51 |
+
# if self.lora_dim != lora_dim:
|
52 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
53 |
+
# else:
|
54 |
+
self.lora_dim = lora_dim
|
55 |
+
|
56 |
+
if org_module.__class__.__name__ == "Conv2d":
|
57 |
+
kernel_size = org_module.kernel_size
|
58 |
+
stride = org_module.stride
|
59 |
+
padding = org_module.padding
|
60 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
61 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
62 |
+
else:
|
63 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
64 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
65 |
+
|
66 |
+
if type(alpha) == torch.Tensor:
|
67 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
68 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
69 |
+
self.scale = alpha / self.lora_dim
|
70 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
71 |
+
|
72 |
+
# # same as microsoft's
|
73 |
+
# torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
74 |
+
|
75 |
+
# according to the paper, initialize LoRA-A (down) as normal distribution
|
76 |
+
torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim)))
|
77 |
+
|
78 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
79 |
+
|
80 |
+
self.multiplier = multiplier
|
81 |
+
self.org_module = org_module # remove in applying
|
82 |
+
self.dropout = dropout
|
83 |
+
self.rank_dropout = rank_dropout
|
84 |
+
self.module_dropout = module_dropout
|
85 |
+
|
86 |
+
def get_trainable_params(self):
|
87 |
+
params = self.named_parameters()
|
88 |
+
trainable_params = []
|
89 |
+
for param in params:
|
90 |
+
if param[0] == "lora_up.weight": # up only
|
91 |
+
trainable_params.append(param[1])
|
92 |
+
return trainable_params
|
93 |
+
|
94 |
+
def requires_grad_(self, requires_grad: bool = True):
|
95 |
+
self.lora_up.requires_grad_(requires_grad)
|
96 |
+
self.lora_down.requires_grad_(False)
|
97 |
+
return self
|
98 |
+
|
99 |
+
def apply_to(self):
|
100 |
+
self.org_forward = self.org_module.forward
|
101 |
+
self.org_module.forward = self.forward
|
102 |
+
del self.org_module
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
org_forwarded = self.org_forward(x)
|
106 |
+
|
107 |
+
# module dropout
|
108 |
+
if self.module_dropout is not None and self.training:
|
109 |
+
if torch.rand(1) < self.module_dropout:
|
110 |
+
return org_forwarded
|
111 |
+
|
112 |
+
lx = self.lora_down(x)
|
113 |
+
|
114 |
+
# normal dropout
|
115 |
+
if self.dropout is not None and self.training:
|
116 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
117 |
+
|
118 |
+
# rank dropout
|
119 |
+
if self.rank_dropout is not None and self.training:
|
120 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
121 |
+
if len(lx.size()) == 3:
|
122 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
123 |
+
elif len(lx.size()) == 4:
|
124 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
125 |
+
lx = lx * mask
|
126 |
+
|
127 |
+
# scaling for rank dropout: treat as if the rank is changed
|
128 |
+
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
129 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
130 |
+
else:
|
131 |
+
scale = self.scale
|
132 |
+
|
133 |
+
lx = self.lora_up(lx)
|
134 |
+
|
135 |
+
return org_forwarded + lx * self.multiplier * scale
|
136 |
+
|
137 |
+
|
138 |
+
class LoRAInfModule(LoRAModule):
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
lora_name,
|
142 |
+
org_module: torch.nn.Module,
|
143 |
+
multiplier=1.0,
|
144 |
+
lora_dim=4,
|
145 |
+
alpha=1,
|
146 |
+
**kwargs,
|
147 |
+
):
|
148 |
+
# no dropout for inference
|
149 |
+
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
150 |
+
|
151 |
+
self.org_module_ref = [org_module] # 後から参照できるように
|
152 |
+
self.enabled = True
|
153 |
+
|
154 |
+
# check regional or not by lora_name
|
155 |
+
self.text_encoder = False
|
156 |
+
if lora_name.startswith("lora_te_"):
|
157 |
+
self.regional = False
|
158 |
+
self.use_sub_prompt = True
|
159 |
+
self.text_encoder = True
|
160 |
+
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
161 |
+
self.regional = False
|
162 |
+
self.use_sub_prompt = True
|
163 |
+
elif "time_emb" in lora_name:
|
164 |
+
self.regional = False
|
165 |
+
self.use_sub_prompt = False
|
166 |
+
else:
|
167 |
+
self.regional = True
|
168 |
+
self.use_sub_prompt = False
|
169 |
+
|
170 |
+
self.network: LoRANetwork = None
|
171 |
+
|
172 |
+
def set_network(self, network):
|
173 |
+
self.network = network
|
174 |
+
|
175 |
+
# freezeしてマージする
|
176 |
+
def merge_to(self, sd, dtype, device):
|
177 |
+
# get up/down weight
|
178 |
+
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
179 |
+
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
180 |
+
|
181 |
+
# extract weight from org_module
|
182 |
+
org_sd = self.org_module.state_dict()
|
183 |
+
weight = org_sd["weight"].to(torch.float)
|
184 |
+
|
185 |
+
# merge weight
|
186 |
+
if len(weight.size()) == 2:
|
187 |
+
# linear
|
188 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
189 |
+
elif down_weight.size()[2:4] == (1, 1):
|
190 |
+
# conv2d 1x1
|
191 |
+
weight = (
|
192 |
+
weight
|
193 |
+
+ self.multiplier
|
194 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
195 |
+
* self.scale
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
# conv2d 3x3
|
199 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
200 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
201 |
+
weight = weight + self.multiplier * conved * self.scale
|
202 |
+
|
203 |
+
# set weight to org_module
|
204 |
+
org_sd["weight"] = weight.to(dtype)
|
205 |
+
self.org_module.load_state_dict(org_sd)
|
206 |
+
|
207 |
+
# 復元できるマージのため、このモジュールのweightを返す
|
208 |
+
def get_weight(self, multiplier=None):
|
209 |
+
if multiplier is None:
|
210 |
+
multiplier = self.multiplier
|
211 |
+
|
212 |
+
# get up/down weight from module
|
213 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
214 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
215 |
+
|
216 |
+
# pre-calculated weight
|
217 |
+
if len(down_weight.size()) == 2:
|
218 |
+
# linear
|
219 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
220 |
+
elif down_weight.size()[2:4] == (1, 1):
|
221 |
+
# conv2d 1x1
|
222 |
+
weight = (
|
223 |
+
self.multiplier
|
224 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
225 |
+
* self.scale
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
# conv2d 3x3
|
229 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
230 |
+
weight = self.multiplier * conved * self.scale
|
231 |
+
|
232 |
+
return weight
|
233 |
+
|
234 |
+
def set_region(self, region):
|
235 |
+
self.region = region
|
236 |
+
self.region_mask = None
|
237 |
+
|
238 |
+
def default_forward(self, x):
|
239 |
+
# print("default_forward", self.lora_name, x.size())
|
240 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
if not self.enabled:
|
244 |
+
return self.org_forward(x)
|
245 |
+
|
246 |
+
if self.network is None or self.network.sub_prompt_index is None:
|
247 |
+
return self.default_forward(x)
|
248 |
+
if not self.regional and not self.use_sub_prompt:
|
249 |
+
return self.default_forward(x)
|
250 |
+
|
251 |
+
if self.regional:
|
252 |
+
return self.regional_forward(x)
|
253 |
+
else:
|
254 |
+
return self.sub_prompt_forward(x)
|
255 |
+
|
256 |
+
def get_mask_for_x(self, x):
|
257 |
+
# calculate size from shape of x
|
258 |
+
if len(x.size()) == 4:
|
259 |
+
h, w = x.size()[2:4]
|
260 |
+
area = h * w
|
261 |
+
else:
|
262 |
+
area = x.size()[1]
|
263 |
+
|
264 |
+
mask = self.network.mask_dic[area]
|
265 |
+
if mask is None:
|
266 |
+
raise ValueError(f"mask is None for resolution {area}")
|
267 |
+
if len(x.size()) != 4:
|
268 |
+
mask = torch.reshape(mask, (1, -1, 1))
|
269 |
+
return mask
|
270 |
+
|
271 |
+
def regional_forward(self, x):
|
272 |
+
if "attn2_to_out" in self.lora_name:
|
273 |
+
return self.to_out_forward(x)
|
274 |
+
|
275 |
+
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
276 |
+
return self.default_forward(x)
|
277 |
+
|
278 |
+
# apply mask for LoRA result
|
279 |
+
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
280 |
+
mask = self.get_mask_for_x(lx)
|
281 |
+
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
282 |
+
lx = lx * mask
|
283 |
+
|
284 |
+
x = self.org_forward(x)
|
285 |
+
x = x + lx
|
286 |
+
|
287 |
+
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
288 |
+
x = self.postp_to_q(x)
|
289 |
+
|
290 |
+
return x
|
291 |
+
|
292 |
+
def postp_to_q(self, x):
|
293 |
+
# repeat x to num_sub_prompts
|
294 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
295 |
+
qc = self.network.batch_size # uncond
|
296 |
+
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
297 |
+
if has_real_uncond:
|
298 |
+
qc += self.network.batch_size # real_uncond
|
299 |
+
|
300 |
+
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
301 |
+
query[: self.network.batch_size] = x[: self.network.batch_size]
|
302 |
+
|
303 |
+
for i in range(self.network.batch_size):
|
304 |
+
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
305 |
+
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
306 |
+
|
307 |
+
if has_real_uncond:
|
308 |
+
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
309 |
+
|
310 |
+
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
311 |
+
return query
|
312 |
+
|
313 |
+
def sub_prompt_forward(self, x):
|
314 |
+
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
315 |
+
return self.org_forward(x)
|
316 |
+
|
317 |
+
emb_idx = self.network.sub_prompt_index
|
318 |
+
if not self.text_encoder:
|
319 |
+
emb_idx += self.network.batch_size
|
320 |
+
|
321 |
+
# apply sub prompt of X
|
322 |
+
lx = x[emb_idx :: self.network.num_sub_prompts]
|
323 |
+
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
324 |
+
|
325 |
+
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
326 |
+
|
327 |
+
x = self.org_forward(x)
|
328 |
+
x[emb_idx :: self.network.num_sub_prompts] += lx
|
329 |
+
|
330 |
+
return x
|
331 |
+
|
332 |
+
def to_out_forward(self, x):
|
333 |
+
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
334 |
+
|
335 |
+
if self.network.is_last_network:
|
336 |
+
masks = [None] * self.network.num_sub_prompts
|
337 |
+
self.network.shared[self.lora_name] = (None, masks)
|
338 |
+
else:
|
339 |
+
lx, masks = self.network.shared[self.lora_name]
|
340 |
+
|
341 |
+
# call own LoRA
|
342 |
+
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
343 |
+
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
344 |
+
|
345 |
+
if self.network.is_last_network:
|
346 |
+
lx = torch.zeros(
|
347 |
+
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
348 |
+
)
|
349 |
+
self.network.shared[self.lora_name] = (lx, masks)
|
350 |
+
|
351 |
+
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
352 |
+
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
353 |
+
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
354 |
+
|
355 |
+
# if not last network, return x and masks
|
356 |
+
x = self.org_forward(x)
|
357 |
+
if not self.network.is_last_network:
|
358 |
+
return x
|
359 |
+
|
360 |
+
lx, masks = self.network.shared.pop(self.lora_name)
|
361 |
+
|
362 |
+
# if last network, combine separated x with mask weighted sum
|
363 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
364 |
+
|
365 |
+
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
366 |
+
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
367 |
+
if has_real_uncond:
|
368 |
+
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
369 |
+
|
370 |
+
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
371 |
+
# for i in range(len(masks)):
|
372 |
+
# if masks[i] is None:
|
373 |
+
# masks[i] = torch.zeros_like(masks[-1])
|
374 |
+
|
375 |
+
mask = torch.cat(masks)
|
376 |
+
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
377 |
+
for i in range(self.network.batch_size):
|
378 |
+
# 1枚の画像ごとに処理する
|
379 |
+
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
380 |
+
lx1 = lx1 * mask
|
381 |
+
lx1 = torch.sum(lx1, dim=0)
|
382 |
+
|
383 |
+
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
384 |
+
x1 = x[xi : xi + self.network.num_sub_prompts]
|
385 |
+
x1 = x1 * mask
|
386 |
+
x1 = torch.sum(x1, dim=0)
|
387 |
+
x1 = x1 / mask_sum
|
388 |
+
|
389 |
+
x1 = x1 + lx1
|
390 |
+
out[self.network.batch_size + i] = x1
|
391 |
+
|
392 |
+
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
393 |
+
return out
|
394 |
+
|
395 |
+
|
396 |
+
def parse_block_lr_kwargs(nw_kwargs):
|
397 |
+
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
398 |
+
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
399 |
+
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
400 |
+
|
401 |
+
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
402 |
+
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
403 |
+
return None, None, None
|
404 |
+
|
405 |
+
# extract learning rate weight for each block
|
406 |
+
if down_lr_weight is not None:
|
407 |
+
# if some parameters are not set, use zero
|
408 |
+
if "," in down_lr_weight:
|
409 |
+
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
410 |
+
|
411 |
+
if mid_lr_weight is not None:
|
412 |
+
mid_lr_weight = float(mid_lr_weight)
|
413 |
+
|
414 |
+
if up_lr_weight is not None:
|
415 |
+
if "," in up_lr_weight:
|
416 |
+
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
417 |
+
|
418 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
419 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
420 |
+
)
|
421 |
+
|
422 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
423 |
+
|
424 |
+
|
425 |
+
def create_network(
|
426 |
+
multiplier: float,
|
427 |
+
network_dim: Optional[int],
|
428 |
+
network_alpha: Optional[float],
|
429 |
+
vae: AutoencoderKL,
|
430 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
431 |
+
unet,
|
432 |
+
neuron_dropout: Optional[float] = None,
|
433 |
+
**kwargs,
|
434 |
+
):
|
435 |
+
if network_dim is None:
|
436 |
+
network_dim = 4 # default
|
437 |
+
if network_alpha is None:
|
438 |
+
network_alpha = 1.0
|
439 |
+
|
440 |
+
# extract dim/alpha for conv2d, and block dim
|
441 |
+
conv_dim = kwargs.get("conv_dim", None)
|
442 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
443 |
+
if conv_dim is not None:
|
444 |
+
conv_dim = int(conv_dim)
|
445 |
+
if conv_alpha is None:
|
446 |
+
conv_alpha = 1.0
|
447 |
+
else:
|
448 |
+
conv_alpha = float(conv_alpha)
|
449 |
+
|
450 |
+
# block dim/alpha/lr
|
451 |
+
block_dims = kwargs.get("block_dims", None)
|
452 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
453 |
+
|
454 |
+
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
455 |
+
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
456 |
+
block_alphas = kwargs.get("block_alphas", None)
|
457 |
+
conv_block_dims = kwargs.get("conv_block_dims", None)
|
458 |
+
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
459 |
+
|
460 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
461 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
462 |
+
)
|
463 |
+
|
464 |
+
# remove block dim/alpha without learning rate
|
465 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
466 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
467 |
+
)
|
468 |
+
|
469 |
+
else:
|
470 |
+
block_alphas = None
|
471 |
+
conv_block_dims = None
|
472 |
+
conv_block_alphas = None
|
473 |
+
|
474 |
+
# rank/module dropout
|
475 |
+
rank_dropout = kwargs.get("rank_dropout", None)
|
476 |
+
if rank_dropout is not None:
|
477 |
+
rank_dropout = float(rank_dropout)
|
478 |
+
module_dropout = kwargs.get("module_dropout", None)
|
479 |
+
if module_dropout is not None:
|
480 |
+
module_dropout = float(module_dropout)
|
481 |
+
|
482 |
+
# すごく引数が多いな ( ^ω^)・・・
|
483 |
+
network = LoRANetwork(
|
484 |
+
text_encoder,
|
485 |
+
unet,
|
486 |
+
multiplier=multiplier,
|
487 |
+
lora_dim=network_dim,
|
488 |
+
alpha=network_alpha,
|
489 |
+
dropout=neuron_dropout,
|
490 |
+
rank_dropout=rank_dropout,
|
491 |
+
module_dropout=module_dropout,
|
492 |
+
conv_lora_dim=conv_dim,
|
493 |
+
conv_alpha=conv_alpha,
|
494 |
+
block_dims=block_dims,
|
495 |
+
block_alphas=block_alphas,
|
496 |
+
conv_block_dims=conv_block_dims,
|
497 |
+
conv_block_alphas=conv_block_alphas,
|
498 |
+
varbose=True,
|
499 |
+
)
|
500 |
+
|
501 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
502 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
503 |
+
|
504 |
+
return network
|
505 |
+
|
506 |
+
|
507 |
+
# このメソッドは外部から呼び出される可能性を考慮しておく
|
508 |
+
# network_dim, network_alpha にはデフォルト値が入っている。
|
509 |
+
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
510 |
+
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
511 |
+
def get_block_dims_and_alphas(
|
512 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
513 |
+
):
|
514 |
+
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
515 |
+
|
516 |
+
def parse_ints(s):
|
517 |
+
return [int(i) for i in s.split(",")]
|
518 |
+
|
519 |
+
def parse_floats(s):
|
520 |
+
return [float(i) for i in s.split(",")]
|
521 |
+
|
522 |
+
# block_dimsとblock_alphasをパースする。必ず値が入る
|
523 |
+
if block_dims is not None:
|
524 |
+
block_dims = parse_ints(block_dims)
|
525 |
+
assert (
|
526 |
+
len(block_dims) == num_total_blocks
|
527 |
+
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
528 |
+
else:
|
529 |
+
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
530 |
+
block_dims = [network_dim] * num_total_blocks
|
531 |
+
|
532 |
+
if block_alphas is not None:
|
533 |
+
block_alphas = parse_floats(block_alphas)
|
534 |
+
assert (
|
535 |
+
len(block_alphas) == num_total_blocks
|
536 |
+
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
537 |
+
else:
|
538 |
+
print(
|
539 |
+
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
540 |
+
)
|
541 |
+
block_alphas = [network_alpha] * num_total_blocks
|
542 |
+
|
543 |
+
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
544 |
+
if conv_block_dims is not None:
|
545 |
+
conv_block_dims = parse_ints(conv_block_dims)
|
546 |
+
assert (
|
547 |
+
len(conv_block_dims) == num_total_blocks
|
548 |
+
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
549 |
+
|
550 |
+
if conv_block_alphas is not None:
|
551 |
+
conv_block_alphas = parse_floats(conv_block_alphas)
|
552 |
+
assert (
|
553 |
+
len(conv_block_alphas) == num_total_blocks
|
554 |
+
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
555 |
+
else:
|
556 |
+
if conv_alpha is None:
|
557 |
+
conv_alpha = 1.0
|
558 |
+
print(
|
559 |
+
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
560 |
+
)
|
561 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
562 |
+
else:
|
563 |
+
if conv_dim is not None:
|
564 |
+
print(
|
565 |
+
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
566 |
+
)
|
567 |
+
conv_block_dims = [conv_dim] * num_total_blocks
|
568 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
569 |
+
else:
|
570 |
+
conv_block_dims = None
|
571 |
+
conv_block_alphas = None
|
572 |
+
|
573 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
574 |
+
|
575 |
+
|
576 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
577 |
+
def get_block_lr_weight(
|
578 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
579 |
+
) -> Tuple[List[float], List[float], List[float]]:
|
580 |
+
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
581 |
+
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
582 |
+
return None, None, None
|
583 |
+
|
584 |
+
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
585 |
+
|
586 |
+
def get_list(name_with_suffix) -> List[float]:
|
587 |
+
import math
|
588 |
+
|
589 |
+
tokens = name_with_suffix.split("+")
|
590 |
+
name = tokens[0]
|
591 |
+
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
592 |
+
|
593 |
+
if name == "cosine":
|
594 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
595 |
+
elif name == "sine":
|
596 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
597 |
+
elif name == "linear":
|
598 |
+
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
599 |
+
elif name == "reverse_linear":
|
600 |
+
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
601 |
+
elif name == "zeros":
|
602 |
+
return [0.0 + base_lr] * max_len
|
603 |
+
else:
|
604 |
+
print(
|
605 |
+
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
606 |
+
% (name)
|
607 |
+
)
|
608 |
+
return None
|
609 |
+
|
610 |
+
if type(down_lr_weight) == str:
|
611 |
+
down_lr_weight = get_list(down_lr_weight)
|
612 |
+
if type(up_lr_weight) == str:
|
613 |
+
up_lr_weight = get_list(up_lr_weight)
|
614 |
+
|
615 |
+
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
616 |
+
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
617 |
+
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
618 |
+
up_lr_weight = up_lr_weight[:max_len]
|
619 |
+
down_lr_weight = down_lr_weight[:max_len]
|
620 |
+
|
621 |
+
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
622 |
+
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
623 |
+
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
624 |
+
|
625 |
+
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
626 |
+
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
627 |
+
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
628 |
+
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
629 |
+
|
630 |
+
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
631 |
+
print("apply block learning rate / 階層別学習率を適用します。")
|
632 |
+
if down_lr_weight != None:
|
633 |
+
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
634 |
+
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
635 |
+
else:
|
636 |
+
print("down_lr_weight: all 1.0, すべて1.0")
|
637 |
+
|
638 |
+
if mid_lr_weight != None:
|
639 |
+
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
640 |
+
print("mid_lr_weight:", mid_lr_weight)
|
641 |
+
else:
|
642 |
+
print("mid_lr_weight: 1.0")
|
643 |
+
|
644 |
+
if up_lr_weight != None:
|
645 |
+
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
646 |
+
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
647 |
+
else:
|
648 |
+
print("up_lr_weight: all 1.0, すべて1.0")
|
649 |
+
|
650 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
651 |
+
|
652 |
+
|
653 |
+
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
654 |
+
def remove_block_dims_and_alphas(
|
655 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
656 |
+
):
|
657 |
+
# set 0 to block dim without learning rate to remove the block
|
658 |
+
if down_lr_weight != None:
|
659 |
+
for i, lr in enumerate(down_lr_weight):
|
660 |
+
if lr == 0:
|
661 |
+
block_dims[i] = 0
|
662 |
+
if conv_block_dims is not None:
|
663 |
+
conv_block_dims[i] = 0
|
664 |
+
if mid_lr_weight != None:
|
665 |
+
if mid_lr_weight == 0:
|
666 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
667 |
+
if conv_block_dims is not None:
|
668 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
669 |
+
if up_lr_weight != None:
|
670 |
+
for i, lr in enumerate(up_lr_weight):
|
671 |
+
if lr == 0:
|
672 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
673 |
+
if conv_block_dims is not None:
|
674 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
675 |
+
|
676 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
677 |
+
|
678 |
+
|
679 |
+
# 外部から呼び出す可能性を考慮しておく
|
680 |
+
def get_block_index(lora_name: str) -> int:
|
681 |
+
block_idx = -1 # invalid lora name
|
682 |
+
|
683 |
+
m = RE_UPDOWN.search(lora_name)
|
684 |
+
if m:
|
685 |
+
g = m.groups()
|
686 |
+
i = int(g[1])
|
687 |
+
j = int(g[3])
|
688 |
+
if g[2] == "resnets":
|
689 |
+
idx = 3 * i + j
|
690 |
+
elif g[2] == "attentions":
|
691 |
+
idx = 3 * i + j
|
692 |
+
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
693 |
+
idx = 3 * i + 2
|
694 |
+
|
695 |
+
if g[0] == "down":
|
696 |
+
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
697 |
+
elif g[0] == "up":
|
698 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
699 |
+
|
700 |
+
elif "mid_block_" in lora_name:
|
701 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
702 |
+
|
703 |
+
return block_idx
|
704 |
+
|
705 |
+
|
706 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
707 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
708 |
+
if weights_sd is None:
|
709 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
710 |
+
from safetensors.torch import load_file, safe_open
|
711 |
+
|
712 |
+
weights_sd = load_file(file)
|
713 |
+
else:
|
714 |
+
weights_sd = torch.load(file, map_location="cpu")
|
715 |
+
|
716 |
+
# get dim/alpha mapping
|
717 |
+
modules_dim = {}
|
718 |
+
modules_alpha = {}
|
719 |
+
for key, value in weights_sd.items():
|
720 |
+
if "." not in key:
|
721 |
+
continue
|
722 |
+
|
723 |
+
lora_name = key.split(".")[0]
|
724 |
+
if "alpha" in key:
|
725 |
+
modules_alpha[lora_name] = value
|
726 |
+
elif "lora_down" in key:
|
727 |
+
dim = value.size()[0]
|
728 |
+
modules_dim[lora_name] = dim
|
729 |
+
# print(lora_name, value.size(), dim)
|
730 |
+
|
731 |
+
# support old LoRA without alpha
|
732 |
+
for key in modules_dim.keys():
|
733 |
+
if key not in modules_alpha:
|
734 |
+
modules_alpha[key] = modules_dim[key]
|
735 |
+
|
736 |
+
module_class = LoRAInfModule if for_inference else LoRAModule
|
737 |
+
|
738 |
+
network = LoRANetwork(
|
739 |
+
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
740 |
+
)
|
741 |
+
|
742 |
+
# block lr
|
743 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
744 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
745 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
746 |
+
|
747 |
+
return network, weights_sd
|
748 |
+
|
749 |
+
|
750 |
+
class LoRANetwork(torch.nn.Module):
|
751 |
+
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
752 |
+
|
753 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
754 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
755 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
756 |
+
LORA_PREFIX_UNET = "lora_unet"
|
757 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
758 |
+
|
759 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
760 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
761 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
762 |
+
|
763 |
+
def __init__(
|
764 |
+
self,
|
765 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
766 |
+
unet,
|
767 |
+
multiplier: float = 1.0,
|
768 |
+
lora_dim: int = 4,
|
769 |
+
alpha: float = 1,
|
770 |
+
dropout: Optional[float] = None,
|
771 |
+
rank_dropout: Optional[float] = None,
|
772 |
+
module_dropout: Optional[float] = None,
|
773 |
+
conv_lora_dim: Optional[int] = None,
|
774 |
+
conv_alpha: Optional[float] = None,
|
775 |
+
block_dims: Optional[List[int]] = None,
|
776 |
+
block_alphas: Optional[List[float]] = None,
|
777 |
+
conv_block_dims: Optional[List[int]] = None,
|
778 |
+
conv_block_alphas: Optional[List[float]] = None,
|
779 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
780 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
781 |
+
module_class: Type[object] = LoRAModule,
|
782 |
+
varbose: Optional[bool] = False,
|
783 |
+
) -> None:
|
784 |
+
"""
|
785 |
+
LoRA network: すごく引数が多いが、パターンは以下の通り
|
786 |
+
1. lora_dimとalphaを指定
|
787 |
+
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
788 |
+
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
789 |
+
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
790 |
+
5. modules_dimとmodules_alphaを指定 (推論用)
|
791 |
+
"""
|
792 |
+
super().__init__()
|
793 |
+
self.multiplier = multiplier
|
794 |
+
|
795 |
+
self.lora_dim = lora_dim
|
796 |
+
self.alpha = alpha
|
797 |
+
self.conv_lora_dim = conv_lora_dim
|
798 |
+
self.conv_alpha = conv_alpha
|
799 |
+
self.dropout = dropout
|
800 |
+
self.rank_dropout = rank_dropout
|
801 |
+
self.module_dropout = module_dropout
|
802 |
+
|
803 |
+
if modules_dim is not None:
|
804 |
+
print(f"create LoRA network from weights")
|
805 |
+
elif block_dims is not None:
|
806 |
+
print(f"create LoRA network from block_dims")
|
807 |
+
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
808 |
+
print(f"block_dims: {block_dims}")
|
809 |
+
print(f"block_alphas: {block_alphas}")
|
810 |
+
if conv_block_dims is not None:
|
811 |
+
print(f"conv_block_dims: {conv_block_dims}")
|
812 |
+
print(f"conv_block_alphas: {conv_block_alphas}")
|
813 |
+
else:
|
814 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
815 |
+
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
816 |
+
if self.conv_lora_dim is not None:
|
817 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
818 |
+
|
819 |
+
# create module instances
|
820 |
+
def create_modules(
|
821 |
+
is_unet: bool,
|
822 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
823 |
+
root_module: torch.nn.Module,
|
824 |
+
target_replace_modules: List[torch.nn.Module],
|
825 |
+
) -> List[LoRAModule]:
|
826 |
+
prefix = (
|
827 |
+
self.LORA_PREFIX_UNET
|
828 |
+
if is_unet
|
829 |
+
else (
|
830 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
831 |
+
if text_encoder_idx is None
|
832 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
833 |
+
)
|
834 |
+
)
|
835 |
+
loras = []
|
836 |
+
skipped = []
|
837 |
+
for name, module in root_module.named_modules():
|
838 |
+
if module.__class__.__name__ in target_replace_modules:
|
839 |
+
for child_name, child_module in module.named_modules():
|
840 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
841 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
842 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
843 |
+
|
844 |
+
if is_linear or is_conv2d:
|
845 |
+
lora_name = prefix + "." + name + "." + child_name
|
846 |
+
lora_name = lora_name.replace(".", "_")
|
847 |
+
|
848 |
+
dim = None
|
849 |
+
alpha = None
|
850 |
+
|
851 |
+
if modules_dim is not None:
|
852 |
+
# モジュール指定あり
|
853 |
+
if lora_name in modules_dim:
|
854 |
+
dim = modules_dim[lora_name]
|
855 |
+
alpha = modules_alpha[lora_name]
|
856 |
+
elif is_unet and block_dims is not None:
|
857 |
+
# U-Netでblock_dims指定あり
|
858 |
+
block_idx = get_block_index(lora_name)
|
859 |
+
if is_linear or is_conv2d_1x1:
|
860 |
+
dim = block_dims[block_idx]
|
861 |
+
alpha = block_alphas[block_idx]
|
862 |
+
elif conv_block_dims is not None:
|
863 |
+
dim = conv_block_dims[block_idx]
|
864 |
+
alpha = conv_block_alphas[block_idx]
|
865 |
+
else:
|
866 |
+
# 通常、すべて対象とする
|
867 |
+
if is_linear or is_conv2d_1x1:
|
868 |
+
dim = self.lora_dim
|
869 |
+
alpha = self.alpha
|
870 |
+
elif self.conv_lora_dim is not None:
|
871 |
+
dim = self.conv_lora_dim
|
872 |
+
alpha = self.conv_alpha
|
873 |
+
|
874 |
+
if dim is None or dim == 0:
|
875 |
+
# skipした情報を出力
|
876 |
+
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
877 |
+
skipped.append(lora_name)
|
878 |
+
continue
|
879 |
+
|
880 |
+
lora = module_class(
|
881 |
+
lora_name,
|
882 |
+
child_module,
|
883 |
+
self.multiplier,
|
884 |
+
dim,
|
885 |
+
alpha,
|
886 |
+
dropout=dropout,
|
887 |
+
rank_dropout=rank_dropout,
|
888 |
+
module_dropout=module_dropout,
|
889 |
+
)
|
890 |
+
loras.append(lora)
|
891 |
+
return loras, skipped
|
892 |
+
|
893 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
894 |
+
|
895 |
+
# create LoRA for text encoder
|
896 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
897 |
+
self.text_encoder_loras = []
|
898 |
+
skipped_te = []
|
899 |
+
for i, text_encoder in enumerate(text_encoders):
|
900 |
+
if len(text_encoders) > 1:
|
901 |
+
index = i + 1
|
902 |
+
print(f"create LoRA for Text Encoder {index}:")
|
903 |
+
else:
|
904 |
+
index = None
|
905 |
+
print(f"create LoRA for Text Encoder:")
|
906 |
+
|
907 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
908 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
909 |
+
skipped_te += skipped
|
910 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
911 |
+
|
912 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
913 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
914 |
+
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
915 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
916 |
+
|
917 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
918 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
919 |
+
|
920 |
+
skipped = skipped_te + skipped_un
|
921 |
+
if varbose and len(skipped) > 0:
|
922 |
+
print(
|
923 |
+
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
924 |
+
)
|
925 |
+
for name in skipped:
|
926 |
+
print(f"\t{name}")
|
927 |
+
|
928 |
+
self.up_lr_weight: List[float] = None
|
929 |
+
self.down_lr_weight: List[float] = None
|
930 |
+
self.mid_lr_weight: float = None
|
931 |
+
self.block_lr = False
|
932 |
+
|
933 |
+
# assertion
|
934 |
+
names = set()
|
935 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
936 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
937 |
+
names.add(lora.lora_name)
|
938 |
+
|
939 |
+
def set_multiplier(self, multiplier):
|
940 |
+
self.multiplier = multiplier
|
941 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
942 |
+
lora.multiplier = self.multiplier
|
943 |
+
|
944 |
+
def load_weights(self, file):
|
945 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
946 |
+
from safetensors.torch import load_file
|
947 |
+
|
948 |
+
weights_sd = load_file(file)
|
949 |
+
else:
|
950 |
+
weights_sd = torch.load(file, map_location="cpu")
|
951 |
+
|
952 |
+
info = self.load_state_dict(weights_sd, False)
|
953 |
+
return info
|
954 |
+
|
955 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
956 |
+
if apply_text_encoder:
|
957 |
+
print("enable LoRA for text encoder")
|
958 |
+
else:
|
959 |
+
self.text_encoder_loras = []
|
960 |
+
|
961 |
+
if apply_unet:
|
962 |
+
print("enable LoRA for U-Net")
|
963 |
+
else:
|
964 |
+
self.unet_loras = []
|
965 |
+
|
966 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
967 |
+
lora.apply_to()
|
968 |
+
self.add_module(lora.lora_name, lora)
|
969 |
+
|
970 |
+
# マージできるかどうかを返す
|
971 |
+
def is_mergeable(self):
|
972 |
+
return True
|
973 |
+
|
974 |
+
# TODO refactor to common function with apply_to
|
975 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
976 |
+
apply_text_encoder = apply_unet = False
|
977 |
+
for key in weights_sd.keys():
|
978 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
979 |
+
apply_text_encoder = True
|
980 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
981 |
+
apply_unet = True
|
982 |
+
|
983 |
+
if apply_text_encoder:
|
984 |
+
print("enable LoRA for text encoder")
|
985 |
+
else:
|
986 |
+
self.text_encoder_loras = []
|
987 |
+
|
988 |
+
if apply_unet:
|
989 |
+
print("enable LoRA for U-Net")
|
990 |
+
else:
|
991 |
+
self.unet_loras = []
|
992 |
+
|
993 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
994 |
+
sd_for_lora = {}
|
995 |
+
for key in weights_sd.keys():
|
996 |
+
if key.startswith(lora.lora_name):
|
997 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
998 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
999 |
+
|
1000 |
+
print(f"weights are merged")
|
1001 |
+
|
1002 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
1003 |
+
def set_block_lr_weight(
|
1004 |
+
self,
|
1005 |
+
up_lr_weight: List[float] = None,
|
1006 |
+
mid_lr_weight: float = None,
|
1007 |
+
down_lr_weight: List[float] = None,
|
1008 |
+
):
|
1009 |
+
self.block_lr = True
|
1010 |
+
self.down_lr_weight = down_lr_weight
|
1011 |
+
self.mid_lr_weight = mid_lr_weight
|
1012 |
+
self.up_lr_weight = up_lr_weight
|
1013 |
+
|
1014 |
+
def get_lr_weight(self, lora: LoRAModule) -> float:
|
1015 |
+
lr_weight = 1.0
|
1016 |
+
block_idx = get_block_index(lora.lora_name)
|
1017 |
+
if block_idx < 0:
|
1018 |
+
return lr_weight
|
1019 |
+
|
1020 |
+
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
1021 |
+
if self.down_lr_weight != None:
|
1022 |
+
lr_weight = self.down_lr_weight[block_idx]
|
1023 |
+
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
1024 |
+
if self.mid_lr_weight != None:
|
1025 |
+
lr_weight = self.mid_lr_weight
|
1026 |
+
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
1027 |
+
if self.up_lr_weight != None:
|
1028 |
+
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
1029 |
+
|
1030 |
+
return lr_weight
|
1031 |
+
|
1032 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
1033 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
1034 |
+
self.requires_grad_(True)
|
1035 |
+
all_params = []
|
1036 |
+
|
1037 |
+
def enumerate_params(loras: List[LoRAModule]):
|
1038 |
+
params = []
|
1039 |
+
for lora in loras:
|
1040 |
+
# params.extend(lora.parameters())
|
1041 |
+
params.extend(lora.get_trainable_params())
|
1042 |
+
return params
|
1043 |
+
|
1044 |
+
if self.text_encoder_loras:
|
1045 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
1046 |
+
if text_encoder_lr is not None:
|
1047 |
+
param_data["lr"] = text_encoder_lr
|
1048 |
+
all_params.append(param_data)
|
1049 |
+
|
1050 |
+
if self.unet_loras:
|
1051 |
+
if self.block_lr:
|
1052 |
+
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
1053 |
+
block_idx_to_lora = {}
|
1054 |
+
for lora in self.unet_loras:
|
1055 |
+
idx = get_block_index(lora.lora_name)
|
1056 |
+
if idx not in block_idx_to_lora:
|
1057 |
+
block_idx_to_lora[idx] = []
|
1058 |
+
block_idx_to_lora[idx].append(lora)
|
1059 |
+
|
1060 |
+
# blockごとにパラメータを設定する
|
1061 |
+
for idx, block_loras in block_idx_to_lora.items():
|
1062 |
+
param_data = {"params": enumerate_params(block_loras)}
|
1063 |
+
|
1064 |
+
if unet_lr is not None:
|
1065 |
+
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
1066 |
+
elif default_lr is not None:
|
1067 |
+
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
1068 |
+
if ("lr" in param_data) and (param_data["lr"] == 0):
|
1069 |
+
continue
|
1070 |
+
all_params.append(param_data)
|
1071 |
+
|
1072 |
+
else:
|
1073 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
1074 |
+
if unet_lr is not None:
|
1075 |
+
param_data["lr"] = unet_lr
|
1076 |
+
all_params.append(param_data)
|
1077 |
+
|
1078 |
+
return all_params
|
1079 |
+
|
1080 |
+
def enable_gradient_checkpointing(self):
|
1081 |
+
# not supported
|
1082 |
+
pass
|
1083 |
+
|
1084 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
1085 |
+
self.requires_grad_(True)
|
1086 |
+
|
1087 |
+
def on_epoch_start(self, text_encoder, unet):
|
1088 |
+
self.train()
|
1089 |
+
|
1090 |
+
def get_trainable_params(self):
|
1091 |
+
return self.parameters()
|
1092 |
+
|
1093 |
+
def save_weights(self, file, dtype, metadata):
|
1094 |
+
if metadata is not None and len(metadata) == 0:
|
1095 |
+
metadata = None
|
1096 |
+
|
1097 |
+
state_dict = self.state_dict()
|
1098 |
+
|
1099 |
+
if dtype is not None:
|
1100 |
+
for key in list(state_dict.keys()):
|
1101 |
+
v = state_dict[key]
|
1102 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
1103 |
+
state_dict[key] = v
|
1104 |
+
|
1105 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
1106 |
+
from safetensors.torch import save_file
|
1107 |
+
from library import train_util
|
1108 |
+
|
1109 |
+
# Precalculate model hashes to save time on indexing
|
1110 |
+
if metadata is None:
|
1111 |
+
metadata = {}
|
1112 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
1113 |
+
metadata["sshs_model_hash"] = model_hash
|
1114 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
1115 |
+
|
1116 |
+
save_file(state_dict, file, metadata)
|
1117 |
+
else:
|
1118 |
+
torch.save(state_dict, file)
|
1119 |
+
|
1120 |
+
# mask is a tensor with values from 0 to 1
|
1121 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
1122 |
+
if mask.max() == 0:
|
1123 |
+
mask = torch.ones_like(mask)
|
1124 |
+
|
1125 |
+
self.mask = mask
|
1126 |
+
self.sub_prompt_index = sub_prompt_index
|
1127 |
+
self.is_last_network = is_last_network
|
1128 |
+
|
1129 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
1130 |
+
lora.set_network(self)
|
1131 |
+
|
1132 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
1133 |
+
self.batch_size = batch_size
|
1134 |
+
self.num_sub_prompts = num_sub_prompts
|
1135 |
+
self.current_size = (height, width)
|
1136 |
+
self.shared = shared
|
1137 |
+
|
1138 |
+
# create masks
|
1139 |
+
mask = self.mask
|
1140 |
+
mask_dic = {}
|
1141 |
+
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
1142 |
+
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
1143 |
+
dtype = ref_weight.dtype
|
1144 |
+
device = ref_weight.device
|
1145 |
+
|
1146 |
+
def resize_add(mh, mw):
|
1147 |
+
# print(mh, mw, mh * mw)
|
1148 |
+
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
1149 |
+
m = m.to(device, dtype=dtype)
|
1150 |
+
mask_dic[mh * mw] = m
|
1151 |
+
|
1152 |
+
h = height // 8
|
1153 |
+
w = width // 8
|
1154 |
+
for _ in range(4):
|
1155 |
+
resize_add(h, w)
|
1156 |
+
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
1157 |
+
resize_add(h + h % 2, w + w % 2)
|
1158 |
+
h = (h + 1) // 2
|
1159 |
+
w = (w + 1) // 2
|
1160 |
+
|
1161 |
+
self.mask_dic = mask_dic
|
1162 |
+
|
1163 |
+
def backup_weights(self):
|
1164 |
+
# 重みのバックアップを行う
|
1165 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
1166 |
+
for lora in loras:
|
1167 |
+
org_module = lora.org_module_ref[0]
|
1168 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
1169 |
+
sd = org_module.state_dict()
|
1170 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
1171 |
+
org_module._lora_restored = True
|
1172 |
+
|
1173 |
+
def restore_weights(self):
|
1174 |
+
# 重みのリストアを行う
|
1175 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
1176 |
+
for lora in loras:
|
1177 |
+
org_module = lora.org_module_ref[0]
|
1178 |
+
if not org_module._lora_restored:
|
1179 |
+
sd = org_module.state_dict()
|
1180 |
+
sd["weight"] = org_module._lora_org_weight
|
1181 |
+
org_module.load_state_dict(sd)
|
1182 |
+
org_module._lora_restored = True
|
1183 |
+
|
1184 |
+
def pre_calculation(self):
|
1185 |
+
# 事前計算を行う
|
1186 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
1187 |
+
for lora in loras:
|
1188 |
+
org_module = lora.org_module_ref[0]
|
1189 |
+
sd = org_module.state_dict()
|
1190 |
+
|
1191 |
+
org_weight = sd["weight"]
|
1192 |
+
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
1193 |
+
sd["weight"] = org_weight + lora_weight
|
1194 |
+
assert sd["weight"].shape == org_weight.shape
|
1195 |
+
org_module.load_state_dict(sd)
|
1196 |
+
|
1197 |
+
org_module._lora_restored = False
|
1198 |
+
lora.enabled = False
|
1199 |
+
|
1200 |
+
def apply_max_norm_regularization(self, max_norm_value, device):
|
1201 |
+
downkeys = []
|
1202 |
+
upkeys = []
|
1203 |
+
alphakeys = []
|
1204 |
+
norms = []
|
1205 |
+
keys_scaled = 0
|
1206 |
+
|
1207 |
+
state_dict = self.state_dict()
|
1208 |
+
for key in state_dict.keys():
|
1209 |
+
if "lora_down" in key and "weight" in key:
|
1210 |
+
downkeys.append(key)
|
1211 |
+
upkeys.append(key.replace("lora_down", "lora_up"))
|
1212 |
+
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
1213 |
+
|
1214 |
+
for i in range(len(downkeys)):
|
1215 |
+
down = state_dict[downkeys[i]].to(device)
|
1216 |
+
up = state_dict[upkeys[i]].to(device)
|
1217 |
+
alpha = state_dict[alphakeys[i]].to(device)
|
1218 |
+
dim = down.shape[0]
|
1219 |
+
scale = alpha / dim
|
1220 |
+
|
1221 |
+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
1222 |
+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
1223 |
+
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
1224 |
+
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
1225 |
+
else:
|
1226 |
+
updown = up @ down
|
1227 |
+
|
1228 |
+
updown *= scale
|
1229 |
+
|
1230 |
+
norm = updown.norm().clamp(min=max_norm_value / 2)
|
1231 |
+
desired = torch.clamp(norm, max=max_norm_value)
|
1232 |
+
ratio = desired.cpu() / norm.cpu()
|
1233 |
+
sqrt_ratio = ratio**0.5
|
1234 |
+
if ratio != 1:
|
1235 |
+
keys_scaled += 1
|
1236 |
+
state_dict[upkeys[i]] *= sqrt_ratio
|
1237 |
+
state_dict[downkeys[i]] *= sqrt_ratio
|
1238 |
+
scalednorm = updown.norm() * ratio
|
1239 |
+
norms.append(scalednorm.item())
|
1240 |
+
|
1241 |
+
return keys_scaled, sum(norms) / len(norms), max(norms)
|
external/llite/networks/lora_interrogator.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from library import model_util
|
5 |
+
import library.train_util as train_util
|
6 |
+
import argparse
|
7 |
+
from transformers import CLIPTokenizer
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import library.model_util as model_util
|
11 |
+
import lora
|
12 |
+
|
13 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
14 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
15 |
+
|
16 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
|
18 |
+
|
19 |
+
def interrogate(args):
|
20 |
+
weights_dtype = torch.float16
|
21 |
+
|
22 |
+
# いろいろ準備する
|
23 |
+
print(f"loading SD model: {args.sd_model}")
|
24 |
+
args.pretrained_model_name_or_path = args.sd_model
|
25 |
+
args.vae = None
|
26 |
+
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
27 |
+
|
28 |
+
print(f"loading LoRA: {args.model}")
|
29 |
+
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
30 |
+
|
31 |
+
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
32 |
+
has_te_weight = False
|
33 |
+
for key in weights_sd.keys():
|
34 |
+
if 'lora_te' in key:
|
35 |
+
has_te_weight = True
|
36 |
+
break
|
37 |
+
if not has_te_weight:
|
38 |
+
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
39 |
+
return
|
40 |
+
del vae
|
41 |
+
|
42 |
+
print("loading tokenizer")
|
43 |
+
if args.v2:
|
44 |
+
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
45 |
+
else:
|
46 |
+
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
47 |
+
|
48 |
+
text_encoder.to(DEVICE, dtype=weights_dtype)
|
49 |
+
text_encoder.eval()
|
50 |
+
unet.to(DEVICE, dtype=weights_dtype)
|
51 |
+
unet.eval() # U-Netは呼び出さないので不要だけど
|
52 |
+
|
53 |
+
# トークンをひとつひとつ当たっていく
|
54 |
+
token_id_start = 0
|
55 |
+
token_id_end = max(tokenizer.all_special_ids)
|
56 |
+
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
57 |
+
|
58 |
+
def get_all_embeddings(text_encoder):
|
59 |
+
embs = []
|
60 |
+
with torch.no_grad():
|
61 |
+
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
|
62 |
+
batch = []
|
63 |
+
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
|
64 |
+
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
|
65 |
+
# tokens = [tid] # こちらは結果がいまひとつ
|
66 |
+
batch.append(tokens)
|
67 |
+
|
68 |
+
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
|
69 |
+
# clip skip対応
|
70 |
+
batch = torch.tensor(batch).to(DEVICE)
|
71 |
+
if args.clip_skip is None:
|
72 |
+
encoder_hidden_states = text_encoder(batch)[0]
|
73 |
+
else:
|
74 |
+
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
|
75 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
76 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
77 |
+
encoder_hidden_states = encoder_hidden_states.to("cpu")
|
78 |
+
|
79 |
+
embs.extend(encoder_hidden_states)
|
80 |
+
return torch.stack(embs)
|
81 |
+
|
82 |
+
print("get original text encoder embeddings.")
|
83 |
+
orig_embs = get_all_embeddings(text_encoder)
|
84 |
+
|
85 |
+
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
86 |
+
info = network.load_state_dict(weights_sd, strict=False)
|
87 |
+
print(f"Loading LoRA weights: {info}")
|
88 |
+
|
89 |
+
network.to(DEVICE, dtype=weights_dtype)
|
90 |
+
network.eval()
|
91 |
+
|
92 |
+
del unet
|
93 |
+
|
94 |
+
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
95 |
+
print("get text encoder embeddings with lora.")
|
96 |
+
lora_embs = get_all_embeddings(text_encoder)
|
97 |
+
|
98 |
+
# 比べる:とりあえず単純に差分の絶対値で
|
99 |
+
print("comparing...")
|
100 |
+
diffs = {}
|
101 |
+
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
102 |
+
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
103 |
+
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
|
104 |
+
diff = float(diff.detach().to('cpu').numpy())
|
105 |
+
diffs[token_id_start + i] = diff
|
106 |
+
|
107 |
+
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
|
108 |
+
|
109 |
+
# 結果を表示する
|
110 |
+
print("top 100:")
|
111 |
+
for i, (token, diff) in enumerate(diffs_sorted[:100]):
|
112 |
+
# if diff < 1e-6:
|
113 |
+
# break
|
114 |
+
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
|
115 |
+
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
116 |
+
|
117 |
+
|
118 |
+
def setup_parser() -> argparse.ArgumentParser:
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
|
121 |
+
parser.add_argument("--v2", action='store_true',
|
122 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
123 |
+
parser.add_argument("--sd_model", type=str, default=None,
|
124 |
+
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
125 |
+
parser.add_argument("--model", type=str, default=None,
|
126 |
+
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
127 |
+
parser.add_argument("--batch_size", type=int, default=16,
|
128 |
+
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
|
129 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
130 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
131 |
+
|
132 |
+
return parser
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
parser = setup_parser()
|
137 |
+
|
138 |
+
args = parser.parse_args()
|
139 |
+
interrogate(args)
|
external/llite/networks/merge_lora.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
from library import sai_model_spec, train_util
|
8 |
+
import library.model_util as model_util
|
9 |
+
import lora
|
10 |
+
|
11 |
+
|
12 |
+
def load_state_dict(file_name, dtype):
|
13 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
14 |
+
sd = load_file(file_name)
|
15 |
+
metadata = train_util.load_metadata_from_safetensors(file_name)
|
16 |
+
else:
|
17 |
+
sd = torch.load(file_name, map_location="cpu")
|
18 |
+
metadata = {}
|
19 |
+
|
20 |
+
for key in list(sd.keys()):
|
21 |
+
if type(sd[key]) == torch.Tensor:
|
22 |
+
sd[key] = sd[key].to(dtype)
|
23 |
+
|
24 |
+
return sd, metadata
|
25 |
+
|
26 |
+
|
27 |
+
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
28 |
+
if dtype is not None:
|
29 |
+
for key in list(state_dict.keys()):
|
30 |
+
if type(state_dict[key]) == torch.Tensor:
|
31 |
+
state_dict[key] = state_dict[key].to(dtype)
|
32 |
+
|
33 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
34 |
+
save_file(model, file_name, metadata=metadata)
|
35 |
+
else:
|
36 |
+
torch.save(model, file_name)
|
37 |
+
|
38 |
+
|
39 |
+
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
40 |
+
text_encoder.to(merge_dtype)
|
41 |
+
unet.to(merge_dtype)
|
42 |
+
|
43 |
+
# create module map
|
44 |
+
name_to_module = {}
|
45 |
+
for i, root_module in enumerate([text_encoder, unet]):
|
46 |
+
if i == 0:
|
47 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
48 |
+
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
49 |
+
else:
|
50 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
51 |
+
target_replace_modules = (
|
52 |
+
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
53 |
+
)
|
54 |
+
|
55 |
+
for name, module in root_module.named_modules():
|
56 |
+
if module.__class__.__name__ in target_replace_modules:
|
57 |
+
for child_name, child_module in module.named_modules():
|
58 |
+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
59 |
+
lora_name = prefix + "." + name + "." + child_name
|
60 |
+
lora_name = lora_name.replace(".", "_")
|
61 |
+
name_to_module[lora_name] = child_module
|
62 |
+
|
63 |
+
for model, ratio in zip(models, ratios):
|
64 |
+
print(f"loading: {model}")
|
65 |
+
lora_sd, _ = load_state_dict(model, merge_dtype)
|
66 |
+
|
67 |
+
print(f"merging...")
|
68 |
+
for key in lora_sd.keys():
|
69 |
+
if "lora_down" in key:
|
70 |
+
up_key = key.replace("lora_down", "lora_up")
|
71 |
+
alpha_key = key[: key.index("lora_down")] + "alpha"
|
72 |
+
|
73 |
+
# find original module for this lora
|
74 |
+
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
75 |
+
if module_name not in name_to_module:
|
76 |
+
print(f"no module found for LoRA weight: {key}")
|
77 |
+
continue
|
78 |
+
module = name_to_module[module_name]
|
79 |
+
# print(f"apply {key} to {module}")
|
80 |
+
|
81 |
+
down_weight = lora_sd[key]
|
82 |
+
up_weight = lora_sd[up_key]
|
83 |
+
|
84 |
+
dim = down_weight.size()[0]
|
85 |
+
alpha = lora_sd.get(alpha_key, dim)
|
86 |
+
scale = alpha / dim
|
87 |
+
|
88 |
+
# W <- W + U * D
|
89 |
+
weight = module.weight
|
90 |
+
if len(weight.size()) == 2:
|
91 |
+
# linear
|
92 |
+
if len(up_weight.size()) == 4: # use linear projection mismatch
|
93 |
+
up_weight = up_weight.squeeze(3).squeeze(2)
|
94 |
+
down_weight = down_weight.squeeze(3).squeeze(2)
|
95 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
96 |
+
elif down_weight.size()[2:4] == (1, 1):
|
97 |
+
# conv2d 1x1
|
98 |
+
weight = (
|
99 |
+
weight
|
100 |
+
+ ratio
|
101 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
102 |
+
* scale
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
# conv2d 3x3
|
106 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
107 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
108 |
+
weight = weight + ratio * conved * scale
|
109 |
+
|
110 |
+
module.weight = torch.nn.Parameter(weight)
|
111 |
+
|
112 |
+
|
113 |
+
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
114 |
+
base_alphas = {} # alpha for merged model
|
115 |
+
base_dims = {}
|
116 |
+
|
117 |
+
merged_sd = {}
|
118 |
+
v2 = None
|
119 |
+
base_model = None
|
120 |
+
for model, ratio in zip(models, ratios):
|
121 |
+
print(f"loading: {model}")
|
122 |
+
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
123 |
+
|
124 |
+
if lora_metadata is not None:
|
125 |
+
if v2 is None:
|
126 |
+
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
|
127 |
+
if base_model is None:
|
128 |
+
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
129 |
+
|
130 |
+
# get alpha and dim
|
131 |
+
alphas = {} # alpha for current model
|
132 |
+
dims = {} # dims for current model
|
133 |
+
for key in lora_sd.keys():
|
134 |
+
if "alpha" in key:
|
135 |
+
lora_module_name = key[: key.rfind(".alpha")]
|
136 |
+
alpha = float(lora_sd[key].detach().numpy())
|
137 |
+
alphas[lora_module_name] = alpha
|
138 |
+
if lora_module_name not in base_alphas:
|
139 |
+
base_alphas[lora_module_name] = alpha
|
140 |
+
elif "lora_down" in key:
|
141 |
+
lora_module_name = key[: key.rfind(".lora_down")]
|
142 |
+
dim = lora_sd[key].size()[0]
|
143 |
+
dims[lora_module_name] = dim
|
144 |
+
if lora_module_name not in base_dims:
|
145 |
+
base_dims[lora_module_name] = dim
|
146 |
+
|
147 |
+
for lora_module_name in dims.keys():
|
148 |
+
if lora_module_name not in alphas:
|
149 |
+
alpha = dims[lora_module_name]
|
150 |
+
alphas[lora_module_name] = alpha
|
151 |
+
if lora_module_name not in base_alphas:
|
152 |
+
base_alphas[lora_module_name] = alpha
|
153 |
+
|
154 |
+
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
155 |
+
|
156 |
+
# merge
|
157 |
+
print(f"merging...")
|
158 |
+
for key in lora_sd.keys():
|
159 |
+
if "alpha" in key:
|
160 |
+
continue
|
161 |
+
if "lora_up" in key and concat:
|
162 |
+
concat_dim = 1
|
163 |
+
elif "lora_down" in key and concat:
|
164 |
+
concat_dim = 0
|
165 |
+
else:
|
166 |
+
concat_dim = None
|
167 |
+
|
168 |
+
lora_module_name = key[: key.rfind(".lora_")]
|
169 |
+
|
170 |
+
base_alpha = base_alphas[lora_module_name]
|
171 |
+
alpha = alphas[lora_module_name]
|
172 |
+
|
173 |
+
scale = math.sqrt(alpha / base_alpha) * ratio
|
174 |
+
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
175 |
+
|
176 |
+
if key in merged_sd:
|
177 |
+
assert (
|
178 |
+
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
179 |
+
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
180 |
+
if concat_dim is not None:
|
181 |
+
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
182 |
+
else:
|
183 |
+
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
184 |
+
else:
|
185 |
+
merged_sd[key] = lora_sd[key] * scale
|
186 |
+
|
187 |
+
# set alpha to sd
|
188 |
+
for lora_module_name, alpha in base_alphas.items():
|
189 |
+
key = lora_module_name + ".alpha"
|
190 |
+
merged_sd[key] = torch.tensor(alpha)
|
191 |
+
if shuffle:
|
192 |
+
key_down = lora_module_name + ".lora_down.weight"
|
193 |
+
key_up = lora_module_name + ".lora_up.weight"
|
194 |
+
dim = merged_sd[key_down].shape[0]
|
195 |
+
perm = torch.randperm(dim)
|
196 |
+
merged_sd[key_down] = merged_sd[key_down][perm]
|
197 |
+
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
198 |
+
|
199 |
+
print("merged model")
|
200 |
+
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
201 |
+
|
202 |
+
# check all dims are same
|
203 |
+
dims_list = list(set(base_dims.values()))
|
204 |
+
alphas_list = list(set(base_alphas.values()))
|
205 |
+
all_same_dims = True
|
206 |
+
all_same_alphas = True
|
207 |
+
for dims in dims_list:
|
208 |
+
if dims != dims_list[0]:
|
209 |
+
all_same_dims = False
|
210 |
+
break
|
211 |
+
for alphas in alphas_list:
|
212 |
+
if alphas != alphas_list[0]:
|
213 |
+
all_same_alphas = False
|
214 |
+
break
|
215 |
+
|
216 |
+
# build minimum metadata
|
217 |
+
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
|
218 |
+
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
|
219 |
+
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
|
220 |
+
|
221 |
+
return merged_sd, metadata, v2 == "True"
|
222 |
+
|
223 |
+
|
224 |
+
def merge(args):
|
225 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
226 |
+
|
227 |
+
def str_to_dtype(p):
|
228 |
+
if p == "float":
|
229 |
+
return torch.float
|
230 |
+
if p == "fp16":
|
231 |
+
return torch.float16
|
232 |
+
if p == "bf16":
|
233 |
+
return torch.bfloat16
|
234 |
+
return None
|
235 |
+
|
236 |
+
merge_dtype = str_to_dtype(args.precision)
|
237 |
+
save_dtype = str_to_dtype(args.save_precision)
|
238 |
+
if save_dtype is None:
|
239 |
+
save_dtype = merge_dtype
|
240 |
+
|
241 |
+
if args.sd_model is not None:
|
242 |
+
print(f"loading SD model: {args.sd_model}")
|
243 |
+
|
244 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
245 |
+
|
246 |
+
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
247 |
+
|
248 |
+
if args.no_metadata:
|
249 |
+
sai_metadata = None
|
250 |
+
else:
|
251 |
+
merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
|
252 |
+
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
253 |
+
sai_metadata = sai_model_spec.build_metadata(
|
254 |
+
None,
|
255 |
+
args.v2,
|
256 |
+
args.v2,
|
257 |
+
False,
|
258 |
+
False,
|
259 |
+
False,
|
260 |
+
time.time(),
|
261 |
+
title=title,
|
262 |
+
merged_from=merged_from,
|
263 |
+
is_stable_diffusion_ckpt=True,
|
264 |
+
)
|
265 |
+
if args.v2:
|
266 |
+
# TODO read sai modelspec
|
267 |
+
print(
|
268 |
+
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
269 |
+
)
|
270 |
+
|
271 |
+
print(f"saving SD model to: {args.save_to}")
|
272 |
+
model_util.save_stable_diffusion_checkpoint(
|
273 |
+
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
277 |
+
|
278 |
+
print(f"calculating hashes and creating metadata...")
|
279 |
+
|
280 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
281 |
+
metadata["sshs_model_hash"] = model_hash
|
282 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
283 |
+
|
284 |
+
if not args.no_metadata:
|
285 |
+
merged_from = sai_model_spec.build_merged_from(args.models)
|
286 |
+
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
287 |
+
sai_metadata = sai_model_spec.build_metadata(
|
288 |
+
state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from
|
289 |
+
)
|
290 |
+
if v2:
|
291 |
+
# TODO read sai modelspec
|
292 |
+
print(
|
293 |
+
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
294 |
+
)
|
295 |
+
metadata.update(sai_metadata)
|
296 |
+
|
297 |
+
print(f"saving model to: {args.save_to}")
|
298 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
299 |
+
|
300 |
+
|
301 |
+
def setup_parser() -> argparse.ArgumentParser:
|
302 |
+
parser = argparse.ArgumentParser()
|
303 |
+
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
304 |
+
parser.add_argument(
|
305 |
+
"--save_precision",
|
306 |
+
type=str,
|
307 |
+
default=None,
|
308 |
+
choices=[None, "float", "fp16", "bf16"],
|
309 |
+
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--precision",
|
313 |
+
type=str,
|
314 |
+
default="float",
|
315 |
+
choices=["float", "fp16", "bf16"],
|
316 |
+
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"--sd_model",
|
320 |
+
type=str,
|
321 |
+
default=None,
|
322 |
+
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
329 |
+
)
|
330 |
+
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
331 |
+
parser.add_argument(
|
332 |
+
"--no_metadata",
|
333 |
+
action="store_true",
|
334 |
+
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
335 |
+
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--concat",
|
339 |
+
action="store_true",
|
340 |
+
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
341 |
+
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
342 |
+
)
|
343 |
+
parser.add_argument(
|
344 |
+
"--shuffle",
|
345 |
+
action="store_true",
|
346 |
+
help="shuffle lora weight./ "
|
347 |
+
+ "LoRAの重みをシャッフルする",
|
348 |
+
)
|
349 |
+
|
350 |
+
return parser
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
parser = setup_parser()
|
355 |
+
|
356 |
+
args = parser.parse_args()
|
357 |
+
merge(args)
|
external/llite/networks/merge_lora_old.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
import library.model_util as model_util
|
8 |
+
import lora
|
9 |
+
|
10 |
+
|
11 |
+
def load_state_dict(file_name, dtype):
|
12 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
13 |
+
sd = load_file(file_name)
|
14 |
+
else:
|
15 |
+
sd = torch.load(file_name, map_location='cpu')
|
16 |
+
for key in list(sd.keys()):
|
17 |
+
if type(sd[key]) == torch.Tensor:
|
18 |
+
sd[key] = sd[key].to(dtype)
|
19 |
+
return sd
|
20 |
+
|
21 |
+
|
22 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
23 |
+
if dtype is not None:
|
24 |
+
for key in list(state_dict.keys()):
|
25 |
+
if type(state_dict[key]) == torch.Tensor:
|
26 |
+
state_dict[key] = state_dict[key].to(dtype)
|
27 |
+
|
28 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
29 |
+
save_file(model, file_name)
|
30 |
+
else:
|
31 |
+
torch.save(model, file_name)
|
32 |
+
|
33 |
+
|
34 |
+
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
35 |
+
text_encoder.to(merge_dtype)
|
36 |
+
unet.to(merge_dtype)
|
37 |
+
|
38 |
+
# create module map
|
39 |
+
name_to_module = {}
|
40 |
+
for i, root_module in enumerate([text_encoder, unet]):
|
41 |
+
if i == 0:
|
42 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
43 |
+
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
44 |
+
else:
|
45 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
46 |
+
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
47 |
+
|
48 |
+
for name, module in root_module.named_modules():
|
49 |
+
if module.__class__.__name__ in target_replace_modules:
|
50 |
+
for child_name, child_module in module.named_modules():
|
51 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
52 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
53 |
+
lora_name = lora_name.replace('.', '_')
|
54 |
+
name_to_module[lora_name] = child_module
|
55 |
+
|
56 |
+
for model, ratio in zip(models, ratios):
|
57 |
+
print(f"loading: {model}")
|
58 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
59 |
+
|
60 |
+
print(f"merging...")
|
61 |
+
for key in lora_sd.keys():
|
62 |
+
if "lora_down" in key:
|
63 |
+
up_key = key.replace("lora_down", "lora_up")
|
64 |
+
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
65 |
+
|
66 |
+
# find original module for this lora
|
67 |
+
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
68 |
+
if module_name not in name_to_module:
|
69 |
+
print(f"no module found for LoRA weight: {key}")
|
70 |
+
continue
|
71 |
+
module = name_to_module[module_name]
|
72 |
+
# print(f"apply {key} to {module}")
|
73 |
+
|
74 |
+
down_weight = lora_sd[key]
|
75 |
+
up_weight = lora_sd[up_key]
|
76 |
+
|
77 |
+
dim = down_weight.size()[0]
|
78 |
+
alpha = lora_sd.get(alpha_key, dim)
|
79 |
+
scale = alpha / dim
|
80 |
+
|
81 |
+
# W <- W + U * D
|
82 |
+
weight = module.weight
|
83 |
+
if len(weight.size()) == 2:
|
84 |
+
# linear
|
85 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
86 |
+
else:
|
87 |
+
# conv2d
|
88 |
+
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
89 |
+
|
90 |
+
module.weight = torch.nn.Parameter(weight)
|
91 |
+
|
92 |
+
|
93 |
+
def merge_lora_models(models, ratios, merge_dtype):
|
94 |
+
merged_sd = {}
|
95 |
+
|
96 |
+
alpha = None
|
97 |
+
dim = None
|
98 |
+
for model, ratio in zip(models, ratios):
|
99 |
+
print(f"loading: {model}")
|
100 |
+
lora_sd = load_state_dict(model, merge_dtype)
|
101 |
+
|
102 |
+
print(f"merging...")
|
103 |
+
for key in lora_sd.keys():
|
104 |
+
if 'alpha' in key:
|
105 |
+
if key in merged_sd:
|
106 |
+
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
107 |
+
else:
|
108 |
+
alpha = lora_sd[key].detach().numpy()
|
109 |
+
merged_sd[key] = lora_sd[key]
|
110 |
+
else:
|
111 |
+
if key in merged_sd:
|
112 |
+
assert merged_sd[key].size() == lora_sd[key].size(
|
113 |
+
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
114 |
+
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
115 |
+
else:
|
116 |
+
if "lora_down" in key:
|
117 |
+
dim = lora_sd[key].size()[0]
|
118 |
+
merged_sd[key] = lora_sd[key] * ratio
|
119 |
+
|
120 |
+
print(f"dim (rank): {dim}, alpha: {alpha}")
|
121 |
+
if alpha is None:
|
122 |
+
alpha = dim
|
123 |
+
|
124 |
+
return merged_sd, dim, alpha
|
125 |
+
|
126 |
+
|
127 |
+
def merge(args):
|
128 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
129 |
+
|
130 |
+
def str_to_dtype(p):
|
131 |
+
if p == 'float':
|
132 |
+
return torch.float
|
133 |
+
if p == 'fp16':
|
134 |
+
return torch.float16
|
135 |
+
if p == 'bf16':
|
136 |
+
return torch.bfloat16
|
137 |
+
return None
|
138 |
+
|
139 |
+
merge_dtype = str_to_dtype(args.precision)
|
140 |
+
save_dtype = str_to_dtype(args.save_precision)
|
141 |
+
if save_dtype is None:
|
142 |
+
save_dtype = merge_dtype
|
143 |
+
|
144 |
+
if args.sd_model is not None:
|
145 |
+
print(f"loading SD model: {args.sd_model}")
|
146 |
+
|
147 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
148 |
+
|
149 |
+
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
150 |
+
|
151 |
+
print(f"\nsaving SD model to: {args.save_to}")
|
152 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
153 |
+
args.sd_model, 0, 0, save_dtype, vae)
|
154 |
+
else:
|
155 |
+
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
156 |
+
|
157 |
+
print(f"\nsaving model to: {args.save_to}")
|
158 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
159 |
+
|
160 |
+
|
161 |
+
def setup_parser() -> argparse.ArgumentParser:
|
162 |
+
parser = argparse.ArgumentParser()
|
163 |
+
parser.add_argument("--v2", action='store_true',
|
164 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
165 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
166 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
167 |
+
parser.add_argument("--precision", type=str, default="float",
|
168 |
+
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
169 |
+
parser.add_argument("--sd_model", type=str, default=None,
|
170 |
+
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
171 |
+
parser.add_argument("--save_to", type=str, default=None,
|
172 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
173 |
+
parser.add_argument("--models", type=str, nargs='*',
|
174 |
+
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
175 |
+
parser.add_argument("--ratios", type=float, nargs='*',
|
176 |
+
help="ratios for each model / それぞれのLoRAモデルの比率")
|
177 |
+
|
178 |
+
return parser
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
parser = setup_parser()
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
merge(args)
|
external/llite/networks/oft.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OFT network module
|
2 |
+
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
6 |
+
from diffusers import AutoencoderKL
|
7 |
+
from transformers import CLIPTextModel
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import re
|
11 |
+
|
12 |
+
|
13 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
14 |
+
|
15 |
+
|
16 |
+
class OFTModule(torch.nn.Module):
|
17 |
+
"""
|
18 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
oft_name,
|
24 |
+
org_module: torch.nn.Module,
|
25 |
+
multiplier=1.0,
|
26 |
+
dim=4,
|
27 |
+
alpha=1,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
dim -> num blocks
|
31 |
+
alpha -> constraint
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
self.oft_name = oft_name
|
35 |
+
|
36 |
+
self.num_blocks = dim
|
37 |
+
|
38 |
+
if "Linear" in org_module.__class__.__name__:
|
39 |
+
out_dim = org_module.out_features
|
40 |
+
elif "Conv" in org_module.__class__.__name__:
|
41 |
+
out_dim = org_module.out_channels
|
42 |
+
|
43 |
+
if type(alpha) == torch.Tensor:
|
44 |
+
alpha = alpha.detach().numpy()
|
45 |
+
self.constraint = alpha * out_dim
|
46 |
+
self.register_buffer("alpha", torch.tensor(alpha))
|
47 |
+
|
48 |
+
self.block_size = out_dim // self.num_blocks
|
49 |
+
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
50 |
+
|
51 |
+
self.out_dim = out_dim
|
52 |
+
self.shape = org_module.weight.shape
|
53 |
+
|
54 |
+
self.multiplier = multiplier
|
55 |
+
self.org_module = [org_module] # moduleにならないようにlistに入れる
|
56 |
+
|
57 |
+
def apply_to(self):
|
58 |
+
self.org_forward = self.org_module[0].forward
|
59 |
+
self.org_module[0].forward = self.forward
|
60 |
+
|
61 |
+
def get_weight(self, multiplier=None):
|
62 |
+
if multiplier is None:
|
63 |
+
multiplier = self.multiplier
|
64 |
+
|
65 |
+
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
66 |
+
norm_Q = torch.norm(block_Q.flatten())
|
67 |
+
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
68 |
+
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
69 |
+
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
70 |
+
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
71 |
+
|
72 |
+
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
73 |
+
R = torch.block_diag(*block_R_weighted)
|
74 |
+
|
75 |
+
return R
|
76 |
+
|
77 |
+
def forward(self, x, scale=None):
|
78 |
+
x = self.org_forward(x)
|
79 |
+
if self.multiplier == 0.0:
|
80 |
+
return x
|
81 |
+
|
82 |
+
R = self.get_weight().to(x.device, dtype=x.dtype)
|
83 |
+
if x.dim() == 4:
|
84 |
+
x = x.permute(0, 2, 3, 1)
|
85 |
+
x = torch.matmul(x, R)
|
86 |
+
x = x.permute(0, 3, 1, 2)
|
87 |
+
else:
|
88 |
+
x = torch.matmul(x, R)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class OFTInfModule(OFTModule):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
oft_name,
|
96 |
+
org_module: torch.nn.Module,
|
97 |
+
multiplier=1.0,
|
98 |
+
dim=4,
|
99 |
+
alpha=1,
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
# no dropout for inference
|
103 |
+
super().__init__(oft_name, org_module, multiplier, dim, alpha)
|
104 |
+
self.enabled = True
|
105 |
+
self.network: OFTNetwork = None
|
106 |
+
|
107 |
+
def set_network(self, network):
|
108 |
+
self.network = network
|
109 |
+
|
110 |
+
def forward(self, x, scale=None):
|
111 |
+
if not self.enabled:
|
112 |
+
return self.org_forward(x)
|
113 |
+
return super().forward(x, scale)
|
114 |
+
|
115 |
+
def merge_to(self, multiplier=None, sign=1):
|
116 |
+
R = self.get_weight(multiplier) * sign
|
117 |
+
|
118 |
+
# get org weight
|
119 |
+
org_sd = self.org_module[0].state_dict()
|
120 |
+
org_weight = org_sd["weight"]
|
121 |
+
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
122 |
+
|
123 |
+
if org_weight.dim() == 4:
|
124 |
+
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
125 |
+
else:
|
126 |
+
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
127 |
+
|
128 |
+
# set weight to org_module
|
129 |
+
org_sd["weight"] = weight
|
130 |
+
self.org_module[0].load_state_dict(org_sd)
|
131 |
+
|
132 |
+
|
133 |
+
def create_network(
|
134 |
+
multiplier: float,
|
135 |
+
network_dim: Optional[int],
|
136 |
+
network_alpha: Optional[float],
|
137 |
+
vae: AutoencoderKL,
|
138 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
139 |
+
unet,
|
140 |
+
neuron_dropout: Optional[float] = None,
|
141 |
+
**kwargs,
|
142 |
+
):
|
143 |
+
if network_dim is None:
|
144 |
+
network_dim = 4 # default
|
145 |
+
if network_alpha is None:
|
146 |
+
network_alpha = 1.0
|
147 |
+
|
148 |
+
enable_all_linear = kwargs.get("enable_all_linear", None)
|
149 |
+
enable_conv = kwargs.get("enable_conv", None)
|
150 |
+
if enable_all_linear is not None:
|
151 |
+
enable_all_linear = bool(enable_all_linear)
|
152 |
+
if enable_conv is not None:
|
153 |
+
enable_conv = bool(enable_conv)
|
154 |
+
|
155 |
+
network = OFTNetwork(
|
156 |
+
text_encoder,
|
157 |
+
unet,
|
158 |
+
multiplier=multiplier,
|
159 |
+
dim=network_dim,
|
160 |
+
alpha=network_alpha,
|
161 |
+
enable_all_linear=enable_all_linear,
|
162 |
+
enable_conv=enable_conv,
|
163 |
+
varbose=True,
|
164 |
+
)
|
165 |
+
return network
|
166 |
+
|
167 |
+
|
168 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
169 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
170 |
+
if weights_sd is None:
|
171 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
172 |
+
from safetensors.torch import load_file, safe_open
|
173 |
+
|
174 |
+
weights_sd = load_file(file)
|
175 |
+
else:
|
176 |
+
weights_sd = torch.load(file, map_location="cpu")
|
177 |
+
|
178 |
+
# check dim, alpha and if weights have for conv2d
|
179 |
+
dim = None
|
180 |
+
alpha = None
|
181 |
+
has_conv2d = None
|
182 |
+
all_linear = None
|
183 |
+
for name, param in weights_sd.items():
|
184 |
+
if name.endswith(".alpha"):
|
185 |
+
if alpha is None:
|
186 |
+
alpha = param.item()
|
187 |
+
else:
|
188 |
+
if dim is None:
|
189 |
+
dim = param.size()[0]
|
190 |
+
if has_conv2d is None and param.dim() == 4:
|
191 |
+
has_conv2d = True
|
192 |
+
if all_linear is None:
|
193 |
+
if param.dim() == 3 and "attn" not in name:
|
194 |
+
all_linear = True
|
195 |
+
if dim is not None and alpha is not None and has_conv2d is not None:
|
196 |
+
break
|
197 |
+
if has_conv2d is None:
|
198 |
+
has_conv2d = False
|
199 |
+
if all_linear is None:
|
200 |
+
all_linear = False
|
201 |
+
|
202 |
+
module_class = OFTInfModule if for_inference else OFTModule
|
203 |
+
network = OFTNetwork(
|
204 |
+
text_encoder,
|
205 |
+
unet,
|
206 |
+
multiplier=multiplier,
|
207 |
+
dim=dim,
|
208 |
+
alpha=alpha,
|
209 |
+
enable_all_linear=all_linear,
|
210 |
+
enable_conv=has_conv2d,
|
211 |
+
module_class=module_class,
|
212 |
+
)
|
213 |
+
return network, weights_sd
|
214 |
+
|
215 |
+
|
216 |
+
class OFTNetwork(torch.nn.Module):
|
217 |
+
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
|
218 |
+
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
|
219 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
220 |
+
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
225 |
+
unet,
|
226 |
+
multiplier: float = 1.0,
|
227 |
+
dim: int = 4,
|
228 |
+
alpha: float = 1,
|
229 |
+
enable_all_linear: Optional[bool] = False,
|
230 |
+
enable_conv: Optional[bool] = False,
|
231 |
+
module_class: Type[object] = OFTModule,
|
232 |
+
varbose: Optional[bool] = False,
|
233 |
+
) -> None:
|
234 |
+
super().__init__()
|
235 |
+
self.multiplier = multiplier
|
236 |
+
|
237 |
+
self.dim = dim
|
238 |
+
self.alpha = alpha
|
239 |
+
|
240 |
+
print(
|
241 |
+
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
242 |
+
)
|
243 |
+
|
244 |
+
# create module instances
|
245 |
+
def create_modules(
|
246 |
+
root_module: torch.nn.Module,
|
247 |
+
target_replace_modules: List[torch.nn.Module],
|
248 |
+
) -> List[OFTModule]:
|
249 |
+
prefix = self.OFT_PREFIX_UNET
|
250 |
+
ofts = []
|
251 |
+
for name, module in root_module.named_modules():
|
252 |
+
if module.__class__.__name__ in target_replace_modules:
|
253 |
+
for child_name, child_module in module.named_modules():
|
254 |
+
is_linear = "Linear" in child_module.__class__.__name__
|
255 |
+
is_conv2d = "Conv2d" in child_module.__class__.__name__
|
256 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
257 |
+
|
258 |
+
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
259 |
+
oft_name = prefix + "." + name + "." + child_name
|
260 |
+
oft_name = oft_name.replace(".", "_")
|
261 |
+
# print(oft_name)
|
262 |
+
|
263 |
+
oft = module_class(
|
264 |
+
oft_name,
|
265 |
+
child_module,
|
266 |
+
self.multiplier,
|
267 |
+
dim,
|
268 |
+
alpha,
|
269 |
+
)
|
270 |
+
ofts.append(oft)
|
271 |
+
return ofts
|
272 |
+
|
273 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
274 |
+
if enable_all_linear:
|
275 |
+
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
|
276 |
+
else:
|
277 |
+
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
|
278 |
+
if enable_conv:
|
279 |
+
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
280 |
+
|
281 |
+
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
282 |
+
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
283 |
+
|
284 |
+
# assertion
|
285 |
+
names = set()
|
286 |
+
for oft in self.unet_ofts:
|
287 |
+
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
|
288 |
+
names.add(oft.oft_name)
|
289 |
+
|
290 |
+
def set_multiplier(self, multiplier):
|
291 |
+
self.multiplier = multiplier
|
292 |
+
for oft in self.unet_ofts:
|
293 |
+
oft.multiplier = self.multiplier
|
294 |
+
|
295 |
+
def load_weights(self, file):
|
296 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
297 |
+
from safetensors.torch import load_file
|
298 |
+
|
299 |
+
weights_sd = load_file(file)
|
300 |
+
else:
|
301 |
+
weights_sd = torch.load(file, map_location="cpu")
|
302 |
+
|
303 |
+
info = self.load_state_dict(weights_sd, False)
|
304 |
+
return info
|
305 |
+
|
306 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
307 |
+
assert apply_unet, "apply_unet must be True"
|
308 |
+
|
309 |
+
for oft in self.unet_ofts:
|
310 |
+
oft.apply_to()
|
311 |
+
self.add_module(oft.oft_name, oft)
|
312 |
+
|
313 |
+
# マージできるかどうかを返す
|
314 |
+
def is_mergeable(self):
|
315 |
+
return True
|
316 |
+
|
317 |
+
# TODO refactor to common function with apply_to
|
318 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
319 |
+
print("enable OFT for U-Net")
|
320 |
+
|
321 |
+
for oft in self.unet_ofts:
|
322 |
+
sd_for_lora = {}
|
323 |
+
for key in weights_sd.keys():
|
324 |
+
if key.startswith(oft.oft_name):
|
325 |
+
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
|
326 |
+
oft.load_state_dict(sd_for_lora, False)
|
327 |
+
oft.merge_to()
|
328 |
+
|
329 |
+
print(f"weights are merged")
|
330 |
+
|
331 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
332 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
333 |
+
self.requires_grad_(True)
|
334 |
+
all_params = []
|
335 |
+
|
336 |
+
def enumerate_params(ofts):
|
337 |
+
params = []
|
338 |
+
for oft in ofts:
|
339 |
+
params.extend(oft.parameters())
|
340 |
+
|
341 |
+
# print num of params
|
342 |
+
num_params = 0
|
343 |
+
for p in params:
|
344 |
+
num_params += p.numel()
|
345 |
+
print(f"OFT params: {num_params}")
|
346 |
+
return params
|
347 |
+
|
348 |
+
param_data = {"params": enumerate_params(self.unet_ofts)}
|
349 |
+
if unet_lr is not None:
|
350 |
+
param_data["lr"] = unet_lr
|
351 |
+
all_params.append(param_data)
|
352 |
+
|
353 |
+
return all_params
|
354 |
+
|
355 |
+
def enable_gradient_checkpointing(self):
|
356 |
+
# not supported
|
357 |
+
pass
|
358 |
+
|
359 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
360 |
+
self.requires_grad_(True)
|
361 |
+
|
362 |
+
def on_epoch_start(self, text_encoder, unet):
|
363 |
+
self.train()
|
364 |
+
|
365 |
+
def get_trainable_params(self):
|
366 |
+
return self.parameters()
|
367 |
+
|
368 |
+
def save_weights(self, file, dtype, metadata):
|
369 |
+
if metadata is not None and len(metadata) == 0:
|
370 |
+
metadata = None
|
371 |
+
|
372 |
+
state_dict = self.state_dict()
|
373 |
+
|
374 |
+
if dtype is not None:
|
375 |
+
for key in list(state_dict.keys()):
|
376 |
+
v = state_dict[key]
|
377 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
378 |
+
state_dict[key] = v
|
379 |
+
|
380 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
381 |
+
from safetensors.torch import save_file
|
382 |
+
from library import train_util
|
383 |
+
|
384 |
+
# Precalculate model hashes to save time on indexing
|
385 |
+
if metadata is None:
|
386 |
+
metadata = {}
|
387 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
388 |
+
metadata["sshs_model_hash"] = model_hash
|
389 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
390 |
+
|
391 |
+
save_file(state_dict, file, metadata)
|
392 |
+
else:
|
393 |
+
torch.save(state_dict, file)
|
394 |
+
|
395 |
+
def backup_weights(self):
|
396 |
+
# 重みのバックアップを行う
|
397 |
+
ofts: List[OFTInfModule] = self.unet_ofts
|
398 |
+
for oft in ofts:
|
399 |
+
org_module = oft.org_module[0]
|
400 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
401 |
+
sd = org_module.state_dict()
|
402 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
403 |
+
org_module._lora_restored = True
|
404 |
+
|
405 |
+
def restore_weights(self):
|
406 |
+
# 重みのリストアを行う
|
407 |
+
ofts: List[OFTInfModule] = self.unet_ofts
|
408 |
+
for oft in ofts:
|
409 |
+
org_module = oft.org_module[0]
|
410 |
+
if not org_module._lora_restored:
|
411 |
+
sd = org_module.state_dict()
|
412 |
+
sd["weight"] = org_module._lora_org_weight
|
413 |
+
org_module.load_state_dict(sd)
|
414 |
+
org_module._lora_restored = True
|
415 |
+
|
416 |
+
def pre_calculation(self):
|
417 |
+
# 事前計算を行う
|
418 |
+
ofts: List[OFTInfModule] = self.unet_ofts
|
419 |
+
for oft in ofts:
|
420 |
+
org_module = oft.org_module[0]
|
421 |
+
oft.merge_to()
|
422 |
+
# sd = org_module.state_dict()
|
423 |
+
# org_weight = sd["weight"]
|
424 |
+
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
425 |
+
# sd["weight"] = org_weight + lora_weight
|
426 |
+
# assert sd["weight"].shape == org_weight.shape
|
427 |
+
# org_module.load_state_dict(sd)
|
428 |
+
|
429 |
+
org_module._lora_restored = False
|
430 |
+
oft.enabled = False
|
external/llite/networks/resize_lora.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
2 |
+
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import torch
|
7 |
+
from safetensors.torch import load_file, save_file, safe_open
|
8 |
+
from tqdm import tqdm
|
9 |
+
from library import train_util, model_util
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
MIN_SV = 1e-6
|
13 |
+
|
14 |
+
# Model save and load functions
|
15 |
+
|
16 |
+
def load_state_dict(file_name, dtype):
|
17 |
+
if model_util.is_safetensors(file_name):
|
18 |
+
sd = load_file(file_name)
|
19 |
+
with safe_open(file_name, framework="pt") as f:
|
20 |
+
metadata = f.metadata()
|
21 |
+
else:
|
22 |
+
sd = torch.load(file_name, map_location='cpu')
|
23 |
+
metadata = None
|
24 |
+
|
25 |
+
for key in list(sd.keys()):
|
26 |
+
if type(sd[key]) == torch.Tensor:
|
27 |
+
sd[key] = sd[key].to(dtype)
|
28 |
+
|
29 |
+
return sd, metadata
|
30 |
+
|
31 |
+
|
32 |
+
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
33 |
+
if dtype is not None:
|
34 |
+
for key in list(state_dict.keys()):
|
35 |
+
if type(state_dict[key]) == torch.Tensor:
|
36 |
+
state_dict[key] = state_dict[key].to(dtype)
|
37 |
+
|
38 |
+
if model_util.is_safetensors(file_name):
|
39 |
+
save_file(model, file_name, metadata)
|
40 |
+
else:
|
41 |
+
torch.save(model, file_name)
|
42 |
+
|
43 |
+
|
44 |
+
# Indexing functions
|
45 |
+
|
46 |
+
def index_sv_cumulative(S, target):
|
47 |
+
original_sum = float(torch.sum(S))
|
48 |
+
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
49 |
+
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
50 |
+
index = max(1, min(index, len(S)-1))
|
51 |
+
|
52 |
+
return index
|
53 |
+
|
54 |
+
|
55 |
+
def index_sv_fro(S, target):
|
56 |
+
S_squared = S.pow(2)
|
57 |
+
s_fro_sq = float(torch.sum(S_squared))
|
58 |
+
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
59 |
+
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
60 |
+
index = max(1, min(index, len(S)-1))
|
61 |
+
|
62 |
+
return index
|
63 |
+
|
64 |
+
|
65 |
+
def index_sv_ratio(S, target):
|
66 |
+
max_sv = S[0]
|
67 |
+
min_sv = max_sv/target
|
68 |
+
index = int(torch.sum(S > min_sv).item())
|
69 |
+
index = max(1, min(index, len(S)-1))
|
70 |
+
|
71 |
+
return index
|
72 |
+
|
73 |
+
|
74 |
+
# Modified from Kohaku-blueleaf's extract/merge functions
|
75 |
+
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
76 |
+
out_size, in_size, kernel_size, _ = weight.size()
|
77 |
+
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
78 |
+
|
79 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
80 |
+
lora_rank = param_dict["new_rank"]
|
81 |
+
|
82 |
+
U = U[:, :lora_rank]
|
83 |
+
S = S[:lora_rank]
|
84 |
+
U = U @ torch.diag(S)
|
85 |
+
Vh = Vh[:lora_rank, :]
|
86 |
+
|
87 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
88 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
89 |
+
del U, S, Vh, weight
|
90 |
+
return param_dict
|
91 |
+
|
92 |
+
|
93 |
+
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
94 |
+
out_size, in_size = weight.size()
|
95 |
+
|
96 |
+
U, S, Vh = torch.linalg.svd(weight.to(device))
|
97 |
+
|
98 |
+
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
99 |
+
lora_rank = param_dict["new_rank"]
|
100 |
+
|
101 |
+
U = U[:, :lora_rank]
|
102 |
+
S = S[:lora_rank]
|
103 |
+
U = U @ torch.diag(S)
|
104 |
+
Vh = Vh[:lora_rank, :]
|
105 |
+
|
106 |
+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
107 |
+
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
108 |
+
del U, S, Vh, weight
|
109 |
+
return param_dict
|
110 |
+
|
111 |
+
|
112 |
+
def merge_conv(lora_down, lora_up, device):
|
113 |
+
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
114 |
+
out_size, out_rank, _, _ = lora_up.shape
|
115 |
+
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
116 |
+
|
117 |
+
lora_down = lora_down.to(device)
|
118 |
+
lora_up = lora_up.to(device)
|
119 |
+
|
120 |
+
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
121 |
+
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
122 |
+
del lora_up, lora_down
|
123 |
+
return weight
|
124 |
+
|
125 |
+
|
126 |
+
def merge_linear(lora_down, lora_up, device):
|
127 |
+
in_rank, in_size = lora_down.shape
|
128 |
+
out_size, out_rank = lora_up.shape
|
129 |
+
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
130 |
+
|
131 |
+
lora_down = lora_down.to(device)
|
132 |
+
lora_up = lora_up.to(device)
|
133 |
+
|
134 |
+
weight = lora_up @ lora_down
|
135 |
+
del lora_up, lora_down
|
136 |
+
return weight
|
137 |
+
|
138 |
+
|
139 |
+
# Calculate new rank
|
140 |
+
|
141 |
+
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
142 |
+
param_dict = {}
|
143 |
+
|
144 |
+
if dynamic_method=="sv_ratio":
|
145 |
+
# Calculate new dim and alpha based off ratio
|
146 |
+
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
147 |
+
new_alpha = float(scale*new_rank)
|
148 |
+
|
149 |
+
elif dynamic_method=="sv_cumulative":
|
150 |
+
# Calculate new dim and alpha based off cumulative sum
|
151 |
+
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
152 |
+
new_alpha = float(scale*new_rank)
|
153 |
+
|
154 |
+
elif dynamic_method=="sv_fro":
|
155 |
+
# Calculate new dim and alpha based off sqrt sum of squares
|
156 |
+
new_rank = index_sv_fro(S, dynamic_param) + 1
|
157 |
+
new_alpha = float(scale*new_rank)
|
158 |
+
else:
|
159 |
+
new_rank = rank
|
160 |
+
new_alpha = float(scale*new_rank)
|
161 |
+
|
162 |
+
|
163 |
+
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
164 |
+
new_rank = 1
|
165 |
+
new_alpha = float(scale*new_rank)
|
166 |
+
elif new_rank > rank: # cap max rank at rank
|
167 |
+
new_rank = rank
|
168 |
+
new_alpha = float(scale*new_rank)
|
169 |
+
|
170 |
+
|
171 |
+
# Calculate resize info
|
172 |
+
s_sum = torch.sum(torch.abs(S))
|
173 |
+
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
174 |
+
|
175 |
+
S_squared = S.pow(2)
|
176 |
+
s_fro = torch.sqrt(torch.sum(S_squared))
|
177 |
+
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
178 |
+
fro_percent = float(s_red_fro/s_fro)
|
179 |
+
|
180 |
+
param_dict["new_rank"] = new_rank
|
181 |
+
param_dict["new_alpha"] = new_alpha
|
182 |
+
param_dict["sum_retained"] = (s_rank)/s_sum
|
183 |
+
param_dict["fro_retained"] = fro_percent
|
184 |
+
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
|
185 |
+
|
186 |
+
return param_dict
|
187 |
+
|
188 |
+
|
189 |
+
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
190 |
+
network_alpha = None
|
191 |
+
network_dim = None
|
192 |
+
verbose_str = "\n"
|
193 |
+
fro_list = []
|
194 |
+
|
195 |
+
# Extract loaded lora dim and alpha
|
196 |
+
for key, value in lora_sd.items():
|
197 |
+
if network_alpha is None and 'alpha' in key:
|
198 |
+
network_alpha = value
|
199 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
200 |
+
network_dim = value.size()[0]
|
201 |
+
if network_alpha is not None and network_dim is not None:
|
202 |
+
break
|
203 |
+
if network_alpha is None:
|
204 |
+
network_alpha = network_dim
|
205 |
+
|
206 |
+
scale = network_alpha/network_dim
|
207 |
+
|
208 |
+
if dynamic_method:
|
209 |
+
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
210 |
+
|
211 |
+
lora_down_weight = None
|
212 |
+
lora_up_weight = None
|
213 |
+
|
214 |
+
o_lora_sd = lora_sd.copy()
|
215 |
+
block_down_name = None
|
216 |
+
block_up_name = None
|
217 |
+
|
218 |
+
with torch.no_grad():
|
219 |
+
for key, value in tqdm(lora_sd.items()):
|
220 |
+
weight_name = None
|
221 |
+
if 'lora_down' in key:
|
222 |
+
block_down_name = key.rsplit('.lora_down', 1)[0]
|
223 |
+
weight_name = key.rsplit(".", 1)[-1]
|
224 |
+
lora_down_weight = value
|
225 |
+
else:
|
226 |
+
continue
|
227 |
+
|
228 |
+
# find corresponding lora_up and alpha
|
229 |
+
block_up_name = block_down_name
|
230 |
+
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
231 |
+
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
232 |
+
|
233 |
+
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
234 |
+
|
235 |
+
if weights_loaded:
|
236 |
+
|
237 |
+
conv2d = (len(lora_down_weight.size()) == 4)
|
238 |
+
if lora_alpha is None:
|
239 |
+
scale = 1.0
|
240 |
+
else:
|
241 |
+
scale = lora_alpha/lora_down_weight.size()[0]
|
242 |
+
|
243 |
+
if conv2d:
|
244 |
+
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
245 |
+
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
246 |
+
else:
|
247 |
+
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
248 |
+
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
249 |
+
|
250 |
+
if verbose:
|
251 |
+
max_ratio = param_dict['max_ratio']
|
252 |
+
sum_retained = param_dict['sum_retained']
|
253 |
+
fro_retained = param_dict['fro_retained']
|
254 |
+
if not np.isnan(fro_retained):
|
255 |
+
fro_list.append(float(fro_retained))
|
256 |
+
|
257 |
+
verbose_str+=f"{block_down_name:75} | "
|
258 |
+
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
259 |
+
|
260 |
+
if verbose and dynamic_method:
|
261 |
+
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
262 |
+
else:
|
263 |
+
verbose_str+=f"\n"
|
264 |
+
|
265 |
+
new_alpha = param_dict['new_alpha']
|
266 |
+
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
267 |
+
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
268 |
+
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
269 |
+
|
270 |
+
block_down_name = None
|
271 |
+
block_up_name = None
|
272 |
+
lora_down_weight = None
|
273 |
+
lora_up_weight = None
|
274 |
+
weights_loaded = False
|
275 |
+
del param_dict
|
276 |
+
|
277 |
+
if verbose:
|
278 |
+
print(verbose_str)
|
279 |
+
|
280 |
+
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
281 |
+
print("resizing complete")
|
282 |
+
return o_lora_sd, network_dim, new_alpha
|
283 |
+
|
284 |
+
|
285 |
+
def resize(args):
|
286 |
+
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
287 |
+
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
288 |
+
|
289 |
+
|
290 |
+
def str_to_dtype(p):
|
291 |
+
if p == 'float':
|
292 |
+
return torch.float
|
293 |
+
if p == 'fp16':
|
294 |
+
return torch.float16
|
295 |
+
if p == 'bf16':
|
296 |
+
return torch.bfloat16
|
297 |
+
return None
|
298 |
+
|
299 |
+
if args.dynamic_method and not args.dynamic_param:
|
300 |
+
raise Exception("If using dynamic_method, then dynamic_param is required")
|
301 |
+
|
302 |
+
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
303 |
+
save_dtype = str_to_dtype(args.save_precision)
|
304 |
+
if save_dtype is None:
|
305 |
+
save_dtype = merge_dtype
|
306 |
+
|
307 |
+
print("loading Model...")
|
308 |
+
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
309 |
+
|
310 |
+
print("Resizing Lora...")
|
311 |
+
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
312 |
+
|
313 |
+
# update metadata
|
314 |
+
if metadata is None:
|
315 |
+
metadata = {}
|
316 |
+
|
317 |
+
comment = metadata.get("ss_training_comment", "")
|
318 |
+
|
319 |
+
if not args.dynamic_method:
|
320 |
+
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
321 |
+
metadata["ss_network_dim"] = str(args.new_rank)
|
322 |
+
metadata["ss_network_alpha"] = str(new_alpha)
|
323 |
+
else:
|
324 |
+
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
325 |
+
metadata["ss_network_dim"] = 'Dynamic'
|
326 |
+
metadata["ss_network_alpha"] = 'Dynamic'
|
327 |
+
|
328 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
329 |
+
metadata["sshs_model_hash"] = model_hash
|
330 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
331 |
+
|
332 |
+
print(f"saving model to: {args.save_to}")
|
333 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
334 |
+
|
335 |
+
|
336 |
+
def setup_parser() -> argparse.ArgumentParser:
|
337 |
+
parser = argparse.ArgumentParser()
|
338 |
+
|
339 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
340 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
|
341 |
+
parser.add_argument("--new_rank", type=int, default=4,
|
342 |
+
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
343 |
+
parser.add_argument("--save_to", type=str, default=None,
|
344 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
345 |
+
parser.add_argument("--model", type=str, default=None,
|
346 |
+
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
347 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
348 |
+
parser.add_argument("--verbose", action="store_true",
|
349 |
+
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
350 |
+
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
351 |
+
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
352 |
+
parser.add_argument("--dynamic_param", type=float, default=None,
|
353 |
+
help="Specify target for dynamic reduction")
|
354 |
+
|
355 |
+
return parser
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == '__main__':
|
359 |
+
parser = setup_parser()
|
360 |
+
|
361 |
+
args = parser.parse_args()
|
362 |
+
resize(args)
|
external/llite/networks/sdxl_merge_lora.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
from tqdm import tqdm
|
8 |
+
from library import sai_model_spec, sdxl_model_util, train_util
|
9 |
+
import library.model_util as model_util
|
10 |
+
import lora
|
11 |
+
|
12 |
+
|
13 |
+
def load_state_dict(file_name, dtype):
|
14 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
15 |
+
sd = load_file(file_name)
|
16 |
+
metadata = train_util.load_metadata_from_safetensors(file_name)
|
17 |
+
else:
|
18 |
+
sd = torch.load(file_name, map_location="cpu")
|
19 |
+
metadata = {}
|
20 |
+
|
21 |
+
for key in list(sd.keys()):
|
22 |
+
if type(sd[key]) == torch.Tensor:
|
23 |
+
sd[key] = sd[key].to(dtype)
|
24 |
+
|
25 |
+
return sd, metadata
|
26 |
+
|
27 |
+
|
28 |
+
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
29 |
+
if dtype is not None:
|
30 |
+
for key in list(state_dict.keys()):
|
31 |
+
if type(state_dict[key]) == torch.Tensor:
|
32 |
+
state_dict[key] = state_dict[key].to(dtype)
|
33 |
+
|
34 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
35 |
+
save_file(model, file_name, metadata=metadata)
|
36 |
+
else:
|
37 |
+
torch.save(model, file_name)
|
38 |
+
|
39 |
+
|
40 |
+
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
41 |
+
text_encoder1.to(merge_dtype)
|
42 |
+
text_encoder1.to(merge_dtype)
|
43 |
+
unet.to(merge_dtype)
|
44 |
+
|
45 |
+
# create module map
|
46 |
+
name_to_module = {}
|
47 |
+
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
48 |
+
if i <= 1:
|
49 |
+
if i == 0:
|
50 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
51 |
+
else:
|
52 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
53 |
+
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
54 |
+
else:
|
55 |
+
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
56 |
+
target_replace_modules = (
|
57 |
+
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
58 |
+
)
|
59 |
+
|
60 |
+
for name, module in root_module.named_modules():
|
61 |
+
if module.__class__.__name__ in target_replace_modules:
|
62 |
+
for child_name, child_module in module.named_modules():
|
63 |
+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
64 |
+
lora_name = prefix + "." + name + "." + child_name
|
65 |
+
lora_name = lora_name.replace(".", "_")
|
66 |
+
name_to_module[lora_name] = child_module
|
67 |
+
|
68 |
+
for model, ratio in zip(models, ratios):
|
69 |
+
print(f"loading: {model}")
|
70 |
+
lora_sd, _ = load_state_dict(model, merge_dtype)
|
71 |
+
|
72 |
+
print(f"merging...")
|
73 |
+
for key in tqdm(lora_sd.keys()):
|
74 |
+
if "lora_down" in key:
|
75 |
+
up_key = key.replace("lora_down", "lora_up")
|
76 |
+
alpha_key = key[: key.index("lora_down")] + "alpha"
|
77 |
+
|
78 |
+
# find original module for this lora
|
79 |
+
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
80 |
+
if module_name not in name_to_module:
|
81 |
+
print(f"no module found for LoRA weight: {key}")
|
82 |
+
continue
|
83 |
+
module = name_to_module[module_name]
|
84 |
+
# print(f"apply {key} to {module}")
|
85 |
+
|
86 |
+
down_weight = lora_sd[key]
|
87 |
+
up_weight = lora_sd[up_key]
|
88 |
+
|
89 |
+
dim = down_weight.size()[0]
|
90 |
+
alpha = lora_sd.get(alpha_key, dim)
|
91 |
+
scale = alpha / dim
|
92 |
+
|
93 |
+
# W <- W + U * D
|
94 |
+
weight = module.weight
|
95 |
+
# print(module_name, down_weight.size(), up_weight.size())
|
96 |
+
if len(weight.size()) == 2:
|
97 |
+
# linear
|
98 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
99 |
+
elif down_weight.size()[2:4] == (1, 1):
|
100 |
+
# conv2d 1x1
|
101 |
+
weight = (
|
102 |
+
weight
|
103 |
+
+ ratio
|
104 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
105 |
+
* scale
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
# conv2d 3x3
|
109 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
110 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
111 |
+
weight = weight + ratio * conved * scale
|
112 |
+
|
113 |
+
module.weight = torch.nn.Parameter(weight)
|
114 |
+
|
115 |
+
|
116 |
+
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
117 |
+
base_alphas = {} # alpha for merged model
|
118 |
+
base_dims = {}
|
119 |
+
|
120 |
+
merged_sd = {}
|
121 |
+
v2 = None
|
122 |
+
base_model = None
|
123 |
+
for model, ratio in zip(models, ratios):
|
124 |
+
print(f"loading: {model}")
|
125 |
+
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
126 |
+
|
127 |
+
if lora_metadata is not None:
|
128 |
+
if v2 is None:
|
129 |
+
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
|
130 |
+
if base_model is None:
|
131 |
+
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
132 |
+
|
133 |
+
# get alpha and dim
|
134 |
+
alphas = {} # alpha for current model
|
135 |
+
dims = {} # dims for current model
|
136 |
+
for key in lora_sd.keys():
|
137 |
+
if "alpha" in key:
|
138 |
+
lora_module_name = key[: key.rfind(".alpha")]
|
139 |
+
alpha = float(lora_sd[key].detach().numpy())
|
140 |
+
alphas[lora_module_name] = alpha
|
141 |
+
if lora_module_name not in base_alphas:
|
142 |
+
base_alphas[lora_module_name] = alpha
|
143 |
+
elif "lora_down" in key:
|
144 |
+
lora_module_name = key[: key.rfind(".lora_down")]
|
145 |
+
dim = lora_sd[key].size()[0]
|
146 |
+
dims[lora_module_name] = dim
|
147 |
+
if lora_module_name not in base_dims:
|
148 |
+
base_dims[lora_module_name] = dim
|
149 |
+
|
150 |
+
for lora_module_name in dims.keys():
|
151 |
+
if lora_module_name not in alphas:
|
152 |
+
alpha = dims[lora_module_name]
|
153 |
+
alphas[lora_module_name] = alpha
|
154 |
+
if lora_module_name not in base_alphas:
|
155 |
+
base_alphas[lora_module_name] = alpha
|
156 |
+
|
157 |
+
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
158 |
+
|
159 |
+
# merge
|
160 |
+
print(f"merging...")
|
161 |
+
for key in tqdm(lora_sd.keys()):
|
162 |
+
if "alpha" in key:
|
163 |
+
continue
|
164 |
+
|
165 |
+
if "lora_up" in key and concat:
|
166 |
+
concat_dim = 1
|
167 |
+
elif "lora_down" in key and concat:
|
168 |
+
concat_dim = 0
|
169 |
+
else:
|
170 |
+
concat_dim = None
|
171 |
+
|
172 |
+
lora_module_name = key[: key.rfind(".lora_")]
|
173 |
+
|
174 |
+
base_alpha = base_alphas[lora_module_name]
|
175 |
+
alpha = alphas[lora_module_name]
|
176 |
+
|
177 |
+
scale = math.sqrt(alpha / base_alpha) * ratio
|
178 |
+
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
179 |
+
|
180 |
+
if key in merged_sd:
|
181 |
+
assert (
|
182 |
+
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
183 |
+
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
184 |
+
if concat_dim is not None:
|
185 |
+
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
186 |
+
else:
|
187 |
+
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
188 |
+
else:
|
189 |
+
merged_sd[key] = lora_sd[key] * scale
|
190 |
+
|
191 |
+
# set alpha to sd
|
192 |
+
for lora_module_name, alpha in base_alphas.items():
|
193 |
+
key = lora_module_name + ".alpha"
|
194 |
+
merged_sd[key] = torch.tensor(alpha)
|
195 |
+
if shuffle:
|
196 |
+
key_down = lora_module_name + ".lora_down.weight"
|
197 |
+
key_up = lora_module_name + ".lora_up.weight"
|
198 |
+
dim = merged_sd[key_down].shape[0]
|
199 |
+
perm = torch.randperm(dim)
|
200 |
+
merged_sd[key_down] = merged_sd[key_down][perm]
|
201 |
+
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
202 |
+
|
203 |
+
print("merged model")
|
204 |
+
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
205 |
+
|
206 |
+
# check all dims are same
|
207 |
+
dims_list = list(set(base_dims.values()))
|
208 |
+
alphas_list = list(set(base_alphas.values()))
|
209 |
+
all_same_dims = True
|
210 |
+
all_same_alphas = True
|
211 |
+
for dims in dims_list:
|
212 |
+
if dims != dims_list[0]:
|
213 |
+
all_same_dims = False
|
214 |
+
break
|
215 |
+
for alphas in alphas_list:
|
216 |
+
if alphas != alphas_list[0]:
|
217 |
+
all_same_alphas = False
|
218 |
+
break
|
219 |
+
|
220 |
+
# build minimum metadata
|
221 |
+
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
|
222 |
+
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
|
223 |
+
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
|
224 |
+
|
225 |
+
return merged_sd, metadata
|
226 |
+
|
227 |
+
|
228 |
+
def merge(args):
|
229 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
230 |
+
|
231 |
+
def str_to_dtype(p):
|
232 |
+
if p == "float":
|
233 |
+
return torch.float
|
234 |
+
if p == "fp16":
|
235 |
+
return torch.float16
|
236 |
+
if p == "bf16":
|
237 |
+
return torch.bfloat16
|
238 |
+
return None
|
239 |
+
|
240 |
+
merge_dtype = str_to_dtype(args.precision)
|
241 |
+
save_dtype = str_to_dtype(args.save_precision)
|
242 |
+
if save_dtype is None:
|
243 |
+
save_dtype = merge_dtype
|
244 |
+
|
245 |
+
if args.sd_model is not None:
|
246 |
+
print(f"loading SD model: {args.sd_model}")
|
247 |
+
|
248 |
+
(
|
249 |
+
text_model1,
|
250 |
+
text_model2,
|
251 |
+
vae,
|
252 |
+
unet,
|
253 |
+
logit_scale,
|
254 |
+
ckpt_info,
|
255 |
+
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
256 |
+
|
257 |
+
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
258 |
+
|
259 |
+
if args.no_metadata:
|
260 |
+
sai_metadata = None
|
261 |
+
else:
|
262 |
+
merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
|
263 |
+
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
264 |
+
sai_metadata = sai_model_spec.build_metadata(
|
265 |
+
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
|
266 |
+
)
|
267 |
+
|
268 |
+
print(f"saving SD model to: {args.save_to}")
|
269 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
270 |
+
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
274 |
+
|
275 |
+
print(f"calculating hashes and creating metadata...")
|
276 |
+
|
277 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
278 |
+
metadata["sshs_model_hash"] = model_hash
|
279 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
280 |
+
|
281 |
+
if not args.no_metadata:
|
282 |
+
merged_from = sai_model_spec.build_merged_from(args.models)
|
283 |
+
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
284 |
+
sai_metadata = sai_model_spec.build_metadata(
|
285 |
+
state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
|
286 |
+
)
|
287 |
+
metadata.update(sai_metadata)
|
288 |
+
|
289 |
+
print(f"saving model to: {args.save_to}")
|
290 |
+
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
291 |
+
|
292 |
+
|
293 |
+
def setup_parser() -> argparse.ArgumentParser:
|
294 |
+
parser = argparse.ArgumentParser()
|
295 |
+
parser.add_argument(
|
296 |
+
"--save_precision",
|
297 |
+
type=str,
|
298 |
+
default=None,
|
299 |
+
choices=[None, "float", "fp16", "bf16"],
|
300 |
+
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
301 |
+
)
|
302 |
+
parser.add_argument(
|
303 |
+
"--precision",
|
304 |
+
type=str,
|
305 |
+
default="float",
|
306 |
+
choices=["float", "fp16", "bf16"],
|
307 |
+
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--sd_model",
|
311 |
+
type=str,
|
312 |
+
default=None,
|
313 |
+
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
320 |
+
)
|
321 |
+
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
322 |
+
parser.add_argument(
|
323 |
+
"--no_metadata",
|
324 |
+
action="store_true",
|
325 |
+
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
326 |
+
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
327 |
+
)
|
328 |
+
parser.add_argument(
|
329 |
+
"--concat",
|
330 |
+
action="store_true",
|
331 |
+
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
332 |
+
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
333 |
+
)
|
334 |
+
parser.add_argument(
|
335 |
+
"--shuffle",
|
336 |
+
action="store_true",
|
337 |
+
help="shuffle lora weight./ "
|
338 |
+
+ "LoRAの重みをシャッフルする",
|
339 |
+
)
|
340 |
+
|
341 |
+
return parser
|
342 |
+
|
343 |
+
|
344 |
+
if __name__ == "__main__":
|
345 |
+
parser = setup_parser()
|
346 |
+
|
347 |
+
args = parser.parse_args()
|
348 |
+
merge(args)
|
external/llite/networks/svd_merge_lora.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
from tqdm import tqdm
|
8 |
+
from library import sai_model_spec, train_util
|
9 |
+
import library.model_util as model_util
|
10 |
+
import lora
|
11 |
+
|
12 |
+
|
13 |
+
CLAMP_QUANTILE = 0.99
|
14 |
+
|
15 |
+
|
16 |
+
def load_state_dict(file_name, dtype):
|
17 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
18 |
+
sd = load_file(file_name)
|
19 |
+
metadata = train_util.load_metadata_from_safetensors(file_name)
|
20 |
+
else:
|
21 |
+
sd = torch.load(file_name, map_location="cpu")
|
22 |
+
metadata = {}
|
23 |
+
|
24 |
+
for key in list(sd.keys()):
|
25 |
+
if type(sd[key]) == torch.Tensor:
|
26 |
+
sd[key] = sd[key].to(dtype)
|
27 |
+
|
28 |
+
return sd, metadata
|
29 |
+
|
30 |
+
|
31 |
+
def save_to_file(file_name, state_dict, dtype, metadata):
|
32 |
+
if dtype is not None:
|
33 |
+
for key in list(state_dict.keys()):
|
34 |
+
if type(state_dict[key]) == torch.Tensor:
|
35 |
+
state_dict[key] = state_dict[key].to(dtype)
|
36 |
+
|
37 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
38 |
+
save_file(state_dict, file_name, metadata=metadata)
|
39 |
+
else:
|
40 |
+
torch.save(state_dict, file_name)
|
41 |
+
|
42 |
+
|
43 |
+
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
44 |
+
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
45 |
+
merged_sd = {}
|
46 |
+
v2 = None
|
47 |
+
base_model = None
|
48 |
+
for model, ratio in zip(models, ratios):
|
49 |
+
print(f"loading: {model}")
|
50 |
+
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
51 |
+
|
52 |
+
if lora_metadata is not None:
|
53 |
+
if v2 is None:
|
54 |
+
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
|
55 |
+
if base_model is None:
|
56 |
+
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
57 |
+
|
58 |
+
# merge
|
59 |
+
print(f"merging...")
|
60 |
+
for key in tqdm(list(lora_sd.keys())):
|
61 |
+
if "lora_down" not in key:
|
62 |
+
continue
|
63 |
+
|
64 |
+
lora_module_name = key[: key.rfind(".lora_down")]
|
65 |
+
|
66 |
+
down_weight = lora_sd[key]
|
67 |
+
network_dim = down_weight.size()[0]
|
68 |
+
|
69 |
+
up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
|
70 |
+
alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
|
71 |
+
|
72 |
+
in_dim = down_weight.size()[1]
|
73 |
+
out_dim = up_weight.size()[0]
|
74 |
+
conv2d = len(down_weight.size()) == 4
|
75 |
+
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
76 |
+
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
77 |
+
|
78 |
+
# make original weight if not exist
|
79 |
+
if lora_module_name not in merged_sd:
|
80 |
+
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
81 |
+
if device:
|
82 |
+
weight = weight.to(device)
|
83 |
+
else:
|
84 |
+
weight = merged_sd[lora_module_name]
|
85 |
+
|
86 |
+
# merge to weight
|
87 |
+
if device:
|
88 |
+
up_weight = up_weight.to(device)
|
89 |
+
down_weight = down_weight.to(device)
|
90 |
+
|
91 |
+
# W <- W + U * D
|
92 |
+
scale = alpha / network_dim
|
93 |
+
|
94 |
+
if device: # and isinstance(scale, torch.Tensor):
|
95 |
+
scale = scale.to(device)
|
96 |
+
|
97 |
+
if not conv2d: # linear
|
98 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
99 |
+
elif kernel_size == (1, 1):
|
100 |
+
weight = (
|
101 |
+
weight
|
102 |
+
+ ratio
|
103 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
104 |
+
* scale
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
108 |
+
weight = weight + ratio * conved * scale
|
109 |
+
|
110 |
+
merged_sd[lora_module_name] = weight
|
111 |
+
|
112 |
+
# extract from merged weights
|
113 |
+
print("extract new lora...")
|
114 |
+
merged_lora_sd = {}
|
115 |
+
with torch.no_grad():
|
116 |
+
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
117 |
+
conv2d = len(mat.size()) == 4
|
118 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
119 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
120 |
+
out_dim, in_dim = mat.size()[0:2]
|
121 |
+
|
122 |
+
if conv2d:
|
123 |
+
if conv2d_3x3:
|
124 |
+
mat = mat.flatten(start_dim=1)
|
125 |
+
else:
|
126 |
+
mat = mat.squeeze()
|
127 |
+
|
128 |
+
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
129 |
+
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
130 |
+
|
131 |
+
U, S, Vh = torch.linalg.svd(mat)
|
132 |
+
|
133 |
+
U = U[:, :module_new_rank]
|
134 |
+
S = S[:module_new_rank]
|
135 |
+
U = U @ torch.diag(S)
|
136 |
+
|
137 |
+
Vh = Vh[:module_new_rank, :]
|
138 |
+
|
139 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
140 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
141 |
+
low_val = -hi_val
|
142 |
+
|
143 |
+
U = U.clamp(low_val, hi_val)
|
144 |
+
Vh = Vh.clamp(low_val, hi_val)
|
145 |
+
|
146 |
+
if conv2d:
|
147 |
+
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
148 |
+
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
149 |
+
|
150 |
+
up_weight = U
|
151 |
+
down_weight = Vh
|
152 |
+
|
153 |
+
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
|
154 |
+
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
155 |
+
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
156 |
+
|
157 |
+
# build minimum metadata
|
158 |
+
dims = f"{new_rank}"
|
159 |
+
alphas = f"{new_rank}"
|
160 |
+
if new_conv_rank is not None:
|
161 |
+
network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
|
162 |
+
else:
|
163 |
+
network_args = None
|
164 |
+
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
|
165 |
+
|
166 |
+
return merged_lora_sd, metadata, v2 == "True", base_model
|
167 |
+
|
168 |
+
|
169 |
+
def merge(args):
|
170 |
+
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
171 |
+
|
172 |
+
def str_to_dtype(p):
|
173 |
+
if p == "float":
|
174 |
+
return torch.float
|
175 |
+
if p == "fp16":
|
176 |
+
return torch.float16
|
177 |
+
if p == "bf16":
|
178 |
+
return torch.bfloat16
|
179 |
+
return None
|
180 |
+
|
181 |
+
merge_dtype = str_to_dtype(args.precision)
|
182 |
+
save_dtype = str_to_dtype(args.save_precision)
|
183 |
+
if save_dtype is None:
|
184 |
+
save_dtype = merge_dtype
|
185 |
+
|
186 |
+
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
187 |
+
state_dict, metadata, v2, base_model = merge_lora_models(
|
188 |
+
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
189 |
+
)
|
190 |
+
|
191 |
+
print(f"calculating hashes and creating metadata...")
|
192 |
+
|
193 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
194 |
+
metadata["sshs_model_hash"] = model_hash
|
195 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
196 |
+
|
197 |
+
if not args.no_metadata:
|
198 |
+
is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
|
199 |
+
merged_from = sai_model_spec.build_merged_from(args.models)
|
200 |
+
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
201 |
+
sai_metadata = sai_model_spec.build_metadata(
|
202 |
+
state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
|
203 |
+
)
|
204 |
+
if v2:
|
205 |
+
# TODO read sai modelspec
|
206 |
+
print(
|
207 |
+
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
208 |
+
)
|
209 |
+
metadata.update(sai_metadata)
|
210 |
+
|
211 |
+
print(f"saving model to: {args.save_to}")
|
212 |
+
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
213 |
+
|
214 |
+
|
215 |
+
def setup_parser() -> argparse.ArgumentParser:
|
216 |
+
parser = argparse.ArgumentParser()
|
217 |
+
parser.add_argument(
|
218 |
+
"--save_precision",
|
219 |
+
type=str,
|
220 |
+
default=None,
|
221 |
+
choices=[None, "float", "fp16", "bf16"],
|
222 |
+
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--precision",
|
226 |
+
type=str,
|
227 |
+
default="float",
|
228 |
+
choices=["float", "fp16", "bf16"],
|
229 |
+
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
236 |
+
)
|
237 |
+
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
238 |
+
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
239 |
+
parser.add_argument(
|
240 |
+
"--new_conv_rank",
|
241 |
+
type=int,
|
242 |
+
default=None,
|
243 |
+
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
244 |
+
)
|
245 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
246 |
+
parser.add_argument(
|
247 |
+
"--no_metadata",
|
248 |
+
action="store_true",
|
249 |
+
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
250 |
+
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
251 |
+
)
|
252 |
+
|
253 |
+
return parser
|
254 |
+
|
255 |
+
|
256 |
+
if __name__ == "__main__":
|
257 |
+
parser = setup_parser()
|
258 |
+
|
259 |
+
args = parser.parse_args()
|
260 |
+
merge(args)
|
external/llite/tools/cache_latents.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# latentsのdiskへの事前キャッシュを行う / cache latents to disk
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import math
|
5 |
+
from multiprocessing import Value
|
6 |
+
import os
|
7 |
+
|
8 |
+
from accelerate.utils import set_seed
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from library import config_util
|
13 |
+
from library import train_util
|
14 |
+
from library import sdxl_train_util
|
15 |
+
from library.config_util import (
|
16 |
+
ConfigSanitizer,
|
17 |
+
BlueprintGenerator,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def cache_to_disk(args: argparse.Namespace) -> None:
|
22 |
+
train_util.prepare_dataset_args(args, True)
|
23 |
+
|
24 |
+
# check cache latents arg
|
25 |
+
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
26 |
+
|
27 |
+
use_dreambooth_method = args.in_json is None
|
28 |
+
|
29 |
+
if args.seed is not None:
|
30 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
31 |
+
|
32 |
+
# tokenizerを準備する:datasetを動かすために必要
|
33 |
+
if args.sdxl:
|
34 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
35 |
+
tokenizers = [tokenizer1, tokenizer2]
|
36 |
+
else:
|
37 |
+
tokenizer = train_util.load_tokenizer(args)
|
38 |
+
tokenizers = [tokenizer]
|
39 |
+
|
40 |
+
# データセットを準備する
|
41 |
+
if args.dataset_class is None:
|
42 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
43 |
+
if args.dataset_config is not None:
|
44 |
+
print(f"Load dataset config from {args.dataset_config}")
|
45 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
46 |
+
ignored = ["train_data_dir", "in_json"]
|
47 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
48 |
+
print(
|
49 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
50 |
+
", ".join(ignored)
|
51 |
+
)
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
if use_dreambooth_method:
|
55 |
+
print("Using DreamBooth method.")
|
56 |
+
user_config = {
|
57 |
+
"datasets": [
|
58 |
+
{
|
59 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
60 |
+
args.train_data_dir, args.reg_data_dir
|
61 |
+
)
|
62 |
+
}
|
63 |
+
]
|
64 |
+
}
|
65 |
+
else:
|
66 |
+
print("Training with captions.")
|
67 |
+
user_config = {
|
68 |
+
"datasets": [
|
69 |
+
{
|
70 |
+
"subsets": [
|
71 |
+
{
|
72 |
+
"image_dir": args.train_data_dir,
|
73 |
+
"metadata_file": args.in_json,
|
74 |
+
}
|
75 |
+
]
|
76 |
+
}
|
77 |
+
]
|
78 |
+
}
|
79 |
+
|
80 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
81 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
82 |
+
else:
|
83 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
84 |
+
|
85 |
+
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
86 |
+
|
87 |
+
current_epoch = Value("i", 0)
|
88 |
+
current_step = Value("i", 0)
|
89 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
90 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
91 |
+
|
92 |
+
# acceleratorを準備する
|
93 |
+
print("prepare accelerator")
|
94 |
+
accelerator = train_util.prepare_accelerator(args)
|
95 |
+
|
96 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
97 |
+
weight_dtype, _ = train_util.prepare_dtype(args)
|
98 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
99 |
+
|
100 |
+
# モデルを読み込む
|
101 |
+
print("load model")
|
102 |
+
if args.sdxl:
|
103 |
+
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
104 |
+
else:
|
105 |
+
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
106 |
+
|
107 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
108 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
109 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
110 |
+
vae.requires_grad_(False)
|
111 |
+
vae.eval()
|
112 |
+
|
113 |
+
# dataloaderを準備する
|
114 |
+
train_dataset_group.set_caching_mode("latents")
|
115 |
+
|
116 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
117 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
118 |
+
|
119 |
+
train_dataloader = torch.utils.data.DataLoader(
|
120 |
+
train_dataset_group,
|
121 |
+
batch_size=1,
|
122 |
+
shuffle=True,
|
123 |
+
collate_fn=collator,
|
124 |
+
num_workers=n_workers,
|
125 |
+
persistent_workers=args.persistent_data_loader_workers,
|
126 |
+
)
|
127 |
+
|
128 |
+
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
129 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
130 |
+
|
131 |
+
# データ取得のためのループ
|
132 |
+
for batch in tqdm(train_dataloader):
|
133 |
+
b_size = len(batch["images"])
|
134 |
+
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
135 |
+
flip_aug = batch["flip_aug"]
|
136 |
+
random_crop = batch["random_crop"]
|
137 |
+
bucket_reso = batch["bucket_reso"]
|
138 |
+
|
139 |
+
# バッチを分割して処理する
|
140 |
+
for i in range(0, b_size, vae_batch_size):
|
141 |
+
images = batch["images"][i : i + vae_batch_size]
|
142 |
+
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
143 |
+
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
144 |
+
|
145 |
+
image_infos = []
|
146 |
+
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
147 |
+
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
148 |
+
image_info.image = image
|
149 |
+
image_info.bucket_reso = bucket_reso
|
150 |
+
image_info.resized_size = resized_size
|
151 |
+
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
|
152 |
+
|
153 |
+
if args.skip_existing:
|
154 |
+
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
155 |
+
print(f"Skipping {image_info.latents_npz} because it already exists.")
|
156 |
+
continue
|
157 |
+
|
158 |
+
image_infos.append(image_info)
|
159 |
+
|
160 |
+
if len(image_infos) > 0:
|
161 |
+
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
|
162 |
+
|
163 |
+
accelerator.wait_for_everyone()
|
164 |
+
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
165 |
+
|
166 |
+
|
167 |
+
def setup_parser() -> argparse.ArgumentParser:
|
168 |
+
parser = argparse.ArgumentParser()
|
169 |
+
|
170 |
+
train_util.add_sd_models_arguments(parser)
|
171 |
+
train_util.add_training_arguments(parser, True)
|
172 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
173 |
+
config_util.add_config_arguments(parser)
|
174 |
+
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
175 |
+
parser.add_argument(
|
176 |
+
"--no_half_vae",
|
177 |
+
action="store_true",
|
178 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--skip_existing",
|
182 |
+
action="store_true",
|
183 |
+
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
184 |
+
)
|
185 |
+
return parser
|
186 |
+
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
parser = setup_parser()
|
190 |
+
|
191 |
+
args = parser.parse_args()
|
192 |
+
args = train_util.read_config_from_file(args, parser)
|
193 |
+
|
194 |
+
cache_to_disk(args)
|
external/llite/tools/cache_text_encoder_outputs.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import math
|
5 |
+
from multiprocessing import Value
|
6 |
+
import os
|
7 |
+
|
8 |
+
from accelerate.utils import set_seed
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from library import config_util
|
13 |
+
from library import train_util
|
14 |
+
from library import sdxl_train_util
|
15 |
+
from library.config_util import (
|
16 |
+
ConfigSanitizer,
|
17 |
+
BlueprintGenerator,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def cache_to_disk(args: argparse.Namespace) -> None:
|
22 |
+
train_util.prepare_dataset_args(args, True)
|
23 |
+
|
24 |
+
# check cache arg
|
25 |
+
assert (
|
26 |
+
args.cache_text_encoder_outputs_to_disk
|
27 |
+
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
28 |
+
|
29 |
+
# できるだけ準備はしておくが今のところSDXLのみしか動かない
|
30 |
+
assert (
|
31 |
+
args.sdxl
|
32 |
+
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
|
33 |
+
|
34 |
+
use_dreambooth_method = args.in_json is None
|
35 |
+
|
36 |
+
if args.seed is not None:
|
37 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
38 |
+
|
39 |
+
# tokenizerを準備する:datasetを動かすために必要
|
40 |
+
if args.sdxl:
|
41 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
42 |
+
tokenizers = [tokenizer1, tokenizer2]
|
43 |
+
else:
|
44 |
+
tokenizer = train_util.load_tokenizer(args)
|
45 |
+
tokenizers = [tokenizer]
|
46 |
+
|
47 |
+
# データセットを準備する
|
48 |
+
if args.dataset_class is None:
|
49 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
50 |
+
if args.dataset_config is not None:
|
51 |
+
print(f"Load dataset config from {args.dataset_config}")
|
52 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
53 |
+
ignored = ["train_data_dir", "in_json"]
|
54 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
55 |
+
print(
|
56 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
57 |
+
", ".join(ignored)
|
58 |
+
)
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
if use_dreambooth_method:
|
62 |
+
print("Using DreamBooth method.")
|
63 |
+
user_config = {
|
64 |
+
"datasets": [
|
65 |
+
{
|
66 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
67 |
+
args.train_data_dir, args.reg_data_dir
|
68 |
+
)
|
69 |
+
}
|
70 |
+
]
|
71 |
+
}
|
72 |
+
else:
|
73 |
+
print("Training with captions.")
|
74 |
+
user_config = {
|
75 |
+
"datasets": [
|
76 |
+
{
|
77 |
+
"subsets": [
|
78 |
+
{
|
79 |
+
"image_dir": args.train_data_dir,
|
80 |
+
"metadata_file": args.in_json,
|
81 |
+
}
|
82 |
+
]
|
83 |
+
}
|
84 |
+
]
|
85 |
+
}
|
86 |
+
|
87 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
88 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
89 |
+
else:
|
90 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
91 |
+
|
92 |
+
current_epoch = Value("i", 0)
|
93 |
+
current_step = Value("i", 0)
|
94 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
95 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
96 |
+
|
97 |
+
# acceleratorを準備する
|
98 |
+
print("prepare accelerator")
|
99 |
+
accelerator = train_util.prepare_accelerator(args)
|
100 |
+
|
101 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
102 |
+
weight_dtype, _ = train_util.prepare_dtype(args)
|
103 |
+
|
104 |
+
# モデルを読み込む
|
105 |
+
print("load model")
|
106 |
+
if args.sdxl:
|
107 |
+
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
108 |
+
text_encoders = [text_encoder1, text_encoder2]
|
109 |
+
else:
|
110 |
+
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
111 |
+
text_encoders = [text_encoder1]
|
112 |
+
|
113 |
+
for text_encoder in text_encoders:
|
114 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
115 |
+
text_encoder.requires_grad_(False)
|
116 |
+
text_encoder.eval()
|
117 |
+
|
118 |
+
# dataloaderを準備する
|
119 |
+
train_dataset_group.set_caching_mode("text")
|
120 |
+
|
121 |
+
# DataLoaderのプロセス数:0はメインプロセスになる
|
122 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
123 |
+
|
124 |
+
train_dataloader = torch.utils.data.DataLoader(
|
125 |
+
train_dataset_group,
|
126 |
+
batch_size=1,
|
127 |
+
shuffle=True,
|
128 |
+
collate_fn=collator,
|
129 |
+
num_workers=n_workers,
|
130 |
+
persistent_workers=args.persistent_data_loader_workers,
|
131 |
+
)
|
132 |
+
|
133 |
+
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
134 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
135 |
+
|
136 |
+
# データ取得のためのループ
|
137 |
+
for batch in tqdm(train_dataloader):
|
138 |
+
absolute_paths = batch["absolute_paths"]
|
139 |
+
input_ids1_list = batch["input_ids1_list"]
|
140 |
+
input_ids2_list = batch["input_ids2_list"]
|
141 |
+
|
142 |
+
image_infos = []
|
143 |
+
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
|
144 |
+
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
145 |
+
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
146 |
+
image_info
|
147 |
+
|
148 |
+
if args.skip_existing:
|
149 |
+
if os.path.exists(image_info.text_encoder_outputs_npz):
|
150 |
+
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
151 |
+
continue
|
152 |
+
|
153 |
+
image_info.input_ids1 = input_ids1
|
154 |
+
image_info.input_ids2 = input_ids2
|
155 |
+
image_infos.append(image_info)
|
156 |
+
|
157 |
+
if len(image_infos) > 0:
|
158 |
+
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
159 |
+
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
|
160 |
+
train_util.cache_batch_text_encoder_outputs(
|
161 |
+
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
|
162 |
+
)
|
163 |
+
|
164 |
+
accelerator.wait_for_everyone()
|
165 |
+
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
166 |
+
|
167 |
+
|
168 |
+
def setup_parser() -> argparse.ArgumentParser:
|
169 |
+
parser = argparse.ArgumentParser()
|
170 |
+
|
171 |
+
train_util.add_sd_models_arguments(parser)
|
172 |
+
train_util.add_training_arguments(parser, True)
|
173 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
174 |
+
config_util.add_config_arguments(parser)
|
175 |
+
sdxl_train_util.add_sdxl_training_arguments(parser)
|
176 |
+
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
177 |
+
parser.add_argument(
|
178 |
+
"--skip_existing",
|
179 |
+
action="store_true",
|
180 |
+
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
181 |
+
)
|
182 |
+
return parser
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
parser = setup_parser()
|
187 |
+
|
188 |
+
args = parser.parse_args()
|
189 |
+
args = train_util.read_config_from_file(args, parser)
|
190 |
+
|
191 |
+
cache_to_disk(args)
|
external/llite/tools/canny.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
def canny(args):
|
6 |
+
img = cv2.imread(args.input)
|
7 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
8 |
+
|
9 |
+
canny_img = cv2.Canny(img, args.thres1, args.thres2)
|
10 |
+
# canny_img = 255 - canny_img
|
11 |
+
|
12 |
+
cv2.imwrite(args.output, canny_img)
|
13 |
+
print("done!")
|
14 |
+
|
15 |
+
|
16 |
+
def setup_parser() -> argparse.ArgumentParser:
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--input", type=str, default=None, help="input path")
|
19 |
+
parser.add_argument("--output", type=str, default=None, help="output path")
|
20 |
+
parser.add_argument("--thres1", type=int, default=32, help="thres1")
|
21 |
+
parser.add_argument("--thres2", type=int, default=224, help="thres2")
|
22 |
+
|
23 |
+
return parser
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
parser = setup_parser()
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
canny(args)
|
external/llite/tools/convert_diffusers20_original_sd.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from diffusers import StableDiffusionPipeline
|
7 |
+
|
8 |
+
import library.model_util as model_util
|
9 |
+
|
10 |
+
|
11 |
+
def convert(args):
|
12 |
+
# 引数を確認する
|
13 |
+
load_dtype = torch.float16 if args.fp16 else None
|
14 |
+
|
15 |
+
save_dtype = None
|
16 |
+
if args.fp16 or args.save_precision_as == "fp16":
|
17 |
+
save_dtype = torch.float16
|
18 |
+
elif args.bf16 or args.save_precision_as == "bf16":
|
19 |
+
save_dtype = torch.bfloat16
|
20 |
+
elif args.float or args.save_precision_as == "float":
|
21 |
+
save_dtype = torch.float
|
22 |
+
|
23 |
+
is_load_ckpt = os.path.isfile(args.model_to_load)
|
24 |
+
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
25 |
+
|
26 |
+
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
27 |
+
# assert (
|
28 |
+
# is_save_ckpt or args.reference_model is not None
|
29 |
+
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
30 |
+
|
31 |
+
# モデルを読み込む
|
32 |
+
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
33 |
+
print(f"loading {msg}: {args.model_to_load}")
|
34 |
+
|
35 |
+
if is_load_ckpt:
|
36 |
+
v2_model = args.v2
|
37 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
38 |
+
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
42 |
+
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
|
43 |
+
)
|
44 |
+
text_encoder = pipe.text_encoder
|
45 |
+
vae = pipe.vae
|
46 |
+
unet = pipe.unet
|
47 |
+
|
48 |
+
if args.v1 == args.v2:
|
49 |
+
# 自動判定する
|
50 |
+
v2_model = unet.config.cross_attention_dim == 1024
|
51 |
+
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
52 |
+
else:
|
53 |
+
v2_model = not args.v1
|
54 |
+
|
55 |
+
# 変換して保存する
|
56 |
+
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
57 |
+
print(f"converting and saving as {msg}: {args.model_to_save}")
|
58 |
+
|
59 |
+
if is_save_ckpt:
|
60 |
+
original_model = args.model_to_load if is_load_ckpt else None
|
61 |
+
key_count = model_util.save_stable_diffusion_checkpoint(
|
62 |
+
v2_model,
|
63 |
+
args.model_to_save,
|
64 |
+
text_encoder,
|
65 |
+
unet,
|
66 |
+
original_model,
|
67 |
+
args.epoch,
|
68 |
+
args.global_step,
|
69 |
+
None if args.metadata is None else eval(args.metadata),
|
70 |
+
save_dtype=save_dtype,
|
71 |
+
vae=vae,
|
72 |
+
)
|
73 |
+
print(f"model saved. total converted state_dict keys: {key_count}")
|
74 |
+
else:
|
75 |
+
print(
|
76 |
+
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
77 |
+
)
|
78 |
+
model_util.save_diffusers_checkpoint(
|
79 |
+
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
80 |
+
)
|
81 |
+
print("model saved.")
|
82 |
+
|
83 |
+
|
84 |
+
def setup_parser() -> argparse.ArgumentParser:
|
85 |
+
parser = argparse.ArgumentParser()
|
86 |
+
parser.add_argument(
|
87 |
+
"--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--unet_use_linear_projection",
|
94 |
+
action="store_true",
|
95 |
+
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--fp16",
|
99 |
+
action="store_true",
|
100 |
+
help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
|
101 |
+
)
|
102 |
+
parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
|
103 |
+
parser.add_argument(
|
104 |
+
"--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--save_precision_as",
|
108 |
+
type=str,
|
109 |
+
default="no",
|
110 |
+
choices=["fp16", "bf16", "float"],
|
111 |
+
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
|
112 |
+
)
|
113 |
+
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
|
114 |
+
parser.add_argument(
|
115 |
+
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--metadata",
|
119 |
+
type=str,
|
120 |
+
default=None,
|
121 |
+
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--variant",
|
125 |
+
type=str,
|
126 |
+
default=None,
|
127 |
+
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--reference_model",
|
131 |
+
type=str,
|
132 |
+
default=None,
|
133 |
+
help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--use_safetensors",
|
137 |
+
action="store_true",
|
138 |
+
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
|
139 |
+
)
|
140 |
+
|
141 |
+
parser.add_argument(
|
142 |
+
"model_to_load",
|
143 |
+
type=str,
|
144 |
+
default=None,
|
145 |
+
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"model_to_save",
|
149 |
+
type=str,
|
150 |
+
default=None,
|
151 |
+
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
|
152 |
+
)
|
153 |
+
return parser
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
parser = setup_parser()
|
158 |
+
|
159 |
+
args = parser.parse_args()
|
160 |
+
convert(args)
|
external/llite/tools/detect_face_rotate.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
3 |
+
|
4 |
+
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
|
5 |
+
|
6 |
+
# v2: extract max face if multiple faces are found
|
7 |
+
# v3: add crop_ratio option
|
8 |
+
# v4: add multiple faces extraction and min/max size
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import math
|
12 |
+
import cv2
|
13 |
+
import glob
|
14 |
+
import os
|
15 |
+
from anime_face_detector import create_detector
|
16 |
+
from tqdm import tqdm
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
KP_REYE = 11
|
20 |
+
KP_LEYE = 19
|
21 |
+
|
22 |
+
SCORE_THRES = 0.90
|
23 |
+
|
24 |
+
|
25 |
+
def detect_faces(detector, image, min_size):
|
26 |
+
preds = detector(image) # bgr
|
27 |
+
# print(len(preds))
|
28 |
+
|
29 |
+
faces = []
|
30 |
+
for pred in preds:
|
31 |
+
bb = pred['bbox']
|
32 |
+
score = bb[-1]
|
33 |
+
if score < SCORE_THRES:
|
34 |
+
continue
|
35 |
+
|
36 |
+
left, top, right, bottom = bb[:4]
|
37 |
+
cx = int((left + right) / 2)
|
38 |
+
cy = int((top + bottom) / 2)
|
39 |
+
fw = int(right - left)
|
40 |
+
fh = int(bottom - top)
|
41 |
+
|
42 |
+
lex, ley = pred['keypoints'][KP_LEYE, 0:2]
|
43 |
+
rex, rey = pred['keypoints'][KP_REYE, 0:2]
|
44 |
+
angle = math.atan2(ley - rey, lex - rex)
|
45 |
+
angle = angle / math.pi * 180
|
46 |
+
|
47 |
+
faces.append((cx, cy, fw, fh, angle))
|
48 |
+
|
49 |
+
faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
|
50 |
+
return faces
|
51 |
+
|
52 |
+
|
53 |
+
def rotate_image(image, angle, cx, cy):
|
54 |
+
h, w = image.shape[0:2]
|
55 |
+
rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
56 |
+
|
57 |
+
# # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
|
58 |
+
# nh = max(h, int(w * math.sin(angle)))
|
59 |
+
# nw = max(w, int(h * math.sin(angle)))
|
60 |
+
# if nh > h or nw > w:
|
61 |
+
# pad_y = nh - h
|
62 |
+
# pad_t = pad_y // 2
|
63 |
+
# pad_x = nw - w
|
64 |
+
# pad_l = pad_x // 2
|
65 |
+
# m = np.array([[0, 0, pad_l],
|
66 |
+
# [0, 0, pad_t]])
|
67 |
+
# rot_mat = rot_mat + m
|
68 |
+
# h, w = nh, nw
|
69 |
+
# cx += pad_l
|
70 |
+
# cy += pad_t
|
71 |
+
|
72 |
+
result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
73 |
+
return result, cx, cy
|
74 |
+
|
75 |
+
|
76 |
+
def process(args):
|
77 |
+
assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
|
78 |
+
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
79 |
+
|
80 |
+
# アニメ顔検出モデルを読み込む
|
81 |
+
print("loading face detector.")
|
82 |
+
detector = create_detector('yolov3')
|
83 |
+
|
84 |
+
# cropの引数を解析する
|
85 |
+
if args.crop_size is None:
|
86 |
+
crop_width = crop_height = None
|
87 |
+
else:
|
88 |
+
tokens = args.crop_size.split(',')
|
89 |
+
assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
|
90 |
+
crop_width, crop_height = [int(t) for t in tokens]
|
91 |
+
|
92 |
+
if args.crop_ratio is None:
|
93 |
+
crop_h_ratio = crop_v_ratio = None
|
94 |
+
else:
|
95 |
+
tokens = args.crop_ratio.split(',')
|
96 |
+
assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
|
97 |
+
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
98 |
+
|
99 |
+
# 画像を処理する
|
100 |
+
print("processing.")
|
101 |
+
output_extension = ".png"
|
102 |
+
|
103 |
+
os.makedirs(args.dst_dir, exist_ok=True)
|
104 |
+
paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
|
105 |
+
glob.glob(os.path.join(args.src_dir, "*.webp"))
|
106 |
+
for path in tqdm(paths):
|
107 |
+
basename = os.path.splitext(os.path.basename(path))[0]
|
108 |
+
|
109 |
+
# image = cv2.imread(path) # 日本語ファイル名でエラーになる
|
110 |
+
image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
|
111 |
+
if len(image.shape) == 2:
|
112 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
113 |
+
if image.shape[2] == 4:
|
114 |
+
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
115 |
+
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
116 |
+
|
117 |
+
h, w = image.shape[:2]
|
118 |
+
|
119 |
+
faces = detect_faces(detector, image, args.multiple_faces)
|
120 |
+
for i, face in enumerate(faces):
|
121 |
+
cx, cy, fw, fh, angle = face
|
122 |
+
face_size = max(fw, fh)
|
123 |
+
if args.min_size is not None and face_size < args.min_size:
|
124 |
+
continue
|
125 |
+
if args.max_size is not None and face_size >= args.max_size:
|
126 |
+
continue
|
127 |
+
face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
|
128 |
+
|
129 |
+
# オプション指定があれば回転する
|
130 |
+
face_img = image
|
131 |
+
if args.rotate:
|
132 |
+
face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
|
133 |
+
|
134 |
+
# オプション指定があれば顔を中心に切り出す
|
135 |
+
if crop_width is not None or crop_h_ratio is not None:
|
136 |
+
cur_crop_width, cur_crop_height = crop_width, crop_height
|
137 |
+
if crop_h_ratio is not None:
|
138 |
+
cur_crop_width = int(face_size * crop_h_ratio + .5)
|
139 |
+
cur_crop_height = int(face_size * crop_v_ratio + .5)
|
140 |
+
|
141 |
+
# リサイズを必要なら行う
|
142 |
+
scale = 1.0
|
143 |
+
if args.resize_face_size is not None:
|
144 |
+
# 顔サイズを基準にリサイズする
|
145 |
+
scale = args.resize_face_size / face_size
|
146 |
+
if scale < cur_crop_width / w:
|
147 |
+
print(
|
148 |
+
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
149 |
+
scale = cur_crop_width / w
|
150 |
+
if scale < cur_crop_height / h:
|
151 |
+
print(
|
152 |
+
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
153 |
+
scale = cur_crop_height / h
|
154 |
+
elif crop_h_ratio is not None:
|
155 |
+
# 倍率指定の時にはリサイズしない
|
156 |
+
pass
|
157 |
+
else:
|
158 |
+
# 切り出しサイズ指定あり
|
159 |
+
if w < cur_crop_width:
|
160 |
+
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
161 |
+
scale = cur_crop_width / w
|
162 |
+
if h < cur_crop_height:
|
163 |
+
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
164 |
+
scale = cur_crop_height / h
|
165 |
+
if args.resize_fit:
|
166 |
+
scale = max(cur_crop_width / w, cur_crop_height / h)
|
167 |
+
|
168 |
+
if scale != 1.0:
|
169 |
+
w = int(w * scale + .5)
|
170 |
+
h = int(h * scale + .5)
|
171 |
+
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
|
172 |
+
cx = int(cx * scale + .5)
|
173 |
+
cy = int(cy * scale + .5)
|
174 |
+
fw = int(fw * scale + .5)
|
175 |
+
fh = int(fh * scale + .5)
|
176 |
+
|
177 |
+
cur_crop_width = min(cur_crop_width, face_img.shape[1])
|
178 |
+
cur_crop_height = min(cur_crop_height, face_img.shape[0])
|
179 |
+
|
180 |
+
x = cx - cur_crop_width // 2
|
181 |
+
cx = cur_crop_width // 2
|
182 |
+
if x < 0:
|
183 |
+
cx = cx + x
|
184 |
+
x = 0
|
185 |
+
elif x + cur_crop_width > w:
|
186 |
+
cx = cx + (x + cur_crop_width - w)
|
187 |
+
x = w - cur_crop_width
|
188 |
+
face_img = face_img[:, x:x+cur_crop_width]
|
189 |
+
|
190 |
+
y = cy - cur_crop_height // 2
|
191 |
+
cy = cur_crop_height // 2
|
192 |
+
if y < 0:
|
193 |
+
cy = cy + y
|
194 |
+
y = 0
|
195 |
+
elif y + cur_crop_height > h:
|
196 |
+
cy = cy + (y + cur_crop_height - h)
|
197 |
+
y = h - cur_crop_height
|
198 |
+
face_img = face_img[y:y + cur_crop_height]
|
199 |
+
|
200 |
+
# # debug
|
201 |
+
# print(path, cx, cy, angle)
|
202 |
+
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
203 |
+
# cv2.imshow("image", crp)
|
204 |
+
# if cv2.waitKey() == 27:
|
205 |
+
# break
|
206 |
+
# cv2.destroyAllWindows()
|
207 |
+
|
208 |
+
# debug
|
209 |
+
if args.debug:
|
210 |
+
cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
|
211 |
+
|
212 |
+
_, buf = cv2.imencode(output_extension, face_img)
|
213 |
+
with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
|
214 |
+
buf.tofile(f)
|
215 |
+
|
216 |
+
|
217 |
+
def setup_parser() -> argparse.ArgumentParser:
|
218 |
+
parser = argparse.ArgumentParser()
|
219 |
+
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
220 |
+
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
221 |
+
parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
|
222 |
+
parser.add_argument("--resize_fit", action="store_true",
|
223 |
+
help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
|
224 |
+
parser.add_argument("--resize_face_size", type=int, default=None,
|
225 |
+
help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
|
226 |
+
parser.add_argument("--crop_size", type=str, default=None,
|
227 |
+
help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
|
228 |
+
parser.add_argument("--crop_ratio", type=str, default=None,
|
229 |
+
help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
|
230 |
+
parser.add_argument("--min_size", type=int, default=None,
|
231 |
+
help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
|
232 |
+
parser.add_argument("--max_size", type=int, default=None,
|
233 |
+
help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
|
234 |
+
parser.add_argument("--multiple_faces", action="store_true",
|
235 |
+
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
236 |
+
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
237 |
+
|
238 |
+
return parser
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == '__main__':
|
242 |
+
parser = setup_parser()
|
243 |
+
|
244 |
+
args = parser.parse_args()
|
245 |
+
|
246 |
+
process(args)
|
external/llite/tools/latent_upscaler.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 外部から簡単にupscalerを呼ぶためのスクリプト
|
2 |
+
# 単体で動くようにモデル定義も含めている
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
from diffusers import AutoencoderKL
|
9 |
+
|
10 |
+
from typing import Dict, List
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from tqdm import tqdm
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class ResidualBlock(nn.Module):
|
20 |
+
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
21 |
+
super(ResidualBlock, self).__init__()
|
22 |
+
|
23 |
+
if out_channels is None:
|
24 |
+
out_channels = in_channels
|
25 |
+
|
26 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
27 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
28 |
+
self.relu1 = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
|
31 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
32 |
+
|
33 |
+
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
|
34 |
+
|
35 |
+
# initialize weights
|
36 |
+
self._initialize_weights()
|
37 |
+
|
38 |
+
def _initialize_weights(self):
|
39 |
+
for m in self.modules():
|
40 |
+
if isinstance(m, nn.Conv2d):
|
41 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
42 |
+
if m.bias is not None:
|
43 |
+
nn.init.constant_(m.bias, 0)
|
44 |
+
elif isinstance(m, nn.BatchNorm2d):
|
45 |
+
nn.init.constant_(m.weight, 1)
|
46 |
+
nn.init.constant_(m.bias, 0)
|
47 |
+
elif isinstance(m, nn.Linear):
|
48 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
49 |
+
nn.init.constant_(m.bias, 0)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
residual = x
|
53 |
+
|
54 |
+
out = self.conv1(x)
|
55 |
+
out = self.bn1(out)
|
56 |
+
out = self.relu1(out)
|
57 |
+
|
58 |
+
out = self.conv2(out)
|
59 |
+
out = self.bn2(out)
|
60 |
+
|
61 |
+
out += residual
|
62 |
+
|
63 |
+
out = self.relu2(out)
|
64 |
+
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class Upscaler(nn.Module):
|
69 |
+
def __init__(self):
|
70 |
+
super(Upscaler, self).__init__()
|
71 |
+
|
72 |
+
# define layers
|
73 |
+
# latent has 4 channels
|
74 |
+
|
75 |
+
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
76 |
+
self.bn1 = nn.BatchNorm2d(128)
|
77 |
+
self.relu1 = nn.ReLU(inplace=True)
|
78 |
+
|
79 |
+
# resblocks
|
80 |
+
# 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
|
81 |
+
self.resblock1 = ResidualBlock(128)
|
82 |
+
self.resblock2 = ResidualBlock(128)
|
83 |
+
self.resblock3 = ResidualBlock(128)
|
84 |
+
self.resblock4 = ResidualBlock(128)
|
85 |
+
self.resblock5 = ResidualBlock(128)
|
86 |
+
self.resblock6 = ResidualBlock(128)
|
87 |
+
self.resblock7 = ResidualBlock(128)
|
88 |
+
self.resblock8 = ResidualBlock(128)
|
89 |
+
self.resblock9 = ResidualBlock(128)
|
90 |
+
self.resblock10 = ResidualBlock(128)
|
91 |
+
self.resblock11 = ResidualBlock(128)
|
92 |
+
self.resblock12 = ResidualBlock(128)
|
93 |
+
self.resblock13 = ResidualBlock(128)
|
94 |
+
self.resblock14 = ResidualBlock(128)
|
95 |
+
self.resblock15 = ResidualBlock(128)
|
96 |
+
self.resblock16 = ResidualBlock(128)
|
97 |
+
self.resblock17 = ResidualBlock(128)
|
98 |
+
self.resblock18 = ResidualBlock(128)
|
99 |
+
self.resblock19 = ResidualBlock(128)
|
100 |
+
self.resblock20 = ResidualBlock(128)
|
101 |
+
|
102 |
+
# last convs
|
103 |
+
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
104 |
+
self.bn2 = nn.BatchNorm2d(64)
|
105 |
+
self.relu2 = nn.ReLU(inplace=True)
|
106 |
+
|
107 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
108 |
+
self.bn3 = nn.BatchNorm2d(64)
|
109 |
+
self.relu3 = nn.ReLU(inplace=True)
|
110 |
+
|
111 |
+
# final conv: output 4 channels
|
112 |
+
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
113 |
+
|
114 |
+
# initialize weights
|
115 |
+
self._initialize_weights()
|
116 |
+
|
117 |
+
def _initialize_weights(self):
|
118 |
+
for m in self.modules():
|
119 |
+
if isinstance(m, nn.Conv2d):
|
120 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
121 |
+
if m.bias is not None:
|
122 |
+
nn.init.constant_(m.bias, 0)
|
123 |
+
elif isinstance(m, nn.BatchNorm2d):
|
124 |
+
nn.init.constant_(m.weight, 1)
|
125 |
+
nn.init.constant_(m.bias, 0)
|
126 |
+
elif isinstance(m, nn.Linear):
|
127 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
128 |
+
nn.init.constant_(m.bias, 0)
|
129 |
+
|
130 |
+
# initialize final conv weights to 0: 流行りのzero conv
|
131 |
+
nn.init.constant_(self.conv_final.weight, 0)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
inp = x
|
135 |
+
|
136 |
+
x = self.conv1(x)
|
137 |
+
x = self.bn1(x)
|
138 |
+
x = self.relu1(x)
|
139 |
+
|
140 |
+
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
|
141 |
+
residual = x
|
142 |
+
x = self.resblock1(x)
|
143 |
+
x = self.resblock2(x)
|
144 |
+
x = self.resblock3(x)
|
145 |
+
x = self.resblock4(x)
|
146 |
+
x = x + residual
|
147 |
+
residual = x
|
148 |
+
x = self.resblock5(x)
|
149 |
+
x = self.resblock6(x)
|
150 |
+
x = self.resblock7(x)
|
151 |
+
x = self.resblock8(x)
|
152 |
+
x = x + residual
|
153 |
+
residual = x
|
154 |
+
x = self.resblock9(x)
|
155 |
+
x = self.resblock10(x)
|
156 |
+
x = self.resblock11(x)
|
157 |
+
x = self.resblock12(x)
|
158 |
+
x = x + residual
|
159 |
+
residual = x
|
160 |
+
x = self.resblock13(x)
|
161 |
+
x = self.resblock14(x)
|
162 |
+
x = self.resblock15(x)
|
163 |
+
x = self.resblock16(x)
|
164 |
+
x = x + residual
|
165 |
+
residual = x
|
166 |
+
x = self.resblock17(x)
|
167 |
+
x = self.resblock18(x)
|
168 |
+
x = self.resblock19(x)
|
169 |
+
x = self.resblock20(x)
|
170 |
+
x = x + residual
|
171 |
+
|
172 |
+
x = self.conv2(x)
|
173 |
+
x = self.bn2(x)
|
174 |
+
x = self.relu2(x)
|
175 |
+
x = self.conv3(x)
|
176 |
+
x = self.bn3(x)
|
177 |
+
|
178 |
+
# ここにreluを入れないほうがいい気がする
|
179 |
+
|
180 |
+
x = self.conv_final(x)
|
181 |
+
|
182 |
+
# network estimates the difference between the input and the output
|
183 |
+
x = x + inp
|
184 |
+
|
185 |
+
return x
|
186 |
+
|
187 |
+
def support_latents(self) -> bool:
|
188 |
+
return False
|
189 |
+
|
190 |
+
def upscale(
|
191 |
+
self,
|
192 |
+
vae: AutoencoderKL,
|
193 |
+
lowreso_images: List[Image.Image],
|
194 |
+
lowreso_latents: torch.Tensor,
|
195 |
+
dtype: torch.dtype,
|
196 |
+
width: int,
|
197 |
+
height: int,
|
198 |
+
batch_size: int = 1,
|
199 |
+
vae_batch_size: int = 1,
|
200 |
+
):
|
201 |
+
# assertion
|
202 |
+
assert lowreso_images is not None, "Upscaler requires lowreso image"
|
203 |
+
|
204 |
+
# make upsampled image with lanczos4
|
205 |
+
upsampled_images = []
|
206 |
+
for lowreso_image in lowreso_images:
|
207 |
+
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
|
208 |
+
upsampled_images.append(upsampled_image)
|
209 |
+
|
210 |
+
# convert to tensor: this tensor is too large to be converted to cuda
|
211 |
+
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
|
212 |
+
upsampled_images = torch.stack(upsampled_images, dim=0)
|
213 |
+
upsampled_images = upsampled_images.to(dtype)
|
214 |
+
|
215 |
+
# normalize to [-1, 1]
|
216 |
+
upsampled_images = upsampled_images / 127.5 - 1.0
|
217 |
+
|
218 |
+
# convert upsample images to latents with batch size
|
219 |
+
# print("Encoding upsampled (LANCZOS4) images...")
|
220 |
+
upsampled_latents = []
|
221 |
+
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
222 |
+
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
223 |
+
with torch.no_grad():
|
224 |
+
batch = vae.encode(batch).latent_dist.sample()
|
225 |
+
upsampled_latents.append(batch)
|
226 |
+
|
227 |
+
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
228 |
+
|
229 |
+
# upscale (refine) latents with this model with batch size
|
230 |
+
print("Upscaling latents...")
|
231 |
+
upscaled_latents = []
|
232 |
+
for i in range(0, upsampled_latents.shape[0], batch_size):
|
233 |
+
with torch.no_grad():
|
234 |
+
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
|
235 |
+
upscaled_latents = torch.cat(upscaled_latents, dim=0)
|
236 |
+
|
237 |
+
return upscaled_latents * 0.18215
|
238 |
+
|
239 |
+
|
240 |
+
# external interface: returns a model
|
241 |
+
def create_upscaler(**kwargs):
|
242 |
+
weights = kwargs["weights"]
|
243 |
+
model = Upscaler()
|
244 |
+
|
245 |
+
print(f"Loading weights from {weights}...")
|
246 |
+
if os.path.splitext(weights)[1] == ".safetensors":
|
247 |
+
from safetensors.torch import load_file
|
248 |
+
|
249 |
+
sd = load_file(weights)
|
250 |
+
else:
|
251 |
+
sd = torch.load(weights, map_location=torch.device("cpu"))
|
252 |
+
model.load_state_dict(sd)
|
253 |
+
return model
|
254 |
+
|
255 |
+
|
256 |
+
# another interface: upscale images with a model for given images from command line
|
257 |
+
def upscale_images(args: argparse.Namespace):
|
258 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
259 |
+
us_dtype = torch.float16 # TODO: support fp32/bf16
|
260 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
261 |
+
|
262 |
+
# load VAE with Diffusers
|
263 |
+
assert args.vae_path is not None, "VAE path is required"
|
264 |
+
print(f"Loading VAE from {args.vae_path}...")
|
265 |
+
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
266 |
+
vae.to(DEVICE, dtype=us_dtype)
|
267 |
+
|
268 |
+
# prepare model
|
269 |
+
print("Preparing model...")
|
270 |
+
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
271 |
+
# print("Loading weights from", args.weights)
|
272 |
+
# upscaler.load_state_dict(torch.load(args.weights))
|
273 |
+
upscaler.eval()
|
274 |
+
upscaler.to(DEVICE, dtype=us_dtype)
|
275 |
+
|
276 |
+
# load images
|
277 |
+
image_paths = glob.glob(args.image_pattern)
|
278 |
+
images = []
|
279 |
+
for image_path in image_paths:
|
280 |
+
image = Image.open(image_path)
|
281 |
+
image = image.convert("RGB")
|
282 |
+
|
283 |
+
# make divisible by 8
|
284 |
+
width = image.width
|
285 |
+
height = image.height
|
286 |
+
if width % 8 != 0:
|
287 |
+
width = width - (width % 8)
|
288 |
+
if height % 8 != 0:
|
289 |
+
height = height - (height % 8)
|
290 |
+
if width != image.width or height != image.height:
|
291 |
+
image = image.crop((0, 0, width, height))
|
292 |
+
|
293 |
+
images.append(image)
|
294 |
+
|
295 |
+
# debug output
|
296 |
+
if args.debug:
|
297 |
+
for image, image_path in zip(images, image_paths):
|
298 |
+
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
|
299 |
+
|
300 |
+
basename = os.path.basename(image_path)
|
301 |
+
basename_wo_ext, ext = os.path.splitext(basename)
|
302 |
+
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
|
303 |
+
image_debug.save(dest_file_name)
|
304 |
+
|
305 |
+
# upscale
|
306 |
+
print("Upscaling...")
|
307 |
+
upscaled_latents = upscaler.upscale(
|
308 |
+
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
309 |
+
)
|
310 |
+
upscaled_latents /= 0.18215
|
311 |
+
|
312 |
+
# decode with batch
|
313 |
+
print("Decoding...")
|
314 |
+
upscaled_images = []
|
315 |
+
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
316 |
+
with torch.no_grad():
|
317 |
+
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
|
318 |
+
batch = batch.to("cpu")
|
319 |
+
upscaled_images.append(batch)
|
320 |
+
upscaled_images = torch.cat(upscaled_images, dim=0)
|
321 |
+
|
322 |
+
# tensor to numpy
|
323 |
+
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
|
324 |
+
upscaled_images = (upscaled_images + 1.0) * 127.5
|
325 |
+
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
|
326 |
+
|
327 |
+
upscaled_images = upscaled_images[..., ::-1]
|
328 |
+
|
329 |
+
# save images
|
330 |
+
for i, image in enumerate(upscaled_images):
|
331 |
+
basename = os.path.basename(image_paths[i])
|
332 |
+
basename_wo_ext, ext = os.path.splitext(basename)
|
333 |
+
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
|
334 |
+
cv2.imwrite(dest_file_name, image)
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == "__main__":
|
338 |
+
parser = argparse.ArgumentParser()
|
339 |
+
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
|
340 |
+
parser.add_argument("--weights", type=str, default=None, help="Weights path")
|
341 |
+
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
|
342 |
+
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
|
343 |
+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
344 |
+
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
|
345 |
+
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
346 |
+
|
347 |
+
args = parser.parse_args()
|
348 |
+
upscale_images(args)
|
external/llite/tools/merge_models.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from safetensors import safe_open
|
6 |
+
from safetensors.torch import load_file, save_file
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def is_unet_key(key):
|
11 |
+
# VAE or TextEncoder, the last one is for SDXL
|
12 |
+
return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key)
|
13 |
+
|
14 |
+
|
15 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
16 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
17 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
18 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
# support for models with different text encoder keys
|
23 |
+
def replace_text_encoder_key(key):
|
24 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
25 |
+
if key.startswith(rep_from):
|
26 |
+
return True, rep_to + key[len(rep_from) :]
|
27 |
+
return False, key
|
28 |
+
|
29 |
+
|
30 |
+
def merge(args):
|
31 |
+
if args.precision == "fp16":
|
32 |
+
dtype = torch.float16
|
33 |
+
elif args.precision == "bf16":
|
34 |
+
dtype = torch.bfloat16
|
35 |
+
else:
|
36 |
+
dtype = torch.float
|
37 |
+
|
38 |
+
if args.saving_precision == "fp16":
|
39 |
+
save_dtype = torch.float16
|
40 |
+
elif args.saving_precision == "bf16":
|
41 |
+
save_dtype = torch.bfloat16
|
42 |
+
else:
|
43 |
+
save_dtype = torch.float
|
44 |
+
|
45 |
+
# check if all models are safetensors
|
46 |
+
for model in args.models:
|
47 |
+
if not model.endswith("safetensors"):
|
48 |
+
print(f"Model {model} is not a safetensors model")
|
49 |
+
exit()
|
50 |
+
if not os.path.isfile(model):
|
51 |
+
print(f"Model {model} does not exist")
|
52 |
+
exit()
|
53 |
+
|
54 |
+
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
55 |
+
|
56 |
+
# load and merge
|
57 |
+
ratio = 1.0 / len(args.models) # default
|
58 |
+
supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later
|
59 |
+
|
60 |
+
merged_sd = None
|
61 |
+
first_model_keys = set() # check missing keys in other models
|
62 |
+
for i, model in enumerate(args.models):
|
63 |
+
if args.ratios is not None:
|
64 |
+
ratio = args.ratios[i]
|
65 |
+
|
66 |
+
if merged_sd is None:
|
67 |
+
# load first model
|
68 |
+
print(f"Loading model {model}, ratio = {ratio}...")
|
69 |
+
merged_sd = {}
|
70 |
+
with safe_open(model, framework="pt", device=args.device) as f:
|
71 |
+
for key in tqdm(f.keys()):
|
72 |
+
value = f.get_tensor(key)
|
73 |
+
_, key = replace_text_encoder_key(key)
|
74 |
+
|
75 |
+
first_model_keys.add(key)
|
76 |
+
|
77 |
+
if not is_unet_key(key) and args.unet_only:
|
78 |
+
supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder
|
79 |
+
continue
|
80 |
+
|
81 |
+
value = ratio * value.to(dtype) # first model's value * ratio
|
82 |
+
merged_sd[key] = value
|
83 |
+
|
84 |
+
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
85 |
+
continue
|
86 |
+
|
87 |
+
# load other models
|
88 |
+
print(f"Loading model {model}, ratio = {ratio}...")
|
89 |
+
|
90 |
+
with safe_open(model, framework="pt", device=args.device) as f:
|
91 |
+
model_keys = f.keys()
|
92 |
+
for key in tqdm(model_keys):
|
93 |
+
_, new_key = replace_text_encoder_key(key)
|
94 |
+
if new_key not in merged_sd:
|
95 |
+
if args.show_skipped and new_key not in first_model_keys:
|
96 |
+
print(f"Skip: {new_key}")
|
97 |
+
continue
|
98 |
+
|
99 |
+
value = f.get_tensor(key)
|
100 |
+
merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype)
|
101 |
+
|
102 |
+
# enumerate keys not in this model
|
103 |
+
model_keys = set(model_keys)
|
104 |
+
for key in merged_sd.keys():
|
105 |
+
if key in model_keys:
|
106 |
+
continue
|
107 |
+
print(f"Key {key} not in model {model}, use first model's value")
|
108 |
+
if key in supplementary_key_ratios:
|
109 |
+
supplementary_key_ratios[key] += ratio
|
110 |
+
else:
|
111 |
+
supplementary_key_ratios[key] = ratio
|
112 |
+
|
113 |
+
# add supplementary keys' value (including VAE and TextEncoder)
|
114 |
+
if len(supplementary_key_ratios) > 0:
|
115 |
+
print("add first model's value")
|
116 |
+
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
117 |
+
for key in tqdm(f.keys()):
|
118 |
+
_, new_key = replace_text_encoder_key(key)
|
119 |
+
if new_key not in supplementary_key_ratios:
|
120 |
+
continue
|
121 |
+
|
122 |
+
if is_unet_key(new_key): # not VAE or TextEncoder
|
123 |
+
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
124 |
+
|
125 |
+
value = f.get_tensor(key) # original key
|
126 |
+
|
127 |
+
if new_key not in merged_sd:
|
128 |
+
merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype)
|
129 |
+
else:
|
130 |
+
merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype)
|
131 |
+
|
132 |
+
# save
|
133 |
+
output_file = args.output
|
134 |
+
if not output_file.endswith(".safetensors"):
|
135 |
+
output_file = output_file + ".safetensors"
|
136 |
+
|
137 |
+
print(f"Saving to {output_file}...")
|
138 |
+
|
139 |
+
# convert to save_dtype
|
140 |
+
for k in merged_sd.keys():
|
141 |
+
merged_sd[k] = merged_sd[k].to(save_dtype)
|
142 |
+
|
143 |
+
save_file(merged_sd, output_file)
|
144 |
+
|
145 |
+
print("Done!")
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
parser = argparse.ArgumentParser(description="Merge models")
|
150 |
+
parser.add_argument("--models", nargs="+", type=str, help="Models to merge")
|
151 |
+
parser.add_argument("--output", type=str, help="Output model")
|
152 |
+
parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0")
|
153 |
+
parser.add_argument("--unet_only", action="store_true", help="Only merge unet")
|
154 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu")
|
155 |
+
parser.add_argument(
|
156 |
+
"--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float"
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--saving_precision",
|
160 |
+
type=str,
|
161 |
+
default="float",
|
162 |
+
choices=["float", "fp16", "bf16"],
|
163 |
+
help="Saving precision, default is float",
|
164 |
+
)
|
165 |
+
parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)")
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
merge(args)
|
external/llite/tools/original_control_net.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, NamedTuple, Any
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
|
7 |
+
from library.original_unet import UNet2DConditionModel, SampleOutput
|
8 |
+
|
9 |
+
import library.model_util as model_util
|
10 |
+
|
11 |
+
|
12 |
+
class ControlNetInfo(NamedTuple):
|
13 |
+
unet: Any
|
14 |
+
net: Any
|
15 |
+
prep: Any
|
16 |
+
weight: float
|
17 |
+
ratio: float
|
18 |
+
|
19 |
+
|
20 |
+
class ControlNet(torch.nn.Module):
|
21 |
+
def __init__(self) -> None:
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
# make control model
|
25 |
+
self.control_model = torch.nn.Module()
|
26 |
+
|
27 |
+
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
|
28 |
+
zero_convs = torch.nn.ModuleList()
|
29 |
+
for i, dim in enumerate(dims):
|
30 |
+
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
|
31 |
+
zero_convs.append(sub_list)
|
32 |
+
self.control_model.add_module("zero_convs", zero_convs)
|
33 |
+
|
34 |
+
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
|
35 |
+
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
|
36 |
+
|
37 |
+
dims = [16, 16, 32, 32, 96, 96, 256, 320]
|
38 |
+
strides = [1, 1, 2, 1, 2, 1, 2, 1]
|
39 |
+
prev_dim = 3
|
40 |
+
input_hint_block = torch.nn.Sequential()
|
41 |
+
for i, (dim, stride) in enumerate(zip(dims, strides)):
|
42 |
+
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
|
43 |
+
if i < len(dims) - 1:
|
44 |
+
input_hint_block.append(torch.nn.SiLU())
|
45 |
+
prev_dim = dim
|
46 |
+
self.control_model.add_module("input_hint_block", input_hint_block)
|
47 |
+
|
48 |
+
|
49 |
+
def load_control_net(v2, unet, model):
|
50 |
+
device = unet.device
|
51 |
+
|
52 |
+
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
53 |
+
# state dictを読み込む
|
54 |
+
print(f"ControlNet: loading control SD model : {model}")
|
55 |
+
|
56 |
+
if model_util.is_safetensors(model):
|
57 |
+
ctrl_sd_sd = load_file(model)
|
58 |
+
else:
|
59 |
+
ctrl_sd_sd = torch.load(model, map_location="cpu")
|
60 |
+
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
61 |
+
|
62 |
+
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
63 |
+
is_difference = "difference" in ctrl_sd_sd
|
64 |
+
print("ControlNet: loading difference:", is_difference)
|
65 |
+
|
66 |
+
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
67 |
+
# またTransfer Controlの元weightとなる
|
68 |
+
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
69 |
+
|
70 |
+
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
|
71 |
+
for key in list(ctrl_unet_sd_sd.keys()):
|
72 |
+
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
|
73 |
+
|
74 |
+
zero_conv_sd = {}
|
75 |
+
for key in list(ctrl_sd_sd.keys()):
|
76 |
+
if key.startswith("control_"):
|
77 |
+
unet_key = "model.diffusion_" + key[len("control_") :]
|
78 |
+
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
79 |
+
zero_conv_sd[key] = ctrl_sd_sd[key]
|
80 |
+
continue
|
81 |
+
if is_difference: # Transfer Control
|
82 |
+
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
83 |
+
else:
|
84 |
+
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
85 |
+
|
86 |
+
unet_config = model_util.create_unet_diffusers_config(v2)
|
87 |
+
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
|
88 |
+
|
89 |
+
# ControlNetのU-Netを作成する
|
90 |
+
ctrl_unet = UNet2DConditionModel(**unet_config)
|
91 |
+
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
92 |
+
print("ControlNet: loading Control U-Net:", info)
|
93 |
+
|
94 |
+
# U-Net以外のControlNetを作成する
|
95 |
+
# TODO support middle only
|
96 |
+
ctrl_net = ControlNet()
|
97 |
+
info = ctrl_net.load_state_dict(zero_conv_sd)
|
98 |
+
print("ControlNet: loading ControlNet:", info)
|
99 |
+
|
100 |
+
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
101 |
+
ctrl_net.to(unet.device, dtype=unet.dtype)
|
102 |
+
return ctrl_unet, ctrl_net
|
103 |
+
|
104 |
+
|
105 |
+
def load_preprocess(prep_type: str):
|
106 |
+
if prep_type is None or prep_type.lower() == "none":
|
107 |
+
return None
|
108 |
+
|
109 |
+
if prep_type.startswith("canny"):
|
110 |
+
args = prep_type.split("_")
|
111 |
+
th1 = int(args[1]) if len(args) >= 2 else 63
|
112 |
+
th2 = int(args[2]) if len(args) >= 3 else 191
|
113 |
+
|
114 |
+
def canny(img):
|
115 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
116 |
+
return cv2.Canny(img, th1, th2)
|
117 |
+
|
118 |
+
return canny
|
119 |
+
|
120 |
+
print("Unsupported prep type:", prep_type)
|
121 |
+
return None
|
122 |
+
|
123 |
+
|
124 |
+
def preprocess_ctrl_net_hint_image(image):
|
125 |
+
image = np.array(image).astype(np.float32) / 255.0
|
126 |
+
# ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
|
127 |
+
# image = image[:, :, ::-1].copy() # rgb to bgr
|
128 |
+
image = image[None].transpose(0, 3, 1, 2) # nchw
|
129 |
+
image = torch.from_numpy(image)
|
130 |
+
return image # 0 to 1
|
131 |
+
|
132 |
+
|
133 |
+
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
|
134 |
+
guided_hints = []
|
135 |
+
for i, cnet_info in enumerate(control_nets):
|
136 |
+
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
|
137 |
+
b_hints = []
|
138 |
+
if len(hints) == 1: # すべて同じ画像をhintとして使う
|
139 |
+
hint = hints[0]
|
140 |
+
if cnet_info.prep is not None:
|
141 |
+
hint = cnet_info.prep(hint)
|
142 |
+
hint = preprocess_ctrl_net_hint_image(hint)
|
143 |
+
b_hints = [hint for _ in range(b_size)]
|
144 |
+
else:
|
145 |
+
for bi in range(b_size):
|
146 |
+
hint = hints[(bi * len(control_nets) + i) % len(hints)]
|
147 |
+
if cnet_info.prep is not None:
|
148 |
+
hint = cnet_info.prep(hint)
|
149 |
+
hint = preprocess_ctrl_net_hint_image(hint)
|
150 |
+
b_hints.append(hint)
|
151 |
+
b_hints = torch.cat(b_hints, dim=0)
|
152 |
+
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
|
153 |
+
|
154 |
+
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
|
155 |
+
guided_hints.append(guided_hint)
|
156 |
+
return guided_hints
|
157 |
+
|
158 |
+
|
159 |
+
def call_unet_and_control_net(
|
160 |
+
step,
|
161 |
+
num_latent_input,
|
162 |
+
original_unet,
|
163 |
+
control_nets: List[ControlNetInfo],
|
164 |
+
guided_hints,
|
165 |
+
current_ratio,
|
166 |
+
sample,
|
167 |
+
timestep,
|
168 |
+
encoder_hidden_states,
|
169 |
+
encoder_hidden_states_for_control_net,
|
170 |
+
):
|
171 |
+
# ControlNet
|
172 |
+
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
173 |
+
cnet_cnt = len(control_nets)
|
174 |
+
cnet_idx = step % cnet_cnt
|
175 |
+
cnet_info = control_nets[cnet_idx]
|
176 |
+
|
177 |
+
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
178 |
+
if cnet_info.ratio < current_ratio:
|
179 |
+
return original_unet(sample, timestep, encoder_hidden_states)
|
180 |
+
|
181 |
+
guided_hint = guided_hints[cnet_idx]
|
182 |
+
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
183 |
+
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
|
184 |
+
outs = [o * cnet_info.weight for o in outs]
|
185 |
+
|
186 |
+
# U-Net
|
187 |
+
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
|
188 |
+
|
189 |
+
|
190 |
+
"""
|
191 |
+
# これはmergeのバージョン
|
192 |
+
# ControlNet
|
193 |
+
cnet_outs_list = []
|
194 |
+
for i, cnet_info in enumerate(control_nets):
|
195 |
+
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
196 |
+
if cnet_info.ratio < current_ratio:
|
197 |
+
continue
|
198 |
+
guided_hint = guided_hints[i]
|
199 |
+
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
200 |
+
for i in range(len(outs)):
|
201 |
+
outs[i] *= cnet_info.weight
|
202 |
+
|
203 |
+
cnet_outs_list.append(outs)
|
204 |
+
|
205 |
+
count = len(cnet_outs_list)
|
206 |
+
if count == 0:
|
207 |
+
return original_unet(sample, timestep, encoder_hidden_states)
|
208 |
+
|
209 |
+
# sum of controlnets
|
210 |
+
for i in range(1, count):
|
211 |
+
cnet_outs_list[0] += cnet_outs_list[i]
|
212 |
+
|
213 |
+
# U-Net
|
214 |
+
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
|
215 |
+
"""
|
216 |
+
|
217 |
+
|
218 |
+
def unet_forward(
|
219 |
+
is_control_net,
|
220 |
+
control_net: ControlNet,
|
221 |
+
unet: UNet2DConditionModel,
|
222 |
+
guided_hint,
|
223 |
+
ctrl_outs,
|
224 |
+
sample,
|
225 |
+
timestep,
|
226 |
+
encoder_hidden_states,
|
227 |
+
):
|
228 |
+
# copy from UNet2DConditionModel
|
229 |
+
default_overall_up_factor = 2**unet.num_upsamplers
|
230 |
+
|
231 |
+
forward_upsample_size = False
|
232 |
+
upsample_size = None
|
233 |
+
|
234 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
235 |
+
print("Forward upsample size to force interpolation output size.")
|
236 |
+
forward_upsample_size = True
|
237 |
+
|
238 |
+
# 1. time
|
239 |
+
timesteps = timestep
|
240 |
+
if not torch.is_tensor(timesteps):
|
241 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
242 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
243 |
+
is_mps = sample.device.type == "mps"
|
244 |
+
if isinstance(timestep, float):
|
245 |
+
dtype = torch.float32 if is_mps else torch.float64
|
246 |
+
else:
|
247 |
+
dtype = torch.int32 if is_mps else torch.int64
|
248 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
249 |
+
elif len(timesteps.shape) == 0:
|
250 |
+
timesteps = timesteps[None].to(sample.device)
|
251 |
+
|
252 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
253 |
+
timesteps = timesteps.expand(sample.shape[0])
|
254 |
+
|
255 |
+
t_emb = unet.time_proj(timesteps)
|
256 |
+
|
257 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
258 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
259 |
+
# there might be better ways to encapsulate this.
|
260 |
+
t_emb = t_emb.to(dtype=unet.dtype)
|
261 |
+
emb = unet.time_embedding(t_emb)
|
262 |
+
|
263 |
+
outs = [] # output of ControlNet
|
264 |
+
zc_idx = 0
|
265 |
+
|
266 |
+
# 2. pre-process
|
267 |
+
sample = unet.conv_in(sample)
|
268 |
+
if is_control_net:
|
269 |
+
sample += guided_hint
|
270 |
+
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
|
271 |
+
zc_idx += 1
|
272 |
+
|
273 |
+
# 3. down
|
274 |
+
down_block_res_samples = (sample,)
|
275 |
+
for downsample_block in unet.down_blocks:
|
276 |
+
if downsample_block.has_cross_attention:
|
277 |
+
sample, res_samples = downsample_block(
|
278 |
+
hidden_states=sample,
|
279 |
+
temb=emb,
|
280 |
+
encoder_hidden_states=encoder_hidden_states,
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
284 |
+
if is_control_net:
|
285 |
+
for rs in res_samples:
|
286 |
+
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
|
287 |
+
zc_idx += 1
|
288 |
+
|
289 |
+
down_block_res_samples += res_samples
|
290 |
+
|
291 |
+
# 4. mid
|
292 |
+
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
293 |
+
if is_control_net:
|
294 |
+
outs.append(control_net.control_model.middle_block_out[0](sample))
|
295 |
+
return outs
|
296 |
+
|
297 |
+
if not is_control_net:
|
298 |
+
sample += ctrl_outs.pop()
|
299 |
+
|
300 |
+
# 5. up
|
301 |
+
for i, upsample_block in enumerate(unet.up_blocks):
|
302 |
+
is_final_block = i == len(unet.up_blocks) - 1
|
303 |
+
|
304 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
305 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
306 |
+
|
307 |
+
if not is_control_net and len(ctrl_outs) > 0:
|
308 |
+
res_samples = list(res_samples)
|
309 |
+
apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
|
310 |
+
ctrl_outs = ctrl_outs[: -len(res_samples)]
|
311 |
+
for j in range(len(res_samples)):
|
312 |
+
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
313 |
+
res_samples = tuple(res_samples)
|
314 |
+
|
315 |
+
# if we have not reached the final block and need to forward the
|
316 |
+
# upsample size, we do it here
|
317 |
+
if not is_final_block and forward_upsample_size:
|
318 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
319 |
+
|
320 |
+
if upsample_block.has_cross_attention:
|
321 |
+
sample = upsample_block(
|
322 |
+
hidden_states=sample,
|
323 |
+
temb=emb,
|
324 |
+
res_hidden_states_tuple=res_samples,
|
325 |
+
encoder_hidden_states=encoder_hidden_states,
|
326 |
+
upsample_size=upsample_size,
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
sample = upsample_block(
|
330 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
331 |
+
)
|
332 |
+
# 6. post-process
|
333 |
+
sample = unet.conv_norm_out(sample)
|
334 |
+
sample = unet.conv_act(sample)
|
335 |
+
sample = unet.conv_out(sample)
|
336 |
+
|
337 |
+
return SampleOutput(sample=sample)
|
external/llite/tools/resize_images_to_resolution.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import shutil
|
6 |
+
import math
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
12 |
+
# Split the max_resolution string by "," and strip any whitespaces
|
13 |
+
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
14 |
+
|
15 |
+
# # Calculate max_pixels from max_resolution string
|
16 |
+
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
17 |
+
|
18 |
+
# Create destination folder if it does not exist
|
19 |
+
if not os.path.exists(dst_img_folder):
|
20 |
+
os.makedirs(dst_img_folder)
|
21 |
+
|
22 |
+
# Select interpolation method
|
23 |
+
if interpolation == 'lanczos4':
|
24 |
+
cv2_interpolation = cv2.INTER_LANCZOS4
|
25 |
+
elif interpolation == 'cubic':
|
26 |
+
cv2_interpolation = cv2.INTER_CUBIC
|
27 |
+
else:
|
28 |
+
cv2_interpolation = cv2.INTER_AREA
|
29 |
+
|
30 |
+
# Iterate through all files in src_img_folder
|
31 |
+
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
32 |
+
for filename in os.listdir(src_img_folder):
|
33 |
+
# Check if the image is png, jpg or webp etc...
|
34 |
+
if not filename.endswith(img_exts):
|
35 |
+
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
|
36 |
+
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
37 |
+
continue
|
38 |
+
|
39 |
+
# Load image
|
40 |
+
# img = cv2.imread(os.path.join(src_img_folder, filename))
|
41 |
+
image = Image.open(os.path.join(src_img_folder, filename))
|
42 |
+
if not image.mode == "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
img = np.array(image, np.uint8)
|
45 |
+
|
46 |
+
base, _ = os.path.splitext(filename)
|
47 |
+
for max_resolution in max_resolutions:
|
48 |
+
# Calculate max_pixels from max_resolution string
|
49 |
+
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
50 |
+
|
51 |
+
# Calculate current number of pixels
|
52 |
+
current_pixels = img.shape[0] * img.shape[1]
|
53 |
+
|
54 |
+
# Check if the image needs resizing
|
55 |
+
if current_pixels > max_pixels:
|
56 |
+
# Calculate scaling factor
|
57 |
+
scale_factor = max_pixels / current_pixels
|
58 |
+
|
59 |
+
# Calculate new dimensions
|
60 |
+
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
61 |
+
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
62 |
+
|
63 |
+
# Resize image
|
64 |
+
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
65 |
+
else:
|
66 |
+
new_height, new_width = img.shape[0:2]
|
67 |
+
|
68 |
+
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
|
69 |
+
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
70 |
+
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
71 |
+
|
72 |
+
# Center crop the image to the calculated dimensions
|
73 |
+
y = int((img.shape[0] - new_height) / 2)
|
74 |
+
x = int((img.shape[1] - new_width) / 2)
|
75 |
+
img = img[y:y + new_height, x:x + new_width]
|
76 |
+
|
77 |
+
# Split filename into base and extension
|
78 |
+
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
|
79 |
+
|
80 |
+
# Save resized image in dst_img_folder
|
81 |
+
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
82 |
+
image = Image.fromarray(img)
|
83 |
+
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
84 |
+
|
85 |
+
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
86 |
+
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
87 |
+
|
88 |
+
# If other files with same basename, copy them with resolution suffix
|
89 |
+
if copy_associated_files:
|
90 |
+
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
|
91 |
+
for asoc_file in asoc_files:
|
92 |
+
ext = os.path.splitext(asoc_file)[1]
|
93 |
+
if ext in img_exts:
|
94 |
+
continue
|
95 |
+
for max_resolution in max_resolutions:
|
96 |
+
new_asoc_file = base + '+' + max_resolution + ext
|
97 |
+
print(f"Copy {asoc_file} as {new_asoc_file}")
|
98 |
+
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
99 |
+
|
100 |
+
|
101 |
+
def setup_parser() -> argparse.ArgumentParser:
|
102 |
+
parser = argparse.ArgumentParser(
|
103 |
+
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
104 |
+
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
105 |
+
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
|
106 |
+
parser.add_argument('--max_resolution', type=str,
|
107 |
+
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
108 |
+
parser.add_argument('--divisible_by', type=int,
|
109 |
+
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
110 |
+
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
111 |
+
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
112 |
+
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
113 |
+
parser.add_argument('--copy_associated_files', action='store_true',
|
114 |
+
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
115 |
+
|
116 |
+
return parser
|
117 |
+
|
118 |
+
|
119 |
+
def main():
|
120 |
+
parser = setup_parser()
|
121 |
+
|
122 |
+
args = parser.parse_args()
|
123 |
+
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
124 |
+
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
main()
|
external/llite/tools/show_metadata.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
from safetensors import safe_open
|
4 |
+
|
5 |
+
parser = argparse.ArgumentParser()
|
6 |
+
parser.add_argument("--model", type=str, required=True)
|
7 |
+
args = parser.parse_args()
|
8 |
+
|
9 |
+
with safe_open(args.model, framework="pt") as f:
|
10 |
+
metadata = f.metadata()
|
11 |
+
|
12 |
+
if metadata is None:
|
13 |
+
print("No metadata found")
|
14 |
+
else:
|
15 |
+
# metadata is json dict, but not pretty printed
|
16 |
+
# sort by key and pretty print
|
17 |
+
print(json.dumps(metadata, indent=4, sort_keys=True))
|
18 |
+
|
19 |
+
|
inference.py
CHANGED
@@ -468,24 +468,46 @@ def img2img(task: Task):
|
|
468 |
|
469 |
width, height = get_intermediate_dimension(task)
|
470 |
|
471 |
-
lora_patcher = lora_style.get_patcher(
|
472 |
-
[img2img_pipe.pipe, high_res.pipe], task.get_style()
|
473 |
-
)
|
474 |
-
lora_patcher.patch()
|
475 |
-
|
476 |
torch.manual_seed(task.get_seed())
|
477 |
|
478 |
-
|
479 |
-
|
480 |
-
"
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
if task.get_high_res_fix():
|
491 |
kwargs = {
|
|
|
468 |
|
469 |
width, height = get_intermediate_dimension(task)
|
470 |
|
|
|
|
|
|
|
|
|
|
|
471 |
torch.manual_seed(task.get_seed())
|
472 |
|
473 |
+
if get_is_sdxl():
|
474 |
+
# we run lineart for img2img
|
475 |
+
controlnet.load_model("linearart")
|
476 |
+
|
477 |
+
lora_patcher = lora_style.get_patcher(
|
478 |
+
[controlnet.pipe2, high_res.pipe], task.get_style()
|
479 |
+
)
|
480 |
+
lora_patcher.patch()
|
481 |
+
|
482 |
+
kwargs = {
|
483 |
+
"imageUrl": task.get_imageUrl(),
|
484 |
+
"seed": task.get_seed(),
|
485 |
+
"num_inference_steps": task.get_steps(),
|
486 |
+
"width": width,
|
487 |
+
"height": height,
|
488 |
+
"prompt": prompt,
|
489 |
+
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
|
490 |
+
**task.cnl_kwargs(),
|
491 |
+
"adapter_conditioning_scale": 0.3,
|
492 |
+
}
|
493 |
+
images, has_nsfw = controlnet.process(**kwargs)
|
494 |
+
else:
|
495 |
+
lora_patcher = lora_style.get_patcher(
|
496 |
+
[img2img_pipe.pipe, high_res.pipe], task.get_style()
|
497 |
+
)
|
498 |
+
lora_patcher.patch()
|
499 |
+
|
500 |
+
kwargs = {
|
501 |
+
"prompt": prompt,
|
502 |
+
"imageUrl": task.get_imageUrl(),
|
503 |
+
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
|
504 |
+
"num_inference_steps": task.get_steps(),
|
505 |
+
"width": width,
|
506 |
+
"height": height,
|
507 |
+
**task.i2i_kwargs(),
|
508 |
+
**lora_patcher.kwargs(),
|
509 |
+
}
|
510 |
+
images, has_nsfw = img2img_pipe.process(**kwargs)
|
511 |
|
512 |
if task.get_high_res_fix():
|
513 |
kwargs = {
|