abc commited on
Commit
3584d95
·
1 Parent(s): 5f5d36f

Delete build

Browse files
build/lib/library/__init__.py DELETED
File without changes
build/lib/library/model_util.py DELETED
@@ -1,1180 +0,0 @@
1
- # v1: split from train_db_fixed.py.
2
- # v2: support safetensors
3
-
4
- import math
5
- import os
6
- import torch
7
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
- from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
- from safetensors.torch import load_file, save_file
10
-
11
- # DiffUsers版StableDiffusionのモデルパラメータ
12
- NUM_TRAIN_TIMESTEPS = 1000
13
- BETA_START = 0.00085
14
- BETA_END = 0.0120
15
-
16
- UNET_PARAMS_MODEL_CHANNELS = 320
17
- UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
- UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
- UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
- UNET_PARAMS_IN_CHANNELS = 4
21
- UNET_PARAMS_OUT_CHANNELS = 4
22
- UNET_PARAMS_NUM_RES_BLOCKS = 2
23
- UNET_PARAMS_CONTEXT_DIM = 768
24
- UNET_PARAMS_NUM_HEADS = 8
25
-
26
- VAE_PARAMS_Z_CHANNELS = 4
27
- VAE_PARAMS_RESOLUTION = 256
28
- VAE_PARAMS_IN_CHANNELS = 3
29
- VAE_PARAMS_OUT_CH = 3
30
- VAE_PARAMS_CH = 128
31
- VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
32
- VAE_PARAMS_NUM_RES_BLOCKS = 2
33
-
34
- # V2
35
- V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
36
- V2_UNET_PARAMS_CONTEXT_DIM = 1024
37
-
38
- # Diffusersの設定を読み込むための参照モデル
39
- DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
40
- DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
41
-
42
-
43
- # region StableDiffusion->Diffusersの変換コード
44
- # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
45
-
46
-
47
- def shave_segments(path, n_shave_prefix_segments=1):
48
- """
49
- Removes segments. Positive values shave the first segments, negative shave the last segments.
50
- """
51
- if n_shave_prefix_segments >= 0:
52
- return ".".join(path.split(".")[n_shave_prefix_segments:])
53
- else:
54
- return ".".join(path.split(".")[:n_shave_prefix_segments])
55
-
56
-
57
- def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
- """
59
- Updates paths inside resnets to the new naming scheme (local renaming)
60
- """
61
- mapping = []
62
- for old_item in old_list:
63
- new_item = old_item.replace("in_layers.0", "norm1")
64
- new_item = new_item.replace("in_layers.2", "conv1")
65
-
66
- new_item = new_item.replace("out_layers.0", "norm2")
67
- new_item = new_item.replace("out_layers.3", "conv2")
68
-
69
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
- new_item = new_item.replace("skip_connection", "conv_shortcut")
71
-
72
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
-
74
- mapping.append({"old": old_item, "new": new_item})
75
-
76
- return mapping
77
-
78
-
79
- def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
- """
81
- Updates paths inside resnets to the new naming scheme (local renaming)
82
- """
83
- mapping = []
84
- for old_item in old_list:
85
- new_item = old_item
86
-
87
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
-
90
- mapping.append({"old": old_item, "new": new_item})
91
-
92
- return mapping
93
-
94
-
95
- def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
- """
97
- Updates paths inside attentions to the new naming scheme (local renaming)
98
- """
99
- mapping = []
100
- for old_item in old_list:
101
- new_item = old_item
102
-
103
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
-
106
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
-
109
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
-
111
- mapping.append({"old": old_item, "new": new_item})
112
-
113
- return mapping
114
-
115
-
116
- def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
- """
118
- Updates paths inside attentions to the new naming scheme (local renaming)
119
- """
120
- mapping = []
121
- for old_item in old_list:
122
- new_item = old_item
123
-
124
- new_item = new_item.replace("norm.weight", "group_norm.weight")
125
- new_item = new_item.replace("norm.bias", "group_norm.bias")
126
-
127
- new_item = new_item.replace("q.weight", "query.weight")
128
- new_item = new_item.replace("q.bias", "query.bias")
129
-
130
- new_item = new_item.replace("k.weight", "key.weight")
131
- new_item = new_item.replace("k.bias", "key.bias")
132
-
133
- new_item = new_item.replace("v.weight", "value.weight")
134
- new_item = new_item.replace("v.bias", "value.bias")
135
-
136
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
-
139
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
-
141
- mapping.append({"old": old_item, "new": new_item})
142
-
143
- return mapping
144
-
145
-
146
- def assign_to_checkpoint(
147
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
- ):
149
- """
150
- This does the final conversion step: take locally converted weights and apply a global renaming
151
- to them. It splits attention layers, and takes into account additional replacements
152
- that may arise.
153
-
154
- Assigns the weights to the new checkpoint.
155
- """
156
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
-
158
- # Splits the attention layers into three variables.
159
- if attention_paths_to_split is not None:
160
- for path, path_map in attention_paths_to_split.items():
161
- old_tensor = old_checkpoint[path]
162
- channels = old_tensor.shape[0] // 3
163
-
164
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
-
166
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
-
168
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
-
171
- checkpoint[path_map["query"]] = query.reshape(target_shape)
172
- checkpoint[path_map["key"]] = key.reshape(target_shape)
173
- checkpoint[path_map["value"]] = value.reshape(target_shape)
174
-
175
- for path in paths:
176
- new_path = path["new"]
177
-
178
- # These have already been assigned
179
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
- continue
181
-
182
- # Global renaming happens here
183
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
-
187
- if additional_replacements is not None:
188
- for replacement in additional_replacements:
189
- new_path = new_path.replace(replacement["old"], replacement["new"])
190
-
191
- # proj_attn.weight has to be converted from conv 1D to linear
192
- if "proj_attn.weight" in new_path:
193
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
- else:
195
- checkpoint[new_path] = old_checkpoint[path["old"]]
196
-
197
-
198
- def conv_attn_to_linear(checkpoint):
199
- keys = list(checkpoint.keys())
200
- attn_keys = ["query.weight", "key.weight", "value.weight"]
201
- for key in keys:
202
- if ".".join(key.split(".")[-2:]) in attn_keys:
203
- if checkpoint[key].ndim > 2:
204
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
- elif "proj_attn.weight" in key:
206
- if checkpoint[key].ndim > 2:
207
- checkpoint[key] = checkpoint[key][:, :, 0]
208
-
209
-
210
- def linear_transformer_to_conv(checkpoint):
211
- keys = list(checkpoint.keys())
212
- tf_keys = ["proj_in.weight", "proj_out.weight"]
213
- for key in keys:
214
- if ".".join(key.split(".")[-2:]) in tf_keys:
215
- if checkpoint[key].ndim == 2:
216
- checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
-
218
-
219
- def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
- """
221
- Takes a state dict and a config, and returns a converted checkpoint.
222
- """
223
-
224
- # extract state_dict for UNet
225
- unet_state_dict = {}
226
- unet_key = "model.diffusion_model."
227
- keys = list(checkpoint.keys())
228
- for key in keys:
229
- if key.startswith(unet_key):
230
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
-
232
- new_checkpoint = {}
233
-
234
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
-
239
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
-
242
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
-
247
- # Retrieves the keys for the input blocks only
248
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
- input_blocks = {
250
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
251
- for layer_id in range(num_input_blocks)
252
- }
253
-
254
- # Retrieves the keys for the middle blocks only
255
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
256
- middle_blocks = {
257
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
258
- for layer_id in range(num_middle_blocks)
259
- }
260
-
261
- # Retrieves the keys for the output blocks only
262
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
263
- output_blocks = {
264
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
265
- for layer_id in range(num_output_blocks)
266
- }
267
-
268
- for i in range(1, num_input_blocks):
269
- block_id = (i - 1) // (config["layers_per_block"] + 1)
270
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
271
-
272
- resnets = [
273
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
274
- ]
275
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
276
-
277
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
278
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
279
- f"input_blocks.{i}.0.op.weight"
280
- )
281
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
282
- f"input_blocks.{i}.0.op.bias"
283
- )
284
-
285
- paths = renew_resnet_paths(resnets)
286
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
287
- assign_to_checkpoint(
288
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
289
- )
290
-
291
- if len(attentions):
292
- paths = renew_attention_paths(attentions)
293
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
294
- assign_to_checkpoint(
295
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
296
- )
297
-
298
- resnet_0 = middle_blocks[0]
299
- attentions = middle_blocks[1]
300
- resnet_1 = middle_blocks[2]
301
-
302
- resnet_0_paths = renew_resnet_paths(resnet_0)
303
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
304
-
305
- resnet_1_paths = renew_resnet_paths(resnet_1)
306
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
307
-
308
- attentions_paths = renew_attention_paths(attentions)
309
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
310
- assign_to_checkpoint(
311
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
312
- )
313
-
314
- for i in range(num_output_blocks):
315
- block_id = i // (config["layers_per_block"] + 1)
316
- layer_in_block_id = i % (config["layers_per_block"] + 1)
317
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
318
- output_block_list = {}
319
-
320
- for layer in output_block_layers:
321
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
322
- if layer_id in output_block_list:
323
- output_block_list[layer_id].append(layer_name)
324
- else:
325
- output_block_list[layer_id] = [layer_name]
326
-
327
- if len(output_block_list) > 1:
328
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
329
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
330
-
331
- resnet_0_paths = renew_resnet_paths(resnets)
332
- paths = renew_resnet_paths(resnets)
333
-
334
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
335
- assign_to_checkpoint(
336
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
337
- )
338
-
339
- # オリジナル:
340
- # if ["conv.weight", "conv.bias"] in output_block_list.values():
341
- # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
342
-
343
- # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
344
- for l in output_block_list.values():
345
- l.sort()
346
-
347
- if ["conv.bias", "conv.weight"] in output_block_list.values():
348
- index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
349
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
350
- f"output_blocks.{i}.{index}.conv.bias"
351
- ]
352
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
353
- f"output_blocks.{i}.{index}.conv.weight"
354
- ]
355
-
356
- # Clear attentions as they have been attributed above.
357
- if len(attentions) == 2:
358
- attentions = []
359
-
360
- if len(attentions):
361
- paths = renew_attention_paths(attentions)
362
- meta_path = {
363
- "old": f"output_blocks.{i}.1",
364
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
365
- }
366
- assign_to_checkpoint(
367
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
368
- )
369
- else:
370
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
371
- for path in resnet_0_paths:
372
- old_path = ".".join(["output_blocks", str(i), path["old"]])
373
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
374
-
375
- new_checkpoint[new_path] = unet_state_dict[old_path]
376
-
377
- # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
378
- if v2:
379
- linear_transformer_to_conv(new_checkpoint)
380
-
381
- return new_checkpoint
382
-
383
-
384
- def convert_ldm_vae_checkpoint(checkpoint, config):
385
- # extract state dict for VAE
386
- vae_state_dict = {}
387
- vae_key = "first_stage_model."
388
- keys = list(checkpoint.keys())
389
- for key in keys:
390
- if key.startswith(vae_key):
391
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
392
- # if len(vae_state_dict) == 0:
393
- # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
394
- # vae_state_dict = checkpoint
395
-
396
- new_checkpoint = {}
397
-
398
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
399
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
400
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
401
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
402
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
403
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
404
-
405
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
406
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
407
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
408
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
409
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
410
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
411
-
412
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
413
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
414
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
415
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
416
-
417
- # Retrieves the keys for the encoder down blocks only
418
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
419
- down_blocks = {
420
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
421
- }
422
-
423
- # Retrieves the keys for the decoder up blocks only
424
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
425
- up_blocks = {
426
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
427
- }
428
-
429
- for i in range(num_down_blocks):
430
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
431
-
432
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
433
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
434
- f"encoder.down.{i}.downsample.conv.weight"
435
- )
436
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
437
- f"encoder.down.{i}.downsample.conv.bias"
438
- )
439
-
440
- paths = renew_vae_resnet_paths(resnets)
441
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
442
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
443
-
444
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
445
- num_mid_res_blocks = 2
446
- for i in range(1, num_mid_res_blocks + 1):
447
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
448
-
449
- paths = renew_vae_resnet_paths(resnets)
450
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
451
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
452
-
453
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
454
- paths = renew_vae_attention_paths(mid_attentions)
455
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
456
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
457
- conv_attn_to_linear(new_checkpoint)
458
-
459
- for i in range(num_up_blocks):
460
- block_id = num_up_blocks - 1 - i
461
- resnets = [
462
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
463
- ]
464
-
465
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
466
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
467
- f"decoder.up.{block_id}.upsample.conv.weight"
468
- ]
469
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
470
- f"decoder.up.{block_id}.upsample.conv.bias"
471
- ]
472
-
473
- paths = renew_vae_resnet_paths(resnets)
474
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
475
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
476
-
477
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
478
- num_mid_res_blocks = 2
479
- for i in range(1, num_mid_res_blocks + 1):
480
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
481
-
482
- paths = renew_vae_resnet_paths(resnets)
483
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
484
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
485
-
486
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
487
- paths = renew_vae_attention_paths(mid_attentions)
488
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
489
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
490
- conv_attn_to_linear(new_checkpoint)
491
- return new_checkpoint
492
-
493
-
494
- def create_unet_diffusers_config(v2):
495
- """
496
- Creates a config for the diffusers based on the config of the LDM model.
497
- """
498
- # unet_params = original_config.model.params.unet_config.params
499
-
500
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
501
-
502
- down_block_types = []
503
- resolution = 1
504
- for i in range(len(block_out_channels)):
505
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
506
- down_block_types.append(block_type)
507
- if i != len(block_out_channels) - 1:
508
- resolution *= 2
509
-
510
- up_block_types = []
511
- for i in range(len(block_out_channels)):
512
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
513
- up_block_types.append(block_type)
514
- resolution //= 2
515
-
516
- config = dict(
517
- sample_size=UNET_PARAMS_IMAGE_SIZE,
518
- in_channels=UNET_PARAMS_IN_CHANNELS,
519
- out_channels=UNET_PARAMS_OUT_CHANNELS,
520
- down_block_types=tuple(down_block_types),
521
- up_block_types=tuple(up_block_types),
522
- block_out_channels=tuple(block_out_channels),
523
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
524
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
525
- attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
526
- )
527
-
528
- return config
529
-
530
-
531
- def create_vae_diffusers_config():
532
- """
533
- Creates a config for the diffusers based on the config of the LDM model.
534
- """
535
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
536
- # _ = original_config.model.params.first_stage_config.params.embed_dim
537
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
538
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
539
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
540
-
541
- config = dict(
542
- sample_size=VAE_PARAMS_RESOLUTION,
543
- in_channels=VAE_PARAMS_IN_CHANNELS,
544
- out_channels=VAE_PARAMS_OUT_CH,
545
- down_block_types=tuple(down_block_types),
546
- up_block_types=tuple(up_block_types),
547
- block_out_channels=tuple(block_out_channels),
548
- latent_channels=VAE_PARAMS_Z_CHANNELS,
549
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
550
- )
551
- return config
552
-
553
-
554
- def convert_ldm_clip_checkpoint_v1(checkpoint):
555
- keys = list(checkpoint.keys())
556
- text_model_dict = {}
557
- for key in keys:
558
- if key.startswith("cond_stage_model.transformer"):
559
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
560
- return text_model_dict
561
-
562
-
563
- def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
564
- # 嫌になるくらい違うぞ!
565
- def convert_key(key):
566
- if not key.startswith("cond_stage_model"):
567
- return None
568
-
569
- # common conversion
570
- key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
571
- key = key.replace("cond_stage_model.model.", "text_model.")
572
-
573
- if "resblocks" in key:
574
- # resblocks conversion
575
- key = key.replace(".resblocks.", ".layers.")
576
- if ".ln_" in key:
577
- key = key.replace(".ln_", ".layer_norm")
578
- elif ".mlp." in key:
579
- key = key.replace(".c_fc.", ".fc1.")
580
- key = key.replace(".c_proj.", ".fc2.")
581
- elif '.attn.out_proj' in key:
582
- key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
583
- elif '.attn.in_proj' in key:
584
- key = None # 特殊なので後で処理する
585
- else:
586
- raise ValueError(f"unexpected key in SD: {key}")
587
- elif '.positional_embedding' in key:
588
- key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
589
- elif '.text_projection' in key:
590
- key = None # 使われない???
591
- elif '.logit_scale' in key:
592
- key = None # 使われない???
593
- elif '.token_embedding' in key:
594
- key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
595
- elif '.ln_final' in key:
596
- key = key.replace(".ln_final", ".final_layer_norm")
597
- return key
598
-
599
- keys = list(checkpoint.keys())
600
- new_sd = {}
601
- for key in keys:
602
- # remove resblocks 23
603
- if '.resblocks.23.' in key:
604
- continue
605
- new_key = convert_key(key)
606
- if new_key is None:
607
- continue
608
- new_sd[new_key] = checkpoint[key]
609
-
610
- # attnの変換
611
- for key in keys:
612
- if '.resblocks.23.' in key:
613
- continue
614
- if '.resblocks' in key and '.attn.in_proj_' in key:
615
- # 三つに分割
616
- values = torch.chunk(checkpoint[key], 3)
617
-
618
- key_suffix = ".weight" if "weight" in key else ".bias"
619
- key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
620
- key_pfx = key_pfx.replace("_weight", "")
621
- key_pfx = key_pfx.replace("_bias", "")
622
- key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
623
- new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
624
- new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
625
- new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
626
-
627
- # rename or add position_ids
628
- ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
629
- if ANOTHER_POSITION_IDS_KEY in new_sd:
630
- # waifu diffusion v1.4
631
- position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
632
- del new_sd[ANOTHER_POSITION_IDS_KEY]
633
- else:
634
- position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
635
-
636
- new_sd["text_model.embeddings.position_ids"] = position_ids
637
- return new_sd
638
-
639
- # endregion
640
-
641
-
642
- # region Diffusers->StableDiffusion の変換コード
643
- # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
644
-
645
- def conv_transformer_to_linear(checkpoint):
646
- keys = list(checkpoint.keys())
647
- tf_keys = ["proj_in.weight", "proj_out.weight"]
648
- for key in keys:
649
- if ".".join(key.split(".")[-2:]) in tf_keys:
650
- if checkpoint[key].ndim > 2:
651
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
652
-
653
-
654
- def convert_unet_state_dict_to_sd(v2, unet_state_dict):
655
- unet_conversion_map = [
656
- # (stable-diffusion, HF Diffusers)
657
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
658
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
659
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
660
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
661
- ("input_blocks.0.0.weight", "conv_in.weight"),
662
- ("input_blocks.0.0.bias", "conv_in.bias"),
663
- ("out.0.weight", "conv_norm_out.weight"),
664
- ("out.0.bias", "conv_norm_out.bias"),
665
- ("out.2.weight", "conv_out.weight"),
666
- ("out.2.bias", "conv_out.bias"),
667
- ]
668
-
669
- unet_conversion_map_resnet = [
670
- # (stable-diffusion, HF Diffusers)
671
- ("in_layers.0", "norm1"),
672
- ("in_layers.2", "conv1"),
673
- ("out_layers.0", "norm2"),
674
- ("out_layers.3", "conv2"),
675
- ("emb_layers.1", "time_emb_proj"),
676
- ("skip_connection", "conv_shortcut"),
677
- ]
678
-
679
- unet_conversion_map_layer = []
680
- for i in range(4):
681
- # loop over downblocks/upblocks
682
-
683
- for j in range(2):
684
- # loop over resnets/attentions for downblocks
685
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
686
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
687
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
688
-
689
- if i < 3:
690
- # no attention layers in down_blocks.3
691
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
692
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
693
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
694
-
695
- for j in range(3):
696
- # loop over resnets/attentions for upblocks
697
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
698
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
699
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
700
-
701
- if i > 0:
702
- # no attention layers in up_blocks.0
703
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
704
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
705
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
706
-
707
- if i < 3:
708
- # no downsample in down_blocks.3
709
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
710
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
711
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
712
-
713
- # no upsample in up_blocks.3
714
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
715
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
716
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
717
-
718
- hf_mid_atn_prefix = "mid_block.attentions.0."
719
- sd_mid_atn_prefix = "middle_block.1."
720
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
721
-
722
- for j in range(2):
723
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
724
- sd_mid_res_prefix = f"middle_block.{2*j}."
725
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
726
-
727
- # buyer beware: this is a *brittle* function,
728
- # and correct output requires that all of these pieces interact in
729
- # the exact order in which I have arranged them.
730
- mapping = {k: k for k in unet_state_dict.keys()}
731
- for sd_name, hf_name in unet_conversion_map:
732
- mapping[hf_name] = sd_name
733
- for k, v in mapping.items():
734
- if "resnets" in k:
735
- for sd_part, hf_part in unet_conversion_map_resnet:
736
- v = v.replace(hf_part, sd_part)
737
- mapping[k] = v
738
- for k, v in mapping.items():
739
- for sd_part, hf_part in unet_conversion_map_layer:
740
- v = v.replace(hf_part, sd_part)
741
- mapping[k] = v
742
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
743
-
744
- if v2:
745
- conv_transformer_to_linear(new_state_dict)
746
-
747
- return new_state_dict
748
-
749
-
750
- # ================#
751
- # VAE Conversion #
752
- # ================#
753
-
754
- def reshape_weight_for_sd(w):
755
- # convert HF linear weights to SD conv2d weights
756
- return w.reshape(*w.shape, 1, 1)
757
-
758
-
759
- def convert_vae_state_dict(vae_state_dict):
760
- vae_conversion_map = [
761
- # (stable-diffusion, HF Diffusers)
762
- ("nin_shortcut", "conv_shortcut"),
763
- ("norm_out", "conv_norm_out"),
764
- ("mid.attn_1.", "mid_block.attentions.0."),
765
- ]
766
-
767
- for i in range(4):
768
- # down_blocks have two resnets
769
- for j in range(2):
770
- hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
771
- sd_down_prefix = f"encoder.down.{i}.block.{j}."
772
- vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
773
-
774
- if i < 3:
775
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
776
- sd_downsample_prefix = f"down.{i}.downsample."
777
- vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
778
-
779
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
780
- sd_upsample_prefix = f"up.{3-i}.upsample."
781
- vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
782
-
783
- # up_blocks have three resnets
784
- # also, up blocks in hf are numbered in reverse from sd
785
- for j in range(3):
786
- hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
787
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
788
- vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
789
-
790
- # this part accounts for mid blocks in both the encoder and the decoder
791
- for i in range(2):
792
- hf_mid_res_prefix = f"mid_block.resnets.{i}."
793
- sd_mid_res_prefix = f"mid.block_{i+1}."
794
- vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
795
-
796
- vae_conversion_map_attn = [
797
- # (stable-diffusion, HF Diffusers)
798
- ("norm.", "group_norm."),
799
- ("q.", "query."),
800
- ("k.", "key."),
801
- ("v.", "value."),
802
- ("proj_out.", "proj_attn."),
803
- ]
804
-
805
- mapping = {k: k for k in vae_state_dict.keys()}
806
- for k, v in mapping.items():
807
- for sd_part, hf_part in vae_conversion_map:
808
- v = v.replace(hf_part, sd_part)
809
- mapping[k] = v
810
- for k, v in mapping.items():
811
- if "attentions" in k:
812
- for sd_part, hf_part in vae_conversion_map_attn:
813
- v = v.replace(hf_part, sd_part)
814
- mapping[k] = v
815
- new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
816
- weights_to_convert = ["q", "k", "v", "proj_out"]
817
- for k, v in new_state_dict.items():
818
- for weight_name in weights_to_convert:
819
- if f"mid.attn_1.{weight_name}.weight" in k:
820
- # print(f"Reshaping {k} for SD format")
821
- new_state_dict[k] = reshape_weight_for_sd(v)
822
-
823
- return new_state_dict
824
-
825
-
826
- # endregion
827
-
828
- # region 自作のモデル読み書きなど
829
-
830
- def is_safetensors(path):
831
- return os.path.splitext(path)[1].lower() == '.safetensors'
832
-
833
-
834
- def load_checkpoint_with_text_encoder_conversion(ckpt_path):
835
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
836
- TEXT_ENCODER_KEY_REPLACEMENTS = [
837
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
838
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
839
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
840
- ]
841
-
842
- if is_safetensors(ckpt_path):
843
- checkpoint = None
844
- state_dict = load_file(ckpt_path, "cpu")
845
- else:
846
- checkpoint = torch.load(ckpt_path, map_location="cpu")
847
- if "state_dict" in checkpoint:
848
- state_dict = checkpoint["state_dict"]
849
- else:
850
- state_dict = checkpoint
851
- checkpoint = None
852
-
853
- key_reps = []
854
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
855
- for key in state_dict.keys():
856
- if key.startswith(rep_from):
857
- new_key = rep_to + key[len(rep_from):]
858
- key_reps.append((key, new_key))
859
-
860
- for key, new_key in key_reps:
861
- state_dict[new_key] = state_dict[key]
862
- del state_dict[key]
863
-
864
- return checkpoint, state_dict
865
-
866
-
867
- # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
868
- def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
869
- _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
870
- if dtype is not None:
871
- for k, v in state_dict.items():
872
- if type(v) is torch.Tensor:
873
- state_dict[k] = v.to(dtype)
874
-
875
- # Convert the UNet2DConditionModel model.
876
- unet_config = create_unet_diffusers_config(v2)
877
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
878
-
879
- unet = UNet2DConditionModel(**unet_config)
880
- info = unet.load_state_dict(converted_unet_checkpoint)
881
- print("loading u-net:", info)
882
-
883
- # Convert the VAE model.
884
- vae_config = create_vae_diffusers_config()
885
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
886
-
887
- vae = AutoencoderKL(**vae_config)
888
- info = vae.load_state_dict(converted_vae_checkpoint)
889
- print("loading vae:", info)
890
-
891
- # convert text_model
892
- if v2:
893
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
894
- cfg = CLIPTextConfig(
895
- vocab_size=49408,
896
- hidden_size=1024,
897
- intermediate_size=4096,
898
- num_hidden_layers=23,
899
- num_attention_heads=16,
900
- max_position_embeddings=77,
901
- hidden_act="gelu",
902
- layer_norm_eps=1e-05,
903
- dropout=0.0,
904
- attention_dropout=0.0,
905
- initializer_range=0.02,
906
- initializer_factor=1.0,
907
- pad_token_id=1,
908
- bos_token_id=0,
909
- eos_token_id=2,
910
- model_type="clip_text_model",
911
- projection_dim=512,
912
- torch_dtype="float32",
913
- transformers_version="4.25.0.dev0",
914
- )
915
- text_model = CLIPTextModel._from_config(cfg)
916
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
917
- else:
918
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
919
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
920
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
921
- print("loading text encoder:", info)
922
-
923
- return text_model, vae, unet
924
-
925
-
926
- def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
927
- def convert_key(key):
928
- # position_idsの除去
929
- if ".position_ids" in key:
930
- return None
931
-
932
- # common
933
- key = key.replace("text_model.encoder.", "transformer.")
934
- key = key.replace("text_model.", "")
935
- if "layers" in key:
936
- # resblocks conversion
937
- key = key.replace(".layers.", ".resblocks.")
938
- if ".layer_norm" in key:
939
- key = key.replace(".layer_norm", ".ln_")
940
- elif ".mlp." in key:
941
- key = key.replace(".fc1.", ".c_fc.")
942
- key = key.replace(".fc2.", ".c_proj.")
943
- elif '.self_attn.out_proj' in key:
944
- key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
945
- elif '.self_attn.' in key:
946
- key = None # 特殊なので後で処理する
947
- else:
948
- raise ValueError(f"unexpected key in DiffUsers model: {key}")
949
- elif '.position_embedding' in key:
950
- key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
951
- elif '.token_embedding' in key:
952
- key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
953
- elif 'final_layer_norm' in key:
954
- key = key.replace("final_layer_norm", "ln_final")
955
- return key
956
-
957
- keys = list(checkpoint.keys())
958
- new_sd = {}
959
- for key in keys:
960
- new_key = convert_key(key)
961
- if new_key is None:
962
- continue
963
- new_sd[new_key] = checkpoint[key]
964
-
965
- # attnの変換
966
- for key in keys:
967
- if 'layers' in key and 'q_proj' in key:
968
- # 三つを結合
969
- key_q = key
970
- key_k = key.replace("q_proj", "k_proj")
971
- key_v = key.replace("q_proj", "v_proj")
972
-
973
- value_q = checkpoint[key_q]
974
- value_k = checkpoint[key_k]
975
- value_v = checkpoint[key_v]
976
- value = torch.cat([value_q, value_k, value_v])
977
-
978
- new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
979
- new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
980
- new_sd[new_key] = value
981
-
982
- # 最後の層などを捏造するか
983
- if make_dummy_weights:
984
- print("make dummy weights for resblock.23, text_projection and logit scale.")
985
- keys = list(new_sd.keys())
986
- for key in keys:
987
- if key.startswith("transformer.resblocks.22."):
988
- new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
989
-
990
- # Diffusersに含まれない重みを作っておく
991
- new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
992
- new_sd['logit_scale'] = torch.tensor(1)
993
-
994
- return new_sd
995
-
996
-
997
- def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
998
- if ckpt_path is not None:
999
- # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1000
- checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1001
- if checkpoint is None: # safetensors または state_dictのckpt
1002
- checkpoint = {}
1003
- strict = False
1004
- else:
1005
- strict = True
1006
- if "state_dict" in state_dict:
1007
- del state_dict["state_dict"]
1008
- else:
1009
- # 新しく作る
1010
- assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1011
- checkpoint = {}
1012
- state_dict = {}
1013
- strict = False
1014
-
1015
- def update_sd(prefix, sd):
1016
- for k, v in sd.items():
1017
- key = prefix + k
1018
- assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1019
- if save_dtype is not None:
1020
- v = v.detach().clone().to("cpu").to(save_dtype)
1021
- state_dict[key] = v
1022
-
1023
- # Convert the UNet model
1024
- unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1025
- update_sd("model.diffusion_model.", unet_state_dict)
1026
-
1027
- # Convert the text encoder model
1028
- if v2:
1029
- make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
1030
- text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1031
- update_sd("cond_stage_model.model.", text_enc_dict)
1032
- else:
1033
- text_enc_dict = text_encoder.state_dict()
1034
- update_sd("cond_stage_model.transformer.", text_enc_dict)
1035
-
1036
- # Convert the VAE
1037
- if vae is not None:
1038
- vae_dict = convert_vae_state_dict(vae.state_dict())
1039
- update_sd("first_stage_model.", vae_dict)
1040
-
1041
- # Put together new checkpoint
1042
- key_count = len(state_dict.keys())
1043
- new_ckpt = {'state_dict': state_dict}
1044
-
1045
- if 'epoch' in checkpoint:
1046
- epochs += checkpoint['epoch']
1047
- if 'global_step' in checkpoint:
1048
- steps += checkpoint['global_step']
1049
-
1050
- new_ckpt['epoch'] = epochs
1051
- new_ckpt['global_step'] = steps
1052
-
1053
- if is_safetensors(output_file):
1054
- # TODO Tensor以外のdictの値を削除したほうがいいか
1055
- save_file(state_dict, output_file)
1056
- else:
1057
- torch.save(new_ckpt, output_file)
1058
-
1059
- return key_count
1060
-
1061
-
1062
- def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1063
- if pretrained_model_name_or_path is None:
1064
- # load default settings for v1/v2
1065
- if v2:
1066
- pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1067
- else:
1068
- pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1069
-
1070
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1071
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1072
- if vae is None:
1073
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1074
-
1075
- pipeline = StableDiffusionPipeline(
1076
- unet=unet,
1077
- text_encoder=text_encoder,
1078
- vae=vae,
1079
- scheduler=scheduler,
1080
- tokenizer=tokenizer,
1081
- safety_checker=None,
1082
- feature_extractor=None,
1083
- requires_safety_checker=None,
1084
- )
1085
- pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1086
-
1087
-
1088
- VAE_PREFIX = "first_stage_model."
1089
-
1090
-
1091
- def load_vae(vae_id, dtype):
1092
- print(f"load VAE: {vae_id}")
1093
- if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1094
- # Diffusers local/remote
1095
- try:
1096
- vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1097
- except EnvironmentError as e:
1098
- print(f"exception occurs in loading vae: {e}")
1099
- print("retry with subfolder='vae'")
1100
- vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1101
- return vae
1102
-
1103
- # local
1104
- vae_config = create_vae_diffusers_config()
1105
-
1106
- if vae_id.endswith(".bin"):
1107
- # SD 1.5 VAE on Huggingface
1108
- converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1109
- else:
1110
- # StableDiffusion
1111
- vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1112
- else torch.load(vae_id, map_location="cpu"))
1113
- vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1114
-
1115
- # vae only or full model
1116
- full_model = False
1117
- for vae_key in vae_sd:
1118
- if vae_key.startswith(VAE_PREFIX):
1119
- full_model = True
1120
- break
1121
- if not full_model:
1122
- sd = {}
1123
- for key, value in vae_sd.items():
1124
- sd[VAE_PREFIX + key] = value
1125
- vae_sd = sd
1126
- del sd
1127
-
1128
- # Convert the VAE model.
1129
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1130
-
1131
- vae = AutoencoderKL(**vae_config)
1132
- vae.load_state_dict(converted_vae_checkpoint)
1133
- return vae
1134
-
1135
- # endregion
1136
-
1137
-
1138
- def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1139
- max_width, max_height = max_reso
1140
- max_area = (max_width // divisible) * (max_height // divisible)
1141
-
1142
- resos = set()
1143
-
1144
- size = int(math.sqrt(max_area)) * divisible
1145
- resos.add((size, size))
1146
-
1147
- size = min_size
1148
- while size <= max_size:
1149
- width = size
1150
- height = min(max_size, (max_area // (width // divisible)) * divisible)
1151
- resos.add((width, height))
1152
- resos.add((height, width))
1153
-
1154
- # # make additional resos
1155
- # if width >= height and width - divisible >= min_size:
1156
- # resos.add((width - divisible, height))
1157
- # resos.add((height, width - divisible))
1158
- # if height >= width and height - divisible >= min_size:
1159
- # resos.add((width, height - divisible))
1160
- # resos.add((height - divisible, width))
1161
-
1162
- size += divisible
1163
-
1164
- resos = list(resos)
1165
- resos.sort()
1166
- return resos
1167
-
1168
-
1169
- if __name__ == '__main__':
1170
- resos = make_bucket_resolutions((512, 768))
1171
- print(len(resos))
1172
- print(resos)
1173
- aspect_ratios = [w / h for w, h in resos]
1174
- print(aspect_ratios)
1175
-
1176
- ars = set()
1177
- for ar in aspect_ratios:
1178
- if ar in ars:
1179
- print("error! duplicate ar:", ar)
1180
- ars.add(ar)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/lib/library/train_util.py DELETED
@@ -1,1796 +0,0 @@
1
- # common functions for training
2
-
3
- import argparse
4
- import json
5
- import shutil
6
- import time
7
- from typing import Dict, List, NamedTuple, Tuple
8
- from accelerate import Accelerator
9
- from torch.autograd.function import Function
10
- import glob
11
- import math
12
- import os
13
- import random
14
- import hashlib
15
- import subprocess
16
- from io import BytesIO
17
-
18
- from tqdm import tqdm
19
- import torch
20
- from torchvision import transforms
21
- from transformers import CLIPTokenizer
22
- import diffusers
23
- from diffusers import DDPMScheduler, StableDiffusionPipeline
24
- import albumentations as albu
25
- import numpy as np
26
- from PIL import Image
27
- import cv2
28
- from einops import rearrange
29
- from torch import einsum
30
- import safetensors.torch
31
-
32
- import library.model_util as model_util
33
-
34
- # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
35
- TOKENIZER_PATH = "openai/clip-vit-large-patch14"
36
- V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
37
-
38
- # checkpointファイル名
39
- EPOCH_STATE_NAME = "{}-{:06d}-state"
40
- EPOCH_FILE_NAME = "{}-{:06d}"
41
- EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
42
- LAST_STATE_NAME = "{}-state"
43
- DEFAULT_EPOCH_NAME = "epoch"
44
- DEFAULT_LAST_OUTPUT_NAME = "last"
45
-
46
- # region dataset
47
-
48
- IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
49
- # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
50
-
51
-
52
- class ImageInfo():
53
- def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
54
- self.image_key: str = image_key
55
- self.num_repeats: int = num_repeats
56
- self.caption: str = caption
57
- self.is_reg: bool = is_reg
58
- self.absolute_path: str = absolute_path
59
- self.image_size: Tuple[int, int] = None
60
- self.resized_size: Tuple[int, int] = None
61
- self.bucket_reso: Tuple[int, int] = None
62
- self.latents: torch.Tensor = None
63
- self.latents_flipped: torch.Tensor = None
64
- self.latents_npz: str = None
65
- self.latents_npz_flipped: str = None
66
-
67
-
68
- class BucketManager():
69
- def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
70
- self.no_upscale = no_upscale
71
- if max_reso is None:
72
- self.max_reso = None
73
- self.max_area = None
74
- else:
75
- self.max_reso = max_reso
76
- self.max_area = max_reso[0] * max_reso[1]
77
- self.min_size = min_size
78
- self.max_size = max_size
79
- self.reso_steps = reso_steps
80
-
81
- self.resos = []
82
- self.reso_to_id = {}
83
- self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
84
-
85
- def add_image(self, reso, image):
86
- bucket_id = self.reso_to_id[reso]
87
- self.buckets[bucket_id].append(image)
88
-
89
- def shuffle(self):
90
- for bucket in self.buckets:
91
- random.shuffle(bucket)
92
-
93
- def sort(self):
94
- # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
95
- sorted_resos = self.resos.copy()
96
- sorted_resos.sort()
97
-
98
- sorted_buckets = []
99
- sorted_reso_to_id = {}
100
- for i, reso in enumerate(sorted_resos):
101
- bucket_id = self.reso_to_id[reso]
102
- sorted_buckets.append(self.buckets[bucket_id])
103
- sorted_reso_to_id[reso] = i
104
-
105
- self.resos = sorted_resos
106
- self.buckets = sorted_buckets
107
- self.reso_to_id = sorted_reso_to_id
108
-
109
- def make_buckets(self):
110
- resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
111
- self.set_predefined_resos(resos)
112
-
113
- def set_predefined_resos(self, resos):
114
- # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
115
- self.predefined_resos = resos.copy()
116
- self.predefined_resos_set = set(resos)
117
- self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
118
-
119
- def add_if_new_reso(self, reso):
120
- if reso not in self.reso_to_id:
121
- bucket_id = len(self.resos)
122
- self.reso_to_id[reso] = bucket_id
123
- self.resos.append(reso)
124
- self.buckets.append([])
125
- # print(reso, bucket_id, len(self.buckets))
126
-
127
- def round_to_steps(self, x):
128
- x = int(x + .5)
129
- return x - x % self.reso_steps
130
-
131
- def select_bucket(self, image_width, image_height):
132
- aspect_ratio = image_width / image_height
133
- if not self.no_upscale:
134
- # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
135
- reso = (image_width, image_height)
136
- if reso in self.predefined_resos_set:
137
- pass
138
- else:
139
- ar_errors = self.predefined_aspect_ratios - aspect_ratio
140
- predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
141
- reso = self.predefined_resos[predefined_bucket_id]
142
-
143
- ar_reso = reso[0] / reso[1]
144
- if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
145
- scale = reso[1] / image_height
146
- else:
147
- scale = reso[0] / image_width
148
-
149
- resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
150
- # print("use predef", image_width, image_height, reso, resized_size)
151
- else:
152
- if image_width * image_height > self.max_area:
153
- # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
154
- resized_width = math.sqrt(self.max_area * aspect_ratio)
155
- resized_height = self.max_area / resized_width
156
- assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
157
-
158
- # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
159
- # 元のbucketingと同じロジック
160
- b_width_rounded = self.round_to_steps(resized_width)
161
- b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
162
- ar_width_rounded = b_width_rounded / b_height_in_wr
163
-
164
- b_height_rounded = self.round_to_steps(resized_height)
165
- b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
166
- ar_height_rounded = b_width_in_hr / b_height_rounded
167
-
168
- # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
169
- # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
170
-
171
- if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
172
- resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
173
- else:
174
- resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
175
- # print(resized_size)
176
- else:
177
- resized_size = (image_width, image_height) # リサイズは不要
178
-
179
- # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
180
- bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
181
- bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
182
- # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
183
-
184
- reso = (bucket_width, bucket_height)
185
-
186
- self.add_if_new_reso(reso)
187
-
188
- ar_error = (reso[0] / reso[1]) - aspect_ratio
189
- return reso, resized_size, ar_error
190
-
191
-
192
- class BucketBatchIndex(NamedTuple):
193
- bucket_index: int
194
- bucket_batch_size: int
195
- batch_index: int
196
-
197
-
198
- class BaseDataset(torch.utils.data.Dataset):
199
- def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
200
- super().__init__()
201
- self.tokenizer: CLIPTokenizer = tokenizer
202
- self.max_token_length = max_token_length
203
- self.shuffle_caption = shuffle_caption
204
- self.shuffle_keep_tokens = shuffle_keep_tokens
205
- # width/height is used when enable_bucket==False
206
- self.width, self.height = (None, None) if resolution is None else resolution
207
- self.face_crop_aug_range = face_crop_aug_range
208
- self.flip_aug = flip_aug
209
- self.color_aug = color_aug
210
- self.debug_dataset = debug_dataset
211
- self.random_crop = random_crop
212
- self.token_padding_disabled = False
213
- self.dataset_dirs_info = {}
214
- self.reg_dataset_dirs_info = {}
215
- self.tag_frequency = {}
216
-
217
- self.enable_bucket = False
218
- self.bucket_manager: BucketManager = None # not initialized
219
- self.min_bucket_reso = None
220
- self.max_bucket_reso = None
221
- self.bucket_reso_steps = None
222
- self.bucket_no_upscale = None
223
- self.bucket_info = None # for metadata
224
-
225
- self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
226
-
227
- self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
228
- self.dropout_rate: float = 0
229
- self.dropout_every_n_epochs: int = None
230
- self.tag_dropout_rate: float = 0
231
-
232
- # augmentation
233
- flip_p = 0.5 if flip_aug else 0.0
234
- if color_aug:
235
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
236
- self.aug = albu.Compose([
237
- albu.OneOf([
238
- albu.HueSaturationValue(8, 0, 0, p=.5),
239
- albu.RandomGamma((95, 105), p=.5),
240
- ], p=.33),
241
- albu.HorizontalFlip(p=flip_p)
242
- ], p=1.)
243
- elif flip_aug:
244
- self.aug = albu.Compose([
245
- albu.HorizontalFlip(p=flip_p)
246
- ], p=1.)
247
- else:
248
- self.aug = None
249
-
250
- self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
251
-
252
- self.image_data: Dict[str, ImageInfo] = {}
253
-
254
- self.replacements = {}
255
-
256
- def set_current_epoch(self, epoch):
257
- self.current_epoch = epoch
258
-
259
- def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
260
- # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
261
- self.dropout_rate = dropout_rate
262
- self.dropout_every_n_epochs = dropout_every_n_epochs
263
- self.tag_dropout_rate = tag_dropout_rate
264
-
265
- def set_tag_frequency(self, dir_name, captions):
266
- frequency_for_dir = self.tag_frequency.get(dir_name, {})
267
- self.tag_frequency[dir_name] = frequency_for_dir
268
- for caption in captions:
269
- for tag in caption.split(","):
270
- if tag and not tag.isspace():
271
- tag = tag.lower()
272
- frequency = frequency_for_dir.get(tag, 0)
273
- frequency_for_dir[tag] = frequency + 1
274
-
275
- def disable_token_padding(self):
276
- self.token_padding_disabled = True
277
-
278
- def add_replacement(self, str_from, str_to):
279
- self.replacements[str_from] = str_to
280
-
281
- def process_caption(self, caption):
282
- # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
283
- is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
284
- is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
285
-
286
- if is_drop_out:
287
- caption = ""
288
- else:
289
- if self.shuffle_caption or self.tag_dropout_rate > 0:
290
- def dropout_tags(tokens):
291
- if self.tag_dropout_rate <= 0:
292
- return tokens
293
- l = []
294
- for token in tokens:
295
- if random.random() >= self.tag_dropout_rate:
296
- l.append(token)
297
- return l
298
-
299
- tokens = [t.strip() for t in caption.strip().split(",")]
300
- if self.shuffle_keep_tokens is None:
301
- if self.shuffle_caption:
302
- random.shuffle(tokens)
303
-
304
- tokens = dropout_tags(tokens)
305
- else:
306
- if len(tokens) > self.shuffle_keep_tokens:
307
- keep_tokens = tokens[:self.shuffle_keep_tokens]
308
- tokens = tokens[self.shuffle_keep_tokens:]
309
-
310
- if self.shuffle_caption:
311
- random.shuffle(tokens)
312
-
313
- tokens = dropout_tags(tokens)
314
-
315
- tokens = keep_tokens + tokens
316
- caption = ", ".join(tokens)
317
-
318
- # textual inversion対応
319
- for str_from, str_to in self.replacements.items():
320
- if str_from == "":
321
- # replace all
322
- if type(str_to) == list:
323
- caption = random.choice(str_to)
324
- else:
325
- caption = str_to
326
- else:
327
- caption = caption.replace(str_from, str_to)
328
-
329
- return caption
330
-
331
- def get_input_ids(self, caption):
332
- input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
333
- max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
334
-
335
- if self.tokenizer_max_length > self.tokenizer.model_max_length:
336
- input_ids = input_ids.squeeze(0)
337
- iids_list = []
338
- if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
339
- # v1
340
- # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
341
- # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
342
- for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
343
- ids_chunk = (input_ids[0].unsqueeze(0),
344
- input_ids[i:i + self.tokenizer.model_max_length - 2],
345
- input_ids[-1].unsqueeze(0))
346
- ids_chunk = torch.cat(ids_chunk)
347
- iids_list.append(ids_chunk)
348
- else:
349
- # v2
350
- # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
351
- for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
352
- ids_chunk = (input_ids[0].unsqueeze(0), # BOS
353
- input_ids[i:i + self.tokenizer.model_max_length - 2],
354
- input_ids[-1].unsqueeze(0)) # PAD or EOS
355
- ids_chunk = torch.cat(ids_chunk)
356
-
357
- # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
358
- # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
359
- if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
360
- ids_chunk[-1] = self.tokenizer.eos_token_id
361
- # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
362
- if ids_chunk[1] == self.tokenizer.pad_token_id:
363
- ids_chunk[1] = self.tokenizer.eos_token_id
364
-
365
- iids_list.append(ids_chunk)
366
-
367
- input_ids = torch.stack(iids_list) # 3,77
368
- return input_ids
369
-
370
- def register_image(self, info: ImageInfo):
371
- self.image_data[info.image_key] = info
372
-
373
- def make_buckets(self):
374
- '''
375
- bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
376
- min_size and max_size are ignored when enable_bucket is False
377
- '''
378
- print("loading image sizes.")
379
- for info in tqdm(self.image_data.values()):
380
- if info.image_size is None:
381
- info.image_size = self.get_image_size(info.absolute_path)
382
-
383
- if self.enable_bucket:
384
- print("make buckets")
385
- else:
386
- print("prepare dataset")
387
-
388
- # bucketを作成し、画像をbucketに振り分ける
389
- if self.enable_bucket:
390
- if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
391
- self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
392
- self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
393
- if not self.bucket_no_upscale:
394
- self.bucket_manager.make_buckets()
395
- else:
396
- print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
397
-
398
- img_ar_errors = []
399
- for image_info in self.image_data.values():
400
- image_width, image_height = image_info.image_size
401
- image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
402
-
403
- # print(image_info.image_key, image_info.bucket_reso)
404
- img_ar_errors.append(abs(ar_error))
405
-
406
- self.bucket_manager.sort()
407
- else:
408
- self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
409
- self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
410
- for image_info in self.image_data.values():
411
- image_width, image_height = image_info.image_size
412
- image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
413
-
414
- for image_info in self.image_data.values():
415
- for _ in range(image_info.num_repeats):
416
- self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
417
-
418
- # bucket情報を表示、格納する
419
- if self.enable_bucket:
420
- self.bucket_info = {"buckets": {}}
421
- print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
422
- for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
423
- count = len(bucket)
424
- if count > 0:
425
- self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
426
- print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
427
-
428
- img_ar_errors = np.array(img_ar_errors)
429
- mean_img_ar_error = np.mean(np.abs(img_ar_errors))
430
- self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
431
- print(f"mean ar error (without repeats): {mean_img_ar_error}")
432
-
433
- # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
434
- self.buckets_indices: List(BucketBatchIndex) = []
435
- for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
436
- batch_count = int(math.ceil(len(bucket) / self.batch_size))
437
- for batch_index in range(batch_count):
438
- self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
439
-
440
- # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
441
- #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
442
- #
443
- # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
444
- # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
445
- # # そのためバッチサイズを画像種類までに制限する
446
- # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
447
- # # TO DO 正則化画像をepochまたがりで利用する仕組み
448
- # num_of_image_types = len(set(bucket))
449
- # bucket_batch_size = min(self.batch_size, num_of_image_types)
450
- # batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
451
- # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
452
- # for batch_index in range(batch_count):
453
- # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
454
- # ↑ここまで
455
-
456
- self.shuffle_buckets()
457
- self._length = len(self.buckets_indices)
458
-
459
- def shuffle_buckets(self):
460
- random.shuffle(self.buckets_indices)
461
- self.bucket_manager.shuffle()
462
-
463
- def load_image(self, image_path):
464
- image = Image.open(image_path)
465
- if not image.mode == "RGB":
466
- image = image.convert("RGB")
467
- img = np.array(image, np.uint8)
468
- return img
469
-
470
- def trim_and_resize_if_required(self, image, reso, resized_size):
471
- image_height, image_width = image.shape[0:2]
472
-
473
- if image_width != resized_size[0] or image_height != resized_size[1]:
474
- # リサイズする
475
- image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
476
-
477
- image_height, image_width = image.shape[0:2]
478
- if image_width > reso[0]:
479
- trim_size = image_width - reso[0]
480
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
481
- # print("w", trim_size, p)
482
- image = image[:, p:p + reso[0]]
483
- if image_height > reso[1]:
484
- trim_size = image_height - reso[1]
485
- p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
486
- # print("h", trim_size, p)
487
- image = image[p:p + reso[1]]
488
-
489
- assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
490
- return image
491
-
492
- def cache_latents(self, vae):
493
- # TODO ここを高速化したい
494
- print("caching latents.")
495
- for info in tqdm(self.image_data.values()):
496
- if info.latents_npz is not None:
497
- info.latents = self.load_latents_from_npz(info, False)
498
- info.latents = torch.FloatTensor(info.latents)
499
- info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
500
- if info.latents_flipped is not None:
501
- info.latents_flipped = torch.FloatTensor(info.latents_flipped)
502
- continue
503
-
504
- image = self.load_image(info.absolute_path)
505
- image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
506
-
507
- img_tensor = self.image_transforms(image)
508
- img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
509
- info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
510
-
511
- if self.flip_aug:
512
- image = image[:, ::-1].copy() # cannot convert to Tensor without copy
513
- img_tensor = self.image_transforms(image)
514
- img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
515
- info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
516
-
517
- def get_image_size(self, image_path):
518
- image = Image.open(image_path)
519
- return image.size
520
-
521
- def load_image_with_face_info(self, image_path: str):
522
- img = self.load_image(image_path)
523
-
524
- face_cx = face_cy = face_w = face_h = 0
525
- if self.face_crop_aug_range is not None:
526
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
527
- if len(tokens) >= 5:
528
- face_cx = int(tokens[-4])
529
- face_cy = int(tokens[-3])
530
- face_w = int(tokens[-2])
531
- face_h = int(tokens[-1])
532
-
533
- return img, face_cx, face_cy, face_w, face_h
534
-
535
- # いい感じに切り出す
536
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
537
- height, width = image.shape[0:2]
538
- if height == self.height and width == self.width:
539
- return image
540
-
541
- # 画像サイズはsizeより大きいのでリサイズする
542
- face_size = max(face_w, face_h)
543
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
544
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
545
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
546
- if min_scale >= max_scale: # range指定がmin==max
547
- scale = min_scale
548
- else:
549
- scale = random.uniform(min_scale, max_scale)
550
-
551
- nh = int(height * scale + .5)
552
- nw = int(width * scale + .5)
553
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
554
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
555
- face_cx = int(face_cx * scale + .5)
556
- face_cy = int(face_cy * scale + .5)
557
- height, width = nh, nw
558
-
559
- # 顔を中心として448*640とかへ切り出す
560
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
561
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
562
-
563
- if self.random_crop:
564
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
565
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
566
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
567
- else:
568
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
569
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
570
- if face_size > self.size // 10 and face_size >= 40:
571
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
572
-
573
- p1 = max(0, min(p1, length - target_size))
574
-
575
- if axis == 0:
576
- image = image[p1:p1 + target_size, :]
577
- else:
578
- image = image[:, p1:p1 + target_size]
579
-
580
- return image
581
-
582
- def load_latents_from_npz(self, image_info: ImageInfo, flipped):
583
- npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
584
- if npz_file is None:
585
- return None
586
- return np.load(npz_file)['arr_0']
587
-
588
- def __len__(self):
589
- return self._length
590
-
591
- def __getitem__(self, index):
592
- if index == 0:
593
- self.shuffle_buckets()
594
-
595
- bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
596
- bucket_batch_size = self.buckets_indices[index].bucket_batch_size
597
- image_index = self.buckets_indices[index].batch_index * bucket_batch_size
598
-
599
- loss_weights = []
600
- captions = []
601
- input_ids_list = []
602
- latents_list = []
603
- images = []
604
-
605
- for image_key in bucket[image_index:image_index + bucket_batch_size]:
606
- image_info = self.image_data[image_key]
607
- loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
608
-
609
- # image/latentsを処理する
610
- if image_info.latents is not None:
611
- latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
612
- image = None
613
- elif image_info.latents_npz is not None:
614
- latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
615
- latents = torch.FloatTensor(latents)
616
- image = None
617
- else:
618
- # 画像を読み込み、必要ならcropする
619
- img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
620
- im_h, im_w = img.shape[0:2]
621
-
622
- if self.enable_bucket:
623
- img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
624
- else:
625
- if face_cx > 0: # 顔位置情報あり
626
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
627
- elif im_h > self.height or im_w > self.width:
628
- assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
629
- if im_h > self.height:
630
- p = random.randint(0, im_h - self.height)
631
- img = img[p:p + self.height]
632
- if im_w > self.width:
633
- p = random.randint(0, im_w - self.width)
634
- img = img[:, p:p + self.width]
635
-
636
- im_h, im_w = img.shape[0:2]
637
- assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
638
-
639
- # augmentation
640
- if self.aug is not None:
641
- img = self.aug(image=img)['image']
642
-
643
- latents = None
644
- image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
645
-
646
- images.append(image)
647
- latents_list.append(latents)
648
-
649
- caption = self.process_caption(image_info.caption)
650
- captions.append(caption)
651
- if not self.token_padding_disabled: # this option might be omitted in future
652
- input_ids_list.append(self.get_input_ids(caption))
653
-
654
- example = {}
655
- example['loss_weights'] = torch.FloatTensor(loss_weights)
656
-
657
- if self.token_padding_disabled:
658
- # padding=True means pad in the batch
659
- example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
660
- else:
661
- # batch processing seems to be good
662
- example['input_ids'] = torch.stack(input_ids_list)
663
-
664
- if images[0] is not None:
665
- images = torch.stack(images)
666
- images = images.to(memory_format=torch.contiguous_format).float()
667
- else:
668
- images = None
669
- example['images'] = images
670
-
671
- example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
672
-
673
- if self.debug_dataset:
674
- example['image_keys'] = bucket[image_index:image_index + self.batch_size]
675
- example['captions'] = captions
676
- return example
677
-
678
-
679
- class DreamBoothDataset(BaseDataset):
680
- def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
681
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
682
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
683
-
684
- assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
685
-
686
- self.batch_size = batch_size
687
- self.size = min(self.width, self.height) # 短いほう
688
- self.prior_loss_weight = prior_loss_weight
689
- self.latents_cache = None
690
-
691
- self.enable_bucket = enable_bucket
692
- if self.enable_bucket:
693
- assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
694
- assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
695
- self.min_bucket_reso = min_bucket_reso
696
- self.max_bucket_reso = max_bucket_reso
697
- self.bucket_reso_steps = bucket_reso_steps
698
- self.bucket_no_upscale = bucket_no_upscale
699
- else:
700
- self.min_bucket_reso = None
701
- self.max_bucket_reso = None
702
- self.bucket_reso_steps = None # この情報は使われない
703
- self.bucket_no_upscale = False
704
-
705
- def read_caption(img_path):
706
- # captionの候補ファイル名を作る
707
- base_name = os.path.splitext(img_path)[0]
708
- base_name_face_det = base_name
709
- tokens = base_name.split("_")
710
- if len(tokens) >= 5:
711
- base_name_face_det = "_".join(tokens[:-4])
712
- cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
713
-
714
- caption = None
715
- for cap_path in cap_paths:
716
- if os.path.isfile(cap_path):
717
- with open(cap_path, "rt", encoding='utf-8') as f:
718
- try:
719
- lines = f.readlines()
720
- except UnicodeDecodeError as e:
721
- print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
722
- raise e
723
- assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
724
- caption = lines[0].strip()
725
- break
726
- return caption
727
-
728
- def load_dreambooth_dir(dir):
729
- if not os.path.isdir(dir):
730
- # print(f"ignore file: {dir}")
731
- return 0, [], []
732
-
733
- tokens = os.path.basename(dir).split('_')
734
- try:
735
- n_repeats = int(tokens[0])
736
- except ValueError as e:
737
- print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
738
- return 0, [], []
739
-
740
- caption_by_folder = '_'.join(tokens[1:])
741
- img_paths = glob_images(dir, "*")
742
- print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
743
-
744
- # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
745
- captions = []
746
- for img_path in img_paths:
747
- cap_for_img = read_caption(img_path)
748
- captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
749
-
750
- self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
751
-
752
- return n_repeats, img_paths, captions
753
-
754
- print("prepare train images.")
755
- train_dirs = os.listdir(train_data_dir)
756
- num_train_images = 0
757
- for dir in train_dirs:
758
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
759
- num_train_images += n_repeats * len(img_paths)
760
-
761
- for img_path, caption in zip(img_paths, captions):
762
- info = ImageInfo(img_path, n_repeats, caption, False, img_path)
763
- self.register_image(info)
764
-
765
- self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
766
-
767
- print(f"{num_train_images} train images with repeating.")
768
- self.num_train_images = num_train_images
769
-
770
- # reg imageは数を数えて学習画像と同じ枚数にする
771
- num_reg_images = 0
772
- if reg_data_dir:
773
- print("prepare reg images.")
774
- reg_infos: List[ImageInfo] = []
775
-
776
- reg_dirs = os.listdir(reg_data_dir)
777
- for dir in reg_dirs:
778
- n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
779
- num_reg_images += n_repeats * len(img_paths)
780
-
781
- for img_path, caption in zip(img_paths, captions):
782
- info = ImageInfo(img_path, n_repeats, caption, True, img_path)
783
- reg_infos.append(info)
784
-
785
- self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
786
-
787
- print(f"{num_reg_images} reg images.")
788
- if num_train_images < num_reg_images:
789
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
790
-
791
- if num_reg_images == 0:
792
- print("no regularization images / 正則化画像が見つかりませんでした")
793
- else:
794
- # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
795
- n = 0
796
- first_loop = True
797
- while n < num_train_images:
798
- for info in reg_infos:
799
- if first_loop:
800
- self.register_image(info)
801
- n += info.num_repeats
802
- else:
803
- info.num_repeats += 1
804
- n += 1
805
- if n >= num_train_images:
806
- break
807
- first_loop = False
808
-
809
- self.num_reg_images = num_reg_images
810
-
811
-
812
- class FineTuningDataset(BaseDataset):
813
- def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
814
- super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
815
- resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
816
-
817
- # メタデータを読み込む
818
- if os.path.exists(json_file_name):
819
- print(f"loading existing metadata: {json_file_name}")
820
- with open(json_file_name, "rt", encoding='utf-8') as f:
821
- metadata = json.load(f)
822
- else:
823
- raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
824
-
825
- self.metadata = metadata
826
- self.train_data_dir = train_data_dir
827
- self.batch_size = batch_size
828
-
829
- tags_list = []
830
- for image_key, img_md in metadata.items():
831
- # path情報を作る
832
- if os.path.exists(image_key):
833
- abs_path = image_key
834
- else:
835
- # わりといい加減だがいい方法が思いつかん
836
- abs_path = glob_images(train_data_dir, image_key)
837
- assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
838
- abs_path = abs_path[0]
839
-
840
- caption = img_md.get('caption')
841
- tags = img_md.get('tags')
842
- if caption is None:
843
- caption = tags
844
- elif tags is not None and len(tags) > 0:
845
- caption = caption + ', ' + tags
846
- tags_list.append(tags)
847
- assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
848
-
849
- image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
850
- image_info.image_size = img_md.get('train_resolution')
851
-
852
- if not self.color_aug and not self.random_crop:
853
- # if npz exists, use them
854
- image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
855
-
856
- self.register_image(image_info)
857
- self.num_train_images = len(metadata) * dataset_repeats
858
- self.num_reg_images = 0
859
-
860
- # TODO do not record tag freq when no tag
861
- self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
862
- self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
863
-
864
- # check existence of all npz files
865
- use_npz_latents = not (self.color_aug or self.random_crop)
866
- if use_npz_latents:
867
- npz_any = False
868
- npz_all = True
869
- for image_info in self.image_data.values():
870
- has_npz = image_info.latents_npz is not None
871
- npz_any = npz_any or has_npz
872
-
873
- if self.flip_aug:
874
- has_npz = has_npz and image_info.latents_npz_flipped is not None
875
- npz_all = npz_all and has_npz
876
-
877
- if npz_any and not npz_all:
878
- break
879
-
880
- if not npz_any:
881
- use_npz_latents = False
882
- print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
883
- elif not npz_all:
884
- use_npz_latents = False
885
- print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
886
- if self.flip_aug:
887
- print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
888
- # else:
889
- # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
890
-
891
- # check min/max bucket size
892
- sizes = set()
893
- resos = set()
894
- for image_info in self.image_data.values():
895
- if image_info.image_size is None:
896
- sizes = None # not calculated
897
- break
898
- sizes.add(image_info.image_size[0])
899
- sizes.add(image_info.image_size[1])
900
- resos.add(tuple(image_info.image_size))
901
-
902
- if sizes is None:
903
- if use_npz_latents:
904
- use_npz_latents = False
905
- print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
906
-
907
- assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
908
-
909
- self.enable_bucket = enable_bucket
910
- if self.enable_bucket:
911
- self.min_bucket_reso = min_bucket_reso
912
- self.max_bucket_reso = max_bucket_reso
913
- self.bucket_reso_steps = bucket_reso_steps
914
- self.bucket_no_upscale = bucket_no_upscale
915
- else:
916
- if not enable_bucket:
917
- print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
918
- print("using bucket info in metadata / メタデータ内のbucket情報を使います")
919
- self.enable_bucket = True
920
-
921
- assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
922
-
923
- # bucket情報を初期化しておく、make_bucketsで再作成しない
924
- self.bucket_manager = BucketManager(False, None, None, None, None)
925
- self.bucket_manager.set_predefined_resos(resos)
926
-
927
- # npz情報をきれいにしておく
928
- if not use_npz_latents:
929
- for image_info in self.image_data.values():
930
- image_info.latents_npz = image_info.latents_npz_flipped = None
931
-
932
- def image_key_to_npz_file(self, image_key):
933
- base_name = os.path.splitext(image_key)[0]
934
- npz_file_norm = base_name + '.npz'
935
-
936
- if os.path.exists(npz_file_norm):
937
- # image_key is full path
938
- npz_file_flip = base_name + '_flip.npz'
939
- if not os.path.exists(npz_file_flip):
940
- npz_file_flip = None
941
- return npz_file_norm, npz_file_flip
942
-
943
- # image_key is relative path
944
- npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
945
- npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
946
-
947
- if not os.path.exists(npz_file_norm):
948
- npz_file_norm = None
949
- npz_file_flip = None
950
- elif not os.path.exists(npz_file_flip):
951
- npz_file_flip = None
952
-
953
- return npz_file_norm, npz_file_flip
954
-
955
-
956
- def debug_dataset(train_dataset, show_input_ids=False):
957
- print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
958
- print("Escape for exit. / Escキーで中断、終了します")
959
-
960
- train_dataset.set_current_epoch(1)
961
- k = 0
962
- for i, example in enumerate(train_dataset):
963
- if example['latents'] is not None:
964
- print(f"sample has latents from npz file: {example['latents'].size()}")
965
- for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
966
- print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
967
- if show_input_ids:
968
- print(f"input ids: {iid}")
969
- if example['images'] is not None:
970
- im = example['images'][j]
971
- print(f"image size: {im.size()}")
972
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
973
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
974
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
975
- if os.name == 'nt': # only windows
976
- cv2.imshow("img", im)
977
- k = cv2.waitKey()
978
- cv2.destroyAllWindows()
979
- if k == 27:
980
- break
981
- if k == 27 or (example['images'] is None and i >= 8):
982
- break
983
-
984
-
985
- def glob_images(directory, base="*"):
986
- img_paths = []
987
- for ext in IMAGE_EXTENSIONS:
988
- if base == '*':
989
- img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
990
- else:
991
- img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
992
- # img_paths = list(set(img_paths)) # 重複を排除
993
- # img_paths.sort()
994
- return img_paths
995
-
996
-
997
- def glob_images_pathlib(dir_path, recursive):
998
- image_paths = []
999
- if recursive:
1000
- for ext in IMAGE_EXTENSIONS:
1001
- image_paths += list(dir_path.rglob('*' + ext))
1002
- else:
1003
- for ext in IMAGE_EXTENSIONS:
1004
- image_paths += list(dir_path.glob('*' + ext))
1005
- # image_paths = list(set(image_paths)) # 重複を排除
1006
- # image_paths.sort()
1007
- return image_paths
1008
-
1009
- # endregion
1010
-
1011
-
1012
- # region モジュール入れ替え部
1013
- """
1014
- 高速化のためのモジュール入れ替え
1015
- """
1016
-
1017
- # FlashAttentionを使うCrossAttention
1018
- # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
1019
- # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
1020
-
1021
- # constants
1022
-
1023
- EPSILON = 1e-6
1024
-
1025
- # helper functions
1026
-
1027
-
1028
- def exists(val):
1029
- return val is not None
1030
-
1031
-
1032
- def default(val, d):
1033
- return val if exists(val) else d
1034
-
1035
-
1036
- def model_hash(filename):
1037
- """Old model hash used by stable-diffusion-webui"""
1038
- try:
1039
- with open(filename, "rb") as file:
1040
- m = hashlib.sha256()
1041
-
1042
- file.seek(0x100000)
1043
- m.update(file.read(0x10000))
1044
- return m.hexdigest()[0:8]
1045
- except FileNotFoundError:
1046
- return 'NOFILE'
1047
-
1048
-
1049
- def calculate_sha256(filename):
1050
- """New model hash used by stable-diffusion-webui"""
1051
- hash_sha256 = hashlib.sha256()
1052
- blksize = 1024 * 1024
1053
-
1054
- with open(filename, "rb") as f:
1055
- for chunk in iter(lambda: f.read(blksize), b""):
1056
- hash_sha256.update(chunk)
1057
-
1058
- return hash_sha256.hexdigest()
1059
-
1060
-
1061
- def precalculate_safetensors_hashes(tensors, metadata):
1062
- """Precalculate the model hashes needed by sd-webui-additional-networks to
1063
- save time on indexing the model later."""
1064
-
1065
- # Because writing user metadata to the file can change the result of
1066
- # sd_models.model_hash(), only retain the training metadata for purposes of
1067
- # calculating the hash, as they are meant to be immutable
1068
- metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
1069
-
1070
- bytes = safetensors.torch.save(tensors, metadata)
1071
- b = BytesIO(bytes)
1072
-
1073
- model_hash = addnet_hash_safetensors(b)
1074
- legacy_hash = addnet_hash_legacy(b)
1075
- return model_hash, legacy_hash
1076
-
1077
-
1078
- def addnet_hash_legacy(b):
1079
- """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
1080
- m = hashlib.sha256()
1081
-
1082
- b.seek(0x100000)
1083
- m.update(b.read(0x10000))
1084
- return m.hexdigest()[0:8]
1085
-
1086
-
1087
- def addnet_hash_safetensors(b):
1088
- """New model hash used by sd-webui-additional-networks for .safetensors format files"""
1089
- hash_sha256 = hashlib.sha256()
1090
- blksize = 1024 * 1024
1091
-
1092
- b.seek(0)
1093
- header = b.read(8)
1094
- n = int.from_bytes(header, "little")
1095
-
1096
- offset = n + 8
1097
- b.seek(offset)
1098
- for chunk in iter(lambda: b.read(blksize), b""):
1099
- hash_sha256.update(chunk)
1100
-
1101
- return hash_sha256.hexdigest()
1102
-
1103
-
1104
- def get_git_revision_hash() -> str:
1105
- try:
1106
- return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
1107
- except:
1108
- return "(unknown)"
1109
-
1110
-
1111
- # flash attention forwards and backwards
1112
-
1113
- # https://arxiv.org/abs/2205.14135
1114
-
1115
-
1116
- class FlashAttentionFunction(torch.autograd.function.Function):
1117
- @ staticmethod
1118
- @ torch.no_grad()
1119
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
1120
- """ Algorithm 2 in the paper """
1121
-
1122
- device = q.device
1123
- dtype = q.dtype
1124
- max_neg_value = -torch.finfo(q.dtype).max
1125
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1126
-
1127
- o = torch.zeros_like(q)
1128
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
1129
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
1130
-
1131
- scale = (q.shape[-1] ** -0.5)
1132
-
1133
- if not exists(mask):
1134
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
1135
- else:
1136
- mask = rearrange(mask, 'b n -> b 1 1 n')
1137
- mask = mask.split(q_bucket_size, dim=-1)
1138
-
1139
- row_splits = zip(
1140
- q.split(q_bucket_size, dim=-2),
1141
- o.split(q_bucket_size, dim=-2),
1142
- mask,
1143
- all_row_sums.split(q_bucket_size, dim=-2),
1144
- all_row_maxes.split(q_bucket_size, dim=-2),
1145
- )
1146
-
1147
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
1148
- q_start_index = ind * q_bucket_size - qk_len_diff
1149
-
1150
- col_splits = zip(
1151
- k.split(k_bucket_size, dim=-2),
1152
- v.split(k_bucket_size, dim=-2),
1153
- )
1154
-
1155
- for k_ind, (kc, vc) in enumerate(col_splits):
1156
- k_start_index = k_ind * k_bucket_size
1157
-
1158
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1159
-
1160
- if exists(row_mask):
1161
- attn_weights.masked_fill_(~row_mask, max_neg_value)
1162
-
1163
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1164
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1165
- device=device).triu(q_start_index - k_start_index + 1)
1166
- attn_weights.masked_fill_(causal_mask, max_neg_value)
1167
-
1168
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
1169
- attn_weights -= block_row_maxes
1170
- exp_weights = torch.exp(attn_weights)
1171
-
1172
- if exists(row_mask):
1173
- exp_weights.masked_fill_(~row_mask, 0.)
1174
-
1175
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
1176
-
1177
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
1178
-
1179
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
1180
-
1181
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
1182
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
1183
-
1184
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
1185
-
1186
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
1187
-
1188
- row_maxes.copy_(new_row_maxes)
1189
- row_sums.copy_(new_row_sums)
1190
-
1191
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
1192
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
1193
-
1194
- return o
1195
-
1196
- @ staticmethod
1197
- @ torch.no_grad()
1198
- def backward(ctx, do):
1199
- """ Algorithm 4 in the paper """
1200
-
1201
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
1202
- q, k, v, o, l, m = ctx.saved_tensors
1203
-
1204
- device = q.device
1205
-
1206
- max_neg_value = -torch.finfo(q.dtype).max
1207
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
1208
-
1209
- dq = torch.zeros_like(q)
1210
- dk = torch.zeros_like(k)
1211
- dv = torch.zeros_like(v)
1212
-
1213
- row_splits = zip(
1214
- q.split(q_bucket_size, dim=-2),
1215
- o.split(q_bucket_size, dim=-2),
1216
- do.split(q_bucket_size, dim=-2),
1217
- mask,
1218
- l.split(q_bucket_size, dim=-2),
1219
- m.split(q_bucket_size, dim=-2),
1220
- dq.split(q_bucket_size, dim=-2)
1221
- )
1222
-
1223
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
1224
- q_start_index = ind * q_bucket_size - qk_len_diff
1225
-
1226
- col_splits = zip(
1227
- k.split(k_bucket_size, dim=-2),
1228
- v.split(k_bucket_size, dim=-2),
1229
- dk.split(k_bucket_size, dim=-2),
1230
- dv.split(k_bucket_size, dim=-2),
1231
- )
1232
-
1233
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
1234
- k_start_index = k_ind * k_bucket_size
1235
-
1236
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
1237
-
1238
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1239
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
1240
- device=device).triu(q_start_index - k_start_index + 1)
1241
- attn_weights.masked_fill_(causal_mask, max_neg_value)
1242
-
1243
- exp_attn_weights = torch.exp(attn_weights - mc)
1244
-
1245
- if exists(row_mask):
1246
- exp_attn_weights.masked_fill_(~row_mask, 0.)
1247
-
1248
- p = exp_attn_weights / lc
1249
-
1250
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
1251
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
1252
-
1253
- D = (doc * oc).sum(dim=-1, keepdims=True)
1254
- ds = p * scale * (dp - D)
1255
-
1256
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
1257
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
1258
-
1259
- dqc.add_(dq_chunk)
1260
- dkc.add_(dk_chunk)
1261
- dvc.add_(dv_chunk)
1262
-
1263
- return dq, dk, dv, None, None, None, None
1264
-
1265
-
1266
- def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
1267
- if mem_eff_attn:
1268
- replace_unet_cross_attn_to_memory_efficient()
1269
- elif xformers:
1270
- replace_unet_cross_attn_to_xformers()
1271
-
1272
-
1273
- def replace_unet_cross_attn_to_memory_efficient():
1274
- print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
1275
- flash_func = FlashAttentionFunction
1276
-
1277
- def forward_flash_attn(self, x, context=None, mask=None):
1278
- q_bucket_size = 512
1279
- k_bucket_size = 1024
1280
-
1281
- h = self.heads
1282
- q = self.to_q(x)
1283
-
1284
- context = context if context is not None else x
1285
- context = context.to(x.dtype)
1286
-
1287
- if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1288
- context_k, context_v = self.hypernetwork.forward(x, context)
1289
- context_k = context_k.to(x.dtype)
1290
- context_v = context_v.to(x.dtype)
1291
- else:
1292
- context_k = context
1293
- context_v = context
1294
-
1295
- k = self.to_k(context_k)
1296
- v = self.to_v(context_v)
1297
- del context, x
1298
-
1299
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
1300
-
1301
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
1302
-
1303
- out = rearrange(out, 'b h n d -> b n (h d)')
1304
-
1305
- # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
1306
- out = self.to_out[0](out)
1307
- out = self.to_out[1](out)
1308
- return out
1309
-
1310
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
1311
-
1312
-
1313
- def replace_unet_cross_attn_to_xformers():
1314
- print("Replace CrossAttention.forward to use xformers")
1315
- try:
1316
- import xformers.ops
1317
- except ImportError:
1318
- raise ImportError("No xformers / xformersがインストールされていないようです")
1319
-
1320
- def forward_xformers(self, x, context=None, mask=None):
1321
- h = self.heads
1322
- q_in = self.to_q(x)
1323
-
1324
- context = default(context, x)
1325
- context = context.to(x.dtype)
1326
-
1327
- if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
1328
- context_k, context_v = self.hypernetwork.forward(x, context)
1329
- context_k = context_k.to(x.dtype)
1330
- context_v = context_v.to(x.dtype)
1331
- else:
1332
- context_k = context
1333
- context_v = context
1334
-
1335
- k_in = self.to_k(context_k)
1336
- v_in = self.to_v(context_v)
1337
-
1338
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
1339
- del q_in, k_in, v_in
1340
-
1341
- q = q.contiguous()
1342
- k = k.contiguous()
1343
- v = v.contiguous()
1344
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
1345
-
1346
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
1347
-
1348
- # diffusers 0.7.0~
1349
- out = self.to_out[0](out)
1350
- out = self.to_out[1](out)
1351
- return out
1352
-
1353
- diffusers.models.attention.CrossAttention.forward = forward_xformers
1354
- # endregion
1355
-
1356
-
1357
- # region arguments
1358
-
1359
- def add_sd_models_arguments(parser: argparse.ArgumentParser):
1360
- # for pretrained models
1361
- parser.add_argument("--v2", action='store_true',
1362
- help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
1363
- parser.add_argument("--v_parameterization", action='store_true',
1364
- help='enable v-parameterization training / v-parameterization学習を有効にする')
1365
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
1366
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
1367
-
1368
-
1369
- def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
1370
- parser.add_argument("--output_dir", type=str, default=None,
1371
- help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
1372
- parser.add_argument("--output_name", type=str, default=None,
1373
- help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
1374
- parser.add_argument("--save_precision", type=str, default=None,
1375
- choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
1376
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
1377
- help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
1378
- parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
1379
- help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
1380
- parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
1381
- parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
1382
- help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
1383
- parser.add_argument("--save_state", action="store_true",
1384
- help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
1385
- parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1386
-
1387
- parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
1388
- parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
1389
- help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
1390
- parser.add_argument("--use_8bit_adam", action="store_true",
1391
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
1392
- parser.add_argument("--use_lion_optimizer", action="store_true",
1393
- help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
1394
- parser.add_argument("--mem_eff_attn", action="store_true",
1395
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
1396
- parser.add_argument("--xformers", action="store_true",
1397
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
1398
- parser.add_argument("--vae", type=str, default=None,
1399
- help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
1400
-
1401
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1402
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1403
- parser.add_argument("--max_train_epochs", type=int, default=None,
1404
- help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
1405
- parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
1406
- help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
1407
- parser.add_argument("--persistent_data_loader_workers", action="store_true",
1408
- help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
1409
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1410
- parser.add_argument("--gradient_checkpointing", action="store_true",
1411
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
1412
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
1413
- help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
1414
- parser.add_argument("--mixed_precision", type=str, default="no",
1415
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
1416
- parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1417
- parser.add_argument("--clip_skip", type=int, default=None,
1418
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
1419
- parser.add_argument("--logging_dir", type=str, default=None,
1420
- help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
1421
- parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
1422
- parser.add_argument("--lr_scheduler", type=str, default="constant",
1423
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
1424
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
1425
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
1426
- parser.add_argument("--noise_offset", type=float, default=None,
1427
- help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
1428
- parser.add_argument("--lowram", action="store_true",
1429
- help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
1430
-
1431
- if support_dreambooth:
1432
- # DreamBooth training
1433
- parser.add_argument("--prior_loss_weight", type=float, default=1.0,
1434
- help="loss weight for regularization images / 正則化画像のlossの重み")
1435
-
1436
-
1437
- def verify_training_args(args: argparse.Namespace):
1438
- if args.v_parameterization and not args.v2:
1439
- print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
1440
- if args.v2 and args.clip_skip is not None:
1441
- print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
1442
-
1443
-
1444
- def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
1445
- # dataset common
1446
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
1447
- parser.add_argument("--shuffle_caption", action="store_true",
1448
- help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
1449
- parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
1450
- parser.add_argument("--caption_extention", type=str, default=None,
1451
- help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
1452
- parser.add_argument("--keep_tokens", type=int, default=None,
1453
- help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
1454
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
1455
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
1456
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
1457
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
1458
- parser.add_argument("--random_crop", action="store_true",
1459
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
1460
- parser.add_argument("--debug_dataset", action="store_true",
1461
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
1462
- parser.add_argument("--resolution", type=str, default=None,
1463
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
1464
- parser.add_argument("--cache_latents", action="store_true",
1465
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
1466
- parser.add_argument("--enable_bucket", action="store_true",
1467
- help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
1468
- parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
1469
- parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
1470
- parser.add_argument("--bucket_reso_steps", type=int, default=64,
1471
- help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
1472
- parser.add_argument("--bucket_no_upscale", action="store_true",
1473
- help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
1474
-
1475
- if support_caption_dropout:
1476
- # Textual Inversion はcaptionのdropoutをsupportしない
1477
- # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
1478
- parser.add_argument("--caption_dropout_rate", type=float, default=0,
1479
- help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
1480
- parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
1481
- help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
1482
- parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
1483
- help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
1484
-
1485
- if support_dreambooth:
1486
- # DreamBooth dataset
1487
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
1488
-
1489
- if support_caption:
1490
- # caption dataset
1491
- parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
1492
- parser.add_argument("--dataset_repeats", type=int, default=1,
1493
- help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
1494
-
1495
-
1496
- def add_sd_saving_arguments(parser: argparse.ArgumentParser):
1497
- parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
1498
- help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
1499
- parser.add_argument("--use_safetensors", action='store_true',
1500
- help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
1501
-
1502
- # endregion
1503
-
1504
- # region utils
1505
-
1506
-
1507
- def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
1508
- # backward compatibility
1509
- if args.caption_extention is not None:
1510
- args.caption_extension = args.caption_extention
1511
- args.caption_extention = None
1512
-
1513
- if args.cache_latents:
1514
- assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
1515
- assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
1516
-
1517
- # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
1518
- if args.resolution is not None:
1519
- args.resolution = tuple([int(r) for r in args.resolution.split(',')])
1520
- if len(args.resolution) == 1:
1521
- args.resolution = (args.resolution[0], args.resolution[0])
1522
- assert len(args.resolution) == 2, \
1523
- f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
1524
-
1525
- if args.face_crop_aug_range is not None:
1526
- args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
1527
- assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
1528
- f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
1529
- else:
1530
- args.face_crop_aug_range = None
1531
-
1532
- if support_metadata:
1533
- if args.in_json is not None and (args.color_aug or args.random_crop):
1534
- print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
1535
-
1536
-
1537
- def load_tokenizer(args: argparse.Namespace):
1538
- print("prepare tokenizer")
1539
- if args.v2:
1540
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
1541
- else:
1542
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
1543
- if args.max_token_length is not None:
1544
- print(f"update token length: {args.max_token_length}")
1545
- return tokenizer
1546
-
1547
-
1548
- def prepare_accelerator(args: argparse.Namespace):
1549
- if args.logging_dir is None:
1550
- log_with = None
1551
- logging_dir = None
1552
- else:
1553
- log_with = "tensorboard"
1554
- log_prefix = "" if args.log_prefix is None else args.log_prefix
1555
- logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
1556
-
1557
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
1558
- log_with=log_with, logging_dir=logging_dir)
1559
-
1560
- # accelerateの互換性問題を解決する
1561
- accelerator_0_15 = True
1562
- try:
1563
- accelerator.unwrap_model("dummy", True)
1564
- print("Using accelerator 0.15.0 or above.")
1565
- except TypeError:
1566
- accelerator_0_15 = False
1567
-
1568
- def unwrap_model(model):
1569
- if accelerator_0_15:
1570
- return accelerator.unwrap_model(model, True)
1571
- return accelerator.unwrap_model(model)
1572
-
1573
- return accelerator, unwrap_model
1574
-
1575
-
1576
- def prepare_dtype(args: argparse.Namespace):
1577
- weight_dtype = torch.float32
1578
- if args.mixed_precision == "fp16":
1579
- weight_dtype = torch.float16
1580
- elif args.mixed_precision == "bf16":
1581
- weight_dtype = torch.bfloat16
1582
-
1583
- save_dtype = None
1584
- if args.save_precision == "fp16":
1585
- save_dtype = torch.float16
1586
- elif args.save_precision == "bf16":
1587
- save_dtype = torch.bfloat16
1588
- elif args.save_precision == "float":
1589
- save_dtype = torch.float32
1590
-
1591
- return weight_dtype, save_dtype
1592
-
1593
-
1594
- def load_target_model(args: argparse.Namespace, weight_dtype):
1595
- load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
1596
- if load_stable_diffusion_format:
1597
- print("load StableDiffusion checkpoint")
1598
- text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
1599
- else:
1600
- print("load Diffusers pretrained models")
1601
- pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
1602
- text_encoder = pipe.text_encoder
1603
- vae = pipe.vae
1604
- unet = pipe.unet
1605
- del pipe
1606
-
1607
- # VAEを読み込む
1608
- if args.vae is not None:
1609
- vae = model_util.load_vae(args.vae, weight_dtype)
1610
- print("additional VAE loaded")
1611
-
1612
- return text_encoder, vae, unet, load_stable_diffusion_format
1613
-
1614
-
1615
- def patch_accelerator_for_fp16_training(accelerator):
1616
- org_unscale_grads = accelerator.scaler._unscale_grads_
1617
-
1618
- def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
1619
- return org_unscale_grads(optimizer, inv_scale, found_inf, True)
1620
-
1621
- accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
1622
-
1623
-
1624
- def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
1625
- # with no_token_padding, the length is not max length, return result immediately
1626
- if input_ids.size()[-1] != tokenizer.model_max_length:
1627
- return text_encoder(input_ids)[0]
1628
-
1629
- b_size = input_ids.size()[0]
1630
- input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
1631
-
1632
- if args.clip_skip is None:
1633
- encoder_hidden_states = text_encoder(input_ids)[0]
1634
- else:
1635
- enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
1636
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1637
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
1638
-
1639
- # bs*3, 77, 768 or 1024
1640
- encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
1641
-
1642
- if args.max_token_length is not None:
1643
- if args.v2:
1644
- # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
1645
- states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1646
- for i in range(1, args.max_token_length, tokenizer.model_max_length):
1647
- chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
1648
- if i > 0:
1649
- for j in range(len(chunk)):
1650
- if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
1651
- chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
1652
- states_list.append(chunk) # <BOS> の後から <EOS> の前まで
1653
- states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
1654
- encoder_hidden_states = torch.cat(states_list, dim=1)
1655
- else:
1656
- # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
1657
- states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
1658
- for i in range(1, args.max_token_length, tokenizer.model_max_length):
1659
- states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
1660
- states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
1661
- encoder_hidden_states = torch.cat(states_list, dim=1)
1662
-
1663
- if weight_dtype is not None:
1664
- # this is required for additional network training
1665
- encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
1666
-
1667
- return encoder_hidden_states
1668
-
1669
-
1670
- def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
1671
- model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
1672
- ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
1673
- return model_name, ckpt_name
1674
-
1675
-
1676
- def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
1677
- saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
1678
- if saving:
1679
- os.makedirs(args.output_dir, exist_ok=True)
1680
- save_func()
1681
-
1682
- if args.save_last_n_epochs is not None:
1683
- remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
1684
- remove_old_func(remove_epoch_no)
1685
- return saving
1686
-
1687
-
1688
- def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
1689
- epoch_no = epoch + 1
1690
- model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
1691
-
1692
- if save_stable_diffusion_format:
1693
- def save_sd():
1694
- ckpt_file = os.path.join(args.output_dir, ckpt_name)
1695
- print(f"saving checkpoint: {ckpt_file}")
1696
- model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1697
- src_path, epoch_no, global_step, save_dtype, vae)
1698
-
1699
- def remove_sd(old_epoch_no):
1700
- _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
1701
- old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1702
- if os.path.exists(old_ckpt_file):
1703
- print(f"removing old checkpoint: {old_ckpt_file}")
1704
- os.remove(old_ckpt_file)
1705
-
1706
- save_func = save_sd
1707
- remove_old_func = remove_sd
1708
- else:
1709
- def save_du():
1710
- out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
1711
- print(f"saving model: {out_dir}")
1712
- os.makedirs(out_dir, exist_ok=True)
1713
- model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1714
- src_path, vae=vae, use_safetensors=use_safetensors)
1715
-
1716
- def remove_du(old_epoch_no):
1717
- out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
1718
- if os.path.exists(out_dir_old):
1719
- print(f"removing old model: {out_dir_old}")
1720
- shutil.rmtree(out_dir_old)
1721
-
1722
- save_func = save_du
1723
- remove_old_func = remove_du
1724
-
1725
- saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
1726
- if saving and args.save_state:
1727
- save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
1728
-
1729
-
1730
- def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
1731
- print("saving state.")
1732
- accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
1733
-
1734
- last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
1735
- if last_n_epochs is not None:
1736
- remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
1737
- state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
1738
- if os.path.exists(state_dir_old):
1739
- print(f"removing old state: {state_dir_old}")
1740
- shutil.rmtree(state_dir_old)
1741
-
1742
-
1743
- def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
1744
- model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1745
-
1746
- if save_stable_diffusion_format:
1747
- os.makedirs(args.output_dir, exist_ok=True)
1748
-
1749
- ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
1750
- ckpt_file = os.path.join(args.output_dir, ckpt_name)
1751
-
1752
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
1753
- model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
1754
- src_path, epoch, global_step, save_dtype, vae)
1755
- else:
1756
- out_dir = os.path.join(args.output_dir, model_name)
1757
- os.makedirs(out_dir, exist_ok=True)
1758
-
1759
- print(f"save trained model as Diffusers to {out_dir}")
1760
- model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
1761
- src_path, vae=vae, use_safetensors=use_safetensors)
1762
-
1763
-
1764
- def save_state_on_train_end(args: argparse.Namespace, accelerator):
1765
- print("saving last state.")
1766
- os.makedirs(args.output_dir, exist_ok=True)
1767
- model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
1768
- accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
1769
-
1770
- # endregion
1771
-
1772
- # region 前処理用
1773
-
1774
-
1775
- class ImageLoadingDataset(torch.utils.data.Dataset):
1776
- def __init__(self, image_paths):
1777
- self.images = image_paths
1778
-
1779
- def __len__(self):
1780
- return len(self.images)
1781
-
1782
- def __getitem__(self, idx):
1783
- img_path = self.images[idx]
1784
-
1785
- try:
1786
- image = Image.open(img_path).convert("RGB")
1787
- # convert to tensor temporarily so dataloader will accept it
1788
- tensor_pil = transforms.functional.pil_to_tensor(image)
1789
- except Exception as e:
1790
- print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
1791
- return None
1792
-
1793
- return (tensor_pil, img_path)
1794
-
1795
-
1796
- # endregion