| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import os |
| import random |
| from typing import List, Tuple, Union |
| import torch |
| from torch import nn |
|
|
|
|
| class DyLoRAModule(torch.nn.Module): |
| """ |
| replaces forward method of the original Linear, instead of replacing the original Linear module. |
| """ |
|
|
| |
| def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1): |
| super().__init__() |
| self.lora_name = lora_name |
| self.lora_dim = lora_dim |
| self.unit = unit |
| assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit" |
|
|
| if org_module.__class__.__name__ == "Conv2d": |
| in_dim = org_module.in_channels |
| out_dim = org_module.out_channels |
| else: |
| in_dim = org_module.in_features |
| out_dim = org_module.out_features |
|
|
| if type(alpha) == torch.Tensor: |
| alpha = alpha.detach().float().numpy() |
| alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
| self.scale = alpha / self.lora_dim |
| self.register_buffer("alpha", torch.tensor(alpha)) |
|
|
| self.is_conv2d = org_module.__class__.__name__ == "Conv2d" |
| self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3) |
|
|
| if self.is_conv2d and self.is_conv2d_3x3: |
| kernel_size = org_module.kernel_size |
| self.stride = org_module.stride |
| self.padding = org_module.padding |
| self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)]) |
| self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)]) |
| else: |
| self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)]) |
| self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)]) |
|
|
| |
| for lora in self.lora_A: |
| torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5)) |
| for lora in self.lora_B: |
| torch.nn.init.zeros_(lora) |
|
|
| self.multiplier = multiplier |
| self.org_module = org_module |
|
|
| def apply_to(self): |
| self.org_forward = self.org_module.forward |
| self.org_module.forward = self.forward |
| del self.org_module |
|
|
| def forward(self, x): |
| result = self.org_forward(x) |
|
|
| |
| trainable_rank = random.randint(0, self.lora_dim - 1) |
| trainable_rank = trainable_rank - trainable_rank % self.unit |
|
|
| |
| for i in range(0, trainable_rank): |
| self.lora_A[i].requires_grad = False |
| self.lora_B[i].requires_grad = False |
| for i in range(trainable_rank, trainable_rank + self.unit): |
| self.lora_A[i].requires_grad = True |
| self.lora_B[i].requires_grad = True |
| for i in range(trainable_rank + self.unit, self.lora_dim): |
| self.lora_A[i].requires_grad = False |
| self.lora_B[i].requires_grad = False |
|
|
| lora_A = torch.cat(tuple(self.lora_A), dim=0) |
| lora_B = torch.cat(tuple(self.lora_B), dim=1) |
|
|
| |
| if self.is_conv2d_3x3: |
| ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding) |
| ab = torch.nn.functional.conv2d(ab, lora_B) |
| else: |
| ab = x |
| if self.is_conv2d: |
| ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) |
|
|
| ab = torch.nn.functional.linear(ab, lora_A) |
| ab = torch.nn.functional.linear(ab, lora_B) |
|
|
| if self.is_conv2d: |
| ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) |
|
|
| |
| result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) |
|
|
| |
| return result |
|
|
| def state_dict(self, destination=None, prefix="", keep_vars=False): |
| |
| |
| sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
|
| lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) |
| if self.is_conv2d and not self.is_conv2d_3x3: |
| lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) |
|
|
| lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) |
| if self.is_conv2d and not self.is_conv2d_3x3: |
| lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) |
|
|
| sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() |
| sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() |
|
|
| i = 0 |
| while True: |
| key_a = f"{self.lora_name}.lora_A.{i}" |
| key_b = f"{self.lora_name}.lora_B.{i}" |
| if key_a in sd: |
| sd.pop(key_a) |
| sd.pop(key_b) |
| else: |
| break |
| i += 1 |
| return sd |
|
|
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| |
| lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) |
| lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) |
|
|
| if lora_A_weight is None or lora_B_weight is None: |
| if strict: |
| raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") |
| else: |
| return |
|
|
| if self.is_conv2d and not self.is_conv2d_3x3: |
| lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) |
| lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) |
|
|
| state_dict.update( |
| {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))} |
| ) |
| state_dict.update( |
| {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))} |
| ) |
|
|
| super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
|
|
|
|
| def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): |
| if network_dim is None: |
| network_dim = 4 |
| if network_alpha is None: |
| network_alpha = 1.0 |
|
|
| |
| conv_dim = kwargs.get("conv_dim", None) |
| conv_alpha = kwargs.get("conv_alpha", None) |
| unit = kwargs.get("unit", None) |
| if conv_dim is not None: |
| conv_dim = int(conv_dim) |
| assert conv_dim == network_dim, "conv_dim must be same as network_dim" |
| if conv_alpha is None: |
| conv_alpha = 1.0 |
| else: |
| conv_alpha = float(conv_alpha) |
| if unit is not None: |
| unit = int(unit) |
| else: |
| unit = 1 |
|
|
| network = DyLoRANetwork( |
| text_encoder, |
| unet, |
| multiplier=multiplier, |
| lora_dim=network_dim, |
| alpha=network_alpha, |
| apply_to_conv=conv_dim is not None, |
| unit=unit, |
| varbose=True, |
| ) |
| return network |
|
|
|
|
| |
| def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): |
| if weights_sd is None: |
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import load_file, safe_open |
|
|
| weights_sd = load_file(file) |
| else: |
| weights_sd = torch.load(file, map_location="cpu") |
|
|
| |
| modules_dim = {} |
| modules_alpha = {} |
| for key, value in weights_sd.items(): |
| if "." not in key: |
| continue |
|
|
| lora_name = key.split(".")[0] |
| if "alpha" in key: |
| modules_alpha[lora_name] = value |
| elif "lora_down" in key: |
| dim = value.size()[0] |
| modules_dim[lora_name] = dim |
| |
|
|
| |
| for key in modules_dim.keys(): |
| if key not in modules_alpha: |
| modules_alpha = modules_dim[key] |
|
|
| module_class = DyLoRAModule |
|
|
| network = DyLoRANetwork( |
| text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class |
| ) |
| return network, weights_sd |
|
|
|
|
| class DyLoRANetwork(torch.nn.Module): |
| UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] |
| UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] |
| LORA_PREFIX_UNET = "lora_unet" |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
|
| def __init__( |
| self, |
| text_encoder, |
| unet, |
| multiplier=1.0, |
| lora_dim=4, |
| alpha=1, |
| apply_to_conv=False, |
| modules_dim=None, |
| modules_alpha=None, |
| unit=1, |
| module_class=DyLoRAModule, |
| varbose=False, |
| ) -> None: |
| super().__init__() |
| self.multiplier = multiplier |
|
|
| self.lora_dim = lora_dim |
| self.alpha = alpha |
| self.apply_to_conv = apply_to_conv |
|
|
| if modules_dim is not None: |
| print(f"create LoRA network from weights") |
| else: |
| print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") |
| if self.apply_to_conv: |
| print(f"apply LoRA to Conv2d with kernel size (3,3).") |
|
|
| |
| def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: |
| prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER |
| loras = [] |
| for name, module in root_module.named_modules(): |
| if module.__class__.__name__ in target_replace_modules: |
| for child_name, child_module in module.named_modules(): |
| is_linear = child_module.__class__.__name__ == "Linear" |
| is_conv2d = child_module.__class__.__name__ == "Conv2d" |
| is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) |
|
|
| if is_linear or is_conv2d: |
| lora_name = prefix + "." + name + "." + child_name |
| lora_name = lora_name.replace(".", "_") |
|
|
| dim = None |
| alpha = None |
| if modules_dim is not None: |
| if lora_name in modules_dim: |
| dim = modules_dim[lora_name] |
| alpha = modules_alpha[lora_name] |
| else: |
| if is_linear or is_conv2d_1x1 or apply_to_conv: |
| dim = self.lora_dim |
| alpha = self.alpha |
|
|
| if dim is None or dim == 0: |
| continue |
|
|
| |
| lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) |
| loras.append(lora) |
| return loras |
|
|
| self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) |
| print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") |
|
|
| |
| target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE |
| if modules_dim is not None or self.apply_to_conv: |
| target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
|
|
| self.unet_loras = create_modules(True, unet, target_modules) |
| print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") |
|
|
| def set_multiplier(self, multiplier): |
| self.multiplier = multiplier |
| for lora in self.text_encoder_loras + self.unet_loras: |
| lora.multiplier = self.multiplier |
|
|
| def load_weights(self, file): |
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import load_file |
|
|
| weights_sd = load_file(file) |
| else: |
| weights_sd = torch.load(file, map_location="cpu") |
|
|
| info = self.load_state_dict(weights_sd, False) |
| return info |
|
|
| def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): |
| if apply_text_encoder: |
| print("enable LoRA for text encoder") |
| else: |
| self.text_encoder_loras = [] |
|
|
| if apply_unet: |
| print("enable LoRA for U-Net") |
| else: |
| self.unet_loras = [] |
|
|
| for lora in self.text_encoder_loras + self.unet_loras: |
| lora.apply_to() |
| self.add_module(lora.lora_name, lora) |
|
|
| """ |
| def merge_to(self, text_encoder, unet, weights_sd, dtype, device): |
| apply_text_encoder = apply_unet = False |
| for key in weights_sd.keys(): |
| if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER): |
| apply_text_encoder = True |
| elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET): |
| apply_unet = True |
| |
| if apply_text_encoder: |
| print("enable LoRA for text encoder") |
| else: |
| self.text_encoder_loras = [] |
| |
| if apply_unet: |
| print("enable LoRA for U-Net") |
| else: |
| self.unet_loras = [] |
| |
| for lora in self.text_encoder_loras + self.unet_loras: |
| sd_for_lora = {} |
| for key in weights_sd.keys(): |
| if key.startswith(lora.lora_name): |
| sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] |
| lora.merge_to(sd_for_lora, dtype, device) |
| |
| print(f"weights are merged") |
| """ |
|
|
| def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): |
| self.requires_grad_(True) |
| all_params = [] |
|
|
| def enumerate_params(loras): |
| params = [] |
| for lora in loras: |
| params.extend(lora.parameters()) |
| return params |
|
|
| if self.text_encoder_loras: |
| param_data = {"params": enumerate_params(self.text_encoder_loras)} |
| if text_encoder_lr is not None: |
| param_data["lr"] = text_encoder_lr |
| all_params.append(param_data) |
|
|
| if self.unet_loras: |
| param_data = {"params": enumerate_params(self.unet_loras)} |
| if unet_lr is not None: |
| param_data["lr"] = unet_lr |
| all_params.append(param_data) |
|
|
| return all_params |
|
|
| def enable_gradient_checkpointing(self): |
| |
| pass |
|
|
| def prepare_grad_etc(self, text_encoder, unet): |
| self.requires_grad_(True) |
|
|
| def on_epoch_start(self, text_encoder, unet): |
| self.train() |
|
|
| def get_trainable_params(self): |
| return self.parameters() |
|
|
| def save_weights(self, file, dtype, metadata): |
| if metadata is not None and len(metadata) == 0: |
| metadata = None |
|
|
| state_dict = self.state_dict() |
|
|
| if dtype is not None: |
| for key in list(state_dict.keys()): |
| v = state_dict[key] |
| v = v.detach().clone().to("cpu").to(dtype) |
| state_dict[key] = v |
|
|
| if os.path.splitext(file)[1] == ".safetensors": |
| from safetensors.torch import save_file |
| from library import train_util |
|
|
| |
| if metadata is None: |
| metadata = {} |
| model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) |
| metadata["sshs_model_hash"] = model_hash |
| metadata["sshs_legacy_hash"] = legacy_hash |
|
|
| save_file(state_dict, file, metadata) |
| else: |
| torch.save(state_dict, file) |
|
|
| |
| def set_region(self, sub_prompt_index, is_last_network, mask): |
| pass |
|
|
| def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): |
| pass |
|
|