ysharma HF staff commited on
Commit
2310d22
1 Parent(s): dc9b150

updated file

Browse files
Files changed (1) hide show
  1. train_lora_dreambooth.py +956 -0
train_lora_dreambooth.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import itertools
7
+ import math
8
+ import os
9
+ import inspect
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import (
22
+ AutoencoderKL,
23
+ DDPMScheduler,
24
+ StableDiffusionPipeline,
25
+ UNet2DConditionModel,
26
+ )
27
+ from diffusers.optimization import get_scheduler
28
+ from huggingface_hub import HfFolder, Repository, whoami
29
+
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from lora_diffusion import (
34
+ inject_trainable_lora,
35
+ save_lora_weight,
36
+ extract_lora_ups_down,
37
+ )
38
+
39
+ from torch.utils.data import Dataset
40
+ from PIL import Image
41
+ from torchvision import transforms
42
+
43
+ from pathlib import Path
44
+
45
+ import random
46
+ import re
47
+
48
+
49
+ class DreamBoothDataset(Dataset):
50
+ """
51
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
52
+ It pre-processes the images and the tokenizes prompts.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ instance_data_root,
58
+ instance_prompt,
59
+ tokenizer,
60
+ class_data_root=None,
61
+ class_prompt=None,
62
+ size=512,
63
+ center_crop=False,
64
+ color_jitter=False,
65
+ ):
66
+ self.size = size
67
+ self.center_crop = center_crop
68
+ self.tokenizer = tokenizer
69
+
70
+ self.instance_data_root = Path(instance_data_root)
71
+ if not self.instance_data_root.exists():
72
+ raise ValueError("Instance images root doesn't exists.")
73
+
74
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
75
+ self.num_instance_images = len(self.instance_images_path)
76
+ self.instance_prompt = instance_prompt
77
+ self._length = self.num_instance_images
78
+
79
+ if class_data_root is not None:
80
+ self.class_data_root = Path(class_data_root)
81
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
82
+ self.class_images_path = list(self.class_data_root.iterdir())
83
+ self.num_class_images = len(self.class_images_path)
84
+ self._length = max(self.num_class_images, self.num_instance_images)
85
+ self.class_prompt = class_prompt
86
+ else:
87
+ self.class_data_root = None
88
+
89
+ self.image_transforms = transforms.Compose(
90
+ [
91
+ transforms.Resize(
92
+ size, interpolation=transforms.InterpolationMode.BILINEAR
93
+ ),
94
+ transforms.CenterCrop(size)
95
+ if center_crop
96
+ else transforms.RandomCrop(size),
97
+ transforms.ColorJitter(0.2, 0.1)
98
+ if color_jitter
99
+ else transforms.Lambda(lambda x: x),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5], [0.5]),
102
+ ]
103
+ )
104
+
105
+ def __len__(self):
106
+ return self._length
107
+
108
+ def __getitem__(self, index):
109
+ example = {}
110
+ instance_image = Image.open(
111
+ self.instance_images_path[index % self.num_instance_images]
112
+ )
113
+ if not instance_image.mode == "RGB":
114
+ instance_image = instance_image.convert("RGB")
115
+ example["instance_images"] = self.image_transforms(instance_image)
116
+ example["instance_prompt_ids"] = self.tokenizer(
117
+ self.instance_prompt,
118
+ padding="do_not_pad",
119
+ truncation=True,
120
+ max_length=self.tokenizer.model_max_length,
121
+ ).input_ids
122
+
123
+ if self.class_data_root:
124
+ class_image = Image.open(
125
+ self.class_images_path[index % self.num_class_images]
126
+ )
127
+ if not class_image.mode == "RGB":
128
+ class_image = class_image.convert("RGB")
129
+ example["class_images"] = self.image_transforms(class_image)
130
+ example["class_prompt_ids"] = self.tokenizer(
131
+ self.class_prompt,
132
+ padding="do_not_pad",
133
+ truncation=True,
134
+ max_length=self.tokenizer.model_max_length,
135
+ ).input_ids
136
+
137
+ return example
138
+
139
+
140
+ class PromptDataset(Dataset):
141
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
142
+
143
+ def __init__(self, prompt, num_samples):
144
+ self.prompt = prompt
145
+ self.num_samples = num_samples
146
+
147
+ def __len__(self):
148
+ return self.num_samples
149
+
150
+ def __getitem__(self, index):
151
+ example = {}
152
+ example["prompt"] = self.prompt
153
+ example["index"] = index
154
+ return example
155
+
156
+
157
+ logger = get_logger(__name__)
158
+
159
+
160
+ def parse_args(input_args=None):
161
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
162
+ parser.add_argument(
163
+ "--pretrained_model_name_or_path",
164
+ type=str,
165
+ default=None,
166
+ required=True,
167
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
168
+ )
169
+ parser.add_argument(
170
+ "--pretrained_vae_name_or_path",
171
+ type=str,
172
+ default=None,
173
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
174
+ )
175
+ parser.add_argument(
176
+ "--revision",
177
+ type=str,
178
+ default=None,
179
+ required=False,
180
+ help="Revision of pretrained model identifier from huggingface.co/models.",
181
+ )
182
+ parser.add_argument(
183
+ "--tokenizer_name",
184
+ type=str,
185
+ default=None,
186
+ help="Pretrained tokenizer name or path if not the same as model_name",
187
+ )
188
+ parser.add_argument(
189
+ "--instance_data_dir",
190
+ type=str,
191
+ default=None,
192
+ required=True,
193
+ help="A folder containing the training data of instance images.",
194
+ )
195
+ parser.add_argument(
196
+ "--class_data_dir",
197
+ type=str,
198
+ default=None,
199
+ required=False,
200
+ help="A folder containing the training data of class images.",
201
+ )
202
+ parser.add_argument(
203
+ "--instance_prompt",
204
+ type=str,
205
+ default=None,
206
+ required=True,
207
+ help="The prompt with identifier specifying the instance",
208
+ )
209
+ parser.add_argument(
210
+ "--class_prompt",
211
+ type=str,
212
+ default=None,
213
+ help="The prompt to specify images in the same class as provided instance images.",
214
+ )
215
+ parser.add_argument(
216
+ "--with_prior_preservation",
217
+ default=False,
218
+ action="store_true",
219
+ help="Flag to add prior preservation loss.",
220
+ )
221
+ parser.add_argument(
222
+ "--prior_loss_weight",
223
+ type=float,
224
+ default=1.0,
225
+ help="The weight of prior preservation loss.",
226
+ )
227
+ parser.add_argument(
228
+ "--num_class_images",
229
+ type=int,
230
+ default=100,
231
+ help=(
232
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
233
+ " sampled with class_prompt."
234
+ ),
235
+ )
236
+ parser.add_argument(
237
+ "--output_dir",
238
+ type=str,
239
+ default="text-inversion-model",
240
+ help="The output directory where the model predictions and checkpoints will be written.",
241
+ )
242
+ parser.add_argument(
243
+ "--seed", type=int, default=None, help="A seed for reproducible training."
244
+ )
245
+ parser.add_argument(
246
+ "--resolution",
247
+ type=int,
248
+ default=512,
249
+ help=(
250
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
251
+ " resolution"
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--center_crop",
256
+ action="store_true",
257
+ help="Whether to center crop images before resizing to resolution",
258
+ )
259
+ parser.add_argument(
260
+ "--color_jitter",
261
+ action="store_true",
262
+ help="Whether to apply color jitter to images",
263
+ )
264
+ parser.add_argument(
265
+ "--train_text_encoder",
266
+ action="store_true",
267
+ help="Whether to train the text encoder",
268
+ )
269
+ parser.add_argument(
270
+ "--train_batch_size",
271
+ type=int,
272
+ default=4,
273
+ help="Batch size (per device) for the training dataloader.",
274
+ )
275
+ parser.add_argument(
276
+ "--sample_batch_size",
277
+ type=int,
278
+ default=4,
279
+ help="Batch size (per device) for sampling images.",
280
+ )
281
+ parser.add_argument("--num_train_epochs", type=int, default=1)
282
+ parser.add_argument(
283
+ "--max_train_steps",
284
+ type=int,
285
+ default=None,
286
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
287
+ )
288
+ parser.add_argument(
289
+ "--save_steps",
290
+ type=int,
291
+ default=500,
292
+ help="Save checkpoint every X updates steps.",
293
+ )
294
+ parser.add_argument(
295
+ "--gradient_accumulation_steps",
296
+ type=int,
297
+ default=1,
298
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
299
+ )
300
+ parser.add_argument(
301
+ "--gradient_checkpointing",
302
+ action="store_true",
303
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
304
+ )
305
+ parser.add_argument(
306
+ "--lora_rank",
307
+ type=int,
308
+ default=4,
309
+ help="Rank of LoRA approximation.",
310
+ )
311
+ parser.add_argument(
312
+ "--learning_rate",
313
+ type=float,
314
+ default=None,
315
+ help="Initial learning rate (after the potential warmup period) to use.",
316
+ )
317
+ parser.add_argument(
318
+ "--learning_rate_text",
319
+ type=float,
320
+ default=5e-6,
321
+ help="Initial learning rate for text encoder (after the potential warmup period) to use.",
322
+ )
323
+ parser.add_argument(
324
+ "--scale_lr",
325
+ action="store_true",
326
+ default=False,
327
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
328
+ )
329
+ parser.add_argument(
330
+ "--lr_scheduler",
331
+ type=str,
332
+ default="constant",
333
+ help=(
334
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
335
+ ' "constant", "constant_with_warmup"]'
336
+ ),
337
+ )
338
+ parser.add_argument(
339
+ "--lr_warmup_steps",
340
+ type=int,
341
+ default=500,
342
+ help="Number of steps for the warmup in the lr scheduler.",
343
+ )
344
+ parser.add_argument(
345
+ "--use_8bit_adam",
346
+ action="store_true",
347
+ help="Whether or not to use 8-bit Adam from bitsandbytes.",
348
+ )
349
+ parser.add_argument(
350
+ "--adam_beta1",
351
+ type=float,
352
+ default=0.9,
353
+ help="The beta1 parameter for the Adam optimizer.",
354
+ )
355
+ parser.add_argument(
356
+ "--adam_beta2",
357
+ type=float,
358
+ default=0.999,
359
+ help="The beta2 parameter for the Adam optimizer.",
360
+ )
361
+ parser.add_argument(
362
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
363
+ )
364
+ parser.add_argument(
365
+ "--adam_epsilon",
366
+ type=float,
367
+ default=1e-08,
368
+ help="Epsilon value for the Adam optimizer",
369
+ )
370
+ parser.add_argument(
371
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
372
+ )
373
+ parser.add_argument(
374
+ "--push_to_hub",
375
+ action="store_true",
376
+ help="Whether or not to push the model to the Hub.",
377
+ )
378
+ parser.add_argument(
379
+ "--hub_token",
380
+ type=str,
381
+ default=None,
382
+ help="The token to use to push to the Model Hub.",
383
+ )
384
+ parser.add_argument(
385
+ "--logging_dir",
386
+ type=str,
387
+ default="logs",
388
+ help=(
389
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
390
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--mixed_precision",
395
+ type=str,
396
+ default=None,
397
+ choices=["no", "fp16", "bf16"],
398
+ help=(
399
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
400
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
401
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
402
+ ),
403
+ )
404
+ parser.add_argument(
405
+ "--local_rank",
406
+ type=int,
407
+ default=-1,
408
+ help="For distributed training: local_rank",
409
+ )
410
+ parser.add_argument(
411
+ "--resume_unet",
412
+ type=str,
413
+ default=None,
414
+ help=("File path for unet lora to resume training."),
415
+ )
416
+ parser.add_argument(
417
+ "--resume_text_encoder",
418
+ type=str,
419
+ default=None,
420
+ help=("File path for text encoder lora to resume training."),
421
+ )
422
+
423
+ if input_args is not None:
424
+ args = parser.parse_args(input_args)
425
+ else:
426
+ args = parser.parse_args()
427
+
428
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
429
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
430
+ args.local_rank = env_local_rank
431
+
432
+ if args.with_prior_preservation:
433
+ if args.class_data_dir is None:
434
+ raise ValueError("You must specify a data directory for class images.")
435
+ if args.class_prompt is None:
436
+ raise ValueError("You must specify prompt for class images.")
437
+ else:
438
+ if args.class_data_dir is not None:
439
+ logger.warning(
440
+ "You need not use --class_data_dir without --with_prior_preservation."
441
+ )
442
+ if args.class_prompt is not None:
443
+ logger.warning(
444
+ "You need not use --class_prompt without --with_prior_preservation."
445
+ )
446
+
447
+ return args
448
+
449
+
450
+ def main(args):
451
+ logging_dir = Path(args.output_dir, args.logging_dir)
452
+
453
+ accelerator = Accelerator(
454
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
455
+ mixed_precision=args.mixed_precision,
456
+ log_with="tensorboard",
457
+ logging_dir=logging_dir,
458
+ )
459
+
460
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
461
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
462
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
463
+ if (
464
+ args.train_text_encoder
465
+ and args.gradient_accumulation_steps > 1
466
+ and accelerator.num_processes > 1
467
+ ):
468
+ raise ValueError(
469
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
470
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
471
+ )
472
+
473
+ if args.seed is not None:
474
+ set_seed(args.seed)
475
+
476
+ if args.with_prior_preservation:
477
+ class_images_dir = Path(args.class_data_dir)
478
+ if not class_images_dir.exists():
479
+ class_images_dir.mkdir(parents=True)
480
+ cur_class_images = len(list(class_images_dir.iterdir()))
481
+
482
+ if cur_class_images < args.num_class_images:
483
+ torch_dtype = (
484
+ torch.float16 if accelerator.device.type == "cuda" else torch.float32
485
+ )
486
+ pipeline = StableDiffusionPipeline.from_pretrained(
487
+ args.pretrained_model_name_or_path,
488
+ torch_dtype=torch_dtype,
489
+ safety_checker=None,
490
+ revision=args.revision,
491
+ )
492
+ pipeline.set_progress_bar_config(disable=True)
493
+
494
+ num_new_images = args.num_class_images - cur_class_images
495
+ logger.info(f"Number of class images to sample: {num_new_images}.")
496
+
497
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
498
+ sample_dataloader = torch.utils.data.DataLoader(
499
+ sample_dataset, batch_size=args.sample_batch_size
500
+ )
501
+
502
+ sample_dataloader = accelerator.prepare(sample_dataloader)
503
+ pipeline.to(accelerator.device)
504
+
505
+ for example in tqdm(
506
+ sample_dataloader,
507
+ desc="Generating class images",
508
+ disable=not accelerator.is_local_main_process,
509
+ ):
510
+ images = pipeline(example["prompt"]).images
511
+
512
+ for i, image in enumerate(images):
513
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
514
+ image_filename = (
515
+ class_images_dir
516
+ / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
517
+ )
518
+ image.save(image_filename)
519
+
520
+ del pipeline
521
+ if torch.cuda.is_available():
522
+ torch.cuda.empty_cache()
523
+
524
+ # Handle the repository creation
525
+ if accelerator.is_main_process:
526
+
527
+ if args.output_dir is not None:
528
+ os.makedirs(args.output_dir, exist_ok=True)
529
+
530
+ # Load the tokenizer
531
+ if args.tokenizer_name:
532
+ tokenizer = CLIPTokenizer.from_pretrained(
533
+ args.tokenizer_name,
534
+ revision=args.revision,
535
+ )
536
+ elif args.pretrained_model_name_or_path:
537
+ tokenizer = CLIPTokenizer.from_pretrained(
538
+ args.pretrained_model_name_or_path,
539
+ subfolder="tokenizer",
540
+ revision=args.revision,
541
+ )
542
+
543
+ # Load models and create wrapper for stable diffusion
544
+ text_encoder = CLIPTextModel.from_pretrained(
545
+ args.pretrained_model_name_or_path,
546
+ subfolder="text_encoder",
547
+ revision=args.revision,
548
+ )
549
+ vae = AutoencoderKL.from_pretrained(
550
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
551
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
552
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
553
+ )
554
+ unet = UNet2DConditionModel.from_pretrained(
555
+ args.pretrained_model_name_or_path,
556
+ subfolder="unet",
557
+ revision=args.revision,
558
+ )
559
+ unet.requires_grad_(False)
560
+ unet_lora_params, _ = inject_trainable_lora(
561
+ unet, r=args.lora_rank, loras=args.resume_unet
562
+ )
563
+
564
+ for _up, _down in extract_lora_ups_down(unet):
565
+ print("Before training: Unet First Layer lora up", _up.weight.data)
566
+ print("Before training: Unet First Layer lora down", _down.weight.data)
567
+ break
568
+
569
+ vae.requires_grad_(False)
570
+ text_encoder.requires_grad_(False)
571
+
572
+ if args.train_text_encoder:
573
+ text_encoder_lora_params, _ = inject_trainable_lora(
574
+ text_encoder,
575
+ target_replace_module=["CLIPAttention"],
576
+ r=args.lora_rank,
577
+ )
578
+ for _up, _down in extract_lora_ups_down(
579
+ text_encoder, target_replace_module=["CLIPAttention"]
580
+ ):
581
+ print("Before training: text encoder First Layer lora up", _up.weight.data)
582
+ print(
583
+ "Before training: text encoder First Layer lora down", _down.weight.data
584
+ )
585
+ break
586
+
587
+ if args.gradient_checkpointing:
588
+ unet.enable_gradient_checkpointing()
589
+ if args.train_text_encoder:
590
+ text_encoder.gradient_checkpointing_enable()
591
+
592
+ if args.scale_lr:
593
+ args.learning_rate = (
594
+ args.learning_rate
595
+ * args.gradient_accumulation_steps
596
+ * args.train_batch_size
597
+ * accelerator.num_processes
598
+ )
599
+
600
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
601
+ if args.use_8bit_adam:
602
+ try:
603
+ import bitsandbytes as bnb
604
+ except ImportError:
605
+ raise ImportError(
606
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
607
+ )
608
+
609
+ optimizer_class = bnb.optim.AdamW8bit
610
+ else:
611
+ optimizer_class = torch.optim.AdamW
612
+
613
+ text_lr = (
614
+ args.learning_rate
615
+ if args.learning_rate_text is None
616
+ else args.learning_rate_text
617
+ )
618
+
619
+ params_to_optimize = (
620
+ [
621
+ {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
622
+ {
623
+ "params": itertools.chain(*text_encoder_lora_params),
624
+ "lr": text_lr,
625
+ },
626
+ ]
627
+ if args.train_text_encoder
628
+ else itertools.chain(*unet_lora_params)
629
+ )
630
+ optimizer = optimizer_class(
631
+ params_to_optimize,
632
+ lr=args.learning_rate,
633
+ betas=(args.adam_beta1, args.adam_beta2),
634
+ weight_decay=args.adam_weight_decay,
635
+ eps=args.adam_epsilon,
636
+ )
637
+
638
+ noise_scheduler = DDPMScheduler.from_config(
639
+ args.pretrained_model_name_or_path, subfolder="scheduler"
640
+ )
641
+
642
+ train_dataset = DreamBoothDataset(
643
+ instance_data_root=args.instance_data_dir,
644
+ instance_prompt=args.instance_prompt,
645
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
646
+ class_prompt=args.class_prompt,
647
+ tokenizer=tokenizer,
648
+ size=args.resolution,
649
+ center_crop=args.center_crop,
650
+ color_jitter=args.color_jitter,
651
+ )
652
+
653
+ def collate_fn(examples):
654
+ input_ids = [example["instance_prompt_ids"] for example in examples]
655
+ pixel_values = [example["instance_images"] for example in examples]
656
+
657
+ # Concat class and instance examples for prior preservation.
658
+ # We do this to avoid doing two forward passes.
659
+ if args.with_prior_preservation:
660
+ input_ids += [example["class_prompt_ids"] for example in examples]
661
+ pixel_values += [example["class_images"] for example in examples]
662
+
663
+ pixel_values = torch.stack(pixel_values)
664
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
665
+
666
+ input_ids = tokenizer.pad(
667
+ {"input_ids": input_ids},
668
+ padding="max_length",
669
+ max_length=tokenizer.model_max_length,
670
+ return_tensors="pt",
671
+ ).input_ids
672
+
673
+ batch = {
674
+ "input_ids": input_ids,
675
+ "pixel_values": pixel_values,
676
+ }
677
+ return batch
678
+
679
+ train_dataloader = torch.utils.data.DataLoader(
680
+ train_dataset,
681
+ batch_size=args.train_batch_size,
682
+ shuffle=True,
683
+ collate_fn=collate_fn,
684
+ num_workers=1,
685
+ )
686
+
687
+ # Scheduler and math around the number of training steps.
688
+ overrode_max_train_steps = False
689
+ num_update_steps_per_epoch = math.ceil(
690
+ len(train_dataloader) / args.gradient_accumulation_steps
691
+ )
692
+ if args.max_train_steps is None:
693
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
694
+ overrode_max_train_steps = True
695
+
696
+ lr_scheduler = get_scheduler(
697
+ args.lr_scheduler,
698
+ optimizer=optimizer,
699
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
700
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
701
+ )
702
+
703
+ if args.train_text_encoder:
704
+ (
705
+ unet,
706
+ text_encoder,
707
+ optimizer,
708
+ train_dataloader,
709
+ lr_scheduler,
710
+ ) = accelerator.prepare(
711
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
712
+ )
713
+ else:
714
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
715
+ unet, optimizer, train_dataloader, lr_scheduler
716
+ )
717
+
718
+ weight_dtype = torch.float32
719
+ if accelerator.mixed_precision == "fp16":
720
+ weight_dtype = torch.float16
721
+ elif accelerator.mixed_precision == "bf16":
722
+ weight_dtype = torch.bfloat16
723
+
724
+ # Move text_encode and vae to gpu.
725
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
726
+ # as these models are only used for inference, keeping weights in full precision is not required.
727
+ vae.to(accelerator.device, dtype=weight_dtype)
728
+ if not args.train_text_encoder:
729
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
730
+
731
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
732
+ num_update_steps_per_epoch = math.ceil(
733
+ len(train_dataloader) / args.gradient_accumulation_steps
734
+ )
735
+ if overrode_max_train_steps:
736
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
737
+ # Afterwards we recalculate our number of training epochs
738
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
739
+
740
+ # We need to initialize the trackers we use, and also store our configuration.
741
+ # The trackers initializes automatically on the main process.
742
+ if accelerator.is_main_process:
743
+ accelerator.init_trackers("dreambooth", config=vars(args))
744
+
745
+ # Train!
746
+ total_batch_size = (
747
+ args.train_batch_size
748
+ * accelerator.num_processes
749
+ * args.gradient_accumulation_steps
750
+ )
751
+
752
+ print("***** Running training *****")
753
+ print(f" Num examples = {len(train_dataset)}")
754
+ print(f" Num batches each epoch = {len(train_dataloader)}")
755
+ print(f" Num Epochs = {args.num_train_epochs}")
756
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
757
+ print(
758
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
759
+ )
760
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
761
+ print(f" Total optimization steps = {args.max_train_steps}")
762
+ # Only show the progress bar once on each machine.
763
+ progress_bar = tqdm(
764
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
765
+ )
766
+ progress_bar.set_description("Steps")
767
+ global_step = 0
768
+ last_save = 0
769
+
770
+ for epoch in range(args.num_train_epochs):
771
+ unet.train()
772
+ if args.train_text_encoder:
773
+ text_encoder.train()
774
+
775
+ for step, batch in enumerate(train_dataloader):
776
+ # Convert images to latent space
777
+ latents = vae.encode(
778
+ batch["pixel_values"].to(dtype=weight_dtype)
779
+ ).latent_dist.sample()
780
+ latents = latents * 0.18215
781
+
782
+ # Sample noise that we'll add to the latents
783
+ noise = torch.randn_like(latents)
784
+ bsz = latents.shape[0]
785
+ # Sample a random timestep for each image
786
+ timesteps = torch.randint(
787
+ 0,
788
+ noise_scheduler.config.num_train_timesteps,
789
+ (bsz,),
790
+ device=latents.device,
791
+ )
792
+ timesteps = timesteps.long()
793
+
794
+ # Add noise to the latents according to the noise magnitude at each timestep
795
+ # (this is the forward diffusion process)
796
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
797
+
798
+ # Get the text embedding for conditioning
799
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
800
+
801
+ # Predict the noise residual
802
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
803
+
804
+ # Get the target for loss depending on the prediction type
805
+ if noise_scheduler.config.prediction_type == "epsilon":
806
+ target = noise
807
+ elif noise_scheduler.config.prediction_type == "v_prediction":
808
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
809
+ else:
810
+ raise ValueError(
811
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
812
+ )
813
+
814
+ if args.with_prior_preservation:
815
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
816
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
817
+ target, target_prior = torch.chunk(target, 2, dim=0)
818
+
819
+ # Compute instance loss
820
+ loss = (
821
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
822
+ .mean([1, 2, 3])
823
+ .mean()
824
+ )
825
+
826
+ # Compute prior loss
827
+ prior_loss = F.mse_loss(
828
+ model_pred_prior.float(), target_prior.float(), reduction="mean"
829
+ )
830
+
831
+ # Add the prior loss to the instance loss.
832
+ loss = loss + args.prior_loss_weight * prior_loss
833
+ else:
834
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
835
+
836
+ accelerator.backward(loss)
837
+ if accelerator.sync_gradients:
838
+ params_to_clip = (
839
+ itertools.chain(unet.parameters(), text_encoder.parameters())
840
+ if args.train_text_encoder
841
+ else unet.parameters()
842
+ )
843
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
844
+ optimizer.step()
845
+ lr_scheduler.step()
846
+ progress_bar.update(1)
847
+ optimizer.zero_grad()
848
+
849
+ global_step += 1
850
+
851
+ # Checks if the accelerator has performed an optimization step behind the scenes
852
+ if accelerator.sync_gradients:
853
+ if args.save_steps and global_step - last_save >= args.save_steps:
854
+ if accelerator.is_main_process:
855
+ # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
856
+ # it, the models will be unwrapped, and when they are then used for further training,
857
+ # we will crash. pass this, but only to newer versions of accelerate. fixes
858
+ # https://github.com/huggingface/diffusers/issues/1566
859
+ accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
860
+ inspect.signature(
861
+ accelerator.unwrap_model
862
+ ).parameters.keys()
863
+ )
864
+ extra_args = (
865
+ {"keep_fp32_wrapper": True}
866
+ if accepts_keep_fp32_wrapper
867
+ else {}
868
+ )
869
+ pipeline = StableDiffusionPipeline.from_pretrained(
870
+ args.pretrained_model_name_or_path,
871
+ unet=accelerator.unwrap_model(unet, **extra_args),
872
+ text_encoder=accelerator.unwrap_model(
873
+ text_encoder, **extra_args
874
+ ),
875
+ revision=args.revision,
876
+ )
877
+
878
+ filename_unet = (
879
+ f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
880
+ )
881
+ filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
882
+ print(f"save weights {filename_unet}, {filename_text_encoder}")
883
+ save_lora_weight(pipeline.unet, filename_unet)
884
+ if args.train_text_encoder:
885
+ save_lora_weight(
886
+ pipeline.text_encoder,
887
+ filename_text_encoder,
888
+ target_replace_module=["CLIPAttention"],
889
+ )
890
+
891
+ for _up, _down in extract_lora_ups_down(pipeline.unet):
892
+ print(
893
+ "First Unet Layer's Up Weight is now : ",
894
+ _up.weight.data,
895
+ )
896
+ print(
897
+ "First Unet Layer's Down Weight is now : ",
898
+ _down.weight.data,
899
+ )
900
+ break
901
+ if args.train_text_encoder:
902
+ for _up, _down in extract_lora_ups_down(
903
+ pipeline.text_encoder,
904
+ target_replace_module=["CLIPAttention"],
905
+ ):
906
+ print(
907
+ "First Text Encoder Layer's Up Weight is now : ",
908
+ _up.weight.data,
909
+ )
910
+ print(
911
+ "First Text Encoder Layer's Down Weight is now : ",
912
+ _down.weight.data,
913
+ )
914
+ break
915
+
916
+ last_save = global_step
917
+
918
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
919
+ progress_bar.set_postfix(**logs)
920
+ accelerator.log(logs, step=global_step)
921
+
922
+ if global_step >= args.max_train_steps:
923
+ break
924
+
925
+ accelerator.wait_for_everyone()
926
+
927
+ # Create the pipeline using using the trained modules and save it.
928
+ if accelerator.is_main_process:
929
+ pipeline = StableDiffusionPipeline.from_pretrained(
930
+ args.pretrained_model_name_or_path,
931
+ unet=accelerator.unwrap_model(unet),
932
+ text_encoder=accelerator.unwrap_model(text_encoder),
933
+ revision=args.revision,
934
+ )
935
+
936
+ print("\n\nLora TRAINING DONE!\n\n")
937
+
938
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
939
+ if args.train_text_encoder:
940
+ save_lora_weight(
941
+ pipeline.text_encoder,
942
+ args.output_dir + "/lora_weight.text_encoder.pt",
943
+ target_replace_module=["CLIPAttention"],
944
+ )
945
+
946
+ if args.push_to_hub:
947
+ repo.push_to_hub(
948
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
949
+ )
950
+
951
+ accelerator.end_training()
952
+
953
+
954
+ if __name__ == "__main__":
955
+ args = parse_args()
956
+ main(args)