fffffchopin commited on
Commit
d9fc95a
·
verified ·
1 Parent(s): 90435f5

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. T2I_train.py +1153 -0
  2. train_instruct_pix2pix.py +1042 -0
T2I_train.py ADDED
@@ -0,0 +1,1153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from contextlib import nullcontext
24
+ from pathlib import Path
25
+
26
+ import accelerate
27
+ import datasets
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.state import AcceleratorState
36
+ from accelerate.utils import ProjectConfiguration, set_seed
37
+ from datasets import load_dataset
38
+ from huggingface_hub import create_repo, upload_folder
39
+ from packaging import version
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import CLIPTextModel, CLIPTokenizer
43
+ from transformers.utils import ContextManagers
44
+
45
+ import diffusers
46
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
47
+ from diffusers.optimization import get_scheduler
48
+ from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
49
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
50
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
51
+ from diffusers.utils.import_utils import is_xformers_available
52
+ from diffusers.utils.torch_utils import is_compiled_module
53
+
54
+
55
+ if is_wandb_available():
56
+ import wandb
57
+
58
+
59
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
+ check_min_version("0.31.0.dev0")
61
+
62
+ logger = get_logger(__name__, log_level="INFO")
63
+
64
+ DATASET_NAME_MAPPING = {
65
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
66
+ }
67
+
68
+
69
+ def save_model_card(
70
+ args,
71
+ repo_id: str,
72
+ images: list = None,
73
+ repo_folder: str = None,
74
+ ):
75
+ img_str = ""
76
+ if len(images) > 0:
77
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
78
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
79
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
80
+
81
+ model_description = f"""
82
+ # Text-to-image finetuning - {repo_id}
83
+
84
+ This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
85
+ {img_str}
86
+
87
+ ## Pipeline usage
88
+
89
+ You can use the pipeline like so:
90
+
91
+ ```python
92
+ from diffusers import DiffusionPipeline
93
+ import torch
94
+
95
+ pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
96
+ prompt = "{args.validation_prompts[0]}"
97
+ image = pipeline(prompt).images[0]
98
+ image.save("my_image.png")
99
+ ```
100
+
101
+ ## Training info
102
+
103
+ These are the key hyperparameters used during training:
104
+
105
+ * Epochs: {args.num_train_epochs}
106
+ * Learning rate: {args.learning_rate}
107
+ * Batch size: {args.train_batch_size}
108
+ * Gradient accumulation steps: {args.gradient_accumulation_steps}
109
+ * Image resolution: {args.resolution}
110
+ * Mixed-precision: {args.mixed_precision}
111
+
112
+ """
113
+ wandb_info = ""
114
+ if is_wandb_available():
115
+ wandb_run_url = None
116
+ if wandb.run is not None:
117
+ wandb_run_url = wandb.run.url
118
+
119
+ if wandb_run_url is not None:
120
+ wandb_info = f"""
121
+ More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
122
+ """
123
+
124
+ model_description += wandb_info
125
+
126
+ model_card = load_or_create_model_card(
127
+ repo_id_or_path=repo_id,
128
+ from_training=True,
129
+ license="creativeml-openrail-m",
130
+ base_model=args.pretrained_model_name_or_path,
131
+ model_description=model_description,
132
+ inference=True,
133
+ )
134
+
135
+ tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
136
+ model_card = populate_model_card(model_card, tags=tags)
137
+
138
+ model_card.save(os.path.join(repo_folder, "README.md"))
139
+
140
+
141
+ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
142
+ logger.info("Running validation... ")
143
+
144
+ pipeline = StableDiffusionPipeline.from_pretrained(
145
+ args.pretrained_model_name_or_path,
146
+ vae=accelerator.unwrap_model(vae),
147
+ text_encoder=accelerator.unwrap_model(text_encoder),
148
+ tokenizer=tokenizer,
149
+ unet=accelerator.unwrap_model(unet),
150
+ safety_checker=None,
151
+ revision=args.revision,
152
+ variant=args.variant,
153
+ torch_dtype=weight_dtype,
154
+ )
155
+ pipeline = pipeline.to(accelerator.device)
156
+ pipeline.set_progress_bar_config(disable=True)
157
+
158
+ if args.enable_xformers_memory_efficient_attention:
159
+ pipeline.enable_xformers_memory_efficient_attention()
160
+
161
+ if args.seed is None:
162
+ generator = None
163
+ else:
164
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
165
+
166
+ images = []
167
+ for i in range(len(args.validation_prompts)):
168
+ if torch.backends.mps.is_available():
169
+ autocast_ctx = nullcontext()
170
+ else:
171
+ autocast_ctx = torch.autocast(accelerator.device.type)
172
+
173
+ with autocast_ctx:
174
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
175
+
176
+ images.append(image)
177
+
178
+ for tracker in accelerator.trackers:
179
+ if tracker.name == "tensorboard":
180
+ np_images = np.stack([np.asarray(img) for img in images])
181
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
182
+ elif tracker.name == "wandb":
183
+ tracker.log(
184
+ {
185
+ "validation": [
186
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
187
+ for i, image in enumerate(images)
188
+ ]
189
+ }
190
+ )
191
+ else:
192
+ logger.warning(f"image logging not implemented for {tracker.name}")
193
+
194
+ del pipeline
195
+ torch.cuda.empty_cache()
196
+
197
+ return images
198
+
199
+
200
+ def parse_args():
201
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
202
+ parser.add_argument(
203
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
204
+ )
205
+ parser.add_argument(
206
+ "--pretrained_model_name_or_path",
207
+ type=str,
208
+ default=None,
209
+ required=True,
210
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
211
+ )
212
+ parser.add_argument(
213
+ "--revision",
214
+ type=str,
215
+ default=None,
216
+ required=False,
217
+ help="Revision of pretrained model identifier from huggingface.co/models.",
218
+ )
219
+ parser.add_argument(
220
+ "--variant",
221
+ type=str,
222
+ default=None,
223
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
224
+ )
225
+ parser.add_argument(
226
+ "--dataset_name",
227
+ type=str,
228
+ default=None,
229
+ help=(
230
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
231
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
232
+ " or to a folder containing files that 🤗 Datasets can understand."
233
+ ),
234
+ )
235
+ parser.add_argument(
236
+ "--dataset_config_name",
237
+ type=str,
238
+ default=None,
239
+ help="The config of the Dataset, leave as None if there's only one config.",
240
+ )
241
+ parser.add_argument(
242
+ "--train_data_dir",
243
+ type=str,
244
+ default=None,
245
+ help=(
246
+ "A folder containing the training data. Folder contents must follow the structure described in"
247
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
248
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
249
+ ),
250
+ )
251
+ parser.add_argument(
252
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
253
+ )
254
+ parser.add_argument(
255
+ "--caption_column",
256
+ type=str,
257
+ default="text",
258
+ help="The column of the dataset containing a caption or a list of captions.",
259
+ )
260
+ parser.add_argument(
261
+ "--max_train_samples",
262
+ type=int,
263
+ default=None,
264
+ help=(
265
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
266
+ "value if set."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--validation_prompts",
271
+ type=str,
272
+ default=None,
273
+ nargs="+",
274
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
275
+ )
276
+ parser.add_argument(
277
+ "--output_dir",
278
+ type=str,
279
+ default="sd-model-finetuned",
280
+ help="The output directory where the model predictions and checkpoints will be written.",
281
+ )
282
+ parser.add_argument(
283
+ "--cache_dir",
284
+ type=str,
285
+ default=None,
286
+ help="The directory where the downloaded models and datasets will be stored.",
287
+ )
288
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
289
+ parser.add_argument(
290
+ "--resolution",
291
+ type=int,
292
+ default=512,
293
+ help=(
294
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
295
+ " resolution"
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--center_crop",
300
+ default=False,
301
+ action="store_true",
302
+ help=(
303
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
304
+ " cropped. The images will be resized to the resolution first before cropping."
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--random_flip",
309
+ action="store_true",
310
+ help="whether to randomly flip images horizontally",
311
+ )
312
+ parser.add_argument(
313
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
314
+ )
315
+ parser.add_argument("--num_train_epochs", type=int, default=100)
316
+ parser.add_argument(
317
+ "--max_train_steps",
318
+ type=int,
319
+ default=None,
320
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
321
+ )
322
+ parser.add_argument(
323
+ "--gradient_accumulation_steps",
324
+ type=int,
325
+ default=1,
326
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
327
+ )
328
+ parser.add_argument(
329
+ "--gradient_checkpointing",
330
+ action="store_true",
331
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
332
+ )
333
+ parser.add_argument(
334
+ "--learning_rate",
335
+ type=float,
336
+ default=1e-4,
337
+ help="Initial learning rate (after the potential warmup period) to use.",
338
+ )
339
+ parser.add_argument(
340
+ "--scale_lr",
341
+ action="store_true",
342
+ default=False,
343
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
344
+ )
345
+ parser.add_argument(
346
+ "--lr_scheduler",
347
+ type=str,
348
+ default="constant",
349
+ help=(
350
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
351
+ ' "constant", "constant_with_warmup"]'
352
+ ),
353
+ )
354
+ parser.add_argument(
355
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
356
+ )
357
+ parser.add_argument(
358
+ "--snr_gamma",
359
+ type=float,
360
+ default=None,
361
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
362
+ "More details here: https://arxiv.org/abs/2303.09556.",
363
+ )
364
+ parser.add_argument(
365
+ "--dream_training",
366
+ action="store_true",
367
+ help=(
368
+ "Use the DREAM training method, which makes training more efficient and accurate at the ",
369
+ "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
370
+ ),
371
+ )
372
+ parser.add_argument(
373
+ "--dream_detail_preservation",
374
+ type=float,
375
+ default=1.0,
376
+ help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
377
+ )
378
+ parser.add_argument(
379
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
380
+ )
381
+ parser.add_argument(
382
+ "--allow_tf32",
383
+ action="store_true",
384
+ help=(
385
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
386
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
387
+ ),
388
+ )
389
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
390
+ parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
391
+ parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
392
+ parser.add_argument(
393
+ "--non_ema_revision",
394
+ type=str,
395
+ default=None,
396
+ required=False,
397
+ help=(
398
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
399
+ " remote repository specified with --pretrained_model_name_or_path."
400
+ ),
401
+ )
402
+ parser.add_argument(
403
+ "--dataloader_num_workers",
404
+ type=int,
405
+ default=0,
406
+ help=(
407
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
408
+ ),
409
+ )
410
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
411
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
412
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
413
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
414
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
415
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
416
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
417
+ parser.add_argument(
418
+ "--prediction_type",
419
+ type=str,
420
+ default=None,
421
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
422
+ )
423
+ parser.add_argument(
424
+ "--hub_model_id",
425
+ type=str,
426
+ default=None,
427
+ help="The name of the repository to keep in sync with the local `output_dir`.",
428
+ )
429
+ parser.add_argument(
430
+ "--logging_dir",
431
+ type=str,
432
+ default="logs",
433
+ help=(
434
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
435
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
436
+ ),
437
+ )
438
+ parser.add_argument(
439
+ "--mixed_precision",
440
+ type=str,
441
+ default=None,
442
+ choices=["no", "fp16", "bf16"],
443
+ help=(
444
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
445
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
446
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
447
+ ),
448
+ )
449
+ parser.add_argument(
450
+ "--report_to",
451
+ type=str,
452
+ default="tensorboard",
453
+ help=(
454
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
455
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
456
+ ),
457
+ )
458
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
459
+ parser.add_argument(
460
+ "--checkpointing_steps",
461
+ type=int,
462
+ default=500,
463
+ help=(
464
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
465
+ " training using `--resume_from_checkpoint`."
466
+ ),
467
+ )
468
+ parser.add_argument(
469
+ "--checkpoints_total_limit",
470
+ type=int,
471
+ default=None,
472
+ help=("Max number of checkpoints to store."),
473
+ )
474
+ parser.add_argument(
475
+ "--resume_from_checkpoint",
476
+ type=str,
477
+ default=None,
478
+ help=(
479
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
480
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
485
+ )
486
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
487
+ parser.add_argument(
488
+ "--validation_epochs",
489
+ type=int,
490
+ default=5,
491
+ help="Run validation every X epochs.",
492
+ )
493
+ parser.add_argument(
494
+ "--tracker_project_name",
495
+ type=str,
496
+ default="text2image-fine-tune",
497
+ help=(
498
+ "The `project_name` argument passed to Accelerator.init_trackers for"
499
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500
+ ),
501
+ )
502
+
503
+ args = parser.parse_args()
504
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
505
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
506
+ args.local_rank = env_local_rank
507
+
508
+ # Sanity checks
509
+ if args.dataset_name is None and args.train_data_dir is None:
510
+ raise ValueError("Need either a dataset name or a training folder.")
511
+
512
+ # default to using the same revision for the non-ema model if not specified
513
+ if args.non_ema_revision is None:
514
+ args.non_ema_revision = args.revision
515
+
516
+ return args
517
+
518
+
519
+ def main():
520
+ args = parse_args()
521
+
522
+ if args.report_to == "wandb" and args.hub_token is not None:
523
+ raise ValueError(
524
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
525
+ " Please use `huggingface-cli login` to authenticate with the Hub."
526
+ )
527
+
528
+ if args.non_ema_revision is not None:
529
+ deprecate(
530
+ "non_ema_revision!=None",
531
+ "0.15.0",
532
+ message=(
533
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
534
+ " use `--variant=non_ema` instead."
535
+ ),
536
+ )
537
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
538
+
539
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
540
+
541
+ accelerator = Accelerator(
542
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
543
+ mixed_precision=args.mixed_precision,
544
+ log_with=args.report_to,
545
+ project_config=accelerator_project_config,
546
+ )
547
+
548
+ # Disable AMP for MPS.
549
+ if torch.backends.mps.is_available():
550
+ accelerator.native_amp = False
551
+
552
+ # Make one log on every process with the configuration for debugging.
553
+ logging.basicConfig(
554
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
555
+ datefmt="%m/%d/%Y %H:%M:%S",
556
+ level=logging.INFO,
557
+ )
558
+ logger.info(accelerator.state, main_process_only=False)
559
+ if accelerator.is_local_main_process:
560
+ datasets.utils.logging.set_verbosity_warning()
561
+ transformers.utils.logging.set_verbosity_warning()
562
+ diffusers.utils.logging.set_verbosity_info()
563
+ else:
564
+ datasets.utils.logging.set_verbosity_error()
565
+ transformers.utils.logging.set_verbosity_error()
566
+ diffusers.utils.logging.set_verbosity_error()
567
+
568
+ # If passed along, set the training seed now.
569
+ if args.seed is not None:
570
+ set_seed(args.seed)
571
+
572
+ # Handle the repository creation
573
+ if accelerator.is_main_process:
574
+ if args.output_dir is not None:
575
+ os.makedirs(args.output_dir, exist_ok=True)
576
+
577
+ if args.push_to_hub:
578
+ repo_id = create_repo(
579
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
580
+ ).repo_id
581
+
582
+ # Load scheduler, tokenizer and models.
583
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
584
+ tokenizer = CLIPTokenizer.from_pretrained(
585
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
586
+ )
587
+
588
+ def deepspeed_zero_init_disabled_context_manager():
589
+ """
590
+ returns either a context list that includes one that will disable zero.Init or an empty context list
591
+ """
592
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
593
+ if deepspeed_plugin is None:
594
+ return []
595
+
596
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
597
+
598
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
599
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
600
+ # will try to assign the same optimizer with the same weights to all models during
601
+ # `deepspeed.initialize`, which of course doesn't work.
602
+ #
603
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
604
+ # frozen models from being partitioned during `zero.Init` which gets called during
605
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
606
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
607
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
608
+ text_encoder = CLIPTextModel.from_pretrained(
609
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
610
+ )
611
+ vae = AutoencoderKL.from_pretrained(
612
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
613
+ )
614
+
615
+ unet = UNet2DConditionModel.from_pretrained(
616
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
617
+ )
618
+
619
+ # Freeze vae and text_encoder and set unet to trainable
620
+ vae.requires_grad_(False)
621
+ text_encoder.requires_grad_(False)
622
+ unet.train()
623
+
624
+ # Create EMA for the unet.
625
+ if args.use_ema:
626
+ ema_unet = UNet2DConditionModel.from_pretrained(
627
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
628
+ )
629
+ ema_unet = EMAModel(
630
+ ema_unet.parameters(),
631
+ model_cls=UNet2DConditionModel,
632
+ model_config=ema_unet.config,
633
+ foreach=args.foreach_ema,
634
+ )
635
+
636
+ if args.enable_xformers_memory_efficient_attention:
637
+ if is_xformers_available():
638
+ import xformers
639
+
640
+ xformers_version = version.parse(xformers.__version__)
641
+ if xformers_version == version.parse("0.0.16"):
642
+ logger.warning(
643
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
644
+ )
645
+ unet.enable_xformers_memory_efficient_attention()
646
+ else:
647
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
648
+
649
+ # `accelerate` 0.16.0 will have better support for customized saving
650
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
651
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
652
+ def save_model_hook(models, weights, output_dir):
653
+ if accelerator.is_main_process:
654
+ if args.use_ema:
655
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
656
+
657
+ for i, model in enumerate(models):
658
+ model.save_pretrained(os.path.join(output_dir, "unet"))
659
+
660
+ # make sure to pop weight so that corresponding model is not saved again
661
+ weights.pop()
662
+
663
+ def load_model_hook(models, input_dir):
664
+ if args.use_ema:
665
+ load_model = EMAModel.from_pretrained(
666
+ os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
667
+ )
668
+ ema_unet.load_state_dict(load_model.state_dict())
669
+ if args.offload_ema:
670
+ ema_unet.pin_memory()
671
+ else:
672
+ ema_unet.to(accelerator.device)
673
+ del load_model
674
+
675
+ for _ in range(len(models)):
676
+ # pop models so that they are not loaded again
677
+ model = models.pop()
678
+
679
+ # load diffusers style into model
680
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
681
+ model.register_to_config(**load_model.config)
682
+
683
+ model.load_state_dict(load_model.state_dict())
684
+ del load_model
685
+
686
+ accelerator.register_save_state_pre_hook(save_model_hook)
687
+ accelerator.register_load_state_pre_hook(load_model_hook)
688
+
689
+ if args.gradient_checkpointing:
690
+ unet.enable_gradient_checkpointing()
691
+
692
+ # Enable TF32 for faster training on Ampere GPUs,
693
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
694
+ if args.allow_tf32:
695
+ torch.backends.cuda.matmul.allow_tf32 = True
696
+
697
+ if args.scale_lr:
698
+ args.learning_rate = (
699
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
700
+ )
701
+
702
+ # Initialize the optimizer
703
+ if args.use_8bit_adam:
704
+ try:
705
+ import bitsandbytes as bnb
706
+ except ImportError:
707
+ raise ImportError(
708
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
709
+ )
710
+
711
+ optimizer_cls = bnb.optim.AdamW8bit
712
+ else:
713
+ optimizer_cls = torch.optim.AdamW
714
+
715
+ optimizer = optimizer_cls(
716
+ unet.parameters(),
717
+ lr=args.learning_rate,
718
+ betas=(args.adam_beta1, args.adam_beta2),
719
+ weight_decay=args.adam_weight_decay,
720
+ eps=args.adam_epsilon,
721
+ )
722
+
723
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
724
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
725
+
726
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
727
+ # download the dataset.
728
+ if args.dataset_name is not None:
729
+ # Downloading and loading a dataset from the hub.
730
+ dataset = load_dataset(
731
+ args.dataset_name,
732
+ args.dataset_config_name,
733
+ cache_dir=args.cache_dir,
734
+ data_dir=args.train_data_dir,
735
+ )
736
+ else:
737
+ data_files = {}
738
+ if args.train_data_dir is not None:
739
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
740
+ dataset = load_dataset(
741
+ "imagefolder",
742
+ data_files=data_files,
743
+ cache_dir=args.cache_dir,
744
+ )
745
+ # See more about loading custom images at
746
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
747
+
748
+ # Preprocessing the datasets.
749
+ # We need to tokenize inputs and targets.
750
+ column_names = dataset["train"].column_names
751
+
752
+ # 6. Get the column names for input/target.
753
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
754
+ if args.image_column is None:
755
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
756
+ else:
757
+ image_column = args.image_column
758
+ if image_column not in column_names:
759
+ raise ValueError(
760
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
761
+ )
762
+ if args.caption_column is None:
763
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
764
+ else:
765
+ caption_column = args.caption_column
766
+ if caption_column not in column_names:
767
+ raise ValueError(
768
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
769
+ )
770
+
771
+ # Preprocessing the datasets.
772
+ # We need to tokenize input captions and transform the images.
773
+ def tokenize_captions(examples, is_train=True):
774
+ captions = []
775
+ for caption in examples[caption_column]:
776
+ if isinstance(caption, str):
777
+ captions.append(caption)
778
+ elif isinstance(caption, (list, np.ndarray)):
779
+ # take a random caption if there are multiple
780
+ captions.append(random.choice(caption) if is_train else caption[0])
781
+ else:
782
+ raise ValueError(
783
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
784
+ )
785
+ inputs = tokenizer(
786
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
787
+ )
788
+ return inputs.input_ids
789
+
790
+ # Preprocessing the datasets.
791
+ train_transforms = transforms.Compose(
792
+ [
793
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
794
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
795
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
796
+ transforms.ToTensor(),
797
+ transforms.Normalize([0.5], [0.5]),
798
+ ]
799
+ )
800
+
801
+ def preprocess_train(examples):
802
+ images = [image.convert("RGB") for image in examples[image_column]]
803
+ examples["pixel_values"] = [train_transforms(image) for image in images]
804
+ examples["input_ids"] = tokenize_captions(examples)
805
+ return examples
806
+
807
+ with accelerator.main_process_first():
808
+ if args.max_train_samples is not None:
809
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
810
+ # Set the training transforms
811
+ train_dataset = dataset["train"].with_transform(preprocess_train)
812
+
813
+ def collate_fn(examples):
814
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
815
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
816
+ input_ids = torch.stack([example["input_ids"] for example in examples])
817
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
818
+
819
+ # DataLoaders creation:
820
+ train_dataloader = torch.utils.data.DataLoader(
821
+ train_dataset,
822
+ shuffle=True,
823
+ collate_fn=collate_fn,
824
+ batch_size=args.train_batch_size,
825
+ num_workers=args.dataloader_num_workers,
826
+ )
827
+
828
+ # Scheduler and math around the number of training steps.
829
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
830
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
831
+ if args.max_train_steps is None:
832
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
833
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
834
+ num_training_steps_for_scheduler = (
835
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
836
+ )
837
+ else:
838
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
839
+
840
+ lr_scheduler = get_scheduler(
841
+ args.lr_scheduler,
842
+ optimizer=optimizer,
843
+ num_warmup_steps=num_warmup_steps_for_scheduler,
844
+ num_training_steps=num_training_steps_for_scheduler,
845
+ )
846
+
847
+ # Prepare everything with our `accelerator`.
848
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
849
+ unet, optimizer, train_dataloader, lr_scheduler
850
+ )
851
+
852
+ if args.use_ema:
853
+ if args.offload_ema:
854
+ ema_unet.pin_memory()
855
+ else:
856
+ ema_unet.to(accelerator.device)
857
+
858
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
859
+ # as these weights are only used for inference, keeping weights in full precision is not required.
860
+ weight_dtype = torch.float32
861
+ if accelerator.mixed_precision == "fp16":
862
+ weight_dtype = torch.float16
863
+ args.mixed_precision = accelerator.mixed_precision
864
+ elif accelerator.mixed_precision == "bf16":
865
+ weight_dtype = torch.bfloat16
866
+ args.mixed_precision = accelerator.mixed_precision
867
+
868
+ # Move text_encode and vae to gpu and cast to weight_dtype
869
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
870
+ vae.to(accelerator.device, dtype=weight_dtype)
871
+
872
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
873
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
874
+ if args.max_train_steps is None:
875
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
876
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
877
+ logger.warning(
878
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
879
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
880
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
881
+ )
882
+ # Afterwards we recalculate our number of training epochs
883
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
884
+
885
+ # We need to initialize the trackers we use, and also store our configuration.
886
+ # The trackers initializes automatically on the main process.
887
+ if accelerator.is_main_process:
888
+ tracker_config = dict(vars(args))
889
+ tracker_config.pop("validation_prompts")
890
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
891
+
892
+ # Function for unwrapping if model was compiled with `torch.compile`.
893
+ def unwrap_model(model):
894
+ model = accelerator.unwrap_model(model)
895
+ model = model._orig_mod if is_compiled_module(model) else model
896
+ return model
897
+
898
+ # Train!
899
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
900
+
901
+ logger.info("***** Running training *****")
902
+ logger.info(f" Num examples = {len(train_dataset)}")
903
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
904
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
905
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
906
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
907
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
908
+ global_step = 0
909
+ first_epoch = 0
910
+
911
+ # Potentially load in the weights and states from a previous save
912
+ if args.resume_from_checkpoint:
913
+ if args.resume_from_checkpoint != "latest":
914
+ path = os.path.basename(args.resume_from_checkpoint)
915
+ else:
916
+ # Get the most recent checkpoint
917
+ dirs = os.listdir(args.output_dir)
918
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
919
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
920
+ path = dirs[-1] if len(dirs) > 0 else None
921
+
922
+ if path is None:
923
+ accelerator.print(
924
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
925
+ )
926
+ args.resume_from_checkpoint = None
927
+ initial_global_step = 0
928
+ else:
929
+ accelerator.print(f"Resuming from checkpoint {path}")
930
+ accelerator.load_state(os.path.join(args.output_dir, path))
931
+ global_step = int(path.split("-")[1])
932
+
933
+ initial_global_step = global_step
934
+ first_epoch = global_step // num_update_steps_per_epoch
935
+
936
+ else:
937
+ initial_global_step = 0
938
+
939
+ progress_bar = tqdm(
940
+ range(0, args.max_train_steps),
941
+ initial=initial_global_step,
942
+ desc="Steps",
943
+ # Only show the progress bar once on each machine.
944
+ disable=not accelerator.is_local_main_process,
945
+ )
946
+
947
+ for epoch in range(first_epoch, args.num_train_epochs):
948
+ train_loss = 0.0
949
+ for step, batch in enumerate(train_dataloader):
950
+ with accelerator.accumulate(unet):
951
+ # Convert images to latent space
952
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
953
+ latents = latents * vae.config.scaling_factor
954
+
955
+ # Sample noise that we'll add to the latents
956
+ noise = torch.randn_like(latents)
957
+ if args.noise_offset:
958
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
959
+ noise += args.noise_offset * torch.randn(
960
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
961
+ )
962
+ if args.input_perturbation:
963
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
964
+ bsz = latents.shape[0]
965
+ # Sample a random timestep for each image
966
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
967
+ timesteps = timesteps.long()
968
+
969
+ # Add noise to the latents according to the noise magnitude at each timestep
970
+ # (this is the forward diffusion process)
971
+ if args.input_perturbation:
972
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
973
+ else:
974
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
975
+
976
+ # Get the text embedding for conditioning
977
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
978
+
979
+ # Get the target for loss depending on the prediction type
980
+ if args.prediction_type is not None:
981
+ # set prediction_type of scheduler if defined
982
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
983
+
984
+ if noise_scheduler.config.prediction_type == "epsilon":
985
+ target = noise
986
+ elif noise_scheduler.config.prediction_type == "v_prediction":
987
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
988
+ else:
989
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
990
+
991
+ if args.dream_training:
992
+ noisy_latents, target = compute_dream_and_update_latents(
993
+ unet,
994
+ noise_scheduler,
995
+ timesteps,
996
+ noise,
997
+ noisy_latents,
998
+ target,
999
+ encoder_hidden_states,
1000
+ args.dream_detail_preservation,
1001
+ )
1002
+
1003
+ # Predict the noise residual and compute loss
1004
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
1005
+
1006
+ if args.snr_gamma is None:
1007
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1008
+ else:
1009
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1010
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1011
+ # This is discussed in Section 4.2 of the same paper.
1012
+ snr = compute_snr(noise_scheduler, timesteps)
1013
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1014
+ dim=1
1015
+ )[0]
1016
+ if noise_scheduler.config.prediction_type == "epsilon":
1017
+ mse_loss_weights = mse_loss_weights / snr
1018
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1019
+ mse_loss_weights = mse_loss_weights / (snr + 1)
1020
+
1021
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1022
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1023
+ loss = loss.mean()
1024
+
1025
+ # Gather the losses across all processes for logging (if we use distributed training).
1026
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1027
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
1028
+
1029
+ # Backpropagate
1030
+ accelerator.backward(loss)
1031
+ if accelerator.sync_gradients:
1032
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
1033
+ optimizer.step()
1034
+ lr_scheduler.step()
1035
+ optimizer.zero_grad()
1036
+
1037
+ # Checks if the accelerator has performed an optimization step behind the scenes
1038
+ if accelerator.sync_gradients:
1039
+ if args.use_ema:
1040
+ if args.offload_ema:
1041
+ ema_unet.to(device="cuda", non_blocking=True)
1042
+ ema_unet.step(unet.parameters())
1043
+ if args.offload_ema:
1044
+ ema_unet.to(device="cpu", non_blocking=True)
1045
+ progress_bar.update(1)
1046
+ global_step += 1
1047
+ accelerator.log({"train_loss": train_loss}, step=global_step)
1048
+ train_loss = 0.0
1049
+
1050
+ if global_step % args.checkpointing_steps == 0:
1051
+ if accelerator.is_main_process:
1052
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1053
+ if args.checkpoints_total_limit is not None:
1054
+ checkpoints = os.listdir(args.output_dir)
1055
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1056
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1057
+
1058
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1059
+ if len(checkpoints) >= args.checkpoints_total_limit:
1060
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1061
+ removing_checkpoints = checkpoints[0:num_to_remove]
1062
+
1063
+ logger.info(
1064
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1065
+ )
1066
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1067
+
1068
+ for removing_checkpoint in removing_checkpoints:
1069
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1070
+ shutil.rmtree(removing_checkpoint)
1071
+
1072
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1073
+ accelerator.save_state(save_path)
1074
+ logger.info(f"Saved state to {save_path}")
1075
+
1076
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1077
+ progress_bar.set_postfix(**logs)
1078
+
1079
+ if global_step >= args.max_train_steps:
1080
+ break
1081
+
1082
+ if accelerator.is_main_process:
1083
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
1084
+ if args.use_ema:
1085
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1086
+ ema_unet.store(unet.parameters())
1087
+ ema_unet.copy_to(unet.parameters())
1088
+ log_validation(
1089
+ vae,
1090
+ text_encoder,
1091
+ tokenizer,
1092
+ unet,
1093
+ args,
1094
+ accelerator,
1095
+ weight_dtype,
1096
+ global_step,
1097
+ )
1098
+ if args.use_ema:
1099
+ # Switch back to the original UNet parameters.
1100
+ ema_unet.restore(unet.parameters())
1101
+
1102
+ # Create the pipeline using the trained modules and save it.
1103
+ accelerator.wait_for_everyone()
1104
+ if accelerator.is_main_process:
1105
+ unet = unwrap_model(unet)
1106
+ if args.use_ema:
1107
+ ema_unet.copy_to(unet.parameters())
1108
+
1109
+ pipeline = StableDiffusionPipeline.from_pretrained(
1110
+ args.pretrained_model_name_or_path,
1111
+ text_encoder=text_encoder,
1112
+ vae=vae,
1113
+ unet=unet,
1114
+ revision=args.revision,
1115
+ variant=args.variant,
1116
+ )
1117
+ pipeline.save_pretrained(args.output_dir)
1118
+
1119
+ # Run a final round of inference.
1120
+ images = []
1121
+ if args.validation_prompts is not None:
1122
+ logger.info("Running inference for collecting generated images...")
1123
+ pipeline = pipeline.to(accelerator.device)
1124
+ pipeline.torch_dtype = weight_dtype
1125
+ pipeline.set_progress_bar_config(disable=True)
1126
+
1127
+ if args.enable_xformers_memory_efficient_attention:
1128
+ pipeline.enable_xformers_memory_efficient_attention()
1129
+
1130
+ if args.seed is None:
1131
+ generator = None
1132
+ else:
1133
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1134
+
1135
+ for i in range(len(args.validation_prompts)):
1136
+ with torch.autocast("cuda"):
1137
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1138
+ images.append(image)
1139
+
1140
+ if args.push_to_hub:
1141
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1142
+ upload_folder(
1143
+ repo_id=repo_id,
1144
+ folder_path=args.output_dir,
1145
+ commit_message="End of training",
1146
+ ignore_patterns=["step_*", "epoch_*"],
1147
+ )
1148
+
1149
+ accelerator.end_training()
1150
+
1151
+
1152
+ if __name__ == "__main__":
1153
+ main()
train_instruct_pix2pix.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Script to fine-tune Stable Diffusion for InstructPix2Pix."""
18
+
19
+ import argparse
20
+ import logging
21
+ import math
22
+ import os
23
+ import shutil
24
+ from contextlib import nullcontext
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import datasets
29
+ import numpy as np
30
+ import PIL
31
+ import requests
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ import torch.utils.checkpoint
36
+ import transformers
37
+ from accelerate import Accelerator
38
+ from accelerate.logging import get_logger
39
+ from accelerate.utils import ProjectConfiguration, set_seed
40
+ from datasets import load_dataset
41
+ from huggingface_hub import create_repo, upload_folder
42
+ from packaging import version
43
+ from torchvision import transforms
44
+ from tqdm.auto import tqdm
45
+ from transformers import CLIPTextModel, CLIPTokenizer
46
+
47
+ import diffusers
48
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
49
+ from diffusers.optimization import get_scheduler
50
+ from diffusers.training_utils import EMAModel
51
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available
52
+ from diffusers.utils.import_utils import is_xformers_available
53
+ from diffusers.utils.torch_utils import is_compiled_module
54
+
55
+
56
+ if is_wandb_available():
57
+ import wandb
58
+
59
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
+ check_min_version("0.31.0.dev0")
61
+
62
+ logger = get_logger(__name__, log_level="INFO")
63
+
64
+ DATASET_NAME_MAPPING = {
65
+ "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),
66
+ }
67
+ WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
68
+
69
+
70
+ def log_validation(
71
+ pipeline,
72
+ args,
73
+ accelerator,
74
+ generator,
75
+ ):
76
+ logger.info(
77
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
78
+ f" {args.validation_prompt}."
79
+ )
80
+ pipeline = pipeline.to(accelerator.device)
81
+ pipeline.set_progress_bar_config(disable=True)
82
+
83
+ # run inference
84
+ original_image = download_image(args.val_image_url)
85
+ edited_images = []
86
+ if torch.backends.mps.is_available():
87
+ autocast_ctx = nullcontext()
88
+ else:
89
+ autocast_ctx = torch.autocast(accelerator.device.type)
90
+
91
+ with autocast_ctx:
92
+ for _ in range(args.num_validation_images):
93
+ edited_images.append(
94
+ pipeline(
95
+ args.validation_prompt,
96
+ image=original_image,
97
+ num_inference_steps=20,
98
+ image_guidance_scale=1.5,
99
+ guidance_scale=7,
100
+ generator=generator,
101
+ ).images[0]
102
+ )
103
+
104
+ for tracker in accelerator.trackers:
105
+ if tracker.name == "wandb":
106
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
107
+ for edited_image in edited_images:
108
+ wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
109
+ tracker.log({"validation": wandb_table})
110
+
111
+
112
+ def parse_args():
113
+ parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
114
+ parser.add_argument(
115
+ "--pretrained_model_name_or_path",
116
+ type=str,
117
+ default="timbrooks/instruct-pix2pix",
118
+ required=False,
119
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
120
+ )
121
+ parser.add_argument(
122
+ "--revision",
123
+ type=str,
124
+ default=None,
125
+ required=False,
126
+ help="Revision of pretrained model identifier from huggingface.co/models.",
127
+ )
128
+ parser.add_argument(
129
+ "--variant",
130
+ type=str,
131
+ default=None,
132
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
133
+ )
134
+ parser.add_argument(
135
+ "--dataset_name",
136
+ type=str,
137
+ default='fusing/instructpix2pix-1000-samples',
138
+ help=(
139
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
140
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
141
+ " or to a folder containing files that 🤗 Datasets can understand."
142
+ ),
143
+ )
144
+ parser.add_argument(
145
+ "--dataset_config_name",
146
+ type=str,
147
+ default=None,
148
+ help="The config of the Dataset, leave as None if there's only one config.",
149
+ )
150
+ parser.add_argument(
151
+ "--train_data_dir",
152
+ type=str,
153
+ default=None,
154
+ help=(
155
+ "A folder containing the training data. Folder contents must follow the structure described in"
156
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
157
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
158
+ ),
159
+ )
160
+ parser.add_argument(
161
+ "--original_image_column",
162
+ type=str,
163
+ default="input_image",
164
+ help="The column of the dataset containing the original image on which edits where made.",
165
+ )
166
+ parser.add_argument(
167
+ "--edited_image_column",
168
+ type=str,
169
+ default="edited_image",
170
+ help="The column of the dataset containing the edited image.",
171
+ )
172
+ parser.add_argument(
173
+ "--edit_prompt_column",
174
+ type=str,
175
+ default="edit_prompt",
176
+ help="The column of the dataset containing the edit instruction.",
177
+ )
178
+ parser.add_argument(
179
+ "--val_image_url",
180
+ type=str,
181
+ default=None,
182
+ help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
183
+ )
184
+ parser.add_argument(
185
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
186
+ )
187
+ parser.add_argument(
188
+ "--num_validation_images",
189
+ type=int,
190
+ default=4,
191
+ help="Number of images that should be generated during validation with `validation_prompt`.",
192
+ )
193
+ parser.add_argument(
194
+ "--validation_epochs",
195
+ type=int,
196
+ default=1,
197
+ help=(
198
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
199
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
200
+ ),
201
+ )
202
+ parser.add_argument(
203
+ "--max_train_samples",
204
+ type=int,
205
+ default=None,
206
+ help=(
207
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
208
+ "value if set."
209
+ ),
210
+ )
211
+ parser.add_argument(
212
+ "--output_dir",
213
+ type=str,
214
+ default="instruct-pix2pix-model",
215
+ help="The output directory where the model predictions and checkpoints will be written.",
216
+ )
217
+ parser.add_argument(
218
+ "--cache_dir",
219
+ type=str,
220
+ default=None,
221
+ help="The directory where the downloaded models and datasets will be stored.",
222
+ )
223
+ parser.add_argument("--seed", type=int, default=1, help="A seed for reproducible training.")
224
+ parser.add_argument(
225
+ "--resolution",
226
+ type=int,
227
+ default=256,
228
+ help=(
229
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
230
+ " resolution"
231
+ ),
232
+ )
233
+ parser.add_argument(
234
+ "--center_crop",
235
+ default=False,
236
+ action="store_true",
237
+ help=(
238
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
239
+ " cropped. The images will be resized to the resolution first before cropping."
240
+ ),
241
+ )
242
+ parser.add_argument(
243
+ "--random_flip",
244
+ action="store_true",
245
+ help="whether to randomly flip images horizontally",
246
+ )
247
+ parser.add_argument(
248
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
249
+ )
250
+ parser.add_argument("--num_train_epochs", type=int, default=100)
251
+ parser.add_argument(
252
+ "--max_train_steps",
253
+ type=int,
254
+ default=None,
255
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
256
+ )
257
+ parser.add_argument(
258
+ "--gradient_accumulation_steps",
259
+ type=int,
260
+ default=1,
261
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
262
+ )
263
+ parser.add_argument(
264
+ "--gradient_checkpointing",
265
+ action="store_true",
266
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
267
+ )
268
+ parser.add_argument(
269
+ "--learning_rate",
270
+ type=float,
271
+ default=1e-4,
272
+ help="Initial learning rate (after the potential warmup period) to use.",
273
+ )
274
+ parser.add_argument(
275
+ "--scale_lr",
276
+ action="store_true",
277
+ default=False,
278
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
279
+ )
280
+ parser.add_argument(
281
+ "--lr_scheduler",
282
+ type=str,
283
+ default="constant",
284
+ help=(
285
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
286
+ ' "constant", "constant_with_warmup"]'
287
+ ),
288
+ )
289
+ parser.add_argument(
290
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
291
+ )
292
+ parser.add_argument(
293
+ "--conditioning_dropout_prob",
294
+ type=float,
295
+ default=None,
296
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
297
+ )
298
+ parser.add_argument(
299
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
300
+ )
301
+ parser.add_argument(
302
+ "--allow_tf32",
303
+ action="store_true",
304
+ help=(
305
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
306
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
307
+ ),
308
+ )
309
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
310
+ parser.add_argument(
311
+ "--non_ema_revision",
312
+ type=str,
313
+ default=None,
314
+ required=False,
315
+ help=(
316
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
317
+ " remote repository specified with --pretrained_model_name_or_path."
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--dataloader_num_workers",
322
+ type=int,
323
+ default=0,
324
+ help=(
325
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
326
+ ),
327
+ )
328
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
329
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
330
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
331
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
332
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
333
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
334
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
335
+ parser.add_argument(
336
+ "--hub_model_id",
337
+ type=str,
338
+ default=None,
339
+ help="The name of the repository to keep in sync with the local `output_dir`.",
340
+ )
341
+ parser.add_argument(
342
+ "--logging_dir",
343
+ type=str,
344
+ default="logs",
345
+ help=(
346
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
347
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--mixed_precision",
352
+ type=str,
353
+ default=None,
354
+ choices=["no", "fp16", "bf16"],
355
+ help=(
356
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
357
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
358
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--report_to",
363
+ type=str,
364
+ default="tensorboard",
365
+ help=(
366
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
367
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
368
+ ),
369
+ )
370
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
371
+ parser.add_argument(
372
+ "--checkpointing_steps",
373
+ type=int,
374
+ default=500,
375
+ help=(
376
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
377
+ " training using `--resume_from_checkpoint`."
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--checkpoints_total_limit",
382
+ type=int,
383
+ default=None,
384
+ help=("Max number of checkpoints to store."),
385
+ )
386
+ parser.add_argument(
387
+ "--resume_from_checkpoint",
388
+ type=str,
389
+ default=None,
390
+ help=(
391
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
392
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
393
+ ),
394
+ )
395
+ parser.add_argument(
396
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
397
+ )
398
+
399
+ args = parser.parse_args()
400
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
401
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
402
+ args.local_rank = env_local_rank
403
+
404
+ # Sanity checks
405
+ if args.dataset_name is None and args.train_data_dir is None:
406
+ raise ValueError("Need either a dataset name or a training folder.")
407
+
408
+ # default to using the same revision for the non-ema model if not specified
409
+ if args.non_ema_revision is None:
410
+ args.non_ema_revision = args.revision
411
+
412
+ return args
413
+
414
+
415
+ def convert_to_np(image, resolution):
416
+ image = image.convert("RGB").resize((resolution, resolution))
417
+ return np.array(image).transpose(2, 0, 1)
418
+
419
+
420
+ def download_image(url):
421
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
422
+ image = PIL.ImageOps.exif_transpose(image)
423
+ image = image.convert("RGB")
424
+ return image
425
+
426
+
427
+ def main():
428
+ args = parse_args()
429
+ if args.report_to == "wandb" and args.hub_token is not None:
430
+ raise ValueError(
431
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
432
+ " Please use `huggingface-cli login` to authenticate with the Hub."
433
+ )
434
+
435
+ if args.non_ema_revision is not None:
436
+ deprecate(
437
+ "non_ema_revision!=None",
438
+ "0.15.0",
439
+ message=(
440
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
441
+ " use `--variant=non_ema` instead."
442
+ ),
443
+ )
444
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
445
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
446
+ accelerator = Accelerator(
447
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
448
+ mixed_precision=args.mixed_precision,
449
+ log_with=args.report_to,
450
+ project_config=accelerator_project_config,
451
+ )
452
+
453
+ # Disable AMP for MPS.
454
+ if torch.backends.mps.is_available():
455
+ accelerator.native_amp = False
456
+
457
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
458
+
459
+ # Make one log on every process with the configuration for debugging.
460
+ logging.basicConfig(
461
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
462
+ datefmt="%m/%d/%Y %H:%M:%S",
463
+ level=logging.INFO,
464
+ )
465
+ logger.info(accelerator.state, main_process_only=False)
466
+ if accelerator.is_local_main_process:
467
+ datasets.utils.logging.set_verbosity_warning()
468
+ transformers.utils.logging.set_verbosity_warning()
469
+ diffusers.utils.logging.set_verbosity_info()
470
+ else:
471
+ datasets.utils.logging.set_verbosity_error()
472
+ transformers.utils.logging.set_verbosity_error()
473
+ diffusers.utils.logging.set_verbosity_error()
474
+
475
+ # If passed along, set the training seed now.
476
+ if args.seed is not None:
477
+ set_seed(args.seed)
478
+
479
+ # Handle the repository creation
480
+ if accelerator.is_main_process:
481
+ if args.output_dir is not None:
482
+ os.makedirs(args.output_dir, exist_ok=True)
483
+
484
+ if args.push_to_hub:
485
+ repo_id = create_repo(
486
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
487
+ ).repo_id
488
+
489
+ # Load scheduler, tokenizer and models.
490
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
491
+ tokenizer = CLIPTokenizer.from_pretrained(
492
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
493
+ )
494
+ text_encoder = CLIPTextModel.from_pretrained(
495
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
496
+ )
497
+ vae = AutoencoderKL.from_pretrained(
498
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
499
+ )
500
+ unet = UNet2DConditionModel.from_pretrained(
501
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
502
+ )
503
+
504
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
505
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
506
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
507
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
508
+ # initialized to zero.
509
+ logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
510
+ in_channels = 8
511
+ out_channels = unet.conv_in.out_channels
512
+ unet.register_to_config(in_channels=in_channels)
513
+
514
+ with torch.no_grad():
515
+ new_conv_in = nn.Conv2d(
516
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
517
+ )
518
+ new_conv_in.weight.zero_()
519
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
520
+ unet.conv_in = new_conv_in
521
+
522
+ # Freeze vae and text_encoder
523
+ vae.requires_grad_(False)
524
+ text_encoder.requires_grad_(False)
525
+
526
+ # Create EMA for the unet.
527
+ if args.use_ema:
528
+ ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
529
+
530
+ if args.enable_xformers_memory_efficient_attention:
531
+ if is_xformers_available():
532
+ import xformers
533
+
534
+ xformers_version = version.parse(xformers.__version__)
535
+ if xformers_version == version.parse("0.0.16"):
536
+ logger.warning(
537
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
538
+ )
539
+ unet.enable_xformers_memory_efficient_attention()
540
+ else:
541
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
542
+
543
+ def unwrap_model(model):
544
+ model = accelerator.unwrap_model(model)
545
+ model = model._orig_mod if is_compiled_module(model) else model
546
+ return model
547
+
548
+ # `accelerate` 0.16.0 will have better support for customized saving
549
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
550
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
551
+ def save_model_hook(models, weights, output_dir):
552
+ if accelerator.is_main_process:
553
+ if args.use_ema:
554
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
555
+
556
+ for i, model in enumerate(models):
557
+ model.save_pretrained(os.path.join(output_dir, "unet"))
558
+
559
+ # make sure to pop weight so that corresponding model is not saved again
560
+ if weights:
561
+ weights.pop()
562
+
563
+ def load_model_hook(models, input_dir):
564
+ if args.use_ema:
565
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
566
+ ema_unet.load_state_dict(load_model.state_dict())
567
+ ema_unet.to(accelerator.device)
568
+ del load_model
569
+
570
+ for i in range(len(models)):
571
+ # pop models so that they are not loaded again
572
+ model = models.pop()
573
+
574
+ # load diffusers style into model
575
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
576
+ model.register_to_config(**load_model.config)
577
+
578
+ model.load_state_dict(load_model.state_dict())
579
+ del load_model
580
+
581
+ accelerator.register_save_state_pre_hook(save_model_hook)
582
+ accelerator.register_load_state_pre_hook(load_model_hook)
583
+
584
+ if args.gradient_checkpointing:
585
+ unet.enable_gradient_checkpointing()
586
+
587
+ # Enable TF32 for faster training on Ampere GPUs,
588
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
589
+ if args.allow_tf32:
590
+ torch.backends.cuda.matmul.allow_tf32 = True
591
+
592
+ if args.scale_lr:
593
+ args.learning_rate = (
594
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
595
+ )
596
+
597
+ # Initialize the optimizer
598
+ if args.use_8bit_adam:
599
+ try:
600
+ import bitsandbytes as bnb
601
+ except ImportError:
602
+ raise ImportError(
603
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
604
+ )
605
+
606
+ optimizer_cls = bnb.optim.AdamW8bit
607
+ else:
608
+ optimizer_cls = torch.optim.AdamW
609
+
610
+ optimizer = optimizer_cls(
611
+ unet.parameters(),
612
+ lr=args.learning_rate,
613
+ betas=(args.adam_beta1, args.adam_beta2),
614
+ weight_decay=args.adam_weight_decay,
615
+ eps=args.adam_epsilon,
616
+ )
617
+
618
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
619
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
620
+
621
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
622
+ # download the dataset.
623
+ if args.dataset_name is not None:
624
+ # Downloading and loading a dataset from the hub.
625
+ dataset = load_dataset(
626
+ args.dataset_name,
627
+ args.dataset_config_name,
628
+ cache_dir=args.cache_dir,
629
+ )
630
+ else:
631
+ data_files = {}
632
+ if args.train_data_dir is not None:
633
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
634
+ dataset = load_dataset(
635
+ "imagefolder",
636
+ data_files=data_files,
637
+ cache_dir=args.cache_dir,
638
+ )
639
+ # See more about loading custom images at
640
+ # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
641
+
642
+ # Preprocessing the datasets.
643
+ # We need to tokenize inputs and targets.
644
+ column_names = dataset["train"].column_names
645
+
646
+ # 6. Get the column names for input/target.
647
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
648
+ if args.original_image_column is None:
649
+ original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
650
+ else:
651
+ original_image_column = args.original_image_column
652
+ if original_image_column not in column_names:
653
+ raise ValueError(
654
+ f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
655
+ )
656
+ if args.edit_prompt_column is None:
657
+ edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
658
+ else:
659
+ edit_prompt_column = args.edit_prompt_column
660
+ if edit_prompt_column not in column_names:
661
+ raise ValueError(
662
+ f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
663
+ )
664
+ if args.edited_image_column is None:
665
+ edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
666
+ else:
667
+ edited_image_column = args.edited_image_column
668
+ if edited_image_column not in column_names:
669
+ raise ValueError(
670
+ f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
671
+ )
672
+
673
+ # Preprocessing the datasets.
674
+ # We need to tokenize input captions and transform the images.
675
+ def tokenize_captions(captions):
676
+ inputs = tokenizer(
677
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
678
+ )
679
+ return inputs.input_ids
680
+
681
+ # Preprocessing the datasets.
682
+ train_transforms = transforms.Compose(
683
+ [
684
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
685
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
686
+ ]
687
+ )
688
+
689
+ def preprocess_images(examples):
690
+ original_images = np.concatenate(
691
+ [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
692
+ )
693
+ edited_images = np.concatenate(
694
+ [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
695
+ )
696
+ # We need to ensure that the original and the edited images undergo the same
697
+ # augmentation transforms.
698
+ images = np.concatenate([original_images, edited_images])
699
+ images = torch.tensor(images)
700
+ images = 2 * (images / 255) - 1
701
+ return train_transforms(images)
702
+
703
+ def preprocess_train(examples):
704
+ # Preprocess images.
705
+ preprocessed_images = preprocess_images(examples)
706
+ # Since the original and edited images were concatenated before
707
+ # applying the transformations, we need to separate them and reshape
708
+ # them accordingly.
709
+ original_images, edited_images = preprocessed_images.chunk(2)
710
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
711
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
712
+
713
+ # Collate the preprocessed images into the `examples`.
714
+ examples["original_pixel_values"] = original_images
715
+ examples["edited_pixel_values"] = edited_images
716
+
717
+ # Preprocess the captions.
718
+ captions = list(examples[edit_prompt_column])
719
+ examples["input_ids"] = tokenize_captions(captions)
720
+ return examples
721
+
722
+ with accelerator.main_process_first():
723
+ if args.max_train_samples is not None:
724
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
725
+ # Set the training transforms
726
+ train_dataset = dataset["train"].with_transform(preprocess_train)
727
+
728
+ def collate_fn(examples):
729
+ original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
730
+ original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
731
+ edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
732
+ edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
733
+ input_ids = torch.stack([example["input_ids"] for example in examples])
734
+ return {
735
+ "original_pixel_values": original_pixel_values,
736
+ "edited_pixel_values": edited_pixel_values,
737
+ "input_ids": input_ids,
738
+ }
739
+
740
+ # DataLoaders creation:
741
+ train_dataloader = torch.utils.data.DataLoader(
742
+ train_dataset,
743
+ shuffle=True,
744
+ collate_fn=collate_fn,
745
+ batch_size=args.train_batch_size,
746
+ num_workers=args.dataloader_num_workers,
747
+ )
748
+
749
+ # Scheduler and math around the number of training steps.
750
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
751
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
752
+ if args.max_train_steps is None:
753
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
754
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
755
+ num_training_steps_for_scheduler = (
756
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
757
+ )
758
+ else:
759
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
760
+
761
+ lr_scheduler = get_scheduler(
762
+ args.lr_scheduler,
763
+ optimizer=optimizer,
764
+ num_warmup_steps=num_warmup_steps_for_scheduler,
765
+ num_training_steps=num_training_steps_for_scheduler,
766
+ )
767
+
768
+ # Prepare everything with our `accelerator`.
769
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
770
+ unet, optimizer, train_dataloader, lr_scheduler
771
+ )
772
+
773
+ if args.use_ema:
774
+ ema_unet.to(accelerator.device)
775
+
776
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
777
+ # as these models are only used for inference, keeping weights in full precision is not required.
778
+ weight_dtype = torch.float32
779
+ if accelerator.mixed_precision == "fp16":
780
+ weight_dtype = torch.float16
781
+ elif accelerator.mixed_precision == "bf16":
782
+ weight_dtype = torch.bfloat16
783
+
784
+ # Move text_encode and vae to gpu and cast to weight_dtype
785
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
786
+ vae.to(accelerator.device, dtype=weight_dtype)
787
+
788
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
789
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
790
+ if args.max_train_steps is None:
791
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
792
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
793
+ logger.warning(
794
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
795
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
796
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
797
+ )
798
+ # Afterwards we recalculate our number of training epochs
799
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
800
+
801
+ # We need to initialize the trackers we use, and also store our configuration.
802
+ # The trackers initializes automatically on the main process.
803
+ if accelerator.is_main_process:
804
+ accelerator.init_trackers("instruct-pix2pix", config=vars(args))
805
+
806
+ # Train!
807
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
808
+
809
+ logger.info("***** Running training *****")
810
+ logger.info(f" Num examples = {len(train_dataset)}")
811
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
812
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
813
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
814
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
815
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
816
+ global_step = 0
817
+ first_epoch = 0
818
+
819
+ # Potentially load in the weights and states from a previous save
820
+ if args.resume_from_checkpoint:
821
+ if args.resume_from_checkpoint != "latest":
822
+ path = os.path.basename(args.resume_from_checkpoint)
823
+ else:
824
+ # Get the most recent checkpoint
825
+ dirs = os.listdir(args.output_dir)
826
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
827
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
828
+ path = dirs[-1] if len(dirs) > 0 else None
829
+
830
+ if path is None:
831
+ accelerator.print(
832
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
833
+ )
834
+ args.resume_from_checkpoint = None
835
+ else:
836
+ accelerator.print(f"Resuming from checkpoint {path}")
837
+ accelerator.load_state(os.path.join(args.output_dir, path))
838
+ global_step = int(path.split("-")[1])
839
+
840
+ resume_global_step = global_step * args.gradient_accumulation_steps
841
+ first_epoch = global_step // num_update_steps_per_epoch
842
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
843
+
844
+ # Only show the progress bar once on each machine.
845
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
846
+ progress_bar.set_description("Steps")
847
+
848
+ for epoch in range(first_epoch, args.num_train_epochs):
849
+ unet.train()
850
+ train_loss = 0.0
851
+ for step, batch in enumerate(train_dataloader):
852
+ # Skip steps until we reach the resumed step
853
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
854
+ if step % args.gradient_accumulation_steps == 0:
855
+ progress_bar.update(1)
856
+ continue
857
+
858
+ with accelerator.accumulate(unet):
859
+ # We want to learn the denoising process w.r.t the edited images which
860
+ # are conditioned on the original image (which was edited) and the edit instruction.
861
+ # So, first, convert images to latent space.
862
+ latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
863
+ latents = latents * vae.config.scaling_factor
864
+
865
+ # Sample noise that we'll add to the latents
866
+ noise = torch.randn_like(latents)
867
+ bsz = latents.shape[0]
868
+ # Sample a random timestep for each image
869
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
870
+ timesteps = timesteps.long()
871
+
872
+ # Add noise to the latents according to the noise magnitude at each timestep
873
+ # (this is the forward diffusion process)
874
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
875
+
876
+ # Get the text embedding for conditioning.
877
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
878
+
879
+ # Get the additional image embedding for conditioning.
880
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
881
+ original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
882
+
883
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
884
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
885
+ if args.conditioning_dropout_prob is not None:
886
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
887
+ # Sample masks for the edit prompts.
888
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
889
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
890
+ # Final text conditioning.
891
+ null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
892
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
893
+
894
+ # Sample masks for the original images.
895
+ image_mask_dtype = original_image_embeds.dtype
896
+ image_mask = 1 - (
897
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
898
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
899
+ )
900
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
901
+ # Final image conditioning.
902
+ original_image_embeds = image_mask * original_image_embeds
903
+
904
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
905
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
906
+
907
+ # Get the target for loss depending on the prediction type
908
+ if noise_scheduler.config.prediction_type == "epsilon":
909
+ target = noise
910
+ elif noise_scheduler.config.prediction_type == "v_prediction":
911
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
912
+ else:
913
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
914
+
915
+ # Predict the noise residual and compute loss
916
+ model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
917
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
918
+
919
+ # Gather the losses across all processes for logging (if we use distributed training).
920
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
921
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
922
+
923
+ # Backpropagate
924
+ accelerator.backward(loss)
925
+ if accelerator.sync_gradients:
926
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
927
+ optimizer.step()
928
+ lr_scheduler.step()
929
+ optimizer.zero_grad()
930
+
931
+ # Checks if the accelerator has performed an optimization step behind the scenes
932
+ if accelerator.sync_gradients:
933
+ if args.use_ema:
934
+ ema_unet.step(unet.parameters())
935
+ progress_bar.update(1)
936
+ global_step += 1
937
+ accelerator.log({"train_loss": train_loss}, step=global_step)
938
+ train_loss = 0.0
939
+
940
+ if global_step % args.checkpointing_steps == 0:
941
+ if accelerator.is_main_process:
942
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
943
+ if args.checkpoints_total_limit is not None:
944
+ checkpoints = os.listdir(args.output_dir)
945
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
946
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
947
+
948
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
949
+ if len(checkpoints) >= args.checkpoints_total_limit:
950
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
951
+ removing_checkpoints = checkpoints[0:num_to_remove]
952
+
953
+ logger.info(
954
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
955
+ )
956
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
957
+
958
+ for removing_checkpoint in removing_checkpoints:
959
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
960
+ shutil.rmtree(removing_checkpoint)
961
+
962
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
963
+ accelerator.save_state(save_path)
964
+ logger.info(f"Saved state to {save_path}")
965
+
966
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
967
+ progress_bar.set_postfix(**logs)
968
+
969
+ if global_step >= args.max_train_steps:
970
+ break
971
+
972
+ if accelerator.is_main_process:
973
+ if (
974
+ (args.val_image_url is not None)
975
+ and (args.validation_prompt is not None)
976
+ and (epoch % args.validation_epochs == 0)
977
+ ):
978
+ if args.use_ema:
979
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
980
+ ema_unet.store(unet.parameters())
981
+ ema_unet.copy_to(unet.parameters())
982
+ # The models need unwrapping because for compatibility in distributed training mode.
983
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
984
+ args.pretrained_model_name_or_path,
985
+ unet=unwrap_model(unet),
986
+ text_encoder=unwrap_model(text_encoder),
987
+ vae=unwrap_model(vae),
988
+ revision=args.revision,
989
+ variant=args.variant,
990
+ torch_dtype=weight_dtype,
991
+ )
992
+
993
+ log_validation(
994
+ pipeline,
995
+ args,
996
+ accelerator,
997
+ generator,
998
+ )
999
+
1000
+ if args.use_ema:
1001
+ # Switch back to the original UNet parameters.
1002
+ ema_unet.restore(unet.parameters())
1003
+
1004
+ del pipeline
1005
+ torch.cuda.empty_cache()
1006
+
1007
+ # Create the pipeline using the trained modules and save it.
1008
+ accelerator.wait_for_everyone()
1009
+ if accelerator.is_main_process:
1010
+ if args.use_ema:
1011
+ ema_unet.copy_to(unet.parameters())
1012
+
1013
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
1014
+ args.pretrained_model_name_or_path,
1015
+ text_encoder=unwrap_model(text_encoder),
1016
+ vae=unwrap_model(vae),
1017
+ unet=unwrap_model(unet),
1018
+ revision=args.revision,
1019
+ variant=args.variant,
1020
+ )
1021
+ pipeline.save_pretrained(args.output_dir)
1022
+
1023
+ if args.push_to_hub:
1024
+ upload_folder(
1025
+ repo_id=repo_id,
1026
+ folder_path=args.output_dir,
1027
+ commit_message="End of training",
1028
+ ignore_patterns=["step_*", "epoch_*"],
1029
+ )
1030
+
1031
+ if (args.val_image_url is not None) and (args.validation_prompt is not None):
1032
+ log_validation(
1033
+ pipeline,
1034
+ args,
1035
+ accelerator,
1036
+ generator,
1037
+ )
1038
+ accelerator.end_training()
1039
+
1040
+
1041
+ if __name__ == "__main__":
1042
+ main()