Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Modified version for full net lora
|
3 |
+
(Lora for ResBlock and up/down sample block)
|
4 |
+
'''
|
5 |
+
import os, sys
|
6 |
+
import re
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from modules import shared, devices, sd_models
|
10 |
+
import lora
|
11 |
+
from locon_compvis import LoConModule, LoConNetworkCompvis, create_network_and_apply_compvis
|
12 |
+
|
13 |
+
|
14 |
+
try:
|
15 |
+
'''
|
16 |
+
Hijack Additional Network extension
|
17 |
+
'''
|
18 |
+
# skip addnet since don't support new version
|
19 |
+
raise
|
20 |
+
now_dir = os.path.dirname(os.path.abspath(__file__))
|
21 |
+
addnet_path = os.path.join(now_dir, '..', '..', 'sd-webui-additional-networks/scripts')
|
22 |
+
sys.path.append(addnet_path)
|
23 |
+
import lora_compvis
|
24 |
+
import scripts
|
25 |
+
scripts.lora_compvis = lora_compvis
|
26 |
+
scripts.lora_compvis.LoRAModule = LoConModule
|
27 |
+
scripts.lora_compvis.LoRANetworkCompvis = LoConNetworkCompvis
|
28 |
+
scripts.lora_compvis.create_network_and_apply_compvis = create_network_and_apply_compvis
|
29 |
+
print('LoCon Extension hijack addnet extension successfully')
|
30 |
+
except:
|
31 |
+
print('Additional Network extension not installed, Only hijack built-in lora')
|
32 |
+
|
33 |
+
|
34 |
+
'''
|
35 |
+
Hijack sd-webui LoRA
|
36 |
+
'''
|
37 |
+
re_digits = re.compile(r"\d+")
|
38 |
+
|
39 |
+
re_unet_conv_in = re.compile(r"lora_unet_conv_in(.+)")
|
40 |
+
re_unet_conv_out = re.compile(r"lora_unet_conv_out(.+)")
|
41 |
+
re_unet_time_embed = re.compile(r"lora_unet_time_embedding_linear_(\d+)(.+)")
|
42 |
+
|
43 |
+
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
44 |
+
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
45 |
+
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
46 |
+
|
47 |
+
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
|
48 |
+
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
|
49 |
+
re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)")
|
50 |
+
|
51 |
+
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
|
52 |
+
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
|
53 |
+
|
54 |
+
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
55 |
+
|
56 |
+
|
57 |
+
def convert_diffusers_name_to_compvis(key):
|
58 |
+
def match(match_list, regex):
|
59 |
+
r = re.match(regex, key)
|
60 |
+
if not r:
|
61 |
+
return False
|
62 |
+
|
63 |
+
match_list.clear()
|
64 |
+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
65 |
+
return True
|
66 |
+
|
67 |
+
m = []
|
68 |
+
|
69 |
+
if match(m, re_unet_conv_in):
|
70 |
+
return f'diffusion_model_input_blocks_0_0{m[0]}'
|
71 |
+
|
72 |
+
if match(m, re_unet_conv_out):
|
73 |
+
return f'diffusion_model_out_2{m[0]}'
|
74 |
+
|
75 |
+
if match(m, re_unet_time_embed):
|
76 |
+
return f"diffusion_model_time_embed_{m[0]*2-2}{m[1]}"
|
77 |
+
|
78 |
+
if match(m, re_unet_down_blocks):
|
79 |
+
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
80 |
+
|
81 |
+
if match(m, re_unet_mid_blocks):
|
82 |
+
return f"diffusion_model_middle_block_1_{m[1]}"
|
83 |
+
|
84 |
+
if match(m, re_unet_up_blocks):
|
85 |
+
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
86 |
+
|
87 |
+
if match(m, re_unet_down_blocks_res):
|
88 |
+
block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_"
|
89 |
+
if m[2].startswith('conv1'):
|
90 |
+
return f"{block}in_layers_2{m[2][len('conv1'):]}"
|
91 |
+
elif m[2].startswith('conv2'):
|
92 |
+
return f"{block}out_layers_3{m[2][len('conv2'):]}"
|
93 |
+
elif m[2].startswith('time_emb_proj'):
|
94 |
+
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
|
95 |
+
elif m[2].startswith('conv_shortcut'):
|
96 |
+
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
|
97 |
+
|
98 |
+
if match(m, re_unet_mid_blocks_res):
|
99 |
+
block = f"diffusion_model_middle_block_{m[0]*2}_"
|
100 |
+
if m[1].startswith('conv1'):
|
101 |
+
return f"{block}in_layers_2{m[1][len('conv1'):]}"
|
102 |
+
elif m[1].startswith('conv2'):
|
103 |
+
return f"{block}out_layers_3{m[1][len('conv2'):]}"
|
104 |
+
elif m[1].startswith('time_emb_proj'):
|
105 |
+
return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}"
|
106 |
+
elif m[1].startswith('conv_shortcut'):
|
107 |
+
return f"{block}skip_connection{m[1][len('conv_shortcut'):]}"
|
108 |
+
|
109 |
+
if match(m, re_unet_up_blocks_res):
|
110 |
+
block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_"
|
111 |
+
if m[2].startswith('conv1'):
|
112 |
+
return f"{block}in_layers_2{m[2][len('conv1'):]}"
|
113 |
+
elif m[2].startswith('conv2'):
|
114 |
+
return f"{block}out_layers_3{m[2][len('conv2'):]}"
|
115 |
+
elif m[2].startswith('time_emb_proj'):
|
116 |
+
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
|
117 |
+
elif m[2].startswith('conv_shortcut'):
|
118 |
+
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
|
119 |
+
|
120 |
+
if match(m, re_unet_downsample):
|
121 |
+
return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}"
|
122 |
+
|
123 |
+
if match(m, re_unet_upsample):
|
124 |
+
return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}"
|
125 |
+
|
126 |
+
if match(m, re_text_block):
|
127 |
+
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
128 |
+
|
129 |
+
return key
|
130 |
+
|
131 |
+
|
132 |
+
class LoraOnDisk:
|
133 |
+
def __init__(self, name, filename):
|
134 |
+
self.name = name
|
135 |
+
self.filename = filename
|
136 |
+
|
137 |
+
|
138 |
+
class LoraModule:
|
139 |
+
def __init__(self, name):
|
140 |
+
self.name = name
|
141 |
+
self.multiplier = 1.0
|
142 |
+
self.modules = {}
|
143 |
+
self.mtime = None
|
144 |
+
|
145 |
+
|
146 |
+
class FakeModule(torch.nn.Module):
|
147 |
+
def __init__(self, weight, func):
|
148 |
+
super().__init__()
|
149 |
+
self.weight = weight
|
150 |
+
self.func = func
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
return self.func(x)
|
154 |
+
|
155 |
+
|
156 |
+
class FullModule:
|
157 |
+
def __init__(self):
|
158 |
+
self.weight = None
|
159 |
+
self.alpha = None
|
160 |
+
self.op = None
|
161 |
+
self.extra_args = {}
|
162 |
+
self.shape = None
|
163 |
+
self.up = None
|
164 |
+
|
165 |
+
def down(self, x):
|
166 |
+
return x
|
167 |
+
|
168 |
+
def inference(self, x):
|
169 |
+
return self.op(x, self.weight, **self.extra_args)
|
170 |
+
|
171 |
+
|
172 |
+
class LoraUpDownModule:
|
173 |
+
def __init__(self):
|
174 |
+
self.up_model = None
|
175 |
+
self.mid_model = None
|
176 |
+
self.down_model = None
|
177 |
+
self.alpha = None
|
178 |
+
self.dim = None
|
179 |
+
self.op = None
|
180 |
+
self.extra_args = {}
|
181 |
+
self.shape = None
|
182 |
+
self.bias = None
|
183 |
+
self.up = None
|
184 |
+
|
185 |
+
def down(self, x):
|
186 |
+
return x
|
187 |
+
|
188 |
+
def inference(self, x):
|
189 |
+
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
190 |
+
out_dim = self.up_model.weight.size(0)
|
191 |
+
rank = self.down_model.weight.size(0)
|
192 |
+
rebuild_weight = (
|
193 |
+
self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
|
194 |
+
+ self.bias
|
195 |
+
).reshape(self.shape)
|
196 |
+
return self.op(
|
197 |
+
x, rebuild_weight,
|
198 |
+
**self.extra_args
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
if self.mid_model is None:
|
202 |
+
return self.up_model(self.down_model(x))
|
203 |
+
else:
|
204 |
+
return self.up_model(self.mid_model(self.down_model(x)))
|
205 |
+
|
206 |
+
|
207 |
+
def pro3(t, wa, wb):
|
208 |
+
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
209 |
+
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
210 |
+
|
211 |
+
|
212 |
+
class LoraHadaModule:
|
213 |
+
def __init__(self):
|
214 |
+
self.t1 = None
|
215 |
+
self.w1a = None
|
216 |
+
self.w1b = None
|
217 |
+
self.t2 = None
|
218 |
+
self.w2a = None
|
219 |
+
self.w2b = None
|
220 |
+
self.alpha = None
|
221 |
+
self.dim = None
|
222 |
+
self.op = None
|
223 |
+
self.extra_args = {}
|
224 |
+
self.shape = None
|
225 |
+
self.bias = None
|
226 |
+
self.up = None
|
227 |
+
|
228 |
+
def down(self, x):
|
229 |
+
return x
|
230 |
+
|
231 |
+
def inference(self, x):
|
232 |
+
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
233 |
+
bias = self.bias
|
234 |
+
else:
|
235 |
+
bias = 0
|
236 |
+
|
237 |
+
if self.t1 is None:
|
238 |
+
return self.op(
|
239 |
+
x,
|
240 |
+
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
|
241 |
+
**self.extra_args
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
return self.op(
|
245 |
+
x,
|
246 |
+
(pro3(self.t1, self.w1a, self.w1b)
|
247 |
+
* pro3(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
|
248 |
+
**self.extra_args
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
CON_KEY = {
|
253 |
+
"lora_up.weight",
|
254 |
+
"lora_down.weight",
|
255 |
+
"lora_mid.weight"
|
256 |
+
}
|
257 |
+
HADA_KEY = {
|
258 |
+
"hada_t1",
|
259 |
+
"hada_w1_a",
|
260 |
+
"hada_w1_b",
|
261 |
+
"hada_t2",
|
262 |
+
"hada_w2_a",
|
263 |
+
"hada_w2_b",
|
264 |
+
}
|
265 |
+
|
266 |
+
def load_lora(name, filename):
|
267 |
+
lora = LoraModule(name)
|
268 |
+
lora.mtime = os.path.getmtime(filename)
|
269 |
+
|
270 |
+
sd = sd_models.read_state_dict(filename)
|
271 |
+
|
272 |
+
keys_failed_to_match = []
|
273 |
+
|
274 |
+
for key_diffusers, weight in sd.items():
|
275 |
+
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
276 |
+
key, lora_key = fullkey.split(".", 1)
|
277 |
+
|
278 |
+
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
279 |
+
if sd_module is None:
|
280 |
+
keys_failed_to_match.append(key_diffusers)
|
281 |
+
continue
|
282 |
+
|
283 |
+
lora_module = lora.modules.get(key, None)
|
284 |
+
if lora_module is None:
|
285 |
+
lora_module = LoraUpDownModule()
|
286 |
+
lora.modules[key] = lora_module
|
287 |
+
|
288 |
+
if lora_key == "alpha":
|
289 |
+
lora_module.alpha = weight.item()
|
290 |
+
continue
|
291 |
+
|
292 |
+
if lora_key == "diff":
|
293 |
+
weight = weight.to(device=devices.device, dtype=devices.dtype)
|
294 |
+
weight.requires_grad_(False)
|
295 |
+
lora_module = FullModule()
|
296 |
+
lora.modules[key] = lora_module
|
297 |
+
lora_module.weight = weight
|
298 |
+
lora_module.alpha = weight.size(1)
|
299 |
+
lora_module.up = FakeModule(
|
300 |
+
weight,
|
301 |
+
lora_module.inference
|
302 |
+
)
|
303 |
+
lora_module.up.to(device=devices.device, dtype=devices.dtype)
|
304 |
+
if len(weight.shape)==2:
|
305 |
+
lora_module.op = torch.nn.functional.linear
|
306 |
+
lora_module.extra_args = {
|
307 |
+
'bias': None
|
308 |
+
}
|
309 |
+
else:
|
310 |
+
lora_module.op = torch.nn.functional.conv2d
|
311 |
+
lora_module.extra_args = {
|
312 |
+
'stride': sd_module.stride,
|
313 |
+
'padding': sd_module.padding,
|
314 |
+
'bias': None
|
315 |
+
}
|
316 |
+
continue
|
317 |
+
|
318 |
+
if 'bias_' in lora_key:
|
319 |
+
if lora_module.bias is None:
|
320 |
+
lora_module.bias = [None, None, None]
|
321 |
+
if 'bias_indices' == lora_key:
|
322 |
+
lora_module.bias[0] = weight
|
323 |
+
elif 'bias_values' == lora_key:
|
324 |
+
lora_module.bias[1] = weight
|
325 |
+
elif 'bias_size' == lora_key:
|
326 |
+
lora_module.bias[2] = weight
|
327 |
+
|
328 |
+
if all((i is not None) for i in lora_module.bias):
|
329 |
+
print('build bias')
|
330 |
+
lora_module.bias = torch.sparse_coo_tensor(
|
331 |
+
lora_module.bias[0],
|
332 |
+
lora_module.bias[1],
|
333 |
+
tuple(lora_module.bias[2]),
|
334 |
+
).to(device=devices.device, dtype=devices.dtype)
|
335 |
+
lora_module.bias.requires_grad_(False)
|
336 |
+
continue
|
337 |
+
|
338 |
+
if lora_key in CON_KEY:
|
339 |
+
if type(sd_module) == torch.nn.Linear:
|
340 |
+
weight = weight.reshape(weight.shape[0], -1)
|
341 |
+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
342 |
+
lora_module.op = torch.nn.functional.linear
|
343 |
+
elif type(sd_module) == torch.nn.Conv2d:
|
344 |
+
if lora_key == "lora_down.weight":
|
345 |
+
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
346 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
|
347 |
+
else:
|
348 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
349 |
+
elif lora_key == "lora_mid.weight":
|
350 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
|
351 |
+
elif lora_key == "lora_up.weight":
|
352 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
353 |
+
lora_module.op = torch.nn.functional.conv2d
|
354 |
+
lora_module.extra_args = {
|
355 |
+
'stride': sd_module.stride,
|
356 |
+
'padding': sd_module.padding
|
357 |
+
}
|
358 |
+
else:
|
359 |
+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
360 |
+
|
361 |
+
lora_module.shape = sd_module.weight.shape
|
362 |
+
with torch.no_grad():
|
363 |
+
module.weight.copy_(weight)
|
364 |
+
|
365 |
+
module.to(device=devices.device, dtype=devices.dtype)
|
366 |
+
module.requires_grad_(False)
|
367 |
+
|
368 |
+
if lora_key == "lora_up.weight":
|
369 |
+
lora_module.up_model = module
|
370 |
+
lora_module.up = FakeModule(
|
371 |
+
lora_module.up_model.weight,
|
372 |
+
lora_module.inference
|
373 |
+
)
|
374 |
+
elif lora_key == "lora_mid.weight":
|
375 |
+
lora_module.mid_model = module
|
376 |
+
elif lora_key == "lora_down.weight":
|
377 |
+
lora_module.down_model = module
|
378 |
+
lora_module.dim = weight.shape[0]
|
379 |
+
elif lora_key in HADA_KEY:
|
380 |
+
if type(lora_module) != LoraHadaModule:
|
381 |
+
alpha = lora_module.alpha
|
382 |
+
bias = lora_module.bias
|
383 |
+
lora_module = LoraHadaModule()
|
384 |
+
lora_module.alpha = alpha
|
385 |
+
lora_module.bias = bias
|
386 |
+
lora.modules[key] = lora_module
|
387 |
+
lora_module.shape = sd_module.weight.shape
|
388 |
+
|
389 |
+
weight = weight.to(device=devices.device, dtype=devices.dtype)
|
390 |
+
weight.requires_grad_(False)
|
391 |
+
|
392 |
+
if lora_key == 'hada_w1_a':
|
393 |
+
lora_module.w1a = weight
|
394 |
+
if lora_module.up is None:
|
395 |
+
lora_module.up = FakeModule(
|
396 |
+
lora_module.w1a,
|
397 |
+
lora_module.inference
|
398 |
+
)
|
399 |
+
elif lora_key == 'hada_w1_b':
|
400 |
+
lora_module.w1b = weight
|
401 |
+
lora_module.dim = weight.shape[0]
|
402 |
+
elif lora_key == 'hada_w2_a':
|
403 |
+
lora_module.w2a = weight
|
404 |
+
elif lora_key == 'hada_w2_b':
|
405 |
+
lora_module.w2b = weight
|
406 |
+
elif lora_key == 'hada_t1':
|
407 |
+
lora_module.t1 = weight
|
408 |
+
lora_module.up = FakeModule(
|
409 |
+
lora_module.t1,
|
410 |
+
lora_module.inference
|
411 |
+
)
|
412 |
+
elif lora_key == 'hada_t2':
|
413 |
+
lora_module.t2 = weight
|
414 |
+
|
415 |
+
if type(sd_module) == torch.nn.Linear:
|
416 |
+
lora_module.op = torch.nn.functional.linear
|
417 |
+
elif type(sd_module) == torch.nn.Conv2d:
|
418 |
+
lora_module.op = torch.nn.functional.conv2d
|
419 |
+
lora_module.extra_args = {
|
420 |
+
'stride': sd_module.stride,
|
421 |
+
'padding': sd_module.padding
|
422 |
+
}
|
423 |
+
else:
|
424 |
+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
425 |
+
|
426 |
+
else:
|
427 |
+
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
428 |
+
|
429 |
+
if len(keys_failed_to_match) > 0:
|
430 |
+
print(shared.sd_model.lora_layer_mapping)
|
431 |
+
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
432 |
+
|
433 |
+
return lora
|
434 |
+
|
435 |
+
|
436 |
+
def lora_forward(module, input, res):
|
437 |
+
if len(lora.loaded_loras) == 0:
|
438 |
+
return res
|
439 |
+
|
440 |
+
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
441 |
+
for lora_m in lora.loaded_loras:
|
442 |
+
module = lora_m.modules.get(lora_layer_name, None)
|
443 |
+
if module is not None and lora_m.multiplier:
|
444 |
+
if hasattr(module, 'up'):
|
445 |
+
scale = lora_m.multiplier * (module.alpha / module.up.weight.size(1) if module.alpha else 1.0)
|
446 |
+
else:
|
447 |
+
scale = lora_m.multiplier * (module.alpha / module.dim if module.alpha else 1.0)
|
448 |
+
|
449 |
+
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
450 |
+
x = res
|
451 |
+
else:
|
452 |
+
x = input
|
453 |
+
|
454 |
+
if hasattr(module, 'inference'):
|
455 |
+
res = res + module.inference(x) * scale
|
456 |
+
elif hasattr(module, 'up'):
|
457 |
+
res = res + module.up(module.down(x)) * scale
|
458 |
+
else:
|
459 |
+
raise NotImplementedError(
|
460 |
+
"Your settings, extensions or models are not compatible with each other."
|
461 |
+
)
|
462 |
+
return res
|
463 |
+
|
464 |
+
|
465 |
+
lora.convert_diffusers_name_to_compvis = convert_diffusers_name_to_compvis
|
466 |
+
lora.load_lora = load_lora
|
467 |
+
lora.lora_forward = lora_forward
|
468 |
+
print('LoCon Extension hijack built-in lora successfully')
|