|
import logging |
|
import os |
|
import re |
|
|
|
import lora_patches |
|
import network |
|
import network_lora |
|
import network_hada |
|
import network_ia3 |
|
import network_lokr |
|
import network_full |
|
import network_norm |
|
|
|
import torch |
|
from typing import Union |
|
|
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack |
|
|
|
module_types = [ |
|
network_lora.ModuleTypeLora(), |
|
network_hada.ModuleTypeHada(), |
|
network_ia3.ModuleTypeIa3(), |
|
network_lokr.ModuleTypeLokr(), |
|
network_full.ModuleTypeFull(), |
|
network_norm.ModuleTypeNorm(), |
|
] |
|
|
|
|
|
re_digits = re.compile(r"\d+") |
|
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") |
|
re_compiled = {} |
|
|
|
suffix_conversion = { |
|
"attentions": {}, |
|
"resnets": { |
|
"conv1": "in_layers_2", |
|
"conv2": "out_layers_3", |
|
"norm1": "in_layers_0", |
|
"norm2": "out_layers_0", |
|
"time_emb_proj": "emb_layers_1", |
|
"conv_shortcut": "skip_connection", |
|
} |
|
} |
|
|
|
|
|
def convert_diffusers_name_to_compvis(key, is_sd2): |
|
def match(match_list, regex_text): |
|
regex = re_compiled.get(regex_text) |
|
if regex is None: |
|
regex = re.compile(regex_text) |
|
re_compiled[regex_text] = regex |
|
|
|
r = re.match(regex, key) |
|
if not r: |
|
return False |
|
|
|
match_list.clear() |
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
|
return True |
|
|
|
m = [] |
|
|
|
if match(m, r"lora_unet_conv_in(.*)"): |
|
return f'diffusion_model_input_blocks_0_0{m[0]}' |
|
|
|
if match(m, r"lora_unet_conv_out(.*)"): |
|
return f'diffusion_model_out_2{m[0]}' |
|
|
|
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): |
|
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" |
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
|
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
|
|
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) |
|
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" |
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): |
|
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" |
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): |
|
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" |
|
|
|
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): |
|
if is_sd2: |
|
if 'mlp_fc1' in m[1]: |
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
|
elif 'mlp_fc2' in m[1]: |
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
|
else: |
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
|
|
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
|
|
|
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): |
|
if 'mlp_fc1' in m[1]: |
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
|
elif 'mlp_fc2' in m[1]: |
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
|
else: |
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
|
|
|
return key |
|
|
|
|
|
def assign_network_names_to_compvis_modules(sd_model): |
|
network_layer_mapping = {} |
|
|
|
if shared.sd_model.is_sdxl: |
|
for i, embedder in enumerate(shared.sd_model.conditioner.embedders): |
|
if not hasattr(embedder, 'wrapped'): |
|
continue |
|
|
|
for name, module in embedder.wrapped.named_modules(): |
|
network_name = f'{i}_{name.replace(".", "_")}' |
|
network_layer_mapping[network_name] = module |
|
module.network_layer_name = network_name |
|
else: |
|
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): |
|
network_name = name.replace(".", "_") |
|
network_layer_mapping[network_name] = module |
|
module.network_layer_name = network_name |
|
|
|
for name, module in shared.sd_model.model.named_modules(): |
|
network_name = name.replace(".", "_") |
|
network_layer_mapping[network_name] = module |
|
module.network_layer_name = network_name |
|
|
|
sd_model.network_layer_mapping = network_layer_mapping |
|
|
|
|
|
def load_network(name, network_on_disk): |
|
net = network.Network(name, network_on_disk) |
|
net.mtime = os.path.getmtime(network_on_disk.filename) |
|
|
|
sd = sd_models.read_state_dict(network_on_disk.filename) |
|
|
|
|
|
if not hasattr(shared.sd_model, 'network_layer_mapping'): |
|
assign_network_names_to_compvis_modules(shared.sd_model) |
|
|
|
keys_failed_to_match = {} |
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping |
|
|
|
matched_networks = {} |
|
|
|
for key_network, weight in sd.items(): |
|
key_network_without_network_parts, network_part = key_network.split(".", 1) |
|
|
|
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) |
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None) |
|
|
|
if sd_module is None: |
|
m = re_x_proj.match(key) |
|
if m: |
|
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) |
|
|
|
|
|
if sd_module is None and "lora_unet" in key_network_without_network_parts: |
|
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") |
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None) |
|
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: |
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") |
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None) |
|
|
|
|
|
if sd_module is None: |
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") |
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None) |
|
|
|
if sd_module is None: |
|
keys_failed_to_match[key_network] = key |
|
continue |
|
|
|
if key not in matched_networks: |
|
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) |
|
|
|
matched_networks[key].w[network_part] = weight |
|
|
|
for key, weights in matched_networks.items(): |
|
net_module = None |
|
for nettype in module_types: |
|
net_module = nettype.create_module(net, weights) |
|
if net_module is not None: |
|
break |
|
|
|
if net_module is None: |
|
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") |
|
|
|
net.modules[key] = net_module |
|
|
|
if keys_failed_to_match: |
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") |
|
|
|
return net |
|
|
|
|
|
def purge_networks_from_memory(): |
|
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: |
|
name = next(iter(networks_in_memory)) |
|
networks_in_memory.pop(name, None) |
|
|
|
devices.torch_gc() |
|
|
|
|
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): |
|
already_loaded = {} |
|
|
|
for net in loaded_networks: |
|
if net.name in names: |
|
already_loaded[net.name] = net |
|
|
|
loaded_networks.clear() |
|
|
|
networks_on_disk = [available_network_aliases.get(name, None) for name in names] |
|
if any(x is None for x in networks_on_disk): |
|
list_available_networks() |
|
|
|
networks_on_disk = [available_network_aliases.get(name, None) for name in names] |
|
|
|
failed_to_load_networks = [] |
|
|
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): |
|
net = already_loaded.get(name, None) |
|
|
|
if network_on_disk is not None: |
|
if net is None: |
|
net = networks_in_memory.get(name) |
|
|
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: |
|
try: |
|
net = load_network(name, network_on_disk) |
|
|
|
networks_in_memory.pop(name, None) |
|
networks_in_memory[name] = net |
|
except Exception as e: |
|
errors.display(e, f"loading network {network_on_disk.filename}") |
|
continue |
|
|
|
net.mentioned_name = name |
|
|
|
network_on_disk.read_hash() |
|
|
|
if net is None: |
|
failed_to_load_networks.append(name) |
|
logging.info(f"Couldn't find network with name {name}") |
|
continue |
|
|
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 |
|
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 |
|
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 |
|
loaded_networks.append(net) |
|
|
|
if failed_to_load_networks: |
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) |
|
|
|
purge_networks_from_memory() |
|
|
|
|
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): |
|
weights_backup = getattr(self, "network_weights_backup", None) |
|
bias_backup = getattr(self, "network_bias_backup", None) |
|
|
|
if weights_backup is None and bias_backup is None: |
|
return |
|
|
|
if weights_backup is not None: |
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
self.in_proj_weight.copy_(weights_backup[0]) |
|
self.out_proj.weight.copy_(weights_backup[1]) |
|
else: |
|
self.weight.copy_(weights_backup) |
|
|
|
if bias_backup is not None: |
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
self.out_proj.bias.copy_(bias_backup) |
|
else: |
|
self.bias.copy_(bias_backup) |
|
else: |
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
self.out_proj.bias = None |
|
else: |
|
self.bias = None |
|
|
|
|
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): |
|
""" |
|
Applies the currently selected set of networks to the weights of torch layer self. |
|
If weights already have this particular set of networks applied, does nothing. |
|
If not, restores orginal weights from backup and alters weights according to networks. |
|
""" |
|
|
|
network_layer_name = getattr(self, 'network_layer_name', None) |
|
if network_layer_name is None: |
|
return |
|
|
|
current_names = getattr(self, "network_current_names", ()) |
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) |
|
|
|
weights_backup = getattr(self, "network_weights_backup", None) |
|
if weights_backup is None and wanted_names != (): |
|
if current_names != (): |
|
raise RuntimeError("no backup weights found and current weights are not unchanged") |
|
|
|
if isinstance(self, torch.nn.MultiheadAttention): |
|
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) |
|
else: |
|
weights_backup = self.weight.to(devices.cpu, copy=True) |
|
|
|
self.network_weights_backup = weights_backup |
|
|
|
bias_backup = getattr(self, "network_bias_backup", None) |
|
if bias_backup is None: |
|
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: |
|
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) |
|
elif getattr(self, 'bias', None) is not None: |
|
bias_backup = self.bias.to(devices.cpu, copy=True) |
|
else: |
|
bias_backup = None |
|
self.network_bias_backup = bias_backup |
|
|
|
if current_names != wanted_names: |
|
network_restore_weights_from_backup(self) |
|
|
|
for net in loaded_networks: |
|
module = net.modules.get(network_layer_name, None) |
|
if module is not None and hasattr(self, 'weight'): |
|
try: |
|
with torch.no_grad(): |
|
updown, ex_bias = module.calc_updown(self.weight) |
|
|
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: |
|
|
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) |
|
|
|
self.weight += updown |
|
if ex_bias is not None and hasattr(self, 'bias'): |
|
if self.bias is None: |
|
self.bias = torch.nn.Parameter(ex_bias) |
|
else: |
|
self.bias += ex_bias |
|
except RuntimeError as e: |
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") |
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 |
|
|
|
continue |
|
|
|
module_q = net.modules.get(network_layer_name + "_q_proj", None) |
|
module_k = net.modules.get(network_layer_name + "_k_proj", None) |
|
module_v = net.modules.get(network_layer_name + "_v_proj", None) |
|
module_out = net.modules.get(network_layer_name + "_out_proj", None) |
|
|
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: |
|
try: |
|
with torch.no_grad(): |
|
updown_q, _ = module_q.calc_updown(self.in_proj_weight) |
|
updown_k, _ = module_k.calc_updown(self.in_proj_weight) |
|
updown_v, _ = module_v.calc_updown(self.in_proj_weight) |
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) |
|
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) |
|
|
|
self.in_proj_weight += updown_qkv |
|
self.out_proj.weight += updown_out |
|
if ex_bias is not None: |
|
if self.out_proj.bias is None: |
|
self.out_proj.bias = torch.nn.Parameter(ex_bias) |
|
else: |
|
self.out_proj.bias += ex_bias |
|
|
|
except RuntimeError as e: |
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") |
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 |
|
|
|
continue |
|
|
|
if module is None: |
|
continue |
|
|
|
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") |
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 |
|
|
|
self.network_current_names = wanted_names |
|
|
|
|
|
def network_forward(module, input, original_forward): |
|
""" |
|
Old way of applying Lora by executing operations during layer's forward. |
|
Stacking many loras this way results in big performance degradation. |
|
""" |
|
|
|
if len(loaded_networks) == 0: |
|
return original_forward(module, input) |
|
|
|
input = devices.cond_cast_unet(input) |
|
|
|
network_restore_weights_from_backup(module) |
|
network_reset_cached_weight(module) |
|
|
|
y = original_forward(module, input) |
|
|
|
network_layer_name = getattr(module, 'network_layer_name', None) |
|
for lora in loaded_networks: |
|
module = lora.modules.get(network_layer_name, None) |
|
if module is None: |
|
continue |
|
|
|
y = module.forward(input, y) |
|
|
|
return y |
|
|
|
|
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): |
|
self.network_current_names = () |
|
self.network_weights_backup = None |
|
|
|
|
|
def network_Linear_forward(self, input): |
|
if shared.opts.lora_functional: |
|
return network_forward(self, input, originals.Linear_forward) |
|
|
|
network_apply_weights(self) |
|
|
|
return originals.Linear_forward(self, input) |
|
|
|
|
|
def network_Linear_load_state_dict(self, *args, **kwargs): |
|
network_reset_cached_weight(self) |
|
|
|
return originals.Linear_load_state_dict(self, *args, **kwargs) |
|
|
|
|
|
def network_Conv2d_forward(self, input): |
|
if shared.opts.lora_functional: |
|
return network_forward(self, input, originals.Conv2d_forward) |
|
|
|
network_apply_weights(self) |
|
|
|
return originals.Conv2d_forward(self, input) |
|
|
|
|
|
def network_Conv2d_load_state_dict(self, *args, **kwargs): |
|
network_reset_cached_weight(self) |
|
|
|
return originals.Conv2d_load_state_dict(self, *args, **kwargs) |
|
|
|
|
|
def network_GroupNorm_forward(self, input): |
|
if shared.opts.lora_functional: |
|
return network_forward(self, input, originals.GroupNorm_forward) |
|
|
|
network_apply_weights(self) |
|
|
|
return originals.GroupNorm_forward(self, input) |
|
|
|
|
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs): |
|
network_reset_cached_weight(self) |
|
|
|
return originals.GroupNorm_load_state_dict(self, *args, **kwargs) |
|
|
|
|
|
def network_LayerNorm_forward(self, input): |
|
if shared.opts.lora_functional: |
|
return network_forward(self, input, originals.LayerNorm_forward) |
|
|
|
network_apply_weights(self) |
|
|
|
return originals.LayerNorm_forward(self, input) |
|
|
|
|
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs): |
|
network_reset_cached_weight(self) |
|
|
|
return originals.LayerNorm_load_state_dict(self, *args, **kwargs) |
|
|
|
|
|
def network_MultiheadAttention_forward(self, *args, **kwargs): |
|
network_apply_weights(self) |
|
|
|
return originals.MultiheadAttention_forward(self, *args, **kwargs) |
|
|
|
|
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): |
|
network_reset_cached_weight(self) |
|
|
|
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) |
|
|
|
|
|
def list_available_networks(): |
|
available_networks.clear() |
|
available_network_aliases.clear() |
|
forbidden_network_aliases.clear() |
|
available_network_hash_lookup.clear() |
|
forbidden_network_aliases.update({"none": 1, "Addams": 1}) |
|
|
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) |
|
|
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) |
|
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) |
|
for filename in candidates: |
|
if os.path.isdir(filename): |
|
continue |
|
|
|
name = os.path.splitext(os.path.basename(filename))[0] |
|
try: |
|
entry = network.NetworkOnDisk(name, filename) |
|
except OSError: |
|
errors.report(f"Failed to load network {name} from {filename}", exc_info=True) |
|
continue |
|
|
|
available_networks[name] = entry |
|
|
|
if entry.alias in available_network_aliases: |
|
forbidden_network_aliases[entry.alias.lower()] = 1 |
|
|
|
available_network_aliases[name] = entry |
|
available_network_aliases[entry.alias] = entry |
|
|
|
|
|
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") |
|
|
|
|
|
def infotext_pasted(infotext, params): |
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: |
|
return |
|
|
|
added = [] |
|
|
|
for k in params: |
|
if not k.startswith("AddNet Model "): |
|
continue |
|
|
|
num = k[13:] |
|
|
|
if params.get("AddNet Module " + num) != "LoRA": |
|
continue |
|
|
|
name = params.get("AddNet Model " + num) |
|
if name is None: |
|
continue |
|
|
|
m = re_network_name.match(name) |
|
if m: |
|
name = m.group(1) |
|
|
|
multiplier = params.get("AddNet Weight A " + num, "1.0") |
|
|
|
added.append(f"<lora:{name}:{multiplier}>") |
|
|
|
if added: |
|
params["Prompt"] += "\n" + "".join(added) |
|
|
|
|
|
originals: lora_patches.LoraPatches = None |
|
|
|
extra_network_lora = None |
|
|
|
available_networks = {} |
|
available_network_aliases = {} |
|
loaded_networks = [] |
|
networks_in_memory = {} |
|
available_network_hash_lookup = {} |
|
forbidden_network_aliases = {} |
|
|
|
list_available_networks() |
|
|