supertori commited on
Commit
a55caa3
·
1 Parent(s): 52b9cbb

Upload 2 files

Browse files
Files changed (2) hide show
  1. cldm_v15.yaml +79 -0
  2. lora_block_weight.py +643 -0
cldm_v15.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: cldm.cldm.ControlLDM
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: "jpg"
10
+ cond_stage_key: "txt"
11
+ control_key: "hint"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ only_mid_control: False
20
+
21
+ control_stage_config:
22
+ target: cldm.cldm.ControlNet
23
+ params:
24
+ image_size: 32 # unused
25
+ in_channels: 4
26
+ hint_channels: 3
27
+ model_channels: 320
28
+ attention_resolutions: [ 4, 2, 1 ]
29
+ num_res_blocks: 2
30
+ channel_mult: [ 1, 2, 4, 4 ]
31
+ num_heads: 8
32
+ use_spatial_transformer: True
33
+ transformer_depth: 1
34
+ context_dim: 768
35
+ use_checkpoint: True
36
+ legacy: False
37
+
38
+ unet_config:
39
+ target: cldm.cldm.ControlledUnetModel
40
+ params:
41
+ image_size: 32 # unused
42
+ in_channels: 4
43
+ out_channels: 4
44
+ model_channels: 320
45
+ attention_resolutions: [ 4, 2, 1 ]
46
+ num_res_blocks: 2
47
+ channel_mult: [ 1, 2, 4, 4 ]
48
+ num_heads: 8
49
+ use_spatial_transformer: True
50
+ transformer_depth: 1
51
+ context_dim: 768
52
+ use_checkpoint: True
53
+ legacy: False
54
+
55
+ first_stage_config:
56
+ target: ldm.models.autoencoder.AutoencoderKL
57
+ params:
58
+ embed_dim: 4
59
+ monitor: val/rec_loss
60
+ ddconfig:
61
+ double_z: true
62
+ z_channels: 4
63
+ resolution: 256
64
+ in_channels: 3
65
+ out_ch: 3
66
+ ch: 128
67
+ ch_mult:
68
+ - 1
69
+ - 2
70
+ - 4
71
+ - 4
72
+ num_res_blocks: 2
73
+ attn_resolutions: []
74
+ dropout: 0.0
75
+ lossconfig:
76
+ target: torch.nn.Identity
77
+
78
+ cond_stage_config:
79
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
lora_block_weight.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import os
4
+ import re
5
+ import torch
6
+ import shutil
7
+ import math
8
+ import numpy as np
9
+ import gradio as gr
10
+ import os.path
11
+ import random
12
+ import modules.ui
13
+ import modules.scripts as scripts
14
+ from PIL import Image, ImageFont, ImageDraw
15
+ from fonts.ttf import Roboto
16
+ import modules.shared as shared
17
+ from modules import devices, sd_models, images,extra_networks
18
+ from modules.shared import opts, state
19
+ from modules.processing import process_images, Processed
20
+
21
+ lxyz = ""
22
+ lzyx = ""
23
+
24
+ LORABLOCKS=["encoder",
25
+ "down_blocks_0_attentions_0",
26
+ "down_blocks_0_attentions_1",
27
+ "down_blocks_1_attentions_0",
28
+ "down_blocks_1_attentions_1",
29
+ "down_blocks_2_attentions_0",
30
+ "down_blocks_2_attentions_1",
31
+ "mid_block_attentions_0",
32
+ "up_blocks_1_attentions_0",
33
+ "up_blocks_1_attentions_1",
34
+ "up_blocks_1_attentions_2",
35
+ "up_blocks_2_attentions_0",
36
+ "up_blocks_2_attentions_1",
37
+ "up_blocks_2_attentions_2",
38
+ "up_blocks_3_attentions_0",
39
+ "up_blocks_3_attentions_1",
40
+ "up_blocks_3_attentions_2"]
41
+
42
+ loopstopper = True
43
+
44
+ ATYPES =["none","Block ID","values","seed","Original Weights"]
45
+
46
+ class Script(modules.scripts.Script):
47
+ def title(self):
48
+ return "LoRA Block Weight"
49
+
50
+ def show(self, is_img2img):
51
+ return modules.scripts.AlwaysVisible
52
+
53
+ def ui(self, is_img2img):
54
+ import lora
55
+ LWEIGHTSPRESETS="\
56
+ NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
57
+ ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
58
+ INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
59
+ IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
60
+ INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
61
+ MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
62
+ OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
63
+ OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
64
+ OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
65
+ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
66
+
67
+ runorigin = scripts.scripts_txt2img.run
68
+ runorigini = scripts.scripts_img2img.run
69
+
70
+ path_root = scripts.basedir()
71
+ extpath = os.path.join(path_root,"extensions","sd-webui-lora-block-weight","scripts", "lbwpresets.txt")
72
+ filepath = os.path.join(path_root,"scripts", "lbwpresets.txt")
73
+
74
+ if os.path.isfile(extpath) and not os.path.isfile(filepath):
75
+ shutil.move(extpath,filepath)
76
+
77
+ lbwpresets=""
78
+ try:
79
+ with open(filepath) as f:
80
+ lbwpresets = f.read()
81
+ except OSError as e:
82
+ lbwpresets=LWEIGHTSPRESETS
83
+
84
+ loraratios=lbwpresets.splitlines()
85
+ lratios={}
86
+ for i,l in enumerate(loraratios):
87
+ lratios[l.split(":")[0]]=l.split(":")[1]
88
+ rasiostags = [k for k in lratios.keys()]
89
+ rasiostags = ",".join(rasiostags)
90
+
91
+ with gr.Accordion("LoRA Block Weight",open = False):
92
+ with gr.Row():
93
+ with gr.Column(min_width = 50, scale=1):
94
+ lbw_useblocks = gr.Checkbox(value = True,label="Active",interactive =True,elem_id="lbw_active")
95
+ with gr.Column(scale=5):
96
+ bw_ratiotags= gr.TextArea(label="",lines=2,value=rasiostags,visible =True,interactive =True,elem_id="lbw_ratios")
97
+ with gr.Accordion("XYZ plot",open = False):
98
+ gr.HTML(value="<p>changeable blocks : BASE,IN01,IN02,IN04,IN05,IN07,IN08,M00,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>")
99
+ xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
100
+ with gr.Row(visible = False) as esets:
101
+ diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
102
+ revxy = gr.Checkbox(value = False,label="change X-Y",interactive =True,elem_id="lbw_changexy")
103
+ thresh = gr.Textbox(label="difference threshold",lines=1,value="20",interactive =True,elem_id="diff_thr")
104
+ xtype = gr.Dropdown(label="X Types ", choices=[x for x in ATYPES], value=ATYPES [2],interactive =True,elem_id="lbw_xtype")
105
+ xmen = gr.Textbox(label="X Values ",lines=1,value="0,0.25,0.5,0.75,1",interactive =True,elem_id="lbw_xmen")
106
+ ytype = gr.Dropdown(label="Y Types ", choices=[y for y in ATYPES], value=ATYPES [1],interactive =True,elem_id="lbw_ytype")
107
+ ymen = gr.Textbox(label="Y Values " ,lines=1,value="IN05-OUT05",interactive =True,elem_id="lbw_ymen")
108
+ ztype = gr.Dropdown(label="Z type ", choices=[z for z in ATYPES], value=ATYPES[0],interactive =True,elem_id="lbw_ztype")
109
+ zmen = gr.Textbox(label="Z values ",lines=1,value="",interactive =True,elem_id="lbw_zmen")
110
+
111
+ exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
112
+ eymen = gr.Textbox(label="Blocks" ,lines=1,value="BASE,IN01,IN02,IN04,IN05,IN07,IN08,M00,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
113
+
114
+ with gr.Accordion("Weights setting",open = True):
115
+ with gr.Row():
116
+ reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
117
+ reloadtags = gr.Button(value="Reload Tags",variant='primary',elem_id="lbw_reload")
118
+ savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
119
+ openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
120
+ lbw_loraratios = gr.TextArea(label="",value=lbwpresets,visible =True,interactive = True,elem_id="lbw_ratiospreset")
121
+
122
+ import subprocess
123
+ def openeditors():
124
+ subprocess.Popen(['start', filepath], shell=True)
125
+
126
+ def reloadpresets():
127
+ try:
128
+ with open(filepath) as f:
129
+ return f.read()
130
+ except OSError as e:
131
+ pass
132
+
133
+ def tagdicter(presets):
134
+ presets=presets.splitlines()
135
+ wdict={}
136
+ for l in presets:
137
+ w=[]
138
+ if ":" in l :
139
+ key = l.split(":",1)[0]
140
+ w = l.split(":",1)[1]
141
+ if len([w for w in w.split(",")]) == 17:
142
+ wdict[key.strip()]=w
143
+ return ",".join(list(wdict.keys()))
144
+
145
+ def savepresets(text):
146
+ with open(filepath,mode = 'w') as f:
147
+ f.write(text)
148
+
149
+ reloadtext.click(fn=reloadpresets,inputs=[],outputs=[lbw_loraratios])
150
+ reloadtags.click(fn=tagdicter,inputs=[lbw_loraratios],outputs=[bw_ratiotags])
151
+ savetext.click(fn=savepresets,inputs=[lbw_loraratios],outputs=[])
152
+ openeditor.click(fn=openeditors,inputs=[],outputs=[])
153
+
154
+
155
+ def urawaza(active):
156
+ if active > 0:
157
+ for obj in scripts.scripts_txt2img.alwayson_scripts:
158
+ if "lora_block_weight" in obj.filename:
159
+ scripts.scripts_txt2img.selectable_scripts.append(obj)
160
+ scripts.scripts_txt2img.titles.append("LoRA Block Weight")
161
+ for obj in scripts.scripts_img2img.alwayson_scripts:
162
+ if "lora_block_weight" in obj.filename:
163
+ scripts.scripts_img2img.selectable_scripts.append(obj)
164
+ scripts.scripts_img2img.titles.append("LoRA Block Weight")
165
+ scripts.scripts_txt2img.run = newrun
166
+ scripts.scripts_img2img.run = newrun
167
+ if active == 1:return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(3)]]
168
+ else:return [*[gr.update(visible = False) for x in range(6)],*[gr.update(visible = True) for x in range(3)]]
169
+ else:
170
+ scripts.scripts_txt2img.run = runorigin
171
+ scripts.scripts_img2img.run = runorigini
172
+ return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(3)]]
173
+
174
+ xyzsetting.change(fn=urawaza,inputs=[xyzsetting],outputs =[xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,esets])
175
+
176
+ return lbw_loraratios,lbw_useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,diffcol,thresh,revxy
177
+
178
+ def process(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,diffcol,thresh,revxy):
179
+ #print("self =",self,"p =",p,"presets =",loraratios,"useblocks =",useblocks,"xyzsettings =",xyzsetting,"xtype =",xtype,"xmen =",xmen,"ytype =",ytype,"ymen =",ymen,"ztype =",ztype,"zmen =",zmen)
180
+
181
+ if useblocks:
182
+ loraratios=loraratios.splitlines()
183
+ lratios={}
184
+ for l in loraratios:
185
+ l0=l.split(":",1)[0]
186
+ lratios[l0.strip()]=l.split(":",1)[1]
187
+ if xyzsetting and "XYZ" in p.prompt:
188
+ lratios["XYZ"] = lxyz
189
+ lratios["ZYX"] = lzyx
190
+ loradealer(p,lratios)
191
+ return
192
+
193
+ def postprocess(self, p, processed, *args):
194
+ import lora
195
+ lora.loaded_loras.clear()
196
+
197
+ def run(self,p,presets,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,diffcol,thresh,revxy):
198
+ if xyzsetting >0:
199
+ import lora
200
+ loraratios=presets.splitlines()
201
+ lratios={}
202
+ for l in loraratios:
203
+ l0=l.split(":",1)[0]
204
+ lratios[l0.strip()]=l.split(":",1)[1]
205
+
206
+ if "XYZ" in p.prompt:
207
+ base = lratios["XYZ"] if "XYZ" in lratios.keys() else "1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1"
208
+ else: return
209
+
210
+ if xyzsetting > 1:
211
+ xmen,ymen = exmen,eymen
212
+ xtype,ytype = "values","ID"
213
+ ebase = xmen.split(",")[1]
214
+ ebase = [ebase.strip()]*17
215
+ base = ",".join(ebase)
216
+ ztype = ""
217
+
218
+ #ATYPES =["none","Block ID","values","seed","Base Weights"]
219
+
220
+ def dicedealer(am):
221
+ for i,a in enumerate(am):
222
+ if a =="-1": am[i] = str(random.randrange(4294967294))
223
+ print(f"the die was thrown : {am}")
224
+
225
+ if p.seed == -1: p.seed = str(random.randrange(4294967294))
226
+
227
+ #print(f"xs:{xmen},ys:{ymen},zs:{zmen}")
228
+
229
+ def adjuster(a,at):
230
+ if "none" in at:a = ""
231
+ a = [a.strip() for a in a.split(',')]
232
+ if "seed" in at:dicedealer(a)
233
+ return a
234
+
235
+ xs = adjuster(xmen,xtype)
236
+ ys = adjuster(ymen,ytype)
237
+ zs = adjuster(zmen,ztype)
238
+
239
+ ids = alpha =seed = ""
240
+ p.batch_size = 1
241
+
242
+ print(f"xs:{xs},ys:{ys},zs:{zs}")
243
+
244
+ images = []
245
+
246
+ def weightsdealer(alpha,ids,base):
247
+ blockid=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
248
+ #print(f"weights from : {base}")
249
+ ids = [z.strip() for z in ids.split(' ')]
250
+ weights_t = [w.strip() for w in base.split(',')]
251
+ if ids[0]!="NOT":
252
+ flagger=[False]*17
253
+ changer = True
254
+ else:
255
+ flagger=[True]*17
256
+ changer = False
257
+ for id in ids:
258
+ if id =="NOT":continue
259
+ if "-" in id:
260
+ it = [it.strip() for it in id.split('-')]
261
+ if blockid.index(it[1]) > blockid.index(it[0]):
262
+ flagger[blockid.index(it[0]):blockid.index(it[1])+1] = [changer]*(blockid.index(it[1])-blockid.index(it[0])+1)
263
+ else:
264
+ flagger[blockid.index(it[1]):blockid.index(it[0])+1] = [changer]*(blockid.index(it[0])-blockid.index(it[1])+1)
265
+ else:
266
+ flagger[blockid.index(id)] =changer
267
+ for i,f in enumerate(flagger):
268
+ if f:weights_t[i]=alpha
269
+ outext = ",".join(weights_t)
270
+ #print(f"weights changed: {outext}")
271
+ return outext
272
+
273
+ def xyzdealer(a,at):
274
+ nonlocal ids,alpha,p,base,c_base
275
+ if "ID" in at:return
276
+ if "values" in at:alpha = a
277
+ if "seed" in at:
278
+ p.seed = int(a)
279
+ if "Weights" in at:base =c_base = lratios[a]
280
+
281
+ grids = []
282
+ images =[]
283
+
284
+ totalcount = len(xs)*len(ys)*len(zs) if xyzsetting < 2 else len(xs)*len(ys)*len(zs) //2 +1
285
+ shared.total_tqdm.updateTotal(totalcount)
286
+ xc = yc =zc = 0
287
+ state.job_count = totalcount
288
+ totalcount = len(xs)*len(ys)*len(zs)
289
+
290
+ for z in zs:
291
+ images = []
292
+ yc = 0
293
+ xyzdealer(z,ztype)
294
+ for y in ys:
295
+ xc = 0
296
+ xyzdealer(y,ytype)
297
+ for x in xs:
298
+ xyzdealer(x,xtype)
299
+ if "ID" in xtype:
300
+ if "values" in ytype:c_base = weightsdealer(y,x,base)
301
+ if "values" in ztype:c_base = weightsdealer(z,x,base)
302
+ if "ID" in ytype:
303
+ if "values" in xtype:c_base = weightsdealer(x,y,base)
304
+ if "values" in ztype:c_base = weightsdealer(z,y,base)
305
+ if "ID" in ztype:
306
+ if "values" in xtype:c_base = weightsdealer(x,z,base)
307
+ if "values" in ytype:c_base = weightsdealer(y,z,base)
308
+
309
+ print(f"X:{xtype}, {x},Y: {ytype},{y}, Z:{ztype},{z}, base:{c_base} ({len(xs)*len(ys)*zc + yc*len(xs) +xc +1}/{totalcount})")
310
+
311
+ global lxyz,lzyx
312
+ lxyz = c_base
313
+
314
+ cr_base = c_base.split(",")
315
+ cr_base_t=[]
316
+ for x in cr_base:
317
+ if x != "R" and x != "U":
318
+ cr_base_t.append(str(1-float(x)))
319
+ else:
320
+ cr_base_t.append(x)
321
+ lzyx = ",".join(cr_base_t)
322
+
323
+ if not(xc == 1 and not (yc ==0 ) and xyzsetting >1):
324
+ lora.loaded_loras.clear()
325
+ processed:Processed = process_images(p)
326
+ images.append(processed.images[0])
327
+ xc += 1
328
+ yc += 1
329
+ zc += 1
330
+ origin = loranames(processed.all_prompts) + ", "+ znamer(ztype,z,base)
331
+ if xyzsetting >1: images,xs,ys = effectivechecker(images,xs,ys,diffcol,thresh,revxy)
332
+ grids.append(smakegrid(images,xs,ys,origin,p))
333
+ processed.images= grids
334
+ lora.loaded_loras.clear()
335
+ return processed
336
+
337
+ def znamer(at,a,base):
338
+ if "ID" in at:return f"Block : {a}"
339
+ if "values" in at:return f"value : {a}"
340
+ if "seed" in at:return f"seed : {a}"
341
+ if "Weights" in at:return f"original weights :\n {base}"
342
+ else: return ""
343
+
344
+ def loranames(all_prompts):
345
+ _, extra_network_data = extra_networks.parse_prompts(all_prompts[0:1])
346
+ calledloras = extra_network_data["lora"]
347
+ names = ""
348
+ for called in calledloras:
349
+ if len(called.items) <3:continue
350
+ names += called.items[0]
351
+ return names
352
+
353
+ def loradealer(p,lratios):
354
+ _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
355
+ calledloras = extra_network_data["lora"]
356
+ lorans = []
357
+ lorars = []
358
+ for called in calledloras:
359
+ if len(called.items) <3:continue
360
+ if called.items[2] in lratios or called.items[2].count(",") ==16:
361
+ lorans.append(called.items[0])
362
+ wei = lratios[called.items[2]] if called.items[2] in lratios else called.items[2]
363
+ multiple = called.items[1]
364
+ ratios = [w for w in wei.split(",")]
365
+ for i,r in enumerate(ratios):
366
+ if r =="R":
367
+ ratios[i] = round(random.random(),3)
368
+ elif r == "U":
369
+ ratios[i] = round(random.uniform(-0.5,1.5),3)
370
+ else:
371
+ ratios[i] = float(r)
372
+ print(f"LoRA Block weight :{called.items[0]}: {ratios}")
373
+ lorars.append(ratios)
374
+ if len(lorars) > 0: load_loras_blocks(lorans,lorars,multiple)
375
+
376
+ re_digits = re.compile(r"\d+")
377
+ re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
378
+ re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
379
+ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
380
+ re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
381
+
382
+ re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
383
+ re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
384
+ re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)")
385
+
386
+ re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
387
+ re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
388
+
389
+ def convert_diffusers_name_to_compvis(key):
390
+ def match(match_list, regex):
391
+ r = re.match(regex, key)
392
+ if not r:
393
+ return False
394
+
395
+ match_list.clear()
396
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
397
+ return True
398
+
399
+ m = []
400
+
401
+ if match(m, re_unet_down_blocks):
402
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
403
+
404
+ if match(m, re_unet_mid_blocks):
405
+ return f"diffusion_model_middle_block_1_{m[1]}"
406
+
407
+ if match(m, re_unet_up_blocks):
408
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
409
+
410
+ if match(m, re_unet_down_blocks_res):
411
+ block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_"
412
+ if m[2].startswith('conv1'):
413
+ return f"{block}in_layers_2{m[2][len('conv1'):]}"
414
+ elif m[2].startswith('conv2'):
415
+ return f"{block}out_layers_3{m[2][len('conv2'):]}"
416
+ elif m[2].startswith('time_emb_proj'):
417
+ return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
418
+ elif m[2].startswith('conv_shortcut'):
419
+ return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
420
+
421
+ if match(m, re_unet_mid_blocks_res):
422
+ block = f"diffusion_model_middle_block_{m[0]*2}_"
423
+ if m[1].startswith('conv1'):
424
+ return f"{block}in_layers_2{m[1][len('conv1'):]}"
425
+ elif m[1].startswith('conv2'):
426
+ return f"{block}out_layers_3{m[1][len('conv2'):]}"
427
+ elif m[1].startswith('time_emb_proj'):
428
+ return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}"
429
+ elif m[1].startswith('conv_shortcut'):
430
+ return f"{block}skip_connection{m[1][len('conv_shortcut'):]}"
431
+
432
+ if match(m, re_unet_up_blocks_res):
433
+ block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_"
434
+ if m[2].startswith('conv1'):
435
+ return f"{block}in_layers_2{m[2][len('conv1'):]}"
436
+ elif m[2].startswith('conv2'):
437
+ return f"{block}out_layers_3{m[2][len('conv2'):]}"
438
+ elif m[2].startswith('time_emb_proj'):
439
+ return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
440
+ elif m[2].startswith('conv_shortcut'):
441
+ return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
442
+
443
+ if match(m, re_unet_downsample):
444
+ return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}"
445
+
446
+ if match(m, re_unet_upsample):
447
+ return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}"
448
+
449
+ if match(m, re_text_block):
450
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
451
+
452
+ return key
453
+
454
+ def load_lora(name, filename,lwei):
455
+ import lora
456
+ locallora = lora.LoraModule(name)
457
+ locallora.mtime = os.path.getmtime(filename)
458
+
459
+ sd = sd_models.read_state_dict(filename)
460
+
461
+ keys_failed_to_match = []
462
+
463
+ for key_diffusers, weight in sd.items():
464
+ ratio = 1
465
+
466
+ for i,block in enumerate(LORABLOCKS):
467
+ if block in key_diffusers:
468
+ ratio = lwei[i]
469
+
470
+ weight =weight *math.sqrt(abs(ratio))
471
+
472
+ fullkey = convert_diffusers_name_to_compvis(key_diffusers)
473
+ #print(key_diffusers+":"+fullkey+"x" + str(ratio))
474
+ key, lora_key = fullkey.split(".", 1)
475
+
476
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
477
+ if sd_module is None:
478
+ keys_failed_to_match.append(key_diffusers)
479
+ continue
480
+
481
+ lora_module = locallora.modules.get(key, None)
482
+ if lora_module is None:
483
+ lora_module = lora.LoraUpDownModule()
484
+ locallora.modules[key] = lora_module
485
+
486
+ if lora_key == "alpha":
487
+ lora_module.alpha = weight.item()
488
+ continue
489
+
490
+ if type(sd_module) == torch.nn.Linear:
491
+ weight = weight.reshape(weight.shape[0], -1)
492
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
493
+ elif type(sd_module) == torch.nn.Conv2d:
494
+ if lora_key == "lora_down.weight":
495
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
496
+ elif lora_key == "lora_up.weight":
497
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
498
+ else:
499
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
500
+
501
+ fugou = 1 if ratio >0 else -1
502
+
503
+ if lora_key == "lora_up.weight":
504
+ with torch.no_grad():
505
+ module.weight.copy_(weight*fugou)
506
+ else:
507
+ with torch.no_grad():
508
+ module.weight.copy_(weight)
509
+
510
+ module.to(device=devices.device, dtype=devices.dtype)
511
+
512
+ if lora_key == "lora_up.weight":
513
+ lora_module.up = module
514
+ elif lora_key == "lora_down.weight":
515
+ lora_module.down = module
516
+ else:
517
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
518
+
519
+ if len(keys_failed_to_match) > 0:
520
+ print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
521
+
522
+ return locallora
523
+
524
+
525
+ def load_loras_blocks(names, lwei=None,multi=1.0):
526
+ import lora
527
+ loras_on_disk = [lora.available_loras.get(name, None) for name in names]
528
+ if any([x is None for x in loras_on_disk]):
529
+ lora.list_available_loras()
530
+
531
+ loras_on_disk = [lora.available_loras.get(name, None) for name in names]
532
+
533
+ for i, name in enumerate(names):
534
+ locallora = None
535
+
536
+ lora_on_disk = loras_on_disk[i]
537
+ if lora_on_disk is not None:
538
+ if locallora is None or os.path.getmtime(lora_on_disk.filename) > locallora.mtime:
539
+ locallora = load_lora(name, lora_on_disk.filename,lwei[i])
540
+
541
+ if locallora is None:
542
+ print(f"Couldn't find Lora with name {name}")
543
+ continue
544
+
545
+ locallora.multiplier = multi
546
+ lora.loaded_loras.append(locallora)
547
+
548
+ def smakegrid(imgs,xs,ys,currentmodel,p):
549
+ ver_texts = [[images.GridAnnotation(y)] for y in ys]
550
+ hor_texts = [[images.GridAnnotation(x)] for x in xs]
551
+
552
+ w, h = imgs[0].size
553
+ grid = Image.new('RGB', size=(len(xs) * w, len(ys) * h), color='black')
554
+
555
+ for i, img in enumerate(imgs):
556
+ grid.paste(img, box=(i % len(xs) * w, i // len(xs) * h))
557
+
558
+ grid = images.draw_grid_annotations(grid,int(p.width), int(p.height), hor_texts, ver_texts)
559
+ grid = draw_origin(grid, currentmodel,w*len(xs),h*len(ys),w)
560
+ if opts.grid_save:
561
+ images.save_image(grid, opts.outdir_txt2img_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=p.seed, grid=True, p=p)
562
+
563
+ return grid
564
+
565
+ def draw_origin(grid, text,width,height,width_one):
566
+ grid_d= Image.new("RGB", (grid.width,grid.height), "white")
567
+ grid_d.paste(grid,(0,0))
568
+ def get_font(fontsize):
569
+ try:
570
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
571
+ except Exception:
572
+ return ImageFont.truetype(Roboto, fontsize)
573
+ d= ImageDraw.Draw(grid_d)
574
+ color_active = (0, 0, 0)
575
+ fontsize = (width+height)//25
576
+ fnt = get_font(fontsize)
577
+
578
+ if grid.width != width_one:
579
+ while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
580
+ fontsize -=1
581
+ fnt = get_font(fontsize)
582
+ d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
583
+ return grid_d
584
+
585
+ def newrun(p, *args):
586
+ script_index = args[0]
587
+
588
+ if args[0] ==0:
589
+ script = None
590
+ for obj in scripts.scripts_txt2img.alwayson_scripts:
591
+ if "lora_block_weight" in obj.filename:
592
+ script = obj
593
+ script_args = args[script.args_from:script.args_to]
594
+ else:
595
+ script = scripts.scripts_txt2img.selectable_scripts[script_index-1]
596
+
597
+ if script is None:
598
+ return None
599
+
600
+ script_args = args[script.args_from:script.args_to]
601
+
602
+ processed = script.run(p, *script_args)
603
+
604
+ shared.total_tqdm.clear()
605
+
606
+ return processed
607
+
608
+ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
609
+ diffs = []
610
+ outnum =[]
611
+ imgs[0],imgs[1] = imgs[1],imgs[0]
612
+ im1 = np.array(imgs[0])
613
+
614
+ for i in range(len(imgs)-1):
615
+ im2 = np.array(imgs[i+1])
616
+
617
+ abs_diff = cv2.absdiff(im2 , im1)
618
+
619
+ abs_diff_t = cv2.threshold(abs_diff, int(thresh), 255, cv2.THRESH_BINARY)[1]
620
+ res = abs_diff_t.astype(np.uint8)
621
+ percentage = (np.count_nonzero(res) * 100)/ res.size
622
+ if "white" in diffcol: abs_diff = cv2.bitwise_not(abs_diff)
623
+ outnum.append(percentage)
624
+
625
+ abs_diff = Image.fromarray(abs_diff)
626
+
627
+ diffs.append(abs_diff)
628
+
629
+ outs = []
630
+ for i in range(len(ls)):
631
+ ls[i] = ls[i] + "\n Diff : " + str(round(outnum[i],3)) + "%"
632
+
633
+ if not revxy:
634
+ for diff,img in zip(diffs,imgs[1:]):
635
+ outs.append(diff)
636
+ outs.append(img)
637
+ outs.append(imgs[0])
638
+ ss = ["diff",ss[0],"source"]
639
+ return outs,ss,ls
640
+ else:
641
+ outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
642
+ ss = ["source",ss[0],"diff"]
643
+ return outs,ls,ss