File size: 7,216 Bytes
1e66485
 
 
 
 
 
 
 
 
4508ef4
1e66485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb4814
1e66485
ebb4814
1e66485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4508ef4
 
 
 
1e66485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb4814
1e66485
 
 
 
ebb4814
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48

import math
import os
import torch
import diffusers
import modules.safe as _
from safetensors.torch import load_file


class LoRAModule(torch.nn.Module):
    """
    replaces forward method of the original Linear, instead of replacing the original Linear module.
    """

    def __init__(
            self,
            lora_name,
            org_module: torch.nn.Module,
            multiplier=1.0,
            lora_dim=4,
            alpha=1,
    ):
        """if alpha == 0 or None, alpha is rank (no scaling)."""
        super().__init__()
        self.lora_name = lora_name
        self.lora_dim = lora_dim

        if org_module.__class__.__name__ == "Conv2d":
            in_dim = org_module.in_channels
            out_dim = org_module.out_channels
            self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
            self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
        else:
            in_dim = org_module.in_features
            out_dim = org_module.out_features
            self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
            self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)

        if type(alpha) == torch.Tensor:
            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error

        alpha = lora_dim if alpha is None or alpha == 0 else alpha
        self.scale = alpha / self.lora_dim
        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える

        # same as microsoft's
        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
        torch.nn.init.zeros_(self.lora_up.weight)

        self.multiplier = multiplier
        self.org_module = org_module  # remove in applying
        self.enable = False

    def resize(self, rank, alpha, multiplier):
        self.alpha = torch.tensor(alpha)
        self.multiplier = multiplier
        self.scale = alpha / rank
        if self.lora_down.__class__.__name__ == "Conv2d":
            in_dim = self.lora_down.in_channels
            out_dim = self.lora_up.out_channels
            self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
            self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
        else:
            in_dim = self.lora_down.in_features
            out_dim = self.lora_up.out_features
            self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
            self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)

    def apply(self):
        if hasattr(self, "org_module"):
            self.org_forward = self.org_module.forward
            self.org_module.forward = self.forward
            del self.org_module

    def forward(self, x):
        if self.enable:
            return (
        self.org_forward(x)
        + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
        )
        return self.org_forward(x)


class LoRANetwork(torch.nn.Module):
    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"

    def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
        super().__init__()
        self.multiplier = multiplier
        self.lora_dim = lora_dim
        self.alpha = alpha

        # create module instances
        def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
            loras = []
            for name, module in root_module.named_modules():
                if module.__class__.__name__ in target_replace_modules:
                    for child_name, child_module in module.named_modules():
                        if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
                            lora_name = prefix + "." + name + "." + child_name
                            lora_name = lora_name.replace(".", "_")
                            lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
                            loras.append(lora)
            return loras

        if isinstance(text_encoder, list):
            self.text_encoder_loras = text_encoder
        else:
            self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
            print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
            
        if diffusers.__version__ >= "0.15.0":
            LoRANetwork.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
    
        self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
        print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")

        self.weights_sd = None

        # assertion
        names = set()
        for lora in self.text_encoder_loras + self.unet_loras:
            assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
            names.add(lora.lora_name)

            lora.apply()
            self.add_module(lora.lora_name, lora)

    def reset(self):
        for lora in self.text_encoder_loras + self.unet_loras:
            lora.enable = False

    def load(self, file, scale):

        weights = None
        if os.path.splitext(file)[1] == ".safetensors":
            weights = load_file(file)
        else:
            weights = torch.load(file, map_location="cpu")

        if not weights:
            return

        network_alpha = None
        network_dim = None
        for key, value in weights.items():
            if network_alpha is None and "alpha" in key:
                network_alpha = value
            if network_dim is None and "lora_down" in key and len(value.size()) == 2:
                network_dim = value.size()[0]

        if network_alpha is None:
            network_alpha = network_dim

        weights_has_text_encoder = weights_has_unet = False
        weights_to_modify = []

        for key in weights.keys():
            if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
                weights_has_text_encoder = True

            if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
                weights_has_unet = True

        if weights_has_text_encoder:
            weights_to_modify += self.text_encoder_loras

        if weights_has_unet:
            weights_to_modify += self.unet_loras

        for lora in self.text_encoder_loras + self.unet_loras:
            lora.resize(network_dim, network_alpha, scale)
            if lora in weights_to_modify:
                lora.enable = True

        info = self.load_state_dict(weights, False)
        if len(info.unexpected_keys) > 0:
            print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")