Spaces:
Running
on
T4
Running
on
T4
# LoRA network module | |
# reference: | |
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py | |
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | |
# https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48 | |
import math | |
import os | |
import torch | |
import modules.safe as _ | |
from safetensors.torch import load_file | |
class LoRAModule(torch.nn.Module): | |
""" | |
replaces forward method of the original Linear, instead of replacing the original Linear module. | |
""" | |
def __init__( | |
self, | |
lora_name, | |
org_module: torch.nn.Module, | |
multiplier=1.0, | |
lora_dim=4, | |
alpha=1, | |
): | |
"""if alpha == 0 or None, alpha is rank (no scaling).""" | |
super().__init__() | |
self.lora_name = lora_name | |
self.lora_dim = lora_dim | |
if org_module.__class__.__name__ == "Conv2d": | |
in_dim = org_module.in_channels | |
out_dim = org_module.out_channels | |
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) | |
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) | |
else: | |
in_dim = org_module.in_features | |
out_dim = org_module.out_features | |
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) | |
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) | |
if type(alpha) == torch.Tensor: | |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error | |
alpha = lora_dim if alpha is None or alpha == 0 else alpha | |
self.scale = alpha / self.lora_dim | |
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える | |
# same as microsoft's | |
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) | |
torch.nn.init.zeros_(self.lora_up.weight) | |
self.multiplier = multiplier | |
self.org_module = org_module # remove in applying | |
self.enable = False | |
def resize(self, rank, alpha, multiplier): | |
self.alpha = torch.tensor(alpha) | |
self.multiplier = multiplier | |
self.scale = alpha / rank | |
if self.lora_down.__class__.__name__ == "Conv2d": | |
in_dim = self.lora_down.in_channels | |
out_dim = self.lora_up.out_channels | |
self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False) | |
self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False) | |
else: | |
in_dim = self.lora_down.in_features | |
out_dim = self.lora_up.out_features | |
self.lora_down = torch.nn.Linear(in_dim, rank, bias=False) | |
self.lora_up = torch.nn.Linear(rank, out_dim, bias=False) | |
def apply(self): | |
if hasattr(self, "org_module"): | |
self.org_forward = self.org_module.forward | |
self.org_module.forward = self.forward | |
del self.org_module | |
def forward(self, x): | |
if self.enable: | |
return ( | |
self.org_forward(x) | |
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale | |
) | |
return self.org_forward(x) | |
class LoRANetwork(torch.nn.Module): | |
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] | |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
LORA_PREFIX_UNET = "lora_unet" | |
LORA_PREFIX_TEXT_ENCODER = "lora_te" | |
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: | |
super().__init__() | |
self.multiplier = multiplier | |
self.lora_dim = lora_dim | |
self.alpha = alpha | |
# create module instances | |
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules): | |
loras = [] | |
for name, module in root_module.named_modules(): | |
if module.__class__.__name__ in target_replace_modules: | |
for child_name, child_module in module.named_modules(): | |
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): | |
lora_name = prefix + "." + name + "." + child_name | |
lora_name = lora_name.replace(".", "_") | |
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,) | |
loras.append(lora) | |
return loras | |
if isinstance(text_encoder, list): | |
self.text_encoder_loras = text_encoder | |
else: | |
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) | |
print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") | |
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) | |
print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.") | |
self.weights_sd = None | |
# assertion | |
names = set() | |
for lora in self.text_encoder_loras + self.unet_loras: | |
assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}" | |
names.add(lora.lora_name) | |
lora.apply() | |
self.add_module(lora.lora_name, lora) | |
def reset(self): | |
for lora in self.text_encoder_loras + self.unet_loras: | |
lora.enable = False | |
def load(self, file, scale): | |
weights = None | |
if os.path.splitext(file)[1] == ".safetensors": | |
weights = load_file(file) | |
else: | |
weights = torch.load(file, map_location="cpu") | |
if not weights: | |
return | |
network_alpha = None | |
network_dim = None | |
for key, value in weights.items(): | |
if network_alpha is None and "alpha" in key: | |
network_alpha = value | |
if network_dim is None and "lora_down" in key and len(value.size()) == 2: | |
network_dim = value.size()[0] | |
if network_alpha is None: | |
network_alpha = network_dim | |
weights_has_text_encoder = weights_has_unet = False | |
weights_to_modify = [] | |
for key in weights.keys(): | |
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): | |
weights_has_text_encoder = True | |
if key.startswith(LoRANetwork.LORA_PREFIX_UNET): | |
weights_has_unet = True | |
if weights_has_text_encoder: | |
weights_to_modify += self.text_encoder_loras | |
if weights_has_unet: | |
weights_to_modify += self.unet_loras | |
for lora in self.text_encoder_loras + self.unet_loras: | |
lora.resize(network_dim, network_alpha, scale) | |
if lora in weights_to_modify: | |
lora.enable = True | |
info = self.load_state_dict(weights, False) | |
if len(info.unexpected_keys) > 0: | |
print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}") | |