3v324v23 commited on
Commit
679ee3e
1 Parent(s): 29e6acd
LoRA-EXTRACTOR ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit fa8e800cf5ab8a87f4cf449507b79c05d190088d
Lora/RUNTHIS.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ !python /content/stable-diffusion-webui/matousbecvar_sd1.5_drbth/LoRA-EXTRACTOR/lib/extract_lora_from_models.py --save_precision fp16 --save_to matous_LORA.safetensors --model_org /content/stable-diffusion-webui/models/Stable-diffusion/stable_diffusion_v1_5.safetensors --model_tuned /content/stable-diffusion-webui/matousbecvar_sd1.5_drbth/matousbecvar.safetensors --dim 128
Lora/lib/FileToOpen.exe ADDED
Binary file (16.4 kB). View file
 
Lora/lib/FileToSave.exe ADDED
Binary file (15.9 kB). View file
 
Lora/lib/__pycache__/lora.cpython-39.pyc ADDED
Binary file (7.38 kB). View file
 
Lora/lib/__pycache__/model_util.cpython-39.pyc ADDED
Binary file (29.7 kB). View file
 
Lora/lib/__pycache__/train_util.cpython-39.pyc ADDED
Binary file (56.4 kB). View file
 
Lora/lib/extract_lora_from_models.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract approximating LoRA by svd from two SD models
2
+ # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo!
4
+
5
+ import argparse
6
+ import os
7
+ import torch
8
+ from safetensors.torch import load_file, save_file
9
+ from tqdm import tqdm
10
+ import model_util
11
+ import lora
12
+
13
+
14
+ CLAMP_QUANTILE = 0.99
15
+ MIN_DIFF = 1e-6
16
+
17
+
18
+ def save_to_file(file_name, model, state_dict, dtype):
19
+ if dtype is not None:
20
+ for key in list(state_dict.keys()):
21
+ if type(state_dict[key]) == torch.Tensor:
22
+ state_dict[key] = state_dict[key].to(dtype)
23
+
24
+ if os.path.splitext(file_name)[1] == '.safetensors':
25
+ save_file(model, file_name)
26
+ else:
27
+ torch.save(model, file_name)
28
+
29
+
30
+ def svd(args):
31
+ def str_to_dtype(p):
32
+ if p == 'float':
33
+ return torch.float
34
+ if p == 'fp16':
35
+ return torch.float16
36
+ if p == 'bf16':
37
+ return torch.bfloat16
38
+ return None
39
+
40
+ save_dtype = str_to_dtype(args.save_precision)
41
+
42
+ print(f"loading SD model : {args.model_org}")
43
+ text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
44
+ print(f"loading SD model : {args.model_tuned}")
45
+ text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
46
+
47
+ # create LoRA network to extract weights: Use dim (rank) as alpha
48
+ lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
49
+ lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
50
+ assert len(lora_network_o.text_encoder_loras) == len(
51
+ lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
52
+
53
+ # get diffs
54
+ diffs = {}
55
+ text_encoder_different = False
56
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
57
+ lora_name = lora_o.lora_name
58
+ module_o = lora_o.org_module
59
+ module_t = lora_t.org_module
60
+ diff = module_t.weight - module_o.weight
61
+
62
+ # Text Encoder might be same
63
+ if torch.max(torch.abs(diff)) > MIN_DIFF:
64
+ text_encoder_different = True
65
+
66
+ diff = diff.float()
67
+ diffs[lora_name] = diff
68
+
69
+ if not text_encoder_different:
70
+ print("Text encoder is same. Extract U-Net only.")
71
+ lora_network_o.text_encoder_loras = []
72
+ diffs = {}
73
+
74
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
75
+ lora_name = lora_o.lora_name
76
+ module_o = lora_o.org_module
77
+ module_t = lora_t.org_module
78
+ diff = module_t.weight - module_o.weight
79
+ diff = diff.float()
80
+
81
+ if args.device:
82
+ diff = diff.to(args.device)
83
+
84
+ diffs[lora_name] = diff
85
+
86
+ # make LoRA with svd
87
+ print("calculating by svd")
88
+ rank = args.dim
89
+ lora_weights = {}
90
+ with torch.no_grad():
91
+ for lora_name, mat in tqdm(list(diffs.items())):
92
+ conv2d = (len(mat.size()) == 4)
93
+ if conv2d:
94
+ mat = mat.squeeze()
95
+
96
+ U, S, Vh = torch.linalg.svd(mat)
97
+
98
+ U = U[:, :rank]
99
+ S = S[:rank]
100
+ U = U @ torch.diag(S)
101
+
102
+ Vh = Vh[:rank, :]
103
+
104
+ dist = torch.cat([U.flatten(), Vh.flatten()])
105
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
106
+ low_val = -hi_val
107
+
108
+ U = U.clamp(low_val, hi_val)
109
+ Vh = Vh.clamp(low_val, hi_val)
110
+
111
+ lora_weights[lora_name] = (U, Vh)
112
+
113
+ # make state dict for LoRA
114
+ lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
115
+ lora_sd = lora_network_o.state_dict()
116
+ print(f"LoRA has {len(lora_sd)} weights.")
117
+
118
+ for key in list(lora_sd.keys()):
119
+ if "alpha" in key:
120
+ continue
121
+
122
+ lora_name = key.split('.')[0]
123
+ i = 0 if "lora_up" in key else 1
124
+
125
+ weights = lora_weights[lora_name][i]
126
+ # print(key, i, weights.size(), lora_sd[key].size())
127
+ if len(lora_sd[key].size()) == 4:
128
+ weights = weights.unsqueeze(2).unsqueeze(3)
129
+
130
+ assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
131
+ lora_sd[key] = weights
132
+
133
+ # load state dict to LoRA and save it
134
+ info = lora_network_o.load_state_dict(lora_sd)
135
+ print(f"Loading extracted LoRA weights: {info}")
136
+
137
+ dir_name = os.path.dirname(args.save_to)
138
+ if dir_name and not os.path.exists(dir_name):
139
+ os.makedirs(dir_name, exist_ok=True)
140
+
141
+ # minimum metadata
142
+ metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
143
+
144
+ lora_network_o.save_weights(args.save_to, save_dtype, metadata)
145
+ print(f"LoRA weights are saved to: {args.save_to}")
146
+
147
+
148
+ if __name__ == '__main__':
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--v2", action='store_true',
151
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
152
+ parser.add_argument("--save_precision", type=str, default=None,
153
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
154
+ parser.add_argument("--model_org", type=str, default=None,
155
+ help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
156
+ parser.add_argument("--model_tuned", type=str, default=None,
157
+ help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
158
+ parser.add_argument("--save_to", type=str, default=None,
159
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
160
+ parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
161
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
162
+
163
+ args = parser.parse_args()
164
+ svd(args)
Lora/lib/lora.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import List
9
+ import torch
10
+
11
+ import train_util
12
+
13
+
14
+ class LoRAModule(torch.nn.Module):
15
+ """
16
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
17
+ """
18
+
19
+ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
20
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
21
+ super().__init__()
22
+ self.lora_name = lora_name
23
+ self.lora_dim = lora_dim
24
+
25
+ if org_module.__class__.__name__ == 'Conv2d':
26
+ in_dim = org_module.in_channels
27
+ out_dim = org_module.out_channels
28
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
29
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
30
+ else:
31
+ in_dim = org_module.in_features
32
+ out_dim = org_module.out_features
33
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
34
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
35
+
36
+ if type(alpha) == torch.Tensor:
37
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
38
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
39
+ self.scale = alpha / self.lora_dim
40
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
41
+
42
+ # same as microsoft's
43
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
44
+ torch.nn.init.zeros_(self.lora_up.weight)
45
+
46
+ self.multiplier = multiplier
47
+ self.org_module = org_module # remove in applying
48
+
49
+ def apply_to(self):
50
+ self.org_forward = self.org_module.forward
51
+ self.org_module.forward = self.forward
52
+ del self.org_module
53
+
54
+ def forward(self, x):
55
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
56
+
57
+
58
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
59
+ if network_dim is None:
60
+ network_dim = 4 # default
61
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
62
+ return network
63
+
64
+
65
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
66
+ if os.path.splitext(file)[1] == '.safetensors':
67
+ from safetensors.torch import load_file, safe_open
68
+ weights_sd = load_file(file)
69
+ else:
70
+ weights_sd = torch.load(file, map_location='cpu')
71
+
72
+ # get dim (rank)
73
+ network_alpha = None
74
+ network_dim = None
75
+ for key, value in weights_sd.items():
76
+ if network_alpha is None and 'alpha' in key:
77
+ network_alpha = value
78
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
79
+ network_dim = value.size()[0]
80
+
81
+ if network_alpha is None:
82
+ network_alpha = network_dim
83
+
84
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
85
+ network.weights_sd = weights_sd
86
+ return network
87
+
88
+
89
+ class LoRANetwork(torch.nn.Module):
90
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
91
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
+ LORA_PREFIX_UNET = 'lora_unet'
93
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
+
95
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
96
+ super().__init__()
97
+ self.multiplier = multiplier
98
+ self.lora_dim = lora_dim
99
+ self.alpha = alpha
100
+
101
+ # create module instances
102
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
103
+ loras = []
104
+ for name, module in root_module.named_modules():
105
+ if module.__class__.__name__ in target_replace_modules:
106
+ for child_name, child_module in module.named_modules():
107
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
108
+ lora_name = prefix + '.' + name + '.' + child_name
109
+ lora_name = lora_name.replace('.', '_')
110
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
111
+ loras.append(lora)
112
+ return loras
113
+
114
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
115
+ text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
116
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
117
+
118
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
119
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
120
+
121
+ self.weights_sd = None
122
+
123
+ # assertion
124
+ names = set()
125
+ for lora in self.text_encoder_loras + self.unet_loras:
126
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
127
+ names.add(lora.lora_name)
128
+
129
+ def load_weights(self, file):
130
+ if os.path.splitext(file)[1] == '.safetensors':
131
+ from safetensors.torch import load_file, safe_open
132
+ self.weights_sd = load_file(file)
133
+ else:
134
+ self.weights_sd = torch.load(file, map_location='cpu')
135
+
136
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
137
+ if self.weights_sd:
138
+ weights_has_text_encoder = weights_has_unet = False
139
+ for key in self.weights_sd.keys():
140
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
141
+ weights_has_text_encoder = True
142
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
143
+ weights_has_unet = True
144
+
145
+ if apply_text_encoder is None:
146
+ apply_text_encoder = weights_has_text_encoder
147
+ else:
148
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
149
+
150
+ if apply_unet is None:
151
+ apply_unet = weights_has_unet
152
+ else:
153
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
154
+ else:
155
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
156
+
157
+ if apply_text_encoder:
158
+ print("enable LoRA for text encoder")
159
+ else:
160
+ self.text_encoder_loras = []
161
+
162
+ if apply_unet:
163
+ print("enable LoRA for U-Net")
164
+ else:
165
+ self.unet_loras = []
166
+
167
+ for lora in self.text_encoder_loras + self.unet_loras:
168
+ lora.apply_to()
169
+ self.add_module(lora.lora_name, lora)
170
+
171
+ if self.weights_sd:
172
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
173
+ info = self.load_state_dict(self.weights_sd, False)
174
+ print(f"weights are loaded: {info}")
175
+
176
+ def enable_gradient_checkpointing(self):
177
+ # not supported
178
+ pass
179
+
180
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
181
+ def enumerate_params(loras):
182
+ params = []
183
+ for lora in loras:
184
+ params.extend(lora.parameters())
185
+ return params
186
+
187
+ self.requires_grad_(True)
188
+ all_params = []
189
+
190
+ if self.text_encoder_loras:
191
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
192
+ if text_encoder_lr is not None:
193
+ param_data['lr'] = text_encoder_lr
194
+ all_params.append(param_data)
195
+
196
+ if self.unet_loras:
197
+ param_data = {'params': enumerate_params(self.unet_loras)}
198
+ if unet_lr is not None:
199
+ param_data['lr'] = unet_lr
200
+ all_params.append(param_data)
201
+
202
+ return all_params
203
+
204
+ def prepare_grad_etc(self, text_encoder, unet):
205
+ self.requires_grad_(True)
206
+
207
+ def on_epoch_start(self, text_encoder, unet):
208
+ self.train()
209
+
210
+ def get_trainable_params(self):
211
+ return self.parameters()
212
+
213
+ def save_weights(self, file, dtype, metadata):
214
+ if metadata is not None and len(metadata) == 0:
215
+ metadata = None
216
+
217
+ state_dict = self.state_dict()
218
+
219
+ if dtype is not None:
220
+ for key in list(state_dict.keys()):
221
+ v = state_dict[key]
222
+ v = v.detach().clone().to("cpu").to(dtype)
223
+ state_dict[key] = v
224
+
225
+ if os.path.splitext(file)[1] == '.safetensors':
226
+ from safetensors.torch import save_file
227
+
228
+ # Precalculate model hashes to save time on indexing
229
+ if metadata is None:
230
+ metadata = {}
231
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
232
+ metadata["sshs_model_hash"] = model_hash
233
+ metadata["sshs_legacy_hash"] = legacy_hash
234
+
235
+ save_file(state_dict, file, metadata)
236
+ else:
237
+ torch.save(state_dict, file)
Lora/lib/model_util.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ # DiffUsers版StableDiffusionのモデルパラメータ
12
+ NUM_TRAIN_TIMESTEPS = 1000
13
+ BETA_START = 0.00085
14
+ BETA_END = 0.0120
15
+
16
+ UNET_PARAMS_MODEL_CHANNELS = 320
17
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
+ UNET_PARAMS_IN_CHANNELS = 4
21
+ UNET_PARAMS_OUT_CHANNELS = 4
22
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
23
+ UNET_PARAMS_CONTEXT_DIM = 768
24
+ UNET_PARAMS_NUM_HEADS = 8
25
+
26
+ VAE_PARAMS_Z_CHANNELS = 4
27
+ VAE_PARAMS_RESOLUTION = 256
28
+ VAE_PARAMS_IN_CHANNELS = 3
29
+ VAE_PARAMS_OUT_CH = 3
30
+ VAE_PARAMS_CH = 128
31
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
32
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
33
+
34
+ # V2
35
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
36
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
37
+
38
+ # Diffusersの設定を読み込むための参照モデル
39
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
40
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
41
+
42
+
43
+ # region StableDiffusion->Diffusersの変換コード
44
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
45
+
46
+
47
+ def shave_segments(path, n_shave_prefix_segments=1):
48
+ """
49
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
50
+ """
51
+ if n_shave_prefix_segments >= 0:
52
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
53
+ else:
54
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
55
+
56
+
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
+
90
+ mapping.append({"old": old_item, "new": new_item})
91
+
92
+ return mapping
93
+
94
+
95
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
+ """
97
+ Updates paths inside attentions to the new naming scheme (local renaming)
98
+ """
99
+ mapping = []
100
+ for old_item in old_list:
101
+ new_item = old_item
102
+
103
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
+
106
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
+
109
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
+
111
+ mapping.append({"old": old_item, "new": new_item})
112
+
113
+ return mapping
114
+
115
+
116
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
+ """
118
+ Updates paths inside attentions to the new naming scheme (local renaming)
119
+ """
120
+ mapping = []
121
+ for old_item in old_list:
122
+ new_item = old_item
123
+
124
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
125
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
126
+
127
+ new_item = new_item.replace("q.weight", "query.weight")
128
+ new_item = new_item.replace("q.bias", "query.bias")
129
+
130
+ new_item = new_item.replace("k.weight", "key.weight")
131
+ new_item = new_item.replace("k.bias", "key.bias")
132
+
133
+ new_item = new_item.replace("v.weight", "value.weight")
134
+ new_item = new_item.replace("v.bias", "value.bias")
135
+
136
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
+
139
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
+
141
+ mapping.append({"old": old_item, "new": new_item})
142
+
143
+ return mapping
144
+
145
+
146
+ def assign_to_checkpoint(
147
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
+ ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
+
154
+ Assigns the weights to the new checkpoint.
155
+ """
156
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
+
158
+ # Splits the attention layers into three variables.
159
+ if attention_paths_to_split is not None:
160
+ for path, path_map in attention_paths_to_split.items():
161
+ old_tensor = old_checkpoint[path]
162
+ channels = old_tensor.shape[0] // 3
163
+
164
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
+
166
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
+
168
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
+
171
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
172
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
173
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
174
+
175
+ for path in paths:
176
+ new_path = path["new"]
177
+
178
+ # These have already been assigned
179
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
+ continue
181
+
182
+ # Global renaming happens here
183
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
+
187
+ if additional_replacements is not None:
188
+ for replacement in additional_replacements:
189
+ new_path = new_path.replace(replacement["old"], replacement["new"])
190
+
191
+ # proj_attn.weight has to be converted from conv 1D to linear
192
+ if "proj_attn.weight" in new_path:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
+ else:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]]
196
+
197
+
198
+ def conv_attn_to_linear(checkpoint):
199
+ keys = list(checkpoint.keys())
200
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
201
+ for key in keys:
202
+ if ".".join(key.split(".")[-2:]) in attn_keys:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
+ elif "proj_attn.weight" in key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0]
208
+
209
+
210
+ def linear_transformer_to_conv(checkpoint):
211
+ keys = list(checkpoint.keys())
212
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
213
+ for key in keys:
214
+ if ".".join(key.split(".")[-2:]) in tf_keys:
215
+ if checkpoint[key].ndim == 2:
216
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
+
218
+
219
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
+ """
221
+ Takes a state dict and a config, and returns a converted checkpoint.
222
+ """
223
+
224
+ # extract state_dict for UNet
225
+ unet_state_dict = {}
226
+ unet_key = "model.diffusion_model."
227
+ keys = list(checkpoint.keys())
228
+ for key in keys:
229
+ if key.startswith(unet_key):
230
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
+
232
+ new_checkpoint = {}
233
+
234
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
+
239
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
251
+ for layer_id in range(num_input_blocks)
252
+ }
253
+
254
+ # Retrieves the keys for the middle blocks only
255
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
256
+ middle_blocks = {
257
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
258
+ for layer_id in range(num_middle_blocks)
259
+ }
260
+
261
+ # Retrieves the keys for the output blocks only
262
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
263
+ output_blocks = {
264
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
265
+ for layer_id in range(num_output_blocks)
266
+ }
267
+
268
+ for i in range(1, num_input_blocks):
269
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
270
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
271
+
272
+ resnets = [
273
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
274
+ ]
275
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
276
+
277
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
278
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
279
+ f"input_blocks.{i}.0.op.weight"
280
+ )
281
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
282
+ f"input_blocks.{i}.0.op.bias"
283
+ )
284
+
285
+ paths = renew_resnet_paths(resnets)
286
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
287
+ assign_to_checkpoint(
288
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
289
+ )
290
+
291
+ if len(attentions):
292
+ paths = renew_attention_paths(attentions)
293
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
294
+ assign_to_checkpoint(
295
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
296
+ )
297
+
298
+ resnet_0 = middle_blocks[0]
299
+ attentions = middle_blocks[1]
300
+ resnet_1 = middle_blocks[2]
301
+
302
+ resnet_0_paths = renew_resnet_paths(resnet_0)
303
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
304
+
305
+ resnet_1_paths = renew_resnet_paths(resnet_1)
306
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
307
+
308
+ attentions_paths = renew_attention_paths(attentions)
309
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
310
+ assign_to_checkpoint(
311
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
312
+ )
313
+
314
+ for i in range(num_output_blocks):
315
+ block_id = i // (config["layers_per_block"] + 1)
316
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
317
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
318
+ output_block_list = {}
319
+
320
+ for layer in output_block_layers:
321
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
322
+ if layer_id in output_block_list:
323
+ output_block_list[layer_id].append(layer_name)
324
+ else:
325
+ output_block_list[layer_id] = [layer_name]
326
+
327
+ if len(output_block_list) > 1:
328
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
329
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
330
+
331
+ resnet_0_paths = renew_resnet_paths(resnets)
332
+ paths = renew_resnet_paths(resnets)
333
+
334
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
335
+ assign_to_checkpoint(
336
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
337
+ )
338
+
339
+ # オリジナル:
340
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
341
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
342
+
343
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
344
+ for l in output_block_list.values():
345
+ l.sort()
346
+
347
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
348
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
349
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
350
+ f"output_blocks.{i}.{index}.conv.bias"
351
+ ]
352
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
353
+ f"output_blocks.{i}.{index}.conv.weight"
354
+ ]
355
+
356
+ # Clear attentions as they have been attributed above.
357
+ if len(attentions) == 2:
358
+ attentions = []
359
+
360
+ if len(attentions):
361
+ paths = renew_attention_paths(attentions)
362
+ meta_path = {
363
+ "old": f"output_blocks.{i}.1",
364
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
365
+ }
366
+ assign_to_checkpoint(
367
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
368
+ )
369
+ else:
370
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
371
+ for path in resnet_0_paths:
372
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
373
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
374
+
375
+ new_checkpoint[new_path] = unet_state_dict[old_path]
376
+
377
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
378
+ if v2:
379
+ linear_transformer_to_conv(new_checkpoint)
380
+
381
+ return new_checkpoint
382
+
383
+
384
+ def convert_ldm_vae_checkpoint(checkpoint, config):
385
+ # extract state dict for VAE
386
+ vae_state_dict = {}
387
+ vae_key = "first_stage_model."
388
+ keys = list(checkpoint.keys())
389
+ for key in keys:
390
+ if key.startswith(vae_key):
391
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
392
+ # if len(vae_state_dict) == 0:
393
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
394
+ # vae_state_dict = checkpoint
395
+
396
+ new_checkpoint = {}
397
+
398
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
399
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
400
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
401
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
402
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
403
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
404
+
405
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
406
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
407
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
408
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
409
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
410
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
411
+
412
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
413
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
414
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
415
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
416
+
417
+ # Retrieves the keys for the encoder down blocks only
418
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
419
+ down_blocks = {
420
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
421
+ }
422
+
423
+ # Retrieves the keys for the decoder up blocks only
424
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
425
+ up_blocks = {
426
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
427
+ }
428
+
429
+ for i in range(num_down_blocks):
430
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
431
+
432
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
433
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
434
+ f"encoder.down.{i}.downsample.conv.weight"
435
+ )
436
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
437
+ f"encoder.down.{i}.downsample.conv.bias"
438
+ )
439
+
440
+ paths = renew_vae_resnet_paths(resnets)
441
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
442
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
443
+
444
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
445
+ num_mid_res_blocks = 2
446
+ for i in range(1, num_mid_res_blocks + 1):
447
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
448
+
449
+ paths = renew_vae_resnet_paths(resnets)
450
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
451
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
452
+
453
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
454
+ paths = renew_vae_attention_paths(mid_attentions)
455
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
456
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
457
+ conv_attn_to_linear(new_checkpoint)
458
+
459
+ for i in range(num_up_blocks):
460
+ block_id = num_up_blocks - 1 - i
461
+ resnets = [
462
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
463
+ ]
464
+
465
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
466
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
467
+ f"decoder.up.{block_id}.upsample.conv.weight"
468
+ ]
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
470
+ f"decoder.up.{block_id}.upsample.conv.bias"
471
+ ]
472
+
473
+ paths = renew_vae_resnet_paths(resnets)
474
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
475
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
476
+
477
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
478
+ num_mid_res_blocks = 2
479
+ for i in range(1, num_mid_res_blocks + 1):
480
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
481
+
482
+ paths = renew_vae_resnet_paths(resnets)
483
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
484
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
485
+
486
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
487
+ paths = renew_vae_attention_paths(mid_attentions)
488
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
489
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
490
+ conv_attn_to_linear(new_checkpoint)
491
+ return new_checkpoint
492
+
493
+
494
+ def create_unet_diffusers_config(v2):
495
+ """
496
+ Creates a config for the diffusers based on the config of the LDM model.
497
+ """
498
+ # unet_params = original_config.model.params.unet_config.params
499
+
500
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
501
+
502
+ down_block_types = []
503
+ resolution = 1
504
+ for i in range(len(block_out_channels)):
505
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
506
+ down_block_types.append(block_type)
507
+ if i != len(block_out_channels) - 1:
508
+ resolution *= 2
509
+
510
+ up_block_types = []
511
+ for i in range(len(block_out_channels)):
512
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
513
+ up_block_types.append(block_type)
514
+ resolution //= 2
515
+
516
+ config = dict(
517
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
518
+ in_channels=UNET_PARAMS_IN_CHANNELS,
519
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
520
+ down_block_types=tuple(down_block_types),
521
+ up_block_types=tuple(up_block_types),
522
+ block_out_channels=tuple(block_out_channels),
523
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
524
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
525
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
526
+ )
527
+
528
+ return config
529
+
530
+
531
+ def create_vae_diffusers_config():
532
+ """
533
+ Creates a config for the diffusers based on the config of the LDM model.
534
+ """
535
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
536
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
537
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
538
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
539
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
540
+
541
+ config = dict(
542
+ sample_size=VAE_PARAMS_RESOLUTION,
543
+ in_channels=VAE_PARAMS_IN_CHANNELS,
544
+ out_channels=VAE_PARAMS_OUT_CH,
545
+ down_block_types=tuple(down_block_types),
546
+ up_block_types=tuple(up_block_types),
547
+ block_out_channels=tuple(block_out_channels),
548
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
549
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
550
+ )
551
+ return config
552
+
553
+
554
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
555
+ keys = list(checkpoint.keys())
556
+ text_model_dict = {}
557
+ for key in keys:
558
+ if key.startswith("cond_stage_model.transformer"):
559
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
560
+ return text_model_dict
561
+
562
+
563
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
564
+ # 嫌になるくらい違うぞ!
565
+ def convert_key(key):
566
+ if not key.startswith("cond_stage_model"):
567
+ return None
568
+
569
+ # common conversion
570
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
571
+ key = key.replace("cond_stage_model.model.", "text_model.")
572
+
573
+ if "resblocks" in key:
574
+ # resblocks conversion
575
+ key = key.replace(".resblocks.", ".layers.")
576
+ if ".ln_" in key:
577
+ key = key.replace(".ln_", ".layer_norm")
578
+ elif ".mlp." in key:
579
+ key = key.replace(".c_fc.", ".fc1.")
580
+ key = key.replace(".c_proj.", ".fc2.")
581
+ elif '.attn.out_proj' in key:
582
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
583
+ elif '.attn.in_proj' in key:
584
+ key = None # 特殊なので後で処理する
585
+ else:
586
+ raise ValueError(f"unexpected key in SD: {key}")
587
+ elif '.positional_embedding' in key:
588
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
589
+ elif '.text_projection' in key:
590
+ key = None # 使われない???
591
+ elif '.logit_scale' in key:
592
+ key = None # 使われない???
593
+ elif '.token_embedding' in key:
594
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
595
+ elif '.ln_final' in key:
596
+ key = key.replace(".ln_final", ".final_layer_norm")
597
+ return key
598
+
599
+ keys = list(checkpoint.keys())
600
+ new_sd = {}
601
+ for key in keys:
602
+ # remove resblocks 23
603
+ if '.resblocks.23.' in key:
604
+ continue
605
+ new_key = convert_key(key)
606
+ if new_key is None:
607
+ continue
608
+ new_sd[new_key] = checkpoint[key]
609
+
610
+ # attnの変換
611
+ for key in keys:
612
+ if '.resblocks.23.' in key:
613
+ continue
614
+ if '.resblocks' in key and '.attn.in_proj_' in key:
615
+ # 三つに分割
616
+ values = torch.chunk(checkpoint[key], 3)
617
+
618
+ key_suffix = ".weight" if "weight" in key else ".bias"
619
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
620
+ key_pfx = key_pfx.replace("_weight", "")
621
+ key_pfx = key_pfx.replace("_bias", "")
622
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
623
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
624
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
625
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
626
+
627
+ # rename or add position_ids
628
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
629
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
630
+ # waifu diffusion v1.4
631
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
632
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
633
+ else:
634
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
635
+
636
+ new_sd["text_model.embeddings.position_ids"] = position_ids
637
+ return new_sd
638
+
639
+ # endregion
640
+
641
+
642
+ # region Diffusers->StableDiffusion の変換コード
643
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
644
+
645
+ def conv_transformer_to_linear(checkpoint):
646
+ keys = list(checkpoint.keys())
647
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
648
+ for key in keys:
649
+ if ".".join(key.split(".")[-2:]) in tf_keys:
650
+ if checkpoint[key].ndim > 2:
651
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
652
+
653
+
654
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
655
+ unet_conversion_map = [
656
+ # (stable-diffusion, HF Diffusers)
657
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
658
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
659
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
660
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
661
+ ("input_blocks.0.0.weight", "conv_in.weight"),
662
+ ("input_blocks.0.0.bias", "conv_in.bias"),
663
+ ("out.0.weight", "conv_norm_out.weight"),
664
+ ("out.0.bias", "conv_norm_out.bias"),
665
+ ("out.2.weight", "conv_out.weight"),
666
+ ("out.2.bias", "conv_out.bias"),
667
+ ]
668
+
669
+ unet_conversion_map_resnet = [
670
+ # (stable-diffusion, HF Diffusers)
671
+ ("in_layers.0", "norm1"),
672
+ ("in_layers.2", "conv1"),
673
+ ("out_layers.0", "norm2"),
674
+ ("out_layers.3", "conv2"),
675
+ ("emb_layers.1", "time_emb_proj"),
676
+ ("skip_connection", "conv_shortcut"),
677
+ ]
678
+
679
+ unet_conversion_map_layer = []
680
+ for i in range(4):
681
+ # loop over downblocks/upblocks
682
+
683
+ for j in range(2):
684
+ # loop over resnets/attentions for downblocks
685
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
686
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
687
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
688
+
689
+ if i < 3:
690
+ # no attention layers in down_blocks.3
691
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
692
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
693
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
694
+
695
+ for j in range(3):
696
+ # loop over resnets/attentions for upblocks
697
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
698
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
699
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
700
+
701
+ if i > 0:
702
+ # no attention layers in up_blocks.0
703
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
704
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
705
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
706
+
707
+ if i < 3:
708
+ # no downsample in down_blocks.3
709
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
710
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
711
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
712
+
713
+ # no upsample in up_blocks.3
714
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
715
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
716
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
717
+
718
+ hf_mid_atn_prefix = "mid_block.attentions.0."
719
+ sd_mid_atn_prefix = "middle_block.1."
720
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
721
+
722
+ for j in range(2):
723
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
724
+ sd_mid_res_prefix = f"middle_block.{2*j}."
725
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
726
+
727
+ # buyer beware: this is a *brittle* function,
728
+ # and correct output requires that all of these pieces interact in
729
+ # the exact order in which I have arranged them.
730
+ mapping = {k: k for k in unet_state_dict.keys()}
731
+ for sd_name, hf_name in unet_conversion_map:
732
+ mapping[hf_name] = sd_name
733
+ for k, v in mapping.items():
734
+ if "resnets" in k:
735
+ for sd_part, hf_part in unet_conversion_map_resnet:
736
+ v = v.replace(hf_part, sd_part)
737
+ mapping[k] = v
738
+ for k, v in mapping.items():
739
+ for sd_part, hf_part in unet_conversion_map_layer:
740
+ v = v.replace(hf_part, sd_part)
741
+ mapping[k] = v
742
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
743
+
744
+ if v2:
745
+ conv_transformer_to_linear(new_state_dict)
746
+
747
+ return new_state_dict
748
+
749
+
750
+ # ================#
751
+ # VAE Conversion #
752
+ # ================#
753
+
754
+ def reshape_weight_for_sd(w):
755
+ # convert HF linear weights to SD conv2d weights
756
+ return w.reshape(*w.shape, 1, 1)
757
+
758
+
759
+ def convert_vae_state_dict(vae_state_dict):
760
+ vae_conversion_map = [
761
+ # (stable-diffusion, HF Diffusers)
762
+ ("nin_shortcut", "conv_shortcut"),
763
+ ("norm_out", "conv_norm_out"),
764
+ ("mid.attn_1.", "mid_block.attentions.0."),
765
+ ]
766
+
767
+ for i in range(4):
768
+ # down_blocks have two resnets
769
+ for j in range(2):
770
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
771
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
772
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
773
+
774
+ if i < 3:
775
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
776
+ sd_downsample_prefix = f"down.{i}.downsample."
777
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
778
+
779
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
780
+ sd_upsample_prefix = f"up.{3-i}.upsample."
781
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
782
+
783
+ # up_blocks have three resnets
784
+ # also, up blocks in hf are numbered in reverse from sd
785
+ for j in range(3):
786
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
787
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
788
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
789
+
790
+ # this part accounts for mid blocks in both the encoder and the decoder
791
+ for i in range(2):
792
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
793
+ sd_mid_res_prefix = f"mid.block_{i+1}."
794
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
795
+
796
+ vae_conversion_map_attn = [
797
+ # (stable-diffusion, HF Diffusers)
798
+ ("norm.", "group_norm."),
799
+ ("q.", "query."),
800
+ ("k.", "key."),
801
+ ("v.", "value."),
802
+ ("proj_out.", "proj_attn."),
803
+ ]
804
+
805
+ mapping = {k: k for k in vae_state_dict.keys()}
806
+ for k, v in mapping.items():
807
+ for sd_part, hf_part in vae_conversion_map:
808
+ v = v.replace(hf_part, sd_part)
809
+ mapping[k] = v
810
+ for k, v in mapping.items():
811
+ if "attentions" in k:
812
+ for sd_part, hf_part in vae_conversion_map_attn:
813
+ v = v.replace(hf_part, sd_part)
814
+ mapping[k] = v
815
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
816
+ weights_to_convert = ["q", "k", "v", "proj_out"]
817
+ for k, v in new_state_dict.items():
818
+ for weight_name in weights_to_convert:
819
+ if f"mid.attn_1.{weight_name}.weight" in k:
820
+ # print(f"Reshaping {k} for SD format")
821
+ new_state_dict[k] = reshape_weight_for_sd(v)
822
+
823
+ return new_state_dict
824
+
825
+
826
+ # endregion
827
+
828
+ # region 自作のモデル読み書きなど
829
+
830
+ def is_safetensors(path):
831
+ return os.path.splitext(path)[1].lower() == '.safetensors'
832
+
833
+
834
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
835
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
836
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
837
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
838
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
839
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
840
+ ]
841
+
842
+ if is_safetensors(ckpt_path):
843
+ checkpoint = None
844
+ state_dict = load_file(ckpt_path, "cpu")
845
+ else:
846
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
847
+ if "state_dict" in checkpoint:
848
+ state_dict = checkpoint["state_dict"]
849
+ else:
850
+ state_dict = checkpoint
851
+ checkpoint = None
852
+
853
+ key_reps = []
854
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
855
+ for key in state_dict.keys():
856
+ if key.startswith(rep_from):
857
+ new_key = rep_to + key[len(rep_from):]
858
+ key_reps.append((key, new_key))
859
+
860
+ for key, new_key in key_reps:
861
+ state_dict[new_key] = state_dict[key]
862
+ del state_dict[key]
863
+
864
+ return checkpoint, state_dict
865
+
866
+
867
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
868
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
869
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
870
+ if dtype is not None:
871
+ for k, v in state_dict.items():
872
+ if type(v) is torch.Tensor:
873
+ state_dict[k] = v.to(dtype)
874
+
875
+ # Convert the UNet2DConditionModel model.
876
+ unet_config = create_unet_diffusers_config(v2)
877
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
878
+
879
+ unet = UNet2DConditionModel(**unet_config)
880
+ info = unet.load_state_dict(converted_unet_checkpoint)
881
+ print("loading u-net:", info)
882
+
883
+ # Convert the VAE model.
884
+ vae_config = create_vae_diffusers_config()
885
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
886
+
887
+ vae = AutoencoderKL(**vae_config)
888
+ info = vae.load_state_dict(converted_vae_checkpoint)
889
+ print("loading vae:", info)
890
+
891
+ # convert text_model
892
+ if v2:
893
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
894
+ cfg = CLIPTextConfig(
895
+ vocab_size=49408,
896
+ hidden_size=1024,
897
+ intermediate_size=4096,
898
+ num_hidden_layers=23,
899
+ num_attention_heads=16,
900
+ max_position_embeddings=77,
901
+ hidden_act="gelu",
902
+ layer_norm_eps=1e-05,
903
+ dropout=0.0,
904
+ attention_dropout=0.0,
905
+ initializer_range=0.02,
906
+ initializer_factor=1.0,
907
+ pad_token_id=1,
908
+ bos_token_id=0,
909
+ eos_token_id=2,
910
+ model_type="clip_text_model",
911
+ projection_dim=512,
912
+ torch_dtype="float32",
913
+ transformers_version="4.25.0.dev0",
914
+ )
915
+ text_model = CLIPTextModel._from_config(cfg)
916
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
+ else:
918
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
920
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
+ print("loading text encoder:", info)
922
+
923
+ return text_model, vae, unet
924
+
925
+
926
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
927
+ def convert_key(key):
928
+ # position_idsの除去
929
+ if ".position_ids" in key:
930
+ return None
931
+
932
+ # common
933
+ key = key.replace("text_model.encoder.", "transformer.")
934
+ key = key.replace("text_model.", "")
935
+ if "layers" in key:
936
+ # resblocks conversion
937
+ key = key.replace(".layers.", ".resblocks.")
938
+ if ".layer_norm" in key:
939
+ key = key.replace(".layer_norm", ".ln_")
940
+ elif ".mlp." in key:
941
+ key = key.replace(".fc1.", ".c_fc.")
942
+ key = key.replace(".fc2.", ".c_proj.")
943
+ elif '.self_attn.out_proj' in key:
944
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
945
+ elif '.self_attn.' in key:
946
+ key = None # 特殊なので後で処理する
947
+ else:
948
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
949
+ elif '.position_embedding' in key:
950
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
951
+ elif '.token_embedding' in key:
952
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
953
+ elif 'final_layer_norm' in key:
954
+ key = key.replace("final_layer_norm", "ln_final")
955
+ return key
956
+
957
+ keys = list(checkpoint.keys())
958
+ new_sd = {}
959
+ for key in keys:
960
+ new_key = convert_key(key)
961
+ if new_key is None:
962
+ continue
963
+ new_sd[new_key] = checkpoint[key]
964
+
965
+ # attnの変換
966
+ for key in keys:
967
+ if 'layers' in key and 'q_proj' in key:
968
+ # 三つを結合
969
+ key_q = key
970
+ key_k = key.replace("q_proj", "k_proj")
971
+ key_v = key.replace("q_proj", "v_proj")
972
+
973
+ value_q = checkpoint[key_q]
974
+ value_k = checkpoint[key_k]
975
+ value_v = checkpoint[key_v]
976
+ value = torch.cat([value_q, value_k, value_v])
977
+
978
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
979
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
980
+ new_sd[new_key] = value
981
+
982
+ # 最後の層などを捏造するか
983
+ if make_dummy_weights:
984
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
985
+ keys = list(new_sd.keys())
986
+ for key in keys:
987
+ if key.startswith("transformer.resblocks.22."):
988
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
989
+
990
+ # Diffusersに含まれない重みを作っておく
991
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
992
+ new_sd['logit_scale'] = torch.tensor(1)
993
+
994
+ return new_sd
995
+
996
+
997
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
998
+ if ckpt_path is not None:
999
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1000
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1001
+ if checkpoint is None: # safetensors または state_dictのckpt
1002
+ checkpoint = {}
1003
+ strict = False
1004
+ else:
1005
+ strict = True
1006
+ if "state_dict" in state_dict:
1007
+ del state_dict["state_dict"]
1008
+ else:
1009
+ # 新しく作る
1010
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1011
+ checkpoint = {}
1012
+ state_dict = {}
1013
+ strict = False
1014
+
1015
+ def update_sd(prefix, sd):
1016
+ for k, v in sd.items():
1017
+ key = prefix + k
1018
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1019
+ if save_dtype is not None:
1020
+ v = v.detach().clone().to("cpu").to(save_dtype)
1021
+ state_dict[key] = v
1022
+
1023
+ # Convert the UNet model
1024
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1025
+ update_sd("model.diffusion_model.", unet_state_dict)
1026
+
1027
+ # Convert the text encoder model
1028
+ if v2:
1029
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
1030
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1031
+ update_sd("cond_stage_model.model.", text_enc_dict)
1032
+ else:
1033
+ text_enc_dict = text_encoder.state_dict()
1034
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1035
+
1036
+ # Convert the VAE
1037
+ if vae is not None:
1038
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1039
+ update_sd("first_stage_model.", vae_dict)
1040
+
1041
+ # Put together new checkpoint
1042
+ key_count = len(state_dict.keys())
1043
+ new_ckpt = {'state_dict': state_dict}
1044
+
1045
+ if 'epoch' in checkpoint:
1046
+ epochs += checkpoint['epoch']
1047
+ if 'global_step' in checkpoint:
1048
+ steps += checkpoint['global_step']
1049
+
1050
+ new_ckpt['epoch'] = epochs
1051
+ new_ckpt['global_step'] = steps
1052
+
1053
+ if is_safetensors(output_file):
1054
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1055
+ save_file(state_dict, output_file)
1056
+ else:
1057
+ torch.save(new_ckpt, output_file)
1058
+
1059
+ return key_count
1060
+
1061
+
1062
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1063
+ if pretrained_model_name_or_path is None:
1064
+ # load default settings for v1/v2
1065
+ if v2:
1066
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1067
+ else:
1068
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1069
+
1070
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1071
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1072
+ if vae is None:
1073
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1074
+
1075
+ pipeline = StableDiffusionPipeline(
1076
+ unet=unet,
1077
+ text_encoder=text_encoder,
1078
+ vae=vae,
1079
+ scheduler=scheduler,
1080
+ tokenizer=tokenizer,
1081
+ safety_checker=None,
1082
+ feature_extractor=None,
1083
+ requires_safety_checker=None,
1084
+ )
1085
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1086
+
1087
+
1088
+ VAE_PREFIX = "first_stage_model."
1089
+
1090
+
1091
+ def load_vae(vae_id, dtype):
1092
+ print(f"load VAE: {vae_id}")
1093
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1094
+ # Diffusers local/remote
1095
+ try:
1096
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1097
+ except EnvironmentError as e:
1098
+ print(f"exception occurs in loading vae: {e}")
1099
+ print("retry with subfolder='vae'")
1100
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1101
+ return vae
1102
+
1103
+ # local
1104
+ vae_config = create_vae_diffusers_config()
1105
+
1106
+ if vae_id.endswith(".bin"):
1107
+ # SD 1.5 VAE on Huggingface
1108
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1109
+ else:
1110
+ # StableDiffusion
1111
+ vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1112
+ else torch.load(vae_id, map_location="cpu"))
1113
+ vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1114
+
1115
+ # vae only or full model
1116
+ full_model = False
1117
+ for vae_key in vae_sd:
1118
+ if vae_key.startswith(VAE_PREFIX):
1119
+ full_model = True
1120
+ break
1121
+ if not full_model:
1122
+ sd = {}
1123
+ for key, value in vae_sd.items():
1124
+ sd[VAE_PREFIX + key] = value
1125
+ vae_sd = sd
1126
+ del sd
1127
+
1128
+ # Convert the VAE model.
1129
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1130
+
1131
+ vae = AutoencoderKL(**vae_config)
1132
+ vae.load_state_dict(converted_vae_checkpoint)
1133
+ return vae
1134
+
1135
+ # endregion
1136
+
1137
+
1138
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1139
+ max_width, max_height = max_reso
1140
+ max_area = (max_width // divisible) * (max_height // divisible)
1141
+
1142
+ resos = set()
1143
+
1144
+ size = int(math.sqrt(max_area)) * divisible
1145
+ resos.add((size, size))
1146
+
1147
+ size = min_size
1148
+ while size <= max_size:
1149
+ width = size
1150
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1151
+ resos.add((width, height))
1152
+ resos.add((height, width))
1153
+
1154
+ # # make additional resos
1155
+ # if width >= height and width - divisible >= min_size:
1156
+ # resos.add((width - divisible, height))
1157
+ # resos.add((height, width - divisible))
1158
+ # if height >= width and height - divisible >= min_size:
1159
+ # resos.add((width, height - divisible))
1160
+ # resos.add((height - divisible, width))
1161
+
1162
+ size += divisible
1163
+
1164
+ resos = list(resos)
1165
+ resos.sort()
1166
+ return resos
1167
+
1168
+
1169
+ if __name__ == '__main__':
1170
+ resos = make_bucket_resolutions((512, 768))
1171
+ print(len(resos))
1172
+ print(resos)
1173
+ aspect_ratios = [w / h for w, h in resos]
1174
+ print(aspect_ratios)
1175
+
1176
+ ars = set()
1177
+ for ar in aspect_ratios:
1178
+ if ar in ars:
1179
+ print("error! duplicate ar:", ar)
1180
+ ars.add(ar)
Lora/lib/qwerty.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_file
2
+ import sys
3
+ import torch
4
+ from pathlib import Path
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ def cal_cross_attn(to_q, to_k, to_v, rand_input):
9
+ hidden_dim, embed_dim = to_q.shape
10
+ attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
11
+ attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False)
12
+ attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False)
13
+ attn_to_q.load_state_dict({"weight": to_q})
14
+ attn_to_k.load_state_dict({"weight": to_k})
15
+ attn_to_v.load_state_dict({"weight": to_v})
16
+
17
+ return torch.einsum(
18
+ "ik, jk -> ik",
19
+ F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
20
+ attn_to_v(rand_input)
21
+ )
22
+
23
+ def model_hash(filename):
24
+ try:
25
+ with open(filename, "rb") as file:
26
+ import hashlib
27
+ m = hashlib.sha256()
28
+
29
+ file.seek(0x100000)
30
+ m.update(file.read(0x10000))
31
+ return m.hexdigest()[0:8]
32
+ except FileNotFoundError:
33
+ return 'NOFILE'
34
+
35
+ def load_model(path):
36
+ if path.suffix == ".safetensors":
37
+ return load_file(path, device="cpu")
38
+ else:
39
+ ckpt = torch.load(path, map_location="cpu")
40
+ return ckpt["state_dict"] if "state_dict" in ckpt else ckpt
41
+
42
+ def eval(model, n, input):
43
+ qk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"
44
+ uk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_k.weight"
45
+ vk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_v.weight"
46
+ atoq, atok, atov = model[qk], model[uk], model[vk]
47
+
48
+ attn = cal_cross_attn(atoq, atok, atov, input)
49
+ return attn
50
+
51
+ def main():
52
+ file1 = Path(sys.argv[1])
53
+ files = sys.argv[2:]
54
+
55
+ seed = 114514
56
+ torch.manual_seed(seed)
57
+ print(f"seed: {seed}")
58
+
59
+ model_a = load_model(file1)
60
+
61
+ print()
62
+ print(f"base: {file1.name} [{model_hash(file1)}]")
63
+ print()
64
+
65
+ map_attn_a = {}
66
+ map_rand_input = {}
67
+ for n in range(3, 11):
68
+ hidden_dim, embed_dim = model_a[f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"].shape
69
+ rand_input = torch.randn([embed_dim, hidden_dim])
70
+
71
+ map_attn_a[n] = eval(model_a, n, rand_input)
72
+ map_rand_input[n] = rand_input
73
+
74
+ del model_a
75
+
76
+ for file2 in files:
77
+ file2 = Path(file2)
78
+ model_b = load_model(file2)
79
+
80
+ sims = []
81
+ for n in range(3, 11):
82
+ attn_a = map_attn_a[n]
83
+ attn_b = eval(model_b, n, map_rand_input[n])
84
+
85
+ sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
86
+ sims.append(sim)
87
+
88
+ print(f"{file2} [{model_hash(file2)}] - {torch.mean(torch.stack(sims)) * 1e2:.2f}%")
89
+
90
+ if __name__ == "__main__":
91
+ main()
Lora/lib/train_util.py ADDED
@@ -0,0 +1,1766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common functions for training
2
+
3
+ import argparse
4
+ import json
5
+ import shutil
6
+ import time
7
+ from typing import Dict, List, NamedTuple, Tuple
8
+ from accelerate import Accelerator
9
+ from torch.autograd.function import Function
10
+ import glob
11
+ import math
12
+ import os
13
+ import random
14
+ import hashlib
15
+ from io import BytesIO
16
+
17
+ from tqdm import tqdm
18
+ import torch
19
+ from torchvision import transforms
20
+ from transformers import CLIPTokenizer
21
+ import diffusers
22
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
23
+ import albumentations as albu
24
+ import numpy as np
25
+ from PIL import Image
26
+ import cv2
27
+ from einops import rearrange
28
+ from torch import einsum
29
+ import safetensors.torch
30
+
31
+ import train_util
32
+
33
+ # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
34
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
35
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
36
+
37
+ # checkpointファイル名
38
+ EPOCH_STATE_NAME = "{}-{:06d}-state"
39
+ EPOCH_FILE_NAME = "{}-{:06d}"
40
+ EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
41
+ LAST_STATE_NAME = "{}-state"
42
+ DEFAULT_EPOCH_NAME = "epoch"
43
+ DEFAULT_LAST_OUTPUT_NAME = "last"
44
+
45
+ # region dataset
46
+
47
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
48
+ # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
49
+
50
+
51
+ class ImageInfo():
52
+ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
53
+ self.image_key: str = image_key
54
+ self.num_repeats: int = num_repeats
55
+ self.caption: str = caption
56
+ self.is_reg: bool = is_reg
57
+ self.absolute_path: str = absolute_path
58
+ self.image_size: Tuple[int, int] = None
59
+ self.resized_size: Tuple[int, int] = None
60
+ self.bucket_reso: Tuple[int, int] = None
61
+ self.latents: torch.Tensor = None
62
+ self.latents_flipped: torch.Tensor = None
63
+ self.latents_npz: str = None
64
+ self.latents_npz_flipped: str = None
65
+
66
+
67
+ class BucketManager():
68
+ def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
69
+ self.no_upscale = no_upscale
70
+ if max_reso is None:
71
+ self.max_reso = None
72
+ self.max_area = None
73
+ else:
74
+ self.max_reso = max_reso
75
+ self.max_area = max_reso[0] * max_reso[1]
76
+ self.min_size = min_size
77
+ self.max_size = max_size
78
+ self.reso_steps = reso_steps
79
+
80
+ self.resos = []
81
+ self.reso_to_id = {}
82
+ self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
83
+
84
+ def add_image(self, reso, image):
85
+ bucket_id = self.reso_to_id[reso]
86
+ self.buckets[bucket_id].append(image)
87
+
88
+ def shuffle(self):
89
+ for bucket in self.buckets:
90
+ random.shuffle(bucket)
91
+
92
+ def sort(self):
93
+ # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
94
+ sorted_resos = self.resos.copy()
95
+ sorted_resos.sort()
96
+
97
+ sorted_buckets = []
98
+ sorted_reso_to_id = {}
99
+ for i, reso in enumerate(sorted_resos):
100
+ bucket_id = self.reso_to_id[reso]
101
+ sorted_buckets.append(self.buckets[bucket_id])
102
+ sorted_reso_to_id[reso] = i
103
+
104
+ self.resos = sorted_resos
105
+ self.buckets = sorted_buckets
106
+ self.reso_to_id = sorted_reso_to_id
107
+
108
+ def make_buckets(self):
109
+ resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
110
+ self.set_predefined_resos(resos)
111
+
112
+ def set_predefined_resos(self, resos):
113
+ # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
114
+ self.predefined_resos = resos.copy()
115
+ self.predefined_resos_set = set(resos)
116
+ self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
117
+
118
+ def add_if_new_reso(self, reso):
119
+ if reso not in self.reso_to_id:
120
+ bucket_id = len(self.resos)
121
+ self.reso_to_id[reso] = bucket_id
122
+ self.resos.append(reso)
123
+ self.buckets.append([])
124
+ # print(reso, bucket_id, len(self.buckets))
125
+
126
+ def round_to_steps(self, x):
127
+ x = int(x + .5)
128
+ return x - x % self.reso_steps
129
+
130
+ def select_bucket(self, image_width, image_height):
131
+ aspect_ratio = image_width / image_height
132
+ if not self.no_upscale:
133
+ # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
134
+ reso = (image_width, image_height)
135
+ if reso in self.predefined_resos_set:
136
+ pass
137
+ else:
138
+ ar_errors = self.predefined_aspect_ratios - aspect_ratio
139
+ predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
140
+ reso = self.predefined_resos[predefined_bucket_id]
141
+
142
+ ar_reso = reso[0] / reso[1]
143
+ if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
144
+ scale = reso[1] / image_height
145
+ else:
146
+ scale = reso[0] / image_width
147
+
148
+ resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
149
+ # print("use predef", image_width, image_height, reso, resized_size)
150
+ else:
151
+ if image_width * image_height > self.max_area:
152
+ # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
153
+ resized_width = math.sqrt(self.max_area * aspect_ratio)
154
+ resized_height = self.max_area / resized_width
155
+ assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
156
+
157
+ # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
158
+ # 元のbucketingと同じロジック
159
+ b_width_rounded = self.round_to_steps(resized_width)
160
+ b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
161
+ ar_width_rounded = b_width_rounded / b_height_in_wr
162
+
163
+ b_height_rounded = self.round_to_steps(resized_height)
164
+ b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
165
+ ar_height_rounded = b_width_in_hr / b_height_rounded
166
+
167
+ # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
168
+ # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
169
+
170
+ if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
171
+ resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
172
+ else:
173
+ resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
174
+ # print(resized_size)
175
+ else:
176
+ resized_size = (image_width, image_height) # リサイズは不要
177
+
178
+ # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
179
+ bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
180
+ bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
181
+ # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
182
+
183
+ reso = (bucket_width, bucket_height)
184
+
185
+ self.add_if_new_reso(reso)
186
+
187
+ ar_error = (reso[0] / reso[1]) - aspect_ratio
188
+ return reso, resized_size, ar_error
189
+
190
+
191
+ class BucketBatchIndex(NamedTuple):
192
+ bucket_index: int
193
+ bucket_batch_size: int
194
+ batch_index: int
195
+
196
+
197
+ class BaseDataset(torch.utils.data.Dataset):
198
+ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
199
+ super().__init__()
200
+ self.tokenizer: CLIPTokenizer = tokenizer
201
+ self.max_token_length = max_token_length
202
+ self.shuffle_caption = shuffle_caption
203
+ self.shuffle_keep_tokens = shuffle_keep_tokens
204
+ # width/height is used when enable_bucket==False
205
+ self.width, self.height = (None, None) if resolution is None else resolution
206
+ self.face_crop_aug_range = face_crop_aug_range
207
+ self.flip_aug = flip_aug
208
+ self.color_aug = color_aug
209
+ self.debug_dataset = debug_dataset
210
+ self.random_crop = random_crop
211
+ self.token_padding_disabled = False
212
+ self.dataset_dirs_info = {}
213
+ self.reg_dataset_dirs_info = {}
214
+ self.tag_frequency = {}
215
+
216
+ self.enable_bucket = False
217
+ self.bucket_manager: BucketManager = None # not initialized
218
+ self.min_bucket_reso = None
219
+ self.max_bucket_reso = None
220
+ self.bucket_reso_steps = None
221
+ self.bucket_no_upscale = None
222
+ self.bucket_info = None # for metadata
223
+
224
+ self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
225
+
226
+ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
227
+ self.dropout_rate: float = 0
228
+ self.dropout_every_n_epochs: int = None
229
+
230
+ # augmentation
231
+ flip_p = 0.5 if flip_aug else 0.0
232
+ if color_aug:
233
+ # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
234
+ self.aug = albu.Compose([
235
+ albu.OneOf([
236
+ albu.HueSaturationValue(8, 0, 0, p=.5),
237
+ albu.RandomGamma((95, 105), p=.5),
238
+ ], p=.33),
239
+ albu.HorizontalFlip(p=flip_p)
240
+ ], p=1.)
241
+ elif flip_aug:
242
+ self.aug = albu.Compose([
243
+ albu.HorizontalFlip(p=flip_p)
244
+ ], p=1.)
245
+ else:
246
+ self.aug = None
247
+
248
+ self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
249
+
250
+ self.image_data: Dict[str, ImageInfo] = {}
251
+
252
+ self.replacements = {}
253
+
254
+ def set_current_epoch(self, epoch):
255
+ self.current_epoch = epoch
256
+
257
+ def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
258
+ # コンストラクタで渡さないのはTextual Inversionで意識したくないから(と��うことにしておく)
259
+ self.dropout_rate = dropout_rate
260
+ self.dropout_every_n_epochs = dropout_every_n_epochs
261
+ self.tag_dropout_rate = tag_dropout_rate
262
+
263
+ def set_tag_frequency(self, dir_name, captions):
264
+ frequency_for_dir = self.tag_frequency.get(dir_name, {})
265
+ self.tag_frequency[dir_name] = frequency_for_dir
266
+ for caption in captions:
267
+ for tag in caption.split(","):
268
+ if tag and not tag.isspace():
269
+ tag = tag.lower()
270
+ frequency = frequency_for_dir.get(tag, 0)
271
+ frequency_for_dir[tag] = frequency + 1
272
+
273
+ def disable_token_padding(self):
274
+ self.token_padding_disabled = True
275
+
276
+ def add_replacement(self, str_from, str_to):
277
+ self.replacements[str_from] = str_to
278
+
279
+ def process_caption(self, caption):
280
+ # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
281
+ is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
282
+ is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
283
+
284
+ if is_drop_out:
285
+ caption = ""
286
+ else:
287
+ if self.shuffle_caption:
288
+ def dropout_tags(tokens):
289
+ if self.tag_dropout_rate <= 0:
290
+ return tokens
291
+ l = []
292
+ for token in tokens:
293
+ if random.random() >= self.tag_dropout_rate:
294
+ l.append(token)
295
+ return l
296
+
297
+ tokens = [t.strip() for t in caption.strip().split(",")]
298
+ if self.shuffle_keep_tokens is None:
299
+ random.shuffle(tokens)
300
+ tokens = dropout_tags(tokens)
301
+ else:
302
+ if len(tokens) > self.shuffle_keep_tokens:
303
+ keep_tokens = tokens[:self.shuffle_keep_tokens]
304
+ tokens = tokens[self.shuffle_keep_tokens:]
305
+ random.shuffle(tokens)
306
+ tokens = dropout_tags(tokens)
307
+
308
+ tokens = keep_tokens + tokens
309
+ caption = ", ".join(tokens)
310
+
311
+ # textual inversion対応
312
+ for str_from, str_to in self.replacements.items():
313
+ if str_from == "":
314
+ # replace all
315
+ if type(str_to) == list:
316
+ caption = random.choice(str_to)
317
+ else:
318
+ caption = str_to
319
+ else:
320
+ caption = caption.replace(str_from, str_to)
321
+
322
+ return caption
323
+
324
+ def get_input_ids(self, caption):
325
+ input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
326
+ max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
327
+
328
+ if self.tokenizer_max_length > self.tokenizer.model_max_length:
329
+ input_ids = input_ids.squeeze(0)
330
+ iids_list = []
331
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
332
+ # v1
333
+ # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
334
+ # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
335
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
336
+ ids_chunk = (input_ids[0].unsqueeze(0),
337
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
338
+ input_ids[-1].unsqueeze(0))
339
+ ids_chunk = torch.cat(ids_chunk)
340
+ iids_list.append(ids_chunk)
341
+ else:
342
+ # v2
343
+ # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
344
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
345
+ ids_chunk = (input_ids[0].unsqueeze(0), # BOS
346
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
347
+ input_ids[-1].unsqueeze(0)) # PAD or EOS
348
+ ids_chunk = torch.cat(ids_chunk)
349
+
350
+ # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
351
+ # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
352
+ if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
353
+ ids_chunk[-1] = self.tokenizer.eos_token_id
354
+ # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
355
+ if ids_chunk[1] == self.tokenizer.pad_token_id:
356
+ ids_chunk[1] = self.tokenizer.eos_token_id
357
+
358
+ iids_list.append(ids_chunk)
359
+
360
+ input_ids = torch.stack(iids_list) # 3,77
361
+ return input_ids
362
+
363
+ def register_image(self, info: ImageInfo):
364
+ self.image_data[info.image_key] = info
365
+
366
+ def make_buckets(self):
367
+ '''
368
+ bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
369
+ min_size and max_size are ignored when enable_bucket is False
370
+ '''
371
+ print("loading image sizes.")
372
+ for info in tqdm(self.image_data.values()):
373
+ if info.image_size is None:
374
+ info.image_size = self.get_image_size(info.absolute_path)
375
+
376
+ if self.enable_bucket:
377
+ print("make buckets")
378
+ else:
379
+ print("prepare dataset")
380
+
381
+ # bucketを作成し、画像をbucketに振り分ける
382
+ if self.enable_bucket:
383
+ if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
384
+ self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
385
+ self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
386
+ if not self.bucket_no_upscale:
387
+ self.bucket_manager.make_buckets()
388
+ else:
389
+ print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
390
+
391
+ img_ar_errors = []
392
+ for image_info in self.image_data.values():
393
+ image_width, image_height = image_info.image_size
394
+ image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
395
+
396
+ # print(image_info.image_key, image_info.bucket_reso)
397
+ img_ar_errors.append(abs(ar_error))
398
+
399
+ self.bucket_manager.sort()
400
+ else:
401
+ self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
402
+ self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
403
+ for image_info in self.image_data.values():
404
+ image_width, image_height = image_info.image_size
405
+ image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
406
+
407
+ for image_info in self.image_data.values():
408
+ for _ in range(image_info.num_repeats):
409
+ self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
410
+
411
+ # bucket情報を表示、格納する
412
+ if self.enable_bucket:
413
+ self.bucket_info = {"buckets": {}}
414
+ print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
415
+ for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
416
+ count = len(bucket)
417
+ if count > 0:
418
+ self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
419
+ print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
420
+
421
+ img_ar_errors = np.array(img_ar_errors)
422
+ mean_img_ar_error = np.mean(np.abs(img_ar_errors))
423
+ self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
424
+ print(f"mean ar error (without repeats): {mean_img_ar_error}")
425
+
426
+ # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
427
+ self.buckets_indices: List(BucketBatchIndex) = []
428
+ for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
429
+ # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
430
+ # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
431
+ # そのためバッチサイズを画像種類までに制限する
432
+ # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
433
+ # TODO 正則化画像をepochまたがりで利用する仕組み
434
+ num_of_image_types = len(set(bucket))
435
+ bucket_batch_size = min(self.batch_size, num_of_image_types)
436
+ batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
437
+ # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
438
+ for batch_index in range(batch_count):
439
+ self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
440
+
441
+ self.shuffle_buckets()
442
+ self._length = len(self.buckets_indices)
443
+
444
+ def shuffle_buckets(self):
445
+ random.shuffle(self.buckets_indices)
446
+ self.bucket_manager.shuffle()
447
+
448
+ def load_image(self, image_path):
449
+ image = Image.open(image_path)
450
+ if not image.mode == "RGB":
451
+ image = image.convert("RGB")
452
+ img = np.array(image, np.uint8)
453
+ return img
454
+
455
+ def trim_and_resize_if_required(self, image, reso, resized_size):
456
+ image_height, image_width = image.shape[0:2]
457
+
458
+ if image_width != resized_size[0] or image_height != resized_size[1]:
459
+ # リサイズする
460
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
461
+
462
+ image_height, image_width = image.shape[0:2]
463
+ if image_width > reso[0]:
464
+ trim_size = image_width - reso[0]
465
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
466
+ # print("w", trim_size, p)
467
+ image = image[:, p:p + reso[0]]
468
+ if image_height > reso[1]:
469
+ trim_size = image_height - reso[1]
470
+ p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
471
+ # print("h", trim_size, p)
472
+ image = image[p:p + reso[1]]
473
+
474
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
475
+ return image
476
+
477
+ def cache_latents(self, vae):
478
+ # TODO ここを高速化したい
479
+ print("caching latents.")
480
+ for info in tqdm(self.image_data.values()):
481
+ if info.latents_npz is not None:
482
+ info.latents = self.load_latents_from_npz(info, False)
483
+ info.latents = torch.FloatTensor(info.latents)
484
+ info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
485
+ if info.latents_flipped is not None:
486
+ info.latents_flipped = torch.FloatTensor(info.latents_flipped)
487
+ continue
488
+
489
+ image = self.load_image(info.absolute_path)
490
+ image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
491
+
492
+ img_tensor = self.image_transforms(image)
493
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
494
+ info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
495
+
496
+ if self.flip_aug:
497
+ image = image[:, ::-1].copy() # cannot convert to Tensor without copy
498
+ img_tensor = self.image_transforms(image)
499
+ img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
500
+ info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
501
+
502
+ def get_image_size(self, image_path):
503
+ image = Image.open(image_path)
504
+ return image.size
505
+
506
+ def load_image_with_face_info(self, image_path: str):
507
+ img = self.load_image(image_path)
508
+
509
+ face_cx = face_cy = face_w = face_h = 0
510
+ if self.face_crop_aug_range is not None:
511
+ tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
512
+ if len(tokens) >= 5:
513
+ face_cx = int(tokens[-4])
514
+ face_cy = int(tokens[-3])
515
+ face_w = int(tokens[-2])
516
+ face_h = int(tokens[-1])
517
+
518
+ return img, face_cx, face_cy, face_w, face_h
519
+
520
+ # いい感じに切り出す
521
+ def crop_target(self, image, face_cx, face_cy, face_w, face_h):
522
+ height, width = image.shape[0:2]
523
+ if height == self.height and width == self.width:
524
+ return image
525
+
526
+ # 画像サイズはsizeより大きいのでリサイズする
527
+ face_size = max(face_w, face_h)
528
+ min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
529
+ min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
530
+ max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
531
+ if min_scale >= max_scale: # range指定がmin==max
532
+ scale = min_scale
533
+ else:
534
+ scale = random.uniform(min_scale, max_scale)
535
+
536
+ nh = int(height * scale + .5)
537
+ nw = int(width * scale + .5)
538
+ assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
539
+ image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
540
+ face_cx = int(face_cx * scale + .5)
541
+ face_cy = int(face_cy * scale + .5)
542
+ height, width = nh, nw
543
+
544
+ # 顔を中心として448*640とかへ切り出す
545
+ for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
546
+ p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
547
+
548
+ if self.random_crop:
549
+ # 背景も含めるために顔を中心に置く確率を高めつつずらす
550
+ range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
551
+ p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
552
+ else:
553
+ # range指定があるときのみ、すこしだけランダムに(わりと適当)
554
+ if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
555
+ if face_size > self.size // 10 and face_size >= 40:
556
+ p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
557
+
558
+ p1 = max(0, min(p1, length - target_size))
559
+
560
+ if axis == 0:
561
+ image = image[p1:p1 + target_size, :]
562
+ else:
563
+ image = image[:, p1:p1 + target_size]
564
+
565
+ return image
566
+
567
+ def load_latents_from_npz(self, image_info: ImageInfo, flipped):
568
+ npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
569
+ if npz_file is None:
570
+ return None
571
+ return np.load(npz_file)['arr_0']
572
+
573
+ def __len__(self):
574
+ return self._length
575
+
576
+ def __getitem__(self, index):
577
+ if index == 0:
578
+ self.shuffle_buckets()
579
+
580
+ bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
581
+ bucket_batch_size = self.buckets_indices[index].bucket_batch_size
582
+ image_index = self.buckets_indices[index].batch_index * bucket_batch_size
583
+
584
+ loss_weights = []
585
+ captions = []
586
+ input_ids_list = []
587
+ latents_list = []
588
+ images = []
589
+
590
+ for image_key in bucket[image_index:image_index + bucket_batch_size]:
591
+ image_info = self.image_data[image_key]
592
+ loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
593
+
594
+ # image/latentsを処理する
595
+ if image_info.latents is not None:
596
+ latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
597
+ image = None
598
+ elif image_info.latents_npz is not None:
599
+ latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
600
+ latents = torch.FloatTensor(latents)
601
+ image = None
602
+ else:
603
+ # 画像を読み込み、必要ならcropする
604
+ img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
605
+ im_h, im_w = img.shape[0:2]
606
+
607
+ if self.enable_bucket:
608
+ img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
609
+ else:
610
+ if face_cx > 0: # 顔位置情報あり
611
+ img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
612
+ elif im_h > self.height or im_w > self.width:
613
+ assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
614
+ if im_h > self.height:
615
+ p = random.randint(0, im_h - self.height)
616
+ img = img[p:p + self.height]
617
+ if im_w > self.width:
618
+ p = random.randint(0, im_w - self.width)
619
+ img = img[:, p:p + self.width]
620
+
621
+ im_h, im_w = img.shape[0:2]
622
+ assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
623
+
624
+ # augmentation
625
+ if self.aug is not None:
626
+ img = self.aug(image=img)['image']
627
+
628
+ latents = None
629
+ image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
630
+
631
+ images.append(image)
632
+ latents_list.append(latents)
633
+
634
+ caption = self.process_caption(image_info.caption)
635
+ captions.append(caption)
636
+ if not self.token_padding_disabled: # this option might be omitted in future
637
+ input_ids_list.append(self.get_input_ids(caption))
638
+
639
+ example = {}
640
+ example['loss_weights'] = torch.FloatTensor(loss_weights)
641
+
642
+ if self.token_padding_disabled:
643
+ # padding=True means pad in the batch
644
+ example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
645
+ else:
646
+ # batch processing seems to be good
647
+ example['input_ids'] = torch.stack(input_ids_list)
648
+
649
+ if images[0] is not None:
650
+ images = torch.stack(images)
651
+ images = images.to(memory_format=torch.contiguous_format).float()
652
+ else:
653
+ images = None
654
+ example['images'] = images
655
+
656
+ example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
657
+
658
+ if self.debug_dataset:
659
+ example['image_keys'] = bucket[image_index:image_index + self.batch_size]
660
+ example['captions'] = captions
661
+ return example
662
+
663
+
664
+ class DreamBoothDataset(BaseDataset):
665
+ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
666
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
667
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
668
+
669
+ assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
670
+
671
+ self.batch_size = batch_size
672
+ self.size = min(self.width, self.height) # 短いほう
673
+ self.prior_loss_weight = prior_loss_weight
674
+ self.latents_cache = None
675
+
676
+ self.enable_bucket = enable_bucket
677
+ if self.enable_bucket:
678
+ assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
679
+ assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
680
+ self.min_bucket_reso = min_bucket_reso
681
+ self.max_bucket_reso = max_bucket_reso
682
+ self.bucket_reso_steps = bucket_reso_steps
683
+ self.bucket_no_upscale = bucket_no_upscale
684
+ else:
685
+ self.min_bucket_reso = None
686
+ self.max_bucket_reso = None
687
+ self.bucket_reso_steps = None # この情報は使われない
688
+ self.bucket_no_upscale = False
689
+
690
+ def read_caption(img_path):
691
+ # captionの候補ファイル名を作る
692
+ base_name = os.path.splitext(img_path)[0]
693
+ base_name_face_det = base_name
694
+ tokens = base_name.split("_")
695
+ if len(tokens) >= 5:
696
+ base_name_face_det = "_".join(tokens[:-4])
697
+ cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
698
+
699
+ caption = None
700
+ for cap_path in cap_paths:
701
+ if os.path.isfile(cap_path):
702
+ with open(cap_path, "rt", encoding='utf-8') as f:
703
+ try:
704
+ lines = f.readlines()
705
+ except UnicodeDecodeError as e:
706
+ print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
707
+ raise e
708
+ assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
709
+ caption = lines[0].strip()
710
+ break
711
+ return caption
712
+
713
+ def load_dreambooth_dir(dir):
714
+ if not os.path.isdir(dir):
715
+ # print(f"ignore file: {dir}")
716
+ return 0, [], []
717
+
718
+ tokens = os.path.basename(dir).split('_')
719
+ try:
720
+ n_repeats = int(tokens[0])
721
+ except ValueError as e:
722
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
723
+ return 0, [], []
724
+
725
+ caption_by_folder = '_'.join(tokens[1:])
726
+ img_paths = glob_images(dir, "*")
727
+ print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
728
+
729
+ # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
730
+ captions = []
731
+ for img_path in img_paths:
732
+ cap_for_img = read_caption(img_path)
733
+ captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
734
+
735
+ self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
736
+
737
+ return n_repeats, img_paths, captions
738
+
739
+ print("prepare train images.")
740
+ train_dirs = os.listdir(train_data_dir)
741
+ num_train_images = 0
742
+ for dir in train_dirs:
743
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
744
+ num_train_images += n_repeats * len(img_paths)
745
+
746
+ for img_path, caption in zip(img_paths, captions):
747
+ info = ImageInfo(img_path, n_repeats, caption, False, img_path)
748
+ self.register_image(info)
749
+
750
+ self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
751
+
752
+ print(f"{num_train_images} train images with repeating.")
753
+ self.num_train_images = num_train_images
754
+
755
+ # reg imageは数を数えて学習画像と同じ枚数にする
756
+ num_reg_images = 0
757
+ if reg_data_dir:
758
+ print("prepare reg images.")
759
+ reg_infos: List[ImageInfo] = []
760
+
761
+ reg_dirs = os.listdir(reg_data_dir)
762
+ for dir in reg_dirs:
763
+ n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
764
+ num_reg_images += n_repeats * len(img_paths)
765
+
766
+ for img_path, caption in zip(img_paths, captions):
767
+ info = ImageInfo(img_path, n_repeats, caption, True, img_path)
768
+ reg_infos.append(info)
769
+
770
+ self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
771
+
772
+ print(f"{num_reg_images} reg images.")
773
+ if num_train_images < num_reg_images:
774
+ print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
775
+
776
+ if num_reg_images == 0:
777
+ print("no regularization images / 正則化画像が見つかりませんでした")
778
+ else:
779
+ # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
780
+ n = 0
781
+ first_loop = True
782
+ while n < num_train_images:
783
+ for info in reg_infos:
784
+ if first_loop:
785
+ self.register_image(info)
786
+ n += info.num_repeats
787
+ else:
788
+ info.num_repeats += 1
789
+ n += 1
790
+ if n >= num_train_images:
791
+ break
792
+ first_loop = False
793
+
794
+ self.num_reg_images = num_reg_images
795
+
796
+
797
+ class FineTuningDataset(BaseDataset):
798
+ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
799
+ super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
800
+ resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
801
+
802
+ # メタデータを読み込む
803
+ if os.path.exists(json_file_name):
804
+ print(f"loading existing metadata: {json_file_name}")
805
+ with open(json_file_name, "rt", encoding='utf-8') as f:
806
+ metadata = json.load(f)
807
+ else:
808
+ raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
809
+
810
+ self.metadata = metadata
811
+ self.train_data_dir = train_data_dir
812
+ self.batch_size = batch_size
813
+
814
+ tags_list = []
815
+ for image_key, img_md in metadata.items():
816
+ # path情報を作る
817
+ if os.path.exists(image_key):
818
+ abs_path = image_key
819
+ else:
820
+ # わりといい加減だがいい方法が思いつかん
821
+ abs_path = glob_images(train_data_dir, image_key)
822
+ assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
823
+ abs_path = abs_path[0]
824
+
825
+ caption = img_md.get('caption')
826
+ tags = img_md.get('tags')
827
+ if caption is None:
828
+ caption = tags
829
+ elif tags is not None and len(tags) > 0:
830
+ caption = caption + ', ' + tags
831
+ tags_list.append(tags)
832
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
833
+
834
+ image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
835
+ image_info.image_size = img_md.get('train_resolution')
836
+
837
+ if not self.color_aug and not self.random_crop:
838
+ # if npz exists, use them
839
+ image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
840
+
841
+ self.register_image(image_info)
842
+ self.num_train_images = len(metadata) * dataset_repeats
843
+ self.num_reg_images = 0
844
+
845
+ self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
846
+ self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
847
+
848
+ # check existence of all npz files
849
+ use_npz_latents = not (self.color_aug or self.random_crop)
850
+ if use_npz_latents:
851
+ npz_any = False
852
+ npz_all = True
853
+ for image_info in self.image_data.values():
854
+ has_npz = image_info.latents_npz is not None
855
+ npz_any = npz_any or has_npz
856
+
857
+ if self.flip_aug:
858
+ has_npz = has_npz and image_info.latents_npz_flipped is not None
859
+ npz_all = npz_all and has_npz
860
+
861
+ if npz_any and not npz_all:
862
+ break
863
+
864
+ if not npz_any:
865
+ use_npz_latents = False
866
+ print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
867
+ elif not npz_all:
868
+ use_npz_latents = False
869
+ print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
870
+ if self.flip_aug:
871
+ print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
872
+ # else:
873
+ # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
874
+
875
+ # check min/max bucket size
876
+ sizes = set()
877
+ resos = set()
878
+ for image_info in self.image_data.values():
879
+ if image_info.image_size is None:
880
+ sizes = None # not calculated
881
+ break
882
+ sizes.add(image_info.image_size[0])
883
+ sizes.add(image_info.image_size[1])
884
+ resos.add(tuple(image_info.image_size))
885
+
886
+ if sizes is None:
887
+ if use_npz_latents:
888
+ use_npz_latents = False
889
+ print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
890
+
891
+ assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
892
+
893
+ self.enable_bucket = enable_bucket
894
+ if self.enable_bucket:
895
+ self.min_bucket_reso = min_bucket_reso
896
+ self.max_bucket_reso = max_bucket_reso
897
+ self.bucket_reso_steps = bucket_reso_steps
898
+ self.bucket_no_upscale = bucket_no_upscale
899
+ else:
900
+ if not enable_bucket:
901
+ print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
902
+ print("using bucket info in metadata / メタデータ内のbucket情報を使います")
903
+ self.enable_bucket = True
904
+
905
+ assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
906
+
907
+ # bucket情報を��期化しておく、make_bucketsで再作成しない
908
+ self.bucket_manager = BucketManager(False, None, None, None, None)
909
+ self.bucket_manager.set_predefined_resos(resos)
910
+
911
+ # npz情報をきれいにしておく
912
+ if not use_npz_latents:
913
+ for image_info in self.image_data.values():
914
+ image_info.latents_npz = image_info.latents_npz_flipped = None
915
+
916
+ def image_key_to_npz_file(self, image_key):
917
+ base_name = os.path.splitext(image_key)[0]
918
+ npz_file_norm = base_name + '.npz'
919
+
920
+ if os.path.exists(npz_file_norm):
921
+ # image_key is full path
922
+ npz_file_flip = base_name + '_flip.npz'
923
+ if not os.path.exists(npz_file_flip):
924
+ npz_file_flip = None
925
+ return npz_file_norm, npz_file_flip
926
+
927
+ # image_key is relative path
928
+ npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
929
+ npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
930
+
931
+ if not os.path.exists(npz_file_norm):
932
+ npz_file_norm = None
933
+ npz_file_flip = None
934
+ elif not os.path.exists(npz_file_flip):
935
+ npz_file_flip = None
936
+
937
+ return npz_file_norm, npz_file_flip
938
+
939
+
940
+ def debug_dataset(train_dataset, show_input_ids=False):
941
+ print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
942
+ print("Escape for exit. / Escキーで中断、終了します")
943
+
944
+ train_dataset.set_current_epoch(1)
945
+ k = 0
946
+ for i, example in enumerate(train_dataset):
947
+ if example['latents'] is not None:
948
+ print(f"sample has latents from npz file: {example['latents'].size()}")
949
+ for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
950
+ print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
951
+ if show_input_ids:
952
+ print(f"input ids: {iid}")
953
+ if example['images'] is not None:
954
+ im = example['images'][j]
955
+ print(f"image size: {im.size()}")
956
+ im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
957
+ im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
958
+ im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
959
+ if os.name == 'nt': # only windows
960
+ cv2.imshow("img", im)
961
+ k = cv2.waitKey()
962
+ cv2.destroyAllWindows()
963
+ if k == 27:
964
+ break
965
+ if k == 27 or (example['images'] is None and i >= 8):
966
+ break
967
+
968
+
969
+ def glob_images(directory, base="*"):
970
+ img_paths = []
971
+ for ext in IMAGE_EXTENSIONS:
972
+ if base == '*':
973
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
974
+ else:
975
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
976
+ # img_paths = list(set(img_paths)) # 重複を排除
977
+ # img_paths.sort()
978
+ return img_paths
979
+
980
+
981
+ def glob_images_pathlib(dir_path, recursive):
982
+ image_paths = []
983
+ if recursive:
984
+ for ext in IMAGE_EXTENSIONS:
985
+ image_paths += list(dir_path.rglob('*' + ext))
986
+ else:
987
+ for ext in IMAGE_EXTENSIONS:
988
+ image_paths += list(dir_path.glob('*' + ext))
989
+ # image_paths = list(set(image_paths)) # 重複を排除
990
+ # image_paths.sort()
991
+ return image_paths
992
+
993
+ # endregion
994
+
995
+
996
+ # region モジュール入れ替え部
997
+ """
998
+ 高速化のためのモジュール入れ替え
999
+ """
1000
+
1001
+ # FlashAttentionを使うCrossAttention
1002
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
1003
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
1004
+
1005
+ # constants
1006
+
1007
+ EPSILON = 1e-6
1008
+
1009
+ # helper functions
1010
+
1011
+
1012
+ def exists(val):
1013
+ return val is not None
1014
+
1015
+
1016
+ def default(val, d):
1017
+ return val if exists(val) else d
1018
+
1019
+
1020
+ def model_hash(filename):
1021
+ """Old model hash used by stable-diffusion-webui"""
1022
+ try:
1023
+ with open(filename, "rb") as file:
1024
+ m = hashlib.sha256()
1025
+
1026
+ file.seek(0x100000)
1027
+ m.update(file.read(0x10000))
1028
+ return m.hexdigest()[0:8]
1029
+ except FileNotFoundError:
1030
+ return 'NOFILE'
1031
+
1032
+
1033
+ def calculate_sha256(filename):
1034
+ """New model hash used by stable-diffusion-webui"""
1035
+ hash_sha256 = hashlib.sha256()
1036
+ blksize = 1024 * 1024
1037
+
1038
+ with open(filename, "rb") as f:
1039
+ for chunk in iter(lambda: f.read(blksize), b""):
1040
+ hash_sha256.update(chunk)
1041
+
1042
+ return hash_sha256.hexdigest()
1043
+
1044
+
1045
+ def precalculate_safetensors_hashes(tensors, metadata):
1046
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
1047
+ save time on indexing the model later."""
1048
+
1049
+ # Because writing user metadata to the file can change the result of
1050
+ # sd_models.model_hash(), only retain the training metadata for purposes of
1051
+ # calculating the hash, as they are meant to be immutable
1052
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
1053
+
1054
+ bytes = safetensors.torch.save(tensors, metadata)
1055
+ b = BytesIO(bytes)
1056
+
1057
+ model_hash = addnet_hash_safetensors(b)
1058
+ legacy_hash = addnet_hash_legacy(b)
1059
+ return model_hash, legacy_hash
1060
+
1061
+
1062
+ def addnet_hash_legacy(b):
1063
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
1064
+ m = hashlib.sha256()
1065
+
1066
+ b.seek(0x100000)
1067
+ m.update(b.read(0x10000))
1068
+ return m.hexdigest()[0:8]
1069
+
1070
+
1071
+ def addnet_hash_safetensors(b):
1072
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
1073
+ hash_sha256 = hashlib.sha256()
1074
+ blksize = 1024 * 1024
1075
+
1076
+ b.seek(0)
1077
+ header = b.read(8)
1078
+ n = int.from_bytes(header, "little")
1079
+
1080
+ offset = n + 8
1081
+ b.seek(offset)
1082
+ for chunk in iter(lambda: b.read(blksize), b""):
1083
+ hash_sha256.update(chunk)
1084
+
1085
+ return hash_sha256.hexdigest()
1086
+
1087
+
1088
+ # flash attention forwards and backwards
1089
+
1090
+ # https://arxiv.org/abs/2205.14135
1091
+
1092
+
1093
+ class FlashAttentionFunction(torch.autograd.function.Function):
1094
+ @ staticmethod
1095
+ @ torch.no_grad()
1096
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
1097
+ """ Algorithm 2 in the paper """
1098
+
1099
+ device = q.device
1100
+ dtype = q.dtype
1101
+ max_neg_value = -torch.finfo(q.dtype).max
1102
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1103
+
1104
+ o = torch.zeros_like(q)
1105
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
1106
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
1107
+
1108
+ scale = (q.shape[-1] ** -0.5)
1109
+
1110
+ if not exists(mask):
1111
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
1112
+ else:
1113
+ mask = rearrange(mask, 'b n -> b 1 1 n')
1114
+ mask = mask.split(q_bucket_size, dim=-1)
1115
+
1116
+ row_splits = zip(
1117
+ q.split(q_bucket_size, dim=-2),
1118
+ o.split(q_bucket_size, dim=-2),
1119
+ mask,
1120
+ all_row_sums.split(q_bucket_size, dim=-2),
1121
+ all_row_maxes.split(q_bucket_size, dim=-2),
1122
+ )
1123
+
1124
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
1125
+ q_start_index = ind * q_bucket_size - qk_len_diff
1126
+
1127
+ col_splits = zip(
1128
+ k.split(k_bucket_size, dim=-2),
1129
+ v.split(k_bucket_size, dim=-2),
1130
+ )
1131
+
1132
+ for k_ind, (kc, vc) in enumerate(col_splits):
1133
+ k_start_index = k_ind * k_bucket_size
1134
+
1135
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1136
+
1137
+ if exists(row_mask):
1138
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
1139
+
1140
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1141
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1142
+ device=device).triu(q_start_index - k_start_index + 1)
1143
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1144
+
1145
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
1146
+ attn_weights -= block_row_maxes
1147
+ exp_weights = torch.exp(attn_weights)
1148
+
1149
+ if exists(row_mask):
1150
+ exp_weights.masked_fill_(~row_mask, 0.)
1151
+
1152
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
1153
+
1154
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
1155
+
1156
+ exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
1157
+
1158
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
1159
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
1160
+
1161
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
1162
+
1163
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
1164
+
1165
+ row_maxes.copy_(new_row_maxes)
1166
+ row_sums.copy_(new_row_sums)
1167
+
1168
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
1169
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
1170
+
1171
+ return o
1172
+
1173
+ @ staticmethod
1174
+ @ torch.no_grad()
1175
+ def backward(ctx, do):
1176
+ """ Algorithm 4 in the paper """
1177
+
1178
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
1179
+ q, k, v, o, l, m = ctx.saved_tensors
1180
+
1181
+ device = q.device
1182
+
1183
+ max_neg_value = -torch.finfo(q.dtype).max
1184
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1185
+
1186
+ dq = torch.zeros_like(q)
1187
+ dk = torch.zeros_like(k)
1188
+ dv = torch.zeros_like(v)
1189
+
1190
+ row_splits = zip(
1191
+ q.split(q_bucket_size, dim=-2),
1192
+ o.split(q_bucket_size, dim=-2),
1193
+ do.split(q_bucket_size, dim=-2),
1194
+ mask,
1195
+ l.split(q_bucket_size, dim=-2),
1196
+ m.split(q_bucket_size, dim=-2),
1197
+ dq.split(q_bucket_size, dim=-2)
1198
+ )
1199
+
1200
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
1201
+ q_start_index = ind * q_bucket_size - qk_len_diff
1202
+
1203
+ col_splits = zip(
1204
+ k.split(k_bucket_size, dim=-2),
1205
+ v.split(k_bucket_size, dim=-2),
1206
+ dk.split(k_bucket_size, dim=-2),
1207
+ dv.split(k_bucket_size, dim=-2),
1208
+ )
1209
+
1210
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
1211
+ k_start_index = k_ind * k_bucket_size
1212
+
1213
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1214
+
1215
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1216
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1217
+ device=device).triu(q_start_index - k_start_index + 1)
1218
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
1219
+
1220
+ exp_attn_weights = torch.exp(attn_weights - mc)
1221
+
1222
+ if exists(row_mask):
1223
+ exp_attn_weights.masked_fill_(~row_mask, 0.)
1224
+
1225
+ p = exp_attn_weights / lc
1226
+
1227
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
1228
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
1229
+
1230
+ D = (doc * oc).sum(dim=-1, keepdims=True)
1231
+ ds = p * scale * (dp - D)
1232
+
1233
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
1234
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
1235
+
1236
+ dqc.add_(dq_chunk)
1237
+ dkc.add_(dk_chunk)
1238
+ dvc.add_(dv_chunk)
1239
+
1240
+ return dq, dk, dv, None, None, None, None
1241
+
1242
+
1243
+ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
1244
+ if mem_eff_attn:
1245
+ replace_unet_cross_attn_to_memory_efficient()
1246
+ elif xformers:
1247
+ replace_unet_cross_attn_to_xformers()
1248
+
1249
+
1250
+ def replace_unet_cross_attn_to_memory_efficient():
1251
+ print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
1252
+ flash_func = FlashAttentionFunction
1253
+
1254
+ def forward_flash_attn(self, x, context=None, mask=None):
1255
+ q_bucket_size = 512
1256
+ k_bucket_size = 1024
1257
+
1258
+ h = self.heads
1259
+ q = self.to_q(x)
1260
+
1261
+ context = context if context is not None else x
1262
+ context = context.to(x.dtype)
1263
+
1264
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1265
+ context_k, context_v = self.hypernetwork.forward(x, context)
1266
+ context_k = context_k.to(x.dtype)
1267
+ context_v = context_v.to(x.dtype)
1268
+ else:
1269
+ context_k = context
1270
+ context_v = context
1271
+
1272
+ k = self.to_k(context_k)
1273
+ v = self.to_v(context_v)
1274
+ del context, x
1275
+
1276
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
1277
+
1278
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
1279
+
1280
+ out = rearrange(out, 'b h n d -> b n (h d)')
1281
+
1282
+ # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
1283
+ out = self.to_out[0](out)
1284
+ out = self.to_out[1](out)
1285
+ return out
1286
+
1287
+ diffusers.models.attention.CrossAttention.forward = forward_flash_attn
1288
+
1289
+
1290
+ def replace_unet_cross_attn_to_xformers():
1291
+ print("Replace CrossAttention.forward to use xformers")
1292
+ try:
1293
+ import xformers.ops
1294
+ except ImportError:
1295
+ raise ImportError("No xformers / xformersがインストールされていないようです")
1296
+
1297
+ def forward_xformers(self, x, context=None, mask=None):
1298
+ h = self.heads
1299
+ q_in = self.to_q(x)
1300
+
1301
+ context = default(context, x)
1302
+ context = context.to(x.dtype)
1303
+
1304
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1305
+ context_k, context_v = self.hypernetwork.forward(x, context)
1306
+ context_k = context_k.to(x.dtype)
1307
+ context_v = context_v.to(x.dtype)
1308
+ else:
1309
+ context_k = context
1310
+ context_v = context
1311
+
1312
+ k_in = self.to_k(context_k)
1313
+ v_in = self.to_v(context_v)
1314
+
1315
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
1316
+ del q_in, k_in, v_in
1317
+
1318
+ q = q.contiguous()
1319
+ k = k.contiguous()
1320
+ v = v.contiguous()
1321
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
1322
+
1323
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
1324
+
1325
+ # diffusers 0.7.0~
1326
+ out = self.to_out[0](out)
1327
+ out = self.to_out[1](out)
1328
+ return out
1329
+
1330
+ diffusers.models.attention.CrossAttention.forward = forward_xformers
1331
+ # endregion
1332
+
1333
+
1334
+ # region arguments
1335
+
1336
+ def add_sd_models_arguments(parser: argparse.ArgumentParser):
1337
+ # for pretrained models
1338
+ parser.add_argument("--v2", action='store_true',
1339
+ help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
1340
+ parser.add_argument("--v_parameterization", action='store_true',
1341
+ help='enable v-parameterization training / v-parameterization学習を有効にする')
1342
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1343
+ help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1344
+
1345
+
1346
+ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
1347
+ parser.add_argument("--output_dir", type=str, default=None,
1348
+ help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
1349
+ parser.add_argument("--output_name", type=str, default=None,
1350
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
1351
+ parser.add_argument("--save_precision", type=str, default=None,
1352
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
1353
+ parser.add_argument("--save_every_n_epochs", type=int, default=None,
1354
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
1355
+ parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
1356
+ help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
1357
+ parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
1358
+ parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
1359
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
1360
+ parser.add_argument("--save_state", action="store_true",
1361
+ help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
1362
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1363
+
1364
+ parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1365
+ parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1366
+ help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1367
+ parser.add_argument("--use_8bit_adam", action="store_true",
1368
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1369
+ parser.add_argument("--mem_eff_attn", action="store_true",
1370
+ help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1371
+ parser.add_argument("--xformers", action="store_true",
1372
+ help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
1373
+ parser.add_argument("--vae", type=str, default=None,
1374
+ help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1375
+
1376
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1377
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1378
+ parser.add_argument("--max_train_epochs", type=int, default=None,
1379
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
1380
+ parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
1381
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
1382
+ parser.add_argument("--persistent_data_loader_workers", action="store_true",
1383
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
1384
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1385
+ parser.add_argument("--gradient_checkpointing", action="store_true",
1386
+ help="enable gradient checkpointing / grandient checkpointingを有効にする")
1387
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
1388
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
1389
+ parser.add_argument("--mixed_precision", type=str, default="no",
1390
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
1391
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1392
+ parser.add_argument("--clip_skip", type=int, default=None,
1393
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
1394
+ parser.add_argument("--logging_dir", type=str, default=None,
1395
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1396
+ parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1397
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
1398
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1399
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
1400
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1401
+
1402
+ if support_dreambooth:
1403
+ # DreamBooth training
1404
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0,
1405
+ help="loss weight for regularization images / 正則化画像のlossの重み")
1406
+
1407
+
1408
+ def verify_training_args(args: argparse.Namespace):
1409
+ if args.v_parameterization and not args.v2:
1410
+ print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
1411
+ if args.v2 and args.clip_skip is not None:
1412
+ print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
1413
+
1414
+
1415
+ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
1416
+ # dataset common
1417
+ parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
1418
+ parser.add_argument("--shuffle_caption", action="store_true",
1419
+ help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
1420
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1421
+ parser.add_argument("--caption_extention", type=str, default=None,
1422
+ help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1423
+ parser.add_argument("--keep_tokens", type=int, default=None,
1424
+ help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1425
+ parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1426
+ parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1427
+ parser.add_argument("--face_crop_aug_range", type=str, default=None,
1428
+ help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
1429
+ parser.add_argument("--random_crop", action="store_true",
1430
+ help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
1431
+ parser.add_argument("--debug_dataset", action="store_true",
1432
+ help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
1433
+ parser.add_argument("--resolution", type=str, default=None,
1434
+ help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
1435
+ parser.add_argument("--cache_latents", action="store_true",
1436
+ help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
1437
+ parser.add_argument("--enable_bucket", action="store_true",
1438
+ help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
1439
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
1440
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
1441
+ parser.add_argument("--bucket_reso_steps", type=int, default=64,
1442
+ help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
1443
+ parser.add_argument("--bucket_no_upscale", action="store_true",
1444
+ help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
1445
+
1446
+ if support_caption_dropout:
1447
+ # Textual Inversion はcaptionのdropoutをsupportしない
1448
+ # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1449
+ parser.add_argument("--caption_dropout_rate", type=float, default=0,
1450
+ help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1451
+ parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1452
+ help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1453
+ parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1454
+ help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1455
+
1456
+ if support_dreambooth:
1457
+ # DreamBooth dataset
1458
+ parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
1459
+
1460
+ if support_caption:
1461
+ # caption dataset
1462
+ parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
1463
+ parser.add_argument("--dataset_repeats", type=int, default=1,
1464
+ help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
1465
+
1466
+
1467
+ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1468
+ parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
1469
+ help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
1470
+ parser.add_argument("--use_safetensors", action='store_true',
1471
+ help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
1472
+
1473
+ # endregion
1474
+
1475
+ # region utils
1476
+
1477
+
1478
+ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1479
+ # backward compatibility
1480
+ if args.caption_extention is not None:
1481
+ args.caption_extension = args.caption_extention
1482
+ args.caption_extention = None
1483
+
1484
+ if args.cache_latents:
1485
+ assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1486
+ assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1487
+
1488
+ # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1489
+ if args.resolution is not None:
1490
+ args.resolution = tuple([int(r) for r in args.resolution.split(',')])
1491
+ if len(args.resolution) == 1:
1492
+ args.resolution = (args.resolution[0], args.resolution[0])
1493
+ assert len(args.resolution) == 2, \
1494
+ f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
1495
+
1496
+ if args.face_crop_aug_range is not None:
1497
+ args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
1498
+ assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
1499
+ f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
1500
+ else:
1501
+ args.face_crop_aug_range = None
1502
+
1503
+ if support_metadata:
1504
+ if args.in_json is not None and (args.color_aug or args.random_crop):
1505
+ print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
1506
+
1507
+
1508
+ def load_tokenizer(args: argparse.Namespace):
1509
+ print("prepare tokenizer")
1510
+ if args.v2:
1511
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1512
+ else:
1513
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1514
+ if args.max_token_length is not None:
1515
+ print(f"update token length: {args.max_token_length}")
1516
+ return tokenizer
1517
+
1518
+
1519
+ def prepare_accelerator(args: argparse.Namespace):
1520
+ if args.logging_dir is None:
1521
+ log_with = None
1522
+ logging_dir = None
1523
+ else:
1524
+ log_with = "tensorboard"
1525
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
1526
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
1527
+
1528
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
1529
+ log_with=log_with, logging_dir=logging_dir)
1530
+
1531
+ # accelerateの互換性問題を解決する
1532
+ accelerator_0_15 = True
1533
+ try:
1534
+ accelerator.unwrap_model("dummy", True)
1535
+ print("Using accelerator 0.15.0 or above.")
1536
+ except TypeError:
1537
+ accelerator_0_15 = False
1538
+
1539
+ def unwrap_model(model):
1540
+ if accelerator_0_15:
1541
+ return accelerator.unwrap_model(model, True)
1542
+ return accelerator.unwrap_model(model)
1543
+
1544
+ return accelerator, unwrap_model
1545
+
1546
+
1547
+ def prepare_dtype(args: argparse.Namespace):
1548
+ weight_dtype = torch.float32
1549
+ if args.mixed_precision == "fp16":
1550
+ weight_dtype = torch.float16
1551
+ elif args.mixed_precision == "bf16":
1552
+ weight_dtype = torch.bfloat16
1553
+
1554
+ save_dtype = None
1555
+ if args.save_precision == "fp16":
1556
+ save_dtype = torch.float16
1557
+ elif args.save_precision == "bf16":
1558
+ save_dtype = torch.bfloat16
1559
+ elif args.save_precision == "float":
1560
+ save_dtype = torch.float32
1561
+
1562
+ return weight_dtype, save_dtype
1563
+
1564
+
1565
+ def load_target_model(args: argparse.Namespace, weight_dtype):
1566
+ load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
1567
+ if load_stable_diffusion_format:
1568
+ print("load StableDiffusion checkpoint")
1569
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1570
+ else:
1571
+ print("load Diffusers pretrained models")
1572
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
1573
+ text_encoder = pipe.text_encoder
1574
+ vae = pipe.vae
1575
+ unet = pipe.unet
1576
+ del pipe
1577
+
1578
+ # VAEを読み込む
1579
+ if args.vae is not None:
1580
+ vae = model_util.load_vae(args.vae, weight_dtype)
1581
+ print("additional VAE loaded")
1582
+
1583
+ return text_encoder, vae, unet, load_stable_diffusion_format
1584
+
1585
+
1586
+ def patch_accelerator_for_fp16_training(accelerator):
1587
+ org_unscale_grads = accelerator.scaler._unscale_grads_
1588
+
1589
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
1590
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
1591
+
1592
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
1593
+
1594
+
1595
+ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
1596
+ # with no_token_padding, the length is not max length, return result immediately
1597
+ if input_ids.size()[-1] != tokenizer.model_max_length:
1598
+ return text_encoder(input_ids)[0]
1599
+
1600
+ b_size = input_ids.size()[0]
1601
+ input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
1602
+
1603
+ if args.clip_skip is None:
1604
+ encoder_hidden_states = text_encoder(input_ids)[0]
1605
+ else:
1606
+ enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
1607
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1608
+ if weight_dtype is not None:
1609
+ # this is required for additional network training
1610
+ encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
1611
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
1612
+
1613
+ # bs*3, 77, 768 or 1024
1614
+ encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
1615
+
1616
+ if args.max_token_length is not None:
1617
+ if args.v2:
1618
+ # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
1619
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1620
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1621
+ chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
1622
+ if i > 0:
1623
+ for j in range(len(chunk)):
1624
+ if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
1625
+ chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
1626
+ states_list.append(chunk) # <BOS> の後から <EOS> の前まで
1627
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
1628
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1629
+ else:
1630
+ # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
1631
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1632
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
1633
+ states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
1634
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
1635
+ encoder_hidden_states = torch.cat(states_list, dim=1)
1636
+
1637
+ return encoder_hidden_states
1638
+
1639
+
1640
+ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
1641
+ model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
1642
+ ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
1643
+ return model_name, ckpt_name
1644
+
1645
+
1646
+ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
1647
+ saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
1648
+ if saving:
1649
+ os.makedirs(args.output_dir, exist_ok=True)
1650
+ save_func()
1651
+
1652
+ if args.save_last_n_epochs is not None:
1653
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
1654
+ remove_old_func(remove_epoch_no)
1655
+ return saving
1656
+
1657
+
1658
+ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
1659
+ epoch_no = epoch + 1
1660
+ model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
1661
+
1662
+ if save_stable_diffusion_format:
1663
+ def save_sd():
1664
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1665
+ print(f"saving checkpoint: {ckpt_file}")
1666
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1667
+ src_path, epoch_no, global_step, save_dtype, vae)
1668
+
1669
+ def remove_sd(old_epoch_no):
1670
+ _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
1671
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1672
+ if os.path.exists(old_ckpt_file):
1673
+ print(f"removing old checkpoint: {old_ckpt_file}")
1674
+ os.remove(old_ckpt_file)
1675
+
1676
+ save_func = save_sd
1677
+ remove_old_func = remove_sd
1678
+ else:
1679
+ def save_du():
1680
+ out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
1681
+ print(f"saving model: {out_dir}")
1682
+ os.makedirs(out_dir, exist_ok=True)
1683
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1684
+ src_path, vae=vae, use_safetensors=use_safetensors)
1685
+
1686
+ def remove_du(old_epoch_no):
1687
+ out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
1688
+ if os.path.exists(out_dir_old):
1689
+ print(f"removing old model: {out_dir_old}")
1690
+ shutil.rmtree(out_dir_old)
1691
+
1692
+ save_func = save_du
1693
+ remove_old_func = remove_du
1694
+
1695
+ saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
1696
+ if saving and args.save_state:
1697
+ save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
1698
+
1699
+
1700
+ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
1701
+ print("saving state.")
1702
+ accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
1703
+
1704
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
1705
+ if last_n_epochs is not None:
1706
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
1707
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
1708
+ if os.path.exists(state_dir_old):
1709
+ print(f"removing old state: {state_dir_old}")
1710
+ shutil.rmtree(state_dir_old)
1711
+
1712
+
1713
+ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
1714
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1715
+
1716
+ if save_stable_diffusion_format:
1717
+ os.makedirs(args.output_dir, exist_ok=True)
1718
+
1719
+ ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
1720
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1721
+
1722
+ print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
1723
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1724
+ src_path, epoch, global_step, save_dtype, vae)
1725
+ else:
1726
+ out_dir = os.path.join(args.output_dir, model_name)
1727
+ os.makedirs(out_dir, exist_ok=True)
1728
+
1729
+ print(f"save trained model as Diffusers to {out_dir}")
1730
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1731
+ src_path, vae=vae, use_safetensors=use_safetensors)
1732
+
1733
+
1734
+ def save_state_on_train_end(args: argparse.Namespace, accelerator):
1735
+ print("saving last state.")
1736
+ os.makedirs(args.output_dir, exist_ok=True)
1737
+ model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1738
+ accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1739
+
1740
+ # endregion
1741
+
1742
+ # region 前処理用
1743
+
1744
+
1745
+ class ImageLoadingDataset(torch.utils.data.Dataset):
1746
+ def __init__(self, image_paths):
1747
+ self.images = image_paths
1748
+
1749
+ def __len__(self):
1750
+ return len(self.images)
1751
+
1752
+ def __getitem__(self, idx):
1753
+ img_path = self.images[idx]
1754
+
1755
+ try:
1756
+ image = Image.open(img_path).convert("RGB")
1757
+ # convert to tensor temporarily so dataloader will accept it
1758
+ tensor_pil = transforms.functional.pil_to_tensor(image)
1759
+ except Exception as e:
1760
+ print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
1761
+ return None
1762
+
1763
+ return (tensor_pil, img_path)
1764
+
1765
+
1766
+ # endregion
Lora/matous_LORA.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52e4a8e2fbda7b304374826a7f73061d33f4401cab515a92093f1ff09fd5fc19
3
+ size 151109011