fffiloni commited on
Commit
a5f898b
1 Parent(s): a7b42a0

Create train_dreambooth_lora_sdxl.py

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora_sdxl.py +1368 -0
train_dreambooth_lora_sdxl.py ADDED
@@ -0,0 +1,1368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import gc
18
+ import hashlib
19
+ import itertools
20
+ import logging
21
+ import math
22
+ import os
23
+ import shutil
24
+ import warnings
25
+ from pathlib import Path
26
+ from typing import Dict
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from PIL.ImageOps import exif_transpose
40
+ from torch.utils.data import Dataset
41
+ from torchvision import transforms
42
+ from tqdm.auto import tqdm
43
+ from transformers import AutoTokenizer, PretrainedConfig
44
+
45
+ import diffusers
46
+ from diffusers import (
47
+ AutoencoderKL,
48
+ DDPMScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ StableDiffusionXLPipeline,
51
+ UNet2DConditionModel,
52
+ )
53
+ from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
54
+ from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
55
+ from diffusers.optimization import get_scheduler
56
+ from diffusers.utils import check_min_version, is_wandb_available
57
+ from diffusers.utils.import_utils import is_xformers_available
58
+
59
+
60
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
+ check_min_version("0.21.0.dev0")
62
+
63
+ logger = get_logger(__name__)
64
+
65
+
66
+ def save_model_card(
67
+ repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
68
+ ):
69
+ img_str = ""
70
+ for i, image in enumerate(images):
71
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
72
+ img_str += f"![img_{i}](./image_{i}.png)\n"
73
+
74
+ yaml = f"""
75
+ ---
76
+ license: openrail++
77
+ base_model: {base_model}
78
+ instance_prompt: {prompt}
79
+ tags:
80
+ - stable-diffusion-xl
81
+ - stable-diffusion-xl-diffusers
82
+ - text-to-image
83
+ - diffusers
84
+ - lora
85
+ inference: true
86
+ ---
87
+ """
88
+ model_card = f"""
89
+ # LoRA DreamBooth - {repo_id}
90
+
91
+ These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
92
+ {img_str}
93
+
94
+ LoRA for the text encoder was enabled: {train_text_encoder}.
95
+
96
+ Special VAE used for training: {vae_path}.
97
+ """
98
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
99
+ f.write(yaml + model_card)
100
+
101
+
102
+ def import_model_class_from_model_name_or_path(
103
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
104
+ ):
105
+ text_encoder_config = PretrainedConfig.from_pretrained(
106
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
107
+ )
108
+ model_class = text_encoder_config.architectures[0]
109
+
110
+ if model_class == "CLIPTextModel":
111
+ from transformers import CLIPTextModel
112
+
113
+ return CLIPTextModel
114
+ elif model_class == "CLIPTextModelWithProjection":
115
+ from transformers import CLIPTextModelWithProjection
116
+
117
+ return CLIPTextModelWithProjection
118
+ else:
119
+ raise ValueError(f"{model_class} is not supported.")
120
+
121
+
122
+ def parse_args(input_args=None):
123
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
124
+ parser.add_argument(
125
+ "--pretrained_model_name_or_path",
126
+ type=str,
127
+ default=None,
128
+ required=True,
129
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
130
+ )
131
+ parser.add_argument(
132
+ "--pretrained_vae_model_name_or_path",
133
+ type=str,
134
+ default=None,
135
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
136
+ )
137
+ parser.add_argument(
138
+ "--revision",
139
+ type=str,
140
+ default=None,
141
+ required=False,
142
+ help="Revision of pretrained model identifier from huggingface.co/models.",
143
+ )
144
+ parser.add_argument(
145
+ "--instance_data_dir",
146
+ type=str,
147
+ default=None,
148
+ required=True,
149
+ help="A folder containing the training data of instance images.",
150
+ )
151
+ parser.add_argument(
152
+ "--class_data_dir",
153
+ type=str,
154
+ default=None,
155
+ required=False,
156
+ help="A folder containing the training data of class images.",
157
+ )
158
+ parser.add_argument(
159
+ "--instance_prompt",
160
+ type=str,
161
+ default=None,
162
+ required=True,
163
+ help="The prompt with identifier specifying the instance",
164
+ )
165
+ parser.add_argument(
166
+ "--class_prompt",
167
+ type=str,
168
+ default=None,
169
+ help="The prompt to specify images in the same class as provided instance images.",
170
+ )
171
+ parser.add_argument(
172
+ "--validation_prompt",
173
+ type=str,
174
+ default=None,
175
+ help="A prompt that is used during validation to verify that the model is learning.",
176
+ )
177
+ parser.add_argument(
178
+ "--num_validation_images",
179
+ type=int,
180
+ default=4,
181
+ help="Number of images that should be generated during validation with `validation_prompt`.",
182
+ )
183
+ parser.add_argument(
184
+ "--validation_epochs",
185
+ type=int,
186
+ default=50,
187
+ help=(
188
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
189
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
190
+ ),
191
+ )
192
+ parser.add_argument(
193
+ "--with_prior_preservation",
194
+ default=False,
195
+ action="store_true",
196
+ help="Flag to add prior preservation loss.",
197
+ )
198
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
199
+ parser.add_argument(
200
+ "--num_class_images",
201
+ type=int,
202
+ default=100,
203
+ help=(
204
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
205
+ " class_data_dir, additional images will be sampled with class_prompt."
206
+ ),
207
+ )
208
+ parser.add_argument(
209
+ "--output_dir",
210
+ type=str,
211
+ default="lora-dreambooth-model",
212
+ help="The output directory where the model predictions and checkpoints will be written.",
213
+ )
214
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
215
+ parser.add_argument(
216
+ "--resolution",
217
+ type=int,
218
+ default=1024,
219
+ help=(
220
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
221
+ " resolution"
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--crops_coords_top_left_h",
226
+ type=int,
227
+ default=0,
228
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
229
+ )
230
+ parser.add_argument(
231
+ "--crops_coords_top_left_w",
232
+ type=int,
233
+ default=0,
234
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
235
+ )
236
+ parser.add_argument(
237
+ "--center_crop",
238
+ default=False,
239
+ action="store_true",
240
+ help=(
241
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
242
+ " cropped. The images will be resized to the resolution first before cropping."
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--train_text_encoder",
247
+ action="store_true",
248
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
249
+ )
250
+ parser.add_argument(
251
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
252
+ )
253
+ parser.add_argument(
254
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
255
+ )
256
+ parser.add_argument("--num_train_epochs", type=int, default=1)
257
+ parser.add_argument(
258
+ "--max_train_steps",
259
+ type=int,
260
+ default=None,
261
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
262
+ )
263
+ parser.add_argument(
264
+ "--checkpointing_steps",
265
+ type=int,
266
+ default=500,
267
+ help=(
268
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
269
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
270
+ " training using `--resume_from_checkpoint`."
271
+ ),
272
+ )
273
+ parser.add_argument(
274
+ "--checkpoints_total_limit",
275
+ type=int,
276
+ default=None,
277
+ help=("Max number of checkpoints to store."),
278
+ )
279
+ parser.add_argument(
280
+ "--resume_from_checkpoint",
281
+ type=str,
282
+ default=None,
283
+ help=(
284
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
285
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--gradient_accumulation_steps",
290
+ type=int,
291
+ default=1,
292
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
293
+ )
294
+ parser.add_argument(
295
+ "--gradient_checkpointing",
296
+ action="store_true",
297
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
298
+ )
299
+ parser.add_argument(
300
+ "--learning_rate",
301
+ type=float,
302
+ default=5e-4,
303
+ help="Initial learning rate (after the potential warmup period) to use.",
304
+ )
305
+ parser.add_argument(
306
+ "--scale_lr",
307
+ action="store_true",
308
+ default=False,
309
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
310
+ )
311
+ parser.add_argument(
312
+ "--lr_scheduler",
313
+ type=str,
314
+ default="constant",
315
+ help=(
316
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
317
+ ' "constant", "constant_with_warmup"]'
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
322
+ )
323
+ parser.add_argument(
324
+ "--lr_num_cycles",
325
+ type=int,
326
+ default=1,
327
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
328
+ )
329
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
330
+ parser.add_argument(
331
+ "--dataloader_num_workers",
332
+ type=int,
333
+ default=0,
334
+ help=(
335
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
336
+ ),
337
+ )
338
+ parser.add_argument(
339
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
340
+ )
341
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
342
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
343
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
344
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
345
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
346
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
347
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
348
+ parser.add_argument(
349
+ "--hub_model_id",
350
+ type=str,
351
+ default=None,
352
+ help="The name of the repository to keep in sync with the local `output_dir`.",
353
+ )
354
+ parser.add_argument(
355
+ "--logging_dir",
356
+ type=str,
357
+ default="logs",
358
+ help=(
359
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
360
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
361
+ ),
362
+ )
363
+ parser.add_argument(
364
+ "--allow_tf32",
365
+ action="store_true",
366
+ help=(
367
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
368
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
369
+ ),
370
+ )
371
+ parser.add_argument(
372
+ "--report_to",
373
+ type=str,
374
+ default="tensorboard",
375
+ help=(
376
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
377
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--mixed_precision",
382
+ type=str,
383
+ default=None,
384
+ choices=["no", "fp16", "bf16"],
385
+ help=(
386
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
387
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
388
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
389
+ ),
390
+ )
391
+ parser.add_argument(
392
+ "--prior_generation_precision",
393
+ type=str,
394
+ default=None,
395
+ choices=["no", "fp32", "fp16", "bf16"],
396
+ help=(
397
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
398
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
399
+ ),
400
+ )
401
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
402
+ parser.add_argument(
403
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
404
+ )
405
+ parser.add_argument(
406
+ "--rank",
407
+ type=int,
408
+ default=4,
409
+ help=("The dimension of the LoRA update matrices."),
410
+ )
411
+
412
+ if input_args is not None:
413
+ args = parser.parse_args(input_args)
414
+ else:
415
+ args = parser.parse_args()
416
+
417
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
418
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
419
+ args.local_rank = env_local_rank
420
+
421
+ if args.with_prior_preservation:
422
+ if args.class_data_dir is None:
423
+ raise ValueError("You must specify a data directory for class images.")
424
+ if args.class_prompt is None:
425
+ raise ValueError("You must specify prompt for class images.")
426
+ else:
427
+ # logger is not available yet
428
+ if args.class_data_dir is not None:
429
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
430
+ if args.class_prompt is not None:
431
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
432
+
433
+ return args
434
+
435
+
436
+ class DreamBoothDataset(Dataset):
437
+ """
438
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
439
+ It pre-processes the images.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ instance_data_root,
445
+ class_data_root=None,
446
+ class_num=None,
447
+ size=1024,
448
+ center_crop=False,
449
+ ):
450
+ self.size = size
451
+ self.center_crop = center_crop
452
+
453
+ self.instance_data_root = Path(instance_data_root)
454
+ if not self.instance_data_root.exists():
455
+ raise ValueError("Instance images root doesn't exists.")
456
+
457
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
458
+ self.num_instance_images = len(self.instance_images_path)
459
+ self._length = self.num_instance_images
460
+
461
+ if class_data_root is not None:
462
+ self.class_data_root = Path(class_data_root)
463
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
464
+ self.class_images_path = list(self.class_data_root.iterdir())
465
+ if class_num is not None:
466
+ self.num_class_images = min(len(self.class_images_path), class_num)
467
+ else:
468
+ self.num_class_images = len(self.class_images_path)
469
+ self._length = max(self.num_class_images, self.num_instance_images)
470
+ else:
471
+ self.class_data_root = None
472
+
473
+ self.image_transforms = transforms.Compose(
474
+ [
475
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
476
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
477
+ transforms.ToTensor(),
478
+ transforms.Normalize([0.5], [0.5]),
479
+ ]
480
+ )
481
+
482
+ def __len__(self):
483
+ return self._length
484
+
485
+ def __getitem__(self, index):
486
+ example = {}
487
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
488
+ instance_image = exif_transpose(instance_image)
489
+
490
+ if not instance_image.mode == "RGB":
491
+ instance_image = instance_image.convert("RGB")
492
+ example["instance_images"] = self.image_transforms(instance_image)
493
+
494
+ if self.class_data_root:
495
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
496
+ class_image = exif_transpose(class_image)
497
+
498
+ if not class_image.mode == "RGB":
499
+ class_image = class_image.convert("RGB")
500
+ example["class_images"] = self.image_transforms(class_image)
501
+
502
+ return example
503
+
504
+
505
+ def collate_fn(examples, with_prior_preservation=False):
506
+ pixel_values = [example["instance_images"] for example in examples]
507
+
508
+ # Concat class and instance examples for prior preservation.
509
+ # We do this to avoid doing two forward passes.
510
+ if with_prior_preservation:
511
+ pixel_values += [example["class_images"] for example in examples]
512
+
513
+ pixel_values = torch.stack(pixel_values)
514
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
515
+
516
+ batch = {"pixel_values": pixel_values}
517
+ return batch
518
+
519
+
520
+ class PromptDataset(Dataset):
521
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
522
+
523
+ def __init__(self, prompt, num_samples):
524
+ self.prompt = prompt
525
+ self.num_samples = num_samples
526
+
527
+ def __len__(self):
528
+ return self.num_samples
529
+
530
+ def __getitem__(self, index):
531
+ example = {}
532
+ example["prompt"] = self.prompt
533
+ example["index"] = index
534
+ return example
535
+
536
+
537
+ def tokenize_prompt(tokenizer, prompt):
538
+ text_inputs = tokenizer(
539
+ prompt,
540
+ padding="max_length",
541
+ max_length=tokenizer.model_max_length,
542
+ truncation=True,
543
+ return_tensors="pt",
544
+ )
545
+ text_input_ids = text_inputs.input_ids
546
+ return text_input_ids
547
+
548
+
549
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
550
+ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
551
+ prompt_embeds_list = []
552
+
553
+ for i, text_encoder in enumerate(text_encoders):
554
+ if tokenizers is not None:
555
+ tokenizer = tokenizers[i]
556
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
557
+ else:
558
+ assert text_input_ids_list is not None
559
+ text_input_ids = text_input_ids_list[i]
560
+
561
+ prompt_embeds = text_encoder(
562
+ text_input_ids.to(text_encoder.device),
563
+ output_hidden_states=True,
564
+ )
565
+
566
+ # We are only ALWAYS interested in the pooled output of the final text encoder
567
+ pooled_prompt_embeds = prompt_embeds[0]
568
+ prompt_embeds = prompt_embeds.hidden_states[-2]
569
+ bs_embed, seq_len, _ = prompt_embeds.shape
570
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
571
+ prompt_embeds_list.append(prompt_embeds)
572
+
573
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
574
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
575
+ return prompt_embeds, pooled_prompt_embeds
576
+
577
+
578
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
579
+ """
580
+ Returns:
581
+ a state dict containing just the attention processor parameters.
582
+ """
583
+ attn_processors = unet.attn_processors
584
+
585
+ attn_processors_state_dict = {}
586
+
587
+ for attn_processor_key, attn_processor in attn_processors.items():
588
+ for parameter_key, parameter in attn_processor.state_dict().items():
589
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
590
+
591
+ return attn_processors_state_dict
592
+
593
+
594
+ def main(args):
595
+ logging_dir = Path(args.output_dir, args.logging_dir)
596
+
597
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
598
+
599
+ accelerator = Accelerator(
600
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
601
+ mixed_precision=args.mixed_precision,
602
+ log_with=args.report_to,
603
+ project_config=accelerator_project_config,
604
+ )
605
+
606
+ if args.report_to == "wandb":
607
+ if not is_wandb_available():
608
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
609
+ import wandb
610
+
611
+ # Make one log on every process with the configuration for debugging.
612
+ logging.basicConfig(
613
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
614
+ datefmt="%m/%d/%Y %H:%M:%S",
615
+ level=logging.INFO,
616
+ )
617
+ logger.info(accelerator.state, main_process_only=False)
618
+ if accelerator.is_local_main_process:
619
+ transformers.utils.logging.set_verbosity_warning()
620
+ diffusers.utils.logging.set_verbosity_info()
621
+ else:
622
+ transformers.utils.logging.set_verbosity_error()
623
+ diffusers.utils.logging.set_verbosity_error()
624
+
625
+ # If passed along, set the training seed now.
626
+ if args.seed is not None:
627
+ set_seed(args.seed)
628
+
629
+ # Generate class images if prior preservation is enabled.
630
+ if args.with_prior_preservation:
631
+ class_images_dir = Path(args.class_data_dir)
632
+ if not class_images_dir.exists():
633
+ class_images_dir.mkdir(parents=True)
634
+ cur_class_images = len(list(class_images_dir.iterdir()))
635
+
636
+ if cur_class_images < args.num_class_images:
637
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
638
+ if args.prior_generation_precision == "fp32":
639
+ torch_dtype = torch.float32
640
+ elif args.prior_generation_precision == "fp16":
641
+ torch_dtype = torch.float16
642
+ elif args.prior_generation_precision == "bf16":
643
+ torch_dtype = torch.bfloat16
644
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
645
+ args.pretrained_model_name_or_path,
646
+ torch_dtype=torch_dtype,
647
+ revision=args.revision,
648
+ )
649
+ pipeline.set_progress_bar_config(disable=True)
650
+
651
+ num_new_images = args.num_class_images - cur_class_images
652
+ logger.info(f"Number of class images to sample: {num_new_images}.")
653
+
654
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
655
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
656
+
657
+ sample_dataloader = accelerator.prepare(sample_dataloader)
658
+ pipeline.to(accelerator.device)
659
+
660
+ for example in tqdm(
661
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
662
+ ):
663
+ images = pipeline(example["prompt"]).images
664
+
665
+ for i, image in enumerate(images):
666
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
667
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
668
+ image.save(image_filename)
669
+
670
+ del pipeline
671
+ if torch.cuda.is_available():
672
+ torch.cuda.empty_cache()
673
+
674
+ # Handle the repository creation
675
+ if accelerator.is_main_process:
676
+ if args.output_dir is not None:
677
+ os.makedirs(args.output_dir, exist_ok=True)
678
+
679
+ if args.push_to_hub:
680
+ repo_id = create_repo(
681
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
682
+ ).repo_id
683
+
684
+ # Load the tokenizers
685
+ tokenizer_one = AutoTokenizer.from_pretrained(
686
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
687
+ )
688
+ tokenizer_two = AutoTokenizer.from_pretrained(
689
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
690
+ )
691
+
692
+ # import correct text encoder classes
693
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
694
+ args.pretrained_model_name_or_path, args.revision
695
+ )
696
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
697
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
698
+ )
699
+
700
+ # Load scheduler and models
701
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
702
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
703
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
704
+ )
705
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
706
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
707
+ )
708
+ vae_path = (
709
+ args.pretrained_model_name_or_path
710
+ if args.pretrained_vae_model_name_or_path is None
711
+ else args.pretrained_vae_model_name_or_path
712
+ )
713
+ vae = AutoencoderKL.from_pretrained(
714
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
715
+ )
716
+ unet = UNet2DConditionModel.from_pretrained(
717
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
718
+ )
719
+
720
+ # We only train the additional adapter LoRA layers
721
+ vae.requires_grad_(False)
722
+ text_encoder_one.requires_grad_(False)
723
+ text_encoder_two.requires_grad_(False)
724
+ unet.requires_grad_(False)
725
+
726
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
727
+ # as these weights are only used for inference, keeping weights in full precision is not required.
728
+ weight_dtype = torch.float32
729
+ if accelerator.mixed_precision == "fp16":
730
+ weight_dtype = torch.float16
731
+ elif accelerator.mixed_precision == "bf16":
732
+ weight_dtype = torch.bfloat16
733
+
734
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
735
+ unet.to(accelerator.device, dtype=weight_dtype)
736
+
737
+ # The VAE is always in float32 to avoid NaN losses.
738
+ vae.to(accelerator.device, dtype=torch.float32)
739
+
740
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
741
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
742
+
743
+ if args.enable_xformers_memory_efficient_attention:
744
+ if is_xformers_available():
745
+ import xformers
746
+
747
+ xformers_version = version.parse(xformers.__version__)
748
+ if xformers_version == version.parse("0.0.16"):
749
+ logger.warn(
750
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
751
+ )
752
+ unet.enable_xformers_memory_efficient_attention()
753
+ else:
754
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
755
+
756
+ if args.gradient_checkpointing:
757
+ unet.enable_gradient_checkpointing()
758
+ if args.train_text_encoder:
759
+ text_encoder_one.gradient_checkpointing_enable()
760
+ text_encoder_two.gradient_checkpointing_enable()
761
+
762
+ # now we will add new LoRA weights to the attention layers
763
+ # Set correct lora layers
764
+ unet_lora_attn_procs = {}
765
+ unet_lora_parameters = []
766
+ for name, attn_processor in unet.attn_processors.items():
767
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
768
+ if name.startswith("mid_block"):
769
+ hidden_size = unet.config.block_out_channels[-1]
770
+ elif name.startswith("up_blocks"):
771
+ block_id = int(name[len("up_blocks.")])
772
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
773
+ elif name.startswith("down_blocks"):
774
+ block_id = int(name[len("down_blocks.")])
775
+ hidden_size = unet.config.block_out_channels[block_id]
776
+
777
+ lora_attn_processor_class = (
778
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
779
+ )
780
+ module = lora_attn_processor_class(
781
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
782
+ )
783
+ unet_lora_attn_procs[name] = module
784
+ unet_lora_parameters.extend(module.parameters())
785
+
786
+ unet.set_attn_processor(unet_lora_attn_procs)
787
+
788
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
789
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
790
+ if args.train_text_encoder:
791
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
792
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
793
+ text_encoder_one, dtype=torch.float32, rank=args.rank
794
+ )
795
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
796
+ text_encoder_two, dtype=torch.float32, rank=args.rank
797
+ )
798
+
799
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
800
+ def save_model_hook(models, weights, output_dir):
801
+ if accelerator.is_main_process:
802
+ # there are only two options here. Either are just the unet attn processor layers
803
+ # or there are the unet and text encoder atten layers
804
+ unet_lora_layers_to_save = None
805
+ text_encoder_one_lora_layers_to_save = None
806
+ text_encoder_two_lora_layers_to_save = None
807
+
808
+ for model in models:
809
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
810
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
811
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
812
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
813
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
814
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
815
+ else:
816
+ raise ValueError(f"unexpected save model: {model.__class__}")
817
+
818
+ # make sure to pop weight so that corresponding model is not saved again
819
+ weights.pop()
820
+
821
+ StableDiffusionXLPipeline.save_lora_weights(
822
+ output_dir,
823
+ unet_lora_layers=unet_lora_layers_to_save,
824
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
825
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
826
+ )
827
+
828
+ def load_model_hook(models, input_dir):
829
+ unet_ = None
830
+ text_encoder_one_ = None
831
+ text_encoder_two_ = None
832
+
833
+ while len(models) > 0:
834
+ model = models.pop()
835
+
836
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
837
+ unet_ = model
838
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
839
+ text_encoder_one_ = model
840
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
841
+ text_encoder_two_ = model
842
+ else:
843
+ raise ValueError(f"unexpected save model: {model.__class__}")
844
+
845
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
846
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
847
+
848
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
849
+ LoraLoaderMixin.load_lora_into_text_encoder(
850
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
851
+ )
852
+
853
+ text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
854
+ LoraLoaderMixin.load_lora_into_text_encoder(
855
+ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
856
+ )
857
+
858
+ accelerator.register_save_state_pre_hook(save_model_hook)
859
+ accelerator.register_load_state_pre_hook(load_model_hook)
860
+
861
+ # Enable TF32 for faster training on Ampere GPUs,
862
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
863
+ if args.allow_tf32:
864
+ torch.backends.cuda.matmul.allow_tf32 = True
865
+
866
+ if args.scale_lr:
867
+ args.learning_rate = (
868
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
869
+ )
870
+
871
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
872
+ if args.use_8bit_adam:
873
+ try:
874
+ import bitsandbytes as bnb
875
+ except ImportError:
876
+ raise ImportError(
877
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
878
+ )
879
+
880
+ optimizer_class = bnb.optim.AdamW8bit
881
+ else:
882
+ optimizer_class = torch.optim.AdamW
883
+
884
+ # Optimizer creation
885
+ params_to_optimize = (
886
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
887
+ if args.train_text_encoder
888
+ else unet_lora_parameters
889
+ )
890
+ optimizer = optimizer_class(
891
+ params_to_optimize,
892
+ lr=args.learning_rate,
893
+ betas=(args.adam_beta1, args.adam_beta2),
894
+ weight_decay=args.adam_weight_decay,
895
+ eps=args.adam_epsilon,
896
+ )
897
+
898
+ # Computes additional embeddings/ids required by the SDXL UNet.
899
+ # regular text emebddings (when `train_text_encoder` is not True)
900
+ # pooled text embeddings
901
+ # time ids
902
+
903
+ def compute_time_ids():
904
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
905
+ original_size = (args.resolution, args.resolution)
906
+ target_size = (args.resolution, args.resolution)
907
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
908
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
909
+ add_time_ids = torch.tensor([add_time_ids])
910
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
911
+ return add_time_ids
912
+
913
+ if not args.train_text_encoder:
914
+ tokenizers = [tokenizer_one, tokenizer_two]
915
+ text_encoders = [text_encoder_one, text_encoder_two]
916
+
917
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
918
+ with torch.no_grad():
919
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
920
+ prompt_embeds = prompt_embeds.to(accelerator.device)
921
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
922
+ return prompt_embeds, pooled_prompt_embeds
923
+
924
+ # Handle instance prompt.
925
+ instance_time_ids = compute_time_ids()
926
+ if not args.train_text_encoder:
927
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
928
+ args.instance_prompt, text_encoders, tokenizers
929
+ )
930
+
931
+ # Handle class prompt for prior-preservation.
932
+ if args.with_prior_preservation:
933
+ class_time_ids = compute_time_ids()
934
+ if not args.train_text_encoder:
935
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
936
+ args.class_prompt, text_encoders, tokenizers
937
+ )
938
+
939
+ # Clear the memory here.
940
+ if not args.train_text_encoder:
941
+ del tokenizers, text_encoders
942
+ gc.collect()
943
+ torch.cuda.empty_cache()
944
+
945
+ # Pack the statically computed variables appropriately. This is so that we don't
946
+ # have to pass them to the dataloader.
947
+ add_time_ids = instance_time_ids
948
+ if args.with_prior_preservation:
949
+ add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
950
+
951
+ if not args.train_text_encoder:
952
+ prompt_embeds = instance_prompt_hidden_states
953
+ unet_add_text_embeds = instance_pooled_prompt_embeds
954
+ if args.with_prior_preservation:
955
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
956
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
957
+ else:
958
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
959
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
960
+ if args.with_prior_preservation:
961
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
962
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
963
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
964
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
965
+
966
+ # Dataset and DataLoaders creation:
967
+ train_dataset = DreamBoothDataset(
968
+ instance_data_root=args.instance_data_dir,
969
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
970
+ class_num=args.num_class_images,
971
+ size=args.resolution,
972
+ center_crop=args.center_crop,
973
+ )
974
+
975
+ train_dataloader = torch.utils.data.DataLoader(
976
+ train_dataset,
977
+ batch_size=args.train_batch_size,
978
+ shuffle=True,
979
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
980
+ num_workers=args.dataloader_num_workers,
981
+ )
982
+
983
+ # Scheduler and math around the number of training steps.
984
+ overrode_max_train_steps = False
985
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
986
+ if args.max_train_steps is None:
987
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
988
+ overrode_max_train_steps = True
989
+
990
+ lr_scheduler = get_scheduler(
991
+ args.lr_scheduler,
992
+ optimizer=optimizer,
993
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
994
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
995
+ num_cycles=args.lr_num_cycles,
996
+ power=args.lr_power,
997
+ )
998
+
999
+ # Prepare everything with our `accelerator`.
1000
+ if args.train_text_encoder:
1001
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1002
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1003
+ )
1004
+ else:
1005
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1006
+ unet, optimizer, train_dataloader, lr_scheduler
1007
+ )
1008
+
1009
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1010
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1011
+ if overrode_max_train_steps:
1012
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1013
+ # Afterwards we recalculate our number of training epochs
1014
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1015
+
1016
+ # We need to initialize the trackers we use, and also store our configuration.
1017
+ # The trackers initializes automatically on the main process.
1018
+ if accelerator.is_main_process:
1019
+ accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1020
+
1021
+ # Train!
1022
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1023
+
1024
+ logger.info("***** Running training *****")
1025
+ logger.info(f" Num examples = {len(train_dataset)}")
1026
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1027
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1028
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1029
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1030
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1031
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1032
+ global_step = 0
1033
+ first_epoch = 0
1034
+
1035
+ # Potentially load in the weights and states from a previous save
1036
+ if args.resume_from_checkpoint:
1037
+ if args.resume_from_checkpoint != "latest":
1038
+ path = os.path.basename(args.resume_from_checkpoint)
1039
+ else:
1040
+ # Get the mos recent checkpoint
1041
+ dirs = os.listdir(args.output_dir)
1042
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1043
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1044
+ path = dirs[-1] if len(dirs) > 0 else None
1045
+
1046
+ if path is None:
1047
+ accelerator.print(
1048
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1049
+ )
1050
+ args.resume_from_checkpoint = None
1051
+ else:
1052
+ accelerator.print(f"Resuming from checkpoint {path}")
1053
+ accelerator.load_state(os.path.join(args.output_dir, path))
1054
+ global_step = int(path.split("-")[1])
1055
+
1056
+ resume_global_step = global_step * args.gradient_accumulation_steps
1057
+ first_epoch = global_step // num_update_steps_per_epoch
1058
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1059
+
1060
+ # Only show the progress bar once on each machine.
1061
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1062
+ progress_bar.set_description("Steps")
1063
+
1064
+ for epoch in range(first_epoch, args.num_train_epochs):
1065
+ unet.train()
1066
+ if args.train_text_encoder:
1067
+ text_encoder_one.train()
1068
+ text_encoder_two.train()
1069
+ for step, batch in enumerate(train_dataloader):
1070
+ # Skip steps until we reach the resumed step
1071
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1072
+ if step % args.gradient_accumulation_steps == 0:
1073
+ progress_bar.update(1)
1074
+ continue
1075
+
1076
+ with accelerator.accumulate(unet):
1077
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1078
+
1079
+ # Convert images to latent space
1080
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1081
+ model_input = model_input * vae.config.scaling_factor
1082
+ if args.pretrained_vae_model_name_or_path is None:
1083
+ model_input = model_input.to(weight_dtype)
1084
+
1085
+ # Sample noise that we'll add to the latents
1086
+ noise = torch.randn_like(model_input)
1087
+ bsz = model_input.shape[0]
1088
+ # Sample a random timestep for each image
1089
+ timesteps = torch.randint(
1090
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1091
+ )
1092
+ timesteps = timesteps.long()
1093
+
1094
+ # Add noise to the model input according to the noise magnitude at each timestep
1095
+ # (this is the forward diffusion process)
1096
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1097
+
1098
+ # Calculate the elements to repeat depending on the use of prior-preservation.
1099
+ elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
1100
+
1101
+ # Predict the noise residual
1102
+ if not args.train_text_encoder:
1103
+ unet_added_conditions = {
1104
+ "time_ids": add_time_ids.repeat(elems_to_repeat, 1),
1105
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
1106
+ }
1107
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
1108
+ model_pred = unet(
1109
+ noisy_model_input,
1110
+ timesteps,
1111
+ prompt_embeds_input,
1112
+ added_cond_kwargs=unet_added_conditions,
1113
+ ).sample
1114
+ else:
1115
+ unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
1116
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1117
+ text_encoders=[text_encoder_one, text_encoder_two],
1118
+ tokenizers=None,
1119
+ prompt=None,
1120
+ text_input_ids_list=[tokens_one, tokens_two],
1121
+ )
1122
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
1123
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
1124
+ model_pred = unet(
1125
+ noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
1126
+ ).sample
1127
+
1128
+ # Get the target for loss depending on the prediction type
1129
+ if noise_scheduler.config.prediction_type == "epsilon":
1130
+ target = noise
1131
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1132
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1133
+ else:
1134
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1135
+
1136
+ if args.with_prior_preservation:
1137
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1138
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1139
+ target, target_prior = torch.chunk(target, 2, dim=0)
1140
+
1141
+ # Compute instance loss
1142
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1143
+
1144
+ # Compute prior loss
1145
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1146
+
1147
+ # Add the prior loss to the instance loss.
1148
+ loss = loss + args.prior_loss_weight * prior_loss
1149
+ else:
1150
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1151
+
1152
+ accelerator.backward(loss)
1153
+ if accelerator.sync_gradients:
1154
+ params_to_clip = (
1155
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1156
+ if args.train_text_encoder
1157
+ else unet_lora_parameters
1158
+ )
1159
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1160
+ optimizer.step()
1161
+ lr_scheduler.step()
1162
+ optimizer.zero_grad()
1163
+
1164
+ # Checks if the accelerator has performed an optimization step behind the scenes
1165
+ if accelerator.sync_gradients:
1166
+ progress_bar.update(1)
1167
+ global_step += 1
1168
+
1169
+ if accelerator.is_main_process:
1170
+ if global_step % args.checkpointing_steps == 0:
1171
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1172
+ if args.checkpoints_total_limit is not None:
1173
+ checkpoints = os.listdir(args.output_dir)
1174
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1175
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1176
+
1177
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1178
+ if len(checkpoints) >= args.checkpoints_total_limit:
1179
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1180
+ removing_checkpoints = checkpoints[0:num_to_remove]
1181
+
1182
+ logger.info(
1183
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1184
+ )
1185
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1186
+
1187
+ for removing_checkpoint in removing_checkpoints:
1188
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1189
+ shutil.rmtree(removing_checkpoint)
1190
+
1191
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1192
+ accelerator.save_state(save_path)
1193
+ logger.info(f"Saved state to {save_path}")
1194
+
1195
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1196
+ progress_bar.set_postfix(**logs)
1197
+ accelerator.log(logs, step=global_step)
1198
+
1199
+ if global_step >= args.max_train_steps:
1200
+ break
1201
+
1202
+ if accelerator.is_main_process:
1203
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1204
+ logger.info(
1205
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1206
+ f" {args.validation_prompt}."
1207
+ )
1208
+ # create pipeline
1209
+ if not args.train_text_encoder:
1210
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1211
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
1212
+ )
1213
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1214
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
1215
+ )
1216
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1217
+ args.pretrained_model_name_or_path,
1218
+ vae=vae,
1219
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1220
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1221
+ unet=accelerator.unwrap_model(unet),
1222
+ revision=args.revision,
1223
+ torch_dtype=weight_dtype,
1224
+ )
1225
+
1226
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1227
+ scheduler_args = {}
1228
+
1229
+ if "variance_type" in pipeline.scheduler.config:
1230
+ variance_type = pipeline.scheduler.config.variance_type
1231
+
1232
+ if variance_type in ["learned", "learned_range"]:
1233
+ variance_type = "fixed_small"
1234
+
1235
+ scheduler_args["variance_type"] = variance_type
1236
+
1237
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1238
+ pipeline.scheduler.config, **scheduler_args
1239
+ )
1240
+
1241
+ pipeline = pipeline.to(accelerator.device)
1242
+ pipeline.set_progress_bar_config(disable=True)
1243
+
1244
+ # run inference
1245
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1246
+ pipeline_args = {"prompt": args.validation_prompt}
1247
+
1248
+ with torch.cuda.amp.autocast():
1249
+ images = [
1250
+ pipeline(**pipeline_args, generator=generator).images[0]
1251
+ for _ in range(args.num_validation_images)
1252
+ ]
1253
+
1254
+ for tracker in accelerator.trackers:
1255
+ if tracker.name == "tensorboard":
1256
+ np_images = np.stack([np.asarray(img) for img in images])
1257
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1258
+ if tracker.name == "wandb":
1259
+ tracker.log(
1260
+ {
1261
+ "validation": [
1262
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1263
+ for i, image in enumerate(images)
1264
+ ]
1265
+ }
1266
+ )
1267
+
1268
+ del pipeline
1269
+ torch.cuda.empty_cache()
1270
+
1271
+ # Save the lora layers
1272
+ accelerator.wait_for_everyone()
1273
+ if accelerator.is_main_process:
1274
+ unet = accelerator.unwrap_model(unet)
1275
+ unet = unet.to(torch.float32)
1276
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
1277
+
1278
+ if args.train_text_encoder:
1279
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1280
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
1281
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1282
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
1283
+ else:
1284
+ text_encoder_lora_layers = None
1285
+ text_encoder_2_lora_layers = None
1286
+
1287
+ StableDiffusionXLPipeline.save_lora_weights(
1288
+ save_directory=args.output_dir,
1289
+ unet_lora_layers=unet_lora_layers,
1290
+ text_encoder_lora_layers=text_encoder_lora_layers,
1291
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1292
+ )
1293
+
1294
+ # Final inference
1295
+ # Load previous pipeline
1296
+ vae = AutoencoderKL.from_pretrained(
1297
+ vae_path,
1298
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1299
+ revision=args.revision,
1300
+ torch_dtype=weight_dtype,
1301
+ )
1302
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1303
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
1304
+ )
1305
+
1306
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1307
+ scheduler_args = {}
1308
+
1309
+ if "variance_type" in pipeline.scheduler.config:
1310
+ variance_type = pipeline.scheduler.config.variance_type
1311
+
1312
+ if variance_type in ["learned", "learned_range"]:
1313
+ variance_type = "fixed_small"
1314
+
1315
+ scheduler_args["variance_type"] = variance_type
1316
+
1317
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1318
+
1319
+ # load attention processors
1320
+ pipeline.load_lora_weights(args.output_dir)
1321
+
1322
+ # run inference
1323
+ images = []
1324
+ if args.validation_prompt and args.num_validation_images > 0:
1325
+ pipeline = pipeline.to(accelerator.device)
1326
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1327
+ images = [
1328
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1329
+ for _ in range(args.num_validation_images)
1330
+ ]
1331
+
1332
+ for tracker in accelerator.trackers:
1333
+ if tracker.name == "tensorboard":
1334
+ np_images = np.stack([np.asarray(img) for img in images])
1335
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1336
+ if tracker.name == "wandb":
1337
+ tracker.log(
1338
+ {
1339
+ "test": [
1340
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1341
+ for i, image in enumerate(images)
1342
+ ]
1343
+ }
1344
+ )
1345
+
1346
+ if args.push_to_hub:
1347
+ save_model_card(
1348
+ repo_id,
1349
+ images=images,
1350
+ base_model=args.pretrained_model_name_or_path,
1351
+ train_text_encoder=args.train_text_encoder,
1352
+ prompt=args.instance_prompt,
1353
+ repo_folder=args.output_dir,
1354
+ vae_path=args.pretrained_vae_model_name_or_path,
1355
+ )
1356
+ upload_folder(
1357
+ repo_id=repo_id,
1358
+ folder_path=args.output_dir,
1359
+ commit_message="End of training",
1360
+ ignore_patterns=["step_*", "epoch_*"],
1361
+ )
1362
+
1363
+ accelerator.end_training()
1364
+
1365
+
1366
+ if __name__ == "__main__":
1367
+ args = parse_args()
1368
+ main(args)