radhika-minion02 commited on
Commit
71a8cfb
1 Parent(s): 64ca889

Uploading remaining model files

Browse files
convert_diffusers_to_original_stable_diffusion.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+
12
+
13
+ # =================#
14
+ # UNet Conversion #
15
+ # =================#
16
+
17
+ unet_conversion_map = [
18
+ # (stable-diffusion, HF Diffusers)
19
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
20
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
21
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
22
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
23
+ ("input_blocks.0.0.weight", "conv_in.weight"),
24
+ ("input_blocks.0.0.bias", "conv_in.bias"),
25
+ ("out.0.weight", "conv_norm_out.weight"),
26
+ ("out.0.bias", "conv_norm_out.bias"),
27
+ ("out.2.weight", "conv_out.weight"),
28
+ ("out.2.bias", "conv_out.bias"),
29
+ ]
30
+
31
+ unet_conversion_map_resnet = [
32
+ # (stable-diffusion, HF Diffusers)
33
+ ("in_layers.0", "norm1"),
34
+ ("in_layers.2", "conv1"),
35
+ ("out_layers.0", "norm2"),
36
+ ("out_layers.3", "conv2"),
37
+ ("emb_layers.1", "time_emb_proj"),
38
+ ("skip_connection", "conv_shortcut"),
39
+ ]
40
+
41
+ unet_conversion_map_layer = []
42
+ # hardcoded number of downblocks and resnets/attentions...
43
+ # would need smarter logic for other networks.
44
+ for i in range(4):
45
+ # loop over downblocks/upblocks
46
+
47
+ for j in range(2):
48
+ # loop over resnets/attentions for downblocks
49
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
50
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
51
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
52
+
53
+ if i < 3:
54
+ # no attention layers in down_blocks.3
55
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
56
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
57
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
58
+
59
+ for j in range(3):
60
+ # loop over resnets/attentions for upblocks
61
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
62
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
63
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
64
+
65
+ if i > 0:
66
+ # no attention layers in up_blocks.0
67
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
68
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
69
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
70
+
71
+ if i < 3:
72
+ # no downsample in down_blocks.3
73
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
74
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
75
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
76
+
77
+ # no upsample in up_blocks.3
78
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
79
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
80
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
81
+
82
+ hf_mid_atn_prefix = "mid_block.attentions.0."
83
+ sd_mid_atn_prefix = "middle_block.1."
84
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
85
+
86
+ for j in range(2):
87
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
88
+ sd_mid_res_prefix = f"middle_block.{2*j}."
89
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
90
+
91
+
92
+ def convert_unet_state_dict(unet_state_dict):
93
+ # buyer beware: this is a *brittle* function,
94
+ # and correct output requires that all of these pieces interact in
95
+ # the exact order in which I have arranged them.
96
+ mapping = {k: k for k in unet_state_dict.keys()}
97
+ for sd_name, hf_name in unet_conversion_map:
98
+ mapping[hf_name] = sd_name
99
+ for k, v in mapping.items():
100
+ if "resnets" in k:
101
+ for sd_part, hf_part in unet_conversion_map_resnet:
102
+ v = v.replace(hf_part, sd_part)
103
+ mapping[k] = v
104
+ for k, v in mapping.items():
105
+ for sd_part, hf_part in unet_conversion_map_layer:
106
+ v = v.replace(hf_part, sd_part)
107
+ mapping[k] = v
108
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
109
+ return new_state_dict
110
+
111
+
112
+ # ================#
113
+ # VAE Conversion #
114
+ # ================#
115
+
116
+ vae_conversion_map = [
117
+ # (stable-diffusion, HF Diffusers)
118
+ ("nin_shortcut", "conv_shortcut"),
119
+ ("norm_out", "conv_norm_out"),
120
+ ("mid.attn_1.", "mid_block.attentions.0."),
121
+ ]
122
+
123
+ for i in range(4):
124
+ # down_blocks have two resnets
125
+ for j in range(2):
126
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
127
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
128
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
129
+
130
+ if i < 3:
131
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
132
+ sd_downsample_prefix = f"down.{i}.downsample."
133
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
134
+
135
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
136
+ sd_upsample_prefix = f"up.{3-i}.upsample."
137
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
138
+
139
+ # up_blocks have three resnets
140
+ # also, up blocks in hf are numbered in reverse from sd
141
+ for j in range(3):
142
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
143
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
144
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
145
+
146
+ # this part accounts for mid blocks in both the encoder and the decoder
147
+ for i in range(2):
148
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
149
+ sd_mid_res_prefix = f"mid.block_{i+1}."
150
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
151
+
152
+
153
+ vae_conversion_map_attn = [
154
+ # (stable-diffusion, HF Diffusers)
155
+ ("norm.", "group_norm."),
156
+ ("q.", "query."),
157
+ ("k.", "key."),
158
+ ("v.", "value."),
159
+ ("proj_out.", "proj_attn."),
160
+ ]
161
+
162
+
163
+ def reshape_weight_for_sd(w):
164
+ # convert HF linear weights to SD conv2d weights
165
+ return w.reshape(*w.shape, 1, 1)
166
+
167
+
168
+ def convert_vae_state_dict(vae_state_dict):
169
+ mapping = {k: k for k in vae_state_dict.keys()}
170
+ for k, v in mapping.items():
171
+ for sd_part, hf_part in vae_conversion_map:
172
+ v = v.replace(hf_part, sd_part)
173
+ mapping[k] = v
174
+ for k, v in mapping.items():
175
+ if "attentions" in k:
176
+ for sd_part, hf_part in vae_conversion_map_attn:
177
+ v = v.replace(hf_part, sd_part)
178
+ mapping[k] = v
179
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
180
+ weights_to_convert = ["q", "k", "v", "proj_out"]
181
+ for k, v in new_state_dict.items():
182
+ for weight_name in weights_to_convert:
183
+ if f"mid.attn_1.{weight_name}.weight" in k:
184
+ print(f"Reshaping {k} for SD format")
185
+ new_state_dict[k] = reshape_weight_for_sd(v)
186
+ return new_state_dict
187
+
188
+
189
+ # =========================#
190
+ # Text Encoder Conversion #
191
+ # =========================#
192
+
193
+
194
+ textenc_conversion_lst = [
195
+ # (stable-diffusion, HF Diffusers)
196
+ ("resblocks.", "text_model.encoder.layers."),
197
+ ("ln_1", "layer_norm1"),
198
+ ("ln_2", "layer_norm2"),
199
+ (".c_fc.", ".fc1."),
200
+ (".c_proj.", ".fc2."),
201
+ (".attn", ".self_attn"),
202
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
203
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
204
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
205
+ ]
206
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
207
+ textenc_pattern = re.compile("|".join(protected.keys()))
208
+
209
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
210
+ code2idx = {"q": 0, "k": 1, "v": 2}
211
+
212
+
213
+ def convert_text_enc_state_dict_v20(text_enc_dict):
214
+ new_state_dict = {}
215
+ capture_qkv_weight = {}
216
+ capture_qkv_bias = {}
217
+ for k, v in text_enc_dict.items():
218
+ if (
219
+ k.endswith(".self_attn.q_proj.weight")
220
+ or k.endswith(".self_attn.k_proj.weight")
221
+ or k.endswith(".self_attn.v_proj.weight")
222
+ ):
223
+ k_pre = k[: -len(".q_proj.weight")]
224
+ k_code = k[-len("q_proj.weight")]
225
+ if k_pre not in capture_qkv_weight:
226
+ capture_qkv_weight[k_pre] = [None, None, None]
227
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
228
+ continue
229
+
230
+ if (
231
+ k.endswith(".self_attn.q_proj.bias")
232
+ or k.endswith(".self_attn.k_proj.bias")
233
+ or k.endswith(".self_attn.v_proj.bias")
234
+ ):
235
+ k_pre = k[: -len(".q_proj.bias")]
236
+ k_code = k[-len("q_proj.bias")]
237
+ if k_pre not in capture_qkv_bias:
238
+ capture_qkv_bias[k_pre] = [None, None, None]
239
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
240
+ continue
241
+
242
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
243
+ new_state_dict[relabelled_key] = v
244
+
245
+ for k_pre, tensors in capture_qkv_weight.items():
246
+ if None in tensors:
247
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
248
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
249
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
250
+
251
+ for k_pre, tensors in capture_qkv_bias.items():
252
+ if None in tensors:
253
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
254
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
255
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
256
+
257
+ return new_state_dict
258
+
259
+
260
+ def convert_text_enc_state_dict(text_enc_dict):
261
+ return text_enc_dict
262
+
263
+
264
+ if __name__ == "__main__":
265
+ parser = argparse.ArgumentParser()
266
+
267
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
268
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
269
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
270
+ parser.add_argument(
271
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
272
+ )
273
+
274
+ args = parser.parse_args()
275
+
276
+ assert args.model_path is not None, "Must provide a model path!"
277
+
278
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
279
+
280
+ # Path for safetensors
281
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
282
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
283
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
284
+
285
+ # Load models from safetensors if it exists, if it doesn't pytorch
286
+ if osp.exists(unet_path):
287
+ unet_state_dict = load_file(unet_path, device="cpu")
288
+ else:
289
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
290
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
291
+
292
+ if osp.exists(vae_path):
293
+ vae_state_dict = load_file(vae_path, device="cpu")
294
+ else:
295
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
296
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
297
+
298
+ if osp.exists(text_enc_path):
299
+ text_enc_dict = load_file(text_enc_path, device="cpu")
300
+ else:
301
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
302
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
303
+
304
+ # Convert the UNet model
305
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
306
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
307
+
308
+ # Convert the VAE model
309
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
310
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
311
+
312
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
313
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
314
+
315
+ if is_v20_model:
316
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
317
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
318
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
319
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
320
+ else:
321
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
322
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
323
+
324
+ # Put together new checkpoint
325
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
326
+ if args.half:
327
+ state_dict = {k: v.half() for k, v in state_dict.items()}
328
+
329
+ if args.use_safetensors:
330
+ save_file(state_dict, args.checkpoint_path)
331
+ else:
332
+ state_dict = {"state_dict": state_dict}
333
+ torch.save(state_dict, args.checkpoint_path)
data/fighter_jets/02730-204967744.png ADDED
data/fighter_jets/02731-204967745.png ADDED
data/fighter_jets/02733-204967747.png ADDED
data/fighter_jets/02734-59881992.png ADDED
data/fighter_jets/02752-304265715.png ADDED
data/fighter_jets/02754-2097021238.png ADDED
data/fighter_jets/02759-393359557.png ADDED
data/fighter_jets/02802-1858981028.png ADDED
data/fighter_jets/02808-3242270246.png ADDED
data/fighter_jets/02913-2981163727.png ADDED
data/trains/00002-512754533-train.png ADDED
data/trains/00003-512754534-train.png ADDED
data/trains/00009-2770989505-train.png ADDED
data/trains/00110-4189543447-train.png ADDED
data/trains/00117-4189543454-train.png ADDED
data/trains/00203-3263883525-train.png ADDED
data/trains/00206-3263883528-train.png ADDED
data/trains/00237-1689927512-train.png ADDED
data/trains/00242-1689927517-train.png ADDED
data/trains/00376-3481227890-train.png ADDED
model/model-001.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9adc95eeecdcba071ffe15f815c537ed350f899780da842d66fc2eb4f9626dd1
3
+ size 2580250395
model/samples/0.png ADDED
model/samples/1.png ADDED
model/samples/2.png ADDED
model/samples/3.png ADDED
model/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.33.2",
24
+ "vocab_size": 49408
25
+ }
model/text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ee0a15adae7f5b5a8e337f3a6ed108ad70bcfe6c550973b407fc48818b52d9f
3
+ size 1361596304
model/unet/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.22.0.dev0",
4
+ "_name_or_path": "stabilityai/stable-diffusion-2",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20,
13
+ 20
14
+ ],
15
+ "attention_type": "default",
16
+ "block_out_channels": [
17
+ 320,
18
+ 640,
19
+ 1280,
20
+ 1280
21
+ ],
22
+ "center_input_sample": false,
23
+ "class_embed_type": null,
24
+ "class_embeddings_concat": false,
25
+ "conv_in_kernel": 3,
26
+ "conv_out_kernel": 3,
27
+ "cross_attention_dim": 1024,
28
+ "cross_attention_norm": null,
29
+ "down_block_types": [
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "DownBlock2D"
34
+ ],
35
+ "downsample_padding": 1,
36
+ "dropout": 0.0,
37
+ "dual_cross_attention": false,
38
+ "encoder_hid_dim": null,
39
+ "encoder_hid_dim_type": null,
40
+ "flip_sin_to_cos": true,
41
+ "freq_shift": 0,
42
+ "in_channels": 4,
43
+ "layers_per_block": 2,
44
+ "mid_block_only_cross_attention": null,
45
+ "mid_block_scale_factor": 1,
46
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
47
+ "norm_eps": 1e-05,
48
+ "norm_num_groups": 32,
49
+ "num_attention_heads": null,
50
+ "num_class_embeds": null,
51
+ "only_cross_attention": false,
52
+ "out_channels": 4,
53
+ "projection_class_embeddings_input_dim": null,
54
+ "resnet_out_scale_factor": 1.0,
55
+ "resnet_skip_time_act": false,
56
+ "resnet_time_scale_shift": "default",
57
+ "sample_size": 96,
58
+ "time_cond_proj_dim": null,
59
+ "time_embedding_act_fn": null,
60
+ "time_embedding_dim": null,
61
+ "time_embedding_type": "positional",
62
+ "timestep_post_act": null,
63
+ "transformer_layers_per_block": 1,
64
+ "up_block_types": [
65
+ "UpBlock2D",
66
+ "CrossAttnUpBlock2D",
67
+ "CrossAttnUpBlock2D",
68
+ "CrossAttnUpBlock2D"
69
+ ],
70
+ "upcast_attention": false,
71
+ "use_linear_projection": true
72
+ }
model/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1191bdf19018935bdd5ce80abbb9bcaf799a4fac296e88d6f855bdf45ba1c9ae
3
+ size 3463726504
model/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
train_dreambooth.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import itertools
4
+ import random
5
+ import json
6
+ import logging
7
+ import math
8
+ import os
9
+ from contextlib import nullcontext
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+ from torch.utils.data import Dataset
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
22
+ from diffusers.optimization import get_scheduler
23
+ from diffusers.utils.import_utils import is_xformers_available
24
+ from huggingface_hub import HfFolder, Repository, whoami
25
+ from PIL import Image
26
+ from torchvision import transforms
27
+ from tqdm.auto import tqdm
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+
30
+
31
+ torch.backends.cudnn.benchmark = True
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ def parse_args(input_args=None):
38
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
39
+ parser.add_argument(
40
+ "--pretrained_model_name_or_path",
41
+ type=str,
42
+ default=None,
43
+ required=True,
44
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
45
+ )
46
+ parser.add_argument(
47
+ "--pretrained_vae_name_or_path",
48
+ type=str,
49
+ default=None,
50
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
51
+ )
52
+ parser.add_argument(
53
+ "--revision",
54
+ type=str,
55
+ default=None,
56
+ required=False,
57
+ help="Revision of pretrained model identifier from huggingface.co/models.",
58
+ )
59
+ parser.add_argument(
60
+ "--tokenizer_name",
61
+ type=str,
62
+ default=None,
63
+ help="Pretrained tokenizer name or path if not the same as model_name",
64
+ )
65
+ parser.add_argument(
66
+ "--instance_data_dir",
67
+ type=str,
68
+ default=None,
69
+ help="A folder containing the training data of instance images.",
70
+ )
71
+ parser.add_argument(
72
+ "--class_data_dir",
73
+ type=str,
74
+ default=None,
75
+ help="A folder containing the training data of class images.",
76
+ )
77
+ parser.add_argument(
78
+ "--instance_prompt",
79
+ type=str,
80
+ default=None,
81
+ help="The prompt with identifier specifying the instance",
82
+ )
83
+ parser.add_argument(
84
+ "--class_prompt",
85
+ type=str,
86
+ default=None,
87
+ help="The prompt to specify images in the same class as provided instance images.",
88
+ )
89
+ parser.add_argument(
90
+ "--save_sample_prompt",
91
+ type=str,
92
+ default=None,
93
+ help="The prompt used to generate sample outputs to save.",
94
+ )
95
+ parser.add_argument(
96
+ "--save_sample_negative_prompt",
97
+ type=str,
98
+ default=None,
99
+ help="The negative prompt used to generate sample outputs to save.",
100
+ )
101
+ parser.add_argument(
102
+ "--n_save_sample",
103
+ type=int,
104
+ default=4,
105
+ help="The number of samples to save.",
106
+ )
107
+ parser.add_argument(
108
+ "--save_guidance_scale",
109
+ type=float,
110
+ default=7.5,
111
+ help="CFG for save sample.",
112
+ )
113
+ parser.add_argument(
114
+ "--save_infer_steps",
115
+ type=int,
116
+ default=20,
117
+ help="The number of inference steps for save sample.",
118
+ )
119
+ parser.add_argument(
120
+ "--pad_tokens",
121
+ default=False,
122
+ action="store_true",
123
+ help="Flag to pad tokens to length 77.",
124
+ )
125
+ parser.add_argument(
126
+ "--with_prior_preservation",
127
+ default=False,
128
+ action="store_true",
129
+ help="Flag to add prior preservation loss.",
130
+ )
131
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
132
+ parser.add_argument(
133
+ "--num_class_images",
134
+ type=int,
135
+ default=100,
136
+ help=(
137
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
138
+ " sampled with class_prompt."
139
+ ),
140
+ )
141
+ parser.add_argument(
142
+ "--output_dir",
143
+ type=str,
144
+ default="text-inversion-model",
145
+ help="The output directory where the model predictions and checkpoints will be written.",
146
+ )
147
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
148
+ parser.add_argument(
149
+ "--resolution",
150
+ type=int,
151
+ default=512,
152
+ help=(
153
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
154
+ " resolution"
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
159
+ )
160
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
161
+ parser.add_argument(
162
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
163
+ )
164
+ parser.add_argument(
165
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
166
+ )
167
+ parser.add_argument("--num_train_epochs", type=int, default=1)
168
+ parser.add_argument(
169
+ "--max_train_steps",
170
+ type=int,
171
+ default=None,
172
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
173
+ )
174
+ parser.add_argument(
175
+ "--gradient_accumulation_steps",
176
+ type=int,
177
+ default=1,
178
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
179
+ )
180
+ parser.add_argument(
181
+ "--gradient_checkpointing",
182
+ action="store_true",
183
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
184
+ )
185
+ parser.add_argument(
186
+ "--learning_rate",
187
+ type=float,
188
+ default=5e-6,
189
+ help="Initial learning rate (after the potential warmup period) to use.",
190
+ )
191
+ parser.add_argument(
192
+ "--scale_lr",
193
+ action="store_true",
194
+ default=False,
195
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
196
+ )
197
+ parser.add_argument(
198
+ "--lr_scheduler",
199
+ type=str,
200
+ default="constant",
201
+ help=(
202
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
203
+ ' "constant", "constant_with_warmup"]'
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
208
+ )
209
+ parser.add_argument(
210
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
211
+ )
212
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
213
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
214
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
215
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
216
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
217
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
218
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
219
+ parser.add_argument(
220
+ "--hub_model_id",
221
+ type=str,
222
+ default=None,
223
+ help="The name of the repository to keep in sync with the local `output_dir`.",
224
+ )
225
+ parser.add_argument(
226
+ "--logging_dir",
227
+ type=str,
228
+ default="logs",
229
+ help=(
230
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
231
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
232
+ ),
233
+ )
234
+ parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
235
+ parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
236
+ parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
237
+ parser.add_argument(
238
+ "--mixed_precision",
239
+ type=str,
240
+ default=None,
241
+ choices=["no", "fp16", "bf16"],
242
+ help=(
243
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
244
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
245
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
246
+ ),
247
+ )
248
+ parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
249
+ parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
250
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
251
+ parser.add_argument(
252
+ "--concepts_list",
253
+ type=str,
254
+ default=None,
255
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
256
+ )
257
+ parser.add_argument(
258
+ "--read_prompts_from_txts",
259
+ action="store_true",
260
+ help="Use prompt per image. Put prompts in the same directory as images, e.g. for image.png create image.png.txt.",
261
+ )
262
+
263
+ if input_args is not None:
264
+ args = parser.parse_args(input_args)
265
+ else:
266
+ args = parser.parse_args()
267
+
268
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
269
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
270
+ args.local_rank = env_local_rank
271
+
272
+ return args
273
+
274
+
275
+ class DreamBoothDataset(Dataset):
276
+ """
277
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
278
+ It pre-processes the images and the tokenizes prompts.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ concepts_list,
284
+ tokenizer,
285
+ with_prior_preservation=True,
286
+ size=512,
287
+ center_crop=False,
288
+ num_class_images=None,
289
+ pad_tokens=False,
290
+ hflip=False,
291
+ read_prompts_from_txts=False,
292
+ ):
293
+ self.size = size
294
+ self.center_crop = center_crop
295
+ self.tokenizer = tokenizer
296
+ self.with_prior_preservation = with_prior_preservation
297
+ self.pad_tokens = pad_tokens
298
+ self.read_prompts_from_txts = read_prompts_from_txts
299
+
300
+ self.instance_images_path = []
301
+ self.class_images_path = []
302
+
303
+ for concept in concepts_list:
304
+ inst_img_path = [
305
+ (x, concept["instance_prompt"])
306
+ for x in Path(concept["instance_data_dir"]).iterdir()
307
+ if x.is_file() and not str(x).endswith(".txt")
308
+ ]
309
+ self.instance_images_path.extend(inst_img_path)
310
+
311
+ if with_prior_preservation:
312
+ class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
313
+ self.class_images_path.extend(class_img_path[:num_class_images])
314
+
315
+ random.shuffle(self.instance_images_path)
316
+ self.num_instance_images = len(self.instance_images_path)
317
+ self.num_class_images = len(self.class_images_path)
318
+ self._length = max(self.num_class_images, self.num_instance_images)
319
+
320
+ self.image_transforms = transforms.Compose(
321
+ [
322
+ transforms.RandomHorizontalFlip(0.5 * hflip),
323
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
324
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
325
+ transforms.ToTensor(),
326
+ transforms.Normalize([0.5], [0.5]),
327
+ ]
328
+ )
329
+
330
+ def __len__(self):
331
+ return self._length
332
+
333
+ def __getitem__(self, index):
334
+ example = {}
335
+ instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
336
+
337
+ if self.read_prompts_from_txts:
338
+ with open(str(instance_path) + ".txt") as f:
339
+ instance_prompt = f.read().strip()
340
+
341
+ instance_image = Image.open(instance_path)
342
+ if not instance_image.mode == "RGB":
343
+ instance_image = instance_image.convert("RGB")
344
+
345
+ example["instance_images"] = self.image_transforms(instance_image)
346
+ example["instance_prompt_ids"] = self.tokenizer(
347
+ instance_prompt,
348
+ padding="max_length" if self.pad_tokens else "do_not_pad",
349
+ truncation=True,
350
+ max_length=self.tokenizer.model_max_length,
351
+ ).input_ids
352
+
353
+ if self.with_prior_preservation:
354
+ class_path, class_prompt = self.class_images_path[index % self.num_class_images]
355
+ class_image = Image.open(class_path)
356
+ if not class_image.mode == "RGB":
357
+ class_image = class_image.convert("RGB")
358
+ example["class_images"] = self.image_transforms(class_image)
359
+ example["class_prompt_ids"] = self.tokenizer(
360
+ class_prompt,
361
+ padding="max_length" if self.pad_tokens else "do_not_pad",
362
+ truncation=True,
363
+ max_length=self.tokenizer.model_max_length,
364
+ ).input_ids
365
+
366
+ return example
367
+
368
+
369
+ class PromptDataset(Dataset):
370
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
371
+
372
+ def __init__(self, prompt, num_samples):
373
+ self.prompt = prompt
374
+ self.num_samples = num_samples
375
+
376
+ def __len__(self):
377
+ return self.num_samples
378
+
379
+ def __getitem__(self, index):
380
+ example = {}
381
+ example["prompt"] = self.prompt
382
+ example["index"] = index
383
+ return example
384
+
385
+
386
+ class LatentsDataset(Dataset):
387
+ def __init__(self, latents_cache, text_encoder_cache):
388
+ self.latents_cache = latents_cache
389
+ self.text_encoder_cache = text_encoder_cache
390
+
391
+ def __len__(self):
392
+ return len(self.latents_cache)
393
+
394
+ def __getitem__(self, index):
395
+ return self.latents_cache[index], self.text_encoder_cache[index]
396
+
397
+
398
+ class AverageMeter:
399
+ def __init__(self, name=None):
400
+ self.name = name
401
+ self.reset()
402
+
403
+ def reset(self):
404
+ self.sum = self.count = self.avg = 0
405
+
406
+ def update(self, val, n=1):
407
+ self.sum += val * n
408
+ self.count += n
409
+ self.avg = self.sum / self.count
410
+
411
+
412
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
413
+ if token is None:
414
+ token = HfFolder.get_token()
415
+ if organization is None:
416
+ username = whoami(token)["name"]
417
+ return f"{username}/{model_id}"
418
+ else:
419
+ return f"{organization}/{model_id}"
420
+
421
+
422
+ def main(args):
423
+ logging_dir = Path(args.output_dir, "0", args.logging_dir)
424
+
425
+ accelerator = Accelerator(
426
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
427
+ mixed_precision=args.mixed_precision,
428
+ log_with="tensorboard",
429
+ project_dir=logging_dir,
430
+ )
431
+
432
+ logging.basicConfig(
433
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
434
+ datefmt="%m/%d/%Y %H:%M:%S",
435
+ level=logging.INFO,
436
+ )
437
+
438
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
439
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
440
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
441
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
442
+ raise ValueError(
443
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
444
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
445
+ )
446
+
447
+ if args.seed is not None:
448
+ set_seed(args.seed)
449
+
450
+ if args.concepts_list is None:
451
+ args.concepts_list = [
452
+ {
453
+ "instance_prompt": args.instance_prompt,
454
+ "class_prompt": args.class_prompt,
455
+ "instance_data_dir": args.instance_data_dir,
456
+ "class_data_dir": args.class_data_dir
457
+ }
458
+ ]
459
+ else:
460
+ with open(args.concepts_list, "r") as f:
461
+ args.concepts_list = json.load(f)
462
+
463
+ if args.with_prior_preservation:
464
+ pipeline = None
465
+ for concept in args.concepts_list:
466
+ class_images_dir = Path(concept["class_data_dir"])
467
+ class_images_dir.mkdir(parents=True, exist_ok=True)
468
+ cur_class_images = len(list(class_images_dir.iterdir()))
469
+
470
+ if cur_class_images < args.num_class_images:
471
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
472
+ if pipeline is None:
473
+ pipeline = StableDiffusionPipeline.from_pretrained(
474
+ args.pretrained_model_name_or_path,
475
+ vae=AutoencoderKL.from_pretrained(
476
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
477
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
478
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
479
+ torch_dtype=torch_dtype
480
+ ),
481
+ torch_dtype=torch_dtype,
482
+ safety_checker=None,
483
+ revision=args.revision
484
+ )
485
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
486
+ if is_xformers_available():
487
+ pipeline.enable_xformers_memory_efficient_attention()
488
+ pipeline.set_progress_bar_config(disable=True)
489
+ pipeline.to(accelerator.device)
490
+
491
+ num_new_images = args.num_class_images - cur_class_images
492
+ logger.info(f"Number of class images to sample: {num_new_images}.")
493
+
494
+ sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
495
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
496
+
497
+ sample_dataloader = accelerator.prepare(sample_dataloader)
498
+
499
+ with torch.autocast("cuda"), torch.inference_mode():
500
+ for example in tqdm(
501
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
502
+ ):
503
+ images = pipeline(
504
+ example["prompt"],
505
+ num_inference_steps=args.save_infer_steps
506
+ ).images
507
+
508
+ for i, image in enumerate(images):
509
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
510
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
511
+ image.save(image_filename)
512
+
513
+ del pipeline
514
+ if torch.cuda.is_available():
515
+ torch.cuda.empty_cache()
516
+
517
+ # Load the tokenizer
518
+ if args.tokenizer_name:
519
+ tokenizer = CLIPTokenizer.from_pretrained(
520
+ args.tokenizer_name,
521
+ revision=args.revision,
522
+ )
523
+ elif args.pretrained_model_name_or_path:
524
+ tokenizer = CLIPTokenizer.from_pretrained(
525
+ args.pretrained_model_name_or_path,
526
+ subfolder="tokenizer",
527
+ revision=args.revision,
528
+ )
529
+
530
+ # Load models and create wrapper for stable diffusion
531
+ text_encoder = CLIPTextModel.from_pretrained(
532
+ args.pretrained_model_name_or_path,
533
+ subfolder="text_encoder",
534
+ revision=args.revision,
535
+ )
536
+ vae = AutoencoderKL.from_pretrained(
537
+ args.pretrained_model_name_or_path,
538
+ subfolder="vae",
539
+ revision=args.revision,
540
+ )
541
+ unet = UNet2DConditionModel.from_pretrained(
542
+ args.pretrained_model_name_or_path,
543
+ subfolder="unet",
544
+ revision=args.revision,
545
+ torch_dtype=torch.float32
546
+ )
547
+
548
+ vae.requires_grad_(False)
549
+ if not args.train_text_encoder:
550
+ text_encoder.requires_grad_(False)
551
+
552
+ if is_xformers_available():
553
+ vae.enable_xformers_memory_efficient_attention()
554
+ unet.enable_xformers_memory_efficient_attention()
555
+ else:
556
+ logger.warning("xformers is not available. Make sure it is installed correctly")
557
+
558
+ if args.gradient_checkpointing:
559
+ unet.enable_gradient_checkpointing()
560
+ if args.train_text_encoder:
561
+ text_encoder.gradient_checkpointing_enable()
562
+
563
+ if args.scale_lr:
564
+ args.learning_rate = (
565
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
566
+ )
567
+
568
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
569
+ if args.use_8bit_adam:
570
+ try:
571
+ import bitsandbytes as bnb
572
+ except ImportError:
573
+ raise ImportError(
574
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
575
+ )
576
+
577
+ optimizer_class = bnb.optim.AdamW8bit
578
+ else:
579
+ optimizer_class = torch.optim.AdamW
580
+
581
+ params_to_optimize = (
582
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
583
+ )
584
+ optimizer = optimizer_class(
585
+ params_to_optimize,
586
+ lr=args.learning_rate,
587
+ betas=(args.adam_beta1, args.adam_beta2),
588
+ weight_decay=args.adam_weight_decay,
589
+ eps=args.adam_epsilon,
590
+ )
591
+
592
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
593
+
594
+ train_dataset = DreamBoothDataset(
595
+ concepts_list=args.concepts_list,
596
+ tokenizer=tokenizer,
597
+ with_prior_preservation=args.with_prior_preservation,
598
+ size=args.resolution,
599
+ center_crop=args.center_crop,
600
+ num_class_images=args.num_class_images,
601
+ pad_tokens=args.pad_tokens,
602
+ hflip=args.hflip,
603
+ read_prompts_from_txts=args.read_prompts_from_txts,
604
+ )
605
+
606
+ def collate_fn(examples):
607
+ input_ids = [example["instance_prompt_ids"] for example in examples]
608
+ pixel_values = [example["instance_images"] for example in examples]
609
+
610
+ # Concat class and instance examples for prior preservation.
611
+ # We do this to avoid doing two forward passes.
612
+ if args.with_prior_preservation:
613
+ input_ids += [example["class_prompt_ids"] for example in examples]
614
+ pixel_values += [example["class_images"] for example in examples]
615
+
616
+ pixel_values = torch.stack(pixel_values)
617
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
618
+
619
+ input_ids = tokenizer.pad(
620
+ {"input_ids": input_ids},
621
+ padding=True,
622
+ return_tensors="pt",
623
+ ).input_ids
624
+
625
+ batch = {
626
+ "input_ids": input_ids,
627
+ "pixel_values": pixel_values,
628
+ }
629
+ return batch
630
+
631
+ train_dataloader = torch.utils.data.DataLoader(
632
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
633
+ )
634
+
635
+ weight_dtype = torch.float32
636
+ if args.mixed_precision == "fp16":
637
+ weight_dtype = torch.float16
638
+ elif args.mixed_precision == "bf16":
639
+ weight_dtype = torch.bfloat16
640
+
641
+ # Move text_encode and vae to gpu.
642
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
643
+ # as these models are only used for inference, keeping weights in full precision is not required.
644
+ vae.to(accelerator.device, dtype=weight_dtype)
645
+ if not args.train_text_encoder:
646
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
647
+
648
+ if not args.not_cache_latents:
649
+ latents_cache = []
650
+ text_encoder_cache = []
651
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
652
+ with torch.no_grad():
653
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
654
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
655
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
656
+ if args.train_text_encoder:
657
+ text_encoder_cache.append(batch["input_ids"])
658
+ else:
659
+ text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
660
+ train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
661
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
662
+
663
+ del vae
664
+ if not args.train_text_encoder:
665
+ del text_encoder
666
+ if torch.cuda.is_available():
667
+ torch.cuda.empty_cache()
668
+
669
+ # Scheduler and math around the number of training steps.
670
+ overrode_max_train_steps = False
671
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
672
+ if args.max_train_steps is None:
673
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
674
+ overrode_max_train_steps = True
675
+
676
+ lr_scheduler = get_scheduler(
677
+ args.lr_scheduler,
678
+ optimizer=optimizer,
679
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
680
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
681
+ )
682
+
683
+ if args.train_text_encoder:
684
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
685
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
686
+ )
687
+ else:
688
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
689
+ unet, optimizer, train_dataloader, lr_scheduler
690
+ )
691
+
692
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
693
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
694
+ if overrode_max_train_steps:
695
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
696
+ # Afterwards we recalculate our number of training epochs
697
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
698
+
699
+ # We need to initialize the trackers we use, and also store our configuration.
700
+ # The trackers initializes automatically on the main process.
701
+ if accelerator.is_main_process:
702
+ accelerator.init_trackers("dreambooth")
703
+
704
+ # Train!
705
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
706
+
707
+ logger.info("***** Running training *****")
708
+ logger.info(f" Num examples = {len(train_dataset)}")
709
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
710
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
711
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
712
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
713
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
714
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
715
+
716
+ def save_weights(step):
717
+ # Create the pipeline using using the trained modules and save it.
718
+ if accelerator.is_main_process:
719
+ if args.train_text_encoder:
720
+ text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
721
+ else:
722
+ text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
723
+ pipeline = StableDiffusionPipeline.from_pretrained(
724
+ args.pretrained_model_name_or_path,
725
+ unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
726
+ text_encoder=text_enc_model,
727
+ vae=AutoencoderKL.from_pretrained(
728
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
729
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
730
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
731
+ ),
732
+ safety_checker=None,
733
+ torch_dtype=torch.float16,
734
+ revision=args.revision,
735
+ )
736
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
737
+ if is_xformers_available():
738
+ pipeline.enable_xformers_memory_efficient_attention()
739
+ save_dir = os.path.join(args.output_dir, f"{step}")
740
+ pipeline.save_pretrained(save_dir)
741
+ with open(os.path.join(save_dir, "args.json"), "w") as f:
742
+ json.dump(args.__dict__, f, indent=2)
743
+
744
+ if args.save_sample_prompt is not None:
745
+ pipeline = pipeline.to(accelerator.device)
746
+ g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)
747
+ pipeline.set_progress_bar_config(disable=True)
748
+ sample_dir = os.path.join(save_dir, "samples")
749
+ os.makedirs(sample_dir, exist_ok=True)
750
+ with torch.autocast("cuda"), torch.inference_mode():
751
+ for i in tqdm(range(args.n_save_sample), desc="Generating samples"):
752
+ images = pipeline(
753
+ args.save_sample_prompt,
754
+ negative_prompt=args.save_sample_negative_prompt,
755
+ guidance_scale=args.save_guidance_scale,
756
+ num_inference_steps=args.save_infer_steps,
757
+ generator=g_cuda
758
+ ).images
759
+ images[0].save(os.path.join(sample_dir, f"{i}.png"))
760
+ del pipeline
761
+ if torch.cuda.is_available():
762
+ torch.cuda.empty_cache()
763
+ print(f"[*] Weights saved at {save_dir}")
764
+
765
+ # Only show the progress bar once on each machine.
766
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
767
+ progress_bar.set_description("Steps")
768
+ global_step = 0
769
+ loss_avg = AverageMeter()
770
+ text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
771
+ for epoch in range(args.num_train_epochs):
772
+ unet.train()
773
+ if args.train_text_encoder:
774
+ text_encoder.train()
775
+ for step, batch in enumerate(train_dataloader):
776
+ with accelerator.accumulate(unet):
777
+ # Convert images to latent space
778
+ with torch.no_grad():
779
+ if not args.not_cache_latents:
780
+ latent_dist = batch[0][0]
781
+ else:
782
+ latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
783
+ latents = latent_dist.sample() * 0.18215
784
+
785
+ # Sample noise that we'll add to the latents
786
+ noise = torch.randn_like(latents)
787
+ bsz = latents.shape[0]
788
+ # Sample a random timestep for each image
789
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
790
+ timesteps = timesteps.long()
791
+
792
+ # Add noise to the latents according to the noise magnitude at each timestep
793
+ # (this is the forward diffusion process)
794
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
795
+
796
+ # Get the text embedding for conditioning
797
+ with text_enc_context:
798
+ if not args.not_cache_latents:
799
+ if args.train_text_encoder:
800
+ encoder_hidden_states = text_encoder(batch[0][1])[0]
801
+ else:
802
+ encoder_hidden_states = batch[0][1]
803
+ else:
804
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
805
+
806
+ # Predict the noise residual
807
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
808
+
809
+ # Get the target for loss depending on the prediction type
810
+ if noise_scheduler.config.prediction_type == "epsilon":
811
+ target = noise
812
+ elif noise_scheduler.config.prediction_type == "v_prediction":
813
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
814
+ else:
815
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
816
+
817
+ if args.with_prior_preservation:
818
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
819
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
820
+ target, target_prior = torch.chunk(target, 2, dim=0)
821
+
822
+ # Compute instance loss
823
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
824
+
825
+ # Compute prior loss
826
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
827
+
828
+ # Add the prior loss to the instance loss.
829
+ loss = loss + args.prior_loss_weight * prior_loss
830
+ else:
831
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
832
+
833
+ accelerator.backward(loss)
834
+ # if accelerator.sync_gradients:
835
+ # params_to_clip = (
836
+ # itertools.chain(unet.parameters(), text_encoder.parameters())
837
+ # if args.train_text_encoder
838
+ # else unet.parameters()
839
+ # )
840
+ # accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
841
+ optimizer.step()
842
+ lr_scheduler.step()
843
+ optimizer.zero_grad(set_to_none=True)
844
+ loss_avg.update(loss.detach_(), bsz)
845
+
846
+ if not global_step % args.log_interval:
847
+ logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
848
+ progress_bar.set_postfix(**logs)
849
+ accelerator.log(logs, step=global_step)
850
+
851
+ if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:
852
+ save_weights(global_step)
853
+
854
+ progress_bar.update(1)
855
+ global_step += 1
856
+
857
+ if global_step >= args.max_train_steps:
858
+ break
859
+
860
+ accelerator.wait_for_everyone()
861
+
862
+ save_weights(global_step)
863
+
864
+ accelerator.end_training()
865
+
866
+
867
+ if __name__ == "__main__":
868
+ args = parse_args()
869
+ main(args)