''' Modified version for full net lora (Lora for ResBlock and up/down sample block) ''' import os, sys import re import torch from modules import shared, devices, sd_models import lora from locon_compvis import LoConModule, LoConNetworkCompvis, create_network_and_apply_compvis try: ''' Hijack Additional Network extension ''' # skip addnet since don't support new version raise now_dir = os.path.dirname(os.path.abspath(__file__)) addnet_path = os.path.join(now_dir, '..', '..', 'sd-webui-additional-networks/scripts') sys.path.append(addnet_path) import lora_compvis import scripts scripts.lora_compvis = lora_compvis scripts.lora_compvis.LoRAModule = LoConModule scripts.lora_compvis.LoRANetworkCompvis = LoConNetworkCompvis scripts.lora_compvis.create_network_and_apply_compvis = create_network_and_apply_compvis print('LoCon Extension hijack addnet extension successfully') except: print('Additional Network extension not installed, Only hijack built-in lora') ''' Hijack sd-webui LoRA ''' re_digits = re.compile(r"\d+") re_unet_conv_in = re.compile(r"lora_unet_conv_in(.+)") re_unet_conv_out = re.compile(r"lora_unet_conv_out(.+)") re_unet_time_embed = re.compile(r"lora_unet_time_embedding_linear_(\d+)(.+)") re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") re_unet_down_blocks_res = re.compile(r"lora_unet_down_blocks_(\d+)_resnets_(\d+)_(.+)") re_unet_mid_blocks_res = re.compile(r"lora_unet_mid_block_resnets_(\d+)_(.+)") re_unet_up_blocks_res = re.compile(r"lora_unet_up_blocks_(\d+)_resnets_(\d+)_(.+)") re_unet_downsample = re.compile(r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv(.+)") re_unet_upsample = re.compile(r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv(.+)") re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") def convert_diffusers_name_to_compvis(key): def match(match_list, regex): r = re.match(regex, key) if not r: return False match_list.clear() match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) return True m = [] if match(m, re_unet_conv_in): return f'diffusion_model_input_blocks_0_0{m[0]}' if match(m, re_unet_conv_out): return f'diffusion_model_out_2{m[0]}' if match(m, re_unet_time_embed): return f"diffusion_model_time_embed_{m[0]*2-2}{m[1]}" if match(m, re_unet_down_blocks): return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" if match(m, re_unet_mid_blocks): return f"diffusion_model_middle_block_1_{m[1]}" if match(m, re_unet_up_blocks): return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" if match(m, re_unet_down_blocks_res): block = f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_0_" if m[2].startswith('conv1'): return f"{block}in_layers_2{m[2][len('conv1'):]}" elif m[2].startswith('conv2'): return f"{block}out_layers_3{m[2][len('conv2'):]}" elif m[2].startswith('time_emb_proj'): return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" elif m[2].startswith('conv_shortcut'): return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" if match(m, re_unet_mid_blocks_res): block = f"diffusion_model_middle_block_{m[0]*2}_" if m[1].startswith('conv1'): return f"{block}in_layers_2{m[1][len('conv1'):]}" elif m[1].startswith('conv2'): return f"{block}out_layers_3{m[1][len('conv2'):]}" elif m[1].startswith('time_emb_proj'): return f"{block}emb_layers_1{m[1][len('time_emb_proj'):]}" elif m[1].startswith('conv_shortcut'): return f"{block}skip_connection{m[1][len('conv_shortcut'):]}" if match(m, re_unet_up_blocks_res): block = f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_0_" if m[2].startswith('conv1'): return f"{block}in_layers_2{m[2][len('conv1'):]}" elif m[2].startswith('conv2'): return f"{block}out_layers_3{m[2][len('conv2'):]}" elif m[2].startswith('time_emb_proj'): return f"{block}emb_layers_1{m[2][len('time_emb_proj'):]}" elif m[2].startswith('conv_shortcut'): return f"{block}skip_connection{m[2][len('conv_shortcut'):]}" if match(m, re_unet_downsample): return f"diffusion_model_input_blocks_{m[0]*3+3}_0_op{m[1]}" if match(m, re_unet_upsample): return f"diffusion_model_output_blocks_{m[0]*3 + 2}_{1+(m[0]!=0)}_conv{m[1]}" if match(m, re_text_block): return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return key class LoraOnDisk: def __init__(self, name, filename): self.name = name self.filename = filename class LoraModule: def __init__(self, name): self.name = name self.multiplier = 1.0 self.modules = {} self.mtime = None class FakeModule(torch.nn.Module): def __init__(self, weight, func): super().__init__() self.weight = weight self.func = func def forward(self, x): return self.func(x) class FullModule: def __init__(self): self.weight = None self.alpha = None self.op = None self.extra_args = {} self.shape = None self.up = None def down(self, x): return x def inference(self, x): return self.op(x, self.weight, **self.extra_args) class LoraUpDownModule: def __init__(self): self.up_model = None self.mid_model = None self.down_model = None self.alpha = None self.dim = None self.op = None self.extra_args = {} self.shape = None self.bias = None self.up = None def down(self, x): return x def inference(self, x): if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor): out_dim = self.up_model.weight.size(0) rank = self.down_model.weight.size(0) rebuild_weight = ( self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1) + self.bias ).reshape(self.shape) return self.op( x, rebuild_weight, **self.extra_args ) else: if self.mid_model is None: return self.up_model(self.down_model(x)) else: return self.up_model(self.mid_model(self.down_model(x))) def pro3(t, wa, wb): temp = torch.einsum('i j k l, j r -> i r k l', t, wb) return torch.einsum('i j k l, i r -> r j k l', temp, wa) class LoraHadaModule: def __init__(self): self.t1 = None self.w1a = None self.w1b = None self.t2 = None self.w2a = None self.w2b = None self.alpha = None self.dim = None self.op = None self.extra_args = {} self.shape = None self.bias = None self.up = None def down(self, x): return x def inference(self, x): if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor): bias = self.bias else: bias = 0 if self.t1 is None: return self.op( x, ((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape), **self.extra_args ) else: return self.op( x, (pro3(self.t1, self.w1a, self.w1b) * pro3(self.t2, self.w2a, self.w2b) + bias).view(self.shape), **self.extra_args ) CON_KEY = { "lora_up.weight", "lora_down.weight", "lora_mid.weight" } HADA_KEY = { "hada_t1", "hada_w1_a", "hada_w1_b", "hada_t2", "hada_w2_a", "hada_w2_b", } def load_lora(name, filename): lora = LoraModule(name) lora.mtime = os.path.getmtime(filename) sd = sd_models.read_state_dict(filename) keys_failed_to_match = [] for key_diffusers, weight in sd.items(): fullkey = convert_diffusers_name_to_compvis(key_diffusers) key, lora_key = fullkey.split(".", 1) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: keys_failed_to_match.append(key_diffusers) continue lora_module = lora.modules.get(key, None) if lora_module is None: lora_module = LoraUpDownModule() lora.modules[key] = lora_module if lora_key == "alpha": lora_module.alpha = weight.item() continue if lora_key == "diff": weight = weight.to(device=devices.device, dtype=devices.dtype) weight.requires_grad_(False) lora_module = FullModule() lora.modules[key] = lora_module lora_module.weight = weight lora_module.alpha = weight.size(1) lora_module.up = FakeModule( weight, lora_module.inference ) lora_module.up.to(device=devices.device, dtype=devices.dtype) if len(weight.shape)==2: lora_module.op = torch.nn.functional.linear lora_module.extra_args = { 'bias': None } else: lora_module.op = torch.nn.functional.conv2d lora_module.extra_args = { 'stride': sd_module.stride, 'padding': sd_module.padding, 'bias': None } continue if 'bias_' in lora_key: if lora_module.bias is None: lora_module.bias = [None, None, None] if 'bias_indices' == lora_key: lora_module.bias[0] = weight elif 'bias_values' == lora_key: lora_module.bias[1] = weight elif 'bias_size' == lora_key: lora_module.bias[2] = weight if all((i is not None) for i in lora_module.bias): print('build bias') lora_module.bias = torch.sparse_coo_tensor( lora_module.bias[0], lora_module.bias[1], tuple(lora_module.bias[2]), ).to(device=devices.device, dtype=devices.dtype) lora_module.bias.requires_grad_(False) continue if lora_key in CON_KEY: if type(sd_module) == torch.nn.Linear: weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) lora_module.op = torch.nn.functional.linear elif type(sd_module) == torch.nn.Conv2d: if lora_key == "lora_down.weight": if weight.shape[2] != 1 or weight.shape[3] != 1: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False) else: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) elif lora_key == "lora_mid.weight": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], sd_module.kernel_size, sd_module.stride, sd_module.padding, bias=False) elif lora_key == "lora_up.weight": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) lora_module.op = torch.nn.functional.conv2d lora_module.extra_args = { 'stride': sd_module.stride, 'padding': sd_module.padding } else: assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' lora_module.shape = sd_module.weight.shape with torch.no_grad(): module.weight.copy_(weight) module.to(device=devices.device, dtype=devices.dtype) module.requires_grad_(False) if lora_key == "lora_up.weight": lora_module.up_model = module lora_module.up = FakeModule( lora_module.up_model.weight, lora_module.inference ) elif lora_key == "lora_mid.weight": lora_module.mid_model = module elif lora_key == "lora_down.weight": lora_module.down_model = module lora_module.dim = weight.shape[0] elif lora_key in HADA_KEY: if type(lora_module) != LoraHadaModule: alpha = lora_module.alpha bias = lora_module.bias lora_module = LoraHadaModule() lora_module.alpha = alpha lora_module.bias = bias lora.modules[key] = lora_module lora_module.shape = sd_module.weight.shape weight = weight.to(device=devices.device, dtype=devices.dtype) weight.requires_grad_(False) if lora_key == 'hada_w1_a': lora_module.w1a = weight if lora_module.up is None: lora_module.up = FakeModule( lora_module.w1a, lora_module.inference ) elif lora_key == 'hada_w1_b': lora_module.w1b = weight lora_module.dim = weight.shape[0] elif lora_key == 'hada_w2_a': lora_module.w2a = weight elif lora_key == 'hada_w2_b': lora_module.w2b = weight elif lora_key == 'hada_t1': lora_module.t1 = weight lora_module.up = FakeModule( lora_module.t1, lora_module.inference ) elif lora_key == 'hada_t2': lora_module.t2 = weight if type(sd_module) == torch.nn.Linear: lora_module.op = torch.nn.functional.linear elif type(sd_module) == torch.nn.Conv2d: lora_module.op = torch.nn.functional.conv2d lora_module.extra_args = { 'stride': sd_module.stride, 'padding': sd_module.padding } else: assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' else: assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(shared.sd_model.lora_layer_mapping) print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") return lora def lora_forward(module, input, res): if len(lora.loaded_loras) == 0: return res lora_layer_name = getattr(module, 'lora_layer_name', None) for lora_m in lora.loaded_loras: module = lora_m.modules.get(lora_layer_name, None) if module is not None and lora_m.multiplier: if hasattr(module, 'up'): scale = lora_m.multiplier * (module.alpha / module.up.weight.size(1) if module.alpha else 1.0) else: scale = lora_m.multiplier * (module.alpha / module.dim if module.alpha else 1.0) if shared.opts.lora_apply_to_outputs and res.shape == input.shape: x = res else: x = input if hasattr(module, 'inference'): res = res + module.inference(x) * scale elif hasattr(module, 'up'): res = res + module.up(module.down(x)) * scale else: raise NotImplementedError( "Your settings, extensions or models are not compatible with each other." ) return res lora.convert_diffusers_name_to_compvis = convert_diffusers_name_to_compvis lora.load_lora = load_lora lora.lora_forward = lora_forward print('LoCon Extension hijack built-in lora successfully')