toto10's picture
Upload folder using huggingface_hub (#1)
34097e9
raw
history blame
28.3 kB
'''
Modified version for full net lora
(Lora for ResBlock and up/down sample block)
'''
import os, sys
import re
import torch
from modules import shared, devices, sd_models
now_dir = os.path.dirname(os.path.abspath(__file__))
lora_path = os.path.join(now_dir, '..', '..', '..', 'extensions-builtin/Lora')
sys.path.insert(0, lora_path)
import lora
new_lora = 'lora_calc_updown' in dir(lora)
from locon_compvis import LoConModule, LoConNetworkCompvis, create_network_and_apply_compvis
try:
'''
Hijack Additional Network extension
'''
# skip addnet since don't support new version
raise
now_dir = os.path.dirname(os.path.abspath(__file__))
addnet_path = os.path.join(now_dir, '..', '..', 'sd-webui-additional-networks/scripts')
sys.path.append(addnet_path)
import lora_compvis
import scripts
scripts.lora_compvis = lora_compvis
scripts.lora_compvis.LoRAModule = LoConModule
scripts.lora_compvis.LoRANetworkCompvis = LoConNetworkCompvis
scripts.lora_compvis.create_network_and_apply_compvis = create_network_and_apply_compvis
print('LoCon Extension hijack addnet extension successfully')
except:
print('Additional Network extension not installed, Only hijack built-in lora')
'''
Hijack sd-webui LoRA
'''
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_unet_conv_in = re.compile(r"lora_unet_conv_in(.+)")
re_unet_conv_out = re.compile(r"lora_unet_conv_out(.+)")
re_unet_time_embed = re.compile(r"lora_unet_time_embedding_linear_(\d+)(.+)")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)")
re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)")
re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)")
re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)")
re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
def convert_diffusers_name_to_compvis(key, is_sd2):
# I don't know why but some state dict has this kind of thing
key = key.replace('text_model_text_model', 'text_model')
def match(match_list, regex):
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, re_unet_conv_in):
return f'diffusion_model_input_blocks_0_0{m[0]}'
if match(m, re_unet_conv_out):
return f'diffusion_model_out_2{m[0]}'
if match(m, re_unet_time_embed):
return f"diffusion_model_time_embed_{m[0]*2-2}{m[1]}"
if match(m, re_unet_down_blocks):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_unet_mid_blocks):
return f"diffusion_model_middle_block_1_{m[1]}"
if match(m, re_unet_up_blocks):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_unet_down_blocks_res):
block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_"
if m[2].startswith('conv1'):
return f"{block}in_layers_2{m[2][len('conv1'):]}"
elif m[2].startswith('conv2'):
return f"{block}out_layers_3{m[2][len('conv2'):]}"
elif m[2].startswith('time_emb_proj'):
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
elif m[2].startswith('conv_shortcut'):
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
if match(m, re_unet_mid_blocks_res):
block = f"diffusion_model_middle_block_{m[0]*2}_"
if m[1].startswith('conv1'):
return f"{block}in_layers_2{m[1][len('conv1'):]}"
elif m[1].startswith('conv2'):
return f"{block}out_layers_3{m[1][len('conv2'):]}"
elif m[1].startswith('time_emb_proj'):
return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}"
elif m[1].startswith('conv_shortcut'):
return f"{block}skip_connection{m[1][len('conv_shortcut'):]}"
if match(m, re_unet_up_blocks_res):
block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_"
if m[2].startswith('conv1'):
return f"{block}in_layers_2{m[2][len('conv1'):]}"
elif m[2].startswith('conv2'):
return f"{block}out_layers_3{m[2][len('conv2'):]}"
elif m[2].startswith('time_emb_proj'):
return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}"
elif m[2].startswith('conv_shortcut'):
return f"{block}skip_connection{m[2][len('conv_shortcut'):]}"
if match(m, re_unet_downsample):
return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}"
if match(m, re_unet_upsample):
return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}"
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2:
if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key
class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
class LoraModule:
def __init__(self, name):
self.name = name
self.multiplier = 1.0
self.modules = {}
self.mtime = None
class FakeModule(torch.nn.Module):
def __init__(self, weight, func):
super().__init__()
self.weight = weight
self.func = func
def forward(self, x):
return self.func(x)
class FullModule:
def __init__(self):
self.weight = None
self.alpha = None
self.op = None
self.extra_args = {}
self.shape = None
self.up = None
def down(self, x):
return x
def inference(self, x):
return self.op(x, self.weight, **self.extra_args)
class LoraUpDownModule:
def __init__(self):
self.up_model = None
self.mid_model = None
self.down_model = None
self.alpha = None
self.dim = None
self.op = None
self.extra_args = {}
self.shape = None
self.bias = None
self.up = None
def down(self, x):
return x
def inference(self, x):
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
out_dim = self.up_model.weight.size(0)
rank = self.down_model.weight.size(0)
rebuild_weight = (
self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
+ self.bias
).reshape(self.shape)
return self.op(
x, rebuild_weight,
bias=None,
**self.extra_args
)
else:
if self.mid_model is None:
return self.up_model(self.down_model(x))
else:
return self.up_model(self.mid_model(self.down_model(x)))
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
class LoraHadaModule:
def __init__(self):
self.t1 = None
self.w1a = None
self.w1b = None
self.t2 = None
self.w2a = None
self.w2b = None
self.alpha = None
self.dim = None
self.op = None
self.extra_args = {}
self.shape = None
self.bias = None
self.up = None
def down(self, x):
return x
def inference(self, x):
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
bias = self.bias
else:
bias = 0
if self.t1 is None:
return self.op(
x,
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
bias=None,
**self.extra_args
)
else:
return self.op(
x,
(make_weight_cp(self.t1, self.w1a, self.w1b)
* make_weight_cp(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
bias=None,
**self.extra_args
)
class IA3Module:
def __init__(self):
self.w = None
self.alpha = None
self.on_input = None
def make_kron(orig_shape, w1, w2):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
return torch.kron(w1, w2).reshape(orig_shape)
class LoraKronModule:
def __init__(self):
self.w1 = None
self.w1a = None
self.w1b = None
self.w2 = None
self.t2 = None
self.w2a = None
self.w2b = None
self._alpha = None
self.dim = None
self.op = None
self.extra_args = {}
self.shape = None
self.bias = None
self.up = None
@property
def alpha(self):
if self.w1a is None and self.w2a is None:
return None
else:
return self._alpha
@alpha.setter
def alpha(self, x):
self._alpha = x
def down(self, x):
return x
def inference(self, x):
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
bias = self.bias
else:
bias = 0
if self.t2 is None:
return self.op(
x,
(torch.kron(self.w1, self.w2a@self.w2b) + bias).view(self.shape),
**self.extra_args
)
else:
# will raise NotImplemented Error
return self.op(
x,
(torch.kron(self.w1, make_weight_cp(self.t2, self.w2a, self.w2b)) + bias).view(self.shape),
**self.extra_args
)
CON_KEY = {
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight"
}
HADA_KEY = {
"hada_t1",
"hada_w1_a",
"hada_w1_b",
"hada_t2",
"hada_w2_a",
"hada_w2_b",
}
IA3_KEY = {
"weight",
"on_input"
}
KRON_KEY = {
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_t2",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
}
def load_lora(name, filename):
print('locon load lora method')
lora = LoraModule(name)
lora.mtime = os.path.getmtime(filename)
sd = sd_models.read_state_dict(filename)
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
keys_failed_to_match = []
for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2)
key, lora_key = fullkey.split(".", 1)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
if sd_module is None:
print(key)
keys_failed_to_match.append(key_diffusers)
continue
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "alpha":
lora_module.alpha = weight.item()
continue
if lora_key == "diff":
weight = weight.to(device=devices.device, dtype=devices.dtype)
weight.requires_grad_(False)
lora_module = FullModule()
lora.modules[key] = lora_module
lora_module.weight = weight
lora_module.alpha = weight.size(1)
lora_module.up = FakeModule(
weight,
lora_module.inference
)
lora_module.up.to(device=devices.cpu if new_lora else devices.device, dtype=devices.dtype)
if len(weight.shape)==2:
lora_module.op = torch.nn.functional.linear
lora_module.extra_args = {
'bias': None
}
else:
lora_module.op = torch.nn.functional.conv2d
lora_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding,
'bias': None
}
continue
if 'bias_' in lora_key:
if lora_module.bias is None:
lora_module.bias = [None, None, None]
if 'bias_indices' == lora_key:
lora_module.bias[0] = weight
elif 'bias_values' == lora_key:
lora_module.bias[1] = weight
elif 'bias_size' == lora_key:
lora_module.bias[2] = weight
if all((i is not None) for i in lora_module.bias):
print('build bias')
lora_module.bias = torch.sparse_coo_tensor(
lora_module.bias[0],
lora_module.bias[1],
tuple(lora_module.bias[2]),
).to(device=devices.cpu if new_lora else devices.device, dtype=devices.dtype)
lora_module.bias.requires_grad_(False)
continue
if lora_key in CON_KEY:
if (type(sd_module) == torch.nn.Linear
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
or type(sd_module) == torch.nn.MultiheadAttention):
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
lora_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
if lora_key == "lora_down.weight":
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)
if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif lora_key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False)
elif lora_key == "lora_up.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
lora_module.op = torch.nn.functional.conv2d
lora_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding
}
else:
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
if hasattr(sd_module, 'weight'):
lora_module.shape = sd_module.weight.shape
with torch.no_grad():
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
module.to(device=devices.cpu if new_lora else devices.device, dtype=devices.dtype)
module.requires_grad_(False)
if lora_key == "lora_up.weight":
lora_module.up_model = module
lora_module.up = FakeModule(
lora_module.up_model.weight,
lora_module.inference
)
elif lora_key == "lora_mid.weight":
lora_module.mid_model = module
elif lora_key == "lora_down.weight":
lora_module.down_model = module
lora_module.dim = weight.shape[0]
else:
print(lora_key)
elif lora_key in HADA_KEY:
if type(lora_module) != LoraHadaModule:
alpha = lora_module.alpha
bias = lora_module.bias
lora_module = LoraHadaModule()
lora_module.alpha = alpha
lora_module.bias = bias
lora.modules[key] = lora_module
if hasattr(sd_module, 'weight'):
lora_module.shape = sd_module.weight.shape
weight = weight.to(device=devices.cpu if new_lora else devices.device, dtype=devices.dtype)
weight.requires_grad_(False)
if lora_key == 'hada_w1_a':
lora_module.w1a = weight
if lora_module.up is None:
lora_module.up = FakeModule(
lora_module.w1a,
lora_module.inference
)
elif lora_key == 'hada_w1_b':
lora_module.w1b = weight
lora_module.dim = weight.shape[0]
elif lora_key == 'hada_w2_a':
lora_module.w2a = weight
elif lora_key == 'hada_w2_b':
lora_module.w2b = weight
elif lora_key == 'hada_t1':
lora_module.t1 = weight
lora_module.up = FakeModule(
lora_module.t1,
lora_module.inference
)
elif lora_key == 'hada_t2':
lora_module.t2 = weight
if (type(sd_module) == torch.nn.Linear
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
or type(sd_module) == torch.nn.MultiheadAttention):
lora_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
lora_module.op = torch.nn.functional.conv2d
lora_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding
}
else:
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
elif lora_key in IA3_KEY:
if type(lora_module) != IA3Module:
lora_module = IA3Module()
lora.modules[key] = lora_module
if lora_key == "weight":
lora_module.w = weight.to(devices.device, dtype=devices.dtype)
elif lora_key == "on_input":
lora_module.on_input = weight
elif lora_key in KRON_KEY:
if not isinstance(lora_module, LoraKronModule):
alpha = lora_module.alpha
bias = lora_module.bias
lora_module = LoraKronModule()
lora_module.alpha = alpha
lora_module.bias = bias
lora.modules[key] = lora_module
if hasattr(sd_module, 'weight'):
lora_module.shape = sd_module.weight.shape
weight = weight.to(device=devices.cpu if new_lora else devices.device, dtype=devices.dtype)
weight.requires_grad_(False)
if lora_key == 'lokr_w1':
lora_module.w1 = weight
elif lora_key == 'lokr_w1_a':
lora_module.w1a = weight
if lora_module.up is None:
lora_module.up = FakeModule(
lora_module.w1a,
lora_module.inference
)
elif lora_key == 'lokr_w1_b':
lora_module.w1b = weight
elif lora_key == 'lokr_w2':
lora_module.w2 = weight
elif lora_key == 'lokr_w2_a':
lora_module.w2a = weight
lora_module.dim = weight.shape[0]
if lora_module.up is None:
lora_module.up = FakeModule(
lora_module.w2a,
lora_module.inference
)
elif lora_key == 'lokr_w2_b':
lora_module.w2b = weight
elif lora_key == 'lokr_t2':
lora_module.t2 = weight
if (any(isinstance(sd_module, torch_layer) for torch_layer in
[torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention])):
lora_module.op = torch.nn.functional.linear
elif isinstance(sd_module, torch.nn.Conv2d):
lora_module.op = torch.nn.functional.conv2d
lora_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding
}
else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0:
print(shared.sd_model.lora_layer_mapping)
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
return lora
def lora_forward(module, input, res):
if len(lora.loaded_loras) == 0:
return res
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora_m in lora.loaded_loras:
module = lora_m.modules.get(lora_layer_name, None)
if module is not None and lora_m.multiplier:
if hasattr(module, 'up'):
scale = lora_m.multiplier * (module.alpha / module.up.weight.size(1) if module.alpha else 1.0)
else:
scale = lora_m.multiplier * (module.alpha / module.dim if module.alpha else 1.0)
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
x = res
else:
x = input
if hasattr(module, 'inference'):
res = res + module.inference(x) * scale
elif hasattr(module, 'up'):
res = res + module.up(module.down(x)) * scale
else:
raise NotImplementedError(
"Your settings, extensions or models are not compatible with each other."
)
return res
def _rebuild_conventional(up, down, shape):
return (up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1)).reshape(shape)
def _rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
def rebuild_weight(module, orig_weight: torch.Tensor) -> torch.Tensor:
if module.__class__.__name__ == 'LoraUpDownModule':
up = module.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
down = module.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [up.size(0), down.size(1)]
if (mid:=module.mid_model) is not None:
# cp-decomposition
mid = mid.weight.to(orig_weight.device, dtype=orig_weight.dtype)
updown = _rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
if len(down.shape) == 4:
output_shape += down.shape[2:]
updown = _rebuild_conventional(up, down, output_shape)
elif module.__class__.__name__ == 'LoraHadaModule':
w1a = module.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = module.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = module.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = module.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
if module.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = module.t1.to(orig_weight.device, dtype=orig_weight.dtype)
updown1 = make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
if len(w1b.shape) == 4:
output_shape += w1b.shape[2:]
updown1 = _rebuild_conventional(w1a, w1b, output_shape)
if module.t2 is not None:
t2 = module.t2.to(orig_weight.device, dtype=orig_weight.dtype)
updown2 = make_weight_cp(t2, w2a, w2b)
else:
updown2 = _rebuild_conventional(w2a, w2b, output_shape)
updown = updown1 * updown2
elif module.__class__.__name__ == 'FullModule':
output_shape = module.weight.shape
updown = module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
elif module.__class__.__name__ == 'IA3Module':
output_shape = [module.w.size(0), orig_weight.size(1)]
if module.on_input:
output_shape.reverse()
else:
module.w = module.w.reshape(-1, 1)
updown = orig_weight * module.w
elif module.__class__.__name__ == 'LoraKronModule':
if module.w1 is not None:
w1 = module.w1.to(orig_weight.device, dtype=orig_weight.dtype)
else:
w1a = module.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = module.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w1 = w1a @ w1b
if module.w2 is not None:
w2 = module.w2.to(orig_weight.device, dtype=orig_weight.dtype)
elif module.t2 is None:
w2a = module.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = module.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = w2a @ w2b
else:
t2 = module.t2.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = module.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = module.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0)*w2.size(0), w1.size(1)*w2.size(1)]
if len(orig_weight.shape) == 4:
output_shape = orig_weight.shape
updown = make_kron(
output_shape, w1, w2
)
else:
raise NotImplementedError(
f"Unknown module type: {module.__class__.__name__}\n"
"If the type is one of "
"'LoraUpDownModule', 'LoraHadaModule', 'FullModule', 'IA3Module', 'LoraKronModule'"
"You may have other lora extension that conflict with locon extension."
)
if hasattr(module, 'bias') and module.bias != None:
updown = updown.reshape(module.bias.shape)
updown += module.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
# print(torch.sum(updown))
return updown
def lora_calc_updown(lora, module, target):
with torch.no_grad():
updown = rebuild_weight(module, target)
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return updown
lora.convert_diffusers_name_to_compvis = convert_diffusers_name_to_compvis
lora.load_lora = load_lora
lora.lora_forward = lora_forward
lora.lora_calc_updown = lora_calc_updown
print('LoCon Extension hijack built-in lora successfully')