supertori commited on
Commit
43d9b29
1 Parent(s): d43d2a2

Upload lora_block_weight.py

Browse files
Files changed (1) hide show
  1. lora_block_weight.py +277 -64
lora_block_weight.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import cv2
3
  import os
4
  import re
@@ -9,6 +8,7 @@ import numpy as np
9
  import gradio as gr
10
  import os.path
11
  import random
 
12
  import modules.ui
13
  import modules.scripts as scripts
14
  from PIL import Image, ImageFont, ImageDraw
@@ -21,23 +21,32 @@ from modules.processing import process_images, Processed
21
  lxyz = ""
22
  lzyx = ""
23
 
24
- LORABLOCKS=["encoder",
25
- "down_blocks_0_attentions_0",
26
- "down_blocks_0_attentions_1",
27
- "down_blocks_1_attentions_0",
28
- "down_blocks_1_attentions_1",
29
- "down_blocks_2_attentions_0",
30
- "down_blocks_2_attentions_1",
31
- "mid_block_attentions_0",
32
- "up_blocks_1_attentions_0",
33
- "up_blocks_1_attentions_1",
34
- "up_blocks_1_attentions_2",
35
- "up_blocks_2_attentions_0",
36
- "up_blocks_2_attentions_1",
37
- "up_blocks_2_attentions_2",
38
- "up_blocks_3_attentions_0",
39
- "up_blocks_3_attentions_1",
40
- "up_blocks_3_attentions_2"]
 
 
 
 
 
 
 
 
 
41
 
42
  loopstopper = True
43
 
@@ -95,7 +104,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
95
  with gr.Column(scale=5):
96
  bw_ratiotags= gr.TextArea(label="",lines=2,value=rasiostags,visible =True,interactive =True,elem_id="lbw_ratios")
97
  with gr.Accordion("XYZ plot",open = False):
98
- gr.HTML(value="<p>changeable blocks : BASE,IN01,IN02,IN04,IN05,IN07,IN08,M00,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>")
99
  xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
100
  with gr.Row(visible = False) as esets:
101
  diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
@@ -109,7 +118,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
109
  zmen = gr.Textbox(label="Z values ",lines=1,value="",interactive =True,elem_id="lbw_zmen")
110
 
111
  exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
112
- eymen = gr.Textbox(label="Blocks" ,lines=1,value="BASE,IN01,IN02,IN04,IN05,IN07,IN08,M00,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
113
 
114
  with gr.Accordion("Weights setting",open = True):
115
  with gr.Row():
@@ -138,7 +147,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
138
  if ":" in l :
139
  key = l.split(":",1)[0]
140
  w = l.split(":",1)[1]
141
- if len([w for w in w.split(",")]) == 17:
142
  wdict[key.strip()]=w
143
  return ",".join(list(wdict.keys()))
144
 
@@ -211,7 +220,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
211
  xmen,ymen = exmen,eymen
212
  xtype,ytype = "values","ID"
213
  ebase = xmen.split(",")[1]
214
- ebase = [ebase.strip()]*17
215
  base = ",".join(ebase)
216
  ztype = ""
217
 
@@ -244,15 +253,17 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
244
  images = []
245
 
246
  def weightsdealer(alpha,ids,base):
247
- blockid=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
 
248
  #print(f"weights from : {base}")
249
  ids = [z.strip() for z in ids.split(' ')]
250
  weights_t = [w.strip() for w in base.split(',')]
 
251
  if ids[0]!="NOT":
252
- flagger=[False]*17
253
  changer = True
254
  else:
255
- flagger=[True]*17
256
  changer = False
257
  for id in ids:
258
  if id =="NOT":continue
@@ -357,7 +368,7 @@ def loradealer(p,lratios):
357
  lorars = []
358
  for called in calledloras:
359
  if len(called.items) <3:continue
360
- if called.items[2] in lratios or called.items[2].count(",") ==16:
361
  lorans.append(called.items[0])
362
  wei = lratios[called.items[2]] if called.items[2] in lratios else called.items[2]
363
  multiple = called.items[1]
@@ -370,14 +381,16 @@ def loradealer(p,lratios):
370
  else:
371
  ratios[i] = float(r)
372
  print(f"LoRA Block weight :{called.items[0]}: {ratios}")
 
 
373
  lorars.append(ratios)
374
  if len(lorars) > 0: load_loras_blocks(lorans,lorars,multiple)
375
 
376
  re_digits = re.compile(r"\d+")
 
377
  re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
378
  re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
379
  re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
380
- re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
381
 
382
  re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
383
  re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
@@ -386,6 +399,9 @@ re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+
386
  re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
387
  re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
388
 
 
 
 
389
  def convert_diffusers_name_to_compvis(key):
390
  def match(match_list, regex):
391
  r = re.match(regex, key)
@@ -451,75 +467,272 @@ def convert_diffusers_name_to_compvis(key):
451
 
452
  return key
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  def load_lora(name, filename,lwei):
455
- import lora
456
- locallora = lora.LoraModule(name)
457
- locallora.mtime = os.path.getmtime(filename)
458
 
459
  sd = sd_models.read_state_dict(filename)
460
 
461
  keys_failed_to_match = []
 
462
 
463
  for key_diffusers, weight in sd.items():
464
  ratio = 1
 
465
 
466
- for i,block in enumerate(LORABLOCKS):
467
- if block in key_diffusers:
468
- ratio = lwei[i]
469
-
470
- weight =weight *math.sqrt(abs(ratio))
471
-
472
  fullkey = convert_diffusers_name_to_compvis(key_diffusers)
473
- #print(key_diffusers+":"+fullkey+"x" + str(ratio))
474
  key, lora_key = fullkey.split(".", 1)
475
 
 
 
 
 
 
 
 
 
 
476
  sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
477
  if sd_module is None:
478
  keys_failed_to_match.append(key_diffusers)
479
  continue
480
 
481
- lora_module = locallora.modules.get(key, None)
482
  if lora_module is None:
483
- lora_module = lora.LoraUpDownModule()
484
- locallora.modules[key] = lora_module
485
 
486
  if lora_key == "alpha":
487
  lora_module.alpha = weight.item()
488
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
- if type(sd_module) == torch.nn.Linear:
491
- weight = weight.reshape(weight.shape[0], -1)
492
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
493
- elif type(sd_module) == torch.nn.Conv2d:
494
- if lora_key == "lora_down.weight":
495
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
496
- elif lora_key == "lora_up.weight":
497
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
498
- else:
499
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
500
-
501
- fugou = 1 if ratio >0 else -1
502
 
503
- if lora_key == "lora_up.weight":
504
  with torch.no_grad():
505
  module.weight.copy_(weight*fugou)
506
- else:
507
- with torch.no_grad():
508
- module.weight.copy_(weight)
509
-
510
- module.to(device=devices.device, dtype=devices.dtype)
511
 
512
- if lora_key == "lora_up.weight":
513
- lora_module.up = module
514
- elif lora_key == "lora_down.weight":
515
- lora_module.down = module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  else:
517
  assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
518
 
519
  if len(keys_failed_to_match) > 0:
520
  print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
521
 
522
- return locallora
 
 
 
523
 
524
 
525
  def load_loras_blocks(names, lwei=None,multi=1.0):
@@ -640,4 +853,4 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
640
  else:
641
  outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
642
  ss = ["source",ss[0],"diff"]
643
- return outs,ls,ss
 
 
1
  import cv2
2
  import os
3
  import re
 
8
  import gradio as gr
9
  import os.path
10
  import random
11
+ from pprint import pprint
12
  import modules.ui
13
  import modules.scripts as scripts
14
  from PIL import Image, ImageFont, ImageDraw
 
21
  lxyz = ""
22
  lzyx = ""
23
 
24
+ BLOCKS=["encoder",
25
+ "diffusion_model_input_blocks_0_",
26
+ "diffusion_model_input_blocks_1_",
27
+ "diffusion_model_input_blocks_2_",
28
+ "diffusion_model_input_blocks_3_",
29
+ "diffusion_model_input_blocks_4_",
30
+ "diffusion_model_input_blocks_5_",
31
+ "diffusion_model_input_blocks_6_",
32
+ "diffusion_model_input_blocks_7_",
33
+ "diffusion_model_input_blocks_8_",
34
+ "diffusion_model_input_blocks_9_",
35
+ "diffusion_model_input_blocks_10_",
36
+ "diffusion_model_input_blocks_11_",
37
+ "diffusion_model_middle_block_",
38
+ "diffusion_model_output_blocks_0_",
39
+ "diffusion_model_output_blocks_1_",
40
+ "diffusion_model_output_blocks_2_",
41
+ "diffusion_model_output_blocks_3_",
42
+ "diffusion_model_output_blocks_4_",
43
+ "diffusion_model_output_blocks_5_",
44
+ "diffusion_model_output_blocks_6_",
45
+ "diffusion_model_output_blocks_7_",
46
+ "diffusion_model_output_blocks_8_",
47
+ "diffusion_model_output_blocks_9_",
48
+ "diffusion_model_output_blocks_10_",
49
+ "diffusion_model_output_blocks_11_"]
50
 
51
  loopstopper = True
52
 
 
104
  with gr.Column(scale=5):
105
  bw_ratiotags= gr.TextArea(label="",lines=2,value=rasiostags,visible =True,interactive =True,elem_id="lbw_ratios")
106
  with gr.Accordion("XYZ plot",open = False):
107
+ gr.HTML(value="<p>changeable blocks : BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>")
108
  xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
109
  with gr.Row(visible = False) as esets:
110
  diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
 
118
  zmen = gr.Textbox(label="Z values ",lines=1,value="",interactive =True,elem_id="lbw_zmen")
119
 
120
  exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
121
+ eymen = gr.Textbox(label="Blocks" ,lines=1,value="BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
122
 
123
  with gr.Accordion("Weights setting",open = True):
124
  with gr.Row():
 
147
  if ":" in l :
148
  key = l.split(":",1)[0]
149
  w = l.split(":",1)[1]
150
+ if len([w for w in w.split(",")]) == 17 or len([w for w in w.split(",")]) ==26:
151
  wdict[key.strip()]=w
152
  return ",".join(list(wdict.keys()))
153
 
 
220
  xmen,ymen = exmen,eymen
221
  xtype,ytype = "values","ID"
222
  ebase = xmen.split(",")[1]
223
+ ebase = [ebase.strip()]*26
224
  base = ",".join(ebase)
225
  ztype = ""
226
 
 
253
  images = []
254
 
255
  def weightsdealer(alpha,ids,base):
256
+ blockid17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
257
+ blockid26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
258
  #print(f"weights from : {base}")
259
  ids = [z.strip() for z in ids.split(' ')]
260
  weights_t = [w.strip() for w in base.split(',')]
261
+ blockid = blockid17 if len(weights_t) ==17 else blockid26
262
  if ids[0]!="NOT":
263
+ flagger=[False]*len(weights_t)
264
  changer = True
265
  else:
266
+ flagger=[True]*len(weights_t)
267
  changer = False
268
  for id in ids:
269
  if id =="NOT":continue
 
368
  lorars = []
369
  for called in calledloras:
370
  if len(called.items) <3:continue
371
+ if called.items[2] in lratios or called.items[2].count(",") ==16 or called.items[2].count(",") ==25:
372
  lorans.append(called.items[0])
373
  wei = lratios[called.items[2]] if called.items[2] in lratios else called.items[2]
374
  multiple = called.items[1]
 
381
  else:
382
  ratios[i] = float(r)
383
  print(f"LoRA Block weight :{called.items[0]}: {ratios}")
384
+ if len(ratios)==17:
385
+ ratios = [ratios[0]] + [1] + ratios[1:3]+ [1] + ratios[3:5]+[1] + ratios[5:7]+[1,1,1] + [ratios[7]] + [1,1,1] + ratios[8:]
386
  lorars.append(ratios)
387
  if len(lorars) > 0: load_loras_blocks(lorans,lorars,multiple)
388
 
389
  re_digits = re.compile(r"\d+")
390
+
391
  re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
392
  re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
393
  re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
 
394
 
395
  re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
396
  re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
 
399
  re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
400
  re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
401
 
402
+ re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
403
+
404
+
405
  def convert_diffusers_name_to_compvis(key):
406
  def match(match_list, regex):
407
  r = re.match(regex, key)
 
467
 
468
  return key
469
 
470
+ class FakeModule(torch.nn.Module):
471
+ def __init__(self, weight, func):
472
+ super().__init__()
473
+ self.weight = weight
474
+ self.func = func
475
+
476
+ def forward(self, x):
477
+ return self.func(x)
478
+
479
+
480
+ class LoraUpDownModule:
481
+ def __init__(self):
482
+ self.up_model = None
483
+ self.mid_model = None
484
+ self.down_model = None
485
+ self.alpha = None
486
+ self.dim = None
487
+ self.op = None
488
+ self.extra_args = {}
489
+ self.shape = None
490
+ self.bias = None
491
+ self.up = None
492
+
493
+ def down(self, x):
494
+ return x
495
+
496
+ def inference(self, x):
497
+ if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
498
+ out_dim = self.up_model.weight.size(0)
499
+ rank = self.down_model.weight.size(0)
500
+ rebuild_weight = (
501
+ self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
502
+ + self.bias
503
+ ).reshape(self.shape)
504
+ return self.op(
505
+ x, rebuild_weight,
506
+ **self.extra_args
507
+ )
508
+ else:
509
+ if self.mid_model is None:
510
+ return self.up_model(self.down_model(x))
511
+ else:
512
+ return self.up_model(self.mid_model(self.down_model(x)))
513
+
514
+
515
+ def pro3(t, wa, wb):
516
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
517
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
518
+
519
+
520
+ class LoraHadaModule:
521
+ def __init__(self):
522
+ self.t1 = None
523
+ self.w1a = None
524
+ self.w1b = None
525
+ self.t2 = None
526
+ self.w2a = None
527
+ self.w2b = None
528
+ self.alpha = None
529
+ self.dim = None
530
+ self.op = None
531
+ self.extra_args = {}
532
+ self.shape = None
533
+ self.bias = None
534
+ self.up = None
535
+
536
+ def down(self, x):
537
+ return x
538
+
539
+ def inference(self, x):
540
+ if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
541
+ bias = self.bias
542
+ else:
543
+ bias = 0
544
+
545
+ if self.t1 is None:
546
+ return self.op(
547
+ x,
548
+ ((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
549
+ **self.extra_args
550
+ )
551
+ else:
552
+ return self.op(
553
+ x,
554
+ (pro3(self.t1, self.w1a, self.w1b)
555
+ * pro3(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
556
+ **self.extra_args
557
+ )
558
+
559
+
560
+ CON_KEY = {
561
+ "lora_up.weight",
562
+ "lora_down.weight",
563
+ "lora_mid.weight"
564
+ }
565
+ HADA_KEY = {
566
+ "hada_t1",
567
+ "hada_w1_a",
568
+ "hada_w1_b",
569
+ "hada_t2",
570
+ "hada_w2_a",
571
+ "hada_w2_b",
572
+ }
573
+
574
+
575
  def load_lora(name, filename,lwei):
576
+ import lora as lora_o
577
+ lora = lora_o.LoraModule(name)
578
+ lora.mtime = os.path.getmtime(filename)
579
 
580
  sd = sd_models.read_state_dict(filename)
581
 
582
  keys_failed_to_match = []
583
+ keys_failed_to_match_lbw = []
584
 
585
  for key_diffusers, weight in sd.items():
586
  ratio = 1
587
+ picked = False
588
 
 
 
 
 
 
 
589
  fullkey = convert_diffusers_name_to_compvis(key_diffusers)
 
590
  key, lora_key = fullkey.split(".", 1)
591
 
592
+ for i,block in enumerate(BLOCKS):
593
+ if block in key:
594
+ ratio = lwei[i]
595
+ picked = True
596
+
597
+ if not picked:keys_failed_to_match_lbw.append(key_diffusers)
598
+
599
+ weight = weight * math.sqrt(abs(ratio))
600
+
601
  sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
602
  if sd_module is None:
603
  keys_failed_to_match.append(key_diffusers)
604
  continue
605
 
606
+ lora_module = lora.modules.get(key, None)
607
  if lora_module is None:
608
+ lora_module = LoraUpDownModule()
609
+ lora.modules[key] = lora_module
610
 
611
  if lora_key == "alpha":
612
  lora_module.alpha = weight.item()
613
  continue
614
+
615
+ if 'bias_' in lora_key:
616
+ if lora_module.bias is None:
617
+ lora_module.bias = [None, None, None]
618
+ if 'bias_indices' == lora_key:
619
+ lora_module.bias[0] = weight
620
+ elif 'bias_values' == lora_key:
621
+ lora_module.bias[1] = weight
622
+ elif 'bias_size' == lora_key:
623
+ lora_module.bias[2] = weight
624
+
625
+ if all((i is not None) for i in lora_module.bias):
626
+ print('build bias')
627
+ lora_module.bias = torch.sparse_coo_tensor(
628
+ lora_module.bias[0],
629
+ lora_module.bias[1],
630
+ tuple(lora_module.bias[2]),
631
+ ).to(device=devices.device, dtype=devices.dtype)
632
+ lora_module.bias.requires_grad_(False)
633
+ continue
634
+
635
+ if lora_key in CON_KEY:
636
+ if type(sd_module) == torch.nn.Linear:
637
+ weight = weight.reshape(weight.shape[0], -1)
638
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
639
+ lora_module.op = torch.nn.functional.linear
640
+ elif type(sd_module) == torch.nn.Conv2d:
641
+ if lora_key == "lora_down.weight":
642
+ if weight.shape[2] != 1 or weight.shape[3] != 1:
643
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
644
+ else:
645
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
646
+ elif lora_key == "lora_mid.weight":
647
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
648
+ elif lora_key == "lora_up.weight":
649
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
650
+ lora_module.op = torch.nn.functional.conv2d
651
+ lora_module.extra_args = {
652
+ 'stride': sd_module.stride,
653
+ 'padding': sd_module.padding
654
+ }
655
+ else:
656
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
657
+
658
+ lora_module.shape = sd_module.weight.shape
659
 
660
+ fugou = np.sign(ratio) if lora_key == "lora_up.weight" else 1
 
 
 
 
 
 
 
 
 
 
 
661
 
 
662
  with torch.no_grad():
663
  module.weight.copy_(weight*fugou)
 
 
 
 
 
664
 
665
+ module.to(device=devices.device, dtype=devices.dtype)
666
+ module.requires_grad_(False)
667
+
668
+ if lora_key == "lora_up.weight":
669
+ lora_module.up_model = module
670
+ lora_module.up = FakeModule(
671
+ lora_module.up_model.weight,
672
+ lora_module.inference
673
+ )
674
+ elif lora_key == "lora_mid.weight":
675
+ lora_module.mid_model = module
676
+ elif lora_key == "lora_down.weight":
677
+ lora_module.down_model = module
678
+ lora_module.dim = weight.shape[0]
679
+ elif lora_key in HADA_KEY:
680
+ if type(lora_module) != LoraHadaModule:
681
+ alpha = lora_module.alpha
682
+ bias = lora_module.bias
683
+ lora_module = LoraHadaModule()
684
+ lora_module.alpha = alpha
685
+ lora_module.bias = bias
686
+ lora.modules[key] = lora_module
687
+ lora_module.shape = sd_module.weight.shape
688
+
689
+ weight = weight.to(device=devices.device, dtype=devices.dtype)
690
+ weight.requires_grad_(False)
691
+
692
+ if lora_key == 'hada_w1_a':
693
+ lora_module.w1a = weight
694
+ if lora_module.up is None:
695
+ lora_module.up = FakeModule(
696
+ lora_module.w1a,
697
+ lora_module.inference
698
+ )
699
+ elif lora_key == 'hada_w1_b':
700
+ lora_module.w1b = weight
701
+ lora_module.dim = weight.shape[0]
702
+ elif lora_key == 'hada_w2_a':
703
+ lora_module.w2a = weight
704
+ elif lora_key == 'hada_w2_b':
705
+ lora_module.w2b = weight
706
+ elif lora_key == 'hada_t1':
707
+ lora_module.t1 = weight
708
+ lora_module.up = FakeModule(
709
+ lora_module.t1,
710
+ lora_module.inference
711
+ )
712
+ elif lora_key == 'hada_t2':
713
+ lora_module.t2 = weight
714
+
715
+ if type(sd_module) == torch.nn.Linear:
716
+ lora_module.op = torch.nn.functional.linear
717
+ elif type(sd_module) == torch.nn.Conv2d:
718
+ lora_module.op = torch.nn.functional.conv2d
719
+ lora_module.extra_args = {
720
+ 'stride': sd_module.stride,
721
+ 'padding': sd_module.padding
722
+ }
723
+ else:
724
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
725
+
726
  else:
727
  assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
728
 
729
  if len(keys_failed_to_match) > 0:
730
  print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
731
 
732
+ if len(keys_failed_to_match_lbw) > 0:
733
+ print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match_lbw}")
734
+
735
+ return lora
736
 
737
 
738
  def load_loras_blocks(names, lwei=None,multi=1.0):
 
853
  else:
854
  outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
855
  ss = ["source",ss[0],"diff"]
856
+ return outs,ls,ss