Upload lora_block_weight.py
Browse files- lora_block_weight.py +277 -64
lora_block_weight.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import cv2
|
3 |
import os
|
4 |
import re
|
@@ -9,6 +8,7 @@ import numpy as np
|
|
9 |
import gradio as gr
|
10 |
import os.path
|
11 |
import random
|
|
|
12 |
import modules.ui
|
13 |
import modules.scripts as scripts
|
14 |
from PIL import Image, ImageFont, ImageDraw
|
@@ -21,23 +21,32 @@ from modules.processing import process_images, Processed
|
|
21 |
lxyz = ""
|
22 |
lzyx = ""
|
23 |
|
24 |
-
|
25 |
-
"
|
26 |
-
"
|
27 |
-
"
|
28 |
-
"
|
29 |
-
"
|
30 |
-
"
|
31 |
-
"
|
32 |
-
"
|
33 |
-
"
|
34 |
-
"
|
35 |
-
"
|
36 |
-
"
|
37 |
-
"
|
38 |
-
"
|
39 |
-
"
|
40 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
loopstopper = True
|
43 |
|
@@ -95,7 +104,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
|
95 |
with gr.Column(scale=5):
|
96 |
bw_ratiotags= gr.TextArea(label="",lines=2,value=rasiostags,visible =True,interactive =True,elem_id="lbw_ratios")
|
97 |
with gr.Accordion("XYZ plot",open = False):
|
98 |
-
gr.HTML(value="<p>changeable blocks : BASE,IN01,IN02,IN04,IN05,IN07,IN08,M00,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>")
|
99 |
xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
|
100 |
with gr.Row(visible = False) as esets:
|
101 |
diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
|
@@ -109,7 +118,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
|
109 |
zmen = gr.Textbox(label="Z values ",lines=1,value="",interactive =True,elem_id="lbw_zmen")
|
110 |
|
111 |
exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
|
112 |
-
eymen = gr.Textbox(label="Blocks" ,lines=1,value="BASE,IN01,IN02,IN04,IN05,IN07,IN08,M00,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
|
113 |
|
114 |
with gr.Accordion("Weights setting",open = True):
|
115 |
with gr.Row():
|
@@ -138,7 +147,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
|
138 |
if ":" in l :
|
139 |
key = l.split(":",1)[0]
|
140 |
w = l.split(":",1)[1]
|
141 |
-
if len([w for w in w.split(",")]) == 17:
|
142 |
wdict[key.strip()]=w
|
143 |
return ",".join(list(wdict.keys()))
|
144 |
|
@@ -211,7 +220,7 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
|
211 |
xmen,ymen = exmen,eymen
|
212 |
xtype,ytype = "values","ID"
|
213 |
ebase = xmen.split(",")[1]
|
214 |
-
ebase = [ebase.strip()]*
|
215 |
base = ",".join(ebase)
|
216 |
ztype = ""
|
217 |
|
@@ -244,15 +253,17 @@ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
|
244 |
images = []
|
245 |
|
246 |
def weightsdealer(alpha,ids,base):
|
247 |
-
|
|
|
248 |
#print(f"weights from : {base}")
|
249 |
ids = [z.strip() for z in ids.split(' ')]
|
250 |
weights_t = [w.strip() for w in base.split(',')]
|
|
|
251 |
if ids[0]!="NOT":
|
252 |
-
flagger=[False]*
|
253 |
changer = True
|
254 |
else:
|
255 |
-
flagger=[True]*
|
256 |
changer = False
|
257 |
for id in ids:
|
258 |
if id =="NOT":continue
|
@@ -357,7 +368,7 @@ def loradealer(p,lratios):
|
|
357 |
lorars = []
|
358 |
for called in calledloras:
|
359 |
if len(called.items) <3:continue
|
360 |
-
if called.items[2] in lratios or called.items[2].count(",") ==16:
|
361 |
lorans.append(called.items[0])
|
362 |
wei = lratios[called.items[2]] if called.items[2] in lratios else called.items[2]
|
363 |
multiple = called.items[1]
|
@@ -370,14 +381,16 @@ def loradealer(p,lratios):
|
|
370 |
else:
|
371 |
ratios[i] = float(r)
|
372 |
print(f"LoRA Block weight :{called.items[0]}: {ratios}")
|
|
|
|
|
373 |
lorars.append(ratios)
|
374 |
if len(lorars) > 0: load_loras_blocks(lorans,lorars,multiple)
|
375 |
|
376 |
re_digits = re.compile(r"\d+")
|
|
|
377 |
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
378 |
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
379 |
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
380 |
-
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
381 |
|
382 |
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
|
383 |
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
|
@@ -386,6 +399,9 @@ re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+
|
|
386 |
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
|
387 |
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
|
388 |
|
|
|
|
|
|
|
389 |
def convert_diffusers_name_to_compvis(key):
|
390 |
def match(match_list, regex):
|
391 |
r = re.match(regex, key)
|
@@ -451,75 +467,272 @@ def convert_diffusers_name_to_compvis(key):
|
|
451 |
|
452 |
return key
|
453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
def load_lora(name, filename,lwei):
|
455 |
-
import lora
|
456 |
-
|
457 |
-
|
458 |
|
459 |
sd = sd_models.read_state_dict(filename)
|
460 |
|
461 |
keys_failed_to_match = []
|
|
|
462 |
|
463 |
for key_diffusers, weight in sd.items():
|
464 |
ratio = 1
|
|
|
465 |
|
466 |
-
for i,block in enumerate(LORABLOCKS):
|
467 |
-
if block in key_diffusers:
|
468 |
-
ratio = lwei[i]
|
469 |
-
|
470 |
-
weight =weight *math.sqrt(abs(ratio))
|
471 |
-
|
472 |
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
473 |
-
#print(key_diffusers+":"+fullkey+"x" + str(ratio))
|
474 |
key, lora_key = fullkey.split(".", 1)
|
475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
477 |
if sd_module is None:
|
478 |
keys_failed_to_match.append(key_diffusers)
|
479 |
continue
|
480 |
|
481 |
-
lora_module =
|
482 |
if lora_module is None:
|
483 |
-
lora_module =
|
484 |
-
|
485 |
|
486 |
if lora_key == "alpha":
|
487 |
lora_module.alpha = weight.item()
|
488 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
-
|
491 |
-
weight = weight.reshape(weight.shape[0], -1)
|
492 |
-
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
493 |
-
elif type(sd_module) == torch.nn.Conv2d:
|
494 |
-
if lora_key == "lora_down.weight":
|
495 |
-
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
|
496 |
-
elif lora_key == "lora_up.weight":
|
497 |
-
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
498 |
-
else:
|
499 |
-
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
500 |
-
|
501 |
-
fugou = 1 if ratio >0 else -1
|
502 |
|
503 |
-
if lora_key == "lora_up.weight":
|
504 |
with torch.no_grad():
|
505 |
module.weight.copy_(weight*fugou)
|
506 |
-
else:
|
507 |
-
with torch.no_grad():
|
508 |
-
module.weight.copy_(weight)
|
509 |
-
|
510 |
-
module.to(device=devices.device, dtype=devices.dtype)
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
else:
|
517 |
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
518 |
|
519 |
if len(keys_failed_to_match) > 0:
|
520 |
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
521 |
|
522 |
-
|
|
|
|
|
|
|
523 |
|
524 |
|
525 |
def load_loras_blocks(names, lwei=None,multi=1.0):
|
@@ -640,4 +853,4 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
|
|
640 |
else:
|
641 |
outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
|
642 |
ss = ["source",ss[0],"diff"]
|
643 |
-
return outs,ls,ss
|
|
|
|
|
1 |
import cv2
|
2 |
import os
|
3 |
import re
|
|
|
8 |
import gradio as gr
|
9 |
import os.path
|
10 |
import random
|
11 |
+
from pprint import pprint
|
12 |
import modules.ui
|
13 |
import modules.scripts as scripts
|
14 |
from PIL import Image, ImageFont, ImageDraw
|
|
|
21 |
lxyz = ""
|
22 |
lzyx = ""
|
23 |
|
24 |
+
BLOCKS=["encoder",
|
25 |
+
"diffusion_model_input_blocks_0_",
|
26 |
+
"diffusion_model_input_blocks_1_",
|
27 |
+
"diffusion_model_input_blocks_2_",
|
28 |
+
"diffusion_model_input_blocks_3_",
|
29 |
+
"diffusion_model_input_blocks_4_",
|
30 |
+
"diffusion_model_input_blocks_5_",
|
31 |
+
"diffusion_model_input_blocks_6_",
|
32 |
+
"diffusion_model_input_blocks_7_",
|
33 |
+
"diffusion_model_input_blocks_8_",
|
34 |
+
"diffusion_model_input_blocks_9_",
|
35 |
+
"diffusion_model_input_blocks_10_",
|
36 |
+
"diffusion_model_input_blocks_11_",
|
37 |
+
"diffusion_model_middle_block_",
|
38 |
+
"diffusion_model_output_blocks_0_",
|
39 |
+
"diffusion_model_output_blocks_1_",
|
40 |
+
"diffusion_model_output_blocks_2_",
|
41 |
+
"diffusion_model_output_blocks_3_",
|
42 |
+
"diffusion_model_output_blocks_4_",
|
43 |
+
"diffusion_model_output_blocks_5_",
|
44 |
+
"diffusion_model_output_blocks_6_",
|
45 |
+
"diffusion_model_output_blocks_7_",
|
46 |
+
"diffusion_model_output_blocks_8_",
|
47 |
+
"diffusion_model_output_blocks_9_",
|
48 |
+
"diffusion_model_output_blocks_10_",
|
49 |
+
"diffusion_model_output_blocks_11_"]
|
50 |
|
51 |
loopstopper = True
|
52 |
|
|
|
104 |
with gr.Column(scale=5):
|
105 |
bw_ratiotags= gr.TextArea(label="",lines=2,value=rasiostags,visible =True,interactive =True,elem_id="lbw_ratios")
|
106 |
with gr.Accordion("XYZ plot",open = False):
|
107 |
+
gr.HTML(value="<p>changeable blocks : BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>")
|
108 |
xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
|
109 |
with gr.Row(visible = False) as esets:
|
110 |
diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
|
|
|
118 |
zmen = gr.Textbox(label="Z values ",lines=1,value="",interactive =True,elem_id="lbw_zmen")
|
119 |
|
120 |
exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
|
121 |
+
eymen = gr.Textbox(label="Blocks" ,lines=1,value="BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
|
122 |
|
123 |
with gr.Accordion("Weights setting",open = True):
|
124 |
with gr.Row():
|
|
|
147 |
if ":" in l :
|
148 |
key = l.split(":",1)[0]
|
149 |
w = l.split(":",1)[1]
|
150 |
+
if len([w for w in w.split(",")]) == 17 or len([w for w in w.split(",")]) ==26:
|
151 |
wdict[key.strip()]=w
|
152 |
return ",".join(list(wdict.keys()))
|
153 |
|
|
|
220 |
xmen,ymen = exmen,eymen
|
221 |
xtype,ytype = "values","ID"
|
222 |
ebase = xmen.split(",")[1]
|
223 |
+
ebase = [ebase.strip()]*26
|
224 |
base = ",".join(ebase)
|
225 |
ztype = ""
|
226 |
|
|
|
253 |
images = []
|
254 |
|
255 |
def weightsdealer(alpha,ids,base):
|
256 |
+
blockid17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
257 |
+
blockid26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
258 |
#print(f"weights from : {base}")
|
259 |
ids = [z.strip() for z in ids.split(' ')]
|
260 |
weights_t = [w.strip() for w in base.split(',')]
|
261 |
+
blockid = blockid17 if len(weights_t) ==17 else blockid26
|
262 |
if ids[0]!="NOT":
|
263 |
+
flagger=[False]*len(weights_t)
|
264 |
changer = True
|
265 |
else:
|
266 |
+
flagger=[True]*len(weights_t)
|
267 |
changer = False
|
268 |
for id in ids:
|
269 |
if id =="NOT":continue
|
|
|
368 |
lorars = []
|
369 |
for called in calledloras:
|
370 |
if len(called.items) <3:continue
|
371 |
+
if called.items[2] in lratios or called.items[2].count(",") ==16 or called.items[2].count(",") ==25:
|
372 |
lorans.append(called.items[0])
|
373 |
wei = lratios[called.items[2]] if called.items[2] in lratios else called.items[2]
|
374 |
multiple = called.items[1]
|
|
|
381 |
else:
|
382 |
ratios[i] = float(r)
|
383 |
print(f"LoRA Block weight :{called.items[0]}: {ratios}")
|
384 |
+
if len(ratios)==17:
|
385 |
+
ratios = [ratios[0]] + [1] + ratios[1:3]+ [1] + ratios[3:5]+[1] + ratios[5:7]+[1,1,1] + [ratios[7]] + [1,1,1] + ratios[8:]
|
386 |
lorars.append(ratios)
|
387 |
if len(lorars) > 0: load_loras_blocks(lorans,lorars,multiple)
|
388 |
|
389 |
re_digits = re.compile(r"\d+")
|
390 |
+
|
391 |
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
392 |
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
393 |
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
|
|
394 |
|
395 |
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
|
396 |
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
|
|
|
399 |
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
|
400 |
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
|
401 |
|
402 |
+
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
403 |
+
|
404 |
+
|
405 |
def convert_diffusers_name_to_compvis(key):
|
406 |
def match(match_list, regex):
|
407 |
r = re.match(regex, key)
|
|
|
467 |
|
468 |
return key
|
469 |
|
470 |
+
class FakeModule(torch.nn.Module):
|
471 |
+
def __init__(self, weight, func):
|
472 |
+
super().__init__()
|
473 |
+
self.weight = weight
|
474 |
+
self.func = func
|
475 |
+
|
476 |
+
def forward(self, x):
|
477 |
+
return self.func(x)
|
478 |
+
|
479 |
+
|
480 |
+
class LoraUpDownModule:
|
481 |
+
def __init__(self):
|
482 |
+
self.up_model = None
|
483 |
+
self.mid_model = None
|
484 |
+
self.down_model = None
|
485 |
+
self.alpha = None
|
486 |
+
self.dim = None
|
487 |
+
self.op = None
|
488 |
+
self.extra_args = {}
|
489 |
+
self.shape = None
|
490 |
+
self.bias = None
|
491 |
+
self.up = None
|
492 |
+
|
493 |
+
def down(self, x):
|
494 |
+
return x
|
495 |
+
|
496 |
+
def inference(self, x):
|
497 |
+
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
498 |
+
out_dim = self.up_model.weight.size(0)
|
499 |
+
rank = self.down_model.weight.size(0)
|
500 |
+
rebuild_weight = (
|
501 |
+
self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
|
502 |
+
+ self.bias
|
503 |
+
).reshape(self.shape)
|
504 |
+
return self.op(
|
505 |
+
x, rebuild_weight,
|
506 |
+
**self.extra_args
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
if self.mid_model is None:
|
510 |
+
return self.up_model(self.down_model(x))
|
511 |
+
else:
|
512 |
+
return self.up_model(self.mid_model(self.down_model(x)))
|
513 |
+
|
514 |
+
|
515 |
+
def pro3(t, wa, wb):
|
516 |
+
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
517 |
+
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
518 |
+
|
519 |
+
|
520 |
+
class LoraHadaModule:
|
521 |
+
def __init__(self):
|
522 |
+
self.t1 = None
|
523 |
+
self.w1a = None
|
524 |
+
self.w1b = None
|
525 |
+
self.t2 = None
|
526 |
+
self.w2a = None
|
527 |
+
self.w2b = None
|
528 |
+
self.alpha = None
|
529 |
+
self.dim = None
|
530 |
+
self.op = None
|
531 |
+
self.extra_args = {}
|
532 |
+
self.shape = None
|
533 |
+
self.bias = None
|
534 |
+
self.up = None
|
535 |
+
|
536 |
+
def down(self, x):
|
537 |
+
return x
|
538 |
+
|
539 |
+
def inference(self, x):
|
540 |
+
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
541 |
+
bias = self.bias
|
542 |
+
else:
|
543 |
+
bias = 0
|
544 |
+
|
545 |
+
if self.t1 is None:
|
546 |
+
return self.op(
|
547 |
+
x,
|
548 |
+
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
|
549 |
+
**self.extra_args
|
550 |
+
)
|
551 |
+
else:
|
552 |
+
return self.op(
|
553 |
+
x,
|
554 |
+
(pro3(self.t1, self.w1a, self.w1b)
|
555 |
+
* pro3(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
|
556 |
+
**self.extra_args
|
557 |
+
)
|
558 |
+
|
559 |
+
|
560 |
+
CON_KEY = {
|
561 |
+
"lora_up.weight",
|
562 |
+
"lora_down.weight",
|
563 |
+
"lora_mid.weight"
|
564 |
+
}
|
565 |
+
HADA_KEY = {
|
566 |
+
"hada_t1",
|
567 |
+
"hada_w1_a",
|
568 |
+
"hada_w1_b",
|
569 |
+
"hada_t2",
|
570 |
+
"hada_w2_a",
|
571 |
+
"hada_w2_b",
|
572 |
+
}
|
573 |
+
|
574 |
+
|
575 |
def load_lora(name, filename,lwei):
|
576 |
+
import lora as lora_o
|
577 |
+
lora = lora_o.LoraModule(name)
|
578 |
+
lora.mtime = os.path.getmtime(filename)
|
579 |
|
580 |
sd = sd_models.read_state_dict(filename)
|
581 |
|
582 |
keys_failed_to_match = []
|
583 |
+
keys_failed_to_match_lbw = []
|
584 |
|
585 |
for key_diffusers, weight in sd.items():
|
586 |
ratio = 1
|
587 |
+
picked = False
|
588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
589 |
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
|
|
590 |
key, lora_key = fullkey.split(".", 1)
|
591 |
|
592 |
+
for i,block in enumerate(BLOCKS):
|
593 |
+
if block in key:
|
594 |
+
ratio = lwei[i]
|
595 |
+
picked = True
|
596 |
+
|
597 |
+
if not picked:keys_failed_to_match_lbw.append(key_diffusers)
|
598 |
+
|
599 |
+
weight = weight * math.sqrt(abs(ratio))
|
600 |
+
|
601 |
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
602 |
if sd_module is None:
|
603 |
keys_failed_to_match.append(key_diffusers)
|
604 |
continue
|
605 |
|
606 |
+
lora_module = lora.modules.get(key, None)
|
607 |
if lora_module is None:
|
608 |
+
lora_module = LoraUpDownModule()
|
609 |
+
lora.modules[key] = lora_module
|
610 |
|
611 |
if lora_key == "alpha":
|
612 |
lora_module.alpha = weight.item()
|
613 |
continue
|
614 |
+
|
615 |
+
if 'bias_' in lora_key:
|
616 |
+
if lora_module.bias is None:
|
617 |
+
lora_module.bias = [None, None, None]
|
618 |
+
if 'bias_indices' == lora_key:
|
619 |
+
lora_module.bias[0] = weight
|
620 |
+
elif 'bias_values' == lora_key:
|
621 |
+
lora_module.bias[1] = weight
|
622 |
+
elif 'bias_size' == lora_key:
|
623 |
+
lora_module.bias[2] = weight
|
624 |
+
|
625 |
+
if all((i is not None) for i in lora_module.bias):
|
626 |
+
print('build bias')
|
627 |
+
lora_module.bias = torch.sparse_coo_tensor(
|
628 |
+
lora_module.bias[0],
|
629 |
+
lora_module.bias[1],
|
630 |
+
tuple(lora_module.bias[2]),
|
631 |
+
).to(device=devices.device, dtype=devices.dtype)
|
632 |
+
lora_module.bias.requires_grad_(False)
|
633 |
+
continue
|
634 |
+
|
635 |
+
if lora_key in CON_KEY:
|
636 |
+
if type(sd_module) == torch.nn.Linear:
|
637 |
+
weight = weight.reshape(weight.shape[0], -1)
|
638 |
+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
639 |
+
lora_module.op = torch.nn.functional.linear
|
640 |
+
elif type(sd_module) == torch.nn.Conv2d:
|
641 |
+
if lora_key == "lora_down.weight":
|
642 |
+
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
643 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
|
644 |
+
else:
|
645 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
646 |
+
elif lora_key == "lora_mid.weight":
|
647 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
|
648 |
+
elif lora_key == "lora_up.weight":
|
649 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
650 |
+
lora_module.op = torch.nn.functional.conv2d
|
651 |
+
lora_module.extra_args = {
|
652 |
+
'stride': sd_module.stride,
|
653 |
+
'padding': sd_module.padding
|
654 |
+
}
|
655 |
+
else:
|
656 |
+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
657 |
+
|
658 |
+
lora_module.shape = sd_module.weight.shape
|
659 |
|
660 |
+
fugou = np.sign(ratio) if lora_key == "lora_up.weight" else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
|
|
|
662 |
with torch.no_grad():
|
663 |
module.weight.copy_(weight*fugou)
|
|
|
|
|
|
|
|
|
|
|
664 |
|
665 |
+
module.to(device=devices.device, dtype=devices.dtype)
|
666 |
+
module.requires_grad_(False)
|
667 |
+
|
668 |
+
if lora_key == "lora_up.weight":
|
669 |
+
lora_module.up_model = module
|
670 |
+
lora_module.up = FakeModule(
|
671 |
+
lora_module.up_model.weight,
|
672 |
+
lora_module.inference
|
673 |
+
)
|
674 |
+
elif lora_key == "lora_mid.weight":
|
675 |
+
lora_module.mid_model = module
|
676 |
+
elif lora_key == "lora_down.weight":
|
677 |
+
lora_module.down_model = module
|
678 |
+
lora_module.dim = weight.shape[0]
|
679 |
+
elif lora_key in HADA_KEY:
|
680 |
+
if type(lora_module) != LoraHadaModule:
|
681 |
+
alpha = lora_module.alpha
|
682 |
+
bias = lora_module.bias
|
683 |
+
lora_module = LoraHadaModule()
|
684 |
+
lora_module.alpha = alpha
|
685 |
+
lora_module.bias = bias
|
686 |
+
lora.modules[key] = lora_module
|
687 |
+
lora_module.shape = sd_module.weight.shape
|
688 |
+
|
689 |
+
weight = weight.to(device=devices.device, dtype=devices.dtype)
|
690 |
+
weight.requires_grad_(False)
|
691 |
+
|
692 |
+
if lora_key == 'hada_w1_a':
|
693 |
+
lora_module.w1a = weight
|
694 |
+
if lora_module.up is None:
|
695 |
+
lora_module.up = FakeModule(
|
696 |
+
lora_module.w1a,
|
697 |
+
lora_module.inference
|
698 |
+
)
|
699 |
+
elif lora_key == 'hada_w1_b':
|
700 |
+
lora_module.w1b = weight
|
701 |
+
lora_module.dim = weight.shape[0]
|
702 |
+
elif lora_key == 'hada_w2_a':
|
703 |
+
lora_module.w2a = weight
|
704 |
+
elif lora_key == 'hada_w2_b':
|
705 |
+
lora_module.w2b = weight
|
706 |
+
elif lora_key == 'hada_t1':
|
707 |
+
lora_module.t1 = weight
|
708 |
+
lora_module.up = FakeModule(
|
709 |
+
lora_module.t1,
|
710 |
+
lora_module.inference
|
711 |
+
)
|
712 |
+
elif lora_key == 'hada_t2':
|
713 |
+
lora_module.t2 = weight
|
714 |
+
|
715 |
+
if type(sd_module) == torch.nn.Linear:
|
716 |
+
lora_module.op = torch.nn.functional.linear
|
717 |
+
elif type(sd_module) == torch.nn.Conv2d:
|
718 |
+
lora_module.op = torch.nn.functional.conv2d
|
719 |
+
lora_module.extra_args = {
|
720 |
+
'stride': sd_module.stride,
|
721 |
+
'padding': sd_module.padding
|
722 |
+
}
|
723 |
+
else:
|
724 |
+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
725 |
+
|
726 |
else:
|
727 |
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
728 |
|
729 |
if len(keys_failed_to_match) > 0:
|
730 |
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
731 |
|
732 |
+
if len(keys_failed_to_match_lbw) > 0:
|
733 |
+
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match_lbw}")
|
734 |
+
|
735 |
+
return lora
|
736 |
|
737 |
|
738 |
def load_loras_blocks(names, lwei=None,multi=1.0):
|
|
|
853 |
else:
|
854 |
outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
|
855 |
ss = ["source",ss[0],"diff"]
|
856 |
+
return outs,ls,ss
|