Spaces:
Running
on
A100
Running
on
A100
Commit
•
ce92feb
1
Parent(s):
fc8ab35
Update lora.py
Browse files
lora.py
CHANGED
@@ -5,12 +5,16 @@
|
|
5 |
|
6 |
import math
|
7 |
import os
|
8 |
-
from typing import List, Tuple, Union
|
|
|
|
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
import re
|
12 |
|
13 |
|
|
|
|
|
14 |
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
15 |
|
16 |
|
@@ -400,7 +404,16 @@ def parse_block_lr_kwargs(nw_kwargs):
|
|
400 |
return down_lr_weight, mid_lr_weight, up_lr_weight
|
401 |
|
402 |
|
403 |
-
def create_network(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
if network_dim is None:
|
405 |
network_dim = 4 # default
|
406 |
if network_alpha is None:
|
@@ -719,33 +732,36 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|
719 |
class LoRANetwork(torch.nn.Module):
|
720 |
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
721 |
|
722 |
-
|
723 |
-
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
724 |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
725 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
726 |
LORA_PREFIX_UNET = "lora_unet"
|
727 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
728 |
|
|
|
|
|
|
|
|
|
729 |
def __init__(
|
730 |
self,
|
731 |
-
text_encoder,
|
732 |
unet,
|
733 |
-
multiplier=1.0,
|
734 |
-
lora_dim=4,
|
735 |
-
alpha=1,
|
736 |
-
dropout=None,
|
737 |
-
rank_dropout=None,
|
738 |
-
module_dropout=None,
|
739 |
-
conv_lora_dim=None,
|
740 |
-
conv_alpha=None,
|
741 |
-
block_dims=None,
|
742 |
-
block_alphas=None,
|
743 |
-
conv_block_dims=None,
|
744 |
-
conv_block_alphas=None,
|
745 |
-
modules_dim=None,
|
746 |
-
modules_alpha=None,
|
747 |
-
module_class=LoRAModule,
|
748 |
-
varbose=False,
|
749 |
) -> None:
|
750 |
"""
|
751 |
LoRA network: すごく引数が多いが、パターンは以下の通り
|
@@ -783,8 +799,21 @@ class LoRANetwork(torch.nn.Module):
|
|
783 |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
784 |
|
785 |
# create module instances
|
786 |
-
def create_modules(
|
787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
788 |
loras = []
|
789 |
skipped = []
|
790 |
for name, module in root_module.named_modules():
|
@@ -800,11 +829,14 @@ class LoRANetwork(torch.nn.Module):
|
|
800 |
|
801 |
dim = None
|
802 |
alpha = None
|
|
|
803 |
if modules_dim is not None:
|
|
|
804 |
if lora_name in modules_dim:
|
805 |
dim = modules_dim[lora_name]
|
806 |
alpha = modules_alpha[lora_name]
|
807 |
elif is_unet and block_dims is not None:
|
|
|
808 |
block_idx = get_block_index(lora_name)
|
809 |
if is_linear or is_conv2d_1x1:
|
810 |
dim = block_dims[block_idx]
|
@@ -813,6 +845,7 @@ class LoRANetwork(torch.nn.Module):
|
|
813 |
dim = conv_block_dims[block_idx]
|
814 |
alpha = conv_block_alphas[block_idx]
|
815 |
else:
|
|
|
816 |
if is_linear or is_conv2d_1x1:
|
817 |
dim = self.lora_dim
|
818 |
alpha = self.alpha
|
@@ -821,6 +854,7 @@ class LoRANetwork(torch.nn.Module):
|
|
821 |
alpha = self.conv_alpha
|
822 |
|
823 |
if dim is None or dim == 0:
|
|
|
824 |
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
825 |
skipped.append(lora_name)
|
826 |
continue
|
@@ -838,7 +872,24 @@ class LoRANetwork(torch.nn.Module):
|
|
838 |
loras.append(lora)
|
839 |
return loras, skipped
|
840 |
|
841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
843 |
|
844 |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
@@ -846,7 +897,7 @@ class LoRANetwork(torch.nn.Module):
|
|
846 |
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
847 |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
848 |
|
849 |
-
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
850 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
851 |
|
852 |
skipped = skipped_te + skipped_un
|
@@ -880,7 +931,6 @@ class LoRANetwork(torch.nn.Module):
|
|
880 |
weights_sd = load_file(file)
|
881 |
else:
|
882 |
weights_sd = torch.load(file, map_location="cpu")
|
883 |
-
|
884 |
info = self.load_state_dict(weights_sd, False)
|
885 |
return info
|
886 |
|
@@ -961,6 +1011,7 @@ class LoRANetwork(torch.nn.Module):
|
|
961 |
|
962 |
return lr_weight
|
963 |
|
|
|
964 |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
965 |
self.requires_grad_(True)
|
966 |
all_params = []
|
|
|
5 |
|
6 |
import math
|
7 |
import os
|
8 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
9 |
+
from diffusers import AutoencoderKL
|
10 |
+
from transformers import CLIPTextModel
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
import re
|
14 |
|
15 |
|
16 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
17 |
+
|
18 |
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
19 |
|
20 |
|
|
|
404 |
return down_lr_weight, mid_lr_weight, up_lr_weight
|
405 |
|
406 |
|
407 |
+
def create_network(
|
408 |
+
multiplier: float,
|
409 |
+
network_dim: Optional[int],
|
410 |
+
network_alpha: Optional[float],
|
411 |
+
vae: AutoencoderKL,
|
412 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
413 |
+
unet,
|
414 |
+
neuron_dropout: Optional[float] = None,
|
415 |
+
**kwargs,
|
416 |
+
):
|
417 |
if network_dim is None:
|
418 |
network_dim = 4 # default
|
419 |
if network_alpha is None:
|
|
|
732 |
class LoRANetwork(torch.nn.Module):
|
733 |
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
734 |
|
735 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
|
|
736 |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
737 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
738 |
LORA_PREFIX_UNET = "lora_unet"
|
739 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
740 |
|
741 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
742 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
743 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
744 |
+
|
745 |
def __init__(
|
746 |
self,
|
747 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
748 |
unet,
|
749 |
+
multiplier: float = 1.0,
|
750 |
+
lora_dim: int = 4,
|
751 |
+
alpha: float = 1,
|
752 |
+
dropout: Optional[float] = None,
|
753 |
+
rank_dropout: Optional[float] = None,
|
754 |
+
module_dropout: Optional[float] = None,
|
755 |
+
conv_lora_dim: Optional[int] = None,
|
756 |
+
conv_alpha: Optional[float] = None,
|
757 |
+
block_dims: Optional[List[int]] = None,
|
758 |
+
block_alphas: Optional[List[float]] = None,
|
759 |
+
conv_block_dims: Optional[List[int]] = None,
|
760 |
+
conv_block_alphas: Optional[List[float]] = None,
|
761 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
762 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
763 |
+
module_class: Type[object] = LoRAModule,
|
764 |
+
varbose: Optional[bool] = False,
|
765 |
) -> None:
|
766 |
"""
|
767 |
LoRA network: すごく引数が多いが、パターンは以下の通り
|
|
|
799 |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
800 |
|
801 |
# create module instances
|
802 |
+
def create_modules(
|
803 |
+
is_unet: bool,
|
804 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
805 |
+
root_module: torch.nn.Module,
|
806 |
+
target_replace_modules: List[torch.nn.Module],
|
807 |
+
) -> List[LoRAModule]:
|
808 |
+
prefix = (
|
809 |
+
self.LORA_PREFIX_UNET
|
810 |
+
if is_unet
|
811 |
+
else (
|
812 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
813 |
+
if text_encoder_idx is None
|
814 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
815 |
+
)
|
816 |
+
)
|
817 |
loras = []
|
818 |
skipped = []
|
819 |
for name, module in root_module.named_modules():
|
|
|
829 |
|
830 |
dim = None
|
831 |
alpha = None
|
832 |
+
|
833 |
if modules_dim is not None:
|
834 |
+
# モジュール指定あり
|
835 |
if lora_name in modules_dim:
|
836 |
dim = modules_dim[lora_name]
|
837 |
alpha = modules_alpha[lora_name]
|
838 |
elif is_unet and block_dims is not None:
|
839 |
+
# U-Netでblock_dims指定あり
|
840 |
block_idx = get_block_index(lora_name)
|
841 |
if is_linear or is_conv2d_1x1:
|
842 |
dim = block_dims[block_idx]
|
|
|
845 |
dim = conv_block_dims[block_idx]
|
846 |
alpha = conv_block_alphas[block_idx]
|
847 |
else:
|
848 |
+
# 通常、すべて対象とする
|
849 |
if is_linear or is_conv2d_1x1:
|
850 |
dim = self.lora_dim
|
851 |
alpha = self.alpha
|
|
|
854 |
alpha = self.conv_alpha
|
855 |
|
856 |
if dim is None or dim == 0:
|
857 |
+
# skipした情報を出力
|
858 |
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
859 |
skipped.append(lora_name)
|
860 |
continue
|
|
|
872 |
loras.append(lora)
|
873 |
return loras, skipped
|
874 |
|
875 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
876 |
+
print(text_encoders)
|
877 |
+
# create LoRA for text encoder
|
878 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
879 |
+
self.text_encoder_loras = []
|
880 |
+
skipped_te = []
|
881 |
+
for i, text_encoder in enumerate(text_encoders):
|
882 |
+
if len(text_encoders) > 1:
|
883 |
+
index = i + 1
|
884 |
+
print(f"create LoRA for Text Encoder {index}:")
|
885 |
+
else:
|
886 |
+
index = None
|
887 |
+
print(f"create LoRA for Text Encoder:")
|
888 |
+
|
889 |
+
print(text_encoder)
|
890 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
891 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
892 |
+
skipped_te += skipped
|
893 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
894 |
|
895 |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
|
|
897 |
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
898 |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
899 |
|
900 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
901 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
902 |
|
903 |
skipped = skipped_te + skipped_un
|
|
|
931 |
weights_sd = load_file(file)
|
932 |
else:
|
933 |
weights_sd = torch.load(file, map_location="cpu")
|
|
|
934 |
info = self.load_state_dict(weights_sd, False)
|
935 |
return info
|
936 |
|
|
|
1011 |
|
1012 |
return lr_weight
|
1013 |
|
1014 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
1015 |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
1016 |
self.requires_grad_(True)
|
1017 |
all_params = []
|