supertori commited on
Commit
d43d2a2
1 Parent(s): 68c6707

Upload 7 files

Browse files
lycoris/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from lycoris import (
2
+ kohya,
3
+ kohya_model_utils,
4
+ kohya_utils,
5
+ locon,
6
+ loha,
7
+ utils,
8
+ )
lycoris/kohya.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # network module for kohya
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
6
+
7
+ import math
8
+ from warnings import warn
9
+ import os
10
+ from typing import List
11
+ import torch
12
+
13
+ from .kohya_utils import *
14
+ from .locon import LoConModule
15
+ from .loha import LohaModule
16
+
17
+
18
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
19
+ if network_dim is None:
20
+ network_dim = 4 # default
21
+ conv_dim = int(kwargs.get('conv_dim', network_dim))
22
+ conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
23
+ dropout = float(kwargs.get('dropout', 0.))
24
+ algo = kwargs.get('algo', 'lora')
25
+ disable_cp = kwargs.get('disable_conv_cp', False)
26
+ network_module = {
27
+ 'lora': LoConModule,
28
+ 'loha': LohaModule,
29
+ }[algo]
30
+
31
+ print(f'Using rank adaptation algo: {algo}')
32
+
33
+ if (algo == 'loha'
34
+ and not kwargs.get('no_dim_warn', False)
35
+ and (network_dim>64 or conv_dim>64)):
36
+ print('='*20 + 'WARNING' + '='*20)
37
+ warn(
38
+ (
39
+ "You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)"
40
+ "in Hadamard Product representation!\n"
41
+ "Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n"
42
+ "If you just want to use high dim loha, please consider use lower lr."
43
+ ),
44
+ stacklevel=2,
45
+ )
46
+ print('='*20 + 'WARNING' + '='*20)
47
+
48
+ network = LycorisNetwork(
49
+ text_encoder, unet,
50
+ multiplier=multiplier,
51
+ lora_dim=network_dim, conv_lora_dim=conv_dim,
52
+ alpha=network_alpha, conv_alpha=conv_alpha,
53
+ dropout=dropout,
54
+ use_cp=(not bool(disable_cp)),
55
+ network_module=network_module
56
+ )
57
+
58
+ return network
59
+
60
+
61
+ class LycorisNetwork(torch.nn.Module):
62
+ '''
63
+ LoRA + LoCon
64
+ '''
65
+ # Ignore proj_in or proj_out, their channels is only a few.
66
+ UNET_TARGET_REPLACE_MODULE = [
67
+ "Transformer2DModel",
68
+ "Attention",
69
+ "ResnetBlock2D",
70
+ "Downsample2D",
71
+ "Upsample2D"
72
+ ]
73
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
74
+ LORA_PREFIX_UNET = 'lora_unet'
75
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
76
+
77
+ def __init__(
78
+ self,
79
+ text_encoder, unet,
80
+ multiplier=1.0,
81
+ lora_dim=4, conv_lora_dim=4,
82
+ alpha=1, conv_alpha=1,
83
+ use_cp = True,
84
+ dropout = 0, network_module = LoConModule,
85
+ ) -> None:
86
+ super().__init__()
87
+ self.multiplier = multiplier
88
+ self.lora_dim = lora_dim
89
+ self.conv_lora_dim = int(conv_lora_dim)
90
+ if self.conv_lora_dim != self.lora_dim:
91
+ print('Apply different lora dim for conv layer')
92
+ print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
93
+
94
+ self.alpha = alpha
95
+ self.conv_alpha = float(conv_alpha)
96
+ if self.alpha != self.conv_alpha:
97
+ print('Apply different alpha value for conv layer')
98
+ print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
99
+
100
+ if 1 >= dropout >= 0:
101
+ print(f'Use Dropout value: {dropout}')
102
+ self.dropout = dropout
103
+
104
+ # create module instances
105
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[network_module]:
106
+ print('Create LyCORIS Module')
107
+ loras = []
108
+ for name, module in root_module.named_modules():
109
+ if module.__class__.__name__ in target_replace_modules:
110
+ for child_name, child_module in module.named_modules():
111
+ lora_name = prefix + '.' + name + '.' + child_name
112
+ lora_name = lora_name.replace('.', '_')
113
+ if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
114
+ lora = network_module(
115
+ lora_name, child_module, self.multiplier,
116
+ self.lora_dim, self.alpha, self.dropout, use_cp
117
+ )
118
+ elif child_module.__class__.__name__ == 'Conv2d':
119
+ k_size, *_ = child_module.kernel_size
120
+ if k_size==1 and lora_dim>0:
121
+ lora = network_module(
122
+ lora_name, child_module, self.multiplier,
123
+ self.lora_dim, self.alpha, self.dropout, use_cp
124
+ )
125
+ elif conv_lora_dim>0:
126
+ lora = network_module(
127
+ lora_name, child_module, self.multiplier,
128
+ self.conv_lora_dim, self.conv_alpha, self.dropout, use_cp
129
+ )
130
+ else:
131
+ continue
132
+ else:
133
+ continue
134
+ loras.append(lora)
135
+ return loras
136
+
137
+ self.text_encoder_loras = create_modules(
138
+ LycorisNetwork.LORA_PREFIX_TEXT_ENCODER,
139
+ text_encoder,
140
+ LycorisNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
141
+ )
142
+ print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
143
+
144
+ self.unet_loras = create_modules(LycorisNetwork.LORA_PREFIX_UNET, unet, LycorisNetwork.UNET_TARGET_REPLACE_MODULE)
145
+ print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
146
+
147
+ self.weights_sd = None
148
+
149
+ # assertion
150
+ names = set()
151
+ for lora in self.text_encoder_loras + self.unet_loras:
152
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
153
+ names.add(lora.lora_name)
154
+
155
+ def set_multiplier(self, multiplier):
156
+ self.multiplier = multiplier
157
+ for lora in self.text_encoder_loras + self.unet_loras:
158
+ lora.multiplier = self.multiplier
159
+
160
+ def load_weights(self, file):
161
+ if os.path.splitext(file)[1] == '.safetensors':
162
+ from safetensors.torch import load_file, safe_open
163
+ self.weights_sd = load_file(file)
164
+ else:
165
+ self.weights_sd = torch.load(file, map_location='cpu')
166
+
167
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
168
+ if self.weights_sd:
169
+ weights_has_text_encoder = weights_has_unet = False
170
+ for key in self.weights_sd.keys():
171
+ if key.startswith(LycorisNetwork.LORA_PREFIX_TEXT_ENCODER):
172
+ weights_has_text_encoder = True
173
+ elif key.startswith(LycorisNetwork.LORA_PREFIX_UNET):
174
+ weights_has_unet = True
175
+
176
+ if apply_text_encoder is None:
177
+ apply_text_encoder = weights_has_text_encoder
178
+ else:
179
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
180
+
181
+ if apply_unet is None:
182
+ apply_unet = weights_has_unet
183
+ else:
184
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
185
+ else:
186
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
187
+
188
+ if apply_text_encoder:
189
+ print("enable LyCORIS for text encoder")
190
+ else:
191
+ self.text_encoder_loras = []
192
+
193
+ if apply_unet:
194
+ print("enable LyCORIS for U-Net")
195
+ else:
196
+ self.unet_loras = []
197
+
198
+ for lora in self.text_encoder_loras + self.unet_loras:
199
+ lora.apply_to()
200
+ self.add_module(lora.lora_name, lora)
201
+
202
+ if self.weights_sd:
203
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
204
+ info = self.load_state_dict(self.weights_sd, False)
205
+ print(f"weights are loaded: {info}")
206
+
207
+ def enable_gradient_checkpointing(self):
208
+ # not supported
209
+ def make_ckpt(module):
210
+ if isinstance(module, torch.nn.Module):
211
+ module.grad_ckpt = True
212
+ self.apply(make_ckpt)
213
+ pass
214
+
215
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
216
+ def enumerate_params(loras):
217
+ params = []
218
+ for lora in loras:
219
+ params.extend(lora.parameters())
220
+ return params
221
+
222
+ self.requires_grad_(True)
223
+ all_params = []
224
+
225
+ if self.text_encoder_loras:
226
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
227
+ if text_encoder_lr is not None:
228
+ param_data['lr'] = text_encoder_lr
229
+ all_params.append(param_data)
230
+
231
+ if self.unet_loras:
232
+ param_data = {'params': enumerate_params(self.unet_loras)}
233
+ if unet_lr is not None:
234
+ param_data['lr'] = unet_lr
235
+ all_params.append(param_data)
236
+
237
+ return all_params
238
+
239
+ def prepare_grad_etc(self, text_encoder, unet):
240
+ self.requires_grad_(True)
241
+
242
+ def on_epoch_start(self, text_encoder, unet):
243
+ self.train()
244
+
245
+ def get_trainable_params(self):
246
+ return self.parameters()
247
+
248
+ def save_weights(self, file, dtype, metadata):
249
+ if metadata is not None and len(metadata) == 0:
250
+ metadata = None
251
+
252
+ state_dict = self.state_dict()
253
+
254
+ if dtype is not None:
255
+ for key in list(state_dict.keys()):
256
+ v = state_dict[key]
257
+ v = v.detach().clone().to("cpu").to(dtype)
258
+ state_dict[key] = v
259
+
260
+ if os.path.splitext(file)[1] == '.safetensors':
261
+ from safetensors.torch import save_file
262
+
263
+ # Precalculate model hashes to save time on indexing
264
+ if metadata is None:
265
+ metadata = {}
266
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
267
+ metadata["sshs_model_hash"] = model_hash
268
+ metadata["sshs_legacy_hash"] = legacy_hash
269
+
270
+ save_file(state_dict, file, metadata)
271
+ else:
272
+ torch.save(state_dict, file)
lycoris/kohya_model_utils.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
3
+ '''
4
+ # v1: split from train_db_fixed.py.
5
+ # v2: support safetensors
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
11
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
12
+ from safetensors.torch import load_file, save_file
13
+
14
+ # DiffUsers版StableDiffusionのモデルパラメータ
15
+ NUM_TRAIN_TIMESTEPS = 1000
16
+ BETA_START = 0.00085
17
+ BETA_END = 0.0120
18
+
19
+ UNET_PARAMS_MODEL_CHANNELS = 320
20
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
21
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
22
+ UNET_PARAMS_IMAGE_SIZE = 32 # unused
23
+ UNET_PARAMS_IN_CHANNELS = 4
24
+ UNET_PARAMS_OUT_CHANNELS = 4
25
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
26
+ UNET_PARAMS_CONTEXT_DIM = 768
27
+ UNET_PARAMS_NUM_HEADS = 8
28
+
29
+ VAE_PARAMS_Z_CHANNELS = 4
30
+ VAE_PARAMS_RESOLUTION = 256
31
+ VAE_PARAMS_IN_CHANNELS = 3
32
+ VAE_PARAMS_OUT_CH = 3
33
+ VAE_PARAMS_CH = 128
34
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
35
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
36
+
37
+ # V2
38
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
39
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
40
+
41
+ # Diffusersの設定を読み込むための参照モデル
42
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
43
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
44
+
45
+
46
+ # region StableDiffusion->Diffusersの変換コード
47
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
48
+
49
+
50
+ def shave_segments(path, n_shave_prefix_segments=1):
51
+ """
52
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
53
+ """
54
+ if n_shave_prefix_segments >= 0:
55
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
56
+ else:
57
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
58
+
59
+
60
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
61
+ """
62
+ Updates paths inside resnets to the new naming scheme (local renaming)
63
+ """
64
+ mapping = []
65
+ for old_item in old_list:
66
+ new_item = old_item.replace("in_layers.0", "norm1")
67
+ new_item = new_item.replace("in_layers.2", "conv1")
68
+
69
+ new_item = new_item.replace("out_layers.0", "norm2")
70
+ new_item = new_item.replace("out_layers.3", "conv2")
71
+
72
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
73
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
74
+
75
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
76
+
77
+ mapping.append({"old": old_item, "new": new_item})
78
+
79
+ return mapping
80
+
81
+
82
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
83
+ """
84
+ Updates paths inside resnets to the new naming scheme (local renaming)
85
+ """
86
+ mapping = []
87
+ for old_item in old_list:
88
+ new_item = old_item
89
+
90
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
91
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
92
+
93
+ mapping.append({"old": old_item, "new": new_item})
94
+
95
+ return mapping
96
+
97
+
98
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
99
+ """
100
+ Updates paths inside attentions to the new naming scheme (local renaming)
101
+ """
102
+ mapping = []
103
+ for old_item in old_list:
104
+ new_item = old_item
105
+
106
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
107
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
108
+
109
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
110
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
111
+
112
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
+
114
+ mapping.append({"old": old_item, "new": new_item})
115
+
116
+ return mapping
117
+
118
+
119
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
+ """
121
+ Updates paths inside attentions to the new naming scheme (local renaming)
122
+ """
123
+ mapping = []
124
+ for old_item in old_list:
125
+ new_item = old_item
126
+
127
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
128
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
129
+
130
+ new_item = new_item.replace("q.weight", "query.weight")
131
+ new_item = new_item.replace("q.bias", "query.bias")
132
+
133
+ new_item = new_item.replace("k.weight", "key.weight")
134
+ new_item = new_item.replace("k.bias", "key.bias")
135
+
136
+ new_item = new_item.replace("v.weight", "value.weight")
137
+ new_item = new_item.replace("v.bias", "value.bias")
138
+
139
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
+
142
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
+
144
+ mapping.append({"old": old_item, "new": new_item})
145
+
146
+ return mapping
147
+
148
+
149
+ def assign_to_checkpoint(
150
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
+ ):
152
+ """
153
+ This does the final conversion step: take locally converted weights and apply a global renaming
154
+ to them. It splits attention layers, and takes into account additional replacements
155
+ that may arise.
156
+
157
+ Assigns the weights to the new checkpoint.
158
+ """
159
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
+
161
+ # Splits the attention layers into three variables.
162
+ if attention_paths_to_split is not None:
163
+ for path, path_map in attention_paths_to_split.items():
164
+ old_tensor = old_checkpoint[path]
165
+ channels = old_tensor.shape[0] // 3
166
+
167
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
+
169
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
+
171
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
+
174
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
175
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
176
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
177
+
178
+ for path in paths:
179
+ new_path = path["new"]
180
+
181
+ # These have already been assigned
182
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
+ continue
184
+
185
+ # Global renaming happens here
186
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
+
190
+ if additional_replacements is not None:
191
+ for replacement in additional_replacements:
192
+ new_path = new_path.replace(replacement["old"], replacement["new"])
193
+
194
+ # proj_attn.weight has to be converted from conv 1D to linear
195
+ if "proj_attn.weight" in new_path:
196
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
+ else:
198
+ checkpoint[new_path] = old_checkpoint[path["old"]]
199
+
200
+
201
+ def conv_attn_to_linear(checkpoint):
202
+ keys = list(checkpoint.keys())
203
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
204
+ for key in keys:
205
+ if ".".join(key.split(".")[-2:]) in attn_keys:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
208
+ elif "proj_attn.weight" in key:
209
+ if checkpoint[key].ndim > 2:
210
+ checkpoint[key] = checkpoint[key][:, :, 0]
211
+
212
+
213
+ def linear_transformer_to_conv(checkpoint):
214
+ keys = list(checkpoint.keys())
215
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
216
+ for key in keys:
217
+ if ".".join(key.split(".")[-2:]) in tf_keys:
218
+ if checkpoint[key].ndim == 2:
219
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
220
+
221
+
222
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
223
+ """
224
+ Takes a state dict and a config, and returns a converted checkpoint.
225
+ """
226
+
227
+ # extract state_dict for UNet
228
+ unet_state_dict = {}
229
+ unet_key = "model.diffusion_model."
230
+ keys = list(checkpoint.keys())
231
+ for key in keys:
232
+ if key.startswith(unet_key):
233
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
234
+
235
+ new_checkpoint = {}
236
+
237
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
238
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
239
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
240
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
241
+
242
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
243
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
244
+
245
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
246
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
247
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
248
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
249
+
250
+ # Retrieves the keys for the input blocks only
251
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
252
+ input_blocks = {
253
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
254
+ for layer_id in range(num_input_blocks)
255
+ }
256
+
257
+ # Retrieves the keys for the middle blocks only
258
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
259
+ middle_blocks = {
260
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
261
+ for layer_id in range(num_middle_blocks)
262
+ }
263
+
264
+ # Retrieves the keys for the output blocks only
265
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
266
+ output_blocks = {
267
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
268
+ for layer_id in range(num_output_blocks)
269
+ }
270
+
271
+ for i in range(1, num_input_blocks):
272
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
273
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
274
+
275
+ resnets = [
276
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
277
+ ]
278
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
279
+
280
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.weight"
283
+ )
284
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
285
+ f"input_blocks.{i}.0.op.bias"
286
+ )
287
+
288
+ paths = renew_resnet_paths(resnets)
289
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
290
+ assign_to_checkpoint(
291
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
292
+ )
293
+
294
+ if len(attentions):
295
+ paths = renew_attention_paths(attentions)
296
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
297
+ assign_to_checkpoint(
298
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
299
+ )
300
+
301
+ resnet_0 = middle_blocks[0]
302
+ attentions = middle_blocks[1]
303
+ resnet_1 = middle_blocks[2]
304
+
305
+ resnet_0_paths = renew_resnet_paths(resnet_0)
306
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ resnet_1_paths = renew_resnet_paths(resnet_1)
309
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
310
+
311
+ attentions_paths = renew_attention_paths(attentions)
312
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
313
+ assign_to_checkpoint(
314
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
315
+ )
316
+
317
+ for i in range(num_output_blocks):
318
+ block_id = i // (config["layers_per_block"] + 1)
319
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
320
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
321
+ output_block_list = {}
322
+
323
+ for layer in output_block_layers:
324
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
325
+ if layer_id in output_block_list:
326
+ output_block_list[layer_id].append(layer_name)
327
+ else:
328
+ output_block_list[layer_id] = [layer_name]
329
+
330
+ if len(output_block_list) > 1:
331
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
332
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
333
+
334
+ resnet_0_paths = renew_resnet_paths(resnets)
335
+ paths = renew_resnet_paths(resnets)
336
+
337
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
338
+ assign_to_checkpoint(
339
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
340
+ )
341
+
342
+ # オリジナル:
343
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
344
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
345
+
346
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
347
+ for l in output_block_list.values():
348
+ l.sort()
349
+
350
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
351
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.bias"
354
+ ]
355
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
356
+ f"output_blocks.{i}.{index}.conv.weight"
357
+ ]
358
+
359
+ # Clear attentions as they have been attributed above.
360
+ if len(attentions) == 2:
361
+ attentions = []
362
+
363
+ if len(attentions):
364
+ paths = renew_attention_paths(attentions)
365
+ meta_path = {
366
+ "old": f"output_blocks.{i}.1",
367
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
368
+ }
369
+ assign_to_checkpoint(
370
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
371
+ )
372
+ else:
373
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
374
+ for path in resnet_0_paths:
375
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
376
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
377
+
378
+ new_checkpoint[new_path] = unet_state_dict[old_path]
379
+
380
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
381
+ if v2:
382
+ linear_transformer_to_conv(new_checkpoint)
383
+
384
+ return new_checkpoint
385
+
386
+
387
+ def convert_ldm_vae_checkpoint(checkpoint, config):
388
+ # extract state dict for VAE
389
+ vae_state_dict = {}
390
+ vae_key = "first_stage_model."
391
+ keys = list(checkpoint.keys())
392
+ for key in keys:
393
+ if key.startswith(vae_key):
394
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
395
+ # if len(vae_state_dict) == 0:
396
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
397
+ # vae_state_dict = checkpoint
398
+
399
+ new_checkpoint = {}
400
+
401
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
402
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
403
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
404
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
405
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
406
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
407
+
408
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
409
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
410
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
411
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
412
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
413
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
414
+
415
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
416
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
417
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
418
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
419
+
420
+ # Retrieves the keys for the encoder down blocks only
421
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
422
+ down_blocks = {
423
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
424
+ }
425
+
426
+ # Retrieves the keys for the decoder up blocks only
427
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
428
+ up_blocks = {
429
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
430
+ }
431
+
432
+ for i in range(num_down_blocks):
433
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
434
+
435
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.weight"
438
+ )
439
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
440
+ f"encoder.down.{i}.downsample.conv.bias"
441
+ )
442
+
443
+ paths = renew_vae_resnet_paths(resnets)
444
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
445
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
446
+
447
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
448
+ num_mid_res_blocks = 2
449
+ for i in range(1, num_mid_res_blocks + 1):
450
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
451
+
452
+ paths = renew_vae_resnet_paths(resnets)
453
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
454
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
455
+
456
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
457
+ paths = renew_vae_attention_paths(mid_attentions)
458
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
459
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
+ conv_attn_to_linear(new_checkpoint)
461
+
462
+ for i in range(num_up_blocks):
463
+ block_id = num_up_blocks - 1 - i
464
+ resnets = [
465
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
466
+ ]
467
+
468
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.weight"
471
+ ]
472
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
473
+ f"decoder.up.{block_id}.upsample.conv.bias"
474
+ ]
475
+
476
+ paths = renew_vae_resnet_paths(resnets)
477
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
478
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
479
+
480
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
481
+ num_mid_res_blocks = 2
482
+ for i in range(1, num_mid_res_blocks + 1):
483
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
484
+
485
+ paths = renew_vae_resnet_paths(resnets)
486
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
487
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
488
+
489
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
490
+ paths = renew_vae_attention_paths(mid_attentions)
491
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
492
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
493
+ conv_attn_to_linear(new_checkpoint)
494
+ return new_checkpoint
495
+
496
+
497
+ def create_unet_diffusers_config(v2):
498
+ """
499
+ Creates a config for the diffusers based on the config of the LDM model.
500
+ """
501
+ # unet_params = original_config.model.params.unet_config.params
502
+
503
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
504
+
505
+ down_block_types = []
506
+ resolution = 1
507
+ for i in range(len(block_out_channels)):
508
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
509
+ down_block_types.append(block_type)
510
+ if i != len(block_out_channels) - 1:
511
+ resolution *= 2
512
+
513
+ up_block_types = []
514
+ for i in range(len(block_out_channels)):
515
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
516
+ up_block_types.append(block_type)
517
+ resolution //= 2
518
+
519
+ config = dict(
520
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
521
+ in_channels=UNET_PARAMS_IN_CHANNELS,
522
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
523
+ down_block_types=tuple(down_block_types),
524
+ up_block_types=tuple(up_block_types),
525
+ block_out_channels=tuple(block_out_channels),
526
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
527
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
528
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
529
+ )
530
+
531
+ return config
532
+
533
+
534
+ def create_vae_diffusers_config():
535
+ """
536
+ Creates a config for the diffusers based on the config of the LDM model.
537
+ """
538
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
539
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
540
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
541
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
542
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
543
+
544
+ config = dict(
545
+ sample_size=VAE_PARAMS_RESOLUTION,
546
+ in_channels=VAE_PARAMS_IN_CHANNELS,
547
+ out_channels=VAE_PARAMS_OUT_CH,
548
+ down_block_types=tuple(down_block_types),
549
+ up_block_types=tuple(up_block_types),
550
+ block_out_channels=tuple(block_out_channels),
551
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
552
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
553
+ )
554
+ return config
555
+
556
+
557
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
558
+ keys = list(checkpoint.keys())
559
+ text_model_dict = {}
560
+ for key in keys:
561
+ if key.startswith("cond_stage_model.transformer"):
562
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
563
+ return text_model_dict
564
+
565
+
566
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
567
+ # 嫌になるくらい違うぞ!
568
+ def convert_key(key):
569
+ if not key.startswith("cond_stage_model"):
570
+ return None
571
+
572
+ # common conversion
573
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
574
+ key = key.replace("cond_stage_model.model.", "text_model.")
575
+
576
+ if "resblocks" in key:
577
+ # resblocks conversion
578
+ key = key.replace(".resblocks.", ".layers.")
579
+ if ".ln_" in key:
580
+ key = key.replace(".ln_", ".layer_norm")
581
+ elif ".mlp." in key:
582
+ key = key.replace(".c_fc.", ".fc1.")
583
+ key = key.replace(".c_proj.", ".fc2.")
584
+ elif '.attn.out_proj' in key:
585
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
586
+ elif '.attn.in_proj' in key:
587
+ key = None # 特殊なので後で処理する
588
+ else:
589
+ raise ValueError(f"unexpected key in SD: {key}")
590
+ elif '.positional_embedding' in key:
591
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
592
+ elif '.text_projection' in key:
593
+ key = None # 使われない???
594
+ elif '.logit_scale' in key:
595
+ key = None # 使われない???
596
+ elif '.token_embedding' in key:
597
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
598
+ elif '.ln_final' in key:
599
+ key = key.replace(".ln_final", ".final_layer_norm")
600
+ return key
601
+
602
+ keys = list(checkpoint.keys())
603
+ new_sd = {}
604
+ for key in keys:
605
+ # remove resblocks 23
606
+ if '.resblocks.23.' in key:
607
+ continue
608
+ new_key = convert_key(key)
609
+ if new_key is None:
610
+ continue
611
+ new_sd[new_key] = checkpoint[key]
612
+
613
+ # attnの変換
614
+ for key in keys:
615
+ if '.resblocks.23.' in key:
616
+ continue
617
+ if '.resblocks' in key and '.attn.in_proj_' in key:
618
+ # 三つに分割
619
+ values = torch.chunk(checkpoint[key], 3)
620
+
621
+ key_suffix = ".weight" if "weight" in key else ".bias"
622
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
623
+ key_pfx = key_pfx.replace("_weight", "")
624
+ key_pfx = key_pfx.replace("_bias", "")
625
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
626
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
627
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
628
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
629
+
630
+ # rename or add position_ids
631
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
632
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
633
+ # waifu diffusion v1.4
634
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
635
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
636
+ else:
637
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
638
+
639
+ new_sd["text_model.embeddings.position_ids"] = position_ids
640
+ return new_sd
641
+
642
+ # endregion
643
+
644
+
645
+ # region Diffusers->StableDiffusion の変換コード
646
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
647
+
648
+ def conv_transformer_to_linear(checkpoint):
649
+ keys = list(checkpoint.keys())
650
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
651
+ for key in keys:
652
+ if ".".join(key.split(".")[-2:]) in tf_keys:
653
+ if checkpoint[key].ndim > 2:
654
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
655
+
656
+
657
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
658
+ unet_conversion_map = [
659
+ # (stable-diffusion, HF Diffusers)
660
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
661
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
662
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
663
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
664
+ ("input_blocks.0.0.weight", "conv_in.weight"),
665
+ ("input_blocks.0.0.bias", "conv_in.bias"),
666
+ ("out.0.weight", "conv_norm_out.weight"),
667
+ ("out.0.bias", "conv_norm_out.bias"),
668
+ ("out.2.weight", "conv_out.weight"),
669
+ ("out.2.bias", "conv_out.bias"),
670
+ ]
671
+
672
+ unet_conversion_map_resnet = [
673
+ # (stable-diffusion, HF Diffusers)
674
+ ("in_layers.0", "norm1"),
675
+ ("in_layers.2", "conv1"),
676
+ ("out_layers.0", "norm2"),
677
+ ("out_layers.3", "conv2"),
678
+ ("emb_layers.1", "time_emb_proj"),
679
+ ("skip_connection", "conv_shortcut"),
680
+ ]
681
+
682
+ unet_conversion_map_layer = []
683
+ for i in range(4):
684
+ # loop over downblocks/upblocks
685
+
686
+ for j in range(2):
687
+ # loop over resnets/attentions for downblocks
688
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
689
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
690
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
691
+
692
+ if i < 3:
693
+ # no attention layers in down_blocks.3
694
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
695
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
696
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
697
+
698
+ for j in range(3):
699
+ # loop over resnets/attentions for upblocks
700
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
701
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
702
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
703
+
704
+ if i > 0:
705
+ # no attention layers in up_blocks.0
706
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
707
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
708
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
709
+
710
+ if i < 3:
711
+ # no downsample in down_blocks.3
712
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
713
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
714
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
715
+
716
+ # no upsample in up_blocks.3
717
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
718
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
719
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
720
+
721
+ hf_mid_atn_prefix = "mid_block.attentions.0."
722
+ sd_mid_atn_prefix = "middle_block.1."
723
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
724
+
725
+ for j in range(2):
726
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
727
+ sd_mid_res_prefix = f"middle_block.{2*j}."
728
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
729
+
730
+ # buyer beware: this is a *brittle* function,
731
+ # and correct output requires that all of these pieces interact in
732
+ # the exact order in which I have arranged them.
733
+ mapping = {k: k for k in unet_state_dict.keys()}
734
+ for sd_name, hf_name in unet_conversion_map:
735
+ mapping[hf_name] = sd_name
736
+ for k, v in mapping.items():
737
+ if "resnets" in k:
738
+ for sd_part, hf_part in unet_conversion_map_resnet:
739
+ v = v.replace(hf_part, sd_part)
740
+ mapping[k] = v
741
+ for k, v in mapping.items():
742
+ for sd_part, hf_part in unet_conversion_map_layer:
743
+ v = v.replace(hf_part, sd_part)
744
+ mapping[k] = v
745
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
746
+
747
+ if v2:
748
+ conv_transformer_to_linear(new_state_dict)
749
+
750
+ return new_state_dict
751
+
752
+
753
+ # ================#
754
+ # VAE Conversion #
755
+ # ================#
756
+
757
+ def reshape_weight_for_sd(w):
758
+ # convert HF linear weights to SD conv2d weights
759
+ return w.reshape(*w.shape, 1, 1)
760
+
761
+
762
+ def convert_vae_state_dict(vae_state_dict):
763
+ vae_conversion_map = [
764
+ # (stable-diffusion, HF Diffusers)
765
+ ("nin_shortcut", "conv_shortcut"),
766
+ ("norm_out", "conv_norm_out"),
767
+ ("mid.attn_1.", "mid_block.attentions.0."),
768
+ ]
769
+
770
+ for i in range(4):
771
+ # down_blocks have two resnets
772
+ for j in range(2):
773
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
774
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
775
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
776
+
777
+ if i < 3:
778
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
779
+ sd_downsample_prefix = f"down.{i}.downsample."
780
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
781
+
782
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
783
+ sd_upsample_prefix = f"up.{3-i}.upsample."
784
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
785
+
786
+ # up_blocks have three resnets
787
+ # also, up blocks in hf are numbered in reverse from sd
788
+ for j in range(3):
789
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
790
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
791
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
792
+
793
+ # this part accounts for mid blocks in both the encoder and the decoder
794
+ for i in range(2):
795
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
796
+ sd_mid_res_prefix = f"mid.block_{i+1}."
797
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
798
+
799
+ vae_conversion_map_attn = [
800
+ # (stable-diffusion, HF Diffusers)
801
+ ("norm.", "group_norm."),
802
+ ("q.", "query."),
803
+ ("k.", "key."),
804
+ ("v.", "value."),
805
+ ("proj_out.", "proj_attn."),
806
+ ]
807
+
808
+ mapping = {k: k for k in vae_state_dict.keys()}
809
+ for k, v in mapping.items():
810
+ for sd_part, hf_part in vae_conversion_map:
811
+ v = v.replace(hf_part, sd_part)
812
+ mapping[k] = v
813
+ for k, v in mapping.items():
814
+ if "attentions" in k:
815
+ for sd_part, hf_part in vae_conversion_map_attn:
816
+ v = v.replace(hf_part, sd_part)
817
+ mapping[k] = v
818
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
819
+ weights_to_convert = ["q", "k", "v", "proj_out"]
820
+ for k, v in new_state_dict.items():
821
+ for weight_name in weights_to_convert:
822
+ if f"mid.attn_1.{weight_name}.weight" in k:
823
+ # print(f"Reshaping {k} for SD format")
824
+ new_state_dict[k] = reshape_weight_for_sd(v)
825
+
826
+ return new_state_dict
827
+
828
+
829
+ # endregion
830
+
831
+ # region 自作のモデル読み書きなど
832
+
833
+ def is_safetensors(path):
834
+ return os.path.splitext(path)[1].lower() == '.safetensors'
835
+
836
+
837
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
838
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
839
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
840
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
841
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
842
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
843
+ ]
844
+
845
+ if is_safetensors(ckpt_path):
846
+ checkpoint = None
847
+ state_dict = load_file(ckpt_path, "cpu")
848
+ else:
849
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
850
+ if "state_dict" in checkpoint:
851
+ state_dict = checkpoint["state_dict"]
852
+ else:
853
+ state_dict = checkpoint
854
+ checkpoint = None
855
+
856
+ key_reps = []
857
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
858
+ for key in state_dict.keys():
859
+ if key.startswith(rep_from):
860
+ new_key = rep_to + key[len(rep_from):]
861
+ key_reps.append((key, new_key))
862
+
863
+ for key, new_key in key_reps:
864
+ state_dict[new_key] = state_dict[key]
865
+ del state_dict[key]
866
+
867
+ return checkpoint, state_dict
868
+
869
+
870
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
871
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
872
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
873
+ if dtype is not None:
874
+ for k, v in state_dict.items():
875
+ if type(v) is torch.Tensor:
876
+ state_dict[k] = v.to(dtype)
877
+
878
+ # Convert the UNet2DConditionModel model.
879
+ unet_config = create_unet_diffusers_config(v2)
880
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
881
+
882
+ unet = UNet2DConditionModel(**unet_config)
883
+ info = unet.load_state_dict(converted_unet_checkpoint)
884
+ print("loading u-net:", info)
885
+
886
+ # Convert the VAE model.
887
+ vae_config = create_vae_diffusers_config()
888
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
889
+
890
+ vae = AutoencoderKL(**vae_config)
891
+ info = vae.load_state_dict(converted_vae_checkpoint)
892
+ print("loading vae:", info)
893
+
894
+ # convert text_model
895
+ if v2:
896
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
897
+ cfg = CLIPTextConfig(
898
+ vocab_size=49408,
899
+ hidden_size=1024,
900
+ intermediate_size=4096,
901
+ num_hidden_layers=23,
902
+ num_attention_heads=16,
903
+ max_position_embeddings=77,
904
+ hidden_act="gelu",
905
+ layer_norm_eps=1e-05,
906
+ dropout=0.0,
907
+ attention_dropout=0.0,
908
+ initializer_range=0.02,
909
+ initializer_factor=1.0,
910
+ pad_token_id=1,
911
+ bos_token_id=0,
912
+ eos_token_id=2,
913
+ model_type="clip_text_model",
914
+ projection_dim=512,
915
+ torch_dtype="float32",
916
+ transformers_version="4.25.0.dev0",
917
+ )
918
+ text_model = CLIPTextModel._from_config(cfg)
919
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
920
+ else:
921
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
922
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
923
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
924
+ print("loading text encoder:", info)
925
+
926
+ return text_model, vae, unet
927
+
928
+
929
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
930
+ def convert_key(key):
931
+ # position_idsの除去
932
+ if ".position_ids" in key:
933
+ return None
934
+
935
+ # common
936
+ key = key.replace("text_model.encoder.", "transformer.")
937
+ key = key.replace("text_model.", "")
938
+ if "layers" in key:
939
+ # resblocks conversion
940
+ key = key.replace(".layers.", ".resblocks.")
941
+ if ".layer_norm" in key:
942
+ key = key.replace(".layer_norm", ".ln_")
943
+ elif ".mlp." in key:
944
+ key = key.replace(".fc1.", ".c_fc.")
945
+ key = key.replace(".fc2.", ".c_proj.")
946
+ elif '.self_attn.out_proj' in key:
947
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
948
+ elif '.self_attn.' in key:
949
+ key = None # 特殊なので後で処理する
950
+ else:
951
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
952
+ elif '.position_embedding' in key:
953
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
954
+ elif '.token_embedding' in key:
955
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
956
+ elif 'final_layer_norm' in key:
957
+ key = key.replace("final_layer_norm", "ln_final")
958
+ return key
959
+
960
+ keys = list(checkpoint.keys())
961
+ new_sd = {}
962
+ for key in keys:
963
+ new_key = convert_key(key)
964
+ if new_key is None:
965
+ continue
966
+ new_sd[new_key] = checkpoint[key]
967
+
968
+ # attnの変換
969
+ for key in keys:
970
+ if 'layers' in key and 'q_proj' in key:
971
+ # 三つを結合
972
+ key_q = key
973
+ key_k = key.replace("q_proj", "k_proj")
974
+ key_v = key.replace("q_proj", "v_proj")
975
+
976
+ value_q = checkpoint[key_q]
977
+ value_k = checkpoint[key_k]
978
+ value_v = checkpoint[key_v]
979
+ value = torch.cat([value_q, value_k, value_v])
980
+
981
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
982
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
983
+ new_sd[new_key] = value
984
+
985
+ # 最後の層などを捏造するか
986
+ if make_dummy_weights:
987
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
988
+ keys = list(new_sd.keys())
989
+ for key in keys:
990
+ if key.startswith("transformer.resblocks.22."):
991
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
992
+
993
+ # Diffusersに含まれない重みを作っておく
994
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
995
+ new_sd['logit_scale'] = torch.tensor(1)
996
+
997
+ return new_sd
998
+
999
+
1000
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
1001
+ if ckpt_path is not None:
1002
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1003
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1004
+ if checkpoint is None: # safetensors または state_dictのckpt
1005
+ checkpoint = {}
1006
+ strict = False
1007
+ else:
1008
+ strict = True
1009
+ if "state_dict" in state_dict:
1010
+ del state_dict["state_dict"]
1011
+ else:
1012
+ # 新しく作る
1013
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1014
+ checkpoint = {}
1015
+ state_dict = {}
1016
+ strict = False
1017
+
1018
+ def update_sd(prefix, sd):
1019
+ for k, v in sd.items():
1020
+ key = prefix + k
1021
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1022
+ if save_dtype is not None:
1023
+ v = v.detach().clone().to("cpu").to(save_dtype)
1024
+ state_dict[key] = v
1025
+
1026
+ # Convert the UNet model
1027
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1028
+ update_sd("model.diffusion_model.", unet_state_dict)
1029
+
1030
+ # Convert the text encoder model
1031
+ if v2:
1032
+ make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1033
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1034
+ update_sd("cond_stage_model.model.", text_enc_dict)
1035
+ else:
1036
+ text_enc_dict = text_encoder.state_dict()
1037
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1038
+
1039
+ # Convert the VAE
1040
+ if vae is not None:
1041
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1042
+ update_sd("first_stage_model.", vae_dict)
1043
+
1044
+ # Put together new checkpoint
1045
+ key_count = len(state_dict.keys())
1046
+ new_ckpt = {'state_dict': state_dict}
1047
+
1048
+ if 'epoch' in checkpoint:
1049
+ epochs += checkpoint['epoch']
1050
+ if 'global_step' in checkpoint:
1051
+ steps += checkpoint['global_step']
1052
+
1053
+ new_ckpt['epoch'] = epochs
1054
+ new_ckpt['global_step'] = steps
1055
+
1056
+ if is_safetensors(output_file):
1057
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1058
+ save_file(state_dict, output_file)
1059
+ else:
1060
+ torch.save(new_ckpt, output_file)
1061
+
1062
+ return key_count
1063
+
1064
+
1065
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1066
+ if pretrained_model_name_or_path is None:
1067
+ # load default settings for v1/v2
1068
+ if v2:
1069
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1070
+ else:
1071
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1072
+
1073
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1074
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1075
+ if vae is None:
1076
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1077
+
1078
+ pipeline = StableDiffusionPipeline(
1079
+ unet=unet,
1080
+ text_encoder=text_encoder,
1081
+ vae=vae,
1082
+ scheduler=scheduler,
1083
+ tokenizer=tokenizer,
1084
+ safety_checker=None,
1085
+ feature_extractor=None,
1086
+ requires_safety_checker=None,
1087
+ )
1088
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1089
+
1090
+
1091
+ VAE_PREFIX = "first_stage_model."
1092
+
1093
+
1094
+ def load_vae(vae_id, dtype):
1095
+ print(f"load VAE: {vae_id}")
1096
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1097
+ # Diffusers local/remote
1098
+ try:
1099
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1100
+ except EnvironmentError as e:
1101
+ print(f"exception occurs in loading vae: {e}")
1102
+ print("retry with subfolder='vae'")
1103
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1104
+ return vae
1105
+
1106
+ # local
1107
+ vae_config = create_vae_diffusers_config()
1108
+
1109
+ if vae_id.endswith(".bin"):
1110
+ # SD 1.5 VAE on Huggingface
1111
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1112
+ else:
1113
+ # StableDiffusion
1114
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1115
+ else torch.load(vae_id, map_location="cpu"))
1116
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1117
+
1118
+ # vae only or full model
1119
+ full_model = False
1120
+ for vae_key in vae_sd:
1121
+ if vae_key.startswith(VAE_PREFIX):
1122
+ full_model = True
1123
+ break
1124
+ if not full_model:
1125
+ sd = {}
1126
+ for key, value in vae_sd.items():
1127
+ sd[VAE_PREFIX + key] = value
1128
+ vae_sd = sd
1129
+ del sd
1130
+
1131
+ # Convert the VAE model.
1132
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1133
+
1134
+ vae = AutoencoderKL(**vae_config)
1135
+ vae.load_state_dict(converted_vae_checkpoint)
1136
+ return vae
1137
+
1138
+ # endregion
1139
+
1140
+
1141
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1142
+ max_width, max_height = max_reso
1143
+ max_area = (max_width // divisible) * (max_height // divisible)
1144
+
1145
+ resos = set()
1146
+
1147
+ size = int(math.sqrt(max_area)) * divisible
1148
+ resos.add((size, size))
1149
+
1150
+ size = min_size
1151
+ while size <= max_size:
1152
+ width = size
1153
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1154
+ resos.add((width, height))
1155
+ resos.add((height, width))
1156
+
1157
+ # # make additional resos
1158
+ # if width >= height and width - divisible >= min_size:
1159
+ # resos.add((width - divisible, height))
1160
+ # resos.add((height, width - divisible))
1161
+ # if height >= width and height - divisible >= min_size:
1162
+ # resos.add((width, height - divisible))
1163
+ # resos.add((height - divisible, width))
1164
+
1165
+ size += divisible
1166
+
1167
+ resos = list(resos)
1168
+ resos.sort()
1169
+
1170
+ aspect_ratios = [w / h for w, h in resos]
1171
+ return resos, aspect_ratios
1172
+
1173
+
1174
+ if __name__ == '__main__':
1175
+ resos, aspect_ratios = make_bucket_resolutions((512, 768))
1176
+ print(len(resos))
1177
+ print(resos)
1178
+ print(aspect_ratios)
1179
+
1180
+ ars = set()
1181
+ for ar in aspect_ratios:
1182
+ if ar in ars:
1183
+ print("error! duplicate ar:", ar)
1184
+ ars.add(ar)
lycoris/kohya_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
2
+
3
+ import hashlib
4
+ import safetensors
5
+ from io import BytesIO
6
+
7
+
8
+ def addnet_hash_legacy(b):
9
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
10
+ m = hashlib.sha256()
11
+
12
+ b.seek(0x100000)
13
+ m.update(b.read(0x10000))
14
+ return m.hexdigest()[0:8]
15
+
16
+
17
+ def addnet_hash_safetensors(b):
18
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
19
+ hash_sha256 = hashlib.sha256()
20
+ blksize = 1024 * 1024
21
+
22
+ b.seek(0)
23
+ header = b.read(8)
24
+ n = int.from_bytes(header, "little")
25
+
26
+ offset = n + 8
27
+ b.seek(offset)
28
+ for chunk in iter(lambda: b.read(blksize), b""):
29
+ hash_sha256.update(chunk)
30
+
31
+ return hash_sha256.hexdigest()
32
+
33
+
34
+ def precalculate_safetensors_hashes(tensors, metadata):
35
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
36
+ save time on indexing the model later."""
37
+
38
+ # Because writing user metadata to the file can change the result of
39
+ # sd_models.model_hash(), only retain the training metadata for purposes of
40
+ # calculating the hash, as they are meant to be immutable
41
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
42
+
43
+ bytes = safetensors.torch.save(tensors, metadata)
44
+ b = BytesIO(bytes)
45
+
46
+ model_hash = addnet_hash_safetensors(b)
47
+ legacy_hash = addnet_hash_legacy(b)
48
+ return model_hash, legacy_hash
lycoris/locon.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LoConModule(nn.Module):
9
+ """
10
+ modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ lora_name, org_module: nn.Module,
16
+ multiplier=1.0,
17
+ lora_dim=4, alpha=1,
18
+ dropout=0.,
19
+ use_cp=True,
20
+ ):
21
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
22
+ super().__init__()
23
+ self.lora_name = lora_name
24
+ self.lora_dim = lora_dim
25
+ self.cp = False
26
+
27
+ if org_module.__class__.__name__ == 'Conv2d':
28
+ # For general LoCon
29
+ in_dim = org_module.in_channels
30
+ k_size = org_module.kernel_size
31
+ stride = org_module.stride
32
+ padding = org_module.padding
33
+ out_dim = org_module.out_channels
34
+ if use_cp and k_size != (1, 1):
35
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
36
+ self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
37
+ self.cp = True
38
+ else:
39
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
40
+ self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
41
+ else:
42
+ in_dim = org_module.in_features
43
+ out_dim = org_module.out_features
44
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
45
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
46
+ self.shape = org_module.weight.shape
47
+
48
+ if dropout:
49
+ self.dropout = nn.Dropout(dropout)
50
+ else:
51
+ self.dropout = nn.Identity()
52
+
53
+ if type(alpha) == torch.Tensor:
54
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
55
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
56
+ self.scale = alpha / self.lora_dim
57
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
58
+
59
+ # same as microsoft's
60
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
61
+ torch.nn.init.zeros_(self.lora_up.weight)
62
+ if self.cp:
63
+ torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
64
+
65
+ self.multiplier = multiplier
66
+ self.org_module = [org_module]
67
+
68
+ def apply_to(self):
69
+ self.org_forward = self.org_module[0].forward
70
+ self.org_module[0].forward = self.forward
71
+
72
+ def make_weight(self):
73
+ wa = self.lora_up.weight
74
+ wb = self.lora_down.weight
75
+ return (wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1)).view(self.shape)
76
+
77
+ def forward(self, x):
78
+ if self.cp:
79
+ return self.org_forward(x) + self.dropout(
80
+ self.lora_up(self.lora_mid(self.lora_down(x)))* self.multiplier * self.scale
81
+ )
82
+ else:
83
+ return self.org_forward(x) + self.dropout(
84
+ self.lora_up(self.lora_down(x))* self.multiplier * self.scale
85
+ )
lycoris/loha.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class HadaWeight(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
11
+ ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
12
+ diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
13
+ return orig_weight.reshape(diff_weight.shape) + diff_weight
14
+
15
+ @staticmethod
16
+ def backward(ctx, grad_out):
17
+ (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
18
+ grad_out = grad_out * scale
19
+ temp = grad_out*(w2a@w2b)
20
+ grad_w1a = temp @ w1b.T
21
+ grad_w1b = w1a.T @ temp
22
+
23
+ temp = grad_out * (w1a@w1b)
24
+ grad_w2a = temp @ w2b.T
25
+ grad_w2b = w2a.T @ temp
26
+
27
+ del temp
28
+ return grad_out, grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
29
+
30
+
31
+ class HadaWeightCP(torch.autograd.Function):
32
+ @staticmethod
33
+ def forward(ctx, orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale=torch.tensor(1)):
34
+ ctx.save_for_backward(t1, w1a, w1b, t2, w2a, w2b, scale)
35
+
36
+ rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', t1, w1b, w1a)
37
+ rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', t2, w2b, w2a)
38
+
39
+ return orig_weight + rebuild1*rebuild2*scale
40
+
41
+ @staticmethod
42
+ def backward(ctx, grad_out):
43
+ (t1, w1a, w1b, t2, w2a, w2b, scale) = ctx.saved_tensors
44
+
45
+ grad_out = grad_out*scale
46
+
47
+ temp = torch.einsum('i j k l, j r -> i r k l', t2, w2b)
48
+ rebuild = torch.einsum('i j k l, i r -> r j k l', temp, w2a)
49
+
50
+ grad_w = rebuild*grad_out
51
+ del rebuild
52
+
53
+ grad_w1a = torch.einsum('r j k l, i j k l -> r i', temp, grad_w)
54
+ grad_temp = torch.einsum('i j k l, i r -> r j k l', grad_w, w1a.T)
55
+ del grad_w, temp
56
+
57
+ grad_w1b = torch.einsum('i r k l, i j k l -> r j', t1, grad_temp)
58
+ grad_t1 = torch.einsum('i j k l, j r -> i r k l', grad_temp, w1b.T)
59
+ del grad_temp
60
+
61
+ temp = torch.einsum('i j k l, j r -> i r k l', t1, w1b)
62
+ rebuild = torch.einsum('i j k l, i r -> r j k l', temp, w1a)
63
+
64
+ grad_w = rebuild*grad_out
65
+ del rebuild
66
+
67
+ grad_w2a = torch.einsum('r j k l, i j k l -> r i', temp, grad_w)
68
+ grad_temp = torch.einsum('i j k l, i r -> r j k l', grad_w, w2a.T)
69
+ del grad_w, temp
70
+
71
+ grad_w2b = torch.einsum('i r k l, i j k l -> r j', t2, grad_temp)
72
+ grad_t2 = torch.einsum('i j k l, j r -> i r k l', grad_temp, w2b.T)
73
+ del grad_temp
74
+ return grad_out, grad_t1, grad_w1a, grad_w1b, grad_t2, grad_w2a, grad_w2b, None
75
+
76
+
77
+ def make_weight(orig_weight, w1a, w1b, w2a, w2b, scale):
78
+ return HadaWeight.apply(orig_weight, w1a, w1b, w2a, w2b, scale)
79
+
80
+
81
+ def make_weight_cp(orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale):
82
+ return HadaWeightCP.apply(orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale)
83
+
84
+
85
+ class LohaModule(nn.Module):
86
+ """
87
+ Hadamard product Implementaion for Low Rank Adaptation
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ lora_name,
93
+ org_module: nn.Module,
94
+ multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
95
+ use_cp=True,
96
+ ):
97
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
98
+ super().__init__()
99
+ self.lora_name = lora_name
100
+ self.lora_dim = lora_dim
101
+ self.cp=False
102
+
103
+ self.shape = org_module.weight.shape
104
+ if org_module.__class__.__name__ == 'Conv2d':
105
+ in_dim = org_module.in_channels
106
+ k_size = org_module.kernel_size
107
+ out_dim = org_module.out_channels
108
+ self.cp = use_cp and k_size!=(1, 1)
109
+ if self.cp:
110
+ shape = (out_dim, in_dim, *k_size)
111
+ else:
112
+ shape = (out_dim, in_dim*k_size[0]*k_size[1])
113
+ self.op = F.conv2d
114
+ self.extra_args = {
115
+ "stride": org_module.stride,
116
+ "padding": org_module.padding,
117
+ "dilation": org_module.dilation,
118
+ "groups": org_module.groups
119
+ }
120
+ else:
121
+ in_dim = org_module.in_features
122
+ out_dim = org_module.out_features
123
+ shape = (out_dim, in_dim)
124
+ self.op = F.linear
125
+ self.extra_args = {}
126
+
127
+ if self.cp:
128
+ self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
129
+ self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, shape[0])) # out_dim, 1-mode
130
+ self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1])) # in_dim , 2-mode
131
+
132
+ self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
133
+ self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0])) # out_dim, 1-mode
134
+ self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1])) # in_dim , 2-mode
135
+ else:
136
+ self.hada_w1_a = nn.Parameter(torch.empty(shape[0], lora_dim))
137
+ self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
138
+
139
+ self.hada_w2_a = nn.Parameter(torch.empty(shape[0], lora_dim))
140
+ self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
141
+
142
+ if dropout:
143
+ self.dropout = nn.Dropout(dropout)
144
+ else:
145
+ self.dropout = nn.Identity()
146
+
147
+ if type(alpha) == torch.Tensor:
148
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
149
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
150
+ self.scale = alpha / self.lora_dim
151
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
152
+
153
+ # Need more experiences on init method
154
+ if self.cp:
155
+ torch.nn.init.normal_(self.hada_t1, std=0.1)
156
+ torch.nn.init.normal_(self.hada_t2, std=0.1)
157
+ torch.nn.init.normal_(self.hada_w1_b, std=1)
158
+ torch.nn.init.normal_(self.hada_w2_b, std=0.01)
159
+ torch.nn.init.normal_(self.hada_w1_a, std=1)
160
+ torch.nn.init.constant_(self.hada_w2_a, 0)
161
+
162
+ self.multiplier = multiplier
163
+ self.org_module = [org_module] # remove in applying
164
+ self.grad_ckpt = False
165
+
166
+ def apply_to(self):
167
+ self.org_module[0].forward = self.forward
168
+
169
+ def get_weight(self):
170
+ d_weight = self.hada_w1_a @ self.hada_w1_b
171
+ d_weight *= self.hada_w2_a @ self.hada_w2_b
172
+ return (d_weight).reshape(self.shape)
173
+
174
+ @torch.enable_grad()
175
+ def forward(self, x):
176
+ # print(torch.mean(torch.abs(self.orig_w1a.to(x.device) - self.hada_w1_a)), end='\r')
177
+ if self.cp:
178
+ weight = make_weight_cp(
179
+ self.org_module[0].weight.data,
180
+ self.hada_t1, self.hada_w1_a, self.hada_w1_b,
181
+ self.hada_t1, self.hada_w2_a, self.hada_w2_b,
182
+ scale = torch.tensor(self.scale*self.multiplier),
183
+ )
184
+ else:
185
+ weight = make_weight(
186
+ self.org_module[0].weight.data,
187
+ self.hada_w1_a, self.hada_w1_b,
188
+ self.hada_w2_a, self.hada_w2_b,
189
+ scale = torch.tensor(self.scale*self.multiplier),
190
+ )
191
+
192
+ bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
193
+ return self.op(
194
+ x,
195
+ weight.view(self.shape),
196
+ bias,
197
+ **self.extra_args
198
+ )
lycoris/utils.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import torch.linalg as linalg
10
+
11
+ from tqdm import tqdm
12
+
13
+
14
+ def make_sparse(t: torch.Tensor, sparsity=0.95):
15
+ abs_t = torch.abs(t)
16
+ np_array = abs_t.detach().cpu().numpy()
17
+ quan = float(np.quantile(np_array, sparsity))
18
+ sparse_t = t.masked_fill(abs_t < quan, 0)
19
+ return sparse_t
20
+
21
+
22
+ def extract_conv(
23
+ weight: Union[torch.Tensor, nn.Parameter],
24
+ mode = 'fixed',
25
+ mode_param = 0,
26
+ device = 'cpu',
27
+ ) -> Tuple[nn.Parameter, nn.Parameter]:
28
+ weight = weight.to(device)
29
+ out_ch, in_ch, kernel_size, _ = weight.shape
30
+
31
+ U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
32
+
33
+ if mode=='fixed':
34
+ lora_rank = mode_param
35
+ elif mode=='threshold':
36
+ assert mode_param>=0
37
+ lora_rank = torch.sum(S>mode_param)
38
+ elif mode=='ratio':
39
+ assert 1>=mode_param>=0
40
+ min_s = torch.max(S)*mode_param
41
+ lora_rank = torch.sum(S>min_s)
42
+ elif mode=='quantile' or mode=='percentile':
43
+ assert 1>=mode_param>=0
44
+ s_cum = torch.cumsum(S, dim=0)
45
+ min_cum_sum = mode_param * torch.sum(S)
46
+ lora_rank = torch.sum(s_cum<min_cum_sum)
47
+ else:
48
+ raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
49
+ lora_rank = max(1, lora_rank)
50
+ lora_rank = min(out_ch, in_ch, lora_rank)
51
+
52
+ U = U[:, :lora_rank]
53
+ S = S[:lora_rank]
54
+ U = U @ torch.diag(S)
55
+ Vh = Vh[:lora_rank, :]
56
+
57
+ diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
58
+ extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
59
+ extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
60
+ del U, S, Vh, weight
61
+ return extract_weight_A, extract_weight_B, diff
62
+
63
+
64
+ def merge_conv(
65
+ weight_a: Union[torch.Tensor, nn.Parameter],
66
+ weight_b: Union[torch.Tensor, nn.Parameter],
67
+ device = 'cpu'
68
+ ):
69
+ rank, in_ch, kernel_size, k_ = weight_a.shape
70
+ out_ch, rank_, _, _ = weight_b.shape
71
+ assert rank == rank_ and kernel_size == k_
72
+
73
+ wa = weight_a.to(device)
74
+ wb = weight_b.to(device)
75
+
76
+ if device == 'cpu':
77
+ wa = wa.float()
78
+ wb = wb.float()
79
+
80
+ merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1)
81
+ weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
82
+ del wb, wa
83
+ return weight
84
+
85
+
86
+ def extract_linear(
87
+ weight: Union[torch.Tensor, nn.Parameter],
88
+ mode = 'fixed',
89
+ mode_param = 0,
90
+ device = 'cpu',
91
+ ) -> Tuple[nn.Parameter, nn.Parameter]:
92
+ weight = weight.to(device)
93
+ out_ch, in_ch = weight.shape
94
+
95
+ U, S, Vh = linalg.svd(weight)
96
+
97
+ if mode=='fixed':
98
+ lora_rank = mode_param
99
+ elif mode=='threshold':
100
+ assert mode_param>=0
101
+ lora_rank = torch.sum(S>mode_param)
102
+ elif mode=='ratio':
103
+ assert 1>=mode_param>=0
104
+ min_s = torch.max(S)*mode_param
105
+ lora_rank = torch.sum(S>min_s)
106
+ elif mode=='quantile' or mode=='percentile':
107
+ assert 1>=mode_param>=0
108
+ s_cum = torch.cumsum(S, dim=0)
109
+ min_cum_sum = mode_param * torch.sum(S)
110
+ lora_rank = torch.sum(s_cum<min_cum_sum)
111
+ else:
112
+ raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
113
+ lora_rank = max(1, lora_rank)
114
+ lora_rank = min(out_ch, in_ch, lora_rank)
115
+
116
+ U = U[:, :lora_rank]
117
+ S = S[:lora_rank]
118
+ U = U @ torch.diag(S)
119
+ Vh = Vh[:lora_rank, :]
120
+
121
+ diff = (weight - U @ Vh).detach()
122
+ extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
123
+ extract_weight_B = U.reshape(out_ch, lora_rank).detach()
124
+ del U, S, Vh, weight
125
+ return extract_weight_A, extract_weight_B, diff
126
+
127
+
128
+ def merge_linear(
129
+ weight_a: Union[torch.Tensor, nn.Parameter],
130
+ weight_b: Union[torch.Tensor, nn.Parameter],
131
+ device = 'cpu'
132
+ ):
133
+ rank, in_ch = weight_a.shape
134
+ out_ch, rank_ = weight_b.shape
135
+ assert rank == rank_
136
+
137
+ wa = weight_a.to(device)
138
+ wb = weight_b.to(device)
139
+
140
+ if device == 'cpu':
141
+ wa = wa.float()
142
+ wb = wb.float()
143
+
144
+ weight = wb @ wa
145
+ del wb, wa
146
+ return weight
147
+
148
+
149
+ def extract_diff(
150
+ base_model,
151
+ db_model,
152
+ mode = 'fixed',
153
+ linear_mode_param = 0,
154
+ conv_mode_param = 0,
155
+ extract_device = 'cpu',
156
+ use_bias = False,
157
+ sparsity = 0.98,
158
+ small_conv = True
159
+ ):
160
+ UNET_TARGET_REPLACE_MODULE = [
161
+ "Transformer2DModel",
162
+ "Attention",
163
+ "ResnetBlock2D",
164
+ "Downsample2D",
165
+ "Upsample2D"
166
+ ]
167
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
168
+ LORA_PREFIX_UNET = 'lora_unet'
169
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
170
+ def make_state_dict(
171
+ prefix,
172
+ root_module: torch.nn.Module,
173
+ target_module: torch.nn.Module,
174
+ target_replace_modules
175
+ ):
176
+ loras = {}
177
+ temp = {}
178
+
179
+ for name, module in root_module.named_modules():
180
+ if module.__class__.__name__ in target_replace_modules:
181
+ temp[name] = {}
182
+ for child_name, child_module in module.named_modules():
183
+ if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
184
+ continue
185
+ temp[name][child_name] = child_module.weight
186
+
187
+ for name, module in tqdm(list(target_module.named_modules())):
188
+ if name in temp:
189
+ weights = temp[name]
190
+ for child_name, child_module in module.named_modules():
191
+ lora_name = prefix + '.' + name + '.' + child_name
192
+ lora_name = lora_name.replace('.', '_')
193
+
194
+ layer = child_module.__class__.__name__
195
+ if layer == 'Linear':
196
+ extract_a, extract_b, diff = extract_linear(
197
+ (child_module.weight - weights[child_name]),
198
+ mode,
199
+ linear_mode_param,
200
+ device = extract_device,
201
+ )
202
+ elif layer == 'Conv2d':
203
+ is_linear = (child_module.weight.shape[2] == 1
204
+ and child_module.weight.shape[3] == 1)
205
+ extract_a, extract_b, diff = extract_conv(
206
+ (child_module.weight - weights[child_name]),
207
+ mode,
208
+ linear_mode_param if is_linear else conv_mode_param,
209
+ device = extract_device,
210
+ )
211
+ if small_conv and not is_linear:
212
+ dim = extract_a.size(0)
213
+ extract_c, extract_a, _ = extract_conv(
214
+ extract_a.transpose(0, 1),
215
+ 'fixed', dim,
216
+ extract_device
217
+ )
218
+ extract_a = extract_a.transpose(0, 1)
219
+ extract_c = extract_c.transpose(0, 1)
220
+ loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
221
+ diff = child_module.weight - torch.einsum(
222
+ 'i j k l, j r, p i -> p r k l',
223
+ extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
224
+ ).detach().cpu().contiguous()
225
+ del extract_c
226
+ else:
227
+ continue
228
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
229
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
230
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
231
+
232
+ if use_bias:
233
+ diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
234
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
235
+
236
+ indices = sparse_diff.indices().to(torch.int16)
237
+ values = sparse_diff.values().half()
238
+ loras[f'{lora_name}.bias_indices'] = indices
239
+ loras[f'{lora_name}.bias_values'] = values
240
+ loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
241
+ del extract_a, extract_b, diff
242
+ return loras
243
+
244
+ text_encoder_loras = make_state_dict(
245
+ LORA_PREFIX_TEXT_ENCODER,
246
+ base_model[0], db_model[0],
247
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
248
+ )
249
+
250
+ unet_loras = make_state_dict(
251
+ LORA_PREFIX_UNET,
252
+ base_model[2], db_model[2],
253
+ UNET_TARGET_REPLACE_MODULE
254
+ )
255
+ print(len(text_encoder_loras), len(unet_loras))
256
+ return text_encoder_loras|unet_loras
257
+
258
+
259
+ def merge_locon(
260
+ base_model,
261
+ locon_state_dict: Dict[str, torch.TensorType],
262
+ scale: float = 1.0,
263
+ device = 'cpu'
264
+ ):
265
+ UNET_TARGET_REPLACE_MODULE = [
266
+ "Transformer2DModel",
267
+ "Attention",
268
+ "ResnetBlock2D",
269
+ "Downsample2D",
270
+ "Upsample2D"
271
+ ]
272
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
273
+ LORA_PREFIX_UNET = 'lora_unet'
274
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
275
+ def merge(
276
+ prefix,
277
+ root_module: torch.nn.Module,
278
+ target_replace_modules
279
+ ):
280
+ temp = {}
281
+
282
+ for name, module in tqdm(list(root_module.named_modules())):
283
+ if module.__class__.__name__ in target_replace_modules:
284
+ temp[name] = {}
285
+ for child_name, child_module in module.named_modules():
286
+ layer = child_module.__class__.__name__
287
+ if layer not in {'Linear', 'Conv2d'}:
288
+ continue
289
+ lora_name = prefix + '.' + name + '.' + child_name
290
+ lora_name = lora_name.replace('.', '_')
291
+
292
+ down = locon_state_dict[f'{lora_name}.lora_down.weight'].float()
293
+ up = locon_state_dict[f'{lora_name}.lora_up.weight'].float()
294
+ alpha = locon_state_dict[f'{lora_name}.alpha'].float()
295
+ rank = down.shape[0]
296
+
297
+ if layer == 'Conv2d':
298
+ delta = merge_conv(down, up, device)
299
+ child_module.weight.requires_grad_(False)
300
+ child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
301
+ del delta
302
+ elif layer == 'Linear':
303
+ delta = merge_linear(down, up, device)
304
+ child_module.weight.requires_grad_(False)
305
+ child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
306
+ del delta
307
+
308
+ merge(
309
+ LORA_PREFIX_TEXT_ENCODER,
310
+ base_model[0],
311
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
312
+ )
313
+ merge(
314
+ LORA_PREFIX_UNET,
315
+ base_model[2],
316
+ UNET_TARGET_REPLACE_MODULE
317
+ )
318
+
319
+
320
+ def merge_loha(
321
+ base_model,
322
+ loha_state_dict: Dict[str, torch.TensorType],
323
+ scale: float = 1.0,
324
+ device = 'cpu'
325
+ ):
326
+ UNET_TARGET_REPLACE_MODULE = [
327
+ "Transformer2DModel",
328
+ "Attention",
329
+ "ResnetBlock2D",
330
+ "Downsample2D",
331
+ "Upsample2D"
332
+ ]
333
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
334
+ LORA_PREFIX_UNET = 'lora_unet'
335
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
336
+ def merge(
337
+ prefix,
338
+ root_module: torch.nn.Module,
339
+ target_replace_modules
340
+ ):
341
+ temp = {}
342
+
343
+ for name, module in tqdm(list(root_module.named_modules())):
344
+ if module.__class__.__name__ in target_replace_modules:
345
+ temp[name] = {}
346
+ for child_name, child_module in module.named_modules():
347
+ layer = child_module.__class__.__name__
348
+ if layer not in {'Linear', 'Conv2d'}:
349
+ continue
350
+ lora_name = prefix + '.' + name + '.' + child_name
351
+ lora_name = lora_name.replace('.', '_')
352
+
353
+ w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device)
354
+ w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device)
355
+ w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device)
356
+ w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device)
357
+ alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device)
358
+ dim = w1b.shape[0]
359
+
360
+ delta = (w1a @ w1b) * (w2a @ w2b)
361
+ delta = delta.reshape(child_module.weight.shape)
362
+
363
+ if layer == 'Conv2d':
364
+ child_module.weight.requires_grad_(False)
365
+ child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
366
+ elif layer == 'Linear':
367
+ child_module.weight.requires_grad_(False)
368
+ child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
369
+ del delta
370
+
371
+ merge(
372
+ LORA_PREFIX_TEXT_ENCODER,
373
+ base_model[0],
374
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
375
+ )
376
+ merge(
377
+ LORA_PREFIX_UNET,
378
+ base_model[2],
379
+ UNET_TARGET_REPLACE_MODULE
380
+ )