|
import cv2
|
|
import json
|
|
import os
|
|
import gc
|
|
import re
|
|
import sys
|
|
import torch
|
|
import shutil
|
|
import math
|
|
import importlib
|
|
import numpy as np
|
|
import gradio as gr
|
|
import os.path
|
|
import random
|
|
from pprint import pprint
|
|
import modules.ui
|
|
import modules.scripts as scripts
|
|
from PIL import Image, ImageFont, ImageDraw
|
|
import modules.shared as shared
|
|
from modules import devices, sd_models, images,cmd_args, extra_networks, sd_hijack
|
|
from modules.shared import cmd_opts, opts, state
|
|
from modules.processing import process_images, Processed
|
|
from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser
|
|
|
|
LBW_T = "customscript/lora_block_weight.py/txt2img/Active/value"
|
|
LBW_I = "customscript/lora_block_weight.py/img2img/Active/value"
|
|
|
|
if os.path.exists(cmd_opts.ui_config_file):
|
|
with open(cmd_opts.ui_config_file, 'r', encoding="utf-8") as json_file:
|
|
ui_config = json.load(json_file)
|
|
else:
|
|
print("ui config file not found, using default values")
|
|
ui_config = {}
|
|
|
|
startup_t = ui_config[LBW_T] if LBW_T in ui_config else None
|
|
startup_i = ui_config[LBW_I] if LBW_I in ui_config else None
|
|
active_t = "Active" if startup_t else "Not Active"
|
|
active_i = "Active" if startup_i else "Not Active"
|
|
|
|
lxyz = ""
|
|
lzyx = ""
|
|
prompts = ""
|
|
xyelem = ""
|
|
princ = False
|
|
|
|
try:
|
|
from ldm_patched.modules import model_management
|
|
forge = True
|
|
except:
|
|
forge = False
|
|
|
|
BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
|
BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
|
BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"]
|
|
BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"]
|
|
BLOCKNUMS = [12,17,20,26]
|
|
BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26]
|
|
|
|
BLOCKS=["encoder",
|
|
"diffusion_model_input_blocks_0_",
|
|
"diffusion_model_input_blocks_1_",
|
|
"diffusion_model_input_blocks_2_",
|
|
"diffusion_model_input_blocks_3_",
|
|
"diffusion_model_input_blocks_4_",
|
|
"diffusion_model_input_blocks_5_",
|
|
"diffusion_model_input_blocks_6_",
|
|
"diffusion_model_input_blocks_7_",
|
|
"diffusion_model_input_blocks_8_",
|
|
"diffusion_model_input_blocks_9_",
|
|
"diffusion_model_input_blocks_10_",
|
|
"diffusion_model_input_blocks_11_",
|
|
"diffusion_model_middle_block_",
|
|
"diffusion_model_output_blocks_0_",
|
|
"diffusion_model_output_blocks_1_",
|
|
"diffusion_model_output_blocks_2_",
|
|
"diffusion_model_output_blocks_3_",
|
|
"diffusion_model_output_blocks_4_",
|
|
"diffusion_model_output_blocks_5_",
|
|
"diffusion_model_output_blocks_6_",
|
|
"diffusion_model_output_blocks_7_",
|
|
"diffusion_model_output_blocks_8_",
|
|
"diffusion_model_output_blocks_9_",
|
|
"diffusion_model_output_blocks_10_",
|
|
"diffusion_model_output_blocks_11_",
|
|
"embedders",
|
|
"transformer_resblocks"]
|
|
|
|
loopstopper = True
|
|
|
|
ATYPES =["none","Block ID","values","seed","Original Weights","elements"]
|
|
|
|
DEF_WEIGHT_PRESET = "\
|
|
NOT:0,0,0,0,0,0,0,0,0,0,0,0\n\
|
|
ALL:1,1,1,1,1,1,1,1,1,1,1,1\n\
|
|
INS:1,1,1,0,0,0,0,0,0,0,0,0\n\
|
|
IND:1,0,0,1,1,0,0,0,0,0,0,0\n\
|
|
INALL:1,1,1,1,1,0,0,0,0,0,0,0\n\
|
|
MIDD:1,0,0,1,1,1,1,1,1,0,0,0\n\
|
|
OUTD:1,0,0,0,0,0,1,1,1,0,0,0\n\
|
|
OUTS:1,0,0,0,0,0,0,0,0,1,1,1\n\
|
|
OUTALL:1,0,0,0,0,0,1,1,1,1,1,1\n\
|
|
BASE:1,0,0,0,0,0,0,0,0,0,0,0\n\
|
|
IN04:1,1,0,0,0,0,0,0,0,0,0,0\n\
|
|
IN05:1,0,1,0,0,0,0,0,0,0,0,0\n\
|
|
IN07:1,0,0,1,0,0,0,0,0,0,0,0\n\
|
|
IN08:1,0,0,0,1,0,0,0,0,0,0,0\n\
|
|
MID:1,0,0,0,0,1,0,0,0,0,0,0\n\
|
|
OUT00:1,0,0,0,0,0,1,0,0,0,0,0\n\
|
|
OUT01:1,0,0,0,0,0,0,1,0,0,0,0\n\
|
|
OUT02:1,0,0,0,0,0,0,0,1,0,0,0\n\
|
|
OUT03:1,0,0,0,0,0,0,0,0,1,0,0\n\
|
|
OUT04:1,0,0,0,0,0,0,0,0,0,1,0\n\
|
|
OUT05:1,0,0,0,0,0,0,0,0,0,0,1"
|
|
|
|
scriptpath = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
class Script(modules.scripts.Script):
|
|
def __init__(self):
|
|
self.log = {}
|
|
self.stops = {}
|
|
self.starts = {}
|
|
self.active = False
|
|
self.lora = {}
|
|
self.lycoris = {}
|
|
self.networks = {}
|
|
|
|
self.stopsf = []
|
|
self.startsf = []
|
|
self.uf = []
|
|
self.lf = []
|
|
self.ef = []
|
|
|
|
def title(self):
|
|
return "LoRA Block Weight"
|
|
|
|
def show(self, is_img2img):
|
|
return modules.scripts.AlwaysVisible
|
|
|
|
def ui(self, is_img2img):
|
|
LWEIGHTSPRESETS = DEF_WEIGHT_PRESET
|
|
|
|
runorigin = scripts.scripts_txt2img.run
|
|
runorigini = scripts.scripts_img2img.run
|
|
|
|
scriptpath = os.path.dirname(os.path.abspath(__file__))
|
|
path_root = scripts.basedir()
|
|
|
|
extpath = os.path.join(scriptpath, "lbwpresets.txt")
|
|
extpathe = os.path.join(scriptpath, "elempresets.txt")
|
|
filepath = os.path.join(path_root,"scripts", "lbwpresets.txt")
|
|
filepathe = os.path.join(path_root,"scripts", "elempresets.txt")
|
|
|
|
if os.path.isfile(filepath) and not os.path.isfile(extpath):
|
|
shutil.move(filepath,extpath)
|
|
|
|
if os.path.isfile(filepathe) and not os.path.isfile(extpathe):
|
|
shutil.move(filepathe,extpathe)
|
|
|
|
lbwpresets=""
|
|
|
|
try:
|
|
with open(extpath,encoding="utf-8") as f:
|
|
lbwpresets = f.read()
|
|
except OSError as e:
|
|
lbwpresets=LWEIGHTSPRESETS
|
|
if not os.path.isfile(extpath):
|
|
try:
|
|
with open(extpath,mode = 'w',encoding="utf-8") as f:
|
|
f.write(lbwpresets)
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
with open(extpathe,encoding="utf-8") as f:
|
|
elempresets = f.read()
|
|
except OSError as e:
|
|
elempresets=ELEMPRESETS
|
|
if not os.path.isfile(extpathe):
|
|
try:
|
|
with open(extpathe,mode = 'w',encoding="utf-8") as f:
|
|
f.write(elempresets)
|
|
except:
|
|
pass
|
|
|
|
loraratios=lbwpresets.splitlines()
|
|
lratios={}
|
|
for i,l in enumerate(loraratios):
|
|
if checkloadcond(l) : continue
|
|
lratios[l.split(":")[0]]=l.split(":")[1]
|
|
ratiostags = [k for k in lratios.keys()]
|
|
ratiostags = ",".join(ratiostags)
|
|
|
|
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
|
args = cmd_args.parser.parse_args()
|
|
else:
|
|
args, _ = cmd_args.parser.parse_known_args()
|
|
if args.api:
|
|
register()
|
|
|
|
with gr.Accordion(f"LoRA Block Weight : {active_i if is_img2img else active_t}",open = False) as acc:
|
|
with gr.Row():
|
|
with gr.Column(min_width = 50, scale=1):
|
|
lbw_useblocks = gr.Checkbox(value = True,label="Active",interactive =True,elem_id="lbw_active")
|
|
debug = gr.Checkbox(value = False,label="Debug",interactive =True,elem_id="lbw_debug")
|
|
with gr.Column(scale=5):
|
|
bw_ratiotags= gr.TextArea(label="",value=ratiostags,visible =True,interactive =True,elem_id="lbw_ratios")
|
|
with gr.Accordion("XYZ plot",open = False):
|
|
gr.HTML(value='<p style= "word-wrap:break-word;">changeable blocks : BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>')
|
|
xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
|
|
with gr.Row(visible = False) as esets:
|
|
diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
|
|
revxy = gr.Checkbox(value = False,label="change X-Y",interactive =True,elem_id="lbw_changexy")
|
|
thresh = gr.Textbox(label="difference threshold",lines=1,value="20",interactive =True,elem_id="diff_thr")
|
|
xtype = gr.Dropdown(label="X Types", choices=[x for x in ATYPES], value=ATYPES [2],interactive =True,elem_id="lbw_xtype")
|
|
xmen = gr.Textbox(label="X Values",lines=1,value="0,0.25,0.5,0.75,1",interactive =True,elem_id="lbw_xmen")
|
|
ytype = gr.Dropdown(label="Y Types", choices=[y for y in ATYPES], value=ATYPES [1],interactive =True,elem_id="lbw_ytype")
|
|
ymen = gr.Textbox(label="Y Values" ,lines=1,value="IN05-OUT05",interactive =True,elem_id="lbw_ymen")
|
|
ztype = gr.Dropdown(label="Z type", choices=[z for z in ATYPES], value=ATYPES[0],interactive =True,elem_id="lbw_ztype")
|
|
zmen = gr.Textbox(label="Z values",lines=1,value="",interactive =True,elem_id="lbw_zmen")
|
|
|
|
exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
|
|
eymen = gr.Textbox(label="Blocks (12ALL,17ALL,20ALL,26ALL also can be used)" ,lines=1,value="BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
|
|
ecount = gr.Number(value=1, label="number of seed", interactive=True, visible = True)
|
|
|
|
with gr.Accordion("Weights setting",open = True):
|
|
with gr.Row():
|
|
reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
|
|
reloadtags = gr.Button(value="Reload Tags",variant='primary',elem_id="lbw_reload")
|
|
savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
|
|
openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
|
|
lbw_loraratios = gr.TextArea(label="",value=lbwpresets,visible =True,interactive = True,elem_id="lbw_ratiospreset")
|
|
|
|
with gr.Accordion("Elemental",open = False):
|
|
with gr.Row():
|
|
e_reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
|
|
e_savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
|
|
e_openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
|
|
elemsets = gr.Checkbox(value = False,label="print change",interactive =True,elem_id="lbw_print_change")
|
|
elemental = gr.TextArea(label="Identifer:BlockID:Elements:Ratio,...,separated by empty line ",value = elempresets,interactive =True,elem_id="element")
|
|
|
|
d_true = gr.Checkbox(value = True,visible = False)
|
|
d_false = gr.Checkbox(value = False,visible = False)
|
|
|
|
with gr.Accordion("Make Weights",open = False):
|
|
with gr.Row():
|
|
m_text = gr.Textbox(value="",label="Weights")
|
|
with gr.Row():
|
|
m_add = gr.Button(value="Add to presets",size="sm",variant='primary')
|
|
m_add_save = gr.Button(value="Add to presets and Save",size="sm",variant='primary')
|
|
m_name = gr.Textbox(value="",label="Identifier")
|
|
with gr.Row():
|
|
m_type = gr.Radio(label="Weights type",choices=["17(1.X/2.X)", "26(1.X/2.X full)", "12(XL)","20(XL full)"], value="17(1.X/2.X)")
|
|
with gr.Row():
|
|
m_set_0 = gr.Button(value="Set All 0",variant='primary')
|
|
m_set_1 = gr.Button(value="Set All 1",variant='primary')
|
|
m_custom = gr.Button(value="Set custom",variant='primary')
|
|
m_custom_v = gr.Slider(show_label=False, minimum=-1.0, maximum=1, step=0.1, value=0, interactive=True)
|
|
with gr.Row():
|
|
with gr.Column(scale=1, min_width=100):
|
|
gr.Slider(visible=False)
|
|
with gr.Column(scale=2, min_width=200):
|
|
base = gr.Slider(label="BASE", minimum=-1, maximum=1, step=0.1, value=0.0)
|
|
with gr.Column(scale=1, min_width=100):
|
|
gr.Slider(visible=False)
|
|
with gr.Row():
|
|
with gr.Column(scale=2, min_width=200):
|
|
ins = [gr.Slider(label=block, minimum=-1.0, maximum=1, step=0.1, value=0, interactive=True) for block in BLOCKID26[1:13]]
|
|
with gr.Column(scale=2, min_width=200):
|
|
outs = [gr.Slider(label=block, minimum=-1.0, maximum=1, step=0.1, value=0, interactive=True) for block in reversed(BLOCKID26[14:])]
|
|
with gr.Row():
|
|
with gr.Column(scale=1, min_width=100):
|
|
gr.Slider(visible=False)
|
|
with gr.Column(scale=2, min_width=200):
|
|
m00 = gr.Slider(label="M00", minimum=-1, maximum=1, step=0.1, value=0.0)
|
|
with gr.Column(scale=1, min_width=100):
|
|
gr.Slider(visible=False)
|
|
|
|
blocks = [base] + ins + [m00] + outs[::-1]
|
|
for block in blocks:
|
|
if block.label not in BLOCKID17:
|
|
block.visible = False
|
|
|
|
m_set_0.click(fn=lambda x:[0]*26 + [",".join(["0"]*int(x[:2]))],inputs=[m_type],outputs=blocks + [m_text])
|
|
m_set_1.click(fn=lambda x:[1]*26 + [",".join(["1"]*int(x[:2]))],inputs=[m_type],outputs=blocks + [m_text])
|
|
m_custom.click(fn=lambda x,y:[x]*26 + [",".join([str(x)]*int(y[:2]))],inputs=[m_custom_v,m_type],outputs=blocks + [m_text])
|
|
|
|
def addweights(weights, id, presets, save = False):
|
|
if id == "":id = "NONAME"
|
|
lines = presets.strip().split("\n")
|
|
id_found = False
|
|
for i, line in enumerate(lines):
|
|
if line.startswith("#"):
|
|
continue
|
|
if line.split(":")[0] == id:
|
|
lines[i] = f"{id}:{weights}"
|
|
id_found = True
|
|
break
|
|
if not id_found:
|
|
lines.append(f"{id}:{weights}")
|
|
|
|
if save:
|
|
with open(extpath,mode = 'w',encoding="utf-8") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
return "\n".join(lines)
|
|
|
|
def changetheblocks(sdver,*blocks):
|
|
sdver = int(sdver[:2])
|
|
output = []
|
|
targ_blocks = BLOCKIDS[BLOCKNUMS.index(sdver)]
|
|
for i, block in enumerate(BLOCKID26):
|
|
if block in targ_blocks:
|
|
output.append(str(blocks[i]))
|
|
return [",".join(output)] + [gr.update(visible = True if block in targ_blocks else False) for block in BLOCKID26]
|
|
|
|
m_add.click(fn=addweights, inputs=[m_text,m_name,lbw_loraratios],outputs=[lbw_loraratios])
|
|
m_add_save.click(fn=addweights, inputs=[m_text,m_name,lbw_loraratios, d_true],outputs=[lbw_loraratios])
|
|
m_type.change(fn=changetheblocks, inputs=[m_type] + blocks,outputs=[m_text] + blocks)
|
|
|
|
d_true = gr.Checkbox(value = True,visible = False)
|
|
d_false = gr.Checkbox(value = False,visible = False)
|
|
|
|
lbw_useblocks.change(fn=lambda x:gr.update(label = f"LoRA Block Weight : {'Active' if x else 'Not Active'}"),inputs=lbw_useblocks, outputs=[acc])
|
|
|
|
def makeweights(sdver, *blocks):
|
|
sdver = int(sdver[:2])
|
|
output = []
|
|
targ_blocks = BLOCKIDS[BLOCKNUMS.index(sdver)]
|
|
for i, block in enumerate(BLOCKID26):
|
|
if block in targ_blocks:
|
|
output.append(str(blocks[i]))
|
|
return ",".join(output)
|
|
|
|
changes = [b.release(fn=makeweights,inputs=[m_type] + blocks,outputs=[m_text]) for b in blocks]
|
|
|
|
import subprocess
|
|
def openeditors(b):
|
|
path = extpath if b else extpathe
|
|
subprocess.Popen(['start', path], shell=True)
|
|
|
|
def reloadpresets(isweight):
|
|
if isweight:
|
|
try:
|
|
with open(extpath,encoding="utf-8") as f:
|
|
return f.read()
|
|
except OSError as e:
|
|
pass
|
|
else:
|
|
try:
|
|
with open(extpath,encoding="utf-8") as f:
|
|
return f.read()
|
|
except OSError as e:
|
|
pass
|
|
|
|
def tagdicter(presets):
|
|
presets=presets.splitlines()
|
|
wdict={}
|
|
for l in presets:
|
|
if checkloadcond(l) : continue
|
|
w=[]
|
|
if ":" in l :
|
|
key = l.split(":",1)[0]
|
|
w = l.split(":",1)[1]
|
|
if any(len([w for w in w.split(",")]) == x for x in BLOCKNUMS):
|
|
wdict[key.strip()]=w
|
|
return ",".join(list(wdict.keys()))
|
|
|
|
def savepresets(text,isweight):
|
|
if isweight:
|
|
with open(extpath,mode = 'w',encoding="utf-8") as f:
|
|
f.write(text)
|
|
else:
|
|
with open(extpathe,mode = 'w',encoding="utf-8") as f:
|
|
f.write(text)
|
|
|
|
reloadtext.click(fn=reloadpresets,inputs=[d_true],outputs=[lbw_loraratios])
|
|
reloadtags.click(fn=tagdicter,inputs=[lbw_loraratios],outputs=[bw_ratiotags])
|
|
savetext.click(fn=savepresets,inputs=[lbw_loraratios,d_true],outputs=[])
|
|
openeditor.click(fn=openeditors,inputs=[d_true],outputs=[])
|
|
|
|
e_reloadtext.click(fn=reloadpresets,inputs=[d_false],outputs=[elemental])
|
|
e_savetext.click(fn=savepresets,inputs=[elemental,d_false],outputs=[])
|
|
e_openeditor.click(fn=openeditors,inputs=[d_false],outputs=[])
|
|
|
|
def urawaza(active):
|
|
if active > 0:
|
|
register()
|
|
scripts.scripts_txt2img.run = newrun
|
|
scripts.scripts_img2img.run = newrun
|
|
if active == 1:return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]]
|
|
else:return [*[gr.update(visible = False) for x in range(6)],*[gr.update(visible = True) for x in range(4)]]
|
|
else:
|
|
scripts.scripts_txt2img.run = runorigin
|
|
scripts.scripts_img2img.run = runorigini
|
|
return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]]
|
|
|
|
xyzsetting.change(fn=urawaza,inputs=[xyzsetting],outputs =[xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,esets])
|
|
|
|
return lbw_loraratios,lbw_useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug
|
|
|
|
def process(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug):
|
|
|
|
|
|
if(loraratios == None):
|
|
loraratios = DEF_WEIGHT_PRESET
|
|
if(useblocks == None):
|
|
useblocks = True
|
|
|
|
lorachecker(self)
|
|
self.log["enable LBW"] = useblocks
|
|
self.log["registerd"] = registerd
|
|
|
|
if useblocks:
|
|
self.active = True
|
|
loraratios=loraratios.splitlines()
|
|
elemental = elemental.split("\n\n") if elemental is not None else []
|
|
lratios={}
|
|
elementals={}
|
|
for l in loraratios:
|
|
if checkloadcond(l) : continue
|
|
l0=l.split(":",1)[0]
|
|
lratios[l0.strip()]=l.split(":",1)[1]
|
|
for e in elemental:
|
|
if ":" not in e: continue
|
|
e0=e.split(":",1)[0]
|
|
elementals[e0.strip()]=e.split(":",1)[1]
|
|
if elemsets : print(xyelem)
|
|
if xyzsetting and "XYZ" in p.prompt:
|
|
lratios["XYZ"] = lxyz
|
|
lratios["ZYX"] = lzyx
|
|
if xyelem != "":
|
|
if "XYZ" in elementals.keys():
|
|
elementals["XYZ"] = elementals["XYZ"] + ","+ xyelem
|
|
else:
|
|
elementals["XYZ"] = xyelem
|
|
self.lratios = lratios
|
|
self.elementals = elementals
|
|
global princ
|
|
princ = elemsets
|
|
|
|
if not hasattr(self,"lbt_dr_callbacks"):
|
|
self.lbt_dr_callbacks = on_cfg_denoiser(self.denoiser_callback)
|
|
|
|
def denoiser_callback(self, params: CFGDenoiserParams):
|
|
def setparams(self, key, te, u ,sets):
|
|
for dicts in [self.lora,self.lycoris,self.networks]:
|
|
for lora in dicts:
|
|
if lora.name.split("_in_LBW_")[0] == key:
|
|
lora.te_multiplier = te
|
|
lora.unet_multiplier = u
|
|
sets.append(key)
|
|
|
|
if forge and self.active:
|
|
if params.sampling_step in self.startsf:
|
|
shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
|
|
for key, vals in shared.sd_model.forge_objects.unet.patches.items():
|
|
n_vals = []
|
|
lvals = [val for val in vals if val[1][0] in LORAS]
|
|
for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef):
|
|
if s is not None and s == params.sampling_step:
|
|
ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
|
|
n_vals.append((ratio * m, *v[1:]))
|
|
else:
|
|
n_vals.append(v)
|
|
shared.sd_model.forge_objects.unet.patches[key] = n_vals
|
|
shared.sd_model.forge_objects.unet.patch_model()
|
|
|
|
if params.sampling_step in self.stopsf:
|
|
shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
|
|
for key, vals in shared.sd_model.forge_objects.unet.patches.items():
|
|
n_vals = []
|
|
lvals = [val for val in vals if val[1][0] in LORAS]
|
|
for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef):
|
|
if s is not None and s == params.sampling_step:
|
|
n_vals.append((0, *v[1:]))
|
|
else:
|
|
n_vals.append(v)
|
|
shared.sd_model.forge_objects.unet.patches[key] = n_vals
|
|
shared.sd_model.forge_objects.unet.patch_model()
|
|
|
|
elif self.active:
|
|
if self.starts and params.sampling_step == 0:
|
|
for key, step_te_u in self.starts.items():
|
|
setparams(self, key, 0, 0, [])
|
|
|
|
|
|
if self.starts:
|
|
sets = []
|
|
for key, step_te_u in self.starts.items():
|
|
step, te, u = step_te_u
|
|
if params.sampling_step > step - 2:
|
|
setparams(self, key, te, u, sets)
|
|
|
|
for key in sets:
|
|
if key in self.starts:
|
|
del self.starts[key]
|
|
|
|
if self.stops:
|
|
sets = []
|
|
for key, step in self.stops.items():
|
|
if params.sampling_step > step - 2:
|
|
setparams(self, key, 0, 0, sets)
|
|
|
|
for key in sets:
|
|
if key in self.stops:
|
|
del self.stops[key]
|
|
|
|
def before_process_batch(self, p, loraratios,useblocks,*args,**kwargs):
|
|
if useblocks:
|
|
resetmemory()
|
|
if not self.isnet: p.disable_extra_networks = False
|
|
global prompts
|
|
prompts = kwargs["prompts"].copy()
|
|
|
|
def process_batch(self, p, loraratios,useblocks,*args,**kwargs):
|
|
if useblocks:
|
|
if not self.isnet: p.disable_extra_networks = True
|
|
|
|
o_prompts = [p.prompt]
|
|
for prompt in prompts:
|
|
if "<lora" in prompt or "<lyco" in prompt:
|
|
o_prompts = prompts.copy()
|
|
if not self.isnet: loradealer(self, o_prompts ,self.lratios,self.elementals)
|
|
|
|
def postprocess(self, p, processed, presets,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug,*args):
|
|
if not useblocks:
|
|
return
|
|
lora = importer(self)
|
|
emb_db = sd_hijack.model_hijack.embedding_db
|
|
|
|
for net in lora.loaded_loras:
|
|
if hasattr(net,"bundle_embeddings"):
|
|
for emb_name, embedding in net.bundle_embeddings.items():
|
|
if embedding.loaded:
|
|
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
|
|
|
lora.loaded_loras.clear()
|
|
|
|
if forge:
|
|
sd_models.model_data.get_sd_model().current_lora_hash = None
|
|
shared.sd_model.forge_objects_after_applying_lora.unet.unpatch_model()
|
|
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.unpatch_model()
|
|
|
|
global lxyz,lzyx,xyelem
|
|
lxyz = lzyx = xyelem = ""
|
|
if debug:
|
|
print(self.log)
|
|
gc.collect()
|
|
|
|
def after_extra_networks_activate(self, p, presets,useblocks, *args, **kwargs):
|
|
if useblocks:
|
|
loradealer(self, kwargs["prompts"] ,self.lratios,self.elementals,kwargs["extra_network_data"])
|
|
|
|
def run(self,p,presets,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug):
|
|
if not useblocks:
|
|
return
|
|
self.__init__()
|
|
self.log["pass XYZ"] = True
|
|
self.log["XYZsets"] = xyzsetting
|
|
self.log["enable LBW"] = useblocks
|
|
|
|
if xyzsetting >0:
|
|
lorachecker(self)
|
|
lora = importer(self)
|
|
loraratios=presets.splitlines()
|
|
lratios={}
|
|
for l in loraratios:
|
|
if checkloadcond(l) : continue
|
|
l0=l.split(":",1)[0]
|
|
lratios[l0.strip()]=l.split(":",1)[1]
|
|
|
|
if "XYZ" in p.prompt:
|
|
base = lratios["XYZ"] if "XYZ" in lratios.keys() else "1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1"
|
|
else: return
|
|
|
|
for i, all in enumerate(["12ALL","17ALL","20ALL","26ALL"]):
|
|
if eymen == all:
|
|
eymen = ",".join(BLOCKIDS[i])
|
|
|
|
if xyzsetting > 1:
|
|
xmen,ymen = exmen,eymen
|
|
xtype,ytype = "values","ID"
|
|
ebase = xmen.split(",")[1]
|
|
ebase = [ebase.strip()]*26
|
|
base = ",".join(ebase)
|
|
ztype = ""
|
|
if ecount > 1:
|
|
ztype = "seed"
|
|
zmen = ",".join([str(random.randrange(4294967294)) for x in range(int(ecount))])
|
|
|
|
|
|
|
|
def dicedealer(am):
|
|
for i,a in enumerate(am):
|
|
if a =="-1": am[i] = str(random.randrange(4294967294))
|
|
print(f"the die was thrown : {am}")
|
|
|
|
if p.seed == -1: p.seed = str(random.randrange(4294967294))
|
|
|
|
|
|
|
|
def adjuster(a,at):
|
|
if "none" in at:a = ""
|
|
a = [a.strip() for a in a.split(',')]
|
|
if "seed" in at:dicedealer(a)
|
|
return a
|
|
|
|
xs = adjuster(xmen,xtype)
|
|
ys = adjuster(ymen,ytype)
|
|
zs = adjuster(zmen,ztype)
|
|
|
|
ids = alpha =seed = ""
|
|
p.batch_size = 1
|
|
|
|
print(f"xs:{xs},ys:{ys},zs:{zs}")
|
|
|
|
images = []
|
|
|
|
def weightsdealer(alpha,ids,base):
|
|
|
|
ids = [z.strip() for z in ids.split(' ')]
|
|
weights_t = [w.strip() for w in base.split(',')]
|
|
blockid = BLOCKIDS[BLOCKNUMS.index(len(weights_t))]
|
|
if ids[0]!="NOT":
|
|
flagger=[False]*len(weights_t)
|
|
changer = True
|
|
else:
|
|
flagger=[True]*len(weights_t)
|
|
changer = False
|
|
for id in ids:
|
|
if id =="NOT":continue
|
|
if "-" in id:
|
|
it = [it.strip() for it in id.split('-')]
|
|
if blockid.index(it[1]) > blockid.index(it[0]):
|
|
flagger[blockid.index(it[0]):blockid.index(it[1])+1] = [changer]*(blockid.index(it[1])-blockid.index(it[0])+1)
|
|
else:
|
|
flagger[blockid.index(it[1]):blockid.index(it[0])+1] = [changer]*(blockid.index(it[0])-blockid.index(it[1])+1)
|
|
else:
|
|
flagger[blockid.index(id)] =changer
|
|
for i,f in enumerate(flagger):
|
|
if f:weights_t[i]=alpha
|
|
outext = ",".join(weights_t)
|
|
|
|
return outext
|
|
|
|
generatedbases=[]
|
|
def xyzdealer(a,at):
|
|
nonlocal ids,alpha,p,base,c_base,generatedbases
|
|
if "ID" in at:return
|
|
if "values" in at:alpha = a
|
|
if "seed" in at:
|
|
p.seed = int(a)
|
|
generatedbases=[]
|
|
if "Weights" in at:base =c_base = lratios[a]
|
|
if "elements" in at:
|
|
global xyelem
|
|
xyelem = a
|
|
|
|
def imagedupewatcher(baselist,basetocheck,currentiteration):
|
|
for idx,alreadygenerated in enumerate(baselist):
|
|
if (basetocheck == alreadygenerated):
|
|
|
|
baselist.insert(currentiteration-1, basetocheck)
|
|
return idx
|
|
return -1
|
|
|
|
def strThree(someNumber):
|
|
return format(someNumber, ".3f").rstrip('0').rstrip('.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def xyoriginalweightsdealer(x,y):
|
|
xweights = np.asarray(lratios[x].split(','), dtype=np.float32)
|
|
yweights = np.asarray(lratios[y].split(','), dtype=np.float32)
|
|
for idx,xval in np.ndenumerate(xweights):
|
|
yval = yweights[idx]
|
|
if xval != 0 and yval != 0:
|
|
yweights[idx] = 0
|
|
|
|
|
|
baseListToStrings = list(map(strThree, np.around(np.add(xweights,yweights,),3).tolist()))
|
|
return ",".join(baseListToStrings)
|
|
|
|
grids = []
|
|
images =[]
|
|
|
|
totalcount = len(xs)*len(ys)*len(zs) if xyzsetting < 2 else len(xs)*len(ys)*len(zs) //2 +1
|
|
shared.total_tqdm.updateTotal(totalcount)
|
|
xc = yc =zc = 0
|
|
state.job_count = totalcount
|
|
totalcount = len(xs)*len(ys)*len(zs)
|
|
c_base = base
|
|
|
|
for z in zs:
|
|
generatedbases=[]
|
|
images = []
|
|
yc = 0
|
|
xyzdealer(z,ztype)
|
|
for y in ys:
|
|
xc = 0
|
|
xyzdealer(y,ytype)
|
|
for x in xs:
|
|
xyzdealer(x,xtype)
|
|
if "Weights" in xtype and "Weights" in ytype:
|
|
c_base = xyoriginalweightsdealer(x,y)
|
|
else:
|
|
if "ID" in xtype:
|
|
if "values" in ytype:c_base = weightsdealer(y,x,base)
|
|
if "values" in ztype:c_base = weightsdealer(z,x,base)
|
|
if "ID" in ytype:
|
|
if "values" in xtype:c_base = weightsdealer(x,y,base)
|
|
if "values" in ztype:c_base = weightsdealer(z,y,base)
|
|
if "ID" in ztype:
|
|
if "values" in xtype:c_base = weightsdealer(x,z,base)
|
|
if "values" in ytype:c_base = weightsdealer(y,z,base)
|
|
|
|
iteration = len(xs)*len(ys)*zc + yc*len(xs) +xc +1
|
|
print(f"X:{xtype}, {x},Y: {ytype},{y}, Z:{ztype},{z}, base:{c_base} ({iteration}/{totalcount})")
|
|
|
|
dupe_index = imagedupewatcher(generatedbases,c_base,iteration)
|
|
if dupe_index > -1:
|
|
print(f"Skipping generation of duplicate base:{c_base}")
|
|
images.append(images[dupe_index].copy())
|
|
xc += 1
|
|
continue
|
|
|
|
global lxyz,lzyx
|
|
lxyz = c_base
|
|
|
|
cr_base = c_base.split(",")
|
|
cr_base_t=[]
|
|
for x in cr_base:
|
|
if not identifier(x):
|
|
cr_base_t.append(str(1-float(x)))
|
|
else:
|
|
cr_base_t.append(x)
|
|
lzyx = ",".join(cr_base_t)
|
|
|
|
if not(xc == 1 and not (yc ==0 ) and xyzsetting >1):
|
|
lora.loaded_loras.clear()
|
|
p.cached_c = [None,None]
|
|
p.cached_uc = [None,None]
|
|
p.cached_hr_c = [None, None]
|
|
p.cached_hr_uc = [None, None]
|
|
processed:Processed = process_images(p)
|
|
images.append(processed.images[0])
|
|
generatedbases.insert(iteration-1, c_base)
|
|
xc += 1
|
|
yc += 1
|
|
zc += 1
|
|
origin = loranames(processed.all_prompts) + ", "+ znamer(ztype,z,base)
|
|
images,xst,yst = effectivechecker(images,xs.copy(),ys.copy(),diffcol,thresh,revxy) if xyzsetting >1 else (images,xs.copy(),ys.copy())
|
|
grids.append(smakegrid(images,xst,yst,origin,p))
|
|
processed.images= grids
|
|
lora.loaded_loras.clear()
|
|
return processed
|
|
|
|
def identifier(char):
|
|
return char[0] in ["R", "U", "X"]
|
|
|
|
def znamer(at,a,base):
|
|
if "ID" in at:return f"Block : {a}"
|
|
if "values" in at:return f"value : {a}"
|
|
if "seed" in at:return f"seed : {a}"
|
|
if "Weights" in at:return f"original weights :\n {base}"
|
|
else: return ""
|
|
|
|
def loranames(all_prompts):
|
|
_, extra_network_data = extra_networks.parse_prompts(all_prompts[0:1])
|
|
calledloras = extra_network_data["lora"] if "lyco" not in extra_network_data.keys() else extra_network_data["lyco"]
|
|
names = ""
|
|
for called in calledloras:
|
|
if len(called.items) <3:continue
|
|
names += called.items[0]
|
|
return names
|
|
|
|
def lorachecker(self):
|
|
try:
|
|
import networks
|
|
self.isnet = True
|
|
self.layer_name = "network_layer_name"
|
|
except:
|
|
self.isnet = False
|
|
self.layer_name = "lora_layer_name"
|
|
try:
|
|
import lora
|
|
self.islora = True
|
|
except:
|
|
pass
|
|
try:
|
|
import lycoris
|
|
self.islyco = True
|
|
except:
|
|
pass
|
|
self.onlyco = (not self.islora) and self.islyco
|
|
self.isxl = hasattr(shared.sd_model,"conditioner")
|
|
|
|
self.log["isnet"] = self.isnet
|
|
self.log["isxl"] = self.isxl
|
|
self.log["islora"] = self.islora
|
|
|
|
def resetmemory():
|
|
try:
|
|
import networks as nets
|
|
nets.networks_in_memory = {}
|
|
gc.collect()
|
|
|
|
except:
|
|
pass
|
|
|
|
def importer(self):
|
|
if self.onlyco:
|
|
|
|
lora_module = importlib.import_module("lycoris")
|
|
return lora_module
|
|
else:
|
|
|
|
lora_module = importlib.import_module("lora")
|
|
return lora_module
|
|
|
|
def loradealer(self, prompts,lratios,elementals, extra_network_data = None):
|
|
if extra_network_data is None:
|
|
_, extra_network_data = extra_networks.parse_prompts(prompts)
|
|
moduletypes = extra_network_data.keys()
|
|
|
|
for ltype in moduletypes:
|
|
lorans = []
|
|
lorars = []
|
|
te_multipliers = []
|
|
unet_multipliers = []
|
|
elements = []
|
|
starts = []
|
|
stops = []
|
|
fparams = []
|
|
load = False
|
|
go_lbw = False
|
|
|
|
if not (ltype == "lora" or ltype == "lyco") : continue
|
|
for called in extra_network_data[ltype]:
|
|
items = called.items
|
|
setnow = False
|
|
name = items[0]
|
|
te = syntaxdealer(items,"te=",1)
|
|
unet = syntaxdealer(items,"unet=",2)
|
|
te,unet = multidealer(te,unet)
|
|
|
|
weights = syntaxdealer(items,"lbw=",2) if syntaxdealer(items,"lbw=",2) is not None else syntaxdealer(items,"w=",2)
|
|
elem = syntaxdealer(items, "lbwe=",3)
|
|
start = syntaxdealer(items,"start=",None)
|
|
stop = syntaxdealer(items,"stop=",None)
|
|
start, stop = stepsdealer(syntaxdealer(items,"step=",None), start, stop)
|
|
|
|
if weights is not None and (weights in lratios or any(weights.count(",") == x - 1 for x in BLOCKNUMS)):
|
|
wei = lratios[weights] if weights in lratios else weights
|
|
ratios = [w.strip() for w in wei.split(",")]
|
|
for i,r in enumerate(ratios):
|
|
if r =="R":
|
|
ratios[i] = round(random.random(),3)
|
|
elif r == "U":
|
|
ratios[i] = round(random.uniform(-0.5,1.5),3)
|
|
elif r[0] == "X":
|
|
base = syntaxdealer(items,"x=", 3) if len(items) >= 4 else 1
|
|
ratios[i] = getinheritedweight(base, r)
|
|
else:
|
|
ratios[i] = float(r)
|
|
|
|
if len(ratios) != 26:
|
|
ratios = to26(ratios)
|
|
setnow = True
|
|
else:
|
|
ratios = [1] * 26
|
|
|
|
if elem in elementals:
|
|
setnow = True
|
|
elem = elementals[elem]
|
|
else:
|
|
elem = ""
|
|
|
|
if setnow:
|
|
print(f"LoRA Block weight ({ltype}): {name}: (Te:{te},Unet:{unet}) x {ratios}")
|
|
go_lbw = True
|
|
fparams.append([unet,ratios,elem])
|
|
settolist([lorans,te_multipliers,unet_multipliers,lorars,elements,starts,stops],[name,te,unet,ratios,elem,start,stop])
|
|
|
|
if start:
|
|
self.starts[name] = [int(start),te,unet]
|
|
self.log["starts"] = load = True
|
|
|
|
if stop:
|
|
self.stops[name] = int(stop)
|
|
self.log["stops"] = load = True
|
|
|
|
self.startsf = [int(s) if s is not None else None for s in starts]
|
|
self.stopsf = [int(s) if s is not None else None for s in stops]
|
|
self.uf = unet_multipliers
|
|
self.lf = lorars
|
|
self.ef = elements
|
|
|
|
if self.isnet: ltype = "nets"
|
|
if forge: ltype = "forge"
|
|
if go_lbw or load: load_loras_blocks(self, lorans,lorars,te_multipliers,unet_multipliers,elements,ltype, starts=starts)
|
|
|
|
def stepsdealer(step, start, stop):
|
|
if step is None or "-" not in step:
|
|
return start, stop
|
|
return step.split("-")
|
|
|
|
def settolist(ls,vs):
|
|
for l, v in zip(ls,vs):
|
|
l.append(v)
|
|
|
|
def syntaxdealer(items,target,index):
|
|
for item in items:
|
|
if target in item:
|
|
return item.replace(target,"")
|
|
if index is None or index + 1> len(items): return None
|
|
if "=" in items[index]:return None
|
|
return items[index] if "@" not in items[index] else 1
|
|
|
|
def isfloat(t):
|
|
try:
|
|
float(t)
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
def multidealer(t, u):
|
|
if t is None and u is None:
|
|
return 1,1
|
|
elif t is None:
|
|
return float(u),float(u)
|
|
elif u is None:
|
|
return float(t), float(t)
|
|
else:
|
|
return float(t),float(u)
|
|
|
|
re_inherited_weight = re.compile(r"X([+-])?([\d.]+)?")
|
|
|
|
def getinheritedweight(weight, offset):
|
|
match = re_inherited_weight.search(offset)
|
|
if match.group(1) == "+":
|
|
return float(weight) + float(match.group(2))
|
|
elif match.group(1) == "-":
|
|
return float(weight) - float(match.group(2))
|
|
else:
|
|
return float(weight)
|
|
|
|
def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts = None):
|
|
oldnew=[]
|
|
if "lora" == ltype:
|
|
lora = importer(self)
|
|
self.lora = lora.loaded_loras
|
|
for loaded in lora.loaded_loras:
|
|
for n, name in enumerate(names):
|
|
if name == loaded.name:
|
|
if lwei[n] == [1] * 26 and elements[n] == "": continue
|
|
lbw(loaded,lwei[n],elements[n])
|
|
setall(loaded,te[n],unet[n])
|
|
newname = loaded.name +"_in_LBW_"+ str(round(random.random(),3))
|
|
oldname = loaded.name
|
|
loaded.name = newname
|
|
oldnew.append([oldname,newname])
|
|
|
|
elif "lyco" == ltype:
|
|
import lycoris as lycomo
|
|
self.lycoris = lycomo.loaded_lycos
|
|
for loaded in lycomo.loaded_lycos:
|
|
for n, name in enumerate(names):
|
|
if name == loaded.name:
|
|
lbw(loaded,lwei[n],elements[n])
|
|
setall(loaded,te[n],unet[n])
|
|
|
|
elif "nets" == ltype:
|
|
import networks as nets
|
|
self.networks = nets.loaded_networks
|
|
for loaded in nets.loaded_networks:
|
|
for n, name in enumerate(names):
|
|
if name == loaded.name:
|
|
lbw(loaded,lwei[n],elements[n])
|
|
setall(loaded,te[n],unet[n])
|
|
|
|
elif "forge" == ltype:
|
|
lbwf(te, unet, lwei, elements, starts)
|
|
|
|
try:
|
|
import lora_ctl_network as ctl
|
|
for old,new in oldnew:
|
|
if old in ctl.lora_weights.keys():
|
|
ctl.lora_weights[new] = ctl.lora_weights[old]
|
|
except:
|
|
pass
|
|
|
|
def setall(m,te,unet):
|
|
m.name = m.name + "_in_LBW_"+ str(round(random.random(),3))
|
|
m.te_multiplier = te
|
|
m.unet_multiplier = unet
|
|
m.multiplier = unet
|
|
|
|
def smakegrid(imgs,xs,ys,currentmodel,p):
|
|
ver_texts = [[images.GridAnnotation(y)] for y in ys]
|
|
hor_texts = [[images.GridAnnotation(x)] for x in xs]
|
|
|
|
w, h = imgs[0].size
|
|
grid = Image.new('RGB', size=(len(xs) * w, len(ys) * h), color='black')
|
|
|
|
for i, img in enumerate(imgs):
|
|
grid.paste(img, box=(i % len(xs) * w, i // len(xs) * h))
|
|
|
|
grid = images.draw_grid_annotations(grid,w, h, hor_texts, ver_texts)
|
|
grid = draw_origin(grid, currentmodel,w*len(xs),h*len(ys),w)
|
|
if opts.grid_save:
|
|
images.save_image(grid, opts.outdir_txt2img_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=p.seed, grid=True, p=p)
|
|
|
|
return grid
|
|
|
|
def get_font(fontsize):
|
|
fontpath = os.path.join(scriptpath, "Roboto-Regular.ttf")
|
|
try:
|
|
return ImageFont.truetype(opts.font or fontpath, fontsize)
|
|
except Exception:
|
|
return ImageFont.truetype(fontpath, fontsize)
|
|
|
|
def draw_origin(grid, text,width,height,width_one):
|
|
grid_d= Image.new("RGB", (grid.width,grid.height), "white")
|
|
grid_d.paste(grid,(0,0))
|
|
|
|
d= ImageDraw.Draw(grid_d)
|
|
color_active = (0, 0, 0)
|
|
fontsize = (width+height)//25
|
|
fnt = get_font(fontsize)
|
|
|
|
if grid.width != width_one:
|
|
while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
|
|
fontsize -=1
|
|
fnt = get_font(fontsize)
|
|
d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
|
|
return grid_d
|
|
|
|
def newrun(p, *args):
|
|
script_index = args[0]
|
|
|
|
if args[0] ==0:
|
|
script = None
|
|
for obj in scripts.scripts_txt2img.alwayson_scripts:
|
|
if "lora_block_weight" in obj.filename:
|
|
script = obj
|
|
script_args = args[script.args_from:script.args_to]
|
|
else:
|
|
script = scripts.scripts_txt2img.selectable_scripts[script_index-1]
|
|
|
|
if script is None:
|
|
return None
|
|
|
|
script_args = args[script.args_from:script.args_to]
|
|
|
|
processed = script.run(p, *script_args)
|
|
|
|
shared.total_tqdm.clear()
|
|
|
|
return processed
|
|
|
|
registerd = False
|
|
|
|
def register():
|
|
global registerd
|
|
registerd = True
|
|
for obj in scripts.scripts_txt2img.alwayson_scripts:
|
|
if "lora_block_weight" in obj.filename:
|
|
if obj not in scripts.scripts_txt2img.selectable_scripts:
|
|
scripts.scripts_txt2img.selectable_scripts.append(obj)
|
|
scripts.scripts_txt2img.titles.append("LoRA Block Weight")
|
|
for obj in scripts.scripts_img2img.alwayson_scripts:
|
|
if "lora_block_weight" in obj.filename:
|
|
if obj not in scripts.scripts_img2img.selectable_scripts:
|
|
scripts.scripts_img2img.selectable_scripts.append(obj)
|
|
scripts.scripts_img2img.titles.append("LoRA Block Weight")
|
|
|
|
def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
|
|
orig = imgs[1]
|
|
imgs = imgs[::2]
|
|
diffs = []
|
|
outnum =[]
|
|
|
|
for img in imgs:
|
|
abs_diff = cv2.absdiff(np.array(img) , np.array(orig))
|
|
|
|
abs_diff_t = cv2.threshold(abs_diff, int(thresh), 255, cv2.THRESH_BINARY)[1]
|
|
res = abs_diff_t.astype(np.uint8)
|
|
percentage = (np.count_nonzero(res) * 100)/ res.size
|
|
if "white" in diffcol: abs_diff = cv2.bitwise_not(abs_diff)
|
|
outnum.append(percentage)
|
|
|
|
abs_diff = Image.fromarray(abs_diff)
|
|
|
|
diffs.append(abs_diff)
|
|
|
|
outs = []
|
|
for i in range(len(ls)):
|
|
ls[i] = ls[i] + "\n Diff : " + str(round(outnum[i],3)) + "%"
|
|
|
|
if not revxy:
|
|
for diff,img in zip(diffs,imgs):
|
|
outs.append(diff)
|
|
outs.append(img)
|
|
outs.append(orig)
|
|
ss = ["diff",ss[0],"source"]
|
|
return outs,ss,ls
|
|
else:
|
|
outs = [orig]*len(diffs) + imgs + diffs
|
|
ss = ["source",ss[0],"diff"]
|
|
return outs,ls,ss
|
|
|
|
def lbw(lora,lwei,elemental):
|
|
elemental = elemental.split(",")
|
|
errormodules = []
|
|
for key in lora.modules.keys():
|
|
ratio, errormodule = ratiodealer(key, lwei, elemental)
|
|
if errormodule:
|
|
errormodules.append(errormodule)
|
|
|
|
ltype = type(lora.modules[key]).__name__
|
|
set = False
|
|
if ltype in LORAANDSOON.keys():
|
|
if "OFT" not in ltype:
|
|
setattr(lora.modules[key],LORAANDSOON[ltype],torch.nn.Parameter(getattr(lora.modules[key],LORAANDSOON[ltype]) * ratio))
|
|
else:
|
|
setattr(lora.modules[key],LORAANDSOON[ltype],getattr(lora.modules[key],LORAANDSOON[ltype]) * ratio)
|
|
set = True
|
|
else:
|
|
if hasattr(lora.modules[key],"up_model"):
|
|
lora.modules[key].up_model.weight= torch.nn.Parameter(lora.modules[key].up_model.weight *ratio)
|
|
|
|
set = True
|
|
else:
|
|
lora.modules[key].up.weight= torch.nn.Parameter(lora.modules[key].up.weight *ratio)
|
|
|
|
set = True
|
|
if not set :
|
|
print("unkwon LoRA")
|
|
|
|
if len(errormodules) > 0:
|
|
print(errormodules)
|
|
return lora
|
|
|
|
LORAS = ["lora", "loha", "lokr"]
|
|
|
|
def lbwf(mt, mu, lwei, elemental, starts):
|
|
for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items():
|
|
n_vals = []
|
|
errormodules = []
|
|
lvals = [val for val in vals if val[1][0] in LORAS]
|
|
for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts):
|
|
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
|
n_vals.append((ratio * m if s is None else 0, *v[1:]))
|
|
if errormodule:errormodules.append(errormodule)
|
|
shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals
|
|
|
|
for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items():
|
|
n_vals = []
|
|
lvals = [val for val in vals if val[1][0] in LORAS]
|
|
for v, m, l, e in zip(lvals, mt, lwei, elemental):
|
|
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
|
n_vals.append((ratio * m, *v[1:]))
|
|
if errormodule:errormodules.append(errormodule)
|
|
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals
|
|
|
|
if len(errormodules) > 0:
|
|
print("Unknown modules:",errormodules)
|
|
|
|
def ratiodealer(key, lwei, elemental):
|
|
ratio = 1
|
|
picked = False
|
|
errormodules = []
|
|
currentblock = 0
|
|
|
|
for i,block in enumerate(BLOCKS):
|
|
if block in key:
|
|
if i == 26 or i == 27:
|
|
i = 0
|
|
ratio = lwei[i]
|
|
picked = True
|
|
currentblock = i
|
|
|
|
if not picked:
|
|
errormodules.append(key)
|
|
|
|
if len(elemental) > 0:
|
|
skey = key + BLOCKID26[currentblock]
|
|
for d in elemental:
|
|
if d.count(":") != 2 :continue
|
|
dbs,dws,dr = (hyphener(d.split(":")[0]),d.split(":")[1],d.split(":")[2])
|
|
dbs,dws = (dbs.split(" "), dws.split(" "))
|
|
dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
|
|
dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
|
|
flag = dbn
|
|
for db in dbs:
|
|
if db in skey:
|
|
flag = not dbn
|
|
if flag:flag = dwn
|
|
else:continue
|
|
for dw in dws:
|
|
if dw in skey:
|
|
flag = not dwn
|
|
if flag:
|
|
dr = float(dr)
|
|
if princ :print(dbs,dws,key,dr)
|
|
ratio = dr
|
|
|
|
return ratio, errormodules
|
|
|
|
LORAANDSOON = {
|
|
"LoraHadaModule" : "w1a",
|
|
"LycoHadaModule" : "w1a",
|
|
"NetworkModuleHada": "w1a",
|
|
"FullModule" : "weight",
|
|
"NetworkModuleFull": "weight",
|
|
"IA3Module" : "w",
|
|
"NetworkModuleIa3" : "w",
|
|
"LoraKronModule" : "w1",
|
|
"LycoKronModule" : "w1",
|
|
"NetworkModuleLokr": "w1",
|
|
"NetworkModuleGLora": "w1a",
|
|
"NetworkModuleNorm": "w_norm",
|
|
"NetworkModuleOFT": "scale"
|
|
}
|
|
|
|
def hyphener(t):
|
|
t = t.split(" ")
|
|
for i,e in enumerate(t):
|
|
if "-" in e:
|
|
e = e.split("-")
|
|
if BLOCKID26.index(e[1]) > BLOCKID26.index(e[0]):
|
|
t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[0]):BLOCKID26.index(e[1])+1])
|
|
else:
|
|
t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[1]):BLOCKID26.index(e[0])+1])
|
|
return " ".join(t)
|
|
|
|
ELEMPRESETS="\
|
|
ATTNDEEPON:IN05-OUT05:attn:1\n\n\
|
|
ATTNDEEPOFF:IN05-OUT05:attn:0\n\n\
|
|
PROJDEEPOFF:IN05-OUT05:proj:0\n\n\
|
|
XYZ:::1"
|
|
|
|
def to26(ratios):
|
|
ids = BLOCKIDS[BLOCKNUMS.index(len(ratios))]
|
|
output = [0]*26
|
|
for i, id in enumerate(ids):
|
|
output[BLOCKID26.index(id)] = ratios[i]
|
|
return output
|
|
|
|
def checkloadcond(l:str)->bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res=(":" not in l) or (not any(l.count(",") == x - 1 for x in BLOCKNUMS)) or ("#" in l)
|
|
|
|
return res
|
|
|