nyanko7's picture
Update modules/lora.py
4508ef4
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
import math
import os
import torch
import diffusers
import modules.safe as _
from safetensors.torch import load_file
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.enable = False
def resize(self, rank, alpha, multiplier):
self.alpha = torch.tensor(alpha)
self.multiplier = multiplier
self.scale = alpha / rank
if self.lora_down.__class__.__name__ == "Conv2d":
in_dim = self.lora_down.in_channels
out_dim = self.lora_up.out_channels
self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
else:
in_dim = self.lora_down.in_features
out_dim = self.lora_up.out_features
self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
def apply(self):
if hasattr(self, "org_module"):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
if self.enable:
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
return self.org_forward(x)
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
# create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
loras = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
loras.append(lora)
return loras
if isinstance(text_encoder, list):
self.text_encoder_loras = text_encoder
else:
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if diffusers.__version__ >= "0.15.0":
LoRANetwork.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
lora.apply()
self.add_module(lora.lora_name, lora)
def reset(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enable = False
def load(self, file, scale):
weights = None
if os.path.splitext(file)[1] == ".safetensors":
weights = load_file(file)
else:
weights = torch.load(file, map_location="cpu")
if not weights:
return
network_alpha = None
network_dim = None
for key, value in weights.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is None:
network_alpha = network_dim
weights_has_text_encoder = weights_has_unet = False
weights_to_modify = []
for key in weights.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if weights_has_text_encoder:
weights_to_modify += self.text_encoder_loras
if weights_has_unet:
weights_to_modify += self.unet_loras
for lora in self.text_encoder_loras + self.unet_loras:
lora.resize(network_dim, network_alpha, scale)
if lora in weights_to_modify:
lora.enable = True
info = self.load_state_dict(weights, False)
if len(info.unexpected_keys) > 0:
print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")