mjbuehler commited on
Commit
b831b2e
·
verified ·
1 Parent(s): 9083d1c

Upload train_dreambooth_lora_sd3.py

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora_sd3.py +1875 -0
train_dreambooth_lora_sd3.py ADDED
@@ -0,0 +1,1875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import gc
19
+ import itertools
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import shutil
25
+ import warnings
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from huggingface_hub.utils import insecure_hashlib
38
+ from peft import LoraConfig, set_peft_model_state_dict
39
+ from peft.utils import get_peft_model_state_dict
40
+ from PIL import Image
41
+ from PIL.ImageOps import exif_transpose
42
+ from torch.utils.data import Dataset
43
+ from torchvision import transforms
44
+ from torchvision.transforms.functional import crop
45
+ from tqdm.auto import tqdm
46
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
47
+
48
+ import diffusers
49
+ from diffusers import (
50
+ AutoencoderKL,
51
+ FlowMatchEulerDiscreteScheduler,
52
+ SD3Transformer2DModel,
53
+ StableDiffusion3Pipeline,
54
+ )
55
+ from diffusers.optimization import get_scheduler
56
+ from diffusers.training_utils import (
57
+ _set_state_dict_into_text_encoder,
58
+ cast_training_params,
59
+ compute_density_for_timestep_sampling,
60
+ compute_loss_weighting_for_sd3,
61
+ )
62
+ from diffusers.utils import (
63
+ check_min_version,
64
+ convert_unet_state_dict_to_peft,
65
+ is_wandb_available,
66
+ )
67
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
68
+ from diffusers.utils.torch_utils import is_compiled_module
69
+
70
+
71
+ if is_wandb_available():
72
+ import wandb
73
+
74
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
75
+ check_min_version("0.30.0.dev0")
76
+
77
+ logger = get_logger(__name__)
78
+
79
+
80
+ def save_model_card(
81
+ repo_id: str,
82
+ images=None,
83
+ base_model: str = None,
84
+ train_text_encoder=False,
85
+ instance_prompt=None,
86
+ validation_prompt=None,
87
+ repo_folder=None,
88
+ ):
89
+ widget_dict = []
90
+ if images is not None:
91
+ for i, image in enumerate(images):
92
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
93
+ widget_dict.append(
94
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
95
+ )
96
+
97
+ model_description = f"""
98
+ # SD3 DreamBooth LoRA - {repo_id}
99
+
100
+ <Gallery />
101
+
102
+ ## Model description
103
+
104
+ These are {repo_id} DreamBooth LoRA weights for {base_model}.
105
+
106
+ The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
107
+
108
+ Was LoRA for the text encoder enabled? {train_text_encoder}.
109
+
110
+ ## Trigger words
111
+
112
+ You should use `{instance_prompt}` to trigger the image generation.
113
+
114
+ ## Download model
115
+
116
+ [Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
117
+
118
+ ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
119
+
120
+ ```py
121
+ from diffusers import AutoPipelineForText2Image
122
+ import torch
123
+ pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
124
+ pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
125
+ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
126
+ ```
127
+
128
+ ### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
129
+
130
+ - **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**.
131
+ - Rename it and place it on your `models/Lora` folder.
132
+ - On AUTOMATIC1111, load the LoRA by adding `<lora:your_new_name:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
133
+
134
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
135
+
136
+ ## License
137
+
138
+ Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
139
+ """
140
+ model_card = load_or_create_model_card(
141
+ repo_id_or_path=repo_id,
142
+ from_training=True,
143
+ license="openrail++",
144
+ base_model=base_model,
145
+ prompt=instance_prompt,
146
+ model_description=model_description,
147
+ widget=widget_dict,
148
+ )
149
+ tags = [
150
+ "text-to-image",
151
+ "diffusers-training",
152
+ "diffusers",
153
+ "lora",
154
+ "sd3",
155
+ "sd3-diffusers",
156
+ "template:sd-lora",
157
+ ]
158
+
159
+ model_card = populate_model_card(model_card, tags=tags)
160
+ model_card.save(os.path.join(repo_folder, "README.md"))
161
+
162
+
163
+ def load_text_encoders(class_one, class_two, class_three):
164
+ text_encoder_one = class_one.from_pretrained(
165
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
166
+ )
167
+ text_encoder_two = class_two.from_pretrained(
168
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
169
+ )
170
+ text_encoder_three = class_three.from_pretrained(
171
+ args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
172
+ )
173
+ return text_encoder_one, text_encoder_two, text_encoder_three
174
+
175
+
176
+ def log_validation(
177
+ pipeline,
178
+ args,
179
+ accelerator,
180
+ pipeline_args,
181
+ epoch,
182
+ is_final_validation=False,
183
+ ):
184
+ logger.info(
185
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
186
+ f" {args.validation_prompt}."
187
+ )
188
+ pipeline = pipeline.to(accelerator.device)
189
+ pipeline.set_progress_bar_config(disable=True)
190
+
191
+ # run inference
192
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
193
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
194
+ autocast_ctx = nullcontext()
195
+
196
+ with autocast_ctx:
197
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
198
+
199
+ for tracker in accelerator.trackers:
200
+ phase_name = "test" if is_final_validation else "validation"
201
+ if tracker.name == "tensorboard":
202
+ np_images = np.stack([np.asarray(img) for img in images])
203
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
204
+ if tracker.name == "wandb":
205
+ tracker.log(
206
+ {
207
+ phase_name: [
208
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
209
+ ]
210
+ }
211
+ )
212
+
213
+ del pipeline
214
+ if torch.cuda.is_available():
215
+ torch.cuda.empty_cache()
216
+
217
+ return images
218
+
219
+
220
+ def import_model_class_from_model_name_or_path(
221
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
222
+ ):
223
+ text_encoder_config = PretrainedConfig.from_pretrained(
224
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
225
+ )
226
+ model_class = text_encoder_config.architectures[0]
227
+ if model_class == "CLIPTextModelWithProjection":
228
+ from transformers import CLIPTextModelWithProjection
229
+
230
+ return CLIPTextModelWithProjection
231
+ elif model_class == "T5EncoderModel":
232
+ from transformers import T5EncoderModel
233
+
234
+ return T5EncoderModel
235
+ else:
236
+ raise ValueError(f"{model_class} is not supported.")
237
+
238
+
239
+ def parse_args(input_args=None):
240
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
241
+ parser.add_argument(
242
+ "--pretrained_model_name_or_path",
243
+ type=str,
244
+ default=None,
245
+ required=True,
246
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
247
+ )
248
+ parser.add_argument(
249
+ "--revision",
250
+ type=str,
251
+ default=None,
252
+ required=False,
253
+ help="Revision of pretrained model identifier from huggingface.co/models.",
254
+ )
255
+ parser.add_argument(
256
+ "--variant",
257
+ type=str,
258
+ default=None,
259
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
260
+ )
261
+ parser.add_argument(
262
+ "--dataset_name",
263
+ type=str,
264
+ default=None,
265
+ help=(
266
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
267
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
268
+ " or to a folder containing files that 🤗 Datasets can understand."
269
+ ),
270
+ )
271
+ parser.add_argument(
272
+ "--dataset_config_name",
273
+ type=str,
274
+ default=None,
275
+ help="The config of the Dataset, leave as None if there's only one config.",
276
+ )
277
+ parser.add_argument(
278
+ "--instance_data_dir",
279
+ type=str,
280
+ default=None,
281
+ help=("A folder containing the training data. "),
282
+ )
283
+
284
+ parser.add_argument(
285
+ "--cache_dir",
286
+ type=str,
287
+ default=None,
288
+ help="The directory where the downloaded models and datasets will be stored.",
289
+ )
290
+
291
+ parser.add_argument(
292
+ "--image_column",
293
+ type=str,
294
+ default="image",
295
+ help="The column of the dataset containing the target image. By "
296
+ "default, the standard Image Dataset maps out 'file_name' "
297
+ "to 'image'.",
298
+ )
299
+ parser.add_argument(
300
+ "--caption_column",
301
+ type=str,
302
+ default=None,
303
+ help="The column of the dataset containing the instance prompt for each image",
304
+ )
305
+
306
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
307
+
308
+ parser.add_argument(
309
+ "--class_data_dir",
310
+ type=str,
311
+ default=None,
312
+ required=False,
313
+ help="A folder containing the training data of class images.",
314
+ )
315
+ parser.add_argument(
316
+ "--instance_prompt",
317
+ type=str,
318
+ default=None,
319
+ required=True,
320
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
321
+ )
322
+ parser.add_argument(
323
+ "--class_prompt",
324
+ type=str,
325
+ default=None,
326
+ help="The prompt to specify images in the same class as provided instance images.",
327
+ )
328
+ parser.add_argument(
329
+ "--max_sequence_length",
330
+ type=int,
331
+ default=77,
332
+ help="Maximum sequence length to use with with the T5 text encoder",
333
+ )
334
+ parser.add_argument(
335
+ "--validation_prompt",
336
+ type=str,
337
+ default=None,
338
+ help="A prompt that is used during validation to verify that the model is learning.",
339
+ )
340
+ parser.add_argument(
341
+ "--num_validation_images",
342
+ type=int,
343
+ default=4,
344
+ help="Number of images that should be generated during validation with `validation_prompt`.",
345
+ )
346
+ parser.add_argument(
347
+ "--validation_epochs",
348
+ type=int,
349
+ default=50,
350
+ help=(
351
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
352
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
353
+ ),
354
+ )
355
+ parser.add_argument(
356
+ "--rank",
357
+ type=int,
358
+ default=4,
359
+ help=("The dimension of the LoRA update matrices."),
360
+ )
361
+ parser.add_argument(
362
+ "--with_prior_preservation",
363
+ default=False,
364
+ action="store_true",
365
+ help="Flag to add prior preservation loss.",
366
+ )
367
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
368
+ parser.add_argument(
369
+ "--num_class_images",
370
+ type=int,
371
+ default=100,
372
+ help=(
373
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
374
+ " class_data_dir, additional images will be sampled with class_prompt."
375
+ ),
376
+ )
377
+ parser.add_argument(
378
+ "--output_dir",
379
+ type=str,
380
+ default="sd3-dreambooth",
381
+ help="The output directory where the model predictions and checkpoints will be written.",
382
+ )
383
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
384
+ parser.add_argument(
385
+ "--resolution",
386
+ type=int,
387
+ default=512,
388
+ help=(
389
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
390
+ " resolution"
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--center_crop",
395
+ default=False,
396
+ action="store_true",
397
+ help=(
398
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
399
+ " cropped. The images will be resized to the resolution first before cropping."
400
+ ),
401
+ )
402
+ parser.add_argument(
403
+ "--random_flip",
404
+ action="store_true",
405
+ help="whether to randomly flip images horizontally",
406
+ )
407
+ parser.add_argument(
408
+ "--train_text_encoder",
409
+ action="store_true",
410
+ help="Whether to train the text encoder (clip text encoders only). If set, the text encoder should be float32 precision.",
411
+ )
412
+
413
+ parser.add_argument(
414
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
415
+ )
416
+ parser.add_argument(
417
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
418
+ )
419
+ parser.add_argument("--num_train_epochs", type=int, default=1)
420
+ parser.add_argument(
421
+ "--max_train_steps",
422
+ type=int,
423
+ default=None,
424
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
425
+ )
426
+ parser.add_argument(
427
+ "--checkpointing_steps",
428
+ type=int,
429
+ default=500,
430
+ help=(
431
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
432
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
433
+ " training using `--resume_from_checkpoint`."
434
+ ),
435
+ )
436
+ parser.add_argument(
437
+ "--checkpoints_total_limit",
438
+ type=int,
439
+ default=None,
440
+ help=("Max number of checkpoints to store."),
441
+ )
442
+ parser.add_argument(
443
+ "--resume_from_checkpoint",
444
+ type=str,
445
+ default=None,
446
+ help=(
447
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
448
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
449
+ ),
450
+ )
451
+ parser.add_argument(
452
+ "--gradient_accumulation_steps",
453
+ type=int,
454
+ default=1,
455
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
456
+ )
457
+ parser.add_argument(
458
+ "--gradient_checkpointing",
459
+ action="store_true",
460
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
461
+ )
462
+ parser.add_argument(
463
+ "--learning_rate",
464
+ type=float,
465
+ default=1e-4,
466
+ help="Initial learning rate (after the potential warmup period) to use.",
467
+ )
468
+
469
+ parser.add_argument(
470
+ "--text_encoder_lr",
471
+ type=float,
472
+ default=5e-6,
473
+ help="Text encoder learning rate to use.",
474
+ )
475
+ parser.add_argument(
476
+ "--scale_lr",
477
+ action="store_true",
478
+ default=False,
479
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
480
+ )
481
+ parser.add_argument(
482
+ "--lr_scheduler",
483
+ type=str,
484
+ default="constant",
485
+ help=(
486
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
487
+ ' "constant", "constant_with_warmup"]'
488
+ ),
489
+ )
490
+ parser.add_argument(
491
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
492
+ )
493
+ parser.add_argument(
494
+ "--lr_num_cycles",
495
+ type=int,
496
+ default=1,
497
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
498
+ )
499
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
500
+ parser.add_argument(
501
+ "--dataloader_num_workers",
502
+ type=int,
503
+ default=0,
504
+ help=(
505
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
506
+ ),
507
+ )
508
+ parser.add_argument(
509
+ "--weighting_scheme",
510
+ type=str,
511
+ default="logit_normal",
512
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
513
+ )
514
+ parser.add_argument(
515
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
516
+ )
517
+ parser.add_argument(
518
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
519
+ )
520
+ parser.add_argument(
521
+ "--mode_scale",
522
+ type=float,
523
+ default=1.29,
524
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
525
+ )
526
+ parser.add_argument(
527
+ "--precondition_outputs",
528
+ type=int,
529
+ default=1,
530
+ help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
531
+ "model `target` is calculated.",
532
+ )
533
+ parser.add_argument(
534
+ "--optimizer",
535
+ type=str,
536
+ default="AdamW",
537
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
538
+ )
539
+
540
+ parser.add_argument(
541
+ "--use_8bit_adam",
542
+ action="store_true",
543
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
544
+ )
545
+
546
+ parser.add_argument(
547
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
548
+ )
549
+ parser.add_argument(
550
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
551
+ )
552
+ parser.add_argument(
553
+ "--prodigy_beta3",
554
+ type=float,
555
+ default=None,
556
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
557
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
558
+ )
559
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
560
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
561
+ parser.add_argument(
562
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
563
+ )
564
+
565
+ parser.add_argument(
566
+ "--adam_epsilon",
567
+ type=float,
568
+ default=1e-08,
569
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
570
+ )
571
+
572
+ parser.add_argument(
573
+ "--prodigy_use_bias_correction",
574
+ type=bool,
575
+ default=True,
576
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
577
+ )
578
+ parser.add_argument(
579
+ "--prodigy_safeguard_warmup",
580
+ type=bool,
581
+ default=True,
582
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
583
+ "Ignored if optimizer is adamW",
584
+ )
585
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
586
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
587
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
588
+ parser.add_argument(
589
+ "--hub_model_id",
590
+ type=str,
591
+ default=None,
592
+ help="The name of the repository to keep in sync with the local `output_dir`.",
593
+ )
594
+ parser.add_argument(
595
+ "--logging_dir",
596
+ type=str,
597
+ default="logs",
598
+ help=(
599
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
600
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
601
+ ),
602
+ )
603
+ parser.add_argument(
604
+ "--allow_tf32",
605
+ action="store_true",
606
+ help=(
607
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
608
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
609
+ ),
610
+ )
611
+ parser.add_argument(
612
+ "--report_to",
613
+ type=str,
614
+ default="tensorboard",
615
+ help=(
616
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
617
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
618
+ ),
619
+ )
620
+ parser.add_argument(
621
+ "--mixed_precision",
622
+ type=str,
623
+ default=None,
624
+ choices=["no", "fp16", "bf16"],
625
+ help=(
626
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
627
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
628
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
629
+ ),
630
+ )
631
+ parser.add_argument(
632
+ "--prior_generation_precision",
633
+ type=str,
634
+ default=None,
635
+ choices=["no", "fp32", "fp16", "bf16"],
636
+ help=(
637
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
638
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
639
+ ),
640
+ )
641
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
642
+
643
+ if input_args is not None:
644
+ args = parser.parse_args(input_args)
645
+ else:
646
+ args = parser.parse_args()
647
+
648
+ if args.dataset_name is None and args.instance_data_dir is None:
649
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
650
+
651
+ if args.dataset_name is not None and args.instance_data_dir is not None:
652
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
653
+
654
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
655
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
656
+ args.local_rank = env_local_rank
657
+
658
+ if args.with_prior_preservation:
659
+ if args.class_data_dir is None:
660
+ raise ValueError("You must specify a data directory for class images.")
661
+ if args.class_prompt is None:
662
+ raise ValueError("You must specify prompt for class images.")
663
+ else:
664
+ # logger is not available yet
665
+ if args.class_data_dir is not None:
666
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
667
+ if args.class_prompt is not None:
668
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
669
+
670
+ return args
671
+
672
+
673
+ class DreamBoothDataset(Dataset):
674
+ """
675
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
676
+ It pre-processes the images.
677
+ """
678
+
679
+ def __init__(
680
+ self,
681
+ instance_data_root,
682
+ instance_prompt,
683
+ class_prompt,
684
+ class_data_root=None,
685
+ class_num=None,
686
+ size=1024,
687
+ repeats=1,
688
+ center_crop=False,
689
+ ):
690
+ self.size = size
691
+ self.center_crop = center_crop
692
+
693
+ self.instance_prompt = instance_prompt
694
+ self.custom_instance_prompts = None
695
+ self.class_prompt = class_prompt
696
+
697
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
698
+ # we load the training data using load_dataset
699
+ if args.dataset_name is not None:
700
+ try:
701
+ from datasets import load_dataset
702
+ except ImportError:
703
+ raise ImportError(
704
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
705
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
706
+ "local folder containing images only, specify --instance_data_dir instead."
707
+ )
708
+ # Downloading and loading a dataset from the hub.
709
+ # See more about loading custom images at
710
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
711
+ dataset = load_dataset(
712
+ args.dataset_name,
713
+ args.dataset_config_name,
714
+ cache_dir=args.cache_dir,
715
+ )
716
+ # Preprocessing the datasets.
717
+ column_names = dataset["train"].column_names
718
+
719
+ # 6. Get the column names for input/target.
720
+ if args.image_column is None:
721
+ image_column = column_names[0]
722
+ logger.info(f"image column defaulting to {image_column}")
723
+ else:
724
+ image_column = args.image_column
725
+ if image_column not in column_names:
726
+ raise ValueError(
727
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
728
+ )
729
+ instance_images = dataset["train"][image_column]
730
+
731
+ if args.caption_column is None:
732
+ logger.info(
733
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
734
+ "contains captions/prompts for the images, make sure to specify the "
735
+ "column as --caption_column"
736
+ )
737
+ self.custom_instance_prompts = None
738
+ else:
739
+ if args.caption_column not in column_names:
740
+ raise ValueError(
741
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
742
+ )
743
+ custom_instance_prompts = dataset["train"][args.caption_column]
744
+ # create final list of captions according to --repeats
745
+ self.custom_instance_prompts = []
746
+ for caption in custom_instance_prompts:
747
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
748
+ else:
749
+ self.instance_data_root = Path(instance_data_root)
750
+ if not self.instance_data_root.exists():
751
+ raise ValueError("Instance images root doesn't exists.")
752
+
753
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
754
+ self.custom_instance_prompts = None
755
+
756
+ self.instance_images = []
757
+ for img in instance_images:
758
+ self.instance_images.extend(itertools.repeat(img, repeats))
759
+
760
+ self.pixel_values = []
761
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
762
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
763
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
764
+ train_transforms = transforms.Compose(
765
+ [
766
+ transforms.ToTensor(),
767
+ transforms.Normalize([0.5], [0.5]),
768
+ ]
769
+ )
770
+ for image in self.instance_images:
771
+ image = exif_transpose(image)
772
+ if not image.mode == "RGB":
773
+ image = image.convert("RGB")
774
+ image = train_resize(image)
775
+ if args.random_flip and random.random() < 0.5:
776
+ # flip
777
+ image = train_flip(image)
778
+ if args.center_crop:
779
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
780
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
781
+ image = train_crop(image)
782
+ else:
783
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
784
+ image = crop(image, y1, x1, h, w)
785
+ image = train_transforms(image)
786
+ self.pixel_values.append(image)
787
+
788
+ self.num_instance_images = len(self.instance_images)
789
+ self._length = self.num_instance_images
790
+
791
+ if class_data_root is not None:
792
+ self.class_data_root = Path(class_data_root)
793
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
794
+ self.class_images_path = list(self.class_data_root.iterdir())
795
+ if class_num is not None:
796
+ self.num_class_images = min(len(self.class_images_path), class_num)
797
+ else:
798
+ self.num_class_images = len(self.class_images_path)
799
+ self._length = max(self.num_class_images, self.num_instance_images)
800
+ else:
801
+ self.class_data_root = None
802
+
803
+ self.image_transforms = transforms.Compose(
804
+ [
805
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
806
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
807
+ transforms.ToTensor(),
808
+ transforms.Normalize([0.5], [0.5]),
809
+ ]
810
+ )
811
+
812
+ def __len__(self):
813
+ return self._length
814
+
815
+ def __getitem__(self, index):
816
+ example = {}
817
+ instance_image = self.pixel_values[index % self.num_instance_images]
818
+ example["instance_images"] = instance_image
819
+
820
+ if self.custom_instance_prompts:
821
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
822
+ if caption:
823
+ example["instance_prompt"] = caption
824
+ else:
825
+ example["instance_prompt"] = self.instance_prompt
826
+
827
+ else: # custom prompts were provided, but length does not match size of image dataset
828
+ example["instance_prompt"] = self.instance_prompt
829
+
830
+ if self.class_data_root:
831
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
832
+ class_image = exif_transpose(class_image)
833
+
834
+ if not class_image.mode == "RGB":
835
+ class_image = class_image.convert("RGB")
836
+ example["class_images"] = self.image_transforms(class_image)
837
+ example["class_prompt"] = self.class_prompt
838
+
839
+ return example
840
+
841
+
842
+ def collate_fn(examples, with_prior_preservation=False):
843
+ pixel_values = [example["instance_images"] for example in examples]
844
+ prompts = [example["instance_prompt"] for example in examples]
845
+
846
+ # Concat class and instance examples for prior preservation.
847
+ # We do this to avoid doing two forward passes.
848
+ if with_prior_preservation:
849
+ pixel_values += [example["class_images"] for example in examples]
850
+ prompts += [example["class_prompt"] for example in examples]
851
+
852
+ pixel_values = torch.stack(pixel_values)
853
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
854
+
855
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
856
+ return batch
857
+
858
+
859
+ class PromptDataset(Dataset):
860
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
861
+
862
+ def __init__(self, prompt, num_samples):
863
+ self.prompt = prompt
864
+ self.num_samples = num_samples
865
+
866
+ def __len__(self):
867
+ return self.num_samples
868
+
869
+ def __getitem__(self, index):
870
+ example = {}
871
+ example["prompt"] = self.prompt
872
+ example["index"] = index
873
+ return example
874
+
875
+
876
+ def tokenize_prompt(tokenizer, prompt):
877
+ text_inputs = tokenizer(
878
+ prompt,
879
+ padding="max_length",
880
+ max_length=77,
881
+ truncation=True,
882
+ return_tensors="pt",
883
+ )
884
+ text_input_ids = text_inputs.input_ids
885
+ return text_input_ids
886
+
887
+
888
+ def _encode_prompt_with_t5(
889
+ text_encoder,
890
+ tokenizer,
891
+ max_sequence_length,
892
+ prompt=None,
893
+ num_images_per_prompt=1,
894
+ device=None,
895
+ text_input_ids=None,
896
+ ):
897
+ prompt = [prompt] if isinstance(prompt, str) else prompt
898
+ batch_size = len(prompt)
899
+
900
+ if tokenizer is not None:
901
+ text_inputs = tokenizer(
902
+ prompt,
903
+ padding="max_length",
904
+ max_length=max_sequence_length,
905
+ truncation=True,
906
+ add_special_tokens=True,
907
+ return_tensors="pt",
908
+ )
909
+ text_input_ids = text_inputs.input_ids
910
+ else:
911
+ if text_input_ids is None:
912
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
913
+
914
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
915
+
916
+ dtype = text_encoder.dtype
917
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
918
+
919
+ _, seq_len, _ = prompt_embeds.shape
920
+
921
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
922
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
923
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
924
+
925
+ return prompt_embeds
926
+
927
+
928
+ def _encode_prompt_with_clip(
929
+ text_encoder,
930
+ tokenizer,
931
+ prompt: str,
932
+ device=None,
933
+ text_input_ids=None,
934
+ num_images_per_prompt: int = 1,
935
+ ):
936
+ prompt = [prompt] if isinstance(prompt, str) else prompt
937
+ batch_size = len(prompt)
938
+
939
+ if tokenizer is not None:
940
+ text_inputs = tokenizer(
941
+ prompt,
942
+ padding="max_length",
943
+ max_length=77,
944
+ truncation=True,
945
+ return_tensors="pt",
946
+ )
947
+
948
+ text_input_ids = text_inputs.input_ids
949
+ else:
950
+ if text_input_ids is None:
951
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
952
+
953
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
954
+
955
+ pooled_prompt_embeds = prompt_embeds[0]
956
+ prompt_embeds = prompt_embeds.hidden_states[-2]
957
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
958
+
959
+ _, seq_len, _ = prompt_embeds.shape
960
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
961
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
962
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
963
+
964
+ return prompt_embeds, pooled_prompt_embeds
965
+
966
+
967
+ def encode_prompt(
968
+ text_encoders,
969
+ tokenizers,
970
+ prompt: str,
971
+ max_sequence_length,
972
+ device=None,
973
+ num_images_per_prompt: int = 1,
974
+ text_input_ids_list=None,
975
+ ):
976
+ prompt = [prompt] if isinstance(prompt, str) else prompt
977
+
978
+ clip_tokenizers = tokenizers[:2]
979
+ clip_text_encoders = text_encoders[:2]
980
+
981
+ clip_prompt_embeds_list = []
982
+ clip_pooled_prompt_embeds_list = []
983
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
984
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
985
+ text_encoder=text_encoder,
986
+ tokenizer=tokenizer,
987
+ prompt=prompt,
988
+ device=device if device is not None else text_encoder.device,
989
+ num_images_per_prompt=num_images_per_prompt,
990
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
991
+ )
992
+ clip_prompt_embeds_list.append(prompt_embeds)
993
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
994
+
995
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
996
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
997
+
998
+ t5_prompt_embed = _encode_prompt_with_t5(
999
+ text_encoders[-1],
1000
+ tokenizers[-1],
1001
+ max_sequence_length,
1002
+ prompt=prompt,
1003
+ num_images_per_prompt=num_images_per_prompt,
1004
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
1005
+ device=device if device is not None else text_encoders[-1].device,
1006
+ )
1007
+
1008
+ clip_prompt_embeds = torch.nn.functional.pad(
1009
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
1010
+ )
1011
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
1012
+
1013
+ return prompt_embeds, pooled_prompt_embeds
1014
+
1015
+
1016
+ def main(args):
1017
+ if args.report_to == "wandb" and args.hub_token is not None:
1018
+ raise ValueError(
1019
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
1020
+ " Please use `huggingface-cli login` to authenticate with the Hub."
1021
+ )
1022
+
1023
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
1024
+ # due to pytorch#99272, MPS does not yet support bfloat16.
1025
+ raise ValueError(
1026
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
1027
+ )
1028
+
1029
+ logging_dir = Path(args.output_dir, args.logging_dir)
1030
+
1031
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
1032
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1033
+ accelerator = Accelerator(
1034
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
1035
+ mixed_precision=args.mixed_precision,
1036
+ log_with=args.report_to,
1037
+ project_config=accelerator_project_config,
1038
+ kwargs_handlers=[kwargs],
1039
+ )
1040
+
1041
+ # Disable AMP for MPS.
1042
+ if torch.backends.mps.is_available():
1043
+ accelerator.native_amp = False
1044
+
1045
+ if args.report_to == "wandb":
1046
+ if not is_wandb_available():
1047
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
1048
+
1049
+ # Make one log on every process with the configuration for debugging.
1050
+ logging.basicConfig(
1051
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1052
+ datefmt="%m/%d/%Y %H:%M:%S",
1053
+ level=logging.INFO,
1054
+ )
1055
+ logger.info(accelerator.state, main_process_only=False)
1056
+ if accelerator.is_local_main_process:
1057
+ transformers.utils.logging.set_verbosity_warning()
1058
+ diffusers.utils.logging.set_verbosity_info()
1059
+ else:
1060
+ transformers.utils.logging.set_verbosity_error()
1061
+ diffusers.utils.logging.set_verbosity_error()
1062
+
1063
+ # If passed along, set the training seed now.
1064
+ if args.seed is not None:
1065
+ set_seed(args.seed)
1066
+
1067
+ # Generate class images if prior preservation is enabled.
1068
+ if args.with_prior_preservation:
1069
+ class_images_dir = Path(args.class_data_dir)
1070
+ if not class_images_dir.exists():
1071
+ class_images_dir.mkdir(parents=True)
1072
+ cur_class_images = len(list(class_images_dir.iterdir()))
1073
+
1074
+ if cur_class_images < args.num_class_images:
1075
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
1076
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
1077
+ if args.prior_generation_precision == "fp32":
1078
+ torch_dtype = torch.float32
1079
+ elif args.prior_generation_precision == "fp16":
1080
+ torch_dtype = torch.float16
1081
+ elif args.prior_generation_precision == "bf16":
1082
+ torch_dtype = torch.bfloat16
1083
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1084
+ args.pretrained_model_name_or_path,
1085
+ torch_dtype=torch_dtype,
1086
+ revision=args.revision,
1087
+ variant=args.variant,
1088
+ )
1089
+ pipeline.set_progress_bar_config(disable=True)
1090
+
1091
+ num_new_images = args.num_class_images - cur_class_images
1092
+ logger.info(f"Number of class images to sample: {num_new_images}.")
1093
+
1094
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
1095
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
1096
+
1097
+ sample_dataloader = accelerator.prepare(sample_dataloader)
1098
+ pipeline.to(accelerator.device)
1099
+
1100
+ for example in tqdm(
1101
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
1102
+ ):
1103
+ images = pipeline(example["prompt"]).images
1104
+
1105
+ for i, image in enumerate(images):
1106
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
1107
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
1108
+ image.save(image_filename)
1109
+
1110
+ del pipeline
1111
+ if torch.cuda.is_available():
1112
+ torch.cuda.empty_cache()
1113
+
1114
+ # Handle the repository creation
1115
+ if accelerator.is_main_process:
1116
+ if args.output_dir is not None:
1117
+ os.makedirs(args.output_dir, exist_ok=True)
1118
+
1119
+ if args.push_to_hub:
1120
+ repo_id = create_repo(
1121
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
1122
+ exist_ok=True,
1123
+ ).repo_id
1124
+
1125
+ # Load the tokenizers
1126
+ tokenizer_one = CLIPTokenizer.from_pretrained(
1127
+ args.pretrained_model_name_or_path,
1128
+ subfolder="tokenizer",
1129
+ revision=args.revision,
1130
+ )
1131
+ tokenizer_two = CLIPTokenizer.from_pretrained(
1132
+ args.pretrained_model_name_or_path,
1133
+ subfolder="tokenizer_2",
1134
+ revision=args.revision,
1135
+ )
1136
+ tokenizer_three = T5TokenizerFast.from_pretrained(
1137
+ args.pretrained_model_name_or_path,
1138
+ subfolder="tokenizer_3",
1139
+ revision=args.revision,
1140
+ )
1141
+
1142
+ # import correct text encoder classes
1143
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
1144
+ args.pretrained_model_name_or_path, args.revision
1145
+ )
1146
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
1147
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
1148
+ )
1149
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
1150
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
1151
+ )
1152
+
1153
+ # Load scheduler and models
1154
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
1155
+ args.pretrained_model_name_or_path, subfolder="scheduler"
1156
+ )
1157
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
1158
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
1159
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
1160
+ )
1161
+ vae = AutoencoderKL.from_pretrained(
1162
+ args.pretrained_model_name_or_path,
1163
+ subfolder="vae",
1164
+ revision=args.revision,
1165
+ variant=args.variant,
1166
+ )
1167
+ transformer = SD3Transformer2DModel.from_pretrained(
1168
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
1169
+ )
1170
+
1171
+ transformer.requires_grad_(False)
1172
+ vae.requires_grad_(False)
1173
+ text_encoder_one.requires_grad_(False)
1174
+ text_encoder_two.requires_grad_(False)
1175
+ text_encoder_three.requires_grad_(False)
1176
+
1177
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
1178
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1179
+ weight_dtype = torch.float32
1180
+ if accelerator.mixed_precision == "fp16":
1181
+ weight_dtype = torch.float16
1182
+ elif accelerator.mixed_precision == "bf16":
1183
+ weight_dtype = torch.bfloat16
1184
+
1185
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
1186
+ # due to pytorch#99272, MPS does not yet support bfloat16.
1187
+ raise ValueError(
1188
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
1189
+ )
1190
+
1191
+ vae.to(accelerator.device, dtype=torch.float32)
1192
+ transformer.to(accelerator.device, dtype=weight_dtype)
1193
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1194
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1195
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
1196
+
1197
+ if args.gradient_checkpointing:
1198
+ transformer.enable_gradient_checkpointing()
1199
+ if args.train_text_encoder:
1200
+ text_encoder_one.gradient_checkpointing_enable()
1201
+ text_encoder_two.gradient_checkpointing_enable()
1202
+
1203
+ # now we will add new LoRA weights to the attention layers
1204
+ transformer_lora_config = LoraConfig(
1205
+ r=args.rank,
1206
+ lora_alpha=args.rank,
1207
+ init_lora_weights="gaussian",
1208
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1209
+ )
1210
+ transformer.add_adapter(transformer_lora_config)
1211
+
1212
+ if args.train_text_encoder:
1213
+ text_lora_config = LoraConfig(
1214
+ r=args.rank,
1215
+ lora_alpha=args.rank,
1216
+ init_lora_weights="gaussian",
1217
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1218
+ )
1219
+ text_encoder_one.add_adapter(text_lora_config)
1220
+ text_encoder_two.add_adapter(text_lora_config)
1221
+
1222
+ def unwrap_model(model):
1223
+ model = accelerator.unwrap_model(model)
1224
+ model = model._orig_mod if is_compiled_module(model) else model
1225
+ return model
1226
+
1227
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1228
+ def save_model_hook(models, weights, output_dir):
1229
+ if accelerator.is_main_process:
1230
+ transformer_lora_layers_to_save = None
1231
+ text_encoder_one_lora_layers_to_save = None
1232
+ text_encoder_two_lora_layers_to_save = None
1233
+
1234
+ for model in models:
1235
+ if isinstance(model, type(unwrap_model(transformer))):
1236
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1237
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
1238
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1239
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
1240
+ text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
1241
+ else:
1242
+ raise ValueError(f"unexpected save model: {model.__class__}")
1243
+
1244
+ # make sure to pop weight so that corresponding model is not saved again
1245
+ weights.pop()
1246
+
1247
+ StableDiffusion3Pipeline.save_lora_weights(
1248
+ output_dir,
1249
+ transformer_lora_layers=transformer_lora_layers_to_save,
1250
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1251
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
1252
+ )
1253
+
1254
+ def load_model_hook(models, input_dir):
1255
+ transformer_ = None
1256
+ text_encoder_one_ = None
1257
+ text_encoder_two_ = None
1258
+
1259
+ while len(models) > 0:
1260
+ model = models.pop()
1261
+
1262
+ if isinstance(model, type(unwrap_model(transformer))):
1263
+ transformer_ = model
1264
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
1265
+ text_encoder_one_ = model
1266
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
1267
+ text_encoder_two_ = model
1268
+ else:
1269
+ raise ValueError(f"unexpected save model: {model.__class__}")
1270
+
1271
+ lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
1272
+
1273
+ transformer_state_dict = {
1274
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
1275
+ }
1276
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
1277
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
1278
+ if incompatible_keys is not None:
1279
+ # check only for unexpected keys
1280
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1281
+ if unexpected_keys:
1282
+ logger.warning(
1283
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1284
+ f" {unexpected_keys}. "
1285
+ )
1286
+ if args.train_text_encoder:
1287
+ # Do we need to call `scale_lora_layers()` here?
1288
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1289
+
1290
+ _set_state_dict_into_text_encoder(
1291
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
1292
+ )
1293
+
1294
+ # Make sure the trainable params are in float32. This is again needed since the base models
1295
+ # are in `weight_dtype`. More details:
1296
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1297
+ if args.mixed_precision == "fp16":
1298
+ models = [transformer_]
1299
+ if args.train_text_encoder:
1300
+ models.extend([text_encoder_one_, text_encoder_two_])
1301
+ # only upcast trainable parameters (LoRA) into fp32
1302
+ cast_training_params(models)
1303
+
1304
+ accelerator.register_save_state_pre_hook(save_model_hook)
1305
+ accelerator.register_load_state_pre_hook(load_model_hook)
1306
+
1307
+ # Enable TF32 for faster training on Ampere GPUs,
1308
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1309
+ if args.allow_tf32 and torch.cuda.is_available():
1310
+ torch.backends.cuda.matmul.allow_tf32 = True
1311
+
1312
+ if args.scale_lr:
1313
+ args.learning_rate = (
1314
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1315
+ )
1316
+
1317
+ # Make sure the trainable params are in float32.
1318
+ if args.mixed_precision == "fp16":
1319
+ models = [transformer]
1320
+ if args.train_text_encoder:
1321
+ models.extend([text_encoder_one, text_encoder_two])
1322
+ # only upcast trainable parameters (LoRA) into fp32
1323
+ cast_training_params(models, dtype=torch.float32)
1324
+
1325
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
1326
+ if args.train_text_encoder:
1327
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
1328
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
1329
+
1330
+ # Optimization parameters
1331
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
1332
+ if args.train_text_encoder:
1333
+ # different learning rate for text encoder and unet
1334
+ text_lora_parameters_one_with_lr = {
1335
+ "params": text_lora_parameters_one,
1336
+ "weight_decay": args.adam_weight_decay_text_encoder,
1337
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
1338
+ }
1339
+ text_lora_parameters_two_with_lr = {
1340
+ "params": text_lora_parameters_two,
1341
+ "weight_decay": args.adam_weight_decay_text_encoder,
1342
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
1343
+ }
1344
+ params_to_optimize = [
1345
+ transformer_parameters_with_lr,
1346
+ text_lora_parameters_one_with_lr,
1347
+ text_lora_parameters_two_with_lr,
1348
+ ]
1349
+ else:
1350
+ params_to_optimize = [transformer_parameters_with_lr]
1351
+
1352
+ # Optimizer creation
1353
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
1354
+ logger.warning(
1355
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
1356
+ "Defaulting to adamW"
1357
+ )
1358
+ args.optimizer = "adamw"
1359
+
1360
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
1361
+ logger.warning(
1362
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
1363
+ f"set to {args.optimizer.lower()}"
1364
+ )
1365
+
1366
+ if args.optimizer.lower() == "adamw":
1367
+ if args.use_8bit_adam:
1368
+ try:
1369
+ import bitsandbytes as bnb
1370
+ except ImportError:
1371
+ raise ImportError(
1372
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1373
+ )
1374
+
1375
+ optimizer_class = bnb.optim.AdamW8bit
1376
+ else:
1377
+ optimizer_class = torch.optim.AdamW
1378
+
1379
+ optimizer = optimizer_class(
1380
+ params_to_optimize,
1381
+ betas=(args.adam_beta1, args.adam_beta2),
1382
+ weight_decay=args.adam_weight_decay,
1383
+ eps=args.adam_epsilon,
1384
+ )
1385
+
1386
+ if args.optimizer.lower() == "prodigy":
1387
+ try:
1388
+ import prodigyopt
1389
+ except ImportError:
1390
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
1391
+
1392
+ optimizer_class = prodigyopt.Prodigy
1393
+
1394
+ if args.learning_rate <= 0.1:
1395
+ logger.warning(
1396
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1397
+ )
1398
+
1399
+ optimizer = optimizer_class(
1400
+ params_to_optimize,
1401
+ lr=args.learning_rate,
1402
+ betas=(args.adam_beta1, args.adam_beta2),
1403
+ beta3=args.prodigy_beta3,
1404
+ weight_decay=args.adam_weight_decay,
1405
+ eps=args.adam_epsilon,
1406
+ decouple=args.prodigy_decouple,
1407
+ use_bias_correction=args.prodigy_use_bias_correction,
1408
+ safeguard_warmup=args.prodigy_safeguard_warmup,
1409
+ )
1410
+
1411
+ # Dataset and DataLoaders creation:
1412
+ train_dataset = DreamBoothDataset(
1413
+ instance_data_root=args.instance_data_dir,
1414
+ instance_prompt=args.instance_prompt,
1415
+ class_prompt=args.class_prompt,
1416
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1417
+ class_num=args.num_class_images,
1418
+ size=args.resolution,
1419
+ repeats=args.repeats,
1420
+ center_crop=args.center_crop,
1421
+ )
1422
+
1423
+ train_dataloader = torch.utils.data.DataLoader(
1424
+ train_dataset,
1425
+ batch_size=args.train_batch_size,
1426
+ shuffle=True,
1427
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1428
+ num_workers=args.dataloader_num_workers,
1429
+ )
1430
+
1431
+ if not args.train_text_encoder:
1432
+ tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
1433
+ text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
1434
+
1435
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1436
+ with torch.no_grad():
1437
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1438
+ text_encoders, tokenizers, prompt, args.max_sequence_length
1439
+ )
1440
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1441
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1442
+ return prompt_embeds, pooled_prompt_embeds
1443
+
1444
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1445
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
1446
+ args.instance_prompt, text_encoders, tokenizers
1447
+ )
1448
+
1449
+ # Handle class prompt for prior-preservation.
1450
+ if args.with_prior_preservation:
1451
+ if not args.train_text_encoder:
1452
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
1453
+ args.class_prompt, text_encoders, tokenizers
1454
+ )
1455
+
1456
+ # Clear the memory here #MJB - commented out since we still need the text_encoders and tokenizers, see line 1621
1457
+ '''
1458
+ if not args.train_text_encoder and train_dataset.custom_instance_prompts:
1459
+ del tokenizers, text_encoders
1460
+ # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
1461
+ del text_encoder_one, text_encoder_two, text_encoder_three
1462
+ gc.collect()
1463
+ if torch.cuda.is_available():
1464
+ torch.cuda.empty_cache()
1465
+ '''
1466
+
1467
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
1468
+ # pack the statically computed variables appropriately here. This is so that we don't
1469
+ # have to pass them to the dataloader.
1470
+
1471
+ if not train_dataset.custom_instance_prompts:
1472
+ if not args.train_text_encoder:
1473
+ prompt_embeds = instance_prompt_hidden_states
1474
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
1475
+ if args.with_prior_preservation:
1476
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
1477
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
1478
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
1479
+ # batch prompts on all training steps
1480
+ else:
1481
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
1482
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
1483
+ tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)
1484
+ if args.with_prior_preservation:
1485
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
1486
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
1487
+ class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt)
1488
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
1489
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
1490
+ tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
1491
+
1492
+ # Scheduler and math around the number of training steps.
1493
+ overrode_max_train_steps = False
1494
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1495
+ if args.max_train_steps is None:
1496
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1497
+ overrode_max_train_steps = True
1498
+
1499
+ lr_scheduler = get_scheduler(
1500
+ args.lr_scheduler,
1501
+ optimizer=optimizer,
1502
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1503
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1504
+ num_cycles=args.lr_num_cycles,
1505
+ power=args.lr_power,
1506
+ )
1507
+
1508
+ # Prepare everything with our `accelerator`.
1509
+ # Prepare everything with our `accelerator`.
1510
+ if args.train_text_encoder:
1511
+ (
1512
+ transformer,
1513
+ text_encoder_one,
1514
+ text_encoder_two,
1515
+ optimizer,
1516
+ train_dataloader,
1517
+ lr_scheduler,
1518
+ ) = accelerator.prepare(
1519
+ transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1520
+ )
1521
+ assert text_encoder_one is not None
1522
+ assert text_encoder_two is not None
1523
+ assert text_encoder_three is not None
1524
+ else:
1525
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1526
+ transformer, optimizer, train_dataloader, lr_scheduler
1527
+ )
1528
+
1529
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1530
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1531
+ if overrode_max_train_steps:
1532
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1533
+ # Afterwards we recalculate our number of training epochs
1534
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1535
+
1536
+ # We need to initialize the trackers we use, and also store our configuration.
1537
+ # The trackers initializes automatically on the main process.
1538
+ if accelerator.is_main_process:
1539
+ tracker_name = "dreambooth-sd3-lora"
1540
+ accelerator.init_trackers(tracker_name, config=vars(args))
1541
+
1542
+ # Train!
1543
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1544
+
1545
+ logger.info("***** Running training *****")
1546
+ logger.info(f" Num examples = {len(train_dataset)}")
1547
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1548
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1549
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1550
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1551
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1552
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1553
+ global_step = 0
1554
+ first_epoch = 0
1555
+
1556
+ # Potentially load in the weights and states from a previous save
1557
+ if args.resume_from_checkpoint:
1558
+ if args.resume_from_checkpoint != "latest":
1559
+ path = os.path.basename(args.resume_from_checkpoint)
1560
+ else:
1561
+ # Get the mos recent checkpoint
1562
+ dirs = os.listdir(args.output_dir)
1563
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1564
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1565
+ path = dirs[-1] if len(dirs) > 0 else None
1566
+
1567
+ if path is None:
1568
+ accelerator.print(
1569
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1570
+ )
1571
+ args.resume_from_checkpoint = None
1572
+ initial_global_step = 0
1573
+ else:
1574
+ accelerator.print(f"Resuming from checkpoint {path}")
1575
+ accelerator.load_state(os.path.join(args.output_dir, path))
1576
+ global_step = int(path.split("-")[1])
1577
+
1578
+ initial_global_step = global_step
1579
+ first_epoch = global_step // num_update_steps_per_epoch
1580
+
1581
+ else:
1582
+ initial_global_step = 0
1583
+
1584
+ progress_bar = tqdm(
1585
+ range(0, args.max_train_steps),
1586
+ initial=initial_global_step,
1587
+ desc="Steps",
1588
+ # Only show the progress bar once on each machine.
1589
+ disable=not accelerator.is_local_main_process,
1590
+ )
1591
+
1592
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1593
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1594
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1595
+ timesteps = timesteps.to(accelerator.device)
1596
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1597
+
1598
+ sigma = sigmas[step_indices].flatten()
1599
+ while len(sigma.shape) < n_dim:
1600
+ sigma = sigma.unsqueeze(-1)
1601
+ return sigma
1602
+
1603
+ for epoch in range(first_epoch, args.num_train_epochs):
1604
+ transformer.train()
1605
+ if args.train_text_encoder:
1606
+ text_encoder_one.train()
1607
+ text_encoder_two.train()
1608
+
1609
+ # set top parameter requires_grad = True for gradient checkpointing works
1610
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1611
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
1612
+
1613
+ for step, batch in enumerate(train_dataloader):
1614
+ models_to_accumulate = [transformer]
1615
+ with accelerator.accumulate(models_to_accumulate):
1616
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1617
+ prompts = batch["prompts"]
1618
+
1619
+ # encode batch prompts when custom prompts are provided for each image -
1620
+ if train_dataset.custom_instance_prompts:
1621
+ if not args.train_text_encoder:
1622
+ prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
1623
+ prompts, text_encoders, tokenizers
1624
+ )
1625
+ else:
1626
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
1627
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
1628
+ tokens_three = tokenize_prompt(tokenizer_three, prompts)
1629
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1630
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
1631
+ tokenizers=[None, None, None],
1632
+ prompt=prompts,
1633
+ max_sequence_length=args.max_sequence_length,
1634
+ text_input_ids_list=[tokens_one, tokens_two, tokens_three],
1635
+ )
1636
+ else:
1637
+ if args.train_text_encoder:
1638
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1639
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
1640
+ tokenizers=[None, None, tokenizer_three],
1641
+ prompt=args.instance_prompt,
1642
+ max_sequence_length=args.max_sequence_length,
1643
+ text_input_ids_list=[tokens_one, tokens_two, tokens_three],
1644
+ )
1645
+
1646
+ # Convert images to latent space
1647
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1648
+ model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
1649
+ model_input = model_input.to(dtype=weight_dtype)
1650
+
1651
+ # Sample noise that we'll add to the latents
1652
+ noise = torch.randn_like(model_input)
1653
+ bsz = model_input.shape[0]
1654
+
1655
+ # Sample a random timestep for each image
1656
+ # for weighting schemes where we sample timesteps non-uniformly
1657
+ u = compute_density_for_timestep_sampling(
1658
+ weighting_scheme=args.weighting_scheme,
1659
+ batch_size=bsz,
1660
+ logit_mean=args.logit_mean,
1661
+ logit_std=args.logit_std,
1662
+ mode_scale=args.mode_scale,
1663
+ )
1664
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1665
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1666
+
1667
+ # Add noise according to flow matching.
1668
+ # zt = (1 - texp) * x + texp * z1
1669
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1670
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1671
+
1672
+ # Predict the noise residual
1673
+ model_pred = transformer(
1674
+ hidden_states=noisy_model_input,
1675
+ timestep=timesteps,
1676
+ encoder_hidden_states=prompt_embeds,
1677
+ pooled_projections=pooled_prompt_embeds,
1678
+ return_dict=False,
1679
+ )[0]
1680
+
1681
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1682
+ # Preconditioning of the model outputs.
1683
+ if args.precondition_outputs:
1684
+ model_pred = model_pred * (-sigmas) + noisy_model_input
1685
+
1686
+ # these weighting schemes use a uniform timestep sampling
1687
+ # and instead post-weight the loss
1688
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1689
+
1690
+ # flow matching loss
1691
+ if args.precondition_outputs:
1692
+ target = model_input
1693
+ else:
1694
+ target = noise - model_input
1695
+
1696
+ if args.with_prior_preservation:
1697
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1698
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1699
+ target, target_prior = torch.chunk(target, 2, dim=0)
1700
+
1701
+ # Compute prior loss
1702
+ prior_loss = torch.mean(
1703
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
1704
+ target_prior.shape[0], -1
1705
+ ),
1706
+ 1,
1707
+ )
1708
+ prior_loss = prior_loss.mean()
1709
+
1710
+ # Compute regular loss.
1711
+ loss = torch.mean(
1712
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1713
+ 1,
1714
+ )
1715
+ loss = loss.mean()
1716
+
1717
+ if args.with_prior_preservation:
1718
+ # Add the prior loss to the instance loss.
1719
+ loss = loss + args.prior_loss_weight * prior_loss
1720
+
1721
+ accelerator.backward(loss)
1722
+ if accelerator.sync_gradients:
1723
+ params_to_clip = (
1724
+ itertools.chain(
1725
+ transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
1726
+ )
1727
+ if args.train_text_encoder
1728
+ else transformer_lora_parameters
1729
+ )
1730
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1731
+
1732
+ optimizer.step()
1733
+ lr_scheduler.step()
1734
+ optimizer.zero_grad()
1735
+
1736
+ # Checks if the accelerator has performed an optimization step behind the scenes
1737
+ if accelerator.sync_gradients:
1738
+ progress_bar.update(1)
1739
+ global_step += 1
1740
+
1741
+ if accelerator.is_main_process:
1742
+ if global_step % args.checkpointing_steps == 0:
1743
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1744
+ if args.checkpoints_total_limit is not None:
1745
+ checkpoints = os.listdir(args.output_dir)
1746
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1747
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1748
+
1749
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1750
+ if len(checkpoints) >= args.checkpoints_total_limit:
1751
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1752
+ removing_checkpoints = checkpoints[0:num_to_remove]
1753
+
1754
+ logger.info(
1755
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1756
+ )
1757
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1758
+
1759
+ for removing_checkpoint in removing_checkpoints:
1760
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1761
+ shutil.rmtree(removing_checkpoint)
1762
+
1763
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1764
+ accelerator.save_state(save_path)
1765
+ logger.info(f"Saved state to {save_path}")
1766
+
1767
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1768
+ progress_bar.set_postfix(**logs)
1769
+ accelerator.log(logs, step=global_step)
1770
+
1771
+ if global_step >= args.max_train_steps:
1772
+ break
1773
+
1774
+ if accelerator.is_main_process:
1775
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1776
+ if not args.train_text_encoder:
1777
+ # create pipeline
1778
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
1779
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
1780
+ )
1781
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1782
+ args.pretrained_model_name_or_path,
1783
+ vae=vae,
1784
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1785
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1786
+ text_encoder_3=accelerator.unwrap_model(text_encoder_three),
1787
+ transformer=accelerator.unwrap_model(transformer),
1788
+ revision=args.revision,
1789
+ variant=args.variant,
1790
+ torch_dtype=weight_dtype,
1791
+ )
1792
+ pipeline_args = {"prompt": args.validation_prompt}
1793
+ images = log_validation(
1794
+ pipeline=pipeline,
1795
+ args=args,
1796
+ accelerator=accelerator,
1797
+ pipeline_args=pipeline_args,
1798
+ epoch=epoch,
1799
+ )
1800
+ if not args.train_text_encoder:
1801
+ del text_encoder_one, text_encoder_two, text_encoder_three
1802
+
1803
+ torch.cuda.empty_cache()
1804
+ gc.collect()
1805
+
1806
+ # Save the lora layers
1807
+ accelerator.wait_for_everyone()
1808
+ if accelerator.is_main_process:
1809
+ transformer = unwrap_model(transformer)
1810
+ transformer = transformer.to(torch.float32)
1811
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
1812
+
1813
+ if args.train_text_encoder:
1814
+ text_encoder_one = unwrap_model(text_encoder_one)
1815
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1816
+ text_encoder_two = unwrap_model(text_encoder_two)
1817
+ text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
1818
+ else:
1819
+ text_encoder_lora_layers = None
1820
+ text_encoder_2_lora_layers = None
1821
+
1822
+ StableDiffusion3Pipeline.save_lora_weights(
1823
+ save_directory=args.output_dir,
1824
+ transformer_lora_layers=transformer_lora_layers,
1825
+ text_encoder_lora_layers=text_encoder_lora_layers,
1826
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1827
+ )
1828
+
1829
+ # Final inference
1830
+ # Load previous pipeline
1831
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
1832
+ args.pretrained_model_name_or_path,
1833
+ revision=args.revision,
1834
+ variant=args.variant,
1835
+ torch_dtype=weight_dtype,
1836
+ )
1837
+ # load attention processors
1838
+ pipeline.load_lora_weights(args.output_dir)
1839
+
1840
+ # run inference
1841
+ images = []
1842
+ if args.validation_prompt and args.num_validation_images > 0:
1843
+ pipeline_args = {"prompt": args.validation_prompt}
1844
+ images = log_validation(
1845
+ pipeline=pipeline,
1846
+ args=args,
1847
+ accelerator=accelerator,
1848
+ pipeline_args=pipeline_args,
1849
+ epoch=epoch,
1850
+ is_final_validation=True,
1851
+ )
1852
+
1853
+ if args.push_to_hub:
1854
+ save_model_card(
1855
+ repo_id,
1856
+ images=images,
1857
+ base_model=args.pretrained_model_name_or_path,
1858
+ instance_prompt=args.instance_prompt,
1859
+ validation_prompt=args.validation_prompt,
1860
+ train_text_encoder=args.train_text_encoder,
1861
+ repo_folder=args.output_dir,
1862
+ )
1863
+ upload_folder(
1864
+ repo_id=repo_id,
1865
+ folder_path=args.output_dir,
1866
+ commit_message="End of training",
1867
+ ignore_patterns=["step_*", "epoch_*"],
1868
+ )
1869
+
1870
+ accelerator.end_training()
1871
+
1872
+
1873
+ if __name__ == "__main__":
1874
+ args = parse_args()
1875
+ main(args)