supertori commited on
Commit
fe41012
1 Parent(s): 43d9b29

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +468 -0
main.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Modified version for full net lora
3
+ (Lora for ResBlock and up/down sample block)
4
+ '''
5
+ import os, sys
6
+ import re
7
+ import torch
8
+
9
+ from modules import shared, devices, sd_models
10
+ import lora
11
+ from locon_compvis import LoConModule, LoConNetworkCompvis, create_network_and_apply_compvis
12
+
13
+
14
+ try:
15
+ '''
16
+ Hijack Additional Network extension
17
+ '''
18
+ # skip addnet since don't support new version
19
+ raise
20
+ now_dir = os.path.dirname(os.path.abspath(__file__))
21
+ addnet_path = os.path.join(now_dir, '..', '..', 'sd-webui-additional-networks/scripts')
22
+ sys.path.append(addnet_path)
23
+ import lora_compvis
24
+ import scripts
25
+ scripts.lora_compvis = lora_compvis
26
+ scripts.lora_compvis.LoRAModule = LoConModule
27
+ scripts.lora_compvis.LoRANetworkCompvis = LoConNetworkCompvis
28
+ scripts.lora_compvis.create_network_and_apply_compvis = create_network_and_apply_compvis
29
+ print('LoCon Extension hijack addnet extension successfully')
30
+ except:
31
+ print('Additional Network extension not installed, Only hijack built-in lora')
32
+
33
+
34
+ '''
35
+ Hijack sd-webui LoRA
36
+ '''
37
+ re_digits = re.compile(r"\d+")
38
+
39
+ re_unet_conv_in = re.compile(r"lora_unet_conv_in(.+)")
40
+ re_unet_conv_out = re.compile(r"lora_unet_conv_out(.+)")
41
+ re_unet_time_embed = re.compile(r"lora_unet_time_embedding_linear_(\d+)(.+)")
42
+
43
+ re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
44
+ re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
45
+ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
46
+
47
+ re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
48
+ re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
49
+ re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)")
50
+
51
+ re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
52
+ re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
53
+
54
+ re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
55
+
56
+
57
+ def convert_diffusers_name_to_compvis(key):
58
+ def match(match_list, regex):
59
+ r = re.match(regex, key)
60
+ if not r:
61
+ return False
62
+
63
+ match_list.clear()
64
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
65
+ return True
66
+
67
+ m = []
68
+
69
+ if match(m, re_unet_conv_in):
70
+ return f'diffusion_model_input_blocks_0_0{m[0]}'
71
+
72
+ if match(m, re_unet_conv_out):
73
+ return f'diffusion_model_out_2{m[0]}'
74
+
75
+ if match(m, re_unet_time_embed):
76
+ return f"diffusion_model_time_embed_{m[0]*2-2}{m[1]}"
77
+
78
+ if match(m, re_unet_down_blocks):
79
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
80
+
81
+ if match(m, re_unet_mid_blocks):
82
+ return f"diffusion_model_middle_block_1_{m[1]}"
83
+
84
+ if match(m, re_unet_up_blocks):
85
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
86
+
87
+ if match(m, re_unet_down_blocks_res):
88
+ block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_"
89
+ if m[2].startswith('conv1'):
90
+ return f"{block}in_layers_2{m[2][len('conv1'):]}"
91
+ elif m[2].startswith('conv2'):
92
+ return f"{block}out_layers_3{m[2][len('conv2'):]}"
93
+ elif m[2].startswith('time_emb_proj'):
94
+ return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
95
+ elif m[2].startswith('conv_shortcut'):
96
+ return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
97
+
98
+ if match(m, re_unet_mid_blocks_res):
99
+ block = f"diffusion_model_middle_block_{m[0]*2}_"
100
+ if m[1].startswith('conv1'):
101
+ return f"{block}in_layers_2{m[1][len('conv1'):]}"
102
+ elif m[1].startswith('conv2'):
103
+ return f"{block}out_layers_3{m[1][len('conv2'):]}"
104
+ elif m[1].startswith('time_emb_proj'):
105
+ return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}"
106
+ elif m[1].startswith('conv_shortcut'):
107
+ return f"{block}skip_connection{m[1][len('conv_shortcut'):]}"
108
+
109
+ if match(m, re_unet_up_blocks_res):
110
+ block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_"
111
+ if m[2].startswith('conv1'):
112
+ return f"{block}in_layers_2{m[2][len('conv1'):]}"
113
+ elif m[2].startswith('conv2'):
114
+ return f"{block}out_layers_3{m[2][len('conv2'):]}"
115
+ elif m[2].startswith('time_emb_proj'):
116
+ return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
117
+ elif m[2].startswith('conv_shortcut'):
118
+ return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
119
+
120
+ if match(m, re_unet_downsample):
121
+ return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}"
122
+
123
+ if match(m, re_unet_upsample):
124
+ return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}"
125
+
126
+ if match(m, re_text_block):
127
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
128
+
129
+ return key
130
+
131
+
132
+ class LoraOnDisk:
133
+ def __init__(self, name, filename):
134
+ self.name = name
135
+ self.filename = filename
136
+
137
+
138
+ class LoraModule:
139
+ def __init__(self, name):
140
+ self.name = name
141
+ self.multiplier = 1.0
142
+ self.modules = {}
143
+ self.mtime = None
144
+
145
+
146
+ class FakeModule(torch.nn.Module):
147
+ def __init__(self, weight, func):
148
+ super().__init__()
149
+ self.weight = weight
150
+ self.func = func
151
+
152
+ def forward(self, x):
153
+ return self.func(x)
154
+
155
+
156
+ class FullModule:
157
+ def __init__(self):
158
+ self.weight = None
159
+ self.alpha = None
160
+ self.op = None
161
+ self.extra_args = {}
162
+ self.shape = None
163
+ self.up = None
164
+
165
+ def down(self, x):
166
+ return x
167
+
168
+ def inference(self, x):
169
+ return self.op(x, self.weight, **self.extra_args)
170
+
171
+
172
+ class LoraUpDownModule:
173
+ def __init__(self):
174
+ self.up_model = None
175
+ self.mid_model = None
176
+ self.down_model = None
177
+ self.alpha = None
178
+ self.dim = None
179
+ self.op = None
180
+ self.extra_args = {}
181
+ self.shape = None
182
+ self.bias = None
183
+ self.up = None
184
+
185
+ def down(self, x):
186
+ return x
187
+
188
+ def inference(self, x):
189
+ if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
190
+ out_dim = self.up_model.weight.size(0)
191
+ rank = self.down_model.weight.size(0)
192
+ rebuild_weight = (
193
+ self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
194
+ + self.bias
195
+ ).reshape(self.shape)
196
+ return self.op(
197
+ x, rebuild_weight,
198
+ **self.extra_args
199
+ )
200
+ else:
201
+ if self.mid_model is None:
202
+ return self.up_model(self.down_model(x))
203
+ else:
204
+ return self.up_model(self.mid_model(self.down_model(x)))
205
+
206
+
207
+ def pro3(t, wa, wb):
208
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
209
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
210
+
211
+
212
+ class LoraHadaModule:
213
+ def __init__(self):
214
+ self.t1 = None
215
+ self.w1a = None
216
+ self.w1b = None
217
+ self.t2 = None
218
+ self.w2a = None
219
+ self.w2b = None
220
+ self.alpha = None
221
+ self.dim = None
222
+ self.op = None
223
+ self.extra_args = {}
224
+ self.shape = None
225
+ self.bias = None
226
+ self.up = None
227
+
228
+ def down(self, x):
229
+ return x
230
+
231
+ def inference(self, x):
232
+ if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
233
+ bias = self.bias
234
+ else:
235
+ bias = 0
236
+
237
+ if self.t1 is None:
238
+ return self.op(
239
+ x,
240
+ ((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
241
+ **self.extra_args
242
+ )
243
+ else:
244
+ return self.op(
245
+ x,
246
+ (pro3(self.t1, self.w1a, self.w1b)
247
+ * pro3(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
248
+ **self.extra_args
249
+ )
250
+
251
+
252
+ CON_KEY = {
253
+ "lora_up.weight",
254
+ "lora_down.weight",
255
+ "lora_mid.weight"
256
+ }
257
+ HADA_KEY = {
258
+ "hada_t1",
259
+ "hada_w1_a",
260
+ "hada_w1_b",
261
+ "hada_t2",
262
+ "hada_w2_a",
263
+ "hada_w2_b",
264
+ }
265
+
266
+ def load_lora(name, filename):
267
+ lora = LoraModule(name)
268
+ lora.mtime = os.path.getmtime(filename)
269
+
270
+ sd = sd_models.read_state_dict(filename)
271
+
272
+ keys_failed_to_match = []
273
+
274
+ for key_diffusers, weight in sd.items():
275
+ fullkey = convert_diffusers_name_to_compvis(key_diffusers)
276
+ key, lora_key = fullkey.split(".", 1)
277
+
278
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
279
+ if sd_module is None:
280
+ keys_failed_to_match.append(key_diffusers)
281
+ continue
282
+
283
+ lora_module = lora.modules.get(key, None)
284
+ if lora_module is None:
285
+ lora_module = LoraUpDownModule()
286
+ lora.modules[key] = lora_module
287
+
288
+ if lora_key == "alpha":
289
+ lora_module.alpha = weight.item()
290
+ continue
291
+
292
+ if lora_key == "diff":
293
+ weight = weight.to(device=devices.device, dtype=devices.dtype)
294
+ weight.requires_grad_(False)
295
+ lora_module = FullModule()
296
+ lora.modules[key] = lora_module
297
+ lora_module.weight = weight
298
+ lora_module.alpha = weight.size(1)
299
+ lora_module.up = FakeModule(
300
+ weight,
301
+ lora_module.inference
302
+ )
303
+ lora_module.up.to(device=devices.device, dtype=devices.dtype)
304
+ if len(weight.shape)==2:
305
+ lora_module.op = torch.nn.functional.linear
306
+ lora_module.extra_args = {
307
+ 'bias': None
308
+ }
309
+ else:
310
+ lora_module.op = torch.nn.functional.conv2d
311
+ lora_module.extra_args = {
312
+ 'stride': sd_module.stride,
313
+ 'padding': sd_module.padding,
314
+ 'bias': None
315
+ }
316
+ continue
317
+
318
+ if 'bias_' in lora_key:
319
+ if lora_module.bias is None:
320
+ lora_module.bias = [None, None, None]
321
+ if 'bias_indices' == lora_key:
322
+ lora_module.bias[0] = weight
323
+ elif 'bias_values' == lora_key:
324
+ lora_module.bias[1] = weight
325
+ elif 'bias_size' == lora_key:
326
+ lora_module.bias[2] = weight
327
+
328
+ if all((i is not None) for i in lora_module.bias):
329
+ print('build bias')
330
+ lora_module.bias = torch.sparse_coo_tensor(
331
+ lora_module.bias[0],
332
+ lora_module.bias[1],
333
+ tuple(lora_module.bias[2]),
334
+ ).to(device=devices.device, dtype=devices.dtype)
335
+ lora_module.bias.requires_grad_(False)
336
+ continue
337
+
338
+ if lora_key in CON_KEY:
339
+ if type(sd_module) == torch.nn.Linear:
340
+ weight = weight.reshape(weight.shape[0], -1)
341
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
342
+ lora_module.op = torch.nn.functional.linear
343
+ elif type(sd_module) == torch.nn.Conv2d:
344
+ if lora_key == "lora_down.weight":
345
+ if weight.shape[2] != 1 or weight.shape[3] != 1:
346
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
347
+ else:
348
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
349
+ elif lora_key == "lora_mid.weight":
350
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
351
+ elif lora_key == "lora_up.weight":
352
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
353
+ lora_module.op = torch.nn.functional.conv2d
354
+ lora_module.extra_args = {
355
+ 'stride': sd_module.stride,
356
+ 'padding': sd_module.padding
357
+ }
358
+ else:
359
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
360
+
361
+ lora_module.shape = sd_module.weight.shape
362
+ with torch.no_grad():
363
+ module.weight.copy_(weight)
364
+
365
+ module.to(device=devices.device, dtype=devices.dtype)
366
+ module.requires_grad_(False)
367
+
368
+ if lora_key == "lora_up.weight":
369
+ lora_module.up_model = module
370
+ lora_module.up = FakeModule(
371
+ lora_module.up_model.weight,
372
+ lora_module.inference
373
+ )
374
+ elif lora_key == "lora_mid.weight":
375
+ lora_module.mid_model = module
376
+ elif lora_key == "lora_down.weight":
377
+ lora_module.down_model = module
378
+ lora_module.dim = weight.shape[0]
379
+ elif lora_key in HADA_KEY:
380
+ if type(lora_module) != LoraHadaModule:
381
+ alpha = lora_module.alpha
382
+ bias = lora_module.bias
383
+ lora_module = LoraHadaModule()
384
+ lora_module.alpha = alpha
385
+ lora_module.bias = bias
386
+ lora.modules[key] = lora_module
387
+ lora_module.shape = sd_module.weight.shape
388
+
389
+ weight = weight.to(device=devices.device, dtype=devices.dtype)
390
+ weight.requires_grad_(False)
391
+
392
+ if lora_key == 'hada_w1_a':
393
+ lora_module.w1a = weight
394
+ if lora_module.up is None:
395
+ lora_module.up = FakeModule(
396
+ lora_module.w1a,
397
+ lora_module.inference
398
+ )
399
+ elif lora_key == 'hada_w1_b':
400
+ lora_module.w1b = weight
401
+ lora_module.dim = weight.shape[0]
402
+ elif lora_key == 'hada_w2_a':
403
+ lora_module.w2a = weight
404
+ elif lora_key == 'hada_w2_b':
405
+ lora_module.w2b = weight
406
+ elif lora_key == 'hada_t1':
407
+ lora_module.t1 = weight
408
+ lora_module.up = FakeModule(
409
+ lora_module.t1,
410
+ lora_module.inference
411
+ )
412
+ elif lora_key == 'hada_t2':
413
+ lora_module.t2 = weight
414
+
415
+ if type(sd_module) == torch.nn.Linear:
416
+ lora_module.op = torch.nn.functional.linear
417
+ elif type(sd_module) == torch.nn.Conv2d:
418
+ lora_module.op = torch.nn.functional.conv2d
419
+ lora_module.extra_args = {
420
+ 'stride': sd_module.stride,
421
+ 'padding': sd_module.padding
422
+ }
423
+ else:
424
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
425
+
426
+ else:
427
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
428
+
429
+ if len(keys_failed_to_match) > 0:
430
+ print(shared.sd_model.lora_layer_mapping)
431
+ print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
432
+
433
+ return lora
434
+
435
+
436
+ def lora_forward(module, input, res):
437
+ if len(lora.loaded_loras) == 0:
438
+ return res
439
+
440
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
441
+ for lora_m in lora.loaded_loras:
442
+ module = lora_m.modules.get(lora_layer_name, None)
443
+ if module is not None and lora_m.multiplier:
444
+ if hasattr(module, 'up'):
445
+ scale = lora_m.multiplier * (module.alpha / module.up.weight.size(1) if module.alpha else 1.0)
446
+ else:
447
+ scale = lora_m.multiplier * (module.alpha / module.dim if module.alpha else 1.0)
448
+
449
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
450
+ x = res
451
+ else:
452
+ x = input
453
+
454
+ if hasattr(module, 'inference'):
455
+ res = res + module.inference(x) * scale
456
+ elif hasattr(module, 'up'):
457
+ res = res + module.up(module.down(x)) * scale
458
+ else:
459
+ raise NotImplementedError(
460
+ "Your settings, extensions or models are not compatible with each other."
461
+ )
462
+ return res
463
+
464
+
465
+ lora.convert_diffusers_name_to_compvis = convert_diffusers_name_to_compvis
466
+ lora.load_lora = load_lora
467
+ lora.lora_forward = lora_forward
468
+ print('LoCon Extension hijack built-in lora successfully')